/master_run/6_Product_Prediction/predict_product.py

https://github.com/jnwei/neural_reaction_fingerprint
Python | 124 lines | 73 code | 25 blank | 26 comment | 21 complexity | 3d56de63885f1c06c9a529138d22a260 MD5 | raw file
  1. ''''
  2. Predict the product and compare with the answer
  3. Mostly for exam questions, for my generated questions should be the same
  4. Based upon the results of the prediction vector.
  5. TO Consider: Decide how to handle markovnikov reactions
  6. - Either: Label alkenes as before to get right markovnikov side
  7. : This ends up providing alg. with extra info to help get product right, do we want it to be that successful?
  8. - Or: Remove Si label from reaction smarts, and leave up to probability
  9. '''
  10. import pickle as pkl
  11. import numpy as np
  12. from rdkit import Chem, DataStructs
  13. from rdkit.Chem.Fingerprints import FingerprintMols
  14. import copy
  15. from rdkit.Chem import AllChem
  16. from rxns_lib import Full_rxn_dict
  17. from neuralfingerprint.parse_data import split_smiles_triples
  18. from neuralfingerprint.toolkit import MarkMarkovAtom, GetDoubleBondAtoms, get_molecule_smi
  19. def returnRxnObjfromPred(rxn_num):
  20. # helper function to parse dictionary for reaction type.
  21. # rxn_num is a string, not '0' (NR separtely handled)
  22. rxn_obj = Full_rxn_dict[str(rxn_num)]
  23. return rxn_obj
  24. def returnProductsfromRxnObj(rxn_obj, rct1_smi, rct2_smi):
  25. # run reaction from the rct smi, generate the products molecules
  26. rct1 = Chem.MolFromSmiles(rct1_smi)
  27. if rct2_smi == '[Nd]':
  28. prods = rxn_obj.RunReactants((rct1,))
  29. else:
  30. rct2 = Chem.MolFromSmiles(rct2_smi)
  31. prods = rxn_obj.RunReactants((rct1,rct2))
  32. if len(prods) != 0:
  33. prod_smi_list = [Chem.MolToSmiles(prod_mol) for prod_mol in prods[0]]
  34. else:
  35. if rct2_smi == '[Nd]' : prod_smi_list = [rct1_smi]
  36. else: prod_smi_list = [rct1_smi, rct2_smi]
  37. return '.'.join(prod_smi_list)
  38. def returnProductsfromMarkRxnObj(rxn_obj, rct1_smi, rct2_smi):
  39. # Run reaction using Mark labelled reaction options
  40. alk = Chem.MolFromSmiles(rct1_smi)
  41. # Labeling mark Alkene:
  42. alk2 = copy.deepcopy(alk)
  43. double_bond_list = GetDoubleBondAtoms(alk2)
  44. at1_id, at2_id = double_bond_list[0]
  45. Mark_label_alk = MarkMarkovAtom(alk2, at1_id, at2_id)
  46. if rct2_smi != '[Nd]':
  47. rct2 = Chem.MolFromSmiles(rct2_smi)
  48. prods = rxn_obj.RunReactants((Mark_label_alk,rct2))
  49. else:
  50. try:
  51. prods = rxn_obj.RunReactants((Mark_label_alk,))
  52. except:
  53. print 'invalid number of reactants'
  54. prods = []
  55. if len(prods) != 0:
  56. prod_smi_list = [Chem.MolToSmiles(prod_mol) for prod_mol in prods[0]]
  57. else:
  58. if rct2_smi == '[Nd]' : prod_smi_list = [rct1_smi]
  59. else: prod_smi_list = [rct1_smi, rct2_smi]
  60. return '.'.join(prod_smi_list)
  61. def returnProductsfromPred(rxn_num, rct1_smi, rct2_smi):
  62. # Wrapper function for writing out products given various options
  63. # @output : a string of all the products (with the dots)
  64. # NR
  65. if rxn_num == '0': return '[Nd]'
  66. # Markovnikov reactions
  67. # List of Markovnikov reactions
  68. Mark_list = ['5', '6', '7', '8', '9', '12', '17']
  69. if rxn_num in Mark_list:
  70. rxn_obj = returnRxnObjfromPred(rxn_num)
  71. return returnProductsfromMarkRxnObj(rxn_obj, rct1_smi, rct2_smi)
  72. # Everything else
  73. return returnProductsfromRxnObj(returnRxnObjfromPred(rxn_num), rct1_smi, rct2_smi)
  74. def tanimotoComparison(pred_prod_list, true_prod_list):
  75. # Return the tanimoto score
  76. pred_mol = Chem.MolFromSmiles(pred_prod_list)
  77. answer_mol = Chem.MolFromSmiles(true_prod_list)
  78. pred_fps = FingerprintMols.FingerprintMol(pred_mol)
  79. answer_fps = FingerprintMols.FingerprintMol(answer_mol)
  80. return DataStructs.FingerprintSimilarity(pred_fps, answer_fps)
  81. if __name__== '__main__':
  82. # Get predictions, calculated elsewhere
  83. vec_pred = pkl.load(open('../results/class_3_1_neural1_200each_Wade_prob8_47.dat'))
  84. with open('../../data/test_question/prob8_47.cf.txt') as probf:
  85. rxn_smis = probf.readlines()
  86. with open('../../data/test_question/prob8_47.ans_smi.txt') as ansf:
  87. ans_rxn_smis = ansf.readlines()
  88. test_input_1, _, _ = split_smiles_triples(rxn_smis)
  89. for ii in range(np.shape(vec_pred)[0]):
  90. pred_type = np.argmax( vec_pred[ii,:])
  91. #print pred_type
  92. # get reactants for product prediction
  93. rct1_smi, rct2_smi, _ = test_input_1[ii]
  94. prods = returnProductsfromPred(str(pred_type), rct1_smi, rct2_smi)
  95. print ans_rxn_smis[ii], prods
  96. print 'Similarity score: ', tanimotoComparison(prods, ans_rxn_smis[ii])
  97. print '\n'