lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,8 +24,8 @@ from .constants import (
24
24
  SandwichFormationMethods,
25
25
  SmallSampleCorrections,
26
26
  )
27
- from .form_adaptive_meat_adjustments_directly import (
28
- form_adaptive_meat_adjustments_directly,
27
+ from .form_adjusted_meat_adjustments_directly import (
28
+ form_adjusted_meat_adjustments_directly,
29
29
  )
30
30
  from . import input_checks
31
31
  from . import get_datum_for_blowup_supervised_learning
@@ -37,9 +37,9 @@ from .helper_functions import (
37
37
  calculate_beta_dim,
38
38
  collect_all_post_update_betas,
39
39
  construct_beta_index_by_policy_num_map,
40
- extract_action_and_policy_by_decision_time_by_user_id,
40
+ extract_action_and_policy_by_decision_time_by_subject_id,
41
41
  flatten_params,
42
- get_in_study_df_column,
42
+ get_active_df_column,
43
43
  get_min_time_by_policy_num,
44
44
  get_radon_nikodym_weight,
45
45
  load_function_from_same_named_file,
@@ -61,7 +61,7 @@ def cli():
61
61
 
62
62
  # TODO: Check all help strings for accuracy.
63
63
  # TODO: Deal with NA, -1, etc policy numbers
64
- # TODO: Make sure in study is never on for more than one stretch EDIT: unclear if
64
+ # TODO: Make sure in deployment is never on for more than one stretch EDIT: unclear if
65
65
  # this will remain an invariant as we deal with more complicated data missingness
66
66
  # TODO: I think I'm agnostic to indexing of calendar times but should check because
67
67
  # otherwise need to add a check here to verify required format.
@@ -69,7 +69,7 @@ def cli():
69
69
  # Higher dimensional objects not supported. Not entirely sure what kind of "scalars" apply.
70
70
  @cli.command(name="analyze")
71
71
  @click.option(
72
- "--study_df_pickle",
72
+ "--analysis_df_pickle",
73
73
  type=click.File("rb"),
74
74
  help="Pickled pandas dataframe in correct format (see contract/readme).",
75
75
  required=True,
@@ -83,7 +83,7 @@ def cli():
83
83
  @click.option(
84
84
  "--action_prob_func_args_pickle",
85
85
  type=click.File("rb"),
86
- help="Pickled dictionary that contains the action probability function arguments for all decision times for all users.",
86
+ help="Pickled dictionary that contains the action probability function arguments for all decision times for all subjects.",
87
87
  required=True,
88
88
  )
89
89
  @click.option(
@@ -95,7 +95,7 @@ def cli():
95
95
  @click.option(
96
96
  "--alg_update_func_filename",
97
97
  type=click.Path(exists=True),
98
- help="File that contains the per-user update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
98
+ help="File that contains the per-subject update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
99
99
  required=True,
100
100
  )
101
101
  @click.option(
@@ -107,7 +107,7 @@ def cli():
107
107
  @click.option(
108
108
  "--alg_update_func_args_pickle",
109
109
  type=click.File("rb"),
110
- help="Pickled dictionary that contains the algorithm update function arguments for all update times for all users.",
110
+ help="Pickled dictionary that contains the algorithm update function arguments for all update times for all subjects.",
111
111
  required=True,
112
112
  )
113
113
  @click.option(
@@ -137,7 +137,7 @@ def cli():
137
137
  @click.option(
138
138
  "--inference_func_filename",
139
139
  type=click.Path(exists=True),
140
- help="File that contains the per-user loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
140
+ help="File that contains the per-subject loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
141
141
  required=True,
142
142
  )
143
143
  @click.option(
@@ -155,56 +155,56 @@ def cli():
155
155
  @click.option(
156
156
  "--theta_calculation_func_filename",
157
157
  type=click.Path(exists=True),
158
- help="Path to file that allows one to actually calculate a theta estimate given the study dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
158
+ help="Path to file that allows one to actually calculate a theta estimate given the analysis dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
159
159
  required=True,
160
160
  )
161
161
  @click.option(
162
- "--in_study_col_name",
162
+ "--active_col_name",
163
163
  type=str,
164
164
  required=True,
165
- help="Name of the binary column in the study dataframe that indicates whether a user is in the study.",
165
+ help="Name of the binary column in the analysis dataframe that indicates whether a subject is in the deployment.",
166
166
  )
167
167
  @click.option(
168
168
  "--action_col_name",
169
169
  type=str,
170
170
  required=True,
171
- help="Name of the binary column in the study dataframe that indicates which action was taken.",
171
+ help="Name of the binary column in the analysis dataframe that indicates which action was taken.",
172
172
  )
173
173
  @click.option(
174
174
  "--policy_num_col_name",
175
175
  type=str,
176
176
  required=True,
177
- help="Name of the column in the study dataframe that indicates the policy number in use.",
177
+ help="Name of the column in the analysis dataframe that indicates the policy number in use.",
178
178
  )
179
179
  @click.option(
180
180
  "--calendar_t_col_name",
181
181
  type=str,
182
182
  required=True,
183
- help="Name of the column in the study dataframe that indicates calendar time (shared integer index across users).",
183
+ help="Name of the column in the analysis dataframe that indicates calendar time (shared integer index across subjects).",
184
184
  )
185
185
  @click.option(
186
- "--user_id_col_name",
186
+ "--subject_id_col_name",
187
187
  type=str,
188
188
  required=True,
189
- help="Name of the column in the study dataframe that indicates user id.",
189
+ help="Name of the column in the analysis dataframe that indicates subject id.",
190
190
  )
191
191
  @click.option(
192
192
  "--action_prob_col_name",
193
193
  type=str,
194
194
  required=True,
195
- help="Name of the column in the study dataframe that gives action one probabilities.",
195
+ help="Name of the column in the analysis dataframe that gives action one probabilities.",
196
196
  )
197
197
  @click.option(
198
198
  "--reward_col_name",
199
199
  type=str,
200
200
  required=True,
201
- help="Name of the column in the study dataframe that gives rewards.",
201
+ help="Name of the column in the analysis dataframe that gives rewards.",
202
202
  )
203
203
  @click.option(
204
204
  "--suppress_interactive_data_checks",
205
205
  type=bool,
206
206
  default=False,
207
- help="Flag to suppress any data checks that require user input. This is suitable for tests and large simulations",
207
+ help="Flag to suppress any data checks that require subject input. This is suitable for tests and large simulations",
208
208
  )
209
209
  @click.option(
210
210
  "--suppress_all_data_checks",
@@ -217,9 +217,9 @@ def cli():
217
217
  type=click.Choice(
218
218
  [
219
219
  SmallSampleCorrections.NONE,
220
- SmallSampleCorrections.HC1theta,
221
- SmallSampleCorrections.HC2theta,
222
- SmallSampleCorrections.HC3theta,
220
+ SmallSampleCorrections.Z1theta,
221
+ SmallSampleCorrections.Z2theta,
222
+ SmallSampleCorrections.Z3theta,
223
223
  ]
224
224
  ),
225
225
  default=SmallSampleCorrections.NONE,
@@ -232,23 +232,23 @@ def cli():
232
232
  help="Flag to collect data for supervised learning blowup detection. This will write a single datum and label to a file in the same directory as the input files.",
233
233
  )
234
234
  @click.option(
235
- "--form_adaptive_meat_adjustments_explicitly",
235
+ "--form_adjusted_meat_adjustments_explicitly",
236
236
  type=bool,
237
237
  default=False,
238
- help="If True, explicitly forms the per-user meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
238
+ help="If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted sandwich from the classical sandwich. This is for diagnostic purposes, as the adjusted sandwich is formed without doing this.",
239
239
  )
240
240
  @click.option(
241
- "--stabilize_joint_adaptive_bread_inverse",
241
+ "--stabilize_joint_bread",
242
242
  type=bool,
243
243
  default=True,
244
- help="If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning thresholds.",
244
+ help="If True, stabilizes the joint bread matrix if it does not meet conditioning thresholds.",
245
245
  )
246
246
  def analyze_dataset_wrapper(**kwargs):
247
247
  """
248
248
  This function is a wrapper around analyze_dataset to facilitate command line use.
249
249
 
250
250
  From the command line, we will take pickles and filenames for Python objects.
251
- Unpickle/load files here for passing to the implementation function, which
251
+ We unpickle/load files here for passing to the implementation function, which
252
252
  may also be called in its own right with in-memory objects.
253
253
 
254
254
  See analyze_dataset for the underlying details.
@@ -256,18 +256,20 @@ def analyze_dataset_wrapper(**kwargs):
256
256
  Returns: None
257
257
  """
258
258
 
259
- # Pass along the folder the study dataframe is in as the output folder.
260
- # Do it now because we will be removing the study dataframe pickle from kwargs.
261
- kwargs["output_dir"] = pathlib.Path(kwargs["study_df_pickle"].name).parent.resolve()
259
+ # Pass along the folder the analysis dataframe is in as the output folder.
260
+ # Do it now because we will be removing the analysis dataframe pickle from kwargs.
261
+ kwargs["output_dir"] = pathlib.Path(
262
+ kwargs["analysis_df_pickle"].name
263
+ ).parent.resolve()
262
264
 
263
265
  # Unpickle pickles and replace those args in kwargs
264
- kwargs["study_df"] = pickle.load(kwargs["study_df_pickle"])
266
+ kwargs["analysis_df"] = pickle.load(kwargs["analysis_df_pickle"])
265
267
  kwargs["action_prob_func_args"] = pickle.load(
266
268
  kwargs["action_prob_func_args_pickle"]
267
269
  )
268
270
  kwargs["alg_update_func_args"] = pickle.load(kwargs["alg_update_func_args_pickle"])
269
271
 
270
- kwargs.pop("study_df_pickle")
272
+ kwargs.pop("analysis_df_pickle")
271
273
  kwargs.pop("action_prob_func_args_pickle")
272
274
  kwargs.pop("alg_update_func_args_pickle")
273
275
 
@@ -295,7 +297,7 @@ def analyze_dataset_wrapper(**kwargs):
295
297
 
296
298
  def analyze_dataset(
297
299
  output_dir: pathlib.Path | str,
298
- study_df: pd.DataFrame,
300
+ analysis_df: pd.DataFrame,
299
301
  action_prob_func: Callable,
300
302
  action_prob_func_args: dict[int, Any],
301
303
  action_prob_func_args_beta_index: int,
@@ -310,35 +312,35 @@ def analyze_dataset(
310
312
  inference_func_type: str,
311
313
  inference_func_args_theta_index: int,
312
314
  theta_calculation_func: Callable[[pd.DataFrame], jnp.ndarray],
313
- in_study_col_name: str,
315
+ active_col_name: str,
314
316
  action_col_name: str,
315
317
  policy_num_col_name: str,
316
318
  calendar_t_col_name: str,
317
- user_id_col_name: str,
319
+ subject_id_col_name: str,
318
320
  action_prob_col_name: str,
319
321
  reward_col_name: str,
320
322
  suppress_interactive_data_checks: bool,
321
323
  suppress_all_data_checks: bool,
322
324
  small_sample_correction: str,
323
325
  collect_data_for_blowup_supervised_learning: bool,
324
- form_adaptive_meat_adjustments_explicitly: bool,
325
- stabilize_joint_adaptive_bread_inverse: bool,
326
+ form_adjusted_meat_adjustments_explicitly: bool,
327
+ stabilize_joint_bread: bool,
326
328
  ) -> None:
327
329
  """
328
- Analyzes a dataset to provide a parameter estimate and an estimate of its variance using adaptive and classical sandwich estimators.
330
+ Analyzes a dataset to provide a parameter estimate and an estimate of its variance using and classical sandwich estimators.
329
331
 
330
332
  There are two modes of use for this function.
331
333
 
332
334
  First, it may be called indirectly from the command line by passing through
333
- analyze_dataset.
335
+ analyze_dataset_wrapper.
334
336
 
335
337
  Second, it may be called directly from Python code with in-memory objects.
336
338
 
337
339
  Parameters:
338
340
  output_dir (pathlib.Path | str):
339
341
  Directory in which to save output files.
340
- study_df (pd.DataFrame):
341
- DataFrame containing the study data.
342
+ analysis_df (pd.DataFrame):
343
+ DataFrame containing the deployment data.
342
344
  action_prob_func (callable):
343
345
  Action probability function.
344
346
  action_prob_func_args (dict[int, Any]):
@@ -364,21 +366,21 @@ def analyze_dataset(
364
366
  inference_func_args_theta_index (int):
365
367
  Index for theta in inference function arguments.
366
368
  theta_calculation_func (callable):
367
- Function to estimate theta from the study DataFrame.
368
- in_study_col_name (str):
369
- Column name indicating if a user is in the study in the study dataframe.
369
+ Function to estimate theta from the analysis dataframe.
370
+ active_col_name (str):
371
+ Column name indicating if a subject is active in the analysis dataframe.
370
372
  action_col_name (str):
371
- Column name for actions in the study dataframe.
373
+ Column name for actions in the analysis dataframe.
372
374
  policy_num_col_name (str):
373
- Column name for policy numbers in the study dataframe.
375
+ Column name for policy numbers in the analysis dataframe.
374
376
  calendar_t_col_name (str):
375
- Column name for calendar time in the study dataframe.
376
- user_id_col_name (str):
377
- Column name for user IDs in the study dataframe.
377
+ Column name for calendar time in the analysis dataframe.
378
+ subject_id_col_name (str):
379
+ Column name for subject IDs in the analysis dataframe.
378
380
  action_prob_col_name (str):
379
- Column name for action probabilities in the study dataframe.
381
+ Column name for action probabilities in the analysis dataframe.
380
382
  reward_col_name (str):
381
- Column name for rewards in the study dataframe.
383
+ Column name for rewards in the analysis dataframe.
382
384
  suppress_interactive_data_checks (bool):
383
385
  Whether to suppress interactive data checks. This should be used in simulations, for example.
384
386
  suppress_all_data_checks (bool):
@@ -386,17 +388,17 @@ def analyze_dataset(
386
388
  small_sample_correction (str):
387
389
  Type of small sample correction to apply.
388
390
  collect_data_for_blowup_supervised_learning (bool):
389
- Whether to collect data for doing supervised learning about adaptive sandwich blowup.
390
- form_adaptive_meat_adjustments_explicitly (bool):
391
- If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
391
+ Whether to collect data for doing supervised learning about adjusted sandwich blowup.
392
+ form_adjusted_meat_adjustments_explicitly (bool):
393
+ If True, explicitly forms the per-subject meat adjustments that differentiate the
392
394
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
393
- adaptive sandwich is formed without doing this.
394
- stabilize_joint_adaptive_bread_inverse (bool):
395
- If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
395
+ adjusted sandwich is formed without doing this.
396
+ stabilize_joint_bread (bool):
397
+ If True, stabilizes the joint bread matrix if it does not meet conditioning
396
398
  thresholds.
397
399
 
398
400
  Returns:
399
- dict: A dictionary containing the theta estimate, adaptive sandwich variance estimate, and
401
+ dict: A dictionary containing the theta estimate, adjusted sandwich variance estimate, and
400
402
  classical sandwich variance estimate.
401
403
  """
402
404
 
@@ -406,19 +408,19 @@ def analyze_dataset(
406
408
  level=logging.INFO,
407
409
  )
408
410
 
409
- theta_est = jnp.array(theta_calculation_func(study_df))
411
+ theta_est = jnp.array(theta_calculation_func(analysis_df))
410
412
 
411
413
  beta_dim = calculate_beta_dim(
412
414
  action_prob_func_args, action_prob_func_args_beta_index
413
415
  )
414
416
  if not suppress_all_data_checks:
415
417
  input_checks.perform_first_wave_input_checks(
416
- study_df,
417
- in_study_col_name,
418
+ analysis_df,
419
+ active_col_name,
418
420
  action_col_name,
419
421
  policy_num_col_name,
420
422
  calendar_t_col_name,
421
- user_id_col_name,
423
+ subject_id_col_name,
422
424
  action_prob_col_name,
423
425
  reward_col_name,
424
426
  action_prob_func,
@@ -436,10 +438,9 @@ def analyze_dataset(
436
438
  )
437
439
 
438
440
  ### Begin collecting data structures that will be used to compute the joint bread matrix.
439
-
440
441
  beta_index_by_policy_num, initial_policy_num = (
441
442
  construct_beta_index_by_policy_num_map(
442
- study_df, policy_num_col_name, in_study_col_name
443
+ analysis_df, policy_num_col_name, active_col_name
443
444
  )
444
445
  )
445
446
 
@@ -447,11 +448,11 @@ def analyze_dataset(
447
448
  beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
448
449
  )
449
450
 
450
- action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id = (
451
- extract_action_and_policy_by_decision_time_by_user_id(
452
- study_df,
453
- user_id_col_name,
454
- in_study_col_name,
451
+ action_by_decision_time_by_subject_id, policy_num_by_decision_time_by_subject_id = (
452
+ extract_action_and_policy_by_decision_time_by_subject_id(
453
+ analysis_df,
454
+ subject_id_col_name,
455
+ active_col_name,
455
456
  calendar_t_col_name,
456
457
  action_col_name,
457
458
  policy_num_col_name,
@@ -459,45 +460,45 @@ def analyze_dataset(
459
460
  )
460
461
 
461
462
  (
462
- inference_func_args_by_user_id,
463
+ inference_func_args_by_subject_id,
463
464
  inference_func_args_action_prob_index,
464
- inference_action_prob_decision_times_by_user_id,
465
+ inference_action_prob_decision_times_by_subject_id,
465
466
  ) = process_inference_func_args(
466
467
  inference_func,
467
468
  inference_func_args_theta_index,
468
- study_df,
469
+ analysis_df,
469
470
  theta_est,
470
471
  action_prob_col_name,
471
472
  calendar_t_col_name,
472
- user_id_col_name,
473
- in_study_col_name,
473
+ subject_id_col_name,
474
+ active_col_name,
474
475
  )
475
476
 
476
- # Use a per-user weighted estimating function stacking functino to derive classical and joint
477
- # adaptive meat and inverse bread matrices. This is facilitated because the *value* of the
477
+ # Use a per-subject weighted estimating function stacking function to derive classical and joint
478
+ # meat and bread matrices. This is facilitated because the *value* of the
478
479
  # weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
479
480
  logger.info(
480
- "Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across users."
481
+ "Constructing joint bread matrix, joint meat matrix, the classical analogs, and the avg estimating function stack across subjects."
481
482
  )
482
483
 
483
- user_ids = jnp.array(study_df[user_id_col_name].unique())
484
+ subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
484
485
  (
485
- stabilized_joint_adaptive_bread_inverse_matrix,
486
- raw_joint_adaptive_bread_inverse_matrix,
487
- joint_adaptive_meat_matrix,
488
- joint_adaptive_sandwich_matrix,
489
- classical_bread_inverse_matrix,
486
+ stabilized_joint_adjusted_bread_matrix,
487
+ raw_joint_adjusted_bread_matrix,
488
+ joint_adjusted_meat_matrix,
489
+ joint_adjusted_sandwich_matrix,
490
+ classical_bread_matrix,
490
491
  classical_meat_matrix,
491
492
  classical_sandwich_var_estimate,
492
493
  avg_estimating_function_stack,
493
- per_user_estimating_function_stacks,
494
- per_user_adaptive_corrections,
495
- per_user_classical_corrections,
496
- per_user_adaptive_meat_adjustments,
497
- ) = construct_classical_and_adaptive_sandwiches(
494
+ per_subject_estimating_function_stacks,
495
+ per_subject_adjusted_corrections,
496
+ per_subject_classical_corrections,
497
+ per_subject_adjusted_meat_adjustments,
498
+ ) = construct_classical_and_adjusted_sandwiches(
498
499
  theta_est,
499
500
  all_post_update_betas,
500
- user_ids,
501
+ subject_ids,
501
502
  action_prob_func,
502
503
  action_prob_func_args_beta_index,
503
504
  alg_update_func,
@@ -511,23 +512,23 @@ def analyze_dataset(
511
512
  inference_func_args_theta_index,
512
513
  inference_func_args_action_prob_index,
513
514
  action_prob_func_args,
514
- policy_num_by_decision_time_by_user_id,
515
+ policy_num_by_decision_time_by_subject_id,
515
516
  initial_policy_num,
516
517
  beta_index_by_policy_num,
517
- inference_func_args_by_user_id,
518
- inference_action_prob_decision_times_by_user_id,
518
+ inference_func_args_by_subject_id,
519
+ inference_action_prob_decision_times_by_subject_id,
519
520
  alg_update_func_args,
520
- action_by_decision_time_by_user_id,
521
+ action_by_decision_time_by_subject_id,
521
522
  suppress_all_data_checks,
522
523
  suppress_interactive_data_checks,
523
524
  small_sample_correction,
524
- form_adaptive_meat_adjustments_explicitly,
525
- stabilize_joint_adaptive_bread_inverse,
526
- study_df,
527
- in_study_col_name,
525
+ form_adjusted_meat_adjustments_explicitly,
526
+ stabilize_joint_bread,
527
+ analysis_df,
528
+ active_col_name,
528
529
  action_col_name,
529
530
  calendar_t_col_name,
530
- user_id_col_name,
531
+ subject_id_col_name,
531
532
  action_prob_func_args,
532
533
  action_prob_col_name,
533
534
  )
@@ -543,27 +544,26 @@ def analyze_dataset(
543
544
 
544
545
  # This bottom right corner of the joint (betas and theta) variance matrix is the portion
545
546
  # corresponding to just theta.
546
- adaptive_sandwich_var_estimate = joint_adaptive_sandwich_matrix[
547
+ adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
547
548
  -theta_dim:, -theta_dim:
548
549
  ]
549
550
 
550
551
  # Check for negative diagonal elements and set them to zero if found
551
- adaptive_diagonal = np.diag(adaptive_sandwich_var_estimate)
552
- if np.any(adaptive_diagonal < 0):
552
+ adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
553
+ if np.any(adjusted_diagonal < 0):
553
554
  logger.warning(
554
- "Found negative diagonal elements in adaptive sandwich variance estimate. Setting them to zero."
555
+ "Found negative diagonal elements in adjusted sandwich variance estimate. Setting them to zero."
555
556
  )
556
557
  np.fill_diagonal(
557
- adaptive_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
558
+ adjusted_sandwich_var_estimate, np.maximum(adjusted_diagonal, 0)
558
559
  )
559
560
 
560
561
  logger.info("Writing results to file...")
561
- # Write analysis results to same directory that input files are in
562
562
  output_folder_abs_path = pathlib.Path(output_dir).resolve()
563
563
 
564
564
  analysis_dict = {
565
565
  "theta_est": theta_est,
566
- "adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
566
+ "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
567
567
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
568
568
  }
569
569
  with open(output_folder_abs_path / "analysis.pkl", "wb") as f:
@@ -572,29 +572,35 @@ def analyze_dataset(
572
572
  f,
573
573
  )
574
574
 
575
- joint_adaptive_bread_inverse_cond = jnp.linalg.cond(
576
- raw_joint_adaptive_bread_inverse_matrix
575
+ joint_adjusted_bread_cond = jnp.linalg.cond(raw_joint_adjusted_bread_matrix)
576
+ logger.info(
577
+ "Joint adjusted bread condition number: %f",
578
+ joint_adjusted_bread_cond,
577
579
  )
580
+
581
+ # calculate the max eigenvalue of the joint adjusted sandwich
582
+ max_eigenvalue = scipy.linalg.eigvalsh(joint_adjusted_sandwich_matrix).max()
578
583
  logger.info(
579
- "Joint adaptive bread inverse condition number: %f",
580
- joint_adaptive_bread_inverse_cond,
584
+ "Max eigenvalue of joint adjusted sandwich matrix: %f",
585
+ max_eigenvalue,
581
586
  )
582
587
 
583
588
  debug_pieces_dict = {
584
589
  "theta_est": theta_est,
585
- "adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
590
+ "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
586
591
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
587
- "raw_joint_bread_inverse_matrix": raw_joint_adaptive_bread_inverse_matrix,
588
- "stabilized_joint_bread_inverse_matrix": stabilized_joint_adaptive_bread_inverse_matrix,
589
- "joint_meat_matrix": joint_adaptive_meat_matrix,
590
- "classical_bread_inverse_matrix": classical_bread_inverse_matrix,
592
+ "raw_joint_bread_matrix": raw_joint_adjusted_bread_matrix,
593
+ "stabilized_joint_bread_matrix": stabilized_joint_adjusted_bread_matrix,
594
+ "joint_meat_matrix": joint_adjusted_meat_matrix,
595
+ "classical_bread_matrix": classical_bread_matrix,
591
596
  "classical_meat_matrix": classical_meat_matrix,
592
- "all_estimating_function_stacks": per_user_estimating_function_stacks,
593
- "joint_bread_inverse_condition_number": joint_adaptive_bread_inverse_cond,
597
+ "all_estimating_function_stacks": per_subject_estimating_function_stacks,
598
+ "joint_bread_condition_number": joint_adjusted_bread_cond,
599
+ "max_eigenvalue_joint_adjusted_sandwich": max_eigenvalue,
594
600
  "all_post_update_betas": all_post_update_betas,
595
- "per_user_adaptive_corrections": per_user_adaptive_corrections,
596
- "per_user_classical_corrections": per_user_classical_corrections,
597
- "per_user_adaptive_meat_adjustments": per_user_adaptive_meat_adjustments,
601
+ "per_subject_adjusted_corrections": per_subject_adjusted_corrections,
602
+ "per_subject_classical_corrections": per_subject_classical_corrections,
603
+ "per_subject_adjusted_meat_adjustments": per_subject_adjusted_meat_adjustments,
598
604
  }
599
605
  with open(output_folder_abs_path / "debug_pieces.pkl", "wb") as f:
600
606
  pickle.dump(
@@ -604,25 +610,25 @@ def analyze_dataset(
604
610
 
605
611
  if collect_data_for_blowup_supervised_learning:
606
612
  datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
607
- raw_joint_adaptive_bread_inverse_matrix,
608
- joint_adaptive_bread_inverse_cond,
613
+ raw_joint_adjusted_bread_matrix,
614
+ joint_adjusted_bread_cond,
609
615
  avg_estimating_function_stack,
610
- per_user_estimating_function_stacks,
616
+ per_subject_estimating_function_stacks,
611
617
  all_post_update_betas,
612
- study_df,
613
- in_study_col_name,
618
+ analysis_df,
619
+ active_col_name,
614
620
  calendar_t_col_name,
615
621
  action_prob_col_name,
616
- user_id_col_name,
622
+ subject_id_col_name,
617
623
  reward_col_name,
618
624
  theta_est,
619
- adaptive_sandwich_var_estimate,
620
- user_ids,
625
+ adjusted_sandwich_var_estimate,
626
+ subject_ids,
621
627
  beta_dim,
622
628
  theta_dim,
623
629
  initial_policy_num,
624
630
  beta_index_by_policy_num,
625
- policy_num_by_decision_time_by_user_id,
631
+ policy_num_by_decision_time_by_subject_id,
626
632
  theta_calculation_func,
627
633
  action_prob_func,
628
634
  action_prob_func_args_beta_index,
@@ -630,16 +636,16 @@ def analyze_dataset(
630
636
  inference_func_type,
631
637
  inference_func_args_theta_index,
632
638
  inference_func_args_action_prob_index,
633
- inference_action_prob_decision_times_by_user_id,
639
+ inference_action_prob_decision_times_by_subject_id,
634
640
  action_prob_func_args,
635
- action_by_decision_time_by_user_id,
641
+ action_by_decision_time_by_subject_id,
636
642
  )
637
643
 
638
644
  with open(output_folder_abs_path / "supervised_learning_datum.pkl", "wb") as f:
639
645
  pickle.dump(datum_and_label_dict, f)
640
646
 
641
647
  print(f"\nParameter estimate:\n {theta_est}")
642
- print(f"\nAdaptive sandwich variance estimate:\n {adaptive_sandwich_var_estimate}")
648
+ print(f"\nAdjusted sandwich variance estimate:\n {adjusted_sandwich_var_estimate}")
643
649
  print(
644
650
  f"\nClassical sandwich variance estimate:\n {classical_sandwich_var_estimate}\n"
645
651
  )
@@ -650,15 +656,15 @@ def analyze_dataset(
650
656
  def process_inference_func_args(
651
657
  inference_func: callable,
652
658
  inference_func_args_theta_index: int,
653
- study_df: pd.DataFrame,
659
+ analysis_df: pd.DataFrame,
654
660
  theta_est: jnp.ndarray,
655
661
  action_prob_col_name: str,
656
662
  calendar_t_col_name: str,
657
- user_id_col_name: str,
658
- in_study_col_name: str,
663
+ subject_id_col_name: str,
664
+ active_col_name: str,
659
665
  ) -> tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int]:
660
666
  """
661
- Collects the inference function arguments for each user from the study DataFrame.
667
+ Collects the inference function arguments for each subject from the analysis DataFrame.
662
668
 
663
669
  Note that theta and action probabilities, if present, will be replaced later
664
670
  so that the function can be differentiated with respect to shared versions
@@ -669,32 +675,32 @@ def process_inference_func_args(
669
675
  The inference function to be used.
670
676
  inference_func_args_theta_index (int):
671
677
  The index of the theta parameter in the inference function's arguments.
672
- study_df (pandas.DataFrame):
673
- The study DataFrame.
678
+ analysis_df (pandas.DataFrame):
679
+ The analysis DataFrame.
674
680
  theta_est (jnp.ndarray):
675
681
  The estimate of the parameter vector.
676
682
  action_prob_col_name (str):
677
- The name of the column in the study DataFrame that gives action probabilities.
683
+ The name of the column in the analysis DataFrame that gives action probabilities.
678
684
  calendar_t_col_name (str):
679
- The name of the column in the study DataFrame that indicates calendar time.
680
- user_id_col_name (str):
681
- The name of the column in the study DataFrame that indicates user ID.
682
- in_study_col_name (str):
683
- The name of the binary column in the study DataFrame that indicates whether a user is in the study.
685
+ The name of the column in the analysis DataFrame that indicates calendar time.
686
+ subject_id_col_name (str):
687
+ The name of the column in the analysis DataFrame that indicates subject ID.
688
+ active_col_name (str):
689
+ The name of the binary column in the analysis DataFrame that indicates whether a subject is in the deployment.
684
690
  Returns:
685
691
  tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int, dict[collections.abc.Hashable, jnp.ndarray[int]]]:
686
692
  A tuple containing
687
- - the inference function arguments dictionary for each user
693
+ - the inference function arguments dictionary for each subject
688
694
  - the index of the action probabilities argument
689
- - a dictionary mapping user IDs to the decision times to which action probabilities correspond
695
+ - a dictionary mapping subject IDs to the decision times to which action probabilities correspond
690
696
  """
691
697
 
692
698
  num_args = inference_func.__code__.co_argcount
693
699
  inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
694
- inference_func_args_by_user_id = {}
700
+ inference_func_args_by_subject_id = {}
695
701
 
696
702
  inference_func_args_action_prob_index = -1
697
- inference_action_prob_decision_times_by_user_id = {}
703
+ inference_action_prob_decision_times_by_subject_id = {}
698
704
 
699
705
  using_action_probs = action_prob_col_name in inference_func_arg_names
700
706
  if using_action_probs:
@@ -702,34 +708,36 @@ def process_inference_func_args(
702
708
  action_prob_col_name
703
709
  )
704
710
 
705
- for user_id in study_df[user_id_col_name].unique():
706
- user_args_list = []
707
- filtered_user_data = study_df.loc[study_df[user_id_col_name] == user_id]
711
+ for subject_id in analysis_df[subject_id_col_name].unique():
712
+ subject_args_list = []
713
+ filtered_subject_data = analysis_df.loc[
714
+ analysis_df[subject_id_col_name] == subject_id
715
+ ]
708
716
  for idx, col_name in enumerate(inference_func_arg_names):
709
717
  if idx == inference_func_args_theta_index:
710
- user_args_list.append(theta_est)
718
+ subject_args_list.append(theta_est)
711
719
  continue
712
- user_args_list.append(
713
- get_in_study_df_column(filtered_user_data, col_name, in_study_col_name)
720
+ subject_args_list.append(
721
+ get_active_df_column(filtered_subject_data, col_name, active_col_name)
714
722
  )
715
- inference_func_args_by_user_id[user_id] = tuple(user_args_list)
723
+ inference_func_args_by_subject_id[subject_id] = tuple(subject_args_list)
716
724
  if using_action_probs:
717
- inference_action_prob_decision_times_by_user_id[user_id] = (
718
- get_in_study_df_column(
719
- filtered_user_data, calendar_t_col_name, in_study_col_name
725
+ inference_action_prob_decision_times_by_subject_id[subject_id] = (
726
+ get_active_df_column(
727
+ filtered_subject_data, calendar_t_col_name, active_col_name
720
728
  )
721
729
  )
722
730
 
723
731
  return (
724
- inference_func_args_by_user_id,
732
+ inference_func_args_by_subject_id,
725
733
  inference_func_args_action_prob_index,
726
- inference_action_prob_decision_times_by_user_id,
734
+ inference_action_prob_decision_times_by_subject_id,
727
735
  )
728
736
 
729
737
 
730
- def single_user_weighted_estimating_function_stacker(
738
+ def single_subject_weighted_estimating_function_stacker(
731
739
  beta_dim: int,
732
- user_id: collections.abc.Hashable,
740
+ subject_id: collections.abc.Hashable,
733
741
  action_prob_func: callable,
734
742
  algorithm_estimating_func: callable,
735
743
  inference_estimating_func: callable,
@@ -763,12 +771,12 @@ def single_user_weighted_estimating_function_stacker(
763
771
  beta_dim (list[jnp.ndarray]):
764
772
  A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
765
773
 
766
- user_id (collections.abc.Hashable):
767
- The user ID for which to compute the weighted estimating function stack.
774
+ subject_id (collections.abc.Hashable):
775
+ The subject ID for which to compute the weighted estimating function stack.
768
776
 
769
777
  action_prob_func (callable):
770
778
  The function used to compute the probability of action 1 at a given decision time for
771
- a particular user given their state and the algorithm parameters.
779
+ a particular subject given their state and the algorithm parameters.
772
780
 
773
781
  algorithm_estimating_func (callable):
774
782
  The estimating function that corresponds to algorithm updates.
@@ -783,9 +791,9 @@ def single_user_weighted_estimating_function_stacker(
783
791
  The index of the theta parameter in the inference loss or estimating function arguments.
784
792
 
785
793
  action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
786
- A map from decision times to tuples of arguments for this user for the action
794
+ A map from decision times to tuples of arguments for this subject for the action
787
795
  probability function. This is for all decision times (args are an empty
788
- tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
796
+ tuple if they are not in the deployment). Should be sorted by decision time. NOTE THAT THESE
789
797
  ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
790
798
  will occur.
791
799
 
@@ -796,21 +804,21 @@ def single_user_weighted_estimating_function_stacker(
796
804
 
797
805
  threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
798
806
  A map from policy numbers to tuples containing the arguments for
799
- the corresponding estimating functions for this user, with the shared betas threaded in
807
+ the corresponding estimating functions for this subject, with the shared betas threaded in
800
808
  for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
801
809
  should be sorted.
802
810
 
803
811
  threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
804
812
  A tuple containing the arguments for the inference
805
- estimating function for this user, with the shared betas threaded in for differentiation.
813
+ estimating function for this subject, with the shared betas threaded in for differentiation.
806
814
 
807
815
  policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
808
816
  A dictionary mapping decision times to the policy number in use. This may be
809
- user-specific. Should be sorted by decision time. Only applies to in-study decision
817
+ subject-specific. Should be sorted by decision time. Only applies to active decision
810
818
  times!
811
819
 
812
820
  action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
813
- A dictionary mapping decision times to actions taken. Only applies to in-study decision
821
+ A dictionary mapping decision times to actions taken. Only applies to active decision
814
822
  times!
815
823
 
816
824
  beta_index_by_policy_num (dict[int | float, int]):
@@ -818,19 +826,21 @@ def single_user_weighted_estimating_function_stacker(
818
826
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
819
827
 
820
828
  Returns:
821
- jnp.ndarray: A 1-D JAX NumPy array representing the user's weighted estimating function
829
+ jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
822
830
  stack.
823
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's adaptive meat contribution.
824
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical meat contribution.
825
- jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical bread contribution.
831
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
832
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
833
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
826
834
  """
827
835
 
828
- logger.info("Computing weighted estimating function stack for user %s.", user_id)
836
+ logger.info(
837
+ "Computing weighted estimating function stack for subject %s.", subject_id
838
+ )
829
839
 
830
840
  # First, reformat the supplied data into more convenient structures.
831
841
 
832
842
  # 1. Form a dictionary mapping policy numbers to the first time they were
833
- # applicable (for this user). Note that this includes ALL policies, initial
843
+ # applicable (for this subject). Note that this includes ALL policies, initial
834
844
  # fallbacks included.
835
845
  # Collect the first time after the first update separately for convenience.
836
846
  # These are both used to form the Radon-Nikodym weights for the right times.
@@ -839,38 +849,38 @@ def single_user_weighted_estimating_function_stacker(
839
849
  beta_index_by_policy_num,
840
850
  )
841
851
 
842
- # 2. Get the start and end times for this user.
843
- user_start_time = math.inf
844
- user_end_time = -math.inf
852
+ # 2. Get the start and end times for this subject.
853
+ subject_start_time = math.inf
854
+ subject_end_time = -math.inf
845
855
  for decision_time in action_by_decision_time:
846
- user_start_time = min(user_start_time, decision_time)
847
- user_end_time = max(user_end_time, decision_time)
856
+ subject_start_time = min(subject_start_time, decision_time)
857
+ subject_end_time = max(subject_end_time, decision_time)
848
858
 
849
859
  # 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
850
860
  logger.info(
851
- "Computing the algorithm component of the weighted estimating function stack for user %s.",
852
- user_id,
861
+ "Computing the algorithm component of the weighted estimating function stack for subject %s.",
862
+ subject_id,
853
863
  )
854
864
 
855
- in_study_action_prob_func_args = [
865
+ active_action_prob_func_args = [
856
866
  args for args in action_prob_func_args_by_decision_time.values() if args
857
867
  ]
858
- in_study_betas_list_by_decision_time_index = jnp.array(
868
+ active_betas_list_by_decision_time_index = jnp.array(
859
869
  [
860
870
  action_prob_func_args[action_prob_func_args_beta_index]
861
- for action_prob_func_args in in_study_action_prob_func_args
871
+ for action_prob_func_args in active_action_prob_func_args
862
872
  ]
863
873
  )
864
- in_study_actions_list_by_decision_time_index = jnp.array(
874
+ active_actions_list_by_decision_time_index = jnp.array(
865
875
  list(action_by_decision_time.values())
866
876
  )
867
877
 
868
878
  # Sort the threaded args by decision time to be cautious. We check if the
869
- # user id is present in the user args dict because we may call this on a
870
- # subset of the user arg dict when we are batching arguments by shape
879
+ # subject id is present in the subject args dict because we may call this on a
880
+ # subset of the subject arg dict when we are batching arguments by shape
871
881
  sorted_threaded_action_prob_args_by_decision_time = {
872
882
  decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
873
- for decision_time in range(user_start_time, user_end_time + 1)
883
+ for decision_time in range(subject_start_time, subject_end_time + 1)
874
884
  if decision_time in threaded_action_prob_func_args_by_decision_time
875
885
  }
876
886
 
@@ -901,19 +911,19 @@ def single_user_weighted_estimating_function_stacker(
901
911
  # Just grab the original beta from the update function arguments. This is the same
902
912
  # value, but impervious to differentiation with respect to all_post_update_betas. The
903
913
  # args, on the other hand, are a function of all_post_update_betas.
904
- in_study_weights = jax.vmap(
914
+ active_weights = jax.vmap(
905
915
  fun=get_radon_nikodym_weight,
906
916
  in_axes=[0, None, None, 0] + batch_axes,
907
917
  out_axes=0,
908
918
  )(
909
- in_study_betas_list_by_decision_time_index,
919
+ active_betas_list_by_decision_time_index,
910
920
  action_prob_func,
911
921
  action_prob_func_args_beta_index,
912
- in_study_actions_list_by_decision_time_index,
922
+ active_actions_list_by_decision_time_index,
913
923
  *batched_threaded_arg_tensors,
914
924
  )
915
925
 
916
- in_study_index = 0
926
+ active_index = 0
917
927
  decision_time_to_all_weights_index_offset = min(
918
928
  sorted_threaded_action_prob_args_by_decision_time
919
929
  )
@@ -922,35 +932,35 @@ def single_user_weighted_estimating_function_stacker(
922
932
  decision_time,
923
933
  args,
924
934
  ) in sorted_threaded_action_prob_args_by_decision_time.items():
925
- all_weights_raw.append(in_study_weights[in_study_index] if args else 1.0)
926
- in_study_index += 1
935
+ all_weights_raw.append(active_weights[active_index] if args else 1.0)
936
+ active_index += 1
927
937
  all_weights = jnp.array(all_weights_raw)
928
938
 
929
939
  algorithm_component = jnp.concatenate(
930
940
  [
931
941
  # Here we compute a product of Radon-Nikodym weights
932
942
  # for all decision times after the first update and before the update
933
- # update under consideration took effect, for which the user was in the study.
943
+ # update under consideration took effect, for which the subject was in the deployment.
934
944
  (
935
945
  jnp.prod(
936
946
  all_weights[
937
- # The earliest time after the first update where the user was in
938
- # the study
947
+ # The earliest time after the first update where the subject was in
948
+ # the deployment
939
949
  max(
940
950
  first_time_after_first_update,
941
- user_start_time,
951
+ subject_start_time,
942
952
  )
943
953
  - decision_time_to_all_weights_index_offset :
944
- # One more than the latest time the user was in the study before the time
954
+ # One more than the latest time the subject was in the deployment before the time
945
955
  # the update under consideration first applied. Note the + 1 because range
946
956
  # does not include the right endpoint.
947
957
  min(
948
958
  min_time_by_policy_num.get(policy_num, math.inf),
949
- user_end_time + 1,
959
+ subject_end_time + 1,
950
960
  )
951
961
  - decision_time_to_all_weights_index_offset,
952
962
  ]
953
- # If the user exited the study before there were any updates,
963
+ # If the subject exited the deployment before there were any updates,
954
964
  # this variable will be None and the above code to grab a weight would
955
965
  # throw an error. Just use 1 to include the unweighted estimating function
956
966
  # if they have data to contribute to the update.
@@ -958,8 +968,8 @@ def single_user_weighted_estimating_function_stacker(
958
968
  else 1
959
969
  ) # Now use the above to weight the alg estimating function for this update
960
970
  * algorithm_estimating_func(*update_args)
961
- # If there are no arguments for the update function, the user is not yet in the
962
- # study, so we just add a zero vector contribution to the sum across users.
971
+ # If there are no arguments for the update function, the subject is not yet in the
972
+ # deployment, so we just add a zero vector contribution to the sum across subjects.
963
973
  # Note that after they exit, they still contribute all their data to later
964
974
  # updates.
965
975
  if update_args
@@ -978,17 +988,17 @@ def single_user_weighted_estimating_function_stacker(
978
988
  )
979
989
  # 4. Form the weighted inference estimating equation.
980
990
  logger.info(
981
- "Computing the inference component of the weighted estimating function stack for user %s.",
982
- user_id,
991
+ "Computing the inference component of the weighted estimating function stack for subject %s.",
992
+ subject_id,
983
993
  )
984
994
  inference_component = jnp.prod(
985
995
  all_weights[
986
- max(first_time_after_first_update, user_start_time)
987
- - decision_time_to_all_weights_index_offset : user_end_time
996
+ max(first_time_after_first_update, subject_start_time)
997
+ - decision_time_to_all_weights_index_offset : subject_end_time
988
998
  + 1
989
999
  - decision_time_to_all_weights_index_offset,
990
1000
  ]
991
- # If the user exited the study before there were any updates,
1001
+ # If the subject exited the deployment before there were any updates,
992
1002
  # this variable will be None and the above code to grab a weight would
993
1003
  # throw an error. Just use 1 to include the unweighted estimating function
994
1004
  # if they have data to contribute here (pretty sure everyone should?)
@@ -997,18 +1007,18 @@ def single_user_weighted_estimating_function_stacker(
997
1007
  ) * inference_estimating_func(*threaded_inference_func_args)
998
1008
 
999
1009
  # 5. Concatenate the two components to form the weighted estimating function stack for this
1000
- # user.
1010
+ # subject.
1001
1011
  weighted_stack = jnp.concatenate([algorithm_component, inference_component])
1002
1012
 
1003
1013
  # 6. Return the following outputs:
1004
- # a. The first is simply the weighted estimating function stack for this user. The average
1005
- # of these is what we differentiate with respect to theta to form the inverse adaptive joint
1014
+ # a. The first is simply the weighted estimating function stack for this subject. The average
1015
+ # of these is what we differentiate with respect to theta to form the joint
1006
1016
  # bread matrix, and we also compare that average to zero to check the estimating functions'
1007
1017
  # fidelity.
1008
- # b. The average outer product of these per-user stacks across users is the adaptive joint meat
1018
+ # b. The average outer product of these per-subject stacks across subjects is the adjusted joint meat
1009
1019
  # matrix, hence the second output.
1010
- # c. The third output is averaged across users to obtain the classical meat matrix.
1011
- # d. The fourth output is averaged across users to obtain the inverse classical bread
1020
+ # c. The third output is averaged across subjects to obtain the classical meat matrix.
1021
+ # d. The fourth output is averaged across subjects to obtain the inverse classical bread
1012
1022
  # matrix.
1013
1023
  return (
1014
1024
  weighted_stack,
@@ -1024,7 +1034,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1024
1034
  flattened_betas_and_theta: jnp.ndarray,
1025
1035
  beta_dim: int,
1026
1036
  theta_dim: int,
1027
- user_ids: jnp.ndarray,
1037
+ subject_ids: jnp.ndarray,
1028
1038
  action_prob_func: callable,
1029
1039
  action_prob_func_args_beta_index: int,
1030
1040
  alg_update_func: callable,
@@ -1037,30 +1047,32 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1037
1047
  inference_func_type: str,
1038
1048
  inference_func_args_theta_index: int,
1039
1049
  inference_func_args_action_prob_index: int,
1040
- action_prob_func_args_by_user_id_by_decision_time: dict[
1050
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
1041
1051
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
1042
1052
  ],
1043
- policy_num_by_decision_time_by_user_id: dict[
1053
+ policy_num_by_decision_time_by_subject_id: dict[
1044
1054
  collections.abc.Hashable, dict[int, int | float]
1045
1055
  ],
1046
1056
  initial_policy_num: int | float,
1047
1057
  beta_index_by_policy_num: dict[int | float, int],
1048
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1049
- inference_action_prob_decision_times_by_user_id: dict[
1058
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1059
+ inference_action_prob_decision_times_by_subject_id: dict[
1050
1060
  collections.abc.Hashable, list[int]
1051
1061
  ],
1052
- update_func_args_by_by_user_id_by_policy_num: dict[
1062
+ update_func_args_by_by_subject_id_by_policy_num: dict[
1053
1063
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
1054
1064
  ],
1055
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
1065
+ action_by_decision_time_by_subject_id: dict[
1066
+ collections.abc.Hashable, dict[int, int]
1067
+ ],
1056
1068
  suppress_all_data_checks: bool,
1057
1069
  suppress_interactive_data_checks: bool,
1058
1070
  ) -> tuple[
1059
1071
  jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
1060
1072
  ]:
1061
1073
  """
1062
- Computes the average weighted estimating function stack across all users, along with
1063
- auxiliary values used to construct the adaptive and classical sandwich variances.
1074
+ Computes the average weighted estimating function stack across all subjects, along with
1075
+ auxiliary values used to construct the adjusted and classical sandwich variances.
1064
1076
 
1065
1077
  Args:
1066
1078
  flattened_betas_and_theta (jnp.ndarray):
@@ -1071,8 +1083,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1071
1083
  The dimension of each of the beta parameters.
1072
1084
  theta_dim (int):
1073
1085
  The dimension of the theta parameter.
1074
- user_ids (jnp.ndarray):
1075
- A 1D JAX NumPy array of user IDs.
1086
+ subject_ids (jnp.ndarray):
1087
+ A 1D JAX NumPy array of subject IDs.
1076
1088
  action_prob_func (callable):
1077
1089
  The action probability function.
1078
1090
  action_prob_func_args_beta_index (int):
@@ -1100,29 +1112,29 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1100
1112
  inference_func_args_action_prob_index (int):
1101
1113
  The index of action probabilities in the inference function arguments tuple, if
1102
1114
  applicable. -1 otherwise.
1103
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1104
- A dictionary mapping decision times to maps of user ids to the function arguments
1105
- required to compute action probabilities for this user.
1106
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1107
- A map of user ids to dictionaries mapping decision times to the policy number in use.
1108
- Only applies to in-study decision times!
1115
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1116
+ A dictionary mapping decision times to maps of subject ids to the function arguments
1117
+ required to compute action probabilities for this subject.
1118
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1119
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
1120
+ Only applies to active decision times!
1109
1121
  initial_policy_num (int | float):
1110
1122
  The policy number of the initial policy before any updates.
1111
1123
  beta_index_by_policy_num (dict[int | float, int]):
1112
1124
  A dictionary mapping policy numbers to the index of the corresponding beta in
1113
1125
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1114
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1115
- A dictionary mapping user IDs to their respective inference function arguments.
1116
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1117
- For each user, a list of decision times to which action probabilities correspond if
1118
- provided. Typically just in-study times if action probabilites are used in the inference
1126
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1127
+ A dictionary mapping subject IDs to their respective inference function arguments.
1128
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
1129
+ For each subject, a list of decision times to which action probabilities correspond if
1130
+ provided. Typically just active times if action probabilites are used in the inference
1119
1131
  loss or estimating function.
1120
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1121
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
1132
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1133
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
1122
1134
  to their respective update function arguments.
1123
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1124
- A dictionary mapping user IDs to their respective actions taken at each decision time.
1125
- Only applies to in-study decision times!
1135
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
1136
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
1137
+ Only applies to active decision times!
1126
1138
  suppress_all_data_checks (bool):
1127
1139
  If True, suppresses carrying out any data checks at all.
1128
1140
  suppress_interactive_data_checks (bool):
@@ -1136,10 +1148,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1136
1148
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1137
1149
  A tuple containing
1138
1150
  1. the average weighted estimating function stack
1139
- 2. the user-level adaptive meat matrix contributions
1140
- 3. the user-level classical meat matrix contributions
1141
- 4. the user-level inverse classical bread matrix contributions
1142
- 5. raw per-user weighted estimating function
1151
+ 2. the subject-level adjusted meat matrix contributions
1152
+ 3. the subject-level classical meat matrix contributions
1153
+ 4. the subject-level inverse classical bread matrix contributions
1154
+ 5. raw per-subject weighted estimating function
1143
1155
  stacks.
1144
1156
  """
1145
1157
 
@@ -1166,15 +1178,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1166
1178
  # supplied for the above functions, so that differentiation works correctly. The existing
1167
1179
  # values should be the same, but not connected to the parameter we are differentiating
1168
1180
  # with respect to. Note we will also find it useful below to have the action probability args
1169
- # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
1181
+ # nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
1170
1182
 
1171
- logger.info("Threading in betas to action probability arguments for all users.")
1183
+ logger.info("Threading in betas to action probability arguments for all subjects.")
1172
1184
  (
1173
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1174
- action_prob_func_args_by_decision_time_by_user_id,
1185
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1186
+ action_prob_func_args_by_decision_time_by_subject_id,
1175
1187
  ) = thread_action_prob_func_args(
1176
- action_prob_func_args_by_user_id_by_decision_time,
1177
- policy_num_by_decision_time_by_user_id,
1188
+ action_prob_func_args_by_subject_id_by_decision_time,
1189
+ policy_num_by_decision_time_by_subject_id,
1178
1190
  initial_policy_num,
1179
1191
  betas,
1180
1192
  beta_index_by_policy_num,
@@ -1186,17 +1198,17 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1186
1198
  # arguments with the central betas introduced.
1187
1199
  logger.info(
1188
1200
  "Threading in betas and beta-dependent action probabilities to algorithm update "
1189
- "function args for all users"
1201
+ "function args for all subjects"
1190
1202
  )
1191
- threaded_update_func_args_by_policy_num_by_user_id = thread_update_func_args(
1192
- update_func_args_by_by_user_id_by_policy_num,
1203
+ threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
1204
+ update_func_args_by_by_subject_id_by_policy_num,
1193
1205
  betas,
1194
1206
  beta_index_by_policy_num,
1195
1207
  alg_update_func_args_beta_index,
1196
1208
  alg_update_func_args_action_prob_index,
1197
1209
  alg_update_func_args_action_prob_times_index,
1198
1210
  alg_update_func_args_previous_betas_index,
1199
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1211
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1200
1212
  action_prob_func,
1201
1213
  )
1202
1214
 
@@ -1206,8 +1218,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1206
1218
  if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
1207
1219
  input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
1208
1220
  algorithm_estimating_func,
1209
- update_func_args_by_by_user_id_by_policy_num,
1210
- threaded_update_func_args_by_policy_num_by_user_id,
1221
+ update_func_args_by_by_subject_id_by_policy_num,
1222
+ threaded_update_func_args_by_policy_num_by_subject_id,
1211
1223
  suppress_interactive_data_checks,
1212
1224
  )
1213
1225
 
@@ -1216,15 +1228,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1216
1228
  # arguments with the central betas introduced.
1217
1229
  logger.info(
1218
1230
  "Threading in theta and beta-dependent action probabilities to inference update "
1219
- "function args for all users"
1231
+ "function args for all subjects"
1220
1232
  )
1221
- threaded_inference_func_args_by_user_id = thread_inference_func_args(
1222
- inference_func_args_by_user_id,
1233
+ threaded_inference_func_args_by_subject_id = thread_inference_func_args(
1234
+ inference_func_args_by_subject_id,
1223
1235
  inference_func_args_theta_index,
1224
1236
  theta,
1225
1237
  inference_func_args_action_prob_index,
1226
- threaded_action_prob_func_args_by_decision_time_by_user_id,
1227
- inference_action_prob_decision_times_by_user_id,
1238
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
1239
+ inference_action_prob_decision_times_by_subject_id,
1228
1240
  action_prob_func,
1229
1241
  )
1230
1242
 
@@ -1234,32 +1246,32 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1234
1246
  if not suppress_all_data_checks and inference_func_args_action_prob_index >= 0:
1235
1247
  input_checks.require_threaded_inference_estimating_function_args_equivalent(
1236
1248
  inference_estimating_func,
1237
- inference_func_args_by_user_id,
1238
- threaded_inference_func_args_by_user_id,
1249
+ inference_func_args_by_subject_id,
1250
+ threaded_inference_func_args_by_subject_id,
1239
1251
  suppress_interactive_data_checks,
1240
1252
  )
1241
1253
 
1242
- # 5. Now we can compute the weighted estimating function stacks for all users
1243
- # as well as collect related values used to construct the adaptive and classical
1254
+ # 5. Now we can compute the weighted estimating function stacks for all subjects
1255
+ # as well as collect related values used to construct the adjusted and classical
1244
1256
  # sandwich variances.
1245
1257
  results = [
1246
- single_user_weighted_estimating_function_stacker(
1258
+ single_subject_weighted_estimating_function_stacker(
1247
1259
  beta_dim,
1248
- user_id,
1260
+ subject_id,
1249
1261
  action_prob_func,
1250
1262
  algorithm_estimating_func,
1251
1263
  inference_estimating_func,
1252
1264
  action_prob_func_args_beta_index,
1253
1265
  inference_func_args_theta_index,
1254
- action_prob_func_args_by_decision_time_by_user_id[user_id],
1255
- threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
1256
- threaded_update_func_args_by_policy_num_by_user_id[user_id],
1257
- threaded_inference_func_args_by_user_id[user_id],
1258
- policy_num_by_decision_time_by_user_id[user_id],
1259
- action_by_decision_time_by_user_id[user_id],
1266
+ action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1267
+ threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
1268
+ threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
1269
+ threaded_inference_func_args_by_subject_id[subject_id],
1270
+ policy_num_by_decision_time_by_subject_id[subject_id],
1271
+ action_by_decision_time_by_subject_id[subject_id],
1260
1272
  beta_index_by_policy_num,
1261
1273
  )
1262
- for user_id in user_ids.tolist()
1274
+ for subject_id in subject_ids.tolist()
1263
1275
  ]
1264
1276
 
1265
1277
  stacks = jnp.array([result[0] for result in results])
@@ -1269,11 +1281,12 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1269
1281
 
1270
1282
  # 6. Note this strange return structure! We will differentiate the first output,
1271
1283
  # but the second tuple will be passed along without modification via has_aux=True and then used
1272
- # for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
1273
- # bread matrices. The raw per-user stacks are also returned for debugging purposes.
1284
+ # for the estimating functions sum check, per_subject_classical_bread_contributions, and
1285
+ # classical meat and inverse read matrices. The raw per-subject stacks are also returned for
1286
+ # debugging purposes.
1274
1287
 
1275
- # Note that returning the raw stacks here as the first arguments is potentially
1276
- # memory-intensive when combined with differentiation. Keep this in mind if the per-user bread
1288
+ # Note that returning the raw stacks here as the first argument is potentially
1289
+ # memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
1277
1290
  # inverse contributions are needed for something like CR2/CR3 small-sample corrections.
1278
1291
  return jnp.mean(stacks, axis=0), (
1279
1292
  jnp.mean(stacks, axis=0),
@@ -1284,10 +1297,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1284
1297
  )
1285
1298
 
1286
1299
 
1287
- def construct_classical_and_adaptive_sandwiches(
1300
+ def construct_classical_and_adjusted_sandwiches(
1288
1301
  theta_est: jnp.ndarray,
1289
1302
  all_post_update_betas: jnp.ndarray,
1290
- user_ids: jnp.ndarray,
1303
+ subject_ids: jnp.ndarray,
1291
1304
  action_prob_func: callable,
1292
1305
  action_prob_func_args_beta_index: int,
1293
1306
  alg_update_func: callable,
@@ -1300,32 +1313,34 @@ def construct_classical_and_adaptive_sandwiches(
1300
1313
  inference_func_type: str,
1301
1314
  inference_func_args_theta_index: int,
1302
1315
  inference_func_args_action_prob_index: int,
1303
- action_prob_func_args_by_user_id_by_decision_time: dict[
1316
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
1304
1317
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
1305
1318
  ],
1306
- policy_num_by_decision_time_by_user_id: dict[
1319
+ policy_num_by_decision_time_by_subject_id: dict[
1307
1320
  collections.abc.Hashable, dict[int, int | float]
1308
1321
  ],
1309
1322
  initial_policy_num: int | float,
1310
1323
  beta_index_by_policy_num: dict[int | float, int],
1311
- inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1312
- inference_action_prob_decision_times_by_user_id: dict[
1324
+ inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
1325
+ inference_action_prob_decision_times_by_subject_id: dict[
1313
1326
  collections.abc.Hashable, list[int]
1314
1327
  ],
1315
- update_func_args_by_by_user_id_by_policy_num: dict[
1328
+ update_func_args_by_by_subject_id_by_policy_num: dict[
1316
1329
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
1317
1330
  ],
1318
- action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
1331
+ action_by_decision_time_by_subject_id: dict[
1332
+ collections.abc.Hashable, dict[int, int]
1333
+ ],
1319
1334
  suppress_all_data_checks: bool,
1320
1335
  suppress_interactive_data_checks: bool,
1321
1336
  small_sample_correction: str,
1322
- form_adaptive_meat_adjustments_explicitly: bool,
1323
- stabilize_joint_adaptive_bread_inverse: bool,
1324
- study_df: pd.DataFrame | None,
1325
- in_study_col_name: str | None,
1337
+ form_adjusted_meat_adjustments_explicitly: bool,
1338
+ stabilize_joint_bread: bool,
1339
+ analysis_df: pd.DataFrame | None,
1340
+ active_col_name: str | None,
1326
1341
  action_col_name: str | None,
1327
1342
  calendar_t_col_name: str | None,
1328
- user_id_col_name: str | None,
1343
+ subject_id_col_name: str | None,
1329
1344
  action_prob_func_args: tuple | None,
1330
1345
  action_prob_col_name: str | None,
1331
1346
  ) -> tuple[
@@ -1342,11 +1357,11 @@ def construct_classical_and_adaptive_sandwiches(
1342
1357
  jnp.ndarray[jnp.float32],
1343
1358
  ]:
1344
1359
  """
1345
- Constructs the classical and adaptive sandwich matrices, as well as various
1360
+ Constructs the classical and adjusted sandwich matrices, as well as various
1346
1361
  intermediate pieces in their consruction.
1347
1362
 
1348
1363
  This is done by computing and differentiating the average weighted estimating function stack
1349
- with respect to the betas and theta, using the resulting Jacobian to compute the inverse bread
1364
+ with respect to the betas and theta, using the resulting Jacobian to compute the bread
1350
1365
  and meat matrices, and then stably computing sandwiches.
1351
1366
 
1352
1367
  Args:
@@ -1354,8 +1369,8 @@ def construct_classical_and_adaptive_sandwiches(
1354
1369
  A 1-D JAX NumPy array representing the parameter estimate for inference.
1355
1370
  all_post_update_betas (jnp.ndarray):
1356
1371
  A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
1357
- user_ids (jnp.ndarray):
1358
- A 1-D JAX NumPy array holding all user IDs in the study.
1372
+ subject_ids (jnp.ndarray):
1373
+ A 1-D JAX NumPy array holding all subject IDs in the deployment.
1359
1374
  action_prob_func (callable):
1360
1375
  The action probability function.
1361
1376
  action_prob_func_args_beta_index (int):
@@ -1383,29 +1398,29 @@ def construct_classical_and_adaptive_sandwiches(
1383
1398
  inference_func_args_action_prob_index (int):
1384
1399
  The index of action probabilities in the inference function arguments tuple, if
1385
1400
  applicable. -1 otherwise.
1386
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1387
- A dictionary mapping decision times to maps of user ids to the function arguments
1388
- required to compute action probabilities for this user.
1389
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1390
- A map of user ids to dictionaries mapping decision times to the policy number in use.
1391
- Only applies to in-study decision times!
1401
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
1402
+ A dictionary mapping decision times to maps of subject ids to the function arguments
1403
+ required to compute action probabilities for this subject.
1404
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
1405
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
1406
+ Only applies to active decision times!
1392
1407
  initial_policy_num (int | float):
1393
1408
  The policy number of the initial policy before any updates.
1394
1409
  beta_index_by_policy_num (dict[int | float, int]):
1395
1410
  A dictionary mapping policy numbers to the index of the corresponding beta in
1396
1411
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1397
- inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1398
- A dictionary mapping user IDs to their respective inference function arguments.
1399
- inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1400
- For each user, a list of decision times to which action probabilities correspond if
1401
- provided. Typically just in-study times if action probabilites are used in the inference
1412
+ inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1413
+ A dictionary mapping subject IDs to their respective inference function arguments.
1414
+ inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
1415
+ For each subject, a list of decision times to which action probabilities correspond if
1416
+ provided. Typically just active times if action probabilites are used in the inference
1402
1417
  loss or estimating function.
1403
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1404
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
1418
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
1419
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
1405
1420
  to their respective update function arguments.
1406
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1407
- A dictionary mapping user IDs to their respective actions taken at each decision time.
1408
- Only applies to in-study decision times!
1421
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
1422
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
1423
+ Only applies to active decision times!
1409
1424
  suppress_all_data_checks (bool):
1410
1425
  If True, suppresses carrying out any data checks at all.
1411
1426
  suppress_interactive_data_checks (bool):
@@ -1415,43 +1430,43 @@ def construct_classical_and_adaptive_sandwiches(
1415
1430
  small_sample_correction (str):
1416
1431
  The type of small sample correction to apply. See SmallSampleCorrections class for
1417
1432
  options.
1418
- form_adaptive_meat_adjustments_explicitly (bool):
1419
- If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
1433
+ form_adjusted_meat_adjustments_explicitly (bool):
1434
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
1420
1435
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
1421
- adaptive sandwich is formed without doing this.
1422
- stabilize_joint_adaptive_bread_inverse (bool):
1423
- If True, will apply various techniques to stabilize the joint adaptive bread inverse if necessary.
1424
- study_df (pd.DataFrame):
1425
- The full study dataframe, needed if forming the adaptive meat adjustments explicitly.
1426
- in_study_col_name (str):
1427
- The name of the column in study_df indicating whether a user is in-study at a given decision time.
1436
+ adjusted sandwich is formed without doing this.
1437
+ stabilize_joint_bread (bool):
1438
+ If True, will apply various techniques to stabilize the joint bread if necessary.
1439
+ analysis_df (pd.DataFrame):
1440
+ The full analysis dataframe, needed if forming the adjusted meat adjustments explicitly.
1441
+ active_col_name (str):
1442
+ The name of the column in analysis_df indicating whether a subject is active at a given decision time.
1428
1443
  action_col_name (str):
1429
- The name of the column in study_df indicating the action taken at a given decision time.
1444
+ The name of the column in analysis_df indicating the action taken at a given decision time.
1430
1445
  calendar_t_col_name (str):
1431
- The name of the column in study_df indicating the calendar time of a given decision time.
1432
- user_id_col_name (str):
1433
- The name of the column in study_df indicating the user ID.
1446
+ The name of the column in analysis_df indicating the calendar time of a given decision time.
1447
+ subject_id_col_name (str):
1448
+ The name of the column in analysis_df indicating the subject ID.
1434
1449
  action_prob_func_args (tuple):
1435
1450
  The arguments to be passed to the action probability function, needed if forming the
1436
- adaptive meat adjustments explicitly.
1451
+ adjusted meat adjustments explicitly.
1437
1452
  action_prob_col_name (str):
1438
- The name of the column in study_df indicating the action probability of the action taken,
1439
- needed if forming the adaptive meat adjustments explicitly.
1453
+ The name of the column in analysis_df indicating the action probability of the action taken,
1454
+ needed if forming the adjusted meat adjustments explicitly.
1440
1455
  Returns:
1441
1456
  tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
1442
1457
  A tuple containing:
1443
- - The raw joint adaptive inverse bread matrix.
1444
- - The (possibly) stabilized joint adaptive inverse bread matrix.
1445
- - The joint adaptive meat matrix.
1446
- - The joint adaptive sandwich matrix.
1447
- - The classical inverse bread matrix.
1458
+ - The raw joint bread matrix.
1459
+ - The (possibly) stabilized joint bread matrix.
1460
+ - The joint meat matrix.
1461
+ - The joint sandwich matrix.
1462
+ - The classical bread matrix.
1448
1463
  - The classical meat matrix.
1449
1464
  - The classical sandwich matrix.
1450
1465
  - The average weighted estimating function stack.
1451
- - All per-user weighted estimating function stacks.
1452
- - The per-user adaptive meat small-sample corrections.
1453
- - The per-user classical meat small-sample corrections.
1454
- - The per-user adaptive meat adjustments, if form_adaptive_meat_adjustments_explicitly
1466
+ - All per-subject weighted estimating function stacks.
1467
+ - The per-subject adjusted meat small-sample corrections.
1468
+ - The per-subject classical meat small-sample corrections.
1469
+ - The per-subject adjusted meat adjustments, if form_adjusted_meat_adjustments_explicitly
1455
1470
  is True, otherwise an array of NaNs.
1456
1471
  """
1457
1472
  logger.info(
@@ -1459,13 +1474,13 @@ def construct_classical_and_adaptive_sandwiches(
1459
1474
  )
1460
1475
  theta_dim = theta_est.shape[0]
1461
1476
  beta_dim = all_post_update_betas.shape[1]
1462
- # Note that these "contributions" are per-user Jacobians of the weighted estimating function stack.
1463
- raw_joint_adaptive_bread_inverse_matrix, (
1477
+ # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
1478
+ raw_joint_adjusted_bread_matrix, (
1464
1479
  avg_estimating_function_stack,
1465
- per_user_joint_adaptive_meat_contributions,
1466
- per_user_classical_meat_contributions,
1467
- per_user_classical_bread_inverse_contributions,
1468
- per_user_estimating_function_stacks,
1480
+ per_subject_joint_adjusted_meat_contributions,
1481
+ per_subject_classical_meat_contributions,
1482
+ per_subject_classical_bread_contributions,
1483
+ per_subject_estimating_function_stacks,
1469
1484
  ) = jax.jacrev(
1470
1485
  get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
1471
1486
  )(
@@ -1475,7 +1490,7 @@ def construct_classical_and_adaptive_sandwiches(
1475
1490
  flatten_params(all_post_update_betas, theta_est),
1476
1491
  beta_dim,
1477
1492
  theta_dim,
1478
- user_ids,
1493
+ subject_ids,
1479
1494
  action_prob_func,
1480
1495
  action_prob_func_args_beta_index,
1481
1496
  alg_update_func,
@@ -1488,166 +1503,164 @@ def construct_classical_and_adaptive_sandwiches(
1488
1503
  inference_func_type,
1489
1504
  inference_func_args_theta_index,
1490
1505
  inference_func_args_action_prob_index,
1491
- action_prob_func_args_by_user_id_by_decision_time,
1492
- policy_num_by_decision_time_by_user_id,
1506
+ action_prob_func_args_by_subject_id_by_decision_time,
1507
+ policy_num_by_decision_time_by_subject_id,
1493
1508
  initial_policy_num,
1494
1509
  beta_index_by_policy_num,
1495
- inference_func_args_by_user_id,
1496
- inference_action_prob_decision_times_by_user_id,
1497
- update_func_args_by_by_user_id_by_policy_num,
1498
- action_by_decision_time_by_user_id,
1510
+ inference_func_args_by_subject_id,
1511
+ inference_action_prob_decision_times_by_subject_id,
1512
+ update_func_args_by_by_subject_id_by_policy_num,
1513
+ action_by_decision_time_by_subject_id,
1499
1514
  suppress_all_data_checks,
1500
1515
  suppress_interactive_data_checks,
1501
1516
  )
1502
1517
 
1503
- num_users = len(user_ids)
1518
+ num_subjects = len(subject_ids)
1504
1519
 
1505
1520
  (
1506
- joint_adaptive_meat_matrix,
1521
+ joint_adjusted_meat_matrix,
1507
1522
  classical_meat_matrix,
1508
- per_user_adaptive_corrections,
1509
- per_user_classical_corrections,
1523
+ per_subject_adjusted_corrections,
1524
+ per_subject_classical_corrections,
1510
1525
  ) = perform_desired_small_sample_correction(
1511
1526
  small_sample_correction,
1512
- per_user_joint_adaptive_meat_contributions,
1513
- per_user_classical_meat_contributions,
1514
- per_user_classical_bread_inverse_contributions,
1515
- num_users,
1527
+ per_subject_joint_adjusted_meat_contributions,
1528
+ per_subject_classical_meat_contributions,
1529
+ per_subject_classical_bread_contributions,
1530
+ num_subjects,
1516
1531
  theta_dim,
1517
1532
  )
1518
1533
 
1519
1534
  # Increase diagonal block dominance possibly improve conditioning of diagonal
1520
- # blocks as necessary, to ensure mathematical stability of joint bread inverse
1521
- stabilized_joint_adaptive_bread_inverse_matrix = (
1535
+ # blocks as necessary, to ensure mathematical stability of joint bread
1536
+ stabilized_joint_adjusted_bread_matrix = (
1522
1537
  (
1523
- stabilize_joint_adaptive_bread_inverse_if_necessary(
1524
- raw_joint_adaptive_bread_inverse_matrix,
1538
+ stabilize_joint_bread_if_necessary(
1539
+ raw_joint_adjusted_bread_matrix,
1525
1540
  beta_dim,
1526
1541
  theta_dim,
1527
1542
  )
1528
1543
  )
1529
- if stabilize_joint_adaptive_bread_inverse
1530
- else raw_joint_adaptive_bread_inverse_matrix
1544
+ if stabilize_joint_bread
1545
+ else raw_joint_adjusted_bread_matrix
1531
1546
  )
1532
1547
 
1533
1548
  # Now stably (no explicit inversion) form our sandwiches.
1534
- joint_adaptive_sandwich = form_sandwich_from_bread_inverse_and_meat(
1535
- stabilized_joint_adaptive_bread_inverse_matrix,
1536
- joint_adaptive_meat_matrix,
1537
- num_users,
1538
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1539
- )
1540
- classical_bread_inverse_matrix = jnp.mean(
1541
- per_user_classical_bread_inverse_contributions, axis=0
1549
+ joint_adjusted_sandwich = form_sandwich_from_bread_and_meat(
1550
+ stabilized_joint_adjusted_bread_matrix,
1551
+ joint_adjusted_meat_matrix,
1552
+ num_subjects,
1553
+ method=SandwichFormationMethods.BREAD_T_QR,
1542
1554
  )
1543
- classical_sandwich = form_sandwich_from_bread_inverse_and_meat(
1544
- classical_bread_inverse_matrix,
1555
+ classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
1556
+ classical_sandwich = form_sandwich_from_bread_and_meat(
1557
+ classical_bread_matrix,
1545
1558
  classical_meat_matrix,
1546
- num_users,
1547
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1559
+ num_subjects,
1560
+ method=SandwichFormationMethods.BREAD_T_QR,
1548
1561
  )
1549
1562
 
1550
- per_user_adaptive_meat_adjustments = jnp.full(
1551
- (len(user_ids), theta_dim, theta_dim), jnp.nan
1563
+ per_subject_adjusted_meat_adjustments = jnp.full(
1564
+ (len(subject_ids), theta_dim, theta_dim), jnp.nan
1552
1565
  )
1553
- if form_adaptive_meat_adjustments_explicitly:
1554
- per_user_adjusted_classical_meat_contributions = (
1555
- form_adaptive_meat_adjustments_directly(
1566
+ if form_adjusted_meat_adjustments_explicitly:
1567
+ per_subject_adjusted_classical_meat_contributions = (
1568
+ form_adjusted_meat_adjustments_directly(
1556
1569
  theta_dim,
1557
1570
  all_post_update_betas.shape[1],
1558
- stabilized_joint_adaptive_bread_inverse_matrix,
1559
- per_user_estimating_function_stacks,
1560
- study_df,
1561
- in_study_col_name,
1571
+ stabilized_joint_adjusted_bread_matrix,
1572
+ per_subject_estimating_function_stacks,
1573
+ analysis_df,
1574
+ active_col_name,
1562
1575
  action_col_name,
1563
1576
  calendar_t_col_name,
1564
- user_id_col_name,
1577
+ subject_id_col_name,
1565
1578
  action_prob_func,
1566
1579
  action_prob_func_args,
1567
1580
  action_prob_func_args_beta_index,
1568
1581
  theta_est,
1569
1582
  inference_func,
1570
1583
  inference_func_args_theta_index,
1571
- user_ids,
1584
+ subject_ids,
1572
1585
  action_prob_col_name,
1573
1586
  )
1574
1587
  )
1575
- # Validate that the adaptive meat adjustments we just formed are accurate by constructing
1576
- # the theta-only adaptive sandwich from them and checking that it matches the standard result
1577
- # we get by taking a subset of the joint adaptive sandwich.
1588
+ # Validate that the adjusted meat adjustments we just formed are accurate by constructing
1589
+ # the theta-only adjusted sandwich from them and checking that it matches the standard result
1590
+ # we get by taking a subset of the joint sandwich.
1578
1591
  # First just apply any small-sample correction for parity.
1579
1592
  (
1580
1593
  _,
1581
- theta_only_adaptive_meat_matrix_v2,
1594
+ theta_only_adjusted_meat_matrix_v2,
1582
1595
  _,
1583
1596
  _,
1584
1597
  ) = perform_desired_small_sample_correction(
1585
1598
  small_sample_correction,
1586
- per_user_joint_adaptive_meat_contributions,
1587
- per_user_adjusted_classical_meat_contributions,
1588
- per_user_classical_bread_inverse_contributions,
1589
- num_users,
1599
+ per_subject_joint_adjusted_meat_contributions,
1600
+ per_subject_adjusted_classical_meat_contributions,
1601
+ per_subject_classical_bread_contributions,
1602
+ num_subjects,
1590
1603
  theta_dim,
1591
1604
  )
1592
- theta_only_adaptive_sandwich_from_adjustments = (
1593
- form_sandwich_from_bread_inverse_and_meat(
1594
- classical_bread_inverse_matrix,
1595
- theta_only_adaptive_meat_matrix_v2,
1596
- num_users,
1597
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1605
+ theta_only_adjusted_sandwich_from_adjustments = (
1606
+ form_sandwich_from_bread_and_meat(
1607
+ classical_bread_matrix,
1608
+ theta_only_adjusted_meat_matrix_v2,
1609
+ num_subjects,
1610
+ method=SandwichFormationMethods.BREAD_T_QR,
1598
1611
  )
1599
1612
  )
1600
- theta_only_adaptive_sandwich = joint_adaptive_sandwich[-theta_dim:, -theta_dim:]
1613
+ theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
1601
1614
 
1602
1615
  if not np.allclose(
1603
- theta_only_adaptive_sandwich,
1604
- theta_only_adaptive_sandwich_from_adjustments,
1616
+ theta_only_adjusted_sandwich,
1617
+ theta_only_adjusted_sandwich_from_adjustments,
1605
1618
  rtol=3e-2,
1606
1619
  ):
1607
1620
  logger.warning(
1608
- "There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adaptive sandwich two different ways and they do not match sufficiently."
1621
+ "There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adjusted sandwich two different ways and they do not match sufficiently."
1609
1622
  )
1610
1623
 
1611
- # Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
1612
- # values too. The joint adaptive bread inverse should always be block lower triangular.
1624
+ # Stack the joint bread pieces together horizontally and return the auxiliary
1625
+ # values too. The joint bread should always be block lower triangular.
1613
1626
  return (
1614
- raw_joint_adaptive_bread_inverse_matrix,
1615
- stabilized_joint_adaptive_bread_inverse_matrix,
1616
- joint_adaptive_meat_matrix,
1617
- joint_adaptive_sandwich,
1618
- classical_bread_inverse_matrix,
1627
+ raw_joint_adjusted_bread_matrix,
1628
+ stabilized_joint_adjusted_bread_matrix,
1629
+ joint_adjusted_meat_matrix,
1630
+ joint_adjusted_sandwich,
1631
+ classical_bread_matrix,
1619
1632
  classical_meat_matrix,
1620
1633
  classical_sandwich,
1621
1634
  avg_estimating_function_stack,
1622
- per_user_estimating_function_stacks,
1623
- per_user_adaptive_corrections,
1624
- per_user_classical_corrections,
1625
- per_user_adaptive_meat_adjustments,
1635
+ per_subject_estimating_function_stacks,
1636
+ per_subject_adjusted_corrections,
1637
+ per_subject_classical_corrections,
1638
+ per_subject_adjusted_meat_adjustments,
1626
1639
  )
1627
1640
 
1628
1641
 
1629
1642
  # TODO: I think there should be interaction to confirm stabilization. It is
1630
- # important for the user to know if this is happening. Even if enabled, it is important
1631
- # that the user know it actually kicks in.
1632
- def stabilize_joint_adaptive_bread_inverse_if_necessary(
1633
- joint_adaptive_bread_inverse_matrix: jnp.ndarray,
1643
+ # important for the subject to know if this is happening. Even if enabled, it is important
1644
+ # that the subject know it actually kicks in.
1645
+ def stabilize_joint_bread_if_necessary(
1646
+ joint_adjusted_bread_matrix: jnp.ndarray,
1634
1647
  beta_dim: int,
1635
1648
  theta_dim: int,
1636
1649
  ) -> jnp.ndarray:
1637
1650
  """
1638
- Stabilizes the joint adaptive bread inverse matrix if necessary by increasing diagonal block
1651
+ Stabilizes the joint bread matrix if necessary by increasing diagonal block
1639
1652
  dominance and/or adding a small ridge penalty to the diagonal blocks.
1640
1653
 
1641
1654
  Args:
1642
- joint_adaptive_bread_inverse_matrix (jnp.ndarray):
1643
- A 2-D JAX NumPy array representing the joint adaptive bread inverse matrix.
1655
+ joint_adjusted_bread_matrix (jnp.ndarray):
1656
+ A 2-D JAX NumPy array representing the joint bread matrix.
1644
1657
  beta_dim (int):
1645
1658
  The dimension of each beta parameter.
1646
1659
  theta_dim (int):
1647
1660
  The dimension of the theta parameter.
1648
1661
  Returns:
1649
1662
  jnp.ndarray:
1650
- A 2-D NumPy array representing the stabilized joint adaptive bread inverse matrix.
1663
+ A 2-D NumPy array representing the stabilized joint bread matrix.
1651
1664
  """
1652
1665
 
1653
1666
  # TODO: come up with more sophisticated settings here. These are maybe a little loose,
@@ -1660,7 +1673,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1660
1673
 
1661
1674
  # Grab just the RL block and convert numpy array for easier manipulation.
1662
1675
  RL_stack_beta_derivatives_block = np.array(
1663
- joint_adaptive_bread_inverse_matrix[:-theta_dim, :-theta_dim]
1676
+ joint_adjusted_bread_matrix[:-theta_dim, :-theta_dim]
1664
1677
  )
1665
1678
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
1666
1679
  for i in range(1, num_updates + 1):
@@ -1688,7 +1701,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1688
1701
  RL_stack_beta_derivatives_block[
1689
1702
  diagonal_block_slice, diagonal_block_slice
1690
1703
  ] = diagonal_block + ridge_penalty * np.eye(beta_dim)
1691
- # TODO: Require user input here in interactive settings?
1704
+ # TODO: Require subject input here in interactive settings?
1692
1705
  logger.info(
1693
1706
  "Added ridge penalty of %s to diagonal block for update %s to improve conditioning from %s to %s",
1694
1707
  ridge_penalty,
@@ -1779,44 +1792,44 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
1779
1792
  [
1780
1793
  [
1781
1794
  RL_stack_beta_derivatives_block,
1782
- joint_adaptive_bread_inverse_matrix[:-theta_dim, -theta_dim:],
1795
+ joint_adjusted_bread_matrix[:-theta_dim, -theta_dim:],
1783
1796
  ],
1784
1797
  [
1785
- joint_adaptive_bread_inverse_matrix[-theta_dim:, :-theta_dim],
1786
- joint_adaptive_bread_inverse_matrix[-theta_dim:, -theta_dim:],
1798
+ joint_adjusted_bread_matrix[-theta_dim:, :-theta_dim],
1799
+ joint_adjusted_bread_matrix[-theta_dim:, -theta_dim:],
1787
1800
  ],
1788
1801
  ]
1789
1802
  )
1790
1803
 
1791
1804
 
1792
- def form_sandwich_from_bread_inverse_and_meat(
1793
- bread_inverse: jnp.ndarray,
1805
+ def form_sandwich_from_bread_and_meat(
1806
+ bread: jnp.ndarray,
1794
1807
  meat: jnp.ndarray,
1795
- num_users: int,
1796
- method: str = SandwichFormationMethods.BREAD_INVERSE_T_QR,
1808
+ num_subjects: int,
1809
+ method: str = SandwichFormationMethods.BREAD_T_QR,
1797
1810
  ) -> jnp.ndarray:
1798
1811
  """
1799
- Forms a sandwich variance matrix from the provided bread inverse and meat matrices.
1812
+ Forms a sandwich variance matrix from the provided bread and meat matrices.
1800
1813
 
1801
- Attempts to do so STABLY without ever forming the bread matrix itself
1814
+ Attempts to do so STABLY without ever forming the bread inverse matrix itself
1802
1815
  (except with naive option).
1803
1816
 
1804
1817
  Args:
1805
- bread_inverse (jnp.ndarray):
1806
- A 2-D JAX NumPy array representing the bread inverse matrix.
1818
+ bread (jnp.ndarray):
1819
+ A 2-D JAX NumPy array representing the bread matrix.
1807
1820
  meat (jnp.ndarray):
1808
1821
  A 2-D JAX NumPy array representing the meat matrix.
1809
- num_users (int):
1810
- The number of users in the study, used to scale the sandwich appropriately.
1822
+ num_subjects (int):
1823
+ The number of subjects in the deployment, used to scale the sandwich appropriately.
1811
1824
  method (str):
1812
1825
  The method to use for forming the sandwich.
1813
1826
 
1814
- SandwichFormationMethods.BREAD_INVERSE_T_QR uses the QR decomposition of the transpose
1815
- of the bread inverse matrix.
1827
+ SandwichFormationMethods.BREAD_T_QR uses the QR decomposition of the transpose
1828
+ of the bread matrix.
1816
1829
 
1817
1830
  SandwichFormationMethods.MEAT_SVD_SOLVE uses a decomposition of the meat matrix.
1818
1831
 
1819
- SandwichFormationMethods.NAIVE simply inverts the bread inverse and forms the sandwich.
1832
+ SandwichFormationMethods.NAIVE simply inverts the bread and forms the sandwich.
1820
1833
 
1821
1834
 
1822
1835
  Returns:
@@ -1824,16 +1837,16 @@ def form_sandwich_from_bread_inverse_and_meat(
1824
1837
  A 2-D JAX NumPy array representing the sandwich variance matrix.
1825
1838
  """
1826
1839
 
1827
- if method == SandwichFormationMethods.BREAD_INVERSE_T_QR:
1840
+ if method == SandwichFormationMethods.BREAD_T_QR:
1828
1841
  # QR of B^T → Q orthogonal, R upper triangular; L = R^T lower triangular
1829
- Q, R = np.linalg.qr(bread_inverse.T, mode="reduced")
1842
+ Q, R = np.linalg.qr(bread.T, mode="reduced")
1830
1843
  L = R.T
1831
1844
 
1832
1845
  new_meat = scipy.linalg.solve_triangular(
1833
1846
  L, scipy.linalg.solve_triangular(L, meat.T, lower=True).T, lower=True
1834
1847
  )
1835
1848
 
1836
- return Q @ new_meat @ Q.T / num_users
1849
+ return Q @ new_meat @ Q.T / num_subjects
1837
1850
  elif method == SandwichFormationMethods.MEAT_SVD_SOLVE:
1838
1851
  # Factor the meat via SVD without any symmetrization or truncation.
1839
1852
  # For general (possibly slightly nonsymmetric) M, SVD gives M = U @ diag(s) @ Vh.
@@ -1844,21 +1857,21 @@ def form_sandwich_from_bread_inverse_and_meat(
1844
1857
  C_right = Vh.T * np.sqrt(s)
1845
1858
 
1846
1859
  # Solve B W_left = C_left and B W_right = C_right (no explicit inverses).
1847
- W_left = scipy.linalg.solve(bread_inverse, C_left)
1848
- W_right = scipy.linalg.solve(bread_inverse, C_right)
1860
+ W_left = scipy.linalg.solve(bread, C_left)
1861
+ W_right = scipy.linalg.solve(bread, C_right)
1849
1862
 
1850
- # Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_users
1851
- return W_left @ W_right.T / num_users
1863
+ # Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
1864
+ return W_left @ W_right.T / num_subjects
1852
1865
 
1853
1866
  elif method == SandwichFormationMethods.NAIVE:
1854
- # Simply invert the bread inverse and form the sandwich directly.
1867
+ # Simply invert the bread and form the sandwich directly.
1855
1868
  # This is NOT numerically stable and is only included for comparison purposes.
1856
- bread = np.linalg.inv(bread_inverse)
1857
- return bread @ meat @ meat.T / num_users
1869
+ bread_inverse = np.linalg.inv(bread)
1870
+ return bread_inverse @ meat @ bread_inverse.T / num_subjects
1858
1871
 
1859
1872
  else:
1860
1873
  raise ValueError(
1861
- f"Unknown sandwich method: {method}. Please use 'bread_inverse_t_qr' or 'meat_decomposition_solve'."
1874
+ f"Unknown sandwich method: {method}. Please use 'bread_t_qr' or 'meat_decomposition_solve'."
1862
1875
  )
1863
1876
 
1864
1877