lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
|
|
6
|
+
from .constants import SmallSampleCorrections
|
|
7
|
+
from .helper_functions import invert_matrix_and_check_conditioning
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
logging.basicConfig(
|
|
11
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
12
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
13
|
+
level=logging.INFO,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def perform_desired_small_sample_correction(
|
|
18
|
+
small_sample_correction,
|
|
19
|
+
per_user_joint_adaptive_meat_contributions,
|
|
20
|
+
per_user_classical_meat_contributions,
|
|
21
|
+
per_user_classical_bread_inverse_contributions,
|
|
22
|
+
num_users,
|
|
23
|
+
theta_dim,
|
|
24
|
+
):
|
|
25
|
+
|
|
26
|
+
# We first compute the classical inverse bread matrix and invert it. While
|
|
27
|
+
# it is possible to avoid this inversion using a QR decomposition and
|
|
28
|
+
# solving linear systems (discussed more below), we typically don't have
|
|
29
|
+
# issues with the conditioning of just the classical bread.
|
|
30
|
+
classical_bread_inverse_matrix = jnp.mean(
|
|
31
|
+
per_user_classical_bread_inverse_contributions, axis=0
|
|
32
|
+
)
|
|
33
|
+
classical_bread_matrix = invert_matrix_and_check_conditioning(
|
|
34
|
+
classical_bread_inverse_matrix,
|
|
35
|
+
)[0]
|
|
36
|
+
|
|
37
|
+
# These will hold either corrective matrices or scalar weights depending on
|
|
38
|
+
# what small sample correction is requested.
|
|
39
|
+
per_user_adaptive_corrections = None
|
|
40
|
+
per_user_classical_corrections = None
|
|
41
|
+
|
|
42
|
+
per_user_adaptive_correction_weights = np.ones(num_users)
|
|
43
|
+
per_user_classical_correction_weights = np.ones(num_users)
|
|
44
|
+
if small_sample_correction == SmallSampleCorrections.NONE:
|
|
45
|
+
logger.info(
|
|
46
|
+
"No small sample correction requested. Using the raw per-user joint adaptive bread inverse contributions."
|
|
47
|
+
)
|
|
48
|
+
elif small_sample_correction == SmallSampleCorrections.HC1theta:
|
|
49
|
+
logger.info(
|
|
50
|
+
"Using HC1 small sample correction at the user trajectory level. Note that we are treating the number of parameters as simply the size of theta, despite the presence of betas."
|
|
51
|
+
)
|
|
52
|
+
per_user_adaptive_correction_weights = per_user_classical_correction_weights = (
|
|
53
|
+
num_users / (num_users - theta_dim) * np.ones(num_users)
|
|
54
|
+
)
|
|
55
|
+
elif small_sample_correction in {
|
|
56
|
+
SmallSampleCorrections.HC2theta,
|
|
57
|
+
SmallSampleCorrections.HC3theta,
|
|
58
|
+
}:
|
|
59
|
+
logger.info("Using %s small sample correction at the user trajectory level.")
|
|
60
|
+
|
|
61
|
+
power = 1 if small_sample_correction == SmallSampleCorrections.HC2theta else 2
|
|
62
|
+
|
|
63
|
+
# It turns out to typically not make sense to compute the adaptive analog
|
|
64
|
+
# of the classical leverages, since this involves correcting the joint adaptive meat matrix
|
|
65
|
+
# involving all beta and theta parameters. HC2/HC3 corrections assume that
|
|
66
|
+
# the number of parameters is smaller than the number of users, which will not typically be
|
|
67
|
+
# the case if the number of users is small enough for these corrections to be important.
|
|
68
|
+
# Therefore we also use the "classical" leverages for the adaptive correction weights, which
|
|
69
|
+
# is sensible, corresponding to only adjusting based on the estimating equations for theta.
|
|
70
|
+
|
|
71
|
+
# ALSO note that one way to test correctness of the leverages is that they should sum
|
|
72
|
+
# to the number of inference parameters, ie the size of theta. I tested that this is
|
|
73
|
+
# true both for the classical leverages and the larger joint adaptive leverages when they
|
|
74
|
+
# were still used, lending credence to the below calculations.
|
|
75
|
+
|
|
76
|
+
# TODO: Write a unit test for some level of logic here and then rewrite this to not require
|
|
77
|
+
# the classical bread explicitly. May be slower, probably needs a for loop so that can use
|
|
78
|
+
# a solver for each matrix multiplication after a QR decomposition of the bread inverse
|
|
79
|
+
# transpose.
|
|
80
|
+
classical_leverages_per_user = (
|
|
81
|
+
np.einsum(
|
|
82
|
+
"nij,ji->n",
|
|
83
|
+
per_user_classical_bread_inverse_contributions,
|
|
84
|
+
classical_bread_matrix,
|
|
85
|
+
)
|
|
86
|
+
/ num_users
|
|
87
|
+
)
|
|
88
|
+
per_user_classical_correction_weights = 1 / (
|
|
89
|
+
(1 - classical_leverages_per_user) ** power
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
per_user_adaptive_correction_weights = per_user_classical_correction_weights
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Unknown small sample correction: {small_sample_correction}. "
|
|
96
|
+
"Please choose from values in SmallSampleCorrections class."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# If we used matrix corrections, they will be stored as these corrections.
|
|
100
|
+
# Otherwise, store the scalar weights.
|
|
101
|
+
if per_user_adaptive_corrections is None:
|
|
102
|
+
per_user_adaptive_corrections = per_user_adaptive_correction_weights
|
|
103
|
+
if per_user_classical_corrections is None:
|
|
104
|
+
per_user_classical_corrections = per_user_classical_correction_weights
|
|
105
|
+
|
|
106
|
+
# The scalar corrections will have computed weights that need to be applied here,
|
|
107
|
+
# whereas the matrix corrections will have been applied to the per-user
|
|
108
|
+
# contributions already.
|
|
109
|
+
joint_adaptive_meat_matrix = jnp.mean(
|
|
110
|
+
per_user_adaptive_correction_weights[:, None, None]
|
|
111
|
+
* per_user_joint_adaptive_meat_contributions,
|
|
112
|
+
axis=0,
|
|
113
|
+
)
|
|
114
|
+
classical_meat_matrix = jnp.mean(
|
|
115
|
+
per_user_classical_correction_weights[:, None, None]
|
|
116
|
+
* per_user_classical_meat_contributions,
|
|
117
|
+
axis=0,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return (
|
|
121
|
+
joint_adaptive_meat_matrix,
|
|
122
|
+
classical_meat_matrix,
|
|
123
|
+
per_user_adaptive_corrections,
|
|
124
|
+
per_user_classical_corrections,
|
|
125
|
+
)
|