PageRenderTime 52ms CodeModel.GetById 28ms RepoModel.GetById 1ms app.codeStats 0ms

/examples/argument_mining/neuro/nn_triple_sequence.py

https://gitlab.com/purdueNlp/DRaiL
Python | 302 lines | 213 code | 64 blank | 25 comment | 39 complexity | ba5419c3122a38e592938728b1d7ed2f MD5 | raw file
  1. import torch
  2. import numpy as np
  3. from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
  4. import torch.nn.functional as F
  5. from drail.neuro.nn_model import NeuralNetworks
  6. class TripleSequenceNet(NeuralNetworks):
  7. def __init__(self, config, nn_id, use_gpu, output_dim):
  8. super(TripleSequenceNet, self).__init__(config, nn_id)
  9. self.use_gpu = use_gpu
  10. self.output_dim = output_dim
  11. #print "output_dim", output_dim
  12. def build_architecture(self, rule_template, fe, shared_params={}):
  13. self.minibatch_size = self.config['batch_size']
  14. self.embedding_layers, emb_dims = \
  15. self._embedding_inputs(rule_template, fe)
  16. if "shared_lstm_1" in self.config:
  17. name = self.config["shared_lstm_1"]
  18. self.n_input_sequence_1 = shared_params[name]["nin"]
  19. self.n_hidden_sequence_1 = shared_params[name]["nout"]
  20. self.sequence_lstm_1 = shared_params[name]["lstm"]
  21. else:
  22. self.n_input_sequence_1 = self.config['n_input_sequence_1']
  23. self.n_hidden_sequence_1 = self.config["n_hidden_sequence_1"]
  24. # LSTM for the sequence
  25. self.sequence_lstm_1 =\
  26. torch.nn.LSTM(input_size=self.n_input_sequence_1,
  27. hidden_size=self.n_hidden_sequence_1,
  28. bidirectional=True,
  29. batch_first=True)
  30. if "shared_reln_1" in self.config:
  31. name = self.config["shared_reln_1"]
  32. self.layer2hidden_1 = shared_params[name]["layer"]
  33. self.n_hidden_layer_1 = shared_params[name]["nout"]
  34. else:
  35. self.n_hidden_layer_1 = self.config["n_hidden_layer_1"]
  36. self.n_input_vector_1 = 0
  37. if 'n_input_vector_1' in self.config:
  38. self.n_input_vector_1 = self.config['n_input_vector_1']
  39. self.layer2hidden_1 =\
  40. torch.nn.Linear(self.n_hidden_sequence_1*2 + self.n_input_vector_1, self.n_hidden_layer_1)
  41. if "shared_lstm_2" in self.config:
  42. name = self.config["shared_lstm_2"]
  43. self.n_input_sequence_2 = shared_params[name]["nin"]
  44. self.n_hidden_sequence_2 = shared_params[name]["nout"]
  45. self.sequence_lstm_2 = shared_params[name]["lstm"]
  46. else:
  47. self.n_input_sequence_2 = self.config['n_input_sequence_2']
  48. self.n_hidden_sequence_2 = self.config["n_hidden_sequence_2"]
  49. # LSTM for the sequence
  50. self.sequence_lstm_2 =\
  51. torch.nn.LSTM(input_size=self.n_input_sequence_2,
  52. hidden_size=self.n_hidden_sequence_2,
  53. bidirectional=True,
  54. batch_first=True)
  55. if "shared_reln_2" in self.config:
  56. name = self.config["shared_reln_2"]
  57. self.layer2hidden_2 = shared_params[name]["layer"]
  58. self.n_hidden_layer_2 = shared_params[name]["nout"]
  59. else:
  60. self.n_hidden_layer_2 = self.config["n_hidden_layer_2"]
  61. self.n_input_vector_2 = 0
  62. if 'n_input_vector_2' in self.config:
  63. self.n_input_vector_2 = self.config['n_input_vector_2']
  64. self.layer2hidden_2 =\
  65. torch.nn.Linear(self.n_hidden_sequence_2*2 + self.n_input_vector_2, self.n_hidden_layer_2)
  66. if "shared_lstm_3" in self.config:
  67. name = self.config["shared_lstm_3"]
  68. self.n_input_sequence_3 = shared_params[name]["nin"]
  69. self.n_hidden_sequence_3 = shared_params[name]["nout"]
  70. self.sequence_lstm_3 = shared_params[name]["lstm"]
  71. else:
  72. self.n_input_sequence_3 = self.config['n_input_sequence_3']
  73. self.n_hidden_sequence_3 = self.config["n_hidden_sequence_3"]
  74. # LSTM for the sequence
  75. self.sequence_lstm_3 =\
  76. torch.nn.LSTM(input_size=self.n_input_sequence_3,
  77. hidden_size=self.n_hidden_sequence_3,
  78. bidirectional=True,
  79. batch_first=True)
  80. if "shared_reln_3" in self.config:
  81. name = self.config["shared_reln_3"]
  82. self.layer2hidden_3 = shared_params[name]["layer"]
  83. self.n_hidden_layer_3 = shared_params[name]["nout"]
  84. else:
  85. self.n_hidden_layer_3 = self.config["n_hidden_layer_3"]
  86. self.n_input_vector_3 = 0
  87. if 'n_input_vector_3' in self.config:
  88. self.n_input_vector_3 = self.config['n_input_vector_3']
  89. self.layer2hidden_3 =\
  90. torch.nn.Linear(self.n_hidden_sequence_3*2 + self.n_input_vector_3, self.n_hidden_layer_3)
  91. self.n_extra = 0
  92. if "shared_extra" in self.config:
  93. name = self.config["shared_extra"]
  94. self.extra_layer = shared_params[name]["layer"]
  95. self.n_extra = self.config["n_extra"]
  96. self.n_input_reln = 0
  97. if "n_reln_input" in self.config:
  98. self.n_input_reln = self.config["n_reln_input"]
  99. self.n_hidden_concat = self.config['n_hidden_concat']
  100. self.concat2hidden = torch.nn.Linear(self.n_hidden_layer_1 + self.n_hidden_layer_2 + self.n_hidden_layer_3 + self.n_input_reln + self.n_extra * 2, self.n_hidden_concat)
  101. self.dropout_layer = torch.nn.Dropout(p=self.config['dropout_rate'])
  102. self.hidden2label =\
  103. torch.nn.Linear(self.n_hidden_concat, self.output_dim)
  104. if self.use_gpu:
  105. self.sequence_lstm_1 = self.sequence_lstm_1.cuda()
  106. self.sequence_lstm_2 = self.sequence_lstm_2.cuda()
  107. self.sequence_lstm_3 = self.sequence_lstm_3.cuda()
  108. self.layer2hidden_1 = self.layer2hidden_1.cuda()
  109. self.layer2hidden_2 = self.layer2hidden_2.cuda()
  110. self.layer2hidden_3 = self.layer2hidden_3.cuda()
  111. self.concat2hidden = self.concat2hidden.cuda()
  112. self.dropout_layer = self.dropout_layer.cuda()
  113. self.hidden2label = self.hidden2label.cuda()
  114. self.hidden_bilstm = self.init_hidden_bilstm()
  115. def init_hidden_bilstm(self):
  116. var1 = torch.autograd.Variable(torch.zeros(2, self.minibatch_size,
  117. self.n_hidden_sequence_1))
  118. var2 = torch.autograd.Variable(torch.zeros(2, self.minibatch_size,
  119. self.n_hidden_sequence_2))
  120. if self.use_gpu:
  121. var1 = var1.cuda()
  122. var2 = var2.cuda()
  123. return (var1, var2)
  124. def _run_sequence(self, seqs, has_embedding_layer, sequence_lstm, n_input_sequence, key):
  125. self.minibatch_size = len(seqs)
  126. # get the length of each seq in your batch
  127. seq_lengths = self._get_long_tensor(list(map(len, seqs)))
  128. # dump padding everywhere, and place seqs on the left.
  129. # NOTE: you only need a tensor as big as your longest sequence
  130. max_seq_len = seq_lengths.max()
  131. # Sort according to lengths
  132. seq_len_sorted, sorted_idx = seq_lengths.sort(descending=True)
  133. if has_embedding_layer:
  134. tensor_seq = torch.zeros((len(seqs), max_seq_len)).long()
  135. if self.use_gpu:
  136. tensor_seq = tensor_seq.cuda()
  137. for idx, (seq, seqlen) in enumerate(zip(seqs, seq_lengths)):
  138. tensor_seq[idx, :seqlen] = self._get_long_tensor(seq)
  139. else:
  140. tensor_seq = torch.zeros((len(seqs), max_seq_len, n_input_sequence)).float()
  141. if self.use_gpu:
  142. tensor_seq = tensor_seq.cuda()
  143. for idx, (seq, seqlen) in enumerate(zip(seqs, seq_lengths)):
  144. tensor_seq[idx, :seqlen] = self._get_float_tensor(seq)
  145. # sort inputs
  146. tensor_seq = tensor_seq[sorted_idx]
  147. var_seq = self._get_variable(tensor_seq)
  148. seq_lengths = self._get_variable(seq_lengths)
  149. if has_embedding_layer:
  150. var_seq = self.embedding_layers[key](var_seq)
  151. # pack padded sequences
  152. packed_input_seq = pack_padded_sequence(var_seq, list(seq_len_sorted.data), batch_first=True)
  153. # run lstm over sequence
  154. self.hidden_bilstm = self.init_hidden_bilstm()
  155. packed_output, self.hidden_bilstm = \
  156. sequence_lstm(packed_input_seq, self.hidden_bilstm)
  157. # unpack the output
  158. unpacked_output, _ = pad_packed_sequence(packed_output, batch_first=True)
  159. # Reverse sorting
  160. unpacked_output = torch.zeros_like(unpacked_output).scatter_(0, sorted_idx.unsqueeze(1).unsqueeze(1).expand(-1, unpacked_output.shape[1], unpacked_output.shape[2]), unpacked_output)
  161. # extract last timestep, since doing [-1] would get the padded zeros
  162. '''
  163. idx = (seq_lengths - 1).view(-1, 1).expand(
  164. unpacked_output.size(0), unpacked_output.size(2)).unsqueeze(1)
  165. lstm_output = unpacked_output.gather(1, idx).squeeze()
  166. if len(list(lstm_output.size())) == 1:
  167. lstm_output = lstm_output.unsqueeze(0)
  168. '''
  169. # Do global avg pooling over timesteps
  170. #lstm_output, _ = torch.max(unpacked_output, dim=1)
  171. lstm_output = torch.mean(unpacked_output, dim=1)
  172. return lstm_output
  173. def forward(self, x):
  174. input_index = 0
  175. has_embedding_layer = len(x['embedding']) > 0
  176. if has_embedding_layer:
  177. all_keys = list(x['embedding'].keys())
  178. key_one = all_keys[0]
  179. seqs_one = x['embedding'][key_one]
  180. key_two = all_keys[1]
  181. seqs_two = x['embedding'][key_two]
  182. key_three = all_keys[2]
  183. seqs_three = x['embedding'][key_three]
  184. else:
  185. seqs_one = [elem[input_index] for elem in x['input']]
  186. input_index += 1
  187. seqs_two = [elem[input_index] for elem in x['input']]
  188. input_index += 1
  189. seqs_three = [elem[input_index] for elem in x['input']]
  190. input_index += 1
  191. lstm_output_one = self._run_sequence(seqs_one, has_embedding_layer, self.sequence_lstm_1, self.n_input_sequence_1, key_one)
  192. lstm_output_two = self._run_sequence(seqs_two, has_embedding_layer, self.sequence_lstm_2, self.n_input_sequence_2, key_two)
  193. lstm_output_three = self._run_sequence(seqs_three, has_embedding_layer, self.sequence_lstm_3, self.n_input_sequence_3, key_three)
  194. # now add the extra features
  195. extra_feats_one = [elem[input_index] for elem in x['input']]
  196. input_index += 1
  197. extra_feats_one = self._get_float_tensor(extra_feats_one)
  198. extra_feats_one = self._get_variable(extra_feats_one)
  199. lstm_output_one = torch.cat([lstm_output_one, extra_feats_one], 1)
  200. extra_feats_two = [elem[input_index] for elem in x['input']]
  201. input_index += 1
  202. extra_feats_two = self._get_float_tensor(extra_feats_two)
  203. extra_feats_two = self._get_variable(extra_feats_two)
  204. lstm_output_two = torch.cat([lstm_output_two, extra_feats_two], 1)
  205. extra_feats_three = [elem[input_index] for elem in x['input']]
  206. input_index += 1
  207. extra_feats_three = self._get_float_tensor(extra_feats_three)
  208. extra_feats_three = self._get_variable(extra_feats_three)
  209. lstm_output_three = torch.cat([lstm_output_three, extra_feats_three], 1)
  210. layer_one = self.layer2hidden_1(lstm_output_one)
  211. layer_one = F.relu(layer_one)
  212. layer_two = self.layer2hidden_2(lstm_output_two)
  213. layer_two = F.relu(layer_two)
  214. layer_three = self.layer2hidden_3(lstm_output_three)
  215. layer_three = F.relu(layer_three)
  216. layer_one = self.dropout_layer(layer_one)
  217. layer_two = self.dropout_layer(layer_two)
  218. layer_three = self.dropout_layer(layer_three)
  219. extra_feats_reln = x['vector']
  220. extra_feats_reln = self._get_float_tensor(extra_feats_reln)
  221. extra_feats_reln = self._get_variable(extra_feats_reln)
  222. if "shared_extra" in self.config:
  223. extra_pred_1 = self.extra_layer(layer_one)
  224. extra_pred_2 = self.extra_layer(layer_two)
  225. extra_pred_3 = self.extra_layer(layer_three)
  226. concat = torch.cat([layer_one, layer_two, layer_three, extra_feats_reln, extra_pred_1, extra_pred_2, extra_pred_3], 1)
  227. else:
  228. concat = torch.cat([layer_one, layer_two, layer_three, extra_feats_reln], 1)
  229. concat = self.concat2hidden(concat)
  230. concat = F.relu(concat)
  231. concat = self.dropout_layer(concat)
  232. logits = self.hidden2label(concat)
  233. if 'output' not in self.config or self.config['output'] == "softmax":
  234. probas = F.softmax(logits, dim=1)
  235. elif self.config['output'] == 'sigmoid':
  236. probas = F.sigmoid(logits)
  237. return logits, probas