diff --git a/tests/test_ttest.py b/tests/test_ttest.py
new file mode 100644
index 0000000..0870bce
--- /dev/null
+++ b/tests/test_ttest.py
@@ -0,0 +1,40 @@
+# scipy-yli: Helpful SciPy utilities and recipes
+# Copyright © 2022 Lee Yingtong Li (RunasSudo)
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from pytest import approx
+
+import numpy as np
+import pandas as pd
+
+import yli
+
+def test_ttest_ind_ol6_1():
+ """Compare yli.ttest_ind for Ott & Longnecker (2016) example 6.1"""
+
+ df = pd.DataFrame({
+ 'Type': ['Fresh'] * 10 + ['Stored'] * 10,
+ 'Potency': [10.2, 10.5, 10.3, 10.8, 9.8, 10.6, 10.7, 10.2, 10.0, 10.6, 9.8, 9.6, 10.1, 10.2, 10.1, 9.7, 9.5, 9.6, 9.8, 9.9]
+ })
+
+ result = yli.ttest_ind(df, 'Potency', 'Type')
+
+ t_expected = 0.54/(0.285*np.sqrt(1/10+1/10))
+
+ assert result.statistic == approx(t_expected, abs=0.01)
+ assert result.dof == 18
+ assert result.delta.point == approx(0.54, abs=0.01)
+ assert result.delta.ci_lower == approx(0.272, abs=0.01)
+ assert result.delta.ci_upper == approx(0.808, abs=0.01)
diff --git a/yli/sig_tests.py b/yli/sig_tests.py
index 319eb68..dcdfbc8 100644
--- a/yli/sig_tests.py
+++ b/yli/sig_tests.py
@@ -21,74 +21,10 @@ import statsmodels.api as sm
import functools
import warnings
-def check_nan(df, nan_policy):
- """Check df against nan_policy and return cleaned input"""
-
- if nan_policy == 'raise':
- if pd.isna(df).any(axis=None):
- raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore')
- elif nan_policy == 'warn':
- df_cleaned = df.dropna()
- if len(df_cleaned) < len(df):
- warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned)))
- return df_cleaned
- elif nan_policy == 'omit':
- return df.dropna()
- else:
- raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')
+from .utils import Estimate, as_2groups, check_nan, fmt_p_html, fmt_p_text
-def do_fmt_p(p):
- """Return sign and formatted p value"""
-
- if p < 0.001:
- return '<', '0.001*'
- elif p < 0.0095:
- return None, '{:.3f}*'.format(p)
- elif p < 0.045:
- return None, '{:.2f}*'.format(p)
- elif p < 0.05:
- return None, '{:.3f}*'.format(p) # 3dps to show significance
- elif p < 0.055:
- return None, '{:.3f}'.format(p) # 3dps to show non-significance
- elif p < 0.095:
- return None, '{:.2f}'.format(p)
- else:
- return None, '{:.1f}'.format(p)
-
-def fmt_p_text(p, nospace=False):
- """Format p value for plaintext"""
-
- sign, fmt = do_fmt_p(p)
- if sign is not None:
- if nospace:
- return sign + fmt # e.g. "<0.001"
- else:
- return sign + ' ' + fmt # e.g. "< 0.001"
- else:
- if nospace:
- return fmt # e.g. "0.05"
- else:
- return '= ' + fmt # e.g. "= 0.05"
-
-def fmt_p_html(p, nospace=False):
- """Format p value for HTML"""
-
- txt = fmt_p_text(p, nospace)
- return txt.replace('<', '<')
-
-class Estimate:
- """A point estimate and surrounding confidence interval"""
-
- def __init__(self, point, ci_lower, ci_upper):
- self.point = point
- self.ci_lower = ci_lower
- self.ci_upper = ci_upper
-
- def _repr_html_(self):
- return self.summary()
-
- def summary(self):
- return '{:.2f} ({:.2f}–{:.2f})'.format(self.point, self.ci_lower, self.ci_upper)
+# ----------------
+# Student's t test
class TTestResult:
"""
@@ -97,44 +33,41 @@ class TTestResult:
delta: Mean difference
"""
- def __init__(self, statistic, dof, pvalue, delta):
+ def __init__(self, statistic, dof, pvalue, delta, delta_direction):
self.statistic = statistic
self.dof = dof
self.pvalue = pvalue
self.delta = delta
+ self.delta_direction = delta_direction
def _repr_html_(self):
- return 't({:.0f}) = {:.2f}; p {}
δ (95% CI) = {}'.format(self.dof, self.statistic, fmt_p_html(self.pvalue), self.delta.summary())
+ return 't({:.0f}) = {:.2f}; p {}
δ (95% CI) = {}, {}'.format(self.dof, self.statistic, fmt_p_html(self.pvalue), self.delta.summary(), self.delta_direction)
def summary(self):
- return 't({:.0f}) = {:.2f}; p {}\nδ (95% CI) = {}'.format(self.dof, self.statistic, fmt_p_text(self.pvalue), self.delta.summary())
+ return 't({:.0f}) = {:.2f}; p {}\nδ (95% CI) = {}, {}'.format(self.dof, self.statistic, fmt_p_text(self.pvalue), self.delta.summary(), self.delta_direction)
def ttest_ind(df, dep, ind, *, nan_policy='warn'):
"""Perform an independent-sample Student's t test"""
+ # Check for/clean NaNs
df = check_nan(df[[ind, dep]], nan_policy)
- # Get groupings for ind
- groups = list(df.groupby(ind).groups.values())
-
- # Ensure only 2 groups to compare
- if len(groups) != 2:
- raise Exception('Got {} values for {}, expected 2'.format(len(groups), ind))
-
- # Get 2 groups
- group1 = df.loc[groups[0], dep]
- group2 = df.loc[groups[1], dep]
+ # Ensure 2 groups for ind
+ group1, data1, group2, data2 = as_2groups(df, dep, ind)
# Do t test
# Use statsmodels rather than SciPy because this provides the mean difference automatically
- d1 = sm.stats.DescrStatsW(group1)
- d2 = sm.stats.DescrStatsW(group2)
+ d1 = sm.stats.DescrStatsW(data1)
+ d2 = sm.stats.DescrStatsW(data2)
- cm = sm.stats.CompareMeans(d2, d1) # This order to get correct CI
+ cm = sm.stats.CompareMeans(d1, d2)
statistic, pvalue, dof = cm.ttest_ind()
- delta = d2.mean - d1.mean
+ delta = d1.mean - d2.mean
ci0, ci1 = cm.tconfint_diff()
- return TTestResult(statistic=statistic, dof=dof, pvalue=pvalue, delta=Estimate(delta, ci0, ci1))
-0
\ No newline at end of file
+ # t test is symmetric so take absolute values
+ return TTestResult(
+ statistic=abs(statistic), dof=dof, pvalue=pvalue,
+ delta=abs(Estimate(delta, ci0, ci1)),
+ delta_direction=('{0} > {1}' if d1.mean > d2.mean else '{1} > {0}').format(group1, group2))
diff --git a/yli/utils.py b/yli/utils.py
new file mode 100644
index 0000000..bc0840e
--- /dev/null
+++ b/yli/utils.py
@@ -0,0 +1,119 @@
+# scipy-yli: Helpful SciPy utilities and recipes
+# Copyright © 2022 Lee Yingtong Li (RunasSudo)
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import numpy as np
+import pandas as pd
+
+import warnings
+
+def check_nan(df, nan_policy):
+ """Check df against nan_policy and return cleaned input"""
+
+ if nan_policy == 'raise':
+ if pd.isna(df).any(axis=None):
+ raise ValueError('NaN in input, pass nan_policy="warn" or "omit" to ignore')
+ elif nan_policy == 'warn':
+ df_cleaned = df.dropna()
+ if len(df_cleaned) < len(df):
+ warnings.warn('Omitting {} rows with NaN'.format(len(df) - len(df_cleaned)))
+ return df_cleaned
+ elif nan_policy == 'omit':
+ return df.dropna()
+ else:
+ raise Exception('Invalid nan_policy, expected "raise", "warn" or "omit"')
+
+def as_2groups(df, data, group):
+ """Group the data by the given variable, ensuring only 2 groups"""
+
+ # Get groupings
+ groups = list(df.groupby(group).groups.items())
+
+ # Ensure only 2 groups to compare
+ if len(groups) != 2:
+ raise Exception('Got {} values for {}, expected 2'.format(len(groups), group))
+
+ # Get 2 groups
+ group1 = groups[0][0]
+ data1 = df.loc[groups[0][1], data]
+ group2 = groups[1][0]
+ data2 = df.loc[groups[1][1], data]
+
+ return group1, data1, group2, data2
+
+def do_fmt_p(p):
+ """Return sign and formatted p value"""
+
+ if p < 0.001:
+ return '<', '0.001*'
+ elif p < 0.0095:
+ return None, '{:.3f}*'.format(p)
+ elif p < 0.045:
+ return None, '{:.2f}*'.format(p)
+ elif p < 0.05:
+ return None, '{:.3f}*'.format(p) # 3dps to show significance
+ elif p < 0.055:
+ return None, '{:.3f}'.format(p) # 3dps to show non-significance
+ elif p < 0.095:
+ return None, '{:.2f}'.format(p)
+ else:
+ return None, '{:.1f}'.format(p)
+
+def fmt_p_text(p, nospace=False):
+ """Format p value for plaintext"""
+
+ sign, fmt = do_fmt_p(p)
+ if sign is not None:
+ if nospace:
+ return sign + fmt # e.g. "<0.001"
+ else:
+ return sign + ' ' + fmt # e.g. "< 0.001"
+ else:
+ if nospace:
+ return fmt # e.g. "0.05"
+ else:
+ return '= ' + fmt # e.g. "= 0.05"
+
+def fmt_p_html(p, nospace=False):
+ """Format p value for HTML"""
+
+ txt = fmt_p_text(p, nospace)
+ return txt.replace('<', '<')
+
+class Estimate:
+ """A point estimate and surrounding confidence interval"""
+
+ def __init__(self, point, ci_lower, ci_upper):
+ self.point = point
+ self.ci_lower = ci_lower
+ self.ci_upper = ci_upper
+
+ def _repr_html_(self):
+ return self.summary()
+
+ def summary(self):
+ return '{:.2f} ({:.2f}–{:.2f})'.format(self.point, self.ci_lower, self.ci_upper)
+
+ def __neg__(self):
+ return Estimate(-self.point, -self.ci_upper, -self.ci_lower)
+
+ def __abs__(self):
+ if self.point < 0:
+ return -self
+ else:
+ return self
+
+ def exp(self):
+ return Estimate(np.exp(self.point), np.exp(self.ci_lower), np.exp(self.ci_upper))