PageRenderTime 56ms CodeModel.GetById 12ms RepoModel.GetById 1ms app.codeStats 0ms

/pandas/tools/plotting.py

http://github.com/wesm/pandas
Python | 3996 lines | 3667 code | 214 blank | 115 comment | 268 complexity | 62ae60e55cfe6d932a716eec2ddd081d MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0

Large files files are truncated, but you can click here to view the full file

  1. # being a bit too dynamic
  2. # pylint: disable=E1101
  3. from __future__ import division
  4. import warnings
  5. import re
  6. from math import ceil
  7. from collections import namedtuple
  8. from contextlib import contextmanager
  9. from distutils.version import LooseVersion
  10. import numpy as np
  11. from pandas.types.common import (is_list_like,
  12. is_integer,
  13. is_number,
  14. is_hashable,
  15. is_iterator)
  16. from pandas.types.missing import isnull, notnull
  17. from pandas.util.decorators import cache_readonly, deprecate_kwarg
  18. from pandas.core.base import PandasObject
  19. from pandas.core.common import AbstractMethodError, _try_sort
  20. from pandas.core.generic import _shared_docs, _shared_doc_kwargs
  21. from pandas.core.index import Index, MultiIndex
  22. from pandas.core.series import Series, remove_na
  23. from pandas.tseries.period import PeriodIndex
  24. from pandas.compat import range, lrange, lmap, map, zip, string_types
  25. import pandas.compat as compat
  26. from pandas.formats.printing import pprint_thing
  27. from pandas.util.decorators import Appender
  28. try: # mpl optional
  29. import pandas.tseries.converter as conv
  30. conv.register() # needs to override so set_xlim works with str/number
  31. except ImportError:
  32. pass
  33. # Extracted from https://gist.github.com/huyng/816622
  34. # this is the rcParams set when setting display.with_mpl_style
  35. # to True.
  36. mpl_stylesheet = {
  37. 'axes.axisbelow': True,
  38. 'axes.color_cycle': ['#348ABD',
  39. '#7A68A6',
  40. '#A60628',
  41. '#467821',
  42. '#CF4457',
  43. '#188487',
  44. '#E24A33'],
  45. 'axes.edgecolor': '#bcbcbc',
  46. 'axes.facecolor': '#eeeeee',
  47. 'axes.grid': True,
  48. 'axes.labelcolor': '#555555',
  49. 'axes.labelsize': 'large',
  50. 'axes.linewidth': 1.0,
  51. 'axes.titlesize': 'x-large',
  52. 'figure.edgecolor': 'white',
  53. 'figure.facecolor': 'white',
  54. 'figure.figsize': (6.0, 4.0),
  55. 'figure.subplot.hspace': 0.5,
  56. 'font.family': 'monospace',
  57. 'font.monospace': ['Andale Mono',
  58. 'Nimbus Mono L',
  59. 'Courier New',
  60. 'Courier',
  61. 'Fixed',
  62. 'Terminal',
  63. 'monospace'],
  64. 'font.size': 10,
  65. 'interactive': True,
  66. 'keymap.all_axes': ['a'],
  67. 'keymap.back': ['left', 'c', 'backspace'],
  68. 'keymap.forward': ['right', 'v'],
  69. 'keymap.fullscreen': ['f'],
  70. 'keymap.grid': ['g'],
  71. 'keymap.home': ['h', 'r', 'home'],
  72. 'keymap.pan': ['p'],
  73. 'keymap.save': ['s'],
  74. 'keymap.xscale': ['L', 'k'],
  75. 'keymap.yscale': ['l'],
  76. 'keymap.zoom': ['o'],
  77. 'legend.fancybox': True,
  78. 'lines.antialiased': True,
  79. 'lines.linewidth': 1.0,
  80. 'patch.antialiased': True,
  81. 'patch.edgecolor': '#EEEEEE',
  82. 'patch.facecolor': '#348ABD',
  83. 'patch.linewidth': 0.5,
  84. 'toolbar': 'toolbar2',
  85. 'xtick.color': '#555555',
  86. 'xtick.direction': 'in',
  87. 'xtick.major.pad': 6.0,
  88. 'xtick.major.size': 0.0,
  89. 'xtick.minor.pad': 6.0,
  90. 'xtick.minor.size': 0.0,
  91. 'ytick.color': '#555555',
  92. 'ytick.direction': 'in',
  93. 'ytick.major.pad': 6.0,
  94. 'ytick.major.size': 0.0,
  95. 'ytick.minor.pad': 6.0,
  96. 'ytick.minor.size': 0.0
  97. }
  98. def _mpl_le_1_2_1():
  99. try:
  100. import matplotlib as mpl
  101. return (str(mpl.__version__) <= LooseVersion('1.2.1') and
  102. str(mpl.__version__)[0] != '0')
  103. except ImportError:
  104. return False
  105. def _mpl_ge_1_3_1():
  106. try:
  107. import matplotlib
  108. # The or v[0] == '0' is because their versioneer is
  109. # messed up on dev
  110. return (matplotlib.__version__ >= LooseVersion('1.3.1') or
  111. matplotlib.__version__[0] == '0')
  112. except ImportError:
  113. return False
  114. def _mpl_ge_1_4_0():
  115. try:
  116. import matplotlib
  117. return (matplotlib.__version__ >= LooseVersion('1.4') or
  118. matplotlib.__version__[0] == '0')
  119. except ImportError:
  120. return False
  121. def _mpl_ge_1_5_0():
  122. try:
  123. import matplotlib
  124. return (matplotlib.__version__ >= LooseVersion('1.5') or
  125. matplotlib.__version__[0] == '0')
  126. except ImportError:
  127. return False
  128. if _mpl_ge_1_5_0():
  129. # Compat with mp 1.5, which uses cycler.
  130. import cycler
  131. colors = mpl_stylesheet.pop('axes.color_cycle')
  132. mpl_stylesheet['axes.prop_cycle'] = cycler.cycler('color', colors)
  133. def _get_standard_kind(kind):
  134. return {'density': 'kde'}.get(kind, kind)
  135. def _get_standard_colors(num_colors=None, colormap=None, color_type='default',
  136. color=None):
  137. import matplotlib.pyplot as plt
  138. if color is None and colormap is not None:
  139. if isinstance(colormap, compat.string_types):
  140. import matplotlib.cm as cm
  141. cmap = colormap
  142. colormap = cm.get_cmap(colormap)
  143. if colormap is None:
  144. raise ValueError("Colormap {0} is not recognized".format(cmap))
  145. colors = lmap(colormap, np.linspace(0, 1, num=num_colors))
  146. elif color is not None:
  147. if colormap is not None:
  148. warnings.warn("'color' and 'colormap' cannot be used "
  149. "simultaneously. Using 'color'")
  150. colors = list(color) if is_list_like(color) else color
  151. else:
  152. if color_type == 'default':
  153. # need to call list() on the result to copy so we don't
  154. # modify the global rcParams below
  155. try:
  156. colors = [c['color']
  157. for c in list(plt.rcParams['axes.prop_cycle'])]
  158. except KeyError:
  159. colors = list(plt.rcParams.get('axes.color_cycle',
  160. list('bgrcmyk')))
  161. if isinstance(colors, compat.string_types):
  162. colors = list(colors)
  163. elif color_type == 'random':
  164. import random
  165. def random_color(column):
  166. random.seed(column)
  167. return [random.random() for _ in range(3)]
  168. colors = lmap(random_color, lrange(num_colors))
  169. else:
  170. raise ValueError("color_type must be either 'default' or 'random'")
  171. if isinstance(colors, compat.string_types):
  172. import matplotlib.colors
  173. conv = matplotlib.colors.ColorConverter()
  174. def _maybe_valid_colors(colors):
  175. try:
  176. [conv.to_rgba(c) for c in colors]
  177. return True
  178. except ValueError:
  179. return False
  180. # check whether the string can be convertable to single color
  181. maybe_single_color = _maybe_valid_colors([colors])
  182. # check whether each character can be convertable to colors
  183. maybe_color_cycle = _maybe_valid_colors(list(colors))
  184. if maybe_single_color and maybe_color_cycle and len(colors) > 1:
  185. msg = ("'{0}' can be parsed as both single color and "
  186. "color cycle. Specify each color using a list "
  187. "like ['{0}'] or {1}")
  188. raise ValueError(msg.format(colors, list(colors)))
  189. elif maybe_single_color:
  190. colors = [colors]
  191. else:
  192. # ``colors`` is regarded as color cycle.
  193. # mpl will raise error any of them is invalid
  194. pass
  195. if len(colors) != num_colors:
  196. multiple = num_colors // len(colors) - 1
  197. mod = num_colors % len(colors)
  198. colors += multiple * colors
  199. colors += colors[:mod]
  200. return colors
  201. class _Options(dict):
  202. """
  203. Stores pandas plotting options.
  204. Allows for parameter aliasing so you can just use parameter names that are
  205. the same as the plot function parameters, but is stored in a canonical
  206. format that makes it easy to breakdown into groups later
  207. """
  208. # alias so the names are same as plotting method parameter names
  209. _ALIASES = {'x_compat': 'xaxis.compat'}
  210. _DEFAULT_KEYS = ['xaxis.compat']
  211. def __init__(self):
  212. self['xaxis.compat'] = False
  213. def __getitem__(self, key):
  214. key = self._get_canonical_key(key)
  215. if key not in self:
  216. raise ValueError('%s is not a valid pandas plotting option' % key)
  217. return super(_Options, self).__getitem__(key)
  218. def __setitem__(self, key, value):
  219. key = self._get_canonical_key(key)
  220. return super(_Options, self).__setitem__(key, value)
  221. def __delitem__(self, key):
  222. key = self._get_canonical_key(key)
  223. if key in self._DEFAULT_KEYS:
  224. raise ValueError('Cannot remove default parameter %s' % key)
  225. return super(_Options, self).__delitem__(key)
  226. def __contains__(self, key):
  227. key = self._get_canonical_key(key)
  228. return super(_Options, self).__contains__(key)
  229. def reset(self):
  230. """
  231. Reset the option store to its initial state
  232. Returns
  233. -------
  234. None
  235. """
  236. self.__init__()
  237. def _get_canonical_key(self, key):
  238. return self._ALIASES.get(key, key)
  239. @contextmanager
  240. def use(self, key, value):
  241. """
  242. Temporarily set a parameter value using the with statement.
  243. Aliasing allowed.
  244. """
  245. old_value = self[key]
  246. try:
  247. self[key] = value
  248. yield self
  249. finally:
  250. self[key] = old_value
  251. plot_params = _Options()
  252. def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
  253. diagonal='hist', marker='.', density_kwds=None,
  254. hist_kwds=None, range_padding=0.05, **kwds):
  255. """
  256. Draw a matrix of scatter plots.
  257. Parameters
  258. ----------
  259. frame : DataFrame
  260. alpha : float, optional
  261. amount of transparency applied
  262. figsize : (float,float), optional
  263. a tuple (width, height) in inches
  264. ax : Matplotlib axis object, optional
  265. grid : bool, optional
  266. setting this to True will show the grid
  267. diagonal : {'hist', 'kde'}
  268. pick between 'kde' and 'hist' for
  269. either Kernel Density Estimation or Histogram
  270. plot in the diagonal
  271. marker : str, optional
  272. Matplotlib marker type, default '.'
  273. hist_kwds : other plotting keyword arguments
  274. To be passed to hist function
  275. density_kwds : other plotting keyword arguments
  276. To be passed to kernel density estimate plot
  277. range_padding : float, optional
  278. relative extension of axis range in x and y
  279. with respect to (x_max - x_min) or (y_max - y_min),
  280. default 0.05
  281. kwds : other plotting keyword arguments
  282. To be passed to scatter function
  283. Examples
  284. --------
  285. >>> df = DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
  286. >>> scatter_matrix(df, alpha=0.2)
  287. """
  288. import matplotlib.pyplot as plt
  289. df = frame._get_numeric_data()
  290. n = df.columns.size
  291. naxes = n * n
  292. fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax,
  293. squeeze=False)
  294. # no gaps between subplots
  295. fig.subplots_adjust(wspace=0, hspace=0)
  296. mask = notnull(df)
  297. marker = _get_marker_compat(marker)
  298. hist_kwds = hist_kwds or {}
  299. density_kwds = density_kwds or {}
  300. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  301. kwds.setdefault('c', plt.rcParams['patch.facecolor'])
  302. boundaries_list = []
  303. for a in df.columns:
  304. values = df[a].values[mask[a].values]
  305. rmin_, rmax_ = np.min(values), np.max(values)
  306. rdelta_ext = (rmax_ - rmin_) * range_padding / 2.
  307. boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
  308. for i, a in zip(lrange(n), df.columns):
  309. for j, b in zip(lrange(n), df.columns):
  310. ax = axes[i, j]
  311. if i == j:
  312. values = df[a].values[mask[a].values]
  313. # Deal with the diagonal by drawing a histogram there.
  314. if diagonal == 'hist':
  315. ax.hist(values, **hist_kwds)
  316. elif diagonal in ('kde', 'density'):
  317. from scipy.stats import gaussian_kde
  318. y = values
  319. gkde = gaussian_kde(y)
  320. ind = np.linspace(y.min(), y.max(), 1000)
  321. ax.plot(ind, gkde.evaluate(ind), **density_kwds)
  322. ax.set_xlim(boundaries_list[i])
  323. else:
  324. common = (mask[a] & mask[b]).values
  325. ax.scatter(df[b][common], df[a][common],
  326. marker=marker, alpha=alpha, **kwds)
  327. ax.set_xlim(boundaries_list[j])
  328. ax.set_ylim(boundaries_list[i])
  329. ax.set_xlabel(b)
  330. ax.set_ylabel(a)
  331. if j != 0:
  332. ax.yaxis.set_visible(False)
  333. if i != n - 1:
  334. ax.xaxis.set_visible(False)
  335. if len(df.columns) > 1:
  336. lim1 = boundaries_list[0]
  337. locs = axes[0][1].yaxis.get_majorticklocs()
  338. locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
  339. adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
  340. lim0 = axes[0][0].get_ylim()
  341. adj = adj * (lim0[1] - lim0[0]) + lim0[0]
  342. axes[0][0].yaxis.set_ticks(adj)
  343. if np.all(locs == locs.astype(int)):
  344. # if all ticks are int
  345. locs = locs.astype(int)
  346. axes[0][0].yaxis.set_ticklabels(locs)
  347. _set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  348. return axes
  349. def _gca():
  350. import matplotlib.pyplot as plt
  351. return plt.gca()
  352. def _gcf():
  353. import matplotlib.pyplot as plt
  354. return plt.gcf()
  355. def _get_marker_compat(marker):
  356. import matplotlib.lines as mlines
  357. import matplotlib as mpl
  358. if mpl.__version__ < '1.1.0' and marker == '.':
  359. return 'o'
  360. if marker not in mlines.lineMarkers:
  361. return 'o'
  362. return marker
  363. def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
  364. """RadViz - a multivariate data visualization algorithm
  365. Parameters:
  366. -----------
  367. frame: DataFrame
  368. class_column: str
  369. Column name containing class names
  370. ax: Matplotlib axis object, optional
  371. color: list or tuple, optional
  372. Colors to use for the different classes
  373. colormap : str or matplotlib colormap object, default None
  374. Colormap to select colors from. If string, load colormap with that name
  375. from matplotlib.
  376. kwds: keywords
  377. Options to pass to matplotlib scatter plotting method
  378. Returns:
  379. --------
  380. ax: Matplotlib axis object
  381. """
  382. import matplotlib.pyplot as plt
  383. import matplotlib.patches as patches
  384. def normalize(series):
  385. a = min(series)
  386. b = max(series)
  387. return (series - a) / (b - a)
  388. n = len(frame)
  389. classes = frame[class_column].drop_duplicates()
  390. class_col = frame[class_column]
  391. df = frame.drop(class_column, axis=1).apply(normalize)
  392. if ax is None:
  393. ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
  394. to_plot = {}
  395. colors = _get_standard_colors(num_colors=len(classes), colormap=colormap,
  396. color_type='random', color=color)
  397. for kls in classes:
  398. to_plot[kls] = [[], []]
  399. m = len(frame.columns) - 1
  400. s = np.array([(np.cos(t), np.sin(t))
  401. for t in [2.0 * np.pi * (i / float(m))
  402. for i in range(m)]])
  403. for i in range(n):
  404. row = df.iloc[i].values
  405. row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
  406. y = (s * row_).sum(axis=0) / row.sum()
  407. kls = class_col.iat[i]
  408. to_plot[kls][0].append(y[0])
  409. to_plot[kls][1].append(y[1])
  410. for i, kls in enumerate(classes):
  411. ax.scatter(to_plot[kls][0], to_plot[kls][1], color=colors[i],
  412. label=pprint_thing(kls), **kwds)
  413. ax.legend()
  414. ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
  415. for xy, name in zip(s, df.columns):
  416. ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
  417. if xy[0] < 0.0 and xy[1] < 0.0:
  418. ax.text(xy[0] - 0.025, xy[1] - 0.025, name,
  419. ha='right', va='top', size='small')
  420. elif xy[0] < 0.0 and xy[1] >= 0.0:
  421. ax.text(xy[0] - 0.025, xy[1] + 0.025, name,
  422. ha='right', va='bottom', size='small')
  423. elif xy[0] >= 0.0 and xy[1] < 0.0:
  424. ax.text(xy[0] + 0.025, xy[1] - 0.025, name,
  425. ha='left', va='top', size='small')
  426. elif xy[0] >= 0.0 and xy[1] >= 0.0:
  427. ax.text(xy[0] + 0.025, xy[1] + 0.025, name,
  428. ha='left', va='bottom', size='small')
  429. ax.axis('equal')
  430. return ax
  431. @deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
  432. def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
  433. colormap=None, **kwds):
  434. """
  435. Generates a matplotlib plot of Andrews curves, for visualising clusters of
  436. multivariate data.
  437. Andrews curves have the functional form:
  438. f(t) = x_1/sqrt(2) + x_2 sin(t) + x_3 cos(t) +
  439. x_4 sin(2t) + x_5 cos(2t) + ...
  440. Where x coefficients correspond to the values of each dimension and t is
  441. linearly spaced between -pi and +pi. Each row of frame then corresponds to
  442. a single curve.
  443. Parameters:
  444. -----------
  445. frame : DataFrame
  446. Data to be plotted, preferably normalized to (0.0, 1.0)
  447. class_column : Name of the column containing class names
  448. ax : matplotlib axes object, default None
  449. samples : Number of points to plot in each curve
  450. color: list or tuple, optional
  451. Colors to use for the different classes
  452. colormap : str or matplotlib colormap object, default None
  453. Colormap to select colors from. If string, load colormap with that name
  454. from matplotlib.
  455. kwds: keywords
  456. Options to pass to matplotlib plotting method
  457. Returns:
  458. --------
  459. ax: Matplotlib axis object
  460. """
  461. from math import sqrt, pi
  462. import matplotlib.pyplot as plt
  463. def function(amplitudes):
  464. def f(t):
  465. x1 = amplitudes[0]
  466. result = x1 / sqrt(2.0)
  467. # Take the rest of the coefficients and resize them
  468. # appropriately. Take a copy of amplitudes as otherwise numpy
  469. # deletes the element from amplitudes itself.
  470. coeffs = np.delete(np.copy(amplitudes), 0)
  471. coeffs.resize(int((coeffs.size + 1) / 2), 2)
  472. # Generate the harmonics and arguments for the sin and cos
  473. # functions.
  474. harmonics = np.arange(0, coeffs.shape[0]) + 1
  475. trig_args = np.outer(harmonics, t)
  476. result += np.sum(coeffs[:, 0, np.newaxis] * np.sin(trig_args) +
  477. coeffs[:, 1, np.newaxis] * np.cos(trig_args),
  478. axis=0)
  479. return result
  480. return f
  481. n = len(frame)
  482. class_col = frame[class_column]
  483. classes = frame[class_column].drop_duplicates()
  484. df = frame.drop(class_column, axis=1)
  485. t = np.linspace(-pi, pi, samples)
  486. used_legends = set([])
  487. color_values = _get_standard_colors(num_colors=len(classes),
  488. colormap=colormap, color_type='random',
  489. color=color)
  490. colors = dict(zip(classes, color_values))
  491. if ax is None:
  492. ax = plt.gca(xlim=(-pi, pi))
  493. for i in range(n):
  494. row = df.iloc[i].values
  495. f = function(row)
  496. y = f(t)
  497. kls = class_col.iat[i]
  498. label = pprint_thing(kls)
  499. if label not in used_legends:
  500. used_legends.add(label)
  501. ax.plot(t, y, color=colors[kls], label=label, **kwds)
  502. else:
  503. ax.plot(t, y, color=colors[kls], **kwds)
  504. ax.legend(loc='upper right')
  505. ax.grid()
  506. return ax
  507. def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
  508. """Bootstrap plot.
  509. Parameters:
  510. -----------
  511. series: Time series
  512. fig: matplotlib figure object, optional
  513. size: number of data points to consider during each sampling
  514. samples: number of times the bootstrap procedure is performed
  515. kwds: optional keyword arguments for plotting commands, must be accepted
  516. by both hist and plot
  517. Returns:
  518. --------
  519. fig: matplotlib figure
  520. """
  521. import random
  522. import matplotlib.pyplot as plt
  523. # random.sample(ndarray, int) fails on python 3.3, sigh
  524. data = list(series.values)
  525. samplings = [random.sample(data, size) for _ in range(samples)]
  526. means = np.array([np.mean(sampling) for sampling in samplings])
  527. medians = np.array([np.median(sampling) for sampling in samplings])
  528. midranges = np.array([(min(sampling) + max(sampling)) * 0.5
  529. for sampling in samplings])
  530. if fig is None:
  531. fig = plt.figure()
  532. x = lrange(samples)
  533. axes = []
  534. ax1 = fig.add_subplot(2, 3, 1)
  535. ax1.set_xlabel("Sample")
  536. axes.append(ax1)
  537. ax1.plot(x, means, **kwds)
  538. ax2 = fig.add_subplot(2, 3, 2)
  539. ax2.set_xlabel("Sample")
  540. axes.append(ax2)
  541. ax2.plot(x, medians, **kwds)
  542. ax3 = fig.add_subplot(2, 3, 3)
  543. ax3.set_xlabel("Sample")
  544. axes.append(ax3)
  545. ax3.plot(x, midranges, **kwds)
  546. ax4 = fig.add_subplot(2, 3, 4)
  547. ax4.set_xlabel("Mean")
  548. axes.append(ax4)
  549. ax4.hist(means, **kwds)
  550. ax5 = fig.add_subplot(2, 3, 5)
  551. ax5.set_xlabel("Median")
  552. axes.append(ax5)
  553. ax5.hist(medians, **kwds)
  554. ax6 = fig.add_subplot(2, 3, 6)
  555. ax6.set_xlabel("Midrange")
  556. axes.append(ax6)
  557. ax6.hist(midranges, **kwds)
  558. for axis in axes:
  559. plt.setp(axis.get_xticklabels(), fontsize=8)
  560. plt.setp(axis.get_yticklabels(), fontsize=8)
  561. return fig
  562. @deprecate_kwarg(old_arg_name='colors', new_arg_name='color')
  563. @deprecate_kwarg(old_arg_name='data', new_arg_name='frame', stacklevel=3)
  564. def parallel_coordinates(frame, class_column, cols=None, ax=None, color=None,
  565. use_columns=False, xticks=None, colormap=None,
  566. axvlines=True, axvlines_kwds=None, **kwds):
  567. """Parallel coordinates plotting.
  568. Parameters
  569. ----------
  570. frame: DataFrame
  571. class_column: str
  572. Column name containing class names
  573. cols: list, optional
  574. A list of column names to use
  575. ax: matplotlib.axis, optional
  576. matplotlib axis object
  577. color: list or tuple, optional
  578. Colors to use for the different classes
  579. use_columns: bool, optional
  580. If true, columns will be used as xticks
  581. xticks: list or tuple, optional
  582. A list of values to use for xticks
  583. colormap: str or matplotlib colormap, default None
  584. Colormap to use for line colors.
  585. axvlines: bool, optional
  586. If true, vertical lines will be added at each xtick
  587. axvlines_kwds: keywords, optional
  588. Options to be passed to axvline method for vertical lines
  589. kwds: keywords
  590. Options to pass to matplotlib plotting method
  591. Returns
  592. -------
  593. ax: matplotlib axis object
  594. Examples
  595. --------
  596. >>> from pandas import read_csv
  597. >>> from pandas.tools.plotting import parallel_coordinates
  598. >>> from matplotlib import pyplot as plt
  599. >>> df = read_csv('https://raw.github.com/pydata/pandas/master'
  600. '/pandas/tests/data/iris.csv')
  601. >>> parallel_coordinates(df, 'Name', color=('#556270',
  602. '#4ECDC4', '#C7F464'))
  603. >>> plt.show()
  604. """
  605. if axvlines_kwds is None:
  606. axvlines_kwds = {'linewidth': 1, 'color': 'black'}
  607. import matplotlib.pyplot as plt
  608. n = len(frame)
  609. classes = frame[class_column].drop_duplicates()
  610. class_col = frame[class_column]
  611. if cols is None:
  612. df = frame.drop(class_column, axis=1)
  613. else:
  614. df = frame[cols]
  615. used_legends = set([])
  616. ncols = len(df.columns)
  617. # determine values to use for xticks
  618. if use_columns is True:
  619. if not np.all(np.isreal(list(df.columns))):
  620. raise ValueError('Columns must be numeric to be used as xticks')
  621. x = df.columns
  622. elif xticks is not None:
  623. if not np.all(np.isreal(xticks)):
  624. raise ValueError('xticks specified must be numeric')
  625. elif len(xticks) != ncols:
  626. raise ValueError('Length of xticks must match number of columns')
  627. x = xticks
  628. else:
  629. x = lrange(ncols)
  630. if ax is None:
  631. ax = plt.gca()
  632. color_values = _get_standard_colors(num_colors=len(classes),
  633. colormap=colormap, color_type='random',
  634. color=color)
  635. colors = dict(zip(classes, color_values))
  636. for i in range(n):
  637. y = df.iloc[i].values
  638. kls = class_col.iat[i]
  639. label = pprint_thing(kls)
  640. if label not in used_legends:
  641. used_legends.add(label)
  642. ax.plot(x, y, color=colors[kls], label=label, **kwds)
  643. else:
  644. ax.plot(x, y, color=colors[kls], **kwds)
  645. if axvlines:
  646. for i in x:
  647. ax.axvline(i, **axvlines_kwds)
  648. ax.set_xticks(x)
  649. ax.set_xticklabels(df.columns)
  650. ax.set_xlim(x[0], x[-1])
  651. ax.legend(loc='upper right')
  652. ax.grid()
  653. return ax
  654. def lag_plot(series, lag=1, ax=None, **kwds):
  655. """Lag plot for time series.
  656. Parameters:
  657. -----------
  658. series: Time series
  659. lag: lag of the scatter plot, default 1
  660. ax: Matplotlib axis object, optional
  661. kwds: Matplotlib scatter method keyword arguments, optional
  662. Returns:
  663. --------
  664. ax: Matplotlib axis object
  665. """
  666. import matplotlib.pyplot as plt
  667. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  668. kwds.setdefault('c', plt.rcParams['patch.facecolor'])
  669. data = series.values
  670. y1 = data[:-lag]
  671. y2 = data[lag:]
  672. if ax is None:
  673. ax = plt.gca()
  674. ax.set_xlabel("y(t)")
  675. ax.set_ylabel("y(t + %s)" % lag)
  676. ax.scatter(y1, y2, **kwds)
  677. return ax
  678. def autocorrelation_plot(series, ax=None, **kwds):
  679. """Autocorrelation plot for time series.
  680. Parameters:
  681. -----------
  682. series: Time series
  683. ax: Matplotlib axis object, optional
  684. kwds : keywords
  685. Options to pass to matplotlib plotting method
  686. Returns:
  687. -----------
  688. ax: Matplotlib axis object
  689. """
  690. import matplotlib.pyplot as plt
  691. n = len(series)
  692. data = np.asarray(series)
  693. if ax is None:
  694. ax = plt.gca(xlim=(1, n), ylim=(-1.0, 1.0))
  695. mean = np.mean(data)
  696. c0 = np.sum((data - mean) ** 2) / float(n)
  697. def r(h):
  698. return ((data[:n - h] - mean) *
  699. (data[h:] - mean)).sum() / float(n) / c0
  700. x = np.arange(n) + 1
  701. y = lmap(r, x)
  702. z95 = 1.959963984540054
  703. z99 = 2.5758293035489004
  704. ax.axhline(y=z99 / np.sqrt(n), linestyle='--', color='grey')
  705. ax.axhline(y=z95 / np.sqrt(n), color='grey')
  706. ax.axhline(y=0.0, color='black')
  707. ax.axhline(y=-z95 / np.sqrt(n), color='grey')
  708. ax.axhline(y=-z99 / np.sqrt(n), linestyle='--', color='grey')
  709. ax.set_xlabel("Lag")
  710. ax.set_ylabel("Autocorrelation")
  711. ax.plot(x, y, **kwds)
  712. if 'label' in kwds:
  713. ax.legend()
  714. ax.grid()
  715. return ax
  716. class MPLPlot(object):
  717. """
  718. Base class for assembling a pandas plot using matplotlib
  719. Parameters
  720. ----------
  721. data :
  722. """
  723. @property
  724. def _kind(self):
  725. """Specify kind str. Must be overridden in child class"""
  726. raise NotImplementedError
  727. _layout_type = 'vertical'
  728. _default_rot = 0
  729. orientation = None
  730. _pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog',
  731. 'mark_right', 'stacked']
  732. _attr_defaults = {'logy': False, 'logx': False, 'loglog': False,
  733. 'mark_right': True, 'stacked': False}
  734. def __init__(self, data, kind=None, by=None, subplots=False, sharex=None,
  735. sharey=False, use_index=True,
  736. figsize=None, grid=None, legend=True, rot=None,
  737. ax=None, fig=None, title=None, xlim=None, ylim=None,
  738. xticks=None, yticks=None,
  739. sort_columns=False, fontsize=None,
  740. secondary_y=False, colormap=None,
  741. table=False, layout=None, **kwds):
  742. self.data = data
  743. self.by = by
  744. self.kind = kind
  745. self.sort_columns = sort_columns
  746. self.subplots = subplots
  747. if sharex is None:
  748. if ax is None:
  749. self.sharex = True
  750. else:
  751. # if we get an axis, the users should do the visibility
  752. # setting...
  753. self.sharex = False
  754. else:
  755. self.sharex = sharex
  756. self.sharey = sharey
  757. self.figsize = figsize
  758. self.layout = layout
  759. self.xticks = xticks
  760. self.yticks = yticks
  761. self.xlim = xlim
  762. self.ylim = ylim
  763. self.title = title
  764. self.use_index = use_index
  765. self.fontsize = fontsize
  766. if rot is not None:
  767. self.rot = rot
  768. # need to know for format_date_labels since it's rotated to 30 by
  769. # default
  770. self._rot_set = True
  771. else:
  772. self._rot_set = False
  773. self.rot = self._default_rot
  774. if grid is None:
  775. grid = False if secondary_y else self.plt.rcParams['axes.grid']
  776. self.grid = grid
  777. self.legend = legend
  778. self.legend_handles = []
  779. self.legend_labels = []
  780. for attr in self._pop_attributes:
  781. value = kwds.pop(attr, self._attr_defaults.get(attr, None))
  782. setattr(self, attr, value)
  783. self.ax = ax
  784. self.fig = fig
  785. self.axes = None
  786. # parse errorbar input if given
  787. xerr = kwds.pop('xerr', None)
  788. yerr = kwds.pop('yerr', None)
  789. self.errors = {}
  790. for kw, err in zip(['xerr', 'yerr'], [xerr, yerr]):
  791. self.errors[kw] = self._parse_errorbars(kw, err)
  792. if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, Index)):
  793. secondary_y = [secondary_y]
  794. self.secondary_y = secondary_y
  795. # ugly TypeError if user passes matplotlib's `cmap` name.
  796. # Probably better to accept either.
  797. if 'cmap' in kwds and colormap:
  798. raise TypeError("Only specify one of `cmap` and `colormap`.")
  799. elif 'cmap' in kwds:
  800. self.colormap = kwds.pop('cmap')
  801. else:
  802. self.colormap = colormap
  803. self.table = table
  804. self.kwds = kwds
  805. self._validate_color_args()
  806. def _validate_color_args(self):
  807. if 'color' not in self.kwds and 'colors' in self.kwds:
  808. warnings.warn(("'colors' is being deprecated. Please use 'color'"
  809. "instead of 'colors'"))
  810. colors = self.kwds.pop('colors')
  811. self.kwds['color'] = colors
  812. if ('color' in self.kwds and self.nseries == 1):
  813. # support series.plot(color='green')
  814. self.kwds['color'] = [self.kwds['color']]
  815. if ('color' in self.kwds or 'colors' in self.kwds) and \
  816. self.colormap is not None:
  817. warnings.warn("'color' and 'colormap' cannot be used "
  818. "simultaneously. Using 'color'")
  819. if 'color' in self.kwds and self.style is not None:
  820. if is_list_like(self.style):
  821. styles = self.style
  822. else:
  823. styles = [self.style]
  824. # need only a single match
  825. for s in styles:
  826. if re.match('^[a-z]+?', s) is not None:
  827. raise ValueError(
  828. "Cannot pass 'style' string with a color "
  829. "symbol and 'color' keyword argument. Please"
  830. " use one or the other or pass 'style' "
  831. "without a color symbol")
  832. def _iter_data(self, data=None, keep_index=False, fillna=None):
  833. if data is None:
  834. data = self.data
  835. if fillna is not None:
  836. data = data.fillna(fillna)
  837. # TODO: unused?
  838. # if self.sort_columns:
  839. # columns = _try_sort(data.columns)
  840. # else:
  841. # columns = data.columns
  842. for col, values in data.iteritems():
  843. if keep_index is True:
  844. yield col, values
  845. else:
  846. yield col, values.values
  847. @property
  848. def nseries(self):
  849. if self.data.ndim == 1:
  850. return 1
  851. else:
  852. return self.data.shape[1]
  853. def draw(self):
  854. self.plt.draw_if_interactive()
  855. def generate(self):
  856. self._args_adjust()
  857. self._compute_plot_data()
  858. self._setup_subplots()
  859. self._make_plot()
  860. self._add_table()
  861. self._make_legend()
  862. self._adorn_subplots()
  863. for ax in self.axes:
  864. self._post_plot_logic_common(ax, self.data)
  865. self._post_plot_logic(ax, self.data)
  866. def _args_adjust(self):
  867. pass
  868. def _has_plotted_object(self, ax):
  869. """check whether ax has data"""
  870. return (len(ax.lines) != 0 or
  871. len(ax.artists) != 0 or
  872. len(ax.containers) != 0)
  873. def _maybe_right_yaxis(self, ax, axes_num):
  874. if not self.on_right(axes_num):
  875. # secondary axes may be passed via ax kw
  876. return self._get_ax_layer(ax)
  877. if hasattr(ax, 'right_ax'):
  878. # if it has right_ax proparty, ``ax`` must be left axes
  879. return ax.right_ax
  880. elif hasattr(ax, 'left_ax'):
  881. # if it has left_ax proparty, ``ax`` must be right axes
  882. return ax
  883. else:
  884. # otherwise, create twin axes
  885. orig_ax, new_ax = ax, ax.twinx()
  886. # TODO: use Matplotlib public API when available
  887. new_ax._get_lines = orig_ax._get_lines
  888. new_ax._get_patches_for_fill = orig_ax._get_patches_for_fill
  889. orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
  890. if not self._has_plotted_object(orig_ax): # no data on left y
  891. orig_ax.get_yaxis().set_visible(False)
  892. return new_ax
  893. def _setup_subplots(self):
  894. if self.subplots:
  895. fig, axes = _subplots(naxes=self.nseries,
  896. sharex=self.sharex, sharey=self.sharey,
  897. figsize=self.figsize, ax=self.ax,
  898. layout=self.layout,
  899. layout_type=self._layout_type)
  900. else:
  901. if self.ax is None:
  902. fig = self.plt.figure(figsize=self.figsize)
  903. axes = fig.add_subplot(111)
  904. else:
  905. fig = self.ax.get_figure()
  906. if self.figsize is not None:
  907. fig.set_size_inches(self.figsize)
  908. axes = self.ax
  909. axes = _flatten(axes)
  910. if self.logx or self.loglog:
  911. [a.set_xscale('log') for a in axes]
  912. if self.logy or self.loglog:
  913. [a.set_yscale('log') for a in axes]
  914. self.fig = fig
  915. self.axes = axes
  916. @property
  917. def result(self):
  918. """
  919. Return result axes
  920. """
  921. if self.subplots:
  922. if self.layout is not None and not is_list_like(self.ax):
  923. return self.axes.reshape(*self.layout)
  924. else:
  925. return self.axes
  926. else:
  927. sec_true = isinstance(self.secondary_y, bool) and self.secondary_y
  928. all_sec = (is_list_like(self.secondary_y) and
  929. len(self.secondary_y) == self.nseries)
  930. if (sec_true or all_sec):
  931. # if all data is plotted on secondary, return right axes
  932. return self._get_ax_layer(self.axes[0], primary=False)
  933. else:
  934. return self.axes[0]
  935. def _compute_plot_data(self):
  936. data = self.data
  937. if isinstance(data, Series):
  938. label = self.label
  939. if label is None and data.name is None:
  940. label = 'None'
  941. data = data.to_frame(name=label)
  942. numeric_data = data._convert(datetime=True)._get_numeric_data()
  943. try:
  944. is_empty = numeric_data.empty
  945. except AttributeError:
  946. is_empty = not len(numeric_data)
  947. # no empty frames or series allowed
  948. if is_empty:
  949. raise TypeError('Empty {0!r}: no numeric data to '
  950. 'plot'.format(numeric_data.__class__.__name__))
  951. self.data = numeric_data
  952. def _make_plot(self):
  953. raise AbstractMethodError(self)
  954. def _add_table(self):
  955. if self.table is False:
  956. return
  957. elif self.table is True:
  958. data = self.data.transpose()
  959. else:
  960. data = self.table
  961. ax = self._get_ax(0)
  962. table(ax, data)
  963. def _post_plot_logic_common(self, ax, data):
  964. """Common post process for each axes"""
  965. labels = [pprint_thing(key) for key in data.index]
  966. labels = dict(zip(range(len(data.index)), labels))
  967. if self.orientation == 'vertical' or self.orientation is None:
  968. if self._need_to_set_index:
  969. xticklabels = [labels.get(x, '') for x in ax.get_xticks()]
  970. ax.set_xticklabels(xticklabels)
  971. self._apply_axis_properties(ax.xaxis, rot=self.rot,
  972. fontsize=self.fontsize)
  973. self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
  974. elif self.orientation == 'horizontal':
  975. if self._need_to_set_index:
  976. yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
  977. ax.set_yticklabels(yticklabels)
  978. self._apply_axis_properties(ax.yaxis, rot=self.rot,
  979. fontsize=self.fontsize)
  980. self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)
  981. else: # pragma no cover
  982. raise ValueError
  983. def _post_plot_logic(self, ax, data):
  984. """Post process for each axes. Overridden in child classes"""
  985. pass
  986. def _adorn_subplots(self):
  987. """Common post process unrelated to data"""
  988. if len(self.axes) > 0:
  989. all_axes = self._get_subplots()
  990. nrows, ncols = self._get_axes_layout()
  991. _handle_shared_axes(axarr=all_axes, nplots=len(all_axes),
  992. naxes=nrows * ncols, nrows=nrows,
  993. ncols=ncols, sharex=self.sharex,
  994. sharey=self.sharey)
  995. for ax in self.axes:
  996. if self.yticks is not None:
  997. ax.set_yticks(self.yticks)
  998. if self.xticks is not None:
  999. ax.set_xticks(self.xticks)
  1000. if self.ylim is not None:
  1001. ax.set_ylim(self.ylim)
  1002. if self.xlim is not None:
  1003. ax.set_xlim(self.xlim)
  1004. ax.grid(self.grid)
  1005. if self.title:
  1006. if self.subplots:
  1007. self.fig.suptitle(self.title)
  1008. else:
  1009. self.axes[0].set_title(self.title)
  1010. def _apply_axis_properties(self, axis, rot=None, fontsize=None):
  1011. labels = axis.get_majorticklabels() + axis.get_minorticklabels()
  1012. for label in labels:
  1013. if rot is not None:
  1014. label.set_rotation(rot)
  1015. if fontsize is not None:
  1016. label.set_fontsize(fontsize)
  1017. @property
  1018. def legend_title(self):
  1019. if not isinstance(self.data.columns, MultiIndex):
  1020. name = self.data.columns.name
  1021. if name is not None:
  1022. name = pprint_thing(name)
  1023. return name
  1024. else:
  1025. stringified = map(pprint_thing,
  1026. self.data.columns.names)
  1027. return ','.join(stringified)
  1028. def _add_legend_handle(self, handle, label, index=None):
  1029. if label is not None:
  1030. if self.mark_right and index is not None:
  1031. if self.on_right(index):
  1032. label = label + ' (right)'
  1033. self.legend_handles.append(handle)
  1034. self.legend_labels.append(label)
  1035. def _make_legend(self):
  1036. ax, leg = self._get_ax_legend(self.axes[0])
  1037. handles = []
  1038. labels = []
  1039. title = ''
  1040. if not self.subplots:
  1041. if leg is not None:
  1042. title = leg.get_title().get_text()
  1043. handles = leg.legendHandles
  1044. labels = [x.get_text() for x in leg.get_texts()]
  1045. if self.legend:
  1046. if self.legend == 'reverse':
  1047. self.legend_handles = reversed(self.legend_handles)
  1048. self.legend_labels = reversed(self.legend_labels)
  1049. handles += self.legend_handles
  1050. labels += self.legend_labels
  1051. if self.legend_title is not None:
  1052. title = self.legend_title
  1053. if len(handles) > 0:
  1054. ax.legend(handles, labels, loc='best', title=title)
  1055. elif self.subplots and self.legend:
  1056. for ax in self.axes:
  1057. if ax.get_visible():
  1058. ax.legend(loc='best')
  1059. def _get_ax_legend(self, ax):
  1060. leg = ax.get_legend()
  1061. other_ax = (getattr(ax, 'left_ax', None) or
  1062. getattr(ax, 'right_ax', None))
  1063. other_leg = None
  1064. if other_ax is not None:
  1065. other_leg = other_ax.get_legend()
  1066. if leg is None and other_leg is not None:
  1067. leg = other_leg
  1068. ax = other_ax
  1069. return ax, leg
  1070. @cache_readonly
  1071. def plt(self):
  1072. import matplotlib.pyplot as plt
  1073. return plt
  1074. @staticmethod
  1075. def mpl_ge_1_3_1():
  1076. return _mpl_ge_1_3_1()
  1077. @staticmethod
  1078. def mpl_ge_1_5_0():
  1079. return _mpl_ge_1_5_0()
  1080. _need_to_set_index = False
  1081. def _get_xticks(self, convert_period=False):
  1082. index = self.data.index
  1083. is_datetype = index.inferred_type in ('datetime', 'date',
  1084. 'datetime64', 'time')
  1085. if self.use_index:
  1086. if convert_period and isinstance(index, PeriodIndex):
  1087. self.data = self.data.reindex(index=index.sort_values())
  1088. x = self.data.index.to_timestamp()._mpl_repr()
  1089. elif index.is_numeric():
  1090. """
  1091. Matplotlib supports numeric values or datetime objects as
  1092. xaxis values. Taking LBYL approach here, by the time
  1093. matplotlib raises exception when using non numeric/datetime
  1094. values for xaxis, several actions are already taken by plt.
  1095. """
  1096. x = index._mpl_repr()
  1097. elif is_datetype:
  1098. self.data = self.data.sort_index()
  1099. x = self.data.index._mpl_repr()
  1100. else:
  1101. self._need_to_set_index = True
  1102. x = lrange(len(index))
  1103. else:
  1104. x = lrange(len(index))
  1105. return x
  1106. @classmethod
  1107. def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds):
  1108. mask = isnull(y)
  1109. if mask.any():
  1110. y = np.ma.array(y)
  1111. y = np.ma.masked_where(mask, y)
  1112. if isinstance(x, Index):
  1113. x = x._mpl_repr()
  1114. if is_errorbar:
  1115. if 'xerr' in kwds:
  1116. kwds['xerr'] = np.array(kwds.get('xerr'))
  1117. if 'yerr' in kwds:
  1118. kwds['yerr'] = np.array(kwds.get('yerr'))
  1119. return ax.errorbar(x, y, **kwds)
  1120. else:
  1121. # prevent style kwarg from going to errorbar, where it is
  1122. # unsupported
  1123. if style is not None:
  1124. args = (x, y, style)
  1125. else:
  1126. args = (x, y)
  1127. return ax.plot(*args, **kwds)
  1128. def _get_index_name(self):
  1129. if isinstance(self.data.index, MultiIndex):
  1130. name = self.data.index.names
  1131. if any(x is not None for x in name):
  1132. name = ','.join([pprint_thing(x) for x in name])
  1133. else:
  1134. name = None
  1135. else:
  1136. name = self.data.index.name
  1137. if name is not None:
  1138. name = pprint_thing(name)
  1139. return name
  1140. @classmethod
  1141. def _get_ax_layer(cls, ax, primary=True):
  1142. """get left (primary) or right (secondary) axes"""
  1143. if primary:
  1144. return getattr(ax, 'left_ax', ax)
  1145. else:
  1146. return getattr(ax, 'right_ax', ax)
  1147. def _get_ax(self, i):
  1148. # get the twinx ax if appropriate
  1149. if self.subplots:
  1150. ax = self.axes[i]
  1151. ax = self._maybe_right_yaxis(ax, i)
  1152. self.axes[i] = ax
  1153. else:
  1154. ax = self.axes[0]
  1155. ax = self._maybe_right_yaxis(ax, i)
  1156. ax.get_yaxis().set_visible(True)
  1157. return ax
  1158. def on_right(self, i):
  1159. if isinstance(self.secondary_y, bool):
  1160. return self.secondary_y
  1161. if isinstance(self.secondary_y, (tuple, list, np.ndarray, Index)):
  1162. return self.data.columns[i] in self.secondary_y
  1163. def _apply_style_colors(self, colors, kwds, col_num, label):
  1164. """
  1165. Manage style and color based on column number and its label.
  1166. Returns tuple of appropriate style and kwds which "color" may be added.
  1167. """
  1168. style = None
  1169. if self.style is not None:
  1170. if isinstance(self.style, list):
  1171. try:
  1172. style = self.style[col_num]
  1173. except IndexError:
  1174. pass
  1175. elif isinstance(self.style, dict):
  1176. style = self.style.get(label, style)
  1177. else:
  1178. style = self.style
  1179. has_color = 'color' in kwds or self.colormap is not None
  1180. nocolor_style = style is None or re.match('[a-z]+', style) is None
  1181. if (has_color or self.subplots) and nocolor_style:
  1182. kwds['color'] = colors[col_num % len(colors)]
  1183. return style, kwds
  1184. def _get_colors(self, num_colors=None, color_kwds='color'):
  1185. if num_colors is None:
  1186. num_colors = self.nseries
  1187. return _get_standard_colors(num_colors=num_colors,
  1188. colormap=self.colormap,
  1189. color=self.kwds.get(color_kwds))
  1190. def _parse_errorbars(self, label, err):
  1191. """
  1192. Look for error keyword arguments and return the actual errorbar data
  1193. or return the error DataFrame/dict
  1194. Error bars can be specified in several ways:
  1195. Series: the user provides a pandas.Series object of the same
  1196. length as the data
  1197. ndarray: provides a np.ndarray of the same length as the data
  1198. DataFrame/dict: error values are paired with keys matching the
  1199. key in the plotted DataFrame
  1200. str: the name of the column within the plotted DataFrame
  1201. """
  1202. if err is None:
  1203. return None
  1204. from pandas import DataFrame, Series
  1205. def match_labels(data, e):
  1206. e = e.reindex_axis(data.index)
  1207. return e
  1208. # key-matched DataFrame
  1209. if isinstance(err, DataFrame):
  1210. err = match_labels(self.data, err)
  1211. # key-matched dict
  1212. elif isinstance(err, dict):
  1213. pass
  1214. # Series of error values
  1215. elif isinstance(err, Series):
  1216. # broadcast error series across data
  1217. err = match_labels(self.data, err)
  1218. err = np.atleast_2d(err)
  1219. err = np.tile(err, (self.nseries, 1))
  1220. # errors are a column in the dataframe
  1221. elif isinstance(err, string_types):
  1222. evalues = self.data[err].values
  1223. self.data = self.data[self.data.columns.drop(err)]
  1224. err = np.atleast_2d(evalues)
  1225. err = np.tile(err, (self.nseries, 1))
  1226. elif is_list_like(err):
  1227. if is_iterator(err):
  1228. err = np.atleast_2d(list(err))
  1229. else:
  1230. # raw error values
  1231. err = np.atleast_2d(err)
  1232. err_shape = err.shape
  1233. # asymmetrical error bars
  1234. if err.ndim == 3:
  1235. if (err_shape[0] != self.nseries) or \
  1236. (err_shape[1] != 2) or \
  1237. (err_shape[2] != len(self.data)):
  1238. msg = "Asymmetrical error bars should be provided " + \
  1239. "with the shape (%u, 2, %u)" % \
  1240. (self.nseries, len(self.data))
  1241. raise ValueError(msg)
  1242. # broadcast errors to each data series
  1243. if len(err) == 1:
  1244. err = np.tile(err, (self.nseries, 1))
  1245. elif is_number(err):
  1246. err = np.tile([err], (self.nseries, len(self.data)))
  1247. else:
  1248. msg = "No valid %s detected" % label
  1249. raise ValueError(msg)
  1250. return err
  1251. def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
  1252. from pandas import DataFrame
  1253. errors = {}
  1254. for kw, flag in zip(['xerr', 'yerr'], [xerr, yerr]):
  1255. if flag:
  1256. err = self.errors[kw]
  1257. # user provided label-matched dataframe of errors
  1258. if isinstance(err, (DataFrame, dict)):
  1259. if label is not None and label in err.keys():
  1260. err = err[label]
  1261. else:
  1262. err = None
  1263. elif index is not None and err is not None:
  1264. err = err[index]
  1265. if err is not None:
  1266. errors[kw] = err
  1267. return errors
  1268. def _get_subplots(self):
  1269. from matplotlib.axes import Subplot
  1270. return [ax for ax in self.axes[0].get_figure().get_axes()
  1271. if isinstance(ax, Subplot)]
  1272. def _get_axes_layout(self):
  1273. axes = self._get_

Large files files are truncated, but you can click here to view the full file