/example/image-classification/train_mnist.py

https://gitlab.com/alvinahmadov2/mxnet · Python · 122 lines · 101 code · 10 blank · 11 comment · 10 complexity · 4faae946ac24bdfbea23e6349847a075 MD5 · raw file

  1. import find_mxnet
  2. import mxnet as mx
  3. import argparse
  4. import os, sys
  5. import train_model
  6. parser = argparse.ArgumentParser(description='train an image classifer on mnist')
  7. parser.add_argument('--network', type=str, default='mlp',
  8. choices = ['mlp', 'lenet'],
  9. help = 'the cnn to use')
  10. parser.add_argument('--data-dir', type=str, default='mnist/',
  11. help='the input data directory')
  12. parser.add_argument('--gpus', type=str,
  13. help='the gpus will be used, e.g "0,1,2,3"')
  14. parser.add_argument('--num-examples', type=int, default=60000,
  15. help='the number of training examples')
  16. parser.add_argument('--batch-size', type=int, default=128,
  17. help='the batch size')
  18. parser.add_argument('--lr', type=float, default=.1,
  19. help='the initial learning rate')
  20. parser.add_argument('--model-prefix', type=str,
  21. help='the prefix of the model to load/save')
  22. parser.add_argument('--num-epochs', type=int, default=10,
  23. help='the number of training epochs')
  24. parser.add_argument('--load-epoch', type=int,
  25. help="load the model on an epoch using the model-prefix")
  26. parser.add_argument('--kv-store', type=str, default='local',
  27. help='the kvstore type')
  28. parser.add_argument('--lr-factor', type=float, default=1,
  29. help='times the lr with a factor for every lr-factor-epoch epoch')
  30. parser.add_argument('--lr-factor-epoch', type=float, default=1,
  31. help='the number of epoch to factor the lr, could be .5')
  32. args = parser.parse_args()
  33. def _download(data_dir):
  34. if not os.path.isdir(data_dir):
  35. os.system("mkdir " + data_dir)
  36. os.chdir(data_dir)
  37. if (not os.path.exists('train-images-idx3-ubyte')) or \
  38. (not os.path.exists('train-labels-idx1-ubyte')) or \
  39. (not os.path.exists('t10k-images-idx3-ubyte')) or \
  40. (not os.path.exists('t10k-labels-idx1-ubyte')):
  41. os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip")
  42. os.system("unzip -u mnist.zip; rm mnist.zip")
  43. os.chdir("..")
  44. def get_mlp():
  45. """
  46. multi-layer perceptron
  47. """
  48. data = mx.symbol.Variable('data')
  49. fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
  50. act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
  51. fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
  52. act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
  53. fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
  54. mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
  55. return mlp
  56. def get_lenet():
  57. """
  58. LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
  59. Haffner. "Gradient-based learning applied to document recognition."
  60. Proceedings of the IEEE (1998)
  61. """
  62. data = mx.symbol.Variable('data')
  63. # first conv
  64. conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
  65. tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
  66. pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
  67. kernel=(2,2), stride=(2,2))
  68. # second conv
  69. conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
  70. tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")
  71. pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
  72. kernel=(2,2), stride=(2,2))
  73. # first fullc
  74. flatten = mx.symbol.Flatten(data=pool2)
  75. fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
  76. tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")
  77. # second fullc
  78. fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
  79. # loss
  80. lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
  81. return lenet
  82. if args.network == 'mlp':
  83. data_shape = (784, )
  84. net = get_mlp()
  85. else:
  86. data_shape = (1, 28, 28)
  87. net = get_lenet()
  88. def get_iterator(args, kv):
  89. data_dir = args.data_dir
  90. if '://' not in args.data_dir:
  91. _download(args.data_dir)
  92. flat = False if len(data_shape) == 3 else True
  93. train = mx.io.MNISTIter(
  94. image = data_dir + "train-images-idx3-ubyte",
  95. label = data_dir + "train-labels-idx1-ubyte",
  96. input_shape = data_shape,
  97. batch_size = args.batch_size,
  98. shuffle = True,
  99. flat = flat,
  100. num_parts = kv.num_workers,
  101. part_index = kv.rank)
  102. val = mx.io.MNISTIter(
  103. image = data_dir + "t10k-images-idx3-ubyte",
  104. label = data_dir + "t10k-labels-idx1-ubyte",
  105. input_shape = data_shape,
  106. batch_size = args.batch_size,
  107. flat = flat,
  108. num_parts = kv.num_workers,
  109. part_index = kv.rank)
  110. return (train, val)
  111. # train
  112. train_model.fit(args, net, get_iterator)