diff options
Diffstat (limited to 'utils/cat-grid-jobs')
-rwxr-xr-x | utils/cat-grid-jobs | 281 |
1 files changed, 195 insertions, 86 deletions
diff --git a/utils/cat-grid-jobs b/utils/cat-grid-jobs index 3873cd3..26258f5 100755 --- a/utils/cat-grid-jobs +++ b/utils/cat-grid-jobs @@ -14,27 +14,95 @@ # You should have received a copy of the GNU General Public License along with # this program. If not, see <https://www.gnu.org/licenses/>. """ -Script to combine the fit results from jobs submitted to the grid. +Script to combine the fit results from jobs submitted to the grid. It's +expected to be run from a cron job: -This script first runs zdab-cat on the zdab file to get the data cleaning words -and SNOMAN fitter results for every event in the file. It then adds any fit -results from the other files listed on the command line and prints the results -as YAML to stdout. + PATH=/usr/bin:$HOME/local/bin + SDDM_DATA=$HOME/sddm/src + DQXX_DIR=$HOME/dqxx -Example: - - $ cat-grid-jobs ~/mc_atm_nu_no_osc_genie_010000_0.mcds ~/grid_job_results/*.txt > output.txt + 0 * * * * module load hdf5; module load py-h5py; module load zlib; cat-grid-jobs --loglevel debug --logfile cat.log --output-dir $HOME/fit_results +The script will loop through all entries in the database and try to combine the +fit results into a single output file. """ from __future__ import print_function, division -import yaml -try: - from yaml import CLoader as Loader, CDumper as Dumper -except ImportError: - from yaml import Loader, Dumper import os import sys +import numpy as np +from datetime import datetime +import h5py +from os.path import join, split +from subprocess import check_call + +DEBUG = 0 +VERBOSE = 1 +NOTICE = 2 +WARNING = 3 + +class Logger(object): + """ + Simple logger class that I wrote for the SNO+ DAQ. Very easy to use: + + log = Logger() + log.set_logfile("test.log") + log.notice("blah") + log.warn("foo") + + The log file format is taken from the Redis log file format which is really + nice since it shows the exact time and severity of each log message. + """ + def __init__(self): + self.logfile = sys.stdout + # by default, we log everything + self.verbosity = DEBUG + + def set_verbosity(self, level): + if isinstance(level, int): + self.verbosity = level + elif isinstance(level, basestring): + if level == 'debug': + self.verbosity = DEBUG + elif level == 'verbose': + self.verbosity = VERBOSE + elif level == 'notice': + self.verbosity = NOTICE + elif level == 'warning': + self.verbosity = WARNING + else: + raise ValueError("unknown loglevel '%s'" % level) + else: + raise TypeError("level must be a string or integer") + + def set_logfile(self, filename): + self.logfile = open(filename, 'a') + + def debug(self, msg): + self.log(DEBUG, msg) + + def verbose(self, msg): + self.log(VERBOSE, msg) + + def notice(self, msg): + self.log(NOTICE, msg) + + def warn(self, msg): + self.log(WARNING, msg) + + def log(self, level, msg): + if level < self.verbosity: + return + + c = '.-*#' + pid = os.getpid() + now = datetime.now() + buf = now.strftime('%d %b %H:%M:%S.%f')[:-3] + + self.logfile.write('%d:%s %c %s\n' % (pid, buf, c[level], msg)) + self.logfile.flush() + +log = Logger() # Check that a given file can be accessed with the correct mode. # Additionally check that `file` is not a directory, as on Windows @@ -107,88 +175,129 @@ def which(cmd, mode=os.F_OK | os.X_OK, path=None): return name return None -# from https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-terminal-in-python -class bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' +def splitext(path): + """ + Like os.path.splitext() except it returns the full extension if the + filename has multiple extensions, for example: + + splitext('foo.tar.gz') -> 'foo', '.tar.gz' + """ + full_root, full_ext = os.path.splitext(path) + while True: + root, ext = os.path.splitext(full_root) + if ext: + full_ext = ext + full_ext + full_root = root + else: + break + + return full_root, full_ext + +def cat_grid_jobs(conn, output_dir): + zdab_cat = which("zdab-cat") + + if zdab_cat is None: + log.warn("couldn't find zdab-cat in path!",file=sys.stderr) + return + + c = conn.cursor() + + results = c.execute('SELECT filename, uuid FROM state').fetchall() + + unique_results = set(results) + + for filename, uuid in unique_results: + head, tail = split(filename) + root, ext = splitext(tail) + + # First, find all hdf5 result files + fit_results = c.execute("SELECT submit_file FROM state WHERE state = 'SUCCESS' AND filename = ? AND uuid = ?", (filename, uuid)).fetchall() + fit_results = [fit_result_filename[0] for fit_result_filename in fit_results] + fit_results = ['%s.hdf5' % splitext(fit_result_filename)[0] for fit_result_filename in fit_results] + + if len(fit_results) == 0: + log.debug("No fit results found for %s (%s)" % (tail, uuid)) + continue -def print_warning(msg): - print(bcolors.WARNING + msg + bcolors.ENDC,file=sys.stderr) + output = join(output_dir,"%s_%s_fit_results.hdf5" % (root,uuid)) -warned = False + if os.path.exists(output): + with h5py.File(output,"a") as fout: + if 'fits' in fout: + total_fits = fout['fits'].shape[0] -def print_warning_once(msg): - global warned - if not warned: - print_warning(msg) - print("skipping further warnings") - warned = True + if total_fits >= len(fit_results): + log.debug("skipping %s because there are already %i fit results" % (tail,len(fit_results))) + continue -def print_fail(msg): - print(bcolors.FAIL + msg + bcolors.ENDC,file=sys.stderr) + # First we get the full event list along with the data cleaning word, FTP + # position, FTK, and RSP energy from the original zdab and then add the fit + # results. + # + # Note: We send stderr to /dev/null since there can be a lot of warnings + # about PMT types and fit results + with open(os.devnull, 'w') as f: + log.debug("zdab-cat %s -o %s" % (filename,output)) + check_call([zdab_cat,filename,"-o",output],stderr=f) + + total_events = 0 + events_with_fit = 0 + total_fits = 0 + + with h5py.File(output,"a") as fout: + total_events = fout['ev'].shape[0] + for filename in fit_results: + head, tail = split(filename) + with h5py.File(filename) as f: + if 'git_sha1' not in f.attrs: + log.warn("No git sha1 found for %s. Skipping..." % tail) + continue + # Check to see if the git sha1 match + if fout.attrs['git_sha1'] != f.attrs['git_sha1']: + log.debug("git_sha1 is %s for current version but %s for %s" % (fout.attrs['git_sha1'],f.attrs['git_sha1'],tail)) + # get fits which match up with the events + valid_fits = f['fits'][np.isin(f['fits'][:][['run','gtid']],fout['ev'][:][['run','gtid']])] + # Add the fit results + fout['fits'].resize((fout['fits'].shape[0]+valid_fits.shape[0],)) + fout['fits'][-valid_fits.shape[0]:] = valid_fits + events_with_fit += len(np.unique(valid_fits[['run','gtid']])) + total_fits += len(np.unique(f['fits']['run','gtid'])) + + log.notice("%s_%s: added %i/%i fit results to a total of %i events" % (filename, uuid, events_with_fit, total_fits, total_events)) if __name__ == '__main__': import argparse - import matplotlib.pyplot as plt - import numpy as np - from subprocess import check_call - from os.path import join, split - import os - import sys - import h5py - import glob + import sqlite3 parser = argparse.ArgumentParser("concatenate fit results from grid jobs into a single file") - parser.add_argument("zdab", help="zdab input file") - parser.add_argument("directory", help="directory with grid results") - parser.add_argument("-o", "--output", type=str, help="output filename", required=True) + parser.add_argument("--db", type=str, help="database file", default=None) + parser.add_argument('--loglevel', + help="logging level (debug, verbose, notice, warning)", + default='notice') + parser.add_argument('--logfile', default=None, + help="filename for log file") + parser.add_argument('--output-dir', default=None, + help="output directory for fit results") args = parser.parse_args() - zdab_cat = which("zdab-cat") + log.set_verbosity(args.loglevel) - if zdab_cat is None: - print("couldn't find zdab-cat in path!",file=sys.stderr) - sys.exit(1) - - # First we get the full event list along with the data cleaning word, FTP - # position, FTK, and RSP energy from the original zdab and then add the fit - # results. - # - # Note: We send stderr to /dev/null since there can be a lot of warnings - # about PMT types and fit results - with open(os.devnull, 'w') as f: - check_call([zdab_cat,args.zdab,"-o",args.output],stderr=f) - - total_events = 0 - events_with_fit = 0 - total_fits = 0 - - with h5py.File(args.output,"a") as fout: - total_events = fout['ev'].shape[0] - for filename in glob.glob(join(args.directory,'*.hdf5')): - head, tail = split(filename) - with h5py.File(filename) as f: - if 'git_sha1' not in f.attrs: - print_fail("No git sha1 found for %s. Skipping..." % tail) - continue - # Check to see if the git sha1 match - if fout.attrs['git_sha1'] != f.attrs['git_sha1']: - print_warning_once("git_sha1 is %s for current version but %s for %s" % (fout.attrs['git_sha1'],f.attrs['git_sha1'],tail)) - # get fits which match up with the events - valid_fits = f['fits'][np.isin(f['fits'][:][['run','gtid']],fout['ev'][:][['run','gtid']])] - # Add the fit results - fout['fits'].resize((fout['fits'].shape[0]+valid_fits.shape[0],)) - fout['fits'][-valid_fits.shape[0]:] = valid_fits - events_with_fit += len(np.unique(valid_fits[['run','gtid']])) - total_fits += len(np.unique(f['fits']['run','gtid'])) - - # Print out number of fit results that were added. Hopefully, this will - # make it easy to catch an error if, for example, this gets run with a - # mismatching zdab and fit results - print("added %i/%i fit results to a total of %i events" % (events_with_fit, total_fits, total_events),file=sys.stderr) + if args.logfile: + log.set_logfile(args.logfile) + + home = os.path.expanduser("~") + + if args.db is None: + args.db = join(home,'state.db') + + if args.output_dir is None: + args.output_dir = home + else: + if not os.path.exists(args.output_dir): + log.debug("mkdir %s" % args.output_dir) + os.mkdir(args.output_dir) + + conn = sqlite3.connect(args.db) + + cat_grid_jobs(conn, args.output_dir) + conn.close() |