lifejacket 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -217,9 +217,9 @@ def cli():
217
217
  type=click.Choice(
218
218
  [
219
219
  SmallSampleCorrections.NONE,
220
- SmallSampleCorrections.HC1theta,
221
- SmallSampleCorrections.HC2theta,
222
- SmallSampleCorrections.HC3theta,
220
+ SmallSampleCorrections.Z1theta,
221
+ SmallSampleCorrections.Z2theta,
222
+ SmallSampleCorrections.Z3theta,
223
223
  ]
224
224
  ),
225
225
  default=SmallSampleCorrections.NONE,
@@ -235,13 +235,13 @@ def cli():
235
235
  "--form_adjusted_meat_adjustments_explicitly",
236
236
  type=bool,
237
237
  default=False,
238
- help="If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
238
+ help="If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted sandwich from the classical sandwich. This is for diagnostic purposes, as the adjusted sandwich is formed without doing this.",
239
239
  )
240
240
  @click.option(
241
- "--stabilize_joint_adjusted_bread_inverse",
241
+ "--stabilize_joint_bread",
242
242
  type=bool,
243
243
  default=True,
244
- help="If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning thresholds.",
244
+ help="If True, stabilizes the joint bread matrix if it does not meet conditioning thresholds.",
245
245
  )
246
246
  def analyze_dataset_wrapper(**kwargs):
247
247
  """
@@ -324,15 +324,15 @@ def analyze_dataset(
324
324
  small_sample_correction: str,
325
325
  collect_data_for_blowup_supervised_learning: bool,
326
326
  form_adjusted_meat_adjustments_explicitly: bool,
327
- stabilize_joint_adjusted_bread_inverse: bool,
327
+ stabilize_joint_bread: bool,
328
328
  ) -> None:
329
329
  """
330
- Analyzes a dataset to provide a parameter estimate and an estimate of its variance using adaptive and classical sandwich estimators.
330
+ Analyzes a dataset to provide a parameter estimate and an estimate of its variance using and classical sandwich estimators.
331
331
 
332
332
  There are two modes of use for this function.
333
333
 
334
334
  First, it may be called indirectly from the command line by passing through
335
- analyze_dataset.
335
+ analyze_dataset_wrapper.
336
336
 
337
337
  Second, it may be called directly from Python code with in-memory objects.
338
338
 
