PageRenderTime 83ms CodeModel.GetById 33ms RepoModel.GetById 1ms app.codeStats 0ms

/tensorflow/python/training/server_lib.py

https://gitlab.com/hrishikeshvganu/tensorflow
Python | 321 lines | 273 code | 7 blank | 41 comment | 7 complexity | 7406ebaeb498578f058fee862220548f MD5 | raw file
  1. # Copyright 2015 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A Python interface for creating TensorFlow servers."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import six # pylint: disable=unused-import
  20. from tensorflow.core.protobuf import tensorflow_server_pb2
  21. from tensorflow.python import pywrap_tensorflow
  22. from tensorflow.python.framework import errors
  23. from tensorflow.python.util import compat
  24. def _make_server_def(server_or_cluster_def, job_name, task_index, protocol):
  25. """Creates a `tf.train.ServerDef` protocol buffer.
  26. Args:
  27. server_or_cluster_def: A `tf.train.ServerDef` or
  28. `tf.train.ClusterDef` protocol buffer, or a
  29. `tf.train.ClusterSpec` object, describing the server to be
  30. defined and/or the cluster of which it is a member.
  31. job_name: (Optional.) Specifies the name of the job of which the server
  32. is a member. Defaults to the value in `server_or_cluster_def`, if
  33. specified.
  34. task_index: (Optional.) Specifies the task index of the server in its job.
  35. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
  36. defaults to 0 if the server's job has only one task.
  37. protocol: (Optional.) Specifies the protocol to be used by the server.
  38. Acceptable values include `"grpc"`. Defaults to the value in
  39. `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
  40. Returns:
  41. A `tf.train.ServerDef`.
  42. Raises:
  43. TypeError: If the arguments do not have the appropriate type.
  44. ValueError: If an argument is not specified and cannot be inferred.
  45. """
  46. server_def = tensorflow_server_pb2.ServerDef()
  47. if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
  48. server_def.MergeFrom(server_or_cluster_def)
  49. if job_name is not None:
  50. server_def.job_name = job_name
  51. if task_index is not None:
  52. server_def.task_index = task_index
  53. if protocol is not None:
  54. server_def.protocol = protocol
  55. else:
  56. try:
  57. cluster_spec = ClusterSpec(server_or_cluster_def)
  58. except TypeError:
  59. raise TypeError("Could not convert `server_or_cluster_def` to a "
  60. "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
  61. if job_name is None:
  62. if len(cluster_spec.jobs) == 1:
  63. job_name = cluster_spec.jobs[0]
  64. else:
  65. raise ValueError("Must specify an explicit `job_name`.")
  66. if task_index is None:
  67. if len(cluster_spec.job_tasks(job_name)) == 1:
  68. task_index = 0
  69. else:
  70. raise ValueError("Must specify an explicit `task_index`.")
  71. if protocol is None:
  72. protocol = "grpc"
  73. server_def = tensorflow_server_pb2.ServerDef(
  74. cluster=cluster_spec.as_cluster_def(),
  75. job_name=job_name, task_index=task_index, protocol=protocol)
  76. return server_def
  77. class Server(object):
  78. """An in-process TensorFlow server, for use in distributed training.
  79. A `tf.train.Server` instance encapsulates a set of devices and a
  80. [`tf.Session`](../../api_docs/python/client.md#Session) target that
  81. can participate in distributed training. A server belongs to a
  82. cluster (specified by a [`tf.train.ClusterSpec`](#ClusterSpec)), and
  83. corresponds to a particular task in a named job. The server can
  84. communicate with any other server in the same cluster.
  85. @@__init__
  86. @@create_local_server
  87. @@target
  88. @@start
  89. @@join
  90. """
  91. def __init__(self,
  92. server_or_cluster_def,
  93. job_name=None,
  94. task_index=None,
  95. protocol=None,
  96. start=True):
  97. """Creates a new server with the given definition.
  98. The `job_name`, `task_index`, and `protocol` arguments are optional, and
  99. override any information provided in `server_or_cluster_def`.
  100. Args:
  101. server_or_cluster_def: A `tf.train.ServerDef` or
  102. `tf.train.ClusterDef` protocol buffer, or a
  103. `tf.train.ClusterSpec` object, describing the server to be
  104. created and/or the cluster of which it is a member.
  105. job_name: (Optional.) Specifies the name of the job of which the server
  106. is a member. Defaults to the value in `server_or_cluster_def`, if
  107. specified.
  108. task_index: (Optional.) Specifies the task index of the server in its
  109. job. Defaults to the value in `server_or_cluster_def`, if specified.
  110. Otherwise defaults to 0 if the server's job has only one task.
  111. protocol: (Optional.) Specifies the protocol to be used by the server.
  112. Acceptable values include `"grpc"`. Defaults to the value in
  113. `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
  114. start: (Optional.) Boolean, indicating whether to start the server
  115. after creating it. Defaults to `True`.
  116. Raises:
  117. tf.errors.OpError: Or one of its subclasses if an error occurs while
  118. creating the TensorFlow server.
  119. """
  120. server_def = _make_server_def(server_or_cluster_def,
  121. job_name, task_index, protocol)
  122. with errors.raise_exception_on_not_ok_status() as status:
  123. self._server = pywrap_tensorflow.PyServer_New(
  124. server_def.SerializeToString(), status)
  125. if start:
  126. self.start()
  127. def start(self):
  128. """Starts this server.
  129. Raises:
  130. tf.errors.OpError: Or one of its subclasses if an error occurs while
  131. starting the TensorFlow server.
  132. """
  133. with errors.raise_exception_on_not_ok_status() as status:
  134. pywrap_tensorflow.PyServer_Start(self._server, status)
  135. def join(self):
  136. """Blocks until the server has shut down.
  137. This method currently blocks forever.
  138. Raises:
  139. tf.errors.OpError: Or one of its subclasses if an error occurs while
  140. joining the TensorFlow server.
  141. """
  142. with errors.raise_exception_on_not_ok_status() as status:
  143. pywrap_tensorflow.PyServer_Join(self._server, status)
  144. @property
  145. def target(self):
  146. """Returns the target for a `tf.Session` to connect to this server.
  147. To create a
  148. [`tf.Session`](../../api_docs/python/client.md#Session) that
  149. connects to this server, use the following snippet:
  150. ```python
  151. server = tf.train.Server(...)
  152. with tf.Session(server.target):
  153. # ...
  154. ```
  155. Returns:
  156. A string containing a session target for this server.
  157. """
  158. return self._server.target()
  159. @staticmethod
  160. def create_local_server(start=True):
  161. """Creates a new single-process cluster running on the local host.
  162. This method is a convenience wrapper for creating a
  163. `tf.train.Server` with a `tf.train.ServerDef` that specifies a
  164. single-process cluster containing a single task in a job called
  165. `"local"`.
  166. Args:
  167. start: (Optional.) Boolean, indicating whether to start the server after
  168. creating it. Defaults to `True`.
  169. Returns:
  170. A local `tf.train.Server`.
  171. """
  172. # Specifying port 0 means that the OS will choose a free port for the
  173. # server.
  174. return Server({"local": ["localhost:0"]}, protocol="grpc", start=start)
  175. class ClusterSpec(object):
  176. """Represents a cluster as a set of "tasks", organized into "jobs".
  177. A `tf.train.ClusterSpec` represents the set of processes that
  178. participate in a distributed TensorFlow computation. Every
  179. [`tf.train.Server`](#Server) is constructed in a particular cluster.
  180. To create a cluster with two jobs and five tasks, you specify the
  181. mapping from job names to lists of network addresses (typically
  182. hostname-port pairs).
  183. ```
  184. cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
  185. "worker1.example.com:2222",
  186. "worker2.example.com:2222"],
  187. "ps": ["ps0.example.com:2222",
  188. "ps1.example.com:2222"]})
  189. ```
  190. @@as_cluster_def
  191. @@as_dict
  192. """
  193. def __init__(self, cluster):
  194. """Creates a `ClusterSpec`.
  195. Args:
  196. cluster: A dictionary mapping one or more job names to lists of network
  197. addresses, or a `tf.train.ClusterDef` protocol buffer.
  198. Raises:
  199. TypeError: If `cluster` is not a dictionary mapping strings to lists
  200. of strings, and not a `tf.train.ClusterDef` protobuf.
  201. """
  202. if isinstance(cluster, dict):
  203. self._cluster_spec = cluster
  204. self._make_cluster_def()
  205. elif isinstance(cluster, tensorflow_server_pb2.ClusterDef):
  206. self._cluster_def = cluster
  207. self._cluster_spec = {}
  208. for job_def in self._cluster_def.job:
  209. self._cluster_spec[job_def.name] = [t for t in job_def.tasks.values()]
  210. elif isinstance(cluster, ClusterSpec):
  211. self._cluster_def = tensorflow_server_pb2.ClusterDef()
  212. self._cluster_def.MergeFrom(cluster.as_cluster_def())
  213. self._cluster_spec = {}
  214. for job_def in self._cluster_def.job:
  215. self._cluster_spec[job_def.name] = [t for t in job_def.tasks.values()]
  216. else:
  217. raise TypeError("`cluster` must be a dictionary mapping one or more "
  218. "job names to lists of network addresses, or a "
  219. "`ClusterDef` protocol buffer")
  220. def as_dict(self):
  221. """Returns a dictionary from job names to lists of network addresses."""
  222. return self._cluster_spec
  223. def as_cluster_def(self):
  224. """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster."""
  225. return self._cluster_def
  226. @property
  227. def jobs(self):
  228. """Returns a list of job names in this cluster.
  229. Returns:
  230. A list of strings, corresponding to the names of jobs in this cluster.
  231. """
  232. return list(self._cluster_spec.keys())
  233. def job_tasks(self, job_name):
  234. """Returns a list of tasks in the given job.
  235. Args:
  236. job_name: The string name of a job in this cluster.
  237. Returns:
  238. A list of strings, corresponding to the network addresses of tasks in
  239. the given job, ordered by task index.
  240. Raises:
  241. ValueError: If `job_name` does not name a job in this cluster.
  242. """
  243. try:
  244. return [task for task in self._cluster_spec[job_name]]
  245. except IndexError:
  246. raise ValueError("No such job in cluster: %r" % job_name)
  247. def _make_cluster_def(self):
  248. """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.
  249. Raises:
  250. TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
  251. of strings.
  252. """
  253. self._cluster_def = tensorflow_server_pb2.ClusterDef()
  254. # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
  255. for job_name, task_list in sorted(self._cluster_spec.items()):
  256. try:
  257. job_name = compat.as_bytes(job_name)
  258. except TypeError:
  259. raise TypeError("Job name %r must be bytes or unicode" % job_name)
  260. job_def = self._cluster_def.job.add()
  261. job_def.name = job_name
  262. for i, task_address in enumerate(task_list):
  263. try:
  264. task_address = compat.as_bytes(task_address)
  265. except TypeError:
  266. raise TypeError(
  267. "Task address %r must be bytes or unicode" % task_address)
  268. job_def.tasks[i] = task_address