PageRenderTime 296ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 1ms

/Automation/psi_ops.py

https://bitbucket.org/psiphon/psiphon-circumvention-system/
Python | 1413 lines | 1290 code | 66 blank | 57 comment | 71 complexity | 45d505d0687721a857f468601a6066ad MD5 | raw file
Possible License(s): GPL-3.0
  1. #!/usr/bin/python
  2. #
  3. # Copyright (c) 2011, Psiphon Inc.
  4. # All rights reserved.
  5. #
  6. # This program is free software: you can redistribute it and/or modify
  7. # it under the terms of the GNU General Public License as published by
  8. # the Free Software Foundation, either version 3 of the License, or
  9. # (at your option) any later version.
  10. #
  11. # This program is distributed in the hope that it will be useful,
  12. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. # GNU General Public License for more details.
  15. #
  16. # You should have received a copy of the GNU General Public License
  17. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. #
  19. import os
  20. import time
  21. import datetime
  22. import pprint
  23. import json
  24. import textwrap
  25. import itertools
  26. import binascii
  27. import base64
  28. import jsonpickle
  29. import tempfile
  30. import pprint
  31. import struct
  32. import socket
  33. import random
  34. import optparse
  35. from pkg_resources import parse_version
  36. import psi_utils
  37. import psi_ops_cms
  38. # Modules available only on the automation server
  39. try:
  40. import psi_ssh
  41. except ImportError as error:
  42. print error
  43. try:
  44. import psi_linode
  45. except ImportError as error:
  46. print error
  47. try:
  48. import psi_elastichosts
  49. except ImportError as error:
  50. print error
  51. try:
  52. import psi_templates
  53. except ImportError as error:
  54. print error
  55. try:
  56. import psi_ops_s3
  57. except ImportError as error:
  58. print error
  59. try:
  60. import psi_ops_install
  61. except ImportError as error:
  62. print error
  63. try:
  64. import psi_ops_deploy
  65. except ImportError as error:
  66. print error
  67. try:
  68. import psi_ops_build
  69. except ImportError as error:
  70. print error
  71. try:
  72. import psi_ops_test
  73. except ImportError as error:
  74. print error
  75. try:
  76. import psi_ops_twitter
  77. except ImportError as error:
  78. print error
  79. try:
  80. import psi_routes
  81. except ImportError as error:
  82. print error
  83. # Modules available only on the node server
  84. try:
  85. import GeoIP
  86. except ImportError:
  87. pass
  88. # NOTE: update compartmentalize() functions when adding fields
  89. PropagationChannel = psi_utils.recordtype(
  90. 'PropagationChannel',
  91. 'id, name, propagation_mechanism_types')
  92. PropagationMechanism = psi_utils.recordtype(
  93. 'PropagationMechanism',
  94. 'type')
  95. TwitterPropagationAccount = psi_utils.recordtype(
  96. 'TwitterPropagationAccount',
  97. 'name, consumer_key, consumer_secret, access_token_key, access_token_secret')
  98. EmailPropagationAccount = psi_utils.recordtype(
  99. 'EmailPropagationAccount',
  100. 'email_address')
  101. Sponsor = psi_utils.recordtype(
  102. 'Sponsor',
  103. 'id, name, banner, home_pages, campaigns, page_view_regexes, https_request_regexes')
  104. SponsorHomePage = psi_utils.recordtype(
  105. 'SponsorHomePage',
  106. 'region, url')
  107. SponsorCampaign = psi_utils.recordtype(
  108. 'SponsorCampaign',
  109. 'propagation_channel_id, propagation_mechanism_type, account, s3_bucket_name')
  110. SponsorRegex = psi_utils.recordtype(
  111. 'SponsorRegex',
  112. 'regex, replace')
  113. Host = psi_utils.recordtype(
  114. 'Host',
  115. 'id, provider, provider_id, ip_address, ssh_port, ssh_username, ssh_password, ssh_host_key, '+
  116. 'stats_ssh_username, stats_ssh_password')
  117. Server = psi_utils.recordtype(
  118. 'Server',
  119. 'id, host_id, ip_address, egress_ip_address, '+
  120. 'propagation_channel_id, is_embedded, discovery_date_range, '+
  121. 'web_server_port, web_server_secret, web_server_certificate, web_server_private_key, '+
  122. 'ssh_port, ssh_username, ssh_password, ssh_host_key, ssh_obfuscated_port, ssh_obfuscated_key',
  123. default=None)
  124. ClientVersion = psi_utils.recordtype(
  125. 'ClientVersion',
  126. 'version, description')
  127. AwsAccount = psi_utils.recordtype(
  128. 'AwsAccount',
  129. 'access_id, secret_key',
  130. default=None)
  131. ProviderRank = psi_utils.recordtype(
  132. 'ProviderRank',
  133. 'provider, rank',
  134. default=None)
  135. ProviderRank.provider_values = ('linode', 'elastichosts')
  136. LinodeAccount = psi_utils.recordtype(
  137. 'LinodeAccount',
  138. 'api_key, base_id, base_ip_address, base_ssh_port, '+
  139. 'base_root_password, base_stats_username, base_host_public_key, '+
  140. 'base_known_hosts_entry, base_rsa_private_key, base_rsa_public_key, '+
  141. 'base_tarball_path',
  142. default=None)
  143. ElasticHostsAccount = psi_utils.recordtype(
  144. 'ElasticHostsAccount',
  145. 'zone, uuid, api_key, base_drive_id, cpu, mem, base_host_public_key, '+
  146. 'root_username, base_root_password, base_ssh_port, stats_username, rank',
  147. default=None)
  148. ElasticHostsAccount.zone_values = ('ELASTICHOSTS_US1', # sat-p
  149. 'ELASTICHOSTS_UK1', # lon-p
  150. 'ELASTICHOSTS_UK2') # lon-b
  151. EmailServerAccount = psi_utils.recordtype(
  152. 'EmailServerAccount',
  153. 'ip_address, ssh_port, ssh_username, ssh_pkey, ssh_host_key, '+
  154. 'config_file_path',
  155. default=None)
  156. StatsServerAccount = psi_utils.recordtype(
  157. 'StatsServerAccount',
  158. 'ip_address, ssh_port, ssh_username, ssh_password, ssh_host_key',
  159. default=None)
  160. SpeedTestURL = psi_utils.recordtype(
  161. 'SpeedTestURL',
  162. 'server_address, server_port, request_path')
  163. class PsiphonNetwork(psi_ops_cms.PersistentObject):
  164. def __init__(self):
  165. super(PsiphonNetwork, self).__init__()
  166. # TODO: what is this __version for?
  167. self.__version = '1.0'
  168. self.__sponsors = {}
  169. self.__propagation_mechanisms = {
  170. 'twitter' : PropagationMechanism('twitter'),
  171. 'email-autoresponder' : PropagationMechanism('email-autoresponder'),
  172. 'static-download' : PropagationMechanism('static-download')
  173. }
  174. self.__propagation_channels = {}
  175. self.__hosts = {}
  176. self.__servers = {}
  177. self.__client_versions = []
  178. self.__email_server_account = EmailServerAccount()
  179. self.__stats_server_account = StatsServerAccount()
  180. self.__aws_account = AwsAccount()
  181. self.__provider_ranks = []
  182. self.__linode_account = LinodeAccount()
  183. self.__elastichosts_accounts = []
  184. self.__deploy_implementation_required_for_hosts = set()
  185. self.__deploy_data_required_for_all = False
  186. self.__deploy_builds_required_for_campaigns = set()
  187. self.__deploy_stats_config_required = False
  188. self.__deploy_email_config_required = False
  189. self.__speed_test_urls = []
  190. class_version = '0.5'
  191. def upgrade(self):
  192. if cmp(parse_version(self.version), parse_version('0.1')) < 0:
  193. self.__provider_ranks = []
  194. self.__elastichosts_accounts = []
  195. self.version = '0.1'
  196. if cmp(parse_version(self.version), parse_version('0.2')) < 0:
  197. for server in self.__servers.itervalues():
  198. server.ssh_obfuscated_port = None
  199. server.ssh_obfuscated_key = None
  200. self.version = '0.2'
  201. if cmp(parse_version(self.version), parse_version('0.3')) < 0:
  202. for host in self.__hosts.itervalues():
  203. host.provider = None
  204. self.version = '0.3'
  205. if cmp(parse_version(self.version), parse_version('0.4')) < 0:
  206. for sponsor in self.__sponsors.itervalues():
  207. sponsor.page_view_regexes = []
  208. sponsor.https_request_regexes = []
  209. self.version = '0.4'
  210. if cmp(parse_version(self.version), parse_version('0.5')) < 0:
  211. self.__speed_test_urls = []
  212. self.version = '0.5'
  213. def show_status(self):
  214. # NOTE: verbose mode prints credentials to stdout
  215. print textwrap.dedent('''
  216. Sponsors: %d
  217. Channels: %d
  218. Twitter Campaigns: %d
  219. Email Campaigns: %d
  220. Hosts: %d
  221. Servers: %d
  222. Email Server: %s
  223. Stats Server: %s
  224. Client Version: %s %s
  225. AWS Account: %s
  226. Provider Ranks: %s
  227. Linode Account: %s
  228. ElasticHosts Account: %s
  229. Deploys Pending: Host Implementations %d
  230. Host Data %s
  231. Campaign Builds %d
  232. Stats Server Config %s
  233. Email Server Config %s
  234. ''') % (
  235. len(self.__sponsors),
  236. len(self.__propagation_channels),
  237. sum([len(filter(lambda x:x.propagation_mechanism_type == 'twitter', sponsor.campaigns))
  238. for sponsor in self.__sponsors.itervalues()]),
  239. sum([len(filter(lambda x:x.propagation_mechanism_type == 'email-autoresponder', sponsor.campaigns))
  240. for sponsor in self.__sponsors.itervalues()]),
  241. len(self.__hosts),
  242. len(self.__servers),
  243. self.__email_server_account.ip_address if self.__email_server_account else 'None',
  244. self.__stats_server_account.ip_address if self.__stats_server_account else 'None',
  245. self.__client_versions[-1].version if self.__client_versions else 'None',
  246. self.__client_versions[-1].description if self.__client_versions else '',
  247. 'Configured' if self.__aws_account.access_id else 'None',
  248. 'Configured' if self.__provider_ranks else 'None',
  249. 'Configured' if self.__linode_account.api_key else 'None',
  250. 'Configured' if self.__elastichosts_accounts else 'None',
  251. len(self.__deploy_implementation_required_for_hosts),
  252. 'Yes' if self.__deploy_data_required_for_all else 'No',
  253. len(self.__deploy_builds_required_for_campaigns),
  254. 'Yes' if self.__deploy_stats_config_required else 'No',
  255. 'Yes' if self.__deploy_email_config_required else 'No')
  256. def __show_logs(self, obj):
  257. for timestamp, message in obj.get_logs():
  258. print '%s: %s' % (timestamp.isoformat(), message)
  259. print ''
  260. def show_sponsors(self):
  261. for s in self.__sponsors.itervalues():
  262. self.show_sponsor(s.name)
  263. def show_sponsor(self, sponsor_name):
  264. s = self.__get_sponsor_by_name(sponsor_name)
  265. print textwrap.dedent('''
  266. ID: %(id)s
  267. Name: %(name)s
  268. Home Pages: %(home_pages)s
  269. Page View Regexes: %(page_view_regexes)s
  270. HTTPS Request Regexes: %(https_request_regexes)s
  271. Campaigns: %(campaigns)s
  272. ''') % {
  273. 'id': s.id,
  274. 'name': s.name,
  275. 'home_pages': '\n '.join(['%s: %s' % (region.ljust(5) if region else 'All',
  276. '\n '.join([h.url for h in home_pages]))
  277. for region, home_pages in sorted(s.home_pages.items())]),
  278. 'page_view_regexes': '\n '.join(['%s -> %s' % (page_view_regex.regex, page_view_regex.replace)
  279. for page_view_regex in s.page_view_regexes]),
  280. 'https_request_regexes': '\n '.join(['%s -> %s' % (https_request_regex.regex, https_request_regex.replace)
  281. for https_request_regex in s.https_request_regexes]),
  282. 'campaigns': '\n '.join(['%s %s %s %s' % (
  283. self.__propagation_channels[c.propagation_channel_id].name,
  284. c.propagation_mechanism_type,
  285. c.account[0] if c.account else 'None',
  286. c.s3_bucket_name)
  287. for c in s.campaigns])
  288. }
  289. self.__show_logs(s)
  290. def show_campaigns_on_propagation_channel(self, propagation_channel_name):
  291. propagation_channel = self.__get_propagation_channel_by_name(propagation_channel_name)
  292. for sponsor in self.__sponsors.itervalues():
  293. for campaign in sponsor.campaigns:
  294. if campaign.propagation_channel_id == propagation_channel.id:
  295. print textwrap.dedent('''
  296. Sponsor: %s
  297. Propagation Mechanism: %s
  298. Account: %s
  299. Bucket Name: %s''') % (
  300. sponsor.name,
  301. campaign.propagation_mechanism_type,
  302. campaign.account[0] if campaign.account else 'None',
  303. campaign.s3_bucket_name)
  304. def show_propagation_channels(self):
  305. for p in self.__propagation_channels.itervalues():
  306. self.show_propagation_channel(p.name)
  307. def show_propagation_channel(self, propagation_channel_name, now=None):
  308. if now == None:
  309. now = datetime.datetime.now()
  310. p = self.__get_propagation_channel_by_name(propagation_channel_name)
  311. embedded_servers = [server.id for server in self.__servers.itervalues()
  312. if server.propagation_channel_id == p.id and server.is_embedded]
  313. old_propagation_servers = [server.id for server in self.__servers.itervalues()
  314. if server.propagation_channel_id == p.id and
  315. not server.is_embedded and not server.discovery_date_range]
  316. current_discovery_servers = ['%s - %s : %s' % (server.discovery_date_range[0].isoformat(),
  317. server.discovery_date_range[1].isoformat(),
  318. server.id)
  319. for server in self.__servers.itervalues()
  320. if server.propagation_channel_id == p.id and server.discovery_date_range and
  321. (server.discovery_date_range[0] <= now < server.discovery_date_range[1])]
  322. current_discovery_servers.sort()
  323. future_discovery_servers = ['%s - %s : %s' % (server.discovery_date_range[0].isoformat(),
  324. server.discovery_date_range[1].isoformat(),
  325. server.id)
  326. for server in self.__servers.itervalues()
  327. if server.propagation_channel_id == p.id and server.discovery_date_range and
  328. server.discovery_date_range[0] > now]
  329. future_discovery_servers.sort()
  330. old_discovery_servers = ['%s - %s : %s' % (server.discovery_date_range[0].isoformat(),
  331. server.discovery_date_range[1].isoformat(),
  332. server.id)
  333. for server in self.__servers.itervalues()
  334. if server.propagation_channel_id == p.id and server.discovery_date_range and
  335. now >= server.discovery_date_range[1]]
  336. old_discovery_servers.sort()
  337. print textwrap.dedent('''
  338. ID: %s
  339. Name: %s
  340. Propagation Mechanisms: %s
  341. Embedded Servers: %s
  342. Discovery Servers: %s
  343. Future Discovery Servers: %s
  344. Old Propagation Servers: %s
  345. Old Discovery Servers: %s
  346. ''') % (
  347. p.id,
  348. p.name,
  349. '\n '.join(p.propagation_mechanism_types),
  350. '\n '.join(embedded_servers),
  351. '\n '.join(current_discovery_servers),
  352. '\n '.join(future_discovery_servers),
  353. '\n '.join(old_propagation_servers),
  354. '\n '.join(old_discovery_servers))
  355. self.__show_logs(p)
  356. def show_servers(self):
  357. for s in self.__servers.itervalues():
  358. self.show_server(s.id)
  359. def show_servers_on_host(self, host_id):
  360. for s in self.__servers.itervalues():
  361. if s.host_id == host_id:
  362. self.show_server(s.id)
  363. def show_server(self, server_id):
  364. s = self.__servers[server_id]
  365. print textwrap.dedent('''
  366. Server: %s
  367. Host: %s %s/%s
  368. IP Address: %s
  369. Propagation Channel: %s
  370. Is Embedded: %s
  371. Discovery Date Range: %s
  372. ''') % (
  373. s.id,
  374. s.host_id,
  375. self.__hosts[s.host_id].ssh_username,
  376. self.__hosts[s.host_id].ssh_password,
  377. s.ip_address,
  378. self.__propagation_channels[s.propagation_channel_id].name if s.propagation_channel_id else 'None',
  379. s.is_embedded,
  380. ('%s - %s' % (s.discovery_date_range[0].isoformat(),
  381. s.discovery_date_range[1].isoformat())) if s.discovery_date_range else 'None')
  382. self.__show_logs(s)
  383. def show_provider_ranks(self):
  384. for r in self.__provider_ranks:
  385. print textwrap.dedent('''
  386. Provider: %s
  387. Rank: %s
  388. ''') % (r.provider, r.rank)
  389. def __generate_id(self):
  390. count = 16
  391. chars = '0123456789ABCDEF'
  392. return ''.join([chars[ord(os.urandom(1))%len(chars)] for i in range(count)])
  393. def __get_propagation_channel_by_name(self, name):
  394. return filter(lambda x:x.name == name,
  395. self.__propagation_channels.itervalues())[0]
  396. def add_propagation_channel(self, name, propagation_mechanism_types):
  397. assert(self.is_locked)
  398. self.import_propagation_channel(self.__generate_id(), name, propagation_mechanism_types)
  399. def import_propagation_channel(self, id, name, propagation_mechanism_types):
  400. assert(self.is_locked)
  401. for type in propagation_mechanism_types: assert(type in self.__propagation_mechanisms)
  402. propagation_channel = PropagationChannel(id, name, propagation_mechanism_types)
  403. assert(id not in self.__propagation_channels)
  404. assert(not filter(lambda x:x.name == name, self.__propagation_channels.itervalues()))
  405. self.__propagation_channels[id] = propagation_channel
  406. def __get_sponsor_by_name(self, name):
  407. return filter(lambda x:x.name == name,
  408. self.__sponsors.itervalues())[0]
  409. def add_sponsor(self, name):
  410. assert(self.is_locked)
  411. self.import_sponsor(self.__generate_id(), name)
  412. def import_sponsor(self, id, name):
  413. assert(self.is_locked)
  414. sponsor = Sponsor(id, name, None, {}, [], [], [])
  415. assert(id not in self.__sponsors)
  416. assert(not filter(lambda x:x.name == name, self.__sponsors.itervalues()))
  417. self.__sponsors[id] = sponsor
  418. def set_sponsor_banner(self, name, banner_filename):
  419. assert(self.is_locked)
  420. with open(banner_filename, 'rb') as file:
  421. banner = base64.b64encode(file.read())
  422. sponsor = self.__get_sponsor_by_name(name)
  423. sponsor.banner = banner
  424. sponsor.log('set banner')
  425. for campaign in sponsor.campaigns:
  426. self.__deploy_builds_required_for_campaigns.add(
  427. (campaign.propagation_channel_id, sponsor.id))
  428. campaign.log('marked for build and publish (new banner)')
  429. def add_sponsor_email_campaign(self, sponsor_name, propagation_channel_name, email_account):
  430. assert(self.is_locked)
  431. sponsor = self.__get_sponsor_by_name(sponsor_name)
  432. propagation_channel = self.__get_propagation_channel_by_name(propagation_channel_name)
  433. propagation_mechanism_type = 'email-autoresponder'
  434. assert(propagation_mechanism_type in propagation_channel.propagation_mechanism_types)
  435. # TODO: assert(email_account not in ...)
  436. campaign = SponsorCampaign(propagation_channel.id,
  437. propagation_mechanism_type,
  438. EmailPropagationAccount(email_account),
  439. None)
  440. if campaign not in sponsor.campaigns:
  441. sponsor.campaigns.append(campaign)
  442. sponsor.log('add email campaign %s' % (email_account,))
  443. self.__deploy_builds_required_for_campaigns.add(
  444. (campaign.propagation_channel_id, sponsor.id))
  445. campaign.log('marked for build and publish (new campaign)')
  446. def add_sponsor_twitter_campaign(self, sponsor_name,
  447. propagation_channel_name,
  448. twitter_account_name,
  449. twitter_account_consumer_key,
  450. twitter_account_consumer_secret,
  451. twitter_account_access_token_key,
  452. twitter_account_access_token_secret):
  453. assert(self.is_locked)
  454. sponsor = self.__get_sponsor_by_name(sponsor_name)
  455. propagation_channel = self.__get_propagation_channel_by_name(propagation_channel_name)
  456. propagation_mechanism_type = 'twitter'
  457. assert(propagation_mechanism_type in propagation_channel.propagation_mechanism_types)
  458. campaign = SponsorCampaign(propagation_channel.id,
  459. propagation_mechanism_type,
  460. TwitterPropagationAccount(
  461. twitter_account_name,
  462. twitter_account_consumer_key,
  463. twitter_account_consumer_secret,
  464. twitter_account_access_token_key,
  465. twitter_account_access_token_secret),
  466. None)
  467. if campaign not in sponsor.campaigns:
  468. sponsor.campaigns.append(campaign)
  469. sponsor.log('add twitter campaign %s' % (twitter_account_name,))
  470. self.__deploy_builds_required_for_campaigns.add(
  471. (campaign.propagation_channel_id, sponsor.id))
  472. campaign.log('marked for build and publish (new campaign)')
  473. def add_sponsor_static_download_campaign(self, sponsor_name, propagation_channel_name):
  474. assert(self.is_locked)
  475. sponsor = self.__get_sponsor_by_name(sponsor_name)
  476. propagation_channel = self.__get_propagation_channel_by_name(propagation_channel_name)
  477. propagation_mechanism_type = 'static-download'
  478. assert(propagation_mechanism_type in propagation_channel.propagation_mechanism_types)
  479. campaign = SponsorCampaign(propagation_channel.id,
  480. propagation_mechanism_type,
  481. None,
  482. None)
  483. if campaign not in sponsor.campaigns:
  484. sponsor.campaigns.append(campaign)
  485. sponsor.log('add static download campaign')
  486. self.__deploy_builds_required_for_campaigns.add(
  487. (campaign.propagation_channel_id, sponsor.id))
  488. campaign.log('marked for build and publish (new campaign)')
  489. def set_sponsor_campaign_s3_bucket_name(self, sponsor_name, propagation_channel_name, account, s3_bucket_name):
  490. assert(self.is_locked)
  491. sponsor = self.__get_sponsor_by_name(sponsor_name)
  492. propagation_channel = self.__get_propagation_channel_by_name(propagation_channel_name)
  493. for campaign in sponsor.campaigns:
  494. if (campaign.propagation_channel_id == propagation_channel.id and
  495. campaign.account[0] == account):
  496. campaign.s3_bucket_name = s3_bucket_name
  497. campaign.log('set campaign s3 bucket name to %s' % (s3_bucket_name,))
  498. self.__deploy_builds_required_for_campaigns.add(
  499. (campaign.propagation_channel_id, sponsor.id))
  500. campaign.log('marked for build and publish (modified campaign)')
  501. def set_sponsor_home_page(self, sponsor_name, region, url):
  502. assert(self.is_locked)
  503. sponsor = self.__get_sponsor_by_name(sponsor_name)
  504. home_page = SponsorHomePage(region, url)
  505. if region not in sponsor.home_pages:
  506. sponsor.home_pages[region] = []
  507. if home_page not in sponsor.home_pages[region]:
  508. sponsor.home_pages[region].append(home_page)
  509. sponsor.log('set home page %s for %s' % (url, region if region else 'All'))
  510. self.__deploy_data_required_for_all = True
  511. sponsor.log('marked all hosts for data deployment')
  512. def remove_sponsor_home_page(self, sponsor_name, region, url):
  513. assert(self.is_locked)
  514. sponsor = self.__get_sponsor_by_name(sponsor_name)
  515. home_page = SponsorHomePage(region, url)
  516. if (region in sponsor.home_pages
  517. and home_page in sponsor.home_pages[region]):
  518. sponsor.home_pages[region].remove(home_page)
  519. sponsor.log('deleted home page %s for %s' % (url, region))
  520. self.__deploy_data_required_for_all = True
  521. sponsor.log('marked all hosts for data deployment')
  522. def set_sponsor_page_view_regex(self, sponsor_name, regex, replace):
  523. assert(self.is_locked)
  524. sponsor = self.__get_sponsor_by_name(sponsor_name)
  525. if not [rx for rx in sponsor.page_view_regexes if rx.regex == regex]:
  526. sponsor.page_view_regexes.append(SponsorRegex(regex, replace))
  527. sponsor.log('set page view regex %s; replace %s' % (regex, replace))
  528. self.__deploy_data_required_for_all = True
  529. sponsor.log('marked all hosts for data deployment')
  530. def remove_sponsor_page_view_regex(self, sponsor_name, regex):
  531. '''
  532. Note that the regex part of the regex+replace pair is unique, so only
  533. it has to be passed in when removing.
  534. '''
  535. assert(self.is_locked)
  536. sponsor = self.__get_sponsor_by_name(sponsor_name)
  537. match = [sponsor.page_view_regexes.pop(idx)
  538. for (idx, rx)
  539. in enumerate(sponsor.page_view_regexes)
  540. if rx.regex == regex]
  541. if match:
  542. sponsor.page_view_regexes.remove(regex)
  543. sponsor.log('deleted page view regex %s' % regex)
  544. self.__deploy_data_required_for_all = True
  545. sponsor.log('marked all hosts for data deployment')
  546. def set_sponsor_https_request_regex(self, sponsor_name, regex, replace):
  547. assert(self.is_locked)
  548. sponsor = self.__get_sponsor_by_name(sponsor_name)
  549. if not [rx for rx in sponsor.https_request_regexes if rx.regex == regex]:
  550. sponsor.https_request_regexes.append(SponsorRegex(regex, replace))
  551. sponsor.log('set https request regex %s; replace %s' % (regex, replace))
  552. self.__deploy_data_required_for_all = True
  553. sponsor.log('marked all hosts for data deployment')
  554. def remove_sponsor_https_request_regex(self, sponsor_name, regex):
  555. '''
  556. Note that the regex part of the regex+replace pair is unique, so only
  557. it has to be passed in when removing.
  558. '''
  559. assert(self.is_locked)
  560. sponsor = self.__get_sponsor_by_name(sponsor_name)
  561. match = [sponsor.https_request_regexes.pop(idx)
  562. for (idx, rx)
  563. in enumerate(sponsor.https_request_regexes)
  564. if rx.regex == regex]
  565. if match:
  566. sponsor.https_request_regexes.remove(regex)
  567. sponsor.log('deleted https request regex %s' % regex)
  568. self.__deploy_data_required_for_all = True
  569. sponsor.log('marked all hosts for data deployment')
  570. def set_sponsor_name(self, sponsor_name, new_sponsor_name):
  571. assert(self.is_locked)
  572. assert(not filter(lambda x:x.name == new_sponsor_name, self.__sponsors.itervalues()))
  573. sponsor = self.__get_sponsor_by_name(sponsor_name)
  574. sponsor.name = (new_sponsor_name)
  575. self.__deploy_stats_config_required = True
  576. sponsor.log('set sponsor name from \'%s\' to \'%s\'' % (sponsor_name, new_sponsor_name))
  577. def get_server_by_ip_address(self, ip_address):
  578. servers = filter(lambda x:x.ip_address == ip_address, self.__servers.itervalues())
  579. if len(servers) == 1:
  580. return servers[0]
  581. return None
  582. def import_host(self, id, provider, provider_id, ip_address, ssh_port, ssh_username, ssh_password, ssh_host_key,
  583. stats_ssh_username, stats_ssh_password):
  584. assert(self.is_locked)
  585. host = Host(
  586. id,
  587. provider,
  588. provider_id,
  589. ip_address,
  590. ssh_port,
  591. ssh_username,
  592. ssh_password,
  593. ssh_host_key,
  594. stats_ssh_username,
  595. stats_ssh_password)
  596. assert(host.id not in self.__hosts)
  597. self.__hosts[host.id] = host
  598. def import_server(self, server_id, host_id, ip_address, egress_ip_address, propagation_channel_id,
  599. is_embedded, discovery_date_range, web_server_port, web_server_secret,
  600. web_server_certificate, web_server_private_key, ssh_port, ssh_username,
  601. ssh_password, ssh_host_key):
  602. assert(self.is_locked)
  603. server = Server(
  604. server_id,
  605. host_id,
  606. ip_address,
  607. egress_ip_address,
  608. propagation_channel_id,
  609. is_embedded,
  610. discovery_date_range,
  611. web_server_port,
  612. web_server_secret,
  613. web_server_certificate,
  614. web_server_private_key,
  615. ssh_port,
  616. ssh_username,
  617. ssh_password,
  618. ssh_host_key)
  619. assert(server.id not in self.__servers)
  620. self.__servers[server.id] = server
  621. def add_servers(self, count, propagation_channel_name, discovery_date_range, replace_others=True):
  622. assert(self.is_locked)
  623. propagation_channel = self.__get_propagation_channel_by_name(propagation_channel_name)
  624. # Embedded servers (aka "propagation servers") are embedded in client
  625. # builds, where as discovery servers are only revealed when clients
  626. # connect to a server.
  627. is_embedded_server = (discovery_date_range is None)
  628. if replace_others:
  629. # If we are creating new propagation servers, stop embedding the old ones
  630. # (they are still active, but not embedded in builds or discovered)
  631. if is_embedded_server:
  632. for old_server in self.__servers.itervalues():
  633. if (old_server.propagation_channel_id == propagation_channel.id and
  634. old_server.is_embedded):
  635. old_server.is_embedded = False
  636. old_server.log('unembedded')
  637. # If we are creating new discovery servers, stop discovering existing ones
  638. else:
  639. self.__replace_propagation_channel_discovery_servers(propagation_channel.id)
  640. for _ in range(count):
  641. provider = self._weighted_random_choice(self.__provider_ranks).provider
  642. # This is pretty dirty. We should use some proper OO technique.
  643. provider_launch_new_server = None
  644. provider_account = None
  645. if provider.lower() == 'linode':
  646. provider_launch_new_server = psi_linode.launch_new_server
  647. provider_account = self.__linode_account
  648. elif provider.lower() == 'elastichosts':
  649. provider_launch_new_server = psi_elastichosts.ElasticHosts().launch_new_server
  650. provider_account = self._weighted_random_choice(self.__elastichosts_accounts)
  651. else:
  652. raise ValueError('bad provider value: %s' % provider)
  653. print 'starting %s process (up to 20 minutes)...' % provider
  654. # Create a new cloud VPS
  655. server_info = provider_launch_new_server(provider_account)
  656. host = Host(*server_info)
  657. host.provider = provider.lower()
  658. # NOTE: jsonpickle will serialize references to discovery_date_range, which can't be
  659. # resolved when unpickling, if discovery_date_range is used directly.
  660. # So create a copy instead.
  661. discovery = self.__copy_date_range(discovery_date_range) if discovery_date_range else None
  662. server = Server(
  663. None,
  664. host.id,
  665. host.ip_address,
  666. host.ip_address,
  667. propagation_channel.id,
  668. is_embedded_server,
  669. discovery,
  670. '8080',
  671. None,
  672. None,
  673. None,
  674. '22',
  675. None,
  676. None,
  677. None,
  678. '995')
  679. # Install Psiphon 3 and generate configuration values
  680. # Here, we're assuming one server/IP address per host
  681. existing_server_ids = [existing_server.id for existing_server in self.__servers.itervalues()]
  682. psi_ops_install.install_host(host, [server], existing_server_ids)
  683. host.log('install')
  684. # Update database
  685. # Add new server (we also add a host; here, the host and server are
  686. # one-to-one, but legacy networks have many servers per host and we
  687. # retain support for this in the data model and general functionality)
  688. # Note: this must be done before deploy_data otherwise the deployed
  689. # data will not include this host and server
  690. assert(host.id not in self.__hosts)
  691. self.__hosts[host.id] = host
  692. assert(server.id not in self.__servers)
  693. self.__servers[server.id] = server
  694. # Deploy will upload web server source database data and client builds
  695. # (Only deploying for the new host, not broadcasting info yet...)
  696. psi_ops_deploy.deploy_implementation(host)
  697. psi_ops_deploy.deploy_data(
  698. host,
  699. self.__compartmentalize_data_for_host(host.id))
  700. psi_ops_deploy.deploy_routes(host)
  701. host.log('initial deployment')
  702. self.test_server(server.id, test_vpn=False, test_ssh=False)
  703. self.save()
  704. self.__deploy_data_required_for_all = True
  705. self.__deploy_stats_config_required = True
  706. # Unless the node is reserved for discovery, release it through
  707. # the campaigns associated with the propagation channel
  708. # TODO: recover from partially complete state...
  709. if is_embedded_server:
  710. for sponsor in self.__sponsors.itervalues():
  711. for campaign in sponsor.campaigns:
  712. if campaign.propagation_channel_id == propagation_channel.id:
  713. self.__deploy_builds_required_for_campaigns.add(
  714. (campaign.propagation_channel_id, sponsor.id))
  715. campaign.log('marked for build and publish (new embedded server)')
  716. # Ensure new server configuration is saved to CMS before deploying new
  717. # server info to the network
  718. # TODO: add need-save flag
  719. self.save()
  720. # This deploy will broadcast server info, propagate builds, and update
  721. # the stats and email server
  722. self.deploy()
  723. def remove_host(self, host_id):
  724. assert(self.is_locked)
  725. host = self.__hosts[host_id]
  726. if host.provider == 'linode':
  727. provider_remove_host = psi_linode.remove_server
  728. provider_account = self.__linode_account
  729. else:
  730. raise ValueError('can\'t remove host from provider %s' % host.provider)
  731. # Remove the actual host through the provider's API
  732. provider_remove_host(provider_account, host.provider_id)
  733. # Delete the host and it's servers from the DB
  734. server_ids_on_host = []
  735. for server in self.__servers.itervalues():
  736. if server.host_id == host.id:
  737. server_ids_on_host.append(server.id)
  738. for server_id in server_ids_on_host:
  739. self.__servers.pop(server_id)
  740. self.__hosts.pop(host.id)
  741. # Clear flags that include this host id. Update stats config.
  742. if host.id in self.__deploy_implementation_required_for_hosts:
  743. self.__deploy_implementation_required_for_hosts.remove(host.id)
  744. self.__deploy_stats_config_required = True
  745. # NOTE: If host was currently discoverable or will be in the future,
  746. # host data should be updated.
  747. # NOTE: If host was currently embedded, new campaign builds are needed.
  748. self.save()
  749. def reinstall_host(self, host_id):
  750. assert(self.is_locked)
  751. host = self.__hosts[host_id]
  752. servers = [server for server in self.__servers.itervalues() if server.host_id == host_id]
  753. existing_server_ids = [existing_server.id for existing_server in self.__servers.itervalues()]
  754. psi_ops_install.install_host(host, servers, existing_server_ids)
  755. psi_ops_deploy.deploy_implementation(host)
  756. # New data might have been generated
  757. # NOTE that if the client version has been incremented but a full deploy has not yet been run,
  758. # this following psi_ops_deploy.deploy_data call is not safe. Data will specify a new version
  759. # that is not yet available on servers (infinite download loop).
  760. psi_ops_deploy.deploy_data(
  761. host,
  762. self.__compartmentalize_data_for_host(host.id))
  763. host.log('reinstall')
  764. def reinstall_hosts(self):
  765. assert(self.is_locked)
  766. for host in self.__hosts.itervalues():
  767. self.reinstall_host(host.id)
  768. def set_servers_propagation_channel_and_discovery_date_range(self, server_names, propagation_channel_name, discovery_date_range, replace_others=True):
  769. assert(self.is_locked)
  770. propagation_channel = self.__get_propagation_channel_by_name(propagation_channel_name)
  771. if replace_others:
  772. self.__replace_propagation_channel_discovery_servers(propagation_channel.id)
  773. for server_name in server_names:
  774. server = self.__servers[server_name]
  775. server.propagation_channel_id = propagation_channel.id
  776. server.discovery_date_range = self.__copy_date_range(discovery_date_range)
  777. server.log('propagation channel set to %s' % (propagation_channel.id,))
  778. server.log('discovery_date_range set to %s - %s' % (server.discovery_date_range[0].isoformat(),
  779. server.discovery_date_range[1].isoformat()))
  780. self.__deploy_data_required_for_all = True
  781. def __copy_date_range(self, date_range):
  782. return (datetime.datetime(date_range[0].year,
  783. date_range[0].month,
  784. date_range[0].day,
  785. date_range[0].hour,
  786. date_range[0].minute),
  787. datetime.datetime(date_range[1].year,
  788. date_range[1].month,
  789. date_range[1].day,
  790. date_range[1].hour,
  791. date_range[1].minute))
  792. def __replace_propagation_channel_discovery_servers(self, propagation_channel_id):
  793. assert(self.is_locked)
  794. now = datetime.datetime.now()
  795. for old_server in self.__servers.itervalues():
  796. # NOTE: don't instantiate today outside of this loop, otherwise jsonpickle will
  797. # serialize references to it (for all but the first server in this loop) which
  798. # are not unpickle-able
  799. today = datetime.datetime(now.year, now.month, now.day)
  800. if (old_server.propagation_channel_id == propagation_channel_id and
  801. old_server.discovery_date_range and
  802. (old_server.discovery_date_range[0] <= today < old_server.discovery_date_range[1])):
  803. old_server.discovery_date_range = (old_server.discovery_date_range[0], today)
  804. old_server.log('replaced')
  805. def _weighted_random_choice(self, choices):
  806. '''
  807. Assumes that each choice has a "rank" attribute, and that the rank is an integer.
  808. Returns the chosen members of the choices iterable.
  809. '''
  810. if not choices:
  811. raise ValueError('choices must not be empty')
  812. rank_total = sum([choice.rank for choice in choices])
  813. rand = random.randrange(rank_total)
  814. rank_accum = 0
  815. for choice in choices:
  816. rank_accum += choice.rank
  817. if rank_accum > rand:
  818. break
  819. return choice
  820. def build(self, propagation_channel_name, sponsor_name, test=False):
  821. propagation_channel = self.__get_propagation_channel_by_name(propagation_channel_name)
  822. sponsor = self.__get_sponsor_by_name(sponsor_name)
  823. version = self.__client_versions[-1].version
  824. encoded_server_list, expected_egress_ip_addresses = \
  825. self.__get_encoded_server_list(propagation_channel.id)
  826. # A sponsor may use the same propagation channel for multiple
  827. # campaigns; we need only build and upload the client once.
  828. return psi_ops_build.build_client(
  829. propagation_channel.id,
  830. sponsor.id,
  831. base64.b64decode(sponsor.banner),
  832. encoded_server_list,
  833. version,
  834. test)
  835. def deploy(self):
  836. # Deploy as required:
  837. #
  838. # - Implementation to flagged hosts
  839. # - Builds for required channels and sponsors
  840. # - Publish, tweet
  841. # - Data to all hosts
  842. # - Email and stats server config
  843. #
  844. # NOTE: Order is important. Hosts get new implementation before
  845. # new data, in case schema has changed; deploy builds before
  846. # deploying new data so an upgrade is available when it's needed
  847. assert(self.is_locked)
  848. # Host implementation
  849. hosts = [self.__hosts[host_id] for host_id in self.__deploy_implementation_required_for_hosts]
  850. psi_ops_deploy.deploy_implementation_to_hosts(hosts)
  851. if len(self.__deploy_implementation_required_for_hosts) > 0:
  852. self.__deploy_implementation_required_for_hosts.clear()
  853. self.save()
  854. # Build
  855. for target in self.__deploy_builds_required_for_campaigns.copy():
  856. propagation_channel_id, sponsor_id = target
  857. propagation_channel = self.__propagation_channels[propagation_channel_id]
  858. sponsor = self.__sponsors[sponsor_id]
  859. # Build and upload to hosts
  860. build_filename = self.build(propagation_channel.name, sponsor.name)
  861. # Upload client builds
  862. # We only upload the builds for Propagation Channel IDs that need to be known for the host.
  863. # UPDATE: Now we copy all builds. We know that this breaks compartmentalization.
  864. # However, we do not want to prevent an upgrade in the case where a user has
  865. # downloaded from multiple propagation channels, and might therefore be connecting
  866. # to a server from one propagation channel using a build from a different one.
  867. psi_ops_deploy.deploy_build_to_hosts(self.__hosts.itervalues(), build_filename)
  868. # Publish to propagation mechanisms
  869. for campaign in filter(lambda x:x.propagation_channel_id == propagation_channel_id, sponsor.campaigns):
  870. if campaign.s3_bucket_name:
  871. psi_ops_s3.update_s3_download(self.__aws_account, build_filename, campaign.s3_bucket_name)
  872. campaign.log('updated s3 bucket %s' % (campaign.s3_bucket_name,))
  873. else:
  874. campaign.s3_bucket_name = psi_ops_s3.publish_s3_download(self.__aws_account, build_filename)
  875. campaign.log('created s3 bucket %s' % (campaign.s3_bucket_name,))
  876. if campaign.propagation_mechanism_type == 'twitter':
  877. message = psi_templates.get_tweet_message(campaign.s3_bucket_name)
  878. psi_ops_twitter.tweet(campaign.account, message)
  879. campaign.log('tweeted')
  880. elif campaign.propagation_mechanism_type == 'email-autoresponder':
  881. if not self.__deploy_email_config_required:
  882. self.__deploy_email_config_required = True
  883. campaign.log('email push scheduled')
  884. self.__deploy_builds_required_for_campaigns.remove(target)
  885. self.save()
  886. # Host data
  887. if self.__deploy_data_required_for_all:
  888. host_and_data_list = []
  889. for host in self.__hosts.itervalues():
  890. host_and_data_list.append(dict(host=host, data=self.__compartmentalize_data_for_host(host.id)))
  891. psi_ops_deploy.deploy_data_to_hosts(host_and_data_list)
  892. self.__deploy_data_required_for_all = False
  893. self.save()
  894. # Email and stats server configs
  895. if self.__deploy_stats_config_required:
  896. self.push_stats_config()
  897. self.__deploy_stats_config_required = False
  898. self.save()
  899. if self.__deploy_email_config_required:
  900. self.push_email_config()
  901. self.__deploy_email_config_required = False
  902. self.save()
  903. def update_routes(self):
  904. assert(self.is_locked) # (host.log is called by deploy)
  905. psi_routes.make_routes()
  906. psi_ops_deploy.deploy_routes_to_hosts(self.__hosts.values())
  907. def push_stats_config(self):
  908. assert(self.is_locked)
  909. print 'push stats config...'
  910. temp_file = tempfile.NamedTemporaryFile(delete=False)
  911. try:
  912. temp_file.write(self.__compartmentalize_data_for_stats_server())
  913. temp_file.close()
  914. psi_ops_cms.import_document(temp_file.name, True)
  915. self.__stats_server_account.log('pushed')
  916. finally:
  917. try:
  918. os.remove(temp_file.name)
  919. except:
  920. pass
  921. def push_email_config(self):
  922. # Generate the email server config file, which is a JSON format
  923. # mapping every request email to a response body containing
  924. # download links.
  925. # Currently, we generate the entire config file for any change.
  926. assert(self.is_locked)
  927. print 'push email config...'
  928. emails = {}
  929. for sponsor in self.__sponsors.itervalues():
  930. for campaign in sponsor.campaigns:
  931. if (campaign.propagation_mechanism_type == 'email-autoresponder' and
  932. campaign.s3_bucket_name != None):
  933. emails[campaign.account.email_address] = \
  934. {
  935. 'body':
  936. [
  937. ['plain', psi_templates.get_plaintext_email_content(campaign.s3_bucket_name)],
  938. ['html', psi_templates.get_html_email_content(campaign.s3_bucket_name)]
  939. ],
  940. 'attachment_bucket': campaign.s3_bucket_name
  941. }
  942. campaign.log('configuring email')
  943. temp_file = tempfile.NamedTemporaryFile(delete=False)
  944. try:
  945. temp_file.write(json.dumps(emails))
  946. temp_file.close()
  947. ssh = psi_ssh.SSH(
  948. self.__email_server_account.ip_address,
  949. self.__email_server_account.ssh_port,
  950. self.__email_server_account.ssh_username,
  951. None,
  952. self.__email_server_account.ssh_host_key,
  953. ssh_pkey=self.__email_server_account.ssh_pkey)
  954. ssh.put_file(
  955. temp_file.name,
  956. self.__email_server_account.config_file_path)
  957. self.__email_server_account.log('pushed')
  958. finally:
  959. try:
  960. os.remove(temp_file.name)
  961. except:
  962. pass
  963. def add_server_version(self):
  964. assert(self.is_locked)
  965. # Marks all hosts for re-deployment of server implementation
  966. for host in self.__hosts.itervalues():
  967. self.__deploy_implementation_required_for_hosts.add(host.id)
  968. host.log('marked for implementation deployment')
  969. def add_client_version(self, description):
  970. assert(self.is_locked)
  971. # Records the new version number to trigger upgrades
  972. next_version = 1
  973. if len(self.__client_versions) > 0:
  974. next_version = int(self.__client_versions[-1].version)+1
  975. client_version = ClientVersion(str(next_version), description)
  976. self.__client_versions.append(client_version)
  977. # Mark deploy flag to rebuild and upload all clients
  978. for sponsor in self.__sponsors.itervalues():
  979. for campaign in sponsor.campaigns:
  980. self.__deploy_builds_required_for_campaigns.add(
  981. (campaign.propagation_channel_id, sponsor.id))
  982. campaign.log('marked for build and publish (upgraded client)')
  983. # Need to deploy data as well for auto-update
  984. self.__deploy_data_required_for_all = True
  985. def get_server_entry(self, server_id):
  986. server = filter(lambda x:x.id == server_id,self.__servers.itervalues())[0]
  987. return self.__get_encoded_server_entry(server)
  988. def deploy_implementation_and_data_for_host_with_server(self, server_id):
  989. server = filter(lambda x:x.id == server_id,self.__servers.itervalues())[0]
  990. host = filter(lambda x:x.id == server.host_id,self.__hosts.itervalues())[0]
  991. psi_ops_deploy.deploy_implementation(host)
  992. psi_ops_deploy.deploy_data(host, self.__compartmentalize_data_for_host(host.id))
  993. def set_aws_account(self, access_id, secret_key):
  994. assert(self.is_locked)
  995. psi_utils.update_recordtype(
  996. self.__aws_account,
  997. access_id=access_id, secret_key=secret_key)
  998. def upsert_provider_rank(self, provider, rank):
  999. '''
  1000. Inserts or updates a Provider-Rank entry. The "key" for an entry is provider.
  1001. rank: the higher the score, the more the provider will be preferred when
  1002. provideres are being randomly selected among.
  1003. '''
  1004. assert(self.is_locked)
  1005. if provider not in ProviderRank.provider_values:
  1006. raise ValueError('bad provider value: %s' % provider)
  1007. pr = ProviderRank()
  1008. found = False
  1009. for existing_pr in self.__provider_ranks:
  1010. if existing_pr.provider == provider:
  1011. pr = existing_pr
  1012. found = True
  1013. break
  1014. if not found:
  1015. self.__provider_ranks.append(pr)
  1016. psi_utils.update_recordtype(
  1017. pr,
  1018. provider=provider, rank=rank)
  1019. def set_linode_account(self, api_key, base_id, base_ip_address, base_ssh_port,
  1020. base_root_password, base_stats_username, base_host_public_key,
  1021. base_known_hosts_entry, base_rsa_private_key, base_rsa_public_key,
  1022. base_tarball_path):
  1023. assert(self.is_locked)
  1024. psi_utils.update_recordtype(
  1025. self.__linode_account,
  1026. api_key=api_key, base_id=base_id, base_ip_address=base_ip_address,
  1027. base_ssh_port=base_ssh_port, base_root_password=base_root_password,
  1028. base_stats_username=base_stats_username, base_host_public_key=base_host_public_key,
  1029. base_known_hosts_entry=base_known_hosts_entry, base_rsa_private_key=base_rsa_private_key,
  1030. base_rsa_public_key=base_rsa_public_key, base_tarball_path=base_tarball_path)
  1031. def upsert_elastichosts_account(self, zone, uuid, api_key, base_drive_id,
  1032. cpu, mem, base_host_public_key, root_username,
  1033. base_root_password, base_ssh_port, stats_username, rank):
  1034. '''
  1035. Inserts or updates an ElasticHosts account information entry. The "key"
  1036. for an entry is zone+uuid.
  1037. rank: the higher the score, the more the account will be preferred when
  1038. the ElasticHosts accounts are being randomly selected among.
  1039. '''
  1040. assert(self.is_locked)
  1041. if zone not in ElasticHostsAccount.zone_values:
  1042. raise ValueError('bad zone value: %s' % zone)
  1043. acct = ElasticHostsAccount()
  1044. found = False
  1045. for existing_acct in self.__elastichosts_accounts:
  1046. if existing_acct.zone == zone and existing_acct.uuid == uuid:
  1047. acct = existing_acct
  1048. found = True
  1049. break
  1050. if not found:
  1051. self.__elastichosts_accounts.append(acct)
  1052. psi_utils.update_recordtype(
  1053. acct,
  1054. zone=zone, uuid=uuid,
  1055. api_key=acct.api_key if api_key is None else api_key,
  1056. base_drive_id=acct.base_drive_id if base_drive_id is None else base_drive_id,
  1057. cpu=acct.cpu if cpu is None else cpu,
  1058. mem=acct.mem if mem is None else mem,
  1059. base_host_public_key=acct.base_host_public_key if base_host_public_key is None else base_host_public_key,
  1060. root_username=acct.root_username if root_username is None else root_username,
  1061. base_root_password=acct.base_root_password if base_root_password is None else base_root_password,
  1062. base_ssh_port=acct.base_ssh_port if base_ssh_port is None else base_ssh_port,
  1063. stats_username=acct.stats_username if stats_username is None else stats_username,
  1064. rank=acct.rank if rank is None else rank)
  1065. def set_email_server_account(self, ip_address, ssh_port,
  1066. ssh_username, ssh_pkey, ssh_host_key,
  1067. config_file_path):
  1068. assert(self.is_locked)
  1069. psi_utils.update_recordtype(
  1070. self.__email_server_account,
  1071. ip_address=ip_address, ssh_port=ssh_port, ssh_username=ssh_username,
  1072. ssh_pkey=ssh_pkey, ssh_host_key=ssh_host_key, config_file_path=config_file_path)
  1073. def set_stats_server_account(self, ip_address, ssh_port,
  1074. ssh_username, ssh_password, ssh_host_key):
  1075. assert(self.is_locked)
  1076. psi_utils.update_recordtype(
  1077. self.__stats_server_account,
  1078. ip_address=ip_address, ssh_port=ssh_port, ssh_username=ssh_username,
  1079. ssh_password=ssh_password, ssh_host_key=ssh_host_key)
  1080. def add_speed_test_url(self, server_address, server_port, request_path):
  1081. assert(self.is_locked)
  1082. if (server_address, server_port, request_path) not in [
  1083. (s.server_address, s.server_port, s.request_path) for s in self.__speed_test_urls]:
  1084. self.__speed_test_urls.append(SpeedTestURL(server_address, server_port, request_path))
  1085. self.__deploy_data_required_for_all = True
  1086. def __get_encoded_server_entry(self, server):
  1087. # Double-check that we're not giving our blank server credentials
  1088. # ...this has happened in the past when following manual build steps
  1089. assert(len(server.ip_address) > 1)
  1090. assert(len(server.web_server_port) > 1)
  1091. assert(len(server.web_server_secret) > 1)
  1092. assert(len(server.web_server_certificate) > 1)
  1093. return binascii.hexlify('%s %s %s %s' % (
  1094. server.ip_address,
  1095. server.web_server_port,
  1096. server.web_server_secret,
  1097. server.web_server_certificate))
  1098. def __get_encoded_server_list(self, propagation_channel_id,
  1099. client_ip_address=None, event_logger=None, discovery_date=None):
  1100. if not client_ip_address:
  1101. # embedded (propagation) server list
  1102. # output all servers for propagation channel ID with no discovery date
  1103. servers = [server for server in self.__servers.itervalues()
  1104. if server.propagation_channel_id == propagation_channel_id and
  1105. server.is_embedded]
  1106. else:
  1107. # discovery case
  1108. if not discovery_date:
  1109. discovery_date = datetime.datetime.now()
  1110. # count servers for propagation channel ID to be discovered in current date range
  1111. servers = [server for server in self.__servers.itervalues()
  1112. if server.propagation_channel_id == propagation_channel_id and (
  1113. server.discovery_date_range is not None and
  1114. server.discovery_date_range[0] <= discovery_date < server.discovery_date_range[1])]
  1115. # number of IP Address buckets is number of matching servers, so just
  1116. # give the client the one server in their bucket
  1117. # NOTE: when there are many servers, we could return more than one per bucket. For example,
  1118. # with 4 matching servers, we could make 2 buckets of 2. But if we have that many servers,
  1119. # it would be better to mix in an additional strategy instead of discovering extra servers
  1120. # for no additional "effort".
  1121. bucket_count = len(servers)
  1122. if bucket_count == 0:
  1123. return ([], None)
  1124. bucket = struct.unpack('!L',socket.inet_aton(client_ip_address))[0] % bucket_count
  1125. servers = [servers[bucket]]
  1126. # optional logger (used by server to log each server IP address disclosed)
  1127. if event_logger:
  1128. for server in servers:
  1129. event_logger(server.ip_address)
  1130. return ([self.__get_encoded_server_entry(server) for server in servers],
  1131. [server.egress_ip_address for server in servers])
  1132. def get_region(self, client_ip_address):
  1133. try:
  1134. region = None
  1135. # Use the commercial "city" database is available
  1136. city_db_filename = '/usr/local/share/GeoIP/GeoIPCity.dat'
  1137. if os.path.isfile(city_db_filename):
  1138. record = GeoIP.open(city_db_filename,
  1139. GeoIP.GEOIP_MEMORY_CACHE).record_by_name(client_ip_address)
  1140. if record:
  1141. region = record['country_code']
  1142. else:
  1143. region = GeoIP.new(GeoIP.GEOIP_MEMORY_CACHE).country_code_by_name(client_ip_address)
  1144. if region is None:
  1145. region = 'None'
  1146. return region
  1147. except NameError:
  1148. # Handle the case where the GeoIP module isn't installed
  1149. return 'None'
  1150. def __get_sponsor_home_pages(self, sponsor_id, client_ip_address, region=None):
  1151. # Web server support function: fails gracefully
  1152. if sponsor_id not in self.__sponsors:
  1153. return []
  1154. sponsor = self.__sponsors[sponsor_id]
  1155. if not region:
  1156. region = self.get_region(client_ip_address)
  1157. # case: lookup succeeded and corresponding region home page found
  1158. sponsor_home_pages = []
  1159. if region in sponsor.home_pages:
  1160. sponsor_home_pages = [home_page.url for home_page in sponsor.home_pages[region]]
  1161. # case: lookup failed or no corresponding region home page found --> use default
  1162. if not sponsor_home_pages and 'None' in sponsor.home_pages:
  1163. sponsor_home_pages = [home_page.url for home_page in sponsor.home_pages['None']]
  1164. return sponsor_home_pages
  1165. def _get_sponsor_page_view_regexes(self, sponsor_id):
  1166. # Web server support function: fails gracefully
  1167. if sponsor_id not in self.__sponsors:
  1168. return []
  1169. sponsor = self.__sponsors[sponsor_id]
  1170. return sponsor.page_view_regexes
  1171. def _get_sponsor_https_request_regexes(self, sponsor_id):
  1172. # Web server support function: fails gracefully
  1173. if sponsor_id not in self.__sponsors:
  1174. return []
  1175. sponsor = self.__sponsors[sponsor_id]
  1176. return sponsor.https_request_regexes
  1177. def __check_upgrade(self, client_version):
  1178. # check last version number against client version number
  1179. # assumes versions list is in ascending version order
  1180. if not self.__client_versions:
  1181. return None
  1182. last_version = self.__client_versions[-1].version
  1183. if int(last_version) > int(client_version):
  1184. return last_version
  1185. return None
  1186. def handshake(self, server_ip_address, client_ip_address,
  1187. propagation_channel_id, sponsor_id, client_version, event_logger=None):
  1188. # Handshake output is a series of Name:Value lines returned to the client
  1189. output = []
  1190. # Give client a set of landing pages to open when connection established
  1191. homepage_urls = self.__get_sponsor_home_pages(sponsor_id, client_ip_address)
  1192. for homepage_url in homepage_urls:
  1193. output.append('Homepage: %s' % (homepage_url,))
  1194. # Tell client if an upgrade is available
  1195. upgrade_client_version = self.__check_upgrade(client_version)
  1196. if upgrade_client_version:
  1197. output.append('Upgrade: %s' % (upgrade_client_version,))
  1198. # Discovery
  1199. encoded_server_list, expected_egress_ip_addresses = \
  1200. self.__get_encoded_server_list(
  1201. propagation_channel_id,
  1202. client_ip_address,
  1203. event_logger=event_logger)
  1204. for encoded_server_entry in encoded_server_list:
  1205. output.append('Server: %s' % (encoded_server_entry,))
  1206. # VPN relay protocol info
  1207. # Note: this is added in the handshake handler in psi_web
  1208. # output.append(psi_psk.set_psk(self.server_ip_address))
  1209. # SSH relay protocol info
  1210. #
  1211. # SSH Session ID is a randomly generated unique ID used for
  1212. # client-side session duration reporting
  1213. #
  1214. server = filter(lambda x : x.ip_address == server_ip_address,
  1215. self.__servers.itervalues())[0]
  1216. if server.ssh_host_key:
  1217. output.append('SSHPort: %s' % (server.ssh_port,))
  1218. output.append('SSHUsername: %s' % (server.ssh_username,))
  1219. output.append('SSHPassword: %s' % (server.ssh_password,))
  1220. key_type, host_key = server.ssh_host_key.split(' ')
  1221. assert(key_type == 'ssh-rsa')
  1222. output.append('SSHHostKey: %s' % (host_key,))
  1223. output.append('SSHSessionID: %s' % (binascii.hexlify(os.urandom(8)),))
  1224. # Obfuscated SSH fields are optional
  1225. if server.ssh_obfuscated_port:
  1226. output.append('SSHObfuscatedPort: %s' % (server.ssh_obfuscated_port,))
  1227. output.append('SSHObfuscatedKey: %s' % (server.ssh_obfuscated_key,))
  1228. # Additional Configuration
  1229. # Extra config is JSON-encoded.
  1230. # Give client a set of regexes indicating which pages should have individual stats
  1231. config = {}
  1232. config['page_view_regexes'] = []
  1233. for sponsor_regex in self._get_sponsor_page_view_regexes(sponsor_id):
  1234. config['page_view_regexes'].append({
  1235. 'regex' : sponsor_regex.regex,
  1236. 'replace' : sponsor_regex.replace
  1237. })
  1238. config['https_request_regexes'] = []
  1239. for sponsor_regex in self._get_sponsor_https_request_regexes(sponsor_id):
  1240. config['https_request_regexes'].appe