/scripts/launch.py
Python | 300 lines | 279 code | 16 blank | 5 comment | 6 complexity | 9db2db3367357c96f7507967b89b8a58 MD5 | raw file
Possible License(s): MIT
- import datetime
- import json
- import os
- import click
- AMI_MAP = {
- "us-west-1": "FILL IN YOUR AMI HERE",
- }
- def highlight(x):
- if not isinstance(x, str):
- x = json.dumps(x, sort_keys=True, indent=2)
- click.secho(x, fg='green')
- def upload_archive(exp_name, archive_excludes, s3_bucket):
- import hashlib, os.path as osp, subprocess, tempfile, uuid, sys
- # Archive this package
- thisfile_dir = osp.dirname(osp.abspath(__file__))
- pkg_parent_dir = osp.abspath(osp.join(thisfile_dir, '..', '..'))
- pkg_subdir = osp.basename(osp.abspath(osp.join(thisfile_dir, '..')))
- assert osp.abspath(__file__) == osp.join(pkg_parent_dir, pkg_subdir, 'scripts', 'launch.py'), 'You moved me!'
- # Run tar
- tmpdir = tempfile.TemporaryDirectory()
- local_archive_path = osp.join(tmpdir.name, '{}.tar.gz'.format(uuid.uuid4()))
- tar_cmd = ["tar", "-zcvf", local_archive_path, "-C", pkg_parent_dir]
- for pattern in archive_excludes:
- tar_cmd += ["--exclude", pattern]
- tar_cmd += ["-h", pkg_subdir]
- highlight(" ".join(tar_cmd))
- if sys.platform == 'darwin':
- # Prevent Mac tar from adding ._* files
- env = os.environ.copy()
- env['COPYFILE_DISABLE'] = '1'
- subprocess.check_call(tar_cmd, env=env)
- else:
- subprocess.check_call(tar_cmd)
- # Construct remote path to place the archive on S3
- with open(local_archive_path, 'rb') as f:
- archive_hash = hashlib.sha224(f.read()).hexdigest()
- remote_archive_path = '{}/{}_{}.tar.gz'.format(s3_bucket, exp_name, archive_hash)
- # Upload
- upload_cmd = ["aws", "s3", "cp", local_archive_path, remote_archive_path]
- highlight(" ".join(upload_cmd))
- subprocess.check_call(upload_cmd)
- presign_cmd = ["aws", "s3", "presign", remote_archive_path, "--expires-in", str(60 * 60 * 24 * 30)]
- highlight(" ".join(presign_cmd))
- remote_url = subprocess.check_output(presign_cmd).decode("utf-8").strip()
- return remote_url
- def make_disable_hyperthreading_script():
- return """
- # disable hyperthreading
- # https://forums.aws.amazon.com/message.jspa?messageID=189757
- for cpunum in $(
- cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list |
- sed 's/-/,/g' | cut -s -d, -f2- | tr ',' '\n' | sort -un); do
- echo 0 > /sys/devices/system/cpu/cpu$cpunum/online
- done
- """
- def make_download_and_run_script(code_url, cmd):
- return """su -l ubuntu <<'EOF'
- set -x
- cd ~
- wget --quiet "{code_url}" -O code.tar.gz
- tar xvaf code.tar.gz
- rm code.tar.gz
- cd es-distributed
- {cmd}
- EOF
- """.format(code_url=code_url, cmd=cmd)
- def make_master_script(code_url, exp_str):
- cmd = """
- cat > ~/experiment.json <<< '{exp_str}'
- python -m es_distributed.main master \
- --master_socket_path /var/run/redis/redis.sock \
- --log_dir ~ \
- --exp_file ~/experiment.json
- """.format(exp_str=exp_str)
- return """#!/bin/bash
- {
- set -x
- %s
- # Disable redis snapshots
- echo 'save ""' >> /etc/redis/redis.conf
- # Make the unix domain socket available for the master client
- # (TCP is still enabled for workers/relays)
- echo "unixsocket /var/run/redis/redis.sock" >> /etc/redis/redis.conf
- echo "unixsocketperm 777" >> /etc/redis/redis.conf
- mkdir -p /var/run/redis
- chown ubuntu:ubuntu /var/run/redis
- systemctl restart redis
- %s
- } >> /home/ubuntu/user_data.log 2>&1
- """ % (make_disable_hyperthreading_script(), make_download_and_run_script(code_url, cmd))
- def make_worker_script(code_url, master_private_ip):
- cmd = ("MKL_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 "
- "python -m es_distributed.main workers "
- "--master_host {} "
- "--relay_socket_path /var/run/redis/redis.sock").format(master_private_ip)
- return """#!/bin/bash
- {
- set -x
- %s
- # Disable redis snapshots
- echo 'save ""' >> /etc/redis/redis.conf
- # Make redis use a unix domain socket and disable TCP sockets
- sed -ie "s/port 6379/port 0/" /etc/redis/redis.conf
- echo "unixsocket /var/run/redis/redis.sock" >> /etc/redis/redis.conf
- echo "unixsocketperm 777" >> /etc/redis/redis.conf
- mkdir -p /var/run/redis
- chown ubuntu:ubuntu /var/run/redis
- systemctl restart redis
- %s
- } >> /home/ubuntu/user_data.log 2>&1
- """ % (make_disable_hyperthreading_script(), make_download_and_run_script(code_url, cmd))
- @click.command()
- @click.argument('exp_files', nargs=-1, type=click.Path(), required=True)
- @click.option('--key_name', default=lambda: os.environ["KEY_NAME"])
- @click.option('--aws_access_key_id', default=os.environ.get("AWS_ACCESS_KEY", None))
- @click.option('--aws_secret_access_key', default=os.environ.get("AWS_ACCESS_SECRET", None))
- @click.option('--archive_excludes', default=(".git", "__pycache__", ".idea", "scratch"))
- @click.option('--s3_bucket')
- @click.option('--spot_price')
- @click.option('--region_name')
- @click.option('--zone')
- @click.option('--cluster_size', type=int, default=1)
- @click.option('--spot_master', is_flag=True, help='Use a spot instance as the master')
- @click.option('--master_instance_type')
- @click.option('--worker_instance_type')
- @click.option('--security_group')
- @click.option('--yes', is_flag=True, help='Skip confirmation prompt')
- def main(exp_files,
- key_name,
- aws_access_key_id,
- aws_secret_access_key,
- archive_excludes,
- s3_bucket,
- spot_price,
- region_name,
- zone,
- cluster_size,
- spot_master,
- master_instance_type,
- worker_instance_type,
- security_group,
- yes
- ):
- highlight('Launching:')
- highlight(locals())
- import boto3
- ec2 = boto3.resource(
- "ec2",
- region_name=region_name,
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key
- )
- as_client = boto3.client(
- 'autoscaling',
- region_name=region_name,
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key
- )
- for i_exp_file, exp_file in enumerate(exp_files):
- with open(exp_file, 'r') as f:
- exp = json.loads(f.read())
- highlight('Experiment [{}/{}]:'.format(i_exp_file + 1, len(exp_files)))
- highlight(exp)
- if not yes:
- click.confirm('Continue?', abort=True)
- exp_prefix = exp['exp_prefix']
- exp_str = json.dumps(exp)
- exp_name = '{}_{}'.format(exp_prefix, datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
- code_url = upload_archive(exp_name, archive_excludes, s3_bucket)
- highlight("code_url: " + code_url)
- image_id = AMI_MAP[region_name]
- highlight('Using AMI: {}'.format(image_id))
- if spot_master:
- import base64
- requests = ec2.meta.client.request_spot_instances(
- SpotPrice=spot_price,
- InstanceCount=1,
- LaunchSpecification=dict(
- ImageId=image_id,
- KeyName=key_name,
- InstanceType=master_instance_type,
- EbsOptimized=True,
- SecurityGroups=[security_group],
- Placement=dict(
- AvailabilityZone=zone,
- ),
- UserData=base64.b64encode(make_master_script(code_url, exp_str).encode()).decode()
- )
- )['SpotInstanceRequests']
- assert len(requests) == 1
- request_id = requests[0]['SpotInstanceRequestId']
- # Wait for fulfillment
- highlight('Waiting for spot request {} to be fulfilled'.format(request_id))
- ec2.meta.client.get_waiter('spot_instance_request_fulfilled').wait(SpotInstanceRequestIds=[request_id])
- req = ec2.meta.client.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id])
- master_instance_id = req['SpotInstanceRequests'][0]['InstanceId']
- master_instance = ec2.Instance(master_instance_id)
- else:
- master_instance = ec2.create_instances(
- ImageId=image_id,
- KeyName=key_name,
- InstanceType=master_instance_type,
- EbsOptimized=True,
- SecurityGroups=[security_group],
- MinCount=1,
- MaxCount=1,
- Placement=dict(
- AvailabilityZone=zone,
- ),
- UserData=make_master_script(code_url, exp_str)
- )[0]
- master_instance.create_tags(
- Tags=[
- dict(Key="Name", Value=exp_name + "-master"),
- dict(Key="es_dist_role", Value="master"),
- dict(Key="exp_prefix", Value=exp_prefix),
- dict(Key="exp_name", Value=exp_name),
- ]
- )
- highlight("Master created. IP: %s" % master_instance.public_ip_address)
- config_resp = as_client.create_launch_configuration(
- ImageId=image_id,
- KeyName=key_name,
- InstanceType=worker_instance_type,
- EbsOptimized=True,
- SecurityGroups=[security_group],
- LaunchConfigurationName=exp_name,
- UserData=make_worker_script(code_url, master_instance.private_ip_address),
- SpotPrice=spot_price,
- )
- assert config_resp["ResponseMetadata"]["HTTPStatusCode"] == 200
- asg_resp = as_client.create_auto_scaling_group(
- AutoScalingGroupName=exp_name,
- LaunchConfigurationName=exp_name,
- MinSize=cluster_size,
- MaxSize=cluster_size,
- DesiredCapacity=cluster_size,
- AvailabilityZones=[zone],
- Tags=[
- dict(Key="Name", Value=exp_name + "-worker"),
- dict(Key="es_dist_role", Value="worker"),
- dict(Key="exp_prefix", Value=exp_prefix),
- dict(Key="exp_name", Value=exp_name),
- ]
- # todo: also try placement group to see if there is increased networking performance
- )
- assert asg_resp["ResponseMetadata"]["HTTPStatusCode"] == 200
- highlight("Scaling group created")
- highlight("%s launched successfully." % exp_name)
- highlight("Manage at %s" % (
- "https://%s.console.aws.amazon.com/ec2/v2/home?region=%s#Instances:sort=tag:Name" % (
- region_name, region_name)
- ))
- if __name__ == '__main__':
- main()