PageRenderTime 481ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/pandas/tools/pivot.py

https://github.com/smc77/pandas
Python | 291 lines | 252 code | 2 blank | 37 comment | 0 complexity | 6a5d3440c71474639096df242ca53d81 MD5 | raw file
  1. # pylint: disable=E1103
  2. from pandas import Series, DataFrame
  3. from pandas.tools.merge import concat
  4. import pandas.core.common as com
  5. import numpy as np
  6. import types
  7. def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
  8. fill_value=None, margins=False):
  9. """
  10. Create a spreadsheet-style pivot table as a DataFrame. The levels in the
  11. pivot table will be stored in MultiIndex objects (hierarchical indexes) on
  12. the index and columns of the result DataFrame
  13. Parameters
  14. ----------
  15. data : DataFrame
  16. values : column to aggregate, optional
  17. rows : list of column names or arrays to group on
  18. Keys to group on the x-axis of the pivot table
  19. cols : list of column names or arrays to group on
  20. Keys to group on the x-axis of the pivot table
  21. aggfunc : function, default numpy.mean, or list of functions
  22. If list of functions passed, the resulting pivot table will have
  23. hierarchical columns whose top level are the function names (inferred
  24. from the function objects themselves)
  25. fill_value : scalar, default None
  26. Value to replace missing values with
  27. margins : boolean, default False
  28. Add all row / columns (e.g. for subtotal / grand totals)
  29. Examples
  30. --------
  31. >>> df
  32. A B C D
  33. 0 foo one small 1
  34. 1 foo one large 2
  35. 2 foo one large 2
  36. 3 foo two small 3
  37. 4 foo two small 3
  38. 5 bar one large 4
  39. 6 bar one small 5
  40. 7 bar two small 6
  41. 8 bar two large 7
  42. >>> table = pivot_table(df, values='D', rows=['A', 'B'],
  43. ... cols=['C'], aggfunc=np.sum)
  44. >>> table
  45. small large
  46. foo one 1 4
  47. two 6 NaN
  48. bar one 5 4
  49. two 6 7
  50. Returns
  51. -------
  52. table : DataFrame
  53. """
  54. rows = _convert_by(rows)
  55. cols = _convert_by(cols)
  56. if isinstance(aggfunc, list):
  57. pieces = []
  58. keys = []
  59. for func in aggfunc:
  60. table = pivot_table(data, values=values, rows=rows, cols=cols,
  61. fill_value=fill_value, aggfunc=func,
  62. margins=margins)
  63. pieces.append(table)
  64. keys.append(func.__name__)
  65. return concat(pieces, keys=keys, axis=1)
  66. keys = rows + cols
  67. values_passed = values is not None
  68. if values_passed:
  69. if isinstance(values, (list, tuple)):
  70. values_multi = True
  71. else:
  72. values_multi = False
  73. values = [values]
  74. else:
  75. values = list(data.columns.drop(keys))
  76. if values_passed:
  77. to_filter = []
  78. for x in keys + values:
  79. try:
  80. if x in data:
  81. to_filter.append(x)
  82. except TypeError:
  83. pass
  84. if len(to_filter) < len(data.columns):
  85. data = data[to_filter]
  86. grouped = data.groupby(keys)
  87. agged = grouped.agg(aggfunc)
  88. table = agged
  89. for i in range(len(cols)):
  90. name = table.index.names[len(rows)]
  91. table = table.unstack(name)
  92. if fill_value is not None:
  93. table = table.fillna(value=fill_value)
  94. if margins:
  95. table = _add_margins(table, data, values, rows=rows,
  96. cols=cols, aggfunc=aggfunc)
  97. # discard the top level
  98. if values_passed and not values_multi:
  99. table = table[values[0]]
  100. return table
  101. DataFrame.pivot_table = pivot_table
  102. def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
  103. grand_margin = {}
  104. for k, v in data[values].iteritems():
  105. try:
  106. if isinstance(aggfunc, basestring):
  107. grand_margin[k] = getattr(v, aggfunc)()
  108. else:
  109. grand_margin[k] = aggfunc(v)
  110. except TypeError:
  111. pass
  112. if len(cols) > 0:
  113. # need to "interleave" the margins
  114. table_pieces = []
  115. margin_keys = []
  116. def _all_key(key):
  117. return (key, 'All') + ('',) * (len(cols) - 1)
  118. if len(rows) > 0:
  119. margin = data[rows + values].groupby(rows).agg(aggfunc)
  120. cat_axis = 1
  121. for key, piece in table.groupby(level=0, axis=cat_axis):
  122. all_key = _all_key(key)
  123. piece[all_key] = margin[key]
  124. table_pieces.append(piece)
  125. margin_keys.append(all_key)
  126. else:
  127. margin = grand_margin
  128. cat_axis = 0
  129. for key, piece in table.groupby(level=0, axis=cat_axis):
  130. all_key = _all_key(key)
  131. table_pieces.append(piece)
  132. table_pieces.append(Series(margin[key], index=[all_key]))
  133. margin_keys.append(all_key)
  134. result = concat(table_pieces, axis=cat_axis)
  135. if len(rows) == 0:
  136. return result
  137. else:
  138. result = table
  139. margin_keys = table.columns
  140. if len(cols) > 0:
  141. row_margin = data[cols + values].groupby(cols).agg(aggfunc)
  142. row_margin = row_margin.stack()
  143. # slight hack
  144. new_order = [len(cols)] + range(len(cols))
  145. row_margin.index = row_margin.index.reorder_levels(new_order)
  146. else:
  147. row_margin = Series(np.nan, index=result.columns)
  148. key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
  149. row_margin = row_margin.reindex(result.columns)
  150. # populate grand margin
  151. for k in margin_keys:
  152. if len(cols) > 0:
  153. row_margin[k] = grand_margin[k[0]]
  154. else:
  155. row_margin[k] = grand_margin[k]
  156. margin_dummy = DataFrame(row_margin, columns=[key]).T
  157. row_names = result.index.names
  158. result = result.append(margin_dummy)
  159. result.index.names = row_names
  160. return result
  161. def _convert_by(by):
  162. if by is None:
  163. by = []
  164. elif (np.isscalar(by) or isinstance(by, np.ndarray)
  165. or hasattr(by, '__call__')):
  166. by = [by]
  167. else:
  168. by = list(by)
  169. return by
  170. def crosstab(rows, cols, values=None, rownames=None, colnames=None,
  171. aggfunc=None, margins=False):
  172. """
  173. Compute a simple cross-tabulation of two (or more) factors. By default
  174. computes a frequency table of the factors unless an array of values and an
  175. aggregation function are passed
  176. Parameters
  177. ----------
  178. rows : array-like, Series, or list of arrays/Series
  179. Values to group by in the rows
  180. cols : array-like, Series, or list of arrays/Series
  181. Values to group by in the columns
  182. values : array-like, optional
  183. Array of values to aggregate according to the factors
  184. aggfunc : function, optional
  185. If no values array is passed, computes a frequency table
  186. rownames : sequence, default None
  187. If passed, must match number of row arrays passed
  188. colnames : sequence, default None
  189. If passed, must match number of column arrays passed
  190. margins : boolean, default False
  191. Add row/column margins (subtotals)
  192. Notes
  193. -----
  194. Any Series passed will have their name attributes used unless row or column
  195. names for the cross-tabulation are specified
  196. Examples
  197. --------
  198. >>> a
  199. array([foo, foo, foo, foo, bar, bar,
  200. bar, bar, foo, foo, foo], dtype=object)
  201. >>> b
  202. array([one, one, one, two, one, one,
  203. one, two, two, two, one], dtype=object)
  204. >>> c
  205. array([dull, dull, shiny, dull, dull, shiny,
  206. shiny, dull, shiny, shiny, shiny], dtype=object)
  207. >>> crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'])
  208. b one two
  209. c dull shiny dull shiny
  210. a
  211. bar 1 2 1 0
  212. foo 2 2 1 2
  213. Returns
  214. -------
  215. crosstab : DataFrame
  216. """
  217. rows = com._maybe_make_list(rows)
  218. cols = com._maybe_make_list(cols)
  219. rownames = _get_names(rows, rownames, prefix='row')
  220. colnames = _get_names(cols, colnames, prefix='col')
  221. data = {}
  222. data.update(zip(rownames, rows))
  223. data.update(zip(colnames, cols))
  224. if values is None:
  225. df = DataFrame(data)
  226. df['__dummy__'] = 0
  227. table = df.pivot_table('__dummy__', rows=rownames, cols=colnames,
  228. aggfunc=len, margins=margins)
  229. return table.fillna(0).astype(np.int64)
  230. else:
  231. data['__dummy__'] = values
  232. df = DataFrame(data)
  233. table = df.pivot_table('__dummy__', rows=rownames, cols=colnames,
  234. aggfunc=aggfunc, margins=margins)
  235. return table
  236. def _get_names(arrs, names, prefix='row'):
  237. if names is None:
  238. names = []
  239. for i, arr in enumerate(arrs):
  240. if isinstance(arr, Series) and arr.name is not None:
  241. names.append(arr.name)
  242. else:
  243. names.append('%s_%d' % (prefix, i))
  244. else:
  245. assert(len(names) == len(arrs))
  246. if not isinstance(names, list):
  247. names = list(names)
  248. return names