PageRenderTime 63ms CodeModel.GetById 32ms RepoModel.GetById 0ms app.codeStats 0ms

/examples/4forums/run_debates_4forums.py

https://gitlab.com/purdueNlp/DRaiL
Python | 300 lines | 248 code | 44 blank | 8 comment | 44 complexity | cc9beee286817f40dd4aef4f0649752e MD5 | raw file
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import argparse
  4. import random
  5. import json
  6. import logging.config
  7. import numpy as np
  8. import torch
  9. from sklearn.metrics import *
  10. from drail.model.argument import ArgumentType
  11. from drail.learn.global_learner import GlobalLearner
  12. from drail.learn.local_learner import LocalLearner
  13. from drail.learn.joint_learner import JointLearner
  14. from drail.inference.randomized.purepython.infer_4forums import RInf4Forums # pure python
  15. # from drail.inference.randomized.infer_4forums import RInf4Forums # C++
  16. def parse_arguments():
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('-g', '--gpu', help='gpu index', dest='gpu_index', type=int, default=None)
  19. parser.add_argument('-d', '--dir', help='directory', dest='dir', type=str, required=True)
  20. parser.add_argument('-r', '--rule', help='rule file', dest='rules', type=str, required=True)
  21. parser.add_argument('-c', '--config', help='config file', dest='config', type=str, required=True)
  22. parser.add_argument('-f', '--fedir', help='fe directory', dest='fedir', type=str, required=True)
  23. parser.add_argument('-n', '--netdir', help='net directory', dest='nedir', type=str, required=True)
  24. parser.add_argument('-m', help='mode: [global|local|joint]', dest='mode', type=str, default='local')
  25. parser.add_argument("--debug", help='debug mode', default=False, action="store_true")
  26. parser.add_argument("--lrate", help="learning rate", type=float, default=0.01)
  27. parser.add_argument("--savedir", help="directory to save model", type=str, required=True)
  28. parser.add_argument('--continue_from_checkpoint', help='continue from checkpoint in savedir', default=False, action='store_true')
  29. parser.add_argument("--bert", action="store_true", default=False)
  30. parser.add_argument("--bert_tiny", action="store_true", default=False)
  31. parser.add_argument("--wordembed", action="store_true", default=False)
  32. parser.add_argument('--train_only', help='run training only', default=False, action='store_true')
  33. parser.add_argument("--infer_only", help="inference only", default=False, action="store_true")
  34. parser.add_argument('--issue', type=str, required=True)
  35. parser.add_argument('--delta', help='loss augmented inference', dest='delta', action='store_true', default=False)
  36. parser.add_argument('-p', '--pkldir', help='pkl directory', dest='pkldir', type=str, default=None)
  37. parser.add_argument("--constraints", help="use constraints", default=False, action="store_true")
  38. parser.add_argument('--randomized', help='randomized inference on training', dest='rand_inf', default=False, action='store_true')
  39. parser.add_argument('--ad3', help='use ad3 solver for training', default=False, action='store_true')
  40. parser.add_argument('--ad3_only', help='use ad3 solver rather than ilp', default=False, action='store_true')
  41. parser.add_argument('--loss', help='mode: [crf|hinge|hinge_smooth]', dest='lossfn', type=str, default="joint")
  42. parser.add_argument('--we_file', help='word embedding file', dest='we_file', type=str, required=False)
  43. parser.add_argument('--logging_config', help='logging configuration file', type=str, required=True)
  44. parser.add_argument('--author_constraints', help='activate author constraints for randomized inferencer', default=False, action='store_true')
  45. parser.add_argument('--local_init', help='initializes the scoring tree local optimal during inference', default=False, action='store_true')
  46. parser.add_argument('--no_randinf_constraints', help='deactivates constraints on randomized inference', default=False, action='store_true')
  47. parser.add_argument('--start', help='start fold index', dest='start_fold_index', type=int, default=0)
  48. parser.add_argument('--end', help='end fold index', dest='end_fold_index', type=int, default=5)
  49. parser.add_argument('--drop_scores', help='drop predicted scores', dest='drop_scores', default=False, action='store_true')
  50. args = parser.parse_args()
  51. return args
  52. def train(folds, avoid):
  53. ret = []
  54. for j in range(0, len(folds)):
  55. if j not in avoid:
  56. ret += folds[j]
  57. return ret
  58. def main():
  59. realpath = os.path.dirname(os.path.realpath(__file__)) + '/'
  60. # Fixed resources
  61. if args.wordembed and not args.bert:
  62. POSTS_F = realpath + "data/posts_preprocessed.json"
  63. WORD2IDX_F = realpath + "data/word2idx.json"
  64. WORD_EMB_F = args.we_file
  65. optimizer = "SGD"
  66. elif args.bert and not args.wordembed:
  67. POSTS_F = realpath + "data/posts_bert.json"
  68. WORD2IDX_F = None; WORD_EMB_F = None
  69. optimizer = "AdamW"
  70. elif args.bert_tiny and not args.wordembed:
  71. POSTS_F = "data/posts_bert_tiny.json"
  72. WORD2IDX_F = None; WORD_EMB_F = None
  73. optimizer = "AdamW"
  74. else:
  75. logger.error("Choose only one: --wordembed or (--bert | --bert_tiny)")
  76. exit(-1)
  77. FOLDS = json.load(open(realpath + "data/chang_folds.json"))
  78. # seed
  79. np.random.seed(1234)
  80. random.seed(1234)
  81. train_inference_solver = "randomized" if args.rand_inf else "AD3" if args.ad3_only or args.ad3 else "ILP"
  82. test_inference_solver = "ILP" if not args.ad3_only else "AD3"
  83. logger.info("train inference solver: {}".format(train_inference_solver))
  84. logger.info("test inference solver: {}".format(test_inference_solver))
  85. logger.info("author-constraints: {}".format(args.author_constraints))
  86. logger.info("local-initialization: {}".format(args.local_init))
  87. logger.info("no randomized inference constraints: {}".format(args.no_randinf_constraints))
  88. # Select what gpu to use
  89. if args.gpu_index is not None:
  90. torch.cuda.set_device(args.gpu_index)
  91. if args.mode == "global":
  92. learner=GlobalLearner(learning_rate=args.lrate, use_gpu=(args.gpu_index is not None), gpu_index=args.gpu_index, \
  93. rand_inf=args.rand_inf, inferencer=RInf4Forums, ad3=args.ad3, ad3_only=args.ad3_only, \
  94. loss_fn=args.lossfn, local_init=args.local_init, no_randinf_constraints=args.no_randinf_constraints)
  95. learner.set_4forums_issue(args.issue, args.author_constraints)
  96. elif args.mode == "joint":
  97. learner=JointLearner()
  98. else:
  99. learner=LocalLearner()
  100. learner.compile_rules(args.rules)
  101. db=learner.create_dataset(os.path.join(args.dir, args.issue))
  102. stance_macros, stance_accuracies = [], []
  103. disagr_macros, disagr_accuracies = [], []
  104. for i in range(args.start_fold_index, args.end_fold_index):
  105. logger.info("Fold {}".format(i))
  106. dev_fold = (i + 1) % 5
  107. test_posts = FOLDS[args.issue][i]
  108. dev_posts = FOLDS[args.issue][dev_fold]
  109. train_posts = train(FOLDS[args.issue], [i, dev_fold])
  110. if args.debug:
  111. test_posts = test_posts[:30]
  112. dev_posts = dev_posts[:30]
  113. train_posts = train_posts[:30]
  114. db.add_filters(filters=[
  115. ("InThread", "isTrain", "postId_2", train_posts),
  116. ("InThread", "isDev", "postId_2", dev_posts),
  117. ("InThread", "isTest", "postId_2", test_posts),
  118. ("InThread", "isDummy", "postId_2", train_posts[:10])
  119. ])
  120. logger.info("{} train posts, {} dev posts, {} test posts".format(len(train_posts), len(dev_posts), len(test_posts)))
  121. learner.build_feature_extractors(db,
  122. word_emb_f=WORD_EMB_F,
  123. data_f=POSTS_F,
  124. word2idx_f=WORD2IDX_F,
  125. use_wordembed=args.wordembed,
  126. debug=args.debug,
  127. femodule_path=args.fedir,
  128. filters=[("InThread", "isDummy", 1)])
  129. if not os.path.isdir(args.savedir):
  130. os.mkdir(args.savedir)
  131. learner.set_savedir(os.path.join(args.savedir, "f{0}".format(i)))
  132. learner.build_models(db, args.config, netmodules_path=args.nedir)
  133. pklpath = None
  134. if args.pkldir is not None:
  135. pklpath = os.path.join(args.pkldir, "f{0}".format(i))
  136. if args.mode == "global":
  137. learner.extract_data(
  138. db,
  139. train_filters=[("InThread", "isTrain", 1)],
  140. dev_filters=[("InThread", "isDev", 1)],
  141. test_filters=[("InThread", "isTest", 1)],
  142. extract_train=not args.infer_only,
  143. extract_dev=not args.infer_only,
  144. extract_test=True,
  145. pickledir=pklpath,
  146. from_pickle=(args.pkldir is not None),
  147. extract_constraints=args.constraints)
  148. # Passing accum_loss=True to accumulate gradients before backprop
  149. res, heads = learner.train(
  150. db,
  151. train_filters=[("InThread", "isTrain", 1)],
  152. dev_filters=[("InThread", "isDev", 1)],
  153. test_filters=[("InThread", "isTest", 1)],
  154. opt_predicates=set(['HasStance', 'Disagree']),
  155. loss_augmented_inference=args.delta,
  156. continue_from_checkpoint=args.continue_from_checkpoint,
  157. inference_only=args.infer_only,
  158. scale_data=False,
  159. weight_classes=True,
  160. accum_loss=True,
  161. patience=5,
  162. optimizer=optimizer)
  163. else:
  164. if args.continue_from_checkpoint:
  165. learner.init_models(scale_data=False)
  166. if not args.infer_only:
  167. learner.train(db,
  168. train_filters=[("InThread", "isTrain", 1)],
  169. dev_filters=[("InThread", "isDev", 1)],
  170. test_filters=[("InThread", "isTest", 1)],
  171. scale_data=False,
  172. optimizer=optimizer)
  173. if not args.train_only:
  174. learner.extract_data(db, extract_test=True, extract_dev=True, extract_train=True,
  175. test_filters=[("InThread", "isTest", 1)],
  176. dev_filters=[("InThread", "isDev", 1)],
  177. train_filters=[("InThread", "isTrain", 1)],
  178. pickledir=pklpath,
  179. from_pickle=(args.pkldir is not None),
  180. extract_constraints=args.constraints)
  181. res, heads = learner.predict(None, fold='test', get_predicates=True, scale_data=False)
  182. if args.drop_scores:
  183. learner.drop_scores(None, fold='test', output='{0}_f{1}_test_scores.csv'.format(args.issue, i))
  184. learner.drop_scores(None, fold='dev', output='{0}_f{1}_dev_scores.csv'.format(args.issue, i))
  185. learner.drop_scores(None, fold='train', output= '{0}_f{1}_train_scores.csv'.format(args.issue, i))
  186. if not args.train_only:
  187. if 'HasStance' in res.metrics:
  188. y_gold = res.metrics['HasStance']['gold_data']
  189. y_pred = res.metrics['HasStance']['pred_data']
  190. logger.info("HasStance")
  191. labels = list(set(y_gold))
  192. if 'other' in labels:
  193. labels.remove('other')
  194. logger.info('\n'+classification_report(y_gold, y_pred, digits=4))
  195. acc_score = accuracy_score(y_gold, y_pred)
  196. f1_macro = f1_score(y_gold, y_pred, average='macro', labels=labels)
  197. TP, TN, P, N = 0, 0, 0, 0
  198. for l, label in enumerate(y_gold):
  199. if label == labels[0]:
  200. P += 1
  201. if y_pred[l] == label:
  202. TP += 1
  203. if label == labels[1]:
  204. N += 1
  205. if y_pred[l] == label:
  206. TN += 1
  207. accuracy_no_other = (TP + TN) / float(P + N)
  208. logger.info("TEST Acc: {}".format(acc_score))
  209. logger.info("TEST Acc without 'other': {}".format(accuracy_no_other))
  210. logger.info("TEST F1 Macro without 'other': {}".format(f1_macro))
  211. stance_macros.append(f1_macro)
  212. stance_accuracies.append(accuracy_no_other)
  213. if 'Disagree' in res.metrics:
  214. y_gold = res.metrics['Disagree']['gold_data']
  215. y_pred = res.metrics['Disagree']['pred_data']
  216. logger.info("Disagree")
  217. logger.info('\n'+classification_report(y_gold, y_pred, digits=4))
  218. acc_score = accuracy_score(y_gold, y_pred)
  219. f1_macro = f1_score(y_gold, y_pred, average='macro')
  220. logger.info("TEST Acc: {}".format(acc_score))
  221. logger.info("TEST F1 Macro: {}".format(f1_macro))
  222. disagr_macros.append(f1_macro)
  223. disagr_accuracies.append(acc_score)
  224. if 'Agree' in res.metrics:
  225. y_gold = res.metrics['Agree']['gold_data']
  226. y_pred = res.metrics['Agree']['pred_data']
  227. logger.info("Agree")
  228. logger.info('\n'+classification_report(y_gold, y_pred, digits=4))
  229. acc_score = accuracy_score(y_gold, y_pred)
  230. f1_macro = f1_score(y_gold, y_pred, average='macro')
  231. logger.info("TEST Acc: {}".format(acc_score))
  232. logger.info("TEST F1 Macro: {}".format(f1_macro))
  233. if not args.infer_only and args.mode == 'global':
  234. # pure inference time during training
  235. logger.info('\ninference encoding time:\t{0}\n'.format(learner.train_metrics.metrics['encoding_time']) +\
  236. 'inference optimiz. time:\t{0}\n'.format(learner.train_metrics.metrics['solving_time']) +\
  237. 'total inference time: \t{0}'.format(learner.train_metrics.total_time()))
  238. learner.reset_train_metrics()
  239. learner.reset_metrics()
  240. #exit()
  241. logger.info("--------------------------------")
  242. logger.info("HasStance F1 macro (no 'other'): {}".format(round(np.mean(stance_macros), 5)))
  243. logger.info("HasStance Accuracy (no 'other'): {}".format(round(np.mean(stance_accuracies), 5)))
  244. logger.info("HasStance F1 macro (no 'other'): {}".format(str(stance_macros)))
  245. logger.info("HasStance Accuracy (no 'other'): {}".format(str(stance_accuracies)))
  246. logger.info("--------------------------------")
  247. logger.info("Disagreement F1 macro (no 'other'): {}".format(round(np.mean(disagr_macros), 5)))
  248. logger.info("Disagreement Accuracy (no 'other'): {}".format(round(np.mean(disagr_accuracies), 5)))
  249. logger.info("Disagreement F1 macro (no 'other'): {}".format(str(disagr_macros)))
  250. logger.info("Disagreement Accuracy (no 'other'): {}".format(str(disagr_accuracies)))
  251. if __name__ == "__main__":
  252. args = parse_arguments() # put args here so it's global
  253. if args.logging_config:
  254. logger = logging.getLogger()
  255. logging.config.dictConfig(json.load(open(args.logging_config)))
  256. main()