PageRenderTime 52ms CodeModel.GetById 20ms RepoModel.GetById 1ms app.codeStats 0ms

/src/nupic/swarming/hypersearch/utils.py

https://gitlab.com/aw1231/nupic
Python | 755 lines | 720 code | 7 blank | 28 comment | 1 complexity | c8aa28161dbb3433d877a466a94ec6ea MD5 | raw file
  1. # ----------------------------------------------------------------------
  2. # Numenta Platform for Intelligent Computing (NuPIC)
  3. # Copyright (C) 2013, Numenta, Inc. Unless you have an agreement
  4. # with Numenta, Inc., for a separate license for this software code, the
  5. # following terms and conditions apply:
  6. #
  7. # This program is free software: you can redistribute it and/or modify
  8. # it under the terms of the GNU Affero Public License version 3 as
  9. # published by the Free Software Foundation.
  10. #
  11. # This program is distributed in the hope that it will be useful,
  12. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  14. # See the GNU Affero Public License for more details.
  15. #
  16. # You should have received a copy of the GNU Affero Public License
  17. # along with this program. If not, see http://www.gnu.org/licenses.
  18. #
  19. # http://numenta.org/licenses/
  20. # ----------------------------------------------------------------------
  21. import copy
  22. import json
  23. import os
  24. import sys
  25. import tempfile
  26. import logging
  27. import re
  28. import traceback
  29. import StringIO
  30. from collections import namedtuple
  31. import pprint
  32. import shutil
  33. import types
  34. import signal
  35. import uuid
  36. import validictory
  37. from nupic.database.ClientJobsDAO import (
  38. ClientJobsDAO, InvalidConnectionException)
  39. # TODO: Note the function 'rUpdate' is also duplicated in the
  40. # nupic.data.dictutils module -- we will eventually want to change this
  41. # TODO: 'ValidationError', 'validate', 'loadJSONValueFromFile' duplicated in
  42. # nupic.data.jsonhelpers -- will want to remove later
  43. class JobFailException(Exception):
  44. """ If a model raises this exception, then the runModelXXX code will
  45. mark the job as canceled so that all other workers exit immediately, and mark
  46. the job as failed.
  47. """
  48. pass
  49. def getCopyrightHead():
  50. return """# ----------------------------------------------------------------------
  51. # Numenta Platform for Intelligent Computing (NuPIC)
  52. # Copyright (C) 2013, Numenta, Inc. Unless you have an agreement
  53. # with Numenta, Inc., for a separate license for this software code, the
  54. # following terms and conditions apply:
  55. #
  56. # This program is free software: you can redistribute it and/or modify
  57. # it under the terms of the GNU Affero Public License version 3 as
  58. # published by the Free Software Foundation.
  59. #
  60. # This program is distributed in the hope that it will be useful,
  61. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  62. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
  63. # See the GNU Affero Public License for more details.
  64. #
  65. # You should have received a copy of the GNU Affero Public License
  66. # along with this program. If not, see http://www.gnu.org/licenses.
  67. #
  68. # http://numenta.org/licenses/
  69. # ----------------------------------------------------------------------
  70. """
  71. def _paramsFileHead():
  72. """
  73. This is the first portion of every sub-experiment params file we generate. Between
  74. the head and the tail are the experiment specific options.
  75. """
  76. str = getCopyrightHead() + \
  77. """
  78. ## This file defines parameters for a prediction experiment.
  79. ###############################################################################
  80. # IMPORTANT!!!
  81. # This params file is dynamically generated by the RunExperimentPermutations
  82. # script. Any changes made manually will be over-written the next time
  83. # RunExperimentPermutations is run!!!
  84. ###############################################################################
  85. from nupic.frameworks.opf.expdescriptionhelpers import importBaseDescription
  86. # the sub-experiment configuration
  87. config ={
  88. """
  89. return str
  90. def _paramsFileTail():
  91. """
  92. This is the tail of every params file we generate. Between the head and the tail
  93. are the experiment specific options.
  94. """
  95. str = \
  96. """
  97. }
  98. mod = importBaseDescription('base.py', config)
  99. locals().update(mod.__dict__)
  100. """
  101. return str
  102. def _appendReportKeys(keys, prefix, results):
  103. """
  104. Generate a set of possible report keys for an experiment's results.
  105. A report key is a string of key names separated by colons, each key being one
  106. level deeper into the experiment results dict. For example, 'key1:key2'.
  107. This routine is called recursively to build keys that are multiple levels
  108. deep from the results dict.
  109. Parameters:
  110. -----------------------------------------------------------
  111. keys: Set of report keys accumulated so far
  112. prefix: prefix formed so far, this is the colon separated list of key
  113. names that led up to the dict passed in results
  114. results: dictionary of results at this level.
  115. """
  116. allKeys = results.keys()
  117. allKeys.sort()
  118. for key in allKeys:
  119. if hasattr(results[key], 'keys'):
  120. _appendReportKeys(keys, "%s%s:" % (prefix, key), results[key])
  121. else:
  122. keys.add("%s%s" % (prefix, key))
  123. class _BadKeyError(Exception):
  124. """ If a model raises this exception, then the runModelXXX code will
  125. mark the job as canceled so that all other workers exit immediately, and mark
  126. the job as failed.
  127. """
  128. pass
  129. def _matchReportKeys(reportKeyREs=[], allReportKeys=[]):
  130. """
  131. Extract all items from the 'allKeys' list whose key matches one of the regular
  132. expressions passed in 'reportKeys'.
  133. Parameters:
  134. ----------------------------------------------------------------------------
  135. reportKeyREs: List of regular expressions
  136. allReportKeys: List of all keys
  137. retval: list of keys from allReportKeys that match the regular expressions
  138. in 'reportKeyREs'
  139. If an invalid regular expression was included in 'reportKeys',
  140. then BadKeyError() is raised
  141. """
  142. matchingReportKeys = []
  143. # Extract the report items of interest
  144. for keyRE in reportKeyREs:
  145. # Find all keys that match this regular expression
  146. matchObj = re.compile(keyRE)
  147. found = False
  148. for keyName in allReportKeys:
  149. match = matchObj.match(keyName)
  150. if match and match.end() == len(keyName):
  151. matchingReportKeys.append(keyName)
  152. found = True
  153. if not found:
  154. raise _BadKeyError(keyRE)
  155. return matchingReportKeys
  156. def _getReportItem(itemName, results):
  157. """
  158. Get a specific item by name out of the results dict.
  159. The format of itemName is a string of dictionary keys separated by colons,
  160. each key being one level deeper into the results dict. For example,
  161. 'key1:key2' would fetch results['key1']['key2'].
  162. If itemName is not found in results, then None is returned
  163. """
  164. subKeys = itemName.split(':')
  165. subResults = results
  166. for subKey in subKeys:
  167. subResults = subResults[subKey]
  168. return subResults
  169. def filterResults(allResults, reportKeys, optimizeKey=None):
  170. """ Given the complete set of results generated by an experiment (passed in
  171. 'results'), filter out and return only the ones the caller wants, as
  172. specified through 'reportKeys' and 'optimizeKey'.
  173. A report key is a string of key names separated by colons, each key being one
  174. level deeper into the experiment results dict. For example, 'key1:key2'.
  175. Parameters:
  176. -------------------------------------------------------------------------
  177. results: dict of all results generated by an experiment
  178. reportKeys: list of items from the results dict to include in
  179. the report. These can be regular expressions.
  180. optimizeKey: Which report item, if any, we will be optimizing for. This can
  181. also be a regular expression, but is an error if it matches
  182. more than one key from the experiment's results.
  183. retval: (reportDict, optimizeDict)
  184. reportDict: a dictionary of the metrics named by desiredReportKeys
  185. optimizeDict: A dictionary containing 1 item: the full name and
  186. value of the metric identified by the optimizeKey
  187. """
  188. # Init return values
  189. optimizeDict = dict()
  190. # Get all available report key names for this experiment
  191. allReportKeys = set()
  192. _appendReportKeys(keys=allReportKeys, prefix='', results=allResults)
  193. #----------------------------------------------------------------------------
  194. # Extract the report items that match the regular expressions passed in reportKeys
  195. matchingKeys = _matchReportKeys(reportKeys, allReportKeys)
  196. # Extract the values of the desired items
  197. reportDict = dict()
  198. for keyName in matchingKeys:
  199. value = _getReportItem(keyName, allResults)
  200. reportDict[keyName] = value
  201. # -------------------------------------------------------------------------
  202. # Extract the report item that matches the regular expression passed in
  203. # optimizeKey
  204. if optimizeKey is not None:
  205. matchingKeys = _matchReportKeys([optimizeKey], allReportKeys)
  206. if len(matchingKeys) == 0:
  207. raise _BadKeyError(optimizeKey)
  208. elif len(matchingKeys) > 1:
  209. raise _BadOptimizeKeyError(optimizeKey, matchingKeys)
  210. optimizeKeyFullName = matchingKeys[0]
  211. # Get the value of the optimize metric
  212. value = _getReportItem(optimizeKeyFullName, allResults)
  213. optimizeDict[optimizeKeyFullName] = value
  214. reportDict[optimizeKeyFullName] = value
  215. # Return info
  216. return(reportDict, optimizeDict)
  217. def _quoteAndEscape(string):
  218. """
  219. string: input string (ascii or unicode)
  220. Returns: a quoted string with characters that are represented in python via
  221. escape sequences converted to those escape sequences
  222. """
  223. assert type(string) in types.StringTypes
  224. return pprint.pformat(string)
  225. def _handleModelRunnerException(jobID, modelID, jobsDAO, experimentDir, logger,
  226. e):
  227. """ Perform standard handling of an exception that occurs while running
  228. a model.
  229. Parameters:
  230. -------------------------------------------------------------------------
  231. jobID: ID for this hypersearch job in the jobs table
  232. modelID: model ID
  233. jobsDAO: ClientJobsDAO instance
  234. experimentDir: directory containing the experiment
  235. logger: the logger to use
  236. e: the exception that occurred
  237. retval: (completionReason, completionMsg)
  238. """
  239. msg = StringIO.StringIO()
  240. print >>msg, "Exception occurred while running model %s: %r (%s)" % (
  241. modelID, e, type(e))
  242. traceback.print_exc(None, msg)
  243. completionReason = jobsDAO.CMPL_REASON_ERROR
  244. completionMsg = msg.getvalue()
  245. logger.error(completionMsg)
  246. # Write results to the model database for the error case. Ignore
  247. # InvalidConnectionException, as this is usually caused by orphaned models
  248. #
  249. # TODO: do we really want to set numRecords to 0? Last updated value might
  250. # be useful for debugging
  251. if type(e) is not InvalidConnectionException:
  252. jobsDAO.modelUpdateResults(modelID, results=None, numRecords=0)
  253. # TODO: Make sure this wasn't the best model in job. If so, set the best
  254. # appropriately
  255. # If this was an exception that should mark the job as failed, do that
  256. # now.
  257. if type(e) == JobFailException:
  258. workerCmpReason = jobsDAO.jobGetFields(jobID,
  259. ['workerCompletionReason'])[0]
  260. if workerCmpReason == ClientJobsDAO.CMPL_REASON_SUCCESS:
  261. jobsDAO.jobSetFields(jobID, fields=dict(
  262. cancel=True,
  263. workerCompletionReason = ClientJobsDAO.CMPL_REASON_ERROR,
  264. workerCompletionMsg = ": ".join(str(i) for i in e.args)),
  265. useConnectionID=False,
  266. ignoreUnchanged=True)
  267. return (completionReason, completionMsg)
  268. def runModelGivenBaseAndParams(modelID, jobID, baseDescription, params,
  269. predictedField, reportKeys, optimizeKey, jobsDAO,
  270. modelCheckpointGUID, logLevel=None, predictionCacheMaxRecords=None):
  271. """ This creates an experiment directory with a base.py description file
  272. created from 'baseDescription' and a description.py generated from the
  273. given params dict and then runs the experiment.
  274. Parameters:
  275. -------------------------------------------------------------------------
  276. modelID: ID for this model in the models table
  277. jobID: ID for this hypersearch job in the jobs table
  278. baseDescription: Contents of a description.py with the base experiment
  279. description
  280. params: Dictionary of specific parameters to override within
  281. the baseDescriptionFile.
  282. predictedField: Name of the input field for which this model is being
  283. optimized
  284. reportKeys: Which metrics of the experiment to store into the
  285. results dict of the model's database entry
  286. optimizeKey: Which metric we are optimizing for
  287. jobsDAO Jobs data access object - the interface to the
  288. jobs database which has the model's table.
  289. modelCheckpointGUID: A persistent, globally-unique identifier for
  290. constructing the model checkpoint key
  291. logLevel: override logging level to this value, if not None
  292. retval: (completionReason, completionMsg)
  293. """
  294. from nupic.swarming.ModelRunner import OPFModelRunner
  295. # The logger for this method
  296. logger = logging.getLogger('com.numenta.nupic.hypersearch.utils')
  297. # --------------------------------------------------------------------------
  298. # Create a temp directory for the experiment and the description files
  299. experimentDir = tempfile.mkdtemp()
  300. try:
  301. logger.info("Using experiment directory: %s" % (experimentDir))
  302. # Create the decription.py from the overrides in params
  303. paramsFilePath = os.path.join(experimentDir, 'description.py')
  304. paramsFile = open(paramsFilePath, 'wb')
  305. paramsFile.write(_paramsFileHead())
  306. items = params.items()
  307. items.sort()
  308. for (key,value) in items:
  309. quotedKey = _quoteAndEscape(key)
  310. if isinstance(value, basestring):
  311. paramsFile.write(" %s : '%s',\n" % (quotedKey , value))
  312. else:
  313. paramsFile.write(" %s : %s,\n" % (quotedKey , value))
  314. paramsFile.write(_paramsFileTail())
  315. paramsFile.close()
  316. # Write out the base description
  317. baseParamsFile = open(os.path.join(experimentDir, 'base.py'), 'wb')
  318. baseParamsFile.write(baseDescription)
  319. baseParamsFile.close()
  320. # Store the experiment's sub-description file into the model table
  321. # for reference
  322. fd = open(paramsFilePath)
  323. expDescription = fd.read()
  324. fd.close()
  325. jobsDAO.modelSetFields(modelID, {'genDescription': expDescription})
  326. # Run the experiment now
  327. try:
  328. runner = OPFModelRunner(
  329. modelID=modelID,
  330. jobID=jobID,
  331. predictedField=predictedField,
  332. experimentDir=experimentDir,
  333. reportKeyPatterns=reportKeys,
  334. optimizeKeyPattern=optimizeKey,
  335. jobsDAO=jobsDAO,
  336. modelCheckpointGUID=modelCheckpointGUID,
  337. logLevel=logLevel,
  338. predictionCacheMaxRecords=predictionCacheMaxRecords)
  339. signal.signal(signal.SIGINT, runner.handleWarningSignal)
  340. (completionReason, completionMsg) = runner.run()
  341. except InvalidConnectionException:
  342. raise
  343. except Exception, e:
  344. (completionReason, completionMsg) = _handleModelRunnerException(jobID,
  345. modelID, jobsDAO, experimentDir, logger, e)
  346. finally:
  347. # delete our temporary directory tree
  348. shutil.rmtree(experimentDir)
  349. signal.signal(signal.SIGINT, signal.default_int_handler)
  350. # Return completion reason and msg
  351. return (completionReason, completionMsg)
  352. def runDummyModel(modelID, jobID, params, predictedField, reportKeys,
  353. optimizeKey, jobsDAO, modelCheckpointGUID, logLevel=None, predictionCacheMaxRecords=None):
  354. from nupic.swarming.DummyModelRunner import OPFDummyModelRunner
  355. # The logger for this method
  356. logger = logging.getLogger('com.numenta.nupic.hypersearch.utils')
  357. # Run the experiment now
  358. try:
  359. if type(params) is bool:
  360. params = {}
  361. runner = OPFDummyModelRunner(modelID=modelID,
  362. jobID=jobID,
  363. params=params,
  364. predictedField=predictedField,
  365. reportKeyPatterns=reportKeys,
  366. optimizeKeyPattern=optimizeKey,
  367. jobsDAO=jobsDAO,
  368. modelCheckpointGUID=modelCheckpointGUID,
  369. logLevel=logLevel,
  370. predictionCacheMaxRecords=predictionCacheMaxRecords)
  371. (completionReason, completionMsg) = runner.run()
  372. # The dummy model runner will call sys.exit(1) if
  373. # NTA_TEST_sysExitFirstNModels is set and the number of models in the
  374. # models table is <= NTA_TEST_sysExitFirstNModels
  375. except SystemExit:
  376. sys.exit(1)
  377. except InvalidConnectionException:
  378. raise
  379. except Exception, e:
  380. (completionReason, completionMsg) = _handleModelRunnerException(jobID,
  381. modelID, jobsDAO, "NA",
  382. logger, e)
  383. # Return completion reason and msg
  384. return (completionReason, completionMsg)
  385. # Passed as parameter to ActivityMgr
  386. #
  387. # repeating: True if the activity is a repeating activite, False if one-shot
  388. # period: period of activity's execution (number of "ticks")
  389. # cb: a callable to call upon expiration of period; will be called
  390. # as cb()
  391. PeriodicActivityRequest = namedtuple("PeriodicActivityRequest",
  392. ("repeating", "period", "cb"))
  393. class PeriodicActivityMgr(object):
  394. """
  395. TODO: move to shared script so that we can share it with run_opf_experiment
  396. """
  397. # iteratorHolder: a list holding one iterator; we use a list so that we can
  398. # replace the iterator for repeating activities (a tuple would not
  399. # allow it if the field was an imutable value)
  400. Activity = namedtuple("Activity", ("repeating",
  401. "period",
  402. "cb",
  403. "iteratorHolder"))
  404. def __init__(self, requestedActivities):
  405. """
  406. requestedActivities: a sequence of PeriodicActivityRequest elements
  407. """
  408. self.__activities = []
  409. for req in requestedActivities:
  410. act = self.Activity(repeating=req.repeating,
  411. period=req.period,
  412. cb=req.cb,
  413. iteratorHolder=[iter(xrange(req.period))])
  414. self.__activities.append(act)
  415. return
  416. def tick(self):
  417. """ Activity tick handler; services all activities
  418. Returns: True if controlling iterator says it's okay to keep going;
  419. False to stop
  420. """
  421. # Run activities whose time has come
  422. for act in self.__activities:
  423. if not act.iteratorHolder[0]:
  424. continue
  425. try:
  426. next(act.iteratorHolder[0])
  427. except StopIteration:
  428. act.cb()
  429. if act.repeating:
  430. act.iteratorHolder[0] = iter(xrange(act.period))
  431. else:
  432. act.iteratorHolder[0] = None
  433. return True
  434. def generatePersistentJobGUID():
  435. """Generates a "persistentJobGUID" value.
  436. Parameters:
  437. ----------------------------------------------------------------------
  438. retval: A persistentJobGUID value
  439. """
  440. return "JOB_UUID1-" + str(uuid.uuid1())
  441. def identityConversion(value, _keys):
  442. return value
  443. def rCopy(d, f=identityConversion, discardNoneKeys=True, deepCopy=True):
  444. """Recursively copies a dict and returns the result.
  445. Args:
  446. d: The dict to copy.
  447. f: A function to apply to values when copying that takes the value and the
  448. list of keys from the root of the dict to the value and returns a value
  449. for the new dict.
  450. discardNoneKeys: If True, discard key-value pairs when f returns None for
  451. the value.
  452. deepCopy: If True, all values in returned dict are true copies (not the
  453. same object).
  454. Returns:
  455. A new dict with keys and values from d replaced with the result of f.
  456. """
  457. # Optionally deep copy the dict.
  458. if deepCopy:
  459. d = copy.deepcopy(d)
  460. newDict = {}
  461. toCopy = [(k, v, newDict, ()) for k, v in d.iteritems()]
  462. while len(toCopy) > 0:
  463. k, v, d, prevKeys = toCopy.pop()
  464. prevKeys = prevKeys + (k,)
  465. if isinstance(v, dict):
  466. d[k] = dict()
  467. toCopy[0:0] = [(innerK, innerV, d[k], prevKeys)
  468. for innerK, innerV in v.iteritems()]
  469. else:
  470. #print k, v, prevKeys
  471. newV = f(v, prevKeys)
  472. if not discardNoneKeys or newV is not None:
  473. d[k] = newV
  474. return newDict
  475. def rApply(d, f):
  476. """Recursively applies f to the values in dict d.
  477. Args:
  478. d: The dict to recurse over.
  479. f: A function to apply to values in d that takes the value and a list of
  480. keys from the root of the dict to the value.
  481. """
  482. remainingDicts = [(d, ())]
  483. while len(remainingDicts) > 0:
  484. current, prevKeys = remainingDicts.pop()
  485. for k, v in current.iteritems():
  486. keys = prevKeys + (k,)
  487. if isinstance(v, dict):
  488. remainingDicts.insert(0, (v, keys))
  489. else:
  490. f(v, keys)
  491. def clippedObj(obj, maxElementSize=64):
  492. """
  493. Return a clipped version of obj suitable for printing, This
  494. is useful when generating log messages by printing data structures, but
  495. don't want the message to be too long.
  496. If passed in a dict, list, or namedtuple, each element of the structure's
  497. string representation will be limited to 'maxElementSize' characters. This
  498. will return a new object where the string representation of each element
  499. has been truncated to fit within maxElementSize.
  500. """
  501. # Is it a named tuple?
  502. if hasattr(obj, '_asdict'):
  503. obj = obj._asdict()
  504. # Printing a dict?
  505. if isinstance(obj, dict):
  506. objOut = dict()
  507. for key,val in obj.iteritems():
  508. objOut[key] = clippedObj(val)
  509. # Printing a list?
  510. elif hasattr(obj, '__iter__'):
  511. objOut = []
  512. for val in obj:
  513. objOut.append(clippedObj(val))
  514. # Some other object
  515. else:
  516. objOut = str(obj)
  517. if len(objOut) > maxElementSize:
  518. objOut = objOut[0:maxElementSize] + '...'
  519. return objOut
  520. class ValidationError(validictory.ValidationError):
  521. pass
  522. def validate(value, **kwds):
  523. """ Validate a python value against json schema:
  524. validate(value, schemaPath)
  525. validate(value, schemaDict)
  526. value: python object to validate against the schema
  527. The json schema may be specified either as a path of the file containing
  528. the json schema or as a python dictionary using one of the
  529. following keywords as arguments:
  530. schemaPath: Path of file containing the json schema object.
  531. schemaDict: Python dictionary containing the json schema object
  532. Returns: nothing
  533. Raises:
  534. ValidationError when value fails json validation
  535. """
  536. assert len(kwds.keys()) >= 1
  537. assert 'schemaPath' in kwds or 'schemaDict' in kwds
  538. schemaDict = None
  539. if 'schemaPath' in kwds:
  540. schemaPath = kwds.pop('schemaPath')
  541. schemaDict = loadJsonValueFromFile(schemaPath)
  542. elif 'schemaDict' in kwds:
  543. schemaDict = kwds.pop('schemaDict')
  544. try:
  545. validictory.validate(value, schemaDict, **kwds)
  546. except validictory.ValidationError as e:
  547. raise ValidationError(e)
  548. def loadJsonValueFromFile(inputFilePath):
  549. """ Loads a json value from a file and converts it to the corresponding python
  550. object.
  551. inputFilePath:
  552. Path of the json file;
  553. Returns:
  554. python value that represents the loaded json value
  555. """
  556. with open(inputFilePath) as fileObj:
  557. value = json.load(fileObj)
  558. return value
  559. def sortedJSONDumpS(obj):
  560. """
  561. Return a JSON representation of obj with sorted keys on any embedded dicts.
  562. This insures that the same object will always be represented by the same
  563. string even if it contains dicts (where the sort order of the keys is
  564. normally undefined).
  565. """
  566. itemStrs = []
  567. if isinstance(obj, dict):
  568. items = obj.items()
  569. items.sort()
  570. for key, value in items:
  571. itemStrs.append('%s: %s' % (json.dumps(key), sortedJSONDumpS(value)))
  572. return '{%s}' % (', '.join(itemStrs))
  573. elif hasattr(obj, '__iter__'):
  574. for val in obj:
  575. itemStrs.append(sortedJSONDumpS(val))
  576. return '[%s]' % (', '.join(itemStrs))
  577. else:
  578. return json.dumps(obj)