Source code for desispec.pipeline.run

#
# See top-level LICENSE.rst file for Copyright information
#
# -*- coding: utf-8 -*-
"""
desispec.pipeline.run
=====================

Tools for running the pipeline.
"""

from __future__ import absolute_import, division, print_function

import os
import sys
import time
import random
import signal

import numpy as np

from desiutil.log import get_logger

from .. import io

from ..parallel import (dist_uniform, dist_discrete, dist_discrete_all,
    stdouterr_redirected)

from .prod import load_prod

from .db import check_tasks

from .scriptgen import parse_job_env

from .plan import compute_worker_tasks, worker_times


#- TimeoutError and timeout handler to prevent runaway tasks
[docs]class TimeoutError(Exception): pass
def _timeout_handler(signum, frame): raise TimeoutError('Timeout at {}'.format(time.asctime()))
[docs]def run_task(name, opts, comm=None, logfile=None, db=None): """Run a single task. Based on the name of the task, call the appropriate run function for that task. Log output to the specified file. Run using the specified MPI communicator and optionally update state to the specified database. Note: This function DOES NOT check the database or filesystem to see if the task has been completed or if its dependencies exist. It assumes that some higher-level code has done that if necessary. Args: name (str): the name of this task. opts (dict): options to use for this task. comm (mpi4py.MPI.Comm): optional MPI communicator. logfile (str): output log file. If None, do not redirect output to a file. db (pipeline.db.DB): The optional database to update. Returns: int: the total number of processes that failed. """ from .tasks.base import task_classes, task_type log = get_logger() ttype = task_type(name) nproc = 1 rank = 0 if comm is not None: nproc = comm.size rank = comm.rank if rank == 0: if (logfile is not None) and os.path.isfile(logfile): os.remove(logfile) # Mark task as in progress if db is not None: task_classes[ttype].state_set(db=db, name=name, state="running") failcount = 0 #- Set timeout alarm to avoid runaway tasks old_sighandler = signal.signal(signal.SIGALRM, _timeout_handler) expected_run_time = task_classes[ttype].run_time(name, procs=nproc, db=db) # Are we running on a slower/faster node than default timing? timefactor = float(os.getenv("DESI_PIPE_RUN_TIMEFACTOR", default=1.0)) expected_run_time *= timefactor signal.alarm(int(expected_run_time * 60)) if rank == 0: log.info("Running {} with timeout {:.1f} min".format( name, expected_run_time)) task_start_time = time.time() try: if logfile is None: # No redirection if db is None: failcount = task_classes[ttype].run(name, opts, comm=comm) else: failcount = task_classes[ttype].run_and_update(db, name, opts, comm=comm) else: #- time jitter so that we don't open all log files simultaneously time.sleep(2 * random.random()) with stdouterr_redirected(to=logfile, comm=comm): if db is None: failcount = task_classes[ttype].run(name, opts, comm=comm) else: failcount = task_classes[ttype].run_and_update(db, name, opts, comm=comm) except TimeoutError: dt = time.time() - task_start_time if rank == 0: log.error("Task {} timed out after {:.1f} sec".format(name, dt)) if db is not None: task_classes[ttype].state_set(db, name, "failed") failcount = nproc finally: #- Reset timeout alarm whether we finished cleanly or not signal.alarm(0) #- Restore previous signal handler signal.signal(signal.SIGALRM, old_sighandler) if rank == 0: log.debug("Finished with task {} sigalarm reset".format(name)) log.debug("Task {} returning failcount {}".format(name, failcount)) return failcount
[docs]def run_task_simple(name, opts, comm=None): """Run a single task with no DB or log redirection. This a wrapper around run_task() for use without a database and with no log redirection. See documentation for that function. Args: name (str): the name of this task. opts (dict): options to use for this task. comm (mpi4py.MPI.Comm): optional MPI communicator. Returns: int: the total number of processes that failed. """ return run_task(name, opts, comm=comm, logfile=None, db=None)
[docs]def run_dist(tasktype, tasklist, db, nproc, procs_per_node, force=False): """Compute the runtime distribution of tasks. For a given number of processes, parse job environment variables and compute the number of workers to use and the remaining tasks to process. Divide the processes into groups, and associate some (or all) of those groups to workers. Assign tasks to these groups of processes. Some groups may have zero tasks if there are more groups than workers needed. Returns: tuple: The (groupsize, groups, tasks, dist) information. Groupsize is the processes per group. Groups is a list of tuples (one per process) giving the group number and rank within the group. The tasks are a sorted list of tasks containing the subset of the inputs that needs to be run. The dist is a list of tuples (one per group) containing the indices of tasks assigned to each group. """ from .tasks.base import task_classes, task_type log = get_logger() runtasks = None ntask = None ndone = None log.info("Distributing {} {} tasks".format(len(tasklist), tasktype)) if force: # Run everything runtasks = tasklist ntask = len(runtasks) ndone = 0 log.info("Forcibly running {} tasks regardless of state".format(ntask)) else: # Actually check which things need to be run. states = check_tasks(tasklist, db=db) runtasks = [ x for x in tasklist if states[x] == "ready" ] ntask = len(runtasks) ndone = len([ x for x in tasklist if states[x] == "done" ]) log.info( "Found {} tasks ready to run and {} tasks done" .format(ntask, ndone) ) # Query the environment for DESI runtime variables set in # pipeline-generated slurm scripts and use default values if # they are not found. Then compute the number of workers and the # distribution of tasks in a way that is identical to what was # done during job planning. job_env = parse_job_env() tfactor = 1.0 if "timefactor" in job_env: tfactor = job_env["timefactor"] log.info("Using timefactor {}".format(tfactor)) else: log.warning( "DESI_PIPE_RUN_TIMEFACTOR not found in environment, using 1.0." ) startup = 0.0 if "startup" in job_env: startup = job_env["startup"] log.info("Using worker startup of {} minutes".format(startup)) else: log.warning( "DESI_PIPE_RUN_STARTUP not found in environment, using 0.0." ) worker_size = 0 if "workersize" in job_env: worker_size = job_env["workersize"] log.info("Found worker size of {} from environment".format(worker_size)) else: # We have no information from the planning, so fall back to using the # default for this task type or else one node as the worker size. worker_size = task_classes[tasktype].run_max_procs() if worker_size == 0: worker_size = procs_per_node log.warning( "DESI_PIPE_RUN_WORKER_SIZE not found in environment, using {}." .format(worker_size) ) nworker = 0 if "workers" in job_env: nworker = job_env["workers"] log.info("Found {} workers from environment".format(nworker)) else: # We have no information from the planning nworker = nproc // worker_size if nworker == 0: nworker = 1 log.warning( "DESI_PIPE_RUN_WORKERS not found in environment, using {}." .format(nworker) ) if nworker > nproc: msg = "Number of workers ({}) larger than number of procs ({}). This should never happen and means that the job script may have been changed by hand.".format(nworker, nproc) raise RuntimeError(msg) # A "group" of processes is identical in size to the worker_size above. # However, there may be more process groups than workers. This can happen # if we reduced the number of workers due to some tasks being completed, # or if there is a "partial" process group remaining when the worker size # does not evenly divide into the total number of processes. We compute # the process group information here so that the calling code can use it # directly if splitting the communicator. ngroup = nproc // worker_size if ngroup * worker_size < nproc: # We have a leftover partial process group ngroup += 1 groups = [(x // worker_size, x % worker_size) for x in range(nproc)] # Compute the task distribution if ntask == 0: # All tasks are done! return worker_size, groups, list(), [(-1, 0) for x in range(ngroup)] if nworker > len(runtasks): # The number of workers set at job planning time is larger # than the number of tasks that remain to be done. Reduce # the number of workers. log.info( "Job has {} workers but only {} tasks to run. Reducing number of workers to match." .format(nworker, len(runtasks)) ) nworker = len(runtasks) (worktasks, worktimes, workdist) = compute_worker_tasks( tasktype, runtasks, tfactor, nworker, worker_size, startup=startup, db=db) # Compute the times for each worker- just for information workertimes, workermin, workermax = worker_times( worktimes, workdist, startup=startup) log.info( "{} workers have times ranging from {} to {} minutes" .format(nworker, workermin, workermax) ) dist = list() for g in range(ngroup): if g < nworker: # This process group is a being used as a worker. Assign it the # tasks. dist.append(workdist[g]) else: # This process group is idle (not acting as a worker) or contains # the leftover processes to make a whole number of nodes. dist.append([]) return worker_size, groups, worktasks, dist
[docs]def run_task_list(tasktype, tasklist, opts, comm=None, db=None, force=False): """Run a collection of tasks of the same type. This function requires that the DESI environment variables are set to point to the current production directory. This function first takes the communicator and uses the maximum processes per task to split the communicator and form groups of processes of the desired size. It then takes the list of tasks and uses their relative run time estimates to assign tasks to the process groups. Each process group loops over its assigned tasks. If the database is not specified, no state tracking will be done and the filesystem will be checked as needed to determine the current state. Only tasks that are ready to run (based on the filesystem checks or the database) will actually be attempted. Args: tasktype (str): the pipeline step to process. tasklist (list): the list of tasks. All tasks should be of type "tasktype" above. opts (dict): the global options (for example, as read from the production options.yaml file). comm (mpi4py.Comm): the full communicator to use for whole set of tasks. db (pipeline.db.DB): The optional database to update. force (bool): If True, ignore database and filesystem state and just run the tasks regardless. Returns: tuple: the number of ready tasks, number that are done, and the number that failed. """ from .tasks.base import task_classes, task_type log = get_logger() nproc = 1 rank = 0 if comm is not None: nproc = comm.size rank = comm.rank # Compute the number of processes that share a node. procs_per_node = 1 if comm is not None: import mpi4py.MPI as MPI nodecomm = comm.Split_type(MPI.COMM_TYPE_SHARED, 0) procs_per_node = nodecomm.size # Total number of input tasks ntask = len(tasklist) # Get the options for this task type. options = opts[tasktype] # Get the tasks that still need to be done. groupsize = None groups = None worktasks = None dist = None if rank == 0: groupsize, groups, worktasks, dist = run_dist( tasktype, tasklist, db, nproc, procs_per_node, force=force ) comm_group = None comm_rank = comm if comm is not None: groupsize = comm.bcast(groupsize, root=0) groups = comm.bcast(groups, root=0) worktasks = comm.bcast(worktasks, root=0) dist = comm.bcast(dist, root=0) # Determine if we need to split the communicator. Are any processes # in a group larger than one? largest_rank = np.max([x[1] for x in groups]) if largest_rank > 0: comm_group = comm.Split(color=groups[rank][0], key=groups[rank][1]) comm_rank = comm.Split(color=groups[rank][1], key=groups[rank][0]) # How many original tasks did we have and how many were done? ntask = len(tasklist) ndone = ntask - len(worktasks) # every group goes and does its tasks... rundir = io.get_pipe_rundir() logdir = os.path.join(rundir, io.get_pipe_logdir()) group = groups[rank][0] group_rank = groups[rank][1] ## group_firsttask = dist[group][0] ## group_ntask = dist[group][1] group_ntask = len(dist[group]) failcount = 0 group_failcount = 0 if group_ntask > 0: if group_rank == 0: log.debug( "Group {}, running {} tasks".format(group, len(dist[group])) ) for t in dist[group]: # For this task, determine the output log file. If the task has # the "night" key in its name, then use that subdirectory. # Otherwise, if it has the "pixel" key, use the appropriate # subdirectory. tt = task_type(worktasks[t]) fields = task_classes[tt].name_split(worktasks[t]) tasklog = None if "night" in fields: tasklogdir = os.path.join(logdir, io.get_pipe_nightdir(), "{:08d}".format(fields["night"])) # (this directory should have been made during the prod update) tasklog = os.path.join(tasklogdir, "{}.log".format(worktasks[t])) elif "pixel" in fields: tasklogdir = os.path.join(logdir, "healpix", io.healpix_subdirectory(fields["nside"],fields["pixel"])) # When creating this directory, there MIGHT be conflicts from # multiple processes working on pixels in the same # sub-directories... try : if not os.path.isdir(os.path.dirname(tasklogdir)): os.makedirs(os.path.dirname(tasklogdir)) except FileExistsError: pass try : if not os.path.isdir(tasklogdir): os.makedirs(tasklogdir) except FileExistsError: pass tasklog = os.path.join(tasklogdir, "{}.log".format(worktasks[t])) failedprocs = run_task(worktasks[t], options, comm=comm_group, logfile=tasklog, db=db) if failedprocs > 0: group_failcount += 1 log.debug("{} failed; group_failcount now {}".format( worktasks[t], group_failcount)) failcount = group_failcount # Every process in each group has the fail count for the tasks assigned to # its group. To get the total onto all processes, we just have to do an # allreduce across the rank communicator. if comm_rank is not None: failcount = comm_rank.allreduce(failcount) if rank == 0: log.debug("Tasks done; {} failed".format(failcount)) if db is not None and rank == 0 : # postprocess the successful tasks log.debug("postprocess the successful tasks") states = db.get_states(worktasks) log.debug("states={}".format(states)) log.debug("runtasks={}".format(worktasks)) with db.cursor() as cur : for name in worktasks : if states[name] == "done" : log.debug("postprocessing {}".format(name)) task_classes[tasktype].postprocessing(db,name,cur) return ntask, ndone, failcount
[docs]def run_task_list_db(tasktype, tasklist, comm=None): """Run a list of tasks using the pipeline DB and options. This is a wrapper around run_task_list which uses the production database and global options file. Args: tasktype (str): the pipeline step to process. tasklist (list): the list of tasks. All tasks should be of type "tasktype" above. comm (mpi4py.Comm): the full communicator to use for whole set of tasks. Returns: tuple: the number of ready tasks, and the number that failed. """ (db, opts) = load_prod("w") return run_task_list(tasktype, tasklist, opts, comm=comm, db=db)
[docs]def dry_run(tasktype, tasklist, opts, procs, procs_per_node, db=None, launch="mpirun -np", force=False): """Compute the distribution of tasks and equivalent commands. This function takes similar arguments as run_task_list() except simulates the data distribution and commands that would be run if given the specified number of processes and processes per node. This can be used to debug issues with the runtime concurrency or the actual options that will be passed to the underying main() entry points for each task. This function requires that the DESI environment variables are set to point to the current production directory. Only tasks that are ready to run (based on the filesystem checks or the database) will actually be attempted. NOTE: Since this function is just informative and for interactive use, we print information directly to STDOUT rather than logging. Args: tasktype (str): the pipeline step to process. tasklist (list): the list of tasks. All tasks should be of type "tasktype" above. opts (dict): the global options (for example, as read from the production options.yaml file). procs (int): the number of processes to simulate. procs_per_node (int): the number of processes per node to simulate. db (pipeline.db.DB): The optional database to update. launch (str): The launching command for a job. This is just a convenience and prepended to each command before the number of processes. force (bool): If True, ignore database and filesystem state and just run the tasks regardless. Returns: Nothing. """ from .tasks.base import task_classes, task_type log = get_logger() prefix = "DRYRUN: " # Get the options for this task type. options = dict() if tasktype in opts: options = opts[tasktype] # Get the tasks that still need to be done. groupsize, groups, worktasks, dist = run_dist( tasktype, tasklist, db, procs, procs_per_node, force=force ) # Go through the tasks rundir = io.get_pipe_rundir() logdir = os.path.join(rundir, io.get_pipe_logdir()) for group, group_rank in groups: ## group_firsttask = dist[group][0] ## group_ntask = dist[group][1] group_ntask = len(dist[group]) if group_ntask == 0: continue for t in dist[group]: # For this task, determine the output log file. If the task has # the "night" key in its name, then use that subdirectory. # Otherwise, if it has the "pixel" key, use the appropriate # subdirectory. tt = task_type(worktasks[t]) fields = task_classes[tt].name_split(worktasks[t]) tasklog = None if "night" in fields: tasklogdir = os.path.join(logdir, io.get_pipe_nightdir(), "{:08d}".format(fields["night"])) # (this directory should have been made during the prod update) tasklog = os.path.join(tasklogdir, "{}.log".format(worktasks[t])) elif "pixel" in fields: tasklogdir = os.path.join(logdir, "healpix", io.healpix_subdirectory(fields["nside"],fields["pixel"])) tasklog = os.path.join(tasklogdir, "{}.log".format(worktasks[t])) com = task_classes[tt].run_cli(worktasks[t], options, groupsize, launch=launch, log=tasklog, db=db) print("{} {}".format(prefix, com)) sys.stdout.flush() print("{}".format(prefix)) sys.stdout.flush() return