/pandas/tools/plotting.py
Python | 3996 lines | 3667 code | 214 blank | 115 comment | 268 complexity | 62ae60e55cfe6d932a716eec2ddd081d MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0
Large files files are truncated, but you can click here to view the full file
- # being a bit too dynamic
- # pylint: disable=E1101
- from __future__ import division
- import warnings
- import re
- from math import ceil
- from collections import namedtuple
- from contextlib import contextmanager
- from distutils.version import LooseVersion
- import numpy as np
- from pandas.types.common import (is_list_like,
- is_integer,
- is_number,
- is_hashable,
- is_iterator)
- from pandas.types.missing import isnull, notnull
- from pandas.util.decorators import cache_readonly, deprecate_kwarg
- from pandas.core.base import PandasObject
- from pandas.core.common import AbstractMethodError, _try_sort
- from pandas.core.generic import _shared_docs, _shared_doc_kwargs
- from pandas.core.index import Index, MultiIndex
- from pandas.core.series import Series, remove_na
- from pandas.tseries.period import PeriodIndex
- from pandas.compat import range, lrange, lmap, map, zip, string_types
- import pandas.compat as compat
- from pandas.formats.printing import pprint_thing
- from pandas.util.decorators import Appender
- try: # mpl optional
- import pandas.tseries.converter as conv
- conv.register() # needs to override so set_xlim works with str/number
- except ImportError:
- pass
- # Extracted from https://gist.github.com/huyng/816622
- # this is the rcParams set when setting display.with_mpl_style
- # to True.
- mpl_stylesheet = {
- 'axes.axisbelow': True,
- 'axes.color_cycle': ['#348ABD',
- '#7A68A6',
- '#A60628',
- '#467821',
- '#CF4457',
- '#188487',
- '#E24A33'],
- 'axes.edgecolor': '#bcbcbc',
- 'axes.facecolor': '#eeeeee',
- 'axes.grid': True,
- 'axes.labelcolor': '#555555',
- 'axes.labelsize': 'large',
- 'axes.linewidth': 1.0,
- 'axes.titlesize': 'x-large',
- 'figure.edgecolor': 'white',
- 'figure.facecolor': 'white',
- 'figure.figsize': (6.0, 4.0),
- 'figure.subplot.hspace': 0.5,
- 'font.family': 'monospace',
- 'font.monospace': ['Andale Mono',
- 'Nimbus Mono L',
- 'Courier New',
- 'Courier',
- 'Fixed',
- 'Terminal',
- 'monospace'],
- 'font.size': 10,
- 'interactive': True,
- 'keymap.all_axes': ['a'],
- 'keymap.back': ['left', 'c', 'backspace'],
- 'keymap.forward': ['right', 'v'],
- 'keymap.fullscreen': ['f'],
- 'keymap.grid': ['g'],
- 'keymap.home': ['h', 'r', 'home'],
- 'keymap.pan': ['p'],
- 'keymap.save': ['s'],
- 'keymap.xscale': ['L', 'k'],
- 'keymap.yscale': ['l'],
- 'keymap.zoom': ['o'],
- 'legend.fancybox': True,
- 'lines.antialiased': True,
- 'lines.linewidth': 1.0,
- 'patch.antialiased': True,
- 'patch.edgecolor': '#EEEEEE',
- 'patch.facecolor': '#348ABD',
- 'patch.linewidth': 0.5,
- 'toolbar': 'toolbar2',
- 'xtick.color': '#555555',
- 'xtick.direction': 'in',
- 'xtick.major.pad': 6.0,
- 'xtick.major.size': 0.0,
- 'xtick.minor.pad': 6.0,
- 'xtick.minor.size': 0.0,
- 'ytick.color': '#555555',
- 'ytick.direction': 'in',
- 'ytick.major.pad': 6.0,
- 'ytick.major.size': 0.0,
- 'ytick.minor.pad': 6.0,
- 'ytick.minor.size': 0.0
- }
- def _mpl_le_1_2_1():
- try:
- import matplotlib as mpl
- return (str(mpl.__version__) <= LooseVersion('1.2.1') and
- str(mpl.__version__)[0] != '0')
- except ImportError:
- return False
- def _mpl_ge_1_3_1():
- try:
- import matplotlib
- # The or v[0] == '0' is because their versioneer is
- # messed up on dev
- return (matplotlib.__version__ >= LooseVersion('1.3.1') or
- matplotlib.__version__[0] == '0')
- except ImportError:
- return False
- def _mpl_ge_1_4_0():
- try:
- import matplotlib
- return (matplotlib.__version__ >= LooseVersion('1.4') or
- matplotlib.__version__[0] == '0')
- except ImportError:
- return False
- def _mpl_ge_1_5_0():
- try:
- import matplotlib
- return (matplotlib.__version__ >= LooseVersion('1.5') or
- matplotlib.__version__[0] == '0')
- except ImportError:
- return False
- if _mpl_ge_1_5_0():
- # Compat with mp 1.5, which uses cycler.
- import cycler
- colors = mpl_stylesheet.pop('axes.color_cycle')
- mpl_stylesheet['axes.prop_cycle'] = cycler.cycler('color', colors)
- def _get_standard_kind(kind):
- return {'density': 'kde'}.get(kind, kind)
- def _get_standard_colors(num_colors=None, colormap=None, color_type='default',
- color=None):
- import matplotlib.pyplot as plt
- if color is None and colormap is not None:
- if isinstance(colormap, compat.string_types):
- import matplotlib.cm as cm
- cmap = colormap
- colormap = cm.get_cmap(colormap)
- if colormap is None:
- raise ValueError("Colormap {0} is not recognized".format(cmap))
- colors = lmap(colormap, np.linspace(0, 1, num=num_colors))
- elif color is not None:
- if colormap is not None:
- warnings.warn("'color' and 'colormap' cannot be used "
- "simultaneously. Using 'color'")
- colors = list(color) if is_list_like(color) else color
- else:
- if color_type == 'default':
- # need to call list() on the result to copy so we don't
- # modify the global rcParams below
- try:
- colors = [c['color']
- for c in list(plt.rcParams['axes.prop_cycle'])]
- except KeyError:
- colors = list(plt.rcParams.get('axes.color_cycle',
- list('bgrcmyk')))
- if isinstance(colors, compat.string_types):
- colors = list(colors)
- elif color_type == 'random':
- import random
- def random_color(column):
- random.seed(column)
- return [random.random() for _ in range(3)]
- colors = lmap(random_color, lrange(num_colors))
- else:
- raise ValueError("color_type must be either 'default' or 'random'")
- if isinstance(colors, compat.string_types):
- import matplotlib.colors
- conv = matplotlib.colors.ColorConverter()
- def _maybe_valid_colors(colors):
- try:
- [conv.to_rgba(c) for c in colors]
- return True
- except ValueError:
- return False
- # check whether the string can be convertable to single color
- maybe_single_color = _maybe_valid_colors([colors])
- # check whether each character can be convertable to colors
- maybe_color_cycle = _maybe_valid_colors(list(colors))
- if maybe_single_color and maybe_color_cycle and len(colors) > 1:
- msg = ("'{0}' can be parsed as both single color and "
- "color cycle. Specify each color using a list "
- "like ['{0}'] or {1}")
- raise ValueError(msg.format(colors, list(colors)))
- elif maybe_single_color:
- colors = [colors]
- else:
- # ``colors`` is regarded as color cycle.
- # mpl will raise error any of them is invalid
- pass
- if len(colors) != num_colors:
- multiple = num_colors // len(colors) - 1
- mod = num_colors % len(colors)
- colors += multiple * colors
- colors += colors[:mod]
- return colors
- class _Options(dict):
- """
- Stores pandas plotting options.
- Allows for parameter aliasing so you can just use parameter names that are
- the same as the plot function parameters, but is stored in a canonical
- format that makes it easy to breakdown into groups later
- """
- # alias so the names are same as plotting method parameter names
- _ALIASES = {'x_compat': 'xaxis.compat'}
- _DEFAULT_KEYS = ['xaxis.compat']
- def __init__(self):
- self['xaxis.compat'] = False
- def __getitem__(self, key):
- key = self._get_canonical_key(key)
- if key not in self:
- raise ValueError('%s is not a valid pandas plotting option' % key)
- return super(_Options, self).__getitem__(key)
- def __setitem__(self, key, value):
- key = self._get_canonical_key(key)
- return super(_Options, self).__setitem__(key, value)
- def __delitem__(self, key):
- key = self._get_canonical_key(key)
- if key in self._DEFAULT_KEYS:
- raise ValueError('Cannot remove default parameter %s' % key)
- return super(_Options, self).__delitem__(key)
- def __contains__(self, key):
- key = self._get_canonical_key(key)
- return super(_Options, self).__contains__(key)
- def reset(self):
- """
- Reset the option store to its initial state
- Returns
- -------
- None
- """
- self.__init__()
- def _get_canonical_key(self, key):
- return self._ALIASES.get(key, key)
- @contextmanager
- def use(self, key, value):
- """
- Temporarily set a parameter value using the with statement.
- Aliasing allowed.
- """
- old_value = self[key]
- try:
- self[key] = value
- yield self
- finally:
- self[key] = old_value
- plot_params = _Options()
- def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
- diagonal='hist', marker='.', density_kwds=None,
- hist_kwds=None, range_padding=0.05, **kwds):
- """
- Draw a matrix of scatter plots.
- Parameters
- ----------
- frame : DataFrame
- alpha : float, optional
- amount of transparency applied
- figsize : (float,float), optional
- a tuple (width, height) in inches
- ax : Matplotlib axis object, optional
- grid : bool, optional
- setting this to True will show the grid
- diagonal : {'hist', 'kde'}
- pick between 'kde' and 'hist' for
- either Kernel Density Estimation or Histogram
- plot in the diagonal
- marker : str, optional
- Matplotlib marker type, default '.'
- hist_kwds : other plotting keyword arguments
- To be passed to hist function
- density_kwds : other plotting keyword arguments
- To be passed to kernel density estimate plot
- range_padding : float, optional
- relative extension of axis range in x and y
- with respect to (x_max - x_min) or (y_max - y_min),
- default 0.05
- kwds : other plotting keyword arguments
- To be passed to scatter function
- Examples
- --------
- >>> df = DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
- >>> scatter_matrix(df, alpha=0.2)
- """
- import matplotlib.pyplot as plt
- df = frame._get_numeric_data()
- n = df.columns.size
- naxes = n * n
- fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax,
- squeeze=False)
- # no gaps between subplots
- fig.subplots_adjust(wspace=0, hspace=0)
- mask = notnull(df)
- marker = _get_marker_compat(marker)
- hist_kwds = hist_kwds or {}
- density_kwds = density_kwds or {}
- # workaround because `c='b'` is hardcoded in matplotlibs scatter method
- kwds.setdefault('c', plt.rcParams['patch.facecolor'])
- boundaries_list = []
- for a in df.columns:
- values = df[a].values[mask[a].values]
- rmin_, rmax_ = np.min(values), np.max(values)
- rdelta_ext = (rmax_ - rmin_) * range_padding / 2.
- boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
- for i, a in zip(lrange(n), df.columns):
- for j, b in zip(lrange(n), df.columns):
- ax = axes[i, j]
- if i == j:
- values = df[a].values[mask[a].values]
- # Deal with the diagonal by drawing a histogram there.
- if diagonal == 'hist':
- ax.hist(values, **hist_kwds)
- elif diagonal in ('kde', 'density'):
- from scipy.stats import gaussian_kde
- y = values
- gkde = gaussian_kde(y)
- ind = np.linspace(y.min(), y.max(), 1000)
- ax.plot(ind, gkde.evaluate(ind), **density_kwds)
- ax.set_xlim(boundaries_list[i])
- else:
- common = (mask[a] & mask[b]).values
- ax.scatter(df[b][common], df[a][common],
- marker=marker, alpha=alpha, **kwds)
- ax.set_xlim(boundaries_list[j])
- ax.set_ylim(boundaries_list[i])
- ax.set_xlabel(b)
- ax.set_ylabel(a)
- if j != 0:
- ax.yaxis.set_visible(False)
- if i != n - 1:
- ax.xaxis.set_visible(False)
- if len(df.columns) > 1:
- lim1 = boundaries_list[0]
- locs = axes[0][1].yaxis.get_majorticklocs()
- locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
- adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
- lim0 = axes[0][0].get_ylim()
- adj = adj * (lim0[1] - lim0[0]) + lim0[0]
- axes[0][0].yaxis.set_ticks(adj)
- if np.all(locs == locs.astype(int)):
- # if all ticks are int
- locs = locs.astype(int)
- axes[0][0].yaxis.set_ticklabels(locs)
- _set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
- return axes
- def _gca():
- import matplotlib.pyplot as plt
- return plt.gca()
- def _gcf():
- import matplotlib.pyplot as plt
- return plt.gcf()
- def _get_marker_compat(marker):
- import matplotlib.lines as mlines
- import matplotlib as mpl
- if mpl.__version__ < '1.1.0' and marker == '.':
- return 'o'
- if marker not in mlines.lineMarkers:
- return 'o'
- return marker
- def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
- """RadViz - a multivariate data visualization algorithm
- Parameters:
- -----------
- frame: DataFrame
- class_column: str
- Column name containing class names
- ax: Matplotlib axis object, optional
- color: list or tuple, optional
- Colors to use for the different classes
- colormap : str or matplotlib colormap object, default None
- Colormap to select colors from. If string, load colormap with that name
- from matplotlib.
- kwds: keywords
- Options to pass to matplotlib scatter plotting method
- Returns:
- --------
- ax: Matplotlib axis object
- """
- import matplotlib.pyplot as plt
- import matplotlib.patches as patches
- def normalize(series):
- a = min(series)
- b = max(series)
- return (series - a) / (b - a)
- n = len(frame)
- classes = frame[class_column].drop_duplicates()
- class_col = frame[class_column]
- df = frame.drop(class_column, axis=1).apply(normalize)
- if ax is None:
- ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
- to_plot = {}
- colors = _get_standard_colors(num_colors=len(classes), colormap=colormap,
- color_type='random', color=color)
- for kls in classes:
- to_plot[kls] = [[], []]
- m = len(frame.columns) - 1
- s = np.array([(np.cos(t), np.sin(t))
- for t in [2.0 * np.pi * (i / float(m))
- for i in range(m)]])
- for i in range(n):
- row = df.iloc[i].values
- row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
- y = (s * row_).sum(axis=0) / row.sum()
- kls = class_col.iat[i]
- to_plot[kls][0].append(y[0])
- to_plot[kls][1].append(y[1])
- for i, kls in enumerate(classes):
- ax.scatter(to_plot[kls][0], to_plot[kls][1], color=colors[i],
- label=pprint_thing(kls), **kwds)
- ax.legend()
- ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
- for xy, name in zip(s, df.columns):
- ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
- if xy[0] < 0.0 and xy[1] < 0.0:
- ax.text(xy[0] - 0.025, xy[1] - 0.025, name,
- ha='right', va='top', size='small')
- elif xy[0] < 0.0 and xy[1] >= 0.0:
- ax.text(xy[0] - 0.025, xy[1] + 0.025, name,
- ha='right', va='bottom', size='small')
- elif xy[0] >= 0.0 and xy[1] < 0.0:
- ax.text(xy[0] + 0.025, xy[1] - 0.025, name,
- ha='left', va='top', size='small')
- elif xy[0] >= 0.0 and xy[1] >= 0.0:
- ax.text(xy[0] + 0.025, xy[1] + 0.025, name,
- ha='left', va='bottom', size='small')
- ax.axis('equal')
- return ax
- @deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
- def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
- colormap=None, **kwds):
- """
- Generates a matplotlib plot of Andrews curves, for visualising clusters of
- multivariate data.
- Andrews curves have the functional form:
- f(t) = x_1/sqrt(2) + x_2 sin(t) + x_3 cos(t) +
- x_4 sin(2t) + x_5 cos(2t) + ...
- Where x coefficients correspond to the values of each dimension and t is
- linearly spaced between -pi and +pi. Each row of frame then corresponds to
- a single curve.
- Parameters:
- -----------
- frame : DataFrame
- Data to be plotted, preferably normalized to (0.0, 1.0)
- class_column : Name of the column containing class names
- ax : matplotlib axes object, default None
- samples : Number of points to plot in each curve
- color: list or tuple, optional
- Colors to use for the different classes
- colormap : str or matplotlib colormap object, default None
- Colormap to select colors from. If string, load colormap with that name
- from matplotlib.
- kwds: keywords
- Options to pass to matplotlib plotting method
- Returns:
- --------
- ax: Matplotlib axis object
- """
- from math import sqrt, pi
- import matplotlib.pyplot as plt
- def function(amplitudes):
- def f(t):
- x1 = amplitudes[0]
- result = x1 / sqrt(2.0)
- # Take the rest of the coefficients and resize them
- # appropriately. Take a copy of amplitudes as otherwise numpy
- # deletes the element from amplitudes itself.
- coeffs = np.delete(np.copy(amplitudes), 0)
- coeffs.resize(int((coeffs.size + 1) / 2), 2)
- # Generate the harmonics and arguments for the sin and cos
- # functions.
- harmonics = np.arange(0, coeffs.shape[0]) + 1
- trig_args = np.outer(harmonics, t)
- result += np.sum(coeffs[:, 0, np.newaxis] * np.sin(trig_args) +
- coeffs[:, 1, np.newaxis] * np.cos(trig_args),
- axis=0)
- return result
- return f
- n = len(frame)
- class_col = frame[class_column]
- classes = frame[class_column].drop_duplicates()
- df = frame.drop(class_column, axis=1)
- t = np.linspace(-pi, pi, samples)
- used_legends = set([])
- color_values = _get_standard_colors(num_colors=len(classes),
- colormap=colormap, color_type='random',
- color=color)
- colors = dict(zip(classes, color_values))
- if ax is None:
- ax = plt.gca(xlim=(-pi, pi))
- for i in range(n):
- row = df.iloc[i].values
- f = function(row)
- y = f(t)
- kls = class_col.iat[i]
- label = pprint_thing(kls)
- if label not in used_legends:
- used_legends.add(label)
- ax.plot(t, y, color=colors[kls], label=label, **kwds)
- else:
- ax.plot(t, y, color=colors[kls], **kwds)
- ax.legend(loc='upper right')
- ax.grid()
- return ax
- def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
- """Bootstrap plot.
- Parameters:
- -----------
- series: Time series
- fig: matplotlib figure object, optional
- size: number of data points to consider during each sampling
- samples: number of times the bootstrap procedure is performed
- kwds: optional keyword arguments for plotting commands, must be accepted
- by both hist and plot
- Returns:
- --------
- fig: matplotlib figure
- """
- import random
- import matplotlib.pyplot as plt
- # random.sample(ndarray, int) fails on python 3.3, sigh
- data = list(series.values)
- samplings = [random.sample(data, size) for _ in range(samples)]
- means = np.array([np.mean(sampling) for sampling in samplings])
- medians = np.array([np.median(sampling) for sampling in samplings])
- midranges = np.array([(min(sampling) + max(sampling)) * 0.5
- for sampling in samplings])
- if fig is None:
- fig = plt.figure()
- x = lrange(samples)
- axes = []
- ax1 = fig.add_subplot(2, 3, 1)
- ax1.set_xlabel("Sample")
- axes.append(ax1)
- ax1.plot(x, means, **kwds)
- ax2 = fig.add_subplot(2, 3, 2)
- ax2.set_xlabel("Sample")
- axes.append(ax2)
- ax2.plot(x, medians, **kwds)
- ax3 = fig.add_subplot(2, 3, 3)
- ax3.set_xlabel("Sample")
- axes.append(ax3)
- ax3.plot(x, midranges, **kwds)
- ax4 = fig.add_subplot(2, 3, 4)
- ax4.set_xlabel("Mean")
- axes.append(ax4)
- ax4.hist(means, **kwds)
- ax5 = fig.add_subplot(2, 3, 5)
- ax5.set_xlabel("Median")
- axes.append(ax5)
- ax5.hist(medians, **kwds)
- ax6 = fig.add_subplot(2, 3, 6)
- ax6.set_xlabel("Midrange")
- axes.append(ax6)
- ax6.hist(midranges, **kwds)
- for axis in axes:
- plt.setp(axis.get_xticklabels(), fontsize=8)
- plt.setp(axis.get_yticklabels(), fontsize=8)
- return fig
- @deprecate_kwarg(old_arg_name='colors', new_arg_name='color')
- @deprecate_kwarg(old_arg_name='data', new_arg_name='frame', stacklevel=3)
- def parallel_coordinates(frame, class_column, cols=None, ax=None, color=None,
- use_columns=False, xticks=None, colormap=None,
- axvlines=True, axvlines_kwds=None, **kwds):
- """Parallel coordinates plotting.
- Parameters
- ----------
- frame: DataFrame
- class_column: str
- Column name containing class names
- cols: list, optional
- A list of column names to use
- ax: matplotlib.axis, optional
- matplotlib axis object
- color: list or tuple, optional
- Colors to use for the different classes
- use_columns: bool, optional
- If true, columns will be used as xticks
- xticks: list or tuple, optional
- A list of values to use for xticks
- colormap: str or matplotlib colormap, default None
- Colormap to use for line colors.
- axvlines: bool, optional
- If true, vertical lines will be added at each xtick
- axvlines_kwds: keywords, optional
- Options to be passed to axvline method for vertical lines
- kwds: keywords
- Options to pass to matplotlib plotting method
- Returns
- -------
- ax: matplotlib axis object
- Examples
- --------
- >>> from pandas import read_csv
- >>> from pandas.tools.plotting import parallel_coordinates
- >>> from matplotlib import pyplot as plt
- >>> df = read_csv('https://raw.github.com/pydata/pandas/master'
- '/pandas/tests/data/iris.csv')
- >>> parallel_coordinates(df, 'Name', color=('#556270',
- '#4ECDC4', '#C7F464'))
- >>> plt.show()
- """
- if axvlines_kwds is None:
- axvlines_kwds = {'linewidth': 1, 'color': 'black'}
- import matplotlib.pyplot as plt
- n = len(frame)
- classes = frame[class_column].drop_duplicates()
- class_col = frame[class_column]
- if cols is None:
- df = frame.drop(class_column, axis=1)
- else:
- df = frame[cols]
- used_legends = set([])
- ncols = len(df.columns)
- # determine values to use for xticks
- if use_columns is True:
- if not np.all(np.isreal(list(df.columns))):
- raise ValueError('Columns must be numeric to be used as xticks')
- x = df.columns
- elif xticks is not None:
- if not np.all(np.isreal(xticks)):
- raise ValueError('xticks specified must be numeric')
- elif len(xticks) != ncols:
- raise ValueError('Length of xticks must match number of columns')
- x = xticks
- else:
- x = lrange(ncols)
- if ax is None:
- ax = plt.gca()
- color_values = _get_standard_colors(num_colors=len(classes),
- colormap=colormap, color_type='random',
- color=color)
- colors = dict(zip(classes, color_values))
- for i in range(n):
- y = df.iloc[i].values
- kls = class_col.iat[i]
- label = pprint_thing(kls)
- if label not in used_legends:
- used_legends.add(label)
- ax.plot(x, y, color=colors[kls], label=label, **kwds)
- else:
- ax.plot(x, y, color=colors[kls], **kwds)
- if axvlines:
- for i in x:
- ax.axvline(i, **axvlines_kwds)
- ax.set_xticks(x)
- ax.set_xticklabels(df.columns)
- ax.set_xlim(x[0], x[-1])
- ax.legend(loc='upper right')
- ax.grid()
- return ax
- def lag_plot(series, lag=1, ax=None, **kwds):
- """Lag plot for time series.
- Parameters:
- -----------
- series: Time series
- lag: lag of the scatter plot, default 1
- ax: Matplotlib axis object, optional
- kwds: Matplotlib scatter method keyword arguments, optional
- Returns:
- --------
- ax: Matplotlib axis object
- """
- import matplotlib.pyplot as plt
- # workaround because `c='b'` is hardcoded in matplotlibs scatter method
- kwds.setdefault('c', plt.rcParams['patch.facecolor'])
- data = series.values
- y1 = data[:-lag]
- y2 = data[lag:]
- if ax is None:
- ax = plt.gca()
- ax.set_xlabel("y(t)")
- ax.set_ylabel("y(t + %s)" % lag)
- ax.scatter(y1, y2, **kwds)
- return ax
- def autocorrelation_plot(series, ax=None, **kwds):
- """Autocorrelation plot for time series.
- Parameters:
- -----------
- series: Time series
- ax: Matplotlib axis object, optional
- kwds : keywords
- Options to pass to matplotlib plotting method
- Returns:
- -----------
- ax: Matplotlib axis object
- """
- import matplotlib.pyplot as plt
- n = len(series)
- data = np.asarray(series)
- if ax is None:
- ax = plt.gca(xlim=(1, n), ylim=(-1.0, 1.0))
- mean = np.mean(data)
- c0 = np.sum((data - mean) ** 2) / float(n)
- def r(h):
- return ((data[:n - h] - mean) *
- (data[h:] - mean)).sum() / float(n) / c0
- x = np.arange(n) + 1
- y = lmap(r, x)
- z95 = 1.959963984540054
- z99 = 2.5758293035489004
- ax.axhline(y=z99 / np.sqrt(n), linestyle='--', color='grey')
- ax.axhline(y=z95 / np.sqrt(n), color='grey')
- ax.axhline(y=0.0, color='black')
- ax.axhline(y=-z95 / np.sqrt(n), color='grey')
- ax.axhline(y=-z99 / np.sqrt(n), linestyle='--', color='grey')
- ax.set_xlabel("Lag")
- ax.set_ylabel("Autocorrelation")
- ax.plot(x, y, **kwds)
- if 'label' in kwds:
- ax.legend()
- ax.grid()
- return ax
- class MPLPlot(object):
- """
- Base class for assembling a pandas plot using matplotlib
- Parameters
- ----------
- data :
- """
- @property
- def _kind(self):
- """Specify kind str. Must be overridden in child class"""
- raise NotImplementedError
- _layout_type = 'vertical'
- _default_rot = 0
- orientation = None
- _pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog',
- 'mark_right', 'stacked']
- _attr_defaults = {'logy': False, 'logx': False, 'loglog': False,
- 'mark_right': True, 'stacked': False}
- def __init__(self, data, kind=None, by=None, subplots=False, sharex=None,
- sharey=False, use_index=True,
- figsize=None, grid=None, legend=True, rot=None,
- ax=None, fig=None, title=None, xlim=None, ylim=None,
- xticks=None, yticks=None,
- sort_columns=False, fontsize=None,
- secondary_y=False, colormap=None,
- table=False, layout=None, **kwds):
- self.data = data
- self.by = by
- self.kind = kind
- self.sort_columns = sort_columns
- self.subplots = subplots
- if sharex is None:
- if ax is None:
- self.sharex = True
- else:
- # if we get an axis, the users should do the visibility
- # setting...
- self.sharex = False
- else:
- self.sharex = sharex
- self.sharey = sharey
- self.figsize = figsize
- self.layout = layout
- self.xticks = xticks
- self.yticks = yticks
- self.xlim = xlim
- self.ylim = ylim
- self.title = title
- self.use_index = use_index
- self.fontsize = fontsize
- if rot is not None:
- self.rot = rot
- # need to know for format_date_labels since it's rotated to 30 by
- # default
- self._rot_set = True
- else:
- self._rot_set = False
- self.rot = self._default_rot
- if grid is None:
- grid = False if secondary_y else self.plt.rcParams['axes.grid']
- self.grid = grid
- self.legend = legend
- self.legend_handles = []
- self.legend_labels = []
- for attr in self._pop_attributes:
- value = kwds.pop(attr, self._attr_defaults.get(attr, None))
- setattr(self, attr, value)
- self.ax = ax
- self.fig = fig
- self.axes = None
- # parse errorbar input if given
- xerr = kwds.pop('xerr', None)
- yerr = kwds.pop('yerr', None)
- self.errors = {}
- for kw, err in zip(['xerr', 'yerr'], [xerr, yerr]):
- self.errors[kw] = self._parse_errorbars(kw, err)
- if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, Index)):
- secondary_y = [secondary_y]
- self.secondary_y = secondary_y
- # ugly TypeError if user passes matplotlib's `cmap` name.
- # Probably better to accept either.
- if 'cmap' in kwds and colormap:
- raise TypeError("Only specify one of `cmap` and `colormap`.")
- elif 'cmap' in kwds:
- self.colormap = kwds.pop('cmap')
- else:
- self.colormap = colormap
- self.table = table
- self.kwds = kwds
- self._validate_color_args()
- def _validate_color_args(self):
- if 'color' not in self.kwds and 'colors' in self.kwds:
- warnings.warn(("'colors' is being deprecated. Please use 'color'"
- "instead of 'colors'"))
- colors = self.kwds.pop('colors')
- self.kwds['color'] = colors
- if ('color' in self.kwds and self.nseries == 1):
- # support series.plot(color='green')
- self.kwds['color'] = [self.kwds['color']]
- if ('color' in self.kwds or 'colors' in self.kwds) and \
- self.colormap is not None:
- warnings.warn("'color' and 'colormap' cannot be used "
- "simultaneously. Using 'color'")
- if 'color' in self.kwds and self.style is not None:
- if is_list_like(self.style):
- styles = self.style
- else:
- styles = [self.style]
- # need only a single match
- for s in styles:
- if re.match('^[a-z]+?', s) is not None:
- raise ValueError(
- "Cannot pass 'style' string with a color "
- "symbol and 'color' keyword argument. Please"
- " use one or the other or pass 'style' "
- "without a color symbol")
- def _iter_data(self, data=None, keep_index=False, fillna=None):
- if data is None:
- data = self.data
- if fillna is not None:
- data = data.fillna(fillna)
- # TODO: unused?
- # if self.sort_columns:
- # columns = _try_sort(data.columns)
- # else:
- # columns = data.columns
- for col, values in data.iteritems():
- if keep_index is True:
- yield col, values
- else:
- yield col, values.values
- @property
- def nseries(self):
- if self.data.ndim == 1:
- return 1
- else:
- return self.data.shape[1]
- def draw(self):
- self.plt.draw_if_interactive()
- def generate(self):
- self._args_adjust()
- self._compute_plot_data()
- self._setup_subplots()
- self._make_plot()
- self._add_table()
- self._make_legend()
- self._adorn_subplots()
- for ax in self.axes:
- self._post_plot_logic_common(ax, self.data)
- self._post_plot_logic(ax, self.data)
- def _args_adjust(self):
- pass
- def _has_plotted_object(self, ax):
- """check whether ax has data"""
- return (len(ax.lines) != 0 or
- len(ax.artists) != 0 or
- len(ax.containers) != 0)
- def _maybe_right_yaxis(self, ax, axes_num):
- if not self.on_right(axes_num):
- # secondary axes may be passed via ax kw
- return self._get_ax_layer(ax)
- if hasattr(ax, 'right_ax'):
- # if it has right_ax proparty, ``ax`` must be left axes
- return ax.right_ax
- elif hasattr(ax, 'left_ax'):
- # if it has left_ax proparty, ``ax`` must be right axes
- return ax
- else:
- # otherwise, create twin axes
- orig_ax, new_ax = ax, ax.twinx()
- # TODO: use Matplotlib public API when available
- new_ax._get_lines = orig_ax._get_lines
- new_ax._get_patches_for_fill = orig_ax._get_patches_for_fill
- orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
- if not self._has_plotted_object(orig_ax): # no data on left y
- orig_ax.get_yaxis().set_visible(False)
- return new_ax
- def _setup_subplots(self):
- if self.subplots:
- fig, axes = _subplots(naxes=self.nseries,
- sharex=self.sharex, sharey=self.sharey,
- figsize=self.figsize, ax=self.ax,
- layout=self.layout,
- layout_type=self._layout_type)
- else:
- if self.ax is None:
- fig = self.plt.figure(figsize=self.figsize)
- axes = fig.add_subplot(111)
- else:
- fig = self.ax.get_figure()
- if self.figsize is not None:
- fig.set_size_inches(self.figsize)
- axes = self.ax
- axes = _flatten(axes)
- if self.logx or self.loglog:
- [a.set_xscale('log') for a in axes]
- if self.logy or self.loglog:
- [a.set_yscale('log') for a in axes]
- self.fig = fig
- self.axes = axes
- @property
- def result(self):
- """
- Return result axes
- """
- if self.subplots:
- if self.layout is not None and not is_list_like(self.ax):
- return self.axes.reshape(*self.layout)
- else:
- return self.axes
- else:
- sec_true = isinstance(self.secondary_y, bool) and self.secondary_y
- all_sec = (is_list_like(self.secondary_y) and
- len(self.secondary_y) == self.nseries)
- if (sec_true or all_sec):
- # if all data is plotted on secondary, return right axes
- return self._get_ax_layer(self.axes[0], primary=False)
- else:
- return self.axes[0]
- def _compute_plot_data(self):
- data = self.data
- if isinstance(data, Series):
- label = self.label
- if label is None and data.name is None:
- label = 'None'
- data = data.to_frame(name=label)
- numeric_data = data._convert(datetime=True)._get_numeric_data()
- try:
- is_empty = numeric_data.empty
- except AttributeError:
- is_empty = not len(numeric_data)
- # no empty frames or series allowed
- if is_empty:
- raise TypeError('Empty {0!r}: no numeric data to '
- 'plot'.format(numeric_data.__class__.__name__))
- self.data = numeric_data
- def _make_plot(self):
- raise AbstractMethodError(self)
- def _add_table(self):
- if self.table is False:
- return
- elif self.table is True:
- data = self.data.transpose()
- else:
- data = self.table
- ax = self._get_ax(0)
- table(ax, data)
- def _post_plot_logic_common(self, ax, data):
- """Common post process for each axes"""
- labels = [pprint_thing(key) for key in data.index]
- labels = dict(zip(range(len(data.index)), labels))
- if self.orientation == 'vertical' or self.orientation is None:
- if self._need_to_set_index:
- xticklabels = [labels.get(x, '') for x in ax.get_xticks()]
- ax.set_xticklabels(xticklabels)
- self._apply_axis_properties(ax.xaxis, rot=self.rot,
- fontsize=self.fontsize)
- self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
- elif self.orientation == 'horizontal':
- if self._need_to_set_index:
- yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
- ax.set_yticklabels(yticklabels)
- self._apply_axis_properties(ax.yaxis, rot=self.rot,
- fontsize=self.fontsize)
- self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)
- else: # pragma no cover
- raise ValueError
- def _post_plot_logic(self, ax, data):
- """Post process for each axes. Overridden in child classes"""
- pass
- def _adorn_subplots(self):
- """Common post process unrelated to data"""
- if len(self.axes) > 0:
- all_axes = self._get_subplots()
- nrows, ncols = self._get_axes_layout()
- _handle_shared_axes(axarr=all_axes, nplots=len(all_axes),
- naxes=nrows * ncols, nrows=nrows,
- ncols=ncols, sharex=self.sharex,
- sharey=self.sharey)
- for ax in self.axes:
- if self.yticks is not None:
- ax.set_yticks(self.yticks)
- if self.xticks is not None:
- ax.set_xticks(self.xticks)
- if self.ylim is not None:
- ax.set_ylim(self.ylim)
- if self.xlim is not None:
- ax.set_xlim(self.xlim)
- ax.grid(self.grid)
- if self.title:
- if self.subplots:
- self.fig.suptitle(self.title)
- else:
- self.axes[0].set_title(self.title)
- def _apply_axis_properties(self, axis, rot=None, fontsize=None):
- labels = axis.get_majorticklabels() + axis.get_minorticklabels()
- for label in labels:
- if rot is not None:
- label.set_rotation(rot)
- if fontsize is not None:
- label.set_fontsize(fontsize)
- @property
- def legend_title(self):
- if not isinstance(self.data.columns, MultiIndex):
- name = self.data.columns.name
- if name is not None:
- name = pprint_thing(name)
- return name
- else:
- stringified = map(pprint_thing,
- self.data.columns.names)
- return ','.join(stringified)
- def _add_legend_handle(self, handle, label, index=None):
- if label is not None:
- if self.mark_right and index is not None:
- if self.on_right(index):
- label = label + ' (right)'
- self.legend_handles.append(handle)
- self.legend_labels.append(label)
- def _make_legend(self):
- ax, leg = self._get_ax_legend(self.axes[0])
- handles = []
- labels = []
- title = ''
- if not self.subplots:
- if leg is not None:
- title = leg.get_title().get_text()
- handles = leg.legendHandles
- labels = [x.get_text() for x in leg.get_texts()]
- if self.legend:
- if self.legend == 'reverse':
- self.legend_handles = reversed(self.legend_handles)
- self.legend_labels = reversed(self.legend_labels)
- handles += self.legend_handles
- labels += self.legend_labels
- if self.legend_title is not None:
- title = self.legend_title
- if len(handles) > 0:
- ax.legend(handles, labels, loc='best', title=title)
- elif self.subplots and self.legend:
- for ax in self.axes:
- if ax.get_visible():
- ax.legend(loc='best')
- def _get_ax_legend(self, ax):
- leg = ax.get_legend()
- other_ax = (getattr(ax, 'left_ax', None) or
- getattr(ax, 'right_ax', None))
- other_leg = None
- if other_ax is not None:
- other_leg = other_ax.get_legend()
- if leg is None and other_leg is not None:
- leg = other_leg
- ax = other_ax
- return ax, leg
- @cache_readonly
- def plt(self):
- import matplotlib.pyplot as plt
- return plt
- @staticmethod
- def mpl_ge_1_3_1():
- return _mpl_ge_1_3_1()
- @staticmethod
- def mpl_ge_1_5_0():
- return _mpl_ge_1_5_0()
- _need_to_set_index = False
- def _get_xticks(self, convert_period=False):
- index = self.data.index
- is_datetype = index.inferred_type in ('datetime', 'date',
- 'datetime64', 'time')
- if self.use_index:
- if convert_period and isinstance(index, PeriodIndex):
- self.data = self.data.reindex(index=index.sort_values())
- x = self.data.index.to_timestamp()._mpl_repr()
- elif index.is_numeric():
- """
- Matplotlib supports numeric values or datetime objects as
- xaxis values. Taking LBYL approach here, by the time
- matplotlib raises exception when using non numeric/datetime
- values for xaxis, several actions are already taken by plt.
- """
- x = index._mpl_repr()
- elif is_datetype:
- self.data = self.data.sort_index()
- x = self.data.index._mpl_repr()
- else:
- self._need_to_set_index = True
- x = lrange(len(index))
- else:
- x = lrange(len(index))
- return x
- @classmethod
- def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds):
- mask = isnull(y)
- if mask.any():
- y = np.ma.array(y)
- y = np.ma.masked_where(mask, y)
- if isinstance(x, Index):
- x = x._mpl_repr()
- if is_errorbar:
- if 'xerr' in kwds:
- kwds['xerr'] = np.array(kwds.get('xerr'))
- if 'yerr' in kwds:
- kwds['yerr'] = np.array(kwds.get('yerr'))
- return ax.errorbar(x, y, **kwds)
- else:
- # prevent style kwarg from going to errorbar, where it is
- # unsupported
- if style is not None:
- args = (x, y, style)
- else:
- args = (x, y)
- return ax.plot(*args, **kwds)
- def _get_index_name(self):
- if isinstance(self.data.index, MultiIndex):
- name = self.data.index.names
- if any(x is not None for x in name):
- name = ','.join([pprint_thing(x) for x in name])
- else:
- name = None
- else:
- name = self.data.index.name
- if name is not None:
- name = pprint_thing(name)
- return name
- @classmethod
- def _get_ax_layer(cls, ax, primary=True):
- """get left (primary) or right (secondary) axes"""
- if primary:
- return getattr(ax, 'left_ax', ax)
- else:
- return getattr(ax, 'right_ax', ax)
- def _get_ax(self, i):
- # get the twinx ax if appropriate
- if self.subplots:
- ax = self.axes[i]
- ax = self._maybe_right_yaxis(ax, i)
- self.axes[i] = ax
- else:
- ax = self.axes[0]
- ax = self._maybe_right_yaxis(ax, i)
- ax.get_yaxis().set_visible(True)
- return ax
- def on_right(self, i):
- if isinstance(self.secondary_y, bool):
- return self.secondary_y
- if isinstance(self.secondary_y, (tuple, list, np.ndarray, Index)):
- return self.data.columns[i] in self.secondary_y
- def _apply_style_colors(self, colors, kwds, col_num, label):
- """
- Manage style and color based on column number and its label.
- Returns tuple of appropriate style and kwds which "color" may be added.
- """
- style = None
- if self.style is not None:
- if isinstance(self.style, list):
- try:
- style = self.style[col_num]
- except IndexError:
- pass
- elif isinstance(self.style, dict):
- style = self.style.get(label, style)
- else:
- style = self.style
- has_color = 'color' in kwds or self.colormap is not None
- nocolor_style = style is None or re.match('[a-z]+', style) is None
- if (has_color or self.subplots) and nocolor_style:
- kwds['color'] = colors[col_num % len(colors)]
- return style, kwds
- def _get_colors(self, num_colors=None, color_kwds='color'):
- if num_colors is None:
- num_colors = self.nseries
- return _get_standard_colors(num_colors=num_colors,
- colormap=self.colormap,
- color=self.kwds.get(color_kwds))
- def _parse_errorbars(self, label, err):
- """
- Look for error keyword arguments and return the actual errorbar data
- or return the error DataFrame/dict
- Error bars can be specified in several ways:
- Series: the user provides a pandas.Series object of the same
- length as the data
- ndarray: provides a np.ndarray of the same length as the data
- DataFrame/dict: error values are paired with keys matching the
- key in the plotted DataFrame
- str: the name of the column within the plotted DataFrame
- """
- if err is None:
- return None
- from pandas import DataFrame, Series
- def match_labels(data, e):
- e = e.reindex_axis(data.index)
- return e
- # key-matched DataFrame
- if isinstance(err, DataFrame):
- err = match_labels(self.data, err)
- # key-matched dict
- elif isinstance(err, dict):
- pass
- # Series of error values
- elif isinstance(err, Series):
- # broadcast error series across data
- err = match_labels(self.data, err)
- err = np.atleast_2d(err)
- err = np.tile(err, (self.nseries, 1))
- # errors are a column in the dataframe
- elif isinstance(err, string_types):
- evalues = self.data[err].values
- self.data = self.data[self.data.columns.drop(err)]
- err = np.atleast_2d(evalues)
- err = np.tile(err, (self.nseries, 1))
- elif is_list_like(err):
- if is_iterator(err):
- err = np.atleast_2d(list(err))
- else:
- # raw error values
- err = np.atleast_2d(err)
- err_shape = err.shape
- # asymmetrical error bars
- if err.ndim == 3:
- if (err_shape[0] != self.nseries) or \
- (err_shape[1] != 2) or \
- (err_shape[2] != len(self.data)):
- msg = "Asymmetrical error bars should be provided " + \
- "with the shape (%u, 2, %u)" % \
- (self.nseries, len(self.data))
- raise ValueError(msg)
- # broadcast errors to each data series
- if len(err) == 1:
- err = np.tile(err, (self.nseries, 1))
- elif is_number(err):
- err = np.tile([err], (self.nseries, len(self.data)))
- else:
- msg = "No valid %s detected" % label
- raise ValueError(msg)
- return err
- def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
- from pandas import DataFrame
- errors = {}
- for kw, flag in zip(['xerr', 'yerr'], [xerr, yerr]):
- if flag:
- err = self.errors[kw]
- # user provided label-matched dataframe of errors
- if isinstance(err, (DataFrame, dict)):
- if label is not None and label in err.keys():
- err = err[label]
- else:
- err = None
- elif index is not None and err is not None:
- err = err[index]
- if err is not None:
- errors[kw] = err
- return errors
- def _get_subplots(self):
- from matplotlib.axes import Subplot
- return [ax for ax in self.axes[0].get_figure().get_axes()
- if isinstance(ax, Subplot)]
- def _get_axes_layout(self):
- axes = self._get_…
Large files files are truncated, but you can click here to view the full file