lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,125 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+ from jax import numpy as jnp
5
+
6
+ from .constants import SmallSampleCorrections
7
+ from .helper_functions import invert_matrix_and_check_conditioning
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logging.basicConfig(
11
+ format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
12
+ datefmt="%Y-%m-%d:%H:%M:%S",
13
+ level=logging.INFO,
14
+ )
15
+
16
+
17
+ def perform_desired_small_sample_correction(
18
+ small_sample_correction,
19
+ per_user_joint_adaptive_meat_contributions,
20
+ per_user_classical_meat_contributions,
21
+ per_user_classical_bread_inverse_contributions,
22
+ num_users,
23
+ theta_dim,
24
+ ):
25
+
26
+ # We first compute the classical inverse bread matrix and invert it. While
27
+ # it is possible to avoid this inversion using a QR decomposition and
28
+ # solving linear systems (discussed more below), we typically don't have
29
+ # issues with the conditioning of just the classical bread.
30
+ classical_bread_inverse_matrix = jnp.mean(
31
+ per_user_classical_bread_inverse_contributions, axis=0
32
+ )
33
+ classical_bread_matrix = invert_matrix_and_check_conditioning(
34
+ classical_bread_inverse_matrix,
35
+ )[0]
36
+
37
+ # These will hold either corrective matrices or scalar weights depending on
38
+ # what small sample correction is requested.
39
+ per_user_adaptive_corrections = None
40
+ per_user_classical_corrections = None
41
+
42
+ per_user_adaptive_correction_weights = np.ones(num_users)
43
+ per_user_classical_correction_weights = np.ones(num_users)
44
+ if small_sample_correction == SmallSampleCorrections.NONE:
45
+ logger.info(
46
+ "No small sample correction requested. Using the raw per-user joint adaptive bread inverse contributions."
47
+ )
48
+ elif small_sample_correction == SmallSampleCorrections.HC1theta:
49
+ logger.info(
50
+ "Using HC1 small sample correction at the user trajectory level. Note that we are treating the number of parameters as simply the size of theta, despite the presence of betas."
51
+ )
52
+ per_user_adaptive_correction_weights = per_user_classical_correction_weights = (
53
+ num_users / (num_users - theta_dim) * np.ones(num_users)
54
+ )
55
+ elif small_sample_correction in {
56
+ SmallSampleCorrections.HC2theta,
57
+ SmallSampleCorrections.HC3theta,
58
+ }:
59
+ logger.info("Using %s small sample correction at the user trajectory level.")
60
+
61
+ power = 1 if small_sample_correction == SmallSampleCorrections.HC2theta else 2
62
+
63
+ # It turns out to typically not make sense to compute the adaptive analog
64
+ # of the classical leverages, since this involves correcting the joint adaptive meat matrix
65
+ # involving all beta and theta parameters. HC2/HC3 corrections assume that
66
+ # the number of parameters is smaller than the number of users, which will not typically be
67
+ # the case if the number of users is small enough for these corrections to be important.
68
+ # Therefore we also use the "classical" leverages for the adaptive correction weights, which
69
+ # is sensible, corresponding to only adjusting based on the estimating equations for theta.
70
+
71
+ # ALSO note that one way to test correctness of the leverages is that they should sum
72
+ # to the number of inference parameters, ie the size of theta. I tested that this is
73
+ # true both for the classical leverages and the larger joint adaptive leverages when they
74
+ # were still used, lending credence to the below calculations.
75
+
76
+ # TODO: Write a unit test for some level of logic here and then rewrite this to not require
77
+ # the classical bread explicitly. May be slower, probably needs a for loop so that can use
78
+ # a solver for each matrix multiplication after a QR decomposition of the bread inverse
79
+ # transpose.
80
+ classical_leverages_per_user = (
81
+ np.einsum(
82
+ "nij,ji->n",
83
+ per_user_classical_bread_inverse_contributions,
84
+ classical_bread_matrix,
85
+ )
86
+ / num_users
87
+ )
88
+ per_user_classical_correction_weights = 1 / (
89
+ (1 - classical_leverages_per_user) ** power
90
+ )
91
+
92
+ per_user_adaptive_correction_weights = per_user_classical_correction_weights
93
+ else:
94
+ raise ValueError(
95
+ f"Unknown small sample correction: {small_sample_correction}. "
96
+ "Please choose from values in SmallSampleCorrections class."
97
+ )
98
+
99
+ # If we used matrix corrections, they will be stored as these corrections.
100
+ # Otherwise, store the scalar weights.
101
+ if per_user_adaptive_corrections is None:
102
+ per_user_adaptive_corrections = per_user_adaptive_correction_weights
103
+ if per_user_classical_corrections is None:
104
+ per_user_classical_corrections = per_user_classical_correction_weights
105
+
106
+ # The scalar corrections will have computed weights that need to be applied here,
107
+ # whereas the matrix corrections will have been applied to the per-user
108
+ # contributions already.
109
+ joint_adaptive_meat_matrix = jnp.mean(
110
+ per_user_adaptive_correction_weights[:, None, None]
111
+ * per_user_joint_adaptive_meat_contributions,
112
+ axis=0,
113
+ )
114
+ classical_meat_matrix = jnp.mean(
115
+ per_user_classical_correction_weights[:, None, None]
116
+ * per_user_classical_meat_contributions,
117
+ axis=0,
118
+ )
119
+
120
+ return (
121
+ joint_adaptive_meat_matrix,
122
+ classical_meat_matrix,
123
+ per_user_adaptive_corrections,
124
+ per_user_classical_corrections,
125
+ )