lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,25 +25,25 @@ logging.basicConfig(
25
25
 
26
26
 
27
27
  def get_datum_for_blowup_supervised_learning(
28
- joint_adaptive_bread_inverse_matrix,
29
- joint_adaptive_bread_inverse_cond,
28
+ joint_adjusted_bread_inverse_matrix,
29
+ joint_adjusted_bread_inverse_cond,
30
30
  avg_estimating_function_stack,
31
- per_user_estimating_function_stacks,
31
+ per_subject_estimating_function_stacks,
32
32
  all_post_update_betas,
33
- study_df,
34
- in_study_col_name,
33
+ analysis_df,
34
+ active_col_name,
35
35
  calendar_t_col_name,
36
36
  action_prob_col_name,
37
- user_id_col_name,
37
+ subject_id_col_name,
38
38
  reward_col_name,
39
39
  theta_est,
40
- adaptive_sandwich_var_estimate,
41
- user_ids,
40
+ adjusted_sandwich_var_estimate,
41
+ subject_ids,
42
42
  beta_dim,
43
43
  theta_dim,
44
44
  initial_policy_num,
45
45
  beta_index_by_policy_num,
46
- policy_num_by_decision_time_by_user_id,
46
+ policy_num_by_decision_time_by_subject_id,
47
47
  theta_calculation_func,
48
48
  action_prob_func,
49
49
  action_prob_func_args_beta_index,
@@ -51,46 +51,46 @@ def get_datum_for_blowup_supervised_learning(
51
51
  inference_func_type,
52
52
  inference_func_args_theta_index,
53
53
  inference_func_args_action_prob_index,
54
- inference_action_prob_decision_times_by_user_id,
54
+ inference_action_prob_decision_times_by_subject_id,
55
55
  action_prob_func_args,
56
- action_by_decision_time_by_user_id,
56
+ action_by_decision_time_by_subject_id,
57
57
  ) -> dict[str, Any]:
58
58
  """
59
- Collects a datum for supervised learning about adaptive sandwich blowup.
59
+ Collects a datum for supervised learning about adjusted sandwich blowup.
60
60
 
61
- The datum consists of features and the raw adaptive sandwich variance estimate as a label.
61
+ The datum consists of features and the raw adjusted sandwich variance estimate as a label.
62
62
 
63
63
  A few plots are produced along the way to help visualize the data.
64
64
 
65
65
  Args:
66
- joint_adaptive_bread_inverse_matrix (jnp.ndarray):
67
- The joint adaptive bread inverse matrix.
68
- joint_adaptive_bread_inverse_cond (float):
69
- The condition number of the joint adaptive bread inverse matrix.
66
+ joint_adjusted_bread_inverse_matrix (jnp.ndarray):
67
+ The joint adjusted bread inverse matrix.
68
+ joint_adjusted_bread_inverse_cond (float):
69
+ The condition number of the joint adjusted bread inverse matrix.
70
70
  avg_estimating_function_stack (jnp.ndarray):
71
- The average estimating function stack across users.
72
- per_user_estimating_function_stacks (jnp.ndarray):
73
- The estimating function stacks for each user.
71
+ The average estimating function stack across subjects.
72
+ per_subject_estimating_function_stacks (jnp.ndarray):
73
+ The estimating function stacks for each subject.
74
74
  all_post_update_betas (jnp.ndarray):
75
75
  All post-update beta parameters.
76
- study_df (pd.DataFrame):
76
+ analysis_df (pd.DataFrame):
77
77
  The study DataFrame.
78
- in_study_col_name (str):
79
- Column name indicating if a user is in the study in the study dataframe.
78
+ active_col_name (str):
79
+ Column name indicating if a subject is in the study in the study dataframe.
80
80
  calendar_t_col_name (str):
81
81
  Column name for calendar time in the study dataframe.
82
82
  action_prob_col_name (str):
83
83
  Column name for action probabilities in the study dataframe.
84
- user_id_col_name (str):
85
- Column name for user IDs in the study dataframe
84
+ subject_id_col_name (str):
85
+ Column name for subject IDs in the study dataframe
86
86
  reward_col_name (str):
87
87
  Column name for rewards in the study dataframe.
88
88
  theta_est (jnp.ndarray):
89
89
  The estimate of the parameter vector theta.
90
- adaptive_sandwich_var_estimate (jnp.ndarray):
91
- The adaptive sandwich variance estimate for theta.
92
- user_ids (jnp.ndarray):
93
- Array of unique user IDs.
90
+ adjusted_sandwich_var_estimate (jnp.ndarray):
91
+ The adjusted sandwich variance estimate for theta.
92
+ subject_ids (jnp.ndarray):
93
+ Array of unique subject IDs.
94
94
  beta_dim (int):
95
95
  Dimension of the beta parameter vector.
96
96
  theta_dim (int):
@@ -99,8 +99,8 @@ def get_datum_for_blowup_supervised_learning(
99
99
  The initial policy number used in the study.
100
100
  beta_index_by_policy_num (dict[int | float, int]):
101
101
  Mapping from policy numbers to indices in all_post_update_betas.
102
- policy_num_by_decision_time_by_user_id (dict):
103
- Mapping from user IDs to their policy numbers by decision time.
102
+ policy_num_by_decision_time_by_subject_id (dict):
103
+ Mapping from subject IDs to their policy numbers by decision time.
104
104
  theta_calculation_func (callable):
105
105
  The theta calculation function.
106
106
  action_prob_func (callable):
@@ -115,17 +115,17 @@ def get_datum_for_blowup_supervised_learning(
115
115
  Index for theta in inference function arguments.
116
116
  inference_func_args_action_prob_index (int):
117
117
  Index for action probability in inference function arguments.
118
- inference_action_prob_decision_times_by_user_id (dict):
119
- Mapping from user IDs to decision times for action probabilities used in inference.
118
+ inference_action_prob_decision_times_by_subject_id (dict):
119
+ Mapping from subject IDs to decision times for action probabilities used in inference.
120
120
  action_prob_func_args (dict):
121
121
  Arguments for the action probability function.
122
- action_by_decision_time_by_user_id (dict):
123
- Mapping from user IDs to their actions by decision time.
122
+ action_by_decision_time_by_subject_id (dict):
123
+ Mapping from subject IDs to their actions by decision time.
124
124
  Returns:
125
125
  dict[str, Any]: A dictionary containing features and the label for supervised learning.
126
126
  """
127
127
  num_diagonal_blocks = (
128
- (joint_adaptive_bread_inverse_matrix.shape[0] - theta_dim) // beta_dim
128
+ (joint_adjusted_bread_inverse_matrix.shape[0] - theta_dim) // beta_dim
129
129
  ) + 1
130
130
  diagonal_block_sizes = ([beta_dim] * (num_diagonal_blocks - 1)) + [theta_dim]
131
131
 
@@ -144,7 +144,7 @@ def get_datum_for_blowup_supervised_learning(
144
144
  row_slice = slice(block_bounds[i], block_bounds[i + 1])
145
145
  col_slice = slice(block_bounds[j], block_bounds[j + 1])
146
146
  block_norm = np.linalg.norm(
147
- joint_adaptive_bread_inverse_matrix[row_slice, col_slice],
147
+ joint_adjusted_bread_inverse_matrix[row_slice, col_slice],
148
148
  ord="fro",
149
149
  )
150
150
  # We will sum here and take the square root later
@@ -155,9 +155,9 @@ def get_datum_for_blowup_supervised_learning(
155
155
  # handle diagonal blocks
156
156
  sl = slice(block_bounds[i], block_bounds[i + 1])
157
157
  diag_norms.append(
158
- np.linalg.norm(joint_adaptive_bread_inverse_matrix[sl, sl], ord="fro")
158
+ np.linalg.norm(joint_adjusted_bread_inverse_matrix[sl, sl], ord="fro")
159
159
  )
160
- diag_conds.append(np.linalg.cond(joint_adaptive_bread_inverse_matrix[sl, sl]))
160
+ diag_conds.append(np.linalg.cond(joint_adjusted_bread_inverse_matrix[sl, sl]))
161
161
 
162
162
  # Sqrt each row/col sum to truly get row/column norms.
163
163
  # Perhaps not necessary for learning, but more natural
@@ -166,7 +166,7 @@ def get_datum_for_blowup_supervised_learning(
166
166
 
167
167
  # Get the per-person estimating function stack norms
168
168
  estimating_function_stack_norms = np.linalg.norm(
169
- per_user_estimating_function_stacks, axis=1
169
+ per_subject_estimating_function_stacks, axis=1
170
170
  )
171
171
 
172
172
  # Get the average estimating function stack norms by update/inference
@@ -175,7 +175,7 @@ def get_datum_for_blowup_supervised_learning(
175
175
  avg_estimating_function_stack_norms_per_segment = [
176
176
  np.mean(
177
177
  np.linalg.norm(
178
- per_user_estimating_function_stacks[
178
+ per_subject_estimating_function_stacks[
179
179
  :, block_bounds[i] : block_bounds[i + 1]
180
180
  ],
181
181
  axis=1,
@@ -191,67 +191,67 @@ def get_datum_for_blowup_supervised_learning(
191
191
  std_successive_beta_diff_norm = np.std(successive_beta_diff_norms)
192
192
 
193
193
  # Add a column with logits of the action probabilities
194
- # Compute the average and standard deviation of the logits of the action probabilities at each decision time using study_df
194
+ # Compute the average and standard deviation of the logits of the action probabilities at each decision time using analysis_df
195
195
  # action_prob_logit_means and action_prob_logit_stds are numpy arrays of mean and stddev at each decision time
196
- # Only compute logits for rows where user is in the study; set others to NaN
197
- in_study_mask = study_df[in_study_col_name] == 1
198
- study_df["action_prob_logit"] = np.where(
196
+ # Only compute logits for rows where subject is in the study; set others to NaN
197
+ in_study_mask = analysis_df[active_col_name] == 1
198
+ analysis_df["action_prob_logit"] = np.where(
199
199
  in_study_mask,
200
- logit(study_df[action_prob_col_name]),
200
+ logit(analysis_df[action_prob_col_name]),
201
201
  np.nan,
202
202
  )
203
- grouped_action_prob_logit = study_df.loc[in_study_mask].groupby(
203
+ grouped_action_prob_logit = analysis_df.loc[in_study_mask].groupby(
204
204
  calendar_t_col_name
205
205
  )["action_prob_logit"]
206
206
  action_prob_logit_means_by_t = grouped_action_prob_logit.mean().values
207
207
  action_prob_logit_stds_by_t = grouped_action_prob_logit.std().values
208
208
 
209
- # Compute the average and standard deviation of the rewards at each decision time using study_df
209
+ # Compute the average and standard deviation of the rewards at each decision time using analysis_df
210
210
  # reward_means and reward_stds are numpy arrays of mean and stddev at each decision time
211
- grouped_reward = study_df.loc[in_study_mask].groupby(calendar_t_col_name)[
211
+ grouped_reward = analysis_df.loc[in_study_mask].groupby(calendar_t_col_name)[
212
212
  reward_col_name
213
213
  ]
214
214
  reward_means_by_t = grouped_reward.mean().values
215
215
  reward_stds_by_t = grouped_reward.std().values
216
216
 
217
217
  joint_bread_inverse_min_singular_value = np.linalg.svd(
218
- joint_adaptive_bread_inverse_matrix, compute_uv=False
218
+ joint_adjusted_bread_inverse_matrix, compute_uv=False
219
219
  )[-1]
220
220
 
221
- max_reward = study_df.loc[in_study_mask][reward_col_name].max()
221
+ max_reward = analysis_df.loc[in_study_mask][reward_col_name].max()
222
222
 
223
223
  norm_avg_estimating_function_stack = np.linalg.norm(avg_estimating_function_stack)
224
224
  max_estimating_function_stack_norm = np.max(estimating_function_stack_norms)
225
225
 
226
226
  (
227
227
  premature_thetas,
228
- premature_adaptive_sandwiches,
228
+ premature_adjusted_sandwiches,
229
229
  premature_classical_sandwiches,
230
- premature_joint_adaptive_bread_inverse_condition_numbers,
230
+ premature_joint_adjusted_bread_inverse_condition_numbers,
231
231
  premature_avg_inference_estimating_functions,
232
- ) = calculate_sequence_of_premature_adaptive_estimates(
233
- study_df,
232
+ ) = calculate_sequence_of_premature_adjusted_estimates(
233
+ analysis_df,
234
234
  initial_policy_num,
235
235
  beta_index_by_policy_num,
236
- policy_num_by_decision_time_by_user_id,
236
+ policy_num_by_decision_time_by_subject_id,
237
237
  theta_calculation_func,
238
238
  calendar_t_col_name,
239
239
  action_prob_col_name,
240
- user_id_col_name,
241
- in_study_col_name,
240
+ subject_id_col_name,
241
+ active_col_name,
242
242
  all_post_update_betas,
243
- user_ids,
243
+ subject_ids,
244
244
  action_prob_func,
245
245
  action_prob_func_args_beta_index,
246
246
  inference_func,
247
247
  inference_func_type,
248
248
  inference_func_args_theta_index,
249
249
  inference_func_args_action_prob_index,
250
- inference_action_prob_decision_times_by_user_id,
250
+ inference_action_prob_decision_times_by_subject_id,
251
251
  action_prob_func_args,
252
- action_by_decision_time_by_user_id,
253
- joint_adaptive_bread_inverse_matrix,
254
- per_user_estimating_function_stacks,
252
+ action_by_decision_time_by_subject_id,
253
+ joint_adjusted_bread_inverse_matrix,
254
+ per_subject_estimating_function_stacks,
255
255
  beta_dim,
256
256
  )
257
257
 
@@ -261,42 +261,42 @@ def get_datum_for_blowup_supervised_learning(
261
261
  atol=1e-3,
262
262
  )
263
263
 
264
- # Plot premature joint adaptive bread inverse log condition numbers
264
+ # Plot premature joint adjusted bread inverse log condition numbers
265
265
  plt.clear_figure()
266
266
  plt.title("Premature Joint Adaptive Bread Inverse Log Condition Numbers")
267
267
  plt.xlabel("Premature Update Index")
268
268
  plt.ylabel("Log Condition Number")
269
269
  plt.scatter(
270
- np.log(premature_joint_adaptive_bread_inverse_condition_numbers),
270
+ np.log(premature_joint_adjusted_bread_inverse_condition_numbers),
271
271
  color="blue+",
272
272
  )
273
273
  plt.grid(True)
274
274
  plt.xticks(
275
275
  range(
276
276
  0,
277
- len(premature_joint_adaptive_bread_inverse_condition_numbers),
277
+ len(premature_joint_adjusted_bread_inverse_condition_numbers),
278
278
  max(
279
279
  1,
280
- len(premature_joint_adaptive_bread_inverse_condition_numbers) // 10,
280
+ len(premature_joint_adjusted_bread_inverse_condition_numbers) // 10,
281
281
  ),
282
282
  )
283
283
  )
284
284
  plt.show()
285
285
 
286
- # Plot each diagonal element of premature adaptive sandwiches
287
- num_diag = premature_adaptive_sandwiches.shape[-1]
286
+ # Plot each diagonal element of premature adjusted sandwiches
287
+ num_diag = premature_adjusted_sandwiches.shape[-1]
288
288
  for i in range(num_diag):
289
289
  plt.clear_figure()
290
290
  plt.title(f"Premature Adaptive Sandwich Diagonal Element {i}")
291
291
  plt.xlabel("Premature Update Index")
292
292
  plt.ylabel(f"Variance (Diagonal {i})")
293
- plt.scatter(np.array(premature_adaptive_sandwiches[:, i, i]), color="blue+")
293
+ plt.scatter(np.array(premature_adjusted_sandwiches[:, i, i]), color="blue+")
294
294
  plt.grid(True)
295
295
  plt.xticks(
296
296
  range(
297
297
  0,
298
- int(premature_adaptive_sandwiches.shape[0]),
299
- max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
298
+ int(premature_adjusted_sandwiches.shape[0]),
299
+ max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
300
300
  )
301
301
  )
302
302
  plt.show()
@@ -308,7 +308,7 @@ def get_datum_for_blowup_supervised_learning(
308
308
  plt.xlabel("Premature Update Index")
309
309
  plt.ylabel(f"Variance (Diagonal {i})")
310
310
  plt.scatter(
311
- np.array(premature_adaptive_sandwiches[:, i, i])
311
+ np.array(premature_adjusted_sandwiches[:, i, i])
312
312
  / np.array(premature_classical_sandwiches[:, i, i]),
313
313
  color="red+",
314
314
  )
@@ -316,8 +316,8 @@ def get_datum_for_blowup_supervised_learning(
316
316
  plt.xticks(
317
317
  range(
318
318
  0,
319
- int(premature_adaptive_sandwiches.shape[0]),
320
- max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
319
+ int(premature_adjusted_sandwiches.shape[0]),
320
+ max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
321
321
  )
322
322
  )
323
323
  plt.show()
@@ -331,14 +331,14 @@ def get_datum_for_blowup_supervised_learning(
331
331
  plt.xticks(
332
332
  range(
333
333
  0,
334
- int(premature_adaptive_sandwiches.shape[0]),
335
- max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
334
+ int(premature_adjusted_sandwiches.shape[0]),
335
+ max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
336
336
  )
337
337
  )
338
338
  plt.show()
339
339
 
340
340
  # Grab predictors related to premature Phi-dot-bars
341
- RL_stack_beta_derivatives_block = joint_adaptive_bread_inverse_matrix[
341
+ RL_stack_beta_derivatives_block = joint_adjusted_bread_inverse_matrix[
342
342
  :-theta_dim, :-theta_dim
343
343
  ]
344
344
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
@@ -397,14 +397,14 @@ def get_datum_for_blowup_supervised_learning(
397
397
  )
398
398
  return {
399
399
  **{
400
- "joint_bread_inverse_condition_number": joint_adaptive_bread_inverse_cond,
400
+ "joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
401
401
  "joint_bread_inverse_min_singular_value": joint_bread_inverse_min_singular_value,
402
402
  "max_reward": max_reward,
403
403
  "norm_avg_estimating_function_stack": norm_avg_estimating_function_stack,
404
404
  "max_estimating_function_stack_norm": max_estimating_function_stack_norm,
405
405
  "max_successive_beta_diff_norm": max_successive_beta_diff_norm,
406
406
  "std_successive_beta_diff_norm": std_successive_beta_diff_norm,
407
- "label": adaptive_sandwich_var_estimate,
407
+ "label": adjusted_sandwich_var_estimate,
408
408
  },
409
409
  **{
410
410
  f"off_diag_block_{i}_{j}_norm": off_diag_block_norms[(i, j)]
@@ -422,10 +422,10 @@ def get_datum_for_blowup_supervised_learning(
422
422
  for i in range(num_block_rows_cols)
423
423
  },
424
424
  **{
425
- f"estimating_function_stack_norm_user_{user_id}": estimating_function_stack_norms[
425
+ f"estimating_function_stack_norm_subject_{subject_id}": estimating_function_stack_norms[
426
426
  i
427
427
  ]
428
- for i, user_id in enumerate(user_ids)
428
+ for i, subject_id in enumerate(subject_ids)
429
429
  },
430
430
  **{
431
431
  f"avg_estimating_function_stack_norm_segment_{i}": avg_estimating_function_stack_norms_per_segment[
@@ -455,18 +455,18 @@ def get_datum_for_blowup_supervised_learning(
455
455
  },
456
456
  **{f"theta_est_{i}": theta_est[i].item() for i in range(len(theta_est))},
457
457
  **{
458
- f"premature_joint_adaptive_bread_inverse_condition_number_{i}": premature_joint_adaptive_bread_inverse_condition_numbers[
458
+ f"premature_joint_adjusted_bread_inverse_condition_number_{i}": premature_joint_adjusted_bread_inverse_condition_numbers[
459
459
  i
460
460
  ]
461
461
  for i in range(
462
- len(premature_joint_adaptive_bread_inverse_condition_numbers)
462
+ len(premature_joint_adjusted_bread_inverse_condition_numbers)
463
463
  )
464
464
  },
465
465
  **{
466
- f"premature_adaptive_sandwich_update_{i}_diag_position_{j}": premature_adaptive_sandwich[
466
+ f"premature_adjusted_sandwich_update_{i}_diag_position_{j}": premature_adjusted_sandwich[
467
467
  j, j
468
468
  ]
469
- for premature_adaptive_sandwich in premature_adaptive_sandwiches
469
+ for premature_adjusted_sandwich in premature_adjusted_sandwiches
470
470
  for j in range(theta_dim)
471
471
  },
472
472
  **{
@@ -497,44 +497,46 @@ def get_datum_for_blowup_supervised_learning(
497
497
  }
498
498
 
499
499
 
500
- def calculate_sequence_of_premature_adaptive_estimates(
501
- study_df: pd.DataFrame,
500
+ def calculate_sequence_of_premature_adjusted_estimates(
501
+ analysis_df: pd.DataFrame,
502
502
  initial_policy_num: int | float,
503
503
  beta_index_by_policy_num: dict[int | float, int],
504
- policy_num_by_decision_time_by_user_id: dict[
504
+ policy_num_by_decision_time_by_subject_id: dict[
505
505
  collections.abc.Hashable, dict[int, int | float]
506
506
  ],
507
507
  theta_calculation_func: str,
508
508
  calendar_t_col_name: str,
509
509
  action_prob_col_name: str,
510
- user_id_col_name: str,
511
- in_study_col_name: str,
510
+ subject_id_col_name: str,
511
+ active_col_name: str,
512
512
  all_post_update_betas: jnp.ndarray,
513
- user_ids: jnp.ndarray,
513
+ subject_ids: jnp.ndarray,
514
514
  action_prob_func: str,
515
515
  action_prob_func_args_beta_index: int,
516
516
  inference_func: str,
517
517
  inference_func_type: str,
518
518
  inference_func_args_theta_index: int,
519
519
  inference_func_args_action_prob_index: int,
520
- inference_action_prob_decision_times_by_user_id: dict[
520
+ inference_action_prob_decision_times_by_subject_id: dict[
521
521
  collections.abc.Hashable, list[int]
522
522
  ],
523
- action_prob_func_args_by_user_id_by_decision_time: dict[
523
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
524
524
  int, dict[collections.abc.Hashable, tuple[Any, ...]]
525
525
  ],
526
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
527
- full_joint_adaptive_bread_inverse_matrix: jnp.ndarray,
528
- per_user_estimating_function_stacks: jnp.ndarray,
526
+ action_by_decision_time_by_subject_id: dict[
527
+ collections.abc.Hashable, dict[int, int]
528
+ ],
529
+ full_joint_adjusted_bread_inverse_matrix: jnp.ndarray,
530
+ per_subject_estimating_function_stacks: jnp.ndarray,
529
531
  beta_dim: int,
530
532
  ) -> jnp.ndarray:
531
533
  """
532
- Calculates a sequence of premature adaptive estimates for the given study DataFrame, where we
534
+ Calculates a sequence of premature adjusted estimates for the given study DataFrame, where we
533
535
  pretend the study ended after each update in sequence. The behavior of this sequence may provide
534
- insight into the stability of the final adaptive estimate.
536
+ insight into the stability of the final adjusted estimate.
535
537
 
536
538
  Args:
537
- study_df (pandas.DataFrame):
539
+ analysis_df (pandas.DataFrame):
538
540
  The DataFrame containing the study data.
539
541
  initial_policy_num (int | float): The policy number of the initial policy before any updates.
540
542
  initial_policy_num (int | float):
@@ -542,23 +544,23 @@ def calculate_sequence_of_premature_adaptive_estimates(
542
544
  beta_index_by_policy_num (dict[int | float, int]):
543
545
  A dictionary mapping policy numbers to the index of the corresponding beta in
544
546
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
545
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
546
- A map of user ids to dictionaries mapping decision times to the policy number in use.
547
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
548
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
547
549
  Only applies to in-study decision times!
548
550
  theta_calculation_func (callable):
549
551
  The filename for the theta calculation function.
550
552
  calendar_t_col_name (str):
551
- The name of the column in study_df representing calendar time.
553
+ The name of the column in analysis_df representing calendar time.
552
554
  action_prob_col_name (str):
553
- The name of the column in study_df representing action probabilities.
554
- user_id_col_name (str):
555
- The name of the column in study_df representing user IDs.
556
- in_study_col_name (str):
557
- The name of the column in study_df indicating whether the user is in the study at that time.
555
+ The name of the column in analysis_df representing action probabilities.
556
+ subject_id_col_name (str):
557
+ The name of the column in analysis_df representing subject IDs.
558
+ active_col_name (str):
559
+ The name of the column in analysis_df indicating whether the subject is in the study at that time.
558
560
  all_post_update_betas (jnp.ndarray):
559
561
  A NumPy array containing all post-update beta values.
560
- user_ids (jnp.ndarray):
561
- A NumPy array containing all user IDs in the study.
562
+ subject_ids (jnp.ndarray):
563
+ A NumPy array containing all subject IDs in the study.
562
564
  action_prob_func (callable):
563
565
  The action probability function.
564
566
  action_prob_func_args_beta_index (int):
@@ -572,56 +574,56 @@ def calculate_sequence_of_premature_adaptive_estimates(
572
574
  inference_func_args_action_prob_index (int):
573
575
  The index of action probabilities in the inference function arguments tuple, if
574
576
  applicable. -1 otherwise.
575
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
576
- For each user, a list of decision times to which action probabilities correspond if
577
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
578
+ For each subject, a list of decision times to which action probabilities correspond if
577
579
  provided. Typically just in-study times if action probabilites are used in the inference
578
580
  loss or estimating function.
579
- action_prob_func_args_by_user_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
580
- A dictionary mapping decision times to maps of user ids to the function arguments
581
- required to compute action probabilities for this user.
582
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
583
- A dictionary mapping user IDs to their respective actions taken at each decision time.
581
+ action_prob_func_args_by_subject_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
582
+ A dictionary mapping decision times to maps of subject ids to the function arguments
583
+ required to compute action probabilities for this subject.
584
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
585
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
584
586
  Only applies to in-study decision times!
585
- full_joint_adaptive_bread_inverse_matrix (jnp.ndarray):
586
- The full joint adaptive bread inverse matrix as a NumPy array.
587
- per_user_estimating_function_stacks (jnp.ndarray):
588
- A NumPy array containing all per-user (weighted) estimating function stacks.
587
+ full_joint_adjusted_bread_inverse_matrix (jnp.ndarray):
588
+ The full joint adjusted bread inverse matrix as a NumPy array.
589
+ per_subject_estimating_function_stacks (jnp.ndarray):
590
+ A NumPy array containing all per-subject (weighted) estimating function stacks.
589
591
  beta_dim (int):
590
592
  The dimension of the beta parameters.
591
593
  Returns:
592
- tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature adaptive estimates.
594
+ tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature adjusted estimates.
593
595
  """
594
596
 
595
- # Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming adaptive and classical
597
+ # Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming adjusted and classical
596
598
  # variance estimates pretending that each was the final policy.
597
- premature_adaptive_sandwiches = []
599
+ premature_adjusted_sandwiches = []
598
600
  premature_thetas = []
599
- premature_joint_adaptive_bread_inverse_condition_numbers = []
601
+ premature_joint_adjusted_bread_inverse_condition_numbers = []
600
602
  premature_avg_inference_estimating_functions = []
601
603
  premature_classical_sandwiches = []
602
604
  logger.info(
603
- "Calculating sequence of premature adaptive estimates by pretending the study ended after each update in sequence."
605
+ "Calculating sequence of premature adjusted estimates by pretending the study ended after each update in sequence."
604
606
  )
605
607
  for policy_num in sorted(beta_index_by_policy_num):
606
608
  logger.info(
607
- "Calculating premature adaptive estimate assuming policy %s is the final one.",
609
+ "Calculating premature adjusted estimate assuming policy %s is the final one.",
608
610
  policy_num,
609
611
  )
610
612
  pretend_max_policy = policy_num
611
613
 
612
- truncated_joint_adaptive_bread_inverse_matrix = (
613
- full_joint_adaptive_bread_inverse_matrix[
614
+ truncated_joint_adjusted_bread_inverse_matrix = (
615
+ full_joint_adjusted_bread_inverse_matrix[
614
616
  : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
615
617
  : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
616
618
  ]
617
619
  )
618
620
 
619
- max_decision_time = study_df[study_df["policy_num"] == pretend_max_policy][
620
- calendar_t_col_name
621
- ].max()
621
+ max_decision_time = analysis_df[
622
+ analysis_df["policy_num"] == pretend_max_policy
623
+ ][calendar_t_col_name].max()
622
624
 
623
- truncated_study_df = study_df[
624
- study_df[calendar_t_col_name] <= max_decision_time
625
+ truncated_analysis_df = analysis_df[
626
+ analysis_df[calendar_t_col_name] <= max_decision_time
625
627
  ].copy()
626
628
 
627
629
  truncated_beta_index_by_policy_num = {
@@ -632,83 +634,83 @@ def calculate_sequence_of_premature_adaptive_estimates(
632
634
 
633
635
  truncated_all_post_update_betas = all_post_update_betas[: max_beta_index + 1, :]
634
636
 
635
- premature_theta = jnp.array(theta_calculation_func(truncated_study_df))
637
+ premature_theta = jnp.array(theta_calculation_func(truncated_analysis_df))
636
638
 
637
- truncated_action_prob_func_args_by_user_id_by_decision_time = {
638
- decision_time: args_by_user_id
639
- for decision_time, args_by_user_id in action_prob_func_args_by_user_id_by_decision_time.items()
639
+ truncated_action_prob_func_args_by_subject_id_by_decision_time = {
640
+ decision_time: args_by_subject_id
641
+ for decision_time, args_by_subject_id in action_prob_func_args_by_subject_id_by_decision_time.items()
640
642
  if decision_time <= max_decision_time
641
643
  }
642
644
 
643
- truncated_inference_func_args_by_user_id, _, _ = (
645
+ truncated_inference_func_args_by_subject_id, _, _ = (
644
646
  after_study_analysis.process_inference_func_args(
645
647
  inference_func,
646
648
  inference_func_args_theta_index,
647
- truncated_study_df,
649
+ truncated_analysis_df,
648
650
  premature_theta,
649
651
  action_prob_col_name,
650
652
  calendar_t_col_name,
651
- user_id_col_name,
652
- in_study_col_name,
653
+ subject_id_col_name,
654
+ active_col_name,
653
655
  )
654
656
  )
655
657
 
656
- truncated_inference_action_prob_decision_times_by_user_id = {
657
- user_id: [
658
+ truncated_inference_action_prob_decision_times_by_subject_id = {
659
+ subject_id: [
658
660
  decision_time
659
- for decision_time in inference_action_prob_decision_times_by_user_id[
660
- user_id
661
+ for decision_time in inference_action_prob_decision_times_by_subject_id[
662
+ subject_id
661
663
  ]
662
664
  if decision_time <= max_decision_time
663
665
  ]
664
666
  # writing this way is important, handles empty dicts correctly
665
- for user_id in inference_action_prob_decision_times_by_user_id
667
+ for subject_id in inference_action_prob_decision_times_by_subject_id
666
668
  }
667
669
 
668
- truncated_action_by_decision_time_by_user_id = {
669
- user_id: {
670
+ truncated_action_by_decision_time_by_subject_id = {
671
+ subject_id: {
670
672
  decision_time: action
671
- for decision_time, action in action_by_decision_time_by_user_id[
672
- user_id
673
+ for decision_time, action in action_by_decision_time_by_subject_id[
674
+ subject_id
673
675
  ].items()
674
676
  if decision_time <= max_decision_time
675
677
  }
676
- for user_id in action_by_decision_time_by_user_id
678
+ for subject_id in action_by_decision_time_by_subject_id
677
679
  }
678
680
 
679
- truncated_per_user_estimating_function_stacks = (
680
- per_user_estimating_function_stacks[
681
+ truncated_per_subject_estimating_function_stacks = (
682
+ per_subject_estimating_function_stacks[
681
683
  :,
682
684
  : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
683
685
  ]
684
686
  )
685
687
 
686
688
  (
687
- premature_adaptive_sandwich,
689
+ premature_adjusted_sandwich,
688
690
  premature_classical_sandwich,
689
691
  premature_avg_inference_estimating_function,
690
- ) = construct_premature_classical_and_adaptive_sandwiches(
691
- truncated_joint_adaptive_bread_inverse_matrix,
692
- truncated_per_user_estimating_function_stacks,
692
+ ) = construct_premature_classical_and_adjusted_sandwiches(
693
+ truncated_joint_adjusted_bread_inverse_matrix,
694
+ truncated_per_subject_estimating_function_stacks,
693
695
  premature_theta,
694
696
  truncated_all_post_update_betas,
695
- user_ids,
697
+ subject_ids,
696
698
  action_prob_func,
697
699
  action_prob_func_args_beta_index,
698
700
  inference_func,
699
701
  inference_func_type,
700
702
  inference_func_args_theta_index,
701
703
  inference_func_args_action_prob_index,
702
- truncated_action_prob_func_args_by_user_id_by_decision_time,
703
- policy_num_by_decision_time_by_user_id,
704
+ truncated_action_prob_func_args_by_subject_id_by_decision_time,
705
+ policy_num_by_decision_time_by_subject_id,
704
706
  initial_policy_num,
705
707
  truncated_beta_index_by_policy_num,
706
- truncated_inference_func_args_by_user_id,
707
- truncated_inference_action_prob_decision_times_by_user_id,
708
- truncated_action_by_decision_time_by_user_id,
708
+ truncated_inference_func_args_by_subject_id,
709
+ truncated_inference_action_prob_decision_times_by_subject_id,
710
+ truncated_action_by_decision_time_by_subject_id,
709
711
  )
710
712
 
711
- premature_adaptive_sandwiches.append(premature_adaptive_sandwich)
713
+ premature_adjusted_sandwiches.append(premature_adjusted_sandwich)
712
714
  premature_classical_sandwiches.append(premature_classical_sandwich)
713
715
  premature_thetas.append(premature_theta)
714
716
  premature_avg_inference_estimating_functions.append(
@@ -716,38 +718,40 @@ def calculate_sequence_of_premature_adaptive_estimates(
716
718
  )
717
719
  return (
718
720
  jnp.array(premature_thetas),
719
- jnp.array(premature_adaptive_sandwiches),
721
+ jnp.array(premature_adjusted_sandwiches),
720
722
  jnp.array(premature_classical_sandwiches),
721
- jnp.array(premature_joint_adaptive_bread_inverse_condition_numbers),
723
+ jnp.array(premature_joint_adjusted_bread_inverse_condition_numbers),
722
724
  jnp.array(premature_avg_inference_estimating_functions),
723
725
  )
724
726
 
725
727
 
726
- def construct_premature_classical_and_adaptive_sandwiches(
727
- truncated_joint_adaptive_bread_inverse_matrix: jnp.ndarray,
728
- per_user_truncated_estimating_function_stacks: jnp.ndarray,
728
+ def construct_premature_classical_and_adjusted_sandwiches(
729
+ truncated_joint_adjusted_bread_inverse_matrix: jnp.ndarray,
730
+ per_subject_truncated_estimating_function_stacks: jnp.ndarray,
729
731
  theta: jnp.ndarray,
730
732
  all_post_update_betas: jnp.ndarray,
731
- user_ids: jnp.ndarray,
733
+ subject_ids: jnp.ndarray,
732
734
  action_prob_func: str,
733
735
  action_prob_func_args_beta_index: int,
734
736
  inference_func: str,
735
737
  inference_func_type: str,
736
738
  inference_func_args_theta_index: int,
737
739
  inference_func_args_action_prob_index: int,
738
- action_prob_func_args_by_user_id_by_decision_time: dict[
740
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
739
741
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
740
742
  ],
741
- policy_num_by_decision_time_by_user_id: dict[
743
+ policy_num_by_decision_time_by_subject_id: dict[
742
744
  collections.abc.Hashable, dict[int, int | float]
743
745
  ],
744
746
  initial_policy_num: int | float,
745
747
  beta_index_by_policy_num: dict[int | float, int],
746
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
747
- inference_action_prob_decision_times_by_user_id: dict[
748
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
749
+ inference_action_prob_decision_times_by_subject_id: dict[
748
750
  collections.abc.Hashable, list[int]
749
751
  ],
750
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
752
+ action_by_decision_time_by_subject_id: dict[
753
+ collections.abc.Hashable, dict[int, int]
754
+ ],
751
755
  ) -> tuple[
752
756
  jnp.ndarray[jnp.float32],
753
757
  jnp.ndarray[jnp.float32],
@@ -759,33 +763,33 @@ def construct_premature_classical_and_adaptive_sandwiches(
759
763
  jnp.ndarray[jnp.float32],
760
764
  ]:
761
765
  """
762
- Constructs the classical bread and meat matrices, as well as the adaptive bread matrix
766
+ Constructs the classical bread and meat matrices, as well as the adjusted bread matrix
763
767
  and the average weighted inference estimating function for the premature variance estimation
764
768
  procedure.
765
769
 
766
770
  This is done by computing and differentiating the new average inference estimating function
767
771
  with respect to the betas and theta, and stitching this together with the existing
768
- adaptive bread inverse matrix portion (corresponding to the updates still under consideration)
769
- to form the new premature joint adaptive bread inverse matrix.
772
+ adjusted bread inverse matrix portion (corresponding to the updates still under consideration)
773
+ to form the new premature joint adjusted bread inverse matrix.
770
774
 
771
775
  Args:
772
- truncated_joint_adaptive_bread_inverse_matrix (jnp.ndarray):
773
- A 2-D JAX NumPy array holding the existing joint adaptive bread inverse but
776
+ truncated_joint_adjusted_bread_inverse_matrix (jnp.ndarray):
777
+ A 2-D JAX NumPy array holding the existing joint adjusted bread inverse but
774
778
  with rows corresponding to updates not under consideration and inference dropped.
775
779
  We will stitch this together with the newly computed inference portion to form
776
- our "premature" joint adaptive bread inverse matrix.
777
- per_user_truncated_estimating_function_stacks (jnp.ndarray):
778
- A 2-D JAX NumPy array holding the existing per-user weighted estimating function
780
+ our "premature" joint adjusted bread inverse matrix.
781
+ per_subject_truncated_estimating_function_stacks (jnp.ndarray):
782
+ A 2-D JAX NumPy array holding the existing per-subject weighted estimating function
779
783
  stacks but with rows corresponding to updates not under consideration dropped.
780
784
  We will stitch this together with the newly computed inference estimating functions
781
- to form our "premature" joint adaptive estimating function stacks from which the new
782
- adaptive meat matrix can be computed.
785
+ to form our "premature" joint adjusted estimating function stacks from which the new
786
+ adjusted meat matrix can be computed.
783
787
  theta (jnp.ndarray):
784
788
  A 1-D JAX NumPy array representing the parameter estimate for inference.
785
789
  all_post_update_betas (jnp.ndarray):
786
790
  A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
787
- user_ids (jnp.ndarray):
788
- A 1-D JAX NumPy array holding all user IDs in the study.
791
+ subject_ids (jnp.ndarray):
792
+ A 1-D JAX NumPy array holding all subject IDs in the study.
789
793
  action_prob_func (callable):
790
794
  The action probability function.
791
795
  action_prob_func_args_beta_index (int):
@@ -799,51 +803,51 @@ def construct_premature_classical_and_adaptive_sandwiches(
799
803
  inference_func_args_action_prob_index (int):
800
804
  The index of action probabilities in the inference function arguments tuple, if
801
805
  applicable. -1 otherwise.
802
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
803
- A dictionary mapping decision times to maps of user ids to the function arguments
804
- required to compute action probabilities for this user.
805
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
806
- A map of user ids to dictionaries mapping decision times to the policy number in use.
806
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
807
+ A dictionary mapping decision times to maps of subject ids to the function arguments
808
+ required to compute action probabilities for this subject.
809
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
810
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
807
811
  Only applies to in-study decision times!
808
812
  initial_policy_num (int | float):
809
813
  The policy number of the initial policy before any updates.
810
814
  beta_index_by_policy_num (dict[int | float, int]):
811
815
  A dictionary mapping policy numbers to the index of the corresponding beta in
812
816
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
813
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
814
- A dictionary mapping user IDs to their respective inference function arguments.
815
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
816
- For each user, a list of decision times to which action probabilities correspond if
817
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
818
+ A dictionary mapping subject IDs to their respective inference function arguments.
819
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
820
+ For each subject, a list of decision times to which action probabilities correspond if
817
821
  provided. Typically just in-study times if action probabilites are used in the inference
818
822
  loss or estimating function.
819
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
820
- A dictionary mapping user IDs to their respective actions taken at each decision time.
823
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
824
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
821
825
  Only applies to in-study decision times!
822
826
  Returns:
823
827
  tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
824
828
  jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
825
829
  jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
826
830
  A tuple containing:
827
- - The joint adaptive inverse bread matrix.
828
- - The joint adaptive bread matrix.
829
- - The joint adaptive meat matrix.
831
+ - The joint adjusted inverse bread matrix.
832
+ - The joint adjusted bread matrix.
833
+ - The joint adjusted meat matrix.
830
834
  - The classical inverse bread matrix.
831
835
  - The classical bread matrix.
832
836
  - The classical meat matrix.
833
837
  - The average (weighted) inference estimating function.
834
- - The joint adaptive inverse bread matrix condition number.
838
+ - The joint adjusted inverse bread matrix condition number.
835
839
  """
836
840
  logger.info(
837
841
  "Differentiating average weighted inference estimating function stack and collecting auxiliary values."
838
842
  )
839
843
  # jax.jacobian may perform worse here--seemed to hang indefinitely while jacrev is merely very
840
844
  # slow.
841
- # Note that these "contributions" are per-user Jacobians of the weighted estimating function stack.
845
+ # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
842
846
  new_inference_block_row, (
843
- per_user_inference_estimating_functions,
847
+ per_subject_inference_estimating_functions,
844
848
  avg_inference_estimating_function,
845
- per_user_classical_meat_contributions,
846
- per_user_classical_bread_inverse_contributions,
849
+ per_subject_classical_meat_contributions,
850
+ per_subject_classical_bread_inverse_contributions,
847
851
  ) = jax.jacrev(get_weighted_inference_estimating_functions_only, has_aux=True)(
848
852
  # While JAX can technically differentiate with respect to a list of JAX arrays,
849
853
  # it is more efficient to flatten them into a single array. This is done
@@ -851,29 +855,29 @@ def construct_premature_classical_and_adaptive_sandwiches(
851
855
  after_study_analysis.flatten_params(all_post_update_betas, theta),
852
856
  all_post_update_betas.shape[1],
853
857
  theta.shape[0],
854
- user_ids,
858
+ subject_ids,
855
859
  action_prob_func,
856
860
  action_prob_func_args_beta_index,
857
861
  inference_func,
858
862
  inference_func_type,
859
863
  inference_func_args_theta_index,
860
864
  inference_func_args_action_prob_index,
861
- action_prob_func_args_by_user_id_by_decision_time,
862
- policy_num_by_decision_time_by_user_id,
865
+ action_prob_func_args_by_subject_id_by_decision_time,
866
+ policy_num_by_decision_time_by_subject_id,
863
867
  initial_policy_num,
864
868
  beta_index_by_policy_num,
865
- inference_func_args_by_user_id,
866
- inference_action_prob_decision_times_by_user_id,
867
- action_by_decision_time_by_user_id,
869
+ inference_func_args_by_subject_id,
870
+ inference_action_prob_decision_times_by_subject_id,
871
+ action_by_decision_time_by_subject_id,
868
872
  )
869
873
 
870
- joint_adaptive_bread_inverse_matrix = jnp.block(
874
+ joint_adjusted_bread_inverse_matrix = jnp.block(
871
875
  [
872
876
  [
873
- truncated_joint_adaptive_bread_inverse_matrix,
877
+ truncated_joint_adjusted_bread_inverse_matrix,
874
878
  np.zeros(
875
879
  (
876
- truncated_joint_adaptive_bread_inverse_matrix.shape[0],
880
+ truncated_joint_adjusted_bread_inverse_matrix.shape[0],
877
881
  new_inference_block_row.shape[0],
878
882
  )
879
883
  ),
@@ -881,51 +885,53 @@ def construct_premature_classical_and_adaptive_sandwiches(
881
885
  [new_inference_block_row],
882
886
  ]
883
887
  )
884
- per_user_estimating_function_stacks = jnp.concatenate(
888
+ per_subject_estimating_function_stacks = jnp.concatenate(
885
889
  [
886
- per_user_truncated_estimating_function_stacks,
887
- per_user_inference_estimating_functions,
890
+ per_subject_truncated_estimating_function_stacks,
891
+ per_subject_inference_estimating_functions,
888
892
  ],
889
893
  axis=1,
890
894
  )
891
- per_user_adaptive_meat_contributions = jnp.einsum(
895
+ per_subject_adjusted_meat_contributions = jnp.einsum(
892
896
  "ni,nj->nij",
893
- per_user_estimating_function_stacks,
894
- per_user_estimating_function_stacks,
897
+ per_subject_estimating_function_stacks,
898
+ per_subject_estimating_function_stacks,
895
899
  )
896
900
 
897
- joint_adaptive_meat_matrix = jnp.mean(per_user_adaptive_meat_contributions, axis=0)
901
+ joint_adjusted_meat_matrix = jnp.mean(
902
+ per_subject_adjusted_meat_contributions, axis=0
903
+ )
898
904
 
899
905
  classical_bread_inverse_matrix = jnp.mean(
900
- per_user_classical_bread_inverse_contributions, axis=0
906
+ per_subject_classical_bread_inverse_contributions, axis=0
901
907
  )
902
- classical_meat_matrix = jnp.mean(per_user_classical_meat_contributions, axis=0)
908
+ classical_meat_matrix = jnp.mean(per_subject_classical_meat_contributions, axis=0)
903
909
 
904
- num_users = user_ids.shape[0]
905
- joint_adaptive_sandwich = (
910
+ num_subjects = subject_ids.shape[0]
911
+ joint_adjusted_sandwich = (
906
912
  after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
907
- joint_adaptive_bread_inverse_matrix,
908
- joint_adaptive_meat_matrix,
909
- num_users,
913
+ joint_adjusted_bread_inverse_matrix,
914
+ joint_adjusted_meat_matrix,
915
+ num_subjects,
910
916
  method="bread_inverse_T_qr",
911
917
  )
912
918
  )
913
- adaptive_sandwich = joint_adaptive_sandwich[-theta.shape[0] :, -theta.shape[0] :]
919
+ adjusted_sandwich = joint_adjusted_sandwich[-theta.shape[0] :, -theta.shape[0] :]
914
920
 
915
921
  classical_bread_inverse_matrix = jnp.mean(
916
- per_user_classical_bread_inverse_contributions, axis=0
922
+ per_subject_classical_bread_inverse_contributions, axis=0
917
923
  )
918
924
  classical_sandwich = after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
919
925
  classical_bread_inverse_matrix,
920
926
  classical_meat_matrix,
921
- num_users,
927
+ num_subjects,
922
928
  method="bread_inverse_T_qr",
923
929
  )
924
930
 
925
- # Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
926
- # values too. The joint adaptive bread inverse should always be block lower triangular.
931
+ # Stack the joint adjusted inverse bread pieces together horizontally and return the auxiliary
932
+ # values too. The joint adjusted bread inverse should always be block lower triangular.
927
933
  return (
928
- adaptive_sandwich,
934
+ adjusted_sandwich,
929
935
  classical_sandwich,
930
936
  avg_inference_estimating_function,
931
937
  )
@@ -935,32 +941,34 @@ def get_weighted_inference_estimating_functions_only(
935
941
  flattened_betas_and_theta: jnp.ndarray,
936
942
  beta_dim: int,
937
943
  theta_dim: int,
938
- user_ids: jnp.ndarray,
944
+ subject_ids: jnp.ndarray,
939
945
  action_prob_func: callable,
940
946
  action_prob_func_args_beta_index: int,
941
947
  inference_func: callable,
942
948
  inference_func_type: str,
943
949
  inference_func_args_theta_index: int,
944
950
  inference_func_args_action_prob_index: int,
945
- action_prob_func_args_by_user_id_by_decision_time: dict[
951
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
946
952
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
947
953
  ],
948
- policy_num_by_decision_time_by_user_id: dict[
954
+ policy_num_by_decision_time_by_subject_id: dict[
949
955
  collections.abc.Hashable, dict[int, int | float]
950
956
  ],
951
957
  initial_policy_num: int | float,
952
958
  beta_index_by_policy_num: dict[int | float, int],
953
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
954
- inference_action_prob_decision_times_by_user_id: dict[
959
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
960
+ inference_action_prob_decision_times_by_subject_id: dict[
955
961
  collections.abc.Hashable, list[int]
956
962
  ],
957
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
963
+ action_by_decision_time_by_subject_id: dict[
964
+ collections.abc.Hashable, dict[int, int]
965
+ ],
958
966
  ) -> tuple[
959
967
  jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
960
968
  ]:
961
969
  """
962
- Computes the average weighted inference estimating function across users, along with
963
- auxiliary values used to construct the adaptive and classical sandwich variances.
970
+ Computes the average weighted inference estimating function across subjects, along with
971
+ auxiliary values used to construct the adjusted and classical sandwich variances.
964
972
 
965
973
  Note that input data should have been adjusted to only correspond to updates/decision times
966
974
  that are being considered for the current "premature" variance estimation procedure.
@@ -974,8 +982,8 @@ def get_weighted_inference_estimating_functions_only(
974
982
  The dimension of each of the beta parameters.
975
983
  theta_dim (int):
976
984
  The dimension of the theta parameter.
977
- user_ids (jnp.ndarray):
978
- A 1D JAX NumPy array of user IDs.
985
+ subject_ids (jnp.ndarray):
986
+ A 1D JAX NumPy array of subject IDs.
979
987
  action_prob_func (str):
980
988
  The action probability function.
981
989
  action_prob_func_args_beta_index (int):
@@ -989,25 +997,25 @@ def get_weighted_inference_estimating_functions_only(
989
997
  inference_func_args_action_prob_index (int):
990
998
  The index of action probabilities in the inference function arguments tuple, if
991
999
  applicable. -1 otherwise.
992
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
993
- A dictionary mapping decision times to maps of user ids to the function arguments
994
- required to compute action probabilities for this user.
995
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
996
- A map of user ids to dictionaries mapping decision times to the policy number in use.
1000
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1001
+ A dictionary mapping decision times to maps of subject ids to the function arguments
1002
+ required to compute action probabilities for this subject.
1003
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1004
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
997
1005
  Only applies to in-study decision times!
998
1006
  initial_policy_num (int | float):
999
1007
  The policy number of the initial policy before any updates.
1000
1008
  beta_index_by_policy_num (dict[int | float, int]):
1001
1009
  A dictionary mapping policy numbers to the index of the corresponding beta in
1002
1010
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1003
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1004
- A dictionary mapping user IDs to their respective inference function arguments.
1005
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1006
- For each user, a list of decision times to which action probabilities correspond if
1011
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1012
+ A dictionary mapping subject IDs to their respective inference function arguments.
1013
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
1014
+ For each subject, a list of decision times to which action probabilities correspond if
1007
1015
  provided. Typically just in-study times if action probabilites are used in the inference
1008
1016
  loss or estimating function.
1009
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1010
- A dictionary mapping user IDs to their respective actions taken at each decision time.
1017
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
1018
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
1011
1019
  Only applies to in-study decision times!
1012
1020
 
1013
1021
  Returns:
@@ -1015,10 +1023,10 @@ def get_weighted_inference_estimating_functions_only(
1015
1023
  A 2D JAX NumPy array holding the average weighted inference estimating function.
1016
1024
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1017
1025
  A tuple containing
1018
- 1. the per-user weighted inference estimating function stacks
1026
+ 1. the per-subject weighted inference estimating function stacks
1019
1027
  2. the average weighted inference estimating function
1020
- 3. the user-level classical meat matrix contributions
1021
- 4. the user-level inverse classical bread matrix contributions
1028
+ 3. the subject-level classical meat matrix contributions
1029
+ 4. the subject-level inverse classical bread matrix contributions
1022
1030
  stacks.
1023
1031
  """
1024
1032
 
@@ -1038,15 +1046,15 @@ def get_weighted_inference_estimating_functions_only(
1038
1046
  # supplied for the above functions, so that differentiation works correctly. The existing
1039
1047
  # values should be the same, but not connected to the parameter we are differentiating
1040
1048
  # with respect to. Note we will also find it useful below to have the action probability args
1041
- # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
1049
+ # nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
1042
1050
 
1043
- logger.info("Threading in betas to action probability arguments for all users.")
1051
+ logger.info("Threading in betas to action probability arguments for all subjects.")
1044
1052
  (
1045
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1046
- action_prob_func_args_by_decision_time_by_user_id,
1053
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1054
+ action_prob_func_args_by_decision_time_by_subject_id,
1047
1055
  ) = after_study_analysis.thread_action_prob_func_args(
1048
- action_prob_func_args_by_user_id_by_decision_time,
1049
- policy_num_by_decision_time_by_user_id,
1056
+ action_prob_func_args_by_subject_id_by_decision_time,
1057
+ policy_num_by_decision_time_by_subject_id,
1050
1058
  initial_policy_num,
1051
1059
  betas,
1052
1060
  beta_index_by_policy_num,
@@ -1058,38 +1066,38 @@ def get_weighted_inference_estimating_functions_only(
1058
1066
  # arguments with the central betas introduced.
1059
1067
  logger.info(
1060
1068
  "Threading in theta and beta-dependent action probabilities to inference update "
1061
- "function args for all users"
1069
+ "function args for all subjects"
1062
1070
  )
1063
- threaded_inference_func_args_by_user_id = (
1071
+ threaded_inference_func_args_by_subject_id = (
1064
1072
  after_study_analysis.thread_inference_func_args(
1065
- inference_func_args_by_user_id,
1073
+ inference_func_args_by_subject_id,
1066
1074
  inference_func_args_theta_index,
1067
1075
  theta,
1068
1076
  inference_func_args_action_prob_index,
1069
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1070
- inference_action_prob_decision_times_by_user_id,
1077
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1078
+ inference_action_prob_decision_times_by_subject_id,
1071
1079
  action_prob_func,
1072
1080
  )
1073
1081
  )
1074
1082
 
1075
- # 5. Now we can compute the the weighted inference estimating functions for all users
1076
- # as well as collect related values used to construct the adaptive and classical
1083
+ # 5. Now we can compute the the weighted inference estimating functions for all subjects
1084
+ # as well as collect related values used to construct the adjusted and classical
1077
1085
  # sandwich variances.
1078
1086
  results = [
1079
- single_user_weighted_inference_estimating_function(
1080
- user_id,
1087
+ single_subject_weighted_inference_estimating_function(
1088
+ subject_id,
1081
1089
  action_prob_func,
1082
1090
  inference_estimating_func,
1083
1091
  action_prob_func_args_beta_index,
1084
1092
  inference_func_args_theta_index,
1085
- action_prob_func_args_by_decision_time_by_user_id[user_id],
1086
- threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
1087
- threaded_inference_func_args_by_user_id[user_id],
1088
- policy_num_by_decision_time_by_user_id[user_id],
1089
- action_by_decision_time_by_user_id[user_id],
1093
+ action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1094
+ threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1095
+ threaded_inference_func_args_by_subject_id[subject_id],
1096
+ policy_num_by_decision_time_by_subject_id[subject_id],
1097
+ action_by_decision_time_by_subject_id[subject_id],
1090
1098
  beta_index_by_policy_num,
1091
1099
  )
1092
- for user_id in user_ids.tolist()
1100
+ for subject_id in subject_ids.tolist()
1093
1101
  ]
1094
1102
 
1095
1103
  weighted_inference_estimating_functions = jnp.array(
@@ -1100,8 +1108,8 @@ def get_weighted_inference_estimating_functions_only(
1100
1108
 
1101
1109
  # 6. Note this strange return structure! We will differentiate the first output,
1102
1110
  # but the second tuple will be passed along without modification via has_aux=True and then used
1103
- # for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
1104
- # bread matrices. The raw per-user estimating functions are also returned again for debugging
1111
+ # for the adjusted meat matrix, estimating functions sum check, and classical meat and inverse
1112
+ # bread matrices. The raw per-subject estimating functions are also returned again for debugging
1105
1113
  # purposes.
1106
1114
  return jnp.mean(weighted_inference_estimating_functions, axis=0), (
1107
1115
  weighted_inference_estimating_functions,
@@ -1111,8 +1119,8 @@ def get_weighted_inference_estimating_functions_only(
1111
1119
  )
1112
1120
 
1113
1121
 
1114
- def single_user_weighted_inference_estimating_function(
1115
- user_id: collections.abc.Hashable,
1122
+ def single_subject_weighted_inference_estimating_function(
1123
+ subject_id: collections.abc.Hashable,
1116
1124
  action_prob_func: callable,
1117
1125
  inference_estimating_func: callable,
1118
1126
  action_prob_func_args_beta_index: int,
@@ -1137,12 +1145,12 @@ def single_user_weighted_inference_estimating_function(
1137
1145
  and action probability function and arguments if applicable.
1138
1146
 
1139
1147
  Args:
1140
- user_id (collections.abc.Hashable):
1141
- The user ID for which to compute the weighted estimating function stack.
1148
+ subject_id (collections.abc.Hashable):
1149
+ The subject ID for which to compute the weighted estimating function stack.
1142
1150
 
1143
1151
  action_prob_func (callable):
1144
1152
  The function used to compute the probability of action 1 at a given decision time for
1145
- a particular user given their state and the algorithm parameters.
1153
+ a particular subject given their state and the algorithm parameters.
1146
1154
 
1147
1155
  inference_estimating_func (callable):
1148
1156
  The estimating function that corresponds to inference.
@@ -1154,7 +1162,7 @@ def single_user_weighted_inference_estimating_function(
1154
1162
  The index of the theta parameter in the inference loss or estimating function arguments.
1155
1163
 
1156
1164
  action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
1157
- A map from decision times to tuples of arguments for this user for the action
1165
+ A map from decision times to tuples of arguments for this subject for the action
1158
1166
  probability function. This is for all decision times (args are an empty
1159
1167
  tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
1160
1168
  ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
@@ -1167,11 +1175,11 @@ def single_user_weighted_inference_estimating_function(
1167
1175
 
1168
1176
  threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
1169
1177
  A tuple containing the arguments for the inference
1170
- estimating function for this user, with the shared betas threaded in for differentiation.
1178
+ estimating function for this subject, with the shared betas threaded in for differentiation.
1171
1179
 
1172
1180
  policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
1173
1181
  A dictionary mapping decision times to the policy number in use. This may be
1174
- user-specific. Should be sorted by decision time. Only applies to in-study decision
1182
+ subject-specific. Should be sorted by decision time. Only applies to in-study decision
1175
1183
  times!
1176
1184
 
1177
1185
  action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
@@ -1183,14 +1191,14 @@ def single_user_weighted_inference_estimating_function(
1183
1191
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1184
1192
 
1185
1193
  Returns:
1186
- jnp.ndarray: A 1-D JAX NumPy array representing the user's weighted inference estimating function.
1187
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical meat contribution.
1188
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical bread contribution.
1194
+ jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted inference estimating function.
1195
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
1196
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
1189
1197
  """
1190
1198
 
1191
1199
  logger.info(
1192
- "Computing only weighted inference estimating function stack for user %s.",
1193
- user_id,
1200
+ "Computing only weighted inference estimating function stack for subject %s.",
1201
+ subject_id,
1194
1202
  )
1195
1203
 
1196
1204
  # First, reformat the supplied data into more convenient structures.
@@ -1202,12 +1210,12 @@ def single_user_weighted_inference_estimating_function(
1202
1210
  beta_index_by_policy_num,
1203
1211
  )
1204
1212
 
1205
- # 2. Get the start and end times for this user.
1206
- user_start_time = math.inf
1207
- user_end_time = -math.inf
1213
+ # 2. Get the start and end times for this subject.
1214
+ subject_start_time = math.inf
1215
+ subject_end_time = -math.inf
1208
1216
  for decision_time in action_by_decision_time:
1209
- user_start_time = min(user_start_time, decision_time)
1210
- user_end_time = max(user_end_time, decision_time)
1217
+ subject_start_time = min(subject_start_time, decision_time)
1218
+ subject_end_time = max(subject_end_time, decision_time)
1211
1219
 
1212
1220
  # 3. Calculate the Radon-Nikodym weights for the inference estimating function.
1213
1221
  in_study_action_prob_func_args = [
@@ -1224,11 +1232,11 @@ def single_user_weighted_inference_estimating_function(
1224
1232
  )
1225
1233
 
1226
1234
  # Sort the threaded args by decision time to be cautious. We check if the
1227
- # user id is present in the user args dict because we may call this on a
1228
- # subset of the user arg dict when we are batching arguments by shape
1235
+ # subject id is present in the subject args dict because we may call this on a
1236
+ # subset of the subject arg dict when we are batching arguments by shape
1229
1237
  sorted_threaded_action_prob_args_by_decision_time = {
1230
1238
  decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
1231
- for decision_time in range(user_start_time, user_end_time + 1)
1239
+ for decision_time in range(subject_start_time, subject_end_time + 1)
1232
1240
  if decision_time in threaded_action_prob_func_args_by_decision_time
1233
1241
  }
1234
1242
 
@@ -1287,12 +1295,12 @@ def single_user_weighted_inference_estimating_function(
1287
1295
  # 4. Form the weighted inference estimating equation.
1288
1296
  weighted_inference_estimating_function = jnp.prod(
1289
1297
  all_weights[
1290
- max(first_time_after_first_update, user_start_time)
1291
- - decision_time_to_all_weights_index_offset : user_end_time
1298
+ max(first_time_after_first_update, subject_start_time)
1299
+ - decision_time_to_all_weights_index_offset : subject_end_time
1292
1300
  + 1
1293
1301
  - decision_time_to_all_weights_index_offset,
1294
1302
  ]
1295
- # If the user exited the study before there were any updates,
1303
+ # If the subject exited the study before there were any updates,
1296
1304
  # this variable will be None and the above code to grab a weight would
1297
1305
  # throw an error. Just use 1 to include the unweighted estimating function
1298
1306
  # if they have data to contribute here (pretty sure everyone should?)