/pandas/util/doctools.py
Python | 190 lines | 145 code | 21 blank | 24 comment | 20 complexity | 8630ad34f087afd4b9dedfc7b16acbb3 MD5 | raw file
Possible License(s): BSD-3-Clause, Apache-2.0
- import numpy as np
- import pandas as pd
- import pandas.compat as compat
- class TablePlotter(object):
- """
- Layout some DataFrames in vertical/horizontal layout for explanation.
- Used in merging.rst
- """
- def __init__(self, cell_width=0.37, cell_height=0.25, font_size=7.5):
- self.cell_width = cell_width
- self.cell_height = cell_height
- self.font_size = font_size
- def _shape(self, df):
- """Calcurate table chape considering index levels"""
- row, col = df.shape
- return row + df.columns.nlevels, col + df.index.nlevels
- def _get_cells(self, left, right, vertical):
- """Calcurate appropriate figure size based on left and right data"""
- if vertical:
- # calcurate required number of cells
- vcells = max(sum([self._shape(l)[0] for l in left]),
- self._shape(right)[0])
- hcells = (max([self._shape(l)[1] for l in left]) +
- self._shape(right)[1])
- else:
- vcells = max([self._shape(l)[0] for l in left] +
- [self._shape(right)[0]])
- hcells = sum([self._shape(l)[1] for l in left] +
- [self._shape(right)[1]])
- return hcells, vcells
- def plot(self, left, right, labels=None, vertical=True):
- """
- Plot left / right DataFrames in specified layout.
- Parameters
- ----------
- left : list of DataFrames before operation is applied
- right : DataFrame of operation result
- labels : list of str to be drawn as titles of left DataFrames
- vertical : bool
- If True, use vertical layout. If False, use horizontal layout.
- """
- import matplotlib.pyplot as plt
- import matplotlib.gridspec as gridspec
- if not isinstance(left, list):
- left = [left]
- left = [self._conv(l) for l in left]
- right = self._conv(right)
- hcells, vcells = self._get_cells(left, right, vertical)
- if vertical:
- figsize = self.cell_width * hcells, self.cell_height * vcells
- else:
- # include margin for titles
- figsize = self.cell_width * hcells, self.cell_height * vcells
- fig = plt.figure(figsize=figsize)
- if vertical:
- gs = gridspec.GridSpec(len(left), hcells)
- # left
- max_left_cols = max([self._shape(l)[1] for l in left])
- max_left_rows = max([self._shape(l)[0] for l in left])
- for i, (l, label) in enumerate(zip(left, labels)):
- ax = fig.add_subplot(gs[i, 0:max_left_cols])
- self._make_table(ax, l, title=label,
- height=1.0 / max_left_rows)
- # right
- ax = plt.subplot(gs[:, max_left_cols:])
- self._make_table(ax, right, title='Result', height=1.05 / vcells)
- fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
- else:
- max_rows = max([self._shape(df)[0] for df in left + [right]])
- height = 1.0 / np.max(max_rows)
- gs = gridspec.GridSpec(1, hcells)
- # left
- i = 0
- for l, label in zip(left, labels):
- sp = self._shape(l)
- ax = fig.add_subplot(gs[0, i:i + sp[1]])
- self._make_table(ax, l, title=label, height=height)
- i += sp[1]
- # right
- ax = plt.subplot(gs[0, i:])
- self._make_table(ax, right, title='Result', height=height)
- fig.subplots_adjust(top=0.85, bottom=0.05, left=0.05, right=0.95)
- return fig
- def _conv(self, data):
- """Convert each input to appropriate for table outplot"""
- if isinstance(data, pd.Series):
- if data.name is None:
- data = data.to_frame(name='')
- else:
- data = data.to_frame()
- data = data.fillna('NaN')
- return data
- def _insert_index(self, data):
- # insert is destructive
- data = data.copy()
- idx_nlevels = data.index.nlevels
- if idx_nlevels == 1:
- data.insert(0, 'Index', data.index)
- else:
- for i in range(idx_nlevels):
- data.insert(i, 'Index{0}'.format(i),
- data.index.get_level_values(i))
- col_nlevels = data.columns.nlevels
- if col_nlevels > 1:
- col = data.columns.get_level_values(0)
- values = [data.columns.get_level_values(i).values
- for i in range(1, col_nlevels)]
- col_df = pd.DataFrame(values)
- data.columns = col_df.columns
- data = pd.concat([col_df, data])
- data.columns = col
- return data
- def _make_table(self, ax, df, title, height=None):
- if df is None:
- ax.set_visible(False)
- return
- import pandas.tools.plotting as plotting
- idx_nlevels = df.index.nlevels
- col_nlevels = df.columns.nlevels
- # must be convert here to get index levels for colorization
- df = self._insert_index(df)
- tb = plotting.table(ax, df, loc=9)
- tb.set_fontsize(self.font_size)
- if height is None:
- height = 1.0 / (len(df) + 1)
- props = tb.properties()
- for (r, c), cell in compat.iteritems(props['celld']):
- if c == -1:
- cell.set_visible(False)
- elif r < col_nlevels and c < idx_nlevels:
- cell.set_visible(False)
- elif r < col_nlevels or c < idx_nlevels:
- cell.set_facecolor('#AAAAAA')
- cell.set_height(height)
- ax.set_title(title, size=self.font_size)
- ax.axis('off')
- if __name__ == "__main__":
- import matplotlib.pyplot as plt
- p = TablePlotter()
- df1 = pd.DataFrame({'A': [10, 11, 12],
- 'B': [20, 21, 22],
- 'C': [30, 31, 32]})
- df2 = pd.DataFrame({'A': [10, 12],
- 'C': [30, 32]})
- p.plot([df1, df2], pd.concat([df1, df2]),
- labels=['df1', 'df2'], vertical=True)
- plt.show()
- df3 = pd.DataFrame({'X': [10, 12],
- 'Z': [30, 32]})
- p.plot([df1, df3], pd.concat([df1, df3], axis=1),
- labels=['df1', 'df2'], vertical=False)
- plt.show()
- idx = pd.MultiIndex.from_tuples([(1, 'A'), (1, 'B'), (1, 'C'),
- (2, 'A'), (2, 'B'), (2, 'C')])
- col = pd.MultiIndex.from_tuples([(1, 'A'), (1, 'B')])
- df3 = pd.DataFrame({'v1': [1, 2, 3, 4, 5, 6],
- 'v2': [5, 6, 7, 8, 9, 10]},
- index=idx)
- df3.columns = col
- p.plot(df3, df3, labels=['df3'])
- plt.show()