lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -23
- lifejacket/constants.py +4 -16
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +163 -138
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +32 -34
- lifejacket/get_datum_for_blowup_supervised_learning.py +341 -339
- lifejacket/helper_functions.py +60 -186
- lifejacket/input_checks.py +303 -302
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +470 -457
- lifejacket/small_sample_corrections.py +49 -49
- lifejacket-1.0.2.dist-info/METADATA +56 -0
- lifejacket-1.0.2.dist-info/RECORD +17 -0
- lifejacket-1.0.2.dist-info/entry_points.txt +2 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- lifejacket-0.2.1.dist-info/entry_points.txt +0 -2
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@ import jax
|
|
|
12
12
|
from jax import numpy as jnp
|
|
13
13
|
import pandas as pd
|
|
14
14
|
|
|
15
|
-
from . import
|
|
15
|
+
from . import post_deployment_analysis
|
|
16
16
|
from .constants import FunctionTypes
|
|
17
17
|
from .vmap_helpers import stack_batched_arg_lists_into_tensors
|
|
18
18
|
|
|
@@ -25,25 +25,25 @@ logging.basicConfig(
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def get_datum_for_blowup_supervised_learning(
|
|
28
|
-
|
|
29
|
-
|
|
28
|
+
joint_adjusted_bread_matrix,
|
|
29
|
+
joint_adjusted_bread_cond,
|
|
30
30
|
avg_estimating_function_stack,
|
|
31
|
-
|
|
31
|
+
per_subject_estimating_function_stacks,
|
|
32
32
|
all_post_update_betas,
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
analysis_df,
|
|
34
|
+
active_col_name,
|
|
35
35
|
calendar_t_col_name,
|
|
36
36
|
action_prob_col_name,
|
|
37
|
-
|
|
37
|
+
subject_id_col_name,
|
|
38
38
|
reward_col_name,
|
|
39
39
|
theta_est,
|
|
40
|
-
|
|
41
|
-
|
|
40
|
+
adjusted_sandwich_var_estimate,
|
|
41
|
+
subject_ids,
|
|
42
42
|
beta_dim,
|
|
43
43
|
theta_dim,
|
|
44
44
|
initial_policy_num,
|
|
45
45
|
beta_index_by_policy_num,
|
|
46
|
-
|
|
46
|
+
policy_num_by_decision_time_by_subject_id,
|
|
47
47
|
theta_calculation_func,
|
|
48
48
|
action_prob_func,
|
|
49
49
|
action_prob_func_args_beta_index,
|
|
@@ -51,46 +51,46 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
51
51
|
inference_func_type,
|
|
52
52
|
inference_func_args_theta_index,
|
|
53
53
|
inference_func_args_action_prob_index,
|
|
54
|
-
|
|
54
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
55
55
|
action_prob_func_args,
|
|
56
|
-
|
|
56
|
+
action_by_decision_time_by_subject_id,
|
|
57
57
|
) -> dict[str, Any]:
|
|
58
58
|
"""
|
|
59
|
-
Collects a datum for supervised learning about
|
|
59
|
+
Collects a datum for supervised learning about adjusted sandwich blowup.
|
|
60
60
|
|
|
61
|
-
The datum consists of features and the raw
|
|
61
|
+
The datum consists of features and the raw adjusted sandwich variance estimate as a label.
|
|
62
62
|
|
|
63
63
|
A few plots are produced along the way to help visualize the data.
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
|
-
|
|
67
|
-
The joint
|
|
68
|
-
|
|
69
|
-
The condition number of the joint
|
|
66
|
+
joint_adjusted_bread_matrix (jnp.ndarray):
|
|
67
|
+
The joint adjusted bread matrix.
|
|
68
|
+
joint_adjusted_bread_cond (float):
|
|
69
|
+
The condition number of the joint adjusted bread matrix.
|
|
70
70
|
avg_estimating_function_stack (jnp.ndarray):
|
|
71
|
-
The average estimating function stack across
|
|
72
|
-
|
|
73
|
-
The estimating function stacks for each
|
|
71
|
+
The average estimating function stack across subjects.
|
|
72
|
+
per_subject_estimating_function_stacks (jnp.ndarray):
|
|
73
|
+
The estimating function stacks for each subject.
|
|
74
74
|
all_post_update_betas (jnp.ndarray):
|
|
75
75
|
All post-update beta parameters.
|
|
76
|
-
|
|
76
|
+
analysis_df (pd.DataFrame):
|
|
77
77
|
The study DataFrame.
|
|
78
|
-
|
|
79
|
-
Column name indicating if a
|
|
78
|
+
active_col_name (str):
|
|
79
|
+
Column name indicating if a subject is in the study in the study dataframe.
|
|
80
80
|
calendar_t_col_name (str):
|
|
81
81
|
Column name for calendar time in the study dataframe.
|
|
82
82
|
action_prob_col_name (str):
|
|
83
83
|
Column name for action probabilities in the study dataframe.
|
|
84
|
-
|
|
85
|
-
Column name for
|
|
84
|
+
subject_id_col_name (str):
|
|
85
|
+
Column name for subject IDs in the study dataframe
|
|
86
86
|
reward_col_name (str):
|
|
87
87
|
Column name for rewards in the study dataframe.
|
|
88
88
|
theta_est (jnp.ndarray):
|
|
89
89
|
The estimate of the parameter vector theta.
|
|
90
|
-
|
|
91
|
-
The
|
|
92
|
-
|
|
93
|
-
Array of unique
|
|
90
|
+
adjusted_sandwich_var_estimate (jnp.ndarray):
|
|
91
|
+
The adjusted sandwich variance estimate for theta.
|
|
92
|
+
subject_ids (jnp.ndarray):
|
|
93
|
+
Array of unique subject IDs.
|
|
94
94
|
beta_dim (int):
|
|
95
95
|
Dimension of the beta parameter vector.
|
|
96
96
|
theta_dim (int):
|
|
@@ -99,8 +99,8 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
99
99
|
The initial policy number used in the study.
|
|
100
100
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
101
101
|
Mapping from policy numbers to indices in all_post_update_betas.
|
|
102
|
-
|
|
103
|
-
Mapping from
|
|
102
|
+
policy_num_by_decision_time_by_subject_id (dict):
|
|
103
|
+
Mapping from subject IDs to their policy numbers by decision time.
|
|
104
104
|
theta_calculation_func (callable):
|
|
105
105
|
The theta calculation function.
|
|
106
106
|
action_prob_func (callable):
|
|
@@ -115,17 +115,17 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
115
115
|
Index for theta in inference function arguments.
|
|
116
116
|
inference_func_args_action_prob_index (int):
|
|
117
117
|
Index for action probability in inference function arguments.
|
|
118
|
-
|
|
119
|
-
Mapping from
|
|
118
|
+
inference_action_prob_decision_times_by_subject_id (dict):
|
|
119
|
+
Mapping from subject IDs to decision times for action probabilities used in inference.
|
|
120
120
|
action_prob_func_args (dict):
|
|
121
121
|
Arguments for the action probability function.
|
|
122
|
-
|
|
123
|
-
Mapping from
|
|
122
|
+
action_by_decision_time_by_subject_id (dict):
|
|
123
|
+
Mapping from subject IDs to their actions by decision time.
|
|
124
124
|
Returns:
|
|
125
125
|
dict[str, Any]: A dictionary containing features and the label for supervised learning.
|
|
126
126
|
"""
|
|
127
127
|
num_diagonal_blocks = (
|
|
128
|
-
(
|
|
128
|
+
(joint_adjusted_bread_matrix.shape[0] - theta_dim) // beta_dim
|
|
129
129
|
) + 1
|
|
130
130
|
diagonal_block_sizes = ([beta_dim] * (num_diagonal_blocks - 1)) + [theta_dim]
|
|
131
131
|
|
|
@@ -144,7 +144,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
144
144
|
row_slice = slice(block_bounds[i], block_bounds[i + 1])
|
|
145
145
|
col_slice = slice(block_bounds[j], block_bounds[j + 1])
|
|
146
146
|
block_norm = np.linalg.norm(
|
|
147
|
-
|
|
147
|
+
joint_adjusted_bread_matrix[row_slice, col_slice],
|
|
148
148
|
ord="fro",
|
|
149
149
|
)
|
|
150
150
|
# We will sum here and take the square root later
|
|
@@ -155,9 +155,9 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
155
155
|
# handle diagonal blocks
|
|
156
156
|
sl = slice(block_bounds[i], block_bounds[i + 1])
|
|
157
157
|
diag_norms.append(
|
|
158
|
-
np.linalg.norm(
|
|
158
|
+
np.linalg.norm(joint_adjusted_bread_matrix[sl, sl], ord="fro")
|
|
159
159
|
)
|
|
160
|
-
diag_conds.append(np.linalg.cond(
|
|
160
|
+
diag_conds.append(np.linalg.cond(joint_adjusted_bread_matrix[sl, sl]))
|
|
161
161
|
|
|
162
162
|
# Sqrt each row/col sum to truly get row/column norms.
|
|
163
163
|
# Perhaps not necessary for learning, but more natural
|
|
@@ -166,7 +166,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
166
166
|
|
|
167
167
|
# Get the per-person estimating function stack norms
|
|
168
168
|
estimating_function_stack_norms = np.linalg.norm(
|
|
169
|
-
|
|
169
|
+
per_subject_estimating_function_stacks, axis=1
|
|
170
170
|
)
|
|
171
171
|
|
|
172
172
|
# Get the average estimating function stack norms by update/inference
|
|
@@ -175,7 +175,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
175
175
|
avg_estimating_function_stack_norms_per_segment = [
|
|
176
176
|
np.mean(
|
|
177
177
|
np.linalg.norm(
|
|
178
|
-
|
|
178
|
+
per_subject_estimating_function_stacks[
|
|
179
179
|
:, block_bounds[i] : block_bounds[i + 1]
|
|
180
180
|
],
|
|
181
181
|
axis=1,
|
|
@@ -191,67 +191,67 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
191
191
|
std_successive_beta_diff_norm = np.std(successive_beta_diff_norms)
|
|
192
192
|
|
|
193
193
|
# Add a column with logits of the action probabilities
|
|
194
|
-
# Compute the average and standard deviation of the logits of the action probabilities at each decision time using
|
|
194
|
+
# Compute the average and standard deviation of the logits of the action probabilities at each decision time using analysis_df
|
|
195
195
|
# action_prob_logit_means and action_prob_logit_stds are numpy arrays of mean and stddev at each decision time
|
|
196
|
-
# Only compute logits for rows where
|
|
197
|
-
in_study_mask =
|
|
198
|
-
|
|
196
|
+
# Only compute logits for rows where subject is in the study; set others to NaN
|
|
197
|
+
in_study_mask = analysis_df[active_col_name] == 1
|
|
198
|
+
analysis_df["action_prob_logit"] = np.where(
|
|
199
199
|
in_study_mask,
|
|
200
|
-
logit(
|
|
200
|
+
logit(analysis_df[action_prob_col_name]),
|
|
201
201
|
np.nan,
|
|
202
202
|
)
|
|
203
|
-
grouped_action_prob_logit =
|
|
203
|
+
grouped_action_prob_logit = analysis_df.loc[in_study_mask].groupby(
|
|
204
204
|
calendar_t_col_name
|
|
205
205
|
)["action_prob_logit"]
|
|
206
206
|
action_prob_logit_means_by_t = grouped_action_prob_logit.mean().values
|
|
207
207
|
action_prob_logit_stds_by_t = grouped_action_prob_logit.std().values
|
|
208
208
|
|
|
209
|
-
# Compute the average and standard deviation of the rewards at each decision time using
|
|
209
|
+
# Compute the average and standard deviation of the rewards at each decision time using analysis_df
|
|
210
210
|
# reward_means and reward_stds are numpy arrays of mean and stddev at each decision time
|
|
211
|
-
grouped_reward =
|
|
211
|
+
grouped_reward = analysis_df.loc[in_study_mask].groupby(calendar_t_col_name)[
|
|
212
212
|
reward_col_name
|
|
213
213
|
]
|
|
214
214
|
reward_means_by_t = grouped_reward.mean().values
|
|
215
215
|
reward_stds_by_t = grouped_reward.std().values
|
|
216
216
|
|
|
217
|
-
|
|
218
|
-
|
|
217
|
+
joint_bread_min_singular_value = np.linalg.svd(
|
|
218
|
+
joint_adjusted_bread_matrix, compute_uv=False
|
|
219
219
|
)[-1]
|
|
220
220
|
|
|
221
|
-
max_reward =
|
|
221
|
+
max_reward = analysis_df.loc[in_study_mask][reward_col_name].max()
|
|
222
222
|
|
|
223
223
|
norm_avg_estimating_function_stack = np.linalg.norm(avg_estimating_function_stack)
|
|
224
224
|
max_estimating_function_stack_norm = np.max(estimating_function_stack_norms)
|
|
225
225
|
|
|
226
226
|
(
|
|
227
227
|
premature_thetas,
|
|
228
|
-
|
|
228
|
+
premature_adjusted_sandwiches,
|
|
229
229
|
premature_classical_sandwiches,
|
|
230
|
-
|
|
230
|
+
premature_joint_adjusted_bread_condition_numbers,
|
|
231
231
|
premature_avg_inference_estimating_functions,
|
|
232
|
-
) =
|
|
233
|
-
|
|
232
|
+
) = calculate_sequence_of_premature_adjusted_estimates(
|
|
233
|
+
analysis_df,
|
|
234
234
|
initial_policy_num,
|
|
235
235
|
beta_index_by_policy_num,
|
|
236
|
-
|
|
236
|
+
policy_num_by_decision_time_by_subject_id,
|
|
237
237
|
theta_calculation_func,
|
|
238
238
|
calendar_t_col_name,
|
|
239
239
|
action_prob_col_name,
|
|
240
|
-
|
|
241
|
-
|
|
240
|
+
subject_id_col_name,
|
|
241
|
+
active_col_name,
|
|
242
242
|
all_post_update_betas,
|
|
243
|
-
|
|
243
|
+
subject_ids,
|
|
244
244
|
action_prob_func,
|
|
245
245
|
action_prob_func_args_beta_index,
|
|
246
246
|
inference_func,
|
|
247
247
|
inference_func_type,
|
|
248
248
|
inference_func_args_theta_index,
|
|
249
249
|
inference_func_args_action_prob_index,
|
|
250
|
-
|
|
250
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
251
251
|
action_prob_func_args,
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
252
|
+
action_by_decision_time_by_subject_id,
|
|
253
|
+
joint_adjusted_bread_matrix,
|
|
254
|
+
per_subject_estimating_function_stacks,
|
|
255
255
|
beta_dim,
|
|
256
256
|
)
|
|
257
257
|
|
|
@@ -261,54 +261,54 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
261
261
|
atol=1e-3,
|
|
262
262
|
)
|
|
263
263
|
|
|
264
|
-
# Plot premature joint
|
|
264
|
+
# Plot premature joint adjusted bread log condition numbers
|
|
265
265
|
plt.clear_figure()
|
|
266
|
-
plt.title("Premature Joint
|
|
266
|
+
plt.title("Premature Joint Adjusted Bread Inverse Log Condition Numbers")
|
|
267
267
|
plt.xlabel("Premature Update Index")
|
|
268
268
|
plt.ylabel("Log Condition Number")
|
|
269
269
|
plt.scatter(
|
|
270
|
-
np.log(
|
|
270
|
+
np.log(premature_joint_adjusted_bread_condition_numbers),
|
|
271
271
|
color="blue+",
|
|
272
272
|
)
|
|
273
273
|
plt.grid(True)
|
|
274
274
|
plt.xticks(
|
|
275
275
|
range(
|
|
276
276
|
0,
|
|
277
|
-
len(
|
|
277
|
+
len(premature_joint_adjusted_bread_condition_numbers),
|
|
278
278
|
max(
|
|
279
279
|
1,
|
|
280
|
-
len(
|
|
280
|
+
len(premature_joint_adjusted_bread_condition_numbers) // 10,
|
|
281
281
|
),
|
|
282
282
|
)
|
|
283
283
|
)
|
|
284
284
|
plt.show()
|
|
285
285
|
|
|
286
|
-
# Plot each diagonal element of premature
|
|
287
|
-
num_diag =
|
|
286
|
+
# Plot each diagonal element of premature adjusted sandwiches
|
|
287
|
+
num_diag = premature_adjusted_sandwiches.shape[-1]
|
|
288
288
|
for i in range(num_diag):
|
|
289
289
|
plt.clear_figure()
|
|
290
|
-
plt.title(f"Premature
|
|
290
|
+
plt.title(f"Premature Adjusted Sandwich Diagonal Element {i}")
|
|
291
291
|
plt.xlabel("Premature Update Index")
|
|
292
292
|
plt.ylabel(f"Variance (Diagonal {i})")
|
|
293
|
-
plt.scatter(np.array(
|
|
293
|
+
plt.scatter(np.array(premature_adjusted_sandwiches[:, i, i]), color="blue+")
|
|
294
294
|
plt.grid(True)
|
|
295
295
|
plt.xticks(
|
|
296
296
|
range(
|
|
297
297
|
0,
|
|
298
|
-
int(
|
|
299
|
-
max(1, int(
|
|
298
|
+
int(premature_adjusted_sandwiches.shape[0]),
|
|
299
|
+
max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
|
|
300
300
|
)
|
|
301
301
|
)
|
|
302
302
|
plt.show()
|
|
303
303
|
|
|
304
304
|
plt.clear_figure()
|
|
305
305
|
plt.title(
|
|
306
|
-
f"Premature
|
|
306
|
+
f"Premature Adjusted Sandwich Diagonal Element {i} Ratio to Classical"
|
|
307
307
|
)
|
|
308
308
|
plt.xlabel("Premature Update Index")
|
|
309
309
|
plt.ylabel(f"Variance (Diagonal {i})")
|
|
310
310
|
plt.scatter(
|
|
311
|
-
np.array(
|
|
311
|
+
np.array(premature_adjusted_sandwiches[:, i, i])
|
|
312
312
|
/ np.array(premature_classical_sandwiches[:, i, i]),
|
|
313
313
|
color="red+",
|
|
314
314
|
)
|
|
@@ -316,8 +316,8 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
316
316
|
plt.xticks(
|
|
317
317
|
range(
|
|
318
318
|
0,
|
|
319
|
-
int(
|
|
320
|
-
max(1, int(
|
|
319
|
+
int(premature_adjusted_sandwiches.shape[0]),
|
|
320
|
+
max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
|
|
321
321
|
)
|
|
322
322
|
)
|
|
323
323
|
plt.show()
|
|
@@ -331,14 +331,14 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
331
331
|
plt.xticks(
|
|
332
332
|
range(
|
|
333
333
|
0,
|
|
334
|
-
int(
|
|
335
|
-
max(1, int(
|
|
334
|
+
int(premature_adjusted_sandwiches.shape[0]),
|
|
335
|
+
max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
|
|
336
336
|
)
|
|
337
337
|
)
|
|
338
338
|
plt.show()
|
|
339
339
|
|
|
340
340
|
# Grab predictors related to premature Phi-dot-bars
|
|
341
|
-
RL_stack_beta_derivatives_block =
|
|
341
|
+
RL_stack_beta_derivatives_block = joint_adjusted_bread_matrix[
|
|
342
342
|
:-theta_dim, :-theta_dim
|
|
343
343
|
]
|
|
344
344
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
@@ -397,14 +397,14 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
397
397
|
)
|
|
398
398
|
return {
|
|
399
399
|
**{
|
|
400
|
-
"
|
|
401
|
-
"
|
|
400
|
+
"joint_bread_condition_number": joint_adjusted_bread_cond,
|
|
401
|
+
"joint_bread_min_singular_value": joint_bread_min_singular_value,
|
|
402
402
|
"max_reward": max_reward,
|
|
403
403
|
"norm_avg_estimating_function_stack": norm_avg_estimating_function_stack,
|
|
404
404
|
"max_estimating_function_stack_norm": max_estimating_function_stack_norm,
|
|
405
405
|
"max_successive_beta_diff_norm": max_successive_beta_diff_norm,
|
|
406
406
|
"std_successive_beta_diff_norm": std_successive_beta_diff_norm,
|
|
407
|
-
"label":
|
|
407
|
+
"label": adjusted_sandwich_var_estimate,
|
|
408
408
|
},
|
|
409
409
|
**{
|
|
410
410
|
f"off_diag_block_{i}_{j}_norm": off_diag_block_norms[(i, j)]
|
|
@@ -422,10 +422,10 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
422
422
|
for i in range(num_block_rows_cols)
|
|
423
423
|
},
|
|
424
424
|
**{
|
|
425
|
-
f"
|
|
425
|
+
f"estimating_function_stack_norm_subject_{subject_id}": estimating_function_stack_norms[
|
|
426
426
|
i
|
|
427
427
|
]
|
|
428
|
-
for i,
|
|
428
|
+
for i, subject_id in enumerate(subject_ids)
|
|
429
429
|
},
|
|
430
430
|
**{
|
|
431
431
|
f"avg_estimating_function_stack_norm_segment_{i}": avg_estimating_function_stack_norms_per_segment[
|
|
@@ -455,18 +455,16 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
455
455
|
},
|
|
456
456
|
**{f"theta_est_{i}": theta_est[i].item() for i in range(len(theta_est))},
|
|
457
457
|
**{
|
|
458
|
-
f"
|
|
458
|
+
f"premature_joint_adjusted_bread_condition_number_{i}": premature_joint_adjusted_bread_condition_numbers[
|
|
459
459
|
i
|
|
460
460
|
]
|
|
461
|
-
for i in range(
|
|
462
|
-
len(premature_joint_adaptive_bread_inverse_condition_numbers)
|
|
463
|
-
)
|
|
461
|
+
for i in range(len(premature_joint_adjusted_bread_condition_numbers))
|
|
464
462
|
},
|
|
465
463
|
**{
|
|
466
|
-
f"
|
|
464
|
+
f"premature_adjusted_sandwich_update_{i}_diag_position_{j}": premature_adjusted_sandwich[
|
|
467
465
|
j, j
|
|
468
466
|
]
|
|
469
|
-
for
|
|
467
|
+
for premature_adjusted_sandwich in premature_adjusted_sandwiches
|
|
470
468
|
for j in range(theta_dim)
|
|
471
469
|
},
|
|
472
470
|
**{
|
|
@@ -497,44 +495,46 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
497
495
|
}
|
|
498
496
|
|
|
499
497
|
|
|
500
|
-
def
|
|
501
|
-
|
|
498
|
+
def calculate_sequence_of_premature_adjusted_estimates(
|
|
499
|
+
analysis_df: pd.DataFrame,
|
|
502
500
|
initial_policy_num: int | float,
|
|
503
501
|
beta_index_by_policy_num: dict[int | float, int],
|
|
504
|
-
|
|
502
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
505
503
|
collections.abc.Hashable, dict[int, int | float]
|
|
506
504
|
],
|
|
507
505
|
theta_calculation_func: str,
|
|
508
506
|
calendar_t_col_name: str,
|
|
509
507
|
action_prob_col_name: str,
|
|
510
|
-
|
|
511
|
-
|
|
508
|
+
subject_id_col_name: str,
|
|
509
|
+
active_col_name: str,
|
|
512
510
|
all_post_update_betas: jnp.ndarray,
|
|
513
|
-
|
|
511
|
+
subject_ids: jnp.ndarray,
|
|
514
512
|
action_prob_func: str,
|
|
515
513
|
action_prob_func_args_beta_index: int,
|
|
516
514
|
inference_func: str,
|
|
517
515
|
inference_func_type: str,
|
|
518
516
|
inference_func_args_theta_index: int,
|
|
519
517
|
inference_func_args_action_prob_index: int,
|
|
520
|
-
|
|
518
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
521
519
|
collections.abc.Hashable, list[int]
|
|
522
520
|
],
|
|
523
|
-
|
|
521
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
524
522
|
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
525
523
|
],
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
524
|
+
action_by_decision_time_by_subject_id: dict[
|
|
525
|
+
collections.abc.Hashable, dict[int, int]
|
|
526
|
+
],
|
|
527
|
+
full_joint_adjusted_bread_matrix: jnp.ndarray,
|
|
528
|
+
per_subject_estimating_function_stacks: jnp.ndarray,
|
|
529
529
|
beta_dim: int,
|
|
530
530
|
) -> jnp.ndarray:
|
|
531
531
|
"""
|
|
532
|
-
Calculates a sequence of premature
|
|
532
|
+
Calculates a sequence of premature adjusted estimates for the given study DataFrame, where we
|
|
533
533
|
pretend the study ended after each update in sequence. The behavior of this sequence may provide
|
|
534
|
-
insight into the stability of the final
|
|
534
|
+
insight into the stability of the final adjusted estimate.
|
|
535
535
|
|
|
536
536
|
Args:
|
|
537
|
-
|
|
537
|
+
analysis_df (pandas.DataFrame):
|
|
538
538
|
The DataFrame containing the study data.
|
|
539
539
|
initial_policy_num (int | float): The policy number of the initial policy before any updates.
|
|
540
540
|
initial_policy_num (int | float):
|
|
@@ -542,23 +542,23 @@ def calculate_sequence_of_premature_adaptive_estimates(
|
|
|
542
542
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
543
543
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
544
544
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
545
|
-
|
|
546
|
-
A map of
|
|
545
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
546
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
547
547
|
Only applies to in-study decision times!
|
|
548
548
|
theta_calculation_func (callable):
|
|
549
549
|
The filename for the theta calculation function.
|
|
550
550
|
calendar_t_col_name (str):
|
|
551
|
-
The name of the column in
|
|
551
|
+
The name of the column in analysis_df representing calendar time.
|
|
552
552
|
action_prob_col_name (str):
|
|
553
|
-
The name of the column in
|
|
554
|
-
|
|
555
|
-
The name of the column in
|
|
556
|
-
|
|
557
|
-
The name of the column in
|
|
553
|
+
The name of the column in analysis_df representing action probabilities.
|
|
554
|
+
subject_id_col_name (str):
|
|
555
|
+
The name of the column in analysis_df representing subject IDs.
|
|
556
|
+
active_col_name (str):
|
|
557
|
+
The name of the column in analysis_df indicating whether the subject is in the study at that time.
|
|
558
558
|
all_post_update_betas (jnp.ndarray):
|
|
559
559
|
A NumPy array containing all post-update beta values.
|
|
560
|
-
|
|
561
|
-
A NumPy array containing all
|
|
560
|
+
subject_ids (jnp.ndarray):
|
|
561
|
+
A NumPy array containing all subject IDs in the study.
|
|
562
562
|
action_prob_func (callable):
|
|
563
563
|
The action probability function.
|
|
564
564
|
action_prob_func_args_beta_index (int):
|
|
@@ -572,56 +572,54 @@ def calculate_sequence_of_premature_adaptive_estimates(
|
|
|
572
572
|
inference_func_args_action_prob_index (int):
|
|
573
573
|
The index of action probabilities in the inference function arguments tuple, if
|
|
574
574
|
applicable. -1 otherwise.
|
|
575
|
-
|
|
576
|
-
For each
|
|
575
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
576
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
577
577
|
provided. Typically just in-study times if action probabilites are used in the inference
|
|
578
578
|
loss or estimating function.
|
|
579
|
-
|
|
580
|
-
A dictionary mapping decision times to maps of
|
|
581
|
-
required to compute action probabilities for this
|
|
582
|
-
|
|
583
|
-
A dictionary mapping
|
|
579
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
580
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
581
|
+
required to compute action probabilities for this subject.
|
|
582
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
583
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
584
584
|
Only applies to in-study decision times!
|
|
585
|
-
|
|
586
|
-
The full joint
|
|
587
|
-
|
|
588
|
-
A NumPy array containing all per-
|
|
585
|
+
full_joint_adjusted_bread_matrix (jnp.ndarray):
|
|
586
|
+
The full joint adjusted bread matrix as a NumPy array.
|
|
587
|
+
per_subject_estimating_function_stacks (jnp.ndarray):
|
|
588
|
+
A NumPy array containing all per-subject (weighted) estimating function stacks.
|
|
589
589
|
beta_dim (int):
|
|
590
590
|
The dimension of the beta parameters.
|
|
591
591
|
Returns:
|
|
592
|
-
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature
|
|
592
|
+
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature adjusted estimates.
|
|
593
593
|
"""
|
|
594
594
|
|
|
595
|
-
# Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming
|
|
595
|
+
# Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming adjusted and classical
|
|
596
596
|
# variance estimates pretending that each was the final policy.
|
|
597
|
-
|
|
597
|
+
premature_adjusted_sandwiches = []
|
|
598
598
|
premature_thetas = []
|
|
599
|
-
|
|
599
|
+
premature_joint_adjusted_bread_condition_numbers = []
|
|
600
600
|
premature_avg_inference_estimating_functions = []
|
|
601
601
|
premature_classical_sandwiches = []
|
|
602
602
|
logger.info(
|
|
603
|
-
"Calculating sequence of premature
|
|
603
|
+
"Calculating sequence of premature adjusted estimates by pretending the study ended after each update in sequence."
|
|
604
604
|
)
|
|
605
605
|
for policy_num in sorted(beta_index_by_policy_num):
|
|
606
606
|
logger.info(
|
|
607
|
-
"Calculating premature
|
|
607
|
+
"Calculating premature adjusted estimate assuming policy %s is the final one.",
|
|
608
608
|
policy_num,
|
|
609
609
|
)
|
|
610
610
|
pretend_max_policy = policy_num
|
|
611
611
|
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
]
|
|
617
|
-
)
|
|
612
|
+
truncated_joint_adjusted_bread_matrix = full_joint_adjusted_bread_matrix[
|
|
613
|
+
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
614
|
+
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
615
|
+
]
|
|
618
616
|
|
|
619
|
-
max_decision_time =
|
|
620
|
-
|
|
621
|
-
].max()
|
|
617
|
+
max_decision_time = analysis_df[
|
|
618
|
+
analysis_df["policy_num"] == pretend_max_policy
|
|
619
|
+
][calendar_t_col_name].max()
|
|
622
620
|
|
|
623
|
-
|
|
624
|
-
|
|
621
|
+
truncated_analysis_df = analysis_df[
|
|
622
|
+
analysis_df[calendar_t_col_name] <= max_decision_time
|
|
625
623
|
].copy()
|
|
626
624
|
|
|
627
625
|
truncated_beta_index_by_policy_num = {
|
|
@@ -632,83 +630,83 @@ def calculate_sequence_of_premature_adaptive_estimates(
|
|
|
632
630
|
|
|
633
631
|
truncated_all_post_update_betas = all_post_update_betas[: max_beta_index + 1, :]
|
|
634
632
|
|
|
635
|
-
premature_theta = jnp.array(theta_calculation_func(
|
|
633
|
+
premature_theta = jnp.array(theta_calculation_func(truncated_analysis_df))
|
|
636
634
|
|
|
637
|
-
|
|
638
|
-
decision_time:
|
|
639
|
-
for decision_time,
|
|
635
|
+
truncated_action_prob_func_args_by_subject_id_by_decision_time = {
|
|
636
|
+
decision_time: args_by_subject_id
|
|
637
|
+
for decision_time, args_by_subject_id in action_prob_func_args_by_subject_id_by_decision_time.items()
|
|
640
638
|
if decision_time <= max_decision_time
|
|
641
639
|
}
|
|
642
640
|
|
|
643
|
-
|
|
644
|
-
|
|
641
|
+
truncated_inference_func_args_by_subject_id, _, _ = (
|
|
642
|
+
post_deployment_analysis.process_inference_func_args(
|
|
645
643
|
inference_func,
|
|
646
644
|
inference_func_args_theta_index,
|
|
647
|
-
|
|
645
|
+
truncated_analysis_df,
|
|
648
646
|
premature_theta,
|
|
649
647
|
action_prob_col_name,
|
|
650
648
|
calendar_t_col_name,
|
|
651
|
-
|
|
652
|
-
|
|
649
|
+
subject_id_col_name,
|
|
650
|
+
active_col_name,
|
|
653
651
|
)
|
|
654
652
|
)
|
|
655
653
|
|
|
656
|
-
|
|
657
|
-
|
|
654
|
+
truncated_inference_action_prob_decision_times_by_subject_id = {
|
|
655
|
+
subject_id: [
|
|
658
656
|
decision_time
|
|
659
|
-
for decision_time in
|
|
660
|
-
|
|
657
|
+
for decision_time in inference_action_prob_decision_times_by_subject_id[
|
|
658
|
+
subject_id
|
|
661
659
|
]
|
|
662
660
|
if decision_time <= max_decision_time
|
|
663
661
|
]
|
|
664
662
|
# writing this way is important, handles empty dicts correctly
|
|
665
|
-
for
|
|
663
|
+
for subject_id in inference_action_prob_decision_times_by_subject_id
|
|
666
664
|
}
|
|
667
665
|
|
|
668
|
-
|
|
669
|
-
|
|
666
|
+
truncated_action_by_decision_time_by_subject_id = {
|
|
667
|
+
subject_id: {
|
|
670
668
|
decision_time: action
|
|
671
|
-
for decision_time, action in
|
|
672
|
-
|
|
669
|
+
for decision_time, action in action_by_decision_time_by_subject_id[
|
|
670
|
+
subject_id
|
|
673
671
|
].items()
|
|
674
672
|
if decision_time <= max_decision_time
|
|
675
673
|
}
|
|
676
|
-
for
|
|
674
|
+
for subject_id in action_by_decision_time_by_subject_id
|
|
677
675
|
}
|
|
678
676
|
|
|
679
|
-
|
|
680
|
-
|
|
677
|
+
truncated_per_subject_estimating_function_stacks = (
|
|
678
|
+
per_subject_estimating_function_stacks[
|
|
681
679
|
:,
|
|
682
680
|
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
683
681
|
]
|
|
684
682
|
)
|
|
685
683
|
|
|
686
684
|
(
|
|
687
|
-
|
|
685
|
+
premature_adjusted_sandwich,
|
|
688
686
|
premature_classical_sandwich,
|
|
689
687
|
premature_avg_inference_estimating_function,
|
|
690
|
-
) =
|
|
691
|
-
|
|
692
|
-
|
|
688
|
+
) = construct_premature_classical_and_adjusted_sandwiches(
|
|
689
|
+
truncated_joint_adjusted_bread_matrix,
|
|
690
|
+
truncated_per_subject_estimating_function_stacks,
|
|
693
691
|
premature_theta,
|
|
694
692
|
truncated_all_post_update_betas,
|
|
695
|
-
|
|
693
|
+
subject_ids,
|
|
696
694
|
action_prob_func,
|
|
697
695
|
action_prob_func_args_beta_index,
|
|
698
696
|
inference_func,
|
|
699
697
|
inference_func_type,
|
|
700
698
|
inference_func_args_theta_index,
|
|
701
699
|
inference_func_args_action_prob_index,
|
|
702
|
-
|
|
703
|
-
|
|
700
|
+
truncated_action_prob_func_args_by_subject_id_by_decision_time,
|
|
701
|
+
policy_num_by_decision_time_by_subject_id,
|
|
704
702
|
initial_policy_num,
|
|
705
703
|
truncated_beta_index_by_policy_num,
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
704
|
+
truncated_inference_func_args_by_subject_id,
|
|
705
|
+
truncated_inference_action_prob_decision_times_by_subject_id,
|
|
706
|
+
truncated_action_by_decision_time_by_subject_id,
|
|
709
707
|
)
|
|
710
708
|
|
|
711
|
-
|
|
709
|
+
premature_adjusted_sandwiches.append(premature_adjusted_sandwich)
|
|
712
710
|
premature_classical_sandwiches.append(premature_classical_sandwich)
|
|
713
711
|
premature_thetas.append(premature_theta)
|
|
714
712
|
premature_avg_inference_estimating_functions.append(
|
|
@@ -716,38 +714,40 @@ def calculate_sequence_of_premature_adaptive_estimates(
|
|
|
716
714
|
)
|
|
717
715
|
return (
|
|
718
716
|
jnp.array(premature_thetas),
|
|
719
|
-
jnp.array(
|
|
717
|
+
jnp.array(premature_adjusted_sandwiches),
|
|
720
718
|
jnp.array(premature_classical_sandwiches),
|
|
721
|
-
jnp.array(
|
|
719
|
+
jnp.array(premature_joint_adjusted_bread_condition_numbers),
|
|
722
720
|
jnp.array(premature_avg_inference_estimating_functions),
|
|
723
721
|
)
|
|
724
722
|
|
|
725
723
|
|
|
726
|
-
def
|
|
727
|
-
|
|
728
|
-
|
|
724
|
+
def construct_premature_classical_and_adjusted_sandwiches(
|
|
725
|
+
truncated_joint_adjusted_bread_matrix: jnp.ndarray,
|
|
726
|
+
per_subject_truncated_estimating_function_stacks: jnp.ndarray,
|
|
729
727
|
theta: jnp.ndarray,
|
|
730
728
|
all_post_update_betas: jnp.ndarray,
|
|
731
|
-
|
|
729
|
+
subject_ids: jnp.ndarray,
|
|
732
730
|
action_prob_func: str,
|
|
733
731
|
action_prob_func_args_beta_index: int,
|
|
734
732
|
inference_func: str,
|
|
735
733
|
inference_func_type: str,
|
|
736
734
|
inference_func_args_theta_index: int,
|
|
737
735
|
inference_func_args_action_prob_index: int,
|
|
738
|
-
|
|
736
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
739
737
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
740
738
|
],
|
|
741
|
-
|
|
739
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
742
740
|
collections.abc.Hashable, dict[int, int | float]
|
|
743
741
|
],
|
|
744
742
|
initial_policy_num: int | float,
|
|
745
743
|
beta_index_by_policy_num: dict[int | float, int],
|
|
746
|
-
|
|
747
|
-
|
|
744
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
745
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
748
746
|
collections.abc.Hashable, list[int]
|
|
749
747
|
],
|
|
750
|
-
|
|
748
|
+
action_by_decision_time_by_subject_id: dict[
|
|
749
|
+
collections.abc.Hashable, dict[int, int]
|
|
750
|
+
],
|
|
751
751
|
) -> tuple[
|
|
752
752
|
jnp.ndarray[jnp.float32],
|
|
753
753
|
jnp.ndarray[jnp.float32],
|
|
@@ -759,33 +759,33 @@ def construct_premature_classical_and_adaptive_sandwiches(
|
|
|
759
759
|
jnp.ndarray[jnp.float32],
|
|
760
760
|
]:
|
|
761
761
|
"""
|
|
762
|
-
Constructs the classical bread and meat matrices, as well as the
|
|
762
|
+
Constructs the classical bread and meat matrices, as well as the adjusted bread matrix
|
|
763
763
|
and the average weighted inference estimating function for the premature variance estimation
|
|
764
764
|
procedure.
|
|
765
765
|
|
|
766
766
|
This is done by computing and differentiating the new average inference estimating function
|
|
767
767
|
with respect to the betas and theta, and stitching this together with the existing
|
|
768
|
-
|
|
769
|
-
to form the new premature joint
|
|
768
|
+
adjusted bread matrix portion (corresponding to the updates still under consideration)
|
|
769
|
+
to form the new premature joint adjusted bread matrix.
|
|
770
770
|
|
|
771
771
|
Args:
|
|
772
|
-
|
|
773
|
-
A 2-D JAX NumPy array holding the existing joint
|
|
772
|
+
truncated_joint_adjusted_bread_matrix (jnp.ndarray):
|
|
773
|
+
A 2-D JAX NumPy array holding the existing joint adjusted bread but
|
|
774
774
|
with rows corresponding to updates not under consideration and inference dropped.
|
|
775
775
|
We will stitch this together with the newly computed inference portion to form
|
|
776
|
-
our "premature" joint
|
|
777
|
-
|
|
778
|
-
A 2-D JAX NumPy array holding the existing per-
|
|
776
|
+
our "premature" joint adjusted bread matrix.
|
|
777
|
+
per_subject_truncated_estimating_function_stacks (jnp.ndarray):
|
|
778
|
+
A 2-D JAX NumPy array holding the existing per-subject weighted estimating function
|
|
779
779
|
stacks but with rows corresponding to updates not under consideration dropped.
|
|
780
780
|
We will stitch this together with the newly computed inference estimating functions
|
|
781
|
-
to form our "premature" joint
|
|
782
|
-
|
|
781
|
+
to form our "premature" joint adjusted estimating function stacks from which the new
|
|
782
|
+
adjusted meat matrix can be computed.
|
|
783
783
|
theta (jnp.ndarray):
|
|
784
784
|
A 1-D JAX NumPy array representing the parameter estimate for inference.
|
|
785
785
|
all_post_update_betas (jnp.ndarray):
|
|
786
786
|
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
787
|
-
|
|
788
|
-
A 1-D JAX NumPy array holding all
|
|
787
|
+
subject_ids (jnp.ndarray):
|
|
788
|
+
A 1-D JAX NumPy array holding all subject IDs in the study.
|
|
789
789
|
action_prob_func (callable):
|
|
790
790
|
The action probability function.
|
|
791
791
|
action_prob_func_args_beta_index (int):
|
|
@@ -799,81 +799,81 @@ def construct_premature_classical_and_adaptive_sandwiches(
|
|
|
799
799
|
inference_func_args_action_prob_index (int):
|
|
800
800
|
The index of action probabilities in the inference function arguments tuple, if
|
|
801
801
|
applicable. -1 otherwise.
|
|
802
|
-
|
|
803
|
-
A dictionary mapping decision times to maps of
|
|
804
|
-
required to compute action probabilities for this
|
|
805
|
-
|
|
806
|
-
A map of
|
|
802
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
803
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
804
|
+
required to compute action probabilities for this subject.
|
|
805
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
806
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
807
807
|
Only applies to in-study decision times!
|
|
808
808
|
initial_policy_num (int | float):
|
|
809
809
|
The policy number of the initial policy before any updates.
|
|
810
810
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
811
811
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
812
812
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
813
|
-
|
|
814
|
-
A dictionary mapping
|
|
815
|
-
|
|
816
|
-
For each
|
|
813
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
814
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
815
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
816
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
817
817
|
provided. Typically just in-study times if action probabilites are used in the inference
|
|
818
818
|
loss or estimating function.
|
|
819
|
-
|
|
820
|
-
A dictionary mapping
|
|
819
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
820
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
821
821
|
Only applies to in-study decision times!
|
|
822
822
|
Returns:
|
|
823
823
|
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
|
|
824
824
|
jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
|
|
825
825
|
jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
826
826
|
A tuple containing:
|
|
827
|
-
- The joint
|
|
828
|
-
- The joint
|
|
829
|
-
- The joint
|
|
830
|
-
- The classical
|
|
827
|
+
- The joint adjusted bread matrix.
|
|
828
|
+
- The joint adjusted bread matrix.
|
|
829
|
+
- The joint adjusted meat matrix.
|
|
830
|
+
- The classical bread matrix.
|
|
831
831
|
- The classical bread matrix.
|
|
832
832
|
- The classical meat matrix.
|
|
833
833
|
- The average (weighted) inference estimating function.
|
|
834
|
-
- The joint
|
|
834
|
+
- The joint adjusted bread matrix condition number.
|
|
835
835
|
"""
|
|
836
836
|
logger.info(
|
|
837
837
|
"Differentiating average weighted inference estimating function stack and collecting auxiliary values."
|
|
838
838
|
)
|
|
839
839
|
# jax.jacobian may perform worse here--seemed to hang indefinitely while jacrev is merely very
|
|
840
840
|
# slow.
|
|
841
|
-
# Note that these "contributions" are per-
|
|
841
|
+
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
842
842
|
new_inference_block_row, (
|
|
843
|
-
|
|
843
|
+
per_subject_inference_estimating_functions,
|
|
844
844
|
avg_inference_estimating_function,
|
|
845
|
-
|
|
846
|
-
|
|
845
|
+
per_subject_classical_meat_contributions,
|
|
846
|
+
per_subject_classical_bread_contributions,
|
|
847
847
|
) = jax.jacrev(get_weighted_inference_estimating_functions_only, has_aux=True)(
|
|
848
848
|
# While JAX can technically differentiate with respect to a list of JAX arrays,
|
|
849
849
|
# it is more efficient to flatten them into a single array. This is done
|
|
850
850
|
# here to improve performance. We can simply unflatten them inside the function.
|
|
851
|
-
|
|
851
|
+
post_deployment_analysis.flatten_params(all_post_update_betas, theta),
|
|
852
852
|
all_post_update_betas.shape[1],
|
|
853
853
|
theta.shape[0],
|
|
854
|
-
|
|
854
|
+
subject_ids,
|
|
855
855
|
action_prob_func,
|
|
856
856
|
action_prob_func_args_beta_index,
|
|
857
857
|
inference_func,
|
|
858
858
|
inference_func_type,
|
|
859
859
|
inference_func_args_theta_index,
|
|
860
860
|
inference_func_args_action_prob_index,
|
|
861
|
-
|
|
862
|
-
|
|
861
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
862
|
+
policy_num_by_decision_time_by_subject_id,
|
|
863
863
|
initial_policy_num,
|
|
864
864
|
beta_index_by_policy_num,
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
865
|
+
inference_func_args_by_subject_id,
|
|
866
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
867
|
+
action_by_decision_time_by_subject_id,
|
|
868
868
|
)
|
|
869
869
|
|
|
870
|
-
|
|
870
|
+
joint_adjusted_bread_matrix = jnp.block(
|
|
871
871
|
[
|
|
872
872
|
[
|
|
873
|
-
|
|
873
|
+
truncated_joint_adjusted_bread_matrix,
|
|
874
874
|
np.zeros(
|
|
875
875
|
(
|
|
876
|
-
|
|
876
|
+
truncated_joint_adjusted_bread_matrix.shape[0],
|
|
877
877
|
new_inference_block_row.shape[0],
|
|
878
878
|
)
|
|
879
879
|
),
|
|
@@ -881,51 +881,49 @@ def construct_premature_classical_and_adaptive_sandwiches(
|
|
|
881
881
|
[new_inference_block_row],
|
|
882
882
|
]
|
|
883
883
|
)
|
|
884
|
-
|
|
884
|
+
per_subject_estimating_function_stacks = jnp.concatenate(
|
|
885
885
|
[
|
|
886
|
-
|
|
887
|
-
|
|
886
|
+
per_subject_truncated_estimating_function_stacks,
|
|
887
|
+
per_subject_inference_estimating_functions,
|
|
888
888
|
],
|
|
889
889
|
axis=1,
|
|
890
890
|
)
|
|
891
|
-
|
|
891
|
+
per_subject_adjusted_meat_contributions = jnp.einsum(
|
|
892
892
|
"ni,nj->nij",
|
|
893
|
-
|
|
894
|
-
|
|
893
|
+
per_subject_estimating_function_stacks,
|
|
894
|
+
per_subject_estimating_function_stacks,
|
|
895
895
|
)
|
|
896
896
|
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
classical_bread_inverse_matrix = jnp.mean(
|
|
900
|
-
per_user_classical_bread_inverse_contributions, axis=0
|
|
897
|
+
joint_adjusted_meat_matrix = jnp.mean(
|
|
898
|
+
per_subject_adjusted_meat_contributions, axis=0
|
|
901
899
|
)
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
900
|
+
|
|
901
|
+
classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
|
|
902
|
+
classical_meat_matrix = jnp.mean(per_subject_classical_meat_contributions, axis=0)
|
|
903
|
+
|
|
904
|
+
num_subjects = subject_ids.shape[0]
|
|
905
|
+
joint_adjusted_sandwich = (
|
|
906
|
+
post_deployment_analysis.form_sandwich_from_bread_and_meat(
|
|
907
|
+
joint_adjusted_bread_matrix,
|
|
908
|
+
joint_adjusted_meat_matrix,
|
|
909
|
+
num_subjects,
|
|
910
|
+
method="bread_T_qr",
|
|
911
911
|
)
|
|
912
912
|
)
|
|
913
|
-
|
|
913
|
+
adjusted_sandwich = joint_adjusted_sandwich[-theta.shape[0] :, -theta.shape[0] :]
|
|
914
914
|
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
classical_sandwich = after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
|
|
919
|
-
classical_bread_inverse_matrix,
|
|
915
|
+
classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
|
|
916
|
+
classical_sandwich = post_deployment_analysis.form_sandwich_from_bread_and_meat(
|
|
917
|
+
classical_bread_matrix,
|
|
920
918
|
classical_meat_matrix,
|
|
921
|
-
|
|
922
|
-
method="
|
|
919
|
+
num_subjects,
|
|
920
|
+
method="bread_T_qr",
|
|
923
921
|
)
|
|
924
922
|
|
|
925
|
-
# Stack the joint
|
|
926
|
-
# values too. The joint
|
|
923
|
+
# Stack the joint adjusted bread pieces together horizontally and return the auxiliary
|
|
924
|
+
# values too. The joint adjusted bread should always be block lower triangular.
|
|
927
925
|
return (
|
|
928
|
-
|
|
926
|
+
adjusted_sandwich,
|
|
929
927
|
classical_sandwich,
|
|
930
928
|
avg_inference_estimating_function,
|
|
931
929
|
)
|
|
@@ -935,32 +933,34 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
935
933
|
flattened_betas_and_theta: jnp.ndarray,
|
|
936
934
|
beta_dim: int,
|
|
937
935
|
theta_dim: int,
|
|
938
|
-
|
|
936
|
+
subject_ids: jnp.ndarray,
|
|
939
937
|
action_prob_func: callable,
|
|
940
938
|
action_prob_func_args_beta_index: int,
|
|
941
939
|
inference_func: callable,
|
|
942
940
|
inference_func_type: str,
|
|
943
941
|
inference_func_args_theta_index: int,
|
|
944
942
|
inference_func_args_action_prob_index: int,
|
|
945
|
-
|
|
943
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
946
944
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
947
945
|
],
|
|
948
|
-
|
|
946
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
949
947
|
collections.abc.Hashable, dict[int, int | float]
|
|
950
948
|
],
|
|
951
949
|
initial_policy_num: int | float,
|
|
952
950
|
beta_index_by_policy_num: dict[int | float, int],
|
|
953
|
-
|
|
954
|
-
|
|
951
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
952
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
955
953
|
collections.abc.Hashable, list[int]
|
|
956
954
|
],
|
|
957
|
-
|
|
955
|
+
action_by_decision_time_by_subject_id: dict[
|
|
956
|
+
collections.abc.Hashable, dict[int, int]
|
|
957
|
+
],
|
|
958
958
|
) -> tuple[
|
|
959
959
|
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
960
960
|
]:
|
|
961
961
|
"""
|
|
962
|
-
Computes the average weighted inference estimating function across
|
|
963
|
-
auxiliary values used to construct the
|
|
962
|
+
Computes the average weighted inference estimating function across subjects, along with
|
|
963
|
+
auxiliary values used to construct the adjusted and classical sandwich variances.
|
|
964
964
|
|
|
965
965
|
Note that input data should have been adjusted to only correspond to updates/decision times
|
|
966
966
|
that are being considered for the current "premature" variance estimation procedure.
|
|
@@ -974,8 +974,8 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
974
974
|
The dimension of each of the beta parameters.
|
|
975
975
|
theta_dim (int):
|
|
976
976
|
The dimension of the theta parameter.
|
|
977
|
-
|
|
978
|
-
A 1D JAX NumPy array of
|
|
977
|
+
subject_ids (jnp.ndarray):
|
|
978
|
+
A 1D JAX NumPy array of subject IDs.
|
|
979
979
|
action_prob_func (str):
|
|
980
980
|
The action probability function.
|
|
981
981
|
action_prob_func_args_beta_index (int):
|
|
@@ -989,25 +989,25 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
989
989
|
inference_func_args_action_prob_index (int):
|
|
990
990
|
The index of action probabilities in the inference function arguments tuple, if
|
|
991
991
|
applicable. -1 otherwise.
|
|
992
|
-
|
|
993
|
-
A dictionary mapping decision times to maps of
|
|
994
|
-
required to compute action probabilities for this
|
|
995
|
-
|
|
996
|
-
A map of
|
|
992
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
993
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
994
|
+
required to compute action probabilities for this subject.
|
|
995
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
996
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
997
997
|
Only applies to in-study decision times!
|
|
998
998
|
initial_policy_num (int | float):
|
|
999
999
|
The policy number of the initial policy before any updates.
|
|
1000
1000
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
1001
1001
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1002
1002
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1003
|
-
|
|
1004
|
-
A dictionary mapping
|
|
1005
|
-
|
|
1006
|
-
For each
|
|
1003
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1004
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
1005
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
1006
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
1007
1007
|
provided. Typically just in-study times if action probabilites are used in the inference
|
|
1008
1008
|
loss or estimating function.
|
|
1009
|
-
|
|
1010
|
-
A dictionary mapping
|
|
1009
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1010
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
1011
1011
|
Only applies to in-study decision times!
|
|
1012
1012
|
|
|
1013
1013
|
Returns:
|
|
@@ -1015,10 +1015,10 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1015
1015
|
A 2D JAX NumPy array holding the average weighted inference estimating function.
|
|
1016
1016
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1017
1017
|
A tuple containing
|
|
1018
|
-
1. the per-
|
|
1018
|
+
1. the per-subject weighted inference estimating function stacks
|
|
1019
1019
|
2. the average weighted inference estimating function
|
|
1020
|
-
3. the
|
|
1021
|
-
4. the
|
|
1020
|
+
3. the subject-level classical meat matrix contributions
|
|
1021
|
+
4. the subject-level inverse classical bread matrix contributions
|
|
1022
1022
|
stacks.
|
|
1023
1023
|
"""
|
|
1024
1024
|
|
|
@@ -1028,7 +1028,7 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1028
1028
|
else inference_func
|
|
1029
1029
|
)
|
|
1030
1030
|
|
|
1031
|
-
betas, theta =
|
|
1031
|
+
betas, theta = post_deployment_analysis.unflatten_params(
|
|
1032
1032
|
flattened_betas_and_theta,
|
|
1033
1033
|
beta_dim,
|
|
1034
1034
|
theta_dim,
|
|
@@ -1038,15 +1038,15 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1038
1038
|
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
1039
1039
|
# values should be the same, but not connected to the parameter we are differentiating
|
|
1040
1040
|
# with respect to. Note we will also find it useful below to have the action probability args
|
|
1041
|
-
# nested dict structure flipped to be
|
|
1041
|
+
# nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
|
|
1042
1042
|
|
|
1043
|
-
logger.info("Threading in betas to action probability arguments for all
|
|
1043
|
+
logger.info("Threading in betas to action probability arguments for all subjects.")
|
|
1044
1044
|
(
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
) =
|
|
1048
|
-
|
|
1049
|
-
|
|
1045
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1046
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
1047
|
+
) = post_deployment_analysis.thread_action_prob_func_args(
|
|
1048
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1049
|
+
policy_num_by_decision_time_by_subject_id,
|
|
1050
1050
|
initial_policy_num,
|
|
1051
1051
|
betas,
|
|
1052
1052
|
beta_index_by_policy_num,
|
|
@@ -1058,38 +1058,38 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1058
1058
|
# arguments with the central betas introduced.
|
|
1059
1059
|
logger.info(
|
|
1060
1060
|
"Threading in theta and beta-dependent action probabilities to inference update "
|
|
1061
|
-
"function args for all
|
|
1061
|
+
"function args for all subjects"
|
|
1062
1062
|
)
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1063
|
+
threaded_inference_func_args_by_subject_id = (
|
|
1064
|
+
post_deployment_analysis.thread_inference_func_args(
|
|
1065
|
+
inference_func_args_by_subject_id,
|
|
1066
1066
|
inference_func_args_theta_index,
|
|
1067
1067
|
theta,
|
|
1068
1068
|
inference_func_args_action_prob_index,
|
|
1069
|
-
|
|
1070
|
-
|
|
1069
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1070
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
1071
1071
|
action_prob_func,
|
|
1072
1072
|
)
|
|
1073
1073
|
)
|
|
1074
1074
|
|
|
1075
|
-
# 5. Now we can compute the the weighted inference estimating functions for all
|
|
1076
|
-
# as well as collect related values used to construct the
|
|
1075
|
+
# 5. Now we can compute the the weighted inference estimating functions for all subjects
|
|
1076
|
+
# as well as collect related values used to construct the adjusted and classical
|
|
1077
1077
|
# sandwich variances.
|
|
1078
1078
|
results = [
|
|
1079
|
-
|
|
1080
|
-
|
|
1079
|
+
single_subject_weighted_inference_estimating_function(
|
|
1080
|
+
subject_id,
|
|
1081
1081
|
action_prob_func,
|
|
1082
1082
|
inference_estimating_func,
|
|
1083
1083
|
action_prob_func_args_beta_index,
|
|
1084
1084
|
inference_func_args_theta_index,
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1085
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1086
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1087
|
+
threaded_inference_func_args_by_subject_id[subject_id],
|
|
1088
|
+
policy_num_by_decision_time_by_subject_id[subject_id],
|
|
1089
|
+
action_by_decision_time_by_subject_id[subject_id],
|
|
1090
1090
|
beta_index_by_policy_num,
|
|
1091
1091
|
)
|
|
1092
|
-
for
|
|
1092
|
+
for subject_id in subject_ids.tolist()
|
|
1093
1093
|
]
|
|
1094
1094
|
|
|
1095
1095
|
weighted_inference_estimating_functions = jnp.array(
|
|
@@ -1100,8 +1100,8 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1100
1100
|
|
|
1101
1101
|
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1102
1102
|
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1103
|
-
# for the
|
|
1104
|
-
# bread matrices. The raw per-
|
|
1103
|
+
# for the adjusted meat matrix, estimating functions sum check, and classical meat and inverse
|
|
1104
|
+
# bread matrices. The raw per-subject estimating functions are also returned again for debugging
|
|
1105
1105
|
# purposes.
|
|
1106
1106
|
return jnp.mean(weighted_inference_estimating_functions, axis=0), (
|
|
1107
1107
|
weighted_inference_estimating_functions,
|
|
@@ -1111,8 +1111,8 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1111
1111
|
)
|
|
1112
1112
|
|
|
1113
1113
|
|
|
1114
|
-
def
|
|
1115
|
-
|
|
1114
|
+
def single_subject_weighted_inference_estimating_function(
|
|
1115
|
+
subject_id: collections.abc.Hashable,
|
|
1116
1116
|
action_prob_func: callable,
|
|
1117
1117
|
inference_estimating_func: callable,
|
|
1118
1118
|
action_prob_func_args_beta_index: int,
|
|
@@ -1137,12 +1137,12 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1137
1137
|
and action probability function and arguments if applicable.
|
|
1138
1138
|
|
|
1139
1139
|
Args:
|
|
1140
|
-
|
|
1141
|
-
The
|
|
1140
|
+
subject_id (collections.abc.Hashable):
|
|
1141
|
+
The subject ID for which to compute the weighted estimating function stack.
|
|
1142
1142
|
|
|
1143
1143
|
action_prob_func (callable):
|
|
1144
1144
|
The function used to compute the probability of action 1 at a given decision time for
|
|
1145
|
-
a particular
|
|
1145
|
+
a particular subject given their state and the algorithm parameters.
|
|
1146
1146
|
|
|
1147
1147
|
inference_estimating_func (callable):
|
|
1148
1148
|
The estimating function that corresponds to inference.
|
|
@@ -1154,7 +1154,7 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1154
1154
|
The index of the theta parameter in the inference loss or estimating function arguments.
|
|
1155
1155
|
|
|
1156
1156
|
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
1157
|
-
A map from decision times to tuples of arguments for this
|
|
1157
|
+
A map from decision times to tuples of arguments for this subject for the action
|
|
1158
1158
|
probability function. This is for all decision times (args are an empty
|
|
1159
1159
|
tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
|
|
1160
1160
|
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
@@ -1167,11 +1167,11 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1167
1167
|
|
|
1168
1168
|
threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1169
1169
|
A tuple containing the arguments for the inference
|
|
1170
|
-
estimating function for this
|
|
1170
|
+
estimating function for this subject, with the shared betas threaded in for differentiation.
|
|
1171
1171
|
|
|
1172
1172
|
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1173
1173
|
A dictionary mapping decision times to the policy number in use. This may be
|
|
1174
|
-
|
|
1174
|
+
subject-specific. Should be sorted by decision time. Only applies to in-study decision
|
|
1175
1175
|
times!
|
|
1176
1176
|
|
|
1177
1177
|
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
@@ -1183,31 +1183,33 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1183
1183
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1184
1184
|
|
|
1185
1185
|
Returns:
|
|
1186
|
-
jnp.ndarray: A 1-D JAX NumPy array representing the
|
|
1187
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
1188
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
1186
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted inference estimating function.
|
|
1187
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
1188
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
1189
1189
|
"""
|
|
1190
1190
|
|
|
1191
1191
|
logger.info(
|
|
1192
|
-
"Computing only weighted inference estimating function stack for
|
|
1193
|
-
|
|
1192
|
+
"Computing only weighted inference estimating function stack for subject %s.",
|
|
1193
|
+
subject_id,
|
|
1194
1194
|
)
|
|
1195
1195
|
|
|
1196
1196
|
# First, reformat the supplied data into more convenient structures.
|
|
1197
1197
|
|
|
1198
1198
|
# 1. Get the first time after the first update for convenience.
|
|
1199
1199
|
# This is used to form the Radon-Nikodym weights for the right times.
|
|
1200
|
-
_, first_time_after_first_update =
|
|
1201
|
-
|
|
1202
|
-
|
|
1200
|
+
_, first_time_after_first_update = (
|
|
1201
|
+
post_deployment_analysis.get_min_time_by_policy_num(
|
|
1202
|
+
policy_num_by_decision_time,
|
|
1203
|
+
beta_index_by_policy_num,
|
|
1204
|
+
)
|
|
1203
1205
|
)
|
|
1204
1206
|
|
|
1205
|
-
# 2. Get the start and end times for this
|
|
1206
|
-
|
|
1207
|
-
|
|
1207
|
+
# 2. Get the start and end times for this subject.
|
|
1208
|
+
subject_start_time = math.inf
|
|
1209
|
+
subject_end_time = -math.inf
|
|
1208
1210
|
for decision_time in action_by_decision_time:
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
+
subject_start_time = min(subject_start_time, decision_time)
|
|
1212
|
+
subject_end_time = max(subject_end_time, decision_time)
|
|
1211
1213
|
|
|
1212
1214
|
# 3. Calculate the Radon-Nikodym weights for the inference estimating function.
|
|
1213
1215
|
in_study_action_prob_func_args = [
|
|
@@ -1224,11 +1226,11 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1224
1226
|
)
|
|
1225
1227
|
|
|
1226
1228
|
# Sort the threaded args by decision time to be cautious. We check if the
|
|
1227
|
-
#
|
|
1228
|
-
# subset of the
|
|
1229
|
+
# subject id is present in the subject args dict because we may call this on a
|
|
1230
|
+
# subset of the subject arg dict when we are batching arguments by shape
|
|
1229
1231
|
sorted_threaded_action_prob_args_by_decision_time = {
|
|
1230
1232
|
decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
|
|
1231
|
-
for decision_time in range(
|
|
1233
|
+
for decision_time in range(subject_start_time, subject_end_time + 1)
|
|
1232
1234
|
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
1233
1235
|
}
|
|
1234
1236
|
|
|
@@ -1260,7 +1262,7 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1260
1262
|
# value, but impervious to differentiation with respect to all_post_update_betas. The
|
|
1261
1263
|
# args, on the other hand, are a function of all_post_update_betas.
|
|
1262
1264
|
in_study_weights = jax.vmap(
|
|
1263
|
-
fun=
|
|
1265
|
+
fun=post_deployment_analysis.get_radon_nikodym_weight,
|
|
1264
1266
|
in_axes=[0, None, None, 0] + batch_axes,
|
|
1265
1267
|
out_axes=0,
|
|
1266
1268
|
)(
|
|
@@ -1287,12 +1289,12 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1287
1289
|
# 4. Form the weighted inference estimating equation.
|
|
1288
1290
|
weighted_inference_estimating_function = jnp.prod(
|
|
1289
1291
|
all_weights[
|
|
1290
|
-
max(first_time_after_first_update,
|
|
1291
|
-
- decision_time_to_all_weights_index_offset :
|
|
1292
|
+
max(first_time_after_first_update, subject_start_time)
|
|
1293
|
+
- decision_time_to_all_weights_index_offset : subject_end_time
|
|
1292
1294
|
+ 1
|
|
1293
1295
|
- decision_time_to_all_weights_index_offset,
|
|
1294
1296
|
]
|
|
1295
|
-
# If the
|
|
1297
|
+
# If the subject exited the study before there were any updates,
|
|
1296
1298
|
# this variable will be None and the above code to grab a weight would
|
|
1297
1299
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
1298
1300
|
# if they have data to contribute here (pretty sure everyone should?)
|