PageRenderTime 153ms CodeModel.GetById 20ms RepoModel.GetById 0ms app.codeStats 1ms

/pandas/tests/test_graphics.py

https://github.com/kljensen/pandas
Python | 519 lines | 412 code | 94 blank | 13 comment | 34 complexity | b413a1f1dea57735842fa1a36d4fe217 MD5 | raw file
Possible License(s): BSD-3-Clause
  1. import nose
  2. import os
  3. import string
  4. import unittest
  5. from datetime import datetime
  6. from pandas import Series, DataFrame, MultiIndex, PeriodIndex, date_range
  7. import pandas.util.testing as tm
  8. import numpy as np
  9. from numpy.testing import assert_array_equal
  10. from numpy.testing.decorators import slow
  11. import pandas.tools.plotting as plotting
  12. class TestSeriesPlots(unittest.TestCase):
  13. @classmethod
  14. def setUpClass(cls):
  15. import sys
  16. if 'IPython' in sys.modules:
  17. raise nose.SkipTest
  18. try:
  19. import matplotlib as mpl
  20. mpl.use('Agg', warn=False)
  21. except ImportError:
  22. raise nose.SkipTest
  23. def setUp(self):
  24. self.ts = tm.makeTimeSeries()
  25. self.ts.name = 'ts'
  26. self.series = tm.makeStringSeries()
  27. self.series.name = 'series'
  28. self.iseries = tm.makePeriodSeries()
  29. self.iseries.name = 'iseries'
  30. @slow
  31. def test_plot(self):
  32. _check_plot_works(self.ts.plot, label='foo')
  33. _check_plot_works(self.ts.plot, use_index=False)
  34. _check_plot_works(self.ts.plot, rot=0)
  35. _check_plot_works(self.ts.plot, style='.', logy=True)
  36. _check_plot_works(self.ts.plot, style='.', logx=True)
  37. _check_plot_works(self.ts.plot, style='.', loglog=True)
  38. _check_plot_works(self.ts[:10].plot, kind='bar')
  39. _check_plot_works(self.series[:5].plot, kind='bar')
  40. _check_plot_works(self.series[:5].plot, kind='line')
  41. _check_plot_works(self.series[:5].plot, kind='barh')
  42. _check_plot_works(self.series[:10].plot, kind='barh')
  43. Series(np.random.randn(10)).plot(kind='bar',color='black')
  44. @slow
  45. def test_bar_colors(self):
  46. import matplotlib.pyplot as plt
  47. import matplotlib.colors as colors
  48. default_colors = 'brgyk'
  49. custom_colors = 'rgcby'
  50. plt.close('all')
  51. df = DataFrame(np.random.randn(5, 5))
  52. ax = df.plot(kind='bar')
  53. rects = ax.patches
  54. conv = colors.colorConverter
  55. for i, rect in enumerate(rects[::5]):
  56. xp = conv.to_rgba(default_colors[i])
  57. rs = rect.get_facecolor()
  58. self.assert_(xp == rs)
  59. plt.close('all')
  60. ax = df.plot(kind='bar', color=custom_colors)
  61. rects = ax.patches
  62. conv = colors.colorConverter
  63. for i, rect in enumerate(rects[::5]):
  64. xp = conv.to_rgba(custom_colors[i])
  65. rs = rect.get_facecolor()
  66. self.assert_(xp == rs)
  67. @slow
  68. def test_bar_linewidth(self):
  69. df = DataFrame(np.random.randn(5, 5))
  70. # regular
  71. ax = df.plot(kind='bar', linewidth=2)
  72. for r in ax.patches:
  73. self.assert_(r.get_linewidth() == 2)
  74. # stacked
  75. ax = df.plot(kind='bar', stacked=True, linewidth=2)
  76. for r in ax.patches:
  77. self.assert_(r.get_linewidth() == 2)
  78. # subplots
  79. axes = df.plot(kind='bar', linewidth=2, subplots=True)
  80. for ax in axes:
  81. for r in ax.patches:
  82. self.assert_(r.get_linewidth() == 2)
  83. @slow
  84. def test_1rotation(self):
  85. df = DataFrame(np.random.randn(5, 5))
  86. ax = df.plot(rot=30)
  87. for l in ax.get_xticklabels():
  88. self.assert_(l.get_rotation() == 30)
  89. @slow
  90. def test_irregular_datetime(self):
  91. rng = date_range('1/1/2000', '3/1/2000')
  92. rng = rng[[0,1,2,3,5,9,10,11,12]]
  93. ser = Series(np.random.randn(len(rng)), rng)
  94. ax = ser.plot()
  95. xp = datetime(1999, 1, 1).toordinal()
  96. ax.set_xlim('1/1/1999', '1/1/2001')
  97. self.assert_(xp == ax.get_xlim()[0])
  98. @slow
  99. def test_hist(self):
  100. _check_plot_works(self.ts.hist)
  101. _check_plot_works(self.ts.hist, grid=False)
  102. @slow
  103. def test_kde(self):
  104. _check_plot_works(self.ts.plot, kind='kde')
  105. _check_plot_works(self.ts.plot, kind='density')
  106. ax = self.ts.plot(kind='kde', logy=True)
  107. self.assert_(ax.get_yscale() == 'log')
  108. @slow
  109. def test_autocorrelation_plot(self):
  110. from pandas.tools.plotting import autocorrelation_plot
  111. _check_plot_works(autocorrelation_plot, self.ts)
  112. _check_plot_works(autocorrelation_plot, self.ts.values)
  113. @slow
  114. def test_lag_plot(self):
  115. from pandas.tools.plotting import lag_plot
  116. _check_plot_works(lag_plot, self.ts)
  117. @slow
  118. def test_bootstrap_plot(self):
  119. from pandas.tools.plotting import bootstrap_plot
  120. _check_plot_works(bootstrap_plot, self.ts, size=10)
  121. class TestDataFramePlots(unittest.TestCase):
  122. @classmethod
  123. def setUpClass(cls):
  124. import sys
  125. if 'IPython' in sys.modules:
  126. raise nose.SkipTest
  127. try:
  128. import matplotlib as mpl
  129. mpl.use('Agg', warn=False)
  130. except ImportError:
  131. raise nose.SkipTest
  132. @slow
  133. def test_plot(self):
  134. df = tm.makeTimeDataFrame()
  135. _check_plot_works(df.plot, grid=False)
  136. _check_plot_works(df.plot, subplots=True)
  137. _check_plot_works(df.plot, subplots=True, use_index=False)
  138. df = DataFrame({'x':[1,2], 'y':[3,4]})
  139. self._check_plot_fails(df.plot, kind='line', blarg=True)
  140. df = DataFrame(np.random.rand(10, 3),
  141. index=list(string.ascii_letters[:10]))
  142. _check_plot_works(df.plot, use_index=True)
  143. _check_plot_works(df.plot, sort_columns=False)
  144. _check_plot_works(df.plot, yticks=[1, 5, 10])
  145. _check_plot_works(df.plot, xticks=[1, 5, 10])
  146. _check_plot_works(df.plot, ylim=(-100, 100), xlim=(-100, 100))
  147. _check_plot_works(df.plot, subplots=True, title='blah')
  148. _check_plot_works(df.plot, title='blah')
  149. tuples = zip(list(string.ascii_letters[:10]), range(10))
  150. df = DataFrame(np.random.rand(10, 3),
  151. index=MultiIndex.from_tuples(tuples))
  152. _check_plot_works(df.plot, use_index=True)
  153. # unicode
  154. index = MultiIndex.from_tuples([(u'\u03b1', 0),
  155. (u'\u03b1', 1),
  156. (u'\u03b2', 2),
  157. (u'\u03b2', 3),
  158. (u'\u03b3', 4),
  159. (u'\u03b3', 5),
  160. (u'\u03b4', 6),
  161. (u'\u03b4', 7)], names=['i0', 'i1'])
  162. columns = MultiIndex.from_tuples([('bar', u'\u0394'),
  163. ('bar', u'\u0395')], names=['c0', 'c1'])
  164. df = DataFrame(np.random.randint(0, 10, (8, 2)),
  165. columns=columns,
  166. index=index)
  167. _check_plot_works(df.plot, title=u'\u03A3')
  168. @slow
  169. def test_plot_xy(self):
  170. # columns.inferred_type == 'string'
  171. df = tm.makeTimeDataFrame()
  172. self._check_data(df.plot(x=0, y=1),
  173. df.set_index('A')['B'].plot())
  174. self._check_data(df.plot(x=0), df.set_index('A').plot())
  175. self._check_data(df.plot(y=0), df.B.plot())
  176. self._check_data(df.plot(x='A', y='B'),
  177. df.set_index('A').B.plot())
  178. self._check_data(df.plot(x='A'), df.set_index('A').plot())
  179. self._check_data(df.plot(y='B'), df.B.plot())
  180. # columns.inferred_type == 'integer'
  181. df.columns = range(1, len(df.columns) + 1)
  182. self._check_data(df.plot(x=1, y=2),
  183. df.set_index(1)[2].plot())
  184. self._check_data(df.plot(x=1), df.set_index(1).plot())
  185. self._check_data(df.plot(y=1), df[1].plot())
  186. # columns.inferred_type == 'mixed'
  187. # TODO add MultiIndex test
  188. def _check_data(self, xp, rs):
  189. xp_lines = xp.get_lines()
  190. rs_lines = rs.get_lines()
  191. def check_line(xpl, rsl):
  192. xpdata = xpl.get_xydata()
  193. rsdata = rsl.get_xydata()
  194. assert_array_equal(xpdata, rsdata)
  195. [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
  196. @slow
  197. def test_subplots(self):
  198. df = DataFrame(np.random.rand(10, 3),
  199. index=list(string.ascii_letters[:10]))
  200. axes = df.plot(subplots=True, sharex=True, legend=True)
  201. for ax in axes:
  202. self.assert_(ax.get_legend() is not None)
  203. axes = df.plot(subplots=True, sharex=True)
  204. for ax in axes[:-2]:
  205. [self.assert_(not label.get_visible())
  206. for label in ax.get_xticklabels()]
  207. [self.assert_(label.get_visible())
  208. for label in ax.get_yticklabels()]
  209. [self.assert_(label.get_visible())
  210. for label in axes[-1].get_xticklabels()]
  211. [self.assert_(label.get_visible())
  212. for label in axes[-1].get_yticklabels()]
  213. axes = df.plot(subplots=True, sharex=False)
  214. for ax in axes:
  215. [self.assert_(label.get_visible())
  216. for label in ax.get_xticklabels()]
  217. [self.assert_(label.get_visible())
  218. for label in ax.get_yticklabels()]
  219. @slow
  220. def test_plot_bar(self):
  221. df = DataFrame(np.random.randn(6, 4),
  222. index=list(string.ascii_letters[:6]),
  223. columns=['one', 'two', 'three', 'four'])
  224. _check_plot_works(df.plot, kind='bar')
  225. _check_plot_works(df.plot, kind='bar', legend=False)
  226. _check_plot_works(df.plot, kind='bar', subplots=True)
  227. _check_plot_works(df.plot, kind='bar', stacked=True)
  228. df = DataFrame(np.random.randn(10, 15),
  229. index=list(string.ascii_letters[:10]),
  230. columns=range(15))
  231. _check_plot_works(df.plot, kind='bar')
  232. df = DataFrame({'a': [0, 1], 'b': [1, 0]})
  233. _check_plot_works(df.plot, kind='bar')
  234. @slow
  235. def test_boxplot(self):
  236. df = DataFrame(np.random.randn(6, 4),
  237. index=list(string.ascii_letters[:6]),
  238. columns=['one', 'two', 'three', 'four'])
  239. df['indic'] = ['foo', 'bar'] * 3
  240. df['indic2'] = ['foo', 'bar', 'foo'] * 2
  241. _check_plot_works(df.boxplot)
  242. _check_plot_works(df.boxplot, column=['one', 'two'])
  243. _check_plot_works(df.boxplot, column=['one', 'two'],
  244. by='indic')
  245. _check_plot_works(df.boxplot, column='one', by=['indic', 'indic2'])
  246. _check_plot_works(df.boxplot, by='indic')
  247. _check_plot_works(df.boxplot, by=['indic', 'indic2'])
  248. _check_plot_works(lambda x: plotting.boxplot(x), df['one'])
  249. _check_plot_works(df.boxplot, notch=1)
  250. _check_plot_works(df.boxplot, by='indic', notch=1)
  251. df = DataFrame(np.random.rand(10,2), columns=['Col1', 'Col2'] )
  252. df['X'] = Series(['A','A','A','A','A','B','B','B','B','B'])
  253. _check_plot_works(df.boxplot, by='X')
  254. @slow
  255. def test_kde(self):
  256. df = DataFrame(np.random.randn(100, 4))
  257. _check_plot_works(df.plot, kind='kde')
  258. _check_plot_works(df.plot, kind='kde', subplots=True)
  259. axes = df.plot(kind='kde', logy=True, subplots=True)
  260. for ax in axes:
  261. self.assert_(ax.get_yscale() == 'log')
  262. @slow
  263. def test_hist(self):
  264. df = DataFrame(np.random.randn(100, 4))
  265. _check_plot_works(df.hist)
  266. _check_plot_works(df.hist, grid=False)
  267. #make sure layout is handled
  268. df = DataFrame(np.random.randn(100, 3))
  269. _check_plot_works(df.hist)
  270. axes = df.hist(grid=False)
  271. self.assert_(not axes[1, 1].get_visible())
  272. df = DataFrame(np.random.randn(100, 1))
  273. _check_plot_works(df.hist)
  274. #make sure layout is handled
  275. df = DataFrame(np.random.randn(100, 6))
  276. _check_plot_works(df.hist)
  277. #make sure sharex, sharey is handled
  278. _check_plot_works(df.hist, sharex=True, sharey=True)
  279. #make sure kwargs are handled
  280. ser = df[0]
  281. xf, yf = 20, 20
  282. xrot, yrot = 30, 30
  283. ax = ser.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30)
  284. ytick = ax.get_yticklabels()[0]
  285. xtick = ax.get_xticklabels()[0]
  286. self.assertAlmostEqual(ytick.get_fontsize(), yf)
  287. self.assertAlmostEqual(ytick.get_rotation(), yrot)
  288. self.assertAlmostEqual(xtick.get_fontsize(), xf)
  289. self.assertAlmostEqual(xtick.get_rotation(), xrot)
  290. xf, yf = 20, 20
  291. xrot, yrot = 30, 30
  292. axes = df.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30)
  293. for i, ax in enumerate(axes.ravel()):
  294. if i < len(df.columns):
  295. ytick = ax.get_yticklabels()[0]
  296. xtick = ax.get_xticklabels()[0]
  297. self.assertAlmostEqual(ytick.get_fontsize(), yf)
  298. self.assertAlmostEqual(ytick.get_rotation(), yrot)
  299. self.assertAlmostEqual(xtick.get_fontsize(), xf)
  300. self.assertAlmostEqual(xtick.get_rotation(), xrot)
  301. @slow
  302. def test_scatter(self):
  303. df = DataFrame(np.random.randn(100, 4))
  304. import pandas.tools.plotting as plt
  305. def scat(**kwds):
  306. return plt.scatter_matrix(df, **kwds)
  307. _check_plot_works(scat)
  308. _check_plot_works(scat, marker='+')
  309. _check_plot_works(scat, vmin=0)
  310. _check_plot_works(scat, diagonal='kde')
  311. _check_plot_works(scat, diagonal='density')
  312. _check_plot_works(scat, diagonal='hist')
  313. def scat2(x, y, by=None, ax=None, figsize=None):
  314. return plt.scatter_plot(df, x, y, by, ax, figsize=None)
  315. _check_plot_works(scat2, 0, 1)
  316. grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index)
  317. _check_plot_works(scat2, 0, 1, by=grouper)
  318. @slow
  319. def test_andrews_curves(self):
  320. from pandas import read_csv
  321. from pandas.tools.plotting import andrews_curves
  322. path = os.path.join(curpath(), 'data/iris.csv')
  323. df = read_csv(path)
  324. _check_plot_works(andrews_curves, df, 'Name')
  325. @slow
  326. def test_parallel_coordinates(self):
  327. from pandas import read_csv
  328. from pandas.tools.plotting import parallel_coordinates
  329. path = os.path.join(curpath(), 'data/iris.csv')
  330. df = read_csv(path)
  331. _check_plot_works(parallel_coordinates, df, 'Name')
  332. @slow
  333. def test_radviz(self):
  334. from pandas import read_csv
  335. from pandas.tools.plotting import radviz
  336. path = os.path.join(curpath(), 'data/iris.csv')
  337. df = read_csv(path)
  338. _check_plot_works(radviz, df, 'Name')
  339. @slow
  340. def test_plot_int_columns(self):
  341. df = DataFrame(np.random.randn(100, 4)).cumsum()
  342. _check_plot_works(df.plot, legend=True)
  343. @slow
  344. def test_legend_name(self):
  345. multi = DataFrame(np.random.randn(4, 4),
  346. columns=[np.array(['a', 'a', 'b', 'b']),
  347. np.array(['x', 'y', 'x', 'y'])])
  348. multi.columns.names = ['group', 'individual']
  349. ax = multi.plot()
  350. leg_title = ax.legend_.get_title()
  351. self.assert_(leg_title.get_text(), 'group,individual')
  352. def _check_plot_fails(self, f, *args, **kwargs):
  353. self.assertRaises(Exception, f, *args, **kwargs)
  354. @slow
  355. def test_style_by_column(self):
  356. import matplotlib.pyplot as plt
  357. fig = plt.gcf()
  358. df = DataFrame(np.random.randn(100, 3))
  359. for markers in [{0: '^', 1: '+', 2: 'o'},
  360. {0: '^', 1: '+'},
  361. ['^', '+', 'o'],
  362. ['^', '+']]:
  363. fig.clf()
  364. fig.add_subplot(111)
  365. ax = df.plot(style=markers)
  366. for i, l in enumerate(ax.get_lines()[:len(markers)]):
  367. self.assertEqual(l.get_marker(), markers[i])
  368. class TestDataFrameGroupByPlots(unittest.TestCase):
  369. @classmethod
  370. def setUpClass(cls):
  371. import sys
  372. if 'IPython' in sys.modules:
  373. raise nose.SkipTest
  374. try:
  375. import matplotlib as mpl
  376. mpl.use('Agg', warn=False)
  377. except ImportError:
  378. raise nose.SkipTest
  379. @slow
  380. def test_boxplot(self):
  381. df = DataFrame(np.random.rand(10,2), columns=['Col1', 'Col2'] )
  382. df['X'] = Series(['A','A','A','A','A','B','B','B','B','B'])
  383. grouped = df.groupby(by='X')
  384. _check_plot_works(grouped.boxplot)
  385. _check_plot_works(grouped.boxplot, subplots=False)
  386. tuples = zip(list(string.ascii_letters[:10]), range(10))
  387. df = DataFrame(np.random.rand(10, 3),
  388. index=MultiIndex.from_tuples(tuples))
  389. grouped = df.groupby(level=1)
  390. _check_plot_works(grouped.boxplot)
  391. _check_plot_works(grouped.boxplot, subplots=False)
  392. grouped = df.unstack(level=1).groupby(level=0, axis=1)
  393. _check_plot_works(grouped.boxplot)
  394. _check_plot_works(grouped.boxplot, subplots=False)
  395. @slow
  396. def test_series_plot_color_kwargs(self):
  397. # #1890
  398. import matplotlib.pyplot as plt
  399. plt.close('all')
  400. ax = Series(np.arange(12) + 1).plot(color='green')
  401. line = ax.get_lines()[0]
  402. self.assert_(line.get_color() == 'green')
  403. PNG_PATH = 'tmp.png'
  404. def _check_plot_works(f, *args, **kwargs):
  405. import matplotlib.pyplot as plt
  406. fig = plt.gcf()
  407. plt.clf()
  408. ax = fig.add_subplot(211)
  409. ret = f(*args, **kwargs)
  410. assert(ret is not None) # do something more intelligent
  411. ax = fig.add_subplot(212)
  412. try:
  413. kwargs['ax'] = ax
  414. ret = f(*args, **kwargs)
  415. assert(ret is not None) # do something more intelligent
  416. except Exception:
  417. pass
  418. plt.savefig(PNG_PATH)
  419. os.remove(PNG_PATH)
  420. def curpath():
  421. pth, _ = os.path.split(os.path.abspath(__file__))
  422. return pth
  423. if __name__ == '__main__':
  424. nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],
  425. exit=False)