PageRenderTime 25ms CodeModel.GetById 20ms RepoModel.GetById 0ms app.codeStats 0ms

/tensorflow/examples/learn/wide_n_deep_tutorial.py

https://gitlab.com/github-cloud-corporation/tensorflow
Python | 212 lines | 193 code | 4 blank | 15 comment | 1 complexity | 3f180c53950dc81efcead891cbf523fc MD5 | raw file
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Example code for TensorFlow Wide & Deep Tutorial using TF.Learn API."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tempfile
  20. import urllib
  21. import pandas as pd
  22. import tensorflow as tf
  23. flags = tf.app.flags
  24. FLAGS = flags.FLAGS
  25. flags.DEFINE_string("model_dir", "", "Base directory for output models.")
  26. flags.DEFINE_string("model_type", "wide_n_deep",
  27. "Valid model types: {'wide', 'deep', 'wide_n_deep'}.")
  28. flags.DEFINE_integer("train_steps", 200, "Number of training steps.")
  29. flags.DEFINE_string(
  30. "train_data",
  31. "",
  32. "Path to the training data.")
  33. flags.DEFINE_string(
  34. "test_data",
  35. "",
  36. "Path to the test data.")
  37. COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
  38. "marital_status", "occupation", "relationship", "race", "gender",
  39. "capital_gain", "capital_loss", "hours_per_week", "native_country",
  40. "income_bracket"]
  41. LABEL_COLUMN = "label"
  42. CATEGORICAL_COLUMNS = ["workclass", "education", "marital_status", "occupation",
  43. "relationship", "race", "gender", "native_country"]
  44. CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss",
  45. "hours_per_week"]
  46. def maybe_download():
  47. """May be downloads training data and returns train and test file names."""
  48. if FLAGS.train_data:
  49. train_file_name = FLAGS.train_data
  50. else:
  51. train_file = tempfile.NamedTemporaryFile(delete=False)
  52. urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long
  53. train_file_name = train_file.name
  54. train_file.close()
  55. print("Training data is downloaded to %s" % train_file_name)
  56. if FLAGS.test_data:
  57. test_file_name = FLAGS.test_data
  58. else:
  59. test_file = tempfile.NamedTemporaryFile(delete=False)
  60. urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long
  61. test_file_name = test_file.name
  62. test_file.close()
  63. print("Test data is downloaded to %s" % test_file_name)
  64. return train_file_name, test_file_name
  65. def build_estimator(model_dir):
  66. """Build an estimator."""
  67. # Sparse base columns.
  68. gender = tf.contrib.layers.sparse_column_with_keys(column_name="gender",
  69. keys=["female", "male"])
  70. race = tf.contrib.layers.sparse_column_with_keys(column_name="race",
  71. keys=["Amer-Indian-Eskimo",
  72. "Asian-Pac-Islander",
  73. "Black", "Other",
  74. "White"])
  75. education = tf.contrib.layers.sparse_column_with_hash_bucket(
  76. "education", hash_bucket_size=1000)
  77. marital_status = tf.contrib.layers.sparse_column_with_hash_bucket(
  78. "marital_status", hash_bucket_size=100)
  79. relationship = tf.contrib.layers.sparse_column_with_hash_bucket(
  80. "relationship", hash_bucket_size=100)
  81. workclass = tf.contrib.layers.sparse_column_with_hash_bucket(
  82. "workclass", hash_bucket_size=100)
  83. occupation = tf.contrib.layers.sparse_column_with_hash_bucket(
  84. "occupation", hash_bucket_size=1000)
  85. native_country = tf.contrib.layers.sparse_column_with_hash_bucket(
  86. "native_country", hash_bucket_size=1000)
  87. # Continuous base columns.
  88. age = tf.contrib.layers.real_valued_column("age")
  89. education_num = tf.contrib.layers.real_valued_column("education_num")
  90. capital_gain = tf.contrib.layers.real_valued_column("capital_gain")
  91. capital_loss = tf.contrib.layers.real_valued_column("capital_loss")
  92. hours_per_week = tf.contrib.layers.real_valued_column("hours_per_week")
  93. # Transformations.
  94. age_buckets = tf.contrib.layers.bucketized_column(age,
  95. boundaries=[
  96. 18, 25, 30, 35, 40, 45,
  97. 50, 55, 60, 65
  98. ])
  99. # Wide columns and deep columns.
  100. wide_columns = [gender, native_country, education, occupation, workclass,
  101. marital_status, relationship, age_buckets,
  102. tf.contrib.layers.crossed_column([education, occupation],
  103. hash_bucket_size=int(1e4)),
  104. tf.contrib.layers.crossed_column(
  105. [age_buckets, race, occupation],
  106. hash_bucket_size=int(1e6)),
  107. tf.contrib.layers.crossed_column([native_country, occupation],
  108. hash_bucket_size=int(1e4))]
  109. deep_columns = [
  110. tf.contrib.layers.embedding_column(workclass, dimension=8),
  111. tf.contrib.layers.embedding_column(education, dimension=8),
  112. tf.contrib.layers.embedding_column(marital_status,
  113. dimension=8),
  114. tf.contrib.layers.embedding_column(gender, dimension=8),
  115. tf.contrib.layers.embedding_column(relationship, dimension=8),
  116. tf.contrib.layers.embedding_column(race, dimension=8),
  117. tf.contrib.layers.embedding_column(native_country,
  118. dimension=8),
  119. tf.contrib.layers.embedding_column(occupation, dimension=8),
  120. age,
  121. education_num,
  122. capital_gain,
  123. capital_loss,
  124. hours_per_week,
  125. ]
  126. if FLAGS.model_type == "wide":
  127. m = tf.contrib.learn.LinearClassifier(model_dir=model_dir,
  128. feature_columns=wide_columns)
  129. elif FLAGS.model_type == "deep":
  130. m = tf.contrib.learn.DNNClassifier(model_dir=model_dir,
  131. feature_columns=deep_columns,
  132. hidden_units=[100, 50])
  133. else:
  134. m = tf.contrib.learn.DNNLinearCombinedClassifier(
  135. model_dir=model_dir,
  136. linear_feature_columns=wide_columns,
  137. dnn_feature_columns=deep_columns,
  138. dnn_hidden_units=[100, 50])
  139. return m
  140. def input_fn(df):
  141. """Input builder function."""
  142. # Creates a dictionary mapping from each continuous feature column name (k) to
  143. # the values of that column stored in a constant Tensor.
  144. continuous_cols = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLUMNS}
  145. # Creates a dictionary mapping from each categorical feature column name (k)
  146. # to the values of that column stored in a tf.SparseTensor.
  147. categorical_cols = {k: tf.SparseTensor(
  148. indices=[[i, 0] for i in range(df[k].size)],
  149. values=df[k].values,
  150. shape=[df[k].size, 1])
  151. for k in CATEGORICAL_COLUMNS}
  152. # Merges the two dictionaries into one.
  153. feature_cols = dict(continuous_cols)
  154. feature_cols.update(categorical_cols)
  155. # Converts the label column into a constant Tensor.
  156. label = tf.constant(df[LABEL_COLUMN].values)
  157. # Returns the feature columns and the label.
  158. return feature_cols, label
  159. def train_and_eval():
  160. """Train and evaluate the model."""
  161. train_file_name, test_file_name = maybe_download()
  162. df_train = pd.read_csv(
  163. tf.gfile.Open(train_file_name),
  164. names=COLUMNS,
  165. skipinitialspace=True)
  166. df_test = pd.read_csv(
  167. tf.gfile.Open(test_file_name),
  168. names=COLUMNS,
  169. skipinitialspace=True,
  170. skiprows=1)
  171. df_train[LABEL_COLUMN] = (
  172. df_train["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
  173. df_test[LABEL_COLUMN] = (
  174. df_test["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
  175. model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
  176. print("model directory = %s" % model_dir)
  177. m = build_estimator(model_dir)
  178. m.fit(input_fn=lambda: input_fn(df_train), steps=FLAGS.train_steps)
  179. results = m.evaluate(input_fn=lambda: input_fn(df_test), steps=1)
  180. for key in sorted(results):
  181. print("%s: %s" % (key, results[key]))
  182. def main(_):
  183. train_and_eval()
  184. if __name__ == "__main__":
  185. tf.app.run()