@@ -388,17 +388,17 @@ def analyze_dataset(
388
388
  small_sample_correction (str):
389
389
  Type of small sample correction to apply.
390
390
  collect_data_for_blowup_supervised_learning (bool):
391
- Whether to collect data for doing supervised learning about adaptive sandwich blowup.
391
+ Whether to collect data for doing supervised learning about adjusted sandwich blowup.
392
392
  form_adjusted_meat_adjustments_explicitly (bool):
393
- If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
393
+ If True, explicitly forms the per-subject meat adjustments that differentiate the
394
394
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
395
- adaptive sandwich is formed without doing this.
396
- stabilize_joint_adjusted_bread_inverse (bool):
397
- If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
395
+ adjusted sandwich is formed without doing this.
396
+ stabilize_joint_bread (bool):
397
+ If True, stabilizes the joint bread matrix if it does not meet conditioning
398
398
  thresholds.
399
399
 
400
400
  Returns:
401
- dict: A dictionary containing the theta estimate, adaptive sandwich variance estimate, and
401
+ dict: A dictionary containing the theta estimate, adjusted sandwich variance estimate, and
402
402
  classical sandwich variance estimate.
403
403
  """
404
404
 
@@ -438,7 +438,6 @@ def analyze_dataset(
438
438
  )
439
439
 
440
440
  ### Begin collecting data structures that will be used to compute the joint bread matrix.
441
-
442
441
  beta_index_by_policy_num, initial_policy_num = (
443
442
  construct_beta_index_by_policy_num_map(
444
443
  analysis_df, policy_num_col_name, active_col_name
@@ -475,20 +474,20 @@ def analyze_dataset(
475
474
  active_col_name,
476
475
  )
477
476
 
478
- # Use a per-subject weighted estimating function stacking functino to derive classical and joint
479
- # adaptive meat and inverse bread matrices. This is facilitated because the *value* of the
477
+ # Use a per-subject weighted estimating function stacking function to derive classical and joint
478
+ # meat and bread matrices. This is facilitated because the *value* of the
480
479
  # weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
481
480
  logger.info(
482
- "Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across subjects."
481
+ "Constructing joint bread matrix, joint meat matrix, the classical analogs, and the avg estimating function stack across subjects."
483
482
  )
484
483
 
485
484
  subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
486
485
  (
487
- stabilized_joint_adjusted_bread_inverse_matrix,
488
- raw_joint_adjusted_bread_inverse_matrix,
486
+ stabilized_joint_adjusted_bread_matrix,
487
+ raw_joint_adjusted_bread_matrix,
489
488
  joint_adjusted_meat_matrix,
490
489
  joint_adjusted_sandwich_matrix,
491
- classical_bread_inverse_matrix,
490
+ classical_bread_matrix,
492
491
  classical_meat_matrix,
493
492
  classical_sandwich_var_estimate,
494
493
  avg_estimating_function_stack,
@@ -524,7 +523,7 @@ def analyze_dataset(
524
523
  suppress_interactive_data_checks,
525
524
  small_sample_correction,
526
525
  form_adjusted_meat_adjustments_explicitly,
527
- stabilize_joint_adjusted_bread_inverse,
526
+ stabilize_joint_bread,
528
527
  analysis_df,
529
528
  active_col_name,
530
529
  action_col_name,
@@ -550,17 +549,16 @@ def analyze_dataset(
550
549
  ]
551
550
 
552
551
  # Check for negative diagonal elements and set them to zero if found
553
- adaptive_diagonal = np.diag(adjusted_sandwich_var_estimate)
554
- if np.any(adaptive_diagonal < 0):
552
+ adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
553
+ if np.any(adjusted_diagonal < 0):
555
554
  logger.warning(
556
- "Found negative diagonal elements in adaptive sandwich variance estimate. Setting them to zero."
555
+ "Found negative diagonal elements in adjusted sandwich variance estimate. Setting them to zero."
557
556
  )
558
557
  np.fill_diagonal(
559
- adjusted_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
558
+ adjusted_sandwich_var_estimate, np.maximum(adjusted_diagonal, 0)
560
559
  )
561
560
 
562
561
  logger.info("Writing results to file...")
563
- # Write analysis results to same directory that input files are in
564
562
  output_folder_abs_path = pathlib.Path(output_dir).resolve()
565
563
 
566
564
  analysis_dict = {
@@ -574,25 +572,31 @@ def analyze_dataset(
574
572
  f,
575
573
  )
576
574
 
577
- joint_adjusted_bread_inverse_cond = jnp.linalg.cond(
578
- raw_joint_adjusted_bread_inverse_matrix
575
+ joint_adjusted_bread_cond = jnp.linalg.cond(raw_joint_adjusted_bread_matrix)
576
+ logger.info(
577
+ "Joint adjusted bread condition number: %f",
578
+ joint_adjusted_bread_cond,
579
579
  )
580
+
581
+ # calculate the max eigenvalue of the joint adjusted sandwich
582
+ max_eigenvalue = scipy.linalg.eigvalsh(joint_adjusted_sandwich_matrix).max()
580
583
  logger.info(
581
- "Joint adjusted bread inverse condition number: %f",
582
- joint_adjusted_bread_inverse_cond,
584
+ "Max eigenvalue of joint adjusted sandwich matrix: %f",
585
+ max_eigenvalue,
583
586
  )
584
587
 
585
588
  debug_pieces_dict = {
586
589
  "theta_est": theta_est,
587
590
  "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
588
591
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
589
- "raw_joint_bread_inverse_matrix": raw_joint_adjusted_bread_inverse_matrix,
590
- "stabilized_joint_bread_inverse_matrix": stabilized_joint_adjusted_bread_inverse_matrix,
592
+ "raw_joint_bread_matrix": raw_joint_adjusted_bread_matrix,
593
+ "stabilized_joint_bread_matrix": stabilized_joint_adjusted_bread_matrix,
591
594
  "joint_meat_matrix": joint_adjusted_meat_matrix,
592
- "classical_bread_inverse_matrix": classical_bread_inverse_matrix,
595
+ "classical_bread_matrix": classical_bread_matrix,
593
596
  "classical_meat_matrix": classical_meat_matrix,
594
597
  "all_estimating_function_stacks": per_subject_estimating_function_stacks,
595
- "joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
598
+ "joint_bread_condition_number": joint_adjusted_bread_cond,
599
+ "max_eigenvalue_joint_adjusted_sandwich": max_eigenvalue,
596
600
  "all_post_update_betas": all_post_update_betas,
597
601
  "per_subject_adjusted_corrections": per_subject_adjusted_corrections,
598
602
  "per_subject_classical_corrections": per_subject_classical_corrections,
@@ -606,8 +610,8 @@ def analyze_dataset(
606
610
 
607
611
  if collect_data_for_blowup_supervised_learning:
608
612
  datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
609
- raw_joint_adjusted_bread_inverse_matrix,
610
- joint_adjusted_bread_inverse_cond,
613
+ raw_joint_adjusted_bread_matrix,
614
+ joint_adjusted_bread_cond,
611
615
  avg_estimating_function_stack,
612
616
  per_subject_estimating_function_stacks,
613
617
  all_post_update_betas,
@@ -824,7 +828,7 @@ def single_subject_weighted_estimating_function_stacker(
824
828
  Returns:
825
829
  jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
826
830
  stack.
827
- jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adaptive meat contribution.
831
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
828
832
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
829
833
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
830
834
  """
@@ -1008,10 +1012,10 @@ def single_subject_weighted_estimating_function_stacker(
1008
1012
 
1009
1013
  # 6. Return the following outputs:
1010
1014
  # a. The first is simply the weighted estimating function stack for this subject. The average
1011
- # of these is what we differentiate with respect to theta to form the inverse adaptive joint
1015
+ # of these is what we differentiate with respect to theta to form the joint
1012
1016
  # bread matrix, and we also compare that average to zero to check the estimating functions'
1013
1017
  # fidelity.
1014
- # b. The average outer product of these per-subject stacks across subjects is the adaptive joint meat
1018
+ # b. The average outer product of these per-subject stacks across subjects is the adjusted joint meat
1015
1019
  # matrix, hence the second output.
1016
1020
  # c. The third output is averaged across subjects to obtain the classical meat matrix.
1017
1021
  # d. The fourth output is averaged across subjects to obtain the inverse classical bread
@@ -1068,7 +1072,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1068
1072
  ]:
1069
1073
  """
1070
1074
  Computes the average weighted estimating function stack across all subjects, along with
1071
- auxiliary values used to construct the adaptive and classical sandwich variances.
1075
+ auxiliary values used to construct the adjusted and classical sandwich variances.
1072
1076
 
1073
1077
  Args:
1074
1078
  flattened_betas_and_theta (jnp.ndarray):
@@ -1144,7 +1148,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1144
1148
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1145
1149
  A tuple containing
1146
1150
  1. the average weighted estimating function stack
1147
- 2. the subject-level adaptive meat matrix contributions
1151
+ 2. the subject-level adjusted meat matrix contributions
1148
1152
  3. the subject-level classical meat matrix contributions
1149
1153
  4. the subject-level inverse classical bread matrix contributions
1150
1154
  5. raw per-subject weighted estimating function
@@ -1248,7 +1252,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1248
1252
  )
1249
1253
 
1250
1254
  # 5. Now we can compute the weighted estimating function stacks for all subjects
1251
- # as well as collect related values used to construct the adaptive and classical
1255
+ # as well as collect related values used to construct the adjusted and classical
1252
1256
  # sandwich variances.
1253
1257
  results = [
1254
1258
  single_subject_weighted_estimating_function_stacker(
@@ -1277,10 +1281,11 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1277
1281
 
1278
1282
  # 6. Note this strange return structure! We will differentiate the first output,
1279
1283
  # but the second tuple will be passed along without modification via has_aux=True and then used
1280
- # for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
1281
- # bread matrices. The raw per-subject stacks are also returned for debugging purposes.
1284
+ # for the estimating functions sum check, per_subject_classical_bread_contributions, and
1285
+ # classical meat and inverse read matrices. The raw per-subject stacks are also returned for
1286
+ # debugging purposes.
1282
1287
 
1283
- # Note that returning the raw stacks here as the first arguments is potentially
1288
+ # Note that returning the raw stacks here as the first argument is potentially
1284
1289
  # memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
1285
1290
  # inverse contributions are needed for something like CR2/CR3 small-sample corrections.
1286
1291
  return jnp.mean(stacks, axis=0), (
@@ -1330,7 +1335,7 @@ def construct_classical_and_adjusted_sandwiches(
1330
1335
  suppress_interactive_data_checks: bool,
1331
1336
  small_sample_correction: str,
1332
1337
  form_adjusted_meat_adjustments_explicitly: bool,
1333
- stabilize_joint_adjusted_bread_inverse: bool,
1338
+ stabilize_joint_bread: bool,
1334
1339
  analysis_df: pd.DataFrame | None,
1335
1340
  active_col_name: str | None,
1336
1341
  action_col_name: str | None,
@@ -1352,11 +1357,11 @@ def construct_classical_and_adjusted_sandwiches(
1352
1357
  jnp.ndarray[jnp.float32],
1353
1358
  ]:
1354
1359
  """
1355
- Constructs the classical and adaptive sandwich matrices, as well as various
1360
+ Constructs the classical and adjusted sandwich matrices, as well as various
1356
1361
  intermediate pieces in their consruction.
1357
1362
 
1358
1363
  This is done by computing and differentiating the average weighted estimating function stack
1359
- with respect to the betas and theta, using the resulting Jacobian to compute the inverse bread
1364
+ with respect to the betas and theta, using the resulting Jacobian to compute the bread
1360
1365
  and meat matrices, and then stably computing sandwiches.
1361
1366
 
1362
1367
  Args:
@@ -1426,13 +1431,13 @@ def construct_classical_and_adjusted_sandwiches(
1426
1431
  The type of small sample correction to apply. See SmallSampleCorrections class for
1427
1432
  options.
1428
1433
  form_adjusted_meat_adjustments_explicitly (bool):
1429
- If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
1434
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
1430
1435
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
1431
- adaptive sandwich is formed without doing this.
1432
- stabilize_joint_adjusted_bread_inverse (bool):
1433
- If True, will apply various techniques to stabilize the joint adaptive bread inverse if necessary.
1436
+ adjusted sandwich is formed without doing this.
1437
+ stabilize_joint_bread (bool):
1438
+ If True, will apply various techniques to stabilize the joint bread if necessary.
1434
1439
  analysis_df (pd.DataFrame):
1435
- The full analysis dataframe, needed if forming the adaptive meat adjustments explicitly.
1440
+ The full analysis dataframe, needed if forming the adjusted meat adjustments explicitly.
1436
1441
  active_col_name (str):
1437
1442
  The name of the column in analysis_df indicating whether a subject is active at a given decision time.
1438
1443
  action_col_name (str):
@@ -1443,25 +1448,25 @@ def construct_classical_and_adjusted_sandwiches(
1443
1448
  The name of the column in analysis_df indicating the subject ID.
1444
1449
  action_prob_func_args (tuple):
1445
1450
  The arguments to be passed to the action probability function, needed if forming the
1446
- adaptive meat adjustments explicitly.
1451
+ adjusted meat adjustments explicitly.
1447
1452
  action_prob_col_name (str):
1448
1453
  The name of the column in analysis_df indicating the action probability of the action taken,
1449
- needed if forming the adaptive meat adjustments explicitly.
1454
+ needed if forming the adjusted meat adjustments explicitly.
1450
1455
  Returns:
1451
1456
  tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
1452
1457
  A tuple containing:
1453
- - The raw joint adaptive inverse bread matrix.
1454
- - The (possibly) stabilized joint adaptive inverse bread matrix.
1455
- - The joint adaptive meat matrix.
1456
- - The joint adaptive sandwich matrix.
1457
- - The classical inverse bread matrix.
1458
+ - The raw joint bread matrix.
1459
+ - The (possibly) stabilized joint bread matrix.
1460
+ - The joint meat matrix.
1461
+ - The joint sandwich matrix.
1462
+ - The classical bread matrix.
1458
1463
  - The classical meat matrix.
1459
1464
  - The classical sandwich matrix.
1460
1465
  - The average weighted estimating function stack.
1461
1466
  - All per-subject weighted estimating function stacks.
1462
- - The per-subject adaptive meat small-sample corrections.
1467
+ - The per-subject adjusted meat small-sample corrections.
1463
1468
  - The per-subject classical meat small-sample corrections.
1464
- - The per-subject adaptive meat adjustments, if form_adjusted_meat_adjustments_explicitly
1469
+ - The per-subject adjusted meat adjustments, if form_adjusted_meat_adjustments_explicitly
1465
1470
  is True, otherwise an array of NaNs.
1466
1471
  """
1467
1472
  logger.info(
@@ -1470,11 +1475,11 @@ def construct_classical_and_adjusted_sandwiches(
1470
1475
  theta_dim = theta_est.shape[0]
1471
1476
  beta_dim = all_post_update_betas.shape[1]
1472
1477
  # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
1473
- raw_joint_adjusted_bread_inverse_matrix, (
1478
+ raw_joint_adjusted_bread_matrix, (
1474
1479
  avg_estimating_function_stack,
1475
1480
  per_subject_joint_adjusted_meat_contributions,
1476
1481
  per_subject_classical_meat_contributions,
1477
- per_subject_classical_bread_inverse_contributions,
1482
+ per_subject_classical_bread_contributions,
1478
1483
  per_subject_estimating_function_stacks,
1479
1484
  ) = jax.jacrev(
1480
1485
  get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
@@ -1521,40 +1526,38 @@ def construct_classical_and_adjusted_sandwiches(
1521
1526
  small_sample_correction,
1522
1527
  per_subject_joint_adjusted_meat_contributions,
1523
1528
  per_subject_classical_meat_contributions,
1524
- per_subject_classical_bread_inverse_contributions,
1529
+ per_subject_classical_bread_contributions,
1525
1530
  num_subjects,
1526
1531
  theta_dim,
1527
1532
  )
1528
1533
 
1529
1534
  # Increase diagonal block dominance possibly improve conditioning of diagonal
1530
- # blocks as necessary, to ensure mathematical stability of joint bread inverse
1531
- stabilized_joint_adjusted_bread_inverse_matrix = (
1535
+ # blocks as necessary, to ensure mathematical stability of joint bread
1536
+ stabilized_joint_adjusted_bread_matrix = (
1532
1537
  (
1533
- stabilize_joint_adjusted_bread_inverse_if_necessary(
1534
- raw_joint_adjusted_bread_inverse_matrix,
1538
+ stabilize_joint_bread_if_necessary(
1539
+ raw_joint_adjusted_bread_matrix,
1535
1540
  beta_dim,
1536
1541
  theta_dim,
1537
1542
  )
1538
1543
  )
1539
- if stabilize_joint_adjusted_bread_inverse
1540
- else raw_joint_adjusted_bread_inverse_matrix
1544
+ if stabilize_joint_bread
1545
+ else raw_joint_adjusted_bread_matrix
1541
1546
  )
1542
1547
 
1543
1548
  # Now stably (no explicit inversion) form our sandwiches.
1544
- joint_adjusted_sandwich = form_sandwich_from_bread_inverse_and_meat(
1545
- stabilized_joint_adjusted_bread_inverse_matrix,
1549
+ joint_adjusted_sandwich = form_sandwich_from_bread_and_meat(
1550
+ stabilized_joint_adjusted_bread_matrix,
1546
1551
  joint_adjusted_meat_matrix,
1547
1552
  num_subjects,
1548
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1549
- )
1550
- classical_bread_inverse_matrix = jnp.mean(
1551
- per_subject_classical_bread_inverse_contributions, axis=0
1553
+ method=SandwichFormationMethods.BREAD_T_QR,
1552
1554
  )
1553
- classical_sandwich = form_sandwich_from_bread_inverse_and_meat(
1554
- classical_bread_inverse_matrix,
1555
+ classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
1556
+ classical_sandwich = form_sandwich_from_bread_and_meat(
1557
+ classical_bread_matrix,
1555
1558
  classical_meat_matrix,
1556
1559
  num_subjects,
1557
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1560
+ method=SandwichFormationMethods.BREAD_T_QR,
1558
1561
  )
1559
1562
 
1560
1563
  per_subject_adjusted_meat_adjustments = jnp.full(
@@ -1565,7 +1568,7 @@ def construct_classical_and_adjusted_sandwiches(
1565
1568
  form_adjusted_meat_adjustments_directly(
1566
1569
  theta_dim,
1567
1570
  all_post_update_betas.shape[1],
1568
- stabilized_joint_adjusted_bread_inverse_matrix,
1571
+ stabilized_joint_adjusted_bread_matrix,
1569
1572
  per_subject_estimating_function_stacks,
1570
1573
  analysis_df,
1571
1574
  active_col_name,
@@ -1582,9 +1585,9 @@ def construct_classical_and_adjusted_sandwiches(
1582
1585
  action_prob_col_name,
1583
1586
  )
1584
1587
  )
1585
- # Validate that the adaptive meat adjustments we just formed are accurate by constructing
1586
- # the theta-only adaptive sandwich from them and checking that it matches the standard result
1587
- # we get by taking a subset of the joint adaptive sandwich.
1588
+ # Validate that the adjusted meat adjustments we just formed are accurate by constructing
1589
+ # the theta-only adjusted sandwich from them and checking that it matches the standard result
1590
+ # we get by taking a subset of the joint sandwich.
1588
1591
  # First just apply any small-sample correction for parity.
1589
1592
  (
1590
1593
  _,
@@ -1595,16 +1598,16 @@ def construct_classical_and_adjusted_sandwiches(
1595
1598
  small_sample_correction,
1596
1599
  per_subject_joint_adjusted_meat_contributions,
1597
1600
  per_subject_adjusted_classical_meat_contributions,
1598
- per_subject_classical_bread_inverse_contributions,
1601
+ per_subject_classical_bread_contributions,
1599
1602
  num_subjects,
1600
1603
  theta_dim,
1601
1604
  )
1602
1605
  theta_only_adjusted_sandwich_from_adjustments = (
1603
- form_sandwich_from_bread_inverse_and_meat(
1604
- classical_bread_inverse_matrix,
1606
+ form_sandwich_from_bread_and_meat(
1607
+ classical_bread_matrix,
1605
1608
  theta_only_adjusted_meat_matrix_v2,
1606
1609
  num_subjects,
1607
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1610
+ method=SandwichFormationMethods.BREAD_T_QR,
1608
1611
  )
1609
1612
  )
1610
1613
  theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
@@ -1615,17 +1618,17 @@ def construct_classical_and_adjusted_sandwiches(
1615
1618
  rtol=3e-2,
1616
1619
  ):
1617
1620
  logger.warning(
1618
- "There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adaptive sandwich two different ways and they do not match sufficiently."
1621
+ "There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adjusted sandwich two different ways and they do not match sufficiently."
1619
1622
  )
1620
1623
 
1621
- # Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
1622
- # values too. The joint adaptive bread inverse should always be block lower triangular.
1624
+ # Stack the joint bread pieces together horizontally and return the auxiliary
1625
+ # values too. The joint bread should always be block lower triangular.
1623
1626
  return (
1624
- raw_joint_adjusted_bread_inverse_matrix,
1625
- stabilized_joint_adjusted_bread_inverse_matrix,
1627
+ raw_joint_adjusted_bread_matrix,
1628
+ stabilized_joint_adjusted_bread_matrix,
1626
1629
  joint_adjusted_meat_matrix,
1627
1630
  joint_adjusted_sandwich,
1628
- classical_bread_inverse_matrix,
1631
+ classical_bread_matrix,
1629
1632
  classical_meat_matrix,
1630
1633
  classical_sandwich,
1631
1634
  avg_estimating_function_stack,
@@ -1639,25 +1642,25 @@ def construct_classical_and_adjusted_sandwiches(
1639
1642
  # TODO: I think there should be interaction to confirm stabilization. It is
1640
1643
  # important for the subject to know if this is happening. Even if enabled, it is important
1641
1644
  # that the subject know it actually kicks in.
1642
- def stabilize_joint_adjusted_bread_inverse_if_necessary(
1643
- joint_adjusted_bread_inverse_matrix: jnp.ndarray,
1645
+ def stabilize_joint_bread_if_necessary(
1646
+ joint_adjusted_bread_matrix: jnp.ndarray,
1644
1647
  beta_dim: int,
1645
1648
  theta_dim: int,
1646
1649
  ) -> jnp.ndarray:
1647
1650
  """
1648
- Stabilizes the joint adaptive bread inverse matrix if necessary by increasing diagonal block
1651
+ Stabilizes the joint bread matrix if necessary by increasing diagonal block
1649
1652
  dominance and/or adding a small ridge penalty to the diagonal blocks.
1650
1653
 
1651
1654
  Args:
1652
- joint_adjusted_bread_inverse_matrix (jnp.ndarray):
1653
- A 2-D JAX NumPy array representing the joint adaptive bread inverse matrix.
1655
+ joint_adjusted_bread_matrix (jnp.ndarray):
1656
+ A 2-D JAX NumPy array representing the joint bread matrix.
1654
1657
  beta_dim (int):
1655
1658
  The dimension of each beta parameter.
1656
1659
  theta_dim (int):
1657
1660
  The dimension of the theta parameter.
1658
1661
  Returns:
1659
1662
  jnp.ndarray:
1660
- A 2-D NumPy array representing the stabilized joint adaptive bread inverse matrix.
1663
+ A 2-D NumPy array representing the stabilized joint bread matrix.
1661
1664
  """
1662
1665
 
1663
1666
  # TODO: come up with more sophisticated settings here. These are maybe a little loose,
@@ -1670,7 +1673,7 @@ def stabilize_joint_adjusted_bread_inverse_if_necessary(
1670
1673
 
1671
1674
  # Grab just the RL block and convert numpy array for easier manipulation.
1672
1675
  RL_stack_beta_derivatives_block = np.array(
1673
- joint_adjusted_bread_inverse_matrix[:-theta_dim, :-theta_dim]
1676
+ joint_adjusted_bread_matrix[:-theta_dim, :-theta_dim]
1674
1677
  )
1675
1678
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
1676
1679
  for i in range(1, num_updates + 1):
@@ -1789,31 +1792,31 @@ def stabilize_joint_adjusted_bread_inverse_if_necessary(
1789
1792
  [
1790
1793
  [
1791
1794
  RL_stack_beta_derivatives_block,
1792
- joint_adjusted_bread_inverse_matrix[:-theta_dim, -theta_dim:],
1795
+ joint_adjusted_bread_matrix[:-theta_dim, -theta_dim:],
1793
1796
  ],
1794
1797
  [
1795
- joint_adjusted_bread_inverse_matrix[-theta_dim:, :-theta_dim],
1796
- joint_adjusted_bread_inverse_matrix[-theta_dim:, -theta_dim:],
1798
+ joint_adjusted_bread_matrix[-theta_dim:, :-theta_dim],
1799
+ joint_adjusted_bread_matrix[-theta_dim:, -theta_dim:],
1797
1800
  ],
1798
1801
  ]
1799
1802
  )
1800
1803
 
1801
1804
 
1802
- def form_sandwich_from_bread_inverse_and_meat(
1803
- bread_inverse: jnp.ndarray,
1805
+ def form_sandwich_from_bread_and_meat(
1806
+ bread: jnp.ndarray,
1804
1807
  meat: jnp.ndarray,
1805
1808
  num_subjects: int,
1806
- method: str = SandwichFormationMethods.BREAD_INVERSE_T_QR,
1809
+ method: str = SandwichFormationMethods.BREAD_T_QR,
1807
1810
  ) -> jnp.ndarray:
1808
1811
  """
1809
- Forms a sandwich variance matrix from the provided bread inverse and meat matrices.
1812
+ Forms a sandwich variance matrix from the provided bread and meat matrices.
1810
1813
 
1811
- Attempts to do so STABLY without ever forming the bread matrix itself
1814
+ Attempts to do so STABLY without ever forming the bread inverse matrix itself
1812
1815
  (except with naive option).
1813
1816
 
1814
1817
  Args:
1815
- bread_inverse (jnp.ndarray):
1816
- A 2-D JAX NumPy array representing the bread inverse matrix.
1818
+ bread (jnp.ndarray):
1819
+ A 2-D JAX NumPy array representing the bread matrix.
1817
1820
  meat (jnp.ndarray):
1818
1821
  A 2-D JAX NumPy array representing the meat matrix.
1819
1822
  num_subjects (int):
@@ -1821,12 +1824,12 @@ def form_sandwich_from_bread_inverse_and_meat(
1821
1824
  method (str):
1822
1825
  The method to use for forming the sandwich.
1823
1826
 
1824
- SandwichFormationMethods.BREAD_INVERSE_T_QR uses the QR decomposition of the transpose
1825
- of the bread inverse matrix.
1827
+ SandwichFormationMethods.BREAD_T_QR uses the QR decomposition of the transpose
1828
+ of the bread matrix.
1826
1829
 
1827
1830
  SandwichFormationMethods.MEAT_SVD_SOLVE uses a decomposition of the meat matrix.
1828
1831
 
1829
- SandwichFormationMethods.NAIVE simply inverts the bread inverse and forms the sandwich.
1832
+ SandwichFormationMethods.NAIVE simply inverts the bread and forms the sandwich.
1830
1833
 
1831
1834
 
1832
1835
  Returns:
@@ -1834,9 +1837,9 @@ def form_sandwich_from_bread_inverse_and_meat(
1834
1837
  A 2-D JAX NumPy array representing the sandwich variance matrix.
1835
1838
  """
1836
1839
 
1837
- if method == SandwichFormationMethods.BREAD_INVERSE_T_QR:
1840
+ if method == SandwichFormationMethods.BREAD_T_QR:
1838
1841
  # QR of B^T → Q orthogonal, R upper triangular; L = R^T lower triangular
1839
- Q, R = np.linalg.qr(bread_inverse.T, mode="reduced")
1842
+ Q, R = np.linalg.qr(bread.T, mode="reduced")
1840
1843
  L = R.T
1841
1844
 
1842
1845
  new_meat = scipy.linalg.solve_triangular(
@@ -1854,21 +1857,21 @@ def form_sandwich_from_bread_inverse_and_meat(
1854
1857
  C_right = Vh.T * np.sqrt(s)
1855
1858
 
1856
1859
  # Solve B W_left = C_left and B W_right = C_right (no explicit inverses).
1857
- W_left = scipy.linalg.solve(bread_inverse, C_left)
1858
- W_right = scipy.linalg.solve(bread_inverse, C_right)
1860
+ W_left = scipy.linalg.solve(bread, C_left)
1861
+ W_right = scipy.linalg.solve(bread, C_right)
1859
1862
 
1860
1863
  # Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
1861
1864
  return W_left @ W_right.T / num_subjects
1862
1865
 
1863
1866
  elif method == SandwichFormationMethods.NAIVE:
1864
- # Simply invert the bread inverse and form the sandwich directly.
1867
+ # Simply invert the bread and form the sandwich directly.
1865
1868
  # This is NOT numerically stable and is only included for comparison purposes.
1866
- bread = np.linalg.inv(bread_inverse)
1867
- return bread @ meat @ meat.T / num_subjects
1869
+ bread_inverse = np.linalg.inv(bread)
1870
+ return bread_inverse @ meat @ bread_inverse.T / num_subjects
1868
1871
 
1869
1872
  else:
1870
1873
  raise ValueError(
1871
- f"Unknown sandwich method: {method}. Please use 'bread_inverse_t_qr' or 'meat_decomposition_solve'."
1874
+ f"Unknown sandwich method: {method}. Please use 'bread_t_qr' or 'meat_decomposition_solve'."
1872
1875
  )
1873
1876
 
1874
1877