PageRenderTime 54ms CodeModel.GetById 0ms RepoModel.GetById 0ms app.codeStats 0ms

/pandas/tools/pivot.py

http://github.com/pydata/pandas
Python | 409 lines | 347 code | 16 blank | 46 comment | 14 complexity | 794644e044fe1ca15d6f5f7a8e758de3 MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0
  1. # pylint: disable=E1103
  2. import warnings
  3. from pandas import Series, DataFrame
  4. from pandas.core.index import MultiIndex
  5. from pandas.core.groupby import Grouper
  6. from pandas.tools.merge import concat
  7. from pandas.tools.util import cartesian_product
  8. from pandas.compat import range, lrange, zip
  9. from pandas.util.decorators import deprecate_kwarg
  10. from pandas import compat
  11. import pandas.core.common as com
  12. import numpy as np
  13. @deprecate_kwarg(old_arg_name='cols', new_arg_name='columns')
  14. @deprecate_kwarg(old_arg_name='rows', new_arg_name='index')
  15. def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
  16. fill_value=None, margins=False, dropna=True):
  17. """
  18. Create a spreadsheet-style pivot table as a DataFrame. The levels in the
  19. pivot table will be stored in MultiIndex objects (hierarchical indexes) on
  20. the index and columns of the result DataFrame
  21. Parameters
  22. ----------
  23. data : DataFrame
  24. values : column to aggregate, optional
  25. index : a column, Grouper, array which has the same length as data, or list of them.
  26. Keys to group by on the pivot table index.
  27. If an array is passed, it is being used as the same manner as column values.
  28. columns : a column, Grouper, array which has the same length as data, or list of them.
  29. Keys to group by on the pivot table column.
  30. If an array is passed, it is being used as the same manner as column values.
  31. aggfunc : function, default numpy.mean, or list of functions
  32. If list of functions passed, the resulting pivot table will have
  33. hierarchical columns whose top level are the function names (inferred
  34. from the function objects themselves)
  35. fill_value : scalar, default None
  36. Value to replace missing values with
  37. margins : boolean, default False
  38. Add all row / columns (e.g. for subtotal / grand totals)
  39. dropna : boolean, default True
  40. Do not include columns whose entries are all NaN
  41. rows : kwarg only alias of index [deprecated]
  42. cols : kwarg only alias of columns [deprecated]
  43. Examples
  44. --------
  45. >>> df
  46. A B C D
  47. 0 foo one small 1
  48. 1 foo one large 2
  49. 2 foo one large 2
  50. 3 foo two small 3
  51. 4 foo two small 3
  52. 5 bar one large 4
  53. 6 bar one small 5
  54. 7 bar two small 6
  55. 8 bar two large 7
  56. >>> table = pivot_table(df, values='D', index=['A', 'B'],
  57. ... columns=['C'], aggfunc=np.sum)
  58. >>> table
  59. small large
  60. foo one 1 4
  61. two 6 NaN
  62. bar one 5 4
  63. two 6 7
  64. Returns
  65. -------
  66. table : DataFrame
  67. """
  68. index = _convert_by(index)
  69. columns = _convert_by(columns)
  70. if isinstance(aggfunc, list):
  71. pieces = []
  72. keys = []
  73. for func in aggfunc:
  74. table = pivot_table(data, values=values, index=index, columns=columns,
  75. fill_value=fill_value, aggfunc=func,
  76. margins=margins)
  77. pieces.append(table)
  78. keys.append(func.__name__)
  79. return concat(pieces, keys=keys, axis=1)
  80. keys = index + columns
  81. values_passed = values is not None
  82. if values_passed:
  83. if isinstance(values, (list, tuple)):
  84. values_multi = True
  85. else:
  86. values_multi = False
  87. values = [values]
  88. else:
  89. values = list(data.columns.drop(keys))
  90. if values_passed:
  91. to_filter = []
  92. for x in keys + values:
  93. if isinstance(x, Grouper):
  94. x = x.key
  95. try:
  96. if x in data:
  97. to_filter.append(x)
  98. except TypeError:
  99. pass
  100. if len(to_filter) < len(data.columns):
  101. data = data[to_filter]
  102. grouped = data.groupby(keys)
  103. agged = grouped.agg(aggfunc)
  104. table = agged
  105. if table.index.nlevels > 1:
  106. to_unstack = [agged.index.names[i]
  107. for i in range(len(index), len(keys))]
  108. table = agged.unstack(to_unstack)
  109. if not dropna:
  110. try:
  111. m = MultiIndex.from_arrays(cartesian_product(table.index.levels))
  112. table = table.reindex_axis(m, axis=0)
  113. except AttributeError:
  114. pass # it's a single level
  115. try:
  116. m = MultiIndex.from_arrays(cartesian_product(table.columns.levels))
  117. table = table.reindex_axis(m, axis=1)
  118. except AttributeError:
  119. pass # it's a single level or a series
  120. if isinstance(table, DataFrame):
  121. if isinstance(table.columns, MultiIndex):
  122. table = table.sortlevel(axis=1)
  123. else:
  124. table = table.sort_index(axis=1)
  125. if fill_value is not None:
  126. table = table.fillna(value=fill_value, downcast='infer')
  127. if margins:
  128. table = _add_margins(table, data, values, rows=index,
  129. cols=columns, aggfunc=aggfunc)
  130. # discard the top level
  131. if values_passed and not values_multi:
  132. table = table[values[0]]
  133. if len(index) == 0 and len(columns) > 0:
  134. table = table.T
  135. return table
  136. DataFrame.pivot_table = pivot_table
  137. def _add_margins(table, data, values, rows, cols, aggfunc):
  138. grand_margin = _compute_grand_margin(data, values, aggfunc)
  139. if not values and isinstance(table, Series):
  140. # If there are no values and the table is a series, then there is only
  141. # one column in the data. Compute grand margin and return it.
  142. row_key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
  143. return table.append(Series({row_key: grand_margin['All']}))
  144. if values:
  145. marginal_result_set = _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin)
  146. if not isinstance(marginal_result_set, tuple):
  147. return marginal_result_set
  148. result, margin_keys, row_margin = marginal_result_set
  149. else:
  150. marginal_result_set = _generate_marginal_results_without_values(table, data, rows, cols, aggfunc)
  151. if not isinstance(marginal_result_set, tuple):
  152. return marginal_result_set
  153. result, margin_keys, row_margin = marginal_result_set
  154. key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
  155. row_margin = row_margin.reindex(result.columns)
  156. # populate grand margin
  157. for k in margin_keys:
  158. if isinstance(k, compat.string_types):
  159. row_margin[k] = grand_margin[k]
  160. else:
  161. row_margin[k] = grand_margin[k[0]]
  162. margin_dummy = DataFrame(row_margin, columns=[key]).T
  163. row_names = result.index.names
  164. result = result.append(margin_dummy)
  165. result.index.names = row_names
  166. return result
  167. def _compute_grand_margin(data, values, aggfunc):
  168. if values:
  169. grand_margin = {}
  170. for k, v in data[values].iteritems():
  171. try:
  172. if isinstance(aggfunc, compat.string_types):
  173. grand_margin[k] = getattr(v, aggfunc)()
  174. else:
  175. grand_margin[k] = aggfunc(v)
  176. except TypeError:
  177. pass
  178. return grand_margin
  179. else:
  180. return {'All': aggfunc(data.index)}
  181. def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin):
  182. if len(cols) > 0:
  183. # need to "interleave" the margins
  184. table_pieces = []
  185. margin_keys = []
  186. def _all_key(key):
  187. return (key, 'All') + ('',) * (len(cols) - 1)
  188. if len(rows) > 0:
  189. margin = data[rows + values].groupby(rows).agg(aggfunc)
  190. cat_axis = 1
  191. for key, piece in table.groupby(level=0, axis=cat_axis):
  192. all_key = _all_key(key)
  193. piece[all_key] = margin[key]
  194. table_pieces.append(piece)
  195. margin_keys.append(all_key)
  196. else:
  197. margin = grand_margin
  198. cat_axis = 0
  199. for key, piece in table.groupby(level=0, axis=cat_axis):
  200. all_key = _all_key(key)
  201. table_pieces.append(piece)
  202. table_pieces.append(Series(margin[key], index=[all_key]))
  203. margin_keys.append(all_key)
  204. result = concat(table_pieces, axis=cat_axis)
  205. if len(rows) == 0:
  206. return result
  207. else:
  208. result = table
  209. margin_keys = table.columns
  210. if len(cols) > 0:
  211. row_margin = data[cols + values].groupby(cols).agg(aggfunc)
  212. row_margin = row_margin.stack()
  213. # slight hack
  214. new_order = [len(cols)] + lrange(len(cols))
  215. row_margin.index = row_margin.index.reorder_levels(new_order)
  216. else:
  217. row_margin = Series(np.nan, index=result.columns)
  218. return result, margin_keys, row_margin
  219. def _generate_marginal_results_without_values(table, data, rows, cols, aggfunc):
  220. if len(cols) > 0:
  221. # need to "interleave" the margins
  222. margin_keys = []
  223. def _all_key():
  224. if len(cols) == 1:
  225. return 'All'
  226. return ('All', ) + ('', ) * (len(cols) - 1)
  227. if len(rows) > 0:
  228. margin = data[rows].groupby(rows).apply(aggfunc)
  229. all_key = _all_key()
  230. table[all_key] = margin
  231. result = table
  232. margin_keys.append(all_key)
  233. else:
  234. margin = data.groupby(level=0, axis=0).apply(aggfunc)
  235. all_key = _all_key()
  236. table[all_key] = margin
  237. result = table
  238. margin_keys.append(all_key)
  239. return result
  240. else:
  241. result = table
  242. margin_keys = table.columns
  243. if len(cols):
  244. row_margin = data[cols].groupby(cols).apply(aggfunc)
  245. else:
  246. row_margin = Series(np.nan, index=result.columns)
  247. return result, margin_keys, row_margin
  248. def _convert_by(by):
  249. if by is None:
  250. by = []
  251. elif (np.isscalar(by) or isinstance(by, (np.ndarray, Series, Grouper))
  252. or hasattr(by, '__call__')):
  253. by = [by]
  254. else:
  255. by = list(by)
  256. return by
  257. @deprecate_kwarg(old_arg_name='cols', new_arg_name='columns')
  258. @deprecate_kwarg(old_arg_name='rows', new_arg_name='index')
  259. def crosstab(index, columns, values=None, rownames=None, colnames=None,
  260. aggfunc=None, margins=False, dropna=True):
  261. """
  262. Compute a simple cross-tabulation of two (or more) factors. By default
  263. computes a frequency table of the factors unless an array of values and an
  264. aggregation function are passed
  265. Parameters
  266. ----------
  267. index : array-like, Series, or list of arrays/Series
  268. Values to group by in the rows
  269. columns : array-like, Series, or list of arrays/Series
  270. Values to group by in the columns
  271. values : array-like, optional
  272. Array of values to aggregate according to the factors
  273. aggfunc : function, optional
  274. If no values array is passed, computes a frequency table
  275. rownames : sequence, default None
  276. If passed, must match number of row arrays passed
  277. colnames : sequence, default None
  278. If passed, must match number of column arrays passed
  279. margins : boolean, default False
  280. Add row/column margins (subtotals)
  281. dropna : boolean, default True
  282. Do not include columns whose entries are all NaN
  283. rows : kwarg only alias of index [deprecated]
  284. cols : kwarg only alias of columns [deprecated]
  285. Notes
  286. -----
  287. Any Series passed will have their name attributes used unless row or column
  288. names for the cross-tabulation are specified
  289. Examples
  290. --------
  291. >>> a
  292. array([foo, foo, foo, foo, bar, bar,
  293. bar, bar, foo, foo, foo], dtype=object)
  294. >>> b
  295. array([one, one, one, two, one, one,
  296. one, two, two, two, one], dtype=object)
  297. >>> c
  298. array([dull, dull, shiny, dull, dull, shiny,
  299. shiny, dull, shiny, shiny, shiny], dtype=object)
  300. >>> crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'])
  301. b one two
  302. c dull shiny dull shiny
  303. a
  304. bar 1 2 1 0
  305. foo 2 2 1 2
  306. Returns
  307. -------
  308. crosstab : DataFrame
  309. """
  310. index = com._maybe_make_list(index)
  311. columns = com._maybe_make_list(columns)
  312. rownames = _get_names(index, rownames, prefix='row')
  313. colnames = _get_names(columns, colnames, prefix='col')
  314. data = {}
  315. data.update(zip(rownames, index))
  316. data.update(zip(colnames, columns))
  317. if values is None:
  318. df = DataFrame(data)
  319. df['__dummy__'] = 0
  320. table = df.pivot_table('__dummy__', index=rownames, columns=colnames,
  321. aggfunc=len, margins=margins, dropna=dropna)
  322. return table.fillna(0).astype(np.int64)
  323. else:
  324. data['__dummy__'] = values
  325. df = DataFrame(data)
  326. table = df.pivot_table('__dummy__', index=rownames, columns=colnames,
  327. aggfunc=aggfunc, margins=margins, dropna=dropna)
  328. return table
  329. def _get_names(arrs, names, prefix='row'):
  330. if names is None:
  331. names = []
  332. for i, arr in enumerate(arrs):
  333. if isinstance(arr, Series) and arr.name is not None:
  334. names.append(arr.name)
  335. else:
  336. names.append('%s_%d' % (prefix, i))
  337. else:
  338. if len(names) != len(arrs):
  339. raise AssertionError('arrays and names must have the same length')
  340. if not isinstance(names, list):
  341. names = list(names)
  342. return names