"""
desispec.scripts.specex
=======================
Run PSF estimation.
"""
from __future__ import print_function, absolute_import, division
import sys
import os
import re
import time
import argparse
import numpy as np
import ctypes as ct
from ctypes.util import find_library
from astropy.io import fits
from desiutil.log import get_logger
from desispec.io.util import get_tempfilename
def parse(options=None):
parser = argparse.ArgumentParser(description="Estimate the PSF for "
"one frame with specex")
parser.add_argument("--input-image", type=str, required=True,
help="input image")
parser.add_argument("--input-psf", type=str, required=False,
help="input psf file")
parser.add_argument("-o", "--output-psf", type=str, required=True,
help="output psf file")
parser.add_argument("--bundlesize", type=int, required=False, default=25,
help="number of spectra per bundle")
parser.add_argument("-s", "--specmin", type=int, required=False, default=0,
help="first spectrum to extract")
parser.add_argument("-n", "--nspec", type=int, required=False, default=500,
help="number of spectra to extract")
parser.add_argument("--extra", type=str, required=False, default=None,
help="quoted string of arbitrary options to pass to "
"specex_desi_psf_fit")
parser.add_argument("--debug", action = 'store_true',
help="debug mode")
parser.add_argument("--broken-fibers", type=str, required=False, default=None,
help="comma separated list of broken fibers")
parser.add_argument("--disable-merge", action = 'store_true',
help="disable merging fiber bundles")
args = parser.parse_args(options)
return args
def main(args=None, comm=None):
if not isinstance(args, argparse.Namespace):
args = parse(args)
log = get_logger()
#- only import when running, to avoid requiring specex install for import
from specex.specex import run_specex
imgfile = args.input_image
outfile = args.output_psf
nproc = 1
rank = 0
if comm is not None:
nproc = comm.size
rank = comm.rank
hdr=None
if rank == 0 :
hdr = fits.getheader(imgfile)
if comm is not None:
hdr = comm.bcast(hdr, root=0)
#- Locate line list in $SPECEXDATA or specex/data
if 'SPECEXDATA' in os.environ:
specexdata = os.environ['SPECEXDATA']
else:
from importlib import resources
specexdata = resources.files('specex').joinpath('data')
lamp_lines_file = os.path.join(specexdata,'specex_linelist_desi.txt')
if args.input_psf is not None:
inpsffile = args.input_psf
else:
from desispec.calibfinder import findcalibfile
inpsffile = findcalibfile([hdr,], 'PSF')
optarray = []
if args.extra is not None:
optarray = args.extra.split()
specmin = int(args.specmin)
nspec = int(args.nspec)
bundlesize = int(args.bundlesize)
specmax = specmin + nspec
# Now we divide our spectra into bundles
checkbundles = set()
checkbundles.update(np.floor_divide(np.arange(specmin, specmax),
bundlesize*np.ones(nspec)).astype(int))
bundles = sorted(checkbundles)
nbundle = len(bundles)
bspecmin = {}
bnspec = {}
for b in bundles:
if specmin > b * bundlesize:
bspecmin[b] = specmin
else:
bspecmin[b] = b * bundlesize
if (b+1) * bundlesize > specmax:
bnspec[b] = specmax - bspecmin[b]
else:
bnspec[b] = (b+1) * bundlesize - bspecmin[b]
# Now we assign bundles to processes
mynbundle = int(nbundle / nproc)
leftover = nbundle % nproc
if rank < leftover:
mynbundle += 1
myfirstbundle = bundles[0] + rank * mynbundle
else:
myfirstbundle = bundles[0] + ((mynbundle + 1) * leftover) + \
(mynbundle * (rank - leftover))
if rank == 0:
# Print parameters
log.info("specex: using {} processes".format(nproc))
log.info("specex: input image = {}".format(imgfile))
log.info("specex: input PSF = {}".format(inpsffile))
log.info("specex: output = {}".format(outfile))
log.info("specex: bundlesize = {}".format(bundlesize))
log.info("specex: specmin = {}".format(specmin))
log.info("specex: specmax = {}".format(specmax))
if args.broken_fibers :
log.info("specex: broken fibers = {}".format(args.broken_fibers))
# get the root output file
outpat = re.compile(r'(.*)\.fits')
outmat = outpat.match(outfile)
if outmat is None:
raise RuntimeError("specex output file should have .fits extension")
outroot = outmat.group(1)
outdir = os.path.dirname(outroot)
if rank == 0:
if outdir != "" :
if not os.path.isdir(outdir):
os.makedirs(outdir)
cam = hdr["camera"].lower().strip()
band = cam[0]
failcount = 0
for b in range(myfirstbundle, myfirstbundle+mynbundle):
outbundle = "{}_{:02d}".format(outroot, b)
outbundlefits = "{}.fits".format(outbundle)
com = ['desi_psf_fit']
com.extend(['-a', imgfile])
com.extend(['--in-psf', inpsffile])
com.extend(['--out-psf', outbundlefits])
com.extend(['--lamp-lines', lamp_lines_file])
com.extend(['--first-bundle', "{}".format(b)])
com.extend(['--last-bundle', "{}".format(b)])
com.extend(['--first-fiber', "{}".format(bspecmin[b])])
com.extend(['--last-fiber', "{}".format(bspecmin[b]+bnspec[b]-1)])
if band == "z" :
com.extend(['--legendre-deg-wave', "{}".format(3)])
com.extend(['--fit-continuum'])
else :
com.extend(['--legendre-deg-wave', "{}".format(1)])
if args.broken_fibers :
com.extend(['--broken-fibers', "{}".format(args.broken_fibers)])
if args.debug :
com.extend(['--debug'])
com.extend(optarray)
log.info("proc {} calling {}".format(rank, " ".join(com)))
retval = run_specex(com)
if retval != 0:
comstr = " ".join(com)
log.error("desi_psf_fit on process {} failed with return "
"value {} running {}".format(rank, retval, comstr))
failcount += 1
else:
log.info(f"proc {rank} succeeded generating {outbundlefits}")
if args.disable_merge:
return failcount
if comm is not None:
from mpi4py import MPI
failcount = comm.allreduce(failcount, op=MPI.SUM)
if failcount > 0:
# all processes throw
raise RuntimeError("some bundles failed desi_psf_fit")
if rank == 0:
outfits = "{}.fits".format(outroot)
inputs = [ "{}_{:02d}.fits".format(outroot, x) for x in bundles ]
if args.disable_merge :
log.info("don't merge")
else :
#- Empirically it appears that files written by one rank sometimes
#- aren't fully buffer-flushed and closed before getting here,
#- despite the MPI allreduce barrier. Pause to let I/O catch up.
log.info('5 sec pause before merging')
sys.stdout.flush()
time.sleep(5.)
try:
merge_psf(inputs,outfits)
except Exception as e:
log.error(e)
log.error("merging failed for {}".format(outfits))
failcount += 1
log.info('done merging')
if failcount == 0:
# only remove the per-bundle files if the merge was good
for f in inputs :
if os.path.isfile(f):
os.remove(f)
if comm is not None:
failcount = comm.bcast(failcount, root=0)
if failcount > 0:
# all processes throw
raise RuntimeError("merging of per-bundle files failed")
return
[docs]def run(comm,cmds,cameras):
"""
Run PSF fits with specex on a set of ccd images in parallel using the run method
of the desispec.workflow.schedule.Schedule (Schedule) class.
Args:
comm: MPI communicator containing all processes available for work and
scheduling (usually MPI_COMM_WORLD); at least 21 processes should
be available, one for scheduling and (group_size=) 20 to fit all
bundles for a given ccd image. Otherwise there is no constraint on
the number of ranks available, but (comm.Get_size()-1)%group_size
will be unused, since every job is assigned exactly group_size=20
ranks. The variable group_size is set at the number of bundles on
a ccd, and there is currently no support for any other number, due
to the way merging of bundles is currently done.
cmds: dictionary keyed by a camera string (e.g. 'b0', 'r1', ...) with
values being the 'desi_compute_psf ...' string that one would run
on the command line.
cameras: list of camera strings identifying the entries in cmds to be run
as jobs in parallel jobs, one entry per ccd image to be fit.
Processes assigned to cameras not present as keys in cmds will
write a message to the log instead of running a PSF fit.
The function first defines the procedure to call specex for a given ccd image
with the "fitframe" inline function, passes the fitframe function
to the Schedule initialization method, and then calls the run method of the
Schedule class to call fitframe len(cameras) times, each with group_size = 20
processes.
"""
from desispec.workflow.schedule import Schedule
from desiutil.log import get_logger, DEBUG, INFO
log = get_logger()
group_size = 20
# reverse to do b cameras last since they take least time
cameras = sorted(cameras, reverse=True)
def fitframe(groupcomm,worldcomm,job):
'''
Run PSF fit with specex on all bundles for a single ccd image
Args:
groupcomm: job-specific MPI communicator
worldcomm: world MPI communicator
job: job index corresponding to position in list of cmds entries
This is an inline function for use by desispec.workflow.schedule.Schedule,
i.e. via the lines
sc = Schedule(fitframe,comm=comm,njobs=len(cameras),group_size=group_size)
sc.run()
immediately after this inline function definition.
This function uses the external variables group_size, cmds, and cameras. In
particular, the list of camera strings (cameras) provides the mapping of the
job index (job) to the commands (cmds) that specify the arguments
to the specex.parse method, i.e.
camera = cameras[job]
...
cmdargs = cmds[camera].split()[1:]
cmdargs = parse(cmdargs)
...
From the point of view of the Schedule.run method, it is running fitframe
njobs = len(cameras) times, each time using group_size processes with a new
value of job in the range 0 to len(cameras)-1.
'''
error_count = 0
grouprank = groupcomm.Get_rank()
worldrank = worldcomm.Get_rank()
camera = cameras[job]
if not camera in cmds:
log.info(f'nothing to do for camera {camera} on MPI group rank '+
f'{grouprank} and world rank {worldrank}')
else:
cmdargs = cmds[camera].split()[1:]
cmdargs = parse(cmdargs)
if grouprank == 0:
t0 = time.time()
timestamp = time.asctime()
log.info(f'MPI ranks {worldrank}-{worldrank+group_size-1}'
f' fitting PSF for {camera} in job {job} at {timestamp}')
try:
main(cmdargs, comm=groupcomm)
except Exception as e:
if grouprank == 0:
log.error(f'FAILED: MPI ranks {worldrank}-{worldrank+group_size-1}'+
f' on camera {camera}')
log.error('FAILED: {}'.format(cmds[camera]))
log.error(e)
error_count += 1
if grouprank == 0:
specex_time = time.time() - t0
log.info(f'specex fit for {camera} took {specex_time:.1f} seconds')
return error_count
sc = Schedule(fitframe,comm=comm,njobs=len(cameras),group_size=group_size)
return sc.run()
[docs]def compatible(head1, head2) :
"""
Return bool for whether two FITS headers are compatible for merging PSFs
"""
log = get_logger()
for k in ["PSFTYPE", "NPIX_X", "NPIX_Y", "HSIZEX", "HSIZEY", "FIBERMIN",
"FIBERMAX", "NPARAMS", "LEGDEG", "GHDEGX", "GHDEGY"] :
if (head1[k] != head2[k]) :
log.warning("different {} : {}, {}".format(k, head1[k], head2[k]))
return False
return True
[docs]def merge_psf(inputs, output):
"""
Merge individual per-bundle PSF files into full PSF
Args:
inputs: list of input PSF filenames
output: output filename
"""
log = get_logger()
npsf = len(inputs)
log.info("Will merge {} PSFs in {}".format(npsf,output))
# we will add/change data to the first PSF
psf_hdulist=fits.open(inputs[0])
for input_filename in inputs[1:] :
log.info("merging {} into {}".format(input_filename,inputs[0]))
other_psf_hdulist=fits.open(input_filename)
# look at what fibers where actually fit
i=np.where(other_psf_hdulist["PSF"].data["PARAM"]=="STATUS")[0][0]
status_of_fibers = \
other_psf_hdulist["PSF"].data["COEFF"][i][:,0].astype(int)
selected_fibers = np.where(status_of_fibers==0)[0]
log.info("fitted fibers in PSF {} = {}".format(input_filename,
selected_fibers))
if selected_fibers.size == 0 :
log.warning("no fiber with status=0 found in {}".format(
input_filename))
other_psf_hdulist.close()
continue
# copy xtrace and ytrace
psf_hdulist["XTRACE"].data[selected_fibers] = \
other_psf_hdulist["XTRACE"].data[selected_fibers]
psf_hdulist["YTRACE"].data[selected_fibers] = \
other_psf_hdulist["YTRACE"].data[selected_fibers]
# copy parameters
parameters = psf_hdulist["PSF"].data["PARAM"]
for param in parameters :
i0=np.where(psf_hdulist["PSF"].data["PARAM"]==param)[0][0]
i1=np.where(other_psf_hdulist["PSF"].data["PARAM"]==param)[0][0]
psf_hdulist["PSF"].data["COEFF"][i0][selected_fibers] = \
other_psf_hdulist["PSF"].data["COEFF"][i1][selected_fibers]
# copy bundle chi2
i = np.where(other_psf_hdulist["PSF"].data["PARAM"]=="BUNDLE")[0][0]
bundles = np.unique(other_psf_hdulist["PSF"].data["COEFF"][i]\
[selected_fibers,0].astype(int))
log.info("fitted bundles in PSF {} = {}".format(input_filename,
bundles))
for b in bundles :
for key in [ "B{:02d}RCHI2".format(b), "B{:02d}NDATA".format(b),
"B{:02d}NPAR".format(b) ]:
psf_hdulist["PSF"].header[key] = \
other_psf_hdulist["PSF"].header[key]
# close file
other_psf_hdulist.close()
# write
tmpfile = get_tempfilename(output)
psf_hdulist.writeto(tmpfile, overwrite=True)
os.rename(tmpfile, output)
log.info("Wrote PSF {}".format(output))
return
[docs]def mean_psf(inputs, output):
"""
Average multiple input PSF files into an output PSF file
Args:
inputs: list of input PSF files
output: output filename
"""
log = get_logger()
npsf = len(inputs)
log.info("Will compute the average of {} PSFs".format(npsf))
refhead=None
tables=[]
xtrace=[]
ytrace=[]
wavemins=[]
wavemaxs=[]
hdulist=None
bundle_rchi2=[]
nbundles=None
nfibers_per_bundle=None
for input in inputs :
log.info("Adding {}".format(input))
if not os.path.isfile(input) :
log.warning("missing {}".format(input))
continue
psf=fits.open(input)
if refhead is None :
hdulist = psf
refhead = psf["PSF"].header
nfibers = \
(psf["PSF"].header["FIBERMAX"]-psf["PSF"].header["FIBERMIN"])+1
PSFVER=int(refhead["PSFVER"])
if(PSFVER<3) :
log.error("ERROR NEED PSFVER>=3")
sys.exit(1)
else :
if not compatible(psf["PSF"].header,refhead) :
log.error("psfs {} and {} are not compatible".format(inputs[0],
input))
sys.exit(12)
tables.append(psf["PSF"].data)
wavemins.append(psf["PSF"].header["WAVEMIN"])
wavemaxs.append(psf["PSF"].header["WAVEMAX"])
if "XTRACE" in psf :
xtrace.append(psf["XTRACE"].data)
if "YTRACE" in psf :
ytrace.append(psf["YTRACE"].data)
rchi2=[]
b=0
while "B{:02d}RCHI2".format(b) in psf["PSF"].header :
rchi2.append(psf["PSF"].header["B{:02d}RCHI2".format(b) ])
b += 1
rchi2=np.array(rchi2)
nbundles=rchi2.size
bundle_rchi2.append(rchi2)
npsf=len(tables)
bundle_rchi2=np.array(bundle_rchi2)
log.debug("bundle_rchi2= {}".format(str(bundle_rchi2)))
median_bundle_rchi2 = np.median(bundle_rchi2)
rchi2_threshold=median_bundle_rchi2+1.
log.debug("median chi2={} threshold={}".format(median_bundle_rchi2,
rchi2_threshold))
WAVEMIN=refhead["WAVEMIN"]
WAVEMAX=refhead["WAVEMAX"]
FIBERMIN=int(refhead["FIBERMIN"])
FIBERMAX=int(refhead["FIBERMAX"])
fibers_in_bundle={}
i=np.where(tables[0]["PARAM"]=="BUNDLE")[0][0]
bundle_of_fibers=tables[0]["COEFF"][i][:,0].astype(int)
bundles=np.unique(bundle_of_fibers)
for b in bundles :
fibers_in_bundle[b]=np.where(bundle_of_fibers==b)[0]
for entry in range(tables[0].size) :
PARAM=tables[0][entry]["PARAM"]
log.info("Averaging '{}' coefficients".format(PARAM))
coeff=[tables[0][entry]["COEFF"]]
npar=coeff[0][1].size
for p in range(1,npsf) :
if wavemins[p]==WAVEMIN and wavemaxs[p]==WAVEMAX :
coeff.append(tables[p][entry]["COEFF"])
else :
log.info("need to refit legendre polynomial ...")
from numpy.polynomial.legendre import legval,legfit
icoeff = tables[p][entry]["COEFF"]
ocoeff = np.zeros(icoeff.shape)
# need to reshape legpol
iu = np.linspace(-1,1,npar+3)
iwavemin = wavemins[p]
iwavemax = wavemaxs[p]
wave = (iu+1.)/2.*(iwavemax-iwavemin)+iwavemin
ou = (wave-WAVEMIN)/(WAVEMAX-WAVEMIN)*2.-1.
for f in range(icoeff.shape[0]) :
val = legval(iu,icoeff[f])
ocoeff[f] = legfit(ou,val,deg=npar-1)
coeff.append(ocoeff)
coeff=np.array(coeff)
output_rchi2=np.zeros((bundle_rchi2.shape[1]))
output_coeff=np.zeros(tables[0][entry]["COEFF"].shape)
# now merge, using rchi2 as selection score
for bundle in fibers_in_bundle.keys() :
ok=np.where(bundle_rchi2[:,bundle]<rchi2_threshold)[0]
#ok=np.array([0,1]) # debug
if entry==0 :
log.info("for fiber bundle {}, {} valid PSFs".format(bundle,
ok.size))
# We finally resorted to use a mean instead of a median here for two reasons.
# First, there is already a vetting of PSF bundles with good chi2 above
# that protects us from bad fits (we only expect outliers because of bad fits because of cosmic rays,
# not a glitch in hardware). Second, some of the PSF parameters have large correlations,
# which mean that two pairs of parameter values, like (p_a_i,p_b_i) and (p_a_j,p_b_j) (with a,b param
# indexes and i,j exposure indices) may give similar PSFs despite large noise in individual parameters
# but a median could decide to select a pair like (p_a_i,p_b_j) that could lead to a PSF inconsistent
# with data. Using a mean instead of a median protects us from this situation.
if ok.size>=2 : # use mean
log.debug("bundle #{} : use mean".format(bundle))
for f in fibers_in_bundle[bundle] :
output_coeff[f]=np.mean(coeff[ok,f],axis=0)
output_rchi2[bundle]=np.mean(bundle_rchi2[ok,bundle])
elif ok.size==1 : # copy
log.debug("bundle #{} : use only one psf ".format(bundle))
for f in fibers_in_bundle[bundle] :
output_coeff[f]=coeff[ok[0],f]
output_rchi2[bundle]=bundle_rchi2[ok[0],bundle]
else : # we have a problem here, take the smallest rchi2
log.debug("bundle #{} : take smallest chi2 ".format(bundle))
i=np.argmin(bundle_rchi2[:,bundle])
for f in fibers_in_bundle[bundle] :
output_coeff[f]=coeff[i,f]
output_rchi2[bundle]=bundle_rchi2[i,bundle]
# now copy this in output table
hdulist["PSF"].data["COEFF"][entry]=output_coeff
# change bundle chi2
for bundle in range(output_rchi2.size) :
hdulist["PSF"].header["B{:02d}RCHI2".format(bundle)] = \
output_rchi2[bundle]
# alter other keys in header
hdulist["PSF"].header["EXPID"]=0. # it's a mix, need to add the expids
if len(xtrace)>0 :
xtrace=np.array(xtrace)
ytrace=np.array(ytrace)
npar = xtrace.shape[2] # assume all have same npar
for p in range(xtrace.shape[0]) :
if wavemins[p]==WAVEMIN and wavemaxs[p]==WAVEMAX :
continue
# need to reshape legpol
iu = np.linspace(-1,1,npar+3)
iwavemin = wavemins[p]
iwavemax = wavemaxs[p]
wave = (iu+1.)/2.*(iwavemax-iwavemin)+iwavemin
ou = (wave-WAVEMIN)/(WAVEMAX-WAVEMIN)*2.-1.
for f in range(icoeff.shape[0]):
val = legval(iu,xtrace[p][f])
xtrace[p][f] = legfit(ou,val,deg=npar-1)
val = legval(iu,ytrace[p][f])
ytrace[p][f] = legfit(ou,val,deg=npar-1)
hdulist["xtrace"].data = np.mean(xtrace,axis=0)
hdulist["ytrace"].data = np.mean(ytrace,axis=0)
for hdu in ["XTRACE","YTRACE","PSF"] :
if hdu in hdulist :
for input in inputs :
hdulist[hdu].header["comment"] = "inc {}".format(input)
# save output PSF
tmpfile = get_tempfilename(output)
hdulist.writeto(tmpfile, overwrite=True)
os.rename(tmpfile, output)
log.info("wrote {}".format(output))
return