/AUC/AUC_plot.py
Python | 228 lines | 148 code | 38 blank | 42 comment | 15 complexity | 612fcd80870bd3e3c1d188f798170b24 MD5 | raw file
- import pandas as pd
- import numpy as np
- import scipy.stats as st
- import sys
- import matplotlib.pyplot as plt
- import os
- import seaborn as sns
- from itertools import product
- from collections import defaultdict
- # dabele dla n1, n2 <20 http://www.real-statistics.com/statistics-tables/mann-whitney-table/
- FILE_T = './U_statistic.csv'
- FILE_S = './AUC_symulacje_full.csv'
- FILE_S = './result.csv'
- min_per_round=2
- U_stat_diag = [ 0., 0., 0., 2., 5., 8., 12., 16.,
- 22., 28., 35., 43., 52., 62., 73., 84.,
- 97., 110., 124., 139., 155., 172., 190., 208.,
- 228., 248., 269., 292., 315., 339., 364., 389.,
- 416., 444., 472., 502., 532., 564., 596., 629.,
- 663., 698., 734., 771., 809., 847., 887., 927.,
- 969., 1011., 1055., 1099., 1144., 1190., 1237., 1285.,
- 1334., 1384., 1435., 1486.]
- u_stat_01 = [1.0, 4.0, 7.0, 12.0, 19.0, 26.0, 35.0, 44.0, 55.0, 67.0, 80.0, 94.0, 110.0, 126.0, 144.0, 162.0, 182.0, 203.0, 225.0, 248.0, 272.0, 297.0, 323.0, 350.0, 379.0, 408.0, 439.0, 470.0, 503.0, 537.0, 572.0, 608.0, 645.0, 683.0, 722.0, 762.0, 803.0, 846.0, 889.0, 934.0, 979.0, 1026.0, 1073.0, 1122.0, 1172.0, 1222.0, 1274.0, 1327.0, 1381.0, 1436.0, 1492.0, 1550.0, 1608.0, 1667.0, 1727.0, 1789.0, 1851.0, 1915.0, 1979.0, 2045.0, 2111.0, 2179.0, 2248.0, 2317.0, 2388.0, 2460.0, 2533.0, 2607.0, 2682.0, 2758.0, 2835.0, 2913.0, 2992.0, 3073.0, 3154.0, 3236.0, 3320.0, 3404.0, 3490.0, 3576.0, 3664.0, 3752.0, 3842.0, 3932.0, 4024.0, 4117.0, 4211.0, 4306.0, 4402.0, 4498.0, 4596.0, 4695.0, 4796.0, 4897.0, 4999.0, 5102.0, 5206.0, 5311.0, 5418.0, 5525.0, 5633.0, 5743.0, 5853.0, 5965.0, 6077.0, 6191.0, 6305.0, 6421.0, 6538.0, 6656.0, 6774.0, 6894.0, 7015.0, 7137.0, 7260.0, 7384.0, 7509.0, 7635.0, 7762.0, 7890.0, 8019.0, 8149.0, 8280.0, 8412.0, 8546.0, 8680.0, 8815.0, 8952.0, 9089.0, 9228.0, 9367.0, 9508.0, 9649.0, 9792.0, 9935.0, 10080.0, 10226.0, 10372.0, 10520.0, 10669.0, 10819.0, 10970.0, 11121.0, 11274.0, 11428.0, 11583.0, 11739.0, 11896.0, 12054.0, 12213.0]
- u_stat_005 = [1.0, 4.0, 8.0, 14.0, 20.0, 28.0, 37.0, 48.0, 59.0, 72.0, 86.0, 101.0, 117.0, 134.0, 152.0, 172.0, 192.0, 214.0, 237.0, 261.0, 286.0, 312.0, 339.0, 368.0, 397.0, 428.0, 460.0, 492.0, 526.0, 561.0, 597.0, 635.0, 673.0, 712.0, 753.0, 794.0, 837.0, 880.0, 925.0, 971.0, 1018.0, 1066.0, 1115.0, 1165.0, 1216.0, 1269.0, 1322.0, 1377.0, 1432.0, 1489.0, 1546.0, 1605.0, 1665.0, 1726.0, 1788.0, 1851.0, 1915.0, 1980.0, 2046.0, 2114.0, 2182.0, 2251.0, 2322.0, 2393.0, 2466.0, 2540.0, 2614.0, 2690.0, 2767.0, 2845.0, 2924.0, 3004.0, 3085.0, 3167.0, 3250.0, 3335.0, 3420.0, 3506.0, 3594.0, 3682.0, 3772.0, 3862.0, 3954.0, 4047.0, 4140.0, 4235.0, 4331.0, 4428.0, 4526.0, 4625.0, 4725.0, 4826.0, 4929.0, 5032.0, 5136.0, 5241.0, 5348.0, 5455.0, 5564.0, 5673.0, 5784.0, 5896.0, 6008.0, 6122.0, 6237.0, 6353.0, 6470.0, 6588.0, 6707.0, 6827.0, 6948.0, 7070.0, 7193.0, 7317.0, 7443.0, 7569.0, 7696.0, 7825.0, 7954.0, 8085.0, 8216.0, 8349.0, 8483.0, 8617.0, 8753.0, 8890.0, 9028.0, 9167.0, 9306.0, 9447.0, 9589.0, 9732.0, 9877.0, 10022.0, 10168.0, 10315.0, 10463.0, 10613.0, 10763.0, 10915.0, 11067.0, 11220.0, 11375.0, 11531.0, 11687.0, 11845.0, 12004.0, 12163.0, 12324.0, 12486.0]
- u_stat_0025 = [1.0, 4.0, 9.0, 15.0, 22.0, 30.0, 40.0, 50.0, 63.0, 76.0, 90.0, 106.0, 123.0, 140.0, 160.0, 180.0, 201.0, 224.0, 247.0, 272.0, 298.0, 325.0, 353.0, 383.0, 413.0, 445.0, 478.0, 511.0, 546.0, 582.0, 619.0, 658.0, 697.0, 737.0, 779.0, 822.0, 865.0, 910.0, 956.0, 1003.0, 1052.0, 1101.0, 1151.0, 1203.0, 1255.0, 1309.0, 1363.0, 1419.0, 1476.0, 1534.0, 1593.0, 1653.0, 1714.0, 1777.0, 1840.0, 1904.0, 1970.0, 2037.0, 2104.0, 2173.0, 2243.0, 2314.0, 2386.0, 2459.0, 2533.0, 2608.0, 2685.0, 2762.0, 2840.0, 2920.0, 3000.0, 3082.0, 3165.0, 3249.0, 3334.0, 3419.0, 3506.0, 3595.0, 3684.0, 3774.0, 3865.0, 3957.0, 4051.0, 4145.0, 4241.0, 4338.0, 4435.0, 4534.0, 4634.0, 4735.0, 4837.0, 4940.0, 5044.0, 5149.0, 5255.0, 5362.0, 5470.0, 5580.0, 5690.0, 5802.0, 5914.0, 6028.0, 6142.0, 6258.0, 6375.0, 6493.0, 6612.0, 6732.0, 6853.0, 6975.0, 7098.0, 7222.0, 7347.0, 7474.0, 7601.0, 7729.0, 7859.0, 7989.0, 8121.0, 8253.0, 8387.0, 8522.0, 8658.0, 8795.0, 8932.0, 9071.0, 9211.0, 9352.0, 9495.0, 9638.0, 9782.0, 9927.0, 10074.0, 10221.0, 10369.0, 10519.0, 10669.0, 10821.0, 10974.0, 11127.0, 11282.0, 11438.0, 11595.0, 11752.0, 11911.0, 12071.0, 12232.0, 12394.0, 12558.0, 12722.0]
- def AUC_T(U, N, p):
- # U wartsc statystyki
- # N liczba T i NT przed walidacja
- # p = p_value
- R_out = []
- N_out = []
- for u, n in zip(U, N):
- if np.isnan(u):
- R_out.append(np.nan)
- N_out.append(n)
- else:
- R_out.append((n**2 - u)/ n**2)
- N_out.append(n)
- for n in range(20, 101):
- meanrank = n**2 / 2
- sd = np.sqrt(n**2 * (2*n+1) / 12.0) # bez tiecorrection, ale powinno byc ok
- z = st.norm.ppf(1-p) # test dwustronny
- u = z*sd + meanrank
- AUC = u / n**2
- R_out.append(AUC)
- N_out.append(n)
- return N_out, R_out
- def main(file_T, file_S, file_DATABASE, NAME, main_diagnosis = None, auc_max=False):
- AUC_T_data = pd.read_csv(file_T, sep=',')
- #p_data = AUC_T_data.columns.values[1:]
-
- p_data = ['0.05']
-
- print p_data
- AUC_S_data = pd.read_csv(file_S, sep=',')
- AUC_DATABASE = pd.read_csv(file_DATABASE, sep=',')
- f = plt.figure(figsize=(10,5), dpi=600)
- ax = f.add_subplot(1, 1, 1)
- color = {'0.01':'b', '0.05':'gray'}
- for ind, p in enumerate(p_data,1):
- # f = plt.figure()
- # ax = f.add_subplot(1, 1, 1)
- #AUC_TEORETYCZNE
- data_U = AUC_T_data[p].values
- data_n = AUC_T_data['n'].values
- N_values, AUC_values = AUC_T(data_U, data_n, float(p))
- ax.plot(N_values, AUC_values, color = color[p] ,label = 'theoretical p={}'.format(p), alpha=0.75)
-
- #~ xx= []
- #~ yy =[]
- #~ for nr, i in enumerate(u_stat_01):
- #~ yy.append(i/(nr+1)**2)
- #~ xx.append(nr+1)
- #~
- #~ ax.plot(xx, yy, label='theory 0.1')
- #~
- #~ u_stat_005
- #~
- #~ xx= []
- #~ yy =[]
- #~ for nr, i in enumerate(u_stat_005):
- #~ yy.append(i/(nr+1)**2)
- #~ xx.append(nr+1)
- #~
- #~ ax.plot(xx, yy, label='theory 0.05')
- #~
- #~ xx= []
- #~ yy =[]
- #~ for nr, i in enumerate(u_stat_0025):
- #~ yy.append(i/(nr+1)**2)
- #~ xx.append(nr+1)
- #~
- #~ ax.plot(xx, yy, label='theory 0.025')
- #AUC SYMULACJE
- data_for_AUC = AUC_S_data[['n_targets', 'auc95' ]].sort_values('n_targets').values
- data_AUC = data_for_AUC[:,1]
- data_n = data_for_AUC[:,0]
- ax.plot(data_n, data_AUC, color = 'blue', label = 'simulation p={}'.format(p), alpha=0.75)
-
- #~ data_for_AUC = AUC_S_data[['n_targets', 'auc90' ]].sort_values('n_targets').values
- #~ data_AUC = data_for_AUC[:,1]
- #~ data_n = data_for_AUC[:,0]
- #~ ax.plot(data_n, data_AUC, color = 'red', label = 'AUC crit. simulation p={}'.format(0.1), alpha=0.75)
- color = ['g', 'y', 'r', 'c', 'm', 'k']
- for i, diag in enumerate(sorted(list(set(AUC_DATABASE['diagnosis'].values)))):
- if auc_max:
- data_AUC_n, data_AUC = get_auc_max(AUC_DATABASE, diag)
- else:
- data_AUC = AUC_DATABASE[AUC_DATABASE['diagnosis']==diag]['AUC'].values
- data_AUC_n = AUC_DATABASE[AUC_DATABASE['diagnosis']==diag]['N_avr_targets'].values
- if main_diagnosis is None:
- color__ = color[i]
- else:
- if diag == main_diagnosis:
- color__ = 'k'
- else:
- color__ = 'lightgray'
-
- if diag=='CONTROL_ADULT':
- diag = 'control adult'
- if diag=='CONTROL_CHILDREN':
- diag = 'control children'
-
- # DATA FILTERING:
- mask = data_AUC_n>=20
-
- data_AUC_n = data_AUC_n[mask]
- data_AUC = data_AUC[mask]
-
- ax.plot(data_AUC_n, data_AUC, 'o', color = color__, label = '{}'.format(diag,), alpha=0.75)
- ax.set_xlim(0, 100)
- ax.set_ylim(0.0, 1)
- ax.set_xlabel("N1 = N2 = N target and non-target example count")
- ax.set_ylabel("AUC")
- # plt.title("P = {} one-tailed".format(p))
- # plt.savefig('./AUC_{}.png'.format(p))
- ############################
- if main_diagnosis:
- import matplotlib.lines as mlines
- theory_line = mlines.Line2D([], [], color='gray',
- label='theoretical p=0.05')
- sim_line = mlines.Line2D([], [], color='gray', marker='>', linestyle='',
- label='simulation p=0.05')
-
- main_diag_line = mlines.Line2D([], [], color='k', marker='o', linestyle='',
- label='{}'.format(main_diagnosis))
-
- other_diag_line = mlines.Line2D([], [], color='lightgray', marker='o', linestyle='',
- label='others')
- plt.legend(handles=[theory_line, sim_line, main_diag_line, other_diag_line], loc=0)
- ######################################################################
- else:
- plt.legend(loc=0)
- plt.savefig('./AUC_{}_{}_max{}_min_per_round_{}.png'.format(main_diagnosis,NAME, auc_max, min_per_round), dpi=600)
- def get_auc_max(d, diag):
- d['crs_group'] = d.diagnosis
- # d.crs_group.loc[(d.crs_group=='MCS-') | (d.crs_group=='MCS+')] = 'MCS'
- patients = d["id"].unique()
- d_max = pd.DataFrame(columns=d.columns)
- for p in patients:
- rounds = d["round"].loc[(d["id"]==p)].unique()
- for r in rounds:
- p_vals = d["AUC"].loc[(d["id"]==p) & (d["round"]==r)]
- temp = d.loc[p_vals.argmax()]
- print('hh', r, p, len(p_vals), diag)
-
- to_append = temp.append(pd.Series({'nb_paradigms': len(p_vals)}))
- if len(p_vals)>=min_per_round or diag=='UWS' or diag=='CONTROL_ADULT' or diag=='CONTROL_CHILDREN':
- d_max = d_max.append(to_append, ignore_index=True)
-
- data_AUC = d_max[d_max['diagnosis']==diag]['AUC'].values
- data_AUC_n = d_max[d_max['diagnosis']==diag]['N_avr_targets'].values
-
- return data_AUC_n, data_AUC
- if __name__ == '__main__':
- database_AUC = [ ["AUC_downsamping_database_1.csv", "downsamping_1"],
-
- ]
- AUC_S_data = pd.read_csv(FILE_S, sep=',')
- AUC_DATABASE = pd.read_csv('AUC_downsamping_database_1.csv', sep=',')
- MAIN_DIAGNOSES = set(AUC_DATABASE['diagnosis'].values)
- for file_, name in database_AUC:
- data = main(FILE_T, FILE_S, file_, name, None, True)
- data = main(FILE_T, FILE_S, file_, name, None, False)
- #for main_diagnosis in MAIN_DIAGNOSES:
- # for file_, name in database_AUC:
- #
- # data = main(FILE_T, FILE_S, file_, name, main_diagnosis, True)