/examples/argument_mining/neuro/nn_triple_sequence.py
Python | 302 lines | 213 code | 64 blank | 25 comment | 39 complexity | ba5419c3122a38e592938728b1d7ed2f MD5 | raw file
- import torch
- import numpy as np
- from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
- import torch.nn.functional as F
- from drail.neuro.nn_model import NeuralNetworks
- class TripleSequenceNet(NeuralNetworks):
- def __init__(self, config, nn_id, use_gpu, output_dim):
- super(TripleSequenceNet, self).__init__(config, nn_id)
- self.use_gpu = use_gpu
- self.output_dim = output_dim
- #print "output_dim", output_dim
- def build_architecture(self, rule_template, fe, shared_params={}):
- self.minibatch_size = self.config['batch_size']
- self.embedding_layers, emb_dims = \
- self._embedding_inputs(rule_template, fe)
- if "shared_lstm_1" in self.config:
- name = self.config["shared_lstm_1"]
- self.n_input_sequence_1 = shared_params[name]["nin"]
- self.n_hidden_sequence_1 = shared_params[name]["nout"]
- self.sequence_lstm_1 = shared_params[name]["lstm"]
- else:
- self.n_input_sequence_1 = self.config['n_input_sequence_1']
- self.n_hidden_sequence_1 = self.config["n_hidden_sequence_1"]
- # LSTM for the sequence
- self.sequence_lstm_1 =\
- torch.nn.LSTM(input_size=self.n_input_sequence_1,
- hidden_size=self.n_hidden_sequence_1,
- bidirectional=True,
- batch_first=True)
- if "shared_reln_1" in self.config:
- name = self.config["shared_reln_1"]
- self.layer2hidden_1 = shared_params[name]["layer"]
- self.n_hidden_layer_1 = shared_params[name]["nout"]
- else:
- self.n_hidden_layer_1 = self.config["n_hidden_layer_1"]
- self.n_input_vector_1 = 0
- if 'n_input_vector_1' in self.config:
- self.n_input_vector_1 = self.config['n_input_vector_1']
- self.layer2hidden_1 =\
- torch.nn.Linear(self.n_hidden_sequence_1*2 + self.n_input_vector_1, self.n_hidden_layer_1)
- if "shared_lstm_2" in self.config:
- name = self.config["shared_lstm_2"]
- self.n_input_sequence_2 = shared_params[name]["nin"]
- self.n_hidden_sequence_2 = shared_params[name]["nout"]
- self.sequence_lstm_2 = shared_params[name]["lstm"]
- else:
- self.n_input_sequence_2 = self.config['n_input_sequence_2']
- self.n_hidden_sequence_2 = self.config["n_hidden_sequence_2"]
- # LSTM for the sequence
- self.sequence_lstm_2 =\
- torch.nn.LSTM(input_size=self.n_input_sequence_2,
- hidden_size=self.n_hidden_sequence_2,
- bidirectional=True,
- batch_first=True)
- if "shared_reln_2" in self.config:
- name = self.config["shared_reln_2"]
- self.layer2hidden_2 = shared_params[name]["layer"]
- self.n_hidden_layer_2 = shared_params[name]["nout"]
- else:
- self.n_hidden_layer_2 = self.config["n_hidden_layer_2"]
- self.n_input_vector_2 = 0
- if 'n_input_vector_2' in self.config:
- self.n_input_vector_2 = self.config['n_input_vector_2']
- self.layer2hidden_2 =\
- torch.nn.Linear(self.n_hidden_sequence_2*2 + self.n_input_vector_2, self.n_hidden_layer_2)
- if "shared_lstm_3" in self.config:
- name = self.config["shared_lstm_3"]
- self.n_input_sequence_3 = shared_params[name]["nin"]
- self.n_hidden_sequence_3 = shared_params[name]["nout"]
- self.sequence_lstm_3 = shared_params[name]["lstm"]
- else:
- self.n_input_sequence_3 = self.config['n_input_sequence_3']
- self.n_hidden_sequence_3 = self.config["n_hidden_sequence_3"]
- # LSTM for the sequence
- self.sequence_lstm_3 =\
- torch.nn.LSTM(input_size=self.n_input_sequence_3,
- hidden_size=self.n_hidden_sequence_3,
- bidirectional=True,
- batch_first=True)
- if "shared_reln_3" in self.config:
- name = self.config["shared_reln_3"]
- self.layer2hidden_3 = shared_params[name]["layer"]
- self.n_hidden_layer_3 = shared_params[name]["nout"]
- else:
- self.n_hidden_layer_3 = self.config["n_hidden_layer_3"]
- self.n_input_vector_3 = 0
- if 'n_input_vector_3' in self.config:
- self.n_input_vector_3 = self.config['n_input_vector_3']
- self.layer2hidden_3 =\
- torch.nn.Linear(self.n_hidden_sequence_3*2 + self.n_input_vector_3, self.n_hidden_layer_3)
- self.n_extra = 0
- if "shared_extra" in self.config:
- name = self.config["shared_extra"]
- self.extra_layer = shared_params[name]["layer"]
- self.n_extra = self.config["n_extra"]
- self.n_input_reln = 0
- if "n_reln_input" in self.config:
- self.n_input_reln = self.config["n_reln_input"]
- self.n_hidden_concat = self.config['n_hidden_concat']
- self.concat2hidden = torch.nn.Linear(self.n_hidden_layer_1 + self.n_hidden_layer_2 + self.n_hidden_layer_3 + self.n_input_reln + self.n_extra * 2, self.n_hidden_concat)
- self.dropout_layer = torch.nn.Dropout(p=self.config['dropout_rate'])
- self.hidden2label =\
- torch.nn.Linear(self.n_hidden_concat, self.output_dim)
- if self.use_gpu:
- self.sequence_lstm_1 = self.sequence_lstm_1.cuda()
- self.sequence_lstm_2 = self.sequence_lstm_2.cuda()
- self.sequence_lstm_3 = self.sequence_lstm_3.cuda()
- self.layer2hidden_1 = self.layer2hidden_1.cuda()
- self.layer2hidden_2 = self.layer2hidden_2.cuda()
- self.layer2hidden_3 = self.layer2hidden_3.cuda()
- self.concat2hidden = self.concat2hidden.cuda()
- self.dropout_layer = self.dropout_layer.cuda()
- self.hidden2label = self.hidden2label.cuda()
- self.hidden_bilstm = self.init_hidden_bilstm()
- def init_hidden_bilstm(self):
- var1 = torch.autograd.Variable(torch.zeros(2, self.minibatch_size,
- self.n_hidden_sequence_1))
- var2 = torch.autograd.Variable(torch.zeros(2, self.minibatch_size,
- self.n_hidden_sequence_2))
- if self.use_gpu:
- var1 = var1.cuda()
- var2 = var2.cuda()
- return (var1, var2)
- def _run_sequence(self, seqs, has_embedding_layer, sequence_lstm, n_input_sequence, key):
- self.minibatch_size = len(seqs)
- # get the length of each seq in your batch
- seq_lengths = self._get_long_tensor(list(map(len, seqs)))
- # dump padding everywhere, and place seqs on the left.
- # NOTE: you only need a tensor as big as your longest sequence
- max_seq_len = seq_lengths.max()
- # Sort according to lengths
- seq_len_sorted, sorted_idx = seq_lengths.sort(descending=True)
- if has_embedding_layer:
- tensor_seq = torch.zeros((len(seqs), max_seq_len)).long()
- if self.use_gpu:
- tensor_seq = tensor_seq.cuda()
- for idx, (seq, seqlen) in enumerate(zip(seqs, seq_lengths)):
- tensor_seq[idx, :seqlen] = self._get_long_tensor(seq)
- else:
- tensor_seq = torch.zeros((len(seqs), max_seq_len, n_input_sequence)).float()
- if self.use_gpu:
- tensor_seq = tensor_seq.cuda()
- for idx, (seq, seqlen) in enumerate(zip(seqs, seq_lengths)):
- tensor_seq[idx, :seqlen] = self._get_float_tensor(seq)
- # sort inputs
- tensor_seq = tensor_seq[sorted_idx]
- var_seq = self._get_variable(tensor_seq)
- seq_lengths = self._get_variable(seq_lengths)
- if has_embedding_layer:
- var_seq = self.embedding_layers[key](var_seq)
- # pack padded sequences
- packed_input_seq = pack_padded_sequence(var_seq, list(seq_len_sorted.data), batch_first=True)
- # run lstm over sequence
- self.hidden_bilstm = self.init_hidden_bilstm()
- packed_output, self.hidden_bilstm = \
- sequence_lstm(packed_input_seq, self.hidden_bilstm)
- # unpack the output
- unpacked_output, _ = pad_packed_sequence(packed_output, batch_first=True)
- # Reverse sorting
- unpacked_output = torch.zeros_like(unpacked_output).scatter_(0, sorted_idx.unsqueeze(1).unsqueeze(1).expand(-1, unpacked_output.shape[1], unpacked_output.shape[2]), unpacked_output)
- # extract last timestep, since doing [-1] would get the padded zeros
- '''
- idx = (seq_lengths - 1).view(-1, 1).expand(
- unpacked_output.size(0), unpacked_output.size(2)).unsqueeze(1)
- lstm_output = unpacked_output.gather(1, idx).squeeze()
- if len(list(lstm_output.size())) == 1:
- lstm_output = lstm_output.unsqueeze(0)
- '''
- # Do global avg pooling over timesteps
- #lstm_output, _ = torch.max(unpacked_output, dim=1)
- lstm_output = torch.mean(unpacked_output, dim=1)
- return lstm_output
- def forward(self, x):
- input_index = 0
- has_embedding_layer = len(x['embedding']) > 0
- if has_embedding_layer:
- all_keys = list(x['embedding'].keys())
- key_one = all_keys[0]
- seqs_one = x['embedding'][key_one]
- key_two = all_keys[1]
- seqs_two = x['embedding'][key_two]
- key_three = all_keys[2]
- seqs_three = x['embedding'][key_three]
- else:
- seqs_one = [elem[input_index] for elem in x['input']]
- input_index += 1
- seqs_two = [elem[input_index] for elem in x['input']]
- input_index += 1
- seqs_three = [elem[input_index] for elem in x['input']]
- input_index += 1
- lstm_output_one = self._run_sequence(seqs_one, has_embedding_layer, self.sequence_lstm_1, self.n_input_sequence_1, key_one)
- lstm_output_two = self._run_sequence(seqs_two, has_embedding_layer, self.sequence_lstm_2, self.n_input_sequence_2, key_two)
- lstm_output_three = self._run_sequence(seqs_three, has_embedding_layer, self.sequence_lstm_3, self.n_input_sequence_3, key_three)
- # now add the extra features
- extra_feats_one = [elem[input_index] for elem in x['input']]
- input_index += 1
- extra_feats_one = self._get_float_tensor(extra_feats_one)
- extra_feats_one = self._get_variable(extra_feats_one)
- lstm_output_one = torch.cat([lstm_output_one, extra_feats_one], 1)
- extra_feats_two = [elem[input_index] for elem in x['input']]
- input_index += 1
- extra_feats_two = self._get_float_tensor(extra_feats_two)
- extra_feats_two = self._get_variable(extra_feats_two)
- lstm_output_two = torch.cat([lstm_output_two, extra_feats_two], 1)
- extra_feats_three = [elem[input_index] for elem in x['input']]
- input_index += 1
- extra_feats_three = self._get_float_tensor(extra_feats_three)
- extra_feats_three = self._get_variable(extra_feats_three)
- lstm_output_three = torch.cat([lstm_output_three, extra_feats_three], 1)
- layer_one = self.layer2hidden_1(lstm_output_one)
- layer_one = F.relu(layer_one)
- layer_two = self.layer2hidden_2(lstm_output_two)
- layer_two = F.relu(layer_two)
- layer_three = self.layer2hidden_3(lstm_output_three)
- layer_three = F.relu(layer_three)
- layer_one = self.dropout_layer(layer_one)
- layer_two = self.dropout_layer(layer_two)
- layer_three = self.dropout_layer(layer_three)
- extra_feats_reln = x['vector']
- extra_feats_reln = self._get_float_tensor(extra_feats_reln)
- extra_feats_reln = self._get_variable(extra_feats_reln)
- if "shared_extra" in self.config:
- extra_pred_1 = self.extra_layer(layer_one)
- extra_pred_2 = self.extra_layer(layer_two)
- extra_pred_3 = self.extra_layer(layer_three)
- concat = torch.cat([layer_one, layer_two, layer_three, extra_feats_reln, extra_pred_1, extra_pred_2, extra_pred_3], 1)
- else:
- concat = torch.cat([layer_one, layer_two, layer_three, extra_feats_reln], 1)
- concat = self.concat2hidden(concat)
- concat = F.relu(concat)
- concat = self.dropout_layer(concat)
- logits = self.hidden2label(concat)
- if 'output' not in self.config or self.config['output'] == "softmax":
- probas = F.softmax(logits, dim=1)
- elif self.config['output'] == 'sigmoid':
- probas = F.sigmoid(logits)
- return logits, probas