PageRenderTime 38ms CodeModel.GetById 33ms RepoModel.GetById 0ms app.codeStats 0ms

/pandas/core/groupby.py

https://github.com/ajcr/pandas
Python | 3509 lines | 3505 code | 4 blank | 0 comment | 9 complexity | 3a4baa974c53d4b762d06020f453f60c MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0
  1. import types
  2. from functools import wraps
  3. import numpy as np
  4. import datetime
  5. import collections
  6. from pandas.compat import(
  7. zip, builtins, range, long, lzip,
  8. OrderedDict, callable
  9. )
  10. from pandas import compat
  11. from pandas.core.base import PandasObject
  12. from pandas.core.categorical import Categorical
  13. from pandas.core.frame import DataFrame
  14. from pandas.core.generic import NDFrame
  15. from pandas.core.index import Index, MultiIndex, _ensure_index, _union_indexes
  16. from pandas.core.internals import BlockManager, make_block
  17. from pandas.core.series import Series
  18. from pandas.core.panel import Panel
  19. from pandas.util.decorators import cache_readonly, Appender
  20. import pandas.core.algorithms as algos
  21. import pandas.core.common as com
  22. from pandas.core.common import(_possibly_downcast_to_dtype, isnull,
  23. notnull, _DATELIKE_DTYPES, is_numeric_dtype,
  24. is_timedelta64_dtype, is_datetime64_dtype)
  25. from pandas import _np_version_under1p7
  26. import pandas.lib as lib
  27. from pandas.lib import Timestamp
  28. import pandas.tslib as tslib
  29. import pandas.algos as _algos
  30. import pandas.hashtable as _hash
  31. _agg_doc = """Aggregate using input function or dict of {column -> function}
  32. Parameters
  33. ----------
  34. arg : function or dict
  35. Function to use for aggregating groups. If a function, must either
  36. work when passed a DataFrame or when passed to DataFrame.apply. If
  37. passed a dict, the keys must be DataFrame column names.
  38. Notes
  39. -----
  40. Numpy functions mean/median/prod/sum/std/var are special cased so the
  41. default behavior is applying the function along axis=0
  42. (e.g., np.mean(arr_2d, axis=0)) as opposed to
  43. mimicking the default Numpy behavior (e.g., np.mean(arr_2d)).
  44. Returns
  45. -------
  46. aggregated : DataFrame
  47. """
  48. # special case to prevent duplicate plots when catching exceptions when
  49. # forwarding methods from NDFrames
  50. _plotting_methods = frozenset(['plot', 'boxplot', 'hist'])
  51. _common_apply_whitelist = frozenset([
  52. 'last', 'first',
  53. 'head', 'tail', 'median',
  54. 'mean', 'sum', 'min', 'max',
  55. 'cumsum', 'cumprod', 'cummin', 'cummax', 'cumcount',
  56. 'resample',
  57. 'describe',
  58. 'rank', 'quantile', 'count',
  59. 'fillna',
  60. 'mad',
  61. 'any', 'all',
  62. 'irow', 'take',
  63. 'idxmax', 'idxmin',
  64. 'shift', 'tshift',
  65. 'ffill', 'bfill',
  66. 'pct_change', 'skew',
  67. 'corr', 'cov', 'diff',
  68. ]) | _plotting_methods
  69. _series_apply_whitelist = \
  70. (_common_apply_whitelist - set(['boxplot'])) | \
  71. frozenset(['dtype', 'value_counts', 'unique', 'nunique',
  72. 'nlargest', 'nsmallest'])
  73. _dataframe_apply_whitelist = \
  74. _common_apply_whitelist | frozenset(['dtypes', 'corrwith'])
  75. class GroupByError(Exception):
  76. pass
  77. class DataError(GroupByError):
  78. pass
  79. class SpecificationError(GroupByError):
  80. pass
  81. def _groupby_function(name, alias, npfunc, numeric_only=True,
  82. _convert=False):
  83. def f(self):
  84. self._set_selection_from_grouper()
  85. try:
  86. return self._cython_agg_general(alias, numeric_only=numeric_only)
  87. except AssertionError as e:
  88. raise SpecificationError(str(e))
  89. except Exception:
  90. result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
  91. if _convert:
  92. result = result.convert_objects()
  93. return result
  94. f.__doc__ = "Compute %s of group values" % name
  95. f.__name__ = name
  96. return f
  97. def _first_compat(x, axis=0):
  98. def _first(x):
  99. x = np.asarray(x)
  100. x = x[notnull(x)]
  101. if len(x) == 0:
  102. return np.nan
  103. return x[0]
  104. if isinstance(x, DataFrame):
  105. return x.apply(_first, axis=axis)
  106. else:
  107. return _first(x)
  108. def _last_compat(x, axis=0):
  109. def _last(x):
  110. x = np.asarray(x)
  111. x = x[notnull(x)]
  112. if len(x) == 0:
  113. return np.nan
  114. return x[-1]
  115. if isinstance(x, DataFrame):
  116. return x.apply(_last, axis=axis)
  117. else:
  118. return _last(x)
  119. def _count_compat(x, axis=0):
  120. return x.size
  121. class Grouper(object):
  122. """
  123. A Grouper allows the user to specify a groupby instruction for a target object
  124. This specification will select a column via the key parameter, or if the level and/or
  125. axis parameters are given, a level of the index of the target object.
  126. These are local specifications and will override 'global' settings, that is the parameters
  127. axis and level which are passed to the groupby itself.
  128. Parameters
  129. ----------
  130. key : string, defaults to None
  131. groupby key, which selects the grouping column of the target
  132. level : name/number, defaults to None
  133. the level for the target index
  134. freq : string / freqency object, defaults to None
  135. This will groupby the specified frequency if the target selection (via key or level) is
  136. a datetime-like object
  137. axis : number/name of the axis, defaults to None
  138. sort : boolean, default to False
  139. whether to sort the resulting labels
  140. additional kwargs to control time-like groupers (when freq is passed)
  141. closed : closed end of interval; left or right
  142. label : interval boundary to use for labeling; left or right
  143. convention : {'start', 'end', 'e', 's'}
  144. If grouper is PeriodIndex
  145. Returns
  146. -------
  147. A specification for a groupby instruction
  148. Examples
  149. --------
  150. >>> df.groupby(Grouper(key='A')) : syntatic sugar for df.groupby('A')
  151. >>> df.groupby(Grouper(key='date',freq='60s')) : specify a resample on the column 'date'
  152. >>> df.groupby(Grouper(level='date',freq='60s',axis=1)) :
  153. specify a resample on the level 'date' on the columns axis with a frequency of 60s
  154. """
  155. def __new__(cls, *args, **kwargs):
  156. if kwargs.get('freq') is not None:
  157. from pandas.tseries.resample import TimeGrouper
  158. cls = TimeGrouper
  159. return super(Grouper, cls).__new__(cls)
  160. def __init__(self, key=None, level=None, freq=None, axis=None, sort=False):
  161. self.key=key
  162. self.level=level
  163. self.freq=freq
  164. self.axis=axis
  165. self.sort=sort
  166. self.grouper=None
  167. self.obj=None
  168. self.indexer=None
  169. self.binner=None
  170. self.grouper=None
  171. @property
  172. def ax(self):
  173. return self.grouper
  174. def _get_grouper(self, obj):
  175. """
  176. Parameters
  177. ----------
  178. obj : the subject object
  179. Returns
  180. -------
  181. a tuple of binner, grouper, obj (possibly sorted)
  182. """
  183. self._set_grouper(obj)
  184. return self.binner, self.grouper, self.obj
  185. def _set_grouper(self, obj, sort=False):
  186. """
  187. given an object and the specifcations, setup the internal grouper for this particular specification
  188. Parameters
  189. ----------
  190. obj : the subject object
  191. """
  192. if self.key is not None and self.level is not None:
  193. raise ValueError("The Grouper cannot specify both a key and a level!")
  194. # the key must be a valid info item
  195. if self.key is not None:
  196. key = self.key
  197. if key not in obj._info_axis:
  198. raise KeyError("The grouper name {0} is not found".format(key))
  199. ax = Index(obj[key],name=key)
  200. else:
  201. ax = obj._get_axis(self.axis)
  202. if self.level is not None:
  203. level = self.level
  204. # if a level is given it must be a mi level or
  205. # equivalent to the axis name
  206. if isinstance(ax, MultiIndex):
  207. if isinstance(level, compat.string_types):
  208. if obj.index.name != level:
  209. raise ValueError('level name %s is not the name of the '
  210. 'index' % level)
  211. elif level > 0:
  212. raise ValueError('level > 0 only valid with MultiIndex')
  213. ax = Index(ax.get_level_values(level), name=level)
  214. else:
  215. if not (level == 0 or level == ax.name):
  216. raise ValueError("The grouper level {0} is not valid".format(level))
  217. # possibly sort
  218. if (self.sort or sort) and not ax.is_monotonic:
  219. indexer = self.indexer = ax.argsort(kind='quicksort')
  220. ax = ax.take(indexer)
  221. obj = obj.take(indexer, axis=self.axis, convert=False, is_copy=False)
  222. self.obj = obj
  223. self.grouper = ax
  224. return self.grouper
  225. def _get_binner_for_grouping(self, obj):
  226. raise NotImplementedError
  227. @property
  228. def groups(self):
  229. return self.grouper.groups
  230. class GroupBy(PandasObject):
  231. """
  232. Class for grouping and aggregating relational data. See aggregate,
  233. transform, and apply functions on this object.
  234. It's easiest to use obj.groupby(...) to use GroupBy, but you can also do:
  235. ::
  236. grouped = groupby(obj, ...)
  237. Parameters
  238. ----------
  239. obj : pandas object
  240. axis : int, default 0
  241. level : int, default None
  242. Level of MultiIndex
  243. groupings : list of Grouping objects
  244. Most users should ignore this
  245. exclusions : array-like, optional
  246. List of columns to exclude
  247. name : string
  248. Most users should ignore this
  249. Notes
  250. -----
  251. After grouping, see aggregate, apply, and transform functions. Here are
  252. some other brief notes about usage. When grouping by multiple groups, the
  253. result index will be a MultiIndex (hierarchical) by default.
  254. Iteration produces (key, group) tuples, i.e. chunking the data by group. So
  255. you can write code like:
  256. ::
  257. grouped = obj.groupby(keys, axis=axis)
  258. for key, group in grouped:
  259. # do something with the data
  260. Function calls on GroupBy, if not specially implemented, "dispatch" to the
  261. grouped data. So if you group a DataFrame and wish to invoke the std()
  262. method on each group, you can simply do:
  263. ::
  264. df.groupby(mapper).std()
  265. rather than
  266. ::
  267. df.groupby(mapper).aggregate(np.std)
  268. You can pass arguments to these "wrapped" functions, too.
  269. See the online documentation for full exposition on these topics and much
  270. more
  271. Returns
  272. -------
  273. **Attributes**
  274. groups : dict
  275. {group name -> group labels}
  276. len(grouped) : int
  277. Number of groups
  278. """
  279. _apply_whitelist = _common_apply_whitelist
  280. _internal_names = ['_cache']
  281. _internal_names_set = set(_internal_names)
  282. _group_selection = None
  283. def __init__(self, obj, keys=None, axis=0, level=None,
  284. grouper=None, exclusions=None, selection=None, as_index=True,
  285. sort=True, group_keys=True, squeeze=False):
  286. self._selection = selection
  287. if isinstance(obj, NDFrame):
  288. obj._consolidate_inplace()
  289. self.level = level
  290. if not as_index:
  291. if not isinstance(obj, DataFrame):
  292. raise TypeError('as_index=False only valid with DataFrame')
  293. if axis != 0:
  294. raise ValueError('as_index=False only valid for axis=0')
  295. self.as_index = as_index
  296. self.keys = keys
  297. self.sort = sort
  298. self.group_keys = group_keys
  299. self.squeeze = squeeze
  300. if grouper is None:
  301. grouper, exclusions, obj = _get_grouper(obj, keys, axis=axis,
  302. level=level, sort=sort)
  303. self.obj = obj
  304. self.axis = obj._get_axis_number(axis)
  305. self.grouper = grouper
  306. self.exclusions = set(exclusions) if exclusions else set()
  307. def __len__(self):
  308. return len(self.indices)
  309. def __unicode__(self):
  310. # TODO: Better unicode/repr for GroupBy object
  311. return object.__repr__(self)
  312. @property
  313. def groups(self):
  314. """ dict {group name -> group labels} """
  315. return self.grouper.groups
  316. @property
  317. def ngroups(self):
  318. return self.grouper.ngroups
  319. @property
  320. def indices(self):
  321. """ dict {group name -> group indices} """
  322. return self.grouper.indices
  323. def _get_index(self, name):
  324. """ safe get index, translate keys for datelike to underlying repr """
  325. def convert(key, s):
  326. # possibly convert to they actual key types
  327. # in the indices, could be a Timestamp or a np.datetime64
  328. if isinstance(s, (Timestamp,datetime.datetime)):
  329. return Timestamp(key)
  330. elif isinstance(s, np.datetime64):
  331. return Timestamp(key).asm8
  332. return key
  333. sample = next(iter(self.indices))
  334. if isinstance(sample, tuple):
  335. if not isinstance(name, tuple):
  336. raise ValueError("must supply a tuple to get_group with multiple grouping keys")
  337. if not len(name) == len(sample):
  338. raise ValueError("must supply a a same-length tuple to get_group with multiple grouping keys")
  339. name = tuple([ convert(n, k) for n, k in zip(name,sample) ])
  340. else:
  341. name = convert(name, sample)
  342. return self.indices[name]
  343. @property
  344. def name(self):
  345. if self._selection is None:
  346. return None # 'result'
  347. else:
  348. return self._selection
  349. @property
  350. def _selection_list(self):
  351. if not isinstance(self._selection, (list, tuple, Series, np.ndarray)):
  352. return [self._selection]
  353. return self._selection
  354. @cache_readonly
  355. def _selected_obj(self):
  356. if self._selection is None or isinstance(self.obj, Series):
  357. if self._group_selection is not None:
  358. return self.obj[self._group_selection]
  359. return self.obj
  360. else:
  361. return self.obj[self._selection]
  362. def _set_selection_from_grouper(self):
  363. """ we may need create a selection if we have non-level groupers """
  364. grp = self.grouper
  365. if self.as_index and getattr(grp,'groupings',None) is not None:
  366. ax = self.obj._info_axis
  367. groupers = [ g.name for g in grp.groupings if g.level is None and g.name is not None and g.name in ax ]
  368. if len(groupers):
  369. self._group_selection = (ax-Index(groupers)).tolist()
  370. def _local_dir(self):
  371. return sorted(set(self.obj._local_dir() + list(self._apply_whitelist)))
  372. def __getattr__(self, attr):
  373. if attr in self._internal_names_set:
  374. return object.__getattribute__(self, attr)
  375. if attr in self.obj:
  376. return self[attr]
  377. if hasattr(self.obj, attr):
  378. return self._make_wrapper(attr)
  379. raise AttributeError("%r object has no attribute %r" %
  380. (type(self).__name__, attr))
  381. def __getitem__(self, key):
  382. raise NotImplementedError('Not implemented: %s' % key)
  383. def _make_wrapper(self, name):
  384. if name not in self._apply_whitelist:
  385. is_callable = callable(getattr(self._selected_obj, name, None))
  386. kind = ' callable ' if is_callable else ' '
  387. msg = ("Cannot access{0}attribute {1!r} of {2!r} objects, try "
  388. "using the 'apply' method".format(kind, name,
  389. type(self).__name__))
  390. raise AttributeError(msg)
  391. # need to setup the selection
  392. # as are not passed directly but in the grouper
  393. self._set_selection_from_grouper()
  394. f = getattr(self._selected_obj, name)
  395. if not isinstance(f, types.MethodType):
  396. return self.apply(lambda self: getattr(self, name))
  397. f = getattr(type(self._selected_obj), name)
  398. def wrapper(*args, **kwargs):
  399. # a little trickery for aggregation functions that need an axis
  400. # argument
  401. kwargs_with_axis = kwargs.copy()
  402. if 'axis' not in kwargs_with_axis:
  403. kwargs_with_axis['axis'] = self.axis
  404. def curried_with_axis(x):
  405. return f(x, *args, **kwargs_with_axis)
  406. def curried(x):
  407. return f(x, *args, **kwargs)
  408. # preserve the name so we can detect it when calling plot methods,
  409. # to avoid duplicates
  410. curried.__name__ = curried_with_axis.__name__ = name
  411. # special case otherwise extra plots are created when catching the
  412. # exception below
  413. if name in _plotting_methods:
  414. return self.apply(curried)
  415. try:
  416. return self.apply(curried_with_axis)
  417. except Exception:
  418. try:
  419. return self.apply(curried)
  420. except Exception:
  421. # related to : GH3688
  422. # try item-by-item
  423. # this can be called recursively, so need to raise ValueError if
  424. # we don't have this method to indicated to aggregate to
  425. # mark this column as an error
  426. try:
  427. return self._aggregate_item_by_item(name, *args, **kwargs)
  428. except (AttributeError):
  429. raise ValueError
  430. return wrapper
  431. def get_group(self, name, obj=None):
  432. """
  433. Constructs NDFrame from group with provided name
  434. Parameters
  435. ----------
  436. name : object
  437. the name of the group to get as a DataFrame
  438. obj : NDFrame, default None
  439. the NDFrame to take the DataFrame out of. If
  440. it is None, the object groupby was called on will
  441. be used
  442. Returns
  443. -------
  444. group : type of obj
  445. """
  446. if obj is None:
  447. obj = self._selected_obj
  448. inds = self._get_index(name)
  449. return obj.take(inds, axis=self.axis, convert=False)
  450. def __iter__(self):
  451. """
  452. Groupby iterator
  453. Returns
  454. -------
  455. Generator yielding sequence of (name, subsetted object)
  456. for each group
  457. """
  458. return self.grouper.get_iterator(self.obj, axis=self.axis)
  459. def apply(self, func, *args, **kwargs):
  460. """
  461. Apply function and combine results together in an intelligent way. The
  462. split-apply-combine combination rules attempt to be as common sense
  463. based as possible. For example:
  464. case 1:
  465. group DataFrame
  466. apply aggregation function (f(chunk) -> Series)
  467. yield DataFrame, with group axis having group labels
  468. case 2:
  469. group DataFrame
  470. apply transform function ((f(chunk) -> DataFrame with same indexes)
  471. yield DataFrame with resulting chunks glued together
  472. case 3:
  473. group Series
  474. apply function with f(chunk) -> DataFrame
  475. yield DataFrame with result of chunks glued together
  476. Parameters
  477. ----------
  478. func : function
  479. Notes
  480. -----
  481. See online documentation for full exposition on how to use apply.
  482. In the current implementation apply calls func twice on the
  483. first group to decide whether it can take a fast or slow code
  484. path. This can lead to unexpected behavior if func has
  485. side-effects, as they will take effect twice for the first
  486. group.
  487. See also
  488. --------
  489. aggregate, transform
  490. Returns
  491. -------
  492. applied : type depending on grouped object and function
  493. """
  494. func = _intercept_function(func)
  495. @wraps(func)
  496. def f(g):
  497. return func(g, *args, **kwargs)
  498. return self._python_apply_general(f)
  499. def _python_apply_general(self, f):
  500. keys, values, mutated = self.grouper.apply(f, self._selected_obj,
  501. self.axis)
  502. return self._wrap_applied_output(keys, values,
  503. not_indexed_same=mutated)
  504. def aggregate(self, func, *args, **kwargs):
  505. raise NotImplementedError
  506. @Appender(_agg_doc)
  507. def agg(self, func, *args, **kwargs):
  508. return self.aggregate(func, *args, **kwargs)
  509. def _iterate_slices(self):
  510. yield self.name, self._selected_obj
  511. def transform(self, func, *args, **kwargs):
  512. raise NotImplementedError
  513. def mean(self):
  514. """
  515. Compute mean of groups, excluding missing values
  516. For multiple groupings, the result index will be a MultiIndex
  517. """
  518. try:
  519. return self._cython_agg_general('mean')
  520. except GroupByError:
  521. raise
  522. except Exception: # pragma: no cover
  523. self._set_selection_from_grouper()
  524. f = lambda x: x.mean(axis=self.axis)
  525. return self._python_agg_general(f)
  526. def median(self):
  527. """
  528. Compute median of groups, excluding missing values
  529. For multiple groupings, the result index will be a MultiIndex
  530. """
  531. try:
  532. return self._cython_agg_general('median')
  533. except GroupByError:
  534. raise
  535. except Exception: # pragma: no cover
  536. self._set_selection_from_grouper()
  537. def f(x):
  538. if isinstance(x, np.ndarray):
  539. x = Series(x)
  540. return x.median(axis=self.axis)
  541. return self._python_agg_general(f)
  542. def std(self, ddof=1):
  543. """
  544. Compute standard deviation of groups, excluding missing values
  545. For multiple groupings, the result index will be a MultiIndex
  546. """
  547. # todo, implement at cython level?
  548. return np.sqrt(self.var(ddof=ddof))
  549. def var(self, ddof=1):
  550. """
  551. Compute variance of groups, excluding missing values
  552. For multiple groupings, the result index will be a MultiIndex
  553. """
  554. if ddof == 1:
  555. return self._cython_agg_general('var')
  556. else:
  557. self._set_selection_from_grouper()
  558. f = lambda x: x.var(ddof=ddof)
  559. return self._python_agg_general(f)
  560. def sem(self, ddof=1):
  561. """
  562. Compute standard error of the mean of groups, excluding missing values
  563. For multiple groupings, the result index will be a MultiIndex
  564. """
  565. return self.std(ddof=ddof)/np.sqrt(self.count())
  566. def size(self):
  567. """
  568. Compute group sizes
  569. """
  570. return self.grouper.size()
  571. sum = _groupby_function('sum', 'add', np.sum)
  572. prod = _groupby_function('prod', 'prod', np.prod)
  573. min = _groupby_function('min', 'min', np.min, numeric_only=False)
  574. max = _groupby_function('max', 'max', np.max, numeric_only=False)
  575. first = _groupby_function('first', 'first', _first_compat,
  576. numeric_only=False, _convert=True)
  577. last = _groupby_function('last', 'last', _last_compat, numeric_only=False,
  578. _convert=True)
  579. _count = _groupby_function('_count', 'count', _count_compat,
  580. numeric_only=False)
  581. def count(self, axis=0):
  582. return self._count().astype('int64')
  583. def ohlc(self):
  584. """
  585. Compute sum of values, excluding missing values
  586. For multiple groupings, the result index will be a MultiIndex
  587. """
  588. return self._apply_to_column_groupbys(
  589. lambda x: x._cython_agg_general('ohlc'))
  590. def nth(self, n, dropna=None):
  591. """
  592. Take the nth row from each group.
  593. If dropna, will not show nth non-null row, dropna is either
  594. Truthy (if a Series) or 'all', 'any' (if a DataFrame); this is equivalent
  595. to calling dropna(how=dropna) before the groupby.
  596. Examples
  597. --------
  598. >>> DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=['A', 'B'])
  599. >>> g = df.groupby('A')
  600. >>> g.nth(0)
  601. A B
  602. 0 1 NaN
  603. 2 5 6
  604. >>> g.nth(1)
  605. A B
  606. 1 1 4
  607. >>> g.nth(-1)
  608. A B
  609. 1 1 4
  610. 2 5 6
  611. >>> g.nth(0, dropna='any')
  612. B
  613. A
  614. 1 4
  615. 5 6
  616. >>> g.nth(1, dropna='any') # NaNs denote group exhausted when using dropna
  617. B
  618. A
  619. 1 NaN
  620. 5 NaN
  621. """
  622. self._set_selection_from_grouper()
  623. if not dropna: # good choice
  624. m = self.grouper._max_groupsize
  625. if n >= m or n < -m:
  626. return self._selected_obj.loc[[]]
  627. rng = np.zeros(m, dtype=bool)
  628. if n >= 0:
  629. rng[n] = True
  630. is_nth = self._cumcount_array(rng)
  631. else:
  632. rng[- n - 1] = True
  633. is_nth = self._cumcount_array(rng, ascending=False)
  634. result = self._selected_obj[is_nth]
  635. # the result index
  636. if self.as_index:
  637. ax = self.obj._info_axis
  638. names = self.grouper.names
  639. if all([ n in ax for n in names ]):
  640. result.index = Index(self.obj[names][is_nth].values.ravel()).set_names(names)
  641. elif self._group_selection is not None:
  642. result.index = self.obj._get_axis(self.axis)[is_nth]
  643. result = result.sort_index()
  644. return result
  645. if (isinstance(self._selected_obj, DataFrame)
  646. and dropna not in ['any', 'all']):
  647. # Note: when agg-ing picker doesn't raise this, just returns NaN
  648. raise ValueError("For a DataFrame groupby, dropna must be "
  649. "either None, 'any' or 'all', "
  650. "(was passed %s)." % (dropna),)
  651. # old behaviour, but with all and any support for DataFrames.
  652. max_len = n if n >= 0 else - 1 - n
  653. def picker(x):
  654. x = x.dropna(how=dropna) # Note: how is ignored if Series
  655. if len(x) <= max_len:
  656. return np.nan
  657. else:
  658. return x.iloc[n]
  659. return self.agg(picker)
  660. def cumcount(self, **kwargs):
  661. """
  662. Number each item in each group from 0 to the length of that group - 1.
  663. Essentially this is equivalent to
  664. >>> self.apply(lambda x: Series(np.arange(len(x)), x.index))
  665. Parameters
  666. ----------
  667. ascending : bool, default True
  668. If False, number in reverse, from length of group - 1 to 0.
  669. Example
  670. -------
  671. >>> df = pd.DataFrame([['a'], ['a'], ['a'], ['b'], ['b'], ['a']],
  672. ... columns=['A'])
  673. >>> df
  674. A
  675. 0 a
  676. 1 a
  677. 2 a
  678. 3 b
  679. 4 b
  680. 5 a
  681. >>> df.groupby('A').cumcount()
  682. 0 0
  683. 1 1
  684. 2 2
  685. 3 0
  686. 4 1
  687. 5 3
  688. dtype: int64
  689. >>> df.groupby('A').cumcount(ascending=False)
  690. 0 3
  691. 1 2
  692. 2 1
  693. 3 1
  694. 4 0
  695. 5 0
  696. dtype: int64
  697. """
  698. self._set_selection_from_grouper()
  699. ascending = kwargs.pop('ascending', True)
  700. index = self._selected_obj.index
  701. cumcounts = self._cumcount_array(ascending=ascending)
  702. return Series(cumcounts, index)
  703. def head(self, n=5):
  704. """
  705. Returns first n rows of each group.
  706. Essentially equivalent to ``.apply(lambda x: x.head(n))``,
  707. except ignores as_index flag.
  708. Example
  709. -------
  710. >>> df = DataFrame([[1, 2], [1, 4], [5, 6]],
  711. columns=['A', 'B'])
  712. >>> df.groupby('A', as_index=False).head(1)
  713. A B
  714. 0 1 2
  715. 2 5 6
  716. >>> df.groupby('A').head(1)
  717. A B
  718. 0 1 2
  719. 2 5 6
  720. """
  721. obj = self._selected_obj
  722. in_head = self._cumcount_array() < n
  723. head = obj[in_head]
  724. return head
  725. def tail(self, n=5):
  726. """
  727. Returns last n rows of each group
  728. Essentially equivalent to ``.apply(lambda x: x.tail(n))``,
  729. except ignores as_index flag.
  730. Example
  731. -------
  732. >>> df = DataFrame([[1, 2], [1, 4], [5, 6]],
  733. columns=['A', 'B'])
  734. >>> df.groupby('A', as_index=False).tail(1)
  735. A B
  736. 0 1 2
  737. 2 5 6
  738. >>> df.groupby('A').head(1)
  739. A B
  740. 0 1 2
  741. 2 5 6
  742. """
  743. obj = self._selected_obj
  744. rng = np.arange(0, -self.grouper._max_groupsize, -1, dtype='int64')
  745. in_tail = self._cumcount_array(rng, ascending=False) > -n
  746. tail = obj[in_tail]
  747. return tail
  748. def _cumcount_array(self, arr=None, **kwargs):
  749. """
  750. arr is where cumcount gets it's values from
  751. """
  752. ascending = kwargs.pop('ascending', True)
  753. if arr is None:
  754. arr = np.arange(self.grouper._max_groupsize, dtype='int64')
  755. len_index = len(self._selected_obj.index)
  756. cumcounts = np.empty(len_index, dtype=arr.dtype)
  757. if ascending:
  758. for v in self.indices.values():
  759. cumcounts[v] = arr[:len(v)]
  760. else:
  761. for v in self.indices.values():
  762. cumcounts[v] = arr[len(v)-1::-1]
  763. return cumcounts
  764. def _index_with_as_index(self, b):
  765. """
  766. Take boolean mask of index to be returned from apply, if as_index=True
  767. """
  768. # TODO perf, it feels like this should already be somewhere...
  769. from itertools import chain
  770. original = self._selected_obj.index
  771. gp = self.grouper
  772. levels = chain((gp.levels[i][gp.labels[i][b]]
  773. for i in range(len(gp.groupings))),
  774. (original.get_level_values(i)[b]
  775. for i in range(original.nlevels)))
  776. new = MultiIndex.from_arrays(list(levels))
  777. new.names = gp.names + original.names
  778. return new
  779. def _try_cast(self, result, obj):
  780. """
  781. try to cast the result to our obj original type,
  782. we may have roundtripped thru object in the mean-time
  783. """
  784. if obj.ndim > 1:
  785. dtype = obj.values.dtype
  786. else:
  787. dtype = obj.dtype
  788. if not np.isscalar(result):
  789. result = _possibly_downcast_to_dtype(result, dtype)
  790. return result
  791. def _cython_agg_general(self, how, numeric_only=True):
  792. output = {}
  793. for name, obj in self._iterate_slices():
  794. is_numeric = is_numeric_dtype(obj.dtype)
  795. if numeric_only and not is_numeric:
  796. continue
  797. try:
  798. result, names = self.grouper.aggregate(obj.values, how)
  799. except AssertionError as e:
  800. raise GroupByError(str(e))
  801. output[name] = self._try_cast(result, obj)
  802. if len(output) == 0:
  803. raise DataError('No numeric types to aggregate')
  804. return self._wrap_aggregated_output(output, names)
  805. def _python_agg_general(self, func, *args, **kwargs):
  806. func = _intercept_function(func)
  807. f = lambda x: func(x, *args, **kwargs)
  808. # iterate through "columns" ex exclusions to populate output dict
  809. output = {}
  810. for name, obj in self._iterate_slices():
  811. try:
  812. result, counts = self.grouper.agg_series(obj, f)
  813. output[name] = self._try_cast(result, obj)
  814. except TypeError:
  815. continue
  816. if len(output) == 0:
  817. return self._python_apply_general(f)
  818. if self.grouper._filter_empty_groups:
  819. mask = counts.ravel() > 0
  820. for name, result in compat.iteritems(output):
  821. # since we are masking, make sure that we have a float object
  822. values = result
  823. if is_numeric_dtype(values.dtype):
  824. values = com.ensure_float(values)
  825. output[name] = self._try_cast(values[mask], result)
  826. return self._wrap_aggregated_output(output)
  827. def _wrap_applied_output(self, *args, **kwargs):
  828. raise NotImplementedError
  829. def _concat_objects(self, keys, values, not_indexed_same=False):
  830. from pandas.tools.merge import concat
  831. if not not_indexed_same:
  832. result = concat(values, axis=self.axis)
  833. ax = self._selected_obj._get_axis(self.axis)
  834. if isinstance(result, Series):
  835. result = result.reindex(ax)
  836. else:
  837. result = result.reindex_axis(ax, axis=self.axis)
  838. elif self.group_keys:
  839. if self.as_index:
  840. # possible MI return case
  841. group_keys = keys
  842. group_levels = self.grouper.levels
  843. group_names = self.grouper.names
  844. result = concat(values, axis=self.axis, keys=group_keys,
  845. levels=group_levels, names=group_names)
  846. else:
  847. # GH5610, returns a MI, with the first level being a
  848. # range index
  849. keys = list(range(len(values)))
  850. result = concat(values, axis=self.axis, keys=keys)
  851. else:
  852. result = concat(values, axis=self.axis)
  853. return result
  854. def _apply_filter(self, indices, dropna):
  855. if len(indices) == 0:
  856. indices = []
  857. else:
  858. indices = np.sort(np.concatenate(indices))
  859. if dropna:
  860. filtered = self._selected_obj.take(indices)
  861. else:
  862. mask = np.empty(len(self._selected_obj.index), dtype=bool)
  863. mask.fill(False)
  864. mask[indices.astype(int)] = True
  865. # mask fails to broadcast when passed to where; broadcast manually.
  866. mask = np.tile(mask, list(self._selected_obj.shape[1:]) + [1]).T
  867. filtered = self._selected_obj.where(mask) # Fill with NaNs.
  868. return filtered
  869. @Appender(GroupBy.__doc__)
  870. def groupby(obj, by, **kwds):
  871. if isinstance(obj, Series):
  872. klass = SeriesGroupBy
  873. elif isinstance(obj, DataFrame):
  874. klass = DataFrameGroupBy
  875. else: # pragma: no cover
  876. raise TypeError('invalid type: %s' % type(obj))
  877. return klass(obj, by, **kwds)
  878. def _get_axes(group):
  879. if isinstance(group, Series):
  880. return [group.index]
  881. else:
  882. return group.axes
  883. def _is_indexed_like(obj, axes):
  884. if isinstance(obj, Series):
  885. if len(axes) > 1:
  886. return False
  887. return obj.index.equals(axes[0])
  888. elif isinstance(obj, DataFrame):
  889. return obj.index.equals(axes[0])
  890. return False
  891. class BaseGrouper(object):
  892. """
  893. This is an internal Grouper class, which actually holds the generated groups
  894. """
  895. def __init__(self, axis, groupings, sort=True, group_keys=True):
  896. self.axis = axis
  897. self.groupings = groupings
  898. self.sort = sort
  899. self.group_keys = group_keys
  900. self.compressed = True
  901. @property
  902. def shape(self):
  903. return tuple(ping.ngroups for ping in self.groupings)
  904. def __iter__(self):
  905. return iter(self.indices)
  906. @property
  907. def nkeys(self):
  908. return len(self.groupings)
  909. def get_iterator(self, data, axis=0):
  910. """
  911. Groupby iterator
  912. Returns
  913. -------
  914. Generator yielding sequence of (name, subsetted object)
  915. for each group
  916. """
  917. splitter = self._get_splitter(data, axis=axis)
  918. keys = self._get_group_keys()
  919. for key, (i, group) in zip(keys, splitter):
  920. yield key, group
  921. def _get_splitter(self, data, axis=0):
  922. comp_ids, _, ngroups = self.group_info
  923. return get_splitter(data, comp_ids, ngroups, axis=axis)
  924. def _get_group_keys(self):
  925. if len(self.groupings) == 1:
  926. return self.levels[0]
  927. else:
  928. comp_ids, _, ngroups = self.group_info
  929. # provide "flattened" iterator for multi-group setting
  930. mapper = _KeyMapper(comp_ids, ngroups, self.labels, self.levels)
  931. return [mapper.get_key(i) for i in range(ngroups)]
  932. def apply(self, f, data, axis=0):
  933. mutated = False
  934. splitter = self._get_splitter(data, axis=axis)
  935. group_keys = self._get_group_keys()
  936. # oh boy
  937. if (f.__name__ not in _plotting_methods and
  938. hasattr(splitter, 'fast_apply') and axis == 0):
  939. try:
  940. values, mutated = splitter.fast_apply(f, group_keys)
  941. return group_keys, values, mutated
  942. except (lib.InvalidApply):
  943. # we detect a mutation of some kind
  944. # so take slow path
  945. pass
  946. except (Exception) as e:
  947. # raise this error to the caller
  948. pass
  949. result_values = []
  950. for key, (i, group) in zip(group_keys, splitter):
  951. object.__setattr__(group, 'name', key)
  952. # group might be modified
  953. group_axes = _get_axes(group)
  954. res = f(group)
  955. if not _is_indexed_like(res, group_axes):
  956. mutated = True
  957. result_values.append(res)
  958. return group_keys, result_values, mutated
  959. @cache_readonly
  960. def indices(self):
  961. """ dict {group name -> group indices} """
  962. if len(self.groupings) == 1:
  963. return self.groupings[0].indices
  964. else:
  965. label_list = [ping.labels for ping in self.groupings]
  966. keys = [ping.group_index for ping in self.groupings]
  967. return _get_indices_dict(label_list, keys)
  968. @property
  969. def labels(self):
  970. return [ping.labels for ping in self.groupings]
  971. @property
  972. def levels(self):
  973. return [ping.group_index for ping in self.groupings]
  974. @property
  975. def names(self):
  976. return [ping.name for ping in self.groupings]
  977. def size(self):
  978. """
  979. Compute group sizes
  980. """
  981. # TODO: better impl
  982. labels, _, ngroups = self.group_info
  983. bin_counts = algos.value_counts(labels, sort=False)
  984. bin_counts = bin_counts.reindex(np.arange(ngroups))
  985. bin_counts.index = self.result_index
  986. return bin_counts
  987. @cache_readonly
  988. def _max_groupsize(self):
  989. '''
  990. Compute size of largest group
  991. '''
  992. # For many items in each group this is much faster than
  993. # self.size().max(), in worst case marginally slower
  994. if self.indices:
  995. return max(len(v) for v in self.indices.values())
  996. else:
  997. return 0
  998. @cache_readonly
  999. def groups(self):
  1000. """ dict {group name -> group labels} """
  1001. if len(self.groupings) == 1:
  1002. return self.groupings[0].groups
  1003. else:
  1004. to_groupby = lzip(*(ping.grouper for ping in self.groupings))
  1005. to_groupby = Index(to_groupby)
  1006. return self.axis.groupby(to_groupby.values)
  1007. @cache_readonly
  1008. def group_info(self):
  1009. comp_ids, obs_group_ids = self._get_compressed_labels()
  1010. ngroups = len(obs_group_ids)
  1011. comp_ids = com._ensure_int64(comp_ids)
  1012. return comp_ids, obs_group_ids, ngroups
  1013. def _get_compressed_labels(self):
  1014. all_labels = [ping.labels for ping in self.groupings]
  1015. if self._overflow_possible:
  1016. tups = lib.fast_zip(all_labels)
  1017. labs, uniques = algos.factorize(tups)
  1018. if self.sort:
  1019. uniques, labs = _reorder_by_uniques(uniques, labs)
  1020. return labs, uniques
  1021. else:
  1022. if len(all_labels) > 1:
  1023. group_index = get_group_index(all_labels, self.shape)
  1024. comp_ids, obs_group_ids = _compress_group_index(group_index)
  1025. else:
  1026. ping = self.groupings[0]
  1027. comp_ids = ping.labels
  1028. obs_group_ids = np.arange(len(ping.group_index))
  1029. self.compressed = False
  1030. self._filter_empty_groups = False
  1031. return comp_ids, obs_group_ids
  1032. @cache_readonly
  1033. def _overflow_possible(self):
  1034. return _int64_overflow_possible(self.shape)
  1035. @cache_readonly
  1036. def ngroups(self):
  1037. return len(self.result_index)
  1038. @cache_readonly
  1039. def result_index(self):
  1040. recons = self.get_group_levels()
  1041. return MultiIndex.from_arrays(recons, names=self.names)
  1042. def get_group_levels(self):
  1043. obs_ids = self.group_info[1]
  1044. if not self.compressed and len(self.groupings) == 1:
  1045. return [self.groupings[0].group_index]
  1046. if self._overflow_possible:
  1047. recons_labels = [np.array(x) for x in zip(*obs_ids)]
  1048. else:
  1049. recons_labels = decons_group_index(obs_ids, self.shape)
  1050. name_list = []
  1051. for ping, labels in zip(self.groupings, recons_labels):
  1052. labels = com._ensure_platform_int(labels)
  1053. name_list.append(ping.group_index.take(labels))
  1054. return name_list
  1055. #------------------------------------------------------------
  1056. # Aggregation functions
  1057. _cython_functions = {
  1058. 'add': 'group_add',
  1059. 'prod': 'group_prod',
  1060. 'min': 'group_min',
  1061. 'max': 'group_max',
  1062. 'mean': 'group_mean',
  1063. 'median': {
  1064. 'name': 'group_median'
  1065. },
  1066. 'var': 'group_var',
  1067. 'first': {
  1068. 'name': 'group_nth',
  1069. 'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
  1070. },
  1071. 'last': 'group_last',
  1072. 'count': 'group_count',
  1073. }
  1074. _cython_arity = {
  1075. 'ohlc': 4, # OHLC
  1076. }
  1077. _name_functions = {}
  1078. _filter_empty_groups = True
  1079. def _get_aggregate_function(self, how, values):
  1080. dtype_str = values.dtype.name
  1081. def get_func(fname):
  1082. # find the function, or use the object function, or return a
  1083. # generic
  1084. for dt in [dtype_str, 'object']:
  1085. f = getattr(_algos, "%s_%s" % (fname, dtype_str), None)
  1086. if f is not None:
  1087. return f
  1088. return getattr(_algos, fname, None)
  1089. ftype = self._cython_functions[how]
  1090. if isinstance(ftype, dict):
  1091. func = afunc = get_func(ftype['name'])
  1092. # a sub-function
  1093. f = ftype.get('f')
  1094. if f is not None:
  1095. def wrapper(*args, **kwargs):
  1096. return f(afunc, *args, **kwargs)
  1097. # need to curry our sub-function
  1098. func = wrapper
  1099. else:
  1100. func = get_func(ftype)
  1101. if func is None:
  1102. raise NotImplementedError("function is not implemented for this"
  1103. "dtype: [how->%s,dtype->%s]" %
  1104. (how, dtype_str))
  1105. return func, dtype_str
  1106. def aggregate(self, values, how, axis=0):
  1107. arity = self._cython_arity.get(how, 1)
  1108. vdim = values.ndim
  1109. swapped = False
  1110. if vdim == 1:
  1111. values = values[:, None]
  1112. out_shape = (self.ngroups, arity)
  1113. else:
  1114. if axis > 0:
  1115. swapped = True
  1116. values = values.swapaxes(0, axis)
  1117. if arity > 1:
  1118. raise NotImplementedError
  1119. out_shape = (self.ngroups,) + values.shape[1:]
  1120. if is_numeric_dtype(values.dtype):
  1121. values = com.ensure_float(values)
  1122. is_numeric = True
  1123. out_dtype = 'f%d' % values.dtype.itemsize
  1124. else:
  1125. is_numeric = issubclass(values.dtype.type, (np.datetime64,
  1126. np.timedelta64))
  1127. if is_numeric:
  1128. out_dtype = 'float64'
  1129. values = values.view('int64')
  1130. else:
  1131. out_dtype = 'object'
  1132. values = values.astype(object)
  1133. # will be filled in Cython function
  1134. result = np.empty(out_shape, dtype=out_dtype)
  1135. result.fill(np.nan)
  1136. counts = np.zeros(self.ngroups, dtype=np.int64)
  1137. result = self._aggregate(result, counts, values, how, is_numeric)
  1138. if self._filter_empty_groups:
  1139. if result.ndim == 2:
  1140. try:
  1141. result = lib.row_bool_subset(
  1142. result, (counts > 0).view(np.uint8))
  1143. except ValueError:
  1144. result = lib.row_bool_subset_object(
  1145. result, (counts > 0).view(np.uint8))
  1146. else:
  1147. result = result[counts > 0]
  1148. if vdim == 1 and arity == 1:
  1149. result = result[:, 0]
  1150. if how in self._name_functions:
  1151. # TODO
  1152. names = self._name_functions[how]()
  1153. else:
  1154. names = None
  1155. if swapped:
  1156. result = result.swapaxes(0, axis)
  1157. return result, names
  1158. def _aggregate(self, result, counts, values, how, is_numeric):
  1159. agg_func, dtype = self._get_aggregate_function(how, values)
  1160. comp_ids, _, ngroups = self.group_info
  1161. if values.ndim > 3:
  1162. # punting for now
  1163. raise NotImplementedError
  1164. elif values.ndim > 2:
  1165. for i, chunk in enumerate(values.transpose(2, 0, 1)):
  1166. chunk = chunk.squeeze()
  1167. agg_func(result[:, :, i], counts, chunk, comp_ids)
  1168. else:
  1169. agg_func(result, counts, values, comp_ids)
  1170. return result
  1171. def agg_series(self, obj, func):
  1172. try:
  1173. return self._aggregate_series_fast(obj, func)
  1174. except Exception:
  1175. return self._aggregate_series_pure_python(obj, func)
  1176. def _aggregate_series_fast(self, obj, func):
  1177. func = _intercept_function(func)
  1178. if obj.index._has_complex_internals:
  1179. raise TypeError('Incompatible index for Cython grouper')
  1180. group_index, _, ngroups = self.group_info
  1181. # avoids object / Series creation overhead
  1182. dummy = obj._get_values(slice(None, 0)).to_dense()
  1183. indexer = _algos.groupsort_indexer(group_index, ngroups)[0]
  1184. obj = obj.take(indexer, convert=False)
  1185. group_index = com.take_nd(group_index, indexer, allow_fill=False)
  1186. grouper = lib.SeriesGrouper(obj, func, group_index, ngroups,
  1187. dummy)
  1188. result, counts = grouper.get_result()
  1189. return result, counts
  1190. def _aggregate_series_pure_python(self, obj, func):
  1191. group_index, _, ngroups = self.group_info
  1192. counts = np.zeros(ngroups, dtype=int)
  1193. result = None
  1194. splitter = get_splitter(obj, group_index, ngroups, axis=self.axis)
  1195. for label, group in splitter:
  1196. res = func(group)
  1197. if result is None:
  1198. if (isinstance(res, (Series, np.ndarray)) or
  1199. isinstance(res, list)):
  1200. raise ValueError('Function does not reduce')
  1201. result = np.empty(ngroups, dtype='O')
  1202. counts[label] = group.shape[0]
  1203. result[label] = res
  1204. result = lib.maybe_convert_objects(result, try_float=0)
  1205. return result, counts
  1206. def generate_bins_generic(values, binner, closed):
  1207. """
  1208. Generate bin edge offsets and bin labels for one array using another array
  1209. which has bin edge values. Both arrays must be sorted.
  1210. Parameters
  1211. ----------
  1212. values : array of values
  1213. binner : a comparable array of values representing bins into which to bin
  1214. the first array. Note, 'values' end-points must fall within 'binner'
  1215. end-points.
  1216. closed : which end of bin is closed; left (default), right
  1217. Returns
  1218. -------
  1219. bins : array of offsets (into 'values' argument) of bins.
  1220. Zero and last edge are excluded in result, so for instance the first
  1221. bin is values[0:bin[0]] and the last is values[bin[-1]:]
  1222. """
  1223. lenidx = len(values)
  1224. lenbin = len(binner)
  1225. if lenidx <= 0 or lenbin <= 0:
  1226. raise ValueError("Invalid length for values or for binner")
  1227. # check binner fits data
  1228. if values[0] < binner[0]:
  1229. raise ValueError("Values falls before first bin")
  1230. if values[lenidx - 1] > binner[lenbin - 1]:
  1231. raise ValueError("Values falls after last bin")
  1232. bins = np.empty(lenbin - 1, dtype=np.int64)
  1233. j = 0 # index into values
  1234. bc = 0 # bin count
  1235. # linear scan, presume nothing about values/binner except that it fits ok
  1236. for i in range(0, lenbin - 1):
  1237. r_bin = binner[i + 1]
  1238. # count values in current bin, advance to next bin
  1239. while j < lenidx and (values[j] < r_bin or
  1240. (closed == 'right' and values[j] == r_bin)):
  1241. j += 1
  1242. bins[bc] = j
  1243. bc += 1
  1244. return bins
  1245. class BinGrouper(BaseGrouper):
  1246. def __init__(self, bins, binlabels, filter_empty=False):
  1247. self.bins = com._ensure_int64(bins)
  1248. self.binlabels = _ensure_index(binlabels)
  1249. self._filter_empty_groups = filter_empty
  1250. @cache_readonly
  1251. def groups(self):
  1252. """ dict {group name -> group labels} """
  1253. # this is mainly for compat
  1254. # GH 3881
  1255. result = {}
  1256. for key, value in zip(self.binlabels, self.bins):
  1257. if key is not tslib.NaT:
  1258. result[key] = value
  1259. return result
  1260. @property
  1261. def nkeys(self):
  1262. return 1
  1263. def get_iterator(self, data, axis=0):
  1264. """
  1265. Groupby iterator
  1266. Returns
  1267. -------
  1268. Generator yielding sequence of (name, subsetted object)
  1269. for each group
  1270. """
  1271. if isinstance(data, NDFrame):
  1272. slicer = lambda start,edge: data._slice(slice(start,edge),axis=axis)
  1273. length = len(data.axes[axis])
  1274. else:
  1275. slicer = lambda start,edge: data[slice(start,edge)]
  1276. length = len(data)
  1277. start = 0
  1278. for edge, label in zip(self.bins, self.binlabels):
  1279. if label is not tslib.NaT:
  1280. yield label, slicer(start,edge)
  1281. start = edge
  1282. if start < length:
  1283. yield self.binlabels[-1], slicer(start,None)
  1284. def apply(self, f, data, axis=0):
  1285. result_keys = []
  1286. result_values = []
  1287. mutated = False
  1288. for key, group in self.get_iterator(data, axis=axis):
  1289. object.__setattr__(group, 'name', key)
  1290. # group might be modified
  1291. group_axes = _get_axes(group)
  1292. res = f(group)
  1293. if not _is_indexed_like(res, group_axes):
  1294. mutated = True
  1295. result_keys.append(key)
  1296. result_values.append(res)
  1297. return result_keys, result_values, mutated
  1298. @cache_readonly
  1299. def indices(self):
  1300. indices = collections.defaultdict(list)
  1301. i = 0
  1302. for label, bin in zip(self.binlabels, self.bins):
  1303. if label is not tslib.NaT and i < bin:
  1304. indices[label] = list(range(i, bin))
  1305. i = bin
  1306. return indices
  1307. @cache_readonly
  1308. def ngroups(self):
  1309. return len(self.binlabels)
  1310. @cache_readonly
  1311. def result_index(self):
  1312. mask = self.binlabels.asi8 == tslib.iNaT
  1313. return self.binlabels[~mask]
  1314. @property
  1315. def levels(self):
  1316. return [self.binlabels]
  1317. @property
  1318. def names(self):
  1319. return [self.binlabels.name]
  1320. #----------------------------------------------------------------------
  1321. # cython aggregation
  1322. _cython_functions = {
  1323. 'add': 'group_add_bin',
  1324. 'prod': 'group_prod_bin',
  1325. 'mean': 'group_mean_bin',
  1326. 'min': 'group_min_bin',
  1327. 'max': 'group_max_bin',
  1328. 'var': 'group_var_bin',
  1329. 'ohlc': 'group_ohlc',
  1330. 'first': {
  1331. 'name': 'group_nth_bin',
  1332. 'f': lambda func, a, b, c, d: func(a, b, c, d, 1)
  1333. },
  1334. 'last': 'group_last_bin',
  1335. 'count': 'group_count_bin',
  1336. }
  1337. _name_functions = {
  1338. 'ohlc': lambda *args: ['open', 'high', 'low', 'close']
  1339. }
  1340. _filter_empty_groups = True
  1341. def _aggregate(self, result, counts, values, how, is_numeric=True):
  1342. agg_func, dtype = self._get_aggregate_function(how, values)
  1343. if values.ndim > 3:
  1344. # punting for now
  1345. raise NotImplementedError
  1346. elif values.ndim > 2:
  1347. for i, chunk in enumerate(values.transpose(2, 0, 1)):
  1348. agg_func(result[:, :, i], counts, chunk, self.bins)
  1349. else:
  1350. agg_func(result, counts, values, self.bins)
  1351. return result
  1352. def agg_series(self, obj, func):
  1353. dummy = obj[:0]
  1354. grouper = lib.SeriesBinGrouper(obj, func, self.bins, dummy)
  1355. return grouper.get_result()
  1356. class Grouping(object):
  1357. """
  1358. Holds the grouping information for a single key
  1359. Parameters
  1360. ----------
  1361. index : Index
  1362. grouper :
  1363. obj :
  1364. name :
  1365. level :
  1366. Returns
  1367. -------
  1368. **Attributes**:
  1369. * indices : dict of {group -> index_list}
  1370. * labels : ndarray, group labels
  1371. * ids : mapping of label -> group
  1372. * counts : array of group counts
  1373. * group_index : unique groups
  1374. * groups : dict of {group -> label_list}
  1375. """
  1376. def __init__(self, index, grouper=None, obj=None, name=None, level=None,
  1377. sort=True):
  1378. self.name = name
  1379. self.level = level
  1380. self.grouper = _convert_grouper(index, grouper)
  1381. self.index = index
  1382. self.sort = sort
  1383. self.obj = obj
  1384. # right place for this?
  1385. if isinstance(grouper, (Series, Index)) and name is None:
  1386. self.name = grouper.name
  1387. if isinstance(grouper, MultiIndex):
  1388. self.grouper = grouper.values
  1389. # pre-computed
  1390. self._was_factor = False
  1391. self._should_compress = True
  1392. # we have a single grouper which may be a myriad of things, some of which are
  1393. # dependent on the passing in level
  1394. #
  1395. if level is not None:
  1396. if not isinstance(level, int):
  1397. if level not in index.names:
  1398. raise AssertionError('Level %s not in index' % str(level))
  1399. level = index.names.index(level)
  1400. inds = index.labels[level]
  1401. level_index = index.levels[level]
  1402. if self.name is None:
  1403. self.name = index.names[level]
  1404. # XXX complete hack
  1405. if grouper is not None:
  1406. level_values = index.levels[level].take(inds)
  1407. self.grouper = level_values.map(self.grouper)
  1408. else:
  1409. self._was_factor = True
  1410. # all levels may not be observed
  1411. labels, uniques = algos.factorize(inds, sort=True)
  1412. if len(uniques) > 0 and uniques[0] == -1:
  1413. # handle NAs
  1414. mask = inds != -1
  1415. ok_labels, uniques = algos.factorize(inds[mask], sort=True)
  1416. labels = np.empty(len(inds), dtype=inds.dtype)
  1417. labels[mask] = ok_labels
  1418. labels[~mask] = -1
  1419. if len(uniques) < len(level_index):
  1420. level_index = level_index.take(uniques)
  1421. self._labels = labels
  1422. self._group_index = level_index
  1423. self.grouper = level_index.take(labels)
  1424. else:
  1425. if isinstance(self.grouper, (list, tuple)):
  1426. self.grouper = com._asarray_tuplesafe(self.grouper)
  1427. # a passed Categorical
  1428. elif isinstance(self.grouper, Categorical):
  1429. factor = self.grouper
  1430. self._was_factor = True
  1431. # Is there any way to avoid this?
  1432. self.grouper = np.asarray(factor)
  1433. self._labels = factor.labels
  1434. self._group_index = factor.levels
  1435. if self.name is None:
  1436. self.name = factor.name
  1437. # a passed Grouper like
  1438. elif isinstance(self.grouper, Grouper):
  1439. # get the new grouper
  1440. grouper = self.grouper._get_binner_for_grouping(self.obj)
  1441. self.obj = self.grouper.obj
  1442. self.grouper = grouper
  1443. if self.name is None:
  1444. self.name = grouper.name
  1445. # no level passed
  1446. if not isinstance(self.grouper, (Series, np.ndarray)):
  1447. self.grouper = self.index.map(self.grouper)
  1448. if not (hasattr(self.grouper, "__len__") and
  1449. len(self.grouper) == len(self.index)):
  1450. errmsg = ('Grouper result violates len(labels) == '
  1451. 'len(data)\nresult: %s' %
  1452. com.pprint_thing(self.grouper))
  1453. self.grouper = None # Try for sanity
  1454. raise AssertionError(errmsg)
  1455. # if we have a date/time-like grouper, make sure that we have Timestamps like
  1456. if getattr(self.grouper,'dtype',None) is not None:
  1457. if is_datetime64_dtype(self.grouper):
  1458. from pandas import to_datetime
  1459. self.grouper = to_datetime(self.grouper)
  1460. elif is_timedelta64_dtype(self.grouper):
  1461. from pandas import to_timedelta
  1462. self.grouper = to_timedelta(self.grouper)
  1463. def __repr__(self):
  1464. return 'Grouping(%s)' % self.name
  1465. def __iter__(self):
  1466. return iter(self.indices)
  1467. _labels = None
  1468. _group_index = None
  1469. @property
  1470. def ngroups(self):
  1471. return len(self.group_index)
  1472. @cache_readonly
  1473. def indices(self):
  1474. return _groupby_indices(self.grouper)
  1475. @property
  1476. def labels(self):
  1477. if self._labels is None:
  1478. self._make_labels()
  1479. return self._labels
  1480. @property
  1481. def group_index(self):
  1482. if self._group_index is None:
  1483. self._make_labels()
  1484. return self._group_index
  1485. def _make_labels(self):
  1486. if self._was_factor: # pragma: no cover
  1487. raise Exception('Should not call this method grouping by level')
  1488. else:
  1489. labels, uniques = algos.factorize(self.grouper, sort=self.sort)
  1490. uniques = Index(uniques, name=self.name)
  1491. self._labels = labels
  1492. self._group_index = uniques
  1493. _groups = None
  1494. @property
  1495. def groups(self):
  1496. if self._groups is None:
  1497. self._groups = self.index.groupby(self.grouper)
  1498. return self._groups
  1499. def _get_grouper(obj, key=None, axis=0, level=None, sort=True):
  1500. """
  1501. create and return a BaseGrouper, which is an internal
  1502. mapping of how to create the grouper indexers.
  1503. This may be composed of multiple Grouping objects, indicating
  1504. multiple groupers
  1505. Groupers are ultimately index mappings. They can originate as:
  1506. index mappings, keys to columns, functions, or Groupers
  1507. Groupers enable local references to axis,level,sort, while
  1508. the passed in axis, level, and sort are 'global'.
  1509. This routine tries to figure of what the passing in references
  1510. are and then creates a Grouping for each one, combined into
  1511. a BaseGrouper.
  1512. """
  1513. group_axis = obj._get_axis(axis)
  1514. # validate thatthe passed level is compatible with the passed
  1515. # axis of the object
  1516. if level is not None:
  1517. if not isinstance(group_axis, MultiIndex):
  1518. if isinstance(level, compat.string_types):
  1519. if obj.index.name != level:
  1520. raise ValueError('level name %s is not the name of the '
  1521. 'index' % level)
  1522. elif level > 0:
  1523. raise ValueError('level > 0 only valid with MultiIndex')
  1524. level = None
  1525. key = group_axis
  1526. # a passed in Grouper, directly convert
  1527. if isinstance(key, Grouper):
  1528. binner, grouper, obj = key._get_grouper(obj)
  1529. if key.key is None:
  1530. return grouper, [], obj
  1531. else:
  1532. return grouper, set([key.key]), obj
  1533. # already have a BaseGrouper, just return it
  1534. elif isinstance(key, BaseGrouper):
  1535. return key, [], obj
  1536. if not isinstance(key, (tuple, list)):
  1537. keys = [key]
  1538. else:
  1539. keys = key
  1540. # what are we after, exactly?
  1541. match_axis_length = len(keys) == len(group_axis)
  1542. any_callable = any(callable(g) or isinstance(g, dict) for g in keys)
  1543. any_arraylike = any(isinstance(g, (list, tuple, Series, np.ndarray))
  1544. for g in keys)
  1545. try:
  1546. if isinstance(obj, DataFrame):
  1547. all_in_columns = all(g in obj.columns for g in keys)
  1548. else:
  1549. all_in_columns = False
  1550. except Exception:
  1551. all_in_columns = False
  1552. if (not any_callable and not all_in_columns
  1553. and not any_arraylike and match_axis_length
  1554. and level is None):
  1555. keys = [com._asarray_tuplesafe(keys)]
  1556. if isinstance(level, (tuple, list)):
  1557. if key is None:
  1558. keys = [None] * len(level)
  1559. levels = level
  1560. else:
  1561. levels = [level] * len(keys)
  1562. groupings = []
  1563. exclusions = []
  1564. for i, (gpr, level) in enumerate(zip(keys, levels)):
  1565. name = None
  1566. try:
  1567. obj._data.items.get_loc(gpr)
  1568. in_axis = True
  1569. except Exception:
  1570. in_axis = False
  1571. if _is_label_like(gpr) or in_axis:
  1572. exclusions.append(gpr)
  1573. name = gpr
  1574. gpr = obj[gpr]
  1575. if isinstance(gpr, Categorical) and len(gpr) != len(obj):
  1576. errmsg = "Categorical grouper must have len(grouper) == len(data)"
  1577. raise AssertionError(errmsg)
  1578. ping = Grouping(group_axis, gpr, obj=obj, name=name, level=level, sort=sort)
  1579. groupings.append(ping)
  1580. if len(groupings) == 0:
  1581. raise ValueError('No group keys passed!')
  1582. # create the internals grouper
  1583. grouper = BaseGrouper(group_axis, groupings, sort=sort)
  1584. return grouper, exclusions, obj
  1585. def _is_label_like(val):
  1586. return isinstance(val, compat.string_types) or np.isscalar(val)
  1587. def _convert_grouper(axis, grouper):
  1588. if isinstance(grouper, dict):
  1589. return grouper.get
  1590. elif isinstance(grouper, Series):
  1591. if grouper.index.equals(axis):
  1592. return grouper.values
  1593. else:
  1594. return grouper.reindex(axis).values
  1595. elif isinstance(grouper, (list, Series, np.ndarray)):
  1596. if len(grouper) != len(axis):
  1597. raise AssertionError('Grouper and axis must be same length')
  1598. return grouper
  1599. else:
  1600. return grouper
  1601. class SeriesGroupBy(GroupBy):
  1602. _apply_whitelist = _series_apply_whitelist
  1603. def aggregate(self, func_or_funcs, *args, **kwargs):
  1604. """
  1605. Apply aggregation function or functions to groups, yielding most likely
  1606. Series but in some cases DataFrame depending on the output of the
  1607. aggregation function
  1608. Parameters
  1609. ----------
  1610. func_or_funcs : function or list / dict of functions
  1611. List/dict of functions will produce DataFrame with column names
  1612. determined by the function names themselves (list) or the keys in
  1613. the dict
  1614. Notes
  1615. -----
  1616. agg is an alias for aggregate. Use it.
  1617. Examples
  1618. --------
  1619. >>> series
  1620. bar 1.0
  1621. baz 2.0
  1622. qot 3.0
  1623. qux 4.0
  1624. >>> mapper = lambda x: x[0] # first letter
  1625. >>> grouped = series.groupby(mapper)
  1626. >>> grouped.aggregate(np.sum)
  1627. b 3.0
  1628. q 7.0
  1629. >>> grouped.aggregate([np.sum, np.mean, np.std])
  1630. mean std sum
  1631. b 1.5 0.5 3
  1632. q 3.5 0.5 7
  1633. >>> grouped.agg({'result' : lambda x: x.mean() / x.std(),
  1634. ... 'total' : np.sum})
  1635. result total
  1636. b 2.121 3
  1637. q 4.95 7
  1638. See also
  1639. --------
  1640. apply, transform
  1641. Returns
  1642. -------
  1643. Series or DataFrame
  1644. """
  1645. if isinstance(func_or_funcs, compat.string_types):
  1646. return getattr(self, func_or_funcs)(*args, **kwargs)
  1647. if hasattr(func_or_funcs, '__iter__'):
  1648. ret = self._aggregate_multiple_funcs(func_or_funcs)
  1649. else:
  1650. cyfunc = _intercept_cython(func_or_funcs)
  1651. if cyfunc and not args and not kwargs:
  1652. return getattr(self, cyfunc)()
  1653. if self.grouper.nkeys > 1:
  1654. return self._python_agg_general(func_or_funcs, *args, **kwargs)
  1655. try:
  1656. return self._python_agg_general(func_or_funcs, *args, **kwargs)
  1657. except Exception:
  1658. result = self._aggregate_named(func_or_funcs, *args, **kwargs)
  1659. index = Index(sorted(result), name=self.grouper.names[0])
  1660. ret = Series(result, index=index)
  1661. if not self.as_index: # pragma: no cover
  1662. print('Warning, ignoring as_index=True')
  1663. return ret
  1664. def _aggregate_multiple_funcs(self, arg):
  1665. if isinstance(arg, dict):
  1666. columns = list(arg.keys())
  1667. arg = list(arg.items())
  1668. elif any(isinstance(x, (tuple, list)) for x in arg):
  1669. arg = [(x, x) if not isinstance(x, (tuple, list)) else x
  1670. for x in arg]
  1671. # indicated column order
  1672. columns = lzip(*arg)[0]
  1673. else:
  1674. # list of functions / function names
  1675. columns = []
  1676. for f in arg:
  1677. if isinstance(f, compat.string_types):
  1678. columns.append(f)
  1679. else:
  1680. columns.append(f.__name__)
  1681. arg = lzip(columns, arg)
  1682. results = {}
  1683. for name, func in arg:
  1684. if name in results:
  1685. raise SpecificationError('Function names must be unique, '
  1686. 'found multiple named %s' % name)
  1687. results[name] = self.aggregate(func)
  1688. return DataFrame(results, columns=columns)
  1689. def _wrap_aggregated_output(self, output, names=None):
  1690. # sort of a kludge
  1691. output = output[self.name]
  1692. index = self.grouper.result_index
  1693. if names is not None:
  1694. return DataFrame(output, index=index, columns=names)
  1695. else:
  1696. return Series(output, index=index, name=self.name)
  1697. def _wrap_applied_output(self, keys, values, not_indexed_same=False):
  1698. if len(keys) == 0:
  1699. # GH #6265
  1700. return Series([], name=self.name)
  1701. def _get_index():
  1702. if self.grouper.nkeys > 1:
  1703. index = MultiIndex.from_tuples(keys, names=self.grouper.names)
  1704. else:
  1705. index = Index(keys, name=self.grouper.names[0])
  1706. return index
  1707. if isinstance(values[0], dict):
  1708. # GH #823
  1709. index = _get_index()
  1710. return DataFrame(values, index=index).stack()
  1711. if isinstance(values[0], (Series, dict)):
  1712. return self._concat_objects(keys, values,
  1713. not_indexed_same=not_indexed_same)
  1714. elif isinstance(values[0], DataFrame):
  1715. # possible that Series -> DataFrame by applied function
  1716. return self._concat_objects(keys, values,
  1717. not_indexed_same=not_indexed_same)
  1718. else:
  1719. # GH #6265
  1720. return Series(values, index=_get_index(), name=self.name)
  1721. def _aggregate_named(self, func, *args, **kwargs):
  1722. result = {}
  1723. for name, group in self:
  1724. group.name = name
  1725. output = func(group, *args, **kwargs)
  1726. if isinstance(output, (Series, np.ndarray)):
  1727. raise Exception('Must produce aggregated value')
  1728. result[name] = self._try_cast(output, group)
  1729. return result
  1730. def transform(self, func, *args, **kwargs):
  1731. """
  1732. Call function producing a like-indexed Series on each group and return
  1733. a Series with the transformed values
  1734. Parameters
  1735. ----------
  1736. func : function
  1737. To apply to each group. Should return a Series with the same index
  1738. Examples
  1739. --------
  1740. >>> grouped.transform(lambda x: (x - x.mean()) / x.std())
  1741. Returns
  1742. -------
  1743. transformed : Series
  1744. """
  1745. dtype = self._selected_obj.dtype
  1746. if isinstance(func, compat.string_types):
  1747. wrapper = lambda x: getattr(x, func)(*args, **kwargs)
  1748. else:
  1749. wrapper = lambda x: func(x, *args, **kwargs)
  1750. result = self._selected_obj.values.copy()
  1751. for i, (name, group) in enumerate(self):
  1752. object.__setattr__(group, 'name', name)
  1753. res = wrapper(group)
  1754. if hasattr(res, 'values'):
  1755. res = res.values
  1756. # may need to astype
  1757. try:
  1758. common_type = np.common_type(np.array(res), result)
  1759. if common_type != result.dtype:
  1760. result = result.astype(common_type)
  1761. except:
  1762. pass
  1763. indexer = self._get_index(name)
  1764. result[indexer] = res
  1765. result = _possibly_downcast_to_dtype(result, dtype)
  1766. return self._selected_obj.__class__(result,
  1767. index=self._selected_obj.index,
  1768. name=self._selected_obj.name)
  1769. def filter(self, func, dropna=True, *args, **kwargs):
  1770. """
  1771. Return a copy of a Series excluding elements from groups that
  1772. do not satisfy the boolean criterion specified by func.
  1773. Parameters
  1774. ----------
  1775. func : function
  1776. To apply to each group. Should return True or False.
  1777. dropna : Drop groups that do not pass the filter. True by default;
  1778. if False, groups that evaluate False are filled with NaNs.
  1779. Example
  1780. -------
  1781. >>> grouped.filter(lambda x: x.mean() > 0)
  1782. Returns
  1783. -------
  1784. filtered : Series
  1785. """
  1786. if isinstance(func, compat.string_types):
  1787. wrapper = lambda x: getattr(x, func)(*args, **kwargs)
  1788. else:
  1789. wrapper = lambda x: func(x, *args, **kwargs)
  1790. # Interpret np.nan as False.
  1791. def true_and_notnull(x, *args, **kwargs):
  1792. b = wrapper(x, *args, **kwargs)
  1793. return b and notnull(b)
  1794. try:
  1795. indices = [self._get_index(name) if true_and_notnull(group) else []
  1796. for name, group in self]
  1797. except ValueError:
  1798. raise TypeError("the filter must return a boolean result")
  1799. except TypeError:
  1800. raise TypeError("the filter must return a boolean result")
  1801. filtered = self._apply_filter(indices, dropna)
  1802. return filtered
  1803. def _apply_to_column_groupbys(self, func):
  1804. """ return a pass thru """
  1805. return func(self)
  1806. class NDFrameGroupBy(GroupBy):
  1807. def _iterate_slices(self):
  1808. if self.axis == 0:
  1809. # kludge
  1810. if self._selection is None:
  1811. slice_axis = self.obj.columns
  1812. else:
  1813. slice_axis = self._selection_list
  1814. slicer = lambda x: self.obj[x]
  1815. else:
  1816. slice_axis = self.obj.index
  1817. slicer = self.obj.xs
  1818. for val in slice_axis:
  1819. if val in self.exclusions:
  1820. continue
  1821. yield val, slicer(val)
  1822. def _cython_agg_general(self, how, numeric_only=True):
  1823. new_items, new_blocks = self._cython_agg_blocks(how, numeric_only=numeric_only)
  1824. return self._wrap_agged_blocks(new_items, new_blocks)
  1825. def _wrap_agged_blocks(self, items, blocks):
  1826. obj = self._obj_with_exclusions
  1827. new_axes = list(obj._data.axes)
  1828. # more kludge
  1829. if self.axis == 0:
  1830. new_axes[0], new_axes[1] = new_axes[1], self.grouper.result_index
  1831. else:
  1832. new_axes[self.axis] = self.grouper.result_index
  1833. # Make sure block manager integrity check passes.
  1834. assert new_axes[0].equals(items)
  1835. new_axes[0] = items
  1836. mgr = BlockManager(blocks, new_axes)
  1837. new_obj = type(obj)(mgr)
  1838. return self._post_process_cython_aggregate(new_obj)
  1839. _block_agg_axis = 0
  1840. def _cython_agg_blocks(self, how, numeric_only=True):
  1841. data, agg_axis = self._get_data_to_aggregate()
  1842. new_blocks = []
  1843. if numeric_only:
  1844. data = data.get_numeric_data(copy=False)
  1845. for block in data.blocks:
  1846. values = block._try_operate(block.values)
  1847. if block.is_numeric:
  1848. values = com.ensure_float(values)
  1849. result, _ = self.grouper.aggregate(values, how, axis=agg_axis)
  1850. # see if we can cast the block back to the original dtype
  1851. result = block._try_coerce_and_cast_result(result)
  1852. newb = make_block(result, placement=block.mgr_locs)
  1853. new_blocks.append(newb)
  1854. if len(new_blocks) == 0:
  1855. raise DataError('No numeric types to aggregate')
  1856. return data.items, new_blocks
  1857. def _get_data_to_aggregate(self):
  1858. obj = self._obj_with_exclusions
  1859. if self.axis == 0:
  1860. return obj.swapaxes(0, 1)._data, 1
  1861. else:
  1862. return obj._data, self.axis
  1863. def _post_process_cython_aggregate(self, obj):
  1864. # undoing kludge from below
  1865. if self.axis == 0:
  1866. obj = obj.swapaxes(0, 1)
  1867. return obj
  1868. @cache_readonly
  1869. def _obj_with_exclusions(self):
  1870. if self._selection is not None:
  1871. return self.obj.reindex(columns=self._selection_list)
  1872. if len(self.exclusions) > 0:
  1873. return self.obj.drop(self.exclusions, axis=1)
  1874. else:
  1875. return self.obj
  1876. @Appender(_agg_doc)
  1877. def aggregate(self, arg, *args, **kwargs):
  1878. if isinstance(arg, compat.string_types):
  1879. return getattr(self, arg)(*args, **kwargs)
  1880. result = OrderedDict()
  1881. if isinstance(arg, dict):
  1882. if self.axis != 0: # pragma: no cover
  1883. raise ValueError('Can only pass dict with axis=0')
  1884. obj = self._selected_obj
  1885. if any(isinstance(x, (list, tuple, dict)) for x in arg.values()):
  1886. new_arg = OrderedDict()
  1887. for k, v in compat.iteritems(arg):
  1888. if not isinstance(v, (tuple, list, dict)):
  1889. new_arg[k] = [v]
  1890. else:
  1891. new_arg[k] = v
  1892. arg = new_arg
  1893. keys = []
  1894. if self._selection is not None:
  1895. subset = obj
  1896. if isinstance(subset, DataFrame):
  1897. raise NotImplementedError
  1898. for fname, agg_how in compat.iteritems(arg):
  1899. colg = SeriesGroupBy(subset, selection=self._selection,
  1900. grouper=self.grouper)
  1901. result[fname] = colg.aggregate(agg_how)
  1902. keys.append(fname)
  1903. else:
  1904. for col, agg_how in compat.iteritems(arg):
  1905. colg = SeriesGroupBy(obj[col], selection=col,
  1906. grouper=self.grouper)
  1907. result[col] = colg.aggregate(agg_how)
  1908. keys.append(col)
  1909. if isinstance(list(result.values())[0], DataFrame):
  1910. from pandas.tools.merge import concat
  1911. result = concat([result[k] for k in keys], keys=keys, axis=1)
  1912. else:
  1913. result = DataFrame(result)
  1914. elif isinstance(arg, list):
  1915. return self._aggregate_multiple_funcs(arg)
  1916. else:
  1917. cyfunc = _intercept_cython(arg)
  1918. if cyfunc and not args and not kwargs:
  1919. return getattr(self, cyfunc)()
  1920. if self.grouper.nkeys > 1:
  1921. return self._python_agg_general(arg, *args, **kwargs)
  1922. else:
  1923. # try to treat as if we are passing a list
  1924. try:
  1925. assert not args and not kwargs
  1926. result = self._aggregate_multiple_funcs([arg])
  1927. result.columns = Index(result.columns.levels[0],
  1928. name=self._selected_obj.columns.name)
  1929. except:
  1930. result = self._aggregate_generic(arg, *args, **kwargs)
  1931. if not self.as_index:
  1932. if isinstance(result.index, MultiIndex):
  1933. zipped = zip(result.index.levels, result.index.labels,
  1934. result.index.names)
  1935. for i, (lev, lab, name) in enumerate(zipped):
  1936. result.insert(i, name,
  1937. com.take_nd(lev.values, lab,
  1938. allow_fill=False))
  1939. result = result.consolidate()
  1940. else:
  1941. values = result.index.values
  1942. name = self.grouper.groupings[0].name
  1943. result.insert(0, name, values)
  1944. result.index = np.arange(len(result))
  1945. return result.convert_objects()
  1946. def _aggregate_multiple_funcs(self, arg):
  1947. from pandas.tools.merge import concat
  1948. if self.axis != 0:
  1949. raise NotImplementedError
  1950. obj = self._obj_with_exclusions
  1951. results = []
  1952. keys = []
  1953. for col in obj:
  1954. try:
  1955. colg = SeriesGroupBy(obj[col], selection=col,
  1956. grouper=self.grouper)
  1957. results.append(colg.aggregate(arg))
  1958. keys.append(col)
  1959. except (TypeError, DataError):
  1960. pass
  1961. except SpecificationError:
  1962. raise
  1963. result = concat(results, keys=keys, axis=1)
  1964. return result
  1965. def _aggregate_generic(self, func, *args, **kwargs):
  1966. if self.grouper.nkeys != 1:
  1967. raise AssertionError('Number of keys must be 1')
  1968. axis = self.axis
  1969. obj = self._obj_with_exclusions
  1970. result = {}
  1971. if axis != obj._info_axis_number:
  1972. try:
  1973. for name, data in self:
  1974. # for name in self.indices:
  1975. # data = self.get_group(name, obj=obj)
  1976. result[name] = self._try_cast(func(data, *args, **kwargs),
  1977. data)
  1978. except Exception:
  1979. return self._aggregate_item_by_item(func, *args, **kwargs)
  1980. else:
  1981. for name in self.indices:
  1982. try:
  1983. data = self.get_group(name, obj=obj)
  1984. result[name] = self._try_cast(func(data, *args, **kwargs),
  1985. data)
  1986. except Exception:
  1987. wrapper = lambda x: func(x, *args, **kwargs)
  1988. result[name] = data.apply(wrapper, axis=axis)
  1989. return self._wrap_generic_output(result, obj)
  1990. def _wrap_aggregated_output(self, output, names=None):
  1991. raise NotImplementedError
  1992. def _aggregate_item_by_item(self, func, *args, **kwargs):
  1993. # only for axis==0
  1994. obj = self._obj_with_exclusions
  1995. result = {}
  1996. cannot_agg = []
  1997. errors=None
  1998. for item in obj:
  1999. try:
  2000. data = obj[item]
  2001. colg = SeriesGroupBy(data, selection=item,
  2002. grouper=self.grouper)
  2003. result[item] = self._try_cast(
  2004. colg.aggregate(func, *args, **kwargs), data)
  2005. except ValueError:
  2006. cannot_agg.append(item)
  2007. continue
  2008. except TypeError as e:
  2009. cannot_agg.append(item)
  2010. errors=e
  2011. continue
  2012. result_columns = obj.columns
  2013. if cannot_agg:
  2014. result_columns = result_columns.drop(cannot_agg)
  2015. # GH6337
  2016. if not len(result_columns) and errors is not None:
  2017. raise errors
  2018. return DataFrame(result, columns=result_columns)
  2019. def _decide_output_index(self, output, labels):
  2020. if len(output) == len(labels):
  2021. output_keys = labels
  2022. else:
  2023. output_keys = sorted(output)
  2024. try:
  2025. output_keys.sort()
  2026. except Exception: # pragma: no cover
  2027. pass
  2028. if isinstance(labels, MultiIndex):
  2029. output_keys = MultiIndex.from_tuples(output_keys,
  2030. names=labels.names)
  2031. return output_keys
  2032. def _wrap_applied_output(self, keys, values, not_indexed_same=False):
  2033. from pandas.core.index import _all_indexes_same
  2034. if len(keys) == 0:
  2035. # XXX
  2036. return DataFrame({})
  2037. key_names = self.grouper.names
  2038. if isinstance(values[0], DataFrame):
  2039. return self._concat_objects(keys, values,
  2040. not_indexed_same=not_indexed_same)
  2041. elif hasattr(self.grouper, 'groupings'):
  2042. if len(self.grouper.groupings) > 1:
  2043. key_index = MultiIndex.from_tuples(keys, names=key_names)
  2044. else:
  2045. ping = self.grouper.groupings[0]
  2046. if len(keys) == ping.ngroups:
  2047. key_index = ping.group_index
  2048. key_index.name = key_names[0]
  2049. key_lookup = Index(keys)
  2050. indexer = key_lookup.get_indexer(key_index)
  2051. # reorder the values
  2052. values = [values[i] for i in indexer]
  2053. else:
  2054. key_index = Index(keys, name=key_names[0])
  2055. # don't use the key indexer
  2056. if not self.as_index:
  2057. key_index = None
  2058. # make Nones an empty object
  2059. if com._count_not_none(*values) != len(values):
  2060. v = next(v for v in values if v is not None)
  2061. if v is None:
  2062. return DataFrame()
  2063. elif isinstance(v, NDFrame):
  2064. values = [
  2065. x if x is not None else
  2066. v._constructor(**v._construct_axes_dict())
  2067. for x in values
  2068. ]
  2069. v = values[0]
  2070. if isinstance(v, (np.ndarray, Series)):
  2071. if isinstance(v, Series):
  2072. applied_index = self._selected_obj._get_axis(self.axis)
  2073. all_indexed_same = _all_indexes_same([
  2074. x.index for x in values
  2075. ])
  2076. singular_series = (len(values) == 1 and
  2077. applied_index.nlevels == 1)
  2078. # GH3596
  2079. # provide a reduction (Frame -> Series) if groups are
  2080. # unique
  2081. if self.squeeze:
  2082. # assign the name to this series
  2083. if singular_series:
  2084. values[0].name = keys[0]
  2085. # GH2893
  2086. # we have series in the values array, we want to
  2087. # produce a series:
  2088. # if any of the sub-series are not indexed the same
  2089. # OR we don't have a multi-index and we have only a
  2090. # single values
  2091. return self._concat_objects(
  2092. keys, values, not_indexed_same=not_indexed_same
  2093. )
  2094. # still a series
  2095. # path added as of GH 5545
  2096. elif all_indexed_same:
  2097. from pandas.tools.merge import concat
  2098. return concat(values)
  2099. if not all_indexed_same:
  2100. return self._concat_objects(
  2101. keys, values, not_indexed_same=not_indexed_same
  2102. )
  2103. try:
  2104. if self.axis == 0:
  2105. # GH6124 if the list of Series have a consistent name,
  2106. # then propagate that name to the result.
  2107. index = v.index.copy()
  2108. if index.name is None:
  2109. # Only propagate the series name to the result
  2110. # if all series have a consistent name. If the
  2111. # series do not have a consistent name, do
  2112. # nothing.
  2113. names = set(v.name for v in values)
  2114. if len(names) == 1:
  2115. index.name = list(names)[0]
  2116. # normally use vstack as its faster than concat
  2117. # and if we have mi-columns
  2118. if not _np_version_under1p7 or isinstance(v.index,MultiIndex) or key_index is None:
  2119. stacked_values = np.vstack([np.asarray(x) for x in values])
  2120. result = DataFrame(stacked_values,index=key_index,columns=index)
  2121. else:
  2122. # GH5788 instead of stacking; concat gets the dtypes correct
  2123. from pandas.tools.merge import concat
  2124. result = concat(values,keys=key_index,names=key_index.names,
  2125. axis=self.axis).unstack()
  2126. result.columns = index
  2127. else:
  2128. stacked_values = np.vstack([np.asarray(x) for x in values])
  2129. result = DataFrame(stacked_values.T,index=v.index,columns=key_index)
  2130. except (ValueError, AttributeError):
  2131. # GH1738: values is list of arrays of unequal lengths fall
  2132. # through to the outer else caluse
  2133. return Series(values, index=key_index)
  2134. # if we have date/time like in the original, then coerce dates
  2135. # as we are stacking can easily have object dtypes here
  2136. if (self._selected_obj.ndim == 2
  2137. and self._selected_obj.dtypes.isin(_DATELIKE_DTYPES).any()):
  2138. cd = 'coerce'
  2139. else:
  2140. cd = True
  2141. return result.convert_objects(convert_dates=cd)
  2142. else:
  2143. # only coerce dates if we find at least 1 datetime
  2144. cd = 'coerce' if any([ isinstance(v,Timestamp) for v in values ]) else False
  2145. return Series(values, index=key_index).convert_objects(convert_dates=cd)
  2146. else:
  2147. # Handle cases like BinGrouper
  2148. return self._concat_objects(keys, values,
  2149. not_indexed_same=not_indexed_same)
  2150. def _transform_general(self, func, *args, **kwargs):
  2151. from pandas.tools.merge import concat
  2152. applied = []
  2153. obj = self._obj_with_exclusions
  2154. gen = self.grouper.get_iterator(obj, axis=self.axis)
  2155. fast_path, slow_path = self._define_paths(func, *args, **kwargs)
  2156. path = None
  2157. for name, group in gen:
  2158. object.__setattr__(group, 'name', name)
  2159. if path is None:
  2160. # Try slow path and fast path.
  2161. try:
  2162. path, res = self._choose_path(fast_path, slow_path, group)
  2163. except TypeError:
  2164. return self._transform_item_by_item(obj, fast_path)
  2165. except Exception: # pragma: no cover
  2166. res = fast_path(group)
  2167. path = fast_path
  2168. else:
  2169. res = path(group)
  2170. # broadcasting
  2171. if isinstance(res, Series):
  2172. if res.index.is_(obj.index):
  2173. group.T.values[:] = res
  2174. else:
  2175. group.values[:] = res
  2176. applied.append(group)
  2177. else:
  2178. applied.append(res)
  2179. concat_index = obj.columns if self.axis == 0 else obj.index
  2180. concatenated = concat(applied, join_axes=[concat_index],
  2181. axis=self.axis, verify_integrity=False)
  2182. concatenated.sort_index(inplace=True)
  2183. return concatenated
  2184. def transform(self, func, *args, **kwargs):
  2185. """
  2186. Call function producing a like-indexed DataFrame on each group and
  2187. return a DataFrame having the same indexes as the original object
  2188. filled with the transformed values
  2189. Parameters
  2190. ----------
  2191. f : function
  2192. Function to apply to each subframe
  2193. Notes
  2194. -----
  2195. Each subframe is endowed the attribute 'name' in case you need to know
  2196. which group you are working on.
  2197. Examples
  2198. --------
  2199. >>> grouped = df.groupby(lambda x: mapping[x])
  2200. >>> grouped.transform(lambda x: (x - x.mean()) / x.std())
  2201. """
  2202. # try to do a fast transform via merge if possible
  2203. try:
  2204. obj = self._obj_with_exclusions
  2205. if isinstance(func, compat.string_types):
  2206. result = getattr(self, func)(*args, **kwargs)
  2207. else:
  2208. cyfunc = _intercept_cython(func)
  2209. if cyfunc and not args and not kwargs:
  2210. result = getattr(self, cyfunc)()
  2211. else:
  2212. return self._transform_general(func, *args, **kwargs)
  2213. except:
  2214. return self._transform_general(func, *args, **kwargs)
  2215. # a reduction transform
  2216. if not isinstance(result, DataFrame):
  2217. return self._transform_general(func, *args, **kwargs)
  2218. # nuiscance columns
  2219. if not result.columns.equals(obj.columns):
  2220. return self._transform_general(func, *args, **kwargs)
  2221. # a grouped that doesn't preserve the index, remap index based on the grouper
  2222. # and broadcast it
  2223. if not isinstance(obj.index,MultiIndex) and type(result.index) != type(obj.index):
  2224. results = obj.values.copy()
  2225. for (name, group), (i, row) in zip(self, result.iterrows()):
  2226. indexer = self._get_index(name)
  2227. results[indexer] = np.tile(row.values,len(indexer)).reshape(len(indexer),-1)
  2228. return DataFrame(results,columns=result.columns,index=obj.index).convert_objects()
  2229. # we can merge the result in
  2230. # GH 7383
  2231. names = result.columns
  2232. result = obj.merge(result, how='outer', left_index=True, right_index=True).ix[:,-result.shape[1]:]
  2233. result.columns = names
  2234. return result
  2235. def _define_paths(self, func, *args, **kwargs):
  2236. if isinstance(func, compat.string_types):
  2237. fast_path = lambda group: getattr(group, func)(*args, **kwargs)
  2238. slow_path = lambda group: group.apply(
  2239. lambda x: getattr(x, func)(*args, **kwargs), axis=self.axis)
  2240. else:
  2241. fast_path = lambda group: func(group, *args, **kwargs)
  2242. slow_path = lambda group: group.apply(
  2243. lambda x: func(x, *args, **kwargs), axis=self.axis)
  2244. return fast_path, slow_path
  2245. def _choose_path(self, fast_path, slow_path, group):
  2246. path = slow_path
  2247. res = slow_path(group)
  2248. # if we make it here, test if we can use the fast path
  2249. try:
  2250. res_fast = fast_path(group)
  2251. # compare that we get the same results
  2252. if res.shape == res_fast.shape:
  2253. res_r = res.values.ravel()
  2254. res_fast_r = res_fast.values.ravel()
  2255. mask = notnull(res_r)
  2256. if (res_r[mask] == res_fast_r[mask]).all():
  2257. path = fast_path
  2258. except:
  2259. pass
  2260. return path, res
  2261. def _transform_item_by_item(self, obj, wrapper):
  2262. # iterate through columns
  2263. output = {}
  2264. inds = []
  2265. for i, col in enumerate(obj):
  2266. try:
  2267. output[col] = self[col].transform(wrapper)
  2268. inds.append(i)
  2269. except Exception:
  2270. pass
  2271. if len(output) == 0: # pragma: no cover
  2272. raise TypeError('Transform function invalid for data types')
  2273. columns = obj.columns
  2274. if len(output) < len(obj.columns):
  2275. columns = columns.take(inds)
  2276. return DataFrame(output, index=obj.index, columns=columns)
  2277. def filter(self, func, dropna=True, *args, **kwargs):
  2278. """
  2279. Return a copy of a DataFrame excluding elements from groups that
  2280. do not satisfy the boolean criterion specified by func.
  2281. Parameters
  2282. ----------
  2283. f : function
  2284. Function to apply to each subframe. Should return True or False.
  2285. dropna : Drop groups that do not pass the filter. True by default;
  2286. if False, groups that evaluate False are filled with NaNs.
  2287. Notes
  2288. -----
  2289. Each subframe is endowed the attribute 'name' in case you need to know
  2290. which group you are working on.
  2291. Example
  2292. --------
  2293. >>> grouped = df.groupby(lambda x: mapping[x])
  2294. >>> grouped.filter(lambda x: x['A'].sum() + x['B'].sum() > 0)
  2295. """
  2296. from pandas.tools.merge import concat
  2297. indices = []
  2298. obj = self._selected_obj
  2299. gen = self.grouper.get_iterator(obj, axis=self.axis)
  2300. fast_path, slow_path = self._define_paths(func, *args, **kwargs)
  2301. path = None
  2302. for name, group in gen:
  2303. object.__setattr__(group, 'name', name)
  2304. if path is None:
  2305. # Try slow path and fast path.
  2306. try:
  2307. path, res = self._choose_path(fast_path, slow_path, group)
  2308. except Exception: # pragma: no cover
  2309. res = fast_path(group)
  2310. path = fast_path
  2311. else:
  2312. res = path(group)
  2313. def add_indices():
  2314. indices.append(self._get_index(name))
  2315. # interpret the result of the filter
  2316. if isinstance(res, (bool, np.bool_)):
  2317. if res:
  2318. add_indices()
  2319. else:
  2320. if getattr(res, 'ndim', None) == 1:
  2321. val = res.ravel()[0]
  2322. if val and notnull(val):
  2323. add_indices()
  2324. else:
  2325. # in theory you could do .all() on the boolean result ?
  2326. raise TypeError("the filter must return a boolean result")
  2327. filtered = self._apply_filter(indices, dropna)
  2328. return filtered
  2329. class DataFrameGroupBy(NDFrameGroupBy):
  2330. _apply_whitelist = _dataframe_apply_whitelist
  2331. _block_agg_axis = 1
  2332. def __getitem__(self, key):
  2333. if self._selection is not None:
  2334. raise Exception('Column(s) %s already selected' % self._selection)
  2335. if isinstance(key, (list, tuple, Series, np.ndarray)):
  2336. if len(self.obj.columns.intersection(key)) != len(key):
  2337. bad_keys = list(set(key).difference(self.obj.columns))
  2338. raise KeyError("Columns not found: %s"
  2339. % str(bad_keys)[1:-1])
  2340. return DataFrameGroupBy(self.obj, self.grouper, selection=key,
  2341. grouper=self.grouper,
  2342. exclusions=self.exclusions,
  2343. as_index=self.as_index)
  2344. elif not self.as_index:
  2345. if key not in self.obj.columns:
  2346. raise KeyError("Column not found: %s" % key)
  2347. return DataFrameGroupBy(self.obj, self.grouper, selection=key,
  2348. grouper=self.grouper,
  2349. exclusions=self.exclusions,
  2350. as_index=self.as_index)
  2351. else:
  2352. if key not in self.obj:
  2353. raise KeyError("Column not found: %s" % key)
  2354. # kind of a kludge
  2355. return SeriesGroupBy(self.obj[key], selection=key,
  2356. grouper=self.grouper,
  2357. exclusions=self.exclusions)
  2358. def _wrap_generic_output(self, result, obj):
  2359. result_index = self.grouper.levels[0]
  2360. if result:
  2361. if self.axis == 0:
  2362. result = DataFrame(result, index=obj.columns,
  2363. columns=result_index).T
  2364. else:
  2365. result = DataFrame(result, index=obj.index,
  2366. columns=result_index)
  2367. else:
  2368. result = DataFrame(result)
  2369. return result
  2370. def _get_data_to_aggregate(self):
  2371. obj = self._obj_with_exclusions
  2372. if self.axis == 1:
  2373. return obj.T._data, 1
  2374. else:
  2375. return obj._data, 1
  2376. def _wrap_aggregated_output(self, output, names=None):
  2377. agg_axis = 0 if self.axis == 1 else 1
  2378. agg_labels = self._obj_with_exclusions._get_axis(agg_axis)
  2379. output_keys = self._decide_output_index(output, agg_labels)
  2380. if not self.as_index:
  2381. result = DataFrame(output, columns=output_keys)
  2382. group_levels = self.grouper.get_group_levels()
  2383. zipped = zip(self.grouper.names, group_levels)
  2384. for i, (name, labels) in enumerate(zipped):
  2385. result.insert(i, name, labels)
  2386. result = result.consolidate()
  2387. else:
  2388. index = self.grouper.result_index
  2389. result = DataFrame(output, index=index, columns=output_keys)
  2390. if self.axis == 1:
  2391. result = result.T
  2392. return result.convert_objects()
  2393. def _wrap_agged_blocks(self, items, blocks):
  2394. if not self.as_index:
  2395. index = np.arange(blocks[0].values.shape[1])
  2396. mgr = BlockManager(blocks, [items, index])
  2397. result = DataFrame(mgr)
  2398. group_levels = self.grouper.get_group_levels()
  2399. zipped = zip(self.grouper.names, group_levels)
  2400. for i, (name, labels) in enumerate(zipped):
  2401. result.insert(i, name, labels)
  2402. result = result.consolidate()
  2403. else:
  2404. index = self.grouper.result_index
  2405. mgr = BlockManager(blocks, [items, index])
  2406. result = DataFrame(mgr)
  2407. if self.axis == 1:
  2408. result = result.T
  2409. return result.convert_objects()
  2410. def _iterate_column_groupbys(self):
  2411. for i, colname in enumerate(self._selected_obj.columns):
  2412. yield colname, SeriesGroupBy(self._selected_obj.iloc[:, i],
  2413. selection=colname,
  2414. grouper=self.grouper,
  2415. exclusions=self.exclusions)
  2416. def _apply_to_column_groupbys(self, func):
  2417. from pandas.tools.merge import concat
  2418. return concat(
  2419. (func(col_groupby) for _, col_groupby
  2420. in self._iterate_column_groupbys()),
  2421. keys=self._selected_obj.columns, axis=1)
  2422. from pandas.tools.plotting import boxplot_frame_groupby
  2423. DataFrameGroupBy.boxplot = boxplot_frame_groupby
  2424. class PanelGroupBy(NDFrameGroupBy):
  2425. def _iterate_slices(self):
  2426. if self.axis == 0:
  2427. # kludge
  2428. if self._selection is None:
  2429. slice_axis = self._selected_obj.items
  2430. else:
  2431. slice_axis = self._selection_list
  2432. slicer = lambda x: self._selected_obj[x]
  2433. else:
  2434. raise NotImplementedError
  2435. for val in slice_axis:
  2436. if val in self.exclusions:
  2437. continue
  2438. yield val, slicer(val)
  2439. def aggregate(self, arg, *args, **kwargs):
  2440. """
  2441. Aggregate using input function or dict of {column -> function}
  2442. Parameters
  2443. ----------
  2444. arg : function or dict
  2445. Function to use for aggregating groups. If a function, must either
  2446. work when passed a Panel or when passed to Panel.apply. If
  2447. pass a dict, the keys must be DataFrame column names
  2448. Returns
  2449. -------
  2450. aggregated : Panel
  2451. """
  2452. if isinstance(arg, compat.string_types):
  2453. return getattr(self, arg)(*args, **kwargs)
  2454. return self._aggregate_generic(arg, *args, **kwargs)
  2455. def _wrap_generic_output(self, result, obj):
  2456. if self.axis == 0:
  2457. new_axes = list(obj.axes)
  2458. new_axes[0] = self.grouper.result_index
  2459. elif self.axis == 1:
  2460. x, y, z = obj.axes
  2461. new_axes = [self.grouper.result_index, z, x]
  2462. else:
  2463. x, y, z = obj.axes
  2464. new_axes = [self.grouper.result_index, y, x]
  2465. result = Panel._from_axes(result, new_axes)
  2466. if self.axis == 1:
  2467. result = result.swapaxes(0, 1).swapaxes(0, 2)
  2468. elif self.axis == 2:
  2469. result = result.swapaxes(0, 2)
  2470. return result
  2471. def _aggregate_item_by_item(self, func, *args, **kwargs):
  2472. obj = self._obj_with_exclusions
  2473. result = {}
  2474. if self.axis > 0:
  2475. for item in obj:
  2476. try:
  2477. itemg = DataFrameGroupBy(obj[item],
  2478. axis=self.axis - 1,
  2479. grouper=self.grouper)
  2480. result[item] = itemg.aggregate(func, *args, **kwargs)
  2481. except (ValueError, TypeError):
  2482. raise
  2483. new_axes = list(obj.axes)
  2484. new_axes[self.axis] = self.grouper.result_index
  2485. return Panel._from_axes(result, new_axes)
  2486. else:
  2487. raise NotImplementedError
  2488. def _wrap_aggregated_output(self, output, names=None):
  2489. raise NotImplementedError
  2490. class NDArrayGroupBy(GroupBy):
  2491. pass
  2492. #----------------------------------------------------------------------
  2493. # Splitting / application
  2494. class DataSplitter(object):
  2495. def __init__(self, data, labels, ngroups, axis=0):
  2496. self.data = data
  2497. self.labels = com._ensure_int64(labels)
  2498. self.ngroups = ngroups
  2499. self.axis = axis
  2500. @cache_readonly
  2501. def slabels(self):
  2502. # Sorted labels
  2503. return com.take_nd(self.labels, self.sort_idx, allow_fill=False)
  2504. @cache_readonly
  2505. def sort_idx(self):
  2506. # Counting sort indexer
  2507. return _algos.groupsort_indexer(self.labels, self.ngroups)[0]
  2508. def __iter__(self):
  2509. sdata = self._get_sorted_data()
  2510. if self.ngroups == 0:
  2511. raise StopIteration
  2512. starts, ends = lib.generate_slices(self.slabels, self.ngroups)
  2513. for i, (start, end) in enumerate(zip(starts, ends)):
  2514. # Since I'm now compressing the group ids, it's now not "possible"
  2515. # to produce empty slices because such groups would not be observed
  2516. # in the data
  2517. # if start >= end:
  2518. # raise AssertionError('Start %s must be less than end %s'
  2519. # % (str(start), str(end)))
  2520. yield i, self._chop(sdata, slice(start, end))
  2521. def _get_sorted_data(self):
  2522. return self.data.take(self.sort_idx, axis=self.axis, convert=False)
  2523. def _chop(self, sdata, slice_obj):
  2524. return sdata.iloc[slice_obj]
  2525. def apply(self, f):
  2526. raise NotImplementedError
  2527. class ArraySplitter(DataSplitter):
  2528. pass
  2529. class SeriesSplitter(DataSplitter):
  2530. def _chop(self, sdata, slice_obj):
  2531. return sdata._get_values(slice_obj).to_dense()
  2532. class FrameSplitter(DataSplitter):
  2533. def __init__(self, data, labels, ngroups, axis=0):
  2534. super(FrameSplitter, self).__init__(data, labels, ngroups, axis=axis)
  2535. def fast_apply(self, f, names):
  2536. # must return keys::list, values::list, mutated::bool
  2537. try:
  2538. starts, ends = lib.generate_slices(self.slabels, self.ngroups)
  2539. except:
  2540. # fails when all -1
  2541. return [], True
  2542. sdata = self._get_sorted_data()
  2543. results, mutated = lib.apply_frame_axis0(sdata, f, names, starts, ends)
  2544. return results, mutated
  2545. def _chop(self, sdata, slice_obj):
  2546. if self.axis == 0:
  2547. return sdata.iloc[slice_obj]
  2548. else:
  2549. return sdata._slice(slice_obj, axis=1) # ix[:, slice_obj]
  2550. class NDFrameSplitter(DataSplitter):
  2551. def __init__(self, data, labels, ngroups, axis=0):
  2552. super(NDFrameSplitter, self).__init__(data, labels, ngroups, axis=axis)
  2553. self.factory = data._constructor
  2554. def _get_sorted_data(self):
  2555. # this is the BlockManager
  2556. data = self.data._data
  2557. # this is sort of wasteful but...
  2558. sorted_axis = data.axes[self.axis].take(self.sort_idx)
  2559. sorted_data = data.reindex_axis(sorted_axis, axis=self.axis)
  2560. return sorted_data
  2561. def _chop(self, sdata, slice_obj):
  2562. return self.factory(sdata.get_slice(slice_obj, axis=self.axis))
  2563. def get_splitter(data, *args, **kwargs):
  2564. if isinstance(data, Series):
  2565. klass = SeriesSplitter
  2566. elif isinstance(data, DataFrame):
  2567. klass = FrameSplitter
  2568. else:
  2569. klass = NDFrameSplitter
  2570. return klass(data, *args, **kwargs)
  2571. #----------------------------------------------------------------------
  2572. # Misc utilities
  2573. def get_group_index(label_list, shape):
  2574. """
  2575. For the particular label_list, gets the offsets into the hypothetical list
  2576. representing the totally ordered cartesian product of all possible label
  2577. combinations.
  2578. """
  2579. if len(label_list) == 1:
  2580. return label_list[0]
  2581. n = len(label_list[0])
  2582. group_index = np.zeros(n, dtype=np.int64)
  2583. mask = np.zeros(n, dtype=bool)
  2584. for i in range(len(shape)):
  2585. stride = np.prod([x for x in shape[i + 1:]], dtype=np.int64)
  2586. group_index += com._ensure_int64(label_list[i]) * stride
  2587. mask |= label_list[i] < 0
  2588. np.putmask(group_index, mask, -1)
  2589. return group_index
  2590. _INT64_MAX = np.iinfo(np.int64).max
  2591. def _int64_overflow_possible(shape):
  2592. the_prod = long(1)
  2593. for x in shape:
  2594. the_prod *= long(x)
  2595. return the_prod >= _INT64_MAX
  2596. def decons_group_index(comp_labels, shape):
  2597. # reconstruct labels
  2598. label_list = []
  2599. factor = 1
  2600. y = 0
  2601. x = comp_labels
  2602. for i in reversed(range(len(shape))):
  2603. labels = (x - y) % (factor * shape[i]) // factor
  2604. np.putmask(labels, comp_labels < 0, -1)
  2605. label_list.append(labels)
  2606. y = labels * factor
  2607. factor *= shape[i]
  2608. return label_list[::-1]
  2609. def _indexer_from_factorized(labels, shape, compress=True):
  2610. if _int64_overflow_possible(shape):
  2611. indexer = np.lexsort(np.array(labels[::-1]))
  2612. return indexer
  2613. group_index = get_group_index(labels, shape)
  2614. if compress:
  2615. comp_ids, obs_ids = _compress_group_index(group_index)
  2616. max_group = len(obs_ids)
  2617. else:
  2618. comp_ids = group_index
  2619. max_group = com._long_prod(shape)
  2620. if max_group > 1e6:
  2621. # Use mergesort to avoid memory errors in counting sort
  2622. indexer = comp_ids.argsort(kind='mergesort')
  2623. else:
  2624. indexer, _ = _algos.groupsort_indexer(comp_ids.astype(np.int64),
  2625. max_group)
  2626. return indexer
  2627. def _lexsort_indexer(keys, orders=None, na_position='last'):
  2628. labels = []
  2629. shape = []
  2630. if isinstance(orders, bool):
  2631. orders = [orders] * len(keys)
  2632. elif orders is None:
  2633. orders = [True] * len(keys)
  2634. for key, order in zip(keys, orders):
  2635. key = np.asanyarray(key)
  2636. rizer = _hash.Factorizer(len(key))
  2637. if not key.dtype == np.object_:
  2638. key = key.astype('O')
  2639. # factorize maps nans to na_sentinel=-1
  2640. ids = rizer.factorize(key, sort=True)
  2641. n = len(rizer.uniques)
  2642. mask = (ids == -1)
  2643. if order: # ascending
  2644. if na_position == 'last':
  2645. ids = np.where(mask, n, ids)
  2646. elif na_position == 'first':
  2647. ids += 1
  2648. else:
  2649. raise ValueError('invalid na_position: {!r}'.format(na_position))
  2650. else: # not order means descending
  2651. if na_position == 'last':
  2652. ids = np.where(mask, n, n-ids-1)
  2653. elif na_position == 'first':
  2654. ids = np.where(mask, 0, n-ids)
  2655. else:
  2656. raise ValueError('invalid na_position: {!r}'.format(na_position))
  2657. if mask.any():
  2658. n += 1
  2659. shape.append(n)
  2660. labels.append(ids)
  2661. return _indexer_from_factorized(labels, shape)
  2662. def _nargsort(items, kind='quicksort', ascending=True, na_position='last'):
  2663. """
  2664. This is intended to be a drop-in replacement for np.argsort which handles NaNs
  2665. It adds ascending and na_position parameters.
  2666. GH #6399, #5231
  2667. """
  2668. items = np.asanyarray(items)
  2669. idx = np.arange(len(items))
  2670. mask = isnull(items)
  2671. non_nans = items[~mask]
  2672. non_nan_idx = idx[~mask]
  2673. nan_idx = np.nonzero(mask)[0]
  2674. if not ascending:
  2675. non_nans = non_nans[::-1]
  2676. non_nan_idx = non_nan_idx[::-1]
  2677. indexer = non_nan_idx[non_nans.argsort(kind=kind)]
  2678. if not ascending:
  2679. indexer = indexer[::-1]
  2680. # Finally, place the NaNs at the end or the beginning according to na_position
  2681. if na_position == 'last':
  2682. indexer = np.concatenate([indexer, nan_idx])
  2683. elif na_position == 'first':
  2684. indexer = np.concatenate([nan_idx, indexer])
  2685. else:
  2686. raise ValueError('invalid na_position: {!r}'.format(na_position))
  2687. return indexer
  2688. class _KeyMapper(object):
  2689. """
  2690. Ease my suffering. Map compressed group id -> key tuple
  2691. """
  2692. def __init__(self, comp_ids, ngroups, labels, levels):
  2693. self.levels = levels
  2694. self.labels = labels
  2695. self.comp_ids = comp_ids.astype(np.int64)
  2696. self.k = len(labels)
  2697. self.tables = [_hash.Int64HashTable(ngroups) for _ in range(self.k)]
  2698. self._populate_tables()
  2699. def _populate_tables(self):
  2700. for labs, table in zip(self.labels, self.tables):
  2701. table.map(self.comp_ids, labs.astype(np.int64))
  2702. def get_key(self, comp_id):
  2703. return tuple(level[table.get_item(comp_id)]
  2704. for table, level in zip(self.tables, self.levels))
  2705. def _get_indices_dict(label_list, keys):
  2706. shape = [len(x) for x in keys]
  2707. group_index = get_group_index(label_list, shape)
  2708. sorter, _ = _algos.groupsort_indexer(com._ensure_int64(group_index),
  2709. np.prod(shape))
  2710. sorter_int = com._ensure_platform_int(sorter)
  2711. sorted_labels = [lab.take(sorter_int) for lab in label_list]
  2712. group_index = group_index.take(sorter_int)
  2713. return lib.indices_fast(sorter, group_index, keys, sorted_labels)
  2714. #----------------------------------------------------------------------
  2715. # sorting levels...cleverly?
  2716. def _compress_group_index(group_index, sort=True):
  2717. """
  2718. Group_index is offsets into cartesian product of all possible labels. This
  2719. space can be huge, so this function compresses it, by computing offsets
  2720. (comp_ids) into the list of unique labels (obs_group_ids).
  2721. """
  2722. table = _hash.Int64HashTable(min(1000000, len(group_index)))
  2723. group_index = com._ensure_int64(group_index)
  2724. # note, group labels come out ascending (ie, 1,2,3 etc)
  2725. comp_ids, obs_group_ids = table.get_labels_groupby(group_index)
  2726. if sort and len(obs_group_ids) > 0:
  2727. obs_group_ids, comp_ids = _reorder_by_uniques(obs_group_ids, comp_ids)
  2728. return comp_ids, obs_group_ids
  2729. def _reorder_by_uniques(uniques, labels):
  2730. # sorter is index where elements ought to go
  2731. sorter = uniques.argsort()
  2732. # reverse_indexer is where elements came from
  2733. reverse_indexer = np.empty(len(sorter), dtype=np.int64)
  2734. reverse_indexer.put(sorter, np.arange(len(sorter)))
  2735. mask = labels < 0
  2736. # move labels to right locations (ie, unsort ascending labels)
  2737. labels = com.take_nd(reverse_indexer, labels, allow_fill=False)
  2738. np.putmask(labels, mask, -1)
  2739. # sort observed ids
  2740. uniques = com.take_nd(uniques, sorter, allow_fill=False)
  2741. return uniques, labels
  2742. _func_table = {
  2743. builtins.sum: np.sum
  2744. }
  2745. _cython_table = {
  2746. builtins.sum: 'sum',
  2747. np.sum: 'sum',
  2748. np.mean: 'mean',
  2749. np.prod: 'prod',
  2750. np.std: 'std',
  2751. np.var: 'var',
  2752. np.median: 'median',
  2753. np.max: 'max',
  2754. np.min: 'min'
  2755. }
  2756. def _intercept_function(func):
  2757. return _func_table.get(func, func)
  2758. def _intercept_cython(func):
  2759. return _cython_table.get(func)
  2760. def _groupby_indices(values):
  2761. return _algos.groupby_indices(com._ensure_object(values))
  2762. def numpy_groupby(data, labels, axis=0):
  2763. s = np.argsort(labels)
  2764. keys, inv = np.unique(labels, return_inverse=True)
  2765. i = inv.take(s)
  2766. groups_at = np.where(i != np.concatenate(([-1], i[:-1])))[0]
  2767. ordered_data = data.take(s, axis=axis)
  2768. group_sums = np.add.reduceat(ordered_data, groups_at, axis=axis)
  2769. return group_sums