/pandas/tools/plotting.py
Python | 3075 lines | 2730 code | 172 blank | 173 comment | 229 complexity | fb838f97cf0225a6a627aac3491130f7 MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0
Large files files are truncated, but you can click here to view the full file
- # being a bit too dynamic
- # pylint: disable=E1101
- import datetime
- import warnings
- import re
- from collections import namedtuple
- from contextlib import contextmanager
- from distutils.version import LooseVersion
- import numpy as np
- from pandas.util.decorators import cache_readonly, deprecate_kwarg
- import pandas.core.common as com
- from pandas.core.generic import _shared_docs, _shared_doc_kwargs
- from pandas.core.index import MultiIndex
- from pandas.core.series import Series, remove_na
- from pandas.tseries.index import DatetimeIndex
- from pandas.tseries.period import PeriodIndex, Period
- from pandas.tseries.frequencies import get_period_alias, get_base_alias
- from pandas.tseries.offsets import DateOffset
- from pandas.compat import range, lrange, lmap, map, zip, string_types
- import pandas.compat as compat
- from pandas.util.decorators import Appender
- try: # mpl optional
- import pandas.tseries.converter as conv
- conv.register() # needs to override so set_xlim works with str/number
- except ImportError:
- pass
- # Extracted from https://gist.github.com/huyng/816622
- # this is the rcParams set when setting display.with_mpl_style
- # to True.
- mpl_stylesheet = {
- 'axes.axisbelow': True,
- 'axes.color_cycle': ['#348ABD',
- '#7A68A6',
- '#A60628',
- '#467821',
- '#CF4457',
- '#188487',
- '#E24A33'],
- 'axes.edgecolor': '#bcbcbc',
- 'axes.facecolor': '#eeeeee',
- 'axes.grid': True,
- 'axes.labelcolor': '#555555',
- 'axes.labelsize': 'large',
- 'axes.linewidth': 1.0,
- 'axes.titlesize': 'x-large',
- 'figure.edgecolor': 'white',
- 'figure.facecolor': 'white',
- 'figure.figsize': (6.0, 4.0),
- 'figure.subplot.hspace': 0.5,
- 'font.family': 'monospace',
- 'font.monospace': ['Andale Mono',
- 'Nimbus Mono L',
- 'Courier New',
- 'Courier',
- 'Fixed',
- 'Terminal',
- 'monospace'],
- 'font.size': 10,
- 'interactive': True,
- 'keymap.all_axes': ['a'],
- 'keymap.back': ['left', 'c', 'backspace'],
- 'keymap.forward': ['right', 'v'],
- 'keymap.fullscreen': ['f'],
- 'keymap.grid': ['g'],
- 'keymap.home': ['h', 'r', 'home'],
- 'keymap.pan': ['p'],
- 'keymap.save': ['s'],
- 'keymap.xscale': ['L', 'k'],
- 'keymap.yscale': ['l'],
- 'keymap.zoom': ['o'],
- 'legend.fancybox': True,
- 'lines.antialiased': True,
- 'lines.linewidth': 1.0,
- 'patch.antialiased': True,
- 'patch.edgecolor': '#EEEEEE',
- 'patch.facecolor': '#348ABD',
- 'patch.linewidth': 0.5,
- 'toolbar': 'toolbar2',
- 'xtick.color': '#555555',
- 'xtick.direction': 'in',
- 'xtick.major.pad': 6.0,
- 'xtick.major.size': 0.0,
- 'xtick.minor.pad': 6.0,
- 'xtick.minor.size': 0.0,
- 'ytick.color': '#555555',
- 'ytick.direction': 'in',
- 'ytick.major.pad': 6.0,
- 'ytick.major.size': 0.0,
- 'ytick.minor.pad': 6.0,
- 'ytick.minor.size': 0.0
- }
- def _get_standard_kind(kind):
- return {'density': 'kde'}.get(kind, kind)
- def _get_standard_colors(num_colors=None, colormap=None, color_type='default',
- color=None):
- import matplotlib.pyplot as plt
- if color is None and colormap is not None:
- if isinstance(colormap, compat.string_types):
- import matplotlib.cm as cm
- cmap = colormap
- colormap = cm.get_cmap(colormap)
- if colormap is None:
- raise ValueError("Colormap {0} is not recognized".format(cmap))
- colors = lmap(colormap, np.linspace(0, 1, num=num_colors))
- elif color is not None:
- if colormap is not None:
- warnings.warn("'color' and 'colormap' cannot be used "
- "simultaneously. Using 'color'")
- colors = color
- else:
- if color_type == 'default':
- colors = plt.rcParams.get('axes.color_cycle', list('bgrcmyk'))
- if isinstance(colors, compat.string_types):
- colors = list(colors)
- elif color_type == 'random':
- import random
- def random_color(column):
- random.seed(column)
- return [random.random() for _ in range(3)]
- colors = lmap(random_color, lrange(num_colors))
- else:
- raise NotImplementedError
- if len(colors) != num_colors:
- multiple = num_colors//len(colors) - 1
- mod = num_colors % len(colors)
- colors += multiple * colors
- colors += colors[:mod]
- return colors
- class _Options(dict):
- """
- Stores pandas plotting options.
- Allows for parameter aliasing so you can just use parameter names that are
- the same as the plot function parameters, but is stored in a canonical
- format that makes it easy to breakdown into groups later
- """
- # alias so the names are same as plotting method parameter names
- _ALIASES = {'x_compat': 'xaxis.compat'}
- _DEFAULT_KEYS = ['xaxis.compat']
- def __init__(self):
- self['xaxis.compat'] = False
- def __getitem__(self, key):
- key = self._get_canonical_key(key)
- if key not in self:
- raise ValueError('%s is not a valid pandas plotting option' % key)
- return super(_Options, self).__getitem__(key)
- def __setitem__(self, key, value):
- key = self._get_canonical_key(key)
- return super(_Options, self).__setitem__(key, value)
- def __delitem__(self, key):
- key = self._get_canonical_key(key)
- if key in self._DEFAULT_KEYS:
- raise ValueError('Cannot remove default parameter %s' % key)
- return super(_Options, self).__delitem__(key)
- def __contains__(self, key):
- key = self._get_canonical_key(key)
- return super(_Options, self).__contains__(key)
- def reset(self):
- """
- Reset the option store to its initial state
- Returns
- -------
- None
- """
- self.__init__()
- def _get_canonical_key(self, key):
- return self._ALIASES.get(key, key)
- @contextmanager
- def use(self, key, value):
- """
- Temporarily set a parameter value using the with statement.
- Aliasing allowed.
- """
- old_value = self[key]
- try:
- self[key] = value
- yield self
- finally:
- self[key] = old_value
- plot_params = _Options()
- def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
- diagonal='hist', marker='.', density_kwds=None,
- hist_kwds=None, range_padding=0.05, **kwds):
- """
- Draw a matrix of scatter plots.
- Parameters
- ----------
- frame : DataFrame
- alpha : float, optional
- amount of transparency applied
- figsize : (float,float), optional
- a tuple (width, height) in inches
- ax : Matplotlib axis object, optional
- grid : bool, optional
- setting this to True will show the grid
- diagonal : {'hist', 'kde'}
- pick between 'kde' and 'hist' for
- either Kernel Density Estimation or Histogram
- plot in the diagonal
- marker : str, optional
- Matplotlib marker type, default '.'
- hist_kwds : other plotting keyword arguments
- To be passed to hist function
- density_kwds : other plotting keyword arguments
- To be passed to kernel density estimate plot
- range_padding : float, optional
- relative extension of axis range in x and y
- with respect to (x_max - x_min) or (y_max - y_min),
- default 0.05
- kwds : other plotting keyword arguments
- To be passed to scatter function
- Examples
- --------
- >>> df = DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
- >>> scatter_matrix(df, alpha=0.2)
- """
- import matplotlib.pyplot as plt
- from matplotlib.artist import setp
- df = frame._get_numeric_data()
- n = df.columns.size
- fig, axes = _subplots(nrows=n, ncols=n, figsize=figsize, ax=ax,
- squeeze=False)
- # no gaps between subplots
- fig.subplots_adjust(wspace=0, hspace=0)
- mask = com.notnull(df)
- marker = _get_marker_compat(marker)
- hist_kwds = hist_kwds or {}
- density_kwds = density_kwds or {}
- # workaround because `c='b'` is hardcoded in matplotlibs scatter method
- kwds.setdefault('c', plt.rcParams['patch.facecolor'])
- boundaries_list = []
- for a in df.columns:
- values = df[a].values[mask[a].values]
- rmin_, rmax_ = np.min(values), np.max(values)
- rdelta_ext = (rmax_ - rmin_) * range_padding / 2.
- boundaries_list.append((rmin_ - rdelta_ext, rmax_+ rdelta_ext))
- for i, a in zip(lrange(n), df.columns):
- for j, b in zip(lrange(n), df.columns):
- ax = axes[i, j]
- if i == j:
- values = df[a].values[mask[a].values]
- # Deal with the diagonal by drawing a histogram there.
- if diagonal == 'hist':
- ax.hist(values, **hist_kwds)
- elif diagonal in ('kde', 'density'):
- from scipy.stats import gaussian_kde
- y = values
- gkde = gaussian_kde(y)
- ind = np.linspace(y.min(), y.max(), 1000)
- ax.plot(ind, gkde.evaluate(ind), **density_kwds)
- ax.set_xlim(boundaries_list[i])
- else:
- common = (mask[a] & mask[b]).values
- ax.scatter(df[b][common], df[a][common],
- marker=marker, alpha=alpha, **kwds)
- ax.set_xlim(boundaries_list[j])
- ax.set_ylim(boundaries_list[i])
- ax.set_xlabel('')
- ax.set_ylabel('')
- _label_axis(ax, kind='x', label=b, position='bottom', rotate=True)
- _label_axis(ax, kind='y', label=a, position='left')
- if j!= 0:
- ax.yaxis.set_visible(False)
- if i != n-1:
- ax.xaxis.set_visible(False)
- for ax in axes.flat:
- setp(ax.get_xticklabels(), fontsize=8)
- setp(ax.get_yticklabels(), fontsize=8)
- return axes
- def _label_axis(ax, kind='x', label='', position='top',
- ticks=True, rotate=False):
- from matplotlib.artist import setp
- if kind == 'x':
- ax.set_xlabel(label, visible=True)
- ax.xaxis.set_visible(True)
- ax.xaxis.set_ticks_position(position)
- ax.xaxis.set_label_position(position)
- if rotate:
- setp(ax.get_xticklabels(), rotation=90)
- elif kind == 'y':
- ax.yaxis.set_visible(True)
- ax.set_ylabel(label, visible=True)
- # ax.set_ylabel(a)
- ax.yaxis.set_ticks_position(position)
- ax.yaxis.set_label_position(position)
- return
- def _gca():
- import matplotlib.pyplot as plt
- return plt.gca()
- def _gcf():
- import matplotlib.pyplot as plt
- return plt.gcf()
- def _get_marker_compat(marker):
- import matplotlib.lines as mlines
- import matplotlib as mpl
- if mpl.__version__ < '1.1.0' and marker == '.':
- return 'o'
- if marker not in mlines.lineMarkers:
- return 'o'
- return marker
- def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
- """RadViz - a multivariate data visualization algorithm
- Parameters:
- -----------
- frame: DataFrame
- class_column: str
- Column name containing class names
- ax: Matplotlib axis object, optional
- color: list or tuple, optional
- Colors to use for the different classes
- colormap : str or matplotlib colormap object, default None
- Colormap to select colors from. If string, load colormap with that name
- from matplotlib.
- kwds: keywords
- Options to pass to matplotlib scatter plotting method
- Returns:
- --------
- ax: Matplotlib axis object
- """
- import matplotlib.pyplot as plt
- import matplotlib.patches as patches
- def normalize(series):
- a = min(series)
- b = max(series)
- return (series - a) / (b - a)
- n = len(frame)
- classes = frame[class_column].drop_duplicates()
- class_col = frame[class_column]
- df = frame.drop(class_column, axis=1).apply(normalize)
- if ax is None:
- ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
- to_plot = {}
- colors = _get_standard_colors(num_colors=len(classes), colormap=colormap,
- color_type='random', color=color)
- for kls in classes:
- to_plot[kls] = [[], []]
- n = len(frame.columns) - 1
- s = np.array([(np.cos(t), np.sin(t))
- for t in [2.0 * np.pi * (i / float(n))
- for i in range(n)]])
- for i in range(n):
- row = df.iloc[i].values
- row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
- y = (s * row_).sum(axis=0) / row.sum()
- kls = class_col.iat[i]
- to_plot[kls][0].append(y[0])
- to_plot[kls][1].append(y[1])
- for i, kls in enumerate(classes):
- ax.scatter(to_plot[kls][0], to_plot[kls][1], color=colors[i],
- label=com.pprint_thing(kls), **kwds)
- ax.legend()
- ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
- for xy, name in zip(s, df.columns):
- ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
- if xy[0] < 0.0 and xy[1] < 0.0:
- ax.text(xy[0] - 0.025, xy[1] - 0.025, name,
- ha='right', va='top', size='small')
- elif xy[0] < 0.0 and xy[1] >= 0.0:
- ax.text(xy[0] - 0.025, xy[1] + 0.025, name,
- ha='right', va='bottom', size='small')
- elif xy[0] >= 0.0 and xy[1] < 0.0:
- ax.text(xy[0] + 0.025, xy[1] - 0.025, name,
- ha='left', va='top', size='small')
- elif xy[0] >= 0.0 and xy[1] >= 0.0:
- ax.text(xy[0] + 0.025, xy[1] + 0.025, name,
- ha='left', va='bottom', size='small')
- ax.axis('equal')
- return ax
- @deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
- def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
- colormap=None, **kwds):
- """
- Parameters:
- -----------
- frame : DataFrame
- Data to be plotted, preferably normalized to (0.0, 1.0)
- class_column : Name of the column containing class names
- ax : matplotlib axes object, default None
- samples : Number of points to plot in each curve
- color: list or tuple, optional
- Colors to use for the different classes
- colormap : str or matplotlib colormap object, default None
- Colormap to select colors from. If string, load colormap with that name
- from matplotlib.
- kwds: keywords
- Options to pass to matplotlib plotting method
- Returns:
- --------
- ax: Matplotlib axis object
- """
- from math import sqrt, pi, sin, cos
- import matplotlib.pyplot as plt
- def function(amplitudes):
- def f(x):
- x1 = amplitudes[0]
- result = x1 / sqrt(2.0)
- harmonic = 1.0
- for x_even, x_odd in zip(amplitudes[1::2], amplitudes[2::2]):
- result += (x_even * sin(harmonic * x) +
- x_odd * cos(harmonic * x))
- harmonic += 1.0
- if len(amplitudes) % 2 != 0:
- result += amplitudes[-1] * sin(harmonic * x)
- return result
- return f
- n = len(frame)
- class_col = frame[class_column]
- classes = frame[class_column].drop_duplicates()
- df = frame.drop(class_column, axis=1)
- x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
- used_legends = set([])
- color_values = _get_standard_colors(num_colors=len(classes),
- colormap=colormap, color_type='random',
- color=color)
- colors = dict(zip(classes, color_values))
- if ax is None:
- ax = plt.gca(xlim=(-pi, pi))
- for i in range(n):
- row = df.iloc[i].values
- f = function(row)
- y = [f(t) for t in x]
- kls = class_col.iat[i]
- label = com.pprint_thing(kls)
- if label not in used_legends:
- used_legends.add(label)
- ax.plot(x, y, color=colors[kls], label=label, **kwds)
- else:
- ax.plot(x, y, color=colors[kls], **kwds)
- ax.legend(loc='upper right')
- ax.grid()
- return ax
- def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
- """Bootstrap plot.
- Parameters:
- -----------
- series: Time series
- fig: matplotlib figure object, optional
- size: number of data points to consider during each sampling
- samples: number of times the bootstrap procedure is performed
- kwds: optional keyword arguments for plotting commands, must be accepted
- by both hist and plot
- Returns:
- --------
- fig: matplotlib figure
- """
- import random
- import matplotlib.pyplot as plt
- # random.sample(ndarray, int) fails on python 3.3, sigh
- data = list(series.values)
- samplings = [random.sample(data, size) for _ in range(samples)]
- means = np.array([np.mean(sampling) for sampling in samplings])
- medians = np.array([np.median(sampling) for sampling in samplings])
- midranges = np.array([(min(sampling) + max(sampling)) * 0.5
- for sampling in samplings])
- if fig is None:
- fig = plt.figure()
- x = lrange(samples)
- axes = []
- ax1 = fig.add_subplot(2, 3, 1)
- ax1.set_xlabel("Sample")
- axes.append(ax1)
- ax1.plot(x, means, **kwds)
- ax2 = fig.add_subplot(2, 3, 2)
- ax2.set_xlabel("Sample")
- axes.append(ax2)
- ax2.plot(x, medians, **kwds)
- ax3 = fig.add_subplot(2, 3, 3)
- ax3.set_xlabel("Sample")
- axes.append(ax3)
- ax3.plot(x, midranges, **kwds)
- ax4 = fig.add_subplot(2, 3, 4)
- ax4.set_xlabel("Mean")
- axes.append(ax4)
- ax4.hist(means, **kwds)
- ax5 = fig.add_subplot(2, 3, 5)
- ax5.set_xlabel("Median")
- axes.append(ax5)
- ax5.hist(medians, **kwds)
- ax6 = fig.add_subplot(2, 3, 6)
- ax6.set_xlabel("Midrange")
- axes.append(ax6)
- ax6.hist(midranges, **kwds)
- for axis in axes:
- plt.setp(axis.get_xticklabels(), fontsize=8)
- plt.setp(axis.get_yticklabels(), fontsize=8)
- return fig
- @deprecate_kwarg(old_arg_name='colors', new_arg_name='color')
- @deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
- def parallel_coordinates(frame, class_column, cols=None, ax=None, color=None,
- use_columns=False, xticks=None, colormap=None,
- **kwds):
- """Parallel coordinates plotting.
- Parameters
- ----------
- frame: DataFrame
- class_column: str
- Column name containing class names
- cols: list, optional
- A list of column names to use
- ax: matplotlib.axis, optional
- matplotlib axis object
- color: list or tuple, optional
- Colors to use for the different classes
- use_columns: bool, optional
- If true, columns will be used as xticks
- xticks: list or tuple, optional
- A list of values to use for xticks
- colormap: str or matplotlib colormap, default None
- Colormap to use for line colors.
- kwds: keywords
- Options to pass to matplotlib plotting method
- Returns
- -------
- ax: matplotlib axis object
- Examples
- --------
- >>> from pandas import read_csv
- >>> from pandas.tools.plotting import parallel_coordinates
- >>> from matplotlib import pyplot as plt
- >>> df = read_csv('https://raw.github.com/pydata/pandas/master/pandas/tests/data/iris.csv')
- >>> parallel_coordinates(df, 'Name', color=('#556270', '#4ECDC4', '#C7F464'))
- >>> plt.show()
- """
- import matplotlib.pyplot as plt
- n = len(frame)
- classes = frame[class_column].drop_duplicates()
- class_col = frame[class_column]
- if cols is None:
- df = frame.drop(class_column, axis=1)
- else:
- df = frame[cols]
- used_legends = set([])
- ncols = len(df.columns)
- # determine values to use for xticks
- if use_columns is True:
- if not np.all(np.isreal(list(df.columns))):
- raise ValueError('Columns must be numeric to be used as xticks')
- x = df.columns
- elif xticks is not None:
- if not np.all(np.isreal(xticks)):
- raise ValueError('xticks specified must be numeric')
- elif len(xticks) != ncols:
- raise ValueError('Length of xticks must match number of columns')
- x = xticks
- else:
- x = lrange(ncols)
- if ax is None:
- ax = plt.gca()
- color_values = _get_standard_colors(num_colors=len(classes),
- colormap=colormap, color_type='random',
- color=color)
- colors = dict(zip(classes, color_values))
- for i in range(n):
- y = df.iloc[i].values
- kls = class_col.iat[i]
- label = com.pprint_thing(kls)
- if label not in used_legends:
- used_legends.add(label)
- ax.plot(x, y, color=colors[kls], label=label, **kwds)
- else:
- ax.plot(x, y, color=colors[kls], **kwds)
- for i in x:
- ax.axvline(i, linewidth=1, color='black')
- ax.set_xticks(x)
- ax.set_xticklabels(df.columns)
- ax.set_xlim(x[0], x[-1])
- ax.legend(loc='upper right')
- ax.grid()
- return ax
- def lag_plot(series, lag=1, ax=None, **kwds):
- """Lag plot for time series.
- Parameters:
- -----------
- series: Time series
- lag: lag of the scatter plot, default 1
- ax: Matplotlib axis object, optional
- kwds: Matplotlib scatter method keyword arguments, optional
- Returns:
- --------
- ax: Matplotlib axis object
- """
- import matplotlib.pyplot as plt
- # workaround because `c='b'` is hardcoded in matplotlibs scatter method
- kwds.setdefault('c', plt.rcParams['patch.facecolor'])
- data = series.values
- y1 = data[:-lag]
- y2 = data[lag:]
- if ax is None:
- ax = plt.gca()
- ax.set_xlabel("y(t)")
- ax.set_ylabel("y(t + %s)" % lag)
- ax.scatter(y1, y2, **kwds)
- return ax
- def autocorrelation_plot(series, ax=None, **kwds):
- """Autocorrelation plot for time series.
- Parameters:
- -----------
- series: Time series
- ax: Matplotlib axis object, optional
- kwds : keywords
- Options to pass to matplotlib plotting method
- Returns:
- -----------
- ax: Matplotlib axis object
- """
- import matplotlib.pyplot as plt
- n = len(series)
- data = np.asarray(series)
- if ax is None:
- ax = plt.gca(xlim=(1, n), ylim=(-1.0, 1.0))
- mean = np.mean(data)
- c0 = np.sum((data - mean) ** 2) / float(n)
- def r(h):
- return ((data[:n - h] - mean) * (data[h:] - mean)).sum() / float(n) / c0
- x = np.arange(n) + 1
- y = lmap(r, x)
- z95 = 1.959963984540054
- z99 = 2.5758293035489004
- ax.axhline(y=z99 / np.sqrt(n), linestyle='--', color='grey')
- ax.axhline(y=z95 / np.sqrt(n), color='grey')
- ax.axhline(y=0.0, color='black')
- ax.axhline(y=-z95 / np.sqrt(n), color='grey')
- ax.axhline(y=-z99 / np.sqrt(n), linestyle='--', color='grey')
- ax.set_xlabel("Lag")
- ax.set_ylabel("Autocorrelation")
- ax.plot(x, y, **kwds)
- if 'label' in kwds:
- ax.legend()
- ax.grid()
- return ax
- class MPLPlot(object):
- """
- Base class for assembling a pandas plot using matplotlib
- Parameters
- ----------
- data :
- """
- _default_rot = 0
- _pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog',
- 'mark_right']
- _attr_defaults = {'logy': False, 'logx': False, 'loglog': False,
- 'mark_right': True}
- def __init__(self, data, kind=None, by=None, subplots=False, sharex=True,
- sharey=False, use_index=True,
- figsize=None, grid=None, legend=True, rot=None,
- ax=None, fig=None, title=None, xlim=None, ylim=None,
- xticks=None, yticks=None,
- sort_columns=False, fontsize=None,
- secondary_y=False, colormap=None,
- table=False, **kwds):
- self.data = data
- self.by = by
- self.kind = kind
- self.sort_columns = sort_columns
- self.subplots = subplots
- self.sharex = sharex
- self.sharey = sharey
- self.figsize = figsize
- self.xticks = xticks
- self.yticks = yticks
- self.xlim = xlim
- self.ylim = ylim
- self.title = title
- self.use_index = use_index
- self.fontsize = fontsize
- self.rot = rot
- if grid is None:
- grid = False if secondary_y else True
- self.grid = grid
- self.legend = legend
- self.legend_handles = []
- self.legend_labels = []
- for attr in self._pop_attributes:
- value = kwds.pop(attr, self._attr_defaults.get(attr, None))
- setattr(self, attr, value)
- self.ax = ax
- self.fig = fig
- self.axes = None
- # parse errorbar input if given
- xerr = kwds.pop('xerr', None)
- yerr = kwds.pop('yerr', None)
- self.errors = {}
- for kw, err in zip(['xerr', 'yerr'], [xerr, yerr]):
- self.errors[kw] = self._parse_errorbars(kw, err)
- if not isinstance(secondary_y, (bool, tuple, list, np.ndarray)):
- secondary_y = [secondary_y]
- self.secondary_y = secondary_y
- # ugly TypeError if user passes matplotlib's `cmap` name.
- # Probably better to accept either.
- if 'cmap' in kwds and colormap:
- raise TypeError("Only specify one of `cmap` and `colormap`.")
- elif 'cmap' in kwds:
- self.colormap = kwds.pop('cmap')
- else:
- self.colormap = colormap
- self.table = table
- self.kwds = kwds
- self._validate_color_args()
- def _validate_color_args(self):
- from pandas import DataFrame
- if 'color' not in self.kwds and 'colors' in self.kwds:
- warnings.warn(("'colors' is being deprecated. Please use 'color'"
- "instead of 'colors'"))
- colors = self.kwds.pop('colors')
- self.kwds['color'] = colors
- if ('color' in self.kwds and
- (isinstance(self.data, Series) or
- isinstance(self.data, DataFrame) and len(self.data.columns) == 1)):
- # support series.plot(color='green')
- self.kwds['color'] = [self.kwds['color']]
- if ('color' in self.kwds or 'colors' in self.kwds) and \
- self.colormap is not None:
- warnings.warn("'color' and 'colormap' cannot be used "
- "simultaneously. Using 'color'")
- if 'color' in self.kwds and self.style is not None:
- # need only a single match
- if re.match('^[a-z]+?', self.style) is not None:
- raise ValueError("Cannot pass 'style' string with a color "
- "symbol and 'color' keyword argument. Please"
- " use one or the other or pass 'style' "
- "without a color symbol")
- def _iter_data(self, data=None, keep_index=False):
- if data is None:
- data = self.data
- from pandas.core.frame import DataFrame
- if isinstance(data, (Series, np.ndarray)):
- if keep_index is True:
- yield self.label, data
- else:
- yield self.label, np.asarray(data)
- elif isinstance(data, DataFrame):
- if self.sort_columns:
- columns = com._try_sort(data.columns)
- else:
- columns = data.columns
- for col in columns:
- # # is this right?
- # empty = df[col].count() == 0
- # values = df[col].values if not empty else np.zeros(len(df))
- if keep_index is True:
- yield col, data[col]
- else:
- yield col, data[col].values
- @property
- def nseries(self):
- if self.data.ndim == 1:
- return 1
- else:
- return self.data.shape[1]
- def draw(self):
- self.plt.draw_if_interactive()
- def generate(self):
- self._args_adjust()
- self._compute_plot_data()
- self._setup_subplots()
- self._make_plot()
- self._add_table()
- self._make_legend()
- self._post_plot_logic()
- self._adorn_subplots()
- def _args_adjust(self):
- pass
- def _maybe_right_yaxis(self, ax):
- if hasattr(ax, 'right_ax'):
- return ax.right_ax
- else:
- orig_ax, new_ax = ax, ax.twinx()
- new_ax._get_lines.color_cycle = orig_ax._get_lines.color_cycle
- orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
- new_ax.right_ax = new_ax
- if len(orig_ax.get_lines()) == 0: # no data on left y
- orig_ax.get_yaxis().set_visible(False)
- return new_ax
- def _setup_subplots(self):
- if self.subplots:
- nrows, ncols = self._get_layout()
- fig, axes = _subplots(nrows=nrows, ncols=ncols,
- sharex=self.sharex, sharey=self.sharey,
- figsize=self.figsize, ax=self.ax)
- if not com.is_list_like(axes):
- axes = np.array([axes])
- else:
- if self.ax is None:
- fig = self.plt.figure(figsize=self.figsize)
- ax = fig.add_subplot(111)
- else:
- fig = self.ax.get_figure()
- if self.figsize is not None:
- fig.set_size_inches(self.figsize)
- ax = self.ax
- axes = [ax]
- if self.logx or self.loglog:
- [a.set_xscale('log') for a in axes]
- if self.logy or self.loglog:
- [a.set_yscale('log') for a in axes]
- self.fig = fig
- self.axes = axes
- def _get_layout(self):
- from pandas.core.frame import DataFrame
- if isinstance(self.data, DataFrame):
- return (len(self.data.columns), 1)
- else:
- return (1, 1)
- def _compute_plot_data(self):
- numeric_data = self.data.convert_objects()._get_numeric_data()
- try:
- is_empty = numeric_data.empty
- except AttributeError:
- is_empty = not len(numeric_data)
- # no empty frames or series allowed
- if is_empty:
- raise TypeError('Empty {0!r}: no numeric data to '
- 'plot'.format(numeric_data.__class__.__name__))
- self.data = numeric_data
- def _make_plot(self):
- raise NotImplementedError
- def _add_table(self):
- if self.table is False:
- return
- elif self.table is True:
- from pandas.core.frame import DataFrame
- if isinstance(self.data, Series):
- data = DataFrame(self.data, columns=[self.data.name])
- elif isinstance(self.data, DataFrame):
- data = self.data
- data = data.transpose()
- else:
- data = self.table
- ax = self._get_ax(0)
- table(ax, data)
- def _post_plot_logic(self):
- pass
- def _adorn_subplots(self):
- to_adorn = self.axes
- # todo: sharex, sharey handling?
- for ax in to_adorn:
- if self.yticks is not None:
- ax.set_yticks(self.yticks)
- if self.xticks is not None:
- ax.set_xticks(self.xticks)
- if self.ylim is not None:
- ax.set_ylim(self.ylim)
- if self.xlim is not None:
- ax.set_xlim(self.xlim)
- ax.grid(self.grid)
- if self.title:
- if self.subplots:
- self.fig.suptitle(self.title)
- else:
- self.axes[0].set_title(self.title)
- if self._need_to_set_index:
- labels = [com.pprint_thing(key) for key in self.data.index]
- labels = dict(zip(range(len(self.data.index)), labels))
- for ax_ in self.axes:
- # ax_.set_xticks(self.xticks)
- xticklabels = [labels.get(x, '') for x in ax_.get_xticks()]
- ax_.set_xticklabels(xticklabels, rotation=self.rot)
- @property
- def legend_title(self):
- if hasattr(self.data, 'columns'):
- if not isinstance(self.data.columns, MultiIndex):
- name = self.data.columns.name
- if name is not None:
- name = com.pprint_thing(name)
- return name
- else:
- stringified = map(com.pprint_thing,
- self.data.columns.names)
- return ','.join(stringified)
- else:
- return None
- def _add_legend_handle(self, handle, label, index=None):
- if not label is None:
- if self.mark_right and index is not None:
- if self.on_right(index):
- label = label + ' (right)'
- self.legend_handles.append(handle)
- self.legend_labels.append(label)
- def _make_legend(self):
- ax, leg = self._get_ax_legend(self.axes[0])
- handles = []
- labels = []
- title = ''
- if not self.subplots:
- if not leg is None:
- title = leg.get_title().get_text()
- handles = leg.legendHandles
- labels = [x.get_text() for x in leg.get_texts()]
- if self.legend:
- if self.legend == 'reverse':
- self.legend_handles = reversed(self.legend_handles)
- self.legend_labels = reversed(self.legend_labels)
- handles += self.legend_handles
- labels += self.legend_labels
- if not self.legend_title is None:
- title = self.legend_title
- if len(handles) > 0:
- ax.legend(handles, labels, loc='best', title=title)
- elif self.subplots and self.legend:
- for ax in self.axes:
- ax.legend(loc='best')
- def _get_ax_legend(self, ax):
- leg = ax.get_legend()
- other_ax = (getattr(ax, 'right_ax', None) or
- getattr(ax, 'left_ax', None))
- other_leg = None
- if other_ax is not None:
- other_leg = other_ax.get_legend()
- if leg is None and other_leg is not None:
- leg = other_leg
- ax = other_ax
- return ax, leg
- @cache_readonly
- def plt(self):
- import matplotlib.pyplot as plt
- return plt
- _need_to_set_index = False
- def _get_xticks(self, convert_period=False):
- index = self.data.index
- is_datetype = index.inferred_type in ('datetime', 'date',
- 'datetime64', 'time')
- if self.use_index:
- if convert_period and isinstance(index, PeriodIndex):
- self.data = self.data.reindex(index=index.order())
- x = self.data.index.to_timestamp()._mpl_repr()
- elif index.is_numeric():
- """
- Matplotlib supports numeric values or datetime objects as
- xaxis values. Taking LBYL approach here, by the time
- matplotlib raises exception when using non numeric/datetime
- values for xaxis, several actions are already taken by plt.
- """
- x = index._mpl_repr()
- elif is_datetype:
- self.data = self.data.sort_index()
- x = self.data.index._mpl_repr()
- else:
- self._need_to_set_index = True
- x = lrange(len(index))
- else:
- x = lrange(len(index))
- return x
- def _is_datetype(self):
- index = self.data.index
- return (isinstance(index, (PeriodIndex, DatetimeIndex)) or
- index.inferred_type in ('datetime', 'date', 'datetime64',
- 'time'))
- def _get_plot_function(self):
- '''
- Returns the matplotlib plotting function (plot or errorbar) based on
- the presence of errorbar keywords.
- '''
- if all(e is None for e in self.errors.values()):
- plotf = self.plt.Axes.plot
- else:
- plotf = self.plt.Axes.errorbar
- return plotf
- def _get_index_name(self):
- if isinstance(self.data.index, MultiIndex):
- name = self.data.index.names
- if any(x is not None for x in name):
- name = ','.join([com.pprint_thing(x) for x in name])
- else:
- name = None
- else:
- name = self.data.index.name
- if name is not None:
- name = com.pprint_thing(name)
- return name
- def _get_ax(self, i):
- # get the twinx ax if appropriate
- if self.subplots:
- ax = self.axes[i]
- if self.on_right(i):
- ax = self._maybe_right_yaxis(ax)
- self.axes[i] = ax
- else:
- ax = self.axes[0]
- if self.on_right(i):
- ax = self._maybe_right_yaxis(ax)
- sec_true = isinstance(self.secondary_y, bool) and self.secondary_y
- all_sec = (com.is_list_like(self.secondary_y) and
- len(self.secondary_y) == self.nseries)
- if sec_true or all_sec:
- self.axes[0] = ax
- ax.get_yaxis().set_visible(True)
- return ax
- def on_right(self, i):
- from pandas.core.frame import DataFrame
- if isinstance(self.secondary_y, bool):
- return self.secondary_y
- if (isinstance(self.data, DataFrame) and
- isinstance(self.secondary_y, (tuple, list, np.ndarray))):
- return self.data.columns[i] in self.secondary_y
- def _get_style(self, i, col_name):
- style = ''
- if self.subplots:
- style = 'k'
- if self.style is not None:
- if isinstance(self.style, list):
- try:
- style = self.style[i]
- except IndexError:
- pass
- elif isinstance(self.style, dict):
- style = self.style.get(col_name, style)
- else:
- style = self.style
- return style or None
- def _get_colors(self, num_colors=None, color_kwds='color'):
- from pandas.core.frame import DataFrame
- if num_colors is None:
- if isinstance(self.data, DataFrame):
- num_colors = len(self.data.columns)
- else:
- num_colors = 1
- return _get_standard_colors(num_colors=num_colors,
- colormap=self.colormap,
- color=self.kwds.get(color_kwds))
- def _maybe_add_color(self, colors, kwds, style, i):
- has_color = 'color' in kwds or self.colormap is not None
- if has_color and (style is None or re.match('[a-z]+', style) is None):
- kwds['color'] = colors[i % len(colors)]
- def _parse_errorbars(self, label, err):
- '''
- Look for error keyword arguments and return the actual errorbar data
- or return the error DataFrame/dict
- Error bars can be specified in several ways:
- Series: the user provides a pandas.Series object of the same
- length as the data
- ndarray: provides a np.ndarray of the same length as the data
- DataFrame/dict: error values are paired with keys matching the
- key in the plotted DataFrame
- str: the name of the column within the plotted DataFrame
- '''
- if err is None:
- return None
- from pandas import DataFrame, Series
- def match_labels(data, e):
- e = e.reindex_axis(data.index)
- return e
- # key-matched DataFrame
- if isinstance(err, DataFrame):
- err = match_labels(self.data, err)
- # key-matched dict
- elif isinstance(err, dict):
- pass
- # Series of error values
- elif isinstance(err, Series):
- # broadcast error series across data
- err = match_labels(self.data, err)
- err = np.atleast_2d(err)
- err = np.tile(err, (self.nseries, 1))
- # errors are a column in the dataframe
- elif isinstance(err, string_types):
- evalues = self.data[err].values
- self.data = self.data[self.data.columns.drop(err)]
- err = np.atleast_2d(evalues)
- err = np.tile(err, (self.nseries, 1))
- elif com.is_list_like(err):
- if com.is_iterator(err):
- err = np.atleast_2d(list(err))
- else:
- # raw error values
- err = np.atleast_2d(err)
- err_shape = err.shape
- # asymmetrical error bars
- if err.ndim == 3:
- if (err_shape[0] != self.nseries) or \
- (err_shape[1] != 2) or \
- (err_shape[2] != len(self.data)):
- msg = "Asymmetrical error bars should be provided " + \
- "with the shape (%u, 2, %u)" % \
- (self.nseries, len(self.data))
- raise ValueError(msg)
- # broadcast errors to each data series
- if len(err) == 1:
- err = np.tile(err, (self.nseries, 1))
- elif com.is_number(err):
- err = np.tile([err], (self.nseries, len(self.data)))
- else:
- msg = "No valid %s detected" % label
- raise ValueError(msg)
- return err
- def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
- from pandas import DataFrame
- errors = {}
- for kw, flag in zip(['xerr', 'yerr'], [xerr, yerr]):
- if flag:
- err = self.errors[kw]
- # user provided label-matched dataframe of errors
- if isinstance(err, (DataFrame, dict)):
- if label is not None and label in err.keys():
- err = err[label]
- else:
- err = None
- elif index is not None and err is not None:
- err = err[index]
- if err is not None:
- errors[kw] = err
- return errors
- class KdePlot(MPLPlot):
- def __init__(self, data, bw_method=None, ind=None, **kwargs):
- MPLPlot.__init__(self, data, **kwargs)
- self.bw_method=bw_method
- self.ind=ind
- def _make_plot(self):
- from scipy.stats import gaussian_kde
- from scipy import __version__ as spv
- from distutils.version import LooseVersion
- plotf = self.plt.Axes.plot
- colors = self._get_colors()
- for i, (label, y) in enumerate(self._iter_data()):
- ax = self._get_ax(i)
- style = self._get_style(i, label)
- label = com.pprint_thing(label)
- if LooseVersion(spv) >= '0.11.0':
- gkde = gaussian_kde(y, bw_method=self.bw_method)
- else:
- gkde = gaussian_kde(y)
- if self.bw_method is not None:
- msg = ('bw_method was added in Scipy 0.11.0.' +
- ' Scipy version in use is %s.' % spv)
- warnings.warn(msg)
- sample_range = max(y) - min(y)
- if self.ind is None:
- ind = np.linspace(min(y) - 0.5 * sample_range,
- max(y) + 0.5 * sample_range, 1000)
- else:
- ind = self.ind
- ax.set_ylabel("Density")
- y = gkde.evaluate(ind)
- kwds = self.kwds.copy()
- kwds['label'] = label
- self._maybe_add_color(colors, kwds, style, i)
- if style is None:
- args = (ax, ind, y)
- else:
- args = (ax, ind, y, style)
- newlines = plotf(*args, **kwds)
- self._add_legend_handle(newlines[0], label)
- class ScatterPlot(MPLPlot):
- def __init__(self, data, x, y, **kwargs):
- MPLPlot.__init__(self, data, **kwargs)
- self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor'])
- if x is None or y is None:
- raise ValueError( 'scatter requires and x and y column')
- if com.is_integer(x) and not self.data.columns.holds_integer():
- x = self.data.columns[x]
- if com.is_integer(y) and not self.data.columns.holds_integer():
- y = self.data.columns[y]
- self.x = x
- self.y = y
- def _get_layout(self):
- return (1, 1)
- def _make_plot(self):
- x, y, data = self.x, self.y, self.data
- ax = self.axes[0]
- if self.legend and hasattr(self, 'label'):
- label = self.label
- else:
- label = None
- scatter = ax.scatter(data[x].values, data[y].values, label=label,
- **self.kwds)
- self._add_legend_handle(scatter, label)
- errors_x = self._get_errorbars(label=x, index=0, yerr=False)
- errors_y = self._get_errorbars(label=y, index=1, xerr=False)
- if len(errors_x) > 0 or len(errors_y) > 0:
- err_kwds = dict(errors_x, **errors_y)
- if 'color' in self.kwds:
- err_kwds['color'] = self.kwds['color']
- ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds)
- def _post_plot_logic(self):
- ax = self.axes[0]
- x, y = self.x, self.y
- ax.set_ylabel(com.pprint_thing(y))
- ax.set_xlabel(com.pprint_thing(x))
- class HexBinPlot(MPLPlot):
- def __init__(self, data, x, y, C=None, **kwargs):
- MPLPlot.__init__(self, data, **kwargs)
- if x is None or y is None:
- raise ValueError('hexbin requires and x and y column')
- if com.is_integer(x) and not self.data.columns.holds_integer():
- x = self.data.columns[x]
- if com.is_integer(y) and not self.data.columns.holds_integer():
- y = self.data.columns[y]
- if com.is_integer(C) and not self.data.columns.holds_integer():
- C = self.data.columns[C]
- self.x = x
- self.y = y
- self.C = C
- def _get_layout(self):
- return (1, 1)
- def _make_plot(self):
- import matplotlib.pyplot as plt
- x, y, data, C = self.x, self.y, self.data, self.C
- ax = self.axes[0]
- # pandas uses colormap, matplotlib uses cmap.
- cmap = self.colormap or 'BuGn'
- cmap = plt.cm.get_cmap(cmap)
- cb = self.kwds.pop('colorbar', True)
- if C is None:
- c_values = None
- else:
- c_values = data[C].values
- ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap,
- **self.kwds)
- if cb:
- img = ax.collections[0]
- self.fig.colorbar(img, ax=ax)
- def _post_plot_logic(self):
- ax = self.axes[0]
- x, y = self.x, self.y
- ax.set_ylabel(com.pprint_thing(y))
- ax.set_xlabel(com.pprint_thing(x))
- class LinePlot(MPLPlot):
- def __init__(self, data, **kwargs):
- self.stacked = kwargs.pop('stacked', False)
- if self.stacked:
- data = data.fillna(value=0)
- MPLPlot.__init__(self, data, **kwargs)
- self.x_compat = plot_params['x_compat']
- if 'x_compat' in self.kwds:
- self.x_compat = bool(self.kwds.pop('x_compat'))
- def _index_freq(self):
- from pandas.core.frame import DataFrame
- if isinstance(self.data, (Series, DataFrame)):
- freq = getattr(self.data.index, 'freq', None)
- if freq is None:
- freq = getattr(self.data.index, 'inferred_freq', None)
- if freq == 'B':
- weekdays = np.unique(self.data.index.dayofweek)
- if (5 in weekdays) or (6 in weekdays):
- freq = None
- return freq
- def _is_dynamic_freq(self, freq):
- if isinstance(freq, DateOffset):
- freq = freq.rule_code
- else:
- freq = get_base_alias(freq)
- freq = get_period_alias(freq)
- return freq is not None and self._no_base(freq)
- def _no_base(self, freq):
- # hack this for 0.10.1, creating more technical debt...sigh
- from pandas.core.frame import DataFrame
- if (isinstance(self.data, (Series, DataFrame))
- and isinstance(self.data.index, DatetimeIndex)):
- import pandas.tseries.frequencies as freqmod
- base = freqmod.get_freq(freq)
- x = self.data.index
- if (base <= freqmod.FreqGroup.FR_DAY):
- return x[:1].is_normalized
- return Period(x[0], freq).to_timestamp(tz=x.tz) == x[0]
- return True
- def _use_dynamic_x(self):
- freq = self._index_freq()
- ax = self._get_ax(0)
- ax_freq = getattr(ax, 'freq', None)
- if freq is None: # convert irregular if axes has freq info
- freq = ax_freq
- else: # do not use tsplot if irregular was plotted first
- if (ax_freq is None) and (len(ax.get_lines()) > 0):
- return False
- return (freq is not None) and self._is_dynamic_freq(freq)
- def _is_ts_plot(self):
- # this is slightly deceptive
- return not self.x_compat and self.use_index and self._use_dynamic_x()
- def _make_plot(self):
- self._pos_prior = np.zeros(len(self.data))
- self._neg_prior = np.zeros(len(self.data))
- if self._is_ts_plot():
- data = self._maybe_convert_index(self.data)
- self._make_ts_plot(data)
- else:
- x = self._get_xticks(convert_period=True)
- plotf = self._get_plot_function()
- …
Large files files are truncated, but you can click here to view the full file