PageRenderTime 808ms CodeModel.GetById 28ms RepoModel.GetById 0ms app.codeStats 0ms

/AUC/AUC_plot.py

https://gitlab.com/BudzikFUW/budzik_analiza
Python | 228 lines | 148 code | 38 blank | 42 comment | 15 complexity | 612fcd80870bd3e3c1d188f798170b24 MD5 | raw file
  1. import pandas as pd
  2. import numpy as np
  3. import scipy.stats as st
  4. import sys
  5. import matplotlib.pyplot as plt
  6. import os
  7. import seaborn as sns
  8. from itertools import product
  9. from collections import defaultdict
  10. # dabele dla n1, n2 <20 http://www.real-statistics.com/statistics-tables/mann-whitney-table/
  11. FILE_T = './U_statistic.csv'
  12. FILE_S = './AUC_symulacje_full.csv'
  13. FILE_S = './result.csv'
  14. min_per_round=2
  15. U_stat_diag = [ 0., 0., 0., 2., 5., 8., 12., 16.,
  16. 22., 28., 35., 43., 52., 62., 73., 84.,
  17. 97., 110., 124., 139., 155., 172., 190., 208.,
  18. 228., 248., 269., 292., 315., 339., 364., 389.,
  19. 416., 444., 472., 502., 532., 564., 596., 629.,
  20. 663., 698., 734., 771., 809., 847., 887., 927.,
  21. 969., 1011., 1055., 1099., 1144., 1190., 1237., 1285.,
  22. 1334., 1384., 1435., 1486.]
  23. u_stat_01 = [1.0, 4.0, 7.0, 12.0, 19.0, 26.0, 35.0, 44.0, 55.0, 67.0, 80.0, 94.0, 110.0, 126.0, 144.0, 162.0, 182.0, 203.0, 225.0, 248.0, 272.0, 297.0, 323.0, 350.0, 379.0, 408.0, 439.0, 470.0, 503.0, 537.0, 572.0, 608.0, 645.0, 683.0, 722.0, 762.0, 803.0, 846.0, 889.0, 934.0, 979.0, 1026.0, 1073.0, 1122.0, 1172.0, 1222.0, 1274.0, 1327.0, 1381.0, 1436.0, 1492.0, 1550.0, 1608.0, 1667.0, 1727.0, 1789.0, 1851.0, 1915.0, 1979.0, 2045.0, 2111.0, 2179.0, 2248.0, 2317.0, 2388.0, 2460.0, 2533.0, 2607.0, 2682.0, 2758.0, 2835.0, 2913.0, 2992.0, 3073.0, 3154.0, 3236.0, 3320.0, 3404.0, 3490.0, 3576.0, 3664.0, 3752.0, 3842.0, 3932.0, 4024.0, 4117.0, 4211.0, 4306.0, 4402.0, 4498.0, 4596.0, 4695.0, 4796.0, 4897.0, 4999.0, 5102.0, 5206.0, 5311.0, 5418.0, 5525.0, 5633.0, 5743.0, 5853.0, 5965.0, 6077.0, 6191.0, 6305.0, 6421.0, 6538.0, 6656.0, 6774.0, 6894.0, 7015.0, 7137.0, 7260.0, 7384.0, 7509.0, 7635.0, 7762.0, 7890.0, 8019.0, 8149.0, 8280.0, 8412.0, 8546.0, 8680.0, 8815.0, 8952.0, 9089.0, 9228.0, 9367.0, 9508.0, 9649.0, 9792.0, 9935.0, 10080.0, 10226.0, 10372.0, 10520.0, 10669.0, 10819.0, 10970.0, 11121.0, 11274.0, 11428.0, 11583.0, 11739.0, 11896.0, 12054.0, 12213.0]
  24. u_stat_005 = [1.0, 4.0, 8.0, 14.0, 20.0, 28.0, 37.0, 48.0, 59.0, 72.0, 86.0, 101.0, 117.0, 134.0, 152.0, 172.0, 192.0, 214.0, 237.0, 261.0, 286.0, 312.0, 339.0, 368.0, 397.0, 428.0, 460.0, 492.0, 526.0, 561.0, 597.0, 635.0, 673.0, 712.0, 753.0, 794.0, 837.0, 880.0, 925.0, 971.0, 1018.0, 1066.0, 1115.0, 1165.0, 1216.0, 1269.0, 1322.0, 1377.0, 1432.0, 1489.0, 1546.0, 1605.0, 1665.0, 1726.0, 1788.0, 1851.0, 1915.0, 1980.0, 2046.0, 2114.0, 2182.0, 2251.0, 2322.0, 2393.0, 2466.0, 2540.0, 2614.0, 2690.0, 2767.0, 2845.0, 2924.0, 3004.0, 3085.0, 3167.0, 3250.0, 3335.0, 3420.0, 3506.0, 3594.0, 3682.0, 3772.0, 3862.0, 3954.0, 4047.0, 4140.0, 4235.0, 4331.0, 4428.0, 4526.0, 4625.0, 4725.0, 4826.0, 4929.0, 5032.0, 5136.0, 5241.0, 5348.0, 5455.0, 5564.0, 5673.0, 5784.0, 5896.0, 6008.0, 6122.0, 6237.0, 6353.0, 6470.0, 6588.0, 6707.0, 6827.0, 6948.0, 7070.0, 7193.0, 7317.0, 7443.0, 7569.0, 7696.0, 7825.0, 7954.0, 8085.0, 8216.0, 8349.0, 8483.0, 8617.0, 8753.0, 8890.0, 9028.0, 9167.0, 9306.0, 9447.0, 9589.0, 9732.0, 9877.0, 10022.0, 10168.0, 10315.0, 10463.0, 10613.0, 10763.0, 10915.0, 11067.0, 11220.0, 11375.0, 11531.0, 11687.0, 11845.0, 12004.0, 12163.0, 12324.0, 12486.0]
  25. u_stat_0025 = [1.0, 4.0, 9.0, 15.0, 22.0, 30.0, 40.0, 50.0, 63.0, 76.0, 90.0, 106.0, 123.0, 140.0, 160.0, 180.0, 201.0, 224.0, 247.0, 272.0, 298.0, 325.0, 353.0, 383.0, 413.0, 445.0, 478.0, 511.0, 546.0, 582.0, 619.0, 658.0, 697.0, 737.0, 779.0, 822.0, 865.0, 910.0, 956.0, 1003.0, 1052.0, 1101.0, 1151.0, 1203.0, 1255.0, 1309.0, 1363.0, 1419.0, 1476.0, 1534.0, 1593.0, 1653.0, 1714.0, 1777.0, 1840.0, 1904.0, 1970.0, 2037.0, 2104.0, 2173.0, 2243.0, 2314.0, 2386.0, 2459.0, 2533.0, 2608.0, 2685.0, 2762.0, 2840.0, 2920.0, 3000.0, 3082.0, 3165.0, 3249.0, 3334.0, 3419.0, 3506.0, 3595.0, 3684.0, 3774.0, 3865.0, 3957.0, 4051.0, 4145.0, 4241.0, 4338.0, 4435.0, 4534.0, 4634.0, 4735.0, 4837.0, 4940.0, 5044.0, 5149.0, 5255.0, 5362.0, 5470.0, 5580.0, 5690.0, 5802.0, 5914.0, 6028.0, 6142.0, 6258.0, 6375.0, 6493.0, 6612.0, 6732.0, 6853.0, 6975.0, 7098.0, 7222.0, 7347.0, 7474.0, 7601.0, 7729.0, 7859.0, 7989.0, 8121.0, 8253.0, 8387.0, 8522.0, 8658.0, 8795.0, 8932.0, 9071.0, 9211.0, 9352.0, 9495.0, 9638.0, 9782.0, 9927.0, 10074.0, 10221.0, 10369.0, 10519.0, 10669.0, 10821.0, 10974.0, 11127.0, 11282.0, 11438.0, 11595.0, 11752.0, 11911.0, 12071.0, 12232.0, 12394.0, 12558.0, 12722.0]
  26. def AUC_T(U, N, p):
  27. # U wartsc statystyki
  28. # N liczba T i NT przed walidacja
  29. # p = p_value
  30. R_out = []
  31. N_out = []
  32. for u, n in zip(U, N):
  33. if np.isnan(u):
  34. R_out.append(np.nan)
  35. N_out.append(n)
  36. else:
  37. R_out.append((n**2 - u)/ n**2)
  38. N_out.append(n)
  39. for n in range(20, 101):
  40. meanrank = n**2 / 2
  41. sd = np.sqrt(n**2 * (2*n+1) / 12.0) # bez tiecorrection, ale powinno byc ok
  42. z = st.norm.ppf(1-p) # test dwustronny
  43. u = z*sd + meanrank
  44. AUC = u / n**2
  45. R_out.append(AUC)
  46. N_out.append(n)
  47. return N_out, R_out
  48. def main(file_T, file_S, file_DATABASE, NAME, main_diagnosis = None, auc_max=False):
  49. AUC_T_data = pd.read_csv(file_T, sep=',')
  50. #p_data = AUC_T_data.columns.values[1:]
  51. p_data = ['0.05']
  52. print p_data
  53. AUC_S_data = pd.read_csv(file_S, sep=',')
  54. AUC_DATABASE = pd.read_csv(file_DATABASE, sep=',')
  55. f = plt.figure(figsize=(10,5), dpi=600)
  56. ax = f.add_subplot(1, 1, 1)
  57. color = {'0.01':'b', '0.05':'gray'}
  58. for ind, p in enumerate(p_data,1):
  59. # f = plt.figure()
  60. # ax = f.add_subplot(1, 1, 1)
  61. #AUC_TEORETYCZNE
  62. data_U = AUC_T_data[p].values
  63. data_n = AUC_T_data['n'].values
  64. N_values, AUC_values = AUC_T(data_U, data_n, float(p))
  65. ax.plot(N_values, AUC_values, color = color[p] ,label = 'theoretical p={}'.format(p), alpha=0.75)
  66. #~ xx= []
  67. #~ yy =[]
  68. #~ for nr, i in enumerate(u_stat_01):
  69. #~ yy.append(i/(nr+1)**2)
  70. #~ xx.append(nr+1)
  71. #~
  72. #~ ax.plot(xx, yy, label='theory 0.1')
  73. #~
  74. #~ u_stat_005
  75. #~
  76. #~ xx= []
  77. #~ yy =[]
  78. #~ for nr, i in enumerate(u_stat_005):
  79. #~ yy.append(i/(nr+1)**2)
  80. #~ xx.append(nr+1)
  81. #~
  82. #~ ax.plot(xx, yy, label='theory 0.05')
  83. #~
  84. #~ xx= []
  85. #~ yy =[]
  86. #~ for nr, i in enumerate(u_stat_0025):
  87. #~ yy.append(i/(nr+1)**2)
  88. #~ xx.append(nr+1)
  89. #~
  90. #~ ax.plot(xx, yy, label='theory 0.025')
  91. #AUC SYMULACJE
  92. data_for_AUC = AUC_S_data[['n_targets', 'auc95' ]].sort_values('n_targets').values
  93. data_AUC = data_for_AUC[:,1]
  94. data_n = data_for_AUC[:,0]
  95. ax.plot(data_n, data_AUC, color = 'blue', label = 'simulation p={}'.format(p), alpha=0.75)
  96. #~ data_for_AUC = AUC_S_data[['n_targets', 'auc90' ]].sort_values('n_targets').values
  97. #~ data_AUC = data_for_AUC[:,1]
  98. #~ data_n = data_for_AUC[:,0]
  99. #~ ax.plot(data_n, data_AUC, color = 'red', label = 'AUC crit. simulation p={}'.format(0.1), alpha=0.75)
  100. color = ['g', 'y', 'r', 'c', 'm', 'k']
  101. for i, diag in enumerate(sorted(list(set(AUC_DATABASE['diagnosis'].values)))):
  102. if auc_max:
  103. data_AUC_n, data_AUC = get_auc_max(AUC_DATABASE, diag)
  104. else:
  105. data_AUC = AUC_DATABASE[AUC_DATABASE['diagnosis']==diag]['AUC'].values
  106. data_AUC_n = AUC_DATABASE[AUC_DATABASE['diagnosis']==diag]['N_avr_targets'].values
  107. if main_diagnosis is None:
  108. color__ = color[i]
  109. else:
  110. if diag == main_diagnosis:
  111. color__ = 'k'
  112. else:
  113. color__ = 'lightgray'
  114. if diag=='CONTROL_ADULT':
  115. diag = 'control adult'
  116. if diag=='CONTROL_CHILDREN':
  117. diag = 'control children'
  118. # DATA FILTERING:
  119. mask = data_AUC_n>=20
  120. data_AUC_n = data_AUC_n[mask]
  121. data_AUC = data_AUC[mask]
  122. ax.plot(data_AUC_n, data_AUC, 'o', color = color__, label = '{}'.format(diag,), alpha=0.75)
  123. ax.set_xlim(0, 100)
  124. ax.set_ylim(0.0, 1)
  125. ax.set_xlabel("N1 = N2 = N target and non-target example count")
  126. ax.set_ylabel("AUC")
  127. # plt.title("P = {} one-tailed".format(p))
  128. # plt.savefig('./AUC_{}.png'.format(p))
  129. ############################
  130. if main_diagnosis:
  131. import matplotlib.lines as mlines
  132. theory_line = mlines.Line2D([], [], color='gray',
  133. label='theoretical p=0.05')
  134. sim_line = mlines.Line2D([], [], color='gray', marker='>', linestyle='',
  135. label='simulation p=0.05')
  136. main_diag_line = mlines.Line2D([], [], color='k', marker='o', linestyle='',
  137. label='{}'.format(main_diagnosis))
  138. other_diag_line = mlines.Line2D([], [], color='lightgray', marker='o', linestyle='',
  139. label='others')
  140. plt.legend(handles=[theory_line, sim_line, main_diag_line, other_diag_line], loc=0)
  141. ######################################################################
  142. else:
  143. plt.legend(loc=0)
  144. plt.savefig('./AUC_{}_{}_max{}_min_per_round_{}.png'.format(main_diagnosis,NAME, auc_max, min_per_round), dpi=600)
  145. def get_auc_max(d, diag):
  146. d['crs_group'] = d.diagnosis
  147. # d.crs_group.loc[(d.crs_group=='MCS-') | (d.crs_group=='MCS+')] = 'MCS'
  148. patients = d["id"].unique()
  149. d_max = pd.DataFrame(columns=d.columns)
  150. for p in patients:
  151. rounds = d["round"].loc[(d["id"]==p)].unique()
  152. for r in rounds:
  153. p_vals = d["AUC"].loc[(d["id"]==p) & (d["round"]==r)]
  154. temp = d.loc[p_vals.argmax()]
  155. print('hh', r, p, len(p_vals), diag)
  156. to_append = temp.append(pd.Series({'nb_paradigms': len(p_vals)}))
  157. if len(p_vals)>=min_per_round or diag=='UWS' or diag=='CONTROL_ADULT' or diag=='CONTROL_CHILDREN':
  158. d_max = d_max.append(to_append, ignore_index=True)
  159. data_AUC = d_max[d_max['diagnosis']==diag]['AUC'].values
  160. data_AUC_n = d_max[d_max['diagnosis']==diag]['N_avr_targets'].values
  161. return data_AUC_n, data_AUC
  162. if __name__ == '__main__':
  163. database_AUC = [ ["AUC_downsamping_database_1.csv", "downsamping_1"],
  164. ]
  165. AUC_S_data = pd.read_csv(FILE_S, sep=',')
  166. AUC_DATABASE = pd.read_csv('AUC_downsamping_database_1.csv', sep=',')
  167. MAIN_DIAGNOSES = set(AUC_DATABASE['diagnosis'].values)
  168. for file_, name in database_AUC:
  169. data = main(FILE_T, FILE_S, file_, name, None, True)
  170. data = main(FILE_T, FILE_S, file_, name, None, False)
  171. #for main_diagnosis in MAIN_DIAGNOSES:
  172. # for file_, name in database_AUC:
  173. #
  174. # data = main(FILE_T, FILE_S, file_, name, main_diagnosis, True)