/examples/4forums/run_debates_4forums.py
Python | 300 lines | 248 code | 44 blank | 8 comment | 44 complexity | cc9beee286817f40dd4aef4f0649752e MD5 | raw file
- # -*- coding: utf-8 -*-
- import os
- import argparse
- import random
- import json
- import logging.config
- import numpy as np
- import torch
- from sklearn.metrics import *
- from drail.model.argument import ArgumentType
- from drail.learn.global_learner import GlobalLearner
- from drail.learn.local_learner import LocalLearner
- from drail.learn.joint_learner import JointLearner
- from drail.inference.randomized.purepython.infer_4forums import RInf4Forums # pure python
- # from drail.inference.randomized.infer_4forums import RInf4Forums # C++
- def parse_arguments():
- parser = argparse.ArgumentParser()
- parser.add_argument('-g', '--gpu', help='gpu index', dest='gpu_index', type=int, default=None)
- parser.add_argument('-d', '--dir', help='directory', dest='dir', type=str, required=True)
- parser.add_argument('-r', '--rule', help='rule file', dest='rules', type=str, required=True)
- parser.add_argument('-c', '--config', help='config file', dest='config', type=str, required=True)
- parser.add_argument('-f', '--fedir', help='fe directory', dest='fedir', type=str, required=True)
- parser.add_argument('-n', '--netdir', help='net directory', dest='nedir', type=str, required=True)
- parser.add_argument('-m', help='mode: [global|local|joint]', dest='mode', type=str, default='local')
- parser.add_argument("--debug", help='debug mode', default=False, action="store_true")
- parser.add_argument("--lrate", help="learning rate", type=float, default=0.01)
- parser.add_argument("--savedir", help="directory to save model", type=str, required=True)
- parser.add_argument('--continue_from_checkpoint', help='continue from checkpoint in savedir', default=False, action='store_true')
- parser.add_argument("--bert", action="store_true", default=False)
- parser.add_argument("--bert_tiny", action="store_true", default=False)
- parser.add_argument("--wordembed", action="store_true", default=False)
- parser.add_argument('--train_only', help='run training only', default=False, action='store_true')
- parser.add_argument("--infer_only", help="inference only", default=False, action="store_true")
- parser.add_argument('--issue', type=str, required=True)
- parser.add_argument('--delta', help='loss augmented inference', dest='delta', action='store_true', default=False)
- parser.add_argument('-p', '--pkldir', help='pkl directory', dest='pkldir', type=str, default=None)
- parser.add_argument("--constraints", help="use constraints", default=False, action="store_true")
- parser.add_argument('--randomized', help='randomized inference on training', dest='rand_inf', default=False, action='store_true')
- parser.add_argument('--ad3', help='use ad3 solver for training', default=False, action='store_true')
- parser.add_argument('--ad3_only', help='use ad3 solver rather than ilp', default=False, action='store_true')
- parser.add_argument('--loss', help='mode: [crf|hinge|hinge_smooth]', dest='lossfn', type=str, default="joint")
- parser.add_argument('--we_file', help='word embedding file', dest='we_file', type=str, required=False)
- parser.add_argument('--logging_config', help='logging configuration file', type=str, required=True)
- parser.add_argument('--author_constraints', help='activate author constraints for randomized inferencer', default=False, action='store_true')
- parser.add_argument('--local_init', help='initializes the scoring tree local optimal during inference', default=False, action='store_true')
- parser.add_argument('--no_randinf_constraints', help='deactivates constraints on randomized inference', default=False, action='store_true')
- parser.add_argument('--start', help='start fold index', dest='start_fold_index', type=int, default=0)
- parser.add_argument('--end', help='end fold index', dest='end_fold_index', type=int, default=5)
- parser.add_argument('--drop_scores', help='drop predicted scores', dest='drop_scores', default=False, action='store_true')
- args = parser.parse_args()
- return args
- def train(folds, avoid):
- ret = []
- for j in range(0, len(folds)):
- if j not in avoid:
- ret += folds[j]
- return ret
- def main():
- realpath = os.path.dirname(os.path.realpath(__file__)) + '/'
- # Fixed resources
- if args.wordembed and not args.bert:
- POSTS_F = realpath + "data/posts_preprocessed.json"
- WORD2IDX_F = realpath + "data/word2idx.json"
- WORD_EMB_F = args.we_file
- optimizer = "SGD"
- elif args.bert and not args.wordembed:
- POSTS_F = realpath + "data/posts_bert.json"
- WORD2IDX_F = None; WORD_EMB_F = None
- optimizer = "AdamW"
- elif args.bert_tiny and not args.wordembed:
- POSTS_F = "data/posts_bert_tiny.json"
- WORD2IDX_F = None; WORD_EMB_F = None
- optimizer = "AdamW"
- else:
- logger.error("Choose only one: --wordembed or (--bert | --bert_tiny)")
- exit(-1)
- FOLDS = json.load(open(realpath + "data/chang_folds.json"))
- # seed
- np.random.seed(1234)
- random.seed(1234)
- train_inference_solver = "randomized" if args.rand_inf else "AD3" if args.ad3_only or args.ad3 else "ILP"
- test_inference_solver = "ILP" if not args.ad3_only else "AD3"
-
- logger.info("train inference solver: {}".format(train_inference_solver))
- logger.info("test inference solver: {}".format(test_inference_solver))
- logger.info("author-constraints: {}".format(args.author_constraints))
- logger.info("local-initialization: {}".format(args.local_init))
- logger.info("no randomized inference constraints: {}".format(args.no_randinf_constraints))
- # Select what gpu to use
- if args.gpu_index is not None:
- torch.cuda.set_device(args.gpu_index)
- if args.mode == "global":
- learner=GlobalLearner(learning_rate=args.lrate, use_gpu=(args.gpu_index is not None), gpu_index=args.gpu_index, \
- rand_inf=args.rand_inf, inferencer=RInf4Forums, ad3=args.ad3, ad3_only=args.ad3_only, \
- loss_fn=args.lossfn, local_init=args.local_init, no_randinf_constraints=args.no_randinf_constraints)
- learner.set_4forums_issue(args.issue, args.author_constraints)
- elif args.mode == "joint":
- learner=JointLearner()
- else:
- learner=LocalLearner()
- learner.compile_rules(args.rules)
- db=learner.create_dataset(os.path.join(args.dir, args.issue))
- stance_macros, stance_accuracies = [], []
- disagr_macros, disagr_accuracies = [], []
- for i in range(args.start_fold_index, args.end_fold_index):
- logger.info("Fold {}".format(i))
- dev_fold = (i + 1) % 5
- test_posts = FOLDS[args.issue][i]
- dev_posts = FOLDS[args.issue][dev_fold]
- train_posts = train(FOLDS[args.issue], [i, dev_fold])
- if args.debug:
- test_posts = test_posts[:30]
- dev_posts = dev_posts[:30]
- train_posts = train_posts[:30]
- db.add_filters(filters=[
- ("InThread", "isTrain", "postId_2", train_posts),
- ("InThread", "isDev", "postId_2", dev_posts),
- ("InThread", "isTest", "postId_2", test_posts),
- ("InThread", "isDummy", "postId_2", train_posts[:10])
- ])
- logger.info("{} train posts, {} dev posts, {} test posts".format(len(train_posts), len(dev_posts), len(test_posts)))
- learner.build_feature_extractors(db,
- word_emb_f=WORD_EMB_F,
- data_f=POSTS_F,
- word2idx_f=WORD2IDX_F,
- use_wordembed=args.wordembed,
- debug=args.debug,
- femodule_path=args.fedir,
- filters=[("InThread", "isDummy", 1)])
-
- if not os.path.isdir(args.savedir):
- os.mkdir(args.savedir)
- learner.set_savedir(os.path.join(args.savedir, "f{0}".format(i)))
- learner.build_models(db, args.config, netmodules_path=args.nedir)
- pklpath = None
- if args.pkldir is not None:
- pklpath = os.path.join(args.pkldir, "f{0}".format(i))
- if args.mode == "global":
- learner.extract_data(
- db,
- train_filters=[("InThread", "isTrain", 1)],
- dev_filters=[("InThread", "isDev", 1)],
- test_filters=[("InThread", "isTest", 1)],
- extract_train=not args.infer_only,
- extract_dev=not args.infer_only,
- extract_test=True,
- pickledir=pklpath,
- from_pickle=(args.pkldir is not None),
- extract_constraints=args.constraints)
- # Passing accum_loss=True to accumulate gradients before backprop
- res, heads = learner.train(
- db,
- train_filters=[("InThread", "isTrain", 1)],
- dev_filters=[("InThread", "isDev", 1)],
- test_filters=[("InThread", "isTest", 1)],
- opt_predicates=set(['HasStance', 'Disagree']),
- loss_augmented_inference=args.delta,
- continue_from_checkpoint=args.continue_from_checkpoint,
- inference_only=args.infer_only,
- scale_data=False,
- weight_classes=True,
- accum_loss=True,
- patience=5,
- optimizer=optimizer)
- else:
- if args.continue_from_checkpoint:
- learner.init_models(scale_data=False)
- if not args.infer_only:
- learner.train(db,
- train_filters=[("InThread", "isTrain", 1)],
- dev_filters=[("InThread", "isDev", 1)],
- test_filters=[("InThread", "isTest", 1)],
- scale_data=False,
- optimizer=optimizer)
- if not args.train_only:
- learner.extract_data(db, extract_test=True, extract_dev=True, extract_train=True,
- test_filters=[("InThread", "isTest", 1)],
- dev_filters=[("InThread", "isDev", 1)],
- train_filters=[("InThread", "isTrain", 1)],
- pickledir=pklpath,
- from_pickle=(args.pkldir is not None),
- extract_constraints=args.constraints)
- res, heads = learner.predict(None, fold='test', get_predicates=True, scale_data=False)
- if args.drop_scores:
- learner.drop_scores(None, fold='test', output='{0}_f{1}_test_scores.csv'.format(args.issue, i))
- learner.drop_scores(None, fold='dev', output='{0}_f{1}_dev_scores.csv'.format(args.issue, i))
- learner.drop_scores(None, fold='train', output= '{0}_f{1}_train_scores.csv'.format(args.issue, i))
- if not args.train_only:
- if 'HasStance' in res.metrics:
- y_gold = res.metrics['HasStance']['gold_data']
- y_pred = res.metrics['HasStance']['pred_data']
- logger.info("HasStance")
- labels = list(set(y_gold))
- if 'other' in labels:
- labels.remove('other')
- logger.info('\n'+classification_report(y_gold, y_pred, digits=4))
- acc_score = accuracy_score(y_gold, y_pred)
- f1_macro = f1_score(y_gold, y_pred, average='macro', labels=labels)
- TP, TN, P, N = 0, 0, 0, 0
- for l, label in enumerate(y_gold):
- if label == labels[0]:
- P += 1
- if y_pred[l] == label:
- TP += 1
- if label == labels[1]:
- N += 1
- if y_pred[l] == label:
- TN += 1
- accuracy_no_other = (TP + TN) / float(P + N)
- logger.info("TEST Acc: {}".format(acc_score))
- logger.info("TEST Acc without 'other': {}".format(accuracy_no_other))
- logger.info("TEST F1 Macro without 'other': {}".format(f1_macro))
- stance_macros.append(f1_macro)
- stance_accuracies.append(accuracy_no_other)
- if 'Disagree' in res.metrics:
- y_gold = res.metrics['Disagree']['gold_data']
- y_pred = res.metrics['Disagree']['pred_data']
- logger.info("Disagree")
- logger.info('\n'+classification_report(y_gold, y_pred, digits=4))
- acc_score = accuracy_score(y_gold, y_pred)
- f1_macro = f1_score(y_gold, y_pred, average='macro')
- logger.info("TEST Acc: {}".format(acc_score))
- logger.info("TEST F1 Macro: {}".format(f1_macro))
- disagr_macros.append(f1_macro)
- disagr_accuracies.append(acc_score)
- if 'Agree' in res.metrics:
- y_gold = res.metrics['Agree']['gold_data']
- y_pred = res.metrics['Agree']['pred_data']
- logger.info("Agree")
- logger.info('\n'+classification_report(y_gold, y_pred, digits=4))
- acc_score = accuracy_score(y_gold, y_pred)
- f1_macro = f1_score(y_gold, y_pred, average='macro')
- logger.info("TEST Acc: {}".format(acc_score))
- logger.info("TEST F1 Macro: {}".format(f1_macro))
- if not args.infer_only and args.mode == 'global':
- # pure inference time during training
- logger.info('\ninference encoding time:\t{0}\n'.format(learner.train_metrics.metrics['encoding_time']) +\
- 'inference optimiz. time:\t{0}\n'.format(learner.train_metrics.metrics['solving_time']) +\
- 'total inference time: \t{0}'.format(learner.train_metrics.total_time()))
- learner.reset_train_metrics()
- learner.reset_metrics()
- #exit()
-
- logger.info("--------------------------------")
- logger.info("HasStance F1 macro (no 'other'): {}".format(round(np.mean(stance_macros), 5)))
- logger.info("HasStance Accuracy (no 'other'): {}".format(round(np.mean(stance_accuracies), 5)))
- logger.info("HasStance F1 macro (no 'other'): {}".format(str(stance_macros)))
- logger.info("HasStance Accuracy (no 'other'): {}".format(str(stance_accuracies)))
- logger.info("--------------------------------")
- logger.info("Disagreement F1 macro (no 'other'): {}".format(round(np.mean(disagr_macros), 5)))
- logger.info("Disagreement Accuracy (no 'other'): {}".format(round(np.mean(disagr_accuracies), 5)))
- logger.info("Disagreement F1 macro (no 'other'): {}".format(str(disagr_macros)))
- logger.info("Disagreement Accuracy (no 'other'): {}".format(str(disagr_accuracies)))
- if __name__ == "__main__":
- args = parse_arguments() # put args here so it's global
- if args.logging_config:
- logger = logging.getLogger()
- logging.config.dictConfig(json.load(open(args.logging_config)))
- main()