lifejacket 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,8 +24,8 @@ from .constants import (
24
24
  SandwichFormationMethods,
25
25
  SmallSampleCorrections,
26
26
  )
27
- from .form_adaptive_meat_adjustments_directly import (
28
- form_adaptive_meat_adjustments_directly,
27
+ from .form_adjusted_meat_adjustments_directly import (
28
+ form_adjusted_meat_adjustments_directly,
29
29
  )
30
30
  from . import input_checks
31
31
  from . import get_datum_for_blowup_supervised_learning
@@ -37,9 +37,9 @@ from .helper_functions import (
37
37
  calculate_beta_dim,
38
38
  collect_all_post_update_betas,
39
39
  construct_beta_index_by_policy_num_map,
40
- extract_action_and_policy_by_decision_time_by_user_id,
40
+ extract_action_and_policy_by_decision_time_by_subject_id,
41
41
  flatten_params,
42
- get_in_study_df_column,
42
+ get_active_df_column,
43
43
  get_min_time_by_policy_num,
44
44
  get_radon_nikodym_weight,
45
45
  load_function_from_same_named_file,
@@ -61,7 +61,7 @@ def cli():
61
61
 
62
62
  # TODO: Check all help strings for accuracy.
63
63
  # TODO: Deal with NA, -1, etc policy numbers
64
- # TODO: Make sure in study is never on for more than one stretch EDIT: unclear if
64
+ # TODO: Make sure in deployment is never on for more than one stretch EDIT: unclear if
65
65
  # this will remain an invariant as we deal with more complicated data missingness
66
66
  # TODO: I think I'm agnostic to indexing of calendar times but should check because
67
67
  # otherwise need to add a check here to verify required format.
@@ -69,7 +69,7 @@ def cli():
69
69
  # Higher dimensional objects not supported. Not entirely sure what kind of "scalars" apply.
70
70
  @cli.command(name="analyze")
71
71
  @click.option(
72
- "--study_df_pickle",
72
+ "--analysis_df_pickle",
73
73
  type=click.File("rb"),
74
74
  help="Pickled pandas dataframe in correct format (see contract/readme).",
75
75
  required=True,
@@ -83,7 +83,7 @@ def cli():
83
83
  @click.option(
84
84
  "--action_prob_func_args_pickle",
85
85
  type=click.File("rb"),
86
- help="Pickled dictionary that contains the action probability function arguments for all decision times for all users.",
86
+ help="Pickled dictionary that contains the action probability function arguments for all decision times for all subjects.",
87
87
  required=True,
88
88
  )
89
89
  @click.option(
@@ -95,7 +95,7 @@ def cli():
95
95
  @click.option(
96
96
  "--alg_update_func_filename",
97
97
  type=click.Path(exists=True),
98
- help="File that contains the per-user update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
98
+ help="File that contains the per-subject update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
99
99
  required=True,
100
100
  )
101
101
  @click.option(
@@ -107,7 +107,7 @@ def cli():
107
107
  @click.option(
108
108
  "--alg_update_func_args_pickle",
109
109
  type=click.File("rb"),
110
- help="Pickled dictionary that contains the algorithm update function arguments for all update times for all users.",
110
+ help="Pickled dictionary that contains the algorithm update function arguments for all update times for all subjects.",
111
111
  required=True,
112
112
  )
113
113
  @click.option(
@@ -137,7 +137,7 @@ def cli():
137
137
  @click.option(
138
138
  "--inference_func_filename",
139
139
  type=click.Path(exists=True),
140
- help="File that contains the per-user loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
140
+ help="File that contains the per-subject loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
141
141
  required=True,
142
142
  )
143
143
  @click.option(
@@ -155,56 +155,56 @@ def cli():
155
155
  @click.option(
156
156
  "--theta_calculation_func_filename",
157
157
  type=click.Path(exists=True),
158
- help="Path to file that allows one to actually calculate a theta estimate given the study dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
158
+ help="Path to file that allows one to actually calculate a theta estimate given the analysis dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
159
159
  required=True,
160
160
  )
161
161
  @click.option(
162
- "--in_study_col_name",
162
+ "--active_col_name",
163
163
  type=str,
164
164
  required=True,
165
- help="Name of the binary column in the study dataframe that indicates whether a user is in the study.",
165
+ help="Name of the binary column in the analysis dataframe that indicates whether a subject is in the deployment.",
166
166
  )
167
167
  @click.option(
168
168
  "--action_col_name",
169
169
  type=str,
170
170
  required=True,
171
- help="Name of the binary column in the study dataframe that indicates which action was taken.",
171
+ help="Name of the binary column in the analysis dataframe that indicates which action was taken.",
172
172
  )
173
173
  @click.option(
174
174
  "--policy_num_col_name",
175
175
  type=str,
176
176
  required=True,
177
- help="Name of the column in the study dataframe that indicates the policy number in use.",
177
+ help="Name of the column in the analysis dataframe that indicates the policy number in use.",
178
178
  )
179
179
  @click.option(
180
180
  "--calendar_t_col_name",
181
181
  type=str,
182
182
  required=True,
183
- help="Name of the column in the study dataframe that indicates calendar time (shared integer index across users).",
183
+ help="Name of the column in the analysis dataframe that indicates calendar time (shared integer index across subjects).",
184
184
  )
185
185
  @click.option(
186
- "--user_id_col_name",
186
+ "--subject_id_col_name",
187
187
  type=str,
188
188
  required=True,
189
- help="Name of the column in the study dataframe that indicates user id.",
189
+ help="Name of the column in the analysis dataframe that indicates subject id.",
190
190
  )
191
191
  @click.option(
192
192
  "--action_prob_col_name",
193
193
  type=str,
194
194
  required=True,
195
- help="Name of the column in the study dataframe that gives action one probabilities.",
195
+ help="Name of the column in the analysis dataframe that gives action one probabilities.",
196
196
  )
197
197
  @click.option(
198
198
  "--reward_col_name",
199
199
  type=str,
200
200
  required=True,
201
- help="Name of the column in the study dataframe that gives rewards.",
201
+ help="Name of the column in the analysis dataframe that gives rewards.",
202
202
  )
203
203
  @click.option(
204
204
  "--suppress_interactive_data_checks",
205
205
  type=bool,
206
206
  default=False,
207
- help="Flag to suppress any data checks that require user input. This is suitable for tests and large simulations",
207
+ help="Flag to suppress any data checks that require subject input. This is suitable for tests and large simulations",
208
208
  )
209
209
  @click.option(
210
210
  "--suppress_all_data_checks",
@@ -232,13 +232,13 @@ def cli():
232
232
  help="Flag to collect data for supervised learning blowup detection. This will write a single datum and label to a file in the same directory as the input files.",
233
233
  )
234
234
  @click.option(
235
- "--form_adaptive_meat_adjustments_explicitly",
235
+ "--form_adjusted_meat_adjustments_explicitly",
236
236
  type=bool,
237
237
  default=False,
238
- help="If True, explicitly forms the per-user meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
238
+ help="If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
239
239
  )
240
240
  @click.option(
241
- "--stabilize_joint_adaptive_bread_inverse",
241
+ "--stabilize_joint_adjusted_bread_inverse",
242
242
  type=bool,
243
243
  default=True,
244
244
  help="If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning thresholds.",
@@ -248,7 +248,7 @@ def analyze_dataset_wrapper(**kwargs):
248
248
  This function is a wrapper around analyze_dataset to facilitate command line use.
249
249
 
250
250
  From the command line, we will take pickles and filenames for Python objects.
251
- Unpickle/load files here for passing to the implementation function, which
251
+ We unpickle/load files here for passing to the implementation function, which
252
252
  may also be called in its own right with in-memory objects.
253
253
 
254
254
  See analyze_dataset for the underlying details.
@@ -256,18 +256,20 @@ def analyze_dataset_wrapper(**kwargs):
256
256
  Returns: None
257
257
  """
258
258
 
259
- # Pass along the folder the study dataframe is in as the output folder.
260
- # Do it now because we will be removing the study dataframe pickle from kwargs.
261
- kwargs["output_dir"] = pathlib.Path(kwargs["study_df_pickle"].name).parent.resolve()
259
+ # Pass along the folder the analysis dataframe is in as the output folder.
260
+ # Do it now because we will be removing the analysis dataframe pickle from kwargs.
261
+ kwargs["output_dir"] = pathlib.Path(
262
+ kwargs["analysis_df_pickle"].name
263
+ ).parent.resolve()
262
264
 
263
265
  # Unpickle pickles and replace those args in kwargs
264
- kwargs["study_df"] = pickle.load(kwargs["study_df_pickle"])
266
+ kwargs["analysis_df"] = pickle.load(kwargs["analysis_df_pickle"])
265
267
  kwargs["action_prob_func_args"] = pickle.load(
266
268
  kwargs["action_prob_func_args_pickle"]
267
269
  )
268
270
  kwargs["alg_update_func_args"] = pickle.load(kwargs["alg_update_func_args_pickle"])
269
271
 
270
- kwargs.pop("study_df_pickle")
272
+ kwargs.pop("analysis_df_pickle")
271
273
  kwargs.pop("action_prob_func_args_pickle")
272
274
  kwargs.pop("alg_update_func_args_pickle")
273
275
 
@@ -295,7 +297,7 @@ def analyze_dataset_wrapper(**kwargs):
295
297
 
296
298
  def analyze_dataset(
297
299
  output_dir: pathlib.Path | str,
298
- study_df: pd.DataFrame,
300
+ analysis_df: pd.DataFrame,
299
301
  action_prob_func: Callable,
300
302
  action_prob_func_args: dict[int, Any],
301
303
  action_prob_func_args_beta_index: int,
@@ -310,19 +312,19 @@ def analyze_dataset(
310
312
  inference_func_type: str,
311
313
  inference_func_args_theta_index: int,
312
314
  theta_calculation_func: Callable[[pd.DataFrame], jnp.ndarray],
313
- in_study_col_name: str,
315
+ active_col_name: str,
314
316
  action_col_name: str,
315
317
  policy_num_col_name: str,
316
318
  calendar_t_col_name: str,
317
- user_id_col_name: str,
319
+ subject_id_col_name: str,
318
320
  action_prob_col_name: str,
319
321
  reward_col_name: str,
320
322
  suppress_interactive_data_checks: bool,
321
323
  suppress_all_data_checks: bool,
322
324
  small_sample_correction: str,
323
325
  collect_data_for_blowup_supervised_learning: bool,
324
- form_adaptive_meat_adjustments_explicitly: bool,
325
- stabilize_joint_adaptive_bread_inverse: bool,
326
+ form_adjusted_meat_adjustments_explicitly: bool,
327
+ stabilize_joint_adjusted_bread_inverse: bool,
326
328
  ) -> None:
327
329
  """
328
330
  Analyzes a dataset to provide a parameter estimate and an estimate of its variance using adaptive and classical sandwich estimators.
@@ -337,8 +339,8 @@ def analyze_dataset(
337
339
  Parameters:
338
340
  output_dir (pathlib.Path | str):
339
341
  Directory in which to save output files.
340
- study_df (pd.DataFrame):
341
- DataFrame containing the study data.
342
+ analysis_df (pd.DataFrame):
343
+ DataFrame containing the deployment data.
342
344
  action_prob_func (callable):
343
345
  Action probability function.
344
346
  action_prob_func_args (dict[int, Any]):
@@ -364,21 +366,21 @@ def analyze_dataset(
364
366
  inference_func_args_theta_index (int):
365
367
  Index for theta in inference function arguments.
366
368
  theta_calculation_func (callable):
367
- Function to estimate theta from the study DataFrame.
368
- in_study_col_name (str):
369
- Column name indicating if a user is in the study in the study dataframe.
369
+ Function to estimate theta from the analysis dataframe.
370
+ active_col_name (str):
371
+ Column name indicating if a subject is active in the analysis dataframe.
370
372
  action_col_name (str):
371
- Column name for actions in the study dataframe.
373
+ Column name for actions in the analysis dataframe.
372
374
  policy_num_col_name (str):
373
- Column name for policy numbers in the study dataframe.
375
+ Column name for policy numbers in the analysis dataframe.
374
376
  calendar_t_col_name (str):
375
- Column name for calendar time in the study dataframe.
376
- user_id_col_name (str):
377
- Column name for user IDs in the study dataframe.
377
+ Column name for calendar time in the analysis dataframe.
378
+ subject_id_col_name (str):
379
+ Column name for subject IDs in the analysis dataframe.
378
380
  action_prob_col_name (str):
379
- Column name for action probabilities in the study dataframe.
381
+ Column name for action probabilities in the analysis dataframe.
380
382
  reward_col_name (str):
381
- Column name for rewards in the study dataframe.
383
+ Column name for rewards in the analysis dataframe.
382
384
  suppress_interactive_data_checks (bool):
383
385
  Whether to suppress interactive data checks. This should be used in simulations, for example.
384
386
  suppress_all_data_checks (bool):
@@ -387,11 +389,11 @@ def analyze_dataset(
387
389
  Type of small sample correction to apply.
388
390
  collect_data_for_blowup_supervised_learning (bool):
389
391
  Whether to collect data for doing supervised learning about adaptive sandwich blowup.
390
- form_adaptive_meat_adjustments_explicitly (bool):
391
- If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
392
+ form_adjusted_meat_adjustments_explicitly (bool):
393
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
392
394
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
393
395
  adaptive sandwich is formed without doing this.
394
- stabilize_joint_adaptive_bread_inverse (bool):
396
+ stabilize_joint_adjusted_bread_inverse (bool):
395
397
  If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
396
398
  thresholds.
397
399
 
@@ -406,19 +408,19 @@ def analyze_dataset(
406
408
  level=logging.INFO,
407
409
  )
408
410
 
409
- theta_est = jnp.array(theta_calculation_func(study_df))
411
+ theta_est = jnp.array(theta_calculation_func(analysis_df))
410
412
 
411
413
  beta_dim = calculate_beta_dim(
412
414
  action_prob_func_args, action_prob_func_args_beta_index
413
415
  )
414
416
  if not suppress_all_data_checks:
415
417
  input_checks.perform_first_wave_input_checks(
416
- study_df,
417
- in_study_col_name,
418
+ analysis_df,
419
+ active_col_name,
418
420
  action_col_name,
419
421
  policy_num_col_name,
420
422
  calendar_t_col_name,
421
- user_id_col_name,
423
+ subject_id_col_name,
422
424
  action_prob_col_name,
423
425
  reward_col_name,
424
426
  action_prob_func,
@@ -439,7 +441,7 @@ def analyze_dataset(
439
441
 
440
442
  beta_index_by_policy_num, initial_policy_num = (
441
443
  construct_beta_index_by_policy_num_map(
442
- study_df, policy_num_col_name, in_study_col_name
444
+ analysis_df, policy_num_col_name, active_col_name
443
445
  )
444
446
  )
445
447
 
@@ -447,11 +449,11 @@ def analyze_dataset(
447
449
  beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
448
450
  )
449
451
 
450
- action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id = (
451
- extract_action_and_policy_by_decision_time_by_user_id(
452
- study_df,
453
- user_id_col_name,
454
- in_study_col_name,
452
+ action_by_decision_time_by_subject_id, policy_num_by_decision_time_by_subject_id = (
453
+ extract_action_and_policy_by_decision_time_by_subject_id(
454
+ analysis_df,
455
+ subject_id_col_name,
456
+ active_col_name,
455
457
  calendar_t_col_name,
456
458
  action_col_name,
457
459
  policy_num_col_name,
@@ -459,45 +461,45 @@ def analyze_dataset(
459
461
  )
460
462
 
461
463
  (
462
- inference_func_args_by_user_id,
464
+ inference_func_args_by_subject_id,
463
465
  inference_func_args_action_prob_index,
464
- inference_action_prob_decision_times_by_user_id,
466
+ inference_action_prob_decision_times_by_subject_id,
465
467
  ) = process_inference_func_args(
466
468
  inference_func,
467
469
  inference_func_args_theta_index,
468
- study_df,
470
+ analysis_df,
469
471
  theta_est,
470
472
  action_prob_col_name,
471
473
  calendar_t_col_name,
472
- user_id_col_name,
473
- in_study_col_name,
474
+ subject_id_col_name,
475
+ active_col_name,
474
476
  )
475
477
 
476
- # Use a per-user weighted estimating function stacking functino to derive classical and joint
478
+ # Use a per-subject weighted estimating function stacking functino to derive classical and joint
477
479
  # adaptive meat and inverse bread matrices. This is facilitated because the *value* of the
478
480
  # weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
479
481
  logger.info(
480
- "Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across users."
482
+ "Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across subjects."
481
483
  )
482
484
 
483
- user_ids = jnp.array(study_df[user_id_col_name].unique())
485
+ subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
484
486
  (
485
- stabilized_joint_adaptive_bread_inverse_matrix,
486
- raw_joint_adaptive_bread_inverse_matrix,
487
- joint_adaptive_meat_matrix,
488
- joint_adaptive_sandwich_matrix,
487
+ stabilized_joint_adjusted_bread_inverse_matrix,
488
+ raw_joint_adjusted_bread_inverse_matrix,
489
+ joint_adjusted_meat_matrix,
490
+ joint_adjusted_sandwich_matrix,
489
491
  classical_bread_inverse_matrix,
490
492
  classical_meat_matrix,
491
493
  classical_sandwich_var_estimate,
492
494
  avg_estimating_function_stack,
493
- per_user_estimating_function_stacks,
494
- per_user_adaptive_corrections,
495
- per_user_classical_corrections,
496
- per_user_adaptive_meat_adjustments,
497
- ) = construct_classical_and_adaptive_sandwiches(
495
+ per_subject_estimating_function_stacks,
496
+ per_subject_adjusted_corrections,
497
+ per_subject_classical_corrections,
498
+ per_subject_adjusted_meat_adjustments,
499
+ ) = construct_classical_and_adjusted_sandwiches(
498
500
  theta_est,
499
501
  all_post_update_betas,
500
- user_ids,
502
+ subject_ids,
501
503
  action_prob_func,
502
504
  action_prob_func_args_beta_index,
503
505
  alg_update_func,
@@ -511,31 +513,27 @@ def analyze_dataset(
511
513
  inference_func_args_theta_index,
512
514
  inference_func_args_action_prob_index,
513
515
  action_prob_func_args,
514
- policy_num_by_decision_time_by_user_id,
516
+ policy_num_by_decision_time_by_subject_id,
515
517
  initial_policy_num,
516
518
  beta_index_by_policy_num,
517
- inference_func_args_by_user_id,
518
- inference_action_prob_decision_times_by_user_id,
519
+ inference_func_args_by_subject_id,
520
+ inference_action_prob_decision_times_by_subject_id,
519
521
  alg_update_func_args,
520
- action_by_decision_time_by_user_id,
522
+ action_by_decision_time_by_subject_id,
521
523
  suppress_all_data_checks,
522
524
  suppress_interactive_data_checks,
523
525
  small_sample_correction,
524
- form_adaptive_meat_adjustments_explicitly,
525
- stabilize_joint_adaptive_bread_inverse,
526
- study_df,
527
- in_study_col_name,
526
+ form_adjusted_meat_adjustments_explicitly,
527
+ stabilize_joint_adjusted_bread_inverse,
528
+ analysis_df,
529
+ active_col_name,
528
530
  action_col_name,
529
531
  calendar_t_col_name,
530
- user_id_col_name,
532
+ subject_id_col_name,
531
533
  action_prob_func_args,
532
534
  action_prob_col_name,
533
535
  )
534
536
 
535
- joint_adaptive_bread_inverse_cond = jnp.linalg.cond(
536
- stabilized_joint_adaptive_bread_inverse_matrix
537
- )
538
-
539
537
  theta_dim = len(theta_est)
540
538
  if not suppress_all_data_checks:
541
539
  input_checks.require_estimating_functions_sum_to_zero(
@@ -547,18 +545,18 @@ def analyze_dataset(
547
545
 
548
546
  # This bottom right corner of the joint (betas and theta) variance matrix is the portion
549
547
  # corresponding to just theta.
550
- adaptive_sandwich_var_estimate = joint_adaptive_sandwich_matrix[
548
+ adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
551
549
  -theta_dim:, -theta_dim:
552
550
  ]
553
551
 
554
552
  # Check for negative diagonal elements and set them to zero if found
555
- adaptive_diagonal = np.diag(adaptive_sandwich_var_estimate)
553
+ adaptive_diagonal = np.diag(adjusted_sandwich_var_estimate)
556
554
  if np.any(adaptive_diagonal < 0):
557
555
  logger.warning(
558
556
  "Found negative diagonal elements in adaptive sandwich variance estimate. Setting them to zero."
559
557
  )
560
558
  np.fill_diagonal(
561
- adaptive_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
559
+ adjusted_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
562
560
  )
563
561
 
564
562
  logger.info("Writing results to file...")
@@ -567,7 +565,7 @@ def analyze_dataset(
567
565
 
568
566
  analysis_dict = {
569
567
  "theta_est": theta_est,
570
- "adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
568
+ "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
571
569
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
572
570
  }
573
571
  with open(output_folder_abs_path / "analysis.pkl", "wb") as f:
@@ -576,21 +574,29 @@ def analyze_dataset(
576
574
  f,
577
575
  )
578
576
 
577
+ joint_adjusted_bread_inverse_cond = jnp.linalg.cond(
578
+ raw_joint_adjusted_bread_inverse_matrix
579
+ )
580
+ logger.info(
581
+ "Joint adjusted bread inverse condition number: %f",
582
+ joint_adjusted_bread_inverse_cond,
583
+ )
584
+
579
585
  debug_pieces_dict = {
580
586
  "theta_est": theta_est,
581
- "adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
587
+ "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
582
588
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
583
- "raw_joint_bread_inverse_matrix": raw_joint_adaptive_bread_inverse_matrix,
584
- "stabilized_joint_bread_inverse_matrix": stabilized_joint_adaptive_bread_inverse_matrix,
585
- "joint_meat_matrix": joint_adaptive_meat_matrix,
589
+ "raw_joint_bread_inverse_matrix": raw_joint_adjusted_bread_inverse_matrix,
590
+ "stabilized_joint_bread_inverse_matrix": stabilized_joint_adjusted_bread_inverse_matrix,
591
+ "joint_meat_matrix": joint_adjusted_meat_matrix,
586
592
  "classical_bread_inverse_matrix": classical_bread_inverse_matrix,
587
593
  "classical_meat_matrix": classical_meat_matrix,
588
- "all_estimating_function_stacks": per_user_estimating_function_stacks,
589
- "joint_bread_inverse_condition_number": joint_adaptive_bread_inverse_cond,
594
+ "all_estimating_function_stacks": per_subject_estimating_function_stacks,
595
+ "joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
590
596
  "all_post_update_betas": all_post_update_betas,
591
- "per_user_adaptive_corrections": per_user_adaptive_corrections,
592
- "per_user_classical_corrections": per_user_classical_corrections,
593
- "per_user_adaptive_meat_adjustments": per_user_adaptive_meat_adjustments,
597
+ "per_subject_adjusted_corrections": per_subject_adjusted_corrections,
598
+ "per_subject_classical_corrections": per_subject_classical_corrections,
599
+ "per_subject_adjusted_meat_adjustments": per_subject_adjusted_meat_adjustments,
594
600
  }
595
601
  with open(output_folder_abs_path / "debug_pieces.pkl", "wb") as f:
596
602
  pickle.dump(
@@ -600,25 +606,25 @@ def analyze_dataset(
600
606
 
601
607
  if collect_data_for_blowup_supervised_learning:
602
608
  datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
603
- raw_joint_adaptive_bread_inverse_matrix,
604
- joint_adaptive_bread_inverse_cond,
609
+ raw_joint_adjusted_bread_inverse_matrix,
610
+ joint_adjusted_bread_inverse_cond,
605
611
  avg_estimating_function_stack,
606
- per_user_estimating_function_stacks,
612
+ per_subject_estimating_function_stacks,
607
613
  all_post_update_betas,
608
- study_df,
609
- in_study_col_name,
614
+ analysis_df,
615
+ active_col_name,
610
616
  calendar_t_col_name,
611
617
  action_prob_col_name,
612
- user_id_col_name,
618
+ subject_id_col_name,
613
619
  reward_col_name,
614
620
  theta_est,
615
- adaptive_sandwich_var_estimate,
616
- user_ids,
621
+ adjusted_sandwich_var_estimate,
622
+ subject_ids,
617
623
  beta_dim,
618
624
  theta_dim,
619
625
  initial_policy_num,
620
626
  beta_index_by_policy_num,
621
- policy_num_by_decision_time_by_user_id,
627
+ policy_num_by_decision_time_by_subject_id,
622
628
  theta_calculation_func,
623
629
  action_prob_func,
624
630
  action_prob_func_args_beta_index,
@@ -626,16 +632,16 @@ def analyze_dataset(
626
632
  inference_func_type,
627
633
  inference_func_args_theta_index,
628
634
  inference_func_args_action_prob_index,
629
- inference_action_prob_decision_times_by_user_id,
635
+ inference_action_prob_decision_times_by_subject_id,
630
636
  action_prob_func_args,
631
- action_by_decision_time_by_user_id,
637
+ action_by_decision_time_by_subject_id,
632
638
  )
633
639
 
634
640
  with open(output_folder_abs_path / "supervised_learning_datum.pkl", "wb") as f:
635
641
  pickle.dump(datum_and_label_dict, f)
636
642
 
637
643
  print(f"\nParameter estimate:\n {theta_est}")
638
- print(f"\nAdaptive sandwich variance estimate:\n {adaptive_sandwich_var_estimate}")
644
+ print(f"\nAdjusted sandwich variance estimate:\n {adjusted_sandwich_var_estimate}")
639
645
  print(
640
646
  f"\nClassical sandwich variance estimate:\n {classical_sandwich_var_estimate}\n"
641
647
  )
@@ -646,15 +652,15 @@ def analyze_dataset(
646
652
  def process_inference_func_args(
647
653
  inference_func: callable,
648
654
  inference_func_args_theta_index: int,
649
- study_df: pd.DataFrame,
655
+ analysis_df: pd.DataFrame,
650
656
  theta_est: jnp.ndarray,
651
657
  action_prob_col_name: str,
652
658
  calendar_t_col_name: str,
653
- user_id_col_name: str,
654
- in_study_col_name: str,
659
+ subject_id_col_name: str,
660
+ active_col_name: str,
655
661
  ) -> tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int]:
656
662
  """
657
- Collects the inference function arguments for each user from the study DataFrame.
663
+ Collects the inference function arguments for each subject from the analysis DataFrame.
658
664
 
659
665
  Note that theta and action probabilities, if present, will be replaced later
660
666
  so that the function can be differentiated with respect to shared versions
@@ -665,32 +671,32 @@ def process_inference_func_args(
665
671
  The inference function to be used.
666
672
  inference_func_args_theta_index (int):
667
673
  The index of the theta parameter in the inference function's arguments.
668
- study_df (pandas.DataFrame):
669
- The study DataFrame.
674
+ analysis_df (pandas.DataFrame):
675
+ The analysis DataFrame.
670
676
  theta_est (jnp.ndarray):
671
677
  The estimate of the parameter vector.
672
678
  action_prob_col_name (str):
673
- The name of the column in the study DataFrame that gives action probabilities.
679
+ The name of the column in the analysis DataFrame that gives action probabilities.
674
680
  calendar_t_col_name (str):
675
- The name of the column in the study DataFrame that indicates calendar time.
676
- user_id_col_name (str):
677
- The name of the column in the study DataFrame that indicates user ID.
678
- in_study_col_name (str):
679
- The name of the binary column in the study DataFrame that indicates whether a user is in the study.
681
+ The name of the column in the analysis DataFrame that indicates calendar time.
682
+ subject_id_col_name (str):
683
+ The name of the column in the analysis DataFrame that indicates subject ID.
684
+ active_col_name (str):
685
+ The name of the binary column in the analysis DataFrame that indicates whether a subject is in the deployment.
680
686
  Returns:
681
687
  tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int, dict[collections.abc.Hashable, jnp.ndarray[int]]]:
682
688
  A tuple containing
683
- - the inference function arguments dictionary for each user
689
+ - the inference function arguments dictionary for each subject
684
690
  - the index of the action probabilities argument
685
- - a dictionary mapping user IDs to the decision times to which action probabilities correspond
691
+ - a dictionary mapping subject IDs to the decision times to which action probabilities correspond
686
692
  """
687
693
 
688
694
  num_args = inference_func.__code__.co_argcount
689
695
  inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
690
- inference_func_args_by_user_id = {}
696
+ inference_func_args_by_subject_id = {}
691
697
 
692
698
  inference_func_args_action_prob_index = -1
693
- inference_action_prob_decision_times_by_user_id = {}
699
+ inference_action_prob_decision_times_by_subject_id = {}
694
700
 
695
701
  using_action_probs = action_prob_col_name in inference_func_arg_names
696
702
  if using_action_probs:
@@ -698,34 +704,36 @@ def process_inference_func_args(
698
704
  action_prob_col_name
699
705
  )
700
706
 
701
- for user_id in study_df[user_id_col_name].unique():
702
- user_args_list = []
703
- filtered_user_data = study_df.loc[study_df[user_id_col_name] == user_id]
707
+ for subject_id in analysis_df[subject_id_col_name].unique():
708
+ subject_args_list = []
709
+ filtered_subject_data = analysis_df.loc[
710
+ analysis_df[subject_id_col_name] == subject_id
711
+ ]
704
712
  for idx, col_name in enumerate(inference_func_arg_names):
705
713
  if idx == inference_func_args_theta_index:
706
- user_args_list.append(theta_est)
714
+ subject_args_list.append(theta_est)
707
715
  continue
708
- user_args_list.append(
709
- get_in_study_df_column(filtered_user_data, col_name, in_study_col_name)
716
+ subject_args_list.append(
717
+ get_active_df_column(filtered_subject_data, col_name, active_col_name)
710
718
  )
711
- inference_func_args_by_user_id[user_id] = tuple(user_args_list)
719
+ inference_func_args_by_subject_id[subject_id] = tuple(subject_args_list)
712
720
  if using_action_probs:
713
- inference_action_prob_decision_times_by_user_id[user_id] = (
714
- get_in_study_df_column(
715
- filtered_user_data, calendar_t_col_name, in_study_col_name
721
+ inference_action_prob_decision_times_by_subject_id[subject_id] = (
722
+ get_active_df_column(
723
+ filtered_subject_data, calendar_t_col_name, active_col_name
716
724
  )
717
725
  )
718
726
 
719
727
  return (
720
- inference_func_args_by_user_id,
728
+ inference_func_args_by_subject_id,
721
729
  inference_func_args_action_prob_index,
722
- inference_action_prob_decision_times_by_user_id,
730
+ inference_action_prob_decision_times_by_subject_id,
723
731
  )
724
732
 
725
733
 
726
- def single_user_weighted_estimating_function_stacker(
734
+ def single_subject_weighted_estimating_function_stacker(
727
735
  beta_dim: int,
728
- user_id: collections.abc.Hashable,
736
+ subject_id: collections.abc.Hashable,
729
737
  action_prob_func: callable,
730
738
  algorithm_estimating_func: callable,
731
739
  inference_estimating_func: callable,
@@ -759,12 +767,12 @@ def single_user_weighted_estimating_function_stacker(
759
767
  beta_dim (list[jnp.ndarray]):
760
768
  A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
761
769
 
762
- user_id (collections.abc.Hashable):
763
- The user ID for which to compute the weighted estimating function stack.
770
+ subject_id (collections.abc.Hashable):
771
+ The subject ID for which to compute the weighted estimating function stack.
764
772
 
765
773
  action_prob_func (callable):
766
774
  The function used to compute the probability of action 1 at a given decision time for
767
- a particular user given their state and the algorithm parameters.
775
+ a particular subject given their state and the algorithm parameters.
768
776
 
769
777
  algorithm_estimating_func (callable):
770
778
  The estimating function that corresponds to algorithm updates.
@@ -779,9 +787,9 @@ def single_user_weighted_estimating_function_stacker(
779
787
  The index of the theta parameter in the inference loss or estimating function arguments.
780
788
 
781
789
  action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
782
- A map from decision times to tuples of arguments for this user for the action
790
+ A map from decision times to tuples of arguments for this subject for the action
783
791
  probability function. This is for all decision times (args are an empty
784
- tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
792
+ tuple if they are not in the deployment). Should be sorted by decision time. NOTE THAT THESE
785
793
  ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
786
794
  will occur.
787
795
 
@@ -792,21 +800,21 @@ def single_user_weighted_estimating_function_stacker(
792
800
 
793
801
  threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
794
802
  A map from policy numbers to tuples containing the arguments for
795
- the corresponding estimating functions for this user, with the shared betas threaded in
803
+ the corresponding estimating functions for this subject, with the shared betas threaded in
796
804
  for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
797
805
  should be sorted.
798
806
 
799
807
  threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
800
808
  A tuple containing the arguments for the inference
801
- estimating function for this user, with the shared betas threaded in for differentiation.
809
+ estimating function for this subject, with the shared betas threaded in for differentiation.
802
810
 
803
811
  policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
804
812
  A dictionary mapping decision times to the policy number in use. This may be
805
- user-specific. Should be sorted by decision time. Only applies to in-study decision
813
+ subject-specific. Should be sorted by decision time. Only applies to active decision
806
814
  times!
807
815
 
808
816
  action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
809
- A dictionary mapping decision times to actions taken. Only applies to in-study decision
817
+ A dictionary mapping decision times to actions taken. Only applies to active decision
810
818
  times!
811
819
 
812
820
  beta_index_by_policy_num (dict[int | float, int]):
@@ -814,19 +822,21 @@ def single_user_weighted_estimating_function_stacker(
814
822
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
815
823
 
816
824
  Returns:
817
- jnp.ndarray: A 1-D JAX NumPy array representing the user's weighted estimating function
825
+ jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
818
826
  stack.
819
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's adaptive meat contribution.
820
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical meat contribution.
821
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical bread contribution.
827
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adaptive meat contribution.
828
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
829
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
822
830
  """
823
831
 
824
- logger.info("Computing weighted estimating function stack for user %s.", user_id)
832
+ logger.info(
833
+ "Computing weighted estimating function stack for subject %s.", subject_id
834
+ )
825
835
 
826
836
  # First, reformat the supplied data into more convenient structures.
827
837
 
828
838
  # 1. Form a dictionary mapping policy numbers to the first time they were
829
- # applicable (for this user). Note that this includes ALL policies, initial
839
+ # applicable (for this subject). Note that this includes ALL policies, initial
830
840
  # fallbacks included.
831
841
  # Collect the first time after the first update separately for convenience.
832
842
  # These are both used to form the Radon-Nikodym weights for the right times.
@@ -835,38 +845,38 @@ def single_user_weighted_estimating_function_stacker(
835
845
  beta_index_by_policy_num,
836
846
  )
837
847
 
838
- # 2. Get the start and end times for this user.
839
- user_start_time = math.inf
840
- user_end_time = -math.inf
848
+ # 2. Get the start and end times for this subject.
849
+ subject_start_time = math.inf
850
+ subject_end_time = -math.inf
841
851
  for decision_time in action_by_decision_time:
842
- user_start_time = min(user_start_time, decision_time)
843
- user_end_time = max(user_end_time, decision_time)
852
+ subject_start_time = min(subject_start_time, decision_time)
853
+ subject_end_time = max(subject_end_time, decision_time)
844
854
 
845
855
  # 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
846
856
  logger.info(
847
- "Computing the algorithm component of the weighted estimating function stack for user %s.",
848
- user_id,
857
+ "Computing the algorithm component of the weighted estimating function stack for subject %s.",
858
+ subject_id,
849
859
  )
850
860
 
851
- in_study_action_prob_func_args = [
861
+ active_action_prob_func_args = [
852
862
  args for args in action_prob_func_args_by_decision_time.values() if args
853
863
  ]
854
- in_study_betas_list_by_decision_time_index = jnp.array(
864
+ active_betas_list_by_decision_time_index = jnp.array(
855
865
  [
856
866
  action_prob_func_args[action_prob_func_args_beta_index]
857
- for action_prob_func_args in in_study_action_prob_func_args
867
+ for action_prob_func_args in active_action_prob_func_args
858
868
  ]
859
869
  )
860
- in_study_actions_list_by_decision_time_index = jnp.array(
870
+ active_actions_list_by_decision_time_index = jnp.array(
861
871
  list(action_by_decision_time.values())
862
872
  )
863
873
 
864
874
  # Sort the threaded args by decision time to be cautious. We check if the
865
- # user id is present in the user args dict because we may call this on a
866
- # subset of the user arg dict when we are batching arguments by shape
875
+ # subject id is present in the subject args dict because we may call this on a
876
+ # subset of the subject arg dict when we are batching arguments by shape
867
877
  sorted_threaded_action_prob_args_by_decision_time = {
868
878
  decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
869
- for decision_time in range(user_start_time, user_end_time + 1)
879
+ for decision_time in range(subject_start_time, subject_end_time + 1)
870
880
  if decision_time in threaded_action_prob_func_args_by_decision_time
871
881
  }
872
882
 
@@ -897,19 +907,19 @@ def single_user_weighted_estimating_function_stacker(
897
907
  # Just grab the original beta from the update function arguments. This is the same
898
908
  # value, but impervious to differentiation with respect to all_post_update_betas. The
899
909
  # args, on the other hand, are a function of all_post_update_betas.
900
- in_study_weights = jax.vmap(
910
+ active_weights = jax.vmap(
901
911
  fun=get_radon_nikodym_weight,
902
912
  in_axes=[0, None, None, 0] + batch_axes,
903
913
  out_axes=0,
904
914
  )(
905
- in_study_betas_list_by_decision_time_index,
915
+ active_betas_list_by_decision_time_index,
906
916
  action_prob_func,
907
917
  action_prob_func_args_beta_index,
908
- in_study_actions_list_by_decision_time_index,
918
+ active_actions_list_by_decision_time_index,
909
919
  *batched_threaded_arg_tensors,
910
920
  )
911
921
 
912
- in_study_index = 0
922
+ active_index = 0
913
923
  decision_time_to_all_weights_index_offset = min(
914
924
  sorted_threaded_action_prob_args_by_decision_time
915
925
  )
@@ -918,35 +928,35 @@ def single_user_weighted_estimating_function_stacker(
918
928
  decision_time,
919
929
  args,
920
930
  ) in sorted_threaded_action_prob_args_by_decision_time.items():
921
- all_weights_raw.append(in_study_weights[in_study_index] if args else 1.0)
922
- in_study_index += 1
931
+ all_weights_raw.append(active_weights[active_index] if args else 1.0)
932
+ active_index += 1
923
933
  all_weights = jnp.array(all_weights_raw)
924
934
 
925
935
  algorithm_component = jnp.concatenate(
926
936
  [
927
937
  # Here we compute a product of Radon-Nikodym weights
928
938
  # for all decision times after the first update and before the update
929
- # update under consideration took effect, for which the user was in the study.
939
+ # update under consideration took effect, for which the subject was in the deployment.
930
940
  (
931
941
  jnp.prod(
932
942
  all_weights[
933
- # The earliest time after the first update where the user was in
934
- # the study
943
+ # The earliest time after the first update where the subject was in
944
+ # the deployment
935
945
  max(
936
946
  first_time_after_first_update,
937
- user_start_time,
947
+ subject_start_time,
938
948
  )
939
949
  - decision_time_to_all_weights_index_offset :
940
- # One more than the latest time the user was in the study before the time
950
+ # One more than the latest time the subject was in the deployment before the time
941
951
  # the update under consideration first applied. Note the + 1 because range
942
952
  # does not include the right endpoint.
943
953
  min(
944
954
  min_time_by_policy_num.get(policy_num, math.inf),
945
- user_end_time + 1,
955
+ subject_end_time + 1,
946
956
  )
947
957
  - decision_time_to_all_weights_index_offset,
948
958
  ]
949
- # If the user exited the study before there were any updates,
959
+ # If the subject exited the deployment before there were any updates,
950
960
  # this variable will be None and the above code to grab a weight would
951
961
  # throw an error. Just use 1 to include the unweighted estimating function
952
962
  # if they have data to contribute to the update.
@@ -954,8 +964,8 @@ def single_user_weighted_estimating_function_stacker(
954
964
  else 1
955
965
  ) # Now use the above to weight the alg estimating function for this update
956
966
  * algorithm_estimating_func(*update_args)
957
- # If there are no arguments for the update function, the user is not yet in the
958
- # study, so we just add a zero vector contribution to the sum across users.
967
+ # If there are no arguments for the update function, the subject is not yet in the
968
+ # deployment, so we just add a zero vector contribution to the sum across subjects.
959
969
  # Note that after they exit, they still contribute all their data to later
960
970
  # updates.
961
971
  if update_args
@@ -974,17 +984,17 @@ def single_user_weighted_estimating_function_stacker(
974
984
  )
975
985
  # 4. Form the weighted inference estimating equation.
976
986
  logger.info(
977
- "Computing the inference component of the weighted estimating function stack for user %s.",
978
- user_id,
987
+ "Computing the inference component of the weighted estimating function stack for subject %s.",
988
+ subject_id,
979
989
  )
980
990
  inference_component = jnp.prod(
981
991
  all_weights[
982
- max(first_time_after_first_update, user_start_time)
983
- - decision_time_to_all_weights_index_offset : user_end_time
992
+ max(first_time_after_first_update, subject_start_time)
993
+ - decision_time_to_all_weights_index_offset : subject_end_time
984
994
  + 1
985
995
  - decision_time_to_all_weights_index_offset,
986
996
  ]
987
- # If the user exited the study before there were any updates,
997
+ # If the subject exited the deployment before there were any updates,
988
998
  # this variable will be None and the above code to grab a weight would
989
999
  # throw an error. Just use 1 to include the unweighted estimating function
990
1000
  # if they have data to contribute here (pretty sure everyone should?)
@@ -993,18 +1003,18 @@ def single_user_weighted_estimating_function_stacker(
993
1003
  ) * inference_estimating_func(*threaded_inference_func_args)
994
1004
 
995
1005
  # 5. Concatenate the two components to form the weighted estimating function stack for this
996
- # user.
1006
+ # subject.
997
1007
  weighted_stack = jnp.concatenate([algorithm_component, inference_component])
998
1008
 
999
1009
  # 6. Return the following outputs:
1000
- # a. The first is simply the weighted estimating function stack for this user. The average
1010
+ # a. The first is simply the weighted estimating function stack for this subject. The average
1001
1011
  # of these is what we differentiate with respect to theta to form the inverse adaptive joint
1002
1012
  # bread matrix, and we also compare that average to zero to check the estimating functions'
1003
1013
  # fidelity.
1004
- # b. The average outer product of these per-user stacks across users is the adaptive joint meat
1014
+ # b. The average outer product of these per-subject stacks across subjects is the adaptive joint meat
1005
1015
  # matrix, hence the second output.
1006
- # c. The third output is averaged across users to obtain the classical meat matrix.
1007
- # d. The fourth output is averaged across users to obtain the inverse classical bread
1016
+ # c. The third output is averaged across subjects to obtain the classical meat matrix.
1017
+ # d. The fourth output is averaged across subjects to obtain the inverse classical bread
1008
1018
  # matrix.
1009
1019
  return (
1010
1020
  weighted_stack,
@@ -1020,7 +1030,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1020
1030
  flattened_betas_and_theta: jnp.ndarray,
1021
1031
  beta_dim: int,
1022
1032
  theta_dim: int,
1023
- user_ids: jnp.ndarray,
1033
+ subject_ids: jnp.ndarray,
1024
1034
  action_prob_func: callable,
1025
1035
  action_prob_func_args_beta_index: int,
1026
1036
  alg_update_func: callable,
@@ -1033,29 +1043,31 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1033
1043
  inference_func_type: str,
1034
1044
  inference_func_args_theta_index: int,
1035
1045
  inference_func_args_action_prob_index: int,
1036
- action_prob_func_args_by_user_id_by_decision_time: dict[
1046
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
1037
1047
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
1038
1048
  ],
1039
- policy_num_by_decision_time_by_user_id: dict[
1049
+ policy_num_by_decision_time_by_subject_id: dict[
1040
1050
  collections.abc.Hashable, dict[int, int | float]
1041
1051
  ],
1042
1052
  initial_policy_num: int | float,
1043
1053
  beta_index_by_policy_num: dict[int | float, int],
1044
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1045
- inference_action_prob_decision_times_by_user_id: dict[
1054
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1055
+ inference_action_prob_decision_times_by_subject_id: dict[
1046
1056
  collections.abc.Hashable, list[int]
1047
1057
  ],
1048
- update_func_args_by_by_user_id_by_policy_num: dict[
1058
+ update_func_args_by_by_subject_id_by_policy_num: dict[
1049
1059
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
1050
1060
  ],
1051
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
1061
+ action_by_decision_time_by_subject_id: dict[
1062
+ collections.abc.Hashable, dict[int, int]
1063
+ ],
1052
1064
  suppress_all_data_checks: bool,
1053
1065
  suppress_interactive_data_checks: bool,
1054
1066
  ) -> tuple[
1055
1067
  jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
1056
1068
  ]:
1057
1069
  """
1058
- Computes the average weighted estimating function stack across all users, along with
1070
+ Computes the average weighted estimating function stack across all subjects, along with
1059
1071
  auxiliary values used to construct the adaptive and classical sandwich variances.
1060
1072
 
1061
1073
  Args:
@@ -1067,8 +1079,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1067
1079
  The dimension of each of the beta parameters.
1068
1080
  theta_dim (int):
1069
1081
  The dimension of the theta parameter.
1070
- user_ids (jnp.ndarray):
1071
- A 1D JAX NumPy array of user IDs.
1082
+ subject_ids (jnp.ndarray):
1083
+ A 1D JAX NumPy array of subject IDs.
1072
1084
  action_prob_func (callable):
1073
1085
  The action probability function.
1074
1086
  action_prob_func_args_beta_index (int):
@@ -1096,29 +1108,29 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1096
1108
  inference_func_args_action_prob_index (int):
1097
1109
  The index of action probabilities in the inference function arguments tuple, if
1098
1110
  applicable. -1 otherwise.
1099
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1100
- A dictionary mapping decision times to maps of user ids to the function arguments
1101
- required to compute action probabilities for this user.
1102
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1103
- A map of user ids to dictionaries mapping decision times to the policy number in use.
1104
- Only applies to in-study decision times!
1111
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1112
+ A dictionary mapping decision times to maps of subject ids to the function arguments
1113
+ required to compute action probabilities for this subject.
1114
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1115
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
1116
+ Only applies to active decision times!
1105
1117
  initial_policy_num (int | float):
1106
1118
  The policy number of the initial policy before any updates.
1107
1119
  beta_index_by_policy_num (dict[int | float, int]):
1108
1120
  A dictionary mapping policy numbers to the index of the corresponding beta in
1109
1121
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1110
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1111
- A dictionary mapping user IDs to their respective inference function arguments.
1112
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1113
- For each user, a list of decision times to which action probabilities correspond if
1114
- provided. Typically just in-study times if action probabilites are used in the inference
1122
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1123
+ A dictionary mapping subject IDs to their respective inference function arguments.
1124
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
1125
+ For each subject, a list of decision times to which action probabilities correspond if
1126
+ provided. Typically just active times if action probabilites are used in the inference
1115
1127
  loss or estimating function.
1116
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1117
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
1128
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1129
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
1118
1130
  to their respective update function arguments.
1119
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1120
- A dictionary mapping user IDs to their respective actions taken at each decision time.
1121
- Only applies to in-study decision times!
1131
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
1132
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
1133
+ Only applies to active decision times!
1122
1134
  suppress_all_data_checks (bool):
1123
1135
  If True, suppresses carrying out any data checks at all.
1124
1136
  suppress_interactive_data_checks (bool):
@@ -1132,10 +1144,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1132
1144
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1133
1145
  A tuple containing
1134
1146
  1. the average weighted estimating function stack
1135
- 2. the user-level adaptive meat matrix contributions
1136
- 3. the user-level classical meat matrix contributions
1137
- 4. the user-level inverse classical bread matrix contributions
1138
- 5. raw per-user weighted estimating function
1147
+ 2. the subject-level adaptive meat matrix contributions
1148
+ 3. the subject-level classical meat matrix contributions
1149
+ 4. the subject-level inverse classical bread matrix contributions
1150
+ 5. raw per-subject weighted estimating function
1139
1151
  stacks.
1140
1152
  """
1141
1153
 
@@ -1162,15 +1174,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1162
1174
  # supplied for the above functions, so that differentiation works correctly. The existing
1163
1175
  # values should be the same, but not connected to the parameter we are differentiating
1164
1176
  # with respect to. Note we will also find it useful below to have the action probability args
1165
- # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
1177
+ # nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
1166
1178
 
1167
- logger.info("Threading in betas to action probability arguments for all users.")
1179
+ logger.info("Threading in betas to action probability arguments for all subjects.")
1168
1180
  (
1169
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1170
- action_prob_func_args_by_decision_time_by_user_id,
1181
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1182
+ action_prob_func_args_by_decision_time_by_subject_id,
1171
1183
  ) = thread_action_prob_func_args(
1172
- action_prob_func_args_by_user_id_by_decision_time,
1173
- policy_num_by_decision_time_by_user_id,
1184
+ action_prob_func_args_by_subject_id_by_decision_time,
1185
+ policy_num_by_decision_time_by_subject_id,
1174
1186
  initial_policy_num,
1175
1187
  betas,
1176
1188
  beta_index_by_policy_num,
@@ -1182,17 +1194,17 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1182
1194
  # arguments with the central betas introduced.
1183
1195
  logger.info(
1184
1196
  "Threading in betas and beta-dependent action probabilities to algorithm update "
1185
- "function args for all users"
1197
+ "function args for all subjects"
1186
1198
  )
1187
- threaded_update_func_args_by_policy_num_by_user_id = thread_update_func_args(
1188
- update_func_args_by_by_user_id_by_policy_num,
1199
+ threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
1200
+ update_func_args_by_by_subject_id_by_policy_num,
1189
1201
  betas,
1190
1202
  beta_index_by_policy_num,
1191
1203
  alg_update_func_args_beta_index,
1192
1204
  alg_update_func_args_action_prob_index,
1193
1205
  alg_update_func_args_action_prob_times_index,
1194
1206
  alg_update_func_args_previous_betas_index,
1195
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1207
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1196
1208
  action_prob_func,
1197
1209
  )
1198
1210
 
@@ -1202,8 +1214,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1202
1214
  if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
1203
1215
  input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
1204
1216
  algorithm_estimating_func,
1205
- update_func_args_by_by_user_id_by_policy_num,
1206
- threaded_update_func_args_by_policy_num_by_user_id,
1217
+ update_func_args_by_by_subject_id_by_policy_num,
1218
+ threaded_update_func_args_by_policy_num_by_subject_id,
1207
1219
  suppress_interactive_data_checks,
1208
1220
  )
1209
1221
 
@@ -1212,15 +1224,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1212
1224
  # arguments with the central betas introduced.
1213
1225
  logger.info(
1214
1226
  "Threading in theta and beta-dependent action probabilities to inference update "
1215
- "function args for all users"
1227
+ "function args for all subjects"
1216
1228
  )
1217
- threaded_inference_func_args_by_user_id = thread_inference_func_args(
1218
- inference_func_args_by_user_id,
1229
+ threaded_inference_func_args_by_subject_id = thread_inference_func_args(
1230
+ inference_func_args_by_subject_id,
1219
1231
  inference_func_args_theta_index,
1220
1232
  theta,
1221
1233
  inference_func_args_action_prob_index,
1222
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1223
- inference_action_prob_decision_times_by_user_id,
1234
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1235
+ inference_action_prob_decision_times_by_subject_id,
1224
1236
  action_prob_func,
1225
1237
  )
1226
1238
 
@@ -1230,32 +1242,32 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1230
1242
  if not suppress_all_data_checks and inference_func_args_action_prob_index >= 0:
1231
1243
  input_checks.require_threaded_inference_estimating_function_args_equivalent(
1232
1244
  inference_estimating_func,
1233
- inference_func_args_by_user_id,
1234
- threaded_inference_func_args_by_user_id,
1245
+ inference_func_args_by_subject_id,
1246
+ threaded_inference_func_args_by_subject_id,
1235
1247
  suppress_interactive_data_checks,
1236
1248
  )
1237
1249
 
1238
- # 5. Now we can compute the weighted estimating function stacks for all users
1250
+ # 5. Now we can compute the weighted estimating function stacks for all subjects
1239
1251
  # as well as collect related values used to construct the adaptive and classical
1240
1252
  # sandwich variances.
1241
1253
  results = [
1242
- single_user_weighted_estimating_function_stacker(
1254
+ single_subject_weighted_estimating_function_stacker(
1243
1255
  beta_dim,
1244
- user_id,
1256
+ subject_id,
1245
1257
  action_prob_func,
1246
1258
  algorithm_estimating_func,
1247
1259
  inference_estimating_func,
1248
1260
  action_prob_func_args_beta_index,
1249
1261
  inference_func_args_theta_index,
1250
- action_prob_func_args_by_decision_time_by_user_id[user_id],
1251
- threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
1252
- threaded_update_func_args_by_policy_num_by_user_id[user_id],
1253
- threaded_inference_func_args_by_user_id[user_id],
1254
- policy_num_by_decision_time_by_user_id[user_id],
1255
- action_by_decision_time_by_user_id[user_id],
1262
+ action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1263
+ threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1264
+ threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
1265
+ threaded_inference_func_args_by_subject_id[subject_id],
1266
+ policy_num_by_decision_time_by_subject_id[subject_id],
1267
+ action_by_decision_time_by_subject_id[subject_id],
1256
1268
  beta_index_by_policy_num,
1257
1269
  )
1258
- for user_id in user_ids.tolist()
1270
+ for subject_id in subject_ids.tolist()
1259
1271
  ]
1260
1272
 
1261
1273
  stacks = jnp.array([result[0] for result in results])
@@ -1266,10 +1278,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1266
1278
  # 6. Note this strange return structure! We will differentiate the first output,
1267
1279
  # but the second tuple will be passed along without modification via has_aux=True and then used
1268
1280
  # for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
1269
- # bread matrices. The raw per-user stacks are also returned for debugging purposes.
1281
+ # bread matrices. The raw per-subject stacks are also returned for debugging purposes.
1270
1282
 
1271
1283
  # Note that returning the raw stacks here as the first arguments is potentially
1272
- # memory-intensive when combined with differentiation. Keep this in mind if the per-user bread
1284
+ # memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
1273
1285
  # inverse contributions are needed for something like CR2/CR3 small-sample corrections.
1274
1286
  return jnp.mean(stacks, axis=0), (
1275
1287
  jnp.mean(stacks, axis=0),
@@ -1280,10 +1292,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1280
1292
  )
1281
1293
 
1282
1294
 
1283
- def construct_classical_and_adaptive_sandwiches(
1295
+ def construct_classical_and_adjusted_sandwiches(
1284
1296
  theta_est: jnp.ndarray,
1285
1297
  all_post_update_betas: jnp.ndarray,
1286
- user_ids: jnp.ndarray,
1298
+ subject_ids: jnp.ndarray,
1287
1299
  action_prob_func: callable,
1288
1300
  action_prob_func_args_beta_index: int,
1289
1301
  alg_update_func: callable,
@@ -1296,32 +1308,34 @@ def construct_classical_and_adaptive_sandwiches(
1296
1308
  inference_func_type: str,
1297
1309
  inference_func_args_theta_index: int,
1298
1310
  inference_func_args_action_prob_index: int,
1299
- action_prob_func_args_by_user_id_by_decision_time: dict[
1311
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
1300
1312
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
1301
1313
  ],
1302
- policy_num_by_decision_time_by_user_id: dict[
1314
+ policy_num_by_decision_time_by_subject_id: dict[
1303
1315
  collections.abc.Hashable, dict[int, int | float]
1304
1316
  ],
1305
1317
  initial_policy_num: int | float,
1306
1318
  beta_index_by_policy_num: dict[int | float, int],
1307
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1308
- inference_action_prob_decision_times_by_user_id: dict[
1319
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1320
+ inference_action_prob_decision_times_by_subject_id: dict[
1309
1321
  collections.abc.Hashable, list[int]
1310
1322
  ],
1311
- update_func_args_by_by_user_id_by_policy_num: dict[
1323
+ update_func_args_by_by_subject_id_by_policy_num: dict[
1312
1324
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
1313
1325
  ],
1314
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
1326
+ action_by_decision_time_by_subject_id: dict[
1327
+ collections.abc.Hashable, dict[int, int]
1328
+ ],
1315
1329
  suppress_all_data_checks: bool,
1316
1330
  suppress_interactive_data_checks: bool,
1317
1331
  small_sample_correction: str,
1318
- form_adaptive_meat_adjustments_explicitly: bool,
1319
- stabilize_joint_adaptive_bread_inverse: bool,
1320
- study_df: pd.DataFrame | None,
1321
- in_study_col_name: str | None,
1332
+ form_adjusted_meat_adjustments_explicitly: bool,
1333
+ stabilize_joint_adjusted_bread_inverse: bool,
1334
+ analysis_df: pd.DataFrame | None,
1335
+ active_col_name: str | None,
1322
1336
  action_col_name: str | None,
1323
1337
  calendar_t_col_name: str | None,
1324
- user_id_col_name: str | None,
1338
+ subject_id_col_name: str | None,
1325
1339
  action_prob_func_args: tuple | None,
1326
1340
  action_prob_col_name: str | None,
1327
1341
  ) -> tuple[
@@ -1350,8 +1364,8 @@ def construct_classical_and_adaptive_sandwiches(
1350
1364
  A 1-D JAX NumPy array representing the parameter estimate for inference.
1351
1365
  all_post_update_betas (jnp.ndarray):
1352
1366
  A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
1353
- user_ids (jnp.ndarray):
1354
- A 1-D JAX NumPy array holding all user IDs in the study.
1367
+ subject_ids (jnp.ndarray):
1368
+ A 1-D JAX NumPy array holding all subject IDs in the deployment.
1355
1369
  action_prob_func (callable):
1356
1370
  The action probability function.
1357
1371
  action_prob_func_args_beta_index (int):
@@ -1379,29 +1393,29 @@ def construct_classical_and_adaptive_sandwiches(
1379
1393
  inference_func_args_action_prob_index (int):
1380
1394
  The index of action probabilities in the inference function arguments tuple, if
1381
1395
  applicable. -1 otherwise.
1382
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1383
- A dictionary mapping decision times to maps of user ids to the function arguments
1384
- required to compute action probabilities for this user.
1385
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1386
- A map of user ids to dictionaries mapping decision times to the policy number in use.
1387
- Only applies to in-study decision times!
1396
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1397
+ A dictionary mapping decision times to maps of subject ids to the function arguments
1398
+ required to compute action probabilities for this subject.
1399
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1400
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
1401
+ Only applies to active decision times!
1388
1402
  initial_policy_num (int | float):
1389
1403
  The policy number of the initial policy before any updates.
1390
1404
  beta_index_by_policy_num (dict[int | float, int]):
1391
1405
  A dictionary mapping policy numbers to the index of the corresponding beta in
1392
1406
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1393
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1394
- A dictionary mapping user IDs to their respective inference function arguments.
1395
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1396
- For each user, a list of decision times to which action probabilities correspond if
1397
- provided. Typically just in-study times if action probabilites are used in the inference
1407
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1408
+ A dictionary mapping subject IDs to their respective inference function arguments.
1409
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
1410
+ For each subject, a list of decision times to which action probabilities correspond if
1411
+ provided. Typically just active times if action probabilites are used in the inference
1398
1412
  loss or estimating function.
1399
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1400
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
1413
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1414
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
1401
1415
  to their respective update function arguments.
1402
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1403
- A dictionary mapping user IDs to their respective actions taken at each decision time.
1404
- Only applies to in-study decision times!
1416
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
1417
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
1418
+ Only applies to active decision times!
1405
1419
  suppress_all_data_checks (bool):
1406
1420
  If True, suppresses carrying out any data checks at all.
1407
1421
  suppress_interactive_data_checks (bool):
@@ -1411,27 +1425,27 @@ def construct_classical_and_adaptive_sandwiches(
1411
1425
  small_sample_correction (str):
1412
1426
  The type of small sample correction to apply. See SmallSampleCorrections class for
1413
1427
  options.
1414
- form_adaptive_meat_adjustments_explicitly (bool):
1415
- If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
1428
+ form_adjusted_meat_adjustments_explicitly (bool):
1429
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
1416
1430
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
1417
1431
  adaptive sandwich is formed without doing this.
1418
- stabilize_joint_adaptive_bread_inverse (bool):
1432
+ stabilize_joint_adjusted_bread_inverse (bool):
1419
1433
  If True, will apply various techniques to stabilize the joint adaptive bread inverse if necessary.
1420
- study_df (pd.DataFrame):
1421
- The full study dataframe, needed if forming the adaptive meat adjustments explicitly.
1422
- in_study_col_name (str):
1423
- The name of the column in study_df indicating whether a user is in-study at a given decision time.
1434
+ analysis_df (pd.DataFrame):
1435
+ The full analysis dataframe, needed if forming the adaptive meat adjustments explicitly.
1436
+ active_col_name (str):
1437
+ The name of the column in analysis_df indicating whether a subject is active at a given decision time.
1424
1438
  action_col_name (str):
1425
- The name of the column in study_df indicating the action taken at a given decision time.
1439
+ The name of the column in analysis_df indicating the action taken at a given decision time.
1426
1440
  calendar_t_col_name (str):
1427
- The name of the column in study_df indicating the calendar time of a given decision time.
1428
- user_id_col_name (str):
1429
- The name of the column in study_df indicating the user ID.
1441
+ The name of the column in analysis_df indicating the calendar time of a given decision time.
1442
+ subject_id_col_name (str):
1443
+ The name of the column in analysis_df indicating the subject ID.
1430
1444
  action_prob_func_args (tuple):
1431
1445
  The arguments to be passed to the action probability function, needed if forming the
1432
1446
  adaptive meat adjustments explicitly.
1433
1447
  action_prob_col_name (str):
1434
- The name of the column in study_df indicating the action probability of the action taken,
1448
+ The name of the column in analysis_df indicating the action probability of the action taken,
1435
1449
  needed if forming the adaptive meat adjustments explicitly.
1436
1450
  Returns:
1437
1451
  tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
@@ -1444,10 +1458,10 @@ def construct_classical_and_adaptive_sandwiches(
1444
1458
  - The classical meat matrix.
1445
1459
  - The classical sandwich matrix.
1446
1460
  - The average weighted estimating function stack.
1447
- - All per-user weighted estimating function stacks.
1448
- - The per-user adaptive meat small-sample corrections.
1449
- - The per-user classical meat small-sample corrections.
1450
- - The per-user adaptive meat adjustments, if form_adaptive_meat_adjustments_explicitly
1461
+ - All per-subject weighted estimating function stacks.
1462
+ - The per-subject adaptive meat small-sample corrections.
1463
+ - The per-subject classical meat small-sample corrections.
1464
+ - The per-subject adaptive meat adjustments, if form_adjusted_meat_adjustments_explicitly
1451
1465
  is True, otherwise an array of NaNs.
1452
1466
  """
1453
1467
  logger.info(
@@ -1455,13 +1469,13 @@ def construct_classical_and_adaptive_sandwiches(
1455
1469
  )
1456
1470
  theta_dim = theta_est.shape[0]
1457
1471
  beta_dim = all_post_update_betas.shape[1]
1458
- # Note that these "contributions" are per-user Jacobians of the weighted estimating function stack.
1459
- raw_joint_adaptive_bread_inverse_matrix, (
1472
+ # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
1473
+ raw_joint_adjusted_bread_inverse_matrix, (
1460
1474
  avg_estimating_function_stack,
1461
- per_user_joint_adaptive_meat_contributions,
1462
- per_user_classical_meat_contributions,
1463
- per_user_classical_bread_inverse_contributions,
1464
- per_user_estimating_function_stacks,
1475
+ per_subject_joint_adjusted_meat_contributions,
1476
+ per_subject_classical_meat_contributions,
1477
+ per_subject_classical_bread_inverse_contributions,
1478
+ per_subject_estimating_function_stacks,
1465
1479
  ) = jax.jacrev(
1466
1480
  get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
1467
1481
  )(
@@ -1471,7 +1485,7 @@ def construct_classical_and_adaptive_sandwiches(
1471
1485
  flatten_params(all_post_update_betas, theta_est),
1472
1486
  beta_dim,
1473
1487
  theta_dim,
1474
- user_ids,
1488
+ subject_ids,
1475
1489
  action_prob_func,
1476
1490
  action_prob_func_args_beta_index,
1477
1491
  alg_update_func,
@@ -1484,87 +1498,87 @@ def construct_classical_and_adaptive_sandwiches(
1484
1498
  inference_func_type,
1485
1499
  inference_func_args_theta_index,
1486
1500
  inference_func_args_action_prob_index,
1487
- action_prob_func_args_by_user_id_by_decision_time,
1488
- policy_num_by_decision_time_by_user_id,
1501
+ action_prob_func_args_by_subject_id_by_decision_time,
1502
+ policy_num_by_decision_time_by_subject_id,
1489
1503
  initial_policy_num,
1490
1504
  beta_index_by_policy_num,
1491
- inference_func_args_by_user_id,
1492
- inference_action_prob_decision_times_by_user_id,
1493
- update_func_args_by_by_user_id_by_policy_num,
1494
- action_by_decision_time_by_user_id,
1505
+ inference_func_args_by_subject_id,
1506
+ inference_action_prob_decision_times_by_subject_id,
1507
+ update_func_args_by_by_subject_id_by_policy_num,
1508
+ action_by_decision_time_by_subject_id,
1495
1509
  suppress_all_data_checks,
1496
1510
  suppress_interactive_data_checks,
1497
1511
  )
1498
1512
 
1499
- num_users = len(user_ids)
1513
+ num_subjects = len(subject_ids)
1500
1514
 
1501
1515
  (
1502
- joint_adaptive_meat_matrix,
1516
+ joint_adjusted_meat_matrix,
1503
1517
  classical_meat_matrix,
1504
- per_user_adaptive_corrections,
1505
- per_user_classical_corrections,
1518
+ per_subject_adjusted_corrections,
1519
+ per_subject_classical_corrections,
1506
1520
  ) = perform_desired_small_sample_correction(
1507
1521
  small_sample_correction,
1508
- per_user_joint_adaptive_meat_contributions,
1509
- per_user_classical_meat_contributions,
1510
- per_user_classical_bread_inverse_contributions,
1511
- num_users,
1522
+ per_subject_joint_adjusted_meat_contributions,
1523
+ per_subject_classical_meat_contributions,
1524
+ per_subject_classical_bread_inverse_contributions,
1525
+ num_subjects,
1512
1526
  theta_dim,
1513
1527
  )
1514
1528
 
1515
1529
  # Increase diagonal block dominance possibly improve conditioning of diagonal
1516
1530
  # blocks as necessary, to ensure mathematical stability of joint bread inverse
1517
- stabilized_joint_adaptive_bread_inverse_matrix = (
1531
+ stabilized_joint_adjusted_bread_inverse_matrix = (
1518
1532
  (
1519
- stabilize_joint_adaptive_bread_inverse_if_necessary(
1520
- raw_joint_adaptive_bread_inverse_matrix,
1533
+ stabilize_joint_adjusted_bread_inverse_if_necessary(
1534
+ raw_joint_adjusted_bread_inverse_matrix,
1521
1535
  beta_dim,
1522
1536
  theta_dim,
1523
1537
  )
1524
1538
  )
1525
- if stabilize_joint_adaptive_bread_inverse
1526
- else raw_joint_adaptive_bread_inverse_matrix
1539
+ if stabilize_joint_adjusted_bread_inverse
1540
+ else raw_joint_adjusted_bread_inverse_matrix
1527
1541
  )
1528
1542
 
1529
1543
  # Now stably (no explicit inversion) form our sandwiches.
1530
- joint_adaptive_sandwich = form_sandwich_from_bread_inverse_and_meat(
1531
- stabilized_joint_adaptive_bread_inverse_matrix,
1532
- joint_adaptive_meat_matrix,
1533
- num_users,
1544
+ joint_adjusted_sandwich = form_sandwich_from_bread_inverse_and_meat(
1545
+ stabilized_joint_adjusted_bread_inverse_matrix,
1546
+ joint_adjusted_meat_matrix,
1547
+ num_subjects,
1534
1548
  method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1535
1549
  )
1536
1550
  classical_bread_inverse_matrix = jnp.mean(
1537
- per_user_classical_bread_inverse_contributions, axis=0
1551
+ per_subject_classical_bread_inverse_contributions, axis=0
1538
1552
  )
1539
1553
  classical_sandwich = form_sandwich_from_bread_inverse_and_meat(
1540
1554
  classical_bread_inverse_matrix,
1541
1555
  classical_meat_matrix,
1542
- num_users,
1556
+ num_subjects,
1543
1557
  method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1544
1558
  )
1545
1559
 
1546
- per_user_adaptive_meat_adjustments = jnp.full(
1547
- (len(user_ids), theta_dim, theta_dim), jnp.nan
1560
+ per_subject_adjusted_meat_adjustments = jnp.full(
1561
+ (len(subject_ids), theta_dim, theta_dim), jnp.nan
1548
1562
  )
1549
- if form_adaptive_meat_adjustments_explicitly:
1550
- per_user_adjusted_classical_meat_contributions = (
1551
- form_adaptive_meat_adjustments_directly(
1563
+ if form_adjusted_meat_adjustments_explicitly:
1564
+ per_subject_adjusted_classical_meat_contributions = (
1565
+ form_adjusted_meat_adjustments_directly(
1552
1566
  theta_dim,
1553
1567
  all_post_update_betas.shape[1],
1554
- stabilized_joint_adaptive_bread_inverse_matrix,
1555
- per_user_estimating_function_stacks,
1556
- study_df,
1557
- in_study_col_name,
1568
+ stabilized_joint_adjusted_bread_inverse_matrix,
1569
+ per_subject_estimating_function_stacks,
1570
+ analysis_df,
1571
+ active_col_name,
1558
1572
  action_col_name,
1559
1573
  calendar_t_col_name,
1560
- user_id_col_name,
1574
+ subject_id_col_name,
1561
1575
  action_prob_func,
1562
1576
  action_prob_func_args,
1563
1577
  action_prob_func_args_beta_index,
1564
1578
  theta_est,
1565
1579
  inference_func,
1566
1580
  inference_func_args_theta_index,
1567
- user_ids,
1581
+ subject_ids,
1568
1582
  action_prob_col_name,
1569
1583
  )
1570
1584
  )
@@ -1574,30 +1588,30 @@ def construct_classical_and_adaptive_sandwiches(
1574
1588
  # First just apply any small-sample correction for parity.
1575
1589
  (
1576
1590
  _,
1577
- theta_only_adaptive_meat_matrix_v2,
1591
+ theta_only_adjusted_meat_matrix_v2,
1578
1592
  _,
1579
1593
  _,
1580
1594
  ) = perform_desired_small_sample_correction(
1581
1595
  small_sample_correction,
1582
- per_user_joint_adaptive_meat_contributions,
1583
- per_user_adjusted_classical_meat_contributions,
1584
- per_user_classical_bread_inverse_contributions,
1585
- num_users,
1596
+ per_subject_joint_adjusted_meat_contributions,
1597
+ per_subject_adjusted_classical_meat_contributions,
1598
+ per_subject_classical_bread_inverse_contributions,
1599
+ num_subjects,
1586
1600
  theta_dim,
1587
1601
  )
1588
- theta_only_adaptive_sandwich_from_adjustments = (
1602
+ theta_only_adjusted_sandwich_from_adjustments = (
1589
1603
  form_sandwich_from_bread_inverse_and_meat(
1590
1604
  classical_bread_inverse_matrix,
1591
- theta_only_adaptive_meat_matrix_v2,
1592
- num_users,
1605
+ theta_only_adjusted_meat_matrix_v2,
1606
+ num_subjects,
1593
1607
  method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1594
1608
  )
1595
1609
  )
1596
- theta_only_adaptive_sandwich = joint_adaptive_sandwich[-theta_dim:, -theta_dim:]
1610
+ theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
1597
1611
 
1598
1612
  if not np.allclose(
1599
- theta_only_adaptive_sandwich,
1600
- theta_only_adaptive_sandwich_from_adjustments,
1613
+ theta_only_adjusted_sandwich,
1614
+ theta_only_adjusted_sandwich_from_adjustments,
1601
1615
  rtol=3e-2,
1602
1616
  ):
1603
1617
  logger.warning(
@@ -1607,26 +1621,26 @@ def construct_classical_and_adaptive_sandwiches(
1607
1621
  # Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
1608
1622
  # values too. The joint adaptive bread inverse should always be block lower triangular.
1609
1623
  return (
1610
- raw_joint_adaptive_bread_inverse_matrix,
1611
- stabilized_joint_adaptive_bread_inverse_matrix,
1612
- joint_adaptive_meat_matrix,
1613
- joint_adaptive_sandwich,
1624
+ raw_joint_adjusted_bread_inverse_matrix,
1625
+ stabilized_joint_adjusted_bread_inverse_matrix,
1626
+ joint_adjusted_meat_matrix,
1627
+ joint_adjusted_sandwich,
1614
1628
  classical_bread_inverse_matrix,
1615
1629
  classical_meat_matrix,
1616
1630
  classical_sandwich,
1617
1631
  avg_estimating_function_stack,
1618
- per_user_estimating_function_stacks,
1619
- per_user_adaptive_corrections,
1620
- per_user_classical_corrections,
1621
- per_user_adaptive_meat_adjustments,
1632
+ per_subject_estimating_function_stacks,
1633
+ per_subject_adjusted_corrections,
1634
+ per_subject_classical_corrections,
1635
+ per_subject_adjusted_meat_adjustments,
1622
1636
  )
1623
1637
 
1624
1638
 
1625
1639
  # TODO: I think there should be interaction to confirm stabilization. It is
1626
- # important for the user to know if this is happening. Even if enabled, it is important
1627
- # that the user know it actually kicks in.
1628
- def stabilize_joint_adaptive_bread_inverse_if_necessary(
1629
- joint_adaptive_bread_inverse_matrix: jnp.ndarray,
1640
+ # important for the subject to know if this is happening. Even if enabled, it is important
1641
+ # that the subject know it actually kicks in.
1642
+ def stabilize_joint_adjusted_bread_inverse_if_necessary(
1643
+ joint_adjusted_bread_inverse_matrix: jnp.ndarray,
1630
1644
  beta_dim: int,
1631
1645
  theta_dim: int,
1632
1646
  ) -> jnp.ndarray:
@@ -1635,7 +1649,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1635
1649
  dominance and/or adding a small ridge penalty to the diagonal blocks.
1636
1650
 
1637
1651
  Args:
1638
- joint_adaptive_bread_inverse_matrix (jnp.ndarray):
1652
+ joint_adjusted_bread_inverse_matrix (jnp.ndarray):
1639
1653
  A 2-D JAX NumPy array representing the joint adaptive bread inverse matrix.
1640
1654
  beta_dim (int):
1641
1655
  The dimension of each beta parameter.
@@ -1656,7 +1670,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1656
1670
 
1657
1671
  # Grab just the RL block and convert numpy array for easier manipulation.
1658
1672
  RL_stack_beta_derivatives_block = np.array(
1659
- joint_adaptive_bread_inverse_matrix[:-theta_dim, :-theta_dim]
1673
+ joint_adjusted_bread_inverse_matrix[:-theta_dim, :-theta_dim]
1660
1674
  )
1661
1675
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
1662
1676
  for i in range(1, num_updates + 1):
@@ -1684,7 +1698,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1684
1698
  RL_stack_beta_derivatives_block[
1685
1699
  diagonal_block_slice, diagonal_block_slice
1686
1700
  ] = diagonal_block + ridge_penalty * np.eye(beta_dim)
1687
- # TODO: Require user input here in interactive settings?
1701
+ # TODO: Require subject input here in interactive settings?
1688
1702
  logger.info(
1689
1703
  "Added ridge penalty of %s to diagonal block for update %s to improve conditioning from %s to %s",
1690
1704
  ridge_penalty,
@@ -1775,11 +1789,11 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1775
1789
  [
1776
1790
  [
1777
1791
  RL_stack_beta_derivatives_block,
1778
- joint_adaptive_bread_inverse_matrix[:-theta_dim, -theta_dim:],
1792
+ joint_adjusted_bread_inverse_matrix[:-theta_dim, -theta_dim:],
1779
1793
  ],
1780
1794
  [
1781
- joint_adaptive_bread_inverse_matrix[-theta_dim:, :-theta_dim],
1782
- joint_adaptive_bread_inverse_matrix[-theta_dim:, -theta_dim:],
1795
+ joint_adjusted_bread_inverse_matrix[-theta_dim:, :-theta_dim],
1796
+ joint_adjusted_bread_inverse_matrix[-theta_dim:, -theta_dim:],
1783
1797
  ],
1784
1798
  ]
1785
1799
  )
@@ -1788,7 +1802,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1788
1802
  def form_sandwich_from_bread_inverse_and_meat(
1789
1803
  bread_inverse: jnp.ndarray,
1790
1804
  meat: jnp.ndarray,
1791
- num_users: int,
1805
+ num_subjects: int,
1792
1806
  method: str = SandwichFormationMethods.BREAD_INVERSE_T_QR,
1793
1807
  ) -> jnp.ndarray:
1794
1808
  """
@@ -1802,8 +1816,8 @@ def form_sandwich_from_bread_inverse_and_meat(
1802
1816
  A 2-D JAX NumPy array representing the bread inverse matrix.
1803
1817
  meat (jnp.ndarray):
1804
1818
  A 2-D JAX NumPy array representing the meat matrix.
1805
- num_users (int):
1806
- The number of users in the study, used to scale the sandwich appropriately.
1819
+ num_subjects (int):
1820
+ The number of subjects in the deployment, used to scale the sandwich appropriately.
1807
1821
  method (str):
1808
1822
  The method to use for forming the sandwich.
1809
1823
 
@@ -1829,7 +1843,7 @@ def form_sandwich_from_bread_inverse_and_meat(
1829
1843
  L, scipy.linalg.solve_triangular(L, meat.T, lower=True).T, lower=True
1830
1844
  )
1831
1845
 
1832
- return Q @ new_meat @ Q.T / num_users
1846
+ return Q @ new_meat @ Q.T / num_subjects
1833
1847
  elif method == SandwichFormationMethods.MEAT_SVD_SOLVE:
1834
1848
  # Factor the meat via SVD without any symmetrization or truncation.
1835
1849
  # For general (possibly slightly nonsymmetric) M, SVD gives M = U @ diag(s) @ Vh.
@@ -1843,14 +1857,14 @@ def form_sandwich_from_bread_inverse_and_meat(
1843
1857
  W_left = scipy.linalg.solve(bread_inverse, C_left)
1844
1858
  W_right = scipy.linalg.solve(bread_inverse, C_right)
1845
1859
 
1846
- # Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_users
1847
- return W_left @ W_right.T / num_users
1860
+ # Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
1861
+ return W_left @ W_right.T / num_subjects
1848
1862
 
1849
1863
  elif method == SandwichFormationMethods.NAIVE:
1850
1864
  # Simply invert the bread inverse and form the sandwich directly.
1851
1865
  # This is NOT numerically stable and is only included for comparison purposes.
1852
1866
  bread = np.linalg.inv(bread_inverse)
1853
- return bread @ meat @ meat.T / num_users
1867
+ return bread @ meat @ meat.T / num_subjects
1854
1868
 
1855
1869
  else:
1856
1870
  raise ValueError(