"""
desispec.util
=============
Utility functions for desispec.
"""
from __future__ import absolute_import, division, print_function
import argparse
import inspect
import os
import sys
import errno
import time
import datetime
import collections
import numbers
import datetime
import numpy as np
import subprocess as sp
from desiutil.log import get_logger, INFO
[docs]def runcmd(cmd, args=None, expandargs=False, inputs=[], outputs=[], comm=None, clobber=False):
"""
Runs a command (function or script), checking for inputs and outputs
Args:
cmd : function object or command string to run with subprocess.call()
Options:
args : list of args to pass to the function or script
expandargs: call function with ``cmd(*args)`` instead of ``cmd(args)``
inputs : list of filename inputs that must exist before running
outputs : list of output filenames that should be created
clobber : if True, run even if outputs already exist
comm : MPI communicator to pass to cmd(..., comm=comm)
Returns:
(result, success)
Notes:
* If any inputs are missing, don't run cmd and return (None, False).
* If outputs exist and have timestamps after all inputs,
don't run cmd and return (None, True).
* If spawned as a script, return (returncode, (returncode==0)).
* If function raises an exception, return (exception, False).
* If function returns result but outputs are missing, return (result, False).
* If function returns result and all outputs are present, return (result, True).
"""
log = get_logger()
if comm is None:
rank = 0
size = 1
else:
from mpi4py import MPI
size = comm.Get_size()
rank = comm.Get_rank()
if rank == 0:
log.info('runcmd parallel with {} ranks'.format(size))
#- construct log string of what will run
cmd_callable = isinstance(cmd, collections.abc.Callable)
if args is None:
args = tuple()
elif cmd_callable and not expandargs:
args = (args,)
if cmd_callable:
funcname = cmd.__module__ + '.' + cmd.__name__
if expandargs:
cmdstr = f'{funcname}{tuple(args)}'
else:
argstr = ', '.join([str(tmp) for tmp in args])
cmdstr = f'{funcname}({argstr})'
else:
cmdstr = cmd + ' ' + ' '.join(args)
#- Check that inputs exist, and timestamp of latest input file
missing_inputs = False
input_time = 0
if rank == 0:
for x in inputs:
if not os.path.exists(x):
log.error(f"missing input {x}")
missing_inputs = True
else:
input_time = max(input_time, os.stat(x).st_mtime)
if comm is not None:
missing_inputs = comm.bcast(missing_inputs, root=0)
input_time = comm.bcast(input_time, root=0)
if missing_inputs:
if rank == 0:
log.critical(f"FAILED missing required inputs: {cmdstr}")
return None, False #- results=None, success=False
#- Check if outputs already exist and that their timestamp is after
#- the last input timestamp
already_done = (not clobber) and (len(outputs) > 0)
if rank == 0 and not clobber:
for x in outputs:
if not os.path.exists(x):
already_done = False
break
if len(inputs)>0 and os.stat(x).st_mtime < input_time:
already_done = False
break
if comm is not None:
already_done = comm.bcast(already_done, root=0)
if already_done:
if rank == 0:
log.info("SKIPPING: {}".format(cmdstr, rank))
return None, True #- results=None, success=True
#- Green light to go; print input/output info
#- Use log.level to decide verbosity, but avoid long prefixes
if rank == 0:
log.info(time.asctime())
log.info("RUNNING: {}".format(cmdstr))
if log.level <= INFO:
if len(inputs) > 0:
print(" Inputs")
for x in inputs:
print(" ", x)
if len(outputs) > 0:
print(" Outputs")
for x in outputs:
print(" ", x)
#- run command
success = True
result = None
try:
if cmd_callable:
if comm is None:
result = cmd(*args)
else:
result = cmd(*args, comm=comm)
else:
result = sp.call(cmdstr, shell=True)
success = (result == 0)
except (BaseException, Exception) as e:
frame,filename,line_number,function_name,lines,index = inspect.stack()[1]
log.critical(f'FAILED rank {rank} exception while running {cmdstr} called from line {line_number} in {filename}')
result = e
success = False
if rank == 0:
import traceback
lines = traceback.format_exception(*sys.exc_info())
for line in lines:
line = line.strip()
log.error(f'{line}')
#- success only if all succeed
if comm is not None:
success = np.all(comm.gather(success, root=0))
success = comm.bcast(success, root=0)
if not success:
if rank == 0:
log.critical(f"FAILED {cmdstr}")
return result, False
#- Check for outputs
outputs_present = True
if rank == 0:
for x in outputs:
if not os.path.exists(x):
log.error("missing output {}".format(rank,x))
outputs_present = False
if comm is not None:
outputs_present = comm.bcast(outputs_present, root=0)
if outputs_present:
if rank == 0:
log.info("SUCCESS: {}".format(cmdstr))
return result, True
else:
log.critical("FAILED missing outputs {}".format(cmdstr))
return result, False
#- Backstop: we shouldn't have gotten here (should have returned)
log.error(f'should not have gotten here')
return None, False
[docs]def mpi_count_failures(num_cmd, num_err, comm=None):
"""
Sum num_cmd and num_err across MPI ranks
Args:
num_cmd (int): number of commands run
num_err (int): number of failures
Options:
comm: mpi4py communicator
Returns:
sum(num_cmd), sum(num_err) summed across all MPI ranks
If ``comm`` is None, returns input num_cmd, num_err
"""
if comm is None:
return num_cmd, num_err
rank = comm.rank
size = comm.size
if num_cmd is None:
num_cmd = 0
if num_err is None:
num_err = 0
num_cmd_all = np.sum(comm.gather(num_cmd, root=0))
num_err_all = np.sum(comm.gather(num_err, root=0))
num_cmd_all = comm.bcast(num_cmd_all, root=0)
num_err_all = comm.bcast(num_err_all, root=0)
return num_cmd_all, num_err_all
[docs]def sprun(com, capture=False, input=None):
"""Run a command with subprocess and handle errors.
This runs a command and returns the lines of STDOUT as a list.
Any contents of STDERR are logged. If an OSError is raised by
the child process, that is also logged. If another exception is
raised by the child process, the traceback from the child process
is printed.
Args:
com (list): the command to run.
capture (bool): if True, return the stdout contents.
input (str): the string data (can include embedded newlines) to write
to the STDIN of the child process.
Returns:
tuple(int, (list)): the return code and optionally the lines of STDOUT
from the child process.
"""
import traceback
log = get_logger()
stdin = None
if input is not None:
stdin = sp.PIPE
out = None
err = None
ret = -1
try:
with sp.Popen(com, stdin=stdin, stdout=sp.PIPE, stderr=sp.PIPE, universal_newlines=True) as p:
if input is None:
out, err = p.communicate()
else:
out, err = p.communicate(input=input)
for line in err.splitlines():
log.info("STDERR: {}".format(line))
ret = p.returncode
except OSError as e:
log.error("OSError: {}".format(e.errno))
log.error("OSError: {}".format(e.strerror))
log.error("OSError: {}".format(e.filename))
except:
exc_type, exc_value, exc_traceback = sys.exc_info()
lines = traceback.format_exception(exc_type, exc_value, exc_traceback)
for line in lines:
log.error("exception: {}".format(line))
if capture:
return ret, out.splitlines()
else:
for line in out.splitlines():
print(line)
return ret
[docs]def pid_exists( pid ):
"""Check whether pid exists in the current process table.
**UNIX only.** Should work the same as psutil.pid_exists().
Args:
pid (int): A process ID.
Returns:
pid_exists (bool): ``True`` if the process exists in the current process table.
"""
if pid < 0:
return False
if pid == 0:
# According to "man 2 kill" PID 0 refers to every process
# in the process group of the calling process.
# On certain systems 0 is a valid PID but we have no way
# to know that in a portable fashion.
raise ValueError('invalid PID 0')
try:
os.kill(pid, 0)
except OSError as err:
if err.errno == errno.ESRCH:
# ESRCH == No such process
return False
elif err.errno == errno.EPERM:
# EPERM clearly means there's a process to deny access to
return True
else:
# According to "man 2 kill" possible error values are
# (EINVAL, EPERM, ESRCH)
raise
else:
return True
[docs]def option_list(opts):
"""Convert key, value pairs into command-line options.
Parameters
----------
opts : dict-like
Convert a dictionary into command-line options.
Returns
-------
:class:`list`
A list of command-line options.
"""
optlist = []
for key, val in opts.items():
keystr = "--{}".format(key)
if isinstance(val, bool):
if val:
optlist.append(keystr)
else:
optlist.append(keystr)
if isinstance(val, float):
optlist.append("{:.14e}".format(val))
elif isinstance(val, (list, tuple)):
optlist.extend(val)
else:
optlist.append("{}".format(val))
return optlist
[docs]def mask32(mask):
'''
Return an input mask as unsigned 32-bit
Raises ValueError if 64-bit input can't be cast to 32-bit without losing
info (i.e. if it contains values > 2**32-1)
'''
if mask.dtype in (
np.dtype('i4'), np.dtype('u4'),
np.dtype('>i4'), np.dtype('>u4'),
np.dtype('<i4'), np.dtype('<u4'),
):
if mask.dtype.isnative:
return mask.view('u4')
else:
return mask.astype('u4')
elif mask.dtype in (
np.dtype('i8'), np.dtype('u8'),
np.dtype('>i8'), np.dtype('>u8'),
np.dtype('<i8'), np.dtype('<u8'),
):
if mask.dtype.isnative:
mask64 = mask.view('u8')
else:
mask64 = mask.astype('i8')
if np.any(mask64 > 2**32-1):
raise ValueError("mask with values above 2**32-1 can't be cast to 32-bit")
return np.asarray(mask, dtype='u4')
elif mask.dtype in (
np.dtype('bool'), np.dtype('bool8'),
np.dtype('i2'), np.dtype('u2'),
np.dtype('>i2'), np.dtype('>u2'),
np.dtype('<i2'), np.dtype('<u2'),
np.dtype('i1'), np.dtype('u1'),
np.dtype('>i1'), np.dtype('>u1'),
np.dtype('<i1'), np.dtype('<u1'),
):
return np.asarray(mask, dtype='u4')
else:
raise ValueError("Can't cast dtype {} to unsigned 32-bit".format(mask.dtype))
[docs]def night2ymd(night):
"""
parse night YEARMMDD string into tuple of integers (year, month, day)
"""
assert isinstance(night, str), 'night is not a string'
assert len(night) == 8, f'invalid YEARMMDD night string {night=}'
year = int(night[0:4])
month = int(night[4:6])
day = int(night[6:8])
if month < 1 or 12 < month:
raise ValueError('MM month should be 1-12, not {}'.format(month))
if day < 1 or 31 < day:
raise ValueError('DD day should be 1-31, not {}'.format(day))
return (year, month, day)
[docs]def night2dateobj(night):
"""
parse night YEARMMDD string into a datetime.date object
"""
year, mm, dd = night2ymd(night)
return datetime.date(year=year, month=mm, day=dd)
[docs]def difference_nights(firstnight, secondnight):
"""
parse two YEARMMDD nights (ints or strings) and determine the number of
days between them, returning secondnight-firstnight
"""
dt1 = night2dateobj(str(firstnight))
dt2 = night2dateobj(str(secondnight))
difference = dt2 - dt1
return difference.days
[docs]def ymd2night(year, month, day):
"""
convert year, month, day integers into cannonical YEARMMDD night string
"""
return "{:04d}{:02d}{:02d}".format(year, month, day)
[docs]def mjd2night(mjd):
"""
Convert MJD to YEARMMDD int night of KPNO sunset
"""
from astropy.time import Time
night = int(Time(mjd - 7/24. - 12/24., format='mjd').strftime('%Y%m%d'))
return night
[docs]def dateobs2night(dateobs):
"""
Convert DATE-OBS ISO8601 UTC string to YEARMMDD int night of KPNO sunset
"""
# use astropy to flexibily handle multiple valid ISO8601 variants
from astropy.time import Time
try:
mjd = Time(dateobs).mjd
except ValueError:
#- only use optional dependency dateutil if needed;
#- it can handle some ISO8601 timezone variants that astropy can't
from dateutil.parser import isoparser
mjd = Time(isoparser().isoparse(dateobs))
return mjd2night(mjd)
[docs]def parse_keyval(keyval):
"""
Parse "key=val" string -> (key,val) tuple with int/float/str/bool val
Args:
keyval (str): "key=value" string
Returns (key, value) tuple where value has been promoted from string
into int/float/bool if possible.
value="True" or "False" becomes boolean True/False, but all other forms
like "T"/"F" or "true"/"false" remain strings.
0 and 1 become ints, not bool.
"""
key, value_string = keyval.split('=', maxsplit=1)
try:
value = int(value_string)
except ValueError:
try:
value = float(value_string)
except ValueError:
if value_string.strip() == 'True':
value = True
elif value_string.strip() == 'False':
value = False
else:
value = value_string
return (key, value)
[docs]def combine_ivar(ivar1, ivar2):
"""
Returns the combined inverse variance of two inputs, making sure not to
divide by 0 in the process.
ivar1 and ivar2 may be scalar or ndarray but must have the same dimensions
"""
iv1 = np.atleast_1d(ivar1) #- handle list, tuple, ndarray, and scalar input
iv2 = np.atleast_1d(ivar2)
assert np.all(iv1 >= 0), 'ivar1 has negative elements'
assert np.all(iv2 >= 0), 'ivar2 has negative elements'
assert iv1.shape == iv2.shape, 'shape mismatch {} vs. {}'.format(iv1.shape, iv2.shape)
ii = (iv1 > 0) & (iv2 > 0)
ivar = np.zeros(iv1.shape)
ivar[ii] = 1.0 / (1.0/iv1[ii] + 1.0/iv2[ii])
#- Convert back to python float if input was scalar
if isinstance(ivar1, (float, numbers.Integral)):
return float(ivar)
#- If input was 0-dim numpy array, convert back to 0-di
elif ivar1.ndim == 0:
return np.asarray(ivar[0])
else:
return ivar
_matplotlib_backend = None
[docs]def set_backend(backend='agg'):
"""
Set matplotlib to use a batch-friendly backend
This function is safe to call multiple times without tripping on a
previously set backend (which remains set)
"""
global _matplotlib_backend
if _matplotlib_backend is None:
_matplotlib_backend = backend
import matplotlib
matplotlib.use(_matplotlib_backend)
return
[docs]def healpix_degrade_fixed(nside, pixel):
"""
Degrade a NEST ordered healpix pixel with a fixed ratio.
This degrades the pixel to a lower nside value that is
fixed to half the healpix "factor".
Args:
nside (int): a valid NSIDE value.
pixel (int): the NESTED pixel index.
Returns (tuple):
a tuple of ints, where the first value is the new
NSIDE and the second value is the degraded pixel
index.
"""
factor = int(np.log2(nside))
subfactor = factor // 2
subnside = 2**subfactor
subpixel = pixel >> (factor - subfactor)
return (subnside, subpixel)
[docs]def parse_int_args(arg_string, include_end=False) :
"""
Short func that parses a string containing a comma separated list of
integers, which can include ":" or ".." or "-" labeled ranges
Args:
arg_string (str) : list of integers or integer ranges
Options:
include_end (bool): if True, include end-value in ranges
Returns (array 1-D):
1D numpy array listing all of the integers given in the list,
including enumerations of ranges given.
Note: this follows python-style ranges, i,e, 1:5 or 1..5 returns 1,2,3,4
unless `include_end` is True, which then returns 1,2,3,4,5
"""
if arg_string is None :
return np.array([], dtype=int)
else:
arg_string = str(arg_string)
if len(arg_string.strip(' \t'))==0:
return np.array([])
if include_end:
pad = 1
else:
pad = 0
fibers=[]
log = get_logger()
for sub in arg_string.split(',') :
sub = sub.replace(' ','')
if sub.isdigit() :
fibers.append(int(sub))
continue
match = False
for symbol in [':','..','-']:
if not match and symbol in sub:
tmp = sub.split(symbol)
if (len(tmp) == 2) and tmp[0].isdigit() and tmp[1].isdigit() :
match = True
for f in range(int(tmp[0]),int(tmp[1])+pad) :
fibers.append(f)
if not match:
msg = "parsing error. Didn't understand {}".format(sub)
log.error(msg)
raise ValueError(msg)
return np.array(fibers)
[docs]def parse_fibers(fiber_string, include_end=False) :
"""
Short func that parses a string containing a comma separated list of
integers, which can include ":" or ".." or "-" labeled ranges
Args:
fiber_string (str) : list of integers or integer ranges
Options:
include_end (bool): if True, include end-value in ranges
Returns (array 1-D):
1D numpy array listing all of the integers given in the list,
including enumerations of ranges given.
Note: this follows python-style ranges, i,e, 1:5 or 1..5 returns 1, 2, 3, 4
unless `include_end` is True, which then returns 1,2,3,4,5
"""
return parse_int_args(fiber_string, include_end)
[docs]def ordered_unique(ar, return_index=False):
"""Find the unique elements of an array in the order they first appear
Like numpy.unique, but preserves original order instead of sorting
Args:
ar: array-like data to find unique elements
Options:
return_index: if True also return indices in ar where items first appear
"""
ar = np.asarray(ar)
unique, sortedidx = np.unique(ar, return_index=True)
ii = np.argsort(sortedidx)
indices = sortedidx[ii]
unique = ar[indices]
if return_index:
return unique, indices
else:
return unique
#- Not yet used, but a snippet of code that might be useful
#- e.g. for mapping TARGETID to the rows in which they appear
[docs]def itemindices(a):
"""
Return dict[key] -> list of indices i where a[i] == key
Args:
a : array-like of hashable values
Return dict[key] -> list of indices i where a[i] == key
The dict keys are inserted in the order that they first appear in a,
and the value lists of indices are sorted
e.g. itemindices([10,30,20,30]) -> {10: [0], 30: [1, 3], 20: [2]}
"""
#- there is probably a more efficient way of doing this, but this code
#- can map 100k targetids in <50ms which is sufficient
idmap = dict()
for i, x in enumerate(a):
if x not in idmap:
idmap[x] = [i,]
else:
idmap[x].append(i)
return idmap
[docs]def argmatch(a, b):
"""
Returns indices ii such that a[ii] == b
Args:
a: array-like
b: array-like
Returns indices ii such that a[ii] == b
Both `a` and `b` are allowed to have repeats, and `a` values can be a
superset of `b`, but `b` cannot contain values that are not in `a`
because then no indices `ii` could result in `a[ii] == b`.
Related: desitarget.geomask.match_to which is similar, but doesn't allow
duplicates in `b`.
"""
a = np.asarray(a)
b = np.asarray(b)
ii = np.argsort(a)
jj = np.argsort(b)
kk = np.searchsorted(a[ii], b[jj])
try:
match_indices = ii[kk[np.argsort(jj)]]
except IndexError:
#- if b has elements not in a, that can fail;
#- only do expensive check if needed
bad_b = np.isin(b, a, invert=True)
if np.any(bad_b):
raise ValueError(f'b contains values not in a; impossible to match {set(b[bad_b])} to {a=}')
else:
#- this should not occur
raise RuntimeError(f'argmatch failure for unknown reason {a=}, {b=}')
if a.dtype.names is None:
if np.any(a[match_indices] != b):
#- this should not occur
raise RuntimeError(f'argmatch failure for unknown reason {a=} {match_indices=} {a[match_indices]=} != {b=}')
else:
#- compare column by column, since comparing the whole table will fail on harmless
#- dtype mismatches e.g. int32 vs. int64 even if all values are the same
for col in a.dtype.names:
if np.any(a[col][match_indices] != b[col]):
raise RuntimeError(f'argmatch failure for unknown reason {a[col]=} {match_indices=} {a[col][match_indices]=} != {b[col]=}')
return match_indices