PageRenderTime 31ms CodeModel.GetById 19ms RepoModel.GetById 1ms app.codeStats 0ms

/pyklip/kpp/detection/ROC.py

https://bitbucket.org/douglase/pyklip
Python | 279 lines | 226 code | 11 blank | 42 comment | 7 complexity | ccd4740b4cf51d1c0f8e0f4ac6192b9e MD5 | raw file
Possible License(s): BSD-3-Clause
  1. __author__ = 'jruffio'
  2. from pyklip.kpp.utils.kppSuperClass import KPPSuperClass
  3. from pyklip.kpp.utils.oi import *
  4. class ROC(KPPSuperClass):
  5. """
  6. Class for calculating the ROC curve for a dataset.
  7. """
  8. def __init__(self,read_func,filename,filename_detec,
  9. mute=None,
  10. overwrite = False,
  11. detec_distance = None,
  12. ignore_distance = None,
  13. OI_list_folder = None,
  14. threshold_sampling = None,
  15. IWA = None,
  16. OWA = None,
  17. pix2as=None):
  18. """
  19. Define the general parameters of the ROC calculation.
  20. Args:
  21. read_func: lambda function treturning a instrument object where the only input should be a list of filenames
  22. to read.
  23. For e.g.:
  24. read_func = lambda filenames:GPI.GPIData(filenames,recalc_centers=False,recalc_wvs=False,highpass=False)
  25. filename: Filename of the file containing the simulated or real planets.
  26. It should be the complete path unless inputDir is used in initialize().
  27. It can include wild characters. The files will be reduced as given by glob.glob().
  28. filename_detec: Filename of the .csv file with the list of blobs in the image as produced by the class
  29. pyklip.kpp.detection.detection.Detection.
  30. mute: If True prevent printed log outputs.
  31. overwrite: Boolean indicating whether or not files should be overwritten if they exist.
  32. See check_existence().
  33. detec_distance: Distance in pixel between a candidate a true planet to claim the detection. (default 2 pixels)
  34. ignore_distance: Distance in pixel from a true planet up to which detections are ignored. (default 10 pixels)
  35. OI_list_folder: List of Object of Interest (OI) that should be masked from any standard deviation
  36. calculation. See the online documentation for instructions on how to define it.
  37. threshold_sampling: Sampling used for the detection threshold (default np.linspace(0.0,20,200))
  38. IWA: inner working angle in pixels.
  39. OWA: outer working angle in pixels.
  40. pix2as: Platescale (arcsec per pixel).
  41. """
  42. # allocate super class
  43. super(ROC, self).__init__(read_func,filename,
  44. folderName = None,
  45. mute=mute,
  46. N_threads=None,
  47. label=None,
  48. overwrite = overwrite)
  49. if detec_distance is None:
  50. self.detec_distance = 2
  51. else:
  52. self.detec_distance = detec_distance
  53. if ignore_distance is None:
  54. self.ignore_distance = 10
  55. else:
  56. self.ignore_distance = ignore_distance
  57. if threshold_sampling is None:
  58. self.threshold_sampling = np.linspace(0.0,20,200)
  59. else:
  60. self.threshold_sampling = threshold_sampling
  61. self.filename_detec = filename_detec
  62. self.OI_list_folder = OI_list_folder
  63. self.IWA = IWA
  64. self.OWA = OWA
  65. self.pix2as = pix2as
  66. def initialize(self,inputDir = None,
  67. outputDir = None,
  68. folderName = None):
  69. """
  70. Read the files using read_func (see the class __init__ function).
  71. Can be called several time to process all the files matching the filename.
  72. Also define the output filename (if it were to be saved) such that check_existence() can be used.
  73. Args:
  74. inputDir: If defined it allows filename to not include the whole path and just the filename.
  75. Files will be read from inputDir.
  76. If inputDir is None then filename is assumed to have the absolute path.
  77. outputDir: Directory where to create the folder containing the outputs.
  78. A kpop folder will be created to save the data. Convention is:
  79. self.outputDir = outputDir+os.path.sep+"kpop_"+label+os.path.sep+folderName
  80. folderName: Name of the folder containing the outputs. It will be located in outputDir+os.path.sep+"kpop_"+label
  81. Default folder name is "default_out".
  82. A nice convention is to have one folder per spectral template.
  83. If the file read has been created with KPOP, folderName is automatically defined from that
  84. file.
  85. Return: True if all the files matching the filename (with wildcards) have been processed. False otherwise.
  86. """
  87. if not self.mute:
  88. print("~~ INITializing "+self.__class__.__name__+" ~~")
  89. # The super class already read the fits file
  90. init_out = super(ROC, self).initialize(inputDir = inputDir,
  91. outputDir = outputDir,
  92. folderName = folderName,
  93. label=None)
  94. try:
  95. self.folderName = self.prihdr["KPPFOLDN"]+os.path.sep
  96. except:
  97. try:
  98. self.folderName = self.exthdr["METFOLDN"]+os.path.sep
  99. print("/!\ CAUTION: Reading deprecated data.")
  100. except:
  101. try:
  102. self.folderName = self.exthdr["STAFOLDN"]+os.path.sep
  103. print("/!\ CAUTION: Reading deprecated data.")
  104. except:
  105. self.folderName = None
  106. # Check file existence and define filename_path
  107. if self.inputDir is None or os.path.isabs(self.filename_detec):
  108. try:
  109. self.filename_detec_path = os.path.abspath(glob(self.filename_detec)[self.id_matching_file])
  110. self.N_matching_files = len(glob(self.filename_detec))
  111. except:
  112. raise Exception("File "+self.filename_detec+"doesn't exist.")
  113. else:
  114. try:
  115. self.filename_detec_path = os.path.abspath(glob(self.inputDir+os.path.sep+self.filename_detec)[self.id_matching_file])
  116. self.N_matching_files = len(glob(self.inputDir+os.path.sep+self.filename_detec))
  117. except:
  118. raise Exception("File "+self.inputDir+os.path.sep+self.filename_detec+" doesn't exist.")
  119. with open(self.filename_detec_path, 'rb') as csvfile:
  120. reader = csv.reader(csvfile, delimiter=';')
  121. csv_as_list = list(reader)
  122. self.detec_table_labels = csv_as_list[0]
  123. self.detec_table = np.array(csv_as_list[1::], dtype='string').astype(np.float)
  124. if not self.mute:
  125. print("Opened: "+self.filename_detec_path)
  126. self.N_detec = self.detec_table.shape[0]
  127. self.val_id = self.detec_table_labels.index("value")
  128. self.x_id = self.detec_table_labels.index("x")
  129. self.y_id = self.detec_table_labels.index("y")
  130. file_ext_ind = os.path.basename(self.filename_detec_path)[::-1].find(".")
  131. self.prefix = os.path.basename(self.filename_detec_path)[:-(file_ext_ind+1)]
  132. self.suffix = "ROC"
  133. return init_out
  134. def check_existence(self):
  135. """
  136. Return whether or not a filename of the processed data can be found.
  137. If overwrite is True, the output is always false.
  138. Return: boolean
  139. """
  140. if self.folderName is not None:
  141. myname = self.outputDir+os.path.sep+self.folderName+os.path.sep+self.prefix+'-'+self.suffix+'.csv'
  142. else:
  143. myname = self.outputDir+os.path.sep+self.prefix+'-'+self.suffix+'.csv'
  144. file_exist = (len(glob(myname)) >= 1)
  145. if file_exist and not self.mute:
  146. print("Output already exist: "+myname)
  147. if self.overwrite and not self.mute:
  148. print("Overwriting is turned ON!")
  149. return file_exist and not self.overwrite
  150. def calculate(self):
  151. """
  152. Calculate the number of false positives and the number true positives.
  153. :return: FPR,TPR table
  154. """
  155. if not self.mute:
  156. print("~~ Calculating "+self.__class__.__name__+" with parameters " + self.suffix+" ~~")
  157. if self.OI_list_folder is not None:
  158. try:
  159. MJDOBS = self.prihdr.header['MJD-OBS']
  160. except:
  161. raise ValueError("Could not find MJDOBS. Probably because non GPI data. Code needs to be improved")
  162. x_real_object_list,y_real_object_list = \
  163. get_pos_known_objects(self.fakeinfohdr,self.star_name,self.pix2as,center=self.center[0],
  164. MJDOBS=MJDOBS,OI_list_folder=self.OI_list_folder,
  165. xy = True,ignore_fakes=True,IWA=self.IWA,OWA=self.OWA)
  166. row_object_list,col_object_list = get_pos_known_objects(self.fakeinfohdr,self.star_name,self.pix2as,center=self.center[0],
  167. IWA=self.IWA,OWA=self.OWA)
  168. self.false_detec_proba_vec = []
  169. # Loop over all the local maxima stored in the detec csv file
  170. for k in range(self.N_detec):
  171. val_criter = self.detec_table[k,self.val_id]
  172. x_pos = self.detec_table[k,self.x_id]
  173. y_pos = self.detec_table[k,self.y_id]
  174. #remove the detection if it is a real object
  175. if self.OI_list_folder is not None:
  176. reject = False
  177. for x_real_object,y_real_object in zip(x_real_object_list,y_real_object_list):
  178. if (x_pos-x_real_object)**2+(y_pos-y_real_object)**2 < self.ignore_distance**2:
  179. reject = True
  180. break
  181. if reject:
  182. continue
  183. if self.IWA is not None:
  184. if np.sqrt( (x_pos)**2+(y_pos)**2) < self.IWA:
  185. continue
  186. if self.OWA is not None:
  187. if np.sqrt( (x_pos)**2+(y_pos)**2) > self.OWA:
  188. continue
  189. self.false_detec_proba_vec.append(val_criter)
  190. self.true_detec_proba_vec = [self.image[np.round(row_real_object),np.round(col_real_object)] \
  191. for row_real_object,col_real_object in zip(row_object_list,col_object_list)]
  192. self.true_detec_proba_vec = np.array(self.true_detec_proba_vec)[np.where(~np.isnan(self.true_detec_proba_vec))]
  193. self.N_false_pos = np.zeros(self.threshold_sampling.shape)
  194. self.N_true_detec = np.zeros(self.threshold_sampling.shape)
  195. for id,threshold_it in enumerate(self.threshold_sampling):
  196. self.N_false_pos[id] = np.sum(self.false_detec_proba_vec >= threshold_it)
  197. self.N_true_detec[id] = np.sum(self.true_detec_proba_vec >= threshold_it)
  198. return zip(self.threshold_sampling,self.N_false_pos,self.N_true_detec)
  199. def save(self):
  200. """
  201. Save the processed files as:
  202. #user_outputDir#+os.path.sep+self.prefix+'-'+self.suffix+'.fits'
  203. or if self.label and self.folderName are not None:
  204. #user_outputDir#+os.path.sep+"kpop_"+self.label+os.path.sep+self.folderName+os.path.sep+self.prefix+'-'+self.suffix+'.fits'
  205. :return: None
  206. """
  207. if self.folderName is not None:
  208. if not os.path.exists(self.outputDir+os.path.sep+self.folderName):
  209. os.makedirs(self.outputDir+os.path.sep+self.folderName)
  210. myname = self.outputDir+os.path.sep+self.folderName+os.path.sep+self.prefix+'-'+self.suffix+'.csv'
  211. else:
  212. if not os.path.exists(self.outputDir):
  213. os.makedirs(self.outputDir)
  214. myname = self.outputDir+os.path.sep+self.prefix+'-'+self.suffix+'.csv'
  215. if not self.mute:
  216. print("Saving: "+myname)
  217. with open(myname, 'w+') as csvfile:
  218. csvwriter = csv.writer(csvfile, delimiter=';')
  219. csvwriter.writerows([["value","N false pos","N true pos"]])
  220. csvwriter.writerows(zip(self.threshold_sampling,self.N_false_pos,self.N_true_detec))
  221. return None
  222. def load(self):
  223. """
  224. :return: None
  225. """
  226. return None