lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/after_study_analysis.py +397 -387
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -21
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +146 -128
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +7 -7
- lifejacket/get_datum_for_blowup_supervised_learning.py +315 -307
- lifejacket/helper_functions.py +45 -38
- lifejacket/input_checks.py +263 -261
- lifejacket/small_sample_corrections.py +42 -40
- lifejacket-1.0.0.dist-info/METADATA +56 -0
- lifejacket-1.0.0.dist-info/RECORD +17 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -25,25 +25,25 @@ logging.basicConfig(
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def get_datum_for_blowup_supervised_learning(
|
|
28
|
-
|
|
29
|
-
|
|
28
|
+
joint_adjusted_bread_inverse_matrix,
|
|
29
|
+
joint_adjusted_bread_inverse_cond,
|
|
30
30
|
avg_estimating_function_stack,
|
|
31
|
-
|
|
31
|
+
per_subject_estimating_function_stacks,
|
|
32
32
|
all_post_update_betas,
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
analysis_df,
|
|
34
|
+
active_col_name,
|
|
35
35
|
calendar_t_col_name,
|
|
36
36
|
action_prob_col_name,
|
|
37
|
-
|
|
37
|
+
subject_id_col_name,
|
|
38
38
|
reward_col_name,
|
|
39
39
|
theta_est,
|
|
40
|
-
|
|
41
|
-
|
|
40
|
+
adjusted_sandwich_var_estimate,
|
|
41
|
+
subject_ids,
|
|
42
42
|
beta_dim,
|
|
43
43
|
theta_dim,
|
|
44
44
|
initial_policy_num,
|
|
45
45
|
beta_index_by_policy_num,
|
|
46
|
-
|
|
46
|
+
policy_num_by_decision_time_by_subject_id,
|
|
47
47
|
theta_calculation_func,
|
|
48
48
|
action_prob_func,
|
|
49
49
|
action_prob_func_args_beta_index,
|
|
@@ -51,46 +51,46 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
51
51
|
inference_func_type,
|
|
52
52
|
inference_func_args_theta_index,
|
|
53
53
|
inference_func_args_action_prob_index,
|
|
54
|
-
|
|
54
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
55
55
|
action_prob_func_args,
|
|
56
|
-
|
|
56
|
+
action_by_decision_time_by_subject_id,
|
|
57
57
|
) -> dict[str, Any]:
|
|
58
58
|
"""
|
|
59
|
-
Collects a datum for supervised learning about
|
|
59
|
+
Collects a datum for supervised learning about adjusted sandwich blowup.
|
|
60
60
|
|
|
61
|
-
The datum consists of features and the raw
|
|
61
|
+
The datum consists of features and the raw adjusted sandwich variance estimate as a label.
|
|
62
62
|
|
|
63
63
|
A few plots are produced along the way to help visualize the data.
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
|
-
|
|
67
|
-
The joint
|
|
68
|
-
|
|
69
|
-
The condition number of the joint
|
|
66
|
+
joint_adjusted_bread_inverse_matrix (jnp.ndarray):
|
|
67
|
+
The joint adjusted bread inverse matrix.
|
|
68
|
+
joint_adjusted_bread_inverse_cond (float):
|
|
69
|
+
The condition number of the joint adjusted bread inverse matrix.
|
|
70
70
|
avg_estimating_function_stack (jnp.ndarray):
|
|
71
|
-
The average estimating function stack across
|
|
72
|
-
|
|
73
|
-
The estimating function stacks for each
|
|
71
|
+
The average estimating function stack across subjects.
|
|
72
|
+
per_subject_estimating_function_stacks (jnp.ndarray):
|
|
73
|
+
The estimating function stacks for each subject.
|
|
74
74
|
all_post_update_betas (jnp.ndarray):
|
|
75
75
|
All post-update beta parameters.
|
|
76
|
-
|
|
76
|
+
analysis_df (pd.DataFrame):
|
|
77
77
|
The study DataFrame.
|
|
78
|
-
|
|
79
|
-
Column name indicating if a
|
|
78
|
+
active_col_name (str):
|
|
79
|
+
Column name indicating if a subject is in the study in the study dataframe.
|
|
80
80
|
calendar_t_col_name (str):
|
|
81
81
|
Column name for calendar time in the study dataframe.
|
|
82
82
|
action_prob_col_name (str):
|
|
83
83
|
Column name for action probabilities in the study dataframe.
|
|
84
|
-
|
|
85
|
-
Column name for
|
|
84
|
+
subject_id_col_name (str):
|
|
85
|
+
Column name for subject IDs in the study dataframe
|
|
86
86
|
reward_col_name (str):
|
|
87
87
|
Column name for rewards in the study dataframe.
|
|
88
88
|
theta_est (jnp.ndarray):
|
|
89
89
|
The estimate of the parameter vector theta.
|
|
90
|
-
|
|
91
|
-
The
|
|
92
|
-
|
|
93
|
-
Array of unique
|
|
90
|
+
adjusted_sandwich_var_estimate (jnp.ndarray):
|
|
91
|
+
The adjusted sandwich variance estimate for theta.
|
|
92
|
+
subject_ids (jnp.ndarray):
|
|
93
|
+
Array of unique subject IDs.
|
|
94
94
|
beta_dim (int):
|
|
95
95
|
Dimension of the beta parameter vector.
|
|
96
96
|
theta_dim (int):
|
|
@@ -99,8 +99,8 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
99
99
|
The initial policy number used in the study.
|
|
100
100
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
101
101
|
Mapping from policy numbers to indices in all_post_update_betas.
|
|
102
|
-
|
|
103
|
-
Mapping from
|
|
102
|
+
policy_num_by_decision_time_by_subject_id (dict):
|
|
103
|
+
Mapping from subject IDs to their policy numbers by decision time.
|
|
104
104
|
theta_calculation_func (callable):
|
|
105
105
|
The theta calculation function.
|
|
106
106
|
action_prob_func (callable):
|
|
@@ -115,17 +115,17 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
115
115
|
Index for theta in inference function arguments.
|
|
116
116
|
inference_func_args_action_prob_index (int):
|
|
117
117
|
Index for action probability in inference function arguments.
|
|
118
|
-
|
|
119
|
-
Mapping from
|
|
118
|
+
inference_action_prob_decision_times_by_subject_id (dict):
|
|
119
|
+
Mapping from subject IDs to decision times for action probabilities used in inference.
|
|
120
120
|
action_prob_func_args (dict):
|
|
121
121
|
Arguments for the action probability function.
|
|
122
|
-
|
|
123
|
-
Mapping from
|
|
122
|
+
action_by_decision_time_by_subject_id (dict):
|
|
123
|
+
Mapping from subject IDs to their actions by decision time.
|
|
124
124
|
Returns:
|
|
125
125
|
dict[str, Any]: A dictionary containing features and the label for supervised learning.
|
|
126
126
|
"""
|
|
127
127
|
num_diagonal_blocks = (
|
|
128
|
-
(
|
|
128
|
+
(joint_adjusted_bread_inverse_matrix.shape[0] - theta_dim) // beta_dim
|
|
129
129
|
) + 1
|
|
130
130
|
diagonal_block_sizes = ([beta_dim] * (num_diagonal_blocks - 1)) + [theta_dim]
|
|
131
131
|
|
|
@@ -144,7 +144,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
144
144
|
row_slice = slice(block_bounds[i], block_bounds[i + 1])
|
|
145
145
|
col_slice = slice(block_bounds[j], block_bounds[j + 1])
|
|
146
146
|
block_norm = np.linalg.norm(
|
|
147
|
-
|
|
147
|
+
joint_adjusted_bread_inverse_matrix[row_slice, col_slice],
|
|
148
148
|
ord="fro",
|
|
149
149
|
)
|
|
150
150
|
# We will sum here and take the square root later
|
|
@@ -155,9 +155,9 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
155
155
|
# handle diagonal blocks
|
|
156
156
|
sl = slice(block_bounds[i], block_bounds[i + 1])
|
|
157
157
|
diag_norms.append(
|
|
158
|
-
np.linalg.norm(
|
|
158
|
+
np.linalg.norm(joint_adjusted_bread_inverse_matrix[sl, sl], ord="fro")
|
|
159
159
|
)
|
|
160
|
-
diag_conds.append(np.linalg.cond(
|
|
160
|
+
diag_conds.append(np.linalg.cond(joint_adjusted_bread_inverse_matrix[sl, sl]))
|
|
161
161
|
|
|
162
162
|
# Sqrt each row/col sum to truly get row/column norms.
|
|
163
163
|
# Perhaps not necessary for learning, but more natural
|
|
@@ -166,7 +166,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
166
166
|
|
|
167
167
|
# Get the per-person estimating function stack norms
|
|
168
168
|
estimating_function_stack_norms = np.linalg.norm(
|
|
169
|
-
|
|
169
|
+
per_subject_estimating_function_stacks, axis=1
|
|
170
170
|
)
|
|
171
171
|
|
|
172
172
|
# Get the average estimating function stack norms by update/inference
|
|
@@ -175,7 +175,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
175
175
|
avg_estimating_function_stack_norms_per_segment = [
|
|
176
176
|
np.mean(
|
|
177
177
|
np.linalg.norm(
|
|
178
|
-
|
|
178
|
+
per_subject_estimating_function_stacks[
|
|
179
179
|
:, block_bounds[i] : block_bounds[i + 1]
|
|
180
180
|
],
|
|
181
181
|
axis=1,
|
|
@@ -191,67 +191,67 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
191
191
|
std_successive_beta_diff_norm = np.std(successive_beta_diff_norms)
|
|
192
192
|
|
|
193
193
|
# Add a column with logits of the action probabilities
|
|
194
|
-
# Compute the average and standard deviation of the logits of the action probabilities at each decision time using
|
|
194
|
+
# Compute the average and standard deviation of the logits of the action probabilities at each decision time using analysis_df
|
|
195
195
|
# action_prob_logit_means and action_prob_logit_stds are numpy arrays of mean and stddev at each decision time
|
|
196
|
-
# Only compute logits for rows where
|
|
197
|
-
in_study_mask =
|
|
198
|
-
|
|
196
|
+
# Only compute logits for rows where subject is in the study; set others to NaN
|
|
197
|
+
in_study_mask = analysis_df[active_col_name] == 1
|
|
198
|
+
analysis_df["action_prob_logit"] = np.where(
|
|
199
199
|
in_study_mask,
|
|
200
|
-
logit(
|
|
200
|
+
logit(analysis_df[action_prob_col_name]),
|
|
201
201
|
np.nan,
|
|
202
202
|
)
|
|
203
|
-
grouped_action_prob_logit =
|
|
203
|
+
grouped_action_prob_logit = analysis_df.loc[in_study_mask].groupby(
|
|
204
204
|
calendar_t_col_name
|
|
205
205
|
)["action_prob_logit"]
|
|
206
206
|
action_prob_logit_means_by_t = grouped_action_prob_logit.mean().values
|
|
207
207
|
action_prob_logit_stds_by_t = grouped_action_prob_logit.std().values
|
|
208
208
|
|
|
209
|
-
# Compute the average and standard deviation of the rewards at each decision time using
|
|
209
|
+
# Compute the average and standard deviation of the rewards at each decision time using analysis_df
|
|
210
210
|
# reward_means and reward_stds are numpy arrays of mean and stddev at each decision time
|
|
211
|
-
grouped_reward =
|
|
211
|
+
grouped_reward = analysis_df.loc[in_study_mask].groupby(calendar_t_col_name)[
|
|
212
212
|
reward_col_name
|
|
213
213
|
]
|
|
214
214
|
reward_means_by_t = grouped_reward.mean().values
|
|
215
215
|
reward_stds_by_t = grouped_reward.std().values
|
|
216
216
|
|
|
217
217
|
joint_bread_inverse_min_singular_value = np.linalg.svd(
|
|
218
|
-
|
|
218
|
+
joint_adjusted_bread_inverse_matrix, compute_uv=False
|
|
219
219
|
)[-1]
|
|
220
220
|
|
|
221
|
-
max_reward =
|
|
221
|
+
max_reward = analysis_df.loc[in_study_mask][reward_col_name].max()
|
|
222
222
|
|
|
223
223
|
norm_avg_estimating_function_stack = np.linalg.norm(avg_estimating_function_stack)
|
|
224
224
|
max_estimating_function_stack_norm = np.max(estimating_function_stack_norms)
|
|
225
225
|
|
|
226
226
|
(
|
|
227
227
|
premature_thetas,
|
|
228
|
-
|
|
228
|
+
premature_adjusted_sandwiches,
|
|
229
229
|
premature_classical_sandwiches,
|
|
230
|
-
|
|
230
|
+
premature_joint_adjusted_bread_inverse_condition_numbers,
|
|
231
231
|
premature_avg_inference_estimating_functions,
|
|
232
|
-
) =
|
|
233
|
-
|
|
232
|
+
) = calculate_sequence_of_premature_adjusted_estimates(
|
|
233
|
+
analysis_df,
|
|
234
234
|
initial_policy_num,
|
|
235
235
|
beta_index_by_policy_num,
|
|
236
|
-
|
|
236
|
+
policy_num_by_decision_time_by_subject_id,
|
|
237
237
|
theta_calculation_func,
|
|
238
238
|
calendar_t_col_name,
|
|
239
239
|
action_prob_col_name,
|
|
240
|
-
|
|
241
|
-
|
|
240
|
+
subject_id_col_name,
|
|
241
|
+
active_col_name,
|
|
242
242
|
all_post_update_betas,
|
|
243
|
-
|
|
243
|
+
subject_ids,
|
|
244
244
|
action_prob_func,
|
|
245
245
|
action_prob_func_args_beta_index,
|
|
246
246
|
inference_func,
|
|
247
247
|
inference_func_type,
|
|
248
248
|
inference_func_args_theta_index,
|
|
249
249
|
inference_func_args_action_prob_index,
|
|
250
|
-
|
|
250
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
251
251
|
action_prob_func_args,
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
252
|
+
action_by_decision_time_by_subject_id,
|
|
253
|
+
joint_adjusted_bread_inverse_matrix,
|
|
254
|
+
per_subject_estimating_function_stacks,
|
|
255
255
|
beta_dim,
|
|
256
256
|
)
|
|
257
257
|
|
|
@@ -261,42 +261,42 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
261
261
|
atol=1e-3,
|
|
262
262
|
)
|
|
263
263
|
|
|
264
|
-
# Plot premature joint
|
|
264
|
+
# Plot premature joint adjusted bread inverse log condition numbers
|
|
265
265
|
plt.clear_figure()
|
|
266
266
|
plt.title("Premature Joint Adaptive Bread Inverse Log Condition Numbers")
|
|
267
267
|
plt.xlabel("Premature Update Index")
|
|
268
268
|
plt.ylabel("Log Condition Number")
|
|
269
269
|
plt.scatter(
|
|
270
|
-
np.log(
|
|
270
|
+
np.log(premature_joint_adjusted_bread_inverse_condition_numbers),
|
|
271
271
|
color="blue+",
|
|
272
272
|
)
|
|
273
273
|
plt.grid(True)
|
|
274
274
|
plt.xticks(
|
|
275
275
|
range(
|
|
276
276
|
0,
|
|
277
|
-
len(
|
|
277
|
+
len(premature_joint_adjusted_bread_inverse_condition_numbers),
|
|
278
278
|
max(
|
|
279
279
|
1,
|
|
280
|
-
len(
|
|
280
|
+
len(premature_joint_adjusted_bread_inverse_condition_numbers) // 10,
|
|
281
281
|
),
|
|
282
282
|
)
|
|
283
283
|
)
|
|
284
284
|
plt.show()
|
|
285
285
|
|
|
286
|
-
# Plot each diagonal element of premature
|
|
287
|
-
num_diag =
|
|
286
|
+
# Plot each diagonal element of premature adjusted sandwiches
|
|
287
|
+
num_diag = premature_adjusted_sandwiches.shape[-1]
|
|
288
288
|
for i in range(num_diag):
|
|
289
289
|
plt.clear_figure()
|
|
290
290
|
plt.title(f"Premature Adaptive Sandwich Diagonal Element {i}")
|
|
291
291
|
plt.xlabel("Premature Update Index")
|
|
292
292
|
plt.ylabel(f"Variance (Diagonal {i})")
|
|
293
|
-
plt.scatter(np.array(
|
|
293
|
+
plt.scatter(np.array(premature_adjusted_sandwiches[:, i, i]), color="blue+")
|
|
294
294
|
plt.grid(True)
|
|
295
295
|
plt.xticks(
|
|
296
296
|
range(
|
|
297
297
|
0,
|
|
298
|
-
int(
|
|
299
|
-
max(1, int(
|
|
298
|
+
int(premature_adjusted_sandwiches.shape[0]),
|
|
299
|
+
max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
|
|
300
300
|
)
|
|
301
301
|
)
|
|
302
302
|
plt.show()
|
|
@@ -308,7 +308,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
308
308
|
plt.xlabel("Premature Update Index")
|
|
309
309
|
plt.ylabel(f"Variance (Diagonal {i})")
|
|
310
310
|
plt.scatter(
|
|
311
|
-
np.array(
|
|
311
|
+
np.array(premature_adjusted_sandwiches[:, i, i])
|
|
312
312
|
/ np.array(premature_classical_sandwiches[:, i, i]),
|
|
313
313
|
color="red+",
|
|
314
314
|
)
|
|
@@ -316,8 +316,8 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
316
316
|
plt.xticks(
|
|
317
317
|
range(
|
|
318
318
|
0,
|
|
319
|
-
int(
|
|
320
|
-
max(1, int(
|
|
319
|
+
int(premature_adjusted_sandwiches.shape[0]),
|
|
320
|
+
max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
|
|
321
321
|
)
|
|
322
322
|
)
|
|
323
323
|
plt.show()
|
|
@@ -331,14 +331,14 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
331
331
|
plt.xticks(
|
|
332
332
|
range(
|
|
333
333
|
0,
|
|
334
|
-
int(
|
|
335
|
-
max(1, int(
|
|
334
|
+
int(premature_adjusted_sandwiches.shape[0]),
|
|
335
|
+
max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
|
|
336
336
|
)
|
|
337
337
|
)
|
|
338
338
|
plt.show()
|
|
339
339
|
|
|
340
340
|
# Grab predictors related to premature Phi-dot-bars
|
|
341
|
-
RL_stack_beta_derivatives_block =
|
|
341
|
+
RL_stack_beta_derivatives_block = joint_adjusted_bread_inverse_matrix[
|
|
342
342
|
:-theta_dim, :-theta_dim
|
|
343
343
|
]
|
|
344
344
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
@@ -397,14 +397,14 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
397
397
|
)
|
|
398
398
|
return {
|
|
399
399
|
**{
|
|
400
|
-
"joint_bread_inverse_condition_number":
|
|
400
|
+
"joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
|
|
401
401
|
"joint_bread_inverse_min_singular_value": joint_bread_inverse_min_singular_value,
|
|
402
402
|
"max_reward": max_reward,
|
|
403
403
|
"norm_avg_estimating_function_stack": norm_avg_estimating_function_stack,
|
|
404
404
|
"max_estimating_function_stack_norm": max_estimating_function_stack_norm,
|
|
405
405
|
"max_successive_beta_diff_norm": max_successive_beta_diff_norm,
|
|
406
406
|
"std_successive_beta_diff_norm": std_successive_beta_diff_norm,
|
|
407
|
-
"label":
|
|
407
|
+
"label": adjusted_sandwich_var_estimate,
|
|
408
408
|
},
|
|
409
409
|
**{
|
|
410
410
|
f"off_diag_block_{i}_{j}_norm": off_diag_block_norms[(i, j)]
|
|
@@ -422,10 +422,10 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
422
422
|
for i in range(num_block_rows_cols)
|
|
423
423
|
},
|
|
424
424
|
**{
|
|
425
|
-
f"
|
|
425
|
+
f"estimating_function_stack_norm_subject_{subject_id}": estimating_function_stack_norms[
|
|
426
426
|
i
|
|
427
427
|
]
|
|
428
|
-
for i,
|
|
428
|
+
for i, subject_id in enumerate(subject_ids)
|
|
429
429
|
},
|
|
430
430
|
**{
|
|
431
431
|
f"avg_estimating_function_stack_norm_segment_{i}": avg_estimating_function_stack_norms_per_segment[
|
|
@@ -455,18 +455,18 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
455
455
|
},
|
|
456
456
|
**{f"theta_est_{i}": theta_est[i].item() for i in range(len(theta_est))},
|
|
457
457
|
**{
|
|
458
|
-
f"
|
|
458
|
+
f"premature_joint_adjusted_bread_inverse_condition_number_{i}": premature_joint_adjusted_bread_inverse_condition_numbers[
|
|
459
459
|
i
|
|
460
460
|
]
|
|
461
461
|
for i in range(
|
|
462
|
-
len(
|
|
462
|
+
len(premature_joint_adjusted_bread_inverse_condition_numbers)
|
|
463
463
|
)
|
|
464
464
|
},
|
|
465
465
|
**{
|
|
466
|
-
f"
|
|
466
|
+
f"premature_adjusted_sandwich_update_{i}_diag_position_{j}": premature_adjusted_sandwich[
|
|
467
467
|
j, j
|
|
468
468
|
]
|
|
469
|
-
for
|
|
469
|
+
for premature_adjusted_sandwich in premature_adjusted_sandwiches
|
|
470
470
|
for j in range(theta_dim)
|
|
471
471
|
},
|
|
472
472
|
**{
|
|
@@ -497,44 +497,46 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
497
497
|
}
|
|
498
498
|
|
|
499
499
|
|
|
500
|
-
def
|
|
501
|
-
|
|
500
|
+
def calculate_sequence_of_premature_adjusted_estimates(
|
|
501
|
+
analysis_df: pd.DataFrame,
|
|
502
502
|
initial_policy_num: int | float,
|
|
503
503
|
beta_index_by_policy_num: dict[int | float, int],
|
|
504
|
-
|
|
504
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
505
505
|
collections.abc.Hashable, dict[int, int | float]
|
|
506
506
|
],
|
|
507
507
|
theta_calculation_func: str,
|
|
508
508
|
calendar_t_col_name: str,
|
|
509
509
|
action_prob_col_name: str,
|
|
510
|
-
|
|
511
|
-
|
|
510
|
+
subject_id_col_name: str,
|
|
511
|
+
active_col_name: str,
|
|
512
512
|
all_post_update_betas: jnp.ndarray,
|
|
513
|
-
|
|
513
|
+
subject_ids: jnp.ndarray,
|
|
514
514
|
action_prob_func: str,
|
|
515
515
|
action_prob_func_args_beta_index: int,
|
|
516
516
|
inference_func: str,
|
|
517
517
|
inference_func_type: str,
|
|
518
518
|
inference_func_args_theta_index: int,
|
|
519
519
|
inference_func_args_action_prob_index: int,
|
|
520
|
-
|
|
520
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
521
521
|
collections.abc.Hashable, list[int]
|
|
522
522
|
],
|
|
523
|
-
|
|
523
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
524
524
|
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
525
525
|
],
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
526
|
+
action_by_decision_time_by_subject_id: dict[
|
|
527
|
+
collections.abc.Hashable, dict[int, int]
|
|
528
|
+
],
|
|
529
|
+
full_joint_adjusted_bread_inverse_matrix: jnp.ndarray,
|
|
530
|
+
per_subject_estimating_function_stacks: jnp.ndarray,
|
|
529
531
|
beta_dim: int,
|
|
530
532
|
) -> jnp.ndarray:
|
|
531
533
|
"""
|
|
532
|
-
Calculates a sequence of premature
|
|
534
|
+
Calculates a sequence of premature adjusted estimates for the given study DataFrame, where we
|
|
533
535
|
pretend the study ended after each update in sequence. The behavior of this sequence may provide
|
|
534
|
-
insight into the stability of the final
|
|
536
|
+
insight into the stability of the final adjusted estimate.
|
|
535
537
|
|
|
536
538
|
Args:
|
|
537
|
-
|
|
539
|
+
analysis_df (pandas.DataFrame):
|
|
538
540
|
The DataFrame containing the study data.
|
|
539
541
|
initial_policy_num (int | float): The policy number of the initial policy before any updates.
|
|
540
542
|
initial_policy_num (int | float):
|
|
@@ -542,23 +544,23 @@ def calculate_sequence_of_premature_adaptive_estimates(
|
|
|
542
544
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
543
545
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
544
546
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
545
|
-
|
|
546
|
-
A map of
|
|
547
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
548
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
547
549
|
Only applies to in-study decision times!
|
|
548
550
|
theta_calculation_func (callable):
|
|
549
551
|
The filename for the theta calculation function.
|
|
550
552
|
calendar_t_col_name (str):
|
|
551
|
-
The name of the column in
|
|
553
|
+
The name of the column in analysis_df representing calendar time.
|
|
552
554
|
action_prob_col_name (str):
|
|
553
|
-
The name of the column in
|
|
554
|
-
|
|
555
|
-
The name of the column in
|
|
556
|
-
|
|
557
|
-
The name of the column in
|
|
555
|
+
The name of the column in analysis_df representing action probabilities.
|
|
556
|
+
subject_id_col_name (str):
|
|
557
|
+
The name of the column in analysis_df representing subject IDs.
|
|
558
|
+
active_col_name (str):
|
|
559
|
+
The name of the column in analysis_df indicating whether the subject is in the study at that time.
|
|
558
560
|
all_post_update_betas (jnp.ndarray):
|
|
559
561
|
A NumPy array containing all post-update beta values.
|
|
560
|
-
|
|
561
|
-
A NumPy array containing all
|
|
562
|
+
subject_ids (jnp.ndarray):
|
|
563
|
+
A NumPy array containing all subject IDs in the study.
|
|
562
564
|
action_prob_func (callable):
|
|
563
565
|
The action probability function.
|
|
564
566
|
action_prob_func_args_beta_index (int):
|
|
@@ -572,56 +574,56 @@ def calculate_sequence_of_premature_adaptive_estimates(
|
|
|
572
574
|
inference_func_args_action_prob_index (int):
|
|
573
575
|
The index of action probabilities in the inference function arguments tuple, if
|
|
574
576
|
applicable. -1 otherwise.
|
|
575
|
-
|
|
576
|
-
For each
|
|
577
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
578
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
577
579
|
provided. Typically just in-study times if action probabilites are used in the inference
|
|
578
580
|
loss or estimating function.
|
|
579
|
-
|
|
580
|
-
A dictionary mapping decision times to maps of
|
|
581
|
-
required to compute action probabilities for this
|
|
582
|
-
|
|
583
|
-
A dictionary mapping
|
|
581
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
582
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
583
|
+
required to compute action probabilities for this subject.
|
|
584
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
585
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
584
586
|
Only applies to in-study decision times!
|
|
585
|
-
|
|
586
|
-
The full joint
|
|
587
|
-
|
|
588
|
-
A NumPy array containing all per-
|
|
587
|
+
full_joint_adjusted_bread_inverse_matrix (jnp.ndarray):
|
|
588
|
+
The full joint adjusted bread inverse matrix as a NumPy array.
|
|
589
|
+
per_subject_estimating_function_stacks (jnp.ndarray):
|
|
590
|
+
A NumPy array containing all per-subject (weighted) estimating function stacks.
|
|
589
591
|
beta_dim (int):
|
|
590
592
|
The dimension of the beta parameters.
|
|
591
593
|
Returns:
|
|
592
|
-
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature
|
|
594
|
+
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature adjusted estimates.
|
|
593
595
|
"""
|
|
594
596
|
|
|
595
|
-
# Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming
|
|
597
|
+
# Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming adjusted and classical
|
|
596
598
|
# variance estimates pretending that each was the final policy.
|
|
597
|
-
|
|
599
|
+
premature_adjusted_sandwiches = []
|
|
598
600
|
premature_thetas = []
|
|
599
|
-
|
|
601
|
+
premature_joint_adjusted_bread_inverse_condition_numbers = []
|
|
600
602
|
premature_avg_inference_estimating_functions = []
|
|
601
603
|
premature_classical_sandwiches = []
|
|
602
604
|
logger.info(
|
|
603
|
-
"Calculating sequence of premature
|
|
605
|
+
"Calculating sequence of premature adjusted estimates by pretending the study ended after each update in sequence."
|
|
604
606
|
)
|
|
605
607
|
for policy_num in sorted(beta_index_by_policy_num):
|
|
606
608
|
logger.info(
|
|
607
|
-
"Calculating premature
|
|
609
|
+
"Calculating premature adjusted estimate assuming policy %s is the final one.",
|
|
608
610
|
policy_num,
|
|
609
611
|
)
|
|
610
612
|
pretend_max_policy = policy_num
|
|
611
613
|
|
|
612
|
-
|
|
613
|
-
|
|
614
|
+
truncated_joint_adjusted_bread_inverse_matrix = (
|
|
615
|
+
full_joint_adjusted_bread_inverse_matrix[
|
|
614
616
|
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
615
617
|
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
616
618
|
]
|
|
617
619
|
)
|
|
618
620
|
|
|
619
|
-
max_decision_time =
|
|
620
|
-
|
|
621
|
-
].max()
|
|
621
|
+
max_decision_time = analysis_df[
|
|
622
|
+
analysis_df["policy_num"] == pretend_max_policy
|
|
623
|
+
][calendar_t_col_name].max()
|
|
622
624
|
|
|
623
|
-
|
|
624
|
-
|
|
625
|
+
truncated_analysis_df = analysis_df[
|
|
626
|
+
analysis_df[calendar_t_col_name] <= max_decision_time
|
|
625
627
|
].copy()
|
|
626
628
|
|
|
627
629
|
truncated_beta_index_by_policy_num = {
|
|
@@ -632,83 +634,83 @@ def calculate_sequence_of_premature_adaptive_estimates(
|
|
|
632
634
|
|
|
633
635
|
truncated_all_post_update_betas = all_post_update_betas[: max_beta_index + 1, :]
|
|
634
636
|
|
|
635
|
-
premature_theta = jnp.array(theta_calculation_func(
|
|
637
|
+
premature_theta = jnp.array(theta_calculation_func(truncated_analysis_df))
|
|
636
638
|
|
|
637
|
-
|
|
638
|
-
decision_time:
|
|
639
|
-
for decision_time,
|
|
639
|
+
truncated_action_prob_func_args_by_subject_id_by_decision_time = {
|
|
640
|
+
decision_time: args_by_subject_id
|
|
641
|
+
for decision_time, args_by_subject_id in action_prob_func_args_by_subject_id_by_decision_time.items()
|
|
640
642
|
if decision_time <= max_decision_time
|
|
641
643
|
}
|
|
642
644
|
|
|
643
|
-
|
|
645
|
+
truncated_inference_func_args_by_subject_id, _, _ = (
|
|
644
646
|
after_study_analysis.process_inference_func_args(
|
|
645
647
|
inference_func,
|
|
646
648
|
inference_func_args_theta_index,
|
|
647
|
-
|
|
649
|
+
truncated_analysis_df,
|
|
648
650
|
premature_theta,
|
|
649
651
|
action_prob_col_name,
|
|
650
652
|
calendar_t_col_name,
|
|
651
|
-
|
|
652
|
-
|
|
653
|
+
subject_id_col_name,
|
|
654
|
+
active_col_name,
|
|
653
655
|
)
|
|
654
656
|
)
|
|
655
657
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
+
truncated_inference_action_prob_decision_times_by_subject_id = {
|
|
659
|
+
subject_id: [
|
|
658
660
|
decision_time
|
|
659
|
-
for decision_time in
|
|
660
|
-
|
|
661
|
+
for decision_time in inference_action_prob_decision_times_by_subject_id[
|
|
662
|
+
subject_id
|
|
661
663
|
]
|
|
662
664
|
if decision_time <= max_decision_time
|
|
663
665
|
]
|
|
664
666
|
# writing this way is important, handles empty dicts correctly
|
|
665
|
-
for
|
|
667
|
+
for subject_id in inference_action_prob_decision_times_by_subject_id
|
|
666
668
|
}
|
|
667
669
|
|
|
668
|
-
|
|
669
|
-
|
|
670
|
+
truncated_action_by_decision_time_by_subject_id = {
|
|
671
|
+
subject_id: {
|
|
670
672
|
decision_time: action
|
|
671
|
-
for decision_time, action in
|
|
672
|
-
|
|
673
|
+
for decision_time, action in action_by_decision_time_by_subject_id[
|
|
674
|
+
subject_id
|
|
673
675
|
].items()
|
|
674
676
|
if decision_time <= max_decision_time
|
|
675
677
|
}
|
|
676
|
-
for
|
|
678
|
+
for subject_id in action_by_decision_time_by_subject_id
|
|
677
679
|
}
|
|
678
680
|
|
|
679
|
-
|
|
680
|
-
|
|
681
|
+
truncated_per_subject_estimating_function_stacks = (
|
|
682
|
+
per_subject_estimating_function_stacks[
|
|
681
683
|
:,
|
|
682
684
|
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
683
685
|
]
|
|
684
686
|
)
|
|
685
687
|
|
|
686
688
|
(
|
|
687
|
-
|
|
689
|
+
premature_adjusted_sandwich,
|
|
688
690
|
premature_classical_sandwich,
|
|
689
691
|
premature_avg_inference_estimating_function,
|
|
690
|
-
) =
|
|
691
|
-
|
|
692
|
-
|
|
692
|
+
) = construct_premature_classical_and_adjusted_sandwiches(
|
|
693
|
+
truncated_joint_adjusted_bread_inverse_matrix,
|
|
694
|
+
truncated_per_subject_estimating_function_stacks,
|
|
693
695
|
premature_theta,
|
|
694
696
|
truncated_all_post_update_betas,
|
|
695
|
-
|
|
697
|
+
subject_ids,
|
|
696
698
|
action_prob_func,
|
|
697
699
|
action_prob_func_args_beta_index,
|
|
698
700
|
inference_func,
|
|
699
701
|
inference_func_type,
|
|
700
702
|
inference_func_args_theta_index,
|
|
701
703
|
inference_func_args_action_prob_index,
|
|
702
|
-
|
|
703
|
-
|
|
704
|
+
truncated_action_prob_func_args_by_subject_id_by_decision_time,
|
|
705
|
+
policy_num_by_decision_time_by_subject_id,
|
|
704
706
|
initial_policy_num,
|
|
705
707
|
truncated_beta_index_by_policy_num,
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
708
|
+
truncated_inference_func_args_by_subject_id,
|
|
709
|
+
truncated_inference_action_prob_decision_times_by_subject_id,
|
|
710
|
+
truncated_action_by_decision_time_by_subject_id,
|
|
709
711
|
)
|
|
710
712
|
|
|
711
|
-
|
|
713
|
+
premature_adjusted_sandwiches.append(premature_adjusted_sandwich)
|
|
712
714
|
premature_classical_sandwiches.append(premature_classical_sandwich)
|
|
713
715
|
premature_thetas.append(premature_theta)
|
|
714
716
|
premature_avg_inference_estimating_functions.append(
|
|
@@ -716,38 +718,40 @@ def calculate_sequence_of_premature_adaptive_estimates(
|
|
|
716
718
|
)
|
|
717
719
|
return (
|
|
718
720
|
jnp.array(premature_thetas),
|
|
719
|
-
jnp.array(
|
|
721
|
+
jnp.array(premature_adjusted_sandwiches),
|
|
720
722
|
jnp.array(premature_classical_sandwiches),
|
|
721
|
-
jnp.array(
|
|
723
|
+
jnp.array(premature_joint_adjusted_bread_inverse_condition_numbers),
|
|
722
724
|
jnp.array(premature_avg_inference_estimating_functions),
|
|
723
725
|
)
|
|
724
726
|
|
|
725
727
|
|
|
726
|
-
def
|
|
727
|
-
|
|
728
|
-
|
|
728
|
+
def construct_premature_classical_and_adjusted_sandwiches(
|
|
729
|
+
truncated_joint_adjusted_bread_inverse_matrix: jnp.ndarray,
|
|
730
|
+
per_subject_truncated_estimating_function_stacks: jnp.ndarray,
|
|
729
731
|
theta: jnp.ndarray,
|
|
730
732
|
all_post_update_betas: jnp.ndarray,
|
|
731
|
-
|
|
733
|
+
subject_ids: jnp.ndarray,
|
|
732
734
|
action_prob_func: str,
|
|
733
735
|
action_prob_func_args_beta_index: int,
|
|
734
736
|
inference_func: str,
|
|
735
737
|
inference_func_type: str,
|
|
736
738
|
inference_func_args_theta_index: int,
|
|
737
739
|
inference_func_args_action_prob_index: int,
|
|
738
|
-
|
|
740
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
739
741
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
740
742
|
],
|
|
741
|
-
|
|
743
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
742
744
|
collections.abc.Hashable, dict[int, int | float]
|
|
743
745
|
],
|
|
744
746
|
initial_policy_num: int | float,
|
|
745
747
|
beta_index_by_policy_num: dict[int | float, int],
|
|
746
|
-
|
|
747
|
-
|
|
748
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
749
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
748
750
|
collections.abc.Hashable, list[int]
|
|
749
751
|
],
|
|
750
|
-
|
|
752
|
+
action_by_decision_time_by_subject_id: dict[
|
|
753
|
+
collections.abc.Hashable, dict[int, int]
|
|
754
|
+
],
|
|
751
755
|
) -> tuple[
|
|
752
756
|
jnp.ndarray[jnp.float32],
|
|
753
757
|
jnp.ndarray[jnp.float32],
|
|
@@ -759,33 +763,33 @@ def construct_premature_classical_and_adaptive_sandwiches(
|
|
|
759
763
|
jnp.ndarray[jnp.float32],
|
|
760
764
|
]:
|
|
761
765
|
"""
|
|
762
|
-
Constructs the classical bread and meat matrices, as well as the
|
|
766
|
+
Constructs the classical bread and meat matrices, as well as the adjusted bread matrix
|
|
763
767
|
and the average weighted inference estimating function for the premature variance estimation
|
|
764
768
|
procedure.
|
|
765
769
|
|
|
766
770
|
This is done by computing and differentiating the new average inference estimating function
|
|
767
771
|
with respect to the betas and theta, and stitching this together with the existing
|
|
768
|
-
|
|
769
|
-
to form the new premature joint
|
|
772
|
+
adjusted bread inverse matrix portion (corresponding to the updates still under consideration)
|
|
773
|
+
to form the new premature joint adjusted bread inverse matrix.
|
|
770
774
|
|
|
771
775
|
Args:
|
|
772
|
-
|
|
773
|
-
A 2-D JAX NumPy array holding the existing joint
|
|
776
|
+
truncated_joint_adjusted_bread_inverse_matrix (jnp.ndarray):
|
|
777
|
+
A 2-D JAX NumPy array holding the existing joint adjusted bread inverse but
|
|
774
778
|
with rows corresponding to updates not under consideration and inference dropped.
|
|
775
779
|
We will stitch this together with the newly computed inference portion to form
|
|
776
|
-
our "premature" joint
|
|
777
|
-
|
|
778
|
-
A 2-D JAX NumPy array holding the existing per-
|
|
780
|
+
our "premature" joint adjusted bread inverse matrix.
|
|
781
|
+
per_subject_truncated_estimating_function_stacks (jnp.ndarray):
|
|
782
|
+
A 2-D JAX NumPy array holding the existing per-subject weighted estimating function
|
|
779
783
|
stacks but with rows corresponding to updates not under consideration dropped.
|
|
780
784
|
We will stitch this together with the newly computed inference estimating functions
|
|
781
|
-
to form our "premature" joint
|
|
782
|
-
|
|
785
|
+
to form our "premature" joint adjusted estimating function stacks from which the new
|
|
786
|
+
adjusted meat matrix can be computed.
|
|
783
787
|
theta (jnp.ndarray):
|
|
784
788
|
A 1-D JAX NumPy array representing the parameter estimate for inference.
|
|
785
789
|
all_post_update_betas (jnp.ndarray):
|
|
786
790
|
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
787
|
-
|
|
788
|
-
A 1-D JAX NumPy array holding all
|
|
791
|
+
subject_ids (jnp.ndarray):
|
|
792
|
+
A 1-D JAX NumPy array holding all subject IDs in the study.
|
|
789
793
|
action_prob_func (callable):
|
|
790
794
|
The action probability function.
|
|
791
795
|
action_prob_func_args_beta_index (int):
|
|
@@ -799,51 +803,51 @@ def construct_premature_classical_and_adaptive_sandwiches(
|
|
|
799
803
|
inference_func_args_action_prob_index (int):
|
|
800
804
|
The index of action probabilities in the inference function arguments tuple, if
|
|
801
805
|
applicable. -1 otherwise.
|
|
802
|
-
|
|
803
|
-
A dictionary mapping decision times to maps of
|
|
804
|
-
required to compute action probabilities for this
|
|
805
|
-
|
|
806
|
-
A map of
|
|
806
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
807
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
808
|
+
required to compute action probabilities for this subject.
|
|
809
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
810
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
807
811
|
Only applies to in-study decision times!
|
|
808
812
|
initial_policy_num (int | float):
|
|
809
813
|
The policy number of the initial policy before any updates.
|
|
810
814
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
811
815
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
812
816
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
813
|
-
|
|
814
|
-
A dictionary mapping
|
|
815
|
-
|
|
816
|
-
For each
|
|
817
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
818
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
819
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
820
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
817
821
|
provided. Typically just in-study times if action probabilites are used in the inference
|
|
818
822
|
loss or estimating function.
|
|
819
|
-
|
|
820
|
-
A dictionary mapping
|
|
823
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
824
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
821
825
|
Only applies to in-study decision times!
|
|
822
826
|
Returns:
|
|
823
827
|
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
|
|
824
828
|
jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
|
|
825
829
|
jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
826
830
|
A tuple containing:
|
|
827
|
-
- The joint
|
|
828
|
-
- The joint
|
|
829
|
-
- The joint
|
|
831
|
+
- The joint adjusted inverse bread matrix.
|
|
832
|
+
- The joint adjusted bread matrix.
|
|
833
|
+
- The joint adjusted meat matrix.
|
|
830
834
|
- The classical inverse bread matrix.
|
|
831
835
|
- The classical bread matrix.
|
|
832
836
|
- The classical meat matrix.
|
|
833
837
|
- The average (weighted) inference estimating function.
|
|
834
|
-
- The joint
|
|
838
|
+
- The joint adjusted inverse bread matrix condition number.
|
|
835
839
|
"""
|
|
836
840
|
logger.info(
|
|
837
841
|
"Differentiating average weighted inference estimating function stack and collecting auxiliary values."
|
|
838
842
|
)
|
|
839
843
|
# jax.jacobian may perform worse here--seemed to hang indefinitely while jacrev is merely very
|
|
840
844
|
# slow.
|
|
841
|
-
# Note that these "contributions" are per-
|
|
845
|
+
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
842
846
|
new_inference_block_row, (
|
|
843
|
-
|
|
847
|
+
per_subject_inference_estimating_functions,
|
|
844
848
|
avg_inference_estimating_function,
|
|
845
|
-
|
|
846
|
-
|
|
849
|
+
per_subject_classical_meat_contributions,
|
|
850
|
+
per_subject_classical_bread_inverse_contributions,
|
|
847
851
|
) = jax.jacrev(get_weighted_inference_estimating_functions_only, has_aux=True)(
|
|
848
852
|
# While JAX can technically differentiate with respect to a list of JAX arrays,
|
|
849
853
|
# it is more efficient to flatten them into a single array. This is done
|
|
@@ -851,29 +855,29 @@ def construct_premature_classical_and_adaptive_sandwiches(
|
|
|
851
855
|
after_study_analysis.flatten_params(all_post_update_betas, theta),
|
|
852
856
|
all_post_update_betas.shape[1],
|
|
853
857
|
theta.shape[0],
|
|
854
|
-
|
|
858
|
+
subject_ids,
|
|
855
859
|
action_prob_func,
|
|
856
860
|
action_prob_func_args_beta_index,
|
|
857
861
|
inference_func,
|
|
858
862
|
inference_func_type,
|
|
859
863
|
inference_func_args_theta_index,
|
|
860
864
|
inference_func_args_action_prob_index,
|
|
861
|
-
|
|
862
|
-
|
|
865
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
866
|
+
policy_num_by_decision_time_by_subject_id,
|
|
863
867
|
initial_policy_num,
|
|
864
868
|
beta_index_by_policy_num,
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
869
|
+
inference_func_args_by_subject_id,
|
|
870
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
871
|
+
action_by_decision_time_by_subject_id,
|
|
868
872
|
)
|
|
869
873
|
|
|
870
|
-
|
|
874
|
+
joint_adjusted_bread_inverse_matrix = jnp.block(
|
|
871
875
|
[
|
|
872
876
|
[
|
|
873
|
-
|
|
877
|
+
truncated_joint_adjusted_bread_inverse_matrix,
|
|
874
878
|
np.zeros(
|
|
875
879
|
(
|
|
876
|
-
|
|
880
|
+
truncated_joint_adjusted_bread_inverse_matrix.shape[0],
|
|
877
881
|
new_inference_block_row.shape[0],
|
|
878
882
|
)
|
|
879
883
|
),
|
|
@@ -881,51 +885,53 @@ def construct_premature_classical_and_adaptive_sandwiches(
|
|
|
881
885
|
[new_inference_block_row],
|
|
882
886
|
]
|
|
883
887
|
)
|
|
884
|
-
|
|
888
|
+
per_subject_estimating_function_stacks = jnp.concatenate(
|
|
885
889
|
[
|
|
886
|
-
|
|
887
|
-
|
|
890
|
+
per_subject_truncated_estimating_function_stacks,
|
|
891
|
+
per_subject_inference_estimating_functions,
|
|
888
892
|
],
|
|
889
893
|
axis=1,
|
|
890
894
|
)
|
|
891
|
-
|
|
895
|
+
per_subject_adjusted_meat_contributions = jnp.einsum(
|
|
892
896
|
"ni,nj->nij",
|
|
893
|
-
|
|
894
|
-
|
|
897
|
+
per_subject_estimating_function_stacks,
|
|
898
|
+
per_subject_estimating_function_stacks,
|
|
895
899
|
)
|
|
896
900
|
|
|
897
|
-
|
|
901
|
+
joint_adjusted_meat_matrix = jnp.mean(
|
|
902
|
+
per_subject_adjusted_meat_contributions, axis=0
|
|
903
|
+
)
|
|
898
904
|
|
|
899
905
|
classical_bread_inverse_matrix = jnp.mean(
|
|
900
|
-
|
|
906
|
+
per_subject_classical_bread_inverse_contributions, axis=0
|
|
901
907
|
)
|
|
902
|
-
classical_meat_matrix = jnp.mean(
|
|
908
|
+
classical_meat_matrix = jnp.mean(per_subject_classical_meat_contributions, axis=0)
|
|
903
909
|
|
|
904
|
-
|
|
905
|
-
|
|
910
|
+
num_subjects = subject_ids.shape[0]
|
|
911
|
+
joint_adjusted_sandwich = (
|
|
906
912
|
after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
913
|
+
joint_adjusted_bread_inverse_matrix,
|
|
914
|
+
joint_adjusted_meat_matrix,
|
|
915
|
+
num_subjects,
|
|
910
916
|
method="bread_inverse_T_qr",
|
|
911
917
|
)
|
|
912
918
|
)
|
|
913
|
-
|
|
919
|
+
adjusted_sandwich = joint_adjusted_sandwich[-theta.shape[0] :, -theta.shape[0] :]
|
|
914
920
|
|
|
915
921
|
classical_bread_inverse_matrix = jnp.mean(
|
|
916
|
-
|
|
922
|
+
per_subject_classical_bread_inverse_contributions, axis=0
|
|
917
923
|
)
|
|
918
924
|
classical_sandwich = after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
|
|
919
925
|
classical_bread_inverse_matrix,
|
|
920
926
|
classical_meat_matrix,
|
|
921
|
-
|
|
927
|
+
num_subjects,
|
|
922
928
|
method="bread_inverse_T_qr",
|
|
923
929
|
)
|
|
924
930
|
|
|
925
|
-
# Stack the joint
|
|
926
|
-
# values too. The joint
|
|
931
|
+
# Stack the joint adjusted inverse bread pieces together horizontally and return the auxiliary
|
|
932
|
+
# values too. The joint adjusted bread inverse should always be block lower triangular.
|
|
927
933
|
return (
|
|
928
|
-
|
|
934
|
+
adjusted_sandwich,
|
|
929
935
|
classical_sandwich,
|
|
930
936
|
avg_inference_estimating_function,
|
|
931
937
|
)
|
|
@@ -935,32 +941,34 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
935
941
|
flattened_betas_and_theta: jnp.ndarray,
|
|
936
942
|
beta_dim: int,
|
|
937
943
|
theta_dim: int,
|
|
938
|
-
|
|
944
|
+
subject_ids: jnp.ndarray,
|
|
939
945
|
action_prob_func: callable,
|
|
940
946
|
action_prob_func_args_beta_index: int,
|
|
941
947
|
inference_func: callable,
|
|
942
948
|
inference_func_type: str,
|
|
943
949
|
inference_func_args_theta_index: int,
|
|
944
950
|
inference_func_args_action_prob_index: int,
|
|
945
|
-
|
|
951
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
946
952
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
947
953
|
],
|
|
948
|
-
|
|
954
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
949
955
|
collections.abc.Hashable, dict[int, int | float]
|
|
950
956
|
],
|
|
951
957
|
initial_policy_num: int | float,
|
|
952
958
|
beta_index_by_policy_num: dict[int | float, int],
|
|
953
|
-
|
|
954
|
-
|
|
959
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
960
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
955
961
|
collections.abc.Hashable, list[int]
|
|
956
962
|
],
|
|
957
|
-
|
|
963
|
+
action_by_decision_time_by_subject_id: dict[
|
|
964
|
+
collections.abc.Hashable, dict[int, int]
|
|
965
|
+
],
|
|
958
966
|
) -> tuple[
|
|
959
967
|
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
960
968
|
]:
|
|
961
969
|
"""
|
|
962
|
-
Computes the average weighted inference estimating function across
|
|
963
|
-
auxiliary values used to construct the
|
|
970
|
+
Computes the average weighted inference estimating function across subjects, along with
|
|
971
|
+
auxiliary values used to construct the adjusted and classical sandwich variances.
|
|
964
972
|
|
|
965
973
|
Note that input data should have been adjusted to only correspond to updates/decision times
|
|
966
974
|
that are being considered for the current "premature" variance estimation procedure.
|
|
@@ -974,8 +982,8 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
974
982
|
The dimension of each of the beta parameters.
|
|
975
983
|
theta_dim (int):
|
|
976
984
|
The dimension of the theta parameter.
|
|
977
|
-
|
|
978
|
-
A 1D JAX NumPy array of
|
|
985
|
+
subject_ids (jnp.ndarray):
|
|
986
|
+
A 1D JAX NumPy array of subject IDs.
|
|
979
987
|
action_prob_func (str):
|
|
980
988
|
The action probability function.
|
|
981
989
|
action_prob_func_args_beta_index (int):
|
|
@@ -989,25 +997,25 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
989
997
|
inference_func_args_action_prob_index (int):
|
|
990
998
|
The index of action probabilities in the inference function arguments tuple, if
|
|
991
999
|
applicable. -1 otherwise.
|
|
992
|
-
|
|
993
|
-
A dictionary mapping decision times to maps of
|
|
994
|
-
required to compute action probabilities for this
|
|
995
|
-
|
|
996
|
-
A map of
|
|
1000
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1001
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
1002
|
+
required to compute action probabilities for this subject.
|
|
1003
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1004
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
997
1005
|
Only applies to in-study decision times!
|
|
998
1006
|
initial_policy_num (int | float):
|
|
999
1007
|
The policy number of the initial policy before any updates.
|
|
1000
1008
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
1001
1009
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1002
1010
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1003
|
-
|
|
1004
|
-
A dictionary mapping
|
|
1005
|
-
|
|
1006
|
-
For each
|
|
1011
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1012
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
1013
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
1014
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
1007
1015
|
provided. Typically just in-study times if action probabilites are used in the inference
|
|
1008
1016
|
loss or estimating function.
|
|
1009
|
-
|
|
1010
|
-
A dictionary mapping
|
|
1017
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1018
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
1011
1019
|
Only applies to in-study decision times!
|
|
1012
1020
|
|
|
1013
1021
|
Returns:
|
|
@@ -1015,10 +1023,10 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1015
1023
|
A 2D JAX NumPy array holding the average weighted inference estimating function.
|
|
1016
1024
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1017
1025
|
A tuple containing
|
|
1018
|
-
1. the per-
|
|
1026
|
+
1. the per-subject weighted inference estimating function stacks
|
|
1019
1027
|
2. the average weighted inference estimating function
|
|
1020
|
-
3. the
|
|
1021
|
-
4. the
|
|
1028
|
+
3. the subject-level classical meat matrix contributions
|
|
1029
|
+
4. the subject-level inverse classical bread matrix contributions
|
|
1022
1030
|
stacks.
|
|
1023
1031
|
"""
|
|
1024
1032
|
|
|
@@ -1038,15 +1046,15 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1038
1046
|
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
1039
1047
|
# values should be the same, but not connected to the parameter we are differentiating
|
|
1040
1048
|
# with respect to. Note we will also find it useful below to have the action probability args
|
|
1041
|
-
# nested dict structure flipped to be
|
|
1049
|
+
# nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
|
|
1042
1050
|
|
|
1043
|
-
logger.info("Threading in betas to action probability arguments for all
|
|
1051
|
+
logger.info("Threading in betas to action probability arguments for all subjects.")
|
|
1044
1052
|
(
|
|
1045
|
-
|
|
1046
|
-
|
|
1053
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1054
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
1047
1055
|
) = after_study_analysis.thread_action_prob_func_args(
|
|
1048
|
-
|
|
1049
|
-
|
|
1056
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1057
|
+
policy_num_by_decision_time_by_subject_id,
|
|
1050
1058
|
initial_policy_num,
|
|
1051
1059
|
betas,
|
|
1052
1060
|
beta_index_by_policy_num,
|
|
@@ -1058,38 +1066,38 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1058
1066
|
# arguments with the central betas introduced.
|
|
1059
1067
|
logger.info(
|
|
1060
1068
|
"Threading in theta and beta-dependent action probabilities to inference update "
|
|
1061
|
-
"function args for all
|
|
1069
|
+
"function args for all subjects"
|
|
1062
1070
|
)
|
|
1063
|
-
|
|
1071
|
+
threaded_inference_func_args_by_subject_id = (
|
|
1064
1072
|
after_study_analysis.thread_inference_func_args(
|
|
1065
|
-
|
|
1073
|
+
inference_func_args_by_subject_id,
|
|
1066
1074
|
inference_func_args_theta_index,
|
|
1067
1075
|
theta,
|
|
1068
1076
|
inference_func_args_action_prob_index,
|
|
1069
|
-
|
|
1070
|
-
|
|
1077
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1078
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
1071
1079
|
action_prob_func,
|
|
1072
1080
|
)
|
|
1073
1081
|
)
|
|
1074
1082
|
|
|
1075
|
-
# 5. Now we can compute the the weighted inference estimating functions for all
|
|
1076
|
-
# as well as collect related values used to construct the
|
|
1083
|
+
# 5. Now we can compute the the weighted inference estimating functions for all subjects
|
|
1084
|
+
# as well as collect related values used to construct the adjusted and classical
|
|
1077
1085
|
# sandwich variances.
|
|
1078
1086
|
results = [
|
|
1079
|
-
|
|
1080
|
-
|
|
1087
|
+
single_subject_weighted_inference_estimating_function(
|
|
1088
|
+
subject_id,
|
|
1081
1089
|
action_prob_func,
|
|
1082
1090
|
inference_estimating_func,
|
|
1083
1091
|
action_prob_func_args_beta_index,
|
|
1084
1092
|
inference_func_args_theta_index,
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1093
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1094
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1095
|
+
threaded_inference_func_args_by_subject_id[subject_id],
|
|
1096
|
+
policy_num_by_decision_time_by_subject_id[subject_id],
|
|
1097
|
+
action_by_decision_time_by_subject_id[subject_id],
|
|
1090
1098
|
beta_index_by_policy_num,
|
|
1091
1099
|
)
|
|
1092
|
-
for
|
|
1100
|
+
for subject_id in subject_ids.tolist()
|
|
1093
1101
|
]
|
|
1094
1102
|
|
|
1095
1103
|
weighted_inference_estimating_functions = jnp.array(
|
|
@@ -1100,8 +1108,8 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1100
1108
|
|
|
1101
1109
|
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1102
1110
|
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1103
|
-
# for the
|
|
1104
|
-
# bread matrices. The raw per-
|
|
1111
|
+
# for the adjusted meat matrix, estimating functions sum check, and classical meat and inverse
|
|
1112
|
+
# bread matrices. The raw per-subject estimating functions are also returned again for debugging
|
|
1105
1113
|
# purposes.
|
|
1106
1114
|
return jnp.mean(weighted_inference_estimating_functions, axis=0), (
|
|
1107
1115
|
weighted_inference_estimating_functions,
|
|
@@ -1111,8 +1119,8 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1111
1119
|
)
|
|
1112
1120
|
|
|
1113
1121
|
|
|
1114
|
-
def
|
|
1115
|
-
|
|
1122
|
+
def single_subject_weighted_inference_estimating_function(
|
|
1123
|
+
subject_id: collections.abc.Hashable,
|
|
1116
1124
|
action_prob_func: callable,
|
|
1117
1125
|
inference_estimating_func: callable,
|
|
1118
1126
|
action_prob_func_args_beta_index: int,
|
|
@@ -1137,12 +1145,12 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1137
1145
|
and action probability function and arguments if applicable.
|
|
1138
1146
|
|
|
1139
1147
|
Args:
|
|
1140
|
-
|
|
1141
|
-
The
|
|
1148
|
+
subject_id (collections.abc.Hashable):
|
|
1149
|
+
The subject ID for which to compute the weighted estimating function stack.
|
|
1142
1150
|
|
|
1143
1151
|
action_prob_func (callable):
|
|
1144
1152
|
The function used to compute the probability of action 1 at a given decision time for
|
|
1145
|
-
a particular
|
|
1153
|
+
a particular subject given their state and the algorithm parameters.
|
|
1146
1154
|
|
|
1147
1155
|
inference_estimating_func (callable):
|
|
1148
1156
|
The estimating function that corresponds to inference.
|
|
@@ -1154,7 +1162,7 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1154
1162
|
The index of the theta parameter in the inference loss or estimating function arguments.
|
|
1155
1163
|
|
|
1156
1164
|
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
1157
|
-
A map from decision times to tuples of arguments for this
|
|
1165
|
+
A map from decision times to tuples of arguments for this subject for the action
|
|
1158
1166
|
probability function. This is for all decision times (args are an empty
|
|
1159
1167
|
tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
|
|
1160
1168
|
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
@@ -1167,11 +1175,11 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1167
1175
|
|
|
1168
1176
|
threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1169
1177
|
A tuple containing the arguments for the inference
|
|
1170
|
-
estimating function for this
|
|
1178
|
+
estimating function for this subject, with the shared betas threaded in for differentiation.
|
|
1171
1179
|
|
|
1172
1180
|
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1173
1181
|
A dictionary mapping decision times to the policy number in use. This may be
|
|
1174
|
-
|
|
1182
|
+
subject-specific. Should be sorted by decision time. Only applies to in-study decision
|
|
1175
1183
|
times!
|
|
1176
1184
|
|
|
1177
1185
|
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
@@ -1183,14 +1191,14 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1183
1191
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1184
1192
|
|
|
1185
1193
|
Returns:
|
|
1186
|
-
jnp.ndarray: A 1-D JAX NumPy array representing the
|
|
1187
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
1188
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
1194
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted inference estimating function.
|
|
1195
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
1196
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
1189
1197
|
"""
|
|
1190
1198
|
|
|
1191
1199
|
logger.info(
|
|
1192
|
-
"Computing only weighted inference estimating function stack for
|
|
1193
|
-
|
|
1200
|
+
"Computing only weighted inference estimating function stack for subject %s.",
|
|
1201
|
+
subject_id,
|
|
1194
1202
|
)
|
|
1195
1203
|
|
|
1196
1204
|
# First, reformat the supplied data into more convenient structures.
|
|
@@ -1202,12 +1210,12 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1202
1210
|
beta_index_by_policy_num,
|
|
1203
1211
|
)
|
|
1204
1212
|
|
|
1205
|
-
# 2. Get the start and end times for this
|
|
1206
|
-
|
|
1207
|
-
|
|
1213
|
+
# 2. Get the start and end times for this subject.
|
|
1214
|
+
subject_start_time = math.inf
|
|
1215
|
+
subject_end_time = -math.inf
|
|
1208
1216
|
for decision_time in action_by_decision_time:
|
|
1209
|
-
|
|
1210
|
-
|
|
1217
|
+
subject_start_time = min(subject_start_time, decision_time)
|
|
1218
|
+
subject_end_time = max(subject_end_time, decision_time)
|
|
1211
1219
|
|
|
1212
1220
|
# 3. Calculate the Radon-Nikodym weights for the inference estimating function.
|
|
1213
1221
|
in_study_action_prob_func_args = [
|
|
@@ -1224,11 +1232,11 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1224
1232
|
)
|
|
1225
1233
|
|
|
1226
1234
|
# Sort the threaded args by decision time to be cautious. We check if the
|
|
1227
|
-
#
|
|
1228
|
-
# subset of the
|
|
1235
|
+
# subject id is present in the subject args dict because we may call this on a
|
|
1236
|
+
# subset of the subject arg dict when we are batching arguments by shape
|
|
1229
1237
|
sorted_threaded_action_prob_args_by_decision_time = {
|
|
1230
1238
|
decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
|
|
1231
|
-
for decision_time in range(
|
|
1239
|
+
for decision_time in range(subject_start_time, subject_end_time + 1)
|
|
1232
1240
|
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
1233
1241
|
}
|
|
1234
1242
|
|
|
@@ -1287,12 +1295,12 @@ def single_user_weighted_inference_estimating_function(
|
|
|
1287
1295
|
# 4. Form the weighted inference estimating equation.
|
|
1288
1296
|
weighted_inference_estimating_function = jnp.prod(
|
|
1289
1297
|
all_weights[
|
|
1290
|
-
max(first_time_after_first_update,
|
|
1291
|
-
- decision_time_to_all_weights_index_offset :
|
|
1298
|
+
max(first_time_after_first_update, subject_start_time)
|
|
1299
|
+
- decision_time_to_all_weights_index_offset : subject_end_time
|
|
1292
1300
|
+ 1
|
|
1293
1301
|
- decision_time_to_all_weights_index_offset,
|
|
1294
1302
|
]
|
|
1295
|
-
# If the
|
|
1303
|
+
# If the subject exited the study before there were any updates,
|
|
1296
1304
|
# this variable will be None and the above code to grab a weight would
|
|
1297
1305
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
1298
1306
|
# if they have data to contribute here (pretty sure everyone should?)
|