PageRenderTime 516ms CodeModel.GetById 28ms RepoModel.GetById 2ms app.codeStats 1ms

/pandas/tools/plotting.py

http://github.com/pydata/pandas
Python | 3075 lines | 2730 code | 172 blank | 173 comment | 229 complexity | fb838f97cf0225a6a627aac3491130f7 MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0
  1. # being a bit too dynamic
  2. # pylint: disable=E1101
  3. import datetime
  4. import warnings
  5. import re
  6. from collections import namedtuple
  7. from contextlib import contextmanager
  8. from distutils.version import LooseVersion
  9. import numpy as np
  10. from pandas.util.decorators import cache_readonly, deprecate_kwarg
  11. import pandas.core.common as com
  12. from pandas.core.generic import _shared_docs, _shared_doc_kwargs
  13. from pandas.core.index import MultiIndex
  14. from pandas.core.series import Series, remove_na
  15. from pandas.tseries.index import DatetimeIndex
  16. from pandas.tseries.period import PeriodIndex, Period
  17. from pandas.tseries.frequencies import get_period_alias, get_base_alias
  18. from pandas.tseries.offsets import DateOffset
  19. from pandas.compat import range, lrange, lmap, map, zip, string_types
  20. import pandas.compat as compat
  21. from pandas.util.decorators import Appender
  22. try: # mpl optional
  23. import pandas.tseries.converter as conv
  24. conv.register() # needs to override so set_xlim works with str/number
  25. except ImportError:
  26. pass
  27. # Extracted from https://gist.github.com/huyng/816622
  28. # this is the rcParams set when setting display.with_mpl_style
  29. # to True.
  30. mpl_stylesheet = {
  31. 'axes.axisbelow': True,
  32. 'axes.color_cycle': ['#348ABD',
  33. '#7A68A6',
  34. '#A60628',
  35. '#467821',
  36. '#CF4457',
  37. '#188487',
  38. '#E24A33'],
  39. 'axes.edgecolor': '#bcbcbc',
  40. 'axes.facecolor': '#eeeeee',
  41. 'axes.grid': True,
  42. 'axes.labelcolor': '#555555',
  43. 'axes.labelsize': 'large',
  44. 'axes.linewidth': 1.0,
  45. 'axes.titlesize': 'x-large',
  46. 'figure.edgecolor': 'white',
  47. 'figure.facecolor': 'white',
  48. 'figure.figsize': (6.0, 4.0),
  49. 'figure.subplot.hspace': 0.5,
  50. 'font.family': 'monospace',
  51. 'font.monospace': ['Andale Mono',
  52. 'Nimbus Mono L',
  53. 'Courier New',
  54. 'Courier',
  55. 'Fixed',
  56. 'Terminal',
  57. 'monospace'],
  58. 'font.size': 10,
  59. 'interactive': True,
  60. 'keymap.all_axes': ['a'],
  61. 'keymap.back': ['left', 'c', 'backspace'],
  62. 'keymap.forward': ['right', 'v'],
  63. 'keymap.fullscreen': ['f'],
  64. 'keymap.grid': ['g'],
  65. 'keymap.home': ['h', 'r', 'home'],
  66. 'keymap.pan': ['p'],
  67. 'keymap.save': ['s'],
  68. 'keymap.xscale': ['L', 'k'],
  69. 'keymap.yscale': ['l'],
  70. 'keymap.zoom': ['o'],
  71. 'legend.fancybox': True,
  72. 'lines.antialiased': True,
  73. 'lines.linewidth': 1.0,
  74. 'patch.antialiased': True,
  75. 'patch.edgecolor': '#EEEEEE',
  76. 'patch.facecolor': '#348ABD',
  77. 'patch.linewidth': 0.5,
  78. 'toolbar': 'toolbar2',
  79. 'xtick.color': '#555555',
  80. 'xtick.direction': 'in',
  81. 'xtick.major.pad': 6.0,
  82. 'xtick.major.size': 0.0,
  83. 'xtick.minor.pad': 6.0,
  84. 'xtick.minor.size': 0.0,
  85. 'ytick.color': '#555555',
  86. 'ytick.direction': 'in',
  87. 'ytick.major.pad': 6.0,
  88. 'ytick.major.size': 0.0,
  89. 'ytick.minor.pad': 6.0,
  90. 'ytick.minor.size': 0.0
  91. }
  92. def _get_standard_kind(kind):
  93. return {'density': 'kde'}.get(kind, kind)
  94. def _get_standard_colors(num_colors=None, colormap=None, color_type='default',
  95. color=None):
  96. import matplotlib.pyplot as plt
  97. if color is None and colormap is not None:
  98. if isinstance(colormap, compat.string_types):
  99. import matplotlib.cm as cm
  100. cmap = colormap
  101. colormap = cm.get_cmap(colormap)
  102. if colormap is None:
  103. raise ValueError("Colormap {0} is not recognized".format(cmap))
  104. colors = lmap(colormap, np.linspace(0, 1, num=num_colors))
  105. elif color is not None:
  106. if colormap is not None:
  107. warnings.warn("'color' and 'colormap' cannot be used "
  108. "simultaneously. Using 'color'")
  109. colors = color
  110. else:
  111. if color_type == 'default':
  112. colors = plt.rcParams.get('axes.color_cycle', list('bgrcmyk'))
  113. if isinstance(colors, compat.string_types):
  114. colors = list(colors)
  115. elif color_type == 'random':
  116. import random
  117. def random_color(column):
  118. random.seed(column)
  119. return [random.random() for _ in range(3)]
  120. colors = lmap(random_color, lrange(num_colors))
  121. else:
  122. raise NotImplementedError
  123. if len(colors) != num_colors:
  124. multiple = num_colors//len(colors) - 1
  125. mod = num_colors % len(colors)
  126. colors += multiple * colors
  127. colors += colors[:mod]
  128. return colors
  129. class _Options(dict):
  130. """
  131. Stores pandas plotting options.
  132. Allows for parameter aliasing so you can just use parameter names that are
  133. the same as the plot function parameters, but is stored in a canonical
  134. format that makes it easy to breakdown into groups later
  135. """
  136. # alias so the names are same as plotting method parameter names
  137. _ALIASES = {'x_compat': 'xaxis.compat'}
  138. _DEFAULT_KEYS = ['xaxis.compat']
  139. def __init__(self):
  140. self['xaxis.compat'] = False
  141. def __getitem__(self, key):
  142. key = self._get_canonical_key(key)
  143. if key not in self:
  144. raise ValueError('%s is not a valid pandas plotting option' % key)
  145. return super(_Options, self).__getitem__(key)
  146. def __setitem__(self, key, value):
  147. key = self._get_canonical_key(key)
  148. return super(_Options, self).__setitem__(key, value)
  149. def __delitem__(self, key):
  150. key = self._get_canonical_key(key)
  151. if key in self._DEFAULT_KEYS:
  152. raise ValueError('Cannot remove default parameter %s' % key)
  153. return super(_Options, self).__delitem__(key)
  154. def __contains__(self, key):
  155. key = self._get_canonical_key(key)
  156. return super(_Options, self).__contains__(key)
  157. def reset(self):
  158. """
  159. Reset the option store to its initial state
  160. Returns
  161. -------
  162. None
  163. """
  164. self.__init__()
  165. def _get_canonical_key(self, key):
  166. return self._ALIASES.get(key, key)
  167. @contextmanager
  168. def use(self, key, value):
  169. """
  170. Temporarily set a parameter value using the with statement.
  171. Aliasing allowed.
  172. """
  173. old_value = self[key]
  174. try:
  175. self[key] = value
  176. yield self
  177. finally:
  178. self[key] = old_value
  179. plot_params = _Options()
  180. def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
  181. diagonal='hist', marker='.', density_kwds=None,
  182. hist_kwds=None, range_padding=0.05, **kwds):
  183. """
  184. Draw a matrix of scatter plots.
  185. Parameters
  186. ----------
  187. frame : DataFrame
  188. alpha : float, optional
  189. amount of transparency applied
  190. figsize : (float,float), optional
  191. a tuple (width, height) in inches
  192. ax : Matplotlib axis object, optional
  193. grid : bool, optional
  194. setting this to True will show the grid
  195. diagonal : {'hist', 'kde'}
  196. pick between 'kde' and 'hist' for
  197. either Kernel Density Estimation or Histogram
  198. plot in the diagonal
  199. marker : str, optional
  200. Matplotlib marker type, default '.'
  201. hist_kwds : other plotting keyword arguments
  202. To be passed to hist function
  203. density_kwds : other plotting keyword arguments
  204. To be passed to kernel density estimate plot
  205. range_padding : float, optional
  206. relative extension of axis range in x and y
  207. with respect to (x_max - x_min) or (y_max - y_min),
  208. default 0.05
  209. kwds : other plotting keyword arguments
  210. To be passed to scatter function
  211. Examples
  212. --------
  213. >>> df = DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
  214. >>> scatter_matrix(df, alpha=0.2)
  215. """
  216. import matplotlib.pyplot as plt
  217. from matplotlib.artist import setp
  218. df = frame._get_numeric_data()
  219. n = df.columns.size
  220. fig, axes = _subplots(nrows=n, ncols=n, figsize=figsize, ax=ax,
  221. squeeze=False)
  222. # no gaps between subplots
  223. fig.subplots_adjust(wspace=0, hspace=0)
  224. mask = com.notnull(df)
  225. marker = _get_marker_compat(marker)
  226. hist_kwds = hist_kwds or {}
  227. density_kwds = density_kwds or {}
  228. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  229. kwds.setdefault('c', plt.rcParams['patch.facecolor'])
  230. boundaries_list = []
  231. for a in df.columns:
  232. values = df[a].values[mask[a].values]
  233. rmin_, rmax_ = np.min(values), np.max(values)
  234. rdelta_ext = (rmax_ - rmin_) * range_padding / 2.
  235. boundaries_list.append((rmin_ - rdelta_ext, rmax_+ rdelta_ext))
  236. for i, a in zip(lrange(n), df.columns):
  237. for j, b in zip(lrange(n), df.columns):
  238. ax = axes[i, j]
  239. if i == j:
  240. values = df[a].values[mask[a].values]
  241. # Deal with the diagonal by drawing a histogram there.
  242. if diagonal == 'hist':
  243. ax.hist(values, **hist_kwds)
  244. elif diagonal in ('kde', 'density'):
  245. from scipy.stats import gaussian_kde
  246. y = values
  247. gkde = gaussian_kde(y)
  248. ind = np.linspace(y.min(), y.max(), 1000)
  249. ax.plot(ind, gkde.evaluate(ind), **density_kwds)
  250. ax.set_xlim(boundaries_list[i])
  251. else:
  252. common = (mask[a] & mask[b]).values
  253. ax.scatter(df[b][common], df[a][common],
  254. marker=marker, alpha=alpha, **kwds)
  255. ax.set_xlim(boundaries_list[j])
  256. ax.set_ylim(boundaries_list[i])
  257. ax.set_xlabel('')
  258. ax.set_ylabel('')
  259. _label_axis(ax, kind='x', label=b, position='bottom', rotate=True)
  260. _label_axis(ax, kind='y', label=a, position='left')
  261. if j!= 0:
  262. ax.yaxis.set_visible(False)
  263. if i != n-1:
  264. ax.xaxis.set_visible(False)
  265. for ax in axes.flat:
  266. setp(ax.get_xticklabels(), fontsize=8)
  267. setp(ax.get_yticklabels(), fontsize=8)
  268. return axes
  269. def _label_axis(ax, kind='x', label='', position='top',
  270. ticks=True, rotate=False):
  271. from matplotlib.artist import setp
  272. if kind == 'x':
  273. ax.set_xlabel(label, visible=True)
  274. ax.xaxis.set_visible(True)
  275. ax.xaxis.set_ticks_position(position)
  276. ax.xaxis.set_label_position(position)
  277. if rotate:
  278. setp(ax.get_xticklabels(), rotation=90)
  279. elif kind == 'y':
  280. ax.yaxis.set_visible(True)
  281. ax.set_ylabel(label, visible=True)
  282. # ax.set_ylabel(a)
  283. ax.yaxis.set_ticks_position(position)
  284. ax.yaxis.set_label_position(position)
  285. return
  286. def _gca():
  287. import matplotlib.pyplot as plt
  288. return plt.gca()
  289. def _gcf():
  290. import matplotlib.pyplot as plt
  291. return plt.gcf()
  292. def _get_marker_compat(marker):
  293. import matplotlib.lines as mlines
  294. import matplotlib as mpl
  295. if mpl.__version__ < '1.1.0' and marker == '.':
  296. return 'o'
  297. if marker not in mlines.lineMarkers:
  298. return 'o'
  299. return marker
  300. def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
  301. """RadViz - a multivariate data visualization algorithm
  302. Parameters:
  303. -----------
  304. frame: DataFrame
  305. class_column: str
  306. Column name containing class names
  307. ax: Matplotlib axis object, optional
  308. color: list or tuple, optional
  309. Colors to use for the different classes
  310. colormap : str or matplotlib colormap object, default None
  311. Colormap to select colors from. If string, load colormap with that name
  312. from matplotlib.
  313. kwds: keywords
  314. Options to pass to matplotlib scatter plotting method
  315. Returns:
  316. --------
  317. ax: Matplotlib axis object
  318. """
  319. import matplotlib.pyplot as plt
  320. import matplotlib.patches as patches
  321. def normalize(series):
  322. a = min(series)
  323. b = max(series)
  324. return (series - a) / (b - a)
  325. n = len(frame)
  326. classes = frame[class_column].drop_duplicates()
  327. class_col = frame[class_column]
  328. df = frame.drop(class_column, axis=1).apply(normalize)
  329. if ax is None:
  330. ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
  331. to_plot = {}
  332. colors = _get_standard_colors(num_colors=len(classes), colormap=colormap,
  333. color_type='random', color=color)
  334. for kls in classes:
  335. to_plot[kls] = [[], []]
  336. n = len(frame.columns) - 1
  337. s = np.array([(np.cos(t), np.sin(t))
  338. for t in [2.0 * np.pi * (i / float(n))
  339. for i in range(n)]])
  340. for i in range(n):
  341. row = df.iloc[i].values
  342. row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
  343. y = (s * row_).sum(axis=0) / row.sum()
  344. kls = class_col.iat[i]
  345. to_plot[kls][0].append(y[0])
  346. to_plot[kls][1].append(y[1])
  347. for i, kls in enumerate(classes):
  348. ax.scatter(to_plot[kls][0], to_plot[kls][1], color=colors[i],
  349. label=com.pprint_thing(kls), **kwds)
  350. ax.legend()
  351. ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
  352. for xy, name in zip(s, df.columns):
  353. ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
  354. if xy[0] < 0.0 and xy[1] < 0.0:
  355. ax.text(xy[0] - 0.025, xy[1] - 0.025, name,
  356. ha='right', va='top', size='small')
  357. elif xy[0] < 0.0 and xy[1] >= 0.0:
  358. ax.text(xy[0] - 0.025, xy[1] + 0.025, name,
  359. ha='right', va='bottom', size='small')
  360. elif xy[0] >= 0.0 and xy[1] < 0.0:
  361. ax.text(xy[0] + 0.025, xy[1] - 0.025, name,
  362. ha='left', va='top', size='small')
  363. elif xy[0] >= 0.0 and xy[1] >= 0.0:
  364. ax.text(xy[0] + 0.025, xy[1] + 0.025, name,
  365. ha='left', va='bottom', size='small')
  366. ax.axis('equal')
  367. return ax
  368. @deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
  369. def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
  370. colormap=None, **kwds):
  371. """
  372. Parameters:
  373. -----------
  374. frame : DataFrame
  375. Data to be plotted, preferably normalized to (0.0, 1.0)
  376. class_column : Name of the column containing class names
  377. ax : matplotlib axes object, default None
  378. samples : Number of points to plot in each curve
  379. color: list or tuple, optional
  380. Colors to use for the different classes
  381. colormap : str or matplotlib colormap object, default None
  382. Colormap to select colors from. If string, load colormap with that name
  383. from matplotlib.
  384. kwds: keywords
  385. Options to pass to matplotlib plotting method
  386. Returns:
  387. --------
  388. ax: Matplotlib axis object
  389. """
  390. from math import sqrt, pi, sin, cos
  391. import matplotlib.pyplot as plt
  392. def function(amplitudes):
  393. def f(x):
  394. x1 = amplitudes[0]
  395. result = x1 / sqrt(2.0)
  396. harmonic = 1.0
  397. for x_even, x_odd in zip(amplitudes[1::2], amplitudes[2::2]):
  398. result += (x_even * sin(harmonic * x) +
  399. x_odd * cos(harmonic * x))
  400. harmonic += 1.0
  401. if len(amplitudes) % 2 != 0:
  402. result += amplitudes[-1] * sin(harmonic * x)
  403. return result
  404. return f
  405. n = len(frame)
  406. class_col = frame[class_column]
  407. classes = frame[class_column].drop_duplicates()
  408. df = frame.drop(class_column, axis=1)
  409. x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
  410. used_legends = set([])
  411. color_values = _get_standard_colors(num_colors=len(classes),
  412. colormap=colormap, color_type='random',
  413. color=color)
  414. colors = dict(zip(classes, color_values))
  415. if ax is None:
  416. ax = plt.gca(xlim=(-pi, pi))
  417. for i in range(n):
  418. row = df.iloc[i].values
  419. f = function(row)
  420. y = [f(t) for t in x]
  421. kls = class_col.iat[i]
  422. label = com.pprint_thing(kls)
  423. if label not in used_legends:
  424. used_legends.add(label)
  425. ax.plot(x, y, color=colors[kls], label=label, **kwds)
  426. else:
  427. ax.plot(x, y, color=colors[kls], **kwds)
  428. ax.legend(loc='upper right')
  429. ax.grid()
  430. return ax
  431. def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
  432. """Bootstrap plot.
  433. Parameters:
  434. -----------
  435. series: Time series
  436. fig: matplotlib figure object, optional
  437. size: number of data points to consider during each sampling
  438. samples: number of times the bootstrap procedure is performed
  439. kwds: optional keyword arguments for plotting commands, must be accepted
  440. by both hist and plot
  441. Returns:
  442. --------
  443. fig: matplotlib figure
  444. """
  445. import random
  446. import matplotlib.pyplot as plt
  447. # random.sample(ndarray, int) fails on python 3.3, sigh
  448. data = list(series.values)
  449. samplings = [random.sample(data, size) for _ in range(samples)]
  450. means = np.array([np.mean(sampling) for sampling in samplings])
  451. medians = np.array([np.median(sampling) for sampling in samplings])
  452. midranges = np.array([(min(sampling) + max(sampling)) * 0.5
  453. for sampling in samplings])
  454. if fig is None:
  455. fig = plt.figure()
  456. x = lrange(samples)
  457. axes = []
  458. ax1 = fig.add_subplot(2, 3, 1)
  459. ax1.set_xlabel("Sample")
  460. axes.append(ax1)
  461. ax1.plot(x, means, **kwds)
  462. ax2 = fig.add_subplot(2, 3, 2)
  463. ax2.set_xlabel("Sample")
  464. axes.append(ax2)
  465. ax2.plot(x, medians, **kwds)
  466. ax3 = fig.add_subplot(2, 3, 3)
  467. ax3.set_xlabel("Sample")
  468. axes.append(ax3)
  469. ax3.plot(x, midranges, **kwds)
  470. ax4 = fig.add_subplot(2, 3, 4)
  471. ax4.set_xlabel("Mean")
  472. axes.append(ax4)
  473. ax4.hist(means, **kwds)
  474. ax5 = fig.add_subplot(2, 3, 5)
  475. ax5.set_xlabel("Median")
  476. axes.append(ax5)
  477. ax5.hist(medians, **kwds)
  478. ax6 = fig.add_subplot(2, 3, 6)
  479. ax6.set_xlabel("Midrange")
  480. axes.append(ax6)
  481. ax6.hist(midranges, **kwds)
  482. for axis in axes:
  483. plt.setp(axis.get_xticklabels(), fontsize=8)
  484. plt.setp(axis.get_yticklabels(), fontsize=8)
  485. return fig
  486. @deprecate_kwarg(old_arg_name='colors', new_arg_name='color')
  487. @deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
  488. def parallel_coordinates(frame, class_column, cols=None, ax=None, color=None,
  489. use_columns=False, xticks=None, colormap=None,
  490. **kwds):
  491. """Parallel coordinates plotting.
  492. Parameters
  493. ----------
  494. frame: DataFrame
  495. class_column: str
  496. Column name containing class names
  497. cols: list, optional
  498. A list of column names to use
  499. ax: matplotlib.axis, optional
  500. matplotlib axis object
  501. color: list or tuple, optional
  502. Colors to use for the different classes
  503. use_columns: bool, optional
  504. If true, columns will be used as xticks
  505. xticks: list or tuple, optional
  506. A list of values to use for xticks
  507. colormap: str or matplotlib colormap, default None
  508. Colormap to use for line colors.
  509. kwds: keywords
  510. Options to pass to matplotlib plotting method
  511. Returns
  512. -------
  513. ax: matplotlib axis object
  514. Examples
  515. --------
  516. >>> from pandas import read_csv
  517. >>> from pandas.tools.plotting import parallel_coordinates
  518. >>> from matplotlib import pyplot as plt
  519. >>> df = read_csv('https://raw.github.com/pydata/pandas/master/pandas/tests/data/iris.csv')
  520. >>> parallel_coordinates(df, 'Name', color=('#556270', '#4ECDC4', '#C7F464'))
  521. >>> plt.show()
  522. """
  523. import matplotlib.pyplot as plt
  524. n = len(frame)
  525. classes = frame[class_column].drop_duplicates()
  526. class_col = frame[class_column]
  527. if cols is None:
  528. df = frame.drop(class_column, axis=1)
  529. else:
  530. df = frame[cols]
  531. used_legends = set([])
  532. ncols = len(df.columns)
  533. # determine values to use for xticks
  534. if use_columns is True:
  535. if not np.all(np.isreal(list(df.columns))):
  536. raise ValueError('Columns must be numeric to be used as xticks')
  537. x = df.columns
  538. elif xticks is not None:
  539. if not np.all(np.isreal(xticks)):
  540. raise ValueError('xticks specified must be numeric')
  541. elif len(xticks) != ncols:
  542. raise ValueError('Length of xticks must match number of columns')
  543. x = xticks
  544. else:
  545. x = lrange(ncols)
  546. if ax is None:
  547. ax = plt.gca()
  548. color_values = _get_standard_colors(num_colors=len(classes),
  549. colormap=colormap, color_type='random',
  550. color=color)
  551. colors = dict(zip(classes, color_values))
  552. for i in range(n):
  553. y = df.iloc[i].values
  554. kls = class_col.iat[i]
  555. label = com.pprint_thing(kls)
  556. if label not in used_legends:
  557. used_legends.add(label)
  558. ax.plot(x, y, color=colors[kls], label=label, **kwds)
  559. else:
  560. ax.plot(x, y, color=colors[kls], **kwds)
  561. for i in x:
  562. ax.axvline(i, linewidth=1, color='black')
  563. ax.set_xticks(x)
  564. ax.set_xticklabels(df.columns)
  565. ax.set_xlim(x[0], x[-1])
  566. ax.legend(loc='upper right')
  567. ax.grid()
  568. return ax
  569. def lag_plot(series, lag=1, ax=None, **kwds):
  570. """Lag plot for time series.
  571. Parameters:
  572. -----------
  573. series: Time series
  574. lag: lag of the scatter plot, default 1
  575. ax: Matplotlib axis object, optional
  576. kwds: Matplotlib scatter method keyword arguments, optional
  577. Returns:
  578. --------
  579. ax: Matplotlib axis object
  580. """
  581. import matplotlib.pyplot as plt
  582. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  583. kwds.setdefault('c', plt.rcParams['patch.facecolor'])
  584. data = series.values
  585. y1 = data[:-lag]
  586. y2 = data[lag:]
  587. if ax is None:
  588. ax = plt.gca()
  589. ax.set_xlabel("y(t)")
  590. ax.set_ylabel("y(t + %s)" % lag)
  591. ax.scatter(y1, y2, **kwds)
  592. return ax
  593. def autocorrelation_plot(series, ax=None, **kwds):
  594. """Autocorrelation plot for time series.
  595. Parameters:
  596. -----------
  597. series: Time series
  598. ax: Matplotlib axis object, optional
  599. kwds : keywords
  600. Options to pass to matplotlib plotting method
  601. Returns:
  602. -----------
  603. ax: Matplotlib axis object
  604. """
  605. import matplotlib.pyplot as plt
  606. n = len(series)
  607. data = np.asarray(series)
  608. if ax is None:
  609. ax = plt.gca(xlim=(1, n), ylim=(-1.0, 1.0))
  610. mean = np.mean(data)
  611. c0 = np.sum((data - mean) ** 2) / float(n)
  612. def r(h):
  613. return ((data[:n - h] - mean) * (data[h:] - mean)).sum() / float(n) / c0
  614. x = np.arange(n) + 1
  615. y = lmap(r, x)
  616. z95 = 1.959963984540054
  617. z99 = 2.5758293035489004
  618. ax.axhline(y=z99 / np.sqrt(n), linestyle='--', color='grey')
  619. ax.axhline(y=z95 / np.sqrt(n), color='grey')
  620. ax.axhline(y=0.0, color='black')
  621. ax.axhline(y=-z95 / np.sqrt(n), color='grey')
  622. ax.axhline(y=-z99 / np.sqrt(n), linestyle='--', color='grey')
  623. ax.set_xlabel("Lag")
  624. ax.set_ylabel("Autocorrelation")
  625. ax.plot(x, y, **kwds)
  626. if 'label' in kwds:
  627. ax.legend()
  628. ax.grid()
  629. return ax
  630. class MPLPlot(object):
  631. """
  632. Base class for assembling a pandas plot using matplotlib
  633. Parameters
  634. ----------
  635. data :
  636. """
  637. _default_rot = 0
  638. _pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog',
  639. 'mark_right']
  640. _attr_defaults = {'logy': False, 'logx': False, 'loglog': False,
  641. 'mark_right': True}
  642. def __init__(self, data, kind=None, by=None, subplots=False, sharex=True,
  643. sharey=False, use_index=True,
  644. figsize=None, grid=None, legend=True, rot=None,
  645. ax=None, fig=None, title=None, xlim=None, ylim=None,
  646. xticks=None, yticks=None,
  647. sort_columns=False, fontsize=None,
  648. secondary_y=False, colormap=None,
  649. table=False, **kwds):
  650. self.data = data
  651. self.by = by
  652. self.kind = kind
  653. self.sort_columns = sort_columns
  654. self.subplots = subplots
  655. self.sharex = sharex
  656. self.sharey = sharey
  657. self.figsize = figsize
  658. self.xticks = xticks
  659. self.yticks = yticks
  660. self.xlim = xlim
  661. self.ylim = ylim
  662. self.title = title
  663. self.use_index = use_index
  664. self.fontsize = fontsize
  665. self.rot = rot
  666. if grid is None:
  667. grid = False if secondary_y else True
  668. self.grid = grid
  669. self.legend = legend
  670. self.legend_handles = []
  671. self.legend_labels = []
  672. for attr in self._pop_attributes:
  673. value = kwds.pop(attr, self._attr_defaults.get(attr, None))
  674. setattr(self, attr, value)
  675. self.ax = ax
  676. self.fig = fig
  677. self.axes = None
  678. # parse errorbar input if given
  679. xerr = kwds.pop('xerr', None)
  680. yerr = kwds.pop('yerr', None)
  681. self.errors = {}
  682. for kw, err in zip(['xerr', 'yerr'], [xerr, yerr]):
  683. self.errors[kw] = self._parse_errorbars(kw, err)
  684. if not isinstance(secondary_y, (bool, tuple, list, np.ndarray)):
  685. secondary_y = [secondary_y]
  686. self.secondary_y = secondary_y
  687. # ugly TypeError if user passes matplotlib's `cmap` name.
  688. # Probably better to accept either.
  689. if 'cmap' in kwds and colormap:
  690. raise TypeError("Only specify one of `cmap` and `colormap`.")
  691. elif 'cmap' in kwds:
  692. self.colormap = kwds.pop('cmap')
  693. else:
  694. self.colormap = colormap
  695. self.table = table
  696. self.kwds = kwds
  697. self._validate_color_args()
  698. def _validate_color_args(self):
  699. from pandas import DataFrame
  700. if 'color' not in self.kwds and 'colors' in self.kwds:
  701. warnings.warn(("'colors' is being deprecated. Please use 'color'"
  702. "instead of 'colors'"))
  703. colors = self.kwds.pop('colors')
  704. self.kwds['color'] = colors
  705. if ('color' in self.kwds and
  706. (isinstance(self.data, Series) or
  707. isinstance(self.data, DataFrame) and len(self.data.columns) == 1)):
  708. # support series.plot(color='green')
  709. self.kwds['color'] = [self.kwds['color']]
  710. if ('color' in self.kwds or 'colors' in self.kwds) and \
  711. self.colormap is not None:
  712. warnings.warn("'color' and 'colormap' cannot be used "
  713. "simultaneously. Using 'color'")
  714. if 'color' in self.kwds and self.style is not None:
  715. # need only a single match
  716. if re.match('^[a-z]+?', self.style) is not None:
  717. raise ValueError("Cannot pass 'style' string with a color "
  718. "symbol and 'color' keyword argument. Please"
  719. " use one or the other or pass 'style' "
  720. "without a color symbol")
  721. def _iter_data(self, data=None, keep_index=False):
  722. if data is None:
  723. data = self.data
  724. from pandas.core.frame import DataFrame
  725. if isinstance(data, (Series, np.ndarray)):
  726. if keep_index is True:
  727. yield self.label, data
  728. else:
  729. yield self.label, np.asarray(data)
  730. elif isinstance(data, DataFrame):
  731. if self.sort_columns:
  732. columns = com._try_sort(data.columns)
  733. else:
  734. columns = data.columns
  735. for col in columns:
  736. # # is this right?
  737. # empty = df[col].count() == 0
  738. # values = df[col].values if not empty else np.zeros(len(df))
  739. if keep_index is True:
  740. yield col, data[col]
  741. else:
  742. yield col, data[col].values
  743. @property
  744. def nseries(self):
  745. if self.data.ndim == 1:
  746. return 1
  747. else:
  748. return self.data.shape[1]
  749. def draw(self):
  750. self.plt.draw_if_interactive()
  751. def generate(self):
  752. self._args_adjust()
  753. self._compute_plot_data()
  754. self._setup_subplots()
  755. self._make_plot()
  756. self._add_table()
  757. self._make_legend()
  758. self._post_plot_logic()
  759. self._adorn_subplots()
  760. def _args_adjust(self):
  761. pass
  762. def _maybe_right_yaxis(self, ax):
  763. if hasattr(ax, 'right_ax'):
  764. return ax.right_ax
  765. else:
  766. orig_ax, new_ax = ax, ax.twinx()
  767. new_ax._get_lines.color_cycle = orig_ax._get_lines.color_cycle
  768. orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
  769. new_ax.right_ax = new_ax
  770. if len(orig_ax.get_lines()) == 0: # no data on left y
  771. orig_ax.get_yaxis().set_visible(False)
  772. return new_ax
  773. def _setup_subplots(self):
  774. if self.subplots:
  775. nrows, ncols = self._get_layout()
  776. fig, axes = _subplots(nrows=nrows, ncols=ncols,
  777. sharex=self.sharex, sharey=self.sharey,
  778. figsize=self.figsize, ax=self.ax)
  779. if not com.is_list_like(axes):
  780. axes = np.array([axes])
  781. else:
  782. if self.ax is None:
  783. fig = self.plt.figure(figsize=self.figsize)
  784. ax = fig.add_subplot(111)
  785. else:
  786. fig = self.ax.get_figure()
  787. if self.figsize is not None:
  788. fig.set_size_inches(self.figsize)
  789. ax = self.ax
  790. axes = [ax]
  791. if self.logx or self.loglog:
  792. [a.set_xscale('log') for a in axes]
  793. if self.logy or self.loglog:
  794. [a.set_yscale('log') for a in axes]
  795. self.fig = fig
  796. self.axes = axes
  797. def _get_layout(self):
  798. from pandas.core.frame import DataFrame
  799. if isinstance(self.data, DataFrame):
  800. return (len(self.data.columns), 1)
  801. else:
  802. return (1, 1)
  803. def _compute_plot_data(self):
  804. numeric_data = self.data.convert_objects()._get_numeric_data()
  805. try:
  806. is_empty = numeric_data.empty
  807. except AttributeError:
  808. is_empty = not len(numeric_data)
  809. # no empty frames or series allowed
  810. if is_empty:
  811. raise TypeError('Empty {0!r}: no numeric data to '
  812. 'plot'.format(numeric_data.__class__.__name__))
  813. self.data = numeric_data
  814. def _make_plot(self):
  815. raise NotImplementedError
  816. def _add_table(self):
  817. if self.table is False:
  818. return
  819. elif self.table is True:
  820. from pandas.core.frame import DataFrame
  821. if isinstance(self.data, Series):
  822. data = DataFrame(self.data, columns=[self.data.name])
  823. elif isinstance(self.data, DataFrame):
  824. data = self.data
  825. data = data.transpose()
  826. else:
  827. data = self.table
  828. ax = self._get_ax(0)
  829. table(ax, data)
  830. def _post_plot_logic(self):
  831. pass
  832. def _adorn_subplots(self):
  833. to_adorn = self.axes
  834. # todo: sharex, sharey handling?
  835. for ax in to_adorn:
  836. if self.yticks is not None:
  837. ax.set_yticks(self.yticks)
  838. if self.xticks is not None:
  839. ax.set_xticks(self.xticks)
  840. if self.ylim is not None:
  841. ax.set_ylim(self.ylim)
  842. if self.xlim is not None:
  843. ax.set_xlim(self.xlim)
  844. ax.grid(self.grid)
  845. if self.title:
  846. if self.subplots:
  847. self.fig.suptitle(self.title)
  848. else:
  849. self.axes[0].set_title(self.title)
  850. if self._need_to_set_index:
  851. labels = [com.pprint_thing(key) for key in self.data.index]
  852. labels = dict(zip(range(len(self.data.index)), labels))
  853. for ax_ in self.axes:
  854. # ax_.set_xticks(self.xticks)
  855. xticklabels = [labels.get(x, '') for x in ax_.get_xticks()]
  856. ax_.set_xticklabels(xticklabels, rotation=self.rot)
  857. @property
  858. def legend_title(self):
  859. if hasattr(self.data, 'columns'):
  860. if not isinstance(self.data.columns, MultiIndex):
  861. name = self.data.columns.name
  862. if name is not None:
  863. name = com.pprint_thing(name)
  864. return name
  865. else:
  866. stringified = map(com.pprint_thing,
  867. self.data.columns.names)
  868. return ','.join(stringified)
  869. else:
  870. return None
  871. def _add_legend_handle(self, handle, label, index=None):
  872. if not label is None:
  873. if self.mark_right and index is not None:
  874. if self.on_right(index):
  875. label = label + ' (right)'
  876. self.legend_handles.append(handle)
  877. self.legend_labels.append(label)
  878. def _make_legend(self):
  879. ax, leg = self._get_ax_legend(self.axes[0])
  880. handles = []
  881. labels = []
  882. title = ''
  883. if not self.subplots:
  884. if not leg is None:
  885. title = leg.get_title().get_text()
  886. handles = leg.legendHandles
  887. labels = [x.get_text() for x in leg.get_texts()]
  888. if self.legend:
  889. if self.legend == 'reverse':
  890. self.legend_handles = reversed(self.legend_handles)
  891. self.legend_labels = reversed(self.legend_labels)
  892. handles += self.legend_handles
  893. labels += self.legend_labels
  894. if not self.legend_title is None:
  895. title = self.legend_title
  896. if len(handles) > 0:
  897. ax.legend(handles, labels, loc='best', title=title)
  898. elif self.subplots and self.legend:
  899. for ax in self.axes:
  900. ax.legend(loc='best')
  901. def _get_ax_legend(self, ax):
  902. leg = ax.get_legend()
  903. other_ax = (getattr(ax, 'right_ax', None) or
  904. getattr(ax, 'left_ax', None))
  905. other_leg = None
  906. if other_ax is not None:
  907. other_leg = other_ax.get_legend()
  908. if leg is None and other_leg is not None:
  909. leg = other_leg
  910. ax = other_ax
  911. return ax, leg
  912. @cache_readonly
  913. def plt(self):
  914. import matplotlib.pyplot as plt
  915. return plt
  916. _need_to_set_index = False
  917. def _get_xticks(self, convert_period=False):
  918. index = self.data.index
  919. is_datetype = index.inferred_type in ('datetime', 'date',
  920. 'datetime64', 'time')
  921. if self.use_index:
  922. if convert_period and isinstance(index, PeriodIndex):
  923. self.data = self.data.reindex(index=index.order())
  924. x = self.data.index.to_timestamp()._mpl_repr()
  925. elif index.is_numeric():
  926. """
  927. Matplotlib supports numeric values or datetime objects as
  928. xaxis values. Taking LBYL approach here, by the time
  929. matplotlib raises exception when using non numeric/datetime
  930. values for xaxis, several actions are already taken by plt.
  931. """
  932. x = index._mpl_repr()
  933. elif is_datetype:
  934. self.data = self.data.sort_index()
  935. x = self.data.index._mpl_repr()
  936. else:
  937. self._need_to_set_index = True
  938. x = lrange(len(index))
  939. else:
  940. x = lrange(len(index))
  941. return x
  942. def _is_datetype(self):
  943. index = self.data.index
  944. return (isinstance(index, (PeriodIndex, DatetimeIndex)) or
  945. index.inferred_type in ('datetime', 'date', 'datetime64',
  946. 'time'))
  947. def _get_plot_function(self):
  948. '''
  949. Returns the matplotlib plotting function (plot or errorbar) based on
  950. the presence of errorbar keywords.
  951. '''
  952. if all(e is None for e in self.errors.values()):
  953. plotf = self.plt.Axes.plot
  954. else:
  955. plotf = self.plt.Axes.errorbar
  956. return plotf
  957. def _get_index_name(self):
  958. if isinstance(self.data.index, MultiIndex):
  959. name = self.data.index.names
  960. if any(x is not None for x in name):
  961. name = ','.join([com.pprint_thing(x) for x in name])
  962. else:
  963. name = None
  964. else:
  965. name = self.data.index.name
  966. if name is not None:
  967. name = com.pprint_thing(name)
  968. return name
  969. def _get_ax(self, i):
  970. # get the twinx ax if appropriate
  971. if self.subplots:
  972. ax = self.axes[i]
  973. if self.on_right(i):
  974. ax = self._maybe_right_yaxis(ax)
  975. self.axes[i] = ax
  976. else:
  977. ax = self.axes[0]
  978. if self.on_right(i):
  979. ax = self._maybe_right_yaxis(ax)
  980. sec_true = isinstance(self.secondary_y, bool) and self.secondary_y
  981. all_sec = (com.is_list_like(self.secondary_y) and
  982. len(self.secondary_y) == self.nseries)
  983. if sec_true or all_sec:
  984. self.axes[0] = ax
  985. ax.get_yaxis().set_visible(True)
  986. return ax
  987. def on_right(self, i):
  988. from pandas.core.frame import DataFrame
  989. if isinstance(self.secondary_y, bool):
  990. return self.secondary_y
  991. if (isinstance(self.data, DataFrame) and
  992. isinstance(self.secondary_y, (tuple, list, np.ndarray))):
  993. return self.data.columns[i] in self.secondary_y
  994. def _get_style(self, i, col_name):
  995. style = ''
  996. if self.subplots:
  997. style = 'k'
  998. if self.style is not None:
  999. if isinstance(self.style, list):
  1000. try:
  1001. style = self.style[i]
  1002. except IndexError:
  1003. pass
  1004. elif isinstance(self.style, dict):
  1005. style = self.style.get(col_name, style)
  1006. else:
  1007. style = self.style
  1008. return style or None
  1009. def _get_colors(self, num_colors=None, color_kwds='color'):
  1010. from pandas.core.frame import DataFrame
  1011. if num_colors is None:
  1012. if isinstance(self.data, DataFrame):
  1013. num_colors = len(self.data.columns)
  1014. else:
  1015. num_colors = 1
  1016. return _get_standard_colors(num_colors=num_colors,
  1017. colormap=self.colormap,
  1018. color=self.kwds.get(color_kwds))
  1019. def _maybe_add_color(self, colors, kwds, style, i):
  1020. has_color = 'color' in kwds or self.colormap is not None
  1021. if has_color and (style is None or re.match('[a-z]+', style) is None):
  1022. kwds['color'] = colors[i % len(colors)]
  1023. def _parse_errorbars(self, label, err):
  1024. '''
  1025. Look for error keyword arguments and return the actual errorbar data
  1026. or return the error DataFrame/dict
  1027. Error bars can be specified in several ways:
  1028. Series: the user provides a pandas.Series object of the same
  1029. length as the data
  1030. ndarray: provides a np.ndarray of the same length as the data
  1031. DataFrame/dict: error values are paired with keys matching the
  1032. key in the plotted DataFrame
  1033. str: the name of the column within the plotted DataFrame
  1034. '''
  1035. if err is None:
  1036. return None
  1037. from pandas import DataFrame, Series
  1038. def match_labels(data, e):
  1039. e = e.reindex_axis(data.index)
  1040. return e
  1041. # key-matched DataFrame
  1042. if isinstance(err, DataFrame):
  1043. err = match_labels(self.data, err)
  1044. # key-matched dict
  1045. elif isinstance(err, dict):
  1046. pass
  1047. # Series of error values
  1048. elif isinstance(err, Series):
  1049. # broadcast error series across data
  1050. err = match_labels(self.data, err)
  1051. err = np.atleast_2d(err)
  1052. err = np.tile(err, (self.nseries, 1))
  1053. # errors are a column in the dataframe
  1054. elif isinstance(err, string_types):
  1055. evalues = self.data[err].values
  1056. self.data = self.data[self.data.columns.drop(err)]
  1057. err = np.atleast_2d(evalues)
  1058. err = np.tile(err, (self.nseries, 1))
  1059. elif com.is_list_like(err):
  1060. if com.is_iterator(err):
  1061. err = np.atleast_2d(list(err))
  1062. else:
  1063. # raw error values
  1064. err = np.atleast_2d(err)
  1065. err_shape = err.shape
  1066. # asymmetrical error bars
  1067. if err.ndim == 3:
  1068. if (err_shape[0] != self.nseries) or \
  1069. (err_shape[1] != 2) or \
  1070. (err_shape[2] != len(self.data)):
  1071. msg = "Asymmetrical error bars should be provided " + \
  1072. "with the shape (%u, 2, %u)" % \
  1073. (self.nseries, len(self.data))
  1074. raise ValueError(msg)
  1075. # broadcast errors to each data series
  1076. if len(err) == 1:
  1077. err = np.tile(err, (self.nseries, 1))
  1078. elif com.is_number(err):
  1079. err = np.tile([err], (self.nseries, len(self.data)))
  1080. else:
  1081. msg = "No valid %s detected" % label
  1082. raise ValueError(msg)
  1083. return err
  1084. def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
  1085. from pandas import DataFrame
  1086. errors = {}
  1087. for kw, flag in zip(['xerr', 'yerr'], [xerr, yerr]):
  1088. if flag:
  1089. err = self.errors[kw]
  1090. # user provided label-matched dataframe of errors
  1091. if isinstance(err, (DataFrame, dict)):
  1092. if label is not None and label in err.keys():
  1093. err = err[label]
  1094. else:
  1095. err = None
  1096. elif index is not None and err is not None:
  1097. err = err[index]
  1098. if err is not None:
  1099. errors[kw] = err
  1100. return errors
  1101. class KdePlot(MPLPlot):
  1102. def __init__(self, data, bw_method=None, ind=None, **kwargs):
  1103. MPLPlot.__init__(self, data, **kwargs)
  1104. self.bw_method=bw_method
  1105. self.ind=ind
  1106. def _make_plot(self):
  1107. from scipy.stats import gaussian_kde
  1108. from scipy import __version__ as spv
  1109. from distutils.version import LooseVersion
  1110. plotf = self.plt.Axes.plot
  1111. colors = self._get_colors()
  1112. for i, (label, y) in enumerate(self._iter_data()):
  1113. ax = self._get_ax(i)
  1114. style = self._get_style(i, label)
  1115. label = com.pprint_thing(label)
  1116. if LooseVersion(spv) >= '0.11.0':
  1117. gkde = gaussian_kde(y, bw_method=self.bw_method)
  1118. else:
  1119. gkde = gaussian_kde(y)
  1120. if self.bw_method is not None:
  1121. msg = ('bw_method was added in Scipy 0.11.0.' +
  1122. ' Scipy version in use is %s.' % spv)
  1123. warnings.warn(msg)
  1124. sample_range = max(y) - min(y)
  1125. if self.ind is None:
  1126. ind = np.linspace(min(y) - 0.5 * sample_range,
  1127. max(y) + 0.5 * sample_range, 1000)
  1128. else:
  1129. ind = self.ind
  1130. ax.set_ylabel("Density")
  1131. y = gkde.evaluate(ind)
  1132. kwds = self.kwds.copy()
  1133. kwds['label'] = label
  1134. self._maybe_add_color(colors, kwds, style, i)
  1135. if style is None:
  1136. args = (ax, ind, y)
  1137. else:
  1138. args = (ax, ind, y, style)
  1139. newlines = plotf(*args, **kwds)
  1140. self._add_legend_handle(newlines[0], label)
  1141. class ScatterPlot(MPLPlot):
  1142. def __init__(self, data, x, y, **kwargs):
  1143. MPLPlot.__init__(self, data, **kwargs)
  1144. self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor'])
  1145. if x is None or y is None:
  1146. raise ValueError( 'scatter requires and x and y column')
  1147. if com.is_integer(x) and not self.data.columns.holds_integer():
  1148. x = self.data.columns[x]
  1149. if com.is_integer(y) and not self.data.columns.holds_integer():
  1150. y = self.data.columns[y]
  1151. self.x = x
  1152. self.y = y
  1153. def _get_layout(self):
  1154. return (1, 1)
  1155. def _make_plot(self):
  1156. x, y, data = self.x, self.y, self.data
  1157. ax = self.axes[0]
  1158. if self.legend and hasattr(self, 'label'):
  1159. label = self.label
  1160. else:
  1161. label = None
  1162. scatter = ax.scatter(data[x].values, data[y].values, label=label,
  1163. **self.kwds)
  1164. self._add_legend_handle(scatter, label)
  1165. errors_x = self._get_errorbars(label=x, index=0, yerr=False)
  1166. errors_y = self._get_errorbars(label=y, index=1, xerr=False)
  1167. if len(errors_x) > 0 or len(errors_y) > 0:
  1168. err_kwds = dict(errors_x, **errors_y)
  1169. if 'color' in self.kwds:
  1170. err_kwds['color'] = self.kwds['color']
  1171. ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds)
  1172. def _post_plot_logic(self):
  1173. ax = self.axes[0]
  1174. x, y = self.x, self.y
  1175. ax.set_ylabel(com.pprint_thing(y))
  1176. ax.set_xlabel(com.pprint_thing(x))
  1177. class HexBinPlot(MPLPlot):
  1178. def __init__(self, data, x, y, C=None, **kwargs):
  1179. MPLPlot.__init__(self, data, **kwargs)
  1180. if x is None or y is None:
  1181. raise ValueError('hexbin requires and x and y column')
  1182. if com.is_integer(x) and not self.data.columns.holds_integer():
  1183. x = self.data.columns[x]
  1184. if com.is_integer(y) and not self.data.columns.holds_integer():
  1185. y = self.data.columns[y]
  1186. if com.is_integer(C) and not self.data.columns.holds_integer():
  1187. C = self.data.columns[C]
  1188. self.x = x
  1189. self.y = y
  1190. self.C = C
  1191. def _get_layout(self):
  1192. return (1, 1)
  1193. def _make_plot(self):
  1194. import matplotlib.pyplot as plt
  1195. x, y, data, C = self.x, self.y, self.data, self.C
  1196. ax = self.axes[0]
  1197. # pandas uses colormap, matplotlib uses cmap.
  1198. cmap = self.colormap or 'BuGn'
  1199. cmap = plt.cm.get_cmap(cmap)
  1200. cb = self.kwds.pop('colorbar', True)
  1201. if C is None:
  1202. c_values = None
  1203. else:
  1204. c_values = data[C].values
  1205. ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap,
  1206. **self.kwds)
  1207. if cb:
  1208. img = ax.collections[0]
  1209. self.fig.colorbar(img, ax=ax)
  1210. def _post_plot_logic(self):
  1211. ax = self.axes[0]
  1212. x, y = self.x, self.y
  1213. ax.set_ylabel(com.pprint_thing(y))
  1214. ax.set_xlabel(com.pprint_thing(x))
  1215. class LinePlot(MPLPlot):
  1216. def __init__(self, data, **kwargs):
  1217. self.stacked = kwargs.pop('stacked', False)
  1218. if self.stacked:
  1219. data = data.fillna(value=0)
  1220. MPLPlot.__init__(self, data, **kwargs)
  1221. self.x_compat = plot_params['x_compat']
  1222. if 'x_compat' in self.kwds:
  1223. self.x_compat = bool(self.kwds.pop('x_compat'))
  1224. def _index_freq(self):
  1225. from pandas.core.frame import DataFrame
  1226. if isinstance(self.data, (Series, DataFrame)):
  1227. freq = getattr(self.data.index, 'freq', None)
  1228. if freq is None:
  1229. freq = getattr(self.data.index, 'inferred_freq', None)
  1230. if freq == 'B':
  1231. weekdays = np.unique(self.data.index.dayofweek)
  1232. if (5 in weekdays) or (6 in weekdays):
  1233. freq = None
  1234. return freq
  1235. def _is_dynamic_freq(self, freq):
  1236. if isinstance(freq, DateOffset):
  1237. freq = freq.rule_code
  1238. else:
  1239. freq = get_base_alias(freq)
  1240. freq = get_period_alias(freq)
  1241. return freq is not None and self._no_base(freq)
  1242. def _no_base(self, freq):
  1243. # hack this for 0.10.1, creating more technical debt...sigh
  1244. from pandas.core.frame import DataFrame
  1245. if (isinstance(self.data, (Series, DataFrame))
  1246. and isinstance(self.data.index, DatetimeIndex)):
  1247. import pandas.tseries.frequencies as freqmod
  1248. base = freqmod.get_freq(freq)
  1249. x = self.data.index
  1250. if (base <= freqmod.FreqGroup.FR_DAY):
  1251. return x[:1].is_normalized
  1252. return Period(x[0], freq).to_timestamp(tz=x.tz) == x[0]
  1253. return True
  1254. def _use_dynamic_x(self):
  1255. freq = self._index_freq()
  1256. ax = self._get_ax(0)
  1257. ax_freq = getattr(ax, 'freq', None)
  1258. if freq is None: # convert irregular if axes has freq info
  1259. freq = ax_freq
  1260. else: # do not use tsplot if irregular was plotted first
  1261. if (ax_freq is None) and (len(ax.get_lines()) > 0):
  1262. return False
  1263. return (freq is not None) and self._is_dynamic_freq(freq)
  1264. def _is_ts_plot(self):
  1265. # this is slightly deceptive
  1266. return not self.x_compat and self.use_index and self._use_dynamic_x()
  1267. def _make_plot(self):
  1268. self._pos_prior = np.zeros(len(self.data))
  1269. self._neg_prior = np.zeros(len(self.data))
  1270. if self._is_ts_plot():
  1271. data = self._maybe_convert_index(self.data)
  1272. self._make_ts_plot(data)
  1273. else:
  1274. x = self._get_xticks(convert_period=True)
  1275. plotf = self._get_plot_function()
  1276. colors = self._get_colors()
  1277. for i, (label, y) in enumerate(self._iter_data()):
  1278. ax = self._get_ax(i)
  1279. style = self._get_style(i, label)
  1280. kwds = self.kwds.copy()
  1281. self._maybe_add_color(colors, kwds, style, i)
  1282. errors = self._get_errorbars(label=label, index=i)
  1283. kwds = dict(kwds, **errors)
  1284. label = com.pprint_thing(label) # .encode('utf-8')
  1285. kwds['label'] = label
  1286. y_values = self._get_stacked_values(y, label)
  1287. if not self.stacked:
  1288. mask = com.isnull(y_values)
  1289. if mask.any():
  1290. y_values = np.ma.array(y_values)
  1291. y_values = np.ma.masked_where(mask, y_values)
  1292. # prevent style kwarg from going to errorbar, where it is unsupported
  1293. if style is not None and plotf.__name__ != 'errorbar':
  1294. args = (ax, x, y_values, style)
  1295. else:
  1296. args = (ax, x, y_values)
  1297. newlines = plotf(*args, **kwds)
  1298. self._add_legend_handle(newlines[0], label, index=i)
  1299. if self.stacked and not self.subplots:
  1300. if (y >= 0).all():
  1301. self._pos_prior += y
  1302. elif (y <= 0).all():
  1303. self._neg_prior += y
  1304. lines = _get_all_lines(ax)
  1305. left, right = _get_xlim(lines)
  1306. ax.set_xlim(left, right)
  1307. def _get_stacked_values(self, y, label):
  1308. if self.stacked:
  1309. if (y >= 0).all():
  1310. return self._pos_prior + y
  1311. elif (y <= 0).all():
  1312. return self._neg_prior + y
  1313. else:
  1314. raise ValueError('When stacked is True, each column must be either all positive or negative.'
  1315. '{0} contains both positive and negative values'.format(label))
  1316. else:
  1317. return y
  1318. def _get_ts_plot_function(self):
  1319. from pandas.tseries.plotting import tsplot
  1320. plotf = self._get_plot_function()
  1321. def _plot(data, ax, label, style, **kwds):
  1322. # errorbar function does not support style argument
  1323. if plotf.__name__ == 'errorbar':
  1324. lines = tsplot(data, plotf, ax=ax, label=label,
  1325. **kwds)
  1326. return lines
  1327. else:
  1328. lines = tsplot(data, plotf, ax=ax, label=label,
  1329. style=style, **kwds)
  1330. return lines
  1331. return _plot
  1332. def _make_ts_plot(self, data, **kwargs):
  1333. colors = self._get_colors()
  1334. plotf = self._get_ts_plot_function()
  1335. it = self._iter_data(data=data, keep_index=True)
  1336. for i, (label, y) in enumerate(it):
  1337. ax = self._get_ax(i)
  1338. style = self._get_style(i, label)
  1339. kwds = self.kwds.copy()
  1340. self._maybe_add_color(colors, kwds, style, i)
  1341. errors = self._get_errorbars(label=label, index=i, xerr=False)
  1342. kwds = dict(kwds, **errors)
  1343. label = com.pprint_thing(label)
  1344. y_values = self._get_stacked_values(y, label)
  1345. newlines = plotf(y_values, ax, label, style, **kwds)
  1346. self._add_legend_handle(newlines[0], label, index=i)
  1347. if self.stacked and not self.subplots:
  1348. if (y >= 0).all():
  1349. self._pos_prior += y
  1350. elif (y <= 0).all():
  1351. self._neg_prior += y
  1352. def _maybe_convert_index(self, data):
  1353. # tsplot converts automatically, but don't want to convert index
  1354. # over and over for DataFrames
  1355. from pandas.core.frame import DataFrame
  1356. if (isinstance(data.index, DatetimeIndex) and
  1357. isinstance(data, DataFrame)):
  1358. freq = getattr(data.index, 'freq', None)
  1359. if freq is None:
  1360. freq = getattr(data.index, 'inferred_freq', None)
  1361. if isinstance(freq, DateOffset):
  1362. freq = freq.rule_code
  1363. freq = get_base_alias(freq)
  1364. freq = get_period_alias(freq)
  1365. if freq is None:
  1366. ax = self._get_ax(0)
  1367. freq = getattr(ax, 'freq', None)
  1368. if freq is None:
  1369. raise ValueError('Could not get frequency alias for plotting')
  1370. data = DataFrame(data.values,
  1371. index=data.index.to_period(freq=freq),
  1372. columns=data.columns)
  1373. return data
  1374. def _post_plot_logic(self):
  1375. df = self.data
  1376. condition = (not self._use_dynamic_x()
  1377. and df.index.is_all_dates
  1378. and not self.subplots
  1379. or (self.subplots and self.sharex))
  1380. index_name = self._get_index_name()
  1381. rot = 30
  1382. if self.rot is not None:
  1383. rot = self.rot
  1384. for ax in self.axes:
  1385. if condition:
  1386. format_date_labels(ax, rot=rot)
  1387. elif self.rot is not None:
  1388. for l in ax.get_xticklabels():
  1389. l.set_rotation(self.rot)
  1390. if index_name is not None:
  1391. ax.set_xlabel(index_name)
  1392. class AreaPlot(LinePlot):
  1393. def __init__(self, data, **kwargs):
  1394. kwargs.setdefault('stacked', True)
  1395. data = data.fillna(value=0)
  1396. LinePlot.__init__(self, data, **kwargs)
  1397. if not self.stacked:
  1398. # use smaller alpha to distinguish overlap
  1399. self.kwds.setdefault('alpha', 0.5)
  1400. def _get_plot_function(self):
  1401. if self.logy or self.loglog:
  1402. raise ValueError("Log-y scales are not supported in area plot")
  1403. else:
  1404. f = LinePlot._get_plot_function(self)
  1405. def plotf(*args, **kwds):
  1406. lines = f(*args, **kwds)
  1407. # insert fill_between starting point
  1408. y = args[2]
  1409. if (y >= 0).all():
  1410. start = self._pos_prior
  1411. elif (y <= 0).all():
  1412. start = self._neg_prior
  1413. else:
  1414. start = np.zeros(len(y))
  1415. # get x data from the line
  1416. # to retrieve x coodinates of tsplot
  1417. xdata = lines[0].get_data()[0]
  1418. # remove style
  1419. args = (args[0], xdata, start, y)
  1420. if not 'color' in kwds:
  1421. kwds['color'] = lines[0].get_color()
  1422. self.plt.Axes.fill_between(*args, **kwds)
  1423. return lines
  1424. return plotf
  1425. def _add_legend_handle(self, handle, label, index=None):
  1426. from matplotlib.patches import Rectangle
  1427. # Because fill_between isn't supported in legend,
  1428. # specifically add Rectangle handle here
  1429. alpha = self.kwds.get('alpha', 0.5)
  1430. handle = Rectangle((0, 0), 1, 1, fc=handle.get_color(), alpha=alpha)
  1431. LinePlot._add_legend_handle(self, handle, label, index=index)
  1432. def _post_plot_logic(self):
  1433. LinePlot._post_plot_logic(self)
  1434. if self._is_ts_plot():
  1435. pass
  1436. else:
  1437. if self.xlim is None:
  1438. for ax in self.axes:
  1439. lines = _get_all_lines(ax)
  1440. left, right = _get_xlim(lines)
  1441. ax.set_xlim(left, right)
  1442. if self.ylim is None:
  1443. if (self.data >= 0).all().all():
  1444. for ax in self.axes:
  1445. ax.set_ylim(0, None)
  1446. elif (self.data <= 0).all().all():
  1447. for ax in self.axes:
  1448. ax.set_ylim(None, 0)
  1449. class BarPlot(MPLPlot):
  1450. _default_rot = {'bar': 90, 'barh': 0}
  1451. def __init__(self, data, **kwargs):
  1452. self.stacked = kwargs.pop('stacked', False)
  1453. self.bar_width = kwargs.pop('width', 0.5)
  1454. pos = kwargs.pop('position', 0.5)
  1455. kwargs.setdefault('align', 'center')
  1456. self.tick_pos = np.arange(len(data))
  1457. self.bottom = kwargs.pop('bottom', None)
  1458. self.left = kwargs.pop('left', None)
  1459. self.log = kwargs.pop('log',False)
  1460. MPLPlot.__init__(self, data, **kwargs)
  1461. if self.stacked or self.subplots:
  1462. self.tickoffset = self.bar_width * pos
  1463. if kwargs['align'] == 'edge':
  1464. self.lim_offset = self.bar_width / 2
  1465. else:
  1466. self.lim_offset = 0
  1467. else:
  1468. if kwargs['align'] == 'edge':
  1469. w = self.bar_width / self.nseries
  1470. self.tickoffset = self.bar_width * (pos - 0.5) + w * 0.5
  1471. self.lim_offset = w * 0.5
  1472. else:
  1473. self.tickoffset = self.bar_width * pos
  1474. self.lim_offset = 0
  1475. self.ax_pos = self.tick_pos - self.tickoffset
  1476. def _args_adjust(self):
  1477. if self.rot is None:
  1478. self.rot = self._default_rot[self.kind]
  1479. if com.is_list_like(self.bottom):
  1480. self.bottom = np.array(self.bottom)
  1481. if com.is_list_like(self.left):
  1482. self.left = np.array(self.left)
  1483. def _get_plot_function(self):
  1484. if self.kind == 'bar':
  1485. def f(ax, x, y, w, start=None, **kwds):
  1486. if self.bottom is not None:
  1487. start = start + self.bottom
  1488. return ax.bar(x, y, w, bottom=start,log=self.log, **kwds)
  1489. elif self.kind == 'barh':
  1490. def f(ax, x, y, w, start=None, log=self.log, **kwds):
  1491. if self.left is not None:
  1492. start = start + self.left
  1493. return ax.barh(x, y, w, left=start, **kwds)
  1494. else:
  1495. raise NotImplementedError
  1496. return f
  1497. def _make_plot(self):
  1498. import matplotlib as mpl
  1499. # mpl decided to make their version string unicode across all Python
  1500. # versions for mpl >= 1.3 so we have to call str here for python 2
  1501. mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')
  1502. colors = self._get_colors()
  1503. ncolors = len(colors)
  1504. bar_f = self._get_plot_function()
  1505. pos_prior = neg_prior = np.zeros(len(self.data))
  1506. K = self.nseries
  1507. for i, (label, y) in enumerate(self._iter_data()):
  1508. ax = self._get_ax(i)
  1509. kwds = self.kwds.copy()
  1510. kwds['color'] = colors[i % ncolors]
  1511. errors = self._get_errorbars(label=label, index=i)
  1512. kwds = dict(kwds, **errors)
  1513. label = com.pprint_thing(label)
  1514. if (('yerr' in kwds) or ('xerr' in kwds)) \
  1515. and (kwds.get('ecolor') is None):
  1516. kwds['ecolor'] = mpl.rcParams['xtick.color']
  1517. start = 0
  1518. if self.log:
  1519. start = 1
  1520. if any(y < 1):
  1521. # GH3254
  1522. start = 0 if mpl_le_1_2_1 else None
  1523. if self.subplots:
  1524. w = self.bar_width / 2
  1525. rect = bar_f(ax, self.ax_pos + w, y, self.bar_width,
  1526. start=start, label=label, **kwds)
  1527. ax.set_title(label)
  1528. elif self.stacked:
  1529. mask = y > 0
  1530. start = np.where(mask, pos_prior, neg_prior)
  1531. w = self.bar_width / 2
  1532. rect = bar_f(ax, self.ax_pos + w, y, self.bar_width,
  1533. start=start, label=label, **kwds)
  1534. pos_prior = pos_prior + np.where(mask, y, 0)
  1535. neg_prior = neg_prior + np.where(mask, 0, y)
  1536. else:
  1537. w = self.bar_width / K
  1538. rect = bar_f(ax, self.ax_pos + (i + 0.5) * w, y, w,
  1539. start=start, label=label, **kwds)
  1540. self._add_legend_handle(rect, label, index=i)
  1541. def _post_plot_logic(self):
  1542. for ax in self.axes:
  1543. if self.use_index:
  1544. str_index = [com.pprint_thing(key) for key in self.data.index]
  1545. else:
  1546. str_index = [com.pprint_thing(key) for key in
  1547. range(self.data.shape[0])]
  1548. name = self._get_index_name()
  1549. s_edge = self.ax_pos[0] - 0.25 + self.lim_offset
  1550. e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset
  1551. if self.kind == 'bar':
  1552. ax.set_xlim((s_edge, e_edge))
  1553. ax.set_xticks(self.tick_pos)
  1554. ax.set_xticklabels(str_index, rotation=self.rot,
  1555. fontsize=self.fontsize)
  1556. if not self.log: # GH3254+
  1557. ax.axhline(0, color='k', linestyle='--')
  1558. if name is not None:
  1559. ax.set_xlabel(name)
  1560. elif self.kind == 'barh':
  1561. # horizontal bars
  1562. ax.set_ylim((s_edge, e_edge))
  1563. ax.set_yticks(self.tick_pos)
  1564. ax.set_yticklabels(str_index, rotation=self.rot,
  1565. fontsize=self.fontsize)
  1566. ax.axvline(0, color='k', linestyle='--')
  1567. if name is not None:
  1568. ax.set_ylabel(name)
  1569. else:
  1570. raise NotImplementedError(self.kind)
  1571. class PiePlot(MPLPlot):
  1572. def __init__(self, data, kind=None, **kwargs):
  1573. data = data.fillna(value=0)
  1574. if (data < 0).any().any():
  1575. raise ValueError("{0} doesn't allow negative values".format(kind))
  1576. MPLPlot.__init__(self, data, kind=kind, **kwargs)
  1577. def _args_adjust(self):
  1578. self.grid = False
  1579. self.logy = False
  1580. self.logx = False
  1581. self.loglog = False
  1582. def _get_layout(self):
  1583. from pandas import DataFrame
  1584. if isinstance(self.data, DataFrame):
  1585. return (1, len(self.data.columns))
  1586. else:
  1587. return (1, 1)
  1588. def _validate_color_args(self):
  1589. pass
  1590. def _make_plot(self):
  1591. self.kwds.setdefault('colors', self._get_colors(num_colors=len(self.data),
  1592. color_kwds='colors'))
  1593. for i, (label, y) in enumerate(self._iter_data()):
  1594. ax = self._get_ax(i)
  1595. if label is not None:
  1596. label = com.pprint_thing(label)
  1597. ax.set_ylabel(label)
  1598. kwds = self.kwds.copy()
  1599. idx = [com.pprint_thing(v) for v in self.data.index]
  1600. labels = kwds.pop('labels', idx)
  1601. # labels is used for each wedge's labels
  1602. results = ax.pie(y, labels=labels, **kwds)
  1603. if kwds.get('autopct', None) is not None:
  1604. patches, texts, autotexts = results
  1605. else:
  1606. patches, texts = results
  1607. autotexts = []
  1608. if self.fontsize is not None:
  1609. for t in texts + autotexts:
  1610. t.set_fontsize(self.fontsize)
  1611. # leglabels is used for legend labels
  1612. leglabels = labels if labels is not None else idx
  1613. for p, l in zip(patches, leglabels):
  1614. self._add_legend_handle(p, l)
  1615. class BoxPlot(MPLPlot):
  1616. pass
  1617. class HistPlot(MPLPlot):
  1618. pass
  1619. # kinds supported by both dataframe and series
  1620. _common_kinds = ['line', 'bar', 'barh', 'kde', 'density', 'area']
  1621. # kinds supported by dataframe
  1622. _dataframe_kinds = ['scatter', 'hexbin']
  1623. # kinds supported only by series or dataframe single column
  1624. _series_kinds = ['pie']
  1625. _all_kinds = _common_kinds + _dataframe_kinds + _series_kinds
  1626. _plot_klass = {'line': LinePlot, 'bar': BarPlot, 'barh': BarPlot,
  1627. 'kde': KdePlot,
  1628. 'scatter': ScatterPlot, 'hexbin': HexBinPlot,
  1629. 'area': AreaPlot, 'pie': PiePlot}
  1630. def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
  1631. sharey=False, use_index=True, figsize=None, grid=None,
  1632. legend=True, rot=None, ax=None, style=None, title=None,
  1633. xlim=None, ylim=None, logx=False, logy=False, xticks=None,
  1634. yticks=None, kind='line', sort_columns=False, fontsize=None,
  1635. secondary_y=False, **kwds):
  1636. """
  1637. Make line, bar, or scatter plots of DataFrame series with the index on the x-axis
  1638. using matplotlib / pylab.
  1639. Parameters
  1640. ----------
  1641. frame : DataFrame
  1642. x : label or position, default None
  1643. y : label or position, default None
  1644. Allows plotting of one column versus another
  1645. yerr : DataFrame (with matching labels), Series, list-type (tuple, list,
  1646. ndarray), or str of column name containing y error values
  1647. xerr : similar functionality as yerr, but for x error values
  1648. subplots : boolean, default False
  1649. Make separate subplots for each time series
  1650. sharex : boolean, default True
  1651. In case subplots=True, share x axis
  1652. sharey : boolean, default False
  1653. In case subplots=True, share y axis
  1654. use_index : boolean, default True
  1655. Use index as ticks for x axis
  1656. stacked : boolean, default False
  1657. If True, create stacked bar plot. Only valid for DataFrame input
  1658. sort_columns: boolean, default False
  1659. Sort column names to determine plot ordering
  1660. title : string
  1661. Title to use for the plot
  1662. grid : boolean, default None (matlab style default)
  1663. Axis grid lines
  1664. legend : False/True/'reverse'
  1665. Place legend on axis subplots
  1666. ax : matplotlib axis object, default None
  1667. style : list or dict
  1668. matplotlib line style per column
  1669. kind : {'line', 'bar', 'barh', 'kde', 'density', 'area', scatter', 'hexbin'}
  1670. line : line plot
  1671. bar : vertical bar plot
  1672. barh : horizontal bar plot
  1673. kde/density : Kernel Density Estimation plot
  1674. area : area plot
  1675. scatter : scatter plot
  1676. hexbin : hexbin plot
  1677. logx : boolean, default False
  1678. Use log scaling on x axis
  1679. logy : boolean, default False
  1680. Use log scaling on y axis
  1681. loglog : boolean, default False
  1682. Use log scaling on both x and y axes
  1683. xticks : sequence
  1684. Values to use for the xticks
  1685. yticks : sequence
  1686. Values to use for the yticks
  1687. xlim : 2-tuple/list
  1688. ylim : 2-tuple/list
  1689. rot : int, default None
  1690. Rotation for ticks
  1691. secondary_y : boolean or sequence, default False
  1692. Whether to plot on the secondary y-axis
  1693. If a list/tuple, which columns to plot on secondary y-axis
  1694. mark_right: boolean, default True
  1695. When using a secondary_y axis, should the legend label the axis of
  1696. the various columns automatically
  1697. colormap : str or matplotlib colormap object, default None
  1698. Colormap to select colors from. If string, load colormap with that name
  1699. from matplotlib.
  1700. position : float
  1701. Specify relative alignments for bar plot layout.
  1702. From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
  1703. table : boolean, Series or DataFrame, default False
  1704. If True, draw a table using the data in the DataFrame and the data will
  1705. be transposed to meet matplotlib's default layout.
  1706. If a Series or DataFrame is passed, use passed data to draw a table.
  1707. kwds : keywords
  1708. Options to pass to matplotlib plotting method
  1709. Returns
  1710. -------
  1711. ax_or_axes : matplotlib.AxesSubplot or list of them
  1712. Notes
  1713. -----
  1714. If `kind`='hexbin', you can control the size of the bins with the
  1715. `gridsize` argument. By default, a histogram of the counts around each
  1716. `(x, y)` point is computed. You can specify alternative aggregations
  1717. by passing values to the `C` and `reduce_C_function` arguments.
  1718. `C` specifies the value at each `(x, y)` point and `reduce_C_function`
  1719. is a function of one argument that reduces all the values in a bin to
  1720. a single number (e.g. `mean`, `max`, `sum`, `std`).
  1721. """
  1722. kind = _get_standard_kind(kind.lower().strip())
  1723. if kind in _all_kinds:
  1724. klass = _plot_klass[kind]
  1725. else:
  1726. raise ValueError('Invalid chart type given %s' % kind)
  1727. if kind in _dataframe_kinds:
  1728. plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
  1729. rot=rot,legend=legend, ax=ax, style=style,
  1730. fontsize=fontsize, use_index=use_index, sharex=sharex,
  1731. sharey=sharey, xticks=xticks, yticks=yticks,
  1732. xlim=xlim, ylim=ylim, title=title, grid=grid,
  1733. figsize=figsize, logx=logx, logy=logy,
  1734. sort_columns=sort_columns, secondary_y=secondary_y,
  1735. **kwds)
  1736. elif kind in _series_kinds:
  1737. if y is None and subplots is False:
  1738. msg = "{0} requires either y column or 'subplots=True'"
  1739. raise ValueError(msg.format(kind))
  1740. elif y is not None:
  1741. if com.is_integer(y) and not frame.columns.holds_integer():
  1742. y = frame.columns[y]
  1743. frame = frame[y] # converted to series actually
  1744. frame.index.name = y
  1745. plot_obj = klass(frame, kind=kind, subplots=subplots,
  1746. rot=rot,legend=legend, ax=ax, style=style,
  1747. fontsize=fontsize, use_index=use_index, sharex=sharex,
  1748. sharey=sharey, xticks=xticks, yticks=yticks,
  1749. xlim=xlim, ylim=ylim, title=title, grid=grid,
  1750. figsize=figsize,
  1751. sort_columns=sort_columns,
  1752. **kwds)
  1753. else:
  1754. if x is not None:
  1755. if com.is_integer(x) and not frame.columns.holds_integer():
  1756. x = frame.columns[x]
  1757. frame = frame.set_index(x)
  1758. if y is not None:
  1759. if com.is_integer(y) and not frame.columns.holds_integer():
  1760. y = frame.columns[y]
  1761. label = x if x is not None else frame.index.name
  1762. label = kwds.pop('label', label)
  1763. ser = frame[y]
  1764. ser.index.name = label
  1765. for kw in ['xerr', 'yerr']:
  1766. if (kw in kwds) and \
  1767. (isinstance(kwds[kw], string_types) or com.is_integer(kwds[kw])):
  1768. try:
  1769. kwds[kw] = frame[kwds[kw]]
  1770. except (IndexError, KeyError, TypeError):
  1771. pass
  1772. return plot_series(ser, label=label, kind=kind,
  1773. use_index=use_index,
  1774. rot=rot, xticks=xticks, yticks=yticks,
  1775. xlim=xlim, ylim=ylim, ax=ax, style=style,
  1776. grid=grid, logx=logx, logy=logy,
  1777. secondary_y=secondary_y, title=title,
  1778. figsize=figsize, fontsize=fontsize, **kwds)
  1779. else:
  1780. plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
  1781. legend=legend, ax=ax, style=style, fontsize=fontsize,
  1782. use_index=use_index, sharex=sharex, sharey=sharey,
  1783. xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
  1784. title=title, grid=grid, figsize=figsize, logx=logx,
  1785. logy=logy, sort_columns=sort_columns,
  1786. secondary_y=secondary_y, **kwds)
  1787. plot_obj.generate()
  1788. plot_obj.draw()
  1789. if subplots:
  1790. return plot_obj.axes
  1791. else:
  1792. return plot_obj.axes[0]
  1793. def plot_series(series, label=None, kind='line', use_index=True, rot=None,
  1794. xticks=None, yticks=None, xlim=None, ylim=None,
  1795. ax=None, style=None, grid=None, legend=False, logx=False,
  1796. logy=False, secondary_y=False, **kwds):
  1797. """
  1798. Plot the input series with the index on the x-axis using matplotlib
  1799. Parameters
  1800. ----------
  1801. label : label argument to provide to plot
  1802. kind : {'line', 'bar', 'barh', 'kde', 'density', 'area'}
  1803. line : line plot
  1804. bar : vertical bar plot
  1805. barh : horizontal bar plot
  1806. kde/density : Kernel Density Estimation plot
  1807. area : area plot
  1808. use_index : boolean, default True
  1809. Plot index as axis tick labels
  1810. rot : int, default None
  1811. Rotation for tick labels
  1812. xticks : sequence
  1813. Values to use for the xticks
  1814. yticks : sequence
  1815. Values to use for the yticks
  1816. xlim : 2-tuple/list
  1817. ylim : 2-tuple/list
  1818. ax : matplotlib axis object
  1819. If not passed, uses gca()
  1820. style : string, default matplotlib default
  1821. matplotlib line style to use
  1822. grid : matplotlib grid
  1823. legend: matplotlib legend
  1824. logx : boolean, default False
  1825. Use log scaling on x axis
  1826. logy : boolean, default False
  1827. Use log scaling on y axis
  1828. loglog : boolean, default False
  1829. Use log scaling on both x and y axes
  1830. secondary_y : boolean or sequence of ints, default False
  1831. If True then y-axis will be on the right
  1832. figsize : a tuple (width, height) in inches
  1833. position : float
  1834. Specify relative alignments for bar plot layout.
  1835. From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
  1836. table : boolean, Series or DataFrame, default False
  1837. If True, draw a table using the data in the Series and the data will
  1838. be transposed to meet matplotlib's default layout.
  1839. If a Series or DataFrame is passed, use passed data to draw a table.
  1840. kwds : keywords
  1841. Options to pass to matplotlib plotting method
  1842. Notes
  1843. -----
  1844. See matplotlib documentation online for more on this subject
  1845. """
  1846. kind = _get_standard_kind(kind.lower().strip())
  1847. if kind in _common_kinds or kind in _series_kinds:
  1848. klass = _plot_klass[kind]
  1849. else:
  1850. raise ValueError('Invalid chart type given %s' % kind)
  1851. """
  1852. If no axis is specified, we check whether there are existing figures.
  1853. If so, we get the current axis and check whether yaxis ticks are on the
  1854. right. Ticks for the plot of the series will be on the right unless
  1855. there is at least one axis with ticks on the left.
  1856. If we do not check for whether there are existing figures, _gca() will
  1857. create a figure with the default figsize, causing the figsize= parameter to
  1858. be ignored.
  1859. """
  1860. import matplotlib.pyplot as plt
  1861. if ax is None and len(plt.get_fignums()) > 0:
  1862. ax = _gca()
  1863. ax = getattr(ax, 'left_ax', ax)
  1864. # is there harm in this?
  1865. if label is None:
  1866. label = series.name
  1867. plot_obj = klass(series, kind=kind, rot=rot, logx=logx, logy=logy,
  1868. ax=ax, use_index=use_index, style=style,
  1869. xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
  1870. legend=legend, grid=grid, label=label,
  1871. secondary_y=secondary_y, **kwds)
  1872. plot_obj.generate()
  1873. plot_obj.draw()
  1874. # plot_obj.ax is None if we created the first figure
  1875. return plot_obj.axes[0]
  1876. _shared_docs['boxplot'] = """
  1877. Make a box plot from DataFrame column optionally grouped by some columns or
  1878. other inputs
  1879. Parameters
  1880. ----------
  1881. data : the pandas object holding the data
  1882. column : column name or list of names, or vector
  1883. Can be any valid input to groupby
  1884. by : string or sequence
  1885. Column in the DataFrame to group by
  1886. ax : Matplotlib axes object, optional
  1887. fontsize : int or string
  1888. rot : label rotation angle
  1889. figsize : A tuple (width, height) in inches
  1890. grid : Setting this to True will show the grid
  1891. layout : tuple (optional)
  1892. (rows, columns) for the layout of the plot
  1893. return_type : {'axes', 'dict', 'both'}, default 'dict'
  1894. The kind of object to return. 'dict' returns a dictionary
  1895. whose values are the matplotlib Lines of the boxplot;
  1896. 'axes' returns the matplotlib axes the boxplot is drawn on;
  1897. 'both' returns a namedtuple with the axes and dict.
  1898. When grouping with ``by``, a dict mapping columns to ``return_type``
  1899. is returned.
  1900. kwds : other plotting keyword arguments to be passed to matplotlib boxplot
  1901. function
  1902. Returns
  1903. -------
  1904. lines : dict
  1905. ax : matplotlib Axes
  1906. (ax, lines): namedtuple
  1907. Notes
  1908. -----
  1909. Use ``return_type='dict'`` when you want to tweak the appearance
  1910. of the lines after plotting. In this case a dict containing the Lines
  1911. making up the boxes, caps, fliers, medians, and whiskers is returned.
  1912. """
  1913. @Appender(_shared_docs['boxplot'] % _shared_doc_kwargs)
  1914. def boxplot(data, column=None, by=None, ax=None, fontsize=None,
  1915. rot=0, grid=True, figsize=None, layout=None, return_type=None,
  1916. **kwds):
  1917. # validate return_type:
  1918. valid_types = (None, 'axes', 'dict', 'both')
  1919. if return_type not in valid_types:
  1920. raise ValueError("return_type")
  1921. from pandas import Series, DataFrame
  1922. if isinstance(data, Series):
  1923. data = DataFrame({'x': data})
  1924. column = 'x'
  1925. def _get_colors():
  1926. return _get_standard_colors(color=kwds.get('color'), num_colors=1)
  1927. def maybe_color_bp(bp):
  1928. if 'color' not in kwds :
  1929. from matplotlib.artist import setp
  1930. setp(bp['boxes'],color=colors[0],alpha=1)
  1931. setp(bp['whiskers'],color=colors[0],alpha=1)
  1932. setp(bp['medians'],color=colors[2],alpha=1)
  1933. BP = namedtuple("Boxplot", ['ax', 'lines']) # namedtuple to hold results
  1934. def plot_group(keys, values, ax):
  1935. keys = [com.pprint_thing(x) for x in keys]
  1936. values = [remove_na(v) for v in values]
  1937. bp = ax.boxplot(values, **kwds)
  1938. if kwds.get('vert', 1):
  1939. ax.set_xticklabels(keys, rotation=rot, fontsize=fontsize)
  1940. else:
  1941. ax.set_yticklabels(keys, rotation=rot, fontsize=fontsize)
  1942. maybe_color_bp(bp)
  1943. # Return axes in multiplot case, maybe revisit later # 985
  1944. if return_type == 'dict':
  1945. return bp
  1946. elif return_type == 'both':
  1947. return BP(ax=ax, lines=bp)
  1948. else:
  1949. return ax
  1950. colors = _get_colors()
  1951. if column is None:
  1952. columns = None
  1953. else:
  1954. if isinstance(column, (list, tuple)):
  1955. columns = column
  1956. else:
  1957. columns = [column]
  1958. if by is not None:
  1959. result = _grouped_plot_by_column(plot_group, data, columns=columns,
  1960. by=by, grid=grid, figsize=figsize,
  1961. ax=ax, layout=layout, return_type=return_type)
  1962. else:
  1963. if layout is not None:
  1964. raise ValueError("The 'layout' keyword is not supported when "
  1965. "'by' is None")
  1966. if return_type is None:
  1967. msg = ("\nThe default value for 'return_type' will change to "
  1968. "'axes' in a future release.\n To use the future behavior "
  1969. "now, set return_type='axes'.\n To keep the previous "
  1970. "behavior and silence this warning, set "
  1971. "return_type='dict'.")
  1972. warnings.warn(msg, FutureWarning)
  1973. return_type = 'dict'
  1974. if ax is None:
  1975. ax = _gca()
  1976. data = data._get_numeric_data()
  1977. if columns is None:
  1978. columns = data.columns
  1979. else:
  1980. data = data[columns]
  1981. result = plot_group(columns, data.values.T, ax)
  1982. ax.grid(grid)
  1983. return result
  1984. def format_date_labels(ax, rot):
  1985. # mini version of autofmt_xdate
  1986. try:
  1987. for label in ax.get_xticklabels():
  1988. label.set_ha('right')
  1989. label.set_rotation(rot)
  1990. fig = ax.get_figure()
  1991. fig.subplots_adjust(bottom=0.2)
  1992. except Exception: # pragma: no cover
  1993. pass
  1994. def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False, **kwargs):
  1995. """
  1996. Make a scatter plot from two DataFrame columns
  1997. Parameters
  1998. ----------
  1999. data : DataFrame
  2000. x : Column name for the x-axis values
  2001. y : Column name for the y-axis values
  2002. ax : Matplotlib axis object
  2003. figsize : A tuple (width, height) in inches
  2004. grid : Setting this to True will show the grid
  2005. kwargs : other plotting keyword arguments
  2006. To be passed to scatter function
  2007. Returns
  2008. -------
  2009. fig : matplotlib.Figure
  2010. """
  2011. import matplotlib.pyplot as plt
  2012. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  2013. kwargs.setdefault('c', plt.rcParams['patch.facecolor'])
  2014. def plot_group(group, ax):
  2015. xvals = group[x].values
  2016. yvals = group[y].values
  2017. ax.scatter(xvals, yvals, **kwargs)
  2018. ax.grid(grid)
  2019. if by is not None:
  2020. fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax)
  2021. else:
  2022. if ax is None:
  2023. fig = plt.figure()
  2024. ax = fig.add_subplot(111)
  2025. else:
  2026. fig = ax.get_figure()
  2027. plot_group(data, ax)
  2028. ax.set_ylabel(com.pprint_thing(y))
  2029. ax.set_xlabel(com.pprint_thing(x))
  2030. ax.grid(grid)
  2031. return fig
  2032. def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
  2033. xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
  2034. sharey=False, figsize=None, layout=None, bins=10, **kwds):
  2035. """
  2036. Draw histogram of the DataFrame's series using matplotlib / pylab.
  2037. Parameters
  2038. ----------
  2039. data : DataFrame
  2040. column : string or sequence
  2041. If passed, will be used to limit data to a subset of columns
  2042. by : object, optional
  2043. If passed, then used to form histograms for separate groups
  2044. grid : boolean, default True
  2045. Whether to show axis grid lines
  2046. xlabelsize : int, default None
  2047. If specified changes the x-axis label size
  2048. xrot : float, default None
  2049. rotation of x axis labels
  2050. ylabelsize : int, default None
  2051. If specified changes the y-axis label size
  2052. yrot : float, default None
  2053. rotation of y axis labels
  2054. ax : matplotlib axes object, default None
  2055. sharex : bool, if True, the X axis will be shared amongst all subplots.
  2056. sharey : bool, if True, the Y axis will be shared amongst all subplots.
  2057. figsize : tuple
  2058. The size of the figure to create in inches by default
  2059. layout: (optional) a tuple (rows, columns) for the layout of the histograms
  2060. bins: integer, default 10
  2061. Number of histogram bins to be used
  2062. kwds : other plotting keyword arguments
  2063. To be passed to hist function
  2064. """
  2065. if by is not None:
  2066. axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize,
  2067. sharex=sharex, sharey=sharey, layout=layout, bins=bins,
  2068. xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot,
  2069. **kwds)
  2070. return axes
  2071. if column is not None:
  2072. if not isinstance(column, (list, np.ndarray)):
  2073. column = [column]
  2074. data = data[column]
  2075. data = data._get_numeric_data()
  2076. naxes = len(data.columns)
  2077. nrows, ncols = _get_layout(naxes, layout=layout)
  2078. fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, ax=ax, squeeze=False,
  2079. sharex=sharex, sharey=sharey, figsize=figsize)
  2080. for i, col in enumerate(com._try_sort(data.columns)):
  2081. ax = axes[i // ncols, i % ncols]
  2082. ax.hist(data[col].dropna().values, bins=bins, **kwds)
  2083. ax.set_title(col)
  2084. ax.grid(grid)
  2085. _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot,
  2086. ylabelsize=ylabelsize, yrot=yrot)
  2087. fig.subplots_adjust(wspace=0.3, hspace=0.3)
  2088. return axes
  2089. def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
  2090. xrot=None, ylabelsize=None, yrot=None, figsize=None, bins=10, **kwds):
  2091. """
  2092. Draw histogram of the input series using matplotlib
  2093. Parameters
  2094. ----------
  2095. by : object, optional
  2096. If passed, then used to form histograms for separate groups
  2097. ax : matplotlib axis object
  2098. If not passed, uses gca()
  2099. grid : boolean, default True
  2100. Whether to show axis grid lines
  2101. xlabelsize : int, default None
  2102. If specified changes the x-axis label size
  2103. xrot : float, default None
  2104. rotation of x axis labels
  2105. ylabelsize : int, default None
  2106. If specified changes the y-axis label size
  2107. yrot : float, default None
  2108. rotation of y axis labels
  2109. figsize : tuple, default None
  2110. figure size in inches by default
  2111. bins: integer, default 10
  2112. Number of histogram bins to be used
  2113. kwds : keywords
  2114. To be passed to the actual plotting function
  2115. Notes
  2116. -----
  2117. See matplotlib documentation online for more on this
  2118. """
  2119. import matplotlib.pyplot as plt
  2120. if by is None:
  2121. if kwds.get('layout', None) is not None:
  2122. raise ValueError("The 'layout' keyword is not supported when "
  2123. "'by' is None")
  2124. # hack until the plotting interface is a bit more unified
  2125. fig = kwds.pop('figure', plt.gcf() if plt.get_fignums() else
  2126. plt.figure(figsize=figsize))
  2127. if (figsize is not None and tuple(figsize) !=
  2128. tuple(fig.get_size_inches())):
  2129. fig.set_size_inches(*figsize, forward=True)
  2130. if ax is None:
  2131. ax = fig.gca()
  2132. elif ax.get_figure() != fig:
  2133. raise AssertionError('passed axis not bound to passed figure')
  2134. values = self.dropna().values
  2135. ax.hist(values, bins=bins, **kwds)
  2136. ax.grid(grid)
  2137. axes = np.array([ax])
  2138. _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot,
  2139. ylabelsize=ylabelsize, yrot=yrot)
  2140. else:
  2141. if 'figure' in kwds:
  2142. raise ValueError("Cannot pass 'figure' when using the "
  2143. "'by' argument, since a new 'Figure' instance "
  2144. "will be created")
  2145. axes = grouped_hist(self, by=by, ax=ax, grid=grid, figsize=figsize, bins=bins,
  2146. xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot,
  2147. **kwds)
  2148. if axes.ndim == 1 and len(axes) == 1:
  2149. return axes[0]
  2150. return axes
  2151. def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
  2152. layout=None, sharex=False, sharey=False, rot=90, grid=True,
  2153. xlabelsize=None, xrot=None, ylabelsize=None, yrot=None,
  2154. **kwargs):
  2155. """
  2156. Grouped histogram
  2157. Parameters
  2158. ----------
  2159. data: Series/DataFrame
  2160. column: object, optional
  2161. by: object, optional
  2162. ax: axes, optional
  2163. bins: int, default 50
  2164. figsize: tuple, optional
  2165. layout: optional
  2166. sharex: boolean, default False
  2167. sharey: boolean, default False
  2168. rot: int, default 90
  2169. grid: bool, default True
  2170. kwargs: dict, keyword arguments passed to matplotlib.Axes.hist
  2171. Returns
  2172. -------
  2173. axes: collection of Matplotlib Axes
  2174. """
  2175. def plot_group(group, ax):
  2176. ax.hist(group.dropna().values, bins=bins, **kwargs)
  2177. xrot = xrot or rot
  2178. fig, axes = _grouped_plot(plot_group, data, column=column,
  2179. by=by, sharex=sharex, sharey=sharey,
  2180. figsize=figsize, layout=layout, rot=rot)
  2181. _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot,
  2182. ylabelsize=ylabelsize, yrot=yrot)
  2183. fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9,
  2184. hspace=0.5, wspace=0.3)
  2185. return axes
  2186. def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
  2187. rot=0, grid=True, ax=None, figsize=None,
  2188. layout=None, **kwds):
  2189. """
  2190. Make box plots from DataFrameGroupBy data.
  2191. Parameters
  2192. ----------
  2193. grouped : Grouped DataFrame
  2194. subplots :
  2195. * ``False`` - no subplots will be used
  2196. * ``True`` - create a subplot for each group
  2197. column : column name or list of names, or vector
  2198. Can be any valid input to groupby
  2199. fontsize : int or string
  2200. rot : label rotation angle
  2201. grid : Setting this to True will show the grid
  2202. figsize : A tuple (width, height) in inches
  2203. layout : tuple (optional)
  2204. (rows, columns) for the layout of the plot
  2205. kwds : other plotting keyword arguments to be passed to matplotlib boxplot
  2206. function
  2207. Returns
  2208. -------
  2209. dict of key/value = group key/DataFrame.boxplot return value
  2210. or DataFrame.boxplot return value in case subplots=figures=False
  2211. Examples
  2212. --------
  2213. >>> import pandas
  2214. >>> import numpy as np
  2215. >>> import itertools
  2216. >>>
  2217. >>> tuples = [t for t in itertools.product(range(1000), range(4))]
  2218. >>> index = pandas.MultiIndex.from_tuples(tuples, names=['lvl0', 'lvl1'])
  2219. >>> data = np.random.randn(len(index),4)
  2220. >>> df = pandas.DataFrame(data, columns=list('ABCD'), index=index)
  2221. >>>
  2222. >>> grouped = df.groupby(level='lvl1')
  2223. >>> boxplot_frame_groupby(grouped)
  2224. >>>
  2225. >>> grouped = df.unstack(level='lvl1').groupby(level=0, axis=1)
  2226. >>> boxplot_frame_groupby(grouped, subplots=False)
  2227. """
  2228. if subplots is True:
  2229. naxes = len(grouped)
  2230. nrows, ncols = _get_layout(naxes, layout=layout)
  2231. fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False,
  2232. ax=ax, sharex=False, sharey=True, figsize=figsize)
  2233. axes = _flatten(axes)
  2234. ret = compat.OrderedDict()
  2235. for (key, group), ax in zip(grouped, axes):
  2236. d = group.boxplot(ax=ax, column=column, fontsize=fontsize,
  2237. rot=rot, grid=grid, **kwds)
  2238. ax.set_title(com.pprint_thing(key))
  2239. ret[key] = d
  2240. fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
  2241. else:
  2242. from pandas.tools.merge import concat
  2243. keys, frames = zip(*grouped)
  2244. if grouped.axis == 0:
  2245. df = concat(frames, keys=keys, axis=1)
  2246. else:
  2247. if len(frames) > 1:
  2248. df = frames[0].join(frames[1::])
  2249. else:
  2250. df = frames[0]
  2251. ret = df.boxplot(column=column, fontsize=fontsize, rot=rot,
  2252. grid=grid, ax=ax, figsize=figsize, layout=layout, **kwds)
  2253. return ret
  2254. def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
  2255. figsize=None, sharex=True, sharey=True, layout=None,
  2256. rot=0, ax=None, **kwargs):
  2257. from pandas import DataFrame
  2258. if figsize == 'default':
  2259. # allowed to specify mpl default with 'default'
  2260. warnings.warn("figsize='default' is deprecated. Specify figure"
  2261. "size by tuple instead", FutureWarning)
  2262. figsize = None
  2263. grouped = data.groupby(by)
  2264. if column is not None:
  2265. grouped = grouped[column]
  2266. naxes = len(grouped)
  2267. nrows, ncols = _get_layout(naxes, layout=layout)
  2268. fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes,
  2269. figsize=figsize, sharex=sharex, sharey=sharey, ax=ax)
  2270. ravel_axes = _flatten(axes)
  2271. for i, (key, group) in enumerate(grouped):
  2272. ax = ravel_axes[i]
  2273. if numeric_only and isinstance(group, DataFrame):
  2274. group = group._get_numeric_data()
  2275. plotf(group, ax, **kwargs)
  2276. ax.set_title(com.pprint_thing(key))
  2277. return fig, axes
  2278. def _grouped_plot_by_column(plotf, data, columns=None, by=None,
  2279. numeric_only=True, grid=False,
  2280. figsize=None, ax=None, layout=None, return_type=None,
  2281. **kwargs):
  2282. grouped = data.groupby(by)
  2283. if columns is None:
  2284. if not isinstance(by, (list, tuple)):
  2285. by = [by]
  2286. columns = data._get_numeric_data().columns - by
  2287. naxes = len(columns)
  2288. nrows, ncols = _get_layout(naxes, layout=layout)
  2289. fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes,
  2290. sharex=True, sharey=True,
  2291. figsize=figsize, ax=ax)
  2292. ravel_axes = _flatten(axes)
  2293. result = compat.OrderedDict()
  2294. for i, col in enumerate(columns):
  2295. ax = ravel_axes[i]
  2296. gp_col = grouped[col]
  2297. keys, values = zip(*gp_col)
  2298. re_plotf = plotf(keys, values, ax, **kwargs)
  2299. ax.set_title(col)
  2300. ax.set_xlabel(com.pprint_thing(by))
  2301. result[col] = re_plotf
  2302. ax.grid(grid)
  2303. # Return axes in multiplot case, maybe revisit later # 985
  2304. if return_type is None:
  2305. result = axes
  2306. byline = by[0] if len(by) == 1 else by
  2307. fig.suptitle('Boxplot grouped by %s' % byline)
  2308. fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
  2309. return result
  2310. def table(ax, data, rowLabels=None, colLabels=None,
  2311. **kwargs):
  2312. """
  2313. Helper function to convert DataFrame and Series to matplotlib.table
  2314. Parameters
  2315. ----------
  2316. `ax`: Matplotlib axes object
  2317. `data`: DataFrame or Series
  2318. data for table contents
  2319. `kwargs`: keywords, optional
  2320. keyword arguments which passed to matplotlib.table.table.
  2321. If `rowLabels` or `colLabels` is not specified, data index or column name will be used.
  2322. Returns
  2323. -------
  2324. matplotlib table object
  2325. """
  2326. from pandas import DataFrame
  2327. if isinstance(data, Series):
  2328. data = DataFrame(data, columns=[data.name])
  2329. elif isinstance(data, DataFrame):
  2330. pass
  2331. else:
  2332. raise ValueError('Input data must be DataFrame or Series')
  2333. if rowLabels is None:
  2334. rowLabels = data.index
  2335. if colLabels is None:
  2336. colLabels = data.columns
  2337. cellText = data.values
  2338. import matplotlib.table
  2339. table = matplotlib.table.table(ax, cellText=cellText,
  2340. rowLabels=rowLabels, colLabels=colLabels, **kwargs)
  2341. return table
  2342. def _get_layout(nplots, layout=None):
  2343. if layout is not None:
  2344. if not isinstance(layout, (tuple, list)) or len(layout) != 2:
  2345. raise ValueError('Layout must be a tuple of (rows, columns)')
  2346. nrows, ncols = layout
  2347. if nrows * ncols < nplots:
  2348. raise ValueError('Layout of %sx%s must be larger than required size %s' %
  2349. (nrows, ncols, nplots))
  2350. return layout
  2351. if nplots == 1:
  2352. return (1, 1)
  2353. elif nplots == 2:
  2354. return (1, 2)
  2355. elif nplots < 4:
  2356. return (2, 2)
  2357. k = 1
  2358. while k ** 2 < nplots:
  2359. k += 1
  2360. if (k - 1) * k >= nplots:
  2361. return k, (k - 1)
  2362. else:
  2363. return k, k
  2364. # copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0
  2365. def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=True,
  2366. subplot_kw=None, ax=None, **fig_kw):
  2367. """Create a figure with a set of subplots already made.
  2368. This utility wrapper makes it convenient to create common layouts of
  2369. subplots, including the enclosing figure object, in a single call.
  2370. Keyword arguments:
  2371. nrows : int
  2372. Number of rows of the subplot grid. Defaults to 1.
  2373. ncols : int
  2374. Number of columns of the subplot grid. Defaults to 1.
  2375. naxes : int
  2376. Number of required axes. Exceeded axes are set invisible. Default is nrows * ncols.
  2377. sharex : bool
  2378. If True, the X axis will be shared amongst all subplots.
  2379. sharey : bool
  2380. If True, the Y axis will be shared amongst all subplots.
  2381. squeeze : bool
  2382. If True, extra dimensions are squeezed out from the returned axis object:
  2383. - if only one subplot is constructed (nrows=ncols=1), the resulting
  2384. single Axis object is returned as a scalar.
  2385. - for Nx1 or 1xN subplots, the returned object is a 1-d numpy object
  2386. array of Axis objects are returned as numpy 1-d arrays.
  2387. - for NxM subplots with N>1 and M>1 are returned as a 2d array.
  2388. If False, no squeezing at all is done: the returned axis object is always
  2389. a 2-d array containing Axis instances, even if it ends up being 1x1.
  2390. subplot_kw : dict
  2391. Dict with keywords passed to the add_subplot() call used to create each
  2392. subplots.
  2393. ax : Matplotlib axis object, optional
  2394. fig_kw : Other keyword arguments to be passed to the figure() call.
  2395. Note that all keywords not recognized above will be
  2396. automatically included here.
  2397. Returns:
  2398. fig, ax : tuple
  2399. - fig is the Matplotlib Figure object
  2400. - ax can be either a single axis object or an array of axis objects if
  2401. more than one subplot was created. The dimensions of the resulting array
  2402. can be controlled with the squeeze keyword, see above.
  2403. **Examples:**
  2404. x = np.linspace(0, 2*np.pi, 400)
  2405. y = np.sin(x**2)
  2406. # Just a figure and one subplot
  2407. f, ax = plt.subplots()
  2408. ax.plot(x, y)
  2409. ax.set_title('Simple plot')
  2410. # Two subplots, unpack the output array immediately
  2411. f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
  2412. ax1.plot(x, y)
  2413. ax1.set_title('Sharing Y axis')
  2414. ax2.scatter(x, y)
  2415. # Four polar axes
  2416. plt.subplots(2, 2, subplot_kw=dict(polar=True))
  2417. """
  2418. import matplotlib.pyplot as plt
  2419. from pandas.core.frame import DataFrame
  2420. if subplot_kw is None:
  2421. subplot_kw = {}
  2422. # Create empty object array to hold all axes. It's easiest to make it 1-d
  2423. # so we can just append subplots upon creation, and then
  2424. nplots = nrows * ncols
  2425. if naxes is None:
  2426. naxes = nrows * ncols
  2427. elif nplots < naxes:
  2428. raise ValueError("naxes {0} is larger than layour size defined by nrows * ncols".format(naxes))
  2429. if ax is None:
  2430. fig = plt.figure(**fig_kw)
  2431. else:
  2432. fig = ax.get_figure()
  2433. # if ax is passed and a number of subplots is 1, return ax as it is
  2434. if naxes == 1:
  2435. if squeeze:
  2436. return fig, ax
  2437. else:
  2438. return fig, _flatten(ax)
  2439. else:
  2440. warnings.warn("To output multiple subplots, the figure containing the passed axes "
  2441. "is being cleared", UserWarning)
  2442. fig.clear()
  2443. axarr = np.empty(nplots, dtype=object)
  2444. # Create first subplot separately, so we can share it if requested
  2445. ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw)
  2446. if sharex:
  2447. subplot_kw['sharex'] = ax0
  2448. if sharey:
  2449. subplot_kw['sharey'] = ax0
  2450. axarr[0] = ax0
  2451. # Note off-by-one counting because add_subplot uses the MATLAB 1-based
  2452. # convention.
  2453. for i in range(1, nplots):
  2454. ax = fig.add_subplot(nrows, ncols, i + 1, **subplot_kw)
  2455. axarr[i] = ax
  2456. if nplots > 1:
  2457. if sharex and nrows > 1:
  2458. for ax in axarr[:naxes][:-ncols]: # only bottom row
  2459. for label in ax.get_xticklabels():
  2460. label.set_visible(False)
  2461. ax.xaxis.get_label().set_visible(False)
  2462. if sharey and ncols > 1:
  2463. for i, ax in enumerate(axarr):
  2464. if (i % ncols) != 0: # only first column
  2465. for label in ax.get_yticklabels():
  2466. label.set_visible(False)
  2467. ax.yaxis.get_label().set_visible(False)
  2468. if naxes != nplots:
  2469. for ax in axarr[naxes:]:
  2470. ax.set_visible(False)
  2471. if squeeze:
  2472. # Reshape the array to have the final desired dimension (nrow,ncol),
  2473. # though discarding unneeded dimensions that equal 1. If we only have
  2474. # one subplot, just return it instead of a 1-element array.
  2475. if nplots == 1:
  2476. axes = axarr[0]
  2477. else:
  2478. axes = axarr.reshape(nrows, ncols).squeeze()
  2479. else:
  2480. # returned axis array will be always 2-d, even if nrows=ncols=1
  2481. axes = axarr.reshape(nrows, ncols)
  2482. return fig, axes
  2483. def _flatten(axes):
  2484. if not com.is_list_like(axes):
  2485. axes = [axes]
  2486. elif isinstance(axes, np.ndarray):
  2487. axes = axes.ravel()
  2488. return axes
  2489. def _get_all_lines(ax):
  2490. lines = ax.get_lines()
  2491. # check for right_ax, which can oddly sometimes point back to ax
  2492. if hasattr(ax, 'right_ax') and ax.right_ax != ax:
  2493. lines += ax.right_ax.get_lines()
  2494. # no such risk with left_ax
  2495. if hasattr(ax, 'left_ax'):
  2496. lines += ax.left_ax.get_lines()
  2497. return lines
  2498. def _get_xlim(lines):
  2499. left, right = np.inf, -np.inf
  2500. for l in lines:
  2501. x = l.get_xdata(orig=False)
  2502. left = min(x[0], left)
  2503. right = max(x[-1], right)
  2504. return left, right
  2505. def _set_ticks_props(axes, xlabelsize=None, xrot=None,
  2506. ylabelsize=None, yrot=None):
  2507. import matplotlib.pyplot as plt
  2508. for ax in _flatten(axes):
  2509. if xlabelsize is not None:
  2510. plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
  2511. if xrot is not None:
  2512. plt.setp(ax.get_xticklabels(), rotation=xrot)
  2513. if ylabelsize is not None:
  2514. plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
  2515. if yrot is not None:
  2516. plt.setp(ax.get_yticklabels(), rotation=yrot)
  2517. return axes
  2518. if __name__ == '__main__':
  2519. # import pandas.rpy.common as com
  2520. # sales = com.load_data('sanfrancisco.home.sales', package='nutshell')
  2521. # top10 = sales['zip'].value_counts()[:10].index
  2522. # sales2 = sales[sales.zip.isin(top10)]
  2523. # _ = scatter_plot(sales2, 'squarefeet', 'price', by='zip')
  2524. # plt.show()
  2525. import matplotlib.pyplot as plt
  2526. import pandas.tools.plotting as plots
  2527. import pandas.core.frame as fr
  2528. reload(plots)
  2529. reload(fr)
  2530. from pandas.core.frame import DataFrame
  2531. data = DataFrame([[3, 6, -5], [4, 8, 2], [4, 9, -6],
  2532. [4, 9, -3], [2, 5, -1]],
  2533. columns=['A', 'B', 'C'])
  2534. data.plot(kind='barh', stacked=True)
  2535. plt.show()