/lincs/stats/stats.py
Python | 320 lines | 206 code | 51 blank | 63 comment | 33 complexity | c07d6dddafbb8a0d6e179f161408b2d3 MD5 | raw file
- # -*- coding: utf-8 -*-
- import pandas
- import itertools
- import operator
- import math
- import logging
- def foo(km, antibodies):
- kmstats = pandas.DataFrame(
- columns=['mean0', 'median0', 'stddev0', 'mean1', 'median1', 'stddev1'],
- index=antibodies
- )
- for ab in antibodies:
- for i in (0, 1):
- d = km[km['cell_count_a'] == i][ab]
- mean = d.mean()
- stddev = d.std()
- median = d.median()
- kmstats.set_value(ab, "mean{0}".format(i), mean)
- kmstats.set_value(ab, "stddev{0}".format(i), stddev)
- kmstats.set_value(ab, "median{0}".format(i), median)
- def true_len(x, real=True):
- """ Returns the number of elemens in x that are Truthy.
- """
- if real:
- val = 1.
- else:
- val = 1
- return sum(val for y in x if y)
- def extract_stats(treatment_data, signals, on_threshold):
- """ Extracts mean, median, & stddev for each treatment in
- for cell counts `treatment_data`,
- :param treatment_data: Dictionary mapping treatments to data
- :type verbose: dictionary of pandas.DataFrame objects.
- :returns: DataFrame.
- """
- count_tuples = (
- ('0', operator.eq, 0),
- ('1', operator.eq, 1),
- ('2', operator.eq, 2),
- ('3', operator.eq, 3),
- )
- index_tuples = list(itertools.product(
- treatment_data.keys(),
- ('0', '1', '2', '3'))
- )
- index = pandas.MultiIndex.from_tuples(
- index_tuples,
- names=['treatment', 'cell_count']
- )
- means = pandas.DataFrame(index=index, columns=signals)
- medians = pandas.DataFrame(index=index, columns=signals)
- stddevs = pandas.DataFrame(index=index, columns=signals)
- # Get aggregate stats for each treatment and cell count combination.
- #
- for treatment, data in treatment_data.iteritems():
- for (label, op, val) in count_tuples:
- # Find the mean, median, and standard deviation
- #
- means.ix[(treatment, label)] \
- = data[op(data["cell_count"], val)].mean()
- medians.ix[(treatment, label)] \
- = data[op(data["cell_count"], val)].median()
- stddevs.ix[(treatment, label)] \
- = data[op(data["cell_count"], val)].std()
- # What what threshold constitutes "on" for this experiment.
- #
- threshold_index = pandas.MultiIndex.from_tuples(
- list(itertools.product(
- treatment_data.keys(),
- ('all', 'mid98', 'delta'))
- ),
- names=['treatment', 'threshold']
- )
- thresholds = pandas.DataFrame(
- index=threshold_index,
- columns=signals
- )
- for treatment, data in treatment_data.iteritems():
- zero_mask = (data.cell_count == 0)
- zero_wells = data[zero_mask]
- sumi = zero_wells.sum(axis=1)
- middle_98_mask = ((sumi > sumi.quantile(0.01))
- & (sumi < sumi.quantile(0.99)))
- zero_wells_98 = zero_wells[middle_98_mask]
- thresholds.ix[(treatment, 'mid98')] \
- = zero_wells_98.mean() \
- + on_threshold * zero_wells_98.std()
- thresholds.ix[(treatment, 'all')] \
- = zero_wells.mean() \
- + on_threshold * zero_wells.std()
- thresholds.ix[(treatment, 'delta')] \
- = thresholds.ix[(treatment, 'mid98')] \
- - thresholds.ix[(treatment, 'all')]
- # Throw out the zero cell wells that are out
- # of the range we're keeping.
- data = data[-zero_mask]
- data = data.merge(zero_wells_98, how="outer")
- treatment_data[treatment] = data
- # Find out what fraction are "on" for each experiment and cell
- # count combination.
- #
- oncounts = pandas.DataFrame(index=index, columns=signals)
- for treatment, data in treatment_data.iteritems():
- for signal in signals:
- for (label, op, count) in count_tuples:
- # Calculate the fraction that are on for
- # this cell population.
- #
- threshold = thresholds[signal][(treatment, 'mid98')]
- count_filter = op(data["cell_count"], count)
- num_cells = sum(1 for l in count_filter if l)
- signal_filter = (data[signal] > threshold)
- num_cells_on = true_len(count_filter & signal_filter)
- if num_cells > 0:
- oncounts[signal][(treatment, label)] \
- = float(num_cells_on) / num_cells
- else:
- oncounts[signal][(treatment, label)] = None
- # Make conditional probability and mutual information DataFrames.
- #
- cprob_index = pandas.MultiIndex.from_tuples(
- list(itertools.product(treatment_data.keys(), signals)),
- names=['treatment', 'signal']
- )
- cprobs = pandas.DataFrame(index=cprob_index, columns=signals)
- minfo = pandas.DataFrame(index=cprob_index, columns=signals)
- for treatment, data in treatment_data.iteritems():
- logging.info("Treatment: {0}".format(treatment))
- for s1 in signals:
- # Threshold for s1
- t1 = thresholds[s1][(treatment, 'mid98')]
- # Bit mask saying which cells are "on" for s1
- s1any = (data["cell_count"] == 1)
- s1on = (s1any & (data[s1] > t1))
- s1off = (s1any & (data[s1] <= t1))
- # Conditional probabiliy and mutual info are both undefined here
- #
- if true_len(s1any) == 0:
- continue
- for s2 in signals:
- t2 = thresholds[s2][(treatment, 'mid98')]
- s2any = s1any
- s2on = (s2any & (data[s2] > t2))
- s2off = (s2any & (data[s2] <= t2))
- # Save the conditional probability that s2 is on
- # given that s1 is on.
- #
- if true_len(s1on) > 0:
- cprobs[s2][(treatment, s1)] \
- = true_len(s1on & s2on) / true_len(s1on)
- else:
- cprobs[s2][(treatment, s1)] = 0
- mi = 0
- for c1, c2 in itertools.product((s1on, s1off), (s2on, s2off)):
- if len(c1) != len(c2):
- raise Exception("WTF")
- p_xy = true_len(c1 & c2) / true_len(s1any)
- p_x = true_len(c1) / true_len(s1any)
- p_y = true_len(c2) / true_len(s1any)
- if p_xy > 0:
- mi += p_xy * math.log((p_xy / (p_x * p_y)), 2)
- minfo[s2][(treatment, s1)] = mi
- # Calculate mean and stddev of those wells that are
- # over the threashold
- #
- onstats_index = pandas.MultiIndex.from_tuples(
- list(
- itertools.product(
- treatment_data.keys(),
- ("{0}-{1}".format(s, c)
- for (s, c)
- in itertools.product(('mean', 'stddev'), range(4))
- )
- ),
- ),
- names=['treatment', 'stat']
- )
- onstats = pandas.DataFrame(index=onstats_index, columns=signals)
- for treatment, data in treatment_data.iteritems():
- for s in signals:
- # Threshold for s
- t1 = thresholds[s][(treatment, 'mid98')]
- for cell_count in range(4):
- # Bit mask saying which cells are "on"
- s_on = ((data["cell_count"] == cell_count) & (data[s] > t1))
- onstats.ix[(treatment, "mean-{0}".format(cell_count))][s] \
- = data[s_on][s].mean()
- onstats.ix[(treatment, "stddev-{0}".format(cell_count))][s] \
- = data[s_on][s].std()
- # Get a single spreadsheet in which the intensities are
- # erased for those signals that are below threshold.
- #
- thresholded_data, signature_data = make_thresholded_data(
- treatment_data,
- signals,
- thresholds,
- 'mid98'
- )
- (normalized_treatments, normalized_treatments_clipped) = normalize_treatments(
- treatment_data,
- signals,
- thresholds
- )
- return (
- treatment_data,
- means,
- medians,
- stddevs,
- thresholds,
- oncounts,
- cprobs,
- minfo,
- onstats,
- thresholded_data,
- signature_data,
- normalized_treatments,
- normalized_treatments_clipped
- )
- def normalize_treatments(treatment_data, signals, thresholds, cell_counts=[1, 2]):
- """ For each treatment, find only those wells with
- one cell and, from these, subtract the threshold
- for each signal. Return two dictionaries, one with
- signals below zero set to zero and one raw.
- """
- # Normalized data & the same, clipped
- nd = {}
- nd_clipped = {}
- for cell_count in cell_counts:
- nd[cell_count] = {}
- nd_clipped[cell_count] = {}
- for t, v in treatment_data.iteritems():
- # Grab wells matching this cell count for this treatment
- # and subtract thresholds for this treatment.
- result = v[v['cell_count'] == cell_count] - thresholds.ix[(t, 'mid98')]
- result['cell_count'] = cell_count
- # This is our normalized data
- nd[cell_count][t] = result
- # Now make the clipped copy
- result = result.clip(None, 0)
- nd_clipped[cell_count][t] = result
- return (nd, nd_clipped)
- def make_thresholded_data(treatment_data, signals, thresholds, key):
- """ Apply thresholds to the data such that the intensity values below
- the threshold for a particular treatment and signal
- """
- # Thresholded treatment data
- ttd = pandas.DataFrame(columns=signals + ["treatment"])
- sigs = pandas.DataFrame(columns=signals + ["treatment"])
- # Recall, `treatment_data is a dictionary where the
- # keys are strings and the values are of type DataFrame.
- #
- for t, v in treatment_data.iteritems():
- tmp = v.copy()
- sig_tmp = pandas.DataFrame(index=tmp.index, columns=tmp.columns)
- tmp['treatment'] = t
- sig_tmp['treatment'] = t
- for s in signals:
- # Where we're below the threshold, remove the data
- tmp[s][tmp[s] < thresholds[s][(t, key)]] = None
- # Get the signatures for this treatment
- tmp_sigs = (~pandas.isnull(tmp[signals])).combine_first(tmp)
- sigs = pandas.concat((sigs, tmp_sigs), ignore_index=True)
- # Select only those rows for which at least one of
- # the signals is above the threshold.
- #
- tmp = tmp[~pandas.isnull(tmp[signals]).all(axis=1)]
- # Append these data to the full set
- ttd = pandas.concat((ttd, tmp), ignore_index=True)
- return ttd, sigs