Source code for indra.util.aws

import re

import boto3
import logging
import requests
from datetime import datetime, timezone, timedelta

from botocore import UNSIGNED
from botocore.client import Config

from time import sleep
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

from indra import get_config, has_config
from indra.util.nested_dict import NestedDict

logger = logging.getLogger(__name__)


def get_ids(job_list):
    if job_list is None:
        return None
    return [job['jobId'] for job in job_list]


[docs]def kill_all(job_queue, reason='None given', states=None, kill_list=None): """Terminates/cancels all jobs on the specified queue. Parameters ---------- job_queue : str The name of the Batch job queue on which you wish to terminate/cancel jobs. reason : str Provide a reason for the kill that will be recorded with the job's record on AWS. states : None or list[str] A list of job states to remove. Possible states are 'STARTING', 'RUNNABLE', and 'RUNNING'. If None, all jobs in all states will be ended (modulo the `kill_list` below). kill_list : None or list[dict] A list of job dictionaries (as returned by the submit function) that you specifically wish to kill. All other jobs on the queue will be ignored. If None, all jobs on the queue will be ended (modulo the above). Returns ------- killed_ids : list[str] A list of the job ids for jobs that were killed. """ # Default is all states. if states is None: states = ['STARTING', 'RUNNABLE', 'RUNNING'] # Get batch client batch = boto3.client('batch') # Get all other jobs, and terminate them. killed_ids = [] for status in states: running = batch.list_jobs(jobQueue=job_queue, jobStatus=status) active_job_list = running.get('jobSummaryList') if active_job_list is None: continue for job in active_job_list: # Check if this is one of the specified jobs, if any specified. ids_to_kill = get_ids(kill_list) if ids_to_kill is not None and job['jobId'] not in ids_to_kill: continue # End the job. if status == 'RUNNING': logger.info('Terminating {jobName} ({jobId})'.format(**job)) res = batch.terminate_job(jobId=job['jobId'], reason=reason) else: logger.info('Canceling {jobName} ({jobId})'.format(**job)) res = batch.cancel_job(jobId=job['jobId'], reason=reason) # Record the result of the kill killed_ids.append(res) return killed_ids
[docs]def tag_instance(instance_id, **tags): """Tag a single ec2 instance.""" logger.debug("Got request to add tags %s to instance %s." % (str(tags), instance_id)) ec2 = boto3.resource('ec2') instance = ec2.Instance(instance_id) # Remove None's from `tags` filtered_tags = {k: v for k, v in tags.items() if v and k} # Check for existing tags if instance.tags is not None: existing_tags = {tag.get('Key'): tag.get('Value') for tag in instance.tags} logger.debug("Ignoring existing tags; %s" % str(existing_tags)) for tag_key in existing_tags.keys(): filtered_tags.pop(tag_key, None) # If we have new tags to add, add them. tag_list = [{'Key': k, 'Value': v} for k, v in filtered_tags.items()] if len(tag_list): logger.info('Adding project tags "%s" to instance %s' % (filtered_tags, instance_id)) instance.create_tags(Tags=tag_list) vols = instance.volumes.all() for page in vols.pages(): for vol in page: vol.create_tags(Tags=tag_list) else: logger.info('No new tags from: %s' % str(tags)) return
[docs]def tag_myself(project='cwc', **other_tags): """Function run when indra is used in an EC2 instance to apply tags.""" base_url = "http://169.254.169.254" try: resp = requests.get(base_url + "/latest/meta-data/instance-id") except requests.exceptions.ConnectionError: logger.warning("Could not connect to service. Note this should only " "be run from within a batch job.") return instance_id = resp.text tag_instance(instance_id, project=project, **other_tags) return
[docs]def get_batch_command(command_list, project=None, purpose=None): """Get the command appropriate for running something on batch.""" command_str = ' '.join(command_list) ret = ['python3', '-m', 'indra.util.aws', 'run_in_batch', command_str] if not project and has_config('DEFAULT_AWS_PROJECT'): project = get_config('DEFAULT_AWS_PROJECT') if project: ret += ['--project', project] if purpose: ret += ['--purpose', purpose] return ret
def run_in_batch(command_list, project, purpose): from subprocess import call tag_myself(project, purpose=purpose) logger.info("Running command list: %s" % str(command_list)) logger.info('\n'+20*'='+' Begin Primary Command Output '+20*'='+'\n') ret_code = call(command_list) logger.info('\n'+21*'='+' End Primary Command Output '+21*'='+'\n') return ret_code
[docs]def get_jobs(job_queue='run_reach_queue', job_status='RUNNING'): """Returns a list of dicts with jobName and jobId for each job with the given status.""" batch = boto3.client('batch') jobs = batch.list_jobs(jobQueue=job_queue, jobStatus=job_status) return jobs.get('jobSummaryList')
s3_path_patt = re.compile('^s3:([-a-zA-Z0-9_]+)/(.*?)$')
[docs]class JobLog(object): """Gets the Cloudwatch log associated with the given job. Parameters ---------- job_info : dict dict containing entries for 'jobName' and 'jobId', e.g., as returned by get_jobs() log_group_name : string Name of the log group; defaults to '/aws/batch/job' Returns ------- list of strings The event messages in the log, with the earliest events listed first. """ _suffix_base = '/part_' def __init__(self, job_info, log_group_name='/aws/batch/job', verbose=False, append_dumps=True): self.job_name = job_info['jobName'] self.job_id = job_info['jobId'] self.logs_client = boto3.client('logs') self.verbose = verbose self.log_group_name = log_group_name batch = boto3.client('batch') resp = batch.describe_jobs(jobs=[self.job_id]) job_desc = resp['jobs'][0] job_def_name = job_desc['jobDefinition'].split('/')[-1].split(':')[0] task_arn_id = job_desc['container']['taskArn'].split('/')[-1] self.log_stream_name = '%s/default/%s' % (job_def_name, task_arn_id) self.latest_timestamp = None self.lines = [] self.nextToken = None self.__len = 0 self.append = append_dumps return def __len__(self): return self.__len def clear_lines(self): self.lines = []
[docs] def dump(self, out_file, append=None): """Dump the logs in their entirety to the specified file.""" if append is None: append = self.append elif append != self.append: logger.info("Overriding default append behavior. This could muddy " "future loads.") m = s3_path_patt.match(out_file) if m is not None: # If the user wants the files on s3... bucket, prefix = m.groups() s3 = boto3.client('s3') # Find the largest part number among the current suffixes if append: max_num = 0 for key in iter_s3_keys(s3, bucket, prefix, do_retry=False): if key[len(prefix):].startswith(self._suffix_base): num = int(key[len(prefix + self._suffix_base):]) if max_num > num: max_num = num # Create the new suffix, and dump the lines to s3. new_suffix = self._suffix_base + str(max_num + 1) key = prefix + new_suffix else: key = prefix s3.put_object(Bucket=bucket, Key=key, Body=self.dumps()) else: # Otherwise, if they want them locally... with open(out_file, 'wt' if append else 'w') as f: for line in self.lines: f.write(line) return
[docs] def load(self, out_file): """Load the log lines from the cached files.""" m = s3_path_patt.match(out_file) if m is not None: bucket, prefix = m.groups() s3 = boto3.client('s3') if self.append: prior_line_bytes = [] for key in sorted(iter_s3_keys(s3, bucket, prefix)): if key[len(prefix):].startswith(self._suffix_base): res = s3.get_object(Bucket=bucket, Key=key) prior_line_bytes += res['Body'].read().splitlines() else: res = s3.get_object(Bucket=bucket, Key=prefix) prior_line_bytes = res['Body'].read().splitlines() prior_lines = [s.decode('utf-8') + '\n' for s in prior_line_bytes] else: with open(out_file, 'r') as f: prior_lines = f.readlines() self.lines = prior_lines + self.lines return
def dumps(self): return ''.join(self.lines) def get_lines(self): kwargs = {'logGroupName': self.log_group_name, 'logStreamName': self.log_stream_name, 'startFromHead': True} while True: if self.nextToken is not None: kwargs['nextToken'] = self.nextToken response = self.logs_client.get_log_events(**kwargs) # If we've gotten all the events already, the nextForwardToken for # this call will be the same as the last one if response.get('nextForwardToken') == self.nextToken: break else: events = response.get('events') if events: for evt in events: line = '%s: %s\n' % (evt['timestamp'], evt['message']) self.lines.append(line) self.latest_timestamp = \ (datetime.fromtimestamp(evt['timestamp']/1000) .astimezone(timezone.utc) .replace(tzinfo=None)) self.__len += 1 if self.verbose: logger.info('%d %s' % (len(self.lines), line)) self.nextToken = response.get('nextForwardToken') return
[docs]def dump_logs(job_queue='run_reach_queue', job_status='RUNNING'): """Write logs for all jobs with given the status to files.""" jobs = get_jobs(job_queue, job_status) for job in jobs: log = JobLog(job) log.get_lines() log.dump('{jobName}_{jobId}.log'.format(**job))
[docs]def get_date_from_str(date_str): """Get a utc datetime object from a string of format %Y-%m-%d-%H-%M-%S Parameters ---------- date_str : str A string of the format %Y(-%m-%d-%H-%M-%S). The string is assumed to represent a UTC time. Returns ------- datetime.datetime """ date_format = '%Y-%m-%d-%H-%M-%S' # Pad date_str specifying less than full format if 1 <= len(date_str.split('-')) < 6: # Add Jan if not present if len(date_str.split('-')) == 1: date_str += '-01' # Add day after month if not present if len(date_str.split('-')) == 2: date_str += '-01' # Pad with 0 hours, 0 minutes and 0 seconds while len(date_str.split('-')) < 6: date_str += '-0' return datetime.strptime( date_str, date_format).replace( tzinfo=timezone.utc)
[docs]def iter_s3_keys(s3, bucket, prefix, date_cutoff=None, after=True, with_dt=False, do_retry=True): """Iterate over the keys in an s3 bucket given a prefix Parameters ---------- s3 : boto3.client.S3 A boto3.client.S3 instance bucket : str The name of the bucket to list objects in prefix : str The prefix filtering of the objects for list date_cutoff : str|datetime.datetime A datestring of format %Y(-%m-%d-%H-%M-%S) or a datetime.datetime object. The date is assumed to be in UTC. By default no filtering is done. Default: None. after : bool If True, only return objects after the given date cutoff. Otherwise, return objects before. Default: True with_dt : bool If True, yield a tuple (key, datetime.datetime(LastModified)) of the s3 Key and the object's LastModified date as a datetime.datetime object, only yield s3 key otherwise. Default: False. do_retry : bool If True, and no contents appear, try again in case there was simply a brief lag. If False, do not retry, and just accept the "directory" is empty. Returns ------- iterator[key]|iterator[(key, datetime.datetime)] An iterator over s3 keys or (key, LastModified) tuples. """ if date_cutoff: date_cutoff = date_cutoff if\ isinstance(date_cutoff, datetime) else\ get_date_from_str(date_cutoff) # Check timezone info if date_cutoff.utcoffset() is None: date_cutoff = date_cutoff.replace(tzinfo=timezone.utc) if date_cutoff.utcoffset() != timedelta(): date_cutoff = date_cutoff.astimezone(timezone.utc) is_truncated = True marker = None while is_truncated: # Get the (next) batch of contents. if marker: resp = s3.list_objects(Bucket=bucket, Prefix=prefix, Marker=marker) else: resp = s3.list_objects(Bucket=bucket, Prefix=prefix) # Handle case where no contents are found. if not resp.get('Contents'): if do_retry: logger.info("Prefix \"%s\" does not seem to have children. " "Retrying once." % prefix) do_retry = False sleep(0.1) continue else: logger.info("No contents found for \"%s\"." % prefix) break # Filter by time. for entry in resp['Contents']: if entry['Key'] != marker: if date_cutoff and after and\ entry['LastModified'] > date_cutoff\ or\ date_cutoff and not after and\ entry['LastModified'] < date_cutoff\ or \ date_cutoff is None: yield (entry['Key'], entry['LastModified']) if with_dt \ else entry['Key'] is_truncated = resp['IsTruncated'] marker = entry['Key']
[docs]def rename_s3_prefix(s3, bucket, old_prefix, new_prefix): """Change an s3 prefix within the same bucket.""" to_delete = [] for key in iter_s3_keys(s3, bucket, old_prefix): # Copy the object to the new key (with prefix replaced) new_key = key.replace(old_prefix, new_prefix) s3.copy_object(Bucket=bucket, Key=new_key, CopySource={'Bucket': bucket, 'Key': key}, MetadataDirective='COPY', TaggingDirective='COPY') # Keep track of the objects that will need to be deleted (the old keys) to_delete.append({'Key': key}) # Delete objects in maximum batches of 1000. if len(to_delete) >= 1000: s3.delete_objects(Bucket=bucket, Delete={'Objects': to_delete[:1000]}) del to_delete[:1000] # Get any stragglers. s3.delete_objects(Bucket=bucket, Delete={'Objects': to_delete}) return
[docs]def get_s3_file_tree(s3, bucket, prefix, date_cutoff=None, after=True, with_dt=False): """Overcome s3 response limit and return NestedDict tree of paths. The NestedDict object also allows the user to search by the ends of a path. The tree mimics a file directory structure, with the leave nodes being the full unbroken key. For example, 'path/to/file.txt' would be retrieved by ret['path']['to']['file.txt']['key'] The NestedDict object returned also has the capability to get paths that lead to a certain value. So if you wanted all paths that lead to something called 'file.txt', you could use ret.get_paths('file.txt') For more details, see the NestedDict docs. Parameters ---------- s3 : boto3.client.S3 A boto3.client.S3 instance bucket : str The name of the bucket to list objects in prefix : str The prefix filtering of the objects for list date_cutoff : str|datetime.datetime A datestring of format %Y(-%m-%d-%H-%M-%S) or a datetime.datetime object. The date is assumed to be in UTC. By default no filtering is done. Default: None. after : bool If True, only return objects after the given date cutoff. Otherwise, return objects before. Default: True with_dt : bool If True, yield a tuple (key, datetime.datetime(LastModified)) of the s3 Key and the object's LastModified date as a datetime.datetime object, only yield s3 key otherwise. Default: False. Returns ------- NestedDict A file tree represented as an NestedDict """ file_tree = NestedDict() pref_path = prefix.split('/')[:-1] # avoid the trailing empty str. for k in iter_s3_keys(s3, bucket, prefix, date_cutoff, after, with_dt): if with_dt: key, dt = k else: key, dt = k, None full_path = key.split('/') relevant_path = full_path[len(pref_path):] curr = file_tree for step in relevant_path: curr = curr[step] curr['key'] = k return file_tree
[docs]def get_s3_client(unsigned=True): """Return a boto3 S3 client with optional unsigned config. Parameters ---------- unsigned : Optional[bool] If True, the client will be using unsigned mode in which public resources can be accessed without credentials. Default: True Returns ------- botocore.client.S3 A client object to AWS S3. """ if unsigned: return boto3.client('s3', config=Config(signature_version=UNSIGNED)) else: return boto3.client('s3')
if __name__ == '__main__': parser = ArgumentParser( 'aws.py', description=('Use some of INDRA\'s aws tools. For more specific help, ' 'select one of the Methods with the `-h` option.') ) subparsers = parser.add_subparsers(title='Task') subparsers.required = True subparsers.dest = 'task' # Create parent parser classes for second layer of options parent_run_parser = ArgumentParser(add_help=False) parent_run_parser.add_argument( 'command', help=('Enter the command as a single string to be run as if in a ' 'batch environment.') ) parent_run_parser.add_argument( '--project', '-P', default='cwc', help='Give a name for the project.' ) parent_run_parser.add_argument( '--purpose', '-p', help='Give the task some meaning.' ) parent_kill_parser = ArgumentParser(add_help=False) parent_kill_parser.add_argument( 'queue_name', help='Select the batch queue in which all jobs should be terminated.' ) parent_kill_parser.add_argument( '--reason', '-R', help='Give a reason for killing all the jobs.' ) # Make non_db_parser and get subparsers run_parser = subparsers.add_parser( 'run_in_batch', parents=[parent_run_parser], description=('This should be called to run any command wtihin an aws ' 'batch job instance.'), formatter_class=ArgumentDefaultsHelpFormatter ) # Make db parser and get subparsers. kill_parser = subparsers.add_parser( 'kill_all', parents=[parent_kill_parser], description='Kill all the jobs running in a given queue.', formatter_class=ArgumentDefaultsHelpFormatter ) args = parser.parse_args() if args.task == 'run_in_batch': ret_code = run_in_batch(args.command.split(), args.project, args.purpose) if ret_code is 0: logger.info('Job endend well.') else: logger.error('Job failed!') import sys sys.exit(ret_code) elif args.task == 'kill_all': kill_all(args.queue_name, args.reason)