PageRenderTime 50ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/pymc/diagnostics.py

http://github.com/jseabold/pymc
Python | 178 lines | 53 code | 31 blank | 94 comment | 15 complexity | 01f203f81d3cf441804e65197c0b8da7 MD5 | raw file
Possible License(s): Apache-2.0
  1. """Convergence diagnostics and model validation"""
  2. import numpy as np
  3. from .stats import autocorr, autocov, statfunc
  4. from copy import copy
  5. __all__ = ['geweke', 'gelman_rubin', 'trace_to_dataframe']
  6. @statfunc
  7. def geweke(x, first=.1, last=.5, intervals=20):
  8. """Return z-scores for convergence diagnostics.
  9. Compare the mean of the first % of series with the mean of the last % of
  10. series. x is divided into a number of segments for which this difference is
  11. computed. If the series is converged, this score should oscillate between
  12. -1 and 1.
  13. Parameters
  14. ----------
  15. x : array-like
  16. The trace of some stochastic parameter.
  17. first : float
  18. The fraction of series at the beginning of the trace.
  19. last : float
  20. The fraction of series at the end to be compared with the section
  21. at the beginning.
  22. intervals : int
  23. The number of segments.
  24. Returns
  25. -------
  26. scores : list [[]]
  27. Return a list of [i, score], where i is the starting index for each
  28. interval and score the Geweke score on the interval.
  29. Notes
  30. -----
  31. The Geweke score on some series x is computed by:
  32. .. math:: \frac{E[x_s] - E[x_e]}{\sqrt{V[x_s] + V[x_e]}}
  33. where :math:`E` stands for the mean, :math:`V` the variance,
  34. :math:`x_s` a section at the start of the series and
  35. :math:`x_e` a section at the end of the series.
  36. References
  37. ----------
  38. Geweke (1992)
  39. """
  40. if np.rank(x) > 1:
  41. return [geweke(y, first, last, intervals) for y in np.transpose(x)]
  42. # Filter out invalid intervals
  43. if first + last >= 1:
  44. raise ValueError(
  45. "Invalid intervals for Geweke convergence analysis",
  46. (first,
  47. last))
  48. # Initialize list of z-scores
  49. zscores = []
  50. # Last index value
  51. end = len(x) - 1
  52. # Calculate starting indices
  53. sindices = np.arange(0, end // 2, step=int((end / 2) / (intervals - 1)))
  54. # Loop over start indices
  55. for start in sindices:
  56. # Calculate slices
  57. first_slice = x[start: start + int(first * (end - start))]
  58. last_slice = x[int(end - last * (end - start)):]
  59. z = (first_slice.mean() - last_slice.mean())
  60. z /= np.sqrt(first_slice.std() ** 2 + last_slice.std() ** 2)
  61. zscores.append([start, z])
  62. if intervals is None:
  63. return np.array(zscores[0])
  64. else:
  65. return np.array(zscores)
  66. def gelman_rubin(mtrace):
  67. """ Returns estimate of R for a set of traces.
  68. The Gelman-Rubin diagnostic tests for lack of convergence by comparing
  69. the variance between multiple chains to the variance within each chain.
  70. If convergence has been achieved, the between-chain and within-chain
  71. variances should be identical. To be most effective in detecting evidence
  72. for nonconvergence, each chain should have been initialized to starting
  73. values that are dispersed relative to the target distribution.
  74. Parameters
  75. ----------
  76. mtrace : MultiTrace
  77. A MultiTrace object containing parallel traces (minimum 2)
  78. of one or more stochastic parameters.
  79. Returns
  80. -------
  81. Rhat : dict
  82. Returns dictionary of the potential scale reduction factors, :math:`\hat{R}`
  83. Notes
  84. -----
  85. The diagnostic is computed by:
  86. .. math:: \hat{R} = \frac{\hat{V}}{W}
  87. where :math:`W` is the within-chain variance and :math:`\hat{V}` is
  88. the posterior variance estimate for the pooled traces. This is the
  89. potential scale reduction factor, which converges to unity when each
  90. of the traces is a sample from the target posterior. Values greater
  91. than one indicate that one or more chains have not yet converged.
  92. References
  93. ----------
  94. Brooks and Gelman (1998)
  95. Gelman and Rubin (1992)"""
  96. if mtrace.nchains < 2:
  97. raise ValueError(
  98. 'Gelman-Rubin diagnostic requires multiple chains of the same length.')
  99. def calc_rhat(x):
  100. try:
  101. # When the variable is multidimensional, this assignment will fail, triggering
  102. # a ValueError that will handle the multidimensional case
  103. m, n = x.shape
  104. # Calculate between-chain variance
  105. B = n * np.var(np.mean(x, axis=1), ddof=1)
  106. # Calculate within-chain variance
  107. W = np.mean(np.var(x, axis=1, ddof=1))
  108. # Estimate of marginal posterior variance
  109. Vhat = W*(n - 1)/n + B/n
  110. return np.sqrt(Vhat/W)
  111. except ValueError:
  112. # Tricky transpose here, shifting the last dimension to the first
  113. rotated_indices = np.roll(np.arange(x.ndim), 1)
  114. # Now iterate over the dimension of the variable
  115. return np.squeeze([calc_rhat(xi) for xi in x.transpose(rotated_indices)])
  116. Rhat = {}
  117. for var in mtrace.varnames:
  118. # Get all traces for var
  119. x = np.array(mtrace.get_values(var))
  120. try:
  121. Rhat[var] = calc_rhat(x)
  122. except ValueError:
  123. Rhat[var] = [calc_rhat(y.transpose()) for y in x.transpose()]
  124. return Rhat
  125. def trace_to_dataframe(trace):
  126. """Convert a PyMC trace consisting of 1-D variables to a pandas DataFrame
  127. """
  128. import pandas as pd
  129. return pd.DataFrame(
  130. {varname: np.squeeze(trace.get_values(varname, combine=True))
  131. for varname in trace.varnames})