PageRenderTime 55ms CodeModel.GetById 24ms RepoModel.GetById 1ms app.codeStats 0ms

/pandas/tests/test_graphics.py

https://github.com/lenolib/pandas
Python | 503 lines | 400 code | 95 blank | 8 comment | 34 complexity | 3f81103f2a49a1fcba1a9d3c09572ec5 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import nose
  2. import os
  3. import string
  4. import unittest
  5. from datetime import datetime
  6. from pandas import Series, DataFrame, MultiIndex, PeriodIndex, date_range
  7. import pandas.util.testing as tm
  8. import numpy as np
  9. from numpy.testing import assert_array_equal
  10. from numpy.testing.decorators import slow
  11. import pandas.tools.plotting as plotting
  12. class TestSeriesPlots(unittest.TestCase):
  13. @classmethod
  14. def setUpClass(cls):
  15. import sys
  16. if 'IPython' in sys.modules:
  17. raise nose.SkipTest
  18. try:
  19. import matplotlib as mpl
  20. mpl.use('Agg', warn=False)
  21. except ImportError:
  22. raise nose.SkipTest
  23. def setUp(self):
  24. self.ts = tm.makeTimeSeries()
  25. self.ts.name = 'ts'
  26. self.series = tm.makeStringSeries()
  27. self.series.name = 'series'
  28. self.iseries = tm.makePeriodSeries()
  29. self.iseries.name = 'iseries'
  30. @slow
  31. def test_plot(self):
  32. _check_plot_works(self.ts.plot, label='foo')
  33. _check_plot_works(self.ts.plot, use_index=False)
  34. _check_plot_works(self.ts.plot, rot=0)
  35. _check_plot_works(self.ts.plot, style='.', logy=True)
  36. _check_plot_works(self.ts.plot, style='.', logx=True)
  37. _check_plot_works(self.ts.plot, style='.', loglog=True)
  38. _check_plot_works(self.ts[:10].plot, kind='bar')
  39. _check_plot_works(self.series[:5].plot, kind='bar')
  40. _check_plot_works(self.series[:5].plot, kind='line')
  41. _check_plot_works(self.series[:5].plot, kind='barh')
  42. _check_plot_works(self.series[:10].plot, kind='barh')
  43. Series(np.random.randn(10)).plot(kind='bar',color='black')
  44. @slow
  45. def test_bar_colors(self):
  46. import matplotlib.pyplot as plt
  47. import matplotlib.colors as colors
  48. default_colors = 'brgyk'
  49. custom_colors = 'rgcby'
  50. plt.close('all')
  51. df = DataFrame(np.random.randn(5, 5))
  52. ax = df.plot(kind='bar')
  53. rects = ax.patches
  54. conv = colors.colorConverter
  55. for i, rect in enumerate(rects[::5]):
  56. xp = conv.to_rgba(default_colors[i])
  57. rs = rect.get_facecolor()
  58. self.assert_(xp == rs)
  59. plt.close('all')
  60. ax = df.plot(kind='bar', color=custom_colors)
  61. rects = ax.patches
  62. conv = colors.colorConverter
  63. for i, rect in enumerate(rects[::5]):
  64. xp = conv.to_rgba(custom_colors[i])
  65. rs = rect.get_facecolor()
  66. self.assert_(xp == rs)
  67. @slow
  68. def test_bar_linewidth(self):
  69. df = DataFrame(np.random.randn(5, 5))
  70. # regular
  71. ax = df.plot(kind='bar', linewidth=2)
  72. for r in ax.patches:
  73. self.assert_(r.get_linewidth() == 2)
  74. # stacked
  75. ax = df.plot(kind='bar', stacked=True, linewidth=2)
  76. for r in ax.patches:
  77. self.assert_(r.get_linewidth() == 2)
  78. # subplots
  79. axes = df.plot(kind='bar', linewidth=2, subplots=True)
  80. for ax in axes:
  81. for r in ax.patches:
  82. self.assert_(r.get_linewidth() == 2)
  83. @slow
  84. def test_1rotation(self):
  85. df = DataFrame(np.random.randn(5, 5))
  86. ax = df.plot(rot=30)
  87. for l in ax.get_xticklabels():
  88. self.assert_(l.get_rotation() == 30)
  89. @slow
  90. def test_irregular_datetime(self):
  91. rng = date_range('1/1/2000', '3/1/2000')
  92. rng = rng[[0,1,2,3,5,9,10,11,12]]
  93. ser = Series(np.random.randn(len(rng)), rng)
  94. ax = ser.plot()
  95. xp = datetime(1999, 1, 1).toordinal()
  96. ax.set_xlim('1/1/1999', '1/1/2001')
  97. self.assert_(xp == ax.get_xlim()[0])
  98. @slow
  99. def test_hist(self):
  100. _check_plot_works(self.ts.hist)
  101. _check_plot_works(self.ts.hist, grid=False)
  102. @slow
  103. def test_kde(self):
  104. _check_plot_works(self.ts.plot, kind='kde')
  105. _check_plot_works(self.ts.plot, kind='density')
  106. ax = self.ts.plot(kind='kde', logy=True)
  107. self.assert_(ax.get_yscale() == 'log')
  108. @slow
  109. def test_autocorrelation_plot(self):
  110. from pandas.tools.plotting import autocorrelation_plot
  111. _check_plot_works(autocorrelation_plot, self.ts)
  112. _check_plot_works(autocorrelation_plot, self.ts.values)
  113. @slow
  114. def test_lag_plot(self):
  115. from pandas.tools.plotting import lag_plot
  116. _check_plot_works(lag_plot, self.ts)
  117. @slow
  118. def test_bootstrap_plot(self):
  119. from pandas.tools.plotting import bootstrap_plot
  120. _check_plot_works(bootstrap_plot, self.ts, size=10)
  121. class TestDataFramePlots(unittest.TestCase):
  122. @classmethod
  123. def setUpClass(cls):
  124. import sys
  125. if 'IPython' in sys.modules:
  126. raise nose.SkipTest
  127. try:
  128. import matplotlib as mpl
  129. mpl.use('Agg', warn=False)
  130. except ImportError:
  131. raise nose.SkipTest
  132. @slow
  133. def test_plot(self):
  134. df = tm.makeTimeDataFrame()
  135. _check_plot_works(df.plot, grid=False)
  136. _check_plot_works(df.plot, subplots=True)
  137. _check_plot_works(df.plot, subplots=True, use_index=False)
  138. df = DataFrame({'x':[1,2], 'y':[3,4]})
  139. self._check_plot_fails(df.plot, kind='line', blarg=True)
  140. df = DataFrame(np.random.rand(10, 3),
  141. index=list(string.ascii_letters[:10]))
  142. _check_plot_works(df.plot, use_index=True)
  143. _check_plot_works(df.plot, sort_columns=False)
  144. _check_plot_works(df.plot, yticks=[1, 5, 10])
  145. _check_plot_works(df.plot, xticks=[1, 5, 10])
  146. _check_plot_works(df.plot, ylim=(-100, 100), xlim=(-100, 100))
  147. _check_plot_works(df.plot, subplots=True, title='blah')
  148. _check_plot_works(df.plot, title='blah')
  149. tuples = zip(list(string.ascii_letters[:10]), range(10))
  150. df = DataFrame(np.random.rand(10, 3),
  151. index=MultiIndex.from_tuples(tuples))
  152. _check_plot_works(df.plot, use_index=True)
  153. # unicode
  154. index = MultiIndex.from_tuples([(u'\u03b1', 0),
  155. (u'\u03b1', 1),
  156. (u'\u03b2', 2),
  157. (u'\u03b2', 3),
  158. (u'\u03b3', 4),
  159. (u'\u03b3', 5),
  160. (u'\u03b4', 6),
  161. (u'\u03b4', 7)], names=['i0', 'i1'])
  162. columns = MultiIndex.from_tuples([('bar', u'\u0394'),
  163. ('bar', u'\u0395')], names=['c0', 'c1'])
  164. df = DataFrame(np.random.randint(0, 10, (8, 2)),
  165. columns=columns,
  166. index=index)
  167. _check_plot_works(df.plot, title=u'\u03A3')
  168. @slow
  169. def test_plot_xy(self):
  170. df = tm.makeTimeDataFrame()
  171. self._check_data(df.plot(x=0, y=1),
  172. df.set_index('A').sort_index()['B'].plot())
  173. self._check_data(df.plot(x=0), df.set_index('A').sort_index().plot())
  174. self._check_data(df.plot(y=0), df.B.plot())
  175. self._check_data(df.plot(x='A', y='B'),
  176. df.set_index('A').sort_index().B.plot())
  177. self._check_data(df.plot(x='A'), df.set_index('A').sort_index().plot())
  178. self._check_data(df.plot(y='B'), df.B.plot())
  179. def _check_data(self, xp, rs):
  180. xp_lines = xp.get_lines()
  181. rs_lines = rs.get_lines()
  182. def check_line(xpl, rsl):
  183. xpdata = xpl.get_xydata()
  184. rsdata = rsl.get_xydata()
  185. assert_array_equal(xpdata, rsdata)
  186. [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
  187. @slow
  188. def test_subplots(self):
  189. df = DataFrame(np.random.rand(10, 3),
  190. index=list(string.ascii_letters[:10]))
  191. axes = df.plot(subplots=True, sharex=True, legend=True)
  192. for ax in axes:
  193. self.assert_(ax.get_legend() is not None)
  194. axes = df.plot(subplots=True, sharex=True)
  195. for ax in axes[:-2]:
  196. [self.assert_(not label.get_visible())
  197. for label in ax.get_xticklabels()]
  198. [self.assert_(label.get_visible())
  199. for label in ax.get_yticklabels()]
  200. [self.assert_(label.get_visible())
  201. for label in axes[-1].get_xticklabels()]
  202. [self.assert_(label.get_visible())
  203. for label in axes[-1].get_yticklabels()]
  204. axes = df.plot(subplots=True, sharex=False)
  205. for ax in axes:
  206. [self.assert_(label.get_visible())
  207. for label in ax.get_xticklabels()]
  208. [self.assert_(label.get_visible())
  209. for label in ax.get_yticklabels()]
  210. @slow
  211. def test_plot_bar(self):
  212. df = DataFrame(np.random.randn(6, 4),
  213. index=list(string.ascii_letters[:6]),
  214. columns=['one', 'two', 'three', 'four'])
  215. _check_plot_works(df.plot, kind='bar')
  216. _check_plot_works(df.plot, kind='bar', legend=False)
  217. _check_plot_works(df.plot, kind='bar', subplots=True)
  218. _check_plot_works(df.plot, kind='bar', stacked=True)
  219. df = DataFrame(np.random.randn(10, 15),
  220. index=list(string.ascii_letters[:10]),
  221. columns=range(15))
  222. _check_plot_works(df.plot, kind='bar')
  223. df = DataFrame({'a': [0, 1], 'b': [1, 0]})
  224. _check_plot_works(df.plot, kind='bar')
  225. @slow
  226. def test_boxplot(self):
  227. df = DataFrame(np.random.randn(6, 4),
  228. index=list(string.ascii_letters[:6]),
  229. columns=['one', 'two', 'three', 'four'])
  230. df['indic'] = ['foo', 'bar'] * 3
  231. df['indic2'] = ['foo', 'bar', 'foo'] * 2
  232. _check_plot_works(df.boxplot)
  233. _check_plot_works(df.boxplot, column=['one', 'two'])
  234. _check_plot_works(df.boxplot, column=['one', 'two'],
  235. by='indic')
  236. _check_plot_works(df.boxplot, column='one', by=['indic', 'indic2'])
  237. _check_plot_works(df.boxplot, by='indic')
  238. _check_plot_works(df.boxplot, by=['indic', 'indic2'])
  239. _check_plot_works(lambda x: plotting.boxplot(x), df['one'])
  240. _check_plot_works(df.boxplot, notch=1)
  241. _check_plot_works(df.boxplot, by='indic', notch=1)
  242. df = DataFrame(np.random.rand(10,2), columns=['Col1', 'Col2'] )
  243. df['X'] = Series(['A','A','A','A','A','B','B','B','B','B'])
  244. _check_plot_works(df.boxplot, by='X')
  245. @slow
  246. def test_kde(self):
  247. df = DataFrame(np.random.randn(100, 4))
  248. _check_plot_works(df.plot, kind='kde')
  249. _check_plot_works(df.plot, kind='kde', subplots=True)
  250. axes = df.plot(kind='kde', logy=True, subplots=True)
  251. for ax in axes:
  252. self.assert_(ax.get_yscale() == 'log')
  253. @slow
  254. def test_hist(self):
  255. df = DataFrame(np.random.randn(100, 4))
  256. _check_plot_works(df.hist)
  257. _check_plot_works(df.hist, grid=False)
  258. #make sure layout is handled
  259. df = DataFrame(np.random.randn(100, 3))
  260. _check_plot_works(df.hist)
  261. axes = df.hist(grid=False)
  262. self.assert_(not axes[1, 1].get_visible())
  263. df = DataFrame(np.random.randn(100, 1))
  264. _check_plot_works(df.hist)
  265. #make sure layout is handled
  266. df = DataFrame(np.random.randn(100, 6))
  267. _check_plot_works(df.hist)
  268. #make sure sharex, sharey is handled
  269. _check_plot_works(df.hist, sharex=True, sharey=True)
  270. #make sure kwargs are handled
  271. ser = df[0]
  272. xf, yf = 20, 20
  273. xrot, yrot = 30, 30
  274. ax = ser.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30)
  275. ytick = ax.get_yticklabels()[0]
  276. xtick = ax.get_xticklabels()[0]
  277. self.assertAlmostEqual(ytick.get_fontsize(), yf)
  278. self.assertAlmostEqual(ytick.get_rotation(), yrot)
  279. self.assertAlmostEqual(xtick.get_fontsize(), xf)
  280. self.assertAlmostEqual(xtick.get_rotation(), xrot)
  281. xf, yf = 20, 20
  282. xrot, yrot = 30, 30
  283. axes = df.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30)
  284. for i, ax in enumerate(axes.ravel()):
  285. if i < len(df.columns):
  286. ytick = ax.get_yticklabels()[0]
  287. xtick = ax.get_xticklabels()[0]
  288. self.assertAlmostEqual(ytick.get_fontsize(), yf)
  289. self.assertAlmostEqual(ytick.get_rotation(), yrot)
  290. self.assertAlmostEqual(xtick.get_fontsize(), xf)
  291. self.assertAlmostEqual(xtick.get_rotation(), xrot)
  292. @slow
  293. def test_scatter(self):
  294. df = DataFrame(np.random.randn(100, 4))
  295. import pandas.tools.plotting as plt
  296. def scat(**kwds):
  297. return plt.scatter_matrix(df, **kwds)
  298. _check_plot_works(scat)
  299. _check_plot_works(scat, marker='+')
  300. _check_plot_works(scat, vmin=0)
  301. _check_plot_works(scat, diagonal='kde')
  302. _check_plot_works(scat, diagonal='density')
  303. _check_plot_works(scat, diagonal='hist')
  304. def scat2(x, y, by=None, ax=None, figsize=None):
  305. return plt.scatter_plot(df, x, y, by, ax, figsize=None)
  306. _check_plot_works(scat2, 0, 1)
  307. grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index)
  308. _check_plot_works(scat2, 0, 1, by=grouper)
  309. @slow
  310. def test_andrews_curves(self):
  311. from pandas import read_csv
  312. from pandas.tools.plotting import andrews_curves
  313. path = os.path.join(curpath(), 'data/iris.csv')
  314. df = read_csv(path)
  315. _check_plot_works(andrews_curves, df, 'Name')
  316. @slow
  317. def test_parallel_coordinates(self):
  318. from pandas import read_csv
  319. from pandas.tools.plotting import parallel_coordinates
  320. path = os.path.join(curpath(), 'data/iris.csv')
  321. df = read_csv(path)
  322. _check_plot_works(parallel_coordinates, df, 'Name')
  323. @slow
  324. def test_radviz(self):
  325. from pandas import read_csv
  326. from pandas.tools.plotting import radviz
  327. path = os.path.join(curpath(), 'data/iris.csv')
  328. df = read_csv(path)
  329. _check_plot_works(radviz, df, 'Name')
  330. @slow
  331. def test_plot_int_columns(self):
  332. df = DataFrame(np.random.randn(100, 4)).cumsum()
  333. _check_plot_works(df.plot, legend=True)
  334. @slow
  335. def test_legend_name(self):
  336. multi = DataFrame(np.random.randn(4, 4),
  337. columns=[np.array(['a', 'a', 'b', 'b']),
  338. np.array(['x', 'y', 'x', 'y'])])
  339. multi.columns.names = ['group', 'individual']
  340. ax = multi.plot()
  341. leg_title = ax.legend_.get_title()
  342. self.assert_(leg_title.get_text(), 'group,individual')
  343. def _check_plot_fails(self, f, *args, **kwargs):
  344. self.assertRaises(Exception, f, *args, **kwargs)
  345. @slow
  346. def test_style_by_column(self):
  347. import matplotlib.pyplot as plt
  348. fig = plt.gcf()
  349. df = DataFrame(np.random.randn(100, 3))
  350. for markers in [{0: '^', 1: '+', 2: 'o'},
  351. {0: '^', 1: '+'},
  352. ['^', '+', 'o'],
  353. ['^', '+']]:
  354. fig.clf()
  355. fig.add_subplot(111)
  356. ax = df.plot(style=markers)
  357. for i, l in enumerate(ax.get_lines()[:len(markers)]):
  358. self.assertEqual(l.get_marker(), markers[i])
  359. class TestDataFrameGroupByPlots(unittest.TestCase):
  360. @classmethod
  361. def setUpClass(cls):
  362. import sys
  363. if 'IPython' in sys.modules:
  364. raise nose.SkipTest
  365. try:
  366. import matplotlib as mpl
  367. mpl.use('Agg', warn=False)
  368. except ImportError:
  369. raise nose.SkipTest
  370. @slow
  371. def test_boxplot(self):
  372. df = DataFrame(np.random.rand(10,2), columns=['Col1', 'Col2'] )
  373. df['X'] = Series(['A','A','A','A','A','B','B','B','B','B'])
  374. grouped = df.groupby(by='X')
  375. _check_plot_works(grouped.boxplot)
  376. _check_plot_works(grouped.boxplot, subplots=False)
  377. tuples = zip(list(string.ascii_letters[:10]), range(10))
  378. df = DataFrame(np.random.rand(10, 3),
  379. index=MultiIndex.from_tuples(tuples))
  380. grouped = df.groupby(level=1)
  381. _check_plot_works(grouped.boxplot)
  382. _check_plot_works(grouped.boxplot, subplots=False)
  383. grouped = df.unstack(level=1).groupby(level=0, axis=1)
  384. _check_plot_works(grouped.boxplot)
  385. _check_plot_works(grouped.boxplot, subplots=False)
  386. PNG_PATH = 'tmp.png'
  387. def _check_plot_works(f, *args, **kwargs):
  388. import matplotlib.pyplot as plt
  389. fig = plt.gcf()
  390. plt.clf()
  391. ax = fig.add_subplot(211)
  392. ret = f(*args, **kwargs)
  393. assert(ret is not None) # do something more intelligent
  394. ax = fig.add_subplot(212)
  395. try:
  396. kwargs['ax'] = ax
  397. ret = f(*args, **kwargs)
  398. assert(ret is not None) # do something more intelligent
  399. except Exception:
  400. pass
  401. plt.savefig(PNG_PATH)
  402. os.remove(PNG_PATH)
  403. def curpath():
  404. pth, _ = os.path.split(os.path.abspath(__file__))
  405. return pth
  406. if __name__ == '__main__':
  407. nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],
  408. exit=False)