PageRenderTime 38ms CodeModel.GetById 26ms RepoModel.GetById 0ms app.codeStats 0ms

/notify_user/pymodules/python2.7/lib/python/statsmodels-0.5.0-py2.7-linux-x86_64.egg/statsmodels/graphics/factorplots.py

https://gitlab.com/pooja043/Globus_Docker_4
Python | 211 lines | 180 code | 3 blank | 28 comment | 2 complexity | 0bf42b510785d308dfc80e99ee1abd42 MD5 | raw file
  1. # -*- coding: utf-8 -*-
  2. """
  3. Authors: Josef Perktold, Skipper Seabold, Denis A. Engemann
  4. """
  5. import numpy as np
  6. from statsmodels.graphics.plottools import rainbow
  7. import utils
  8. def interaction_plot(x, trace, response, func=np.mean, ax=None, plottype='b',
  9. xlabel=None, ylabel=None, colors=[], markers=[],
  10. linestyles=[], legendloc='best', legendtitle=None,
  11. **kwargs):
  12. """
  13. Interaction plot for factor level statistics.
  14. Note. If categorial factors are supplied levels will be internally
  15. recoded to integers. This ensures matplotlib compatiblity.
  16. uses pandas.DataFrame to calculate an `aggregate` statistic for each
  17. level of the factor or group given by `trace`.
  18. Parameters
  19. ----------
  20. x : array-like
  21. The `x` factor levels constitute the x-axis. If a `pandas.Series` is
  22. given its name will be used in `xlabel` if `xlabel` is None.
  23. trace : array-like
  24. The `trace` factor levels will be drawn as lines in the plot.
  25. If `trace` is a `pandas.Series` its name will be used as the
  26. `legendtitle` if `legendtitle` is None.
  27. response : array-like
  28. The reponse or dependent variable. If a `pandas.Series` is given
  29. its name will be used in `ylabel` if `ylabel` is None.
  30. func : function
  31. Anything accepted by `pandas.DataFrame.aggregate`. This is applied to
  32. the response variable grouped by the trace levels.
  33. plottype : str {'line', 'scatter', 'both'}, optional
  34. The type of plot to return. Can be 'l', 's', or 'b'
  35. ax : axes, optional
  36. Matplotlib axes instance
  37. xlabel : str, optional
  38. Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it
  39. will use the series names.
  40. ylabel : str, optional
  41. Label to use for `response`. Default is 'func of response'. If
  42. `response` is a `pandas.Series` it will use the series names.
  43. colors : list, optional
  44. If given, must have length == number of levels in trace.
  45. linestyles : list, optional
  46. If given, must have length == number of levels in trace.
  47. markers : list, optional
  48. If given, must have length == number of lovels in trace
  49. kwargs
  50. These will be passed to the plot command used either plot or scatter.
  51. If you want to control the overall plotting options, use kwargs.
  52. Returns
  53. -------
  54. fig : Figure
  55. The figure given by `ax.figure` or a new instance.
  56. Examples
  57. --------
  58. >>> import numpy as np
  59. >>> np.random.seed(12345)
  60. >>> weight = np.random.randint(1,4,size=60)
  61. >>> duration = np.random.randint(1,3,size=60)
  62. >>> days = np.log(np.random.randint(1,30, size=60))
  63. >>> fig = interaction_plot(weight, duration, days,
  64. ... colors=['red','blue'], markers=['D','^'], ms=10)
  65. >>> import matplotlib.pyplot as plt
  66. >>> plt.show()
  67. .. plot::
  68. import numpy as np
  69. from statsmodels.graphics.factorplots import interaction_plot
  70. np.random.seed(12345)
  71. weight = np.random.randint(1,4,size=60)
  72. duration = np.random.randint(1,3,size=60)
  73. days = np.log(np.random.randint(1,30, size=60))
  74. fig = interaction_plot(weight, duration, days,
  75. colors=['red','blue'], markers=['D','^'], ms=10)
  76. import matplotlib.pyplot as plt
  77. #plt.show()
  78. """
  79. from pandas import DataFrame
  80. fig, ax = utils.create_mpl_ax(ax)
  81. response_name = ylabel or getattr(response, 'name', 'response')
  82. ylabel = '%s of %s' % (func.func_name, response_name)
  83. xlabel = xlabel or getattr(x, 'name', 'X')
  84. legendtitle = legendtitle or getattr(trace, 'name', 'Trace')
  85. ax.set_ylabel(ylabel)
  86. ax.set_xlabel(xlabel)
  87. x_values = x_levels = None
  88. if isinstance(x[0], str):
  89. x_levels = [l for l in np.unique(x)]
  90. x_values = xrange(len(x_levels))
  91. x = _recode(x, dict(zip(x_levels, x_values)))
  92. data = DataFrame(dict(x=x, trace=trace, response=response))
  93. plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()
  94. # return data
  95. # check plot args
  96. n_trace = len(plot_data['trace'].unique())
  97. if linestyles:
  98. try:
  99. assert len(linestyles) == n_trace
  100. except AssertionError, err:
  101. raise ValueError("Must be a linestyle for each trace level")
  102. else: # set a default
  103. linestyles = ['-'] * n_trace
  104. if markers:
  105. try:
  106. assert len(markers) == n_trace
  107. except AssertionError, err:
  108. raise ValueError("Must be a linestyle for each trace level")
  109. else: # set a default
  110. markers = ['.'] * n_trace
  111. if colors:
  112. try:
  113. assert len(colors) == n_trace
  114. except AssertionError, err:
  115. raise ValueError("Must be a linestyle for each trace level")
  116. else: # set a default
  117. #TODO: how to get n_trace different colors?
  118. colors = rainbow(n_trace)
  119. if plottype == 'both' or plottype == 'b':
  120. for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
  121. # trace label
  122. label = str(group['trace'].values[0])
  123. ax.plot(group['x'], group['response'], color=colors[i],
  124. marker=markers[i], label=label,
  125. linestyle=linestyles[i], **kwargs)
  126. elif plottype == 'line' or plottype == 'l':
  127. for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
  128. # trace label
  129. label = str(group['trace'].values[0])
  130. ax.plot(group['x'], group['response'], color=colors[i],
  131. label=label, linestyle=linestyles[i], **kwargs)
  132. elif plottype == 'scatter' or plottype == 's':
  133. for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
  134. # trace label
  135. label = str(group['trace'].values[0])
  136. ax.scatter(group['x'], group['response'], color=colors[i],
  137. label=label, marker=markers[i], **kwargs)
  138. else:
  139. raise ValueError("Plot type %s not understood" % plottype)
  140. ax.legend(loc=legendloc, title=legendtitle)
  141. ax.margins(.1)
  142. if all([x_levels, x_values]):
  143. ax.set_xticks(x_values)
  144. ax.set_xticklabels(x_levels)
  145. return fig
  146. def _recode(x, levels):
  147. """ Recode categorial data to int factor.
  148. Parameters
  149. ----------
  150. x : array-like
  151. array like object supporting with numpy array methods of categorially
  152. coded data.
  153. levels : dict
  154. mapping of labels to integer-codings
  155. Returns
  156. -------
  157. out : instance numpy.ndarray
  158. """
  159. from pandas import Series
  160. name = None
  161. if isinstance(x, Series):
  162. name = x.name
  163. x = x.values
  164. if x.dtype.type not in [np.str_, np.object_]:
  165. raise ValueError('This is not a categorial factor.'
  166. ' Array of str type required.')
  167. elif not isinstance(levels, dict):
  168. raise ValueError('This is not a valid value for levels.'
  169. ' Dict required.')
  170. elif not (np.unique(x) == np.unique(levels.keys())).all():
  171. raise ValueError('The levels do not match the array values.')
  172. else:
  173. out = np.empty(x.shape[0], dtype=np.int)
  174. for level, coding in levels.items():
  175. out[x == level] = coding
  176. if name:
  177. out = Series(out)
  178. out.name = name
  179. return out