lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@ import jax
12
12
  from jax import numpy as jnp
13
13
  import pandas as pd
14
14
 
15
- from . import after_study_analysis
15
+ from . import post_deployment_analysis
16
16
  from .constants import FunctionTypes
17
17
  from .vmap_helpers import stack_batched_arg_lists_into_tensors
18
18
 
@@ -25,25 +25,25 @@ logging.basicConfig(
25
25
 
26
26
 
27
27
  def get_datum_for_blowup_supervised_learning(
28
- joint_adaptive_bread_inverse_matrix,
29
- joint_adaptive_bread_inverse_cond,
28
+ joint_adjusted_bread_matrix,
29
+ joint_adjusted_bread_cond,
30
30
  avg_estimating_function_stack,
31
- per_user_estimating_function_stacks,
31
+ per_subject_estimating_function_stacks,
32
32
  all_post_update_betas,
33
- study_df,
34
- in_study_col_name,
33
+ analysis_df,
34
+ active_col_name,
35
35
  calendar_t_col_name,
36
36
  action_prob_col_name,
37
- user_id_col_name,
37
+ subject_id_col_name,
38
38
  reward_col_name,
39
39
  theta_est,
40
- adaptive_sandwich_var_estimate,
41
- user_ids,
40
+ adjusted_sandwich_var_estimate,
41
+ subject_ids,
42
42
  beta_dim,
43
43
  theta_dim,
44
44
  initial_policy_num,
45
45
  beta_index_by_policy_num,
46
- policy_num_by_decision_time_by_user_id,
46
+ policy_num_by_decision_time_by_subject_id,
47
47
  theta_calculation_func,
48
48
  action_prob_func,
49
49
  action_prob_func_args_beta_index,
@@ -51,46 +51,46 @@ def get_datum_for_blowup_supervised_learning(
51
51
  inference_func_type,
52
52
  inference_func_args_theta_index,
53
53
  inference_func_args_action_prob_index,
54
- inference_action_prob_decision_times_by_user_id,
54
+ inference_action_prob_decision_times_by_subject_id,
55
55
  action_prob_func_args,
56
- action_by_decision_time_by_user_id,
56
+ action_by_decision_time_by_subject_id,
57
57
  ) -> dict[str, Any]:
58
58
  """
59
- Collects a datum for supervised learning about adaptive sandwich blowup.
59
+ Collects a datum for supervised learning about adjusted sandwich blowup.
60
60
 
61
- The datum consists of features and the raw adaptive sandwich variance estimate as a label.
61
+ The datum consists of features and the raw adjusted sandwich variance estimate as a label.
62
62
 
63
63
  A few plots are produced along the way to help visualize the data.
64
64
 
65
65
  Args:
66
- joint_adaptive_bread_inverse_matrix (jnp.ndarray):
67
- The joint adaptive bread inverse matrix.
68
- joint_adaptive_bread_inverse_cond (float):
69
- The condition number of the joint adaptive bread inverse matrix.
66
+ joint_adjusted_bread_matrix (jnp.ndarray):
67
+ The joint adjusted bread matrix.
68
+ joint_adjusted_bread_cond (float):
69
+ The condition number of the joint adjusted bread matrix.
70
70
  avg_estimating_function_stack (jnp.ndarray):
71
- The average estimating function stack across users.
72
- per_user_estimating_function_stacks (jnp.ndarray):
73
- The estimating function stacks for each user.
71
+ The average estimating function stack across subjects.
72
+ per_subject_estimating_function_stacks (jnp.ndarray):
73
+ The estimating function stacks for each subject.
74
74
  all_post_update_betas (jnp.ndarray):
75
75
  All post-update beta parameters.
76
- study_df (pd.DataFrame):
76
+ analysis_df (pd.DataFrame):
77
77
  The study DataFrame.
78
- in_study_col_name (str):
79
- Column name indicating if a user is in the study in the study dataframe.
78
+ active_col_name (str):
79
+ Column name indicating if a subject is in the study in the study dataframe.
80
80
  calendar_t_col_name (str):
81
81
  Column name for calendar time in the study dataframe.
82
82
  action_prob_col_name (str):
83
83
  Column name for action probabilities in the study dataframe.
84
- user_id_col_name (str):
85
- Column name for user IDs in the study dataframe
84
+ subject_id_col_name (str):
85
+ Column name for subject IDs in the study dataframe
86
86
  reward_col_name (str):
87
87
  Column name for rewards in the study dataframe.
88
88
  theta_est (jnp.ndarray):
89
89
  The estimate of the parameter vector theta.
90
- adaptive_sandwich_var_estimate (jnp.ndarray):
91
- The adaptive sandwich variance estimate for theta.
92
- user_ids (jnp.ndarray):
93
- Array of unique user IDs.
90
+ adjusted_sandwich_var_estimate (jnp.ndarray):
91
+ The adjusted sandwich variance estimate for theta.
92
+ subject_ids (jnp.ndarray):
93
+ Array of unique subject IDs.
94
94
  beta_dim (int):
95
95
  Dimension of the beta parameter vector.
96
96
  theta_dim (int):
@@ -99,8 +99,8 @@ def get_datum_for_blowup_supervised_learning(
99
99
  The initial policy number used in the study.
100
100
  beta_index_by_policy_num (dict[int | float, int]):
101
101
  Mapping from policy numbers to indices in all_post_update_betas.
102
- policy_num_by_decision_time_by_user_id (dict):
103
- Mapping from user IDs to their policy numbers by decision time.
102
+ policy_num_by_decision_time_by_subject_id (dict):
103
+ Mapping from subject IDs to their policy numbers by decision time.
104
104
  theta_calculation_func (callable):
105
105
  The theta calculation function.
106
106
  action_prob_func (callable):
@@ -115,17 +115,17 @@ def get_datum_for_blowup_supervised_learning(
115
115
  Index for theta in inference function arguments.
116
116
  inference_func_args_action_prob_index (int):
117
117
  Index for action probability in inference function arguments.
118
- inference_action_prob_decision_times_by_user_id (dict):
119
- Mapping from user IDs to decision times for action probabilities used in inference.
118
+ inference_action_prob_decision_times_by_subject_id (dict):
119
+ Mapping from subject IDs to decision times for action probabilities used in inference.
120
120
  action_prob_func_args (dict):
121
121
  Arguments for the action probability function.
122
- action_by_decision_time_by_user_id (dict):
123
- Mapping from user IDs to their actions by decision time.
122
+ action_by_decision_time_by_subject_id (dict):
123
+ Mapping from subject IDs to their actions by decision time.
124
124
  Returns:
125
125
  dict[str, Any]: A dictionary containing features and the label for supervised learning.
126
126
  """
127
127
  num_diagonal_blocks = (
128
- (joint_adaptive_bread_inverse_matrix.shape[0] - theta_dim) // beta_dim
128
+ (joint_adjusted_bread_matrix.shape[0] - theta_dim) // beta_dim
129
129
  ) + 1
130
130
  diagonal_block_sizes = ([beta_dim] * (num_diagonal_blocks - 1)) + [theta_dim]
131
131
 
@@ -144,7 +144,7 @@ def get_datum_for_blowup_supervised_learning(
144
144
  row_slice = slice(block_bounds[i], block_bounds[i + 1])
145
145
  col_slice = slice(block_bounds[j], block_bounds[j + 1])
146
146
  block_norm = np.linalg.norm(
147
- joint_adaptive_bread_inverse_matrix[row_slice, col_slice],
147
+ joint_adjusted_bread_matrix[row_slice, col_slice],
148
148
  ord="fro",
149
149
  )
150
150
  # We will sum here and take the square root later
@@ -155,9 +155,9 @@ def get_datum_for_blowup_supervised_learning(
155
155
  # handle diagonal blocks
156
156
  sl = slice(block_bounds[i], block_bounds[i + 1])
157
157
  diag_norms.append(
158
- np.linalg.norm(joint_adaptive_bread_inverse_matrix[sl, sl], ord="fro")
158
+ np.linalg.norm(joint_adjusted_bread_matrix[sl, sl], ord="fro")
159
159
  )
160
- diag_conds.append(np.linalg.cond(joint_adaptive_bread_inverse_matrix[sl, sl]))
160
+ diag_conds.append(np.linalg.cond(joint_adjusted_bread_matrix[sl, sl]))
161
161
 
162
162
  # Sqrt each row/col sum to truly get row/column norms.
163
163
  # Perhaps not necessary for learning, but more natural
@@ -166,7 +166,7 @@ def get_datum_for_blowup_supervised_learning(
166
166
 
167
167
  # Get the per-person estimating function stack norms
168
168
  estimating_function_stack_norms = np.linalg.norm(
169
- per_user_estimating_function_stacks, axis=1
169
+ per_subject_estimating_function_stacks, axis=1
170
170
  )
171
171
 
172
172
  # Get the average estimating function stack norms by update/inference
@@ -175,7 +175,7 @@ def get_datum_for_blowup_supervised_learning(
175
175
  avg_estimating_function_stack_norms_per_segment = [
176
176
  np.mean(
177
177
  np.linalg.norm(
178
- per_user_estimating_function_stacks[
178
+ per_subject_estimating_function_stacks[
179
179
  :, block_bounds[i] : block_bounds[i + 1]
180
180
  ],
181
181
  axis=1,
@@ -191,67 +191,67 @@ def get_datum_for_blowup_supervised_learning(
191
191
  std_successive_beta_diff_norm = np.std(successive_beta_diff_norms)
192
192
 
193
193
  # Add a column with logits of the action probabilities
194
- # Compute the average and standard deviation of the logits of the action probabilities at each decision time using study_df
194
+ # Compute the average and standard deviation of the logits of the action probabilities at each decision time using analysis_df
195
195
  # action_prob_logit_means and action_prob_logit_stds are numpy arrays of mean and stddev at each decision time
196
- # Only compute logits for rows where user is in the study; set others to NaN
197
- in_study_mask = study_df[in_study_col_name] == 1
198
- study_df["action_prob_logit"] = np.where(
196
+ # Only compute logits for rows where subject is in the study; set others to NaN
197
+ in_study_mask = analysis_df[active_col_name] == 1
198
+ analysis_df["action_prob_logit"] = np.where(
199
199
  in_study_mask,
200
- logit(study_df[action_prob_col_name]),
200
+ logit(analysis_df[action_prob_col_name]),
201
201
  np.nan,
202
202
  )
203
- grouped_action_prob_logit = study_df.loc[in_study_mask].groupby(
203
+ grouped_action_prob_logit = analysis_df.loc[in_study_mask].groupby(
204
204
  calendar_t_col_name
205
205
  )["action_prob_logit"]
206
206
  action_prob_logit_means_by_t = grouped_action_prob_logit.mean().values
207
207
  action_prob_logit_stds_by_t = grouped_action_prob_logit.std().values
208
208
 
209
- # Compute the average and standard deviation of the rewards at each decision time using study_df
209
+ # Compute the average and standard deviation of the rewards at each decision time using analysis_df
210
210
  # reward_means and reward_stds are numpy arrays of mean and stddev at each decision time
211
- grouped_reward = study_df.loc[in_study_mask].groupby(calendar_t_col_name)[
211
+ grouped_reward = analysis_df.loc[in_study_mask].groupby(calendar_t_col_name)[
212
212
  reward_col_name
213
213
  ]
214
214
  reward_means_by_t = grouped_reward.mean().values
215
215
  reward_stds_by_t = grouped_reward.std().values
216
216
 
217
- joint_bread_inverse_min_singular_value = np.linalg.svd(
218
- joint_adaptive_bread_inverse_matrix, compute_uv=False
217
+ joint_bread_min_singular_value = np.linalg.svd(
218
+ joint_adjusted_bread_matrix, compute_uv=False
219
219
  )[-1]
220
220
 
221
- max_reward = study_df.loc[in_study_mask][reward_col_name].max()
221
+ max_reward = analysis_df.loc[in_study_mask][reward_col_name].max()
222
222
 
223
223
  norm_avg_estimating_function_stack = np.linalg.norm(avg_estimating_function_stack)
224
224
  max_estimating_function_stack_norm = np.max(estimating_function_stack_norms)
225
225
 
226
226
  (
227
227
  premature_thetas,
228
- premature_adaptive_sandwiches,
228
+ premature_adjusted_sandwiches,
229
229
  premature_classical_sandwiches,
230
- premature_joint_adaptive_bread_inverse_condition_numbers,
230
+ premature_joint_adjusted_bread_condition_numbers,
231
231
  premature_avg_inference_estimating_functions,
232
- ) = calculate_sequence_of_premature_adaptive_estimates(
233
- study_df,
232
+ ) = calculate_sequence_of_premature_adjusted_estimates(
233
+ analysis_df,
234
234
  initial_policy_num,
235
235
  beta_index_by_policy_num,
236
- policy_num_by_decision_time_by_user_id,
236
+ policy_num_by_decision_time_by_subject_id,
237
237
  theta_calculation_func,
238
238
  calendar_t_col_name,
239
239
  action_prob_col_name,
240
- user_id_col_name,
241
- in_study_col_name,
240
+ subject_id_col_name,
241
+ active_col_name,
242
242
  all_post_update_betas,
243
- user_ids,
243
+ subject_ids,
244
244
  action_prob_func,
245
245
  action_prob_func_args_beta_index,
246
246
  inference_func,
247
247
  inference_func_type,
248
248
  inference_func_args_theta_index,
249
249
  inference_func_args_action_prob_index,
250
- inference_action_prob_decision_times_by_user_id,
250
+ inference_action_prob_decision_times_by_subject_id,
251
251
  action_prob_func_args,
252
- action_by_decision_time_by_user_id,
253
- joint_adaptive_bread_inverse_matrix,
254
- per_user_estimating_function_stacks,
252
+ action_by_decision_time_by_subject_id,
253
+ joint_adjusted_bread_matrix,
254
+ per_subject_estimating_function_stacks,
255
255
  beta_dim,
256
256
  )
257
257
 
@@ -261,54 +261,54 @@ def get_datum_for_blowup_supervised_learning(
261
261
  atol=1e-3,
262
262
  )
263
263
 
264
- # Plot premature joint adaptive bread inverse log condition numbers
264
+ # Plot premature joint adjusted bread log condition numbers
265
265
  plt.clear_figure()
266
- plt.title("Premature Joint Adaptive Bread Inverse Log Condition Numbers")
266
+ plt.title("Premature Joint Adjusted Bread Inverse Log Condition Numbers")
267
267
  plt.xlabel("Premature Update Index")
268
268
  plt.ylabel("Log Condition Number")
269
269
  plt.scatter(
270
- np.log(premature_joint_adaptive_bread_inverse_condition_numbers),
270
+ np.log(premature_joint_adjusted_bread_condition_numbers),
271
271
  color="blue+",
272
272
  )
273
273
  plt.grid(True)
274
274
  plt.xticks(
275
275
  range(
276
276
  0,
277
- len(premature_joint_adaptive_bread_inverse_condition_numbers),
277
+ len(premature_joint_adjusted_bread_condition_numbers),
278
278
  max(
279
279
  1,
280
- len(premature_joint_adaptive_bread_inverse_condition_numbers) // 10,
280
+ len(premature_joint_adjusted_bread_condition_numbers) // 10,
281
281
  ),
282
282
  )
283
283
  )
284
284
  plt.show()
285
285
 
286
- # Plot each diagonal element of premature adaptive sandwiches
287
- num_diag = premature_adaptive_sandwiches.shape[-1]
286
+ # Plot each diagonal element of premature adjusted sandwiches
287
+ num_diag = premature_adjusted_sandwiches.shape[-1]
288
288
  for i in range(num_diag):
289
289
  plt.clear_figure()
290
- plt.title(f"Premature Adaptive Sandwich Diagonal Element {i}")
290
+ plt.title(f"Premature Adjusted Sandwich Diagonal Element {i}")
291
291
  plt.xlabel("Premature Update Index")
292
292
  plt.ylabel(f"Variance (Diagonal {i})")
293
- plt.scatter(np.array(premature_adaptive_sandwiches[:, i, i]), color="blue+")
293
+ plt.scatter(np.array(premature_adjusted_sandwiches[:, i, i]), color="blue+")
294
294
  plt.grid(True)
295
295
  plt.xticks(
296
296
  range(
297
297
  0,
298
- int(premature_adaptive_sandwiches.shape[0]),
299
- max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
298
+ int(premature_adjusted_sandwiches.shape[0]),
299
+ max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
300
300
  )
301
301
  )
302
302
  plt.show()
303
303
 
304
304
  plt.clear_figure()
305
305
  plt.title(
306
- f"Premature Adaptive Sandwich Diagonal Element {i} Ratio to Classical"
306
+ f"Premature Adjusted Sandwich Diagonal Element {i} Ratio to Classical"
307
307
  )
308
308
  plt.xlabel("Premature Update Index")
309
309
  plt.ylabel(f"Variance (Diagonal {i})")
310
310
  plt.scatter(
311
- np.array(premature_adaptive_sandwiches[:, i, i])
311
+ np.array(premature_adjusted_sandwiches[:, i, i])
312
312
  / np.array(premature_classical_sandwiches[:, i, i]),
313
313
  color="red+",
314
314
  )
@@ -316,8 +316,8 @@ def get_datum_for_blowup_supervised_learning(
316
316
  plt.xticks(
317
317
  range(
318
318
  0,
319
- int(premature_adaptive_sandwiches.shape[0]),
320
- max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
319
+ int(premature_adjusted_sandwiches.shape[0]),
320
+ max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
321
321
  )
322
322
  )
323
323
  plt.show()
@@ -331,14 +331,14 @@ def get_datum_for_blowup_supervised_learning(
331
331
  plt.xticks(
332
332
  range(
333
333
  0,
334
- int(premature_adaptive_sandwiches.shape[0]),
335
- max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
334
+ int(premature_adjusted_sandwiches.shape[0]),
335
+ max(1, int(premature_adjusted_sandwiches.shape[0]) // 10),
336
336
  )
337
337
  )
338
338
  plt.show()
339
339
 
340
340
  # Grab predictors related to premature Phi-dot-bars
341
- RL_stack_beta_derivatives_block = joint_adaptive_bread_inverse_matrix[
341
+ RL_stack_beta_derivatives_block = joint_adjusted_bread_matrix[
342
342
  :-theta_dim, :-theta_dim
343
343
  ]
344
344
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
@@ -397,14 +397,14 @@ def get_datum_for_blowup_supervised_learning(
397
397
  )
398
398
  return {
399
399
  **{
400
- "joint_bread_inverse_condition_number": joint_adaptive_bread_inverse_cond,
401
- "joint_bread_inverse_min_singular_value": joint_bread_inverse_min_singular_value,
400
+ "joint_bread_condition_number": joint_adjusted_bread_cond,
401
+ "joint_bread_min_singular_value": joint_bread_min_singular_value,
402
402
  "max_reward": max_reward,
403
403
  "norm_avg_estimating_function_stack": norm_avg_estimating_function_stack,
404
404
  "max_estimating_function_stack_norm": max_estimating_function_stack_norm,
405
405
  "max_successive_beta_diff_norm": max_successive_beta_diff_norm,
406
406
  "std_successive_beta_diff_norm": std_successive_beta_diff_norm,
407
- "label": adaptive_sandwich_var_estimate,
407
+ "label": adjusted_sandwich_var_estimate,
408
408
  },
409
409
  **{
410
410
  f"off_diag_block_{i}_{j}_norm": off_diag_block_norms[(i, j)]
@@ -422,10 +422,10 @@ def get_datum_for_blowup_supervised_learning(
422
422
  for i in range(num_block_rows_cols)
423
423
  },
424
424
  **{
425
- f"estimating_function_stack_norm_user_{user_id}": estimating_function_stack_norms[
425
+ f"estimating_function_stack_norm_subject_{subject_id}": estimating_function_stack_norms[
426
426
  i
427
427
  ]
428
- for i, user_id in enumerate(user_ids)
428
+ for i, subject_id in enumerate(subject_ids)
429
429
  },
430
430
  **{
431
431
  f"avg_estimating_function_stack_norm_segment_{i}": avg_estimating_function_stack_norms_per_segment[
@@ -455,18 +455,16 @@ def get_datum_for_blowup_supervised_learning(
455
455
  },
456
456
  **{f"theta_est_{i}": theta_est[i].item() for i in range(len(theta_est))},
457
457
  **{
458
- f"premature_joint_adaptive_bread_inverse_condition_number_{i}": premature_joint_adaptive_bread_inverse_condition_numbers[
458
+ f"premature_joint_adjusted_bread_condition_number_{i}": premature_joint_adjusted_bread_condition_numbers[
459
459
  i
460
460
  ]
461
- for i in range(
462
- len(premature_joint_adaptive_bread_inverse_condition_numbers)
463
- )
461
+ for i in range(len(premature_joint_adjusted_bread_condition_numbers))
464
462
  },
465
463
  **{
466
- f"premature_adaptive_sandwich_update_{i}_diag_position_{j}": premature_adaptive_sandwich[
464
+ f"premature_adjusted_sandwich_update_{i}_diag_position_{j}": premature_adjusted_sandwich[
467
465
  j, j
468
466
  ]
469
- for premature_adaptive_sandwich in premature_adaptive_sandwiches
467
+ for premature_adjusted_sandwich in premature_adjusted_sandwiches
470
468
  for j in range(theta_dim)
471
469
  },
472
470
  **{
@@ -497,44 +495,46 @@ def get_datum_for_blowup_supervised_learning(
497
495
  }
498
496
 
499
497
 
500
- def calculate_sequence_of_premature_adaptive_estimates(
501
- study_df: pd.DataFrame,
498
+ def calculate_sequence_of_premature_adjusted_estimates(
499
+ analysis_df: pd.DataFrame,
502
500
  initial_policy_num: int | float,
503
501
  beta_index_by_policy_num: dict[int | float, int],
504
- policy_num_by_decision_time_by_user_id: dict[
502
+ policy_num_by_decision_time_by_subject_id: dict[
505
503
  collections.abc.Hashable, dict[int, int | float]
506
504
  ],
507
505
  theta_calculation_func: str,
508
506
  calendar_t_col_name: str,
509
507
  action_prob_col_name: str,
510
- user_id_col_name: str,
511
- in_study_col_name: str,
508
+ subject_id_col_name: str,
509
+ active_col_name: str,
512
510
  all_post_update_betas: jnp.ndarray,
513
- user_ids: jnp.ndarray,
511
+ subject_ids: jnp.ndarray,
514
512
  action_prob_func: str,
515
513
  action_prob_func_args_beta_index: int,
516
514
  inference_func: str,
517
515
  inference_func_type: str,
518
516
  inference_func_args_theta_index: int,
519
517
  inference_func_args_action_prob_index: int,
520
- inference_action_prob_decision_times_by_user_id: dict[
518
+ inference_action_prob_decision_times_by_subject_id: dict[
521
519
  collections.abc.Hashable, list[int]
522
520
  ],
523
- action_prob_func_args_by_user_id_by_decision_time: dict[
521
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
524
522
  int, dict[collections.abc.Hashable, tuple[Any, ...]]
525
523
  ],
526
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
527
- full_joint_adaptive_bread_inverse_matrix: jnp.ndarray,
528
- per_user_estimating_function_stacks: jnp.ndarray,
524
+ action_by_decision_time_by_subject_id: dict[
525
+ collections.abc.Hashable, dict[int, int]
526
+ ],
527
+ full_joint_adjusted_bread_matrix: jnp.ndarray,
528
+ per_subject_estimating_function_stacks: jnp.ndarray,
529
529
  beta_dim: int,
530
530
  ) -> jnp.ndarray:
531
531
  """
532
- Calculates a sequence of premature adaptive estimates for the given study DataFrame, where we
532
+ Calculates a sequence of premature adjusted estimates for the given study DataFrame, where we
533
533
  pretend the study ended after each update in sequence. The behavior of this sequence may provide
534
- insight into the stability of the final adaptive estimate.
534
+ insight into the stability of the final adjusted estimate.
535
535
 
536
536
  Args:
537
- study_df (pandas.DataFrame):
537
+ analysis_df (pandas.DataFrame):
538
538
  The DataFrame containing the study data.
539
539
  initial_policy_num (int | float): The policy number of the initial policy before any updates.
540
540
  initial_policy_num (int | float):
@@ -542,23 +542,23 @@ def calculate_sequence_of_premature_adaptive_estimates(
542
542
  beta_index_by_policy_num (dict[int | float, int]):
543
543
  A dictionary mapping policy numbers to the index of the corresponding beta in
544
544
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
545
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
546
- A map of user ids to dictionaries mapping decision times to the policy number in use.
545
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
546
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
547
547
  Only applies to in-study decision times!
548
548
  theta_calculation_func (callable):
549
549
  The filename for the theta calculation function.
550
550
  calendar_t_col_name (str):
551
- The name of the column in study_df representing calendar time.
551
+ The name of the column in analysis_df representing calendar time.
552
552
  action_prob_col_name (str):
553
- The name of the column in study_df representing action probabilities.
554
- user_id_col_name (str):
555
- The name of the column in study_df representing user IDs.
556
- in_study_col_name (str):
557
- The name of the column in study_df indicating whether the user is in the study at that time.
553
+ The name of the column in analysis_df representing action probabilities.
554
+ subject_id_col_name (str):
555
+ The name of the column in analysis_df representing subject IDs.
556
+ active_col_name (str):
557
+ The name of the column in analysis_df indicating whether the subject is in the study at that time.
558
558
  all_post_update_betas (jnp.ndarray):
559
559
  A NumPy array containing all post-update beta values.
560
- user_ids (jnp.ndarray):
561
- A NumPy array containing all user IDs in the study.
560
+ subject_ids (jnp.ndarray):
561
+ A NumPy array containing all subject IDs in the study.
562
562
  action_prob_func (callable):
563
563
  The action probability function.
564
564
  action_prob_func_args_beta_index (int):
@@ -572,56 +572,54 @@ def calculate_sequence_of_premature_adaptive_estimates(
572
572
  inference_func_args_action_prob_index (int):
573
573
  The index of action probabilities in the inference function arguments tuple, if
574
574
  applicable. -1 otherwise.
575
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
576
- For each user, a list of decision times to which action probabilities correspond if
575
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
576
+ For each subject, a list of decision times to which action probabilities correspond if
577
577
  provided. Typically just in-study times if action probabilites are used in the inference
578
578
  loss or estimating function.
579
- action_prob_func_args_by_user_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
580
- A dictionary mapping decision times to maps of user ids to the function arguments
581
- required to compute action probabilities for this user.
582
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
583
- A dictionary mapping user IDs to their respective actions taken at each decision time.
579
+ action_prob_func_args_by_subject_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
580
+ A dictionary mapping decision times to maps of subject ids to the function arguments
581
+ required to compute action probabilities for this subject.
582
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
583
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
584
584
  Only applies to in-study decision times!
585
- full_joint_adaptive_bread_inverse_matrix (jnp.ndarray):
586
- The full joint adaptive bread inverse matrix as a NumPy array.
587
- per_user_estimating_function_stacks (jnp.ndarray):
588
- A NumPy array containing all per-user (weighted) estimating function stacks.
585
+ full_joint_adjusted_bread_matrix (jnp.ndarray):
586
+ The full joint adjusted bread matrix as a NumPy array.
587
+ per_subject_estimating_function_stacks (jnp.ndarray):
588
+ A NumPy array containing all per-subject (weighted) estimating function stacks.
589
589
  beta_dim (int):
590
590
  The dimension of the beta parameters.
591
591
  Returns:
592
- tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature adaptive estimates.
592
+ tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature adjusted estimates.
593
593
  """
594
594
 
595
- # Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming adaptive and classical
595
+ # Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming adjusted and classical
596
596
  # variance estimates pretending that each was the final policy.
597
- premature_adaptive_sandwiches = []
597
+ premature_adjusted_sandwiches = []
598
598
  premature_thetas = []
599
- premature_joint_adaptive_bread_inverse_condition_numbers = []
599
+ premature_joint_adjusted_bread_condition_numbers = []
600
600
  premature_avg_inference_estimating_functions = []
601
601
  premature_classical_sandwiches = []
602
602
  logger.info(
603
- "Calculating sequence of premature adaptive estimates by pretending the study ended after each update in sequence."
603
+ "Calculating sequence of premature adjusted estimates by pretending the study ended after each update in sequence."
604
604
  )
605
605
  for policy_num in sorted(beta_index_by_policy_num):
606
606
  logger.info(
607
- "Calculating premature adaptive estimate assuming policy %s is the final one.",
607
+ "Calculating premature adjusted estimate assuming policy %s is the final one.",
608
608
  policy_num,
609
609
  )
610
610
  pretend_max_policy = policy_num
611
611
 
612
- truncated_joint_adaptive_bread_inverse_matrix = (
613
- full_joint_adaptive_bread_inverse_matrix[
614
- : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
615
- : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
616
- ]
617
- )
612
+ truncated_joint_adjusted_bread_matrix = full_joint_adjusted_bread_matrix[
613
+ : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
614
+ : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
615
+ ]
618
616
 
619
- max_decision_time = study_df[study_df["policy_num"] == pretend_max_policy][
620
- calendar_t_col_name
621
- ].max()
617
+ max_decision_time = analysis_df[
618
+ analysis_df["policy_num"] == pretend_max_policy
619
+ ][calendar_t_col_name].max()
622
620
 
623
- truncated_study_df = study_df[
624
- study_df[calendar_t_col_name] <= max_decision_time
621
+ truncated_analysis_df = analysis_df[
622
+ analysis_df[calendar_t_col_name] <= max_decision_time
625
623
  ].copy()
626
624
 
627
625
  truncated_beta_index_by_policy_num = {
@@ -632,83 +630,83 @@ def calculate_sequence_of_premature_adaptive_estimates(
632
630
 
633
631
  truncated_all_post_update_betas = all_post_update_betas[: max_beta_index + 1, :]
634
632
 
635
- premature_theta = jnp.array(theta_calculation_func(truncated_study_df))
633
+ premature_theta = jnp.array(theta_calculation_func(truncated_analysis_df))
636
634
 
637
- truncated_action_prob_func_args_by_user_id_by_decision_time = {
638
- decision_time: args_by_user_id
639
- for decision_time, args_by_user_id in action_prob_func_args_by_user_id_by_decision_time.items()
635
+ truncated_action_prob_func_args_by_subject_id_by_decision_time = {
636
+ decision_time: args_by_subject_id
637
+ for decision_time, args_by_subject_id in action_prob_func_args_by_subject_id_by_decision_time.items()
640
638
  if decision_time <= max_decision_time
641
639
  }
642
640
 
643
- truncated_inference_func_args_by_user_id, _, _ = (
644
- after_study_analysis.process_inference_func_args(
641
+ truncated_inference_func_args_by_subject_id, _, _ = (
642
+ post_deployment_analysis.process_inference_func_args(
645
643
  inference_func,
646
644
  inference_func_args_theta_index,
647
- truncated_study_df,
645
+ truncated_analysis_df,
648
646
  premature_theta,
649
647
  action_prob_col_name,
650
648
  calendar_t_col_name,
651
- user_id_col_name,
652
- in_study_col_name,
649
+ subject_id_col_name,
650
+ active_col_name,
653
651
  )
654
652
  )
655
653
 
656
- truncated_inference_action_prob_decision_times_by_user_id = {
657
- user_id: [
654
+ truncated_inference_action_prob_decision_times_by_subject_id = {
655
+ subject_id: [
658
656
  decision_time
659
- for decision_time in inference_action_prob_decision_times_by_user_id[
660
- user_id
657
+ for decision_time in inference_action_prob_decision_times_by_subject_id[
658
+ subject_id
661
659
  ]
662
660
  if decision_time <= max_decision_time
663
661
  ]
664
662
  # writing this way is important, handles empty dicts correctly
665
- for user_id in inference_action_prob_decision_times_by_user_id
663
+ for subject_id in inference_action_prob_decision_times_by_subject_id
666
664
  }
667
665
 
668
- truncated_action_by_decision_time_by_user_id = {
669
- user_id: {
666
+ truncated_action_by_decision_time_by_subject_id = {
667
+ subject_id: {
670
668
  decision_time: action
671
- for decision_time, action in action_by_decision_time_by_user_id[
672
- user_id
669
+ for decision_time, action in action_by_decision_time_by_subject_id[
670
+ subject_id
673
671
  ].items()
674
672
  if decision_time <= max_decision_time
675
673
  }
676
- for user_id in action_by_decision_time_by_user_id
674
+ for subject_id in action_by_decision_time_by_subject_id
677
675
  }
678
676
 
679
- truncated_per_user_estimating_function_stacks = (
680
- per_user_estimating_function_stacks[
677
+ truncated_per_subject_estimating_function_stacks = (
678
+ per_subject_estimating_function_stacks[
681
679
  :,
682
680
  : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
683
681
  ]
684
682
  )
685
683
 
686
684
  (
687
- premature_adaptive_sandwich,
685
+ premature_adjusted_sandwich,
688
686
  premature_classical_sandwich,
689
687
  premature_avg_inference_estimating_function,
690
- ) = construct_premature_classical_and_adaptive_sandwiches(
691
- truncated_joint_adaptive_bread_inverse_matrix,
692
- truncated_per_user_estimating_function_stacks,
688
+ ) = construct_premature_classical_and_adjusted_sandwiches(
689
+ truncated_joint_adjusted_bread_matrix,
690
+ truncated_per_subject_estimating_function_stacks,
693
691
  premature_theta,
694
692
  truncated_all_post_update_betas,
695
- user_ids,
693
+ subject_ids,
696
694
  action_prob_func,
697
695
  action_prob_func_args_beta_index,
698
696
  inference_func,
699
697
  inference_func_type,
700
698
  inference_func_args_theta_index,
701
699
  inference_func_args_action_prob_index,
702
- truncated_action_prob_func_args_by_user_id_by_decision_time,
703
- policy_num_by_decision_time_by_user_id,
700
+ truncated_action_prob_func_args_by_subject_id_by_decision_time,
701
+ policy_num_by_decision_time_by_subject_id,
704
702
  initial_policy_num,
705
703
  truncated_beta_index_by_policy_num,
706
- truncated_inference_func_args_by_user_id,
707
- truncated_inference_action_prob_decision_times_by_user_id,
708
- truncated_action_by_decision_time_by_user_id,
704
+ truncated_inference_func_args_by_subject_id,
705
+ truncated_inference_action_prob_decision_times_by_subject_id,
706
+ truncated_action_by_decision_time_by_subject_id,
709
707
  )
710
708
 
711
- premature_adaptive_sandwiches.append(premature_adaptive_sandwich)
709
+ premature_adjusted_sandwiches.append(premature_adjusted_sandwich)
712
710
  premature_classical_sandwiches.append(premature_classical_sandwich)
713
711
  premature_thetas.append(premature_theta)
714
712
  premature_avg_inference_estimating_functions.append(
@@ -716,38 +714,40 @@ def calculate_sequence_of_premature_adaptive_estimates(
716
714
  )
717
715
  return (
718
716
  jnp.array(premature_thetas),
719
- jnp.array(premature_adaptive_sandwiches),
717
+ jnp.array(premature_adjusted_sandwiches),
720
718
  jnp.array(premature_classical_sandwiches),
721
- jnp.array(premature_joint_adaptive_bread_inverse_condition_numbers),
719
+ jnp.array(premature_joint_adjusted_bread_condition_numbers),
722
720
  jnp.array(premature_avg_inference_estimating_functions),
723
721
  )
724
722
 
725
723
 
726
- def construct_premature_classical_and_adaptive_sandwiches(
727
- truncated_joint_adaptive_bread_inverse_matrix: jnp.ndarray,
728
- per_user_truncated_estimating_function_stacks: jnp.ndarray,
724
+ def construct_premature_classical_and_adjusted_sandwiches(
725
+ truncated_joint_adjusted_bread_matrix: jnp.ndarray,
726
+ per_subject_truncated_estimating_function_stacks: jnp.ndarray,
729
727
  theta: jnp.ndarray,
730
728
  all_post_update_betas: jnp.ndarray,
731
- user_ids: jnp.ndarray,
729
+ subject_ids: jnp.ndarray,
732
730
  action_prob_func: str,
733
731
  action_prob_func_args_beta_index: int,
734
732
  inference_func: str,
735
733
  inference_func_type: str,
736
734
  inference_func_args_theta_index: int,
737
735
  inference_func_args_action_prob_index: int,
738
- action_prob_func_args_by_user_id_by_decision_time: dict[
736
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
739
737
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
740
738
  ],
741
- policy_num_by_decision_time_by_user_id: dict[
739
+ policy_num_by_decision_time_by_subject_id: dict[
742
740
  collections.abc.Hashable, dict[int, int | float]
743
741
  ],
744
742
  initial_policy_num: int | float,
745
743
  beta_index_by_policy_num: dict[int | float, int],
746
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
747
- inference_action_prob_decision_times_by_user_id: dict[
744
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
745
+ inference_action_prob_decision_times_by_subject_id: dict[
748
746
  collections.abc.Hashable, list[int]
749
747
  ],
750
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
748
+ action_by_decision_time_by_subject_id: dict[
749
+ collections.abc.Hashable, dict[int, int]
750
+ ],
751
751
  ) -> tuple[
752
752
  jnp.ndarray[jnp.float32],
753
753
  jnp.ndarray[jnp.float32],
@@ -759,33 +759,33 @@ def construct_premature_classical_and_adaptive_sandwiches(
759
759
  jnp.ndarray[jnp.float32],
760
760
  ]:
761
761
  """
762
- Constructs the classical bread and meat matrices, as well as the adaptive bread matrix
762
+ Constructs the classical bread and meat matrices, as well as the adjusted bread matrix
763
763
  and the average weighted inference estimating function for the premature variance estimation
764
764
  procedure.
765
765
 
766
766
  This is done by computing and differentiating the new average inference estimating function
767
767
  with respect to the betas and theta, and stitching this together with the existing
768
- adaptive bread inverse matrix portion (corresponding to the updates still under consideration)
769
- to form the new premature joint adaptive bread inverse matrix.
768
+ adjusted bread matrix portion (corresponding to the updates still under consideration)
769
+ to form the new premature joint adjusted bread matrix.
770
770
 
771
771
  Args:
772
- truncated_joint_adaptive_bread_inverse_matrix (jnp.ndarray):
773
- A 2-D JAX NumPy array holding the existing joint adaptive bread inverse but
772
+ truncated_joint_adjusted_bread_matrix (jnp.ndarray):
773
+ A 2-D JAX NumPy array holding the existing joint adjusted bread but
774
774
  with rows corresponding to updates not under consideration and inference dropped.
775
775
  We will stitch this together with the newly computed inference portion to form
776
- our "premature" joint adaptive bread inverse matrix.
777
- per_user_truncated_estimating_function_stacks (jnp.ndarray):
778
- A 2-D JAX NumPy array holding the existing per-user weighted estimating function
776
+ our "premature" joint adjusted bread matrix.
777
+ per_subject_truncated_estimating_function_stacks (jnp.ndarray):
778
+ A 2-D JAX NumPy array holding the existing per-subject weighted estimating function
779
779
  stacks but with rows corresponding to updates not under consideration dropped.
780
780
  We will stitch this together with the newly computed inference estimating functions
781
- to form our "premature" joint adaptive estimating function stacks from which the new
782
- adaptive meat matrix can be computed.
781
+ to form our "premature" joint adjusted estimating function stacks from which the new
782
+ adjusted meat matrix can be computed.
783
783
  theta (jnp.ndarray):
784
784
  A 1-D JAX NumPy array representing the parameter estimate for inference.
785
785
  all_post_update_betas (jnp.ndarray):
786
786
  A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
787
- user_ids (jnp.ndarray):
788
- A 1-D JAX NumPy array holding all user IDs in the study.
787
+ subject_ids (jnp.ndarray):
788
+ A 1-D JAX NumPy array holding all subject IDs in the study.
789
789
  action_prob_func (callable):
790
790
  The action probability function.
791
791
  action_prob_func_args_beta_index (int):
@@ -799,81 +799,81 @@ def construct_premature_classical_and_adaptive_sandwiches(
799
799
  inference_func_args_action_prob_index (int):
800
800
  The index of action probabilities in the inference function arguments tuple, if
801
801
  applicable. -1 otherwise.
802
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
803
- A dictionary mapping decision times to maps of user ids to the function arguments
804
- required to compute action probabilities for this user.
805
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
806
- A map of user ids to dictionaries mapping decision times to the policy number in use.
802
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
803
+ A dictionary mapping decision times to maps of subject ids to the function arguments
804
+ required to compute action probabilities for this subject.
805
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
806
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
807
807
  Only applies to in-study decision times!
808
808
  initial_policy_num (int | float):
809
809
  The policy number of the initial policy before any updates.
810
810
  beta_index_by_policy_num (dict[int | float, int]):
811
811
  A dictionary mapping policy numbers to the index of the corresponding beta in
812
812
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
813
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
814
- A dictionary mapping user IDs to their respective inference function arguments.
815
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
816
- For each user, a list of decision times to which action probabilities correspond if
813
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
814
+ A dictionary mapping subject IDs to their respective inference function arguments.
815
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
816
+ For each subject, a list of decision times to which action probabilities correspond if
817
817
  provided. Typically just in-study times if action probabilites are used in the inference
818
818
  loss or estimating function.
819
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
820
- A dictionary mapping user IDs to their respective actions taken at each decision time.
819
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
820
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
821
821
  Only applies to in-study decision times!
822
822
  Returns:
823
823
  tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
824
824
  jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
825
825
  jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
826
826
  A tuple containing:
827
- - The joint adaptive inverse bread matrix.
828
- - The joint adaptive bread matrix.
829
- - The joint adaptive meat matrix.
830
- - The classical inverse bread matrix.
827
+ - The joint adjusted bread matrix.
828
+ - The joint adjusted bread matrix.
829
+ - The joint adjusted meat matrix.
830
+ - The classical bread matrix.
831
831
  - The classical bread matrix.
832
832
  - The classical meat matrix.
833
833
  - The average (weighted) inference estimating function.
834
- - The joint adaptive inverse bread matrix condition number.
834
+ - The joint adjusted bread matrix condition number.
835
835
  """
836
836
  logger.info(
837
837
  "Differentiating average weighted inference estimating function stack and collecting auxiliary values."
838
838
  )
839
839
  # jax.jacobian may perform worse here--seemed to hang indefinitely while jacrev is merely very
840
840
  # slow.
841
- # Note that these "contributions" are per-user Jacobians of the weighted estimating function stack.
841
+ # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
842
842
  new_inference_block_row, (
843
- per_user_inference_estimating_functions,
843
+ per_subject_inference_estimating_functions,
844
844
  avg_inference_estimating_function,
845
- per_user_classical_meat_contributions,
846
- per_user_classical_bread_inverse_contributions,
845
+ per_subject_classical_meat_contributions,
846
+ per_subject_classical_bread_contributions,
847
847
  ) = jax.jacrev(get_weighted_inference_estimating_functions_only, has_aux=True)(
848
848
  # While JAX can technically differentiate with respect to a list of JAX arrays,
849
849
  # it is more efficient to flatten them into a single array. This is done
850
850
  # here to improve performance. We can simply unflatten them inside the function.
851
- after_study_analysis.flatten_params(all_post_update_betas, theta),
851
+ post_deployment_analysis.flatten_params(all_post_update_betas, theta),
852
852
  all_post_update_betas.shape[1],
853
853
  theta.shape[0],
854
- user_ids,
854
+ subject_ids,
855
855
  action_prob_func,
856
856
  action_prob_func_args_beta_index,
857
857
  inference_func,
858
858
  inference_func_type,
859
859
  inference_func_args_theta_index,
860
860
  inference_func_args_action_prob_index,
861
- action_prob_func_args_by_user_id_by_decision_time,
862
- policy_num_by_decision_time_by_user_id,
861
+ action_prob_func_args_by_subject_id_by_decision_time,
862
+ policy_num_by_decision_time_by_subject_id,
863
863
  initial_policy_num,
864
864
  beta_index_by_policy_num,
865
- inference_func_args_by_user_id,
866
- inference_action_prob_decision_times_by_user_id,
867
- action_by_decision_time_by_user_id,
865
+ inference_func_args_by_subject_id,
866
+ inference_action_prob_decision_times_by_subject_id,
867
+ action_by_decision_time_by_subject_id,
868
868
  )
869
869
 
870
- joint_adaptive_bread_inverse_matrix = jnp.block(
870
+ joint_adjusted_bread_matrix = jnp.block(
871
871
  [
872
872
  [
873
- truncated_joint_adaptive_bread_inverse_matrix,
873
+ truncated_joint_adjusted_bread_matrix,
874
874
  np.zeros(
875
875
  (
876
- truncated_joint_adaptive_bread_inverse_matrix.shape[0],
876
+ truncated_joint_adjusted_bread_matrix.shape[0],
877
877
  new_inference_block_row.shape[0],
878
878
  )
879
879
  ),
@@ -881,51 +881,49 @@ def construct_premature_classical_and_adaptive_sandwiches(
881
881
  [new_inference_block_row],
882
882
  ]
883
883
  )
884
- per_user_estimating_function_stacks = jnp.concatenate(
884
+ per_subject_estimating_function_stacks = jnp.concatenate(
885
885
  [
886
- per_user_truncated_estimating_function_stacks,
887
- per_user_inference_estimating_functions,
886
+ per_subject_truncated_estimating_function_stacks,
887
+ per_subject_inference_estimating_functions,
888
888
  ],
889
889
  axis=1,
890
890
  )
891
- per_user_adaptive_meat_contributions = jnp.einsum(
891
+ per_subject_adjusted_meat_contributions = jnp.einsum(
892
892
  "ni,nj->nij",
893
- per_user_estimating_function_stacks,
894
- per_user_estimating_function_stacks,
893
+ per_subject_estimating_function_stacks,
894
+ per_subject_estimating_function_stacks,
895
895
  )
896
896
 
897
- joint_adaptive_meat_matrix = jnp.mean(per_user_adaptive_meat_contributions, axis=0)
898
-
899
- classical_bread_inverse_matrix = jnp.mean(
900
- per_user_classical_bread_inverse_contributions, axis=0
897
+ joint_adjusted_meat_matrix = jnp.mean(
898
+ per_subject_adjusted_meat_contributions, axis=0
901
899
  )
902
- classical_meat_matrix = jnp.mean(per_user_classical_meat_contributions, axis=0)
903
-
904
- num_users = user_ids.shape[0]
905
- joint_adaptive_sandwich = (
906
- after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
907
- joint_adaptive_bread_inverse_matrix,
908
- joint_adaptive_meat_matrix,
909
- num_users,
910
- method="bread_inverse_T_qr",
900
+
901
+ classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
902
+ classical_meat_matrix = jnp.mean(per_subject_classical_meat_contributions, axis=0)
903
+
904
+ num_subjects = subject_ids.shape[0]
905
+ joint_adjusted_sandwich = (
906
+ post_deployment_analysis.form_sandwich_from_bread_and_meat(
907
+ joint_adjusted_bread_matrix,
908
+ joint_adjusted_meat_matrix,
909
+ num_subjects,
910
+ method="bread_T_qr",
911
911
  )
912
912
  )
913
- adaptive_sandwich = joint_adaptive_sandwich[-theta.shape[0] :, -theta.shape[0] :]
913
+ adjusted_sandwich = joint_adjusted_sandwich[-theta.shape[0] :, -theta.shape[0] :]
914
914
 
915
- classical_bread_inverse_matrix = jnp.mean(
916
- per_user_classical_bread_inverse_contributions, axis=0
917
- )
918
- classical_sandwich = after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
919
- classical_bread_inverse_matrix,
915
+ classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
916
+ classical_sandwich = post_deployment_analysis.form_sandwich_from_bread_and_meat(
917
+ classical_bread_matrix,
920
918
  classical_meat_matrix,
921
- num_users,
922
- method="bread_inverse_T_qr",
919
+ num_subjects,
920
+ method="bread_T_qr",
923
921
  )
924
922
 
925
- # Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
926
- # values too. The joint adaptive bread inverse should always be block lower triangular.
923
+ # Stack the joint adjusted bread pieces together horizontally and return the auxiliary
924
+ # values too. The joint adjusted bread should always be block lower triangular.
927
925
  return (
928
- adaptive_sandwich,
926
+ adjusted_sandwich,
929
927
  classical_sandwich,
930
928
  avg_inference_estimating_function,
931
929
  )
@@ -935,32 +933,34 @@ def get_weighted_inference_estimating_functions_only(
935
933
  flattened_betas_and_theta: jnp.ndarray,
936
934
  beta_dim: int,
937
935
  theta_dim: int,
938
- user_ids: jnp.ndarray,
936
+ subject_ids: jnp.ndarray,
939
937
  action_prob_func: callable,
940
938
  action_prob_func_args_beta_index: int,
941
939
  inference_func: callable,
942
940
  inference_func_type: str,
943
941
  inference_func_args_theta_index: int,
944
942
  inference_func_args_action_prob_index: int,
945
- action_prob_func_args_by_user_id_by_decision_time: dict[
943
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
946
944
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
947
945
  ],
948
- policy_num_by_decision_time_by_user_id: dict[
946
+ policy_num_by_decision_time_by_subject_id: dict[
949
947
  collections.abc.Hashable, dict[int, int | float]
950
948
  ],
951
949
  initial_policy_num: int | float,
952
950
  beta_index_by_policy_num: dict[int | float, int],
953
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
954
- inference_action_prob_decision_times_by_user_id: dict[
951
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
952
+ inference_action_prob_decision_times_by_subject_id: dict[
955
953
  collections.abc.Hashable, list[int]
956
954
  ],
957
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
955
+ action_by_decision_time_by_subject_id: dict[
956
+ collections.abc.Hashable, dict[int, int]
957
+ ],
958
958
  ) -> tuple[
959
959
  jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
960
960
  ]:
961
961
  """
962
- Computes the average weighted inference estimating function across users, along with
963
- auxiliary values used to construct the adaptive and classical sandwich variances.
962
+ Computes the average weighted inference estimating function across subjects, along with
963
+ auxiliary values used to construct the adjusted and classical sandwich variances.
964
964
 
965
965
  Note that input data should have been adjusted to only correspond to updates/decision times
966
966
  that are being considered for the current "premature" variance estimation procedure.
@@ -974,8 +974,8 @@ def get_weighted_inference_estimating_functions_only(
974
974
  The dimension of each of the beta parameters.
975
975
  theta_dim (int):
976
976
  The dimension of the theta parameter.
977
- user_ids (jnp.ndarray):
978
- A 1D JAX NumPy array of user IDs.
977
+ subject_ids (jnp.ndarray):
978
+ A 1D JAX NumPy array of subject IDs.
979
979
  action_prob_func (str):
980
980
  The action probability function.
981
981
  action_prob_func_args_beta_index (int):
@@ -989,25 +989,25 @@ def get_weighted_inference_estimating_functions_only(
989
989
  inference_func_args_action_prob_index (int):
990
990
  The index of action probabilities in the inference function arguments tuple, if
991
991
  applicable. -1 otherwise.
992
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
993
- A dictionary mapping decision times to maps of user ids to the function arguments
994
- required to compute action probabilities for this user.
995
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
996
- A map of user ids to dictionaries mapping decision times to the policy number in use.
992
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
993
+ A dictionary mapping decision times to maps of subject ids to the function arguments
994
+ required to compute action probabilities for this subject.
995
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
996
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
997
997
  Only applies to in-study decision times!
998
998
  initial_policy_num (int | float):
999
999
  The policy number of the initial policy before any updates.
1000
1000
  beta_index_by_policy_num (dict[int | float, int]):
1001
1001
  A dictionary mapping policy numbers to the index of the corresponding beta in
1002
1002
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1003
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1004
- A dictionary mapping user IDs to their respective inference function arguments.
1005
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1006
- For each user, a list of decision times to which action probabilities correspond if
1003
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1004
+ A dictionary mapping subject IDs to their respective inference function arguments.
1005
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
1006
+ For each subject, a list of decision times to which action probabilities correspond if
1007
1007
  provided. Typically just in-study times if action probabilites are used in the inference
1008
1008
  loss or estimating function.
1009
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1010
- A dictionary mapping user IDs to their respective actions taken at each decision time.
1009
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
1010
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
1011
1011
  Only applies to in-study decision times!
1012
1012
 
1013
1013
  Returns:
@@ -1015,10 +1015,10 @@ def get_weighted_inference_estimating_functions_only(
1015
1015
  A 2D JAX NumPy array holding the average weighted inference estimating function.
1016
1016
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1017
1017
  A tuple containing
1018
- 1. the per-user weighted inference estimating function stacks
1018
+ 1. the per-subject weighted inference estimating function stacks
1019
1019
  2. the average weighted inference estimating function
1020
- 3. the user-level classical meat matrix contributions
1021
- 4. the user-level inverse classical bread matrix contributions
1020
+ 3. the subject-level classical meat matrix contributions
1021
+ 4. the subject-level inverse classical bread matrix contributions
1022
1022
  stacks.
1023
1023
  """
1024
1024
 
@@ -1028,7 +1028,7 @@ def get_weighted_inference_estimating_functions_only(
1028
1028
  else inference_func
1029
1029
  )
1030
1030
 
1031
- betas, theta = after_study_analysis.unflatten_params(
1031
+ betas, theta = post_deployment_analysis.unflatten_params(
1032
1032
  flattened_betas_and_theta,
1033
1033
  beta_dim,
1034
1034
  theta_dim,
@@ -1038,15 +1038,15 @@ def get_weighted_inference_estimating_functions_only(
1038
1038
  # supplied for the above functions, so that differentiation works correctly. The existing
1039
1039
  # values should be the same, but not connected to the parameter we are differentiating
1040
1040
  # with respect to. Note we will also find it useful below to have the action probability args
1041
- # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
1041
+ # nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
1042
1042
 
1043
- logger.info("Threading in betas to action probability arguments for all users.")
1043
+ logger.info("Threading in betas to action probability arguments for all subjects.")
1044
1044
  (
1045
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1046
- action_prob_func_args_by_decision_time_by_user_id,
1047
- ) = after_study_analysis.thread_action_prob_func_args(
1048
- action_prob_func_args_by_user_id_by_decision_time,
1049
- policy_num_by_decision_time_by_user_id,
1045
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1046
+ action_prob_func_args_by_decision_time_by_subject_id,
1047
+ ) = post_deployment_analysis.thread_action_prob_func_args(
1048
+ action_prob_func_args_by_subject_id_by_decision_time,
1049
+ policy_num_by_decision_time_by_subject_id,
1050
1050
  initial_policy_num,
1051
1051
  betas,
1052
1052
  beta_index_by_policy_num,
@@ -1058,38 +1058,38 @@ def get_weighted_inference_estimating_functions_only(
1058
1058
  # arguments with the central betas introduced.
1059
1059
  logger.info(
1060
1060
  "Threading in theta and beta-dependent action probabilities to inference update "
1061
- "function args for all users"
1061
+ "function args for all subjects"
1062
1062
  )
1063
- threaded_inference_func_args_by_user_id = (
1064
- after_study_analysis.thread_inference_func_args(
1065
- inference_func_args_by_user_id,
1063
+ threaded_inference_func_args_by_subject_id = (
1064
+ post_deployment_analysis.thread_inference_func_args(
1065
+ inference_func_args_by_subject_id,
1066
1066
  inference_func_args_theta_index,
1067
1067
  theta,
1068
1068
  inference_func_args_action_prob_index,
1069
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1070
- inference_action_prob_decision_times_by_user_id,
1069
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1070
+ inference_action_prob_decision_times_by_subject_id,
1071
1071
  action_prob_func,
1072
1072
  )
1073
1073
  )
1074
1074
 
1075
- # 5. Now we can compute the the weighted inference estimating functions for all users
1076
- # as well as collect related values used to construct the adaptive and classical
1075
+ # 5. Now we can compute the the weighted inference estimating functions for all subjects
1076
+ # as well as collect related values used to construct the adjusted and classical
1077
1077
  # sandwich variances.
1078
1078
  results = [
1079
- single_user_weighted_inference_estimating_function(
1080
- user_id,
1079
+ single_subject_weighted_inference_estimating_function(
1080
+ subject_id,
1081
1081
  action_prob_func,
1082
1082
  inference_estimating_func,
1083
1083
  action_prob_func_args_beta_index,
1084
1084
  inference_func_args_theta_index,
1085
- action_prob_func_args_by_decision_time_by_user_id[user_id],
1086
- threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
1087
- threaded_inference_func_args_by_user_id[user_id],
1088
- policy_num_by_decision_time_by_user_id[user_id],
1089
- action_by_decision_time_by_user_id[user_id],
1085
+ action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1086
+ threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1087
+ threaded_inference_func_args_by_subject_id[subject_id],
1088
+ policy_num_by_decision_time_by_subject_id[subject_id],
1089
+ action_by_decision_time_by_subject_id[subject_id],
1090
1090
  beta_index_by_policy_num,
1091
1091
  )
1092
- for user_id in user_ids.tolist()
1092
+ for subject_id in subject_ids.tolist()
1093
1093
  ]
1094
1094
 
1095
1095
  weighted_inference_estimating_functions = jnp.array(
@@ -1100,8 +1100,8 @@ def get_weighted_inference_estimating_functions_only(
1100
1100
 
1101
1101
  # 6. Note this strange return structure! We will differentiate the first output,
1102
1102
  # but the second tuple will be passed along without modification via has_aux=True and then used
1103
- # for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
1104
- # bread matrices. The raw per-user estimating functions are also returned again for debugging
1103
+ # for the adjusted meat matrix, estimating functions sum check, and classical meat and inverse
1104
+ # bread matrices. The raw per-subject estimating functions are also returned again for debugging
1105
1105
  # purposes.
1106
1106
  return jnp.mean(weighted_inference_estimating_functions, axis=0), (
1107
1107
  weighted_inference_estimating_functions,
@@ -1111,8 +1111,8 @@ def get_weighted_inference_estimating_functions_only(
1111
1111
  )
1112
1112
 
1113
1113
 
1114
- def single_user_weighted_inference_estimating_function(
1115
- user_id: collections.abc.Hashable,
1114
+ def single_subject_weighted_inference_estimating_function(
1115
+ subject_id: collections.abc.Hashable,
1116
1116
  action_prob_func: callable,
1117
1117
  inference_estimating_func: callable,
1118
1118
  action_prob_func_args_beta_index: int,
@@ -1137,12 +1137,12 @@ def single_user_weighted_inference_estimating_function(
1137
1137
  and action probability function and arguments if applicable.
1138
1138
 
1139
1139
  Args:
1140
- user_id (collections.abc.Hashable):
1141
- The user ID for which to compute the weighted estimating function stack.
1140
+ subject_id (collections.abc.Hashable):
1141
+ The subject ID for which to compute the weighted estimating function stack.
1142
1142
 
1143
1143
  action_prob_func (callable):
1144
1144
  The function used to compute the probability of action 1 at a given decision time for
1145
- a particular user given their state and the algorithm parameters.
1145
+ a particular subject given their state and the algorithm parameters.
1146
1146
 
1147
1147
  inference_estimating_func (callable):
1148
1148
  The estimating function that corresponds to inference.
@@ -1154,7 +1154,7 @@ def single_user_weighted_inference_estimating_function(
1154
1154
  The index of the theta parameter in the inference loss or estimating function arguments.
1155
1155
 
1156
1156
  action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
1157
- A map from decision times to tuples of arguments for this user for the action
1157
+ A map from decision times to tuples of arguments for this subject for the action
1158
1158
  probability function. This is for all decision times (args are an empty
1159
1159
  tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
1160
1160
  ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
@@ -1167,11 +1167,11 @@ def single_user_weighted_inference_estimating_function(
1167
1167
 
1168
1168
  threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
1169
1169
  A tuple containing the arguments for the inference
1170
- estimating function for this user, with the shared betas threaded in for differentiation.
1170
+ estimating function for this subject, with the shared betas threaded in for differentiation.
1171
1171
 
1172
1172
  policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
1173
1173
  A dictionary mapping decision times to the policy number in use. This may be
1174
- user-specific. Should be sorted by decision time. Only applies to in-study decision
1174
+ subject-specific. Should be sorted by decision time. Only applies to in-study decision
1175
1175
  times!
1176
1176
 
1177
1177
  action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
@@ -1183,31 +1183,33 @@ def single_user_weighted_inference_estimating_function(
1183
1183
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1184
1184
 
1185
1185
  Returns:
1186
- jnp.ndarray: A 1-D JAX NumPy array representing the user's weighted inference estimating function.
1187
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical meat contribution.
1188
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical bread contribution.
1186
+ jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted inference estimating function.
1187
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
1188
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
1189
1189
  """
1190
1190
 
1191
1191
  logger.info(
1192
- "Computing only weighted inference estimating function stack for user %s.",
1193
- user_id,
1192
+ "Computing only weighted inference estimating function stack for subject %s.",
1193
+ subject_id,
1194
1194
  )
1195
1195
 
1196
1196
  # First, reformat the supplied data into more convenient structures.
1197
1197
 
1198
1198
  # 1. Get the first time after the first update for convenience.
1199
1199
  # This is used to form the Radon-Nikodym weights for the right times.
1200
- _, first_time_after_first_update = after_study_analysis.get_min_time_by_policy_num(
1201
- policy_num_by_decision_time,
1202
- beta_index_by_policy_num,
1200
+ _, first_time_after_first_update = (
1201
+ post_deployment_analysis.get_min_time_by_policy_num(
1202
+ policy_num_by_decision_time,
1203
+ beta_index_by_policy_num,
1204
+ )
1203
1205
  )
1204
1206
 
1205
- # 2. Get the start and end times for this user.
1206
- user_start_time = math.inf
1207
- user_end_time = -math.inf
1207
+ # 2. Get the start and end times for this subject.
1208
+ subject_start_time = math.inf
1209
+ subject_end_time = -math.inf
1208
1210
  for decision_time in action_by_decision_time:
1209
- user_start_time = min(user_start_time, decision_time)
1210
- user_end_time = max(user_end_time, decision_time)
1211
+ subject_start_time = min(subject_start_time, decision_time)
1212
+ subject_end_time = max(subject_end_time, decision_time)
1211
1213
 
1212
1214
  # 3. Calculate the Radon-Nikodym weights for the inference estimating function.
1213
1215
  in_study_action_prob_func_args = [
@@ -1224,11 +1226,11 @@ def single_user_weighted_inference_estimating_function(
1224
1226
  )
1225
1227
 
1226
1228
  # Sort the threaded args by decision time to be cautious. We check if the
1227
- # user id is present in the user args dict because we may call this on a
1228
- # subset of the user arg dict when we are batching arguments by shape
1229
+ # subject id is present in the subject args dict because we may call this on a
1230
+ # subset of the subject arg dict when we are batching arguments by shape
1229
1231
  sorted_threaded_action_prob_args_by_decision_time = {
1230
1232
  decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
1231
- for decision_time in range(user_start_time, user_end_time + 1)
1233
+ for decision_time in range(subject_start_time, subject_end_time + 1)
1232
1234
  if decision_time in threaded_action_prob_func_args_by_decision_time
1233
1235
  }
1234
1236
 
@@ -1260,7 +1262,7 @@ def single_user_weighted_inference_estimating_function(
1260
1262
  # value, but impervious to differentiation with respect to all_post_update_betas. The
1261
1263
  # args, on the other hand, are a function of all_post_update_betas.
1262
1264
  in_study_weights = jax.vmap(
1263
- fun=after_study_analysis.get_radon_nikodym_weight,
1265
+ fun=post_deployment_analysis.get_radon_nikodym_weight,
1264
1266
  in_axes=[0, None, None, 0] + batch_axes,
1265
1267
  out_axes=0,
1266
1268
  )(
@@ -1287,12 +1289,12 @@ def single_user_weighted_inference_estimating_function(
1287
1289
  # 4. Form the weighted inference estimating equation.
1288
1290
  weighted_inference_estimating_function = jnp.prod(
1289
1291
  all_weights[
1290
- max(first_time_after_first_update, user_start_time)
1291
- - decision_time_to_all_weights_index_offset : user_end_time
1292
+ max(first_time_after_first_update, subject_start_time)
1293
+ - decision_time_to_all_weights_index_offset : subject_end_time
1292
1294
  + 1
1293
1295
  - decision_time_to_all_weights_index_offset,
1294
1296
  ]
1295
- # If the user exited the study before there were any updates,
1297
+ # If the subject exited the study before there were any updates,
1296
1298
  # this variable will be None and the above code to grab a weight would
1297
1299
  # throw an error. Just use 1 to include the unweighted estimating function
1298
1300
  # if they have data to contribute here (pretty sure everyone should?)