import re
import boto3
import logging
import requests
from datetime import datetime, timezone, timedelta
from botocore import UNSIGNED
from botocore.client import Config
from time import sleep
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from indra import get_config, has_config
from indra.util.nested_dict import NestedDict
logger = logging.getLogger(__name__)
def get_ids(job_list):
if job_list is None:
return None
return [job['jobId'] for job in job_list]
[docs]def kill_all(job_queue, reason='None given', states=None, kill_list=None):
"""Terminates/cancels all jobs on the specified queue.
Parameters
----------
job_queue : str
The name of the Batch job queue on which you wish to terminate/cancel
jobs.
reason : str
Provide a reason for the kill that will be recorded with the job's
record on AWS.
states : None or list[str]
A list of job states to remove. Possible states are 'STARTING',
'RUNNABLE', and 'RUNNING'. If None, all jobs in all states will be
ended (modulo the `kill_list` below).
kill_list : None or list[dict]
A list of job dictionaries (as returned by the submit function) that
you specifically wish to kill. All other jobs on the queue will be
ignored. If None, all jobs on the queue will be ended (modulo the
above).
Returns
-------
killed_ids : list[str]
A list of the job ids for jobs that were killed.
"""
# Default is all states.
if states is None:
states = ['STARTING', 'RUNNABLE', 'RUNNING']
# Get batch client
batch = boto3.client('batch')
# Get all other jobs, and terminate them.
killed_ids = []
for status in states:
running = batch.list_jobs(jobQueue=job_queue, jobStatus=status)
active_job_list = running.get('jobSummaryList')
if active_job_list is None:
continue
for job in active_job_list:
# Check if this is one of the specified jobs, if any specified.
ids_to_kill = get_ids(kill_list)
if ids_to_kill is not None and job['jobId'] not in ids_to_kill:
continue
# End the job.
if status == 'RUNNING':
logger.info('Terminating {jobName} ({jobId})'.format(**job))
res = batch.terminate_job(jobId=job['jobId'], reason=reason)
else:
logger.info('Canceling {jobName} ({jobId})'.format(**job))
res = batch.cancel_job(jobId=job['jobId'], reason=reason)
# Record the result of the kill
killed_ids.append(res)
return killed_ids
[docs]def tag_instance(instance_id, **tags):
"""Tag a single ec2 instance."""
logger.debug("Got request to add tags %s to instance %s."
% (str(tags), instance_id))
ec2 = boto3.resource('ec2')
instance = ec2.Instance(instance_id)
# Remove None's from `tags` and reformat to the list format that
# boto3 expects
tag_list = [{'Key': k, 'Value': v} for k, v in tags.items() if k and v]
if tag_list:
logger.info('Adding project tags "%s" to instance %s'
% (str(tag_list), instance_id))
instance.create_tags(Tags=tag_list)
vols = instance.volumes.all()
for page in vols.pages():
for vol in page:
vol.create_tags(Tags=tag_list)
else:
logger.info('No new tags from: %s' % str(tags))
instance_tags = {tag.get('Key'): tag.get('Value') for tag in instance.tags}
logger.info('Updated instance tags: %s' % instance_tags)
return
[docs]def tag_myself(project='aske', **other_tags):
"""Function run when indra is used in an EC2 instance to apply tags."""
base_url = "http://169.254.169.254"
try:
resp = requests.get(base_url + "/latest/meta-data/instance-id")
except requests.exceptions.ConnectionError:
logger.warning("Could not connect to service. Note this should only "
"be run from within a batch job.")
return
instance_id = resp.text
tag_instance(instance_id, project=project, **other_tags)
return
[docs]def get_batch_command(command_list, project=None, purpose=None):
"""Get the command appropriate for running something on batch."""
command_str = ' '.join(command_list)
ret = ['python3', '-m', 'indra.util.aws', 'run_in_batch', command_str]
if not project and has_config('DEFAULT_AWS_PROJECT'):
project = get_config('DEFAULT_AWS_PROJECT')
if project:
ret += ['--project', project]
if purpose:
ret += ['--purpose', purpose]
return ret
def run_in_batch(command_list, project, purpose):
from subprocess import call
tag_myself(project, purpose=purpose)
logger.info("Running command list: %s" % str(command_list))
logger.info('\n'+20*'='+' Begin Primary Command Output '+20*'='+'\n')
ret_code = call(command_list)
logger.info('\n'+21*'='+' End Primary Command Output '+21*'='+'\n')
return ret_code
[docs]def get_jobs(job_queue='run_reach_queue', job_status='RUNNING'):
"""Returns a list of dicts with jobName and jobId for each job with the
given status."""
batch = boto3.client('batch')
jobs = batch.list_jobs(jobQueue=job_queue, jobStatus=job_status)
return jobs.get('jobSummaryList')
s3_path_patt = re.compile('^s3:([-a-zA-Z0-9_]+)/(.*?)$')
[docs]class JobLog(object):
"""Gets the Cloudwatch log associated with the given job.
Parameters
----------
job_info : dict
dict containing entries for 'jobName' and 'jobId', e.g., as returned
by get_jobs()
log_group_name : string
Name of the log group; defaults to '/aws/batch/job'
Returns
-------
list of strings
The event messages in the log, with the earliest events listed first.
"""
_suffix_base = '/part_'
def __init__(self, job_info, log_group_name='/aws/batch/job',
verbose=False, append_dumps=True):
self.job_name = job_info['jobName']
self.job_id = job_info['jobId']
self.logs_client = boto3.client('logs')
self.verbose = verbose
self.log_group_name = log_group_name
batch = boto3.client('batch')
resp = batch.describe_jobs(jobs=[self.job_id])
job_desc = resp['jobs'][0]
job_def_name = job_desc['jobDefinition'].split('/')[-1].split(':')[0]
task_arn_id = job_desc['container']['taskArn'].split('/')[-1]
self.log_stream_name = '%s/default/%s' % (job_def_name, task_arn_id)
self.latest_timestamp = None
self.lines = []
self.nextToken = None
self.__len = 0
self.append = append_dumps
return
def __len__(self):
return self.__len
def clear_lines(self):
self.lines = []
[docs] def dump(self, out_file, append=None):
"""Dump the logs in their entirety to the specified file."""
if append is None:
append = self.append
elif append != self.append:
logger.info("Overriding default append behavior. This could muddy "
"future loads.")
m = s3_path_patt.match(out_file)
if m is not None:
# If the user wants the files on s3...
bucket, prefix = m.groups()
s3 = boto3.client('s3')
# Find the largest part number among the current suffixes
if append:
max_num = 0
for key in iter_s3_keys(s3, bucket, prefix, do_retry=False):
if key[len(prefix):].startswith(self._suffix_base):
num = int(key[len(prefix + self._suffix_base):])
if max_num > num:
max_num = num
# Create the new suffix, and dump the lines to s3.
new_suffix = self._suffix_base + str(max_num + 1)
key = prefix + new_suffix
else:
key = prefix
s3.put_object(Bucket=bucket, Key=key, Body=self.dumps())
else:
# Otherwise, if they want them locally...
with open(out_file, 'wt' if append else 'w') as f:
for line in self.lines:
f.write(line)
return
[docs] def load(self, out_file):
"""Load the log lines from the cached files."""
m = s3_path_patt.match(out_file)
if m is not None:
bucket, prefix = m.groups()
s3 = boto3.client('s3')
if self.append:
prior_line_bytes = []
for key in sorted(iter_s3_keys(s3, bucket, prefix)):
if key[len(prefix):].startswith(self._suffix_base):
res = s3.get_object(Bucket=bucket, Key=key)
prior_line_bytes += res['Body'].read().splitlines()
else:
res = s3.get_object(Bucket=bucket, Key=prefix)
prior_line_bytes = res['Body'].read().splitlines()
prior_lines = [s.decode('utf-8') + '\n'
for s in prior_line_bytes]
else:
with open(out_file, 'r') as f:
prior_lines = f.readlines()
self.lines = prior_lines + self.lines
return
def dumps(self):
return ''.join(self.lines)
def get_lines(self):
kwargs = {'logGroupName': self.log_group_name,
'logStreamName': self.log_stream_name,
'startFromHead': True}
while True:
if self.nextToken is not None:
kwargs['nextToken'] = self.nextToken
response = self.logs_client.get_log_events(**kwargs)
# If we've gotten all the events already, the nextForwardToken for
# this call will be the same as the last one
if response.get('nextForwardToken') == self.nextToken:
break
else:
events = response.get('events')
if events:
for evt in events:
line = '%s: %s\n' % (evt['timestamp'], evt['message'])
self.lines.append(line)
self.latest_timestamp = \
(datetime.fromtimestamp(evt['timestamp']/1000)
.astimezone(timezone.utc)
.replace(tzinfo=None))
self.__len += 1
if self.verbose:
logger.info('%d %s' % (len(self.lines), line))
self.nextToken = response.get('nextForwardToken')
return
[docs]def dump_logs(job_queue='run_reach_queue', job_status='RUNNING'):
"""Write logs for all jobs with given the status to files."""
jobs = get_jobs(job_queue, job_status)
for job in jobs:
log = JobLog(job)
log.get_lines()
log.dump('{jobName}_{jobId}.log'.format(**job))
[docs]def get_date_from_str(date_str):
"""Get a utc datetime object from a string of format %Y-%m-%d-%H-%M-%S
Parameters
----------
date_str : str
A string of the format %Y(-%m-%d-%H-%M-%S). The string is assumed
to represent a UTC time.
Returns
-------
datetime.datetime
"""
date_format = '%Y-%m-%d-%H-%M-%S'
# Pad date_str specifying less than full format
if 1 <= len(date_str.split('-')) < 6:
# Add Jan if not present
if len(date_str.split('-')) == 1:
date_str += '-01'
# Add day after month if not present
if len(date_str.split('-')) == 2:
date_str += '-01'
# Pad with 0 hours, 0 minutes and 0 seconds
while len(date_str.split('-')) < 6:
date_str += '-0'
return datetime.strptime(
date_str, date_format).replace(
tzinfo=timezone.utc)
[docs]def iter_s3_keys(s3, bucket, prefix, date_cutoff=None, after=True,
with_dt=False, do_retry=True):
"""Iterate over the keys in an s3 bucket given a prefix
Parameters
----------
s3 : boto3.client.S3
A boto3.client.S3 instance
bucket : str
The name of the bucket to list objects in
prefix : str
The prefix filtering of the objects for list
date_cutoff : str|datetime.datetime
A datestring of format %Y(-%m-%d-%H-%M-%S) or a datetime.datetime
object. The date is assumed to be in UTC. By default no filtering
is done. Default: None.
after : bool
If True, only return objects after the given date cutoff.
Otherwise, return objects before. Default: True
with_dt : bool
If True, yield a tuple (key, datetime.datetime(LastModified)) of
the s3 Key and the object's LastModified date as a
datetime.datetime object, only yield s3 key otherwise.
Default: False.
do_retry : bool
If True, and no contents appear, try again in case there was simply a
brief lag. If False, do not retry, and just accept the "directory" is
empty.
Returns
-------
iterator[key]|iterator[(key, datetime.datetime)]
An iterator over s3 keys or (key, LastModified) tuples.
"""
if date_cutoff:
date_cutoff = date_cutoff if\
isinstance(date_cutoff, datetime) else\
get_date_from_str(date_cutoff)
# Check timezone info
if date_cutoff.utcoffset() is None:
date_cutoff = date_cutoff.replace(tzinfo=timezone.utc)
if date_cutoff.utcoffset() != timedelta():
date_cutoff = date_cutoff.astimezone(timezone.utc)
is_truncated = True
marker = None
while is_truncated:
# Get the (next) batch of contents.
if marker:
resp = s3.list_objects(Bucket=bucket, Prefix=prefix, Marker=marker)
else:
resp = s3.list_objects(Bucket=bucket, Prefix=prefix)
# Handle case where no contents are found.
if not resp.get('Contents'):
if do_retry:
logger.info("Prefix \"%s\" does not seem to have children. "
"Retrying once." % prefix)
do_retry = False
sleep(0.1)
continue
else:
logger.info("No contents found for \"%s\"." % prefix)
break
# Filter by time.
for entry in resp['Contents']:
if entry['Key'] != marker:
if date_cutoff and after and\
entry['LastModified'] > date_cutoff\
or\
date_cutoff and not after and\
entry['LastModified'] < date_cutoff\
or \
date_cutoff is None:
yield (entry['Key'], entry['LastModified']) if with_dt \
else entry['Key']
is_truncated = resp['IsTruncated']
marker = entry['Key']
[docs]def rename_s3_prefix(s3, bucket, old_prefix, new_prefix):
"""Change an s3 prefix within the same bucket."""
to_delete = []
for key in iter_s3_keys(s3, bucket, old_prefix):
# Copy the object to the new key (with prefix replaced)
new_key = key.replace(old_prefix, new_prefix)
s3.copy_object(Bucket=bucket, Key=new_key,
CopySource={'Bucket': bucket, 'Key': key},
MetadataDirective='COPY',
TaggingDirective='COPY')
# Keep track of the objects that will need to be deleted (the old keys)
to_delete.append({'Key': key})
# Delete objects in maximum batches of 1000.
if len(to_delete) >= 1000:
s3.delete_objects(Bucket=bucket,
Delete={'Objects': to_delete[:1000]})
del to_delete[:1000]
# Get any stragglers.
s3.delete_objects(Bucket=bucket,
Delete={'Objects': to_delete})
return
[docs]def get_s3_file_tree(s3, bucket, prefix, date_cutoff=None, after=True,
with_dt=False):
"""Overcome s3 response limit and return NestedDict tree of paths.
The NestedDict object also allows the user to search by the ends of a path.
The tree mimics a file directory structure, with the leave nodes being the
full unbroken key. For example, 'path/to/file.txt' would be retrieved by
ret['path']['to']['file.txt']['key']
The NestedDict object returned also has the capability to get paths that
lead to a certain value. So if you wanted all paths that lead to something
called 'file.txt', you could use
ret.get_paths('file.txt')
For more details, see the NestedDict docs.
Parameters
----------
s3 : boto3.client.S3
A boto3.client.S3 instance
bucket : str
The name of the bucket to list objects in
prefix : str
The prefix filtering of the objects for list
date_cutoff : str|datetime.datetime
A datestring of format %Y(-%m-%d-%H-%M-%S) or a datetime.datetime
object. The date is assumed to be in UTC. By default no filtering
is done. Default: None.
after : bool
If True, only return objects after the given date cutoff.
Otherwise, return objects before. Default: True
with_dt : bool
If True, yield a tuple (key, datetime.datetime(LastModified)) of
the s3 Key and the object's LastModified date as a
datetime.datetime object, only yield s3 key otherwise.
Default: False.
Returns
-------
NestedDict
A file tree represented as an NestedDict
"""
file_tree = NestedDict()
pref_path = prefix.split('/')[:-1] # avoid the trailing empty str.
for k in iter_s3_keys(s3, bucket, prefix, date_cutoff, after, with_dt):
if with_dt:
key, dt = k
else:
key, dt = k, None
full_path = key.split('/')
relevant_path = full_path[len(pref_path):]
curr = file_tree
for step in relevant_path:
curr = curr[step]
curr['key'] = k
return file_tree
[docs]def get_s3_client(unsigned=True):
"""Return a boto3 S3 client with optional unsigned config.
Parameters
----------
unsigned : Optional[bool]
If True, the client will be using unsigned mode in which public
resources can be accessed without credentials. Default: True
Returns
-------
botocore.client.S3
A client object to AWS S3.
"""
if unsigned:
return boto3.client('s3', config=Config(signature_version=UNSIGNED))
else:
return boto3.client('s3')
if __name__ == '__main__':
parser = ArgumentParser(
'aws.py',
description=('Use some of INDRA\'s aws tools. For more specific help, '
'select one of the Methods with the `-h` option.')
)
subparsers = parser.add_subparsers(title='Task')
subparsers.required = True
subparsers.dest = 'task'
# Create parent parser classes for second layer of options
parent_run_parser = ArgumentParser(add_help=False)
parent_run_parser.add_argument(
'command',
help=('Enter the command as a single string to be run as if in a '
'batch environment.')
)
parent_run_parser.add_argument(
'--project', '-P',
default='aske',
help='Give a name for the project.'
)
parent_run_parser.add_argument(
'--purpose', '-p',
help='Give the task some meaning.'
)
parent_kill_parser = ArgumentParser(add_help=False)
parent_kill_parser.add_argument(
'queue_name',
help='Select the batch queue in which all jobs should be terminated.'
)
parent_kill_parser.add_argument(
'--reason', '-R',
help='Give a reason for killing all the jobs.'
)
# Make non_db_parser and get subparsers
run_parser = subparsers.add_parser(
'run_in_batch',
parents=[parent_run_parser],
description=('This should be called to run any command wtihin an aws '
'batch job instance.'),
formatter_class=ArgumentDefaultsHelpFormatter
)
# Make db parser and get subparsers.
kill_parser = subparsers.add_parser(
'kill_all',
parents=[parent_kill_parser],
description='Kill all the jobs running in a given queue.',
formatter_class=ArgumentDefaultsHelpFormatter
)
args = parser.parse_args()
if args.task == 'run_in_batch':
ret_code = run_in_batch(args.command.split(), args.project,
args.purpose)
if ret_code == 0:
logger.info('Job ended well.')
else:
logger.error('Job failed!')
import sys
sys.exit(ret_code)
elif args.task == 'kill_all':
kill_all(args.queue_name, args.reason)