lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -23
- lifejacket/constants.py +4 -16
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +163 -138
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +32 -34
- lifejacket/get_datum_for_blowup_supervised_learning.py +341 -339
- lifejacket/helper_functions.py +60 -186
- lifejacket/input_checks.py +303 -302
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +470 -457
- lifejacket/small_sample_corrections.py +49 -49
- lifejacket-1.0.2.dist-info/METADATA +56 -0
- lifejacket-1.0.2.dist-info/RECORD +17 -0
- lifejacket-1.0.2.dist-info/entry_points.txt +2 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- lifejacket-0.2.1.dist-info/entry_points.txt +0 -2
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/top_level.txt +0 -0
lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py}
RENAMED
|
@@ -18,16 +18,16 @@ logging.basicConfig(
|
|
|
18
18
|
)
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def
|
|
21
|
+
def form_adjusted_meat_adjustments_directly(
|
|
22
22
|
theta_dim: int,
|
|
23
23
|
beta_dim: int,
|
|
24
|
-
|
|
24
|
+
joint_bread_matrix: jnp.ndarray,
|
|
25
25
|
per_user_estimating_function_stacks: jnp.ndarray,
|
|
26
26
|
study_df: pd.DataFrame,
|
|
27
|
-
|
|
27
|
+
active_col_name: str,
|
|
28
28
|
action_col_name: str,
|
|
29
29
|
calendar_t_col_name: str,
|
|
30
|
-
|
|
30
|
+
subject_id_col_name: str,
|
|
31
31
|
action_prob_func: callable,
|
|
32
32
|
action_prob_func_args: dict,
|
|
33
33
|
action_prob_func_args_beta_index: int,
|
|
@@ -38,18 +38,16 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
38
38
|
action_prob_col_name: str,
|
|
39
39
|
) -> jnp.ndarray:
|
|
40
40
|
logger.info(
|
|
41
|
-
"Explicitly forming the per-user meat adjustments that differentiate the
|
|
41
|
+
"Explicitly forming the per-user meat adjustments that differentiate the adjusted sandwich from the classical sandwich."
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
# 1. Form the M-matrices, which are shared across users.
|
|
45
45
|
# This is not quite the paper definition of the M-matrices, which
|
|
46
46
|
# includes multiplication by the classical bread. We don't care about
|
|
47
47
|
# that here, since in forming the adjustments there is a multiplication
|
|
48
|
-
# by the classical bread
|
|
49
|
-
V_blocks_together =
|
|
50
|
-
RL_stack_beta_derivatives_block =
|
|
51
|
-
:-theta_dim, :-theta_dim
|
|
52
|
-
]
|
|
48
|
+
# by the classical bread that cancels it out.
|
|
49
|
+
V_blocks_together = joint_bread_matrix[-theta_dim:, :-theta_dim]
|
|
50
|
+
RL_stack_beta_derivatives_block = joint_bread_matrix[:-theta_dim, :-theta_dim]
|
|
53
51
|
effective_M_blocks_together = np.linalg.solve(
|
|
54
52
|
RL_stack_beta_derivatives_block.T, V_blocks_together.T
|
|
55
53
|
).T
|
|
@@ -89,17 +87,17 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
89
87
|
# be added to each users inference estimating function before an outer product is taken with
|
|
90
88
|
# itself to get each users's contributioan theta-only meat matrix).
|
|
91
89
|
# Result is shape (num_users, theta_dim).
|
|
92
|
-
# Form the per-user
|
|
90
|
+
# Form the per-user adjusted meat adjustments explicitly for diagnostic purposes.
|
|
93
91
|
per_user_meat_adjustments_stacked = np.einsum(
|
|
94
92
|
"utb,nub->nt", M_blocks_stacked, per_user_RL_only_est_fns_stacked
|
|
95
93
|
)
|
|
96
94
|
|
|
97
|
-
# Log some diagnostics about the pieces going into the
|
|
95
|
+
# Log some diagnostics about the pieces going into the adjusted meat adjustments
|
|
98
96
|
# and the adjustments themselves.
|
|
99
97
|
V_blocks = np.split(
|
|
100
98
|
V_blocks_together, V_blocks_together.shape[1] // beta_dim, axis=1
|
|
101
99
|
)
|
|
102
|
-
logger.info("Examining
|
|
100
|
+
logger.info("Examining adjusted meat adjustments.")
|
|
103
101
|
# No scientific notation
|
|
104
102
|
np.set_printoptions(suppress=True)
|
|
105
103
|
|
|
@@ -118,11 +116,11 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
118
116
|
)
|
|
119
117
|
|
|
120
118
|
logger.debug(
|
|
121
|
-
"Per-user
|
|
119
|
+
"Per-user adjusted meat adjustments, to be added to inference estimating functions before forming the meat. Formed from the sum of the products of the M-blocks and the corresponding RL update estimating functions for each user: %s",
|
|
122
120
|
per_user_meat_adjustments_stacked,
|
|
123
121
|
)
|
|
124
122
|
logger.debug(
|
|
125
|
-
"Norms of per-user
|
|
123
|
+
"Norms of per-user adjusted meat adjustments: %s",
|
|
126
124
|
np.linalg.norm(per_user_meat_adjustments_stacked, axis=1),
|
|
127
125
|
)
|
|
128
126
|
|
|
@@ -148,10 +146,10 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
148
146
|
|
|
149
147
|
logger.debug(
|
|
150
148
|
"M_blocks, one per update, each shape theta_dim x beta_dim. The sum of the products "
|
|
151
|
-
"of each of these times a user's corresponding RL estimating function forms their
|
|
149
|
+
"of each of these times a user's corresponding RL estimating function forms their adjusted "
|
|
152
150
|
"adjustment. The M's are the blocks of the the product of the V's concatened and the inverse of "
|
|
153
|
-
"the RL-only upper-left block of the joint
|
|
154
|
-
"left block of the joint
|
|
151
|
+
"the RL-only upper-left block of the joint bread. In other words, the lower "
|
|
152
|
+
"left block of the joint bread. Also note that the inference estimating function "
|
|
155
153
|
"derivative inverse is omitted here despite the definition of the M's in the paper, because "
|
|
156
154
|
"that factor simply cancels later: %s",
|
|
157
155
|
M_blocks_stacked,
|
|
@@ -159,11 +157,11 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
159
157
|
logger.debug("Norms of M-blocks: %s", np.linalg.norm(M_blocks_stacked, axis=(1, 2)))
|
|
160
158
|
|
|
161
159
|
logger.debug(
|
|
162
|
-
"RL block of joint
|
|
160
|
+
"RL block of joint bread. The *inverse* of this goes into the M's: %s",
|
|
163
161
|
RL_stack_beta_derivatives_block,
|
|
164
162
|
)
|
|
165
163
|
logger.debug(
|
|
166
|
-
"Norm of RL block of joint
|
|
164
|
+
"Norm of RL block of joint bread: %s",
|
|
167
165
|
np.linalg.norm(RL_stack_beta_derivatives_block),
|
|
168
166
|
)
|
|
169
167
|
|
|
@@ -171,11 +169,11 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
171
169
|
RL_stack_beta_derivatives_block
|
|
172
170
|
)
|
|
173
171
|
logger.debug(
|
|
174
|
-
"Inverse of RL block of joint
|
|
172
|
+
"Inverse of RL block of joint bread. This goes into the M's: %s",
|
|
175
173
|
inverse_RL_stack_beta_derivatives_block,
|
|
176
174
|
)
|
|
177
175
|
logger.debug(
|
|
178
|
-
"Norm of Inverse of RL block of joint
|
|
176
|
+
"Norm of Inverse of RL block of joint bread: %s",
|
|
179
177
|
np.linalg.norm(inverse_RL_stack_beta_derivatives_block),
|
|
180
178
|
)
|
|
181
179
|
|
|
@@ -205,10 +203,10 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
205
203
|
|
|
206
204
|
pi_and_weight_gradients_by_calendar_t = calculate_pi_and_weight_gradients(
|
|
207
205
|
study_df,
|
|
208
|
-
|
|
206
|
+
active_col_name,
|
|
209
207
|
action_col_name,
|
|
210
208
|
calendar_t_col_name,
|
|
211
|
-
|
|
209
|
+
subject_id_col_name,
|
|
212
210
|
action_prob_func,
|
|
213
211
|
action_prob_func_args,
|
|
214
212
|
action_prob_func_args_beta_index,
|
|
@@ -220,9 +218,9 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
220
218
|
inference_func,
|
|
221
219
|
inference_func_args_theta_index,
|
|
222
220
|
user_ids,
|
|
223
|
-
|
|
221
|
+
subject_id_col_name,
|
|
224
222
|
action_prob_col_name,
|
|
225
|
-
|
|
223
|
+
active_col_name,
|
|
226
224
|
calendar_t_col_name,
|
|
227
225
|
)
|
|
228
226
|
# Take the outer product of each row of (per_user_meat_adjustments_stacked + per_user_inference_estimating_functions_stacked)
|
|
@@ -230,17 +228,17 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
230
228
|
per_user_meat_adjustments_stacked
|
|
231
229
|
+ per_user_inference_estimating_functions_stacked
|
|
232
230
|
)
|
|
233
|
-
|
|
231
|
+
per_user_theta_only_adjusted_meat_contributions = jnp.einsum(
|
|
234
232
|
"ni,nj->nij",
|
|
235
233
|
per_user_adjusted_inference_estimating_functions_stacked,
|
|
236
234
|
per_user_adjusted_inference_estimating_functions_stacked,
|
|
237
235
|
)
|
|
238
|
-
|
|
239
|
-
|
|
236
|
+
adjusted_theta_only_meat_matrix = jnp.mean(
|
|
237
|
+
per_user_theta_only_adjusted_meat_contributions, axis=0
|
|
240
238
|
)
|
|
241
239
|
logger.info(
|
|
242
|
-
"Theta-only
|
|
243
|
-
|
|
240
|
+
"Theta-only adjusted meat matrix (no small sample corrections): %s",
|
|
241
|
+
adjusted_theta_only_meat_matrix,
|
|
244
242
|
)
|
|
245
243
|
classical_theta_only_meat_matrix = jnp.mean(
|
|
246
244
|
jnp.einsum(
|
|
@@ -311,7 +309,7 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
311
309
|
# much or being too opinionated about what to log.
|
|
312
310
|
breakpoint()
|
|
313
311
|
|
|
314
|
-
# # Visualize the inverse RL block of joint
|
|
312
|
+
# # Visualize the inverse RL block of joint bread using seaborn heatmap
|
|
315
313
|
# pyplt.figure(figsize=(8, 6))
|
|
316
314
|
# sns.heatmap(inverse_RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
|
|
317
315
|
# pyplt.title("Inverse RL Block of Joint Adaptive Bread Inverse")
|
|
@@ -320,7 +318,7 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
320
318
|
# pyplt.tight_layout()
|
|
321
319
|
# pyplt.show()
|
|
322
320
|
|
|
323
|
-
# # # Visualize the RL block of joint
|
|
321
|
+
# # # Visualize the RL block of joint bread using seaborn heatmap
|
|
324
322
|
|
|
325
323
|
# pyplt.figure(figsize=(8, 6))
|
|
326
324
|
# sns.heatmap(RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
|
|
@@ -330,4 +328,4 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
330
328
|
# pyplt.tight_layout()
|
|
331
329
|
# pyplt.show()
|
|
332
330
|
|
|
333
|
-
return
|
|
331
|
+
return per_user_theta_only_adjusted_meat_contributions
|