PageRenderTime 41ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 0ms

/src/train_predict_keras4.py

https://gitlab.com/tianzhou2011/talkingdata
Python | 178 lines | 129 code | 14 blank | 35 comment | 5 complexity | 20343ffb2c5d4ecad1e6f479e898eb5c MD5 | raw file
  1. #!/usr/bin/env python
  2. from __future__ import absolute_import, division, print_function
  3. from keras.callbacks import EarlyStopping
  4. from keras.models import Sequential
  5. from keras.layers.core import Dense, Dropout, Activation
  6. from keras.layers.normalization import BatchNormalization
  7. from keras.layers.advanced_activations import PReLU
  8. from keras.utils import np_utils
  9. from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
  10. from sklearn.cross_validation import StratifiedKFold
  11. from sklearn.metrics import log_loss
  12. import argparse
  13. import logging
  14. import numpy as np
  15. import os
  16. import pandas as pd
  17. import time
  18. from kaggler.data_io import load_data
  19. from const import N_CLASS, SEED
  20. np.random.seed(SEED)
  21. def batch_generator(X, y, batch_size, shuffle):
  22. """ Chenglong's code for fiting from generator.
  23. (https://www.kaggle.com/c/talkingdata-mobile-user-demographics/forums/t/22567/neural-network-for-sparse-matrices)
  24. """
  25. number_of_batches = np.ceil(X.shape[0]/batch_size)
  26. counter = 0
  27. sample_index = np.arange(X.shape[0])
  28. if shuffle:
  29. np.random.shuffle(sample_index)
  30. while True:
  31. batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]
  32. X_batch = X[batch_index,:]
  33. y_batch = y[batch_index]
  34. counter += 1
  35. yield X_batch, y_batch
  36. if (counter == number_of_batches):
  37. if shuffle:
  38. np.random.shuffle(sample_index)
  39. counter = 0
  40. def batch_generatorp(X, batch_size, shuffle):
  41. number_of_batches = X.shape[0] / np.ceil(X.shape[0]/batch_size)
  42. counter = 0
  43. sample_index = np.arange(X.shape[0])
  44. while True:
  45. batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)]
  46. X_batch = X[batch_index, :]
  47. counter += 1
  48. yield X_batch
  49. if (counter == number_of_batches):
  50. counter = 0
  51. def baseline_model(nb_classes, dims, hiddens=2, neurons=512, dropout=0.5):
  52. # create model
  53. model = Sequential()
  54. model.add(Dense(neurons, input_dim=dims, init='normal'))
  55. model.add(PReLU())
  56. model.add(Dropout(dropout))
  57. for i in range(hiddens):
  58. model.add(Dense(neurons // (2 ** i), init='normal'))
  59. model.add(PReLU())
  60. model.add(Dropout(dropout // (2 ** i)))
  61. model.add(Dense(nb_classes, init='normal', activation='softmax'))
  62. model.compile(loss='categorical_crossentropy', optimizer='adadelta',
  63. metrics=['accuracy'])
  64. return model
  65. def train_predict(train_file, test_file, predict_valid_file, predict_test_file,
  66. cv_id_file, n_est=100, hiddens=2, neurons=512, dropout=0.5,
  67. batch=16, n_stop=2, n_fold=5):
  68. feature_name = os.path.basename(train_file)[:-8]
  69. model_name = 'keras4_{}_{}_{}_{}_{}_{}_{}'.format(
  70. n_est, hiddens, neurons, dropout, batch, n_stop, feature_name
  71. )
  72. logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
  73. level=logging.DEBUG,
  74. filename='{}.log'.format(model_name))
  75. logging.info('Loading training and test data...')
  76. X, y = load_data(train_file, dense=True)
  77. Y = np_utils.to_categorical(y)
  78. X_tst, _ = load_data(test_file, dense=True)
  79. nb_classes = Y.shape[1]
  80. dims = X.shape[1]
  81. logging.info('{} classes, {} dims'.format(nb_classes, dims))
  82. logging.info('Loading CV Ids')
  83. cv_id = np.loadtxt(cv_id_file)
  84. P_val = np.zeros_like(Y)
  85. P_tst = np.zeros((X_tst.shape[0], nb_classes))
  86. for i in range(1, n_fold + 1):
  87. i_trn = np.where(cv_id != i)[0]
  88. i_val = np.where(cv_id == i)[0]
  89. logging.info('Training model #{}'.format(i))
  90. clf = baseline_model(nb_classes, dims, hiddens, neurons, dropout)
  91. if i == 1:
  92. early_stopping = EarlyStopping(monitor='val_loss', patience=n_stop)
  93. h = clf.fit_generator(generator=batch_generator(X[i_trn], Y[i_trn], batch, True),
  94. nb_epoch=n_est,
  95. samples_per_epoch=len(i_trn),
  96. validation_data=(X[i_val], Y[i_val]),
  97. verbose=2,
  98. callbacks=[early_stopping])
  99. val_losses = h.history['val_loss']
  100. n_best = val_losses.index(min(val_losses)) + 1
  101. logging.info('best epoch={}'.format(n_best))
  102. else:
  103. clf.fit_generator(generator=batch_generator(X[i_trn], Y[i_trn], batch, True),
  104. nb_epoch=n_best,
  105. samples_per_epoch=len(i_trn),
  106. validation_data=(X[i_val], Y[i_val]),
  107. verbose=2)
  108. P_val[i_val] = clf.predict_generator(generator=batch_generatorp(X[i_val], batch, False),
  109. val_samples=X[i_val].shape[0])
  110. logging.info('CV #{} Log Loss: {:.6f}'.format(i, log_loss(Y[i_val], P_val[i_val])))
  111. P_tst += clf.predict_generator(generator=batch_generatorp(X_tst, batch, False),
  112. val_samples=X_tst.shape[0]) / n_fold
  113. logging.info('Saving normalized validation predictions...')
  114. logging.info('CV Log Loss: {:.6f}'.format(log_loss(Y, P_val)))
  115. np.savetxt(predict_valid_file, P_val, fmt='%.6f', delimiter=',')
  116. logging.info('Saving normalized test predictions...')
  117. np.savetxt(predict_test_file, P_tst, fmt='%.6f', delimiter=',')
  118. if __name__ == '__main__':
  119. parser = argparse.ArgumentParser()
  120. parser.add_argument('--train-file', required=True, dest='train_file')
  121. parser.add_argument('--test-file', required=True, dest='test_file')
  122. parser.add_argument('--predict-valid-file', required=True,
  123. dest='predict_valid_file')
  124. parser.add_argument('--predict-test-file', required=True,
  125. dest='predict_test_file')
  126. parser.add_argument('--n-est', default=10, type=int, dest='n_est')
  127. parser.add_argument('--batch-size', default=64, type=int,
  128. dest='batch_size')
  129. parser.add_argument('--hiddens', default=2, type=int)
  130. parser.add_argument('--neurons', default=512, type=int)
  131. parser.add_argument('--dropout', default=0.5, type=float)
  132. parser.add_argument('--early-stopping', default=2, type=int, dest='n_stop')
  133. parser.add_argument('--cv-id', required=True, dest='cv_id_file')
  134. args = parser.parse_args()
  135. start = time.time()
  136. train_predict(train_file=args.train_file,
  137. test_file=args.test_file,
  138. predict_valid_file=args.predict_valid_file,
  139. predict_test_file=args.predict_test_file,
  140. cv_id_file=args.cv_id_file,
  141. n_est=args.n_est,
  142. neurons=args.neurons,
  143. dropout=args.dropout,
  144. batch=args.batch_size,
  145. hiddens=args.hiddens,
  146. n_stop=args.n_stop)
  147. logging.info('finished ({:.2f} min elasped)'.format((time.time() - start) /
  148. 60))