lifejacket 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,6 +53,8 @@ logging.basicConfig(
53
53
  level=logging.INFO,
54
54
  )
55
55
 
56
+ jax.config.update("jax_enable_x64", True)
57
+
56
58
 
57
59
  @click.group()
58
60
  def cli():
@@ -217,9 +219,9 @@ def cli():
217
219
  type=click.Choice(
218
220
  [
219
221
  SmallSampleCorrections.NONE,
220
- SmallSampleCorrections.HC1theta,
221
- SmallSampleCorrections.HC2theta,
222
- SmallSampleCorrections.HC3theta,
222
+ SmallSampleCorrections.Z1theta,
223
+ SmallSampleCorrections.Z2theta,
224
+ SmallSampleCorrections.Z3theta,
223
225
  ]
224
226
  ),
225
227
  default=SmallSampleCorrections.NONE,
@@ -235,13 +237,13 @@ def cli():
235
237
  "--form_adjusted_meat_adjustments_explicitly",
236
238
  type=bool,
237
239
  default=False,
238
- help="If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
240
+ help="If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted sandwich from the classical sandwich. This is for diagnostic purposes, as the adjusted sandwich is formed without doing this.",
239
241
  )
240
242
  @click.option(
241
- "--stabilize_joint_adjusted_bread_inverse",
243
+ "--stabilize_joint_bread",
242
244
  type=bool,
243
245
  default=True,
244
- help="If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning thresholds.",
246
+ help="If True, stabilizes the joint bread matrix if it does not meet conditioning thresholds.",
245
247
  )
246
248
  def analyze_dataset_wrapper(**kwargs):
247
249
  """
@@ -324,15 +326,15 @@ def analyze_dataset(
324
326
  small_sample_correction: str,
325
327
  collect_data_for_blowup_supervised_learning: bool,
326
328
  form_adjusted_meat_adjustments_explicitly: bool,
327
- stabilize_joint_adjusted_bread_inverse: bool,
329
+ stabilize_joint_bread: bool,
328
330
  ) -> None:
329
331
  """
330
- Analyzes a dataset to provide a parameter estimate and an estimate of its variance using adaptive and classical sandwich estimators.
332
+ Analyzes a dataset to provide a parameter estimate and an estimate of its variance using and classical sandwich estimators.
331
333
 
332
334
  There are two modes of use for this function.
333
335
 
334
336
  First, it may be called indirectly from the command line by passing through
335
- analyze_dataset.
337
+ analyze_dataset_wrapper.
336
338
 
337
339
  Second, it may be called directly from Python code with in-memory objects.
338
340
 
@@ -388,17 +390,17 @@ def analyze_dataset(
388
390
  small_sample_correction (str):
389
391
  Type of small sample correction to apply.
390
392
  collect_data_for_blowup_supervised_learning (bool):
391
- Whether to collect data for doing supervised learning about adaptive sandwich blowup.
393
+ Whether to collect data for doing supervised learning about adjusted sandwich blowup.
392
394
  form_adjusted_meat_adjustments_explicitly (bool):
393
- If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
395
+ If True, explicitly forms the per-subject meat adjustments that differentiate the
394
396
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
395
- adaptive sandwich is formed without doing this.
396
- stabilize_joint_adjusted_bread_inverse (bool):
397
- If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
397
+ adjusted sandwich is formed without doing this.
398
+ stabilize_joint_bread (bool):
399
+ If True, stabilizes the joint bread matrix if it does not meet conditioning
398
400
  thresholds.
399
401
 
400
402
  Returns:
401
- dict: A dictionary containing the theta estimate, adaptive sandwich variance estimate, and
403
+ dict: A dictionary containing the theta estimate, adjusted sandwich variance estimate, and
402
404
  classical sandwich variance estimate.
403
405
  """
404
406
 
@@ -438,7 +440,6 @@ def analyze_dataset(
438
440
  )
439
441
 
440
442
  ### Begin collecting data structures that will be used to compute the joint bread matrix.
441
-
442
443
  beta_index_by_policy_num, initial_policy_num = (
443
444
  construct_beta_index_by_policy_num_map(
444
445
  analysis_df, policy_num_col_name, active_col_name
@@ -475,20 +476,20 @@ def analyze_dataset(
475
476
  active_col_name,
476
477
  )
477
478
 
478
- # Use a per-subject weighted estimating function stacking functino to derive classical and joint
479
- # adaptive meat and inverse bread matrices. This is facilitated because the *value* of the
479
+ # Use a per-subject weighted estimating function stacking function to derive classical and joint
480
+ # meat and bread matrices. This is facilitated because the *value* of the
480
481
  # weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
481
482
  logger.info(
482
- "Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across subjects."
483
+ "Constructing joint bread matrix, joint meat matrix, the classical analogs, and the avg estimating function stack across subjects."
483
484
  )
484
485
 
485
486
  subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
486
487
  (
487
- stabilized_joint_adjusted_bread_inverse_matrix,
488
- raw_joint_adjusted_bread_inverse_matrix,
488
+ stabilized_joint_bread_matrix,
489
+ raw_joint_bread_matrix,
489
490
  joint_adjusted_meat_matrix,
490
- joint_adjusted_sandwich_matrix,
491
- classical_bread_inverse_matrix,
491
+ joint_sandwich_matrix,
492
+ classical_bread_matrix,
492
493
  classical_meat_matrix,
493
494
  classical_sandwich_var_estimate,
494
495
  avg_estimating_function_stack,
@@ -524,7 +525,7 @@ def analyze_dataset(
524
525
  suppress_interactive_data_checks,
525
526
  small_sample_correction,
526
527
  form_adjusted_meat_adjustments_explicitly,
527
- stabilize_joint_adjusted_bread_inverse,
528
+ stabilize_joint_bread,
528
529
  analysis_df,
529
530
  active_col_name,
530
531
  action_col_name,
@@ -545,22 +546,19 @@ def analyze_dataset(
545
546
 
546
547
  # This bottom right corner of the joint (betas and theta) variance matrix is the portion
547
548
  # corresponding to just theta.
548
- adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
549
- -theta_dim:, -theta_dim:
550
- ]
549
+ adjusted_sandwich_var_estimate = joint_sandwich_matrix[-theta_dim:, -theta_dim:]
551
550
 
552
551
  # Check for negative diagonal elements and set them to zero if found
553
- adaptive_diagonal = np.diag(adjusted_sandwich_var_estimate)
554
- if np.any(adaptive_diagonal < 0):
552
+ adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
553
+ if np.any(adjusted_diagonal < 0):
555
554
  logger.warning(
556
- "Found negative diagonal elements in adaptive sandwich variance estimate. Setting them to zero."
555
+ "Found negative diagonal elements in adjusted sandwich variance estimate. Setting them to zero."
557
556
  )
558
557
  np.fill_diagonal(
559
- adjusted_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
558
+ adjusted_sandwich_var_estimate, np.maximum(adjusted_diagonal, 0)
560
559
  )
561
560
 
562
561
  logger.info("Writing results to file...")
563
- # Write analysis results to same directory that input files are in
564
562
  output_folder_abs_path = pathlib.Path(output_dir).resolve()
565
563
 
566
564
  analysis_dict = {
@@ -574,25 +572,229 @@ def analyze_dataset(
574
572
  f,
575
573
  )
576
574
 
577
- joint_adjusted_bread_inverse_cond = jnp.linalg.cond(
578
- raw_joint_adjusted_bread_inverse_matrix
575
+ joint_bread_cond = jnp.linalg.cond(raw_joint_bread_matrix)
576
+ logger.info(
577
+ "Joint bread condition number: %f",
578
+ joint_bread_cond,
579
+ )
580
+
581
+ # calculate the max eigenvalue of the theta-only adjusted sandwich
582
+ max_eigenvalue_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
583
+ adjusted_sandwich_var_estimate
584
+ ).max()
585
+ logger.info(
586
+ "Max eigenvalue of theta-only adjusted sandwich matrix: %f",
587
+ max_eigenvalue_theta_only_adjusted_sandwich,
588
+ )
589
+
590
+ # Compute ratios: max eigenvalue / median eigenvalue among those >= 1e-8 * max.
591
+ eigvals_joint_sandwich = scipy.linalg.eigvalsh(joint_sandwich_matrix)
592
+ max_eig_joint = float(eigvals_joint_sandwich.max())
593
+ logger.info(
594
+ "Max eigenvalue of joint adjusted sandwich matrix: %f",
595
+ max_eig_joint,
596
+ )
597
+
598
+ joint_keep = eigvals_joint_sandwich >= (1e-8 * max_eig_joint)
599
+ joint_median_kept = (
600
+ float(np.median(eigvals_joint_sandwich[joint_keep]))
601
+ if np.any(joint_keep)
602
+ else math.nan
603
+ )
604
+ max_to_median_ratio_joint_sandwich = (
605
+ (max_eig_joint / joint_median_kept)
606
+ if (not math.isnan(joint_median_kept) and joint_median_kept > 0)
607
+ else (
608
+ math.inf
609
+ if (not math.isnan(joint_median_kept) and joint_median_kept == 0)
610
+ else math.nan
611
+ )
579
612
  )
580
613
  logger.info(
581
- "Joint adjusted bread inverse condition number: %f",
582
- joint_adjusted_bread_inverse_cond,
614
+ "Max/median eigenvalue ratio (joint sandwich; median over eigvals >= 1e-8*max): %f",
615
+ max_to_median_ratio_joint_sandwich,
616
+ )
617
+
618
+ eigvals_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
619
+ adjusted_sandwich_var_estimate
583
620
  )
621
+ max_eig_theta = float(eigvals_theta_only_adjusted_sandwich.max())
622
+ theta_keep = eigvals_theta_only_adjusted_sandwich >= (1e-8 * max_eig_theta)
623
+ theta_median_kept = (
624
+ float(np.median(eigvals_theta_only_adjusted_sandwich[theta_keep]))
625
+ if np.any(theta_keep)
626
+ else math.nan
627
+ )
628
+ max_to_median_ratio_theta_only_adjusted_sandwich = (
629
+ (max_eig_theta / theta_median_kept)
630
+ if (not math.isnan(theta_median_kept) and theta_median_kept > 0)
631
+ else (
632
+ math.inf
633
+ if (not math.isnan(theta_median_kept) and theta_median_kept == 0)
634
+ else math.nan
635
+ )
636
+ )
637
+ logger.info(
638
+ "Max/median eigenvalue ratio (theta-only adjusted sandwich; median over eigvals >= 1e-8*max): %f",
639
+ max_to_median_ratio_theta_only_adjusted_sandwich,
640
+ )
641
+
642
+ # --- Local linearization validity diagnostic (single-run) ---
643
+ # We compare the nonlinear Taylor remainder of the joint estimating-function map to the
644
+ # retained linear term, at perturbations on the O(1/sqrt(n)) scale.
645
+ #
646
+ # Define r(delta) = || g(eta+delta) - g(eta) - B delta ||_2 / || B delta ||_2,
647
+ # where g(eta) is the avg per-subject weighted estimating-function stack and B is the
648
+ # stabilized joint bread (Jacobian of g w.r.t. flattened betas+theta).
649
+ #
650
+ # This ratio is dimensionless and can be used as a necessary/sanity diagnostic that the
651
+ # first-order linearization is locally accurate at the estimation scale.
652
+
653
+ def _compute_local_linearization_error_ratio() -> tuple[float, float]:
654
+ # Ensure float64 for diagnostics even if upstream ran in float32.
655
+ joint_bread_float64 = jnp.asarray(
656
+ stabilized_joint_bread_matrix, dtype=jnp.float64
657
+ )
658
+ g_hat = jnp.asarray(avg_estimating_function_stack, dtype=jnp.float64)
659
+ stacks_float64 = jnp.asarray(
660
+ per_subject_estimating_function_stacks, dtype=jnp.float64
661
+ )
662
+
663
+ num_subjects = stacks_float64.shape[0]
664
+
665
+ def _eval_avg_stack_jit(flattened_betas_and_theta: jnp.ndarray) -> jnp.ndarray:
666
+ return jnp.asarray(
667
+ get_avg_weighted_estimating_function_stacks_and_aux_values(
668
+ flattened_betas_and_theta,
669
+ beta_dim,
670
+ theta_dim,
671
+ subject_ids,
672
+ action_prob_func,
673
+ action_prob_func_args_beta_index,
674
+ alg_update_func,
675
+ alg_update_func_type,
676
+ alg_update_func_args_beta_index,
677
+ alg_update_func_args_action_prob_index,
678
+ alg_update_func_args_action_prob_times_index,
679
+ alg_update_func_args_previous_betas_index,
680
+ inference_func,
681
+ inference_func_type,
682
+ inference_func_args_theta_index,
683
+ inference_func_args_action_prob_index,
684
+ action_prob_func_args,
685
+ policy_num_by_decision_time_by_subject_id,
686
+ initial_policy_num,
687
+ beta_index_by_policy_num,
688
+ inference_func_args_by_subject_id,
689
+ inference_action_prob_decision_times_by_subject_id,
690
+ alg_update_func_args,
691
+ action_by_decision_time_by_subject_id,
692
+ True, # suppress_all_data_checks
693
+ True, # suppress_interactive_data_checks
694
+ False, # include_auxiliary_outputs
695
+ ),
696
+ dtype=jnp.float64,
697
+ )
698
+
699
+ # Evaluate at the final estimate.
700
+ eta_hat = jnp.asarray(
701
+ flatten_params(all_post_update_betas, theta_est), dtype=jnp.float64
702
+ )
703
+
704
+ # Draw perturbations delta_j on the O(1/sqrt(n)) scale, aligned with the empirical
705
+ # joint estimating function stack covariance, without forming a d_joint x d_joint matrix
706
+ # square-root. If G is the (n x d) matrix of per-subject stacks, then (1/n) G^T G is the
707
+ # empirical covariance in joint estimating function stack space. Sampling u = (G^T w)/sqrt(n) with w~N(0, I_n) gives
708
+ # u ~ N(0, empirical joint estimating function stack covariance G^T G/n ) in joint estimating function stack space.
709
+ key = jax.random.PRNGKey(0)
710
+
711
+ # The number of perturbations we will probe
712
+ J = 15
713
+ # Each requires num_subjects standard normal draws, which we will then transform
714
+ # into joint estimating function space perturbations in U
715
+ W = jax.random.normal(key, shape=(J, num_subjects), dtype=jnp.float64)
716
+
717
+ # Joint estimating function space perturbations: u_j in R^{d_joint}
718
+ # U = (1/sqrt(n)) * W G, where rows of G are g_i^T
719
+ U = (W @ stacks_float64) / jnp.sqrt(num_subjects)
720
+
721
+ # Parameter perturbations: delta = (c/sqrt(n)) * B^{-1} u
722
+ # Use solve rather than explicit inverse.
723
+ c = 1.0
724
+ delta = (c / jnp.sqrt(num_subjects)) * jnp.linalg.solve(
725
+ joint_bread_float64, U.T
726
+ ).T
727
+
728
+ # Compute ratios r_j.
729
+ # NOTE: We use the Euclidean norm in score space; this is dimensionless and avoids
730
+ # forming/pseudoinverting a potentially rank-deficient matrix.
731
+ B_delta = (joint_bread_float64 @ delta.T).T
732
+ g_plus = jax.vmap(lambda d: _eval_avg_stack_jit(eta_hat + d))(delta)
733
+ remainder = g_plus - g_hat - B_delta
734
+
735
+ denom = jnp.linalg.norm(B_delta, axis=1)
736
+ numer = jnp.linalg.norm(remainder, axis=1)
737
+
738
+ # Avoid division by zero (should not happen unless delta collapses numerically).
739
+ ratios = jnp.where(denom > 0, numer / denom, jnp.inf)
740
+
741
+ local_error_ratio_median = float(jnp.median(ratios))
742
+ local_error_ratio_p90 = float(jnp.quantile(ratios, 0.9))
743
+ local_error_ratio_max = float(jnp.max(ratios))
744
+
745
+ logger.info(
746
+ "Local linearization error ratio (median over %d draws): %.6f",
747
+ J,
748
+ local_error_ratio_median,
749
+ )
750
+ logger.info(
751
+ "Local linearization error ratio (90th pct over %d draws): %.6f",
752
+ J,
753
+ local_error_ratio_p90,
754
+ )
755
+
756
+ logger.info(
757
+ "Local linearization error ratio (max over %d draws): %.6f",
758
+ J,
759
+ local_error_ratio_max,
760
+ )
761
+
762
+ return local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max
763
+
764
+ try:
765
+ local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max = (
766
+ _compute_local_linearization_error_ratio()
767
+ )
768
+ except Exception as e:
769
+ # This diagnostic is best-effort; failure should not break analysis.
770
+ logger.warning(
771
+ "Failed to compute local linearization error ratio diagnostic: %s",
772
+ str(e),
773
+ )
774
+ local_error_ratio_median = math.nan
775
+ local_error_ratio_p90 = math.nan
776
+ local_error_ratio_max = math.nan
584
777
 
585
778
  debug_pieces_dict = {
586
779
  "theta_est": theta_est,
587
780
  "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
588
781
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
589
- "raw_joint_bread_inverse_matrix": raw_joint_adjusted_bread_inverse_matrix,
590
- "stabilized_joint_bread_inverse_matrix": stabilized_joint_adjusted_bread_inverse_matrix,
782
+ "raw_joint_bread_matrix": raw_joint_bread_matrix,
783
+ "stabilized_joint_bread_matrix": stabilized_joint_bread_matrix,
591
784
  "joint_meat_matrix": joint_adjusted_meat_matrix,
592
- "classical_bread_inverse_matrix": classical_bread_inverse_matrix,
785
+ "classical_bread_matrix": classical_bread_matrix,
593
786
  "classical_meat_matrix": classical_meat_matrix,
594
787
  "all_estimating_function_stacks": per_subject_estimating_function_stacks,
595
- "joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
788
+ "joint_bread_condition_number": joint_bread_cond,
789
+ "max_eigenvalue_joint_sandwich": max_eig_joint,
790
+ "all_eigenvalues_joint_sandwich": eigvals_joint_sandwich,
791
+ "max_to_median_ratio_joint_sandwich": max_to_median_ratio_joint_sandwich,
792
+ "max_eigenvalue_theta_only_adjusted_sandwich": max_eig_theta,
793
+ "all_eigenvalues_theta_only_adjusted_sandwich": eigvals_theta_only_adjusted_sandwich,
794
+ "max_to_median_ratio_theta_only_adjusted_sandwich": max_to_median_ratio_theta_only_adjusted_sandwich,
795
+ "local_linearization_error_ratio_median": local_error_ratio_median,
796
+ "local_linearization_error_ratio_p90": local_error_ratio_p90,
797
+ "local_linearization_error_ratio_max": local_error_ratio_max,
596
798
  "all_post_update_betas": all_post_update_betas,
597
799
  "per_subject_adjusted_corrections": per_subject_adjusted_corrections,
598
800
  "per_subject_classical_corrections": per_subject_classical_corrections,
@@ -606,8 +808,8 @@ def analyze_dataset(
606
808
 
607
809
  if collect_data_for_blowup_supervised_learning:
608
810
  datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
609
- raw_joint_adjusted_bread_inverse_matrix,
610
- joint_adjusted_bread_inverse_cond,
811
+ raw_joint_bread_matrix,
812
+ joint_bread_cond,
611
813
  avg_estimating_function_stack,
612
814
  per_subject_estimating_function_stacks,
613
815
  all_post_update_betas,
@@ -752,12 +954,16 @@ def single_subject_weighted_estimating_function_stacker(
752
954
  policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
753
955
  action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
754
956
  beta_index_by_policy_num: dict[int | float, int],
755
- ) -> tuple[
756
- jnp.ndarray[jnp.float32],
757
- jnp.ndarray[jnp.float32],
758
- jnp.ndarray[jnp.float32],
759
- jnp.ndarray[jnp.float32],
760
- ]:
957
+ include_auxiliary_outputs: bool = True,
958
+ ) -> (
959
+ tuple[
960
+ jnp.ndarray[jnp.float32],
961
+ jnp.ndarray[jnp.float32],
962
+ jnp.ndarray[jnp.float32],
963
+ jnp.ndarray[jnp.float32],
964
+ ]
965
+ | jnp.ndarray[jnp.float32]
966
+ ):
761
967
  """
762
968
  Computes a weighted estimating function stack for a given algorithm estimating function
763
969
  and arguments, inference estimating functio and arguments, and action probability function and
@@ -821,12 +1027,23 @@ def single_subject_weighted_estimating_function_stacker(
821
1027
  A dictionary mapping policy numbers to the index of the corresponding beta in
822
1028
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
823
1029
 
1030
+ include_auxiliary_outputs (bool):
1031
+ If True, returns the adjusted meat, classical meat, and classical bread contributions in
1032
+ a second returned tuple. If False, only returns the weighted estimating function stack.
1033
+
824
1034
  Returns:
825
1035
  jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
826
1036
  stack.
827
- jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adaptive meat contribution.
1037
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
828
1038
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
829
1039
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
1040
+
1041
+ or
1042
+
1043
+ jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
1044
+ stack.
1045
+
1046
+ depending on the value of include_auxiliary_outputs.
830
1047
  """
831
1048
 
832
1049
  logger.info(
@@ -1008,22 +1225,26 @@ def single_subject_weighted_estimating_function_stacker(
1008
1225
 
1009
1226
  # 6. Return the following outputs:
1010
1227
  # a. The first is simply the weighted estimating function stack for this subject. The average
1011
- # of these is what we differentiate with respect to theta to form the inverse adaptive joint
1228
+ # of these is what we differentiate with respect to theta to form the joint
1012
1229
  # bread matrix, and we also compare that average to zero to check the estimating functions'
1013
1230
  # fidelity.
1014
- # b. The average outer product of these per-subject stacks across subjects is the adaptive joint meat
1231
+ # b. The average outer product of these per-subject stacks across subjects is the adjusted joint meat
1015
1232
  # matrix, hence the second output.
1016
1233
  # c. The third output is averaged across subjects to obtain the classical meat matrix.
1017
1234
  # d. The fourth output is averaged across subjects to obtain the inverse classical bread
1018
1235
  # matrix.
1019
- return (
1020
- weighted_stack,
1021
- jnp.outer(weighted_stack, weighted_stack),
1022
- jnp.outer(inference_component, inference_component),
1023
- jax.jacrev(inference_estimating_func, argnums=inference_func_args_theta_index)(
1024
- *threaded_inference_func_args
1025
- ),
1026
- )
1236
+ if include_auxiliary_outputs:
1237
+ return (
1238
+ weighted_stack,
1239
+ jnp.outer(weighted_stack, weighted_stack),
1240
+ jnp.outer(inference_component, inference_component),
1241
+ jax.jacrev(
1242
+ inference_estimating_func, argnums=inference_func_args_theta_index
1243
+ )(*threaded_inference_func_args),
1244
+ )
1245
+
1246
+ else:
1247
+ return weighted_stack
1027
1248
 
1028
1249
 
1029
1250
  def get_avg_weighted_estimating_function_stacks_and_aux_values(
@@ -1063,12 +1284,13 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1063
1284
  ],
1064
1285
  suppress_all_data_checks: bool,
1065
1286
  suppress_interactive_data_checks: bool,
1287
+ include_auxiliary_outputs: bool = True,
1066
1288
  ) -> tuple[
1067
1289
  jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
1068
1290
  ]:
1069
1291
  """
1070
1292
  Computes the average weighted estimating function stack across all subjects, along with
1071
- auxiliary values used to construct the adaptive and classical sandwich variances.
1293
+ auxiliary values used to construct the adjusted and classical sandwich variances.
1072
1294
 
1073
1295
  Args:
1074
1296
  flattened_betas_and_theta (jnp.ndarray):
@@ -1137,18 +1359,26 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1137
1359
  If True, suppresses interactive data checks that would otherwise be performed to ensure
1138
1360
  the correctness of the threaded arguments. The checks are still performed, but
1139
1361
  any interactive prompts are suppressed.
1362
+ include_auxiliary_outputs (bool):
1363
+ If True, returns the adjusted meat, classical meat, and classical bread contributions in addition to the average weighted estimating function stack.
1364
+ If False, returns only the average weighted estimating function stack.
1140
1365
 
1141
1366
  Returns:
1142
1367
  jnp.ndarray:
1143
1368
  A 2D JAX NumPy array holding the average weighted estimating function stack.
1369
+
1144
1370
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1145
1371
  A tuple containing
1146
1372
  1. the average weighted estimating function stack
1147
- 2. the subject-level adaptive meat matrix contributions
1373
+ 2. the subject-level adjusted meat matrix contributions
1148
1374
  3. the subject-level classical meat matrix contributions
1149
1375
  4. the subject-level inverse classical bread matrix contributions
1150
1376
  5. raw per-subject weighted estimating function
1151
1377
  stacks.
1378
+ or jnp.ndarray:
1379
+ A 1-D JAX NumPy array representing the subject's weighted estimating function
1380
+ stack.
1381
+ depending on the value of include_auxiliary_outputs.
1152
1382
  """
1153
1383
 
1154
1384
  # 1. Collect estimating functions by differentiating the loss functions if needed.
@@ -1248,7 +1478,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1248
1478
  )
1249
1479
 
1250
1480
  # 5. Now we can compute the weighted estimating function stacks for all subjects
1251
- # as well as collect related values used to construct the adaptive and classical
1481
+ # as well as collect related values used to construct the adjusted and classical
1252
1482
  # sandwich variances.
1253
1483
  results = [
1254
1484
  single_subject_weighted_estimating_function_stacker(
@@ -1271,16 +1501,21 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1271
1501
  ]
1272
1502
 
1273
1503
  stacks = jnp.array([result[0] for result in results])
1504
+
1505
+ if not include_auxiliary_outputs:
1506
+ return jnp.mean(stacks, axis=0)
1507
+
1274
1508
  outer_products = jnp.array([result[1] for result in results])
1275
1509
  inference_only_outer_products = jnp.array([result[2] for result in results])
1276
1510
  inference_hessians = jnp.array([result[3] for result in results])
1277
1511
 
1278
1512
  # 6. Note this strange return structure! We will differentiate the first output,
1279
1513
  # but the second tuple will be passed along without modification via has_aux=True and then used
1280
- # for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
1281
- # bread matrices. The raw per-subject stacks are also returned for debugging purposes.
1514
+ # for the estimating functions sum check, per_subject_classical_bread_contributions, and
1515
+ # classical meat and inverse read matrices. The raw per-subject stacks are also returned for
1516
+ # debugging purposes.
1282
1517
 
1283
- # Note that returning the raw stacks here as the first arguments is potentially
1518
+ # Note that returning the raw stacks here as the first argument is potentially
1284
1519
  # memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
1285
1520
  # inverse contributions are needed for something like CR2/CR3 small-sample corrections.
1286
1521
  return jnp.mean(stacks, axis=0), (
@@ -1330,7 +1565,7 @@ def construct_classical_and_adjusted_sandwiches(
1330
1565
  suppress_interactive_data_checks: bool,
1331
1566
  small_sample_correction: str,
1332
1567
  form_adjusted_meat_adjustments_explicitly: bool,
1333
- stabilize_joint_adjusted_bread_inverse: bool,
1568
+ stabilize_joint_bread: bool,
1334
1569
  analysis_df: pd.DataFrame | None,
1335
1570
  active_col_name: str | None,
1336
1571
  action_col_name: str | None,
@@ -1352,11 +1587,11 @@ def construct_classical_and_adjusted_sandwiches(
1352
1587
  jnp.ndarray[jnp.float32],
1353
1588
  ]:
1354
1589
  """
1355
- Constructs the classical and adaptive sandwich matrices, as well as various
1590
+ Constructs the classical and adjusted sandwich matrices, as well as various
1356
1591
  intermediate pieces in their consruction.
1357
1592
 
1358
1593
  This is done by computing and differentiating the average weighted estimating function stack
1359
- with respect to the betas and theta, using the resulting Jacobian to compute the inverse bread
1594
+ with respect to the betas and theta, using the resulting Jacobian to compute the bread
1360
1595
  and meat matrices, and then stably computing sandwiches.
1361
1596
 
1362
1597
  Args:
@@ -1426,13 +1661,13 @@ def construct_classical_and_adjusted_sandwiches(
1426
1661
  The type of small sample correction to apply. See SmallSampleCorrections class for
1427
1662
  options.
1428
1663
  form_adjusted_meat_adjustments_explicitly (bool):
1429
- If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
1664
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
1430
1665
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
1431
- adaptive sandwich is formed without doing this.
1432
- stabilize_joint_adjusted_bread_inverse (bool):
1433
- If True, will apply various techniques to stabilize the joint adaptive bread inverse if necessary.
1666
+ adjusted sandwich is formed without doing this.
1667
+ stabilize_joint_bread (bool):
1668
+ If True, will apply various techniques to stabilize the joint bread if necessary.
1434
1669
  analysis_df (pd.DataFrame):
1435
- The full analysis dataframe, needed if forming the adaptive meat adjustments explicitly.
1670
+ The full analysis dataframe, needed if forming the adjusted meat adjustments explicitly.
1436
1671
  active_col_name (str):
1437
1672
  The name of the column in analysis_df indicating whether a subject is active at a given decision time.
1438
1673
  action_col_name (str):
@@ -1443,25 +1678,25 @@ def construct_classical_and_adjusted_sandwiches(
1443
1678
  The name of the column in analysis_df indicating the subject ID.
1444
1679
  action_prob_func_args (tuple):
1445
1680
  The arguments to be passed to the action probability function, needed if forming the
1446
- adaptive meat adjustments explicitly.
1681
+ adjusted meat adjustments explicitly.
1447
1682
  action_prob_col_name (str):
1448
1683
  The name of the column in analysis_df indicating the action probability of the action taken,
1449
- needed if forming the adaptive meat adjustments explicitly.
1684
+ needed if forming the adjusted meat adjustments explicitly.
1450
1685
  Returns:
1451
1686
  tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
1452
1687
  A tuple containing:
1453
- - The raw joint adaptive inverse bread matrix.
1454
- - The (possibly) stabilized joint adaptive inverse bread matrix.
1455
- - The joint adaptive meat matrix.
1456
- - The joint adaptive sandwich matrix.
1457
- - The classical inverse bread matrix.
1688
+ - The raw joint bread matrix.
1689
+ - The (possibly) stabilized joint bread matrix.
1690
+ - The joint meat matrix.
1691
+ - The joint sandwich matrix.
1692
+ - The classical bread matrix.
1458
1693
  - The classical meat matrix.
1459
1694
  - The classical sandwich matrix.
1460
1695
  - The average weighted estimating function stack.
1461
1696
  - All per-subject weighted estimating function stacks.
1462
- - The per-subject adaptive meat small-sample corrections.
1697
+ - The per-subject adjusted meat small-sample corrections.
1463
1698
  - The per-subject classical meat small-sample corrections.
1464
- - The per-subject adaptive meat adjustments, if form_adjusted_meat_adjustments_explicitly
1699
+ - The per-subject adjusted meat adjustments, if form_adjusted_meat_adjustments_explicitly
1465
1700
  is True, otherwise an array of NaNs.
1466
1701
  """
1467
1702
  logger.info(
@@ -1470,11 +1705,11 @@ def construct_classical_and_adjusted_sandwiches(
1470
1705
  theta_dim = theta_est.shape[0]
1471
1706
  beta_dim = all_post_update_betas.shape[1]
1472
1707
  # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
1473
- raw_joint_adjusted_bread_inverse_matrix, (
1708
+ raw_joint_bread_matrix, (
1474
1709
  avg_estimating_function_stack,
1475
1710
  per_subject_joint_adjusted_meat_contributions,
1476
1711
  per_subject_classical_meat_contributions,
1477
- per_subject_classical_bread_inverse_contributions,
1712
+ per_subject_classical_bread_contributions,
1478
1713
  per_subject_estimating_function_stacks,
1479
1714
  ) = jax.jacrev(
1480
1715
  get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
@@ -1521,40 +1756,38 @@ def construct_classical_and_adjusted_sandwiches(
1521
1756
  small_sample_correction,
1522
1757
  per_subject_joint_adjusted_meat_contributions,
1523
1758
  per_subject_classical_meat_contributions,
1524
- per_subject_classical_bread_inverse_contributions,
1759
+ per_subject_classical_bread_contributions,
1525
1760
  num_subjects,
1526
1761
  theta_dim,
1527
1762
  )
1528
1763
 
1529
1764
  # Increase diagonal block dominance possibly improve conditioning of diagonal
1530
- # blocks as necessary, to ensure mathematical stability of joint bread inverse
1531
- stabilized_joint_adjusted_bread_inverse_matrix = (
1765
+ # blocks as necessary, to ensure mathematical stability of joint bread
1766
+ stabilized_joint_bread_matrix = (
1532
1767
  (
1533
- stabilize_joint_adjusted_bread_inverse_if_necessary(
1534
- raw_joint_adjusted_bread_inverse_matrix,
1768
+ stabilize_joint_bread_if_necessary(
1769
+ raw_joint_bread_matrix,
1535
1770
  beta_dim,
1536
1771
  theta_dim,
1537
1772
  )
1538
1773
  )
1539
- if stabilize_joint_adjusted_bread_inverse
1540
- else raw_joint_adjusted_bread_inverse_matrix
1774
+ if stabilize_joint_bread
1775
+ else raw_joint_bread_matrix
1541
1776
  )
1542
1777
 
1543
1778
  # Now stably (no explicit inversion) form our sandwiches.
1544
- joint_adjusted_sandwich = form_sandwich_from_bread_inverse_and_meat(
1545
- stabilized_joint_adjusted_bread_inverse_matrix,
1779
+ joint_sandwich = form_sandwich_from_bread_and_meat(
1780
+ stabilized_joint_bread_matrix,
1546
1781
  joint_adjusted_meat_matrix,
1547
1782
  num_subjects,
1548
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1549
- )
1550
- classical_bread_inverse_matrix = jnp.mean(
1551
- per_subject_classical_bread_inverse_contributions, axis=0
1783
+ method=SandwichFormationMethods.BREAD_T_QR,
1552
1784
  )
1553
- classical_sandwich = form_sandwich_from_bread_inverse_and_meat(
1554
- classical_bread_inverse_matrix,
1785
+ classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
1786
+ classical_sandwich = form_sandwich_from_bread_and_meat(
1787
+ classical_bread_matrix,
1555
1788
  classical_meat_matrix,
1556
1789
  num_subjects,
1557
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1790
+ method=SandwichFormationMethods.BREAD_T_QR,
1558
1791
  )
1559
1792
 
1560
1793
  per_subject_adjusted_meat_adjustments = jnp.full(
@@ -1565,7 +1798,7 @@ def construct_classical_and_adjusted_sandwiches(
1565
1798
  form_adjusted_meat_adjustments_directly(
1566
1799
  theta_dim,
1567
1800
  all_post_update_betas.shape[1],
1568
- stabilized_joint_adjusted_bread_inverse_matrix,
1801
+ stabilized_joint_bread_matrix,
1569
1802
  per_subject_estimating_function_stacks,
1570
1803
  analysis_df,
1571
1804
  active_col_name,
@@ -1582,9 +1815,9 @@ def construct_classical_and_adjusted_sandwiches(
1582
1815
  action_prob_col_name,
1583
1816
  )
1584
1817
  )
1585
- # Validate that the adaptive meat adjustments we just formed are accurate by constructing
1586
- # the theta-only adaptive sandwich from them and checking that it matches the standard result
1587
- # we get by taking a subset of the joint adaptive sandwich.
1818
+ # Validate that the adjusted meat adjustments we just formed are accurate by constructing
1819
+ # the theta-only adjusted sandwich from them and checking that it matches the standard result
1820
+ # we get by taking a subset of the joint sandwich.
1588
1821
  # First just apply any small-sample correction for parity.
1589
1822
  (
1590
1823
  _,
@@ -1595,19 +1828,19 @@ def construct_classical_and_adjusted_sandwiches(
1595
1828
  small_sample_correction,
1596
1829
  per_subject_joint_adjusted_meat_contributions,
1597
1830
  per_subject_adjusted_classical_meat_contributions,
1598
- per_subject_classical_bread_inverse_contributions,
1831
+ per_subject_classical_bread_contributions,
1599
1832
  num_subjects,
1600
1833
  theta_dim,
1601
1834
  )
1602
1835
  theta_only_adjusted_sandwich_from_adjustments = (
1603
- form_sandwich_from_bread_inverse_and_meat(
1604
- classical_bread_inverse_matrix,
1836
+ form_sandwich_from_bread_and_meat(
1837
+ classical_bread_matrix,
1605
1838
  theta_only_adjusted_meat_matrix_v2,
1606
1839
  num_subjects,
1607
- method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
1840
+ method=SandwichFormationMethods.BREAD_T_QR,
1608
1841
  )
1609
1842
  )
1610
- theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
1843
+ theta_only_adjusted_sandwich = joint_sandwich[-theta_dim:, -theta_dim:]
1611
1844
 
1612
1845
  if not np.allclose(
1613
1846
  theta_only_adjusted_sandwich,
@@ -1615,17 +1848,17 @@ def construct_classical_and_adjusted_sandwiches(
1615
1848
  rtol=3e-2,
1616
1849
  ):
1617
1850
  logger.warning(
1618
- "There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adaptive sandwich two different ways and they do not match sufficiently."
1851
+ "There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adjusted sandwich two different ways and they do not match sufficiently."
1619
1852
  )
1620
1853
 
1621
- # Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
1622
- # values too. The joint adaptive bread inverse should always be block lower triangular.
1854
+ # Stack the joint bread pieces together horizontally and return the auxiliary
1855
+ # values too. The joint bread should always be block lower triangular.
1623
1856
  return (
1624
- raw_joint_adjusted_bread_inverse_matrix,
1625
- stabilized_joint_adjusted_bread_inverse_matrix,
1857
+ raw_joint_bread_matrix,
1858
+ stabilized_joint_bread_matrix,
1626
1859
  joint_adjusted_meat_matrix,
1627
- joint_adjusted_sandwich,
1628
- classical_bread_inverse_matrix,
1860
+ joint_sandwich,
1861
+ classical_bread_matrix,
1629
1862
  classical_meat_matrix,
1630
1863
  classical_sandwich,
1631
1864
  avg_estimating_function_stack,
@@ -1639,25 +1872,25 @@ def construct_classical_and_adjusted_sandwiches(
1639
1872
  # TODO: I think there should be interaction to confirm stabilization. It is
1640
1873
  # important for the subject to know if this is happening. Even if enabled, it is important
1641
1874
  # that the subject know it actually kicks in.
1642
- def stabilize_joint_adjusted_bread_inverse_if_necessary(
1643
- joint_adjusted_bread_inverse_matrix: jnp.ndarray,
1875
+ def stabilize_joint_bread_if_necessary(
1876
+ joint_bread_matrix: jnp.ndarray,
1644
1877
  beta_dim: int,
1645
1878
  theta_dim: int,
1646
1879
  ) -> jnp.ndarray:
1647
1880
  """
1648
- Stabilizes the joint adaptive bread inverse matrix if necessary by increasing diagonal block
1881
+ Stabilizes the joint bread matrix if necessary by increasing diagonal block
1649
1882
  dominance and/or adding a small ridge penalty to the diagonal blocks.
1650
1883
 
1651
1884
  Args:
1652
- joint_adjusted_bread_inverse_matrix (jnp.ndarray):
1653
- A 2-D JAX NumPy array representing the joint adaptive bread inverse matrix.
1885
+ joint_bread_matrix (jnp.ndarray):
1886
+ A 2-D JAX NumPy array representing the joint bread matrix.
1654
1887
  beta_dim (int):
1655
1888
  The dimension of each beta parameter.
1656
1889
  theta_dim (int):
1657
1890
  The dimension of the theta parameter.
1658
1891
  Returns:
1659
1892
  jnp.ndarray:
1660
- A 2-D NumPy array representing the stabilized joint adaptive bread inverse matrix.
1893
+ A 2-D NumPy array representing the stabilized joint bread matrix.
1661
1894
  """
1662
1895
 
1663
1896
  # TODO: come up with more sophisticated settings here. These are maybe a little loose,
@@ -1670,7 +1903,7 @@ def stabilize_joint_adjusted_bread_inverse_if_necessary(
1670
1903
 
1671
1904
  # Grab just the RL block and convert numpy array for easier manipulation.
1672
1905
  RL_stack_beta_derivatives_block = np.array(
1673
- joint_adjusted_bread_inverse_matrix[:-theta_dim, :-theta_dim]
1906
+ joint_bread_matrix[:-theta_dim, :-theta_dim]
1674
1907
  )
1675
1908
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
1676
1909
  for i in range(1, num_updates + 1):
@@ -1789,31 +2022,31 @@ def stabilize_joint_adjusted_bread_inverse_if_necessary(
1789
2022
  [
1790
2023
  [
1791
2024
  RL_stack_beta_derivatives_block,
1792
- joint_adjusted_bread_inverse_matrix[:-theta_dim, -theta_dim:],
2025
+ joint_bread_matrix[:-theta_dim, -theta_dim:],
1793
2026
  ],
1794
2027
  [
1795
- joint_adjusted_bread_inverse_matrix[-theta_dim:, :-theta_dim],
1796
- joint_adjusted_bread_inverse_matrix[-theta_dim:, -theta_dim:],
2028
+ joint_bread_matrix[-theta_dim:, :-theta_dim],
2029
+ joint_bread_matrix[-theta_dim:, -theta_dim:],
1797
2030
  ],
1798
2031
  ]
1799
2032
  )
1800
2033
 
1801
2034
 
1802
- def form_sandwich_from_bread_inverse_and_meat(
1803
- bread_inverse: jnp.ndarray,
2035
+ def form_sandwich_from_bread_and_meat(
2036
+ bread: jnp.ndarray,
1804
2037
  meat: jnp.ndarray,
1805
2038
  num_subjects: int,
1806
- method: str = SandwichFormationMethods.BREAD_INVERSE_T_QR,
2039
+ method: str = SandwichFormationMethods.BREAD_T_QR,
1807
2040
  ) -> jnp.ndarray:
1808
2041
  """
1809
- Forms a sandwich variance matrix from the provided bread inverse and meat matrices.
2042
+ Forms a sandwich variance matrix from the provided bread and meat matrices.
1810
2043
 
1811
- Attempts to do so STABLY without ever forming the bread matrix itself
2044
+ Attempts to do so STABLY without ever forming the bread inverse matrix itself
1812
2045
  (except with naive option).
1813
2046
 
1814
2047
  Args:
1815
- bread_inverse (jnp.ndarray):
1816
- A 2-D JAX NumPy array representing the bread inverse matrix.
2048
+ bread (jnp.ndarray):
2049
+ A 2-D JAX NumPy array representing the bread matrix.
1817
2050
  meat (jnp.ndarray):
1818
2051
  A 2-D JAX NumPy array representing the meat matrix.
1819
2052
  num_subjects (int):
@@ -1821,12 +2054,12 @@ def form_sandwich_from_bread_inverse_and_meat(
1821
2054
  method (str):
1822
2055
  The method to use for forming the sandwich.
1823
2056
 
1824
- SandwichFormationMethods.BREAD_INVERSE_T_QR uses the QR decomposition of the transpose
1825
- of the bread inverse matrix.
2057
+ SandwichFormationMethods.BREAD_T_QR uses the QR decomposition of the transpose
2058
+ of the bread matrix.
1826
2059
 
1827
2060
  SandwichFormationMethods.MEAT_SVD_SOLVE uses a decomposition of the meat matrix.
1828
2061
 
1829
- SandwichFormationMethods.NAIVE simply inverts the bread inverse and forms the sandwich.
2062
+ SandwichFormationMethods.NAIVE simply inverts the bread and forms the sandwich.
1830
2063
 
1831
2064
 
1832
2065
  Returns:
@@ -1834,9 +2067,9 @@ def form_sandwich_from_bread_inverse_and_meat(
1834
2067
  A 2-D JAX NumPy array representing the sandwich variance matrix.
1835
2068
  """
1836
2069
 
1837
- if method == SandwichFormationMethods.BREAD_INVERSE_T_QR:
2070
+ if method == SandwichFormationMethods.BREAD_T_QR:
1838
2071
  # QR of B^T → Q orthogonal, R upper triangular; L = R^T lower triangular
1839
- Q, R = np.linalg.qr(bread_inverse.T, mode="reduced")
2072
+ Q, R = np.linalg.qr(bread.T, mode="reduced")
1840
2073
  L = R.T
1841
2074
 
1842
2075
  new_meat = scipy.linalg.solve_triangular(
@@ -1854,21 +2087,21 @@ def form_sandwich_from_bread_inverse_and_meat(
1854
2087
  C_right = Vh.T * np.sqrt(s)
1855
2088
 
1856
2089
  # Solve B W_left = C_left and B W_right = C_right (no explicit inverses).
1857
- W_left = scipy.linalg.solve(bread_inverse, C_left)
1858
- W_right = scipy.linalg.solve(bread_inverse, C_right)
2090
+ W_left = scipy.linalg.solve(bread, C_left)
2091
+ W_right = scipy.linalg.solve(bread, C_right)
1859
2092
 
1860
2093
  # Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
1861
2094
  return W_left @ W_right.T / num_subjects
1862
2095
 
1863
2096
  elif method == SandwichFormationMethods.NAIVE:
1864
- # Simply invert the bread inverse and form the sandwich directly.
2097
+ # Simply invert the bread and form the sandwich directly.
1865
2098
  # This is NOT numerically stable and is only included for comparison purposes.
1866
- bread = np.linalg.inv(bread_inverse)
1867
- return bread @ meat @ meat.T / num_subjects
2099
+ bread_inverse = np.linalg.inv(bread)
2100
+ return bread_inverse @ meat @ bread_inverse.T / num_subjects
1868
2101
 
1869
2102
  else:
1870
2103
  raise ValueError(
1871
- f"Unknown sandwich method: {method}. Please use 'bread_inverse_t_qr' or 'meat_decomposition_solve'."
2104
+ f"Unknown sandwich method: {method}. Please use 'bread_t_qr' or 'meat_decomposition_solve'."
1872
2105
  )
1873
2106
 
1874
2107