diff --git a/tests/test_ttest.py b/tests/test_ttest.py new file mode 100644 index 0000000..0870bce --- /dev/null +++ b/tests/test_ttest.py @@ -0,0 +1,40 @@ +# scipy-yli: Helpful SciPy utilities and recipes +# Copyright © 2022 Lee Yingtong Li (RunasSudo) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from pytest import approx + +import numpy as np +import pandas as pd + +import yli + +def test_ttest_ind_ol6_1(): + """Compare yli.ttest_ind for Ott & Longnecker (2016) example 6.1""" + + df = pd.DataFrame({ + 'Type': ['Fresh'] * 10 + ['Stored'] * 10, + 'Potency': [10.2, 10.5, 10.3, 10.8, 9.8, 10.6, 10.7, 10.2, 10.0, 10.6, 9.8, 9.6, 10.1, 10.2, 10.1, 9.7, 9.5, 9.6, 9.8, 9.9] + }) + + result = yli.ttest_ind(df, 'Potency', 'Type') + + t_expected = 0.54/(0.285*np.sqrt(1/10+1/10)) + + assert result.statistic == approx(t_expected, abs=0.01) + assert result.dof == 18 + assert result.delta.point == approx(0.54, abs=0.01) + assert result.delta.ci_lower == approx(0.272, abs=0.01) + assert result.delta.ci_upper == approx(0.808, abs=0.01) diff --git a/yli/sig_tests.py b/yli/sig_tests.py index 319eb68..dcdfbc8 100644 --- a/yli/sig_tests.py +++ b/yli/sig_tests.py @@ -21,74 +21,10 @@ import statsmodels.api as sm import functools import warnings -def check_nan(df, nan_policy): - """Check df against nan_policy and return cleaned input""" - - if nan_policy == 'raise': - if pd.isna(df).any(axis=None): - raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore') - elif nan_policy == 'warn': - df_cleaned = df.dropna() - if len(df_cleaned) < len(df): - warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned))) - return df_cleaned - elif nan_policy == 'omit': - return df.dropna() - else: - raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"') +from .utils import Estimate, as_2groups, check_nan, fmt_p_html, fmt_p_text -def do_fmt_p(p): - """Return sign and formatted p value""" - - if p < 0.001: - return '<', '0.001*' - elif p < 0.0095: - return None, '{:.3f}*'.format(p) - elif p < 0.045: - return None, '{:.2f}*'.format(p) - elif p < 0.05: - return None, '{:.3f}*'.format(p) # 3dps to show significance - elif p < 0.055: - return None, '{:.3f}'.format(p) # 3dps to show non-significance - elif p < 0.095: - return None, '{:.2f}'.format(p) - else: - return None, '{:.1f}'.format(p) - -def fmt_p_text(p, nospace=False): - """Format p value for plaintext""" - - sign, fmt = do_fmt_p(p) - if sign is not None: - if nospace: - return sign + fmt # e.g. "<0.001" - else: - return sign + ' ' + fmt # e.g. "< 0.001" - else: - if nospace: - return fmt # e.g. "0.05" - else: - return '= ' + fmt # e.g. "= 0.05" - -def fmt_p_html(p, nospace=False): - """Format p value for HTML""" - - txt = fmt_p_text(p, nospace) - return txt.replace('<', '<') - -class Estimate: - """A point estimate and surrounding confidence interval""" - - def __init__(self, point, ci_lower, ci_upper): - self.point = point - self.ci_lower = ci_lower - self.ci_upper = ci_upper - - def _repr_html_(self): - return self.summary() - - def summary(self): - return '{:.2f} ({:.2f}–{:.2f})'.format(self.point, self.ci_lower, self.ci_upper) +# ---------------- +# Student's t test class TTestResult: """ @@ -97,44 +33,41 @@ class TTestResult: delta: Mean difference """ - def __init__(self, statistic, dof, pvalue, delta): + def __init__(self, statistic, dof, pvalue, delta, delta_direction): self.statistic = statistic self.dof = dof self.pvalue = pvalue self.delta = delta + self.delta_direction = delta_direction def _repr_html_(self): - return 't({:.0f}) = {:.2f}; p {}
δ (95% CI) = {}'.format(self.dof, self.statistic, fmt_p_html(self.pvalue), self.delta.summary()) + return 't({:.0f}) = {:.2f}; p {}
δ (95% CI) = {}, {}'.format(self.dof, self.statistic, fmt_p_html(self.pvalue), self.delta.summary(), self.delta_direction) def summary(self): - return 't({:.0f}) = {:.2f}; p {}\nδ (95% CI) = {}'.format(self.dof, self.statistic, fmt_p_text(self.pvalue), self.delta.summary()) + return 't({:.0f}) = {:.2f}; p {}\nδ (95% CI) = {}, {}'.format(self.dof, self.statistic, fmt_p_text(self.pvalue), self.delta.summary(), self.delta_direction) def ttest_ind(df, dep, ind, *, nan_policy='warn'): """Perform an independent-sample Student's t test""" + # Check for/clean NaNs df = check_nan(df[[ind, dep]], nan_policy) - # Get groupings for ind - groups = list(df.groupby(ind).groups.values()) - - # Ensure only 2 groups to compare - if len(groups) != 2: - raise Exception('Got {} values for {}, expected 2'.format(len(groups), ind)) - - # Get 2 groups - group1 = df.loc[groups[0], dep] - group2 = df.loc[groups[1], dep] + # Ensure 2 groups for ind + group1, data1, group2, data2 = as_2groups(df, dep, ind) # Do t test # Use statsmodels rather than SciPy because this provides the mean difference automatically - d1 = sm.stats.DescrStatsW(group1) - d2 = sm.stats.DescrStatsW(group2) + d1 = sm.stats.DescrStatsW(data1) + d2 = sm.stats.DescrStatsW(data2) - cm = sm.stats.CompareMeans(d2, d1) # This order to get correct CI + cm = sm.stats.CompareMeans(d1, d2) statistic, pvalue, dof = cm.ttest_ind() - delta = d2.mean - d1.mean + delta = d1.mean - d2.mean ci0, ci1 = cm.tconfint_diff() - return TTestResult(statistic=statistic, dof=dof, pvalue=pvalue, delta=Estimate(delta, ci0, ci1)) -0 \ No newline at end of file + # t test is symmetric so take absolute values + return TTestResult( + statistic=abs(statistic), dof=dof, pvalue=pvalue, + delta=abs(Estimate(delta, ci0, ci1)), + delta_direction=('{0} > {1}' if d1.mean > d2.mean else '{1} > {0}').format(group1, group2)) diff --git a/yli/utils.py b/yli/utils.py new file mode 100644 index 0000000..bc0840e --- /dev/null +++ b/yli/utils.py @@ -0,0 +1,119 @@ +# scipy-yli: Helpful SciPy utilities and recipes +# Copyright © 2022 Lee Yingtong Li (RunasSudo) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import numpy as np +import pandas as pd + +import warnings + +def check_nan(df, nan_policy): + """Check df against nan_policy and return cleaned input""" + + if nan_policy == 'raise': + if pd.isna(df).any(axis=None): + raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore') + elif nan_policy == 'warn': + df_cleaned = df.dropna() + if len(df_cleaned) < len(df): + warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned))) + return df_cleaned + elif nan_policy == 'omit': + return df.dropna() + else: + raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"') + +def as_2groups(df, data, group): + """Group the data by the given variable, ensuring only 2 groups""" + + # Get groupings + groups = list(df.groupby(group).groups.items()) + + # Ensure only 2 groups to compare + if len(groups) != 2: + raise Exception('Got {} values for {}, expected 2'.format(len(groups), group)) + + # Get 2 groups + group1 = groups[0][0] + data1 = df.loc[groups[0][1], data] + group2 = groups[1][0] + data2 = df.loc[groups[1][1], data] + + return group1, data1, group2, data2 + +def do_fmt_p(p): + """Return sign and formatted p value""" + + if p < 0.001: + return '<', '0.001*' + elif p < 0.0095: + return None, '{:.3f}*'.format(p) + elif p < 0.045: + return None, '{:.2f}*'.format(p) + elif p < 0.05: + return None, '{:.3f}*'.format(p) # 3dps to show significance + elif p < 0.055: + return None, '{:.3f}'.format(p) # 3dps to show non-significance + elif p < 0.095: + return None, '{:.2f}'.format(p) + else: + return None, '{:.1f}'.format(p) + +def fmt_p_text(p, nospace=False): + """Format p value for plaintext""" + + sign, fmt = do_fmt_p(p) + if sign is not None: + if nospace: + return sign + fmt # e.g. "<0.001" + else: + return sign + ' ' + fmt # e.g. "< 0.001" + else: + if nospace: + return fmt # e.g. "0.05" + else: + return '= ' + fmt # e.g. "= 0.05" + +def fmt_p_html(p, nospace=False): + """Format p value for HTML""" + + txt = fmt_p_text(p, nospace) + return txt.replace('<', '<') + +class Estimate: + """A point estimate and surrounding confidence interval""" + + def __init__(self, point, ci_lower, ci_upper): + self.point = point + self.ci_lower = ci_lower + self.ci_upper = ci_upper + + def _repr_html_(self): + return self.summary() + + def summary(self): + return '{:.2f} ({:.2f}–{:.2f})'.format(self.point, self.ci_lower, self.ci_upper) + + def __neg__(self): + return Estimate(-self.point, -self.ci_upper, -self.ci_lower) + + def __abs__(self): + if self.point < 0: + return -self + else: + return self + + def exp(self): + return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))