PageRenderTime 72ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 1ms

/pandas/tools/plotting.py

http://github.com/wesm/pandas
Python | 3996 lines | 3667 code | 214 blank | 115 comment | 268 complexity | 62ae60e55cfe6d932a716eec2ddd081d MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0
  1. # being a bit too dynamic
  2. # pylint: disable=E1101
  3. from __future__ import division
  4. import warnings
  5. import re
  6. from math import ceil
  7. from collections import namedtuple
  8. from contextlib import contextmanager
  9. from distutils.version import LooseVersion
  10. import numpy as np
  11. from pandas.types.common import (is_list_like,
  12. is_integer,
  13. is_number,
  14. is_hashable,
  15. is_iterator)
  16. from pandas.types.missing import isnull, notnull
  17. from pandas.util.decorators import cache_readonly, deprecate_kwarg
  18. from pandas.core.base import PandasObject
  19. from pandas.core.common import AbstractMethodError, _try_sort
  20. from pandas.core.generic import _shared_docs, _shared_doc_kwargs
  21. from pandas.core.index import Index, MultiIndex
  22. from pandas.core.series import Series, remove_na
  23. from pandas.tseries.period import PeriodIndex
  24. from pandas.compat import range, lrange, lmap, map, zip, string_types
  25. import pandas.compat as compat
  26. from pandas.formats.printing import pprint_thing
  27. from pandas.util.decorators import Appender
  28. try: # mpl optional
  29. import pandas.tseries.converter as conv
  30. conv.register() # needs to override so set_xlim works with str/number
  31. except ImportError:
  32. pass
  33. # Extracted from https://gist.github.com/huyng/816622
  34. # this is the rcParams set when setting display.with_mpl_style
  35. # to True.
  36. mpl_stylesheet = {
  37. 'axes.axisbelow': True,
  38. 'axes.color_cycle': ['#348ABD',
  39. '#7A68A6',
  40. '#A60628',
  41. '#467821',
  42. '#CF4457',
  43. '#188487',
  44. '#E24A33'],
  45. 'axes.edgecolor': '#bcbcbc',
  46. 'axes.facecolor': '#eeeeee',
  47. 'axes.grid': True,
  48. 'axes.labelcolor': '#555555',
  49. 'axes.labelsize': 'large',
  50. 'axes.linewidth': 1.0,
  51. 'axes.titlesize': 'x-large',
  52. 'figure.edgecolor': 'white',
  53. 'figure.facecolor': 'white',
  54. 'figure.figsize': (6.0, 4.0),
  55. 'figure.subplot.hspace': 0.5,
  56. 'font.family': 'monospace',
  57. 'font.monospace': ['Andale Mono',
  58. 'Nimbus Mono L',
  59. 'Courier New',
  60. 'Courier',
  61. 'Fixed',
  62. 'Terminal',
  63. 'monospace'],
  64. 'font.size': 10,
  65. 'interactive': True,
  66. 'keymap.all_axes': ['a'],
  67. 'keymap.back': ['left', 'c', 'backspace'],
  68. 'keymap.forward': ['right', 'v'],
  69. 'keymap.fullscreen': ['f'],
  70. 'keymap.grid': ['g'],
  71. 'keymap.home': ['h', 'r', 'home'],
  72. 'keymap.pan': ['p'],
  73. 'keymap.save': ['s'],
  74. 'keymap.xscale': ['L', 'k'],
  75. 'keymap.yscale': ['l'],
  76. 'keymap.zoom': ['o'],
  77. 'legend.fancybox': True,
  78. 'lines.antialiased': True,
  79. 'lines.linewidth': 1.0,
  80. 'patch.antialiased': True,
  81. 'patch.edgecolor': '#EEEEEE',
  82. 'patch.facecolor': '#348ABD',
  83. 'patch.linewidth': 0.5,
  84. 'toolbar': 'toolbar2',
  85. 'xtick.color': '#555555',
  86. 'xtick.direction': 'in',
  87. 'xtick.major.pad': 6.0,
  88. 'xtick.major.size': 0.0,
  89. 'xtick.minor.pad': 6.0,
  90. 'xtick.minor.size': 0.0,
  91. 'ytick.color': '#555555',
  92. 'ytick.direction': 'in',
  93. 'ytick.major.pad': 6.0,
  94. 'ytick.major.size': 0.0,
  95. 'ytick.minor.pad': 6.0,
  96. 'ytick.minor.size': 0.0
  97. }
  98. def _mpl_le_1_2_1():
  99. try:
  100. import matplotlib as mpl
  101. return (str(mpl.__version__) <= LooseVersion('1.2.1') and
  102. str(mpl.__version__)[0] != '0')
  103. except ImportError:
  104. return False
  105. def _mpl_ge_1_3_1():
  106. try:
  107. import matplotlib
  108. # The or v[0] == '0' is because their versioneer is
  109. # messed up on dev
  110. return (matplotlib.__version__ >= LooseVersion('1.3.1') or
  111. matplotlib.__version__[0] == '0')
  112. except ImportError:
  113. return False
  114. def _mpl_ge_1_4_0():
  115. try:
  116. import matplotlib
  117. return (matplotlib.__version__ >= LooseVersion('1.4') or
  118. matplotlib.__version__[0] == '0')
  119. except ImportError:
  120. return False
  121. def _mpl_ge_1_5_0():
  122. try:
  123. import matplotlib
  124. return (matplotlib.__version__ >= LooseVersion('1.5') or
  125. matplotlib.__version__[0] == '0')
  126. except ImportError:
  127. return False
  128. if _mpl_ge_1_5_0():
  129. # Compat with mp 1.5, which uses cycler.
  130. import cycler
  131. colors = mpl_stylesheet.pop('axes.color_cycle')
  132. mpl_stylesheet['axes.prop_cycle'] = cycler.cycler('color', colors)
  133. def _get_standard_kind(kind):
  134. return {'density': 'kde'}.get(kind, kind)
  135. def _get_standard_colors(num_colors=None, colormap=None, color_type='default',
  136. color=None):
  137. import matplotlib.pyplot as plt
  138. if color is None and colormap is not None:
  139. if isinstance(colormap, compat.string_types):
  140. import matplotlib.cm as cm
  141. cmap = colormap
  142. colormap = cm.get_cmap(colormap)
  143. if colormap is None:
  144. raise ValueError("Colormap {0} is not recognized".format(cmap))
  145. colors = lmap(colormap, np.linspace(0, 1, num=num_colors))
  146. elif color is not None:
  147. if colormap is not None:
  148. warnings.warn("'color' and 'colormap' cannot be used "
  149. "simultaneously. Using 'color'")
  150. colors = list(color) if is_list_like(color) else color
  151. else:
  152. if color_type == 'default':
  153. # need to call list() on the result to copy so we don't
  154. # modify the global rcParams below
  155. try:
  156. colors = [c['color']
  157. for c in list(plt.rcParams['axes.prop_cycle'])]
  158. except KeyError:
  159. colors = list(plt.rcParams.get('axes.color_cycle',
  160. list('bgrcmyk')))
  161. if isinstance(colors, compat.string_types):
  162. colors = list(colors)
  163. elif color_type == 'random':
  164. import random
  165. def random_color(column):
  166. random.seed(column)
  167. return [random.random() for _ in range(3)]
  168. colors = lmap(random_color, lrange(num_colors))
  169. else:
  170. raise ValueError("color_type must be either 'default' or 'random'")
  171. if isinstance(colors, compat.string_types):
  172. import matplotlib.colors
  173. conv = matplotlib.colors.ColorConverter()
  174. def _maybe_valid_colors(colors):
  175. try:
  176. [conv.to_rgba(c) for c in colors]
  177. return True
  178. except ValueError:
  179. return False
  180. # check whether the string can be convertable to single color
  181. maybe_single_color = _maybe_valid_colors([colors])
  182. # check whether each character can be convertable to colors
  183. maybe_color_cycle = _maybe_valid_colors(list(colors))
  184. if maybe_single_color and maybe_color_cycle and len(colors) > 1:
  185. msg = ("'{0}' can be parsed as both single color and "
  186. "color cycle. Specify each color using a list "
  187. "like ['{0}'] or {1}")
  188. raise ValueError(msg.format(colors, list(colors)))
  189. elif maybe_single_color:
  190. colors = [colors]
  191. else:
  192. # ``colors`` is regarded as color cycle.
  193. # mpl will raise error any of them is invalid
  194. pass
  195. if len(colors) != num_colors:
  196. multiple = num_colors // len(colors) - 1
  197. mod = num_colors % len(colors)
  198. colors += multiple * colors
  199. colors += colors[:mod]
  200. return colors
  201. class _Options(dict):
  202. """
  203. Stores pandas plotting options.
  204. Allows for parameter aliasing so you can just use parameter names that are
  205. the same as the plot function parameters, but is stored in a canonical
  206. format that makes it easy to breakdown into groups later
  207. """
  208. # alias so the names are same as plotting method parameter names
  209. _ALIASES = {'x_compat': 'xaxis.compat'}
  210. _DEFAULT_KEYS = ['xaxis.compat']
  211. def __init__(self):
  212. self['xaxis.compat'] = False
  213. def __getitem__(self, key):
  214. key = self._get_canonical_key(key)
  215. if key not in self:
  216. raise ValueError('%s is not a valid pandas plotting option' % key)
  217. return super(_Options, self).__getitem__(key)
  218. def __setitem__(self, key, value):
  219. key = self._get_canonical_key(key)
  220. return super(_Options, self).__setitem__(key, value)
  221. def __delitem__(self, key):
  222. key = self._get_canonical_key(key)
  223. if key in self._DEFAULT_KEYS:
  224. raise ValueError('Cannot remove default parameter %s' % key)
  225. return super(_Options, self).__delitem__(key)
  226. def __contains__(self, key):
  227. key = self._get_canonical_key(key)
  228. return super(_Options, self).__contains__(key)
  229. def reset(self):
  230. """
  231. Reset the option store to its initial state
  232. Returns
  233. -------
  234. None
  235. """
  236. self.__init__()
  237. def _get_canonical_key(self, key):
  238. return self._ALIASES.get(key, key)
  239. @contextmanager
  240. def use(self, key, value):
  241. """
  242. Temporarily set a parameter value using the with statement.
  243. Aliasing allowed.
  244. """
  245. old_value = self[key]
  246. try:
  247. self[key] = value
  248. yield self
  249. finally:
  250. self[key] = old_value
  251. plot_params = _Options()
  252. def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
  253. diagonal='hist', marker='.', density_kwds=None,
  254. hist_kwds=None, range_padding=0.05, **kwds):
  255. """
  256. Draw a matrix of scatter plots.
  257. Parameters
  258. ----------
  259. frame : DataFrame
  260. alpha : float, optional
  261. amount of transparency applied
  262. figsize : (float,float), optional
  263. a tuple (width, height) in inches
  264. ax : Matplotlib axis object, optional
  265. grid : bool, optional
  266. setting this to True will show the grid
  267. diagonal : {'hist', 'kde'}
  268. pick between 'kde' and 'hist' for
  269. either Kernel Density Estimation or Histogram
  270. plot in the diagonal
  271. marker : str, optional
  272. Matplotlib marker type, default '.'
  273. hist_kwds : other plotting keyword arguments
  274. To be passed to hist function
  275. density_kwds : other plotting keyword arguments
  276. To be passed to kernel density estimate plot
  277. range_padding : float, optional
  278. relative extension of axis range in x and y
  279. with respect to (x_max - x_min) or (y_max - y_min),
  280. default 0.05
  281. kwds : other plotting keyword arguments
  282. To be passed to scatter function
  283. Examples
  284. --------
  285. >>> df = DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
  286. >>> scatter_matrix(df, alpha=0.2)
  287. """
  288. import matplotlib.pyplot as plt
  289. df = frame._get_numeric_data()
  290. n = df.columns.size
  291. naxes = n * n
  292. fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax,
  293. squeeze=False)
  294. # no gaps between subplots
  295. fig.subplots_adjust(wspace=0, hspace=0)
  296. mask = notnull(df)
  297. marker = _get_marker_compat(marker)
  298. hist_kwds = hist_kwds or {}
  299. density_kwds = density_kwds or {}
  300. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  301. kwds.setdefault('c', plt.rcParams['patch.facecolor'])
  302. boundaries_list = []
  303. for a in df.columns:
  304. values = df[a].values[mask[a].values]
  305. rmin_, rmax_ = np.min(values), np.max(values)
  306. rdelta_ext = (rmax_ - rmin_) * range_padding / 2.
  307. boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
  308. for i, a in zip(lrange(n), df.columns):
  309. for j, b in zip(lrange(n), df.columns):
  310. ax = axes[i, j]
  311. if i == j:
  312. values = df[a].values[mask[a].values]
  313. # Deal with the diagonal by drawing a histogram there.
  314. if diagonal == 'hist':
  315. ax.hist(values, **hist_kwds)
  316. elif diagonal in ('kde', 'density'):
  317. from scipy.stats import gaussian_kde
  318. y = values
  319. gkde = gaussian_kde(y)
  320. ind = np.linspace(y.min(), y.max(), 1000)
  321. ax.plot(ind, gkde.evaluate(ind), **density_kwds)
  322. ax.set_xlim(boundaries_list[i])
  323. else:
  324. common = (mask[a] & mask[b]).values
  325. ax.scatter(df[b][common], df[a][common],
  326. marker=marker, alpha=alpha, **kwds)
  327. ax.set_xlim(boundaries_list[j])
  328. ax.set_ylim(boundaries_list[i])
  329. ax.set_xlabel(b)
  330. ax.set_ylabel(a)
  331. if j != 0:
  332. ax.yaxis.set_visible(False)
  333. if i != n - 1:
  334. ax.xaxis.set_visible(False)
  335. if len(df.columns) > 1:
  336. lim1 = boundaries_list[0]
  337. locs = axes[0][1].yaxis.get_majorticklocs()
  338. locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
  339. adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
  340. lim0 = axes[0][0].get_ylim()
  341. adj = adj * (lim0[1] - lim0[0]) + lim0[0]
  342. axes[0][0].yaxis.set_ticks(adj)
  343. if np.all(locs == locs.astype(int)):
  344. # if all ticks are int
  345. locs = locs.astype(int)
  346. axes[0][0].yaxis.set_ticklabels(locs)
  347. _set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  348. return axes
  349. def _gca():
  350. import matplotlib.pyplot as plt
  351. return plt.gca()
  352. def _gcf():
  353. import matplotlib.pyplot as plt
  354. return plt.gcf()
  355. def _get_marker_compat(marker):
  356. import matplotlib.lines as mlines
  357. import matplotlib as mpl
  358. if mpl.__version__ < '1.1.0' and marker == '.':
  359. return 'o'
  360. if marker not in mlines.lineMarkers:
  361. return 'o'
  362. return marker
  363. def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
  364. """RadViz - a multivariate data visualization algorithm
  365. Parameters:
  366. -----------
  367. frame: DataFrame
  368. class_column: str
  369. Column name containing class names
  370. ax: Matplotlib axis object, optional
  371. color: list or tuple, optional
  372. Colors to use for the different classes
  373. colormap : str or matplotlib colormap object, default None
  374. Colormap to select colors from. If string, load colormap with that name
  375. from matplotlib.
  376. kwds: keywords
  377. Options to pass to matplotlib scatter plotting method
  378. Returns:
  379. --------
  380. ax: Matplotlib axis object
  381. """
  382. import matplotlib.pyplot as plt
  383. import matplotlib.patches as patches
  384. def normalize(series):
  385. a = min(series)
  386. b = max(series)
  387. return (series - a) / (b - a)
  388. n = len(frame)
  389. classes = frame[class_column].drop_duplicates()
  390. class_col = frame[class_column]
  391. df = frame.drop(class_column, axis=1).apply(normalize)
  392. if ax is None:
  393. ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
  394. to_plot = {}
  395. colors = _get_standard_colors(num_colors=len(classes), colormap=colormap,
  396. color_type='random', color=color)
  397. for kls in classes:
  398. to_plot[kls] = [[], []]
  399. m = len(frame.columns) - 1
  400. s = np.array([(np.cos(t), np.sin(t))
  401. for t in [2.0 * np.pi * (i / float(m))
  402. for i in range(m)]])
  403. for i in range(n):
  404. row = df.iloc[i].values
  405. row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
  406. y = (s * row_).sum(axis=0) / row.sum()
  407. kls = class_col.iat[i]
  408. to_plot[kls][0].append(y[0])
  409. to_plot[kls][1].append(y[1])
  410. for i, kls in enumerate(classes):
  411. ax.scatter(to_plot[kls][0], to_plot[kls][1], color=colors[i],
  412. label=pprint_thing(kls), **kwds)
  413. ax.legend()
  414. ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
  415. for xy, name in zip(s, df.columns):
  416. ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
  417. if xy[0] < 0.0 and xy[1] < 0.0:
  418. ax.text(xy[0] - 0.025, xy[1] - 0.025, name,
  419. ha='right', va='top', size='small')
  420. elif xy[0] < 0.0 and xy[1] >= 0.0:
  421. ax.text(xy[0] - 0.025, xy[1] + 0.025, name,
  422. ha='right', va='bottom', size='small')
  423. elif xy[0] >= 0.0 and xy[1] < 0.0:
  424. ax.text(xy[0] + 0.025, xy[1] - 0.025, name,
  425. ha='left', va='top', size='small')
  426. elif xy[0] >= 0.0 and xy[1] >= 0.0:
  427. ax.text(xy[0] + 0.025, xy[1] + 0.025, name,
  428. ha='left', va='bottom', size='small')
  429. ax.axis('equal')
  430. return ax
  431. @deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
  432. def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
  433. colormap=None, **kwds):
  434. """
  435. Generates a matplotlib plot of Andrews curves, for visualising clusters of
  436. multivariate data.
  437. Andrews curves have the functional form:
  438. f(t) = x_1/sqrt(2) + x_2 sin(t) + x_3 cos(t) +
  439. x_4 sin(2t) + x_5 cos(2t) + ...
  440. Where x coefficients correspond to the values of each dimension and t is
  441. linearly spaced between -pi and +pi. Each row of frame then corresponds to
  442. a single curve.
  443. Parameters:
  444. -----------
  445. frame : DataFrame
  446. Data to be plotted, preferably normalized to (0.0, 1.0)
  447. class_column : Name of the column containing class names
  448. ax : matplotlib axes object, default None
  449. samples : Number of points to plot in each curve
  450. color: list or tuple, optional
  451. Colors to use for the different classes
  452. colormap : str or matplotlib colormap object, default None
  453. Colormap to select colors from. If string, load colormap with that name
  454. from matplotlib.
  455. kwds: keywords
  456. Options to pass to matplotlib plotting method
  457. Returns:
  458. --------
  459. ax: Matplotlib axis object
  460. """
  461. from math import sqrt, pi
  462. import matplotlib.pyplot as plt
  463. def function(amplitudes):
  464. def f(t):
  465. x1 = amplitudes[0]
  466. result = x1 / sqrt(2.0)
  467. # Take the rest of the coefficients and resize them
  468. # appropriately. Take a copy of amplitudes as otherwise numpy
  469. # deletes the element from amplitudes itself.
  470. coeffs = np.delete(np.copy(amplitudes), 0)
  471. coeffs.resize(int((coeffs.size + 1) / 2), 2)
  472. # Generate the harmonics and arguments for the sin and cos
  473. # functions.
  474. harmonics = np.arange(0, coeffs.shape[0]) + 1
  475. trig_args = np.outer(harmonics, t)
  476. result += np.sum(coeffs[:, 0, np.newaxis] * np.sin(trig_args) +
  477. coeffs[:, 1, np.newaxis] * np.cos(trig_args),
  478. axis=0)
  479. return result
  480. return f
  481. n = len(frame)
  482. class_col = frame[class_column]
  483. classes = frame[class_column].drop_duplicates()
  484. df = frame.drop(class_column, axis=1)
  485. t = np.linspace(-pi, pi, samples)
  486. used_legends = set([])
  487. color_values = _get_standard_colors(num_colors=len(classes),
  488. colormap=colormap, color_type='random',
  489. color=color)
  490. colors = dict(zip(classes, color_values))
  491. if ax is None:
  492. ax = plt.gca(xlim=(-pi, pi))
  493. for i in range(n):
  494. row = df.iloc[i].values
  495. f = function(row)
  496. y = f(t)
  497. kls = class_col.iat[i]
  498. label = pprint_thing(kls)
  499. if label not in used_legends:
  500. used_legends.add(label)
  501. ax.plot(t, y, color=colors[kls], label=label, **kwds)
  502. else:
  503. ax.plot(t, y, color=colors[kls], **kwds)
  504. ax.legend(loc='upper right')
  505. ax.grid()
  506. return ax
  507. def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
  508. """Bootstrap plot.
  509. Parameters:
  510. -----------
  511. series: Time series
  512. fig: matplotlib figure object, optional
  513. size: number of data points to consider during each sampling
  514. samples: number of times the bootstrap procedure is performed
  515. kwds: optional keyword arguments for plotting commands, must be accepted
  516. by both hist and plot
  517. Returns:
  518. --------
  519. fig: matplotlib figure
  520. """
  521. import random
  522. import matplotlib.pyplot as plt
  523. # random.sample(ndarray, int) fails on python 3.3, sigh
  524. data = list(series.values)
  525. samplings = [random.sample(data, size) for _ in range(samples)]
  526. means = np.array([np.mean(sampling) for sampling in samplings])
  527. medians = np.array([np.median(sampling) for sampling in samplings])
  528. midranges = np.array([(min(sampling) + max(sampling)) * 0.5
  529. for sampling in samplings])
  530. if fig is None:
  531. fig = plt.figure()
  532. x = lrange(samples)
  533. axes = []
  534. ax1 = fig.add_subplot(2, 3, 1)
  535. ax1.set_xlabel("Sample")
  536. axes.append(ax1)
  537. ax1.plot(x, means, **kwds)
  538. ax2 = fig.add_subplot(2, 3, 2)
  539. ax2.set_xlabel("Sample")
  540. axes.append(ax2)
  541. ax2.plot(x, medians, **kwds)
  542. ax3 = fig.add_subplot(2, 3, 3)
  543. ax3.set_xlabel("Sample")
  544. axes.append(ax3)
  545. ax3.plot(x, midranges, **kwds)
  546. ax4 = fig.add_subplot(2, 3, 4)
  547. ax4.set_xlabel("Mean")
  548. axes.append(ax4)
  549. ax4.hist(means, **kwds)
  550. ax5 = fig.add_subplot(2, 3, 5)
  551. ax5.set_xlabel("Median")
  552. axes.append(ax5)
  553. ax5.hist(medians, **kwds)
  554. ax6 = fig.add_subplot(2, 3, 6)
  555. ax6.set_xlabel("Midrange")
  556. axes.append(ax6)
  557. ax6.hist(midranges, **kwds)
  558. for axis in axes:
  559. plt.setp(axis.get_xticklabels(), fontsize=8)
  560. plt.setp(axis.get_yticklabels(), fontsize=8)
  561. return fig
  562. @deprecate_kwarg(old_arg_name='colors', new_arg_name='color')
  563. @deprecate_kwarg(old_arg_name='data', new_arg_name='frame', stacklevel=3)
  564. def parallel_coordinates(frame, class_column, cols=None, ax=None, color=None,
  565. use_columns=False, xticks=None, colormap=None,
  566. axvlines=True, axvlines_kwds=None, **kwds):
  567. """Parallel coordinates plotting.
  568. Parameters
  569. ----------
  570. frame: DataFrame
  571. class_column: str
  572. Column name containing class names
  573. cols: list, optional
  574. A list of column names to use
  575. ax: matplotlib.axis, optional
  576. matplotlib axis object
  577. color: list or tuple, optional
  578. Colors to use for the different classes
  579. use_columns: bool, optional
  580. If true, columns will be used as xticks
  581. xticks: list or tuple, optional
  582. A list of values to use for xticks
  583. colormap: str or matplotlib colormap, default None
  584. Colormap to use for line colors.
  585. axvlines: bool, optional
  586. If true, vertical lines will be added at each xtick
  587. axvlines_kwds: keywords, optional
  588. Options to be passed to axvline method for vertical lines
  589. kwds: keywords
  590. Options to pass to matplotlib plotting method
  591. Returns
  592. -------
  593. ax: matplotlib axis object
  594. Examples
  595. --------
  596. >>> from pandas import read_csv
  597. >>> from pandas.tools.plotting import parallel_coordinates
  598. >>> from matplotlib import pyplot as plt
  599. >>> df = read_csv('https://raw.github.com/pydata/pandas/master'
  600. '/pandas/tests/data/iris.csv')
  601. >>> parallel_coordinates(df, 'Name', color=('#556270',
  602. '#4ECDC4', '#C7F464'))
  603. >>> plt.show()
  604. """
  605. if axvlines_kwds is None:
  606. axvlines_kwds = {'linewidth': 1, 'color': 'black'}
  607. import matplotlib.pyplot as plt
  608. n = len(frame)
  609. classes = frame[class_column].drop_duplicates()
  610. class_col = frame[class_column]
  611. if cols is None:
  612. df = frame.drop(class_column, axis=1)
  613. else:
  614. df = frame[cols]
  615. used_legends = set([])
  616. ncols = len(df.columns)
  617. # determine values to use for xticks
  618. if use_columns is True:
  619. if not np.all(np.isreal(list(df.columns))):
  620. raise ValueError('Columns must be numeric to be used as xticks')
  621. x = df.columns
  622. elif xticks is not None:
  623. if not np.all(np.isreal(xticks)):
  624. raise ValueError('xticks specified must be numeric')
  625. elif len(xticks) != ncols:
  626. raise ValueError('Length of xticks must match number of columns')
  627. x = xticks
  628. else:
  629. x = lrange(ncols)
  630. if ax is None:
  631. ax = plt.gca()
  632. color_values = _get_standard_colors(num_colors=len(classes),
  633. colormap=colormap, color_type='random',
  634. color=color)
  635. colors = dict(zip(classes, color_values))
  636. for i in range(n):
  637. y = df.iloc[i].values
  638. kls = class_col.iat[i]
  639. label = pprint_thing(kls)
  640. if label not in used_legends:
  641. used_legends.add(label)
  642. ax.plot(x, y, color=colors[kls], label=label, **kwds)
  643. else:
  644. ax.plot(x, y, color=colors[kls], **kwds)
  645. if axvlines:
  646. for i in x:
  647. ax.axvline(i, **axvlines_kwds)
  648. ax.set_xticks(x)
  649. ax.set_xticklabels(df.columns)
  650. ax.set_xlim(x[0], x[-1])
  651. ax.legend(loc='upper right')
  652. ax.grid()
  653. return ax
  654. def lag_plot(series, lag=1, ax=None, **kwds):
  655. """Lag plot for time series.
  656. Parameters:
  657. -----------
  658. series: Time series
  659. lag: lag of the scatter plot, default 1
  660. ax: Matplotlib axis object, optional
  661. kwds: Matplotlib scatter method keyword arguments, optional
  662. Returns:
  663. --------
  664. ax: Matplotlib axis object
  665. """
  666. import matplotlib.pyplot as plt
  667. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  668. kwds.setdefault('c', plt.rcParams['patch.facecolor'])
  669. data = series.values
  670. y1 = data[:-lag]
  671. y2 = data[lag:]
  672. if ax is None:
  673. ax = plt.gca()
  674. ax.set_xlabel("y(t)")
  675. ax.set_ylabel("y(t + %s)" % lag)
  676. ax.scatter(y1, y2, **kwds)
  677. return ax
  678. def autocorrelation_plot(series, ax=None, **kwds):
  679. """Autocorrelation plot for time series.
  680. Parameters:
  681. -----------
  682. series: Time series
  683. ax: Matplotlib axis object, optional
  684. kwds : keywords
  685. Options to pass to matplotlib plotting method
  686. Returns:
  687. -----------
  688. ax: Matplotlib axis object
  689. """
  690. import matplotlib.pyplot as plt
  691. n = len(series)
  692. data = np.asarray(series)
  693. if ax is None:
  694. ax = plt.gca(xlim=(1, n), ylim=(-1.0, 1.0))
  695. mean = np.mean(data)
  696. c0 = np.sum((data - mean) ** 2) / float(n)
  697. def r(h):
  698. return ((data[:n - h] - mean) *
  699. (data[h:] - mean)).sum() / float(n) / c0
  700. x = np.arange(n) + 1
  701. y = lmap(r, x)
  702. z95 = 1.959963984540054
  703. z99 = 2.5758293035489004
  704. ax.axhline(y=z99 / np.sqrt(n), linestyle='--', color='grey')
  705. ax.axhline(y=z95 / np.sqrt(n), color='grey')
  706. ax.axhline(y=0.0, color='black')
  707. ax.axhline(y=-z95 / np.sqrt(n), color='grey')
  708. ax.axhline(y=-z99 / np.sqrt(n), linestyle='--', color='grey')
  709. ax.set_xlabel("Lag")
  710. ax.set_ylabel("Autocorrelation")
  711. ax.plot(x, y, **kwds)
  712. if 'label' in kwds:
  713. ax.legend()
  714. ax.grid()
  715. return ax
  716. class MPLPlot(object):
  717. """
  718. Base class for assembling a pandas plot using matplotlib
  719. Parameters
  720. ----------
  721. data :
  722. """
  723. @property
  724. def _kind(self):
  725. """Specify kind str. Must be overridden in child class"""
  726. raise NotImplementedError
  727. _layout_type = 'vertical'
  728. _default_rot = 0
  729. orientation = None
  730. _pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog',
  731. 'mark_right', 'stacked']
  732. _attr_defaults = {'logy': False, 'logx': False, 'loglog': False,
  733. 'mark_right': True, 'stacked': False}
  734. def __init__(self, data, kind=None, by=None, subplots=False, sharex=None,
  735. sharey=False, use_index=True,
  736. figsize=None, grid=None, legend=True, rot=None,
  737. ax=None, fig=None, title=None, xlim=None, ylim=None,
  738. xticks=None, yticks=None,
  739. sort_columns=False, fontsize=None,
  740. secondary_y=False, colormap=None,
  741. table=False, layout=None, **kwds):
  742. self.data = data
  743. self.by = by
  744. self.kind = kind
  745. self.sort_columns = sort_columns
  746. self.subplots = subplots
  747. if sharex is None:
  748. if ax is None:
  749. self.sharex = True
  750. else:
  751. # if we get an axis, the users should do the visibility
  752. # setting...
  753. self.sharex = False
  754. else:
  755. self.sharex = sharex
  756. self.sharey = sharey
  757. self.figsize = figsize
  758. self.layout = layout
  759. self.xticks = xticks
  760. self.yticks = yticks
  761. self.xlim = xlim
  762. self.ylim = ylim
  763. self.title = title
  764. self.use_index = use_index
  765. self.fontsize = fontsize
  766. if rot is not None:
  767. self.rot = rot
  768. # need to know for format_date_labels since it's rotated to 30 by
  769. # default
  770. self._rot_set = True
  771. else:
  772. self._rot_set = False
  773. self.rot = self._default_rot
  774. if grid is None:
  775. grid = False if secondary_y else self.plt.rcParams['axes.grid']
  776. self.grid = grid
  777. self.legend = legend
  778. self.legend_handles = []
  779. self.legend_labels = []
  780. for attr in self._pop_attributes:
  781. value = kwds.pop(attr, self._attr_defaults.get(attr, None))
  782. setattr(self, attr, value)
  783. self.ax = ax
  784. self.fig = fig
  785. self.axes = None
  786. # parse errorbar input if given
  787. xerr = kwds.pop('xerr', None)
  788. yerr = kwds.pop('yerr', None)
  789. self.errors = {}
  790. for kw, err in zip(['xerr', 'yerr'], [xerr, yerr]):
  791. self.errors[kw] = self._parse_errorbars(kw, err)
  792. if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, Index)):
  793. secondary_y = [secondary_y]
  794. self.secondary_y = secondary_y
  795. # ugly TypeError if user passes matplotlib's `cmap` name.
  796. # Probably better to accept either.
  797. if 'cmap' in kwds and colormap:
  798. raise TypeError("Only specify one of `cmap` and `colormap`.")
  799. elif 'cmap' in kwds:
  800. self.colormap = kwds.pop('cmap')
  801. else:
  802. self.colormap = colormap
  803. self.table = table
  804. self.kwds = kwds
  805. self._validate_color_args()
  806. def _validate_color_args(self):
  807. if 'color' not in self.kwds and 'colors' in self.kwds:
  808. warnings.warn(("'colors' is being deprecated. Please use 'color'"
  809. "instead of 'colors'"))
  810. colors = self.kwds.pop('colors')
  811. self.kwds['color'] = colors
  812. if ('color' in self.kwds and self.nseries == 1):
  813. # support series.plot(color='green')
  814. self.kwds['color'] = [self.kwds['color']]
  815. if ('color' in self.kwds or 'colors' in self.kwds) and \
  816. self.colormap is not None:
  817. warnings.warn("'color' and 'colormap' cannot be used "
  818. "simultaneously. Using 'color'")
  819. if 'color' in self.kwds and self.style is not None:
  820. if is_list_like(self.style):
  821. styles = self.style
  822. else:
  823. styles = [self.style]
  824. # need only a single match
  825. for s in styles:
  826. if re.match('^[a-z]+?', s) is not None:
  827. raise ValueError(
  828. "Cannot pass 'style' string with a color "
  829. "symbol and 'color' keyword argument. Please"
  830. " use one or the other or pass 'style' "
  831. "without a color symbol")
  832. def _iter_data(self, data=None, keep_index=False, fillna=None):
  833. if data is None:
  834. data = self.data
  835. if fillna is not None:
  836. data = data.fillna(fillna)
  837. # TODO: unused?
  838. # if self.sort_columns:
  839. # columns = _try_sort(data.columns)
  840. # else:
  841. # columns = data.columns
  842. for col, values in data.iteritems():
  843. if keep_index is True:
  844. yield col, values
  845. else:
  846. yield col, values.values
  847. @property
  848. def nseries(self):
  849. if self.data.ndim == 1:
  850. return 1
  851. else:
  852. return self.data.shape[1]
  853. def draw(self):
  854. self.plt.draw_if_interactive()
  855. def generate(self):
  856. self._args_adjust()
  857. self._compute_plot_data()
  858. self._setup_subplots()
  859. self._make_plot()
  860. self._add_table()
  861. self._make_legend()
  862. self._adorn_subplots()
  863. for ax in self.axes:
  864. self._post_plot_logic_common(ax, self.data)
  865. self._post_plot_logic(ax, self.data)
  866. def _args_adjust(self):
  867. pass
  868. def _has_plotted_object(self, ax):
  869. """check whether ax has data"""
  870. return (len(ax.lines) != 0 or
  871. len(ax.artists) != 0 or
  872. len(ax.containers) != 0)
  873. def _maybe_right_yaxis(self, ax, axes_num):
  874. if not self.on_right(axes_num):
  875. # secondary axes may be passed via ax kw
  876. return self._get_ax_layer(ax)
  877. if hasattr(ax, 'right_ax'):
  878. # if it has right_ax proparty, ``ax`` must be left axes
  879. return ax.right_ax
  880. elif hasattr(ax, 'left_ax'):
  881. # if it has left_ax proparty, ``ax`` must be right axes
  882. return ax
  883. else:
  884. # otherwise, create twin axes
  885. orig_ax, new_ax = ax, ax.twinx()
  886. # TODO: use Matplotlib public API when available
  887. new_ax._get_lines = orig_ax._get_lines
  888. new_ax._get_patches_for_fill = orig_ax._get_patches_for_fill
  889. orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
  890. if not self._has_plotted_object(orig_ax): # no data on left y
  891. orig_ax.get_yaxis().set_visible(False)
  892. return new_ax
  893. def _setup_subplots(self):
  894. if self.subplots:
  895. fig, axes = _subplots(naxes=self.nseries,
  896. sharex=self.sharex, sharey=self.sharey,
  897. figsize=self.figsize, ax=self.ax,
  898. layout=self.layout,
  899. layout_type=self._layout_type)
  900. else:
  901. if self.ax is None:
  902. fig = self.plt.figure(figsize=self.figsize)
  903. axes = fig.add_subplot(111)
  904. else:
  905. fig = self.ax.get_figure()
  906. if self.figsize is not None:
  907. fig.set_size_inches(self.figsize)
  908. axes = self.ax
  909. axes = _flatten(axes)
  910. if self.logx or self.loglog:
  911. [a.set_xscale('log') for a in axes]
  912. if self.logy or self.loglog:
  913. [a.set_yscale('log') for a in axes]
  914. self.fig = fig
  915. self.axes = axes
  916. @property
  917. def result(self):
  918. """
  919. Return result axes
  920. """
  921. if self.subplots:
  922. if self.layout is not None and not is_list_like(self.ax):
  923. return self.axes.reshape(*self.layout)
  924. else:
  925. return self.axes
  926. else:
  927. sec_true = isinstance(self.secondary_y, bool) and self.secondary_y
  928. all_sec = (is_list_like(self.secondary_y) and
  929. len(self.secondary_y) == self.nseries)
  930. if (sec_true or all_sec):
  931. # if all data is plotted on secondary, return right axes
  932. return self._get_ax_layer(self.axes[0], primary=False)
  933. else:
  934. return self.axes[0]
  935. def _compute_plot_data(self):
  936. data = self.data
  937. if isinstance(data, Series):
  938. label = self.label
  939. if label is None and data.name is None:
  940. label = 'None'
  941. data = data.to_frame(name=label)
  942. numeric_data = data._convert(datetime=True)._get_numeric_data()
  943. try:
  944. is_empty = numeric_data.empty
  945. except AttributeError:
  946. is_empty = not len(numeric_data)
  947. # no empty frames or series allowed
  948. if is_empty:
  949. raise TypeError('Empty {0!r}: no numeric data to '
  950. 'plot'.format(numeric_data.__class__.__name__))
  951. self.data = numeric_data
  952. def _make_plot(self):
  953. raise AbstractMethodError(self)
  954. def _add_table(self):
  955. if self.table is False:
  956. return
  957. elif self.table is True:
  958. data = self.data.transpose()
  959. else:
  960. data = self.table
  961. ax = self._get_ax(0)
  962. table(ax, data)
  963. def _post_plot_logic_common(self, ax, data):
  964. """Common post process for each axes"""
  965. labels = [pprint_thing(key) for key in data.index]
  966. labels = dict(zip(range(len(data.index)), labels))
  967. if self.orientation == 'vertical' or self.orientation is None:
  968. if self._need_to_set_index:
  969. xticklabels = [labels.get(x, '') for x in ax.get_xticks()]
  970. ax.set_xticklabels(xticklabels)
  971. self._apply_axis_properties(ax.xaxis, rot=self.rot,
  972. fontsize=self.fontsize)
  973. self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
  974. elif self.orientation == 'horizontal':
  975. if self._need_to_set_index:
  976. yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
  977. ax.set_yticklabels(yticklabels)
  978. self._apply_axis_properties(ax.yaxis, rot=self.rot,
  979. fontsize=self.fontsize)
  980. self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)
  981. else: # pragma no cover
  982. raise ValueError
  983. def _post_plot_logic(self, ax, data):
  984. """Post process for each axes. Overridden in child classes"""
  985. pass
  986. def _adorn_subplots(self):
  987. """Common post process unrelated to data"""
  988. if len(self.axes) > 0:
  989. all_axes = self._get_subplots()
  990. nrows, ncols = self._get_axes_layout()
  991. _handle_shared_axes(axarr=all_axes, nplots=len(all_axes),
  992. naxes=nrows * ncols, nrows=nrows,
  993. ncols=ncols, sharex=self.sharex,
  994. sharey=self.sharey)
  995. for ax in self.axes:
  996. if self.yticks is not None:
  997. ax.set_yticks(self.yticks)
  998. if self.xticks is not None:
  999. ax.set_xticks(self.xticks)
  1000. if self.ylim is not None:
  1001. ax.set_ylim(self.ylim)
  1002. if self.xlim is not None:
  1003. ax.set_xlim(self.xlim)
  1004. ax.grid(self.grid)
  1005. if self.title:
  1006. if self.subplots:
  1007. self.fig.suptitle(self.title)
  1008. else:
  1009. self.axes[0].set_title(self.title)
  1010. def _apply_axis_properties(self, axis, rot=None, fontsize=None):
  1011. labels = axis.get_majorticklabels() + axis.get_minorticklabels()
  1012. for label in labels:
  1013. if rot is not None:
  1014. label.set_rotation(rot)
  1015. if fontsize is not None:
  1016. label.set_fontsize(fontsize)
  1017. @property
  1018. def legend_title(self):
  1019. if not isinstance(self.data.columns, MultiIndex):
  1020. name = self.data.columns.name
  1021. if name is not None:
  1022. name = pprint_thing(name)
  1023. return name
  1024. else:
  1025. stringified = map(pprint_thing,
  1026. self.data.columns.names)
  1027. return ','.join(stringified)
  1028. def _add_legend_handle(self, handle, label, index=None):
  1029. if label is not None:
  1030. if self.mark_right and index is not None:
  1031. if self.on_right(index):
  1032. label = label + ' (right)'
  1033. self.legend_handles.append(handle)
  1034. self.legend_labels.append(label)
  1035. def _make_legend(self):
  1036. ax, leg = self._get_ax_legend(self.axes[0])
  1037. handles = []
  1038. labels = []
  1039. title = ''
  1040. if not self.subplots:
  1041. if leg is not None:
  1042. title = leg.get_title().get_text()
  1043. handles = leg.legendHandles
  1044. labels = [x.get_text() for x in leg.get_texts()]
  1045. if self.legend:
  1046. if self.legend == 'reverse':
  1047. self.legend_handles = reversed(self.legend_handles)
  1048. self.legend_labels = reversed(self.legend_labels)
  1049. handles += self.legend_handles
  1050. labels += self.legend_labels
  1051. if self.legend_title is not None:
  1052. title = self.legend_title
  1053. if len(handles) > 0:
  1054. ax.legend(handles, labels, loc='best', title=title)
  1055. elif self.subplots and self.legend:
  1056. for ax in self.axes:
  1057. if ax.get_visible():
  1058. ax.legend(loc='best')
  1059. def _get_ax_legend(self, ax):
  1060. leg = ax.get_legend()
  1061. other_ax = (getattr(ax, 'left_ax', None) or
  1062. getattr(ax, 'right_ax', None))
  1063. other_leg = None
  1064. if other_ax is not None:
  1065. other_leg = other_ax.get_legend()
  1066. if leg is None and other_leg is not None:
  1067. leg = other_leg
  1068. ax = other_ax
  1069. return ax, leg
  1070. @cache_readonly
  1071. def plt(self):
  1072. import matplotlib.pyplot as plt
  1073. return plt
  1074. @staticmethod
  1075. def mpl_ge_1_3_1():
  1076. return _mpl_ge_1_3_1()
  1077. @staticmethod
  1078. def mpl_ge_1_5_0():
  1079. return _mpl_ge_1_5_0()
  1080. _need_to_set_index = False
  1081. def _get_xticks(self, convert_period=False):
  1082. index = self.data.index
  1083. is_datetype = index.inferred_type in ('datetime', 'date',
  1084. 'datetime64', 'time')
  1085. if self.use_index:
  1086. if convert_period and isinstance(index, PeriodIndex):
  1087. self.data = self.data.reindex(index=index.sort_values())
  1088. x = self.data.index.to_timestamp()._mpl_repr()
  1089. elif index.is_numeric():
  1090. """
  1091. Matplotlib supports numeric values or datetime objects as
  1092. xaxis values. Taking LBYL approach here, by the time
  1093. matplotlib raises exception when using non numeric/datetime
  1094. values for xaxis, several actions are already taken by plt.
  1095. """
  1096. x = index._mpl_repr()
  1097. elif is_datetype:
  1098. self.data = self.data.sort_index()
  1099. x = self.data.index._mpl_repr()
  1100. else:
  1101. self._need_to_set_index = True
  1102. x = lrange(len(index))
  1103. else:
  1104. x = lrange(len(index))
  1105. return x
  1106. @classmethod
  1107. def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds):
  1108. mask = isnull(y)
  1109. if mask.any():
  1110. y = np.ma.array(y)
  1111. y = np.ma.masked_where(mask, y)
  1112. if isinstance(x, Index):
  1113. x = x._mpl_repr()
  1114. if is_errorbar:
  1115. if 'xerr' in kwds:
  1116. kwds['xerr'] = np.array(kwds.get('xerr'))
  1117. if 'yerr' in kwds:
  1118. kwds['yerr'] = np.array(kwds.get('yerr'))
  1119. return ax.errorbar(x, y, **kwds)
  1120. else:
  1121. # prevent style kwarg from going to errorbar, where it is
  1122. # unsupported
  1123. if style is not None:
  1124. args = (x, y, style)
  1125. else:
  1126. args = (x, y)
  1127. return ax.plot(*args, **kwds)
  1128. def _get_index_name(self):
  1129. if isinstance(self.data.index, MultiIndex):
  1130. name = self.data.index.names
  1131. if any(x is not None for x in name):
  1132. name = ','.join([pprint_thing(x) for x in name])
  1133. else:
  1134. name = None
  1135. else:
  1136. name = self.data.index.name
  1137. if name is not None:
  1138. name = pprint_thing(name)
  1139. return name
  1140. @classmethod
  1141. def _get_ax_layer(cls, ax, primary=True):
  1142. """get left (primary) or right (secondary) axes"""
  1143. if primary:
  1144. return getattr(ax, 'left_ax', ax)
  1145. else:
  1146. return getattr(ax, 'right_ax', ax)
  1147. def _get_ax(self, i):
  1148. # get the twinx ax if appropriate
  1149. if self.subplots:
  1150. ax = self.axes[i]
  1151. ax = self._maybe_right_yaxis(ax, i)
  1152. self.axes[i] = ax
  1153. else:
  1154. ax = self.axes[0]
  1155. ax = self._maybe_right_yaxis(ax, i)
  1156. ax.get_yaxis().set_visible(True)
  1157. return ax
  1158. def on_right(self, i):
  1159. if isinstance(self.secondary_y, bool):
  1160. return self.secondary_y
  1161. if isinstance(self.secondary_y, (tuple, list, np.ndarray, Index)):
  1162. return self.data.columns[i] in self.secondary_y
  1163. def _apply_style_colors(self, colors, kwds, col_num, label):
  1164. """
  1165. Manage style and color based on column number and its label.
  1166. Returns tuple of appropriate style and kwds which "color" may be added.
  1167. """
  1168. style = None
  1169. if self.style is not None:
  1170. if isinstance(self.style, list):
  1171. try:
  1172. style = self.style[col_num]
  1173. except IndexError:
  1174. pass
  1175. elif isinstance(self.style, dict):
  1176. style = self.style.get(label, style)
  1177. else:
  1178. style = self.style
  1179. has_color = 'color' in kwds or self.colormap is not None
  1180. nocolor_style = style is None or re.match('[a-z]+', style) is None
  1181. if (has_color or self.subplots) and nocolor_style:
  1182. kwds['color'] = colors[col_num % len(colors)]
  1183. return style, kwds
  1184. def _get_colors(self, num_colors=None, color_kwds='color'):
  1185. if num_colors is None:
  1186. num_colors = self.nseries
  1187. return _get_standard_colors(num_colors=num_colors,
  1188. colormap=self.colormap,
  1189. color=self.kwds.get(color_kwds))
  1190. def _parse_errorbars(self, label, err):
  1191. """
  1192. Look for error keyword arguments and return the actual errorbar data
  1193. or return the error DataFrame/dict
  1194. Error bars can be specified in several ways:
  1195. Series: the user provides a pandas.Series object of the same
  1196. length as the data
  1197. ndarray: provides a np.ndarray of the same length as the data
  1198. DataFrame/dict: error values are paired with keys matching the
  1199. key in the plotted DataFrame
  1200. str: the name of the column within the plotted DataFrame
  1201. """
  1202. if err is None:
  1203. return None
  1204. from pandas import DataFrame, Series
  1205. def match_labels(data, e):
  1206. e = e.reindex_axis(data.index)
  1207. return e
  1208. # key-matched DataFrame
  1209. if isinstance(err, DataFrame):
  1210. err = match_labels(self.data, err)
  1211. # key-matched dict
  1212. elif isinstance(err, dict):
  1213. pass
  1214. # Series of error values
  1215. elif isinstance(err, Series):
  1216. # broadcast error series across data
  1217. err = match_labels(self.data, err)
  1218. err = np.atleast_2d(err)
  1219. err = np.tile(err, (self.nseries, 1))
  1220. # errors are a column in the dataframe
  1221. elif isinstance(err, string_types):
  1222. evalues = self.data[err].values
  1223. self.data = self.data[self.data.columns.drop(err)]
  1224. err = np.atleast_2d(evalues)
  1225. err = np.tile(err, (self.nseries, 1))
  1226. elif is_list_like(err):
  1227. if is_iterator(err):
  1228. err = np.atleast_2d(list(err))
  1229. else:
  1230. # raw error values
  1231. err = np.atleast_2d(err)
  1232. err_shape = err.shape
  1233. # asymmetrical error bars
  1234. if err.ndim == 3:
  1235. if (err_shape[0] != self.nseries) or \
  1236. (err_shape[1] != 2) or \
  1237. (err_shape[2] != len(self.data)):
  1238. msg = "Asymmetrical error bars should be provided " + \
  1239. "with the shape (%u, 2, %u)" % \
  1240. (self.nseries, len(self.data))
  1241. raise ValueError(msg)
  1242. # broadcast errors to each data series
  1243. if len(err) == 1:
  1244. err = np.tile(err, (self.nseries, 1))
  1245. elif is_number(err):
  1246. err = np.tile([err], (self.nseries, len(self.data)))
  1247. else:
  1248. msg = "No valid %s detected" % label
  1249. raise ValueError(msg)
  1250. return err
  1251. def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
  1252. from pandas import DataFrame
  1253. errors = {}
  1254. for kw, flag in zip(['xerr', 'yerr'], [xerr, yerr]):
  1255. if flag:
  1256. err = self.errors[kw]
  1257. # user provided label-matched dataframe of errors
  1258. if isinstance(err, (DataFrame, dict)):
  1259. if label is not None and label in err.keys():
  1260. err = err[label]
  1261. else:
  1262. err = None
  1263. elif index is not None and err is not None:
  1264. err = err[index]
  1265. if err is not None:
  1266. errors[kw] = err
  1267. return errors
  1268. def _get_subplots(self):
  1269. from matplotlib.axes import Subplot
  1270. return [ax for ax in self.axes[0].get_figure().get_axes()
  1271. if isinstance(ax, Subplot)]
  1272. def _get_axes_layout(self):
  1273. axes = self._get_subplots()
  1274. x_set = set()
  1275. y_set = set()
  1276. for ax in axes:
  1277. # check axes coordinates to estimate layout
  1278. points = ax.get_position().get_points()
  1279. x_set.add(points[0][0])
  1280. y_set.add(points[0][1])
  1281. return (len(y_set), len(x_set))
  1282. class PlanePlot(MPLPlot):
  1283. """
  1284. Abstract class for plotting on plane, currently scatter and hexbin.
  1285. """
  1286. _layout_type = 'single'
  1287. def __init__(self, data, x, y, **kwargs):
  1288. MPLPlot.__init__(self, data, **kwargs)
  1289. if x is None or y is None:
  1290. raise ValueError(self._kind + ' requires and x and y column')
  1291. if is_integer(x) and not self.data.columns.holds_integer():
  1292. x = self.data.columns[x]
  1293. if is_integer(y) and not self.data.columns.holds_integer():
  1294. y = self.data.columns[y]
  1295. self.x = x
  1296. self.y = y
  1297. @property
  1298. def nseries(self):
  1299. return 1
  1300. def _post_plot_logic(self, ax, data):
  1301. x, y = self.x, self.y
  1302. ax.set_ylabel(pprint_thing(y))
  1303. ax.set_xlabel(pprint_thing(x))
  1304. class ScatterPlot(PlanePlot):
  1305. _kind = 'scatter'
  1306. def __init__(self, data, x, y, s=None, c=None, **kwargs):
  1307. if s is None:
  1308. # hide the matplotlib default for size, in case we want to change
  1309. # the handling of this argument later
  1310. s = 20
  1311. super(ScatterPlot, self).__init__(data, x, y, s=s, **kwargs)
  1312. if is_integer(c) and not self.data.columns.holds_integer():
  1313. c = self.data.columns[c]
  1314. self.c = c
  1315. def _make_plot(self):
  1316. x, y, c, data = self.x, self.y, self.c, self.data
  1317. ax = self.axes[0]
  1318. c_is_column = is_hashable(c) and c in self.data.columns
  1319. # plot a colorbar only if a colormap is provided or necessary
  1320. cb = self.kwds.pop('colorbar', self.colormap or c_is_column)
  1321. # pandas uses colormap, matplotlib uses cmap.
  1322. cmap = self.colormap or 'Greys'
  1323. cmap = self.plt.cm.get_cmap(cmap)
  1324. color = self.kwds.pop("color", None)
  1325. if c is not None and color is not None:
  1326. raise TypeError('Specify exactly one of `c` and `color`')
  1327. elif c is None and color is None:
  1328. c_values = self.plt.rcParams['patch.facecolor']
  1329. elif color is not None:
  1330. c_values = color
  1331. elif c_is_column:
  1332. c_values = self.data[c].values
  1333. else:
  1334. c_values = c
  1335. if self.legend and hasattr(self, 'label'):
  1336. label = self.label
  1337. else:
  1338. label = None
  1339. scatter = ax.scatter(data[x].values, data[y].values, c=c_values,
  1340. label=label, cmap=cmap, **self.kwds)
  1341. if cb:
  1342. img = ax.collections[0]
  1343. kws = dict(ax=ax)
  1344. if self.mpl_ge_1_3_1():
  1345. kws['label'] = c if c_is_column else ''
  1346. self.fig.colorbar(img, **kws)
  1347. if label is not None:
  1348. self._add_legend_handle(scatter, label)
  1349. else:
  1350. self.legend = False
  1351. errors_x = self._get_errorbars(label=x, index=0, yerr=False)
  1352. errors_y = self._get_errorbars(label=y, index=0, xerr=False)
  1353. if len(errors_x) > 0 or len(errors_y) > 0:
  1354. err_kwds = dict(errors_x, **errors_y)
  1355. err_kwds['ecolor'] = scatter.get_facecolor()[0]
  1356. ax.errorbar(data[x].values, data[y].values,
  1357. linestyle='none', **err_kwds)
  1358. class HexBinPlot(PlanePlot):
  1359. _kind = 'hexbin'
  1360. def __init__(self, data, x, y, C=None, **kwargs):
  1361. super(HexBinPlot, self).__init__(data, x, y, **kwargs)
  1362. if is_integer(C) and not self.data.columns.holds_integer():
  1363. C = self.data.columns[C]
  1364. self.C = C
  1365. def _make_plot(self):
  1366. x, y, data, C = self.x, self.y, self.data, self.C
  1367. ax = self.axes[0]
  1368. # pandas uses colormap, matplotlib uses cmap.
  1369. cmap = self.colormap or 'BuGn'
  1370. cmap = self.plt.cm.get_cmap(cmap)
  1371. cb = self.kwds.pop('colorbar', True)
  1372. if C is None:
  1373. c_values = None
  1374. else:
  1375. c_values = data[C].values
  1376. ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap,
  1377. **self.kwds)
  1378. if cb:
  1379. img = ax.collections[0]
  1380. self.fig.colorbar(img, ax=ax)
  1381. def _make_legend(self):
  1382. pass
  1383. class LinePlot(MPLPlot):
  1384. _kind = 'line'
  1385. _default_rot = 0
  1386. orientation = 'vertical'
  1387. def __init__(self, data, **kwargs):
  1388. MPLPlot.__init__(self, data, **kwargs)
  1389. if self.stacked:
  1390. self.data = self.data.fillna(value=0)
  1391. self.x_compat = plot_params['x_compat']
  1392. if 'x_compat' in self.kwds:
  1393. self.x_compat = bool(self.kwds.pop('x_compat'))
  1394. def _is_ts_plot(self):
  1395. # this is slightly deceptive
  1396. return not self.x_compat and self.use_index and self._use_dynamic_x()
  1397. def _use_dynamic_x(self):
  1398. from pandas.tseries.plotting import _use_dynamic_x
  1399. return _use_dynamic_x(self._get_ax(0), self.data)
  1400. def _make_plot(self):
  1401. if self._is_ts_plot():
  1402. from pandas.tseries.plotting import _maybe_convert_index
  1403. data = _maybe_convert_index(self._get_ax(0), self.data)
  1404. x = data.index # dummy, not used
  1405. plotf = self._ts_plot
  1406. it = self._iter_data(data=data, keep_index=True)
  1407. else:
  1408. x = self._get_xticks(convert_period=True)
  1409. plotf = self._plot
  1410. it = self._iter_data()
  1411. stacking_id = self._get_stacking_id()
  1412. is_errorbar = any(e is not None for e in self.errors.values())
  1413. colors = self._get_colors()
  1414. for i, (label, y) in enumerate(it):
  1415. ax = self._get_ax(i)
  1416. kwds = self.kwds.copy()
  1417. style, kwds = self._apply_style_colors(colors, kwds, i, label)
  1418. errors = self._get_errorbars(label=label, index=i)
  1419. kwds = dict(kwds, **errors)
  1420. label = pprint_thing(label) # .encode('utf-8')
  1421. kwds['label'] = label
  1422. newlines = plotf(ax, x, y, style=style, column_num=i,
  1423. stacking_id=stacking_id,
  1424. is_errorbar=is_errorbar,
  1425. **kwds)
  1426. self._add_legend_handle(newlines[0], label, index=i)
  1427. lines = _get_all_lines(ax)
  1428. left, right = _get_xlim(lines)
  1429. ax.set_xlim(left, right)
  1430. @classmethod
  1431. def _plot(cls, ax, x, y, style=None, column_num=None,
  1432. stacking_id=None, **kwds):
  1433. # column_num is used to get the target column from protf in line and
  1434. # area plots
  1435. if column_num == 0:
  1436. cls._initialize_stacker(ax, stacking_id, len(y))
  1437. y_values = cls._get_stacked_values(ax, stacking_id, y, kwds['label'])
  1438. lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds)
  1439. cls._update_stacker(ax, stacking_id, y)
  1440. return lines
  1441. @classmethod
  1442. def _ts_plot(cls, ax, x, data, style=None, **kwds):
  1443. from pandas.tseries.plotting import (_maybe_resample,
  1444. _decorate_axes,
  1445. format_dateaxis)
  1446. # accept x to be consistent with normal plot func,
  1447. # x is not passed to tsplot as it uses data.index as x coordinate
  1448. # column_num must be in kwds for stacking purpose
  1449. freq, data = _maybe_resample(data, ax, kwds)
  1450. # Set ax with freq info
  1451. _decorate_axes(ax, freq, kwds)
  1452. # digging deeper
  1453. if hasattr(ax, 'left_ax'):
  1454. _decorate_axes(ax.left_ax, freq, kwds)
  1455. if hasattr(ax, 'right_ax'):
  1456. _decorate_axes(ax.right_ax, freq, kwds)
  1457. ax._plot_data.append((data, cls._kind, kwds))
  1458. lines = cls._plot(ax, data.index, data.values, style=style, **kwds)
  1459. # set date formatter, locators and rescale limits
  1460. format_dateaxis(ax, ax.freq)
  1461. return lines
  1462. def _get_stacking_id(self):
  1463. if self.stacked:
  1464. return id(self.data)
  1465. else:
  1466. return None
  1467. @classmethod
  1468. def _initialize_stacker(cls, ax, stacking_id, n):
  1469. if stacking_id is None:
  1470. return
  1471. if not hasattr(ax, '_stacker_pos_prior'):
  1472. ax._stacker_pos_prior = {}
  1473. if not hasattr(ax, '_stacker_neg_prior'):
  1474. ax._stacker_neg_prior = {}
  1475. ax._stacker_pos_prior[stacking_id] = np.zeros(n)
  1476. ax._stacker_neg_prior[stacking_id] = np.zeros(n)
  1477. @classmethod
  1478. def _get_stacked_values(cls, ax, stacking_id, values, label):
  1479. if stacking_id is None:
  1480. return values
  1481. if not hasattr(ax, '_stacker_pos_prior'):
  1482. # stacker may not be initialized for subplots
  1483. cls._initialize_stacker(ax, stacking_id, len(values))
  1484. if (values >= 0).all():
  1485. return ax._stacker_pos_prior[stacking_id] + values
  1486. elif (values <= 0).all():
  1487. return ax._stacker_neg_prior[stacking_id] + values
  1488. raise ValueError('When stacked is True, each column must be either '
  1489. 'all positive or negative.'
  1490. '{0} contains both positive and negative values'
  1491. .format(label))
  1492. @classmethod
  1493. def _update_stacker(cls, ax, stacking_id, values):
  1494. if stacking_id is None:
  1495. return
  1496. if (values >= 0).all():
  1497. ax._stacker_pos_prior[stacking_id] += values
  1498. elif (values <= 0).all():
  1499. ax._stacker_neg_prior[stacking_id] += values
  1500. def _post_plot_logic(self, ax, data):
  1501. condition = (not self._use_dynamic_x() and
  1502. data.index.is_all_dates and
  1503. not self.subplots or
  1504. (self.subplots and self.sharex))
  1505. index_name = self._get_index_name()
  1506. if condition:
  1507. # irregular TS rotated 30 deg. by default
  1508. # probably a better place to check / set this.
  1509. if not self._rot_set:
  1510. self.rot = 30
  1511. format_date_labels(ax, rot=self.rot)
  1512. if index_name is not None and self.use_index:
  1513. ax.set_xlabel(index_name)
  1514. class AreaPlot(LinePlot):
  1515. _kind = 'area'
  1516. def __init__(self, data, **kwargs):
  1517. kwargs.setdefault('stacked', True)
  1518. data = data.fillna(value=0)
  1519. LinePlot.__init__(self, data, **kwargs)
  1520. if not self.stacked:
  1521. # use smaller alpha to distinguish overlap
  1522. self.kwds.setdefault('alpha', 0.5)
  1523. if self.logy or self.loglog:
  1524. raise ValueError("Log-y scales are not supported in area plot")
  1525. @classmethod
  1526. def _plot(cls, ax, x, y, style=None, column_num=None,
  1527. stacking_id=None, is_errorbar=False, **kwds):
  1528. if column_num == 0:
  1529. cls._initialize_stacker(ax, stacking_id, len(y))
  1530. y_values = cls._get_stacked_values(ax, stacking_id, y, kwds['label'])
  1531. # need to remove label, because subplots uses mpl legend as it is
  1532. line_kwds = kwds.copy()
  1533. if cls.mpl_ge_1_5_0():
  1534. line_kwds.pop('label')
  1535. lines = MPLPlot._plot(ax, x, y_values, style=style, **line_kwds)
  1536. # get data from the line to get coordinates for fill_between
  1537. xdata, y_values = lines[0].get_data(orig=False)
  1538. # unable to use ``_get_stacked_values`` here to get starting point
  1539. if stacking_id is None:
  1540. start = np.zeros(len(y))
  1541. elif (y >= 0).all():
  1542. start = ax._stacker_pos_prior[stacking_id]
  1543. elif (y <= 0).all():
  1544. start = ax._stacker_neg_prior[stacking_id]
  1545. else:
  1546. start = np.zeros(len(y))
  1547. if 'color' not in kwds:
  1548. kwds['color'] = lines[0].get_color()
  1549. rect = ax.fill_between(xdata, start, y_values, **kwds)
  1550. cls._update_stacker(ax, stacking_id, y)
  1551. # LinePlot expects list of artists
  1552. res = [rect] if cls.mpl_ge_1_5_0() else lines
  1553. return res
  1554. def _add_legend_handle(self, handle, label, index=None):
  1555. if not self.mpl_ge_1_5_0():
  1556. from matplotlib.patches import Rectangle
  1557. # Because fill_between isn't supported in legend,
  1558. # specifically add Rectangle handle here
  1559. alpha = self.kwds.get('alpha', None)
  1560. handle = Rectangle((0, 0), 1, 1, fc=handle.get_color(),
  1561. alpha=alpha)
  1562. LinePlot._add_legend_handle(self, handle, label, index=index)
  1563. def _post_plot_logic(self, ax, data):
  1564. LinePlot._post_plot_logic(self, ax, data)
  1565. if self.ylim is None:
  1566. if (data >= 0).all().all():
  1567. ax.set_ylim(0, None)
  1568. elif (data <= 0).all().all():
  1569. ax.set_ylim(None, 0)
  1570. class BarPlot(MPLPlot):
  1571. _kind = 'bar'
  1572. _default_rot = 90
  1573. orientation = 'vertical'
  1574. def __init__(self, data, **kwargs):
  1575. self.bar_width = kwargs.pop('width', 0.5)
  1576. pos = kwargs.pop('position', 0.5)
  1577. kwargs.setdefault('align', 'center')
  1578. self.tick_pos = np.arange(len(data))
  1579. self.bottom = kwargs.pop('bottom', 0)
  1580. self.left = kwargs.pop('left', 0)
  1581. self.log = kwargs.pop('log', False)
  1582. MPLPlot.__init__(self, data, **kwargs)
  1583. if self.stacked or self.subplots:
  1584. self.tickoffset = self.bar_width * pos
  1585. if kwargs['align'] == 'edge':
  1586. self.lim_offset = self.bar_width / 2
  1587. else:
  1588. self.lim_offset = 0
  1589. else:
  1590. if kwargs['align'] == 'edge':
  1591. w = self.bar_width / self.nseries
  1592. self.tickoffset = self.bar_width * (pos - 0.5) + w * 0.5
  1593. self.lim_offset = w * 0.5
  1594. else:
  1595. self.tickoffset = self.bar_width * pos
  1596. self.lim_offset = 0
  1597. self.ax_pos = self.tick_pos - self.tickoffset
  1598. def _args_adjust(self):
  1599. if is_list_like(self.bottom):
  1600. self.bottom = np.array(self.bottom)
  1601. if is_list_like(self.left):
  1602. self.left = np.array(self.left)
  1603. @classmethod
  1604. def _plot(cls, ax, x, y, w, start=0, log=False, **kwds):
  1605. return ax.bar(x, y, w, bottom=start, log=log, **kwds)
  1606. @property
  1607. def _start_base(self):
  1608. return self.bottom
  1609. def _make_plot(self):
  1610. import matplotlib as mpl
  1611. colors = self._get_colors()
  1612. ncolors = len(colors)
  1613. pos_prior = neg_prior = np.zeros(len(self.data))
  1614. K = self.nseries
  1615. for i, (label, y) in enumerate(self._iter_data(fillna=0)):
  1616. ax = self._get_ax(i)
  1617. kwds = self.kwds.copy()
  1618. kwds['color'] = colors[i % ncolors]
  1619. errors = self._get_errorbars(label=label, index=i)
  1620. kwds = dict(kwds, **errors)
  1621. label = pprint_thing(label)
  1622. if (('yerr' in kwds) or ('xerr' in kwds)) \
  1623. and (kwds.get('ecolor') is None):
  1624. kwds['ecolor'] = mpl.rcParams['xtick.color']
  1625. start = 0
  1626. if self.log and (y >= 1).all():
  1627. start = 1
  1628. start = start + self._start_base
  1629. if self.subplots:
  1630. w = self.bar_width / 2
  1631. rect = self._plot(ax, self.ax_pos + w, y, self.bar_width,
  1632. start=start, label=label,
  1633. log=self.log, **kwds)
  1634. ax.set_title(label)
  1635. elif self.stacked:
  1636. mask = y > 0
  1637. start = np.where(mask, pos_prior, neg_prior) + self._start_base
  1638. w = self.bar_width / 2
  1639. rect = self._plot(ax, self.ax_pos + w, y, self.bar_width,
  1640. start=start, label=label,
  1641. log=self.log, **kwds)
  1642. pos_prior = pos_prior + np.where(mask, y, 0)
  1643. neg_prior = neg_prior + np.where(mask, 0, y)
  1644. else:
  1645. w = self.bar_width / K
  1646. rect = self._plot(ax, self.ax_pos + (i + 0.5) * w, y, w,
  1647. start=start, label=label,
  1648. log=self.log, **kwds)
  1649. self._add_legend_handle(rect, label, index=i)
  1650. def _post_plot_logic(self, ax, data):
  1651. if self.use_index:
  1652. str_index = [pprint_thing(key) for key in data.index]
  1653. else:
  1654. str_index = [pprint_thing(key) for key in range(data.shape[0])]
  1655. name = self._get_index_name()
  1656. s_edge = self.ax_pos[0] - 0.25 + self.lim_offset
  1657. e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset
  1658. self._decorate_ticks(ax, name, str_index, s_edge, e_edge)
  1659. def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge):
  1660. ax.set_xlim((start_edge, end_edge))
  1661. ax.set_xticks(self.tick_pos)
  1662. ax.set_xticklabels(ticklabels)
  1663. if name is not None and self.use_index:
  1664. ax.set_xlabel(name)
  1665. class BarhPlot(BarPlot):
  1666. _kind = 'barh'
  1667. _default_rot = 0
  1668. orientation = 'horizontal'
  1669. @property
  1670. def _start_base(self):
  1671. return self.left
  1672. @classmethod
  1673. def _plot(cls, ax, x, y, w, start=0, log=False, **kwds):
  1674. return ax.barh(x, y, w, left=start, log=log, **kwds)
  1675. def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge):
  1676. # horizontal bars
  1677. ax.set_ylim((start_edge, end_edge))
  1678. ax.set_yticks(self.tick_pos)
  1679. ax.set_yticklabels(ticklabels)
  1680. if name is not None and self.use_index:
  1681. ax.set_ylabel(name)
  1682. class HistPlot(LinePlot):
  1683. _kind = 'hist'
  1684. def __init__(self, data, bins=10, bottom=0, **kwargs):
  1685. self.bins = bins # use mpl default
  1686. self.bottom = bottom
  1687. # Do not call LinePlot.__init__ which may fill nan
  1688. MPLPlot.__init__(self, data, **kwargs)
  1689. def _args_adjust(self):
  1690. if is_integer(self.bins):
  1691. # create common bin edge
  1692. values = (self.data._convert(datetime=True)._get_numeric_data())
  1693. values = np.ravel(values)
  1694. values = values[~isnull(values)]
  1695. hist, self.bins = np.histogram(
  1696. values, bins=self.bins,
  1697. range=self.kwds.get('range', None),
  1698. weights=self.kwds.get('weights', None))
  1699. if is_list_like(self.bottom):
  1700. self.bottom = np.array(self.bottom)
  1701. @classmethod
  1702. def _plot(cls, ax, y, style=None, bins=None, bottom=0, column_num=0,
  1703. stacking_id=None, **kwds):
  1704. if column_num == 0:
  1705. cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
  1706. y = y[~isnull(y)]
  1707. base = np.zeros(len(bins) - 1)
  1708. bottom = bottom + \
  1709. cls._get_stacked_values(ax, stacking_id, base, kwds['label'])
  1710. # ignore style
  1711. n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
  1712. cls._update_stacker(ax, stacking_id, n)
  1713. return patches
  1714. def _make_plot(self):
  1715. colors = self._get_colors()
  1716. stacking_id = self._get_stacking_id()
  1717. for i, (label, y) in enumerate(self._iter_data()):
  1718. ax = self._get_ax(i)
  1719. kwds = self.kwds.copy()
  1720. label = pprint_thing(label)
  1721. kwds['label'] = label
  1722. style, kwds = self._apply_style_colors(colors, kwds, i, label)
  1723. if style is not None:
  1724. kwds['style'] = style
  1725. kwds = self._make_plot_keywords(kwds, y)
  1726. artists = self._plot(ax, y, column_num=i,
  1727. stacking_id=stacking_id, **kwds)
  1728. self._add_legend_handle(artists[0], label, index=i)
  1729. def _make_plot_keywords(self, kwds, y):
  1730. """merge BoxPlot/KdePlot properties to passed kwds"""
  1731. # y is required for KdePlot
  1732. kwds['bottom'] = self.bottom
  1733. kwds['bins'] = self.bins
  1734. return kwds
  1735. def _post_plot_logic(self, ax, data):
  1736. if self.orientation == 'horizontal':
  1737. ax.set_xlabel('Frequency')
  1738. else:
  1739. ax.set_ylabel('Frequency')
  1740. @property
  1741. def orientation(self):
  1742. if self.kwds.get('orientation', None) == 'horizontal':
  1743. return 'horizontal'
  1744. else:
  1745. return 'vertical'
  1746. class KdePlot(HistPlot):
  1747. _kind = 'kde'
  1748. orientation = 'vertical'
  1749. def __init__(self, data, bw_method=None, ind=None, **kwargs):
  1750. MPLPlot.__init__(self, data, **kwargs)
  1751. self.bw_method = bw_method
  1752. self.ind = ind
  1753. def _args_adjust(self):
  1754. pass
  1755. def _get_ind(self, y):
  1756. if self.ind is None:
  1757. sample_range = max(y) - min(y)
  1758. ind = np.linspace(min(y) - 0.5 * sample_range,
  1759. max(y) + 0.5 * sample_range, 1000)
  1760. else:
  1761. ind = self.ind
  1762. return ind
  1763. @classmethod
  1764. def _plot(cls, ax, y, style=None, bw_method=None, ind=None,
  1765. column_num=None, stacking_id=None, **kwds):
  1766. from scipy.stats import gaussian_kde
  1767. from scipy import __version__ as spv
  1768. y = remove_na(y)
  1769. if LooseVersion(spv) >= '0.11.0':
  1770. gkde = gaussian_kde(y, bw_method=bw_method)
  1771. else:
  1772. gkde = gaussian_kde(y)
  1773. if bw_method is not None:
  1774. msg = ('bw_method was added in Scipy 0.11.0.' +
  1775. ' Scipy version in use is %s.' % spv)
  1776. warnings.warn(msg)
  1777. y = gkde.evaluate(ind)
  1778. lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
  1779. return lines
  1780. def _make_plot_keywords(self, kwds, y):
  1781. kwds['bw_method'] = self.bw_method
  1782. kwds['ind'] = self._get_ind(y)
  1783. return kwds
  1784. def _post_plot_logic(self, ax, data):
  1785. ax.set_ylabel('Density')
  1786. class PiePlot(MPLPlot):
  1787. _kind = 'pie'
  1788. _layout_type = 'horizontal'
  1789. def __init__(self, data, kind=None, **kwargs):
  1790. data = data.fillna(value=0)
  1791. if (data < 0).any().any():
  1792. raise ValueError("{0} doesn't allow negative values".format(kind))
  1793. MPLPlot.__init__(self, data, kind=kind, **kwargs)
  1794. def _args_adjust(self):
  1795. self.grid = False
  1796. self.logy = False
  1797. self.logx = False
  1798. self.loglog = False
  1799. def _validate_color_args(self):
  1800. pass
  1801. def _make_plot(self):
  1802. colors = self._get_colors(
  1803. num_colors=len(self.data), color_kwds='colors')
  1804. self.kwds.setdefault('colors', colors)
  1805. for i, (label, y) in enumerate(self._iter_data()):
  1806. ax = self._get_ax(i)
  1807. if label is not None:
  1808. label = pprint_thing(label)
  1809. ax.set_ylabel(label)
  1810. kwds = self.kwds.copy()
  1811. def blank_labeler(label, value):
  1812. if value == 0:
  1813. return ''
  1814. else:
  1815. return label
  1816. idx = [pprint_thing(v) for v in self.data.index]
  1817. labels = kwds.pop('labels', idx)
  1818. # labels is used for each wedge's labels
  1819. # Blank out labels for values of 0 so they don't overlap
  1820. # with nonzero wedges
  1821. if labels is not None:
  1822. blabels = [blank_labeler(l, value) for
  1823. l, value in zip(labels, y)]
  1824. else:
  1825. blabels = None
  1826. results = ax.pie(y, labels=blabels, **kwds)
  1827. if kwds.get('autopct', None) is not None:
  1828. patches, texts, autotexts = results
  1829. else:
  1830. patches, texts = results
  1831. autotexts = []
  1832. if self.fontsize is not None:
  1833. for t in texts + autotexts:
  1834. t.set_fontsize(self.fontsize)
  1835. # leglabels is used for legend labels
  1836. leglabels = labels if labels is not None else idx
  1837. for p, l in zip(patches, leglabels):
  1838. self._add_legend_handle(p, l)
  1839. class BoxPlot(LinePlot):
  1840. _kind = 'box'
  1841. _layout_type = 'horizontal'
  1842. _valid_return_types = (None, 'axes', 'dict', 'both')
  1843. # namedtuple to hold results
  1844. BP = namedtuple("Boxplot", ['ax', 'lines'])
  1845. def __init__(self, data, return_type=None, **kwargs):
  1846. # Do not call LinePlot.__init__ which may fill nan
  1847. if return_type not in self._valid_return_types:
  1848. raise ValueError(
  1849. "return_type must be {None, 'axes', 'dict', 'both'}")
  1850. self.return_type = return_type
  1851. MPLPlot.__init__(self, data, **kwargs)
  1852. def _args_adjust(self):
  1853. if self.subplots:
  1854. # Disable label ax sharing. Otherwise, all subplots shows last
  1855. # column label
  1856. if self.orientation == 'vertical':
  1857. self.sharex = False
  1858. else:
  1859. self.sharey = False
  1860. @classmethod
  1861. def _plot(cls, ax, y, column_num=None, return_type=None, **kwds):
  1862. if y.ndim == 2:
  1863. y = [remove_na(v) for v in y]
  1864. # Boxplot fails with empty arrays, so need to add a NaN
  1865. # if any cols are empty
  1866. # GH 8181
  1867. y = [v if v.size > 0 else np.array([np.nan]) for v in y]
  1868. else:
  1869. y = remove_na(y)
  1870. bp = ax.boxplot(y, **kwds)
  1871. if return_type == 'dict':
  1872. return bp, bp
  1873. elif return_type == 'both':
  1874. return cls.BP(ax=ax, lines=bp), bp
  1875. else:
  1876. return ax, bp
  1877. def _validate_color_args(self):
  1878. if 'color' in self.kwds:
  1879. if self.colormap is not None:
  1880. warnings.warn("'color' and 'colormap' cannot be used "
  1881. "simultaneously. Using 'color'")
  1882. self.color = self.kwds.pop('color')
  1883. if isinstance(self.color, dict):
  1884. valid_keys = ['boxes', 'whiskers', 'medians', 'caps']
  1885. for key, values in compat.iteritems(self.color):
  1886. if key not in valid_keys:
  1887. raise ValueError("color dict contains invalid "
  1888. "key '{0}' "
  1889. "The key must be either {1}"
  1890. .format(key, valid_keys))
  1891. else:
  1892. self.color = None
  1893. # get standard colors for default
  1894. colors = _get_standard_colors(num_colors=3,
  1895. colormap=self.colormap,
  1896. color=None)
  1897. # use 2 colors by default, for box/whisker and median
  1898. # flier colors isn't needed here
  1899. # because it can be specified by ``sym`` kw
  1900. self._boxes_c = colors[0]
  1901. self._whiskers_c = colors[0]
  1902. self._medians_c = colors[2]
  1903. self._caps_c = 'k' # mpl default
  1904. def _get_colors(self, num_colors=None, color_kwds='color'):
  1905. pass
  1906. def maybe_color_bp(self, bp):
  1907. if isinstance(self.color, dict):
  1908. boxes = self.color.get('boxes', self._boxes_c)
  1909. whiskers = self.color.get('whiskers', self._whiskers_c)
  1910. medians = self.color.get('medians', self._medians_c)
  1911. caps = self.color.get('caps', self._caps_c)
  1912. else:
  1913. # Other types are forwarded to matplotlib
  1914. # If None, use default colors
  1915. boxes = self.color or self._boxes_c
  1916. whiskers = self.color or self._whiskers_c
  1917. medians = self.color or self._medians_c
  1918. caps = self.color or self._caps_c
  1919. from matplotlib.artist import setp
  1920. setp(bp['boxes'], color=boxes, alpha=1)
  1921. setp(bp['whiskers'], color=whiskers, alpha=1)
  1922. setp(bp['medians'], color=medians, alpha=1)
  1923. setp(bp['caps'], color=caps, alpha=1)
  1924. def _make_plot(self):
  1925. if self.subplots:
  1926. self._return_obj = compat.OrderedDict()
  1927. for i, (label, y) in enumerate(self._iter_data()):
  1928. ax = self._get_ax(i)
  1929. kwds = self.kwds.copy()
  1930. ret, bp = self._plot(ax, y, column_num=i,
  1931. return_type=self.return_type, **kwds)
  1932. self.maybe_color_bp(bp)
  1933. self._return_obj[label] = ret
  1934. label = [pprint_thing(label)]
  1935. self._set_ticklabels(ax, label)
  1936. else:
  1937. y = self.data.values.T
  1938. ax = self._get_ax(0)
  1939. kwds = self.kwds.copy()
  1940. ret, bp = self._plot(ax, y, column_num=0,
  1941. return_type=self.return_type, **kwds)
  1942. self.maybe_color_bp(bp)
  1943. self._return_obj = ret
  1944. labels = [l for l, _ in self._iter_data()]
  1945. labels = [pprint_thing(l) for l in labels]
  1946. if not self.use_index:
  1947. labels = [pprint_thing(key) for key in range(len(labels))]
  1948. self._set_ticklabels(ax, labels)
  1949. def _set_ticklabels(self, ax, labels):
  1950. if self.orientation == 'vertical':
  1951. ax.set_xticklabels(labels)
  1952. else:
  1953. ax.set_yticklabels(labels)
  1954. def _make_legend(self):
  1955. pass
  1956. def _post_plot_logic(self, ax, data):
  1957. pass
  1958. @property
  1959. def orientation(self):
  1960. if self.kwds.get('vert', True):
  1961. return 'vertical'
  1962. else:
  1963. return 'horizontal'
  1964. @property
  1965. def result(self):
  1966. if self.return_type is None:
  1967. return super(BoxPlot, self).result
  1968. else:
  1969. return self._return_obj
  1970. # kinds supported by both dataframe and series
  1971. _common_kinds = ['line', 'bar', 'barh',
  1972. 'kde', 'density', 'area', 'hist', 'box']
  1973. # kinds supported by dataframe
  1974. _dataframe_kinds = ['scatter', 'hexbin']
  1975. # kinds supported only by series or dataframe single column
  1976. _series_kinds = ['pie']
  1977. _all_kinds = _common_kinds + _dataframe_kinds + _series_kinds
  1978. _klasses = [LinePlot, BarPlot, BarhPlot, KdePlot, HistPlot, BoxPlot,
  1979. ScatterPlot, HexBinPlot, AreaPlot, PiePlot]
  1980. _plot_klass = {}
  1981. for klass in _klasses:
  1982. _plot_klass[klass._kind] = klass
  1983. def _plot(data, x=None, y=None, subplots=False,
  1984. ax=None, kind='line', **kwds):
  1985. kind = _get_standard_kind(kind.lower().strip())
  1986. if kind in _all_kinds:
  1987. klass = _plot_klass[kind]
  1988. else:
  1989. raise ValueError("%r is not a valid plot kind" % kind)
  1990. from pandas import DataFrame
  1991. if kind in _dataframe_kinds:
  1992. if isinstance(data, DataFrame):
  1993. plot_obj = klass(data, x=x, y=y, subplots=subplots, ax=ax,
  1994. kind=kind, **kwds)
  1995. else:
  1996. raise ValueError("plot kind %r can only be used for data frames"
  1997. % kind)
  1998. elif kind in _series_kinds:
  1999. if isinstance(data, DataFrame):
  2000. if y is None and subplots is False:
  2001. msg = "{0} requires either y column or 'subplots=True'"
  2002. raise ValueError(msg.format(kind))
  2003. elif y is not None:
  2004. if is_integer(y) and not data.columns.holds_integer():
  2005. y = data.columns[y]
  2006. # converted to series actually. copy to not modify
  2007. data = data[y].copy()
  2008. data.index.name = y
  2009. plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
  2010. else:
  2011. if isinstance(data, DataFrame):
  2012. if x is not None:
  2013. if is_integer(x) and not data.columns.holds_integer():
  2014. x = data.columns[x]
  2015. data = data.set_index(x)
  2016. if y is not None:
  2017. if is_integer(y) and not data.columns.holds_integer():
  2018. y = data.columns[y]
  2019. label = kwds['label'] if 'label' in kwds else y
  2020. series = data[y].copy() # Don't modify
  2021. series.name = label
  2022. for kw in ['xerr', 'yerr']:
  2023. if (kw in kwds) and \
  2024. (isinstance(kwds[kw], string_types) or
  2025. is_integer(kwds[kw])):
  2026. try:
  2027. kwds[kw] = data[kwds[kw]]
  2028. except (IndexError, KeyError, TypeError):
  2029. pass
  2030. data = series
  2031. plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
  2032. plot_obj.generate()
  2033. plot_obj.draw()
  2034. return plot_obj.result
  2035. df_kind = """- 'scatter' : scatter plot
  2036. - 'hexbin' : hexbin plot"""
  2037. series_kind = ""
  2038. df_coord = """x : label or position, default None
  2039. y : label or position, default None
  2040. Allows plotting of one column versus another"""
  2041. series_coord = ""
  2042. df_unique = """stacked : boolean, default False in line and
  2043. bar plots, and True in area plot. If True, create stacked plot.
  2044. sort_columns : boolean, default False
  2045. Sort column names to determine plot ordering
  2046. secondary_y : boolean or sequence, default False
  2047. Whether to plot on the secondary y-axis
  2048. If a list/tuple, which columns to plot on secondary y-axis"""
  2049. series_unique = """label : label argument to provide to plot
  2050. secondary_y : boolean or sequence of ints, default False
  2051. If True then y-axis will be on the right"""
  2052. df_ax = """ax : matplotlib axes object, default None
  2053. subplots : boolean, default False
  2054. Make separate subplots for each column
  2055. sharex : boolean, default True if ax is None else False
  2056. In case subplots=True, share x axis and set some x axis labels to
  2057. invisible; defaults to True if ax is None otherwise False if an ax
  2058. is passed in; Be aware, that passing in both an ax and sharex=True
  2059. will alter all x axis labels for all axis in a figure!
  2060. sharey : boolean, default False
  2061. In case subplots=True, share y axis and set some y axis labels to
  2062. invisible
  2063. layout : tuple (optional)
  2064. (rows, columns) for the layout of subplots"""
  2065. series_ax = """ax : matplotlib axes object
  2066. If not passed, uses gca()"""
  2067. df_note = """- If `kind` = 'scatter' and the argument `c` is the name of a dataframe
  2068. column, the values of that column are used to color each point.
  2069. - If `kind` = 'hexbin', you can control the size of the bins with the
  2070. `gridsize` argument. By default, a histogram of the counts around each
  2071. `(x, y)` point is computed. You can specify alternative aggregations
  2072. by passing values to the `C` and `reduce_C_function` arguments.
  2073. `C` specifies the value at each `(x, y)` point and `reduce_C_function`
  2074. is a function of one argument that reduces all the values in a bin to
  2075. a single number (e.g. `mean`, `max`, `sum`, `std`)."""
  2076. series_note = ""
  2077. _shared_doc_df_kwargs = dict(klass='DataFrame', klass_obj='df',
  2078. klass_kind=df_kind, klass_coord=df_coord,
  2079. klass_ax=df_ax, klass_unique=df_unique,
  2080. klass_note=df_note)
  2081. _shared_doc_series_kwargs = dict(klass='Series', klass_obj='s',
  2082. klass_kind=series_kind,
  2083. klass_coord=series_coord, klass_ax=series_ax,
  2084. klass_unique=series_unique,
  2085. klass_note=series_note)
  2086. _shared_docs['plot'] = """
  2087. Make plots of %(klass)s using matplotlib / pylab.
  2088. *New in version 0.17.0:* Each plot kind has a corresponding method on the
  2089. ``%(klass)s.plot`` accessor:
  2090. ``%(klass_obj)s.plot(kind='line')`` is equivalent to
  2091. ``%(klass_obj)s.plot.line()``.
  2092. Parameters
  2093. ----------
  2094. data : %(klass)s
  2095. %(klass_coord)s
  2096. kind : str
  2097. - 'line' : line plot (default)
  2098. - 'bar' : vertical bar plot
  2099. - 'barh' : horizontal bar plot
  2100. - 'hist' : histogram
  2101. - 'box' : boxplot
  2102. - 'kde' : Kernel Density Estimation plot
  2103. - 'density' : same as 'kde'
  2104. - 'area' : area plot
  2105. - 'pie' : pie plot
  2106. %(klass_kind)s
  2107. %(klass_ax)s
  2108. figsize : a tuple (width, height) in inches
  2109. use_index : boolean, default True
  2110. Use index as ticks for x axis
  2111. title : string
  2112. Title to use for the plot
  2113. grid : boolean, default None (matlab style default)
  2114. Axis grid lines
  2115. legend : False/True/'reverse'
  2116. Place legend on axis subplots
  2117. style : list or dict
  2118. matplotlib line style per column
  2119. logx : boolean, default False
  2120. Use log scaling on x axis
  2121. logy : boolean, default False
  2122. Use log scaling on y axis
  2123. loglog : boolean, default False
  2124. Use log scaling on both x and y axes
  2125. xticks : sequence
  2126. Values to use for the xticks
  2127. yticks : sequence
  2128. Values to use for the yticks
  2129. xlim : 2-tuple/list
  2130. ylim : 2-tuple/list
  2131. rot : int, default None
  2132. Rotation for ticks (xticks for vertical, yticks for horizontal plots)
  2133. fontsize : int, default None
  2134. Font size for xticks and yticks
  2135. colormap : str or matplotlib colormap object, default None
  2136. Colormap to select colors from. If string, load colormap with that name
  2137. from matplotlib.
  2138. colorbar : boolean, optional
  2139. If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots)
  2140. position : float
  2141. Specify relative alignments for bar plot layout.
  2142. From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
  2143. layout : tuple (optional)
  2144. (rows, columns) for the layout of the plot
  2145. table : boolean, Series or DataFrame, default False
  2146. If True, draw a table using the data in the DataFrame and the data will
  2147. be transposed to meet matplotlib's default layout.
  2148. If a Series or DataFrame is passed, use passed data to draw a table.
  2149. yerr : DataFrame, Series, array-like, dict and str
  2150. See :ref:`Plotting with Error Bars <visualization.errorbars>` for
  2151. detail.
  2152. xerr : same types as yerr.
  2153. %(klass_unique)s
  2154. mark_right : boolean, default True
  2155. When using a secondary_y axis, automatically mark the column
  2156. labels with "(right)" in the legend
  2157. kwds : keywords
  2158. Options to pass to matplotlib plotting method
  2159. Returns
  2160. -------
  2161. axes : matplotlib.AxesSubplot or np.array of them
  2162. Notes
  2163. -----
  2164. - See matplotlib documentation online for more on this subject
  2165. - If `kind` = 'bar' or 'barh', you can specify relative alignments
  2166. for bar plot layout by `position` keyword.
  2167. From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
  2168. %(klass_note)s
  2169. """
  2170. @Appender(_shared_docs['plot'] % _shared_doc_df_kwargs)
  2171. def plot_frame(data, x=None, y=None, kind='line', ax=None,
  2172. subplots=False, sharex=None, sharey=False, layout=None,
  2173. figsize=None, use_index=True, title=None, grid=None,
  2174. legend=True, style=None, logx=False, logy=False, loglog=False,
  2175. xticks=None, yticks=None, xlim=None, ylim=None,
  2176. rot=None, fontsize=None, colormap=None, table=False,
  2177. yerr=None, xerr=None,
  2178. secondary_y=False, sort_columns=False,
  2179. **kwds):
  2180. return _plot(data, kind=kind, x=x, y=y, ax=ax,
  2181. subplots=subplots, sharex=sharex, sharey=sharey,
  2182. layout=layout, figsize=figsize, use_index=use_index,
  2183. title=title, grid=grid, legend=legend,
  2184. style=style, logx=logx, logy=logy, loglog=loglog,
  2185. xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
  2186. rot=rot, fontsize=fontsize, colormap=colormap, table=table,
  2187. yerr=yerr, xerr=xerr,
  2188. secondary_y=secondary_y, sort_columns=sort_columns,
  2189. **kwds)
  2190. @Appender(_shared_docs['plot'] % _shared_doc_series_kwargs)
  2191. def plot_series(data, kind='line', ax=None, # Series unique
  2192. figsize=None, use_index=True, title=None, grid=None,
  2193. legend=False, style=None, logx=False, logy=False, loglog=False,
  2194. xticks=None, yticks=None, xlim=None, ylim=None,
  2195. rot=None, fontsize=None, colormap=None, table=False,
  2196. yerr=None, xerr=None,
  2197. label=None, secondary_y=False, # Series unique
  2198. **kwds):
  2199. import matplotlib.pyplot as plt
  2200. """
  2201. If no axes is specified, check whether there are existing figures
  2202. If there is no existing figures, _gca() will
  2203. create a figure with the default figsize, causing the figsize=parameter to
  2204. be ignored.
  2205. """
  2206. if ax is None and len(plt.get_fignums()) > 0:
  2207. ax = _gca()
  2208. ax = MPLPlot._get_ax_layer(ax)
  2209. return _plot(data, kind=kind, ax=ax,
  2210. figsize=figsize, use_index=use_index, title=title,
  2211. grid=grid, legend=legend,
  2212. style=style, logx=logx, logy=logy, loglog=loglog,
  2213. xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim,
  2214. rot=rot, fontsize=fontsize, colormap=colormap, table=table,
  2215. yerr=yerr, xerr=xerr,
  2216. label=label, secondary_y=secondary_y,
  2217. **kwds)
  2218. _shared_docs['boxplot'] = """
  2219. Make a box plot from DataFrame column optionally grouped by some columns or
  2220. other inputs
  2221. Parameters
  2222. ----------
  2223. data : the pandas object holding the data
  2224. column : column name or list of names, or vector
  2225. Can be any valid input to groupby
  2226. by : string or sequence
  2227. Column in the DataFrame to group by
  2228. ax : Matplotlib axes object, optional
  2229. fontsize : int or string
  2230. rot : label rotation angle
  2231. figsize : A tuple (width, height) in inches
  2232. grid : Setting this to True will show the grid
  2233. layout : tuple (optional)
  2234. (rows, columns) for the layout of the plot
  2235. return_type : {'axes', 'dict', 'both'}, default 'dict'
  2236. The kind of object to return. 'dict' returns a dictionary
  2237. whose values are the matplotlib Lines of the boxplot;
  2238. 'axes' returns the matplotlib axes the boxplot is drawn on;
  2239. 'both' returns a namedtuple with the axes and dict.
  2240. When grouping with ``by``, a dict mapping columns to ``return_type``
  2241. is returned.
  2242. kwds : other plotting keyword arguments to be passed to matplotlib boxplot
  2243. function
  2244. Returns
  2245. -------
  2246. lines : dict
  2247. ax : matplotlib Axes
  2248. (ax, lines): namedtuple
  2249. Notes
  2250. -----
  2251. Use ``return_type='dict'`` when you want to tweak the appearance
  2252. of the lines after plotting. In this case a dict containing the Lines
  2253. making up the boxes, caps, fliers, medians, and whiskers is returned.
  2254. """
  2255. @Appender(_shared_docs['boxplot'] % _shared_doc_kwargs)
  2256. def boxplot(data, column=None, by=None, ax=None, fontsize=None,
  2257. rot=0, grid=True, figsize=None, layout=None, return_type=None,
  2258. **kwds):
  2259. # validate return_type:
  2260. if return_type not in BoxPlot._valid_return_types:
  2261. raise ValueError("return_type must be {None, 'axes', 'dict', 'both'}")
  2262. from pandas import Series, DataFrame
  2263. if isinstance(data, Series):
  2264. data = DataFrame({'x': data})
  2265. column = 'x'
  2266. def _get_colors():
  2267. return _get_standard_colors(color=kwds.get('color'), num_colors=1)
  2268. def maybe_color_bp(bp):
  2269. if 'color' not in kwds:
  2270. from matplotlib.artist import setp
  2271. setp(bp['boxes'], color=colors[0], alpha=1)
  2272. setp(bp['whiskers'], color=colors[0], alpha=1)
  2273. setp(bp['medians'], color=colors[2], alpha=1)
  2274. def plot_group(keys, values, ax):
  2275. keys = [pprint_thing(x) for x in keys]
  2276. values = [remove_na(v) for v in values]
  2277. bp = ax.boxplot(values, **kwds)
  2278. if kwds.get('vert', 1):
  2279. ax.set_xticklabels(keys, rotation=rot, fontsize=fontsize)
  2280. else:
  2281. ax.set_yticklabels(keys, rotation=rot, fontsize=fontsize)
  2282. maybe_color_bp(bp)
  2283. # Return axes in multiplot case, maybe revisit later # 985
  2284. if return_type == 'dict':
  2285. return bp
  2286. elif return_type == 'both':
  2287. return BoxPlot.BP(ax=ax, lines=bp)
  2288. else:
  2289. return ax
  2290. colors = _get_colors()
  2291. if column is None:
  2292. columns = None
  2293. else:
  2294. if isinstance(column, (list, tuple)):
  2295. columns = column
  2296. else:
  2297. columns = [column]
  2298. if by is not None:
  2299. result = _grouped_plot_by_column(plot_group, data, columns=columns,
  2300. by=by, grid=grid, figsize=figsize,
  2301. ax=ax, layout=layout,
  2302. return_type=return_type)
  2303. else:
  2304. if layout is not None:
  2305. raise ValueError("The 'layout' keyword is not supported when "
  2306. "'by' is None")
  2307. if return_type is None:
  2308. msg = ("\nThe default value for 'return_type' will change to "
  2309. "'axes' in a future release.\n To use the future behavior "
  2310. "now, set return_type='axes'.\n To keep the previous "
  2311. "behavior and silence this warning, set "
  2312. "return_type='dict'.")
  2313. warnings.warn(msg, FutureWarning, stacklevel=3)
  2314. return_type = 'dict'
  2315. if ax is None:
  2316. ax = _gca()
  2317. data = data._get_numeric_data()
  2318. if columns is None:
  2319. columns = data.columns
  2320. else:
  2321. data = data[columns]
  2322. result = plot_group(columns, data.values.T, ax)
  2323. ax.grid(grid)
  2324. return result
  2325. def format_date_labels(ax, rot):
  2326. # mini version of autofmt_xdate
  2327. try:
  2328. for label in ax.get_xticklabels():
  2329. label.set_ha('right')
  2330. label.set_rotation(rot)
  2331. fig = ax.get_figure()
  2332. fig.subplots_adjust(bottom=0.2)
  2333. except Exception: # pragma: no cover
  2334. pass
  2335. def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False,
  2336. **kwargs):
  2337. """
  2338. Make a scatter plot from two DataFrame columns
  2339. Parameters
  2340. ----------
  2341. data : DataFrame
  2342. x : Column name for the x-axis values
  2343. y : Column name for the y-axis values
  2344. ax : Matplotlib axis object
  2345. figsize : A tuple (width, height) in inches
  2346. grid : Setting this to True will show the grid
  2347. kwargs : other plotting keyword arguments
  2348. To be passed to scatter function
  2349. Returns
  2350. -------
  2351. fig : matplotlib.Figure
  2352. """
  2353. import matplotlib.pyplot as plt
  2354. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  2355. kwargs.setdefault('c', plt.rcParams['patch.facecolor'])
  2356. def plot_group(group, ax):
  2357. xvals = group[x].values
  2358. yvals = group[y].values
  2359. ax.scatter(xvals, yvals, **kwargs)
  2360. ax.grid(grid)
  2361. if by is not None:
  2362. fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax)
  2363. else:
  2364. if ax is None:
  2365. fig = plt.figure()
  2366. ax = fig.add_subplot(111)
  2367. else:
  2368. fig = ax.get_figure()
  2369. plot_group(data, ax)
  2370. ax.set_ylabel(pprint_thing(y))
  2371. ax.set_xlabel(pprint_thing(x))
  2372. ax.grid(grid)
  2373. return fig
  2374. def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
  2375. xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
  2376. sharey=False, figsize=None, layout=None, bins=10, **kwds):
  2377. """
  2378. Draw histogram of the DataFrame's series using matplotlib / pylab.
  2379. Parameters
  2380. ----------
  2381. data : DataFrame
  2382. column : string or sequence
  2383. If passed, will be used to limit data to a subset of columns
  2384. by : object, optional
  2385. If passed, then used to form histograms for separate groups
  2386. grid : boolean, default True
  2387. Whether to show axis grid lines
  2388. xlabelsize : int, default None
  2389. If specified changes the x-axis label size
  2390. xrot : float, default None
  2391. rotation of x axis labels
  2392. ylabelsize : int, default None
  2393. If specified changes the y-axis label size
  2394. yrot : float, default None
  2395. rotation of y axis labels
  2396. ax : matplotlib axes object, default None
  2397. sharex : boolean, default True if ax is None else False
  2398. In case subplots=True, share x axis and set some x axis labels to
  2399. invisible; defaults to True if ax is None otherwise False if an ax
  2400. is passed in; Be aware, that passing in both an ax and sharex=True
  2401. will alter all x axis labels for all subplots in a figure!
  2402. sharey : boolean, default False
  2403. In case subplots=True, share y axis and set some y axis labels to
  2404. invisible
  2405. figsize : tuple
  2406. The size of the figure to create in inches by default
  2407. layout: (optional) a tuple (rows, columns) for the layout of the histograms
  2408. bins: integer, default 10
  2409. Number of histogram bins to be used
  2410. kwds : other plotting keyword arguments
  2411. To be passed to hist function
  2412. """
  2413. if by is not None:
  2414. axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid,
  2415. figsize=figsize, sharex=sharex, sharey=sharey,
  2416. layout=layout, bins=bins, xlabelsize=xlabelsize,
  2417. xrot=xrot, ylabelsize=ylabelsize,
  2418. yrot=yrot, **kwds)
  2419. return axes
  2420. if column is not None:
  2421. if not isinstance(column, (list, np.ndarray, Index)):
  2422. column = [column]
  2423. data = data[column]
  2424. data = data._get_numeric_data()
  2425. naxes = len(data.columns)
  2426. fig, axes = _subplots(naxes=naxes, ax=ax, squeeze=False,
  2427. sharex=sharex, sharey=sharey, figsize=figsize,
  2428. layout=layout)
  2429. _axes = _flatten(axes)
  2430. for i, col in enumerate(_try_sort(data.columns)):
  2431. ax = _axes[i]
  2432. ax.hist(data[col].dropna().values, bins=bins, **kwds)
  2433. ax.set_title(col)
  2434. ax.grid(grid)
  2435. _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot,
  2436. ylabelsize=ylabelsize, yrot=yrot)
  2437. fig.subplots_adjust(wspace=0.3, hspace=0.3)
  2438. return axes
  2439. def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
  2440. xrot=None, ylabelsize=None, yrot=None, figsize=None,
  2441. bins=10, **kwds):
  2442. """
  2443. Draw histogram of the input series using matplotlib
  2444. Parameters
  2445. ----------
  2446. by : object, optional
  2447. If passed, then used to form histograms for separate groups
  2448. ax : matplotlib axis object
  2449. If not passed, uses gca()
  2450. grid : boolean, default True
  2451. Whether to show axis grid lines
  2452. xlabelsize : int, default None
  2453. If specified changes the x-axis label size
  2454. xrot : float, default None
  2455. rotation of x axis labels
  2456. ylabelsize : int, default None
  2457. If specified changes the y-axis label size
  2458. yrot : float, default None
  2459. rotation of y axis labels
  2460. figsize : tuple, default None
  2461. figure size in inches by default
  2462. bins: integer, default 10
  2463. Number of histogram bins to be used
  2464. kwds : keywords
  2465. To be passed to the actual plotting function
  2466. Notes
  2467. -----
  2468. See matplotlib documentation online for more on this
  2469. """
  2470. import matplotlib.pyplot as plt
  2471. if by is None:
  2472. if kwds.get('layout', None) is not None:
  2473. raise ValueError("The 'layout' keyword is not supported when "
  2474. "'by' is None")
  2475. # hack until the plotting interface is a bit more unified
  2476. fig = kwds.pop('figure', plt.gcf() if plt.get_fignums() else
  2477. plt.figure(figsize=figsize))
  2478. if (figsize is not None and tuple(figsize) !=
  2479. tuple(fig.get_size_inches())):
  2480. fig.set_size_inches(*figsize, forward=True)
  2481. if ax is None:
  2482. ax = fig.gca()
  2483. elif ax.get_figure() != fig:
  2484. raise AssertionError('passed axis not bound to passed figure')
  2485. values = self.dropna().values
  2486. ax.hist(values, bins=bins, **kwds)
  2487. ax.grid(grid)
  2488. axes = np.array([ax])
  2489. _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot,
  2490. ylabelsize=ylabelsize, yrot=yrot)
  2491. else:
  2492. if 'figure' in kwds:
  2493. raise ValueError("Cannot pass 'figure' when using the "
  2494. "'by' argument, since a new 'Figure' instance "
  2495. "will be created")
  2496. axes = grouped_hist(self, by=by, ax=ax, grid=grid, figsize=figsize,
  2497. bins=bins, xlabelsize=xlabelsize, xrot=xrot,
  2498. ylabelsize=ylabelsize, yrot=yrot, **kwds)
  2499. if hasattr(axes, 'ndim'):
  2500. if axes.ndim == 1 and len(axes) == 1:
  2501. return axes[0]
  2502. return axes
  2503. def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
  2504. layout=None, sharex=False, sharey=False, rot=90, grid=True,
  2505. xlabelsize=None, xrot=None, ylabelsize=None, yrot=None,
  2506. **kwargs):
  2507. """
  2508. Grouped histogram
  2509. Parameters
  2510. ----------
  2511. data: Series/DataFrame
  2512. column: object, optional
  2513. by: object, optional
  2514. ax: axes, optional
  2515. bins: int, default 50
  2516. figsize: tuple, optional
  2517. layout: optional
  2518. sharex: boolean, default False
  2519. sharey: boolean, default False
  2520. rot: int, default 90
  2521. grid: bool, default True
  2522. kwargs: dict, keyword arguments passed to matplotlib.Axes.hist
  2523. Returns
  2524. -------
  2525. axes: collection of Matplotlib Axes
  2526. """
  2527. def plot_group(group, ax):
  2528. ax.hist(group.dropna().values, bins=bins, **kwargs)
  2529. xrot = xrot or rot
  2530. fig, axes = _grouped_plot(plot_group, data, column=column,
  2531. by=by, sharex=sharex, sharey=sharey, ax=ax,
  2532. figsize=figsize, layout=layout, rot=rot)
  2533. _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot,
  2534. ylabelsize=ylabelsize, yrot=yrot)
  2535. fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9,
  2536. hspace=0.5, wspace=0.3)
  2537. return axes
  2538. def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
  2539. rot=0, grid=True, ax=None, figsize=None,
  2540. layout=None, **kwds):
  2541. """
  2542. Make box plots from DataFrameGroupBy data.
  2543. Parameters
  2544. ----------
  2545. grouped : Grouped DataFrame
  2546. subplots :
  2547. * ``False`` - no subplots will be used
  2548. * ``True`` - create a subplot for each group
  2549. column : column name or list of names, or vector
  2550. Can be any valid input to groupby
  2551. fontsize : int or string
  2552. rot : label rotation angle
  2553. grid : Setting this to True will show the grid
  2554. ax : Matplotlib axis object, default None
  2555. figsize : A tuple (width, height) in inches
  2556. layout : tuple (optional)
  2557. (rows, columns) for the layout of the plot
  2558. kwds : other plotting keyword arguments to be passed to matplotlib boxplot
  2559. function
  2560. Returns
  2561. -------
  2562. dict of key/value = group key/DataFrame.boxplot return value
  2563. or DataFrame.boxplot return value in case subplots=figures=False
  2564. Examples
  2565. --------
  2566. >>> import pandas
  2567. >>> import numpy as np
  2568. >>> import itertools
  2569. >>>
  2570. >>> tuples = [t for t in itertools.product(range(1000), range(4))]
  2571. >>> index = pandas.MultiIndex.from_tuples(tuples, names=['lvl0', 'lvl1'])
  2572. >>> data = np.random.randn(len(index),4)
  2573. >>> df = pandas.DataFrame(data, columns=list('ABCD'), index=index)
  2574. >>>
  2575. >>> grouped = df.groupby(level='lvl1')
  2576. >>> boxplot_frame_groupby(grouped)
  2577. >>>
  2578. >>> grouped = df.unstack(level='lvl1').groupby(level=0, axis=1)
  2579. >>> boxplot_frame_groupby(grouped, subplots=False)
  2580. """
  2581. if subplots is True:
  2582. naxes = len(grouped)
  2583. fig, axes = _subplots(naxes=naxes, squeeze=False,
  2584. ax=ax, sharex=False, sharey=True,
  2585. figsize=figsize, layout=layout)
  2586. axes = _flatten(axes)
  2587. ret = compat.OrderedDict()
  2588. for (key, group), ax in zip(grouped, axes):
  2589. d = group.boxplot(ax=ax, column=column, fontsize=fontsize,
  2590. rot=rot, grid=grid, **kwds)
  2591. ax.set_title(pprint_thing(key))
  2592. ret[key] = d
  2593. fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1,
  2594. right=0.9, wspace=0.2)
  2595. else:
  2596. from pandas.tools.merge import concat
  2597. keys, frames = zip(*grouped)
  2598. if grouped.axis == 0:
  2599. df = concat(frames, keys=keys, axis=1)
  2600. else:
  2601. if len(frames) > 1:
  2602. df = frames[0].join(frames[1::])
  2603. else:
  2604. df = frames[0]
  2605. ret = df.boxplot(column=column, fontsize=fontsize, rot=rot,
  2606. grid=grid, ax=ax, figsize=figsize,
  2607. layout=layout, **kwds)
  2608. return ret
  2609. def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
  2610. figsize=None, sharex=True, sharey=True, layout=None,
  2611. rot=0, ax=None, **kwargs):
  2612. from pandas import DataFrame
  2613. if figsize == 'default':
  2614. # allowed to specify mpl default with 'default'
  2615. warnings.warn("figsize='default' is deprecated. Specify figure"
  2616. "size by tuple instead", FutureWarning, stacklevel=4)
  2617. figsize = None
  2618. grouped = data.groupby(by)
  2619. if column is not None:
  2620. grouped = grouped[column]
  2621. naxes = len(grouped)
  2622. fig, axes = _subplots(naxes=naxes, figsize=figsize,
  2623. sharex=sharex, sharey=sharey, ax=ax,
  2624. layout=layout)
  2625. _axes = _flatten(axes)
  2626. for i, (key, group) in enumerate(grouped):
  2627. ax = _axes[i]
  2628. if numeric_only and isinstance(group, DataFrame):
  2629. group = group._get_numeric_data()
  2630. plotf(group, ax, **kwargs)
  2631. ax.set_title(pprint_thing(key))
  2632. return fig, axes
  2633. def _grouped_plot_by_column(plotf, data, columns=None, by=None,
  2634. numeric_only=True, grid=False,
  2635. figsize=None, ax=None, layout=None,
  2636. return_type=None, **kwargs):
  2637. grouped = data.groupby(by)
  2638. if columns is None:
  2639. if not isinstance(by, (list, tuple)):
  2640. by = [by]
  2641. columns = data._get_numeric_data().columns.difference(by)
  2642. naxes = len(columns)
  2643. fig, axes = _subplots(naxes=naxes, sharex=True, sharey=True,
  2644. figsize=figsize, ax=ax, layout=layout)
  2645. _axes = _flatten(axes)
  2646. result = compat.OrderedDict()
  2647. for i, col in enumerate(columns):
  2648. ax = _axes[i]
  2649. gp_col = grouped[col]
  2650. keys, values = zip(*gp_col)
  2651. re_plotf = plotf(keys, values, ax, **kwargs)
  2652. ax.set_title(col)
  2653. ax.set_xlabel(pprint_thing(by))
  2654. result[col] = re_plotf
  2655. ax.grid(grid)
  2656. # Return axes in multiplot case, maybe revisit later # 985
  2657. if return_type is None:
  2658. result = axes
  2659. byline = by[0] if len(by) == 1 else by
  2660. fig.suptitle('Boxplot grouped by %s' % byline)
  2661. fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
  2662. return result
  2663. def table(ax, data, rowLabels=None, colLabels=None,
  2664. **kwargs):
  2665. """
  2666. Helper function to convert DataFrame and Series to matplotlib.table
  2667. Parameters
  2668. ----------
  2669. `ax`: Matplotlib axes object
  2670. `data`: DataFrame or Series
  2671. data for table contents
  2672. `kwargs`: keywords, optional
  2673. keyword arguments which passed to matplotlib.table.table.
  2674. If `rowLabels` or `colLabels` is not specified, data index or column
  2675. name will be used.
  2676. Returns
  2677. -------
  2678. matplotlib table object
  2679. """
  2680. from pandas import DataFrame
  2681. if isinstance(data, Series):
  2682. data = DataFrame(data, columns=[data.name])
  2683. elif isinstance(data, DataFrame):
  2684. pass
  2685. else:
  2686. raise ValueError('Input data must be DataFrame or Series')
  2687. if rowLabels is None:
  2688. rowLabels = data.index
  2689. if colLabels is None:
  2690. colLabels = data.columns
  2691. cellText = data.values
  2692. import matplotlib.table
  2693. table = matplotlib.table.table(ax, cellText=cellText,
  2694. rowLabels=rowLabels,
  2695. colLabels=colLabels, **kwargs)
  2696. return table
  2697. def _get_layout(nplots, layout=None, layout_type='box'):
  2698. if layout is not None:
  2699. if not isinstance(layout, (tuple, list)) or len(layout) != 2:
  2700. raise ValueError('Layout must be a tuple of (rows, columns)')
  2701. nrows, ncols = layout
  2702. # Python 2 compat
  2703. ceil_ = lambda x: int(ceil(x))
  2704. if nrows == -1 and ncols > 0:
  2705. layout = nrows, ncols = (ceil_(float(nplots) / ncols), ncols)
  2706. elif ncols == -1 and nrows > 0:
  2707. layout = nrows, ncols = (nrows, ceil_(float(nplots) / nrows))
  2708. elif ncols <= 0 and nrows <= 0:
  2709. msg = "At least one dimension of layout must be positive"
  2710. raise ValueError(msg)
  2711. if nrows * ncols < nplots:
  2712. raise ValueError('Layout of %sx%s must be larger than '
  2713. 'required size %s' % (nrows, ncols, nplots))
  2714. return layout
  2715. if layout_type == 'single':
  2716. return (1, 1)
  2717. elif layout_type == 'horizontal':
  2718. return (1, nplots)
  2719. elif layout_type == 'vertical':
  2720. return (nplots, 1)
  2721. layouts = {1: (1, 1), 2: (1, 2), 3: (2, 2), 4: (2, 2)}
  2722. try:
  2723. return layouts[nplots]
  2724. except KeyError:
  2725. k = 1
  2726. while k ** 2 < nplots:
  2727. k += 1
  2728. if (k - 1) * k >= nplots:
  2729. return k, (k - 1)
  2730. else:
  2731. return k, k
  2732. # copied from matplotlib/pyplot.py and modified for pandas.plotting
  2733. def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True,
  2734. subplot_kw=None, ax=None, layout=None, layout_type='box',
  2735. **fig_kw):
  2736. """Create a figure with a set of subplots already made.
  2737. This utility wrapper makes it convenient to create common layouts of
  2738. subplots, including the enclosing figure object, in a single call.
  2739. Keyword arguments:
  2740. naxes : int
  2741. Number of required axes. Exceeded axes are set invisible. Default is
  2742. nrows * ncols.
  2743. sharex : bool
  2744. If True, the X axis will be shared amongst all subplots.
  2745. sharey : bool
  2746. If True, the Y axis will be shared amongst all subplots.
  2747. squeeze : bool
  2748. If True, extra dimensions are squeezed out from the returned axis object:
  2749. - if only one subplot is constructed (nrows=ncols=1), the resulting
  2750. single Axis object is returned as a scalar.
  2751. - for Nx1 or 1xN subplots, the returned object is a 1-d numpy object
  2752. array of Axis objects are returned as numpy 1-d arrays.
  2753. - for NxM subplots with N>1 and M>1 are returned as a 2d array.
  2754. If False, no squeezing at all is done: the returned axis object is always
  2755. a 2-d array containing Axis instances, even if it ends up being 1x1.
  2756. subplot_kw : dict
  2757. Dict with keywords passed to the add_subplot() call used to create each
  2758. subplots.
  2759. ax : Matplotlib axis object, optional
  2760. layout : tuple
  2761. Number of rows and columns of the subplot grid.
  2762. If not specified, calculated from naxes and layout_type
  2763. layout_type : {'box', 'horziontal', 'vertical'}, default 'box'
  2764. Specify how to layout the subplot grid.
  2765. fig_kw : Other keyword arguments to be passed to the figure() call.
  2766. Note that all keywords not recognized above will be
  2767. automatically included here.
  2768. Returns:
  2769. fig, ax : tuple
  2770. - fig is the Matplotlib Figure object
  2771. - ax can be either a single axis object or an array of axis objects if
  2772. more than one subplot was created. The dimensions of the resulting array
  2773. can be controlled with the squeeze keyword, see above.
  2774. **Examples:**
  2775. x = np.linspace(0, 2*np.pi, 400)
  2776. y = np.sin(x**2)
  2777. # Just a figure and one subplot
  2778. f, ax = plt.subplots()
  2779. ax.plot(x, y)
  2780. ax.set_title('Simple plot')
  2781. # Two subplots, unpack the output array immediately
  2782. f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
  2783. ax1.plot(x, y)
  2784. ax1.set_title('Sharing Y axis')
  2785. ax2.scatter(x, y)
  2786. # Four polar axes
  2787. plt.subplots(2, 2, subplot_kw=dict(polar=True))
  2788. """
  2789. import matplotlib.pyplot as plt
  2790. if subplot_kw is None:
  2791. subplot_kw = {}
  2792. if ax is None:
  2793. fig = plt.figure(**fig_kw)
  2794. else:
  2795. if is_list_like(ax):
  2796. ax = _flatten(ax)
  2797. if layout is not None:
  2798. warnings.warn("When passing multiple axes, layout keyword is "
  2799. "ignored", UserWarning)
  2800. if sharex or sharey:
  2801. warnings.warn("When passing multiple axes, sharex and sharey "
  2802. "are ignored. These settings must be specified "
  2803. "when creating axes", UserWarning,
  2804. stacklevel=4)
  2805. if len(ax) == naxes:
  2806. fig = ax[0].get_figure()
  2807. return fig, ax
  2808. else:
  2809. raise ValueError("The number of passed axes must be {0}, the "
  2810. "same as the output plot".format(naxes))
  2811. fig = ax.get_figure()
  2812. # if ax is passed and a number of subplots is 1, return ax as it is
  2813. if naxes == 1:
  2814. if squeeze:
  2815. return fig, ax
  2816. else:
  2817. return fig, _flatten(ax)
  2818. else:
  2819. warnings.warn("To output multiple subplots, the figure containing "
  2820. "the passed axes is being cleared", UserWarning,
  2821. stacklevel=4)
  2822. fig.clear()
  2823. nrows, ncols = _get_layout(naxes, layout=layout, layout_type=layout_type)
  2824. nplots = nrows * ncols
  2825. # Create empty object array to hold all axes. It's easiest to make it 1-d
  2826. # so we can just append subplots upon creation, and then
  2827. axarr = np.empty(nplots, dtype=object)
  2828. # Create first subplot separately, so we can share it if requested
  2829. ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw)
  2830. if sharex:
  2831. subplot_kw['sharex'] = ax0
  2832. if sharey:
  2833. subplot_kw['sharey'] = ax0
  2834. axarr[0] = ax0
  2835. # Note off-by-one counting because add_subplot uses the MATLAB 1-based
  2836. # convention.
  2837. for i in range(1, nplots):
  2838. kwds = subplot_kw.copy()
  2839. # Set sharex and sharey to None for blank/dummy axes, these can
  2840. # interfere with proper axis limits on the visible axes if
  2841. # they share axes e.g. issue #7528
  2842. if i >= naxes:
  2843. kwds['sharex'] = None
  2844. kwds['sharey'] = None
  2845. ax = fig.add_subplot(nrows, ncols, i + 1, **kwds)
  2846. axarr[i] = ax
  2847. if naxes != nplots:
  2848. for ax in axarr[naxes:]:
  2849. ax.set_visible(False)
  2850. _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey)
  2851. if squeeze:
  2852. # Reshape the array to have the final desired dimension (nrow,ncol),
  2853. # though discarding unneeded dimensions that equal 1. If we only have
  2854. # one subplot, just return it instead of a 1-element array.
  2855. if nplots == 1:
  2856. axes = axarr[0]
  2857. else:
  2858. axes = axarr.reshape(nrows, ncols).squeeze()
  2859. else:
  2860. # returned axis array will be always 2-d, even if nrows=ncols=1
  2861. axes = axarr.reshape(nrows, ncols)
  2862. return fig, axes
  2863. def _remove_labels_from_axis(axis):
  2864. for t in axis.get_majorticklabels():
  2865. t.set_visible(False)
  2866. try:
  2867. # set_visible will not be effective if
  2868. # minor axis has NullLocator and NullFormattor (default)
  2869. import matplotlib.ticker as ticker
  2870. if isinstance(axis.get_minor_locator(), ticker.NullLocator):
  2871. axis.set_minor_locator(ticker.AutoLocator())
  2872. if isinstance(axis.get_minor_formatter(), ticker.NullFormatter):
  2873. axis.set_minor_formatter(ticker.FormatStrFormatter(''))
  2874. for t in axis.get_minorticklabels():
  2875. t.set_visible(False)
  2876. except Exception: # pragma no cover
  2877. raise
  2878. axis.get_label().set_visible(False)
  2879. def _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey):
  2880. if nplots > 1:
  2881. if nrows > 1:
  2882. try:
  2883. # first find out the ax layout,
  2884. # so that we can correctly handle 'gaps"
  2885. layout = np.zeros((nrows + 1, ncols + 1), dtype=np.bool)
  2886. for ax in axarr:
  2887. layout[ax.rowNum, ax.colNum] = ax.get_visible()
  2888. for ax in axarr:
  2889. # only the last row of subplots should get x labels -> all
  2890. # other off layout handles the case that the subplot is
  2891. # the last in the column, because below is no subplot/gap.
  2892. if not layout[ax.rowNum + 1, ax.colNum]:
  2893. continue
  2894. if sharex or len(ax.get_shared_x_axes()
  2895. .get_siblings(ax)) > 1:
  2896. _remove_labels_from_axis(ax.xaxis)
  2897. except IndexError:
  2898. # if gridspec is used, ax.rowNum and ax.colNum may different
  2899. # from layout shape. in this case, use last_row logic
  2900. for ax in axarr:
  2901. if ax.is_last_row():
  2902. continue
  2903. if sharex or len(ax.get_shared_x_axes()
  2904. .get_siblings(ax)) > 1:
  2905. _remove_labels_from_axis(ax.xaxis)
  2906. if ncols > 1:
  2907. for ax in axarr:
  2908. # only the first column should get y labels -> set all other to
  2909. # off as we only have labels in teh first column and we always
  2910. # have a subplot there, we can skip the layout test
  2911. if ax.is_first_col():
  2912. continue
  2913. if sharey or len(ax.get_shared_y_axes().get_siblings(ax)) > 1:
  2914. _remove_labels_from_axis(ax.yaxis)
  2915. def _flatten(axes):
  2916. if not is_list_like(axes):
  2917. return np.array([axes])
  2918. elif isinstance(axes, (np.ndarray, Index)):
  2919. return axes.ravel()
  2920. return np.array(axes)
  2921. def _get_all_lines(ax):
  2922. lines = ax.get_lines()
  2923. if hasattr(ax, 'right_ax'):
  2924. lines += ax.right_ax.get_lines()
  2925. if hasattr(ax, 'left_ax'):
  2926. lines += ax.left_ax.get_lines()
  2927. return lines
  2928. def _get_xlim(lines):
  2929. left, right = np.inf, -np.inf
  2930. for l in lines:
  2931. x = l.get_xdata(orig=False)
  2932. left = min(x[0], left)
  2933. right = max(x[-1], right)
  2934. return left, right
  2935. def _set_ticks_props(axes, xlabelsize=None, xrot=None,
  2936. ylabelsize=None, yrot=None):
  2937. import matplotlib.pyplot as plt
  2938. for ax in _flatten(axes):
  2939. if xlabelsize is not None:
  2940. plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
  2941. if xrot is not None:
  2942. plt.setp(ax.get_xticklabels(), rotation=xrot)
  2943. if ylabelsize is not None:
  2944. plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
  2945. if yrot is not None:
  2946. plt.setp(ax.get_yticklabels(), rotation=yrot)
  2947. return axes
  2948. class BasePlotMethods(PandasObject):
  2949. def __init__(self, data):
  2950. self._data = data
  2951. def __call__(self, *args, **kwargs):
  2952. raise NotImplementedError
  2953. class SeriesPlotMethods(BasePlotMethods):
  2954. """Series plotting accessor and method
  2955. Examples
  2956. --------
  2957. >>> s.plot.line()
  2958. >>> s.plot.bar()
  2959. >>> s.plot.hist()
  2960. Plotting methods can also be accessed by calling the accessor as a method
  2961. with the ``kind`` argument:
  2962. ``s.plot(kind='line')`` is equivalent to ``s.plot.line()``
  2963. """
  2964. def __call__(self, kind='line', ax=None,
  2965. figsize=None, use_index=True, title=None, grid=None,
  2966. legend=False, style=None, logx=False, logy=False,
  2967. loglog=False, xticks=None, yticks=None,
  2968. xlim=None, ylim=None,
  2969. rot=None, fontsize=None, colormap=None, table=False,
  2970. yerr=None, xerr=None,
  2971. label=None, secondary_y=False, **kwds):
  2972. return plot_series(self._data, kind=kind, ax=ax, figsize=figsize,
  2973. use_index=use_index, title=title, grid=grid,
  2974. legend=legend, style=style, logx=logx, logy=logy,
  2975. loglog=loglog, xticks=xticks, yticks=yticks,
  2976. xlim=xlim, ylim=ylim, rot=rot, fontsize=fontsize,
  2977. colormap=colormap, table=table, yerr=yerr,
  2978. xerr=xerr, label=label, secondary_y=secondary_y,
  2979. **kwds)
  2980. __call__.__doc__ = plot_series.__doc__
  2981. def line(self, **kwds):
  2982. """
  2983. Line plot
  2984. .. versionadded:: 0.17.0
  2985. Parameters
  2986. ----------
  2987. **kwds : optional
  2988. Keyword arguments to pass on to :py:meth:`pandas.Series.plot`.
  2989. Returns
  2990. -------
  2991. axes : matplotlib.AxesSubplot or np.array of them
  2992. """
  2993. return self(kind='line', **kwds)
  2994. def bar(self, **kwds):
  2995. """
  2996. Vertical bar plot
  2997. .. versionadded:: 0.17.0
  2998. Parameters
  2999. ----------
  3000. **kwds : optional
  3001. Keyword arguments to pass on to :py:meth:`pandas.Series.plot`.
  3002. Returns
  3003. -------
  3004. axes : matplotlib.AxesSubplot or np.array of them
  3005. """
  3006. return self(kind='bar', **kwds)
  3007. def barh(self, **kwds):
  3008. """
  3009. Horizontal bar plot
  3010. .. versionadded:: 0.17.0
  3011. Parameters
  3012. ----------
  3013. **kwds : optional
  3014. Keyword arguments to pass on to :py:meth:`pandas.Series.plot`.
  3015. Returns
  3016. -------
  3017. axes : matplotlib.AxesSubplot or np.array of them
  3018. """
  3019. return self(kind='barh', **kwds)
  3020. def box(self, **kwds):
  3021. """
  3022. Boxplot
  3023. .. versionadded:: 0.17.0
  3024. Parameters
  3025. ----------
  3026. **kwds : optional
  3027. Keyword arguments to pass on to :py:meth:`pandas.Series.plot`.
  3028. Returns
  3029. -------
  3030. axes : matplotlib.AxesSubplot or np.array of them
  3031. """
  3032. return self(kind='box', **kwds)
  3033. def hist(self, bins=10, **kwds):
  3034. """
  3035. Histogram
  3036. .. versionadded:: 0.17.0
  3037. Parameters
  3038. ----------
  3039. bins: integer, default 10
  3040. Number of histogram bins to be used
  3041. **kwds : optional
  3042. Keyword arguments to pass on to :py:meth:`pandas.Series.plot`.
  3043. Returns
  3044. -------
  3045. axes : matplotlib.AxesSubplot or np.array of them
  3046. """
  3047. return self(kind='hist', bins=bins, **kwds)
  3048. def kde(self, **kwds):
  3049. """
  3050. Kernel Density Estimate plot
  3051. .. versionadded:: 0.17.0
  3052. Parameters
  3053. ----------
  3054. **kwds : optional
  3055. Keyword arguments to pass on to :py:meth:`pandas.Series.plot`.
  3056. Returns
  3057. -------
  3058. axes : matplotlib.AxesSubplot or np.array of them
  3059. """
  3060. return self(kind='kde', **kwds)
  3061. density = kde
  3062. def area(self, **kwds):
  3063. """
  3064. Area plot
  3065. .. versionadded:: 0.17.0
  3066. Parameters
  3067. ----------
  3068. **kwds : optional
  3069. Keyword arguments to pass on to :py:meth:`pandas.Series.plot`.
  3070. Returns
  3071. -------
  3072. axes : matplotlib.AxesSubplot or np.array of them
  3073. """
  3074. return self(kind='area', **kwds)
  3075. def pie(self, **kwds):
  3076. """
  3077. Pie chart
  3078. .. versionadded:: 0.17.0
  3079. Parameters
  3080. ----------
  3081. **kwds : optional
  3082. Keyword arguments to pass on to :py:meth:`pandas.Series.plot`.
  3083. Returns
  3084. -------
  3085. axes : matplotlib.AxesSubplot or np.array of them
  3086. """
  3087. return self(kind='pie', **kwds)
  3088. class FramePlotMethods(BasePlotMethods):
  3089. """DataFrame plotting accessor and method
  3090. Examples
  3091. --------
  3092. >>> df.plot.line()
  3093. >>> df.plot.scatter('x', 'y')
  3094. >>> df.plot.hexbin()
  3095. These plotting methods can also be accessed by calling the accessor as a
  3096. method with the ``kind`` argument:
  3097. ``df.plot(kind='line')`` is equivalent to ``df.plot.line()``
  3098. """
  3099. def __call__(self, x=None, y=None, kind='line', ax=None,
  3100. subplots=False, sharex=None, sharey=False, layout=None,
  3101. figsize=None, use_index=True, title=None, grid=None,
  3102. legend=True, style=None, logx=False, logy=False, loglog=False,
  3103. xticks=None, yticks=None, xlim=None, ylim=None,
  3104. rot=None, fontsize=None, colormap=None, table=False,
  3105. yerr=None, xerr=None,
  3106. secondary_y=False, sort_columns=False, **kwds):
  3107. return plot_frame(self._data, kind=kind, x=x, y=y, ax=ax,
  3108. subplots=subplots, sharex=sharex, sharey=sharey,
  3109. layout=layout, figsize=figsize, use_index=use_index,
  3110. title=title, grid=grid, legend=legend, style=style,
  3111. logx=logx, logy=logy, loglog=loglog, xticks=xticks,
  3112. yticks=yticks, xlim=xlim, ylim=ylim, rot=rot,
  3113. fontsize=fontsize, colormap=colormap, table=table,
  3114. yerr=yerr, xerr=xerr, secondary_y=secondary_y,
  3115. sort_columns=sort_columns, **kwds)
  3116. __call__.__doc__ = plot_frame.__doc__
  3117. def line(self, x=None, y=None, **kwds):
  3118. """
  3119. Line plot
  3120. .. versionadded:: 0.17.0
  3121. Parameters
  3122. ----------
  3123. x, y : label or position, optional
  3124. Coordinates for each point.
  3125. **kwds : optional
  3126. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3127. Returns
  3128. -------
  3129. axes : matplotlib.AxesSubplot or np.array of them
  3130. """
  3131. return self(kind='line', x=x, y=y, **kwds)
  3132. def bar(self, x=None, y=None, **kwds):
  3133. """
  3134. Vertical bar plot
  3135. .. versionadded:: 0.17.0
  3136. Parameters
  3137. ----------
  3138. x, y : label or position, optional
  3139. Coordinates for each point.
  3140. **kwds : optional
  3141. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3142. Returns
  3143. -------
  3144. axes : matplotlib.AxesSubplot or np.array of them
  3145. """
  3146. return self(kind='bar', x=x, y=y, **kwds)
  3147. def barh(self, x=None, y=None, **kwds):
  3148. """
  3149. Horizontal bar plot
  3150. .. versionadded:: 0.17.0
  3151. Parameters
  3152. ----------
  3153. x, y : label or position, optional
  3154. Coordinates for each point.
  3155. **kwds : optional
  3156. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3157. Returns
  3158. -------
  3159. axes : matplotlib.AxesSubplot or np.array of them
  3160. """
  3161. return self(kind='barh', x=x, y=y, **kwds)
  3162. def box(self, by=None, **kwds):
  3163. """
  3164. Boxplot
  3165. .. versionadded:: 0.17.0
  3166. Parameters
  3167. ----------
  3168. by : string or sequence
  3169. Column in the DataFrame to group by.
  3170. \*\*kwds : optional
  3171. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3172. Returns
  3173. -------
  3174. axes : matplotlib.AxesSubplot or np.array of them
  3175. """
  3176. return self(kind='box', by=by, **kwds)
  3177. def hist(self, by=None, bins=10, **kwds):
  3178. """
  3179. Histogram
  3180. .. versionadded:: 0.17.0
  3181. Parameters
  3182. ----------
  3183. by : string or sequence
  3184. Column in the DataFrame to group by.
  3185. bins: integer, default 10
  3186. Number of histogram bins to be used
  3187. **kwds : optional
  3188. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3189. Returns
  3190. -------
  3191. axes : matplotlib.AxesSubplot or np.array of them
  3192. """
  3193. return self(kind='hist', by=by, bins=bins, **kwds)
  3194. def kde(self, **kwds):
  3195. """
  3196. Kernel Density Estimate plot
  3197. .. versionadded:: 0.17.0
  3198. Parameters
  3199. ----------
  3200. **kwds : optional
  3201. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3202. Returns
  3203. -------
  3204. axes : matplotlib.AxesSubplot or np.array of them
  3205. """
  3206. return self(kind='kde', **kwds)
  3207. density = kde
  3208. def area(self, x=None, y=None, **kwds):
  3209. """
  3210. Area plot
  3211. .. versionadded:: 0.17.0
  3212. Parameters
  3213. ----------
  3214. x, y : label or position, optional
  3215. Coordinates for each point.
  3216. **kwds : optional
  3217. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3218. Returns
  3219. -------
  3220. axes : matplotlib.AxesSubplot or np.array of them
  3221. """
  3222. return self(kind='area', x=x, y=y, **kwds)
  3223. def pie(self, y=None, **kwds):
  3224. """
  3225. Pie chart
  3226. .. versionadded:: 0.17.0
  3227. Parameters
  3228. ----------
  3229. y : label or position, optional
  3230. Column to plot.
  3231. **kwds : optional
  3232. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3233. Returns
  3234. -------
  3235. axes : matplotlib.AxesSubplot or np.array of them
  3236. """
  3237. return self(kind='pie', y=y, **kwds)
  3238. def scatter(self, x, y, s=None, c=None, **kwds):
  3239. """
  3240. Scatter plot
  3241. .. versionadded:: 0.17.0
  3242. Parameters
  3243. ----------
  3244. x, y : label or position, optional
  3245. Coordinates for each point.
  3246. s : scalar or array_like, optional
  3247. Size of each point.
  3248. c : label or position, optional
  3249. Color of each point.
  3250. **kwds : optional
  3251. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3252. Returns
  3253. -------
  3254. axes : matplotlib.AxesSubplot or np.array of them
  3255. """
  3256. return self(kind='scatter', x=x, y=y, c=c, s=s, **kwds)
  3257. def hexbin(self, x, y, C=None, reduce_C_function=None, gridsize=None,
  3258. **kwds):
  3259. """
  3260. Hexbin plot
  3261. .. versionadded:: 0.17.0
  3262. Parameters
  3263. ----------
  3264. x, y : label or position, optional
  3265. Coordinates for each point.
  3266. C : label or position, optional
  3267. The value at each `(x, y)` point.
  3268. reduce_C_function : callable, optional
  3269. Function of one argument that reduces all the values in a bin to
  3270. a single number (e.g. `mean`, `max`, `sum`, `std`).
  3271. gridsize : int, optional
  3272. Number of bins.
  3273. **kwds : optional
  3274. Keyword arguments to pass on to :py:meth:`pandas.DataFrame.plot`.
  3275. Returns
  3276. -------
  3277. axes : matplotlib.AxesSubplot or np.array of them
  3278. """
  3279. if reduce_C_function is not None:
  3280. kwds['reduce_C_function'] = reduce_C_function
  3281. if gridsize is not None:
  3282. kwds['gridsize'] = gridsize
  3283. return self(kind='hexbin', x=x, y=y, C=C, **kwds)
  3284. if __name__ == '__main__':
  3285. # import pandas.rpy.common as com
  3286. # sales = com.load_data('sanfrancisco.home.sales', package='nutshell')
  3287. # top10 = sales['zip'].value_counts()[:10].index
  3288. # sales2 = sales[sales.zip.isin(top10)]
  3289. # _ = scatter_plot(sales2, 'squarefeet', 'price', by='zip')
  3290. # plt.show()
  3291. import matplotlib.pyplot as plt
  3292. import pandas.tools.plotting as plots
  3293. import pandas.core.frame as fr
  3294. reload(plots) # noqa
  3295. reload(fr) # noqa
  3296. from pandas.core.frame import DataFrame
  3297. data = DataFrame([[3, 6, -5], [4, 8, 2], [4, 9, -6],
  3298. [4, 9, -3], [2, 5, -1]],
  3299. columns=['A', 'B', 'C'])
  3300. data.plot(kind='barh', stacked=True)
  3301. plt.show()