PageRenderTime 69ms CodeModel.GetById 30ms RepoModel.GetById 0ms app.codeStats 1ms

/scripts/launch.py

https://bitbucket.org/JoseETeixeira/evolution-strategies-starter
Python | 300 lines | 279 code | 16 blank | 5 comment | 6 complexity | 9db2db3367357c96f7507967b89b8a58 MD5 | raw file
Possible License(s): MIT
  1. import datetime
  2. import json
  3. import os
  4. import click
  5. AMI_MAP = {
  6. "us-west-1": "FILL IN YOUR AMI HERE",
  7. }
  8. def highlight(x):
  9. if not isinstance(x, str):
  10. x = json.dumps(x, sort_keys=True, indent=2)
  11. click.secho(x, fg='green')
  12. def upload_archive(exp_name, archive_excludes, s3_bucket):
  13. import hashlib, os.path as osp, subprocess, tempfile, uuid, sys
  14. # Archive this package
  15. thisfile_dir = osp.dirname(osp.abspath(__file__))
  16. pkg_parent_dir = osp.abspath(osp.join(thisfile_dir, '..', '..'))
  17. pkg_subdir = osp.basename(osp.abspath(osp.join(thisfile_dir, '..')))
  18. assert osp.abspath(__file__) == osp.join(pkg_parent_dir, pkg_subdir, 'scripts', 'launch.py'), 'You moved me!'
  19. # Run tar
  20. tmpdir = tempfile.TemporaryDirectory()
  21. local_archive_path = osp.join(tmpdir.name, '{}.tar.gz'.format(uuid.uuid4()))
  22. tar_cmd = ["tar", "-zcvf", local_archive_path, "-C", pkg_parent_dir]
  23. for pattern in archive_excludes:
  24. tar_cmd += ["--exclude", pattern]
  25. tar_cmd += ["-h", pkg_subdir]
  26. highlight(" ".join(tar_cmd))
  27. if sys.platform == 'darwin':
  28. # Prevent Mac tar from adding ._* files
  29. env = os.environ.copy()
  30. env['COPYFILE_DISABLE'] = '1'
  31. subprocess.check_call(tar_cmd, env=env)
  32. else:
  33. subprocess.check_call(tar_cmd)
  34. # Construct remote path to place the archive on S3
  35. with open(local_archive_path, 'rb') as f:
  36. archive_hash = hashlib.sha224(f.read()).hexdigest()
  37. remote_archive_path = '{}/{}_{}.tar.gz'.format(s3_bucket, exp_name, archive_hash)
  38. # Upload
  39. upload_cmd = ["aws", "s3", "cp", local_archive_path, remote_archive_path]
  40. highlight(" ".join(upload_cmd))
  41. subprocess.check_call(upload_cmd)
  42. presign_cmd = ["aws", "s3", "presign", remote_archive_path, "--expires-in", str(60 * 60 * 24 * 30)]
  43. highlight(" ".join(presign_cmd))
  44. remote_url = subprocess.check_output(presign_cmd).decode("utf-8").strip()
  45. return remote_url
  46. def make_disable_hyperthreading_script():
  47. return """
  48. # disable hyperthreading
  49. # https://forums.aws.amazon.com/message.jspa?messageID=189757
  50. for cpunum in $(
  51. cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list |
  52. sed 's/-/,/g' | cut -s -d, -f2- | tr ',' '\n' | sort -un); do
  53. echo 0 > /sys/devices/system/cpu/cpu$cpunum/online
  54. done
  55. """
  56. def make_download_and_run_script(code_url, cmd):
  57. return """su -l ubuntu <<'EOF'
  58. set -x
  59. cd ~
  60. wget --quiet "{code_url}" -O code.tar.gz
  61. tar xvaf code.tar.gz
  62. rm code.tar.gz
  63. cd es-distributed
  64. {cmd}
  65. EOF
  66. """.format(code_url=code_url, cmd=cmd)
  67. def make_master_script(code_url, exp_str):
  68. cmd = """
  69. cat > ~/experiment.json <<< '{exp_str}'
  70. python -m es_distributed.main master \
  71. --master_socket_path /var/run/redis/redis.sock \
  72. --log_dir ~ \
  73. --exp_file ~/experiment.json
  74. """.format(exp_str=exp_str)
  75. return """#!/bin/bash
  76. {
  77. set -x
  78. %s
  79. # Disable redis snapshots
  80. echo 'save ""' >> /etc/redis/redis.conf
  81. # Make the unix domain socket available for the master client
  82. # (TCP is still enabled for workers/relays)
  83. echo "unixsocket /var/run/redis/redis.sock" >> /etc/redis/redis.conf
  84. echo "unixsocketperm 777" >> /etc/redis/redis.conf
  85. mkdir -p /var/run/redis
  86. chown ubuntu:ubuntu /var/run/redis
  87. systemctl restart redis
  88. %s
  89. } >> /home/ubuntu/user_data.log 2>&1
  90. """ % (make_disable_hyperthreading_script(), make_download_and_run_script(code_url, cmd))
  91. def make_worker_script(code_url, master_private_ip):
  92. cmd = ("MKL_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 "
  93. "python -m es_distributed.main workers "
  94. "--master_host {} "
  95. "--relay_socket_path /var/run/redis/redis.sock").format(master_private_ip)
  96. return """#!/bin/bash
  97. {
  98. set -x
  99. %s
  100. # Disable redis snapshots
  101. echo 'save ""' >> /etc/redis/redis.conf
  102. # Make redis use a unix domain socket and disable TCP sockets
  103. sed -ie "s/port 6379/port 0/" /etc/redis/redis.conf
  104. echo "unixsocket /var/run/redis/redis.sock" >> /etc/redis/redis.conf
  105. echo "unixsocketperm 777" >> /etc/redis/redis.conf
  106. mkdir -p /var/run/redis
  107. chown ubuntu:ubuntu /var/run/redis
  108. systemctl restart redis
  109. %s
  110. } >> /home/ubuntu/user_data.log 2>&1
  111. """ % (make_disable_hyperthreading_script(), make_download_and_run_script(code_url, cmd))
  112. @click.command()
  113. @click.argument('exp_files', nargs=-1, type=click.Path(), required=True)
  114. @click.option('--key_name', default=lambda: os.environ["KEY_NAME"])
  115. @click.option('--aws_access_key_id', default=os.environ.get("AWS_ACCESS_KEY", None))
  116. @click.option('--aws_secret_access_key', default=os.environ.get("AWS_ACCESS_SECRET", None))
  117. @click.option('--archive_excludes', default=(".git", "__pycache__", ".idea", "scratch"))
  118. @click.option('--s3_bucket')
  119. @click.option('--spot_price')
  120. @click.option('--region_name')
  121. @click.option('--zone')
  122. @click.option('--cluster_size', type=int, default=1)
  123. @click.option('--spot_master', is_flag=True, help='Use a spot instance as the master')
  124. @click.option('--master_instance_type')
  125. @click.option('--worker_instance_type')
  126. @click.option('--security_group')
  127. @click.option('--yes', is_flag=True, help='Skip confirmation prompt')
  128. def main(exp_files,
  129. key_name,
  130. aws_access_key_id,
  131. aws_secret_access_key,
  132. archive_excludes,
  133. s3_bucket,
  134. spot_price,
  135. region_name,
  136. zone,
  137. cluster_size,
  138. spot_master,
  139. master_instance_type,
  140. worker_instance_type,
  141. security_group,
  142. yes
  143. ):
  144. highlight('Launching:')
  145. highlight(locals())
  146. import boto3
  147. ec2 = boto3.resource(
  148. "ec2",
  149. region_name=region_name,
  150. aws_access_key_id=aws_access_key_id,
  151. aws_secret_access_key=aws_secret_access_key
  152. )
  153. as_client = boto3.client(
  154. 'autoscaling',
  155. region_name=region_name,
  156. aws_access_key_id=aws_access_key_id,
  157. aws_secret_access_key=aws_secret_access_key
  158. )
  159. for i_exp_file, exp_file in enumerate(exp_files):
  160. with open(exp_file, 'r') as f:
  161. exp = json.loads(f.read())
  162. highlight('Experiment [{}/{}]:'.format(i_exp_file + 1, len(exp_files)))
  163. highlight(exp)
  164. if not yes:
  165. click.confirm('Continue?', abort=True)
  166. exp_prefix = exp['exp_prefix']
  167. exp_str = json.dumps(exp)
  168. exp_name = '{}_{}'.format(exp_prefix, datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
  169. code_url = upload_archive(exp_name, archive_excludes, s3_bucket)
  170. highlight("code_url: " + code_url)
  171. image_id = AMI_MAP[region_name]
  172. highlight('Using AMI: {}'.format(image_id))
  173. if spot_master:
  174. import base64
  175. requests = ec2.meta.client.request_spot_instances(
  176. SpotPrice=spot_price,
  177. InstanceCount=1,
  178. LaunchSpecification=dict(
  179. ImageId=image_id,
  180. KeyName=key_name,
  181. InstanceType=master_instance_type,
  182. EbsOptimized=True,
  183. SecurityGroups=[security_group],
  184. Placement=dict(
  185. AvailabilityZone=zone,
  186. ),
  187. UserData=base64.b64encode(make_master_script(code_url, exp_str).encode()).decode()
  188. )
  189. )['SpotInstanceRequests']
  190. assert len(requests) == 1
  191. request_id = requests[0]['SpotInstanceRequestId']
  192. # Wait for fulfillment
  193. highlight('Waiting for spot request {} to be fulfilled'.format(request_id))
  194. ec2.meta.client.get_waiter('spot_instance_request_fulfilled').wait(SpotInstanceRequestIds=[request_id])
  195. req = ec2.meta.client.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id])
  196. master_instance_id = req['SpotInstanceRequests'][0]['InstanceId']
  197. master_instance = ec2.Instance(master_instance_id)
  198. else:
  199. master_instance = ec2.create_instances(
  200. ImageId=image_id,
  201. KeyName=key_name,
  202. InstanceType=master_instance_type,
  203. EbsOptimized=True,
  204. SecurityGroups=[security_group],
  205. MinCount=1,
  206. MaxCount=1,
  207. Placement=dict(
  208. AvailabilityZone=zone,
  209. ),
  210. UserData=make_master_script(code_url, exp_str)
  211. )[0]
  212. master_instance.create_tags(
  213. Tags=[
  214. dict(Key="Name", Value=exp_name + "-master"),
  215. dict(Key="es_dist_role", Value="master"),
  216. dict(Key="exp_prefix", Value=exp_prefix),
  217. dict(Key="exp_name", Value=exp_name),
  218. ]
  219. )
  220. highlight("Master created. IP: %s" % master_instance.public_ip_address)
  221. config_resp = as_client.create_launch_configuration(
  222. ImageId=image_id,
  223. KeyName=key_name,
  224. InstanceType=worker_instance_type,
  225. EbsOptimized=True,
  226. SecurityGroups=[security_group],
  227. LaunchConfigurationName=exp_name,
  228. UserData=make_worker_script(code_url, master_instance.private_ip_address),
  229. SpotPrice=spot_price,
  230. )
  231. assert config_resp["ResponseMetadata"]["HTTPStatusCode"] == 200
  232. asg_resp = as_client.create_auto_scaling_group(
  233. AutoScalingGroupName=exp_name,
  234. LaunchConfigurationName=exp_name,
  235. MinSize=cluster_size,
  236. MaxSize=cluster_size,
  237. DesiredCapacity=cluster_size,
  238. AvailabilityZones=[zone],
  239. Tags=[
  240. dict(Key="Name", Value=exp_name + "-worker"),
  241. dict(Key="es_dist_role", Value="worker"),
  242. dict(Key="exp_prefix", Value=exp_prefix),
  243. dict(Key="exp_name", Value=exp_name),
  244. ]
  245. # todo: also try placement group to see if there is increased networking performance
  246. )
  247. assert asg_resp["ResponseMetadata"]["HTTPStatusCode"] == 200
  248. highlight("Scaling group created")
  249. highlight("%s launched successfully." % exp_name)
  250. highlight("Manage at %s" % (
  251. "https://%s.console.aws.amazon.com/ec2/v2/home?region=%s#Instances:sort=tag:Name" % (
  252. region_name, region_name)
  253. ))
  254. if __name__ == '__main__':
  255. main()