lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,8 +24,8 @@ from .constants import (
24
24
  SandwichFormationMethods,
25
25
  SmallSampleCorrections,
26
26
  )
27
- from .form_adaptive_meat_adjustments_directly import (
28
- form_adaptive_meat_adjustments_directly,
27
+ from .form_adjusted_meat_adjustments_directly import (
28
+ form_adjusted_meat_adjustments_directly,
29
29
  )
30
30
  from . import input_checks
31
31
  from . import get_datum_for_blowup_supervised_learning
@@ -37,9 +37,9 @@ from .helper_functions import (
37
37
  calculate_beta_dim,
38
38
  collect_all_post_update_betas,
39
39
  construct_beta_index_by_policy_num_map,
40
- extract_action_and_policy_by_decision_time_by_user_id,
40
+ extract_action_and_policy_by_decision_time_by_subject_id,
41
41
  flatten_params,
42
- get_in_study_df_column,
42
+ get_active_df_column,
43
43
  get_min_time_by_policy_num,
44
44
  get_radon_nikodym_weight,
45
45
  load_function_from_same_named_file,
@@ -61,7 +61,7 @@ def cli():
61
61
 
62
62
  # TODO: Check all help strings for accuracy.
63
63
  # TODO: Deal with NA, -1, etc policy numbers
64
- # TODO: Make sure in study is never on for more than one stretch EDIT: unclear if
64
+ # TODO: Make sure in deployment is never on for more than one stretch EDIT: unclear if
65
65
  # this will remain an invariant as we deal with more complicated data missingness
66
66
  # TODO: I think I'm agnostic to indexing of calendar times but should check because
67
67
  # otherwise need to add a check here to verify required format.
@@ -69,7 +69,7 @@ def cli():
69
69
  # Higher dimensional objects not supported. Not entirely sure what kind of "scalars" apply.
70
70
  @cli.command(name="analyze")
71
71
  @click.option(
72
- "--study_df_pickle",
72
+ "--analysis_df_pickle",
73
73
  type=click.File("rb"),
74
74
  help="Pickled pandas dataframe in correct format (see contract/readme).",
75
75
  required=True,
@@ -83,7 +83,7 @@ def cli():
83
83
  @click.option(
84
84
  "--action_prob_func_args_pickle",
85
85
  type=click.File("rb"),
86
- help="Pickled dictionary that contains the action probability function arguments for all decision times for all users.",
86
+ help="Pickled dictionary that contains the action probability function arguments for all decision times for all subjects.",
87
87
  required=True,
88
88
  )
89
89
  @click.option(
@@ -95,7 +95,7 @@ def cli():
95
95
  @click.option(
96
96
  "--alg_update_func_filename",
97
97
  type=click.Path(exists=True),
98
- help="File that contains the per-user update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
98
+ help="File that contains the per-subject update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
99
99
  required=True,
100
100
  )
101
101
  @click.option(
@@ -107,7 +107,7 @@ def cli():
107
107
  @click.option(
108
108
  "--alg_update_func_args_pickle",
109
109
  type=click.File("rb"),
110
- help="Pickled dictionary that contains the algorithm update function arguments for all update times for all users.",
110
+ help="Pickled dictionary that contains the algorithm update function arguments for all update times for all subjects.",
111
111
  required=True,
112
112
  )
113
113
  @click.option(
@@ -137,7 +137,7 @@ def cli():
137
137
  @click.option(
138
138
  "--inference_func_filename",
139
139
  type=click.Path(exists=True),
140
- help="File that contains the per-user loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
140
+ help="File that contains the per-subject loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
141
141
  required=True,
142
142
  )
143
143
  @click.option(
@@ -155,56 +155,56 @@ def cli():
155
155
  @click.option(
156
156
  "--theta_calculation_func_filename",
157
157
  type=click.Path(exists=True),
158
- help="Path to file that allows one to actually calculate a theta estimate given the study dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
158
+ help="Path to file that allows one to actually calculate a theta estimate given the analysis dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
159
159
  required=True,
160
160
  )
161
161
  @click.option(
162
- "--in_study_col_name",
162
+ "--active_col_name",
163
163
  type=str,
164
164
  required=True,
165
- help="Name of the binary column in the study dataframe that indicates whether a user is in the study.",
165
+ help="Name of the binary column in the analysis dataframe that indicates whether a subject is in the deployment.",
166
166
  )
167
167
  @click.option(
168
168
  "--action_col_name",
169
169
  type=str,
170
170
  required=True,
171
- help="Name of the binary column in the study dataframe that indicates which action was taken.",
171
+ help="Name of the binary column in the analysis dataframe that indicates which action was taken.",
172
172
  )
173
173
  @click.option(
174
174
  "--policy_num_col_name",
175
175
  type=str,
176
176
  required=True,
177
- help="Name of the column in the study dataframe that indicates the policy number in use.",
177
+ help="Name of the column in the analysis dataframe that indicates the policy number in use.",
178
178
  )
179
179
  @click.option(
180
180
  "--calendar_t_col_name",
181
181
  type=str,
182
182
  required=True,
183
- help="Name of the column in the study dataframe that indicates calendar time (shared integer index across users).",
183
+ help="Name of the column in the analysis dataframe that indicates calendar time (shared integer index across subjects).",
184
184
  )
185
185
  @click.option(
186
- "--user_id_col_name",
186
+ "--subject_id_col_name",
187
187
  type=str,
188
188
  required=True,
189
- help="Name of the column in the study dataframe that indicates user id.",
189
+ help="Name of the column in the analysis dataframe that indicates subject id.",
190
190
  )
191
191
  @click.option(
192
192
  "--action_prob_col_name",
193
193
  type=str,
194
194
  required=True,
195
- help="Name of the column in the study dataframe that gives action one probabilities.",
195
+ help="Name of the column in the analysis dataframe that gives action one probabilities.",
196
196
  )
197
197
  @click.option(
198
198
  "--reward_col_name",
199
199
  type=str,
200
200
  required=True,
201
- help="Name of the column in the study dataframe that gives rewards.",
201
+ help="Name of the column in the analysis dataframe that gives rewards.",
202
202
  )
203
203
  @click.option(
204
204
  "--suppress_interactive_data_checks",
205
205
  type=bool,
206
206
  default=False,
207
- help="Flag to suppress any data checks that require user input. This is suitable for tests and large simulations",
207
+ help="Flag to suppress any data checks that require subject input. This is suitable for tests and large simulations",
208
208
  )
209
209
  @click.option(
210
210
  "--suppress_all_data_checks",
@@ -232,13 +232,13 @@ def cli():
232
232
  help="Flag to collect data for supervised learning blowup detection. This will write a single datum and label to a file in the same directory as the input files.",
233
233
  )
234
234
  @click.option(
235
- "--form_adaptive_meat_adjustments_explicitly",
235
+ "--form_adjusted_meat_adjustments_explicitly",
236
236
  type=bool,
237
237
  default=False,
238
- help="If True, explicitly forms the per-user meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
238
+ help="If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
239
239
  )
240
240
  @click.option(
241
- "--stabilize_joint_adaptive_bread_inverse",
241
+ "--stabilize_joint_adjusted_bread_inverse",
242
242
  type=bool,
243
243
  default=True,
244
244
  help="If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning thresholds.",
@@ -248,7 +248,7 @@ def analyze_dataset_wrapper(**kwargs):
248
248
  This function is a wrapper around analyze_dataset to facilitate command line use.
249
249
 
250
250
  From the command line, we will take pickles and filenames for Python objects.
251
- Unpickle/load files here for passing to the implementation function, which
251
+ We unpickle/load files here for passing to the implementation function, which
252
252
  may also be called in its own right with in-memory objects.
253
253
 
254
254
  See analyze_dataset for the underlying details.
@@ -256,18 +256,20 @@ def analyze_dataset_wrapper(**kwargs):
256
256
  Returns: None
257
257
  """
258
258
 
259
- # Pass along the folder the study dataframe is in as the output folder.
260
- # Do it now because we will be removing the study dataframe pickle from kwargs.
261
- kwargs["output_dir"] = pathlib.Path(kwargs["study_df_pickle"].name).parent.resolve()
259
+ # Pass along the folder the analysis dataframe is in as the output folder.
260
+ # Do it now because we will be removing the analysis dataframe pickle from kwargs.
261
+ kwargs["output_dir"] = pathlib.Path(
262
+ kwargs["analysis_df_pickle"].name
263
+ ).parent.resolve()
262
264
 
263
265
  # Unpickle pickles and replace those args in kwargs
264
- kwargs["study_df"] = pickle.load(kwargs["study_df_pickle"])
266
+ kwargs["analysis_df"] = pickle.load(kwargs["analysis_df_pickle"])
265
267
  kwargs["action_prob_func_args"] = pickle.load(
266
268
  kwargs["action_prob_func_args_pickle"]
267
269
  )
268
270
  kwargs["alg_update_func_args"] = pickle.load(kwargs["alg_update_func_args_pickle"])
269
271
 
270
- kwargs.pop("study_df_pickle")
272
+ kwargs.pop("analysis_df_pickle")
271
273
  kwargs.pop("action_prob_func_args_pickle")
272
274
  kwargs.pop("alg_update_func_args_pickle")
273
275
 
@@ -295,7 +297,7 @@ def analyze_dataset_wrapper(**kwargs):
295
297
 
296
298
  def analyze_dataset(
297
299
  output_dir: pathlib.Path | str,
298
- study_df: pd.DataFrame,
300
+ analysis_df: pd.DataFrame,
299
301
  action_prob_func: Callable,
300
302
  action_prob_func_args: dict[int, Any],
301
303
  action_prob_func_args_beta_index: int,
@@ -310,19 +312,19 @@ def analyze_dataset(
310
312
  inference_func_type: str,
311
313
  inference_func_args_theta_index: int,
312
314
  theta_calculation_func: Callable[[pd.DataFrame], jnp.ndarray],
313
- in_study_col_name: str,
315
+ active_col_name: str,
314
316
  action_col_name: str,
315
317
  policy_num_col_name: str,
316
318
  calendar_t_col_name: str,
317
- user_id_col_name: str,
319
+ subject_id_col_name: str,
318
320
  action_prob_col_name: str,
319
321
  reward_col_name: str,
320
322
  suppress_interactive_data_checks: bool,
321
323
  suppress_all_data_checks: bool,
322
324
  small_sample_correction: str,
323
325
  collect_data_for_blowup_supervised_learning: bool,
324
- form_adaptive_meat_adjustments_explicitly: bool,
325
- stabilize_joint_adaptive_bread_inverse: bool,
326
+ form_adjusted_meat_adjustments_explicitly: bool,
327
+ stabilize_joint_adjusted_bread_inverse: bool,
326
328
  ) -> None:
327
329
  """
328
330
  Analyzes a dataset to provide a parameter estimate and an estimate of its variance using adaptive and classical sandwich estimators.
@@ -337,8 +339,8 @@ def analyze_dataset(
337
339
  Parameters:
338
340
  output_dir (pathlib.Path | str):
339
341
  Directory in which to save output files.
340
- study_df (pd.DataFrame):
341
- DataFrame containing the study data.
342
+ analysis_df (pd.DataFrame):
343
+ DataFrame containing the deployment data.
342
344
  action_prob_func (callable):
343
345
  Action probability function.
344
346
  action_prob_func_args (dict[int, Any]):
@@ -364,21 +366,21 @@ def analyze_dataset(
364
366
  inference_func_args_theta_index (int):
365
367
  Index for theta in inference function arguments.
366
368
  theta_calculation_func (callable):
367
- Function to estimate theta from the study DataFrame.
368
- in_study_col_name (str):
369
- Column name indicating if a user is in the study in the study dataframe.
369
+ Function to estimate theta from the analysis dataframe.
370
+ active_col_name (str):
371
+ Column name indicating if a subject is active in the analysis dataframe.
370
372
  action_col_name (str):
371
- Column name for actions in the study dataframe.
373
+ Column name for actions in the analysis dataframe.
372
374
  policy_num_col_name (str):
373
- Column name for policy numbers in the study dataframe.
375
+ Column name for policy numbers in the analysis dataframe.
374
376
  calendar_t_col_name (str):
375
- Column name for calendar time in the study dataframe.
376
- user_id_col_name (str):
377
- Column name for user IDs in the study dataframe.
377
+ Column name for calendar time in the analysis dataframe.
378
+ subject_id_col_name (str):
379
+ Column name for subject IDs in the analysis dataframe.
378
380
  action_prob_col_name (str):
379
- Column name for action probabilities in the study dataframe.
381
+ Column name for action probabilities in the analysis dataframe.
380
382
  reward_col_name (str):
381
- Column name for rewards in the study dataframe.
383
+ Column name for rewards in the analysis dataframe.
382
384
  suppress_interactive_data_checks (bool):
383
385
  Whether to suppress interactive data checks. This should be used in simulations, for example.
384
386
  suppress_all_data_checks (bool):
@@ -387,11 +389,11 @@ def analyze_dataset(
387
389
  Type of small sample correction to apply.
388
390
  collect_data_for_blowup_supervised_learning (bool):
389
391
  Whether to collect data for doing supervised learning about adaptive sandwich blowup.
390
- form_adaptive_meat_adjustments_explicitly (bool):
391
- If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
392
+ form_adjusted_meat_adjustments_explicitly (bool):
393
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
392
394
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
393
395
  adaptive sandwich is formed without doing this.
394
- stabilize_joint_adaptive_bread_inverse (bool):
396
+ stabilize_joint_adjusted_bread_inverse (bool):
395
397
  If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
396
398
  thresholds.
397
399
 
@@ -406,19 +408,19 @@ def analyze_dataset(
406
408
  level=logging.INFO,
407
409
  )
408
410
 
409
- theta_est = jnp.array(theta_calculation_func(study_df))
411
+ theta_est = jnp.array(theta_calculation_func(analysis_df))
410
412
 
411
413
  beta_dim = calculate_beta_dim(
412
414
  action_prob_func_args, action_prob_func_args_beta_index
413
415
  )
414
416
  if not suppress_all_data_checks:
415
417
  input_checks.perform_first_wave_input_checks(
416
- study_df,
417
- in_study_col_name,
418
+ analysis_df,
419
+ active_col_name,
418
420
  action_col_name,
419
421
  policy_num_col_name,
420
422
  calendar_t_col_name,
421
- user_id_col_name,
423
+ subject_id_col_name,
422
424
  action_prob_col_name,
423
425
  reward_col_name,
424
426
  action_prob_func,
@@ -439,7 +441,7 @@ def analyze_dataset(
439
441
 
440
442
  beta_index_by_policy_num, initial_policy_num = (
441
443
  construct_beta_index_by_policy_num_map(
442
- study_df, policy_num_col_name, in_study_col_name
444
+ analysis_df, policy_num_col_name, active_col_name
443
445
  )
444
446
  )
445
447
 
@@ -447,11 +449,11 @@ def analyze_dataset(
447
449
  beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
448
450
  )
449
451
 
450
- action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id = (
451
- extract_action_and_policy_by_decision_time_by_user_id(
452
- study_df,
453
- user_id_col_name,
454
- in_study_col_name,
452
+ action_by_decision_time_by_subject_id, policy_num_by_decision_time_by_subject_id = (
453
+ extract_action_and_policy_by_decision_time_by_subject_id(
454
+ analysis_df,
455
+ subject_id_col_name,
456
+ active_col_name,
455
457
  calendar_t_col_name,
456
458
  action_col_name,
457
459
  policy_num_col_name,
@@ -459,45 +461,45 @@ def analyze_dataset(
459
461
  )
460
462
 
461
463
  (
462
- inference_func_args_by_user_id,
464
+ inference_func_args_by_subject_id,
463
465
  inference_func_args_action_prob_index,
464
- inference_action_prob_decision_times_by_user_id,
466
+ inference_action_prob_decision_times_by_subject_id,
465
467
  ) = process_inference_func_args(
466
468
  inference_func,
467
469
  inference_func_args_theta_index,
468
- study_df,
470
+ analysis_df,
469
471
  theta_est,
470
472
  action_prob_col_name,
471
473
  calendar_t_col_name,
472
- user_id_col_name,
473
- in_study_col_name,
474
+ subject_id_col_name,
475
+ active_col_name,
474
476
  )
475
477
 
476
- # Use a per-user weighted estimating function stacking functino to derive classical and joint
478
+ # Use a per-subject weighted estimating function stacking functino to derive classical and joint
477
479
  # adaptive meat and inverse bread matrices. This is facilitated because the *value* of the
478
480
  # weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
479
481
  logger.info(
480
- "Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across users."
482
+ "Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across subjects."
481
483
  )
482
484
 
483
- user_ids = jnp.array(study_df[user_id_col_name].unique())
485
+ subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
484
486
  (
485
- stabilized_joint_adaptive_bread_inverse_matrix,
486
- raw_joint_adaptive_bread_inverse_matrix,
487
- joint_adaptive_meat_matrix,
488
- joint_adaptive_sandwich_matrix,
487
+ stabilized_joint_adjusted_bread_inverse_matrix,
488
+ raw_joint_adjusted_bread_inverse_matrix,
489
+ joint_adjusted_meat_matrix,
490
+ joint_adjusted_sandwich_matrix,
489
491
  classical_bread_inverse_matrix,
490
492
  classical_meat_matrix,
491
493
  classical_sandwich_var_estimate,
492
494
  avg_estimating_function_stack,
493
- per_user_estimating_function_stacks,
494
- per_user_adaptive_corrections,
495
- per_user_classical_corrections,
496
- per_user_adaptive_meat_adjustments,
497
- ) = construct_classical_and_adaptive_sandwiches(
495
+ per_subject_estimating_function_stacks,
496
+ per_subject_adjusted_corrections,
497
+ per_subject_classical_corrections,
498
+ per_subject_adjusted_meat_adjustments,
499
+ ) = construct_classical_and_adjusted_sandwiches(
498
500
  theta_est,
499
501
  all_post_update_betas,
500
- user_ids,
502
+ subject_ids,
501
503
  action_prob_func,
502
504
  action_prob_func_args_beta_index,
503
505
  alg_update_func,
@@ -511,23 +513,23 @@ def analyze_dataset(
511
513
  inference_func_args_theta_index,
512
514
  inference_func_args_action_prob_index,
513
515
  action_prob_func_args,
514
- policy_num_by_decision_time_by_user_id,
516
+ policy_num_by_decision_time_by_subject_id,
515
517
  initial_policy_num,
516
518
  beta_index_by_policy_num,
517
- inference_func_args_by_user_id,
518
- inference_action_prob_decision_times_by_user_id,
519
+ inference_func_args_by_subject_id,
520
+ inference_action_prob_decision_times_by_subject_id,
519
521
  alg_update_func_args,
520
- action_by_decision_time_by_user_id,
522
+ action_by_decision_time_by_subject_id,
521
523
  suppress_all_data_checks,
522
524
  suppress_interactive_data_checks,
523
525
  small_sample_correction,
524
- form_adaptive_meat_adjustments_explicitly,
525
- stabilize_joint_adaptive_bread_inverse,
526
- study_df,
527
- in_study_col_name,
526
+ form_adjusted_meat_adjustments_explicitly,
527
+ stabilize_joint_adjusted_bread_inverse,
528
+ analysis_df,
529
+ active_col_name,
528
530
  action_col_name,
529
531
  calendar_t_col_name,
530
- user_id_col_name,
532
+ subject_id_col_name,
531
533
  action_prob_func_args,
532
534
  action_prob_col_name,
533
535
  )
@@ -543,18 +545,18 @@ def analyze_dataset(
543
545
 
544
546
  # This bottom right corner of the joint (betas and theta) variance matrix is the portion
545
547
  # corresponding to just theta.
546
- adaptive_sandwich_var_estimate = joint_adaptive_sandwich_matrix[
548
+ adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
547
549
  -theta_dim:, -theta_dim:
548
550
  ]
549
551
 
550
552
  # Check for negative diagonal elements and set them to zero if found
551
- adaptive_diagonal = np.diag(adaptive_sandwich_var_estimate)
553
+ adaptive_diagonal = np.diag(adjusted_sandwich_var_estimate)
552
554
  if np.any(adaptive_diagonal < 0):
553
555
  logger.warning(
554
556
  "Found negative diagonal elements in adaptive sandwich variance estimate. Setting them to zero."
555
557
  )
556
558
  np.fill_diagonal(
557
- adaptive_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
559
+ adjusted_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
558
560
  )
559
561
 
560
562
  logger.info("Writing results to file...")
@@ -563,7 +565,7 @@ def analyze_dataset(
563
565
 
564
566
  analysis_dict = {
565
567
  "theta_est": theta_est,
566
- "adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
568
+ "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
567
569
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
568
570
  }
569
571
  with open(output_folder_abs_path / "analysis.pkl", "wb") as f:
@@ -572,29 +574,29 @@ def analyze_dataset(
572
574
  f,
573
575
  )
574
576
 
575
- joint_adaptive_bread_inverse_cond = jnp.linalg.cond(
576
- raw_joint_adaptive_bread_inverse_matrix
577
+ joint_adjusted_bread_inverse_cond = jnp.linalg.cond(
578
+ raw_joint_adjusted_bread_inverse_matrix
577
579
  )
578
580
  logger.info(
579
- "Joint adaptive bread inverse condition number: %f",
580
- joint_adaptive_bread_inverse_cond,
581
+ "Joint adjusted bread inverse condition number: %f",
582
+ joint_adjusted_bread_inverse_cond,
581
583
  )
582
584
 
583
585
  debug_pieces_dict = {
584
586
  "theta_est": theta_est,
585
- "adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
587
+ "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
586
588
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
587
- "raw_joint_bread_inverse_matrix": raw_joint_adaptive_bread_inverse_matrix,
588
- "stabilized_joint_bread_inverse_matrix": stabilized_joint_adaptive_bread_inverse_matrix,
589
- "joint_meat_matrix": joint_adaptive_meat_matrix,
589
+ "raw_joint_bread_inverse_matrix": raw_joint_adjusted_bread_inverse_matrix,
590
+ "stabilized_joint_bread_inverse_matrix": stabilized_joint_adjusted_bread_inverse_matrix,
591
+ "joint_meat_matrix": joint_adjusted_meat_matrix,
590
592
  "classical_bread_inverse_matrix": classical_bread_inverse_matrix,
591
593
  "classical_meat_matrix": classical_meat_matrix,
592
- "all_estimating_function_stacks": per_user_estimating_function_stacks,
593
- "joint_bread_inverse_condition_number": joint_adaptive_bread_inverse_cond,
594
+ "all_estimating_function_stacks": per_subject_estimating_function_stacks,
595
+ "joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
594
596
  "all_post_update_betas": all_post_update_betas,
595
- "per_user_adaptive_corrections": per_user_adaptive_corrections,
596
- "per_user_classical_corrections": per_user_classical_corrections,
597
- "per_user_adaptive_meat_adjustments": per_user_adaptive_meat_adjustments,
597
+ "per_subject_adjusted_corrections": per_subject_adjusted_corrections,
598
+ "per_subject_classical_corrections": per_subject_classical_corrections,
599
+ "per_subject_adjusted_meat_adjustments": per_subject_adjusted_meat_adjustments,
598
600
  }
599
601
  with open(output_folder_abs_path / "debug_pieces.pkl", "wb") as f:
600
602
  pickle.dump(
@@ -604,25 +606,25 @@ def analyze_dataset(
604
606
 
605
607
  if collect_data_for_blowup_supervised_learning:
606
608
  datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
607
- raw_joint_adaptive_bread_inverse_matrix,
608
- joint_adaptive_bread_inverse_cond,
609
+ raw_joint_adjusted_bread_inverse_matrix,
610
+ joint_adjusted_bread_inverse_cond,
609
611
  avg_estimating_function_stack,
610
- per_user_estimating_function_stacks,
612
+ per_subject_estimating_function_stacks,
611
613
  all_post_update_betas,
612
- study_df,
613
- in_study_col_name,
614
+ analysis_df,
615
+ active_col_name,
614
616
  calendar_t_col_name,
615
617
  action_prob_col_name,
616
- user_id_col_name,
618
+ subject_id_col_name,
617
619
  reward_col_name,
618
620
  theta_est,
619
- adaptive_sandwich_var_estimate,
620
- user_ids,
621
+ adjusted_sandwich_var_estimate,
622
+ subject_ids,
621
623
  beta_dim,
622
624
  theta_dim,
623
625
  initial_policy_num,
624
626
  beta_index_by_policy_num,
625
- policy_num_by_decision_time_by_user_id,
627
+ policy_num_by_decision_time_by_subject_id,
626
628
  theta_calculation_func,
627
629
  action_prob_func,
628
630
  action_prob_func_args_beta_index,
@@ -630,16 +632,16 @@ def analyze_dataset(
630
632
  inference_func_type,
631
633
  inference_func_args_theta_index,
632
634
  inference_func_args_action_prob_index,
633
- inference_action_prob_decision_times_by_user_id,
635
+ inference_action_prob_decision_times_by_subject_id,
634
636
  action_prob_func_args,
635
- action_by_decision_time_by_user_id,
637
+ action_by_decision_time_by_subject_id,
636
638
  )
637
639
 
638
640
  with open(output_folder_abs_path / "supervised_learning_datum.pkl", "wb") as f:
639
641
  pickle.dump(datum_and_label_dict, f)
640
642
 
641
643
  print(f"\nParameter estimate:\n {theta_est}")
642
- print(f"\nAdaptive sandwich variance estimate:\n {adaptive_sandwich_var_estimate}")
644
+ print(f"\nAdjusted sandwich variance estimate:\n {adjusted_sandwich_var_estimate}")
643
645
  print(
644
646
  f"\nClassical sandwich variance estimate:\n {classical_sandwich_var_estimate}\n"
645
647
  )
@@ -650,15 +652,15 @@ def analyze_dataset(
650
652
  def process_inference_func_args(
651
653
  inference_func: callable,
652
654
  inference_func_args_theta_index: int,
653
- study_df: pd.DataFrame,
655
+ analysis_df: pd.DataFrame,
654
656
  theta_est: jnp.ndarray,
655
657
  action_prob_col_name: str,
656
658
  calendar_t_col_name: str,
657
- user_id_col_name: str,
658
- in_study_col_name: str,
659
+ subject_id_col_name: str,
660
+ active_col_name: str,
659
661
  ) -> tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int]:
660
662
  """
661
- Collects the inference function arguments for each user from the study DataFrame.
663
+ Collects the inference function arguments for each subject from the analysis DataFrame.
662
664
 
663
665
  Note that theta and action probabilities, if present, will be replaced later
664
666
  so that the function can be differentiated with respect to shared versions
@@ -669,32 +671,32 @@ def process_inference_func_args(
669
671
  The inference function to be used.
670
672
  inference_func_args_theta_index (int):
671
673
  The index of the theta parameter in the inference function's arguments.
672
- study_df (pandas.DataFrame):
673
- The study DataFrame.
674
+ analysis_df (pandas.DataFrame):
675
+ The analysis DataFrame.
674
676
  theta_est (jnp.ndarray):
675
677
  The estimate of the parameter vector.
676
678
  action_prob_col_name (str):
677
- The name of the column in the study DataFrame that gives action probabilities.
679
+ The name of the column in the analysis DataFrame that gives action probabilities.
678
680
  calendar_t_col_name (str):
679
- The name of the column in the study DataFrame that indicates calendar time.
680
- user_id_col_name (str):
681
- The name of the column in the study DataFrame that indicates user ID.
682
- in_study_col_name (str):
683
- The name of the binary column in the study DataFrame that indicates whether a user is in the study.
681
+ The name of the column in the analysis DataFrame that indicates calendar time.
682
+ subject_id_col_name (str):
683
+ The name of the column in the analysis DataFrame that indicates subject ID.
684
+ active_col_name (str):
685
+ The name of the binary column in the analysis DataFrame that indicates whether a subject is in the deployment.
684
686
  Returns:
685
687
  tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int, dict[collections.abc.Hashable, jnp.ndarray[int]]]:
686
688
  A tuple containing
687
- - the inference function arguments dictionary for each user
689
+ - the inference function arguments dictionary for each subject
688
690
  - the index of the action probabilities argument
689
- - a dictionary mapping user IDs to the decision times to which action probabilities correspond
691
+ - a dictionary mapping subject IDs to the decision times to which action probabilities correspond
690
692
  """
691
693
 
692
694
  num_args = inference_func.__code__.co_argcount
693
695
  inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
694
- inference_func_args_by_user_id = {}
696
+ inference_func_args_by_subject_id = {}
695
697
 
696
698
  inference_func_args_action_prob_index = -1
697
- inference_action_prob_decision_times_by_user_id = {}
699
+ inference_action_prob_decision_times_by_subject_id = {}
698
700
 
699
701
  using_action_probs = action_prob_col_name in inference_func_arg_names
700
702
  if using_action_probs:
@@ -702,34 +704,36 @@ def process_inference_func_args(
702
704
  action_prob_col_name
703
705
  )
704
706
 
705
- for user_id in study_df[user_id_col_name].unique():
706
- user_args_list = []
707
- filtered_user_data = study_df.loc[study_df[user_id_col_name] == user_id]
707
+ for subject_id in analysis_df[subject_id_col_name].unique():
708
+ subject_args_list = []
709
+ filtered_subject_data = analysis_df.loc[
710
+ analysis_df[subject_id_col_name] == subject_id
711
+ ]
708
712
  for idx, col_name in enumerate(inference_func_arg_names):
709
713
  if idx == inference_func_args_theta_index:
710
- user_args_list.append(theta_est)
714
+ subject_args_list.append(theta_est)
711
715
  continue
712
- user_args_list.append(
713
- get_in_study_df_column(filtered_user_data, col_name, in_study_col_name)
716
+ subject_args_list.append(
717
+ get_active_df_column(filtered_subject_data, col_name, active_col_name)
714
718
  )
715
- inference_func_args_by_user_id[user_id] = tuple(user_args_list)
719
+ inference_func_args_by_subject_id[subject_id] = tuple(subject_args_list)
716
720
  if using_action_probs:
717
- inference_action_prob_decision_times_by_user_id[user_id] = (
718
- get_in_study_df_column(
719
- filtered_user_data, calendar_t_col_name, in_study_col_name
721
+ inference_action_prob_decision_times_by_subject_id[subject_id] = (
722
+ get_active_df_column(
723
+ filtered_subject_data, calendar_t_col_name, active_col_name
720
724
  )
721
725
  )
722
726
 
723
727
  return (
724
- inference_func_args_by_user_id,
728
+ inference_func_args_by_subject_id,
725
729
  inference_func_args_action_prob_index,
726
- inference_action_prob_decision_times_by_user_id,
730
+ inference_action_prob_decision_times_by_subject_id,
727
731
  )
728
732
 
729
733
 
730
- def single_user_weighted_estimating_function_stacker(
734
+ def single_subject_weighted_estimating_function_stacker(
731
735
  beta_dim: int,
732
- user_id: collections.abc.Hashable,
736
+ subject_id: collections.abc.Hashable,
733
737
  action_prob_func: callable,
734
738
  algorithm_estimating_func: callable,
735
739
  inference_estimating_func: callable,
@@ -763,12 +767,12 @@ def single_user_weighted_estimating_function_stacker(
763
767
  beta_dim (list[jnp.ndarray]):
764
768
  A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
765
769
 
766
- user_id (collections.abc.Hashable):
767
- The user ID for which to compute the weighted estimating function stack.
770
+ subject_id (collections.abc.Hashable):
771
+ The subject ID for which to compute the weighted estimating function stack.
768
772
 
769
773
  action_prob_func (callable):
770
774
  The function used to compute the probability of action 1 at a given decision time for
771
- a particular user given their state and the algorithm parameters.
775
+ a particular subject given their state and the algorithm parameters.
772
776
 
773
777
  algorithm_estimating_func (callable):
774
778
  The estimating function that corresponds to algorithm updates.
@@ -783,9 +787,9 @@ def single_user_weighted_estimating_function_stacker(
783
787
  The index of the theta parameter in the inference loss or estimating function arguments.
784
788
 
785
789
  action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
786
- A map from decision times to tuples of arguments for this user for the action
790
+ A map from decision times to tuples of arguments for this subject for the action
787
791
  probability function. This is for all decision times (args are an empty
788
- tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
792
+ tuple if they are not in the deployment). Should be sorted by decision time. NOTE THAT THESE
789
793
  ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
790
794
  will occur.
791
795
 
@@ -796,21 +800,21 @@ def single_user_weighted_estimating_function_stacker(
796
800
 
797
801
  threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
798
802
  A map from policy numbers to tuples containing the arguments for
799
- the corresponding estimating functions for this user, with the shared betas threaded in
803
+ the corresponding estimating functions for this subject, with the shared betas threaded in
800
804
  for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
801
805
  should be sorted.
802
806
 
803
807
  threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
804
808
  A tuple containing the arguments for the inference
805
- estimating function for this user, with the shared betas threaded in for differentiation.
809
+ estimating function for this subject, with the shared betas threaded in for differentiation.
806
810
 
807
811
  policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
808
812
  A dictionary mapping decision times to the policy number in use. This may be
809
- user-specific. Should be sorted by decision time. Only applies to in-study decision
813
+ subject-specific. Should be sorted by decision time. Only applies to active decision
810
814
  times!
811
815
 
812
816
  action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
813
- A dictionary mapping decision times to actions taken. Only applies to in-study decision
817
+ A dictionary mapping decision times to actions taken. Only applies to active decision
814
818
  times!
815
819
 
816
820
  beta_index_by_policy_num (dict[int | float, int]):
@@ -818,19 +822,21 @@ def single_user_weighted_estimating_function_stacker(
818
822
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
819
823
 
820
824
  Returns:
821
- jnp.ndarray: A 1-D JAX NumPy array representing the user's weighted estimating function
825
+ jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
822
826
  stack.
823
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's adaptive meat contribution.
824
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical meat contribution.
825
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical bread contribution.
827
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adaptive meat contribution.
828
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
829
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
826
830
  """
827
831
 
828
- logger.info("Computing weighted estimating function stack for user %s.", user_id)
832
+ logger.info(
833
+ "Computing weighted estimating function stack for subject %s.", subject_id
834
+ )
829
835
 
830
836
  # First, reformat the supplied data into more convenient structures.
831
837
 
832
838
  # 1. Form a dictionary mapping policy numbers to the first time they were
833
- # applicable (for this user). Note that this includes ALL policies, initial
839
+ # applicable (for this subject). Note that this includes ALL policies, initial
834
840
  # fallbacks included.
835
841
  # Collect the first time after the first update separately for convenience.
836
842
  # These are both used to form the Radon-Nikodym weights for the right times.
@@ -839,38 +845,38 @@ def single_user_weighted_estimating_function_stacker(
839
845
  beta_index_by_policy_num,
840
846
  )
841
847
 
842
- # 2. Get the start and end times for this user.
843
- user_start_time = math.inf
844
- user_end_time = -math.inf
848
+ # 2. Get the start and end times for this subject.
849
+ subject_start_time = math.inf
850
+ subject_end_time = -math.inf
845
851
  for decision_time in action_by_decision_time:
846
- user_start_time = min(user_start_time, decision_time)
847
- user_end_time = max(user_end_time, decision_time)
852
+ subject_start_time = min(subject_start_time, decision_time)
853
+ subject_end_time = max(subject_end_time, decision_time)
848
854
 
849
855
  # 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
850
856
  logger.info(
851
- "Computing the algorithm component of the weighted estimating function stack for user %s.",
852
- user_id,
857
+ "Computing the algorithm component of the weighted estimating function stack for subject %s.",
858
+ subject_id,
853
859
  )
854
860
 
855
- in_study_action_prob_func_args = [
861
+ active_action_prob_func_args = [
856
862
  args for args in action_prob_func_args_by_decision_time.values() if args
857
863
  ]
858
- in_study_betas_list_by_decision_time_index = jnp.array(
864
+ active_betas_list_by_decision_time_index = jnp.array(
859
865
  [
860
866
  action_prob_func_args[action_prob_func_args_beta_index]
861
- for action_prob_func_args in in_study_action_prob_func_args
867
+ for action_prob_func_args in active_action_prob_func_args
862
868
  ]
863
869
  )
864
- in_study_actions_list_by_decision_time_index = jnp.array(
870
+ active_actions_list_by_decision_time_index = jnp.array(
865
871
  list(action_by_decision_time.values())
866
872
  )
867
873
 
868
874
  # Sort the threaded args by decision time to be cautious. We check if the
869
- # user id is present in the user args dict because we may call this on a
870
- # subset of the user arg dict when we are batching arguments by shape
875
+ # subject id is present in the subject args dict because we may call this on a
876
+ # subset of the subject arg dict when we are batching arguments by shape
871
877
  sorted_threaded_action_prob_args_by_decision_time = {
872
878
  decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
873
- for decision_time in range(user_start_time, user_end_time + 1)
879
+ for decision_time in range(subject_start_time, subject_end_time + 1)
874
880
  if decision_time in threaded_action_prob_func_args_by_decision_time
875
881
  }
876
882
 
@@ -901,19 +907,19 @@ def single_user_weighted_estimating_function_stacker(
901
907
  # Just grab the original beta from the update function arguments. This is the same
902
908
  # value, but impervious to differentiation with respect to all_post_update_betas. The
903
909
  # args, on the other hand, are a function of all_post_update_betas.
904
- in_study_weights = jax.vmap(
910
+ active_weights = jax.vmap(
905
911
  fun=get_radon_nikodym_weight,
906
912
  in_axes=[0, None, None, 0] + batch_axes,
907
913
  out_axes=0,
908
914
  )(
909
- in_study_betas_list_by_decision_time_index,
915
+ active_betas_list_by_decision_time_index,
910
916
  action_prob_func,
911
917
  action_prob_func_args_beta_index,
912
- in_study_actions_list_by_decision_time_index,
918
+ active_actions_list_by_decision_time_index,
913
919
  *batched_threaded_arg_tensors,
914
920
  )
915
921
 
916
- in_study_index = 0
922
+ active_index = 0
917
923
  decision_time_to_all_weights_index_offset = min(
918
924
  sorted_threaded_action_prob_args_by_decision_time
919
925
  )
@@ -922,35 +928,35 @@ def single_user_weighted_estimating_function_stacker(
922
928
  decision_time,
923
929
  args,
924
930
  ) in sorted_threaded_action_prob_args_by_decision_time.items():
925
- all_weights_raw.append(in_study_weights[in_study_index] if args else 1.0)
926
- in_study_index += 1
931
+ all_weights_raw.append(active_weights[active_index] if args else 1.0)
932
+ active_index += 1
927
933
  all_weights = jnp.array(all_weights_raw)
928
934
 
929
935
  algorithm_component = jnp.concatenate(
930
936
  [
931
937
  # Here we compute a product of Radon-Nikodym weights
932
938
  # for all decision times after the first update and before the update
933
- # update under consideration took effect, for which the user was in the study.
939
+ # update under consideration took effect, for which the subject was in the deployment.
934
940
  (
935
941
  jnp.prod(
936
942
  all_weights[
937
- # The earliest time after the first update where the user was in
938
- # the study
943
+ # The earliest time after the first update where the subject was in
944
+ # the deployment
939
945
  max(
940
946
  first_time_after_first_update,
941
- user_start_time,
947
+ subject_start_time,
942
948
  )
943
949
  - decision_time_to_all_weights_index_offset :
944
- # One more than the latest time the user was in the study before the time
950
+ # One more than the latest time the subject was in the deployment before the time
945
951
  # the update under consideration first applied. Note the + 1 because range
946
952
  # does not include the right endpoint.
947
953
  min(
948
954
  min_time_by_policy_num.get(policy_num, math.inf),
949
- user_end_time + 1,
955
+ subject_end_time + 1,
950
956
  )
951
957
  - decision_time_to_all_weights_index_offset,
952
958
  ]
953
- # If the user exited the study before there were any updates,
959
+ # If the subject exited the deployment before there were any updates,
954
960
  # this variable will be None and the above code to grab a weight would
955
961
  # throw an error. Just use 1 to include the unweighted estimating function
956
962
  # if they have data to contribute to the update.
@@ -958,8 +964,8 @@ def single_user_weighted_estimating_function_stacker(
958
964
  else 1
959
965
  ) # Now use the above to weight the alg estimating function for this update
960
966
  * algorithm_estimating_func(*update_args)
961
- # If there are no arguments for the update function, the user is not yet in the
962
- # study, so we just add a zero vector contribution to the sum across users.
967
+ # If there are no arguments for the update function, the subject is not yet in the
968
+ # deployment, so we just add a zero vector contribution to the sum across subjects.
963
969
  # Note that after they exit, they still contribute all their data to later
964
970
  # updates.
965
971
  if update_args
@@ -978,17 +984,17 @@ def single_user_weighted_estimating_function_stacker(
978
984
  )
979
985
  # 4. Form the weighted inference estimating equation.
980
986
  logger.info(
981
- "Computing the inference component of the weighted estimating function stack for user %s.",
982
- user_id,
987
+ "Computing the inference component of the weighted estimating function stack for subject %s.",
988
+ subject_id,
983
989
  )
984
990
  inference_component = jnp.prod(
985
991
  all_weights[
986
- max(first_time_after_first_update, user_start_time)
987
- - decision_time_to_all_weights_index_offset : user_end_time
992
+ max(first_time_after_first_update, subject_start_time)
993
+ - decision_time_to_all_weights_index_offset : subject_end_time
988
994
  + 1
989
995
  - decision_time_to_all_weights_index_offset,
990
996
  ]
991
- # If the user exited the study before there were any updates,
997
+ # If the subject exited the deployment before there were any updates,
992
998
  # this variable will be None and the above code to grab a weight would
993
999
  # throw an error. Just use 1 to include the unweighted estimating function
994
1000
  # if they have data to contribute here (pretty sure everyone should?)
@@ -997,18 +1003,18 @@ def single_user_weighted_estimating_function_stacker(
997
1003
  ) * inference_estimating_func(*threaded_inference_func_args)
998
1004
 
999
1005
  # 5. Concatenate the two components to form the weighted estimating function stack for this
1000
- # user.
1006
+ # subject.
1001
1007
  weighted_stack = jnp.concatenate([algorithm_component, inference_component])
1002
1008
 
1003
1009
  # 6. Return the following outputs:
1004
- # a. The first is simply the weighted estimating function stack for this user. The average
1010
+ # a. The first is simply the weighted estimating function stack for this subject. The average
1005
1011
  # of these is what we differentiate with respect to theta to form the inverse adaptive joint
1006
1012
  # bread matrix, and we also compare that average to zero to check the estimating functions'
1007
1013
  # fidelity.
1008
- # b. The average outer product of these per-user stacks across users is the adaptive joint meat
1014
+ # b. The average outer product of these per-subject stacks across subjects is the adaptive joint meat
1009
1015
  # matrix, hence the second output.
1010
- # c. The third output is averaged across users to obtain the classical meat matrix.
1011
- # d. The fourth output is averaged across users to obtain the inverse classical bread
1016
+ # c. The third output is averaged across subjects to obtain the classical meat matrix.
1017
+ # d. The fourth output is averaged across subjects to obtain the inverse classical bread
1012
1018
  # matrix.
1013
1019
  return (
1014
1020
  weighted_stack,
@@ -1024,7 +1030,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1024
1030
  flattened_betas_and_theta: jnp.ndarray,
1025
1031
  beta_dim: int,
1026
1032
  theta_dim: int,
1027
- user_ids: jnp.ndarray,
1033
+ subject_ids: jnp.ndarray,
1028
1034
  action_prob_func: callable,
1029
1035
  action_prob_func_args_beta_index: int,
1030
1036
  alg_update_func: callable,
@@ -1037,29 +1043,31 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1037
1043
  inference_func_type: str,
1038
1044
  inference_func_args_theta_index: int,
1039
1045
  inference_func_args_action_prob_index: int,
1040
- action_prob_func_args_by_user_id_by_decision_time: dict[
1046
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
1041
1047
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
1042
1048
  ],
1043
- policy_num_by_decision_time_by_user_id: dict[
1049
+ policy_num_by_decision_time_by_subject_id: dict[
1044
1050
  collections.abc.Hashable, dict[int, int | float]
1045
1051
  ],
1046
1052
  initial_policy_num: int | float,
1047
1053
  beta_index_by_policy_num: dict[int | float, int],
1048
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1049
- inference_action_prob_decision_times_by_user_id: dict[
1054
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1055
+ inference_action_prob_decision_times_by_subject_id: dict[
1050
1056
  collections.abc.Hashable, list[int]
1051
1057
  ],
1052
- update_func_args_by_by_user_id_by_policy_num: dict[
1058
+ update_func_args_by_by_subject_id_by_policy_num: dict[
1053
1059
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
1054
1060
  ],
1055
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
1061
+ action_by_decision_time_by_subject_id: dict[
1062
+ collections.abc.Hashable, dict[int, int]
1063
+ ],
1056
1064
  suppress_all_data_checks: bool,
1057
1065
  suppress_interactive_data_checks: bool,
1058
1066
  ) -> tuple[
1059
1067
  jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
1060
1068
  ]:
1061
1069
  """
1062
- Computes the average weighted estimating function stack across all users, along with
1070
+ Computes the average weighted estimating function stack across all subjects, along with
1063
1071
  auxiliary values used to construct the adaptive and classical sandwich variances.
1064
1072
 
1065
1073
  Args:
@@ -1071,8 +1079,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1071
1079
  The dimension of each of the beta parameters.
1072
1080
  theta_dim (int):
1073
1081
  The dimension of the theta parameter.
1074
- user_ids (jnp.ndarray):
1075
- A 1D JAX NumPy array of user IDs.
1082
+ subject_ids (jnp.ndarray):
1083
+ A 1D JAX NumPy array of subject IDs.
1076
1084
  action_prob_func (callable):
1077
1085
  The action probability function.
1078
1086
  action_prob_func_args_beta_index (int):
@@ -1100,29 +1108,29 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1100
1108
  inference_func_args_action_prob_index (int):
1101
1109
  The index of action probabilities in the inference function arguments tuple, if
1102
1110
  applicable. -1 otherwise.
1103
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1104
- A dictionary mapping decision times to maps of user ids to the function arguments
1105
- required to compute action probabilities for this user.
1106
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1107
- A map of user ids to dictionaries mapping decision times to the policy number in use.
1108
- Only applies to in-study decision times!
1111
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1112
+ A dictionary mapping decision times to maps of subject ids to the function arguments
1113
+ required to compute action probabilities for this subject.
1114
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1115
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
1116
+ Only applies to active decision times!
1109
1117
  initial_policy_num (int | float):
1110
1118
  The policy number of the initial policy before any updates.
1111
1119
  beta_index_by_policy_num (dict[int | float, int]):
1112
1120
  A dictionary mapping policy numbers to the index of the corresponding beta in
1113
1121
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1114
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1115
- A dictionary mapping user IDs to their respective inference function arguments.
1116
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1117
- For each user, a list of decision times to which action probabilities correspond if
1118
- provided. Typically just in-study times if action probabilites are used in the inference
1122
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1123
+ A dictionary mapping subject IDs to their respective inference function arguments.
1124
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
1125
+ For each subject, a list of decision times to which action probabilities correspond if
1126
+ provided. Typically just active times if action probabilites are used in the inference
1119
1127
  loss or estimating function.
1120
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1121
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
1128
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1129
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
1122
1130
  to their respective update function arguments.
1123
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1124
- A dictionary mapping user IDs to their respective actions taken at each decision time.
1125
- Only applies to in-study decision times!
1131
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
1132
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
1133
+ Only applies to active decision times!
1126
1134
  suppress_all_data_checks (bool):
1127
1135
  If True, suppresses carrying out any data checks at all.
1128
1136
  suppress_interactive_data_checks (bool):
@@ -1136,10 +1144,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1136
1144
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1137
1145
  A tuple containing
1138
1146
  1. the average weighted estimating function stack
1139
- 2. the user-level adaptive meat matrix contributions
1140
- 3. the user-level classical meat matrix contributions
1141
- 4. the user-level inverse classical bread matrix contributions
1142
- 5. raw per-user weighted estimating function
1147
+ 2. the subject-level adaptive meat matrix contributions
1148
+ 3. the subject-level classical meat matrix contributions
1149
+ 4. the subject-level inverse classical bread matrix contributions
1150
+ 5. raw per-subject weighted estimating function
1143
1151
  stacks.
1144
1152
  """
1145
1153
 
@@ -1166,15 +1174,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1166
1174
  # supplied for the above functions, so that differentiation works correctly. The existing
1167
1175
  # values should be the same, but not connected to the parameter we are differentiating
1168
1176
  # with respect to. Note we will also find it useful below to have the action probability args
1169
- # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
1177
+ # nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
1170
1178
 
1171
- logger.info("Threading in betas to action probability arguments for all users.")
1179
+ logger.info("Threading in betas to action probability arguments for all subjects.")
1172
1180
  (
1173
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1174
- action_prob_func_args_by_decision_time_by_user_id,
1181
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1182
+ action_prob_func_args_by_decision_time_by_subject_id,
1175
1183
  ) = thread_action_prob_func_args(
1176
- action_prob_func_args_by_user_id_by_decision_time,
1177
- policy_num_by_decision_time_by_user_id,
1184
+ action_prob_func_args_by_subject_id_by_decision_time,
1185
+ policy_num_by_decision_time_by_subject_id,
1178
1186
  initial_policy_num,
1179
1187
  betas,
1180
1188
  beta_index_by_policy_num,
@@ -1186,17 +1194,17 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1186
1194
  # arguments with the central betas introduced.
1187
1195
  logger.info(
1188
1196
  "Threading in betas and beta-dependent action probabilities to algorithm update "
1189
- "function args for all users"
1197
+ "function args for all subjects"
1190
1198
  )
1191
- threaded_update_func_args_by_policy_num_by_user_id = thread_update_func_args(
1192
- update_func_args_by_by_user_id_by_policy_num,
1199
+ threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
1200
+ update_func_args_by_by_subject_id_by_policy_num,
1193
1201
  betas,
1194
1202
  beta_index_by_policy_num,
1195
1203
  alg_update_func_args_beta_index,
1196
1204
  alg_update_func_args_action_prob_index,
1197
1205
  alg_update_func_args_action_prob_times_index,
1198
1206
  alg_update_func_args_previous_betas_index,
1199
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1207
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1200
1208
  action_prob_func,
1201
1209
  )
1202
1210
 
@@ -1206,8 +1214,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1206
1214
  if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
1207
1215
  input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
1208
1216
  algorithm_estimating_func,
1209
- update_func_args_by_by_user_id_by_policy_num,
1210
- threaded_update_func_args_by_policy_num_by_user_id,
1217
+ update_func_args_by_by_subject_id_by_policy_num,
1218
+ threaded_update_func_args_by_policy_num_by_subject_id,
1211
1219
  suppress_interactive_data_checks,
1212
1220
  )
1213
1221
 
@@ -1216,15 +1224,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1216
1224
  # arguments with the central betas introduced.
1217
1225
  logger.info(
1218
1226
  "Threading in theta and beta-dependent action probabilities to inference update "
1219
- "function args for all users"
1227
+ "function args for all subjects"
1220
1228
  )
1221
- threaded_inference_func_args_by_user_id = thread_inference_func_args(
1222
- inference_func_args_by_user_id,
1229
+ threaded_inference_func_args_by_subject_id = thread_inference_func_args(
1230
+ inference_func_args_by_subject_id,
1223
1231
  inference_func_args_theta_index,
1224
1232
  theta,
1225
1233
  inference_func_args_action_prob_index,
1226
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1227
- inference_action_prob_decision_times_by_user_id,
1234
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1235
+ inference_action_prob_decision_times_by_subject_id,
1228
1236
  action_prob_func,
1229
1237
  )
1230
1238
 
@@ -1234,32 +1242,32 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1234
1242
  if not suppress_all_data_checks and inference_func_args_action_prob_index >= 0:
1235
1243
  input_checks.require_threaded_inference_estimating_function_args_equivalent(
1236
1244
  inference_estimating_func,
1237
- inference_func_args_by_user_id,
1238
- threaded_inference_func_args_by_user_id,
1245
+ inference_func_args_by_subject_id,
1246
+ threaded_inference_func_args_by_subject_id,
1239
1247
  suppress_interactive_data_checks,
1240
1248
  )
1241
1249
 
1242
- # 5. Now we can compute the weighted estimating function stacks for all users
1250
+ # 5. Now we can compute the weighted estimating function stacks for all subjects
1243
1251
  # as well as collect related values used to construct the adaptive and classical
1244
1252
  # sandwich variances.
1245
1253
  results = [
1246
- single_user_weighted_estimating_function_stacker(
1254
+ single_subject_weighted_estimating_function_stacker(
1247
1255
  beta_dim,
1248
- user_id,
1256
+ subject_id,
1249
1257
  action_prob_func,
1250
1258
  algorithm_estimating_func,
1251
1259
  inference_estimating_func,
1252
1260
  action_prob_func_args_beta_index,
1253
1261
  inference_func_args_theta_index,
1254
- action_prob_func_args_by_decision_time_by_user_id[user_id],
1255
- threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
1256
- threaded_update_func_args_by_policy_num_by_user_id[user_id],
1257
- threaded_inference_func_args_by_user_id[user_id],
1258
- policy_num_by_decision_time_by_user_id[user_id],
1259
- action_by_decision_time_by_user_id[user_id],
1262
+ action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1263
+ threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1264
+ threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
1265
+ threaded_inference_func_args_by_subject_id[subject_id],
1266
+ policy_num_by_decision_time_by_subject_id[subject_id],
1267
+ action_by_decision_time_by_subject_id[subject_id],
1260
1268
  beta_index_by_policy_num,
1261
1269
  )
1262
- for user_id in user_ids.tolist()
1270
+ for subject_id in subject_ids.tolist()
1263
1271
  ]
1264
1272
 
1265
1273
  stacks = jnp.array([result[0] for result in results])
@@ -1270,10 +1278,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1270
1278
  # 6. Note this strange return structure! We will differentiate the first output,
1271
1279
  # but the second tuple will be passed along without modification via has_aux=True and then used
1272
1280
  # for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
1273
- # bread matrices. The raw per-user stacks are also returned for debugging purposes.
1281
+ # bread matrices. The raw per-subject stacks are also returned for debugging purposes.
1274
1282
 
1275
1283
  # Note that returning the raw stacks here as the first arguments is potentially
1276
- # memory-intensive when combined with differentiation. Keep this in mind if the per-user bread
1284
+ # memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
1277
1285
  # inverse contributions are needed for something like CR2/CR3 small-sample corrections.
1278
1286
  return jnp.mean(stacks, axis=0), (
1279
1287
  jnp.mean(stacks, axis=0),
@@ -1284,10 +1292,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1284
1292
  )
1285
1293
 
1286
1294
 
1287
- def construct_classical_and_adaptive_sandwiches(
1295
+ def construct_classical_and_adjusted_sandwiches(
1288
1296
  theta_est: jnp.ndarray,
1289
1297
  all_post_update_betas: jnp.ndarray,
1290
- user_ids: jnp.ndarray,
1298
+ subject_ids: jnp.ndarray,
1291
1299
  action_prob_func: callable,
1292
1300
  action_prob_func_args_beta_index: int,
1293
1301
  alg_update_func: callable,
@@ -1300,32 +1308,34 @@ def construct_classical_and_adaptive_sandwiches(
1300
1308
  inference_func_type: str,
1301
1309
  inference_func_args_theta_index: int,
1302
1310
  inference_func_args_action_prob_index: int,
1303
- action_prob_func_args_by_user_id_by_decision_time: dict[
1311
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
1304
1312
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
1305
1313
  ],
1306
- policy_num_by_decision_time_by_user_id: dict[
1314
+ policy_num_by_decision_time_by_subject_id: dict[
1307
1315
  collections.abc.Hashable, dict[int, int | float]
1308
1316
  ],
1309
1317
  initial_policy_num: int | float,
1310
1318
  beta_index_by_policy_num: dict[int | float, int],
1311
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1312
- inference_action_prob_decision_times_by_user_id: dict[
1319
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1320
+ inference_action_prob_decision_times_by_subject_id: dict[
1313
1321
  collections.abc.Hashable, list[int]
1314
1322
  ],
1315
- update_func_args_by_by_user_id_by_policy_num: dict[
1323
+ update_func_args_by_by_subject_id_by_policy_num: dict[
1316
1324
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
1317
1325
  ],
1318
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
1326
+ action_by_decision_time_by_subject_id: dict[
1327
+ collections.abc.Hashable, dict[int, int]
1328
+ ],
1319
1329
  suppress_all_data_checks: bool,
1320
1330
  suppress_interactive_data_checks: bool,
1321
1331
  small_sample_correction: str,
1322
- form_adaptive_meat_adjustments_explicitly: bool,
1323
- stabilize_joint_adaptive_bread_inverse: bool,
1324
- study_df: pd.DataFrame | None,
1325
- in_study_col_name: str | None,
1332
+ form_adjusted_meat_adjustments_explicitly: bool,
1333
+ stabilize_joint_adjusted_bread_inverse: bool,
1334
+ analysis_df: pd.DataFrame | None,
1335
+ active_col_name: str | None,
1326
1336
  action_col_name: str | None,
1327
1337
  calendar_t_col_name: str | None,
1328
- user_id_col_name: str | None,
1338
+ subject_id_col_name: str | None,
1329
1339
  action_prob_func_args: tuple | None,
1330
1340
  action_prob_col_name: str | None,
1331
1341
  ) -> tuple[
@@ -1354,8 +1364,8 @@ def construct_classical_and_adaptive_sandwiches(
1354
1364
  A 1-D JAX NumPy array representing the parameter estimate for inference.
1355
1365
  all_post_update_betas (jnp.ndarray):
1356
1366
  A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
1357
- user_ids (jnp.ndarray):
1358
- A 1-D JAX NumPy array holding all user IDs in the study.
1367
+ subject_ids (jnp.ndarray):
1368
+ A 1-D JAX NumPy array holding all subject IDs in the deployment.
1359
1369
  action_prob_func (callable):
1360
1370
  The action probability function.
1361
1371
  action_prob_func_args_beta_index (int):
@@ -1383,29 +1393,29 @@ def construct_classical_and_adaptive_sandwiches(
1383
1393
  inference_func_args_action_prob_index (int):
1384
1394
  The index of action probabilities in the inference function arguments tuple, if
1385
1395
  applicable. -1 otherwise.
1386
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1387
- A dictionary mapping decision times to maps of user ids to the function arguments
1388
- required to compute action probabilities for this user.
1389
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1390
- A map of user ids to dictionaries mapping decision times to the policy number in use.
1391
- Only applies to in-study decision times!
1396
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1397
+ A dictionary mapping decision times to maps of subject ids to the function arguments
1398
+ required to compute action probabilities for this subject.
1399
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1400
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
1401
+ Only applies to active decision times!
1392
1402
  initial_policy_num (int | float):
1393
1403
  The policy number of the initial policy before any updates.
1394
1404
  beta_index_by_policy_num (dict[int | float, int]):
1395
1405
  A dictionary mapping policy numbers to the index of the corresponding beta in
1396
1406
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1397
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1398
- A dictionary mapping user IDs to their respective inference function arguments.
1399
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1400
- For each user, a list of decision times to which action probabilities correspond if
1401
- provided. Typically just in-study times if action probabilites are used in the inference
1407
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1408
+ A dictionary mapping subject IDs to their respective inference function arguments.
1409
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
1410
+ For each subject, a list of decision times to which action probabilities correspond if
1411
+ provided. Typically just active times if action probabilites are used in the inference
1402
1412
  loss or estimating function.
1403
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1404
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
1413
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1414
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
1405
1415
  to their respective update function arguments.
1406
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1407
- A dictionary mapping user IDs to their respective actions taken at each decision time.
1408
- Only applies to in-study decision times!
1416
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
1417
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
1418
+ Only applies to active decision times!
1409
1419
  suppress_all_data_checks (bool):
1410
1420
  If True, suppresses carrying out any data checks at all.
1411
1421
  suppress_interactive_data_checks (bool):
@@ -1415,27 +1425,27 @@ def construct_classical_and_adaptive_sandwiches(
1415
1425
  small_sample_correction (str):
1416
1426
  The type of small sample correction to apply. See SmallSampleCorrections class for
1417
1427
  options.
1418
- form_adaptive_meat_adjustments_explicitly (bool):
1419
- If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
1428
+ form_adjusted_meat_adjustments_explicitly (bool):
1429
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
1420
1430
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
1421
1431
  adaptive sandwich is formed without doing this.
1422
- stabilize_joint_adaptive_bread_inverse (bool):
1432
+ stabilize_joint_adjusted_bread_inverse (bool):
1423
1433
  If True, will apply various techniques to stabilize the joint adaptive bread inverse if necessary.
1424
- study_df (pd.DataFrame):
1425
- The full study dataframe, needed if forming the adaptive meat adjustments explicitly.
1426
- in_study_col_name (str):
1427
- The name of the column in study_df indicating whether a user is in-study at a given decision time.
1434
+ analysis_df (pd.DataFrame):
1435
+ The full analysis dataframe, needed if forming the adaptive meat adjustments explicitly.
1436
+ active_col_name (str):
1437
+ The name of the column in analysis_df indicating whether a subject is active at a given decision time.
1428
1438
  action_col_name (str):
1429
- The name of the column in study_df indicating the action taken at a given decision time.
1439
+ The name of the column in analysis_df indicating the action taken at a given decision time.
1430
1440
  calendar_t_col_name (str):
1431
- The name of the column in study_df indicating the calendar time of a given decision time.
1432
- user_id_col_name (str):
1433
- The name of the column in study_df indicating the user ID.
1441
+ The name of the column in analysis_df indicating the calendar time of a given decision time.
1442
+ subject_id_col_name (str):
1443
+ The name of the column in analysis_df indicating the subject ID.
1434
1444
  action_prob_func_args (tuple):
1435
1445
  The arguments to be passed to the action probability function, needed if forming the
1436
1446
  adaptive meat adjustments explicitly.
1437
1447
  action_prob_col_name (str):
1438
- The name of the column in study_df indicating the action probability of the action taken,
1448
+ The name of the column in analysis_df indicating the action probability of the action taken,
1439
1449
  needed if forming the adaptive meat adjustments explicitly.
1440
1450
  Returns:
1441
1451
  tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
@@ -1448,10 +1458,10 @@ def construct_classical_and_adaptive_sandwiches(
1448
1458
  - The classical meat matrix.
1449
1459
  - The classical sandwich matrix.
1450
1460
  - The average weighted estimating function stack.
1451
- - All per-user weighted estimating function stacks.
1452
- - The per-user adaptive meat small-sample corrections.
1453
- - The per-user classical meat small-sample corrections.
1454
- - The per-user adaptive meat adjustments, if form_adaptive_meat_adjustments_explicitly
1461
+ - All per-subject weighted estimating function stacks.
1462
+ - The per-subject adaptive meat small-sample corrections.
1463
+ - The per-subject classical meat small-sample corrections.
1464
+ - The per-subject adaptive meat adjustments, if form_adjusted_meat_adjustments_explicitly
1455
1465
  is True, otherwise an array of NaNs.
1456
1466
  """
1457
1467
  logger.info(
@@ -1459,13 +1469,13 @@ def construct_classical_and_adaptive_sandwiches(
1459
1469
  )
1460
1470
  theta_dim = theta_est.shape[0]
1461
1471
  beta_dim = all_post_update_betas.shape[1]
1462
- # Note that these "contributions" are per-user Jacobians of the weighted estimating function stack.
1463
- raw_joint_adaptive_bread_inverse_matrix, (
1472
+ # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
1473
+ raw_joint_adjusted_bread_inverse_matrix, (
1464
1474
  avg_estimating_function_stack,
1465
- per_user_joint_adaptive_meat_contributions,
1466
- per_user_classical_meat_contributions,
1467
- per_user_classical_bread_inverse_contributions,
1468
- per_user_estimating_function_stacks,
1475
+ per_subject_joint_adjusted_meat_contributions,
1476
+ per_subject_classical_meat_contributions,
1477
+ per_subject_classical_bread_inverse_contributions,
1478
+ per_subject_estimating_function_stacks,
1469
1479
  ) = jax.jacrev(
1470
1480
  get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
1471
1481
  )(
@@ -1475,7 +1485,7 @@ def construct_classical_and_adaptive_sandwiches(
1475
1485
  flatten_params(all_post_update_betas, theta_est),
1476
1486
  beta_dim,
1477
1487
  theta_dim,
1478
- user_ids,
1488
+ subject_ids,
1479
1489
  action_prob_func,
1480
1490
  action_prob_func_args_beta_index,
1481
1491
  alg_update_func,
@@ -1488,87 +1498,87 @@ def construct_classical_and_adaptive_sandwiches(
1488
1498
  inference_func_type,
1489
1499
  inference_func_args_theta_index,
1490
1500
  inference_func_args_action_prob_index,
1491
- action_prob_func_args_by_user_id_by_decision_time,
1492
- policy_num_by_decision_time_by_user_id,
1501
+ action_prob_func_args_by_subject_id_by_decision_time,
1502
+ policy_num_by_decision_time_by_subject_id,
1493
1503
  initial_policy_num,
1494
1504
  beta_index_by_policy_num,
1495
- inference_func_args_by_user_id,
1496
- inference_action_prob_decision_times_by_user_id,
1497
- update_func_args_by_by_user_id_by_policy_num,
1498
- action_by_decision_time_by_user_id,
1505
+ inference_func_args_by_subject_id,
1506
+ inference_action_prob_decision_times_by_subject_id,
1507
+ update_func_args_by_by_subject_id_by_policy_num,
1508
+ action_by_decision_time_by_subject_id,
1499
1509
  suppress_all_data_checks,
1500
1510
  suppress_interactive_data_checks,
1501
1511
  )
1502
1512
 
1503
- num_users = len(user_ids)
1513
+ num_subjects = len(subject_ids)
1504
1514
 
1505
1515
  (
1506
- joint_adaptive_meat_matrix,
1516
+ joint_adjusted_meat_matrix,
1507
1517
  classical_meat_matrix,
1508
- per_user_adaptive_corrections,
1509
- per_user_classical_corrections,
1518
+ per_subject_adjusted_corrections,
1519
+ per_subject_classical_corrections,
1510
1520
  ) = perform_desired_small_sample_correction(
1511
1521
  small_sample_correction,
1512
- per_user_joint_adaptive_meat_contributions,
1513
- per_user_classical_meat_contributions,
1514
- per_user_classical_bread_inverse_contributions,
1515
- num_users,
1522
+ per_subject_joint_adjusted_meat_contributions,
1523
+ per_subject_classical_meat_contributions,
1524
+ per_subject_classical_bread_inverse_contributions,
1525
+ num_subjects,
1516
1526
  theta_dim,
1517
1527
  )
1518
1528
 
1519
1529
  # Increase diagonal block dominance possibly improve conditioning of diagonal
1520
1530
  # blocks as necessary, to ensure mathematical stability of joint bread inverse
1521
- stabilized_joint_adaptive_bread_inverse_matrix = (
1531
+ stabilized_joint_adjusted_bread_inverse_matrix = (
1522
1532
  (
1523
- stabilize_joint_adaptive_bread_inverse_if_necessary(
1524
- raw_joint_adaptive_bread_inverse_matrix,
1533
+ stabilize_joint_adjusted_bread_inverse_if_necessary(
1534
+ raw_joint_adjusted_bread_inverse_matrix,
1525
1535
  beta_dim,
1526
1536
  theta_dim,
1527
1537
  )
1528
1538
  )
1529
- if stabilize_joint_adaptive_bread_inverse
1530
- else raw_joint_adaptive_bread_inverse_matrix
1539
+ if stabilize_joint_adjusted_bread_inverse
1540
+ else raw_joint_adjusted_bread_inverse_matrix
1531
1541
  )
1532
1542
 
1533
1543
  # Now stably (no explicit inversion) form our sandwiches.
1534
- joint_adaptive_sandwich = form_sandwich_from_bread_inverse_and_meat(
1535
- stabilized_joint_adaptive_bread_inverse_matrix,
1536
- joint_adaptive_meat_matrix,
1537
- num_users,
1544
+ joint_adjusted_sandwich = form_sandwich_from_bread_inverse_and_meat(
1545
+ stabilized_joint_adjusted_bread_inverse_matrix,
1546
+ joint_adjusted_meat_matrix,
1547
+ num_subjects,
1538
1548
  method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1539
1549
  )
1540
1550
  classical_bread_inverse_matrix = jnp.mean(
1541
- per_user_classical_bread_inverse_contributions, axis=0
1551
+ per_subject_classical_bread_inverse_contributions, axis=0
1542
1552
  )
1543
1553
  classical_sandwich = form_sandwich_from_bread_inverse_and_meat(
1544
1554
  classical_bread_inverse_matrix,
1545
1555
  classical_meat_matrix,
1546
- num_users,
1556
+ num_subjects,
1547
1557
  method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1548
1558
  )
1549
1559
 
1550
- per_user_adaptive_meat_adjustments = jnp.full(
1551
- (len(user_ids), theta_dim, theta_dim), jnp.nan
1560
+ per_subject_adjusted_meat_adjustments = jnp.full(
1561
+ (len(subject_ids), theta_dim, theta_dim), jnp.nan
1552
1562
  )
1553
- if form_adaptive_meat_adjustments_explicitly:
1554
- per_user_adjusted_classical_meat_contributions = (
1555
- form_adaptive_meat_adjustments_directly(
1563
+ if form_adjusted_meat_adjustments_explicitly:
1564
+ per_subject_adjusted_classical_meat_contributions = (
1565
+ form_adjusted_meat_adjustments_directly(
1556
1566
  theta_dim,
1557
1567
  all_post_update_betas.shape[1],
1558
- stabilized_joint_adaptive_bread_inverse_matrix,
1559
- per_user_estimating_function_stacks,
1560
- study_df,
1561
- in_study_col_name,
1568
+ stabilized_joint_adjusted_bread_inverse_matrix,
1569
+ per_subject_estimating_function_stacks,
1570
+ analysis_df,
1571
+ active_col_name,
1562
1572
  action_col_name,
1563
1573
  calendar_t_col_name,
1564
- user_id_col_name,
1574
+ subject_id_col_name,
1565
1575
  action_prob_func,
1566
1576
  action_prob_func_args,
1567
1577
  action_prob_func_args_beta_index,
1568
1578
  theta_est,
1569
1579
  inference_func,
1570
1580
  inference_func_args_theta_index,
1571
- user_ids,
1581
+ subject_ids,
1572
1582
  action_prob_col_name,
1573
1583
  )
1574
1584
  )
@@ -1578,30 +1588,30 @@ def construct_classical_and_adaptive_sandwiches(
1578
1588
  # First just apply any small-sample correction for parity.
1579
1589
  (
1580
1590
  _,
1581
- theta_only_adaptive_meat_matrix_v2,
1591
+ theta_only_adjusted_meat_matrix_v2,
1582
1592
  _,
1583
1593
  _,
1584
1594
  ) = perform_desired_small_sample_correction(
1585
1595
  small_sample_correction,
1586
- per_user_joint_adaptive_meat_contributions,
1587
- per_user_adjusted_classical_meat_contributions,
1588
- per_user_classical_bread_inverse_contributions,
1589
- num_users,
1596
+ per_subject_joint_adjusted_meat_contributions,
1597
+ per_subject_adjusted_classical_meat_contributions,
1598
+ per_subject_classical_bread_inverse_contributions,
1599
+ num_subjects,
1590
1600
  theta_dim,
1591
1601
  )
1592
- theta_only_adaptive_sandwich_from_adjustments = (
1602
+ theta_only_adjusted_sandwich_from_adjustments = (
1593
1603
  form_sandwich_from_bread_inverse_and_meat(
1594
1604
  classical_bread_inverse_matrix,
1595
- theta_only_adaptive_meat_matrix_v2,
1596
- num_users,
1605
+ theta_only_adjusted_meat_matrix_v2,
1606
+ num_subjects,
1597
1607
  method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1598
1608
  )
1599
1609
  )
1600
- theta_only_adaptive_sandwich = joint_adaptive_sandwich[-theta_dim:, -theta_dim:]
1610
+ theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
1601
1611
 
1602
1612
  if not np.allclose(
1603
- theta_only_adaptive_sandwich,
1604
- theta_only_adaptive_sandwich_from_adjustments,
1613
+ theta_only_adjusted_sandwich,
1614
+ theta_only_adjusted_sandwich_from_adjustments,
1605
1615
  rtol=3e-2,
1606
1616
  ):
1607
1617
  logger.warning(
@@ -1611,26 +1621,26 @@ def construct_classical_and_adaptive_sandwiches(
1611
1621
  # Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
1612
1622
  # values too. The joint adaptive bread inverse should always be block lower triangular.
1613
1623
  return (
1614
- raw_joint_adaptive_bread_inverse_matrix,
1615
- stabilized_joint_adaptive_bread_inverse_matrix,
1616
- joint_adaptive_meat_matrix,
1617
- joint_adaptive_sandwich,
1624
+ raw_joint_adjusted_bread_inverse_matrix,
1625
+ stabilized_joint_adjusted_bread_inverse_matrix,
1626
+ joint_adjusted_meat_matrix,
1627
+ joint_adjusted_sandwich,
1618
1628
  classical_bread_inverse_matrix,
1619
1629
  classical_meat_matrix,
1620
1630
  classical_sandwich,
1621
1631
  avg_estimating_function_stack,
1622
- per_user_estimating_function_stacks,
1623
- per_user_adaptive_corrections,
1624
- per_user_classical_corrections,
1625
- per_user_adaptive_meat_adjustments,
1632
+ per_subject_estimating_function_stacks,
1633
+ per_subject_adjusted_corrections,
1634
+ per_subject_classical_corrections,
1635
+ per_subject_adjusted_meat_adjustments,
1626
1636
  )
1627
1637
 
1628
1638
 
1629
1639
  # TODO: I think there should be interaction to confirm stabilization. It is
1630
- # important for the user to know if this is happening. Even if enabled, it is important
1631
- # that the user know it actually kicks in.
1632
- def stabilize_joint_adaptive_bread_inverse_if_necessary(
1633
- joint_adaptive_bread_inverse_matrix: jnp.ndarray,
1640
+ # important for the subject to know if this is happening. Even if enabled, it is important
1641
+ # that the subject know it actually kicks in.
1642
+ def stabilize_joint_adjusted_bread_inverse_if_necessary(
1643
+ joint_adjusted_bread_inverse_matrix: jnp.ndarray,
1634
1644
  beta_dim: int,
1635
1645
  theta_dim: int,
1636
1646
  ) -> jnp.ndarray:
@@ -1639,7 +1649,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1639
1649
  dominance and/or adding a small ridge penalty to the diagonal blocks.
1640
1650
 
1641
1651
  Args:
1642
- joint_adaptive_bread_inverse_matrix (jnp.ndarray):
1652
+ joint_adjusted_bread_inverse_matrix (jnp.ndarray):
1643
1653
  A 2-D JAX NumPy array representing the joint adaptive bread inverse matrix.
1644
1654
  beta_dim (int):
1645
1655
  The dimension of each beta parameter.
@@ -1660,7 +1670,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1660
1670
 
1661
1671
  # Grab just the RL block and convert numpy array for easier manipulation.
1662
1672
  RL_stack_beta_derivatives_block = np.array(
1663
- joint_adaptive_bread_inverse_matrix[:-theta_dim, :-theta_dim]
1673
+ joint_adjusted_bread_inverse_matrix[:-theta_dim, :-theta_dim]
1664
1674
  )
1665
1675
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
1666
1676
  for i in range(1, num_updates + 1):
@@ -1688,7 +1698,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1688
1698
  RL_stack_beta_derivatives_block[
1689
1699
  diagonal_block_slice, diagonal_block_slice
1690
1700
  ] = diagonal_block + ridge_penalty * np.eye(beta_dim)
1691
- # TODO: Require user input here in interactive settings?
1701
+ # TODO: Require subject input here in interactive settings?
1692
1702
  logger.info(
1693
1703
  "Added ridge penalty of %s to diagonal block for update %s to improve conditioning from %s to %s",
1694
1704
  ridge_penalty,
@@ -1779,11 +1789,11 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1779
1789
  [
1780
1790
  [
1781
1791
  RL_stack_beta_derivatives_block,
1782
- joint_adaptive_bread_inverse_matrix[:-theta_dim, -theta_dim:],
1792
+ joint_adjusted_bread_inverse_matrix[:-theta_dim, -theta_dim:],
1783
1793
  ],
1784
1794
  [
1785
- joint_adaptive_bread_inverse_matrix[-theta_dim:, :-theta_dim],
1786
- joint_adaptive_bread_inverse_matrix[-theta_dim:, -theta_dim:],
1795
+ joint_adjusted_bread_inverse_matrix[-theta_dim:, :-theta_dim],
1796
+ joint_adjusted_bread_inverse_matrix[-theta_dim:, -theta_dim:],
1787
1797
  ],
1788
1798
  ]
1789
1799
  )
@@ -1792,7 +1802,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1792
1802
  def form_sandwich_from_bread_inverse_and_meat(
1793
1803
  bread_inverse: jnp.ndarray,
1794
1804
  meat: jnp.ndarray,
1795
- num_users: int,
1805
+ num_subjects: int,
1796
1806
  method: str = SandwichFormationMethods.BREAD_INVERSE_T_QR,
1797
1807
  ) -> jnp.ndarray:
1798
1808
  """
@@ -1806,8 +1816,8 @@ def form_sandwich_from_bread_inverse_and_meat(
1806
1816
  A 2-D JAX NumPy array representing the bread inverse matrix.
1807
1817
  meat (jnp.ndarray):
1808
1818
  A 2-D JAX NumPy array representing the meat matrix.
1809
- num_users (int):
1810
- The number of users in the study, used to scale the sandwich appropriately.
1819
+ num_subjects (int):
1820
+ The number of subjects in the deployment, used to scale the sandwich appropriately.
1811
1821
  method (str):
1812
1822
  The method to use for forming the sandwich.
1813
1823
 
@@ -1833,7 +1843,7 @@ def form_sandwich_from_bread_inverse_and_meat(
1833
1843
  L, scipy.linalg.solve_triangular(L, meat.T, lower=True).T, lower=True
1834
1844
  )
1835
1845
 
1836
- return Q @ new_meat @ Q.T / num_users
1846
+ return Q @ new_meat @ Q.T / num_subjects
1837
1847
  elif method == SandwichFormationMethods.MEAT_SVD_SOLVE:
1838
1848
  # Factor the meat via SVD without any symmetrization or truncation.
1839
1849
  # For general (possibly slightly nonsymmetric) M, SVD gives M = U @ diag(s) @ Vh.
@@ -1847,14 +1857,14 @@ def form_sandwich_from_bread_inverse_and_meat(
1847
1857
  W_left = scipy.linalg.solve(bread_inverse, C_left)
1848
1858
  W_right = scipy.linalg.solve(bread_inverse, C_right)
1849
1859
 
1850
- # Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_users
1851
- return W_left @ W_right.T / num_users
1860
+ # Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
1861
+ return W_left @ W_right.T / num_subjects
1852
1862
 
1853
1863
  elif method == SandwichFormationMethods.NAIVE:
1854
1864
  # Simply invert the bread inverse and form the sandwich directly.
1855
1865
  # This is NOT numerically stable and is only included for comparison purposes.
1856
1866
  bread = np.linalg.inv(bread_inverse)
1857
- return bread @ meat @ meat.T / num_users
1867
+ return bread @ meat @ meat.T / num_subjects
1858
1868
 
1859
1869
  else:
1860
1870
  raise ValueError(