/CHAID/column.py

https://github.com/Rambatino/CHAID · Python · 309 lines · 196 code · 48 blank · 65 comment · 60 complexity · 59891d432f6286f6b4dcc1bb5181bd0a MD5 · raw file

  1. import numpy as np
  2. from math import isnan
  3. from itertools import combinations
  4. from .mapping_dict import MappingDict
  5. def is_sorted(ndarr, nan_val=None):
  6. store = []
  7. for arr in ndarr:
  8. if arr == [] or len(arr) == 1: continue
  9. if nan_val is not None and nan_val in arr:
  10. arr.remove(nan_val)
  11. store.append(arr[-1] - arr[0] == len(arr) - 1)
  12. return all(store)
  13. class Column(object):
  14. """
  15. A numpy array with metadata
  16. Parameters
  17. ----------
  18. arr : iterable object
  19. The numpy array
  20. metadata : dict
  21. The substitutions of the vector
  22. missing_id : string
  23. An identifier for the missing value to be associated
  24. substitute : bool
  25. Whether the objects in the given array need to be substitued for
  26. integers
  27. """
  28. def __init__(self, arr=None, metadata=None, missing_id='<missing>',
  29. substitute=True, weights=None, name=None):
  30. self.metadata = dict(metadata or {})
  31. self.arr = np.array(arr)
  32. self._missing_id = missing_id
  33. self.weights = weights
  34. self.name = name
  35. def __iter__(self):
  36. return iter(self.arr)
  37. def __getitem__(self, key):
  38. raise NotImplementedError
  39. def __setitem__(self, key, value):
  40. raise NotImplementedError
  41. def possible_groupings(self):
  42. raise NotImplementedError
  43. @property
  44. def type(self):
  45. """
  46. Returns a string representing the type
  47. """
  48. raise NotImplementedError
  49. def deep_copy(self):
  50. """
  51. Returns a deep copy
  52. """
  53. raise NotImplementedError
  54. def bell_set(self, collection, ordinal=False):
  55. """
  56. Calculates the Bell set
  57. """
  58. if len(collection) == 1:
  59. yield [ collection ]
  60. return
  61. first = collection[0]
  62. for smaller in self.bell_set(collection[1:]):
  63. for n, subset in enumerate(smaller):
  64. if not ordinal or (ordinal and is_sorted(smaller[:n] + [[ first ] + subset] + smaller[n+1:], self._nan)):
  65. yield smaller[:n] + [[ first ] + subset] + smaller[n+1:]
  66. if not ordinal or (ordinal and is_sorted([ [ first ] ] + smaller, self._nan)):
  67. yield [ [ first ] ] + smaller
  68. class NominalColumn(Column):
  69. """
  70. A column containing numerical values that are unrelated to
  71. one another (i.e. do not follow a progression)
  72. """
  73. def __init__(self, arr=None, metadata=None, missing_id='<missing>',
  74. substitute=True, weights=None, name=None):
  75. super(self.__class__, self).__init__(arr, metadata=metadata, missing_id=missing_id, weights=weights, name=name)
  76. if substitute and metadata is None:
  77. self.substitute_values(arr)
  78. self._groupings = MappingDict()
  79. for x in np.unique(self.arr):
  80. self._groupings[x] = [x]
  81. def deep_copy(self):
  82. """
  83. Returns a deep copy.
  84. """
  85. return NominalColumn(self.arr, metadata=self.metadata, name=self.name,
  86. missing_id=self._missing_id, substitute=False, weights=self.weights)
  87. def substitute_values(self, vect):
  88. """
  89. Internal method to substitute integers into the vector, and construct
  90. metadata to convert back to the original vector.
  91. np.nan is always given -1, all other objects are given integers in
  92. order of apperence.
  93. Parameters
  94. ----------
  95. vect : np.array
  96. the vector in which to substitute values in
  97. """
  98. try:
  99. unique = np.unique(vect)
  100. except:
  101. unique = set(vect)
  102. unique = [
  103. x for x in unique if not isinstance(x, float) or not isnan(x)
  104. ]
  105. arr = np.copy(vect)
  106. for new_id, value in enumerate(unique):
  107. np.place(arr, arr==value, new_id)
  108. self.metadata[new_id] = value
  109. arr = arr.astype(np.float)
  110. np.place(arr, np.isnan(arr), -1)
  111. self.arr = arr
  112. if -1 in arr:
  113. self.metadata[-1] = self._missing_id
  114. def __getitem__(self, key):
  115. new_weights = None if self.weights is None else self.weights[key]
  116. return NominalColumn(self.arr[key], metadata=self.metadata, substitute=False, weights=new_weights, name=self.name)
  117. def __setitem__(self, key, value):
  118. self.arr[key] = value
  119. return self
  120. def groups(self):
  121. return list(self._groupings.values())
  122. def possible_groupings(self):
  123. return combinations(self._groupings.keys(), 2)
  124. def all_combinations(self):
  125. bell_set = self.bell_set(sorted(list(self._groupings.keys())))
  126. next(bell_set)
  127. return bell_set
  128. def group(self, x, y):
  129. self._groupings[x] += self._groupings[y]
  130. del self._groupings[y]
  131. self.arr[self.arr == y] = x
  132. @property
  133. def type(self):
  134. """
  135. Returns a string representing the type
  136. """
  137. return 'nominal'
  138. class OrdinalColumn(Column):
  139. """
  140. A column containing integer values that have an order
  141. """
  142. def __init__(self, arr=None, metadata=None, missing_id='<missing>',
  143. groupings=None, substitute=True, weights=None, name=None):
  144. super(self.__class__, self).__init__(arr, metadata, missing_id=missing_id, weights=weights, name=name)
  145. self._nan = np.array([np.nan]).astype(int)[0]
  146. if substitute and metadata is None:
  147. self.arr, self.orig_type = self.substitute_values(self.arr)
  148. elif substitute and metadata and not np.issubdtype(self.arr.dtype, np.integer):
  149. # custom metadata has been passed in from external source, and must be converted to int
  150. self.arr = self.arr.astype(int)
  151. self.metadata = { int(k):v for k, v in metadata.items() }
  152. self.metadata[self._nan] = missing_id
  153. self._groupings = {}
  154. if groupings is None:
  155. for x in np.unique(self.arr):
  156. self._groupings[x] = [x, x + 1, False]
  157. else:
  158. for x in np.unique(self.arr):
  159. self._groupings[x] = list(groupings[x])
  160. self._possible_groups = None
  161. def substitute_values(self, vect):
  162. if not np.issubdtype(vect.dtype, np.integer):
  163. uniq = set(vect)
  164. uniq_floats = np.array(list(uniq), dtype=float)
  165. uniq_ints = uniq_floats.astype(int)
  166. nan = self._missing_id
  167. self.metadata = {
  168. new: nan if isnan(as_float) else old
  169. for old, as_float, new in zip(uniq, uniq_floats, uniq_ints)
  170. }
  171. self.arr = self.arr.astype(float)
  172. return self.arr.astype(int), self.arr.dtype.type
  173. def deep_copy(self):
  174. """
  175. Returns a deep copy.
  176. """
  177. return OrdinalColumn(self.arr, metadata=self.metadata, name=self.name,
  178. missing_id=self._missing_id, substitute=True,
  179. groupings=self._groupings, weights=self.weights)
  180. def __getitem__(self, key):
  181. new_weights = None if self.weights is None else self.weights[key]
  182. return OrdinalColumn(self.arr[key], metadata=self.metadata, name=self.name,
  183. missing_id=self._missing_id, substitute=True,
  184. groupings=self._groupings, weights=new_weights)
  185. def __setitem__(self, key, value):
  186. self.arr[key] = value
  187. return self
  188. def groups(self):
  189. vals = self._groupings.values()
  190. return [
  191. [x for x in range(minmax[0], minmax[1])] + ([self._nan] if minmax[2] else [])
  192. for minmax in vals
  193. ]
  194. def possible_groupings(self):
  195. if self._possible_groups is None:
  196. ranges = sorted(self._groupings.items())
  197. candidates = zip(ranges[0:], ranges[1:])
  198. self._possible_groups = [
  199. (k1, k2) for (k1, minmax1), (k2, minmax2) in candidates
  200. if minmax1[1] == minmax2[0]
  201. ]
  202. if self._nan in self.arr:
  203. self._possible_groups += [
  204. (key, self._nan) for key in self._groupings.keys() if key != self._nan
  205. ]
  206. return self._possible_groups.__iter__()
  207. def all_combinations(self):
  208. bell_set = self.bell_set(sorted(list(self._groupings.keys())), True)
  209. next(bell_set)
  210. return bell_set
  211. def group(self, x, y):
  212. self._possible_groups = None
  213. if y != self._nan:
  214. x = int(x)
  215. y = int(y)
  216. x_max = self._groupings[x][1]
  217. y_min = self._groupings[y][0]
  218. if y_min >= x_max:
  219. self._groupings[x][1] = self._groupings[y][1]
  220. else:
  221. self._groupings[x][0] = y_min
  222. self._groupings[x][2] = self._groupings[x][2] or self._groupings[y][2]
  223. else:
  224. self._groupings[x][2] = True
  225. del self._groupings[y]
  226. self.arr[self.arr == y] = x
  227. @property
  228. def type(self):
  229. """
  230. Returns a string representing the type
  231. """
  232. return 'ordinal'
  233. class ContinuousColumn(Column):
  234. """
  235. A column containing numerical values on a continuous scale
  236. """
  237. def __init__(self, arr=None, metadata=None, missing_id='<missing>',
  238. weights=None):
  239. if not np.issubdtype(arr.dtype, np.number):
  240. raise ValueError('Must only pass numerical values to create continuous column')
  241. super(self.__class__, self).__init__(np.nan_to_num(arr), metadata, missing_id=missing_id, weights=weights)
  242. def deep_copy(self):
  243. """
  244. Returns a deep copy.
  245. """
  246. return ContinuousColumn(self.arr, metadata=self.metadata, missing_id=self._missing_id, weights=self.weights)
  247. def __getitem__(self, key):
  248. new_weights = None if self.weights is None else self.weights[key]
  249. return ContinuousColumn(self.arr[key], metadata=self.metadata, weights=new_weights)
  250. def __setitem__(self, key, value):
  251. self.arr[key] = value
  252. return self
  253. @property
  254. def type(self):
  255. """
  256. Returns a string representing the type
  257. """
  258. return 'continuous'