/kcbo/statistical_tests/lognormal_comparison_test.py

https://github.com/cwharland/kcbo · Python · 168 lines · 104 code · 38 blank · 26 comment · 4 complexity · 4e8973121d11082410c7e316796e5813 MD5 · raw file

  1. from kcbo.statistical_tests.utils import StatisticalTest, statistic
  2. from kcbo.utils import output_templates
  3. import numpy as np
  4. from tabulate import tabulate
  5. from itertools import combinations
  6. class LognormalMedianComparison(StatisticalTest):
  7. TYPE = 'Lognormal Median Comparison Test'
  8. def __init__(self, *args, **kwargs):
  9. super(type(self), self).__init__(*args, **kwargs)
  10. def initialize_test(self, dataframe, groups=None, groupcol='group', valuecol='value', samples=100000, **kwargs):
  11. df = dataframe
  12. df = df[df[valuecol] > 0]
  13. if not groups:
  14. groups = df[groupcol].unique()
  15. pooled = df[valuecol]
  16. self.pooled = pooled
  17. self.groups = groups
  18. self.groupcol = groupcol
  19. self.valuecol = valuecol
  20. self.df = df
  21. self.samples = samples
  22. self.keys = list(combinations(groups, 2))
  23. self.median_distributions = {}
  24. self.mean_distributions = {}
  25. def run_model(self, *args, **kwargs):
  26. groups = kwargs.get('groups', self.groups)
  27. m, v = (self.pooled.mean(), self.pooled.var())
  28. compute_mu = lambda m, v: np.log(m ** 2 / np.sqrt(v + m ** 2))
  29. compute_var = lambda m, v: np.log(v * 1. / m ** 2 + 1)
  30. pooled_mean = compute_mu(m, v)
  31. pooled_variance = compute_var(m, v)
  32. pooled_tau = 1. / 1000000
  33. mc_samples = self.samples
  34. tau = 1. / pooled_variance
  35. for group in groups:
  36. g = self.df[self.df[self.groupcol] == group][self.valuecol]
  37. n = g.shape[0]
  38. # MC Simulation to generate distribution
  39. mean_data = np.random.normal(
  40. loc=(pooled_tau * pooled_mean + tau * np.log(g).mean() * n) /
  41. (pooled_tau + n * tau),
  42. scale= np.sqrt(1. / (pooled_tau + n * tau)),
  43. size=mc_samples
  44. )
  45. median_data = np.exp(mean_data)
  46. self.median_distributions[group] = median_data
  47. self.mean_distributions[group] = mean_data
  48. @statistic('median', individual=True, is_distribution=True, is_estimate=True)
  49. def group_median_distribution(self, group):
  50. return self.median_distributions.get(group, [])
  51. @statistic('mu', individual=True, is_distribution=True, is_estimate=True)
  52. def group_mean_distribution(self, group):
  53. return self.mean_distributions.get(group, [])
  54. @statistic('diff_medians', is_distribution=True, pairwise=True, is_estimate=True)
  55. def diff_medians(self, groups):
  56. group1, group2 = groups
  57. return self.median_distributions[group2] - self.median_distributions[group1]
  58. @statistic('p_diff_medians', pairwise=True, is_estimate=True)
  59. def p_diff_medians(self, groups):
  60. return (self.diff_medians(groups) > 0).mean()
  61. def summary(self, *args, **kwargs):
  62. summary_data = self.compute_statistic(
  63. keys=list(self.keys).extend(self.groups))
  64. return self.generate_text_description(summary_data), summary_data
  65. def generate_text_description(self, summary_data):
  66. group_summary_header = [
  67. 'Group', 'Median', '95% CI Lower', '95% CI Upper', 'Mu', '95% CI Lower', '95% CI Upper']
  68. group_summary_table_data = [
  69. [
  70. group,
  71. summary_data[group]['estimate median'],
  72. summary_data[group]['95_CI median'][0],
  73. summary_data[group]['95_CI median'][1],
  74. summary_data[group]['estimate mu'],
  75. summary_data[group]['95_CI mu'][0],
  76. summary_data[group]['95_CI mu'][1]
  77. ]
  78. for group in self.groups]
  79. group_summary_table = tabulate(
  80. group_summary_table_data, group_summary_header, tablefmt="pipe")
  81. comparisons_header = [
  82. "Hypothesis", "Difference of Medians", "P.Value", "95% CI Lower", "95% CI Upper"]
  83. comparisons_data = [
  84. [
  85. "{} < {}".format(*pair),
  86. self.diff_medians(pair).mean(),
  87. summary_data[pair]['p_diff_medians'],
  88. summary_data[pair]['95_CI diff_medians'][0],
  89. summary_data[pair]['95_CI diff_medians'][1],
  90. ] for pair in self.keys
  91. ]
  92. comparison_summary_table = tabulate(
  93. comparisons_data, comparisons_header, tablefmt="pipe")
  94. description = output_templates['groups with comparison'].format(
  95. title=self.TYPE,
  96. groups_header="Groups:",
  97. groups_string=", ".join(self.groups),
  98. groups_summary=group_summary_table,
  99. comparison_summary=comparison_summary_table,
  100. )
  101. return description
  102. def lognormal_comparison_test(dataframe, groups=None, groupcol='group', valuecol='value', **kwargs):
  103. """Lognormal Median Comparison
  104. Given a dataframe of the form:
  105. |Group |Observed Value|
  106. |-------|--------------|
  107. |<group>| <float>|
  108. ...
  109. Compute estimates of the difference of medians between groups.
  110. Note: This test assumes that input comes from distributions with the same variance.
  111. Inputs:
  112. dataframe -- Pandas dataframe of form above
  113. groups -- (optional) list of groups to look at. Excluded looks at all groups
  114. groupcol -- string for indexing dataframe column for groups
  115. valuecol -- string for indexing dataframe column for values of observations
  116. Returns:
  117. (description, raw_data)
  118. description: table describing output data
  119. raw_data: dictionary of output data
  120. """
  121. results = LognormalMedianComparison(
  122. dataframe, groups=None, groupcol='group', valuecol='value', **kwargs)
  123. return results.summary()