"""
desispec.util
=============
Utility functions for desispec.
"""
from __future__ import absolute_import, division, print_function
import argparse
import inspect
import os
import sys
import errno
import time
import datetime
import collections
import numbers
import datetime
import textwrap
import numpy as np
import subprocess as sp
from desiutil.log import get_logger, INFO
[docs]
def is_robust_mode():
"""
Return True/False for whether we should run in robust mode
based upon $DESI_SPECTRO_ROBUST (True/Yes/1)
This is intended to be used e.g. to decide to preproc anyway
even if a dark model can't be found.
"""
if 'DESI_SPECTRO_ROBUST' in os.environ and os.environ['DESI_SPECTRO_ROBUST'].upper() in ('YES', 'TRUE', '1'):
return True
else:
return False
[docs]
def runcmd(cmd, args=None, expandargs=False, inputs=[], outputs=[], comm=None, clobber=False, check_return=False):
"""
Runs a command (function or script), checking for inputs and outputs
Args:
cmd : function object or command string to run with subprocess.call()
Options:
args : list of args to pass to the function or script
expandargs: call function with ``cmd(*args)`` instead of ``cmd(args)``
inputs : list of filename inputs that must exist before running
outputs : list of output filenames that should be created
clobber : if True, run even if outputs already exist
comm : MPI communicator to pass to cmd(..., comm=comm)
check_return : if True, check return value of function and require 0 or None for success
Returns:
(result, success)
Notes:
* If any inputs are missing, don't run cmd and return (None, False).
* If outputs exist and have timestamps after all inputs,
don't run cmd and return (None, True).
* If spawned as a script, return (returncode, (returncode==0)).
* If function raises an exception, return (exception, False).
* If function returns result but outputs are missing, return (result, False).
* If check_return is True and function result != (0 or None), return (result, False)
* If function returns result and all outputs are present, return (result, True).
"""
log = get_logger()
if comm is None:
rank = 0
size = 1
else:
from mpi4py import MPI
size = comm.Get_size()
rank = comm.Get_rank()
if rank == 0:
log.info('runcmd parallel with {} ranks'.format(size))
#- construct log string of what will run
cmd_callable = isinstance(cmd, collections.abc.Callable)
if args is None:
args = tuple()
elif cmd_callable and not expandargs:
args = (args,)
if cmd_callable:
funcname = cmd.__module__ + '.' + cmd.__name__
if expandargs:
cmdstr = f'{funcname}{tuple(args)}'
else:
argstr = ', '.join([str(tmp) for tmp in args])
cmdstr = f'{funcname}({argstr})'
else:
cmdstr = cmd + ' ' + ' '.join(args)
#- Check that inputs exist, and timestamp of latest input file
missing_inputs = False
input_time = 0
if rank == 0:
for x in inputs:
if not os.path.exists(x):
log.error(f"missing input {x}")
missing_inputs = True
else:
input_time = max(input_time, os.stat(x).st_mtime)
if comm is not None:
missing_inputs = comm.bcast(missing_inputs, root=0)
input_time = comm.bcast(input_time, root=0)
if missing_inputs:
if rank == 0:
log.critical(f"FAILED missing required inputs: {cmdstr}")
return None, False #- results=None, success=False
#- Check if outputs already exist and that their timestamp is after
#- the last input timestamp
already_done = (not clobber) and (len(outputs) > 0)
if rank == 0 and not clobber:
for x in outputs:
if not os.path.exists(x):
already_done = False
break
if len(inputs)>0 and os.stat(x).st_mtime < input_time:
already_done = False
break
if comm is not None:
already_done = comm.bcast(already_done, root=0)
if already_done:
if rank == 0:
log.info("SKIPPING: {}".format(cmdstr, rank))
return None, True #- results=None, success=True
#- Green light to go; print input/output info
#- Use log.level to decide verbosity, but avoid long prefixes
if rank == 0:
log.info(time.asctime())
log.info("RUNNING: {}".format(cmdstr))
if log.level <= INFO:
if len(inputs) > 0:
print(" Inputs")
for x in inputs:
print(" ", x)
if len(outputs) > 0:
print(" Outputs")
for x in outputs:
print(" ", x)
#- run command
success = True
result = None
try:
if cmd_callable:
if comm is None:
result = cmd(*args)
else:
result = cmd(*args, comm=comm)
if check_return:
if result not in (0, None):
success = False
else: # not a callable function, spawn as script
result = sp.call(cmdstr, shell=True)
success = (result == 0)
except (BaseException, Exception) as e:
frame,filename,line_number,function_name,lines,index = inspect.stack()[1]
log.critical(f'FAILED rank {rank} exception while running {cmdstr} called from line {line_number} in {filename}')
result = e
success = False
if rank == 0:
import traceback
lines = traceback.format_exception(*sys.exc_info())
for line in lines:
line = line.strip()
log.error(f'{line}')
#- success only if all succeed
if comm is not None:
success = np.all(comm.gather(success, root=0))
success = comm.bcast(success, root=0)
if not success:
if rank == 0:
log.critical(f"FAILED {cmdstr}")
return result, False
#- Check for outputs
outputs_present = True
if rank == 0:
for x in outputs:
if not os.path.exists(x):
log.error("missing output {}".format(rank,x))
outputs_present = False
if comm is not None:
outputs_present = comm.bcast(outputs_present, root=0)
if outputs_present:
if rank == 0:
log.info("SUCCESS: {}".format(cmdstr))
return result, True
else:
log.critical("FAILED missing outputs {}".format(cmdstr))
return result, False
#- Backstop: we shouldn't have gotten here (should have returned)
log.error(f'should not have gotten here')
return None, False
[docs]
def mpi_count_failures(num_cmd, num_err, comm=None):
"""
Sum num_cmd and num_err across MPI ranks
Args:
num_cmd (int): number of commands run
num_err (int): number of failures
Options:
comm: mpi4py communicator
Returns:
sum(num_cmd), sum(num_err) summed across all MPI ranks
If ``comm`` is None, returns input num_cmd, num_err
"""
if comm is None:
return num_cmd, num_err
rank = comm.rank
size = comm.size
if num_cmd is None:
num_cmd = 0
if num_err is None:
num_err = 0
num_cmd_all = np.sum(comm.gather(num_cmd, root=0))
num_err_all = np.sum(comm.gather(num_err, root=0))
num_cmd_all = comm.bcast(num_cmd_all, root=0)
num_err_all = comm.bcast(num_err_all, root=0)
return num_cmd_all, num_err_all
[docs]
def sprun(com, capture=False, input=None):
"""Run a command with subprocess and handle errors.
This runs a command and returns the lines of STDOUT as a list.
Any contents of STDERR are logged. If an OSError is raised by
the child process, that is also logged. If another exception is
raised by the child process, the traceback from the child process
is printed.
Args:
com (list): the command to run.
capture (bool): if True, return the stdout contents.
input (str): the string data (can include embedded newlines) to write
to the STDIN of the child process.
Returns:
tuple(int, (list)): the return code and optionally the lines of STDOUT
from the child process.
"""
import traceback
log = get_logger()
stdin = None
if input is not None:
stdin = sp.PIPE
out = None
err = None
ret = -1
try:
with sp.Popen(com, stdin=stdin, stdout=sp.PIPE, stderr=sp.PIPE, universal_newlines=True) as p:
if input is None:
out, err = p.communicate()
else:
out, err = p.communicate(input=input)
for line in err.splitlines():
log.info("STDERR: {}".format(line))
ret = p.returncode
except OSError as e:
log.error("OSError: {}".format(e.errno))
log.error("OSError: {}".format(e.strerror))
log.error("OSError: {}".format(e.filename))
except:
exc_type, exc_value, exc_traceback = sys.exc_info()
lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
for line in lines:
log.error("exception: {}".format(line))
if capture:
return ret, out.splitlines()
else:
for line in out.splitlines():
print(line)
return ret
[docs]
def pid_exists( pid ):
"""Check whether pid exists in the current process table.
**UNIX only.** Should work the same as psutil.pid_exists().
Args:
pid (int): A process ID.
Returns:
pid_exists (bool): ``True`` if the process exists in the current process table.
"""
if pid < 0:
return False
if pid == 0:
# According to "man 2 kill" PID 0 refers to every process
# in the process group of the calling process.
# On certain systems 0 is a valid PID but we have no way
# to know that in a portable fashion.
raise ValueError('invalid PID 0')
try:
os.kill(pid, 0)
except OSError as err:
if err.errno == errno.ESRCH:
# ESRCH == No such process
return False
elif err.errno == errno.EPERM:
# EPERM clearly means there's a process to deny access to
return True
else:
# According to "man 2 kill" possible error values are
# (EINVAL, EPERM, ESRCH)
raise
else:
return True
[docs]
def option_list(opts):
"""Convert key, value pairs into command-line options.
Parameters
----------
opts : dict-like
Convert a dictionary into command-line options.
Returns
-------
:class:`list`
A list of command-line options.
"""
optlist = []
for key, val in opts.items():
keystr = "--{}".format(key)
if isinstance(val, bool):
if val:
optlist.append(keystr)
else:
optlist.append(keystr)
if isinstance(val, float):
optlist.append("{:.14e}".format(val))
elif isinstance(val, (list, tuple)):
optlist.extend(val)
else:
optlist.append("{}".format(val))
return optlist
[docs]
def mask32(mask):
'''
Return an input mask as unsigned 32-bit
Raises ValueError if 64-bit input can't be cast to 32-bit without losing
info (i.e. if it contains values > 2**32-1)
'''
if mask.dtype in (
np.dtype('i4'), np.dtype('u4'),
np.dtype('>i4'), np.dtype('>u4'),
np.dtype('<i4'), np.dtype('<u4'),
):
if mask.dtype.isnative:
return mask.view('u4')
else:
return mask.astype('u4')
elif mask.dtype in (
np.dtype('i8'), np.dtype('u8'),
np.dtype('>i8'), np.dtype('>u8'),
np.dtype('<i8'), np.dtype('<u8'),
):
if mask.dtype.isnative:
mask64 = mask.view('u8')
else:
mask64 = mask.astype('i8')
if np.any(mask64 > 2**32-1):
raise ValueError("mask with values above 2**32-1 can't be cast to 32-bit")
return np.asarray(mask, dtype='u4')
elif mask.dtype in (
np.dtype('bool'),
np.dtype('i2'), np.dtype('u2'),
np.dtype('>i2'), np.dtype('>u2'),
np.dtype('<i2'), np.dtype('<u2'),
np.dtype('i1'), np.dtype('u1'),
np.dtype('>i1'), np.dtype('>u1'),
np.dtype('<i1'), np.dtype('<u1'),
):
return np.asarray(mask, dtype='u4')
else:
raise ValueError("Can't cast dtype {} to unsigned 32-bit".format(mask.dtype))
[docs]
def night2ymd(night):
"""
parse night YEARMMDD into tuple of integers (year, month, day)
"""
night = str(night) # support both in and str input
if len(night) != 8:
raise ValueError(f'invalid YEARMMDD night string {night=}')
year = int(night[0:4])
month = int(night[4:6])
day = int(night[6:8])
if month < 1 or 12 < month:
raise ValueError('MM month should be 1-12, not {}'.format(month))
if day < 1 or 31 < day:
raise ValueError('DD day should be 1-31, not {}'.format(day))
return (year, month, day)
[docs]
def night2dateobj(night):
"""
parse night YEARMMDD into a datetime.date object
"""
year, mm, dd = night2ymd(night)
return datetime.date(year=year, month=mm, day=dd)
[docs]
def dateobj2night(dateobj):
"""
Convert datetime.date object into YEARMMDD int
"""
return int(ymd2night(dateobj.year, dateobj.month, dateobj.day))
[docs]
def difference_nights(firstnight, secondnight):
"""
parse two YEARMMDD nights (ints or strings) and determine the number of
days between them, returning secondnight-firstnight
"""
dt1 = night2dateobj(str(firstnight))
dt2 = night2dateobj(str(secondnight))
difference = dt2 - dt1
return difference.days
[docs]
def ymd2night(year, month, day):
"""
convert year, month, day integers into cannonical YEARMMDD night string
"""
return "{:04d}{:02d}{:02d}".format(year, month, day)
[docs]
def mjd2night(mjd):
"""
Convert MJD to YEARMMDD int night of KPNO sunset
"""
from astropy.time import Time
night = int(Time(mjd - 7/24. - 12/24., format='mjd').strftime('%Y%m%d'))
return night
[docs]
def dateobs2night(dateobs):
"""
Convert DATE-OBS ISO8601 UTC string to YEARMMDD int night of KPNO sunset
"""
# use astropy to flexibily handle multiple valid ISO8601 variants
from astropy.time import Time
try:
mjd = Time(dateobs).mjd
except ValueError:
#- only use optional dependency dateutil if needed;
#- it can handle some ISO8601 timezone variants that astropy can't
from dateutil.parser import isoparser
mjd = Time(isoparser().isoparse(dateobs)).mjd
return mjd2night(mjd)
[docs]
def parse_keyval(keyval):
"""
Parse "key=val" string -> (key,val) tuple with int/float/str/bool val
Args:
keyval (str): "key=value" string
Returns (key, value) tuple where value has been promoted from string
into int/float/bool if possible.
value="True" or "False" becomes boolean True/False, but all other forms
like "T"/"F" or "true"/"false" remain strings.
0 and 1 become ints, not bool.
"""
key, value_string = keyval.split('=', maxsplit=1)
try:
value = int(value_string)
except ValueError:
try:
value = float(value_string)
except ValueError:
if value_string.strip() == 'True':
value = True
elif value_string.strip() == 'False':
value = False
else:
value = value_string
return (key, value)
[docs]
def combine_ivar(ivar1, ivar2):
"""
Returns the combined inverse variance of two inputs, making sure not to
divide by 0 in the process.
ivar1 and ivar2 may be scalar or ndarray but must have the same dimensions
"""
iv1 = np.atleast_1d(ivar1) #- handle list, tuple, ndarray, and scalar input
iv2 = np.atleast_1d(ivar2)
assert np.all(iv1 >= 0), 'ivar1 has negative elements'
assert np.all(iv2 >= 0), 'ivar2 has negative elements'
assert iv1.shape == iv2.shape, 'shape mismatch {} vs. {}'.format(iv1.shape, iv2.shape)
ii = (iv1 > 0) & (iv2 > 0)
ivar = np.zeros(iv1.shape)
ivar[ii] = 1.0 / (1.0/iv1[ii] + 1.0/iv2[ii])
#- Convert back to python float if input was scalar
if isinstance(ivar1, (float, numbers.Integral)):
# Fix "Conversion of an array with ndim > 0 to a scalar is deprecated"
return float(ivar[0])
#- If input was 0-dim numpy array, convert back to 0-di
elif ivar1.ndim == 0:
return np.asarray(ivar[0])
else:
return ivar
_matplotlib_backend = None
[docs]
def set_backend(backend='agg'):
"""
Set matplotlib to use a batch-friendly backend
This function is safe to call multiple times without tripping on a
previously set backend (which remains set)
"""
global _matplotlib_backend
if _matplotlib_backend is None:
_matplotlib_backend = backend
import matplotlib
matplotlib.use(_matplotlib_backend)
return
[docs]
def healpix_degrade_fixed(nside, pixel):
"""
Degrade a NEST ordered healpix pixel with a fixed ratio.
This degrades the pixel to a lower nside value that is
fixed to half the healpix "factor".
Args:
nside (int): a valid NSIDE value.
pixel (int): the NESTED pixel index.
Returns (tuple):
a tuple of ints, where the first value is the new
NSIDE and the second value is the degraded pixel
index.
"""
factor = int(np.log2(nside))
subfactor = factor // 2
subnside = 2**subfactor
subpixel = pixel >> (factor - subfactor)
return (subnside, subpixel)
[docs]
def parse_int_args(arg_string, include_end=False) :
"""
Short func that parses a string containing a comma separated list of
integers, which can include ":" or ".." or "-" labeled ranges
Args:
arg_string (str) : list of integers or integer ranges
Options:
include_end (bool): if True, include end-value in ranges
Returns (array 1-D):
1D numpy array listing all of the integers given in the list,
including enumerations of ranges given.
Note: this follows python-style ranges, i,e, 1:5 or 1..5 returns 1,2,3,4
unless `include_end` is True, which then returns 1,2,3,4,5
"""
if arg_string is None :
return np.array([], dtype=int)
else:
arg_string = str(arg_string)
if len(arg_string.strip(' \t'))==0:
return np.array([])
if include_end:
pad = 1
else:
pad = 0
fibers=[]
log = get_logger()
for sub in arg_string.split(',') :
sub = sub.replace(' ','')
if sub.isdigit() :
fibers.append(int(sub))
continue
match = False
for symbol in [':','..','-']:
if not match and symbol in sub:
tmp = sub.split(symbol)
if (len(tmp) == 2) and tmp[0].isdigit() and tmp[1].isdigit() :
match = True
for f in range(int(tmp[0]),int(tmp[1])+pad) :
fibers.append(f)
if not match:
msg = "parsing error. Didn't understand {}".format(sub)
log.error(msg)
raise ValueError(msg)
return np.array(fibers)
[docs]
def parse_fibers(fiber_string, include_end=False) :
"""
Short func that parses a string containing a comma separated list of
integers, which can include ":" or ".." or "-" labeled ranges
Args:
fiber_string (str) : list of integers or integer ranges
Options:
include_end (bool): if True, include end-value in ranges
Returns (array 1-D):
1D numpy array listing all of the integers given in the list,
including enumerations of ranges given.
Note: this follows python-style ranges, i,e, 1:5 or 1..5 returns 1, 2, 3, 4
unless `include_end` is True, which then returns 1,2,3,4,5
"""
return parse_int_args(fiber_string, include_end)
[docs]
def parse_nights(nights_string, include_end=False) :
"""
Short func that parses a string containing a comma separated list of
YYYYMMDD, which can include ":" or ".." or "-" labeled ranges
Args:
nights_string (str) : list of integers or integer ranges
Options:
include_end (bool): if True, include end-value in ranges
Returns (array 1-D):
1D numpy array listing all of the integers given in the list,
including enumerations of ranges given.
Note: this follows python-style ranges, i.e. 20250101-20250103
returns [20250101, 20250102] unless `include_end` is True,
which then returns [20250101, 20250102, 20250103].
Ranges that span month and year boundaries are handled correctly,
including leap-years.
"""
import datetime
tmpvalues = parse_int_args(nights_string, include_end)
# now keep only valid YYYYMMDD
values=[]
for value in tmpvalues :
try:
#- basic checks on month/day ranges and ability to convert to datetime.date
date = night2dateobj(value)
#- success, so add to list of values
values.append(value)
except ValueError:
pass
return np.array(values)
[docs]
def get_night_range(night, before, after):
"""
Generate an array of YEARMMDD ints for a range of nights before and after a given night.
Args:
night (int): reference YEARMMDD night
before (int): number of nights before `night` to include
after (int): number of nights after `night` to include
Returns:
array of YEARMMDD night integers
Example: get_night_range(20250501, 2, 3) -> [20250428, 20250430, 20250501, 20250502, 20250503, 20250504],
i.e. two nights before, the night requested, and 3 nights after.
"""
nightobj = night2dateobj(night)
firstnight = dateobj2night(nightobj - datetime.timedelta(days=before))
lastnight = dateobj2night(nightobj + datetime.timedelta(days=after))
return parse_nights(f'{firstnight}:{lastnight}', include_end=True)
[docs]
def ordered_unique(ar, return_index=False):
"""Find the unique elements of an array in the order they first appear
Like numpy.unique, but preserves original order instead of sorting
Args:
ar: array-like data to find unique elements
Options:
return_index: if True also return indices in ar where items first appear
"""
ar = np.asarray(ar)
unique, sortedidx = np.unique(ar, return_index=True)
ii = np.argsort(sortedidx)
indices = sortedidx[ii]
unique = ar[indices]
if return_index:
return unique, indices
else:
return unique
#- Not yet used, but a snippet of code that might be useful
#- e.g. for mapping TARGETID to the rows in which they appear
[docs]
def itemindices(a):
"""
Return dict[key] -> list of indices i where a[i] == key
Args:
a : array-like of hashable values
Return dict[key] -> list of indices i where a[i] == key
The dict keys are inserted in the order that they first appear in a,
and the value lists of indices are sorted
e.g. itemindices([10,30,20,30]) -> {10: [0], 30: [1, 3], 20: [2]}
"""
#- there is probably a more efficient way of doing this, but this code
#- can map 100k targetids in <50ms which is sufficient
idmap = dict()
for i, x in enumerate(a):
if x not in idmap:
idmap[x] = [i,]
else:
idmap[x].append(i)
return idmap
[docs]
def argmatch(a, b):
"""
Returns indices ii such that a[ii] == b
Args:
a: array-like
b: array-like
Returns indices ii such that a[ii] == b
Both `a` and `b` are allowed to have repeats, and `a` values can be a
superset of `b`, but `b` cannot contain values that are not in `a`
because then no indices `ii` could result in `a[ii] == b`.
Related: desitarget.geomask.match_to which is similar, but doesn't allow
duplicates in `b`.
"""
a = np.asarray(a)
b = np.asarray(b)
ii = np.argsort(a)
jj = np.argsort(b)
kk = np.searchsorted(a[ii], b[jj])
try:
match_indices = ii[kk[np.argsort(jj)]]
except IndexError:
#- if b has elements not in a, that can fail;
#- only do expensive check if needed
bad_b = np.isin(b, a, invert=True)
if np.any(bad_b):
raise ValueError(f'b contains values not in a; impossible to match {set(b[bad_b])} to {a=}')
else:
#- this should not occur
raise RuntimeError(f'argmatch failure for unknown reason {a=}, {b=}')
if a.dtype.names is None:
if np.any(a[match_indices] != b):
#- this should not occur
raise RuntimeError(f'argmatch failure for unknown reason {a=} {match_indices=} {a[match_indices]=} != {b=}')
else:
#- compare column by column, since comparing the whole table will fail on harmless
#- dtype mismatches e.g. int32 vs. int64 even if all values are the same
for col in a.dtype.names:
if np.any(a[col][match_indices] != b[col]):
raise RuntimeError(f'argmatch failure for unknown reason {a[col]=} {match_indices=} {a[col][match_indices]=} != {b[col]=}')
return match_indices
[docs]
def wrap_long_logs(text, width=120):
"""
Wraps long log messages to a specified width.
Args:
text: The text message to wrap and log.
width: The maximum width of each line in the log message.
Returns:
str: The wrapped log message.
"""
return textwrap.fill(text, width=width)
[docs]
def convert_to_pandas(data, columns):
"""
Convert table-like data to pandas DataFrame with specified columns.
Args:
data: astropy Table, numpy structured array, or pandas DataFrame.
columns: List of column names to select from the input data.
Returns:
A pandas DataFrame containing only the specified columns.
"""
from astropy.table import Table
import pandas as pd
if isinstance(data, Table):
return data[columns].to_pandas()
elif isinstance(data, np.ndarray):
data = data[columns]
data = data.astype(data.dtype.newbyteorder('=')) # ensure native byte order for pandas
return pd.DataFrame(data)
elif isinstance(data, pd.DataFrame):
return data[columns]
else:
raise ValueError(f'Unexpected {type(data)=}, should be astropy Table, numpy structured array, or pandas DataFrame')