PageRenderTime 24ms CodeModel.GetById 24ms RepoModel.GetById 0ms app.codeStats 0ms

/tensorflow/python/training/server_lib.py

https://gitlab.com/admin-github-cloud/tensorflow
Python | 345 lines | 295 code | 7 blank | 43 comment | 8 complexity | 458603252068c2ef6a6b728da1d6b3db MD5 | raw file
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A Python interface for creating TensorFlow servers."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import six # pylint: disable=unused-import
  20. from tensorflow.core.protobuf import tensorflow_server_pb2
  21. from tensorflow.python import pywrap_tensorflow
  22. from tensorflow.python.framework import errors
  23. from tensorflow.python.util import compat
  24. def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
  25. config):
  26. """Creates a `tf.train.ServerDef` protocol buffer.
  27. Args:
  28. server_or_cluster_def: A `tf.train.ServerDef` or
  29. `tf.train.ClusterDef` protocol buffer, or a
  30. `tf.train.ClusterSpec` object, describing the server to be
  31. defined and/or the cluster of which it is a member.
  32. job_name: (Optional.) Specifies the name of the job of which the server
  33. is a member. Defaults to the value in `server_or_cluster_def`, if
  34. specified.
  35. task_index: (Optional.) Specifies the task index of the server in its job.
  36. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
  37. defaults to 0 if the server's job has only one task.
  38. protocol: (Optional.) Specifies the protocol to be used by the server.
  39. Acceptable values include `"grpc"`. Defaults to the value in
  40. `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
  41. config: (Options.) A `tf.ConfigProto` that specifies default configuration
  42. options for all sessions that run on this server.
  43. Returns:
  44. A `tf.train.ServerDef`.
  45. Raises:
  46. TypeError: If the arguments do not have the appropriate type.
  47. ValueError: If an argument is not specified and cannot be inferred.
  48. """
  49. server_def = tensorflow_server_pb2.ServerDef()
  50. if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
  51. server_def.MergeFrom(server_or_cluster_def)
  52. if job_name is not None:
  53. server_def.job_name = job_name
  54. if task_index is not None:
  55. server_def.task_index = task_index
  56. if protocol is not None:
  57. server_def.protocol = protocol
  58. if config is not None:
  59. server_def.default_session_config.MergeFrom(config)
  60. else:
  61. try:
  62. cluster_spec = ClusterSpec(server_or_cluster_def)
  63. except TypeError:
  64. raise TypeError("Could not convert `server_or_cluster_def` to a "
  65. "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
  66. if job_name is None:
  67. if len(cluster_spec.jobs) == 1:
  68. job_name = cluster_spec.jobs[0]
  69. else:
  70. raise ValueError("Must specify an explicit `job_name`.")
  71. if task_index is None:
  72. if len(cluster_spec.job_tasks(job_name)) == 1:
  73. task_index = 0
  74. else:
  75. raise ValueError("Must specify an explicit `task_index`.")
  76. if protocol is None:
  77. protocol = "grpc"
  78. server_def = tensorflow_server_pb2.ServerDef(
  79. cluster=cluster_spec.as_cluster_def(),
  80. job_name=job_name, task_index=task_index, protocol=protocol)
  81. if config is not None:
  82. server_def.default_session_config.MergeFrom(config)
  83. return server_def
  84. class Server(object):
  85. """An in-process TensorFlow server, for use in distributed training.
  86. A `tf.train.Server` instance encapsulates a set of devices and a
  87. [`tf.Session`](../../api_docs/python/client.md#Session) target that
  88. can participate in distributed training. A server belongs to a
  89. cluster (specified by a [`tf.train.ClusterSpec`](#ClusterSpec)), and
  90. corresponds to a particular task in a named job. The server can
  91. communicate with any other server in the same cluster.
  92. @@__init__
  93. @@create_local_server
  94. @@target
  95. @@server_def
  96. @@start
  97. @@join
  98. """
  99. def __init__(self,
  100. server_or_cluster_def,
  101. job_name=None,
  102. task_index=None,
  103. protocol=None,
  104. config=None,
  105. start=True):
  106. """Creates a new server with the given definition.
  107. The `job_name`, `task_index`, and `protocol` arguments are optional, and
  108. override any information provided in `server_or_cluster_def`.
  109. Args:
  110. server_or_cluster_def: A `tf.train.ServerDef` or
  111. `tf.train.ClusterDef` protocol buffer, or a
  112. `tf.train.ClusterSpec` object, describing the server to be
  113. created and/or the cluster of which it is a member.
  114. job_name: (Optional.) Specifies the name of the job of which the server
  115. is a member. Defaults to the value in `server_or_cluster_def`, if
  116. specified.
  117. task_index: (Optional.) Specifies the task index of the server in its
  118. job. Defaults to the value in `server_or_cluster_def`, if specified.
  119. Otherwise defaults to 0 if the server's job has only one task.
  120. protocol: (Optional.) Specifies the protocol to be used by the server.
  121. Acceptable values include `"grpc"`. Defaults to the value in
  122. `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
  123. config: (Options.) A `tf.ConfigProto` that specifies default
  124. configuration options for all sessions that run on this server.
  125. start: (Optional.) Boolean, indicating whether to start the server
  126. after creating it. Defaults to `True`.
  127. Raises:
  128. tf.errors.OpError: Or one of its subclasses if an error occurs while
  129. creating the TensorFlow server.
  130. """
  131. self._server_def = _make_server_def(server_or_cluster_def,
  132. job_name, task_index, protocol, config)
  133. with errors.raise_exception_on_not_ok_status() as status:
  134. self._server = pywrap_tensorflow.PyServer_New(
  135. self._server_def.SerializeToString(), status)
  136. if start:
  137. self.start()
  138. def start(self):
  139. """Starts this server.
  140. Raises:
  141. tf.errors.OpError: Or one of its subclasses if an error occurs while
  142. starting the TensorFlow server.
  143. """
  144. with errors.raise_exception_on_not_ok_status() as status:
  145. pywrap_tensorflow.PyServer_Start(self._server, status)
  146. def join(self):
  147. """Blocks until the server has shut down.
  148. This method currently blocks forever.
  149. Raises:
  150. tf.errors.OpError: Or one of its subclasses if an error occurs while
  151. joining the TensorFlow server.
  152. """
  153. with errors.raise_exception_on_not_ok_status() as status:
  154. pywrap_tensorflow.PyServer_Join(self._server, status)
  155. @property
  156. def server_def(self):
  157. """Returns the `tf.train.ServerDef` for this server.
  158. Returns:
  159. A `tf.train.ServerDef` protocol buffer that describes the configuration
  160. of this server.
  161. """
  162. return self._server_def
  163. @property
  164. def target(self):
  165. """Returns the target for a `tf.Session` to connect to this server.
  166. To create a
  167. [`tf.Session`](../../api_docs/python/client.md#Session) that
  168. connects to this server, use the following snippet:
  169. ```python
  170. server = tf.train.Server(...)
  171. with tf.Session(server.target):
  172. # ...
  173. ```
  174. Returns:
  175. A string containing a session target for this server.
  176. """
  177. return self._server.target()
  178. @staticmethod
  179. def create_local_server(config=None, start=True):
  180. """Creates a new single-process cluster running on the local host.
  181. This method is a convenience wrapper for creating a
  182. `tf.train.Server` with a `tf.train.ServerDef` that specifies a
  183. single-process cluster containing a single task in a job called
  184. `"local"`.
  185. Args:
  186. config: (Options.) A `tf.ConfigProto` that specifies default
  187. configuration options for all sessions that run on this server.
  188. start: (Optional.) Boolean, indicating whether to start the server after
  189. creating it. Defaults to `True`.
  190. Returns:
  191. A local `tf.train.Server`.
  192. """
  193. # Specifying port 0 means that the OS will choose a free port for the
  194. # server.
  195. return Server({"local": ["localhost:0"]}, protocol="grpc", config=config,
  196. start=start)
  197. class ClusterSpec(object):
  198. """Represents a cluster as a set of "tasks", organized into "jobs".
  199. A `tf.train.ClusterSpec` represents the set of processes that
  200. participate in a distributed TensorFlow computation. Every
  201. [`tf.train.Server`](#Server) is constructed in a particular cluster.
  202. To create a cluster with two jobs and five tasks, you specify the
  203. mapping from job names to lists of network addresses (typically
  204. hostname-port pairs).
  205. ```
  206. cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
  207. "worker1.example.com:2222",
  208. "worker2.example.com:2222"],
  209. "ps": ["ps0.example.com:2222",
  210. "ps1.example.com:2222"]})
  211. ```
  212. @@as_cluster_def
  213. @@as_dict
  214. """
  215. def __init__(self, cluster):
  216. """Creates a `ClusterSpec`.
  217. Args:
  218. cluster: A dictionary mapping one or more job names to lists of network
  219. addresses, or a `tf.train.ClusterDef` protocol buffer.
  220. Raises:
  221. TypeError: If `cluster` is not a dictionary mapping strings to lists
  222. of strings, and not a `tf.train.ClusterDef` protobuf.
  223. """
  224. if isinstance(cluster, dict):
  225. self._cluster_spec = cluster
  226. self._make_cluster_def()
  227. elif isinstance(cluster, tensorflow_server_pb2.ClusterDef):
  228. self._cluster_def = cluster
  229. self._cluster_spec = {}
  230. for job_def in self._cluster_def.job:
  231. self._cluster_spec[job_def.name] = [t for t in job_def.tasks.values()]
  232. elif isinstance(cluster, ClusterSpec):
  233. self._cluster_def = tensorflow_server_pb2.ClusterDef()
  234. self._cluster_def.MergeFrom(cluster.as_cluster_def())
  235. self._cluster_spec = {}
  236. for job_def in self._cluster_def.job:
  237. self._cluster_spec[job_def.name] = [t for t in job_def.tasks.values()]
  238. else:
  239. raise TypeError("`cluster` must be a dictionary mapping one or more "
  240. "job names to lists of network addresses, or a "
  241. "`ClusterDef` protocol buffer")
  242. def as_dict(self):
  243. """Returns a dictionary from job names to lists of network addresses."""
  244. return self._cluster_spec
  245. def as_cluster_def(self):
  246. """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster."""
  247. return self._cluster_def
  248. @property
  249. def jobs(self):
  250. """Returns a list of job names in this cluster.
  251. Returns:
  252. A list of strings, corresponding to the names of jobs in this cluster.
  253. """
  254. return list(self._cluster_spec.keys())
  255. def job_tasks(self, job_name):
  256. """Returns a list of tasks in the given job.
  257. Args:
  258. job_name: The string name of a job in this cluster.
  259. Returns:
  260. A list of strings, corresponding to the network addresses of tasks in
  261. the given job, ordered by task index.
  262. Raises:
  263. ValueError: If `job_name` does not name a job in this cluster.
  264. """
  265. try:
  266. return [task for task in self._cluster_spec[job_name]]
  267. except IndexError:
  268. raise ValueError("No such job in cluster: %r" % job_name)
  269. def _make_cluster_def(self):
  270. """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.
  271. Raises:
  272. TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
  273. of strings.
  274. """
  275. self._cluster_def = tensorflow_server_pb2.ClusterDef()
  276. # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
  277. for job_name, task_list in sorted(self._cluster_spec.items()):
  278. try:
  279. job_name = compat.as_bytes(job_name)
  280. except TypeError:
  281. raise TypeError("Job name %r must be bytes or unicode" % job_name)
  282. job_def = self._cluster_def.job.add()
  283. job_def.name = job_name
  284. for i, task_address in enumerate(task_list):
  285. try:
  286. task_address = compat.as_bytes(task_address)
  287. except TypeError:
  288. raise TypeError(
  289. "Task address %r must be bytes or unicode" % task_address)
  290. job_def.tasks[i] = task_address