lifejacket 1.0.2__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,6 +53,8 @@ logging.basicConfig(
53
53
  level=logging.INFO,
54
54
  )
55
55
 
56
+ jax.config.update("jax_enable_x64", True)
57
+
56
58
 
57
59
  @click.group()
58
60
  def cli():
@@ -483,10 +485,10 @@ def analyze_dataset(
483
485
 
484
486
  subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
485
487
  (
486
- stabilized_joint_adjusted_bread_matrix,
487
- raw_joint_adjusted_bread_matrix,
488
+ stabilized_joint_bread_matrix,
489
+ raw_joint_bread_matrix,
488
490
  joint_adjusted_meat_matrix,
489
- joint_adjusted_sandwich_matrix,
491
+ joint_sandwich_matrix,
490
492
  classical_bread_matrix,
491
493
  classical_meat_matrix,
492
494
  classical_sandwich_var_estimate,
@@ -544,9 +546,7 @@ def analyze_dataset(
544
546
 
545
547
  # This bottom right corner of the joint (betas and theta) variance matrix is the portion
546
548
  # corresponding to just theta.
547
- adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
548
- -theta_dim:, -theta_dim:
549
- ]
549
+ adjusted_sandwich_var_estimate = joint_sandwich_matrix[-theta_dim:, -theta_dim:]
550
550
 
551
551
  # Check for negative diagonal elements and set them to zero if found
552
552
  adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
@@ -572,31 +572,229 @@ def analyze_dataset(
572
572
  f,
573
573
  )
574
574
 
575
- joint_adjusted_bread_cond = jnp.linalg.cond(raw_joint_adjusted_bread_matrix)
575
+ joint_bread_cond = jnp.linalg.cond(raw_joint_bread_matrix)
576
+ logger.info(
577
+ "Joint bread condition number: %f",
578
+ joint_bread_cond,
579
+ )
580
+
581
+ # calculate the max eigenvalue of the theta-only adjusted sandwich
582
+ max_eigenvalue_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
583
+ adjusted_sandwich_var_estimate
584
+ ).max()
576
585
  logger.info(
577
- "Joint adjusted bread condition number: %f",
578
- joint_adjusted_bread_cond,
586
+ "Max eigenvalue of theta-only adjusted sandwich matrix: %f",
587
+ max_eigenvalue_theta_only_adjusted_sandwich,
579
588
  )
580
589
 
581
- # calculate the max eigenvalue of the joint adjusted sandwich
582
- max_eigenvalue = scipy.linalg.eigvalsh(joint_adjusted_sandwich_matrix).max()
590
+ # Compute ratios: max eigenvalue / median eigenvalue among those >= 1e-8 * max.
591
+ eigvals_joint_sandwich = scipy.linalg.eigvalsh(joint_sandwich_matrix)
592
+ max_eig_joint = float(eigvals_joint_sandwich.max())
583
593
  logger.info(
584
594
  "Max eigenvalue of joint adjusted sandwich matrix: %f",
585
- max_eigenvalue,
595
+ max_eig_joint,
586
596
  )
587
597
 
598
+ joint_keep = eigvals_joint_sandwich >= (1e-8 * max_eig_joint)
599
+ joint_median_kept = (
600
+ float(np.median(eigvals_joint_sandwich[joint_keep]))
601
+ if np.any(joint_keep)
602
+ else math.nan
603
+ )
604
+ max_to_median_ratio_joint_sandwich = (
605
+ (max_eig_joint / joint_median_kept)
606
+ if (not math.isnan(joint_median_kept) and joint_median_kept > 0)
607
+ else (
608
+ math.inf
609
+ if (not math.isnan(joint_median_kept) and joint_median_kept == 0)
610
+ else math.nan
611
+ )
612
+ )
613
+ logger.info(
614
+ "Max/median eigenvalue ratio (joint sandwich; median over eigvals >= 1e-8*max): %f",
615
+ max_to_median_ratio_joint_sandwich,
616
+ )
617
+
618
+ eigvals_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
619
+ adjusted_sandwich_var_estimate
620
+ )
621
+ max_eig_theta = float(eigvals_theta_only_adjusted_sandwich.max())
622
+ theta_keep = eigvals_theta_only_adjusted_sandwich >= (1e-8 * max_eig_theta)
623
+ theta_median_kept = (
624
+ float(np.median(eigvals_theta_only_adjusted_sandwich[theta_keep]))
625
+ if np.any(theta_keep)
626
+ else math.nan
627
+ )
628
+ max_to_median_ratio_theta_only_adjusted_sandwich = (
629
+ (max_eig_theta / theta_median_kept)
630
+ if (not math.isnan(theta_median_kept) and theta_median_kept > 0)
631
+ else (
632
+ math.inf
633
+ if (not math.isnan(theta_median_kept) and theta_median_kept == 0)
634
+ else math.nan
635
+ )
636
+ )
637
+ logger.info(
638
+ "Max/median eigenvalue ratio (theta-only adjusted sandwich; median over eigvals >= 1e-8*max): %f",
639
+ max_to_median_ratio_theta_only_adjusted_sandwich,
640
+ )
641
+
642
+ # --- Local linearization validity diagnostic (single-run) ---
643
+ # We compare the nonlinear Taylor remainder of the joint estimating-function map to the
644
+ # retained linear term, at perturbations on the O(1/sqrt(n)) scale.
645
+ #
646
+ # Define r(delta) = || g(eta+delta) - g(eta) - B delta ||_2 / || B delta ||_2,
647
+ # where g(eta) is the avg per-subject weighted estimating-function stack and B is the
648
+ # stabilized joint bread (Jacobian of g w.r.t. flattened betas+theta).
649
+ #
650
+ # This ratio is dimensionless and can be used as a necessary/sanity diagnostic that the
651
+ # first-order linearization is locally accurate at the estimation scale.
652
+
653
+ def _compute_local_linearization_error_ratio() -> tuple[float, float]:
654
+ # Ensure float64 for diagnostics even if upstream ran in float32.
655
+ joint_bread_float64 = jnp.asarray(
656
+ stabilized_joint_bread_matrix, dtype=jnp.float64
657
+ )
658
+ g_hat = jnp.asarray(avg_estimating_function_stack, dtype=jnp.float64)
659
+ stacks_float64 = jnp.asarray(
660
+ per_subject_estimating_function_stacks, dtype=jnp.float64
661
+ )
662
+
663
+ num_subjects = stacks_float64.shape[0]
664
+
665
+ def _eval_avg_stack_jit(flattened_betas_and_theta: jnp.ndarray) -> jnp.ndarray:
666
+ return jnp.asarray(
667
+ get_avg_weighted_estimating_function_stacks_and_aux_values(
668
+ flattened_betas_and_theta,
669
+ beta_dim,
670
+ theta_dim,
671
+ subject_ids,
672
+ action_prob_func,
673
+ action_prob_func_args_beta_index,
674
+ alg_update_func,
675
+ alg_update_func_type,
676
+ alg_update_func_args_beta_index,
677
+ alg_update_func_args_action_prob_index,
678
+ alg_update_func_args_action_prob_times_index,
679
+ alg_update_func_args_previous_betas_index,
680
+ inference_func,
681
+ inference_func_type,
682
+ inference_func_args_theta_index,
683
+ inference_func_args_action_prob_index,
684
+ action_prob_func_args,
685
+ policy_num_by_decision_time_by_subject_id,
686
+ initial_policy_num,
687
+ beta_index_by_policy_num,
688
+ inference_func_args_by_subject_id,
689
+ inference_action_prob_decision_times_by_subject_id,
690
+ alg_update_func_args,
691
+ action_by_decision_time_by_subject_id,
692
+ True, # suppress_all_data_checks
693
+ True, # suppress_interactive_data_checks
694
+ False, # include_auxiliary_outputs
695
+ ),
696
+ dtype=jnp.float64,
697
+ )
698
+
699
+ # Evaluate at the final estimate.
700
+ eta_hat = jnp.asarray(
701
+ flatten_params(all_post_update_betas, theta_est), dtype=jnp.float64
702
+ )
703
+
704
+ # Draw perturbations delta_j on the O(1/sqrt(n)) scale, aligned with the empirical
705
+ # joint estimating function stack covariance, without forming a d_joint x d_joint matrix
706
+ # square-root. If G is the (n x d) matrix of per-subject stacks, then (1/n) G^T G is the
707
+ # empirical covariance in joint estimating function stack space. Sampling u = (G^T w)/sqrt(n) with w~N(0, I_n) gives
708
+ # u ~ N(0, empirical joint estimating function stack covariance G^T G/n ) in joint estimating function stack space.
709
+ key = jax.random.PRNGKey(0)
710
+
711
+ # The number of perturbations we will probe
712
+ J = 15
713
+ # Each requires num_subjects standard normal draws, which we will then transform
714
+ # into joint estimating function space perturbations in U
715
+ W = jax.random.normal(key, shape=(J, num_subjects), dtype=jnp.float64)
716
+
717
+ # Joint estimating function space perturbations: u_j in R^{d_joint}
718
+ # U = (1/sqrt(n)) * W G, where rows of G are g_i^T
719
+ U = (W @ stacks_float64) / jnp.sqrt(num_subjects)
720
+
721
+ # Parameter perturbations: delta = (c/sqrt(n)) * B^{-1} u
722
+ # Use solve rather than explicit inverse.
723
+ c = 1.0
724
+ delta = (c / jnp.sqrt(num_subjects)) * jnp.linalg.solve(
725
+ joint_bread_float64, U.T
726
+ ).T
727
+
728
+ # Compute ratios r_j.
729
+ # NOTE: We use the Euclidean norm in score space; this is dimensionless and avoids
730
+ # forming/pseudoinverting a potentially rank-deficient matrix.
731
+ B_delta = (joint_bread_float64 @ delta.T).T
732
+ g_plus = jax.vmap(lambda d: _eval_avg_stack_jit(eta_hat + d))(delta)
733
+ remainder = g_plus - g_hat - B_delta
734
+
735
+ denom = jnp.linalg.norm(B_delta, axis=1)
736
+ numer = jnp.linalg.norm(remainder, axis=1)
737
+
738
+ # Avoid division by zero (should not happen unless delta collapses numerically).
739
+ ratios = jnp.where(denom > 0, numer / denom, jnp.inf)
740
+
741
+ local_error_ratio_median = float(jnp.median(ratios))
742
+ local_error_ratio_p90 = float(jnp.quantile(ratios, 0.9))
743
+ local_error_ratio_max = float(jnp.max(ratios))
744
+
745
+ logger.info(
746
+ "Local linearization error ratio (median over %d draws): %.6f",
747
+ J,
748
+ local_error_ratio_median,
749
+ )
750
+ logger.info(
751
+ "Local linearization error ratio (90th pct over %d draws): %.6f",
752
+ J,
753
+ local_error_ratio_p90,
754
+ )
755
+
756
+ logger.info(
757
+ "Local linearization error ratio (max over %d draws): %.6f",
758
+ J,
759
+ local_error_ratio_max,
760
+ )
761
+
762
+ return local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max
763
+
764
+ try:
765
+ local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max = (
766
+ _compute_local_linearization_error_ratio()
767
+ )
768
+ except Exception as e:
769
+ # This diagnostic is best-effort; failure should not break analysis.
770
+ logger.warning(
771
+ "Failed to compute local linearization error ratio diagnostic: %s",
772
+ str(e),
773
+ )
774
+ local_error_ratio_median = math.nan
775
+ local_error_ratio_p90 = math.nan
776
+ local_error_ratio_max = math.nan
777
+
588
778
  debug_pieces_dict = {
589
779
  "theta_est": theta_est,
590
780
  "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
591
781
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
592
- "raw_joint_bread_matrix": raw_joint_adjusted_bread_matrix,
593
- "stabilized_joint_bread_matrix": stabilized_joint_adjusted_bread_matrix,
782
+ "raw_joint_bread_matrix": raw_joint_bread_matrix,
783
+ "stabilized_joint_bread_matrix": stabilized_joint_bread_matrix,
594
784
  "joint_meat_matrix": joint_adjusted_meat_matrix,
595
785
  "classical_bread_matrix": classical_bread_matrix,
596
786
  "classical_meat_matrix": classical_meat_matrix,
597
787
  "all_estimating_function_stacks": per_subject_estimating_function_stacks,
598
- "joint_bread_condition_number": joint_adjusted_bread_cond,
599
- "max_eigenvalue_joint_adjusted_sandwich": max_eigenvalue,
788
+ "joint_bread_condition_number": joint_bread_cond,
789
+ "max_eigenvalue_joint_sandwich": max_eig_joint,
790
+ "all_eigenvalues_joint_sandwich": eigvals_joint_sandwich,
791
+ "max_to_median_ratio_joint_sandwich": max_to_median_ratio_joint_sandwich,
792
+ "max_eigenvalue_theta_only_adjusted_sandwich": max_eig_theta,
793
+ "all_eigenvalues_theta_only_adjusted_sandwich": eigvals_theta_only_adjusted_sandwich,
794
+ "max_to_median_ratio_theta_only_adjusted_sandwich": max_to_median_ratio_theta_only_adjusted_sandwich,
795
+ "local_linearization_error_ratio_median": local_error_ratio_median,
796
+ "local_linearization_error_ratio_p90": local_error_ratio_p90,
797
+ "local_linearization_error_ratio_max": local_error_ratio_max,
600
798
  "all_post_update_betas": all_post_update_betas,
601
799
  "per_subject_adjusted_corrections": per_subject_adjusted_corrections,
602
800
  "per_subject_classical_corrections": per_subject_classical_corrections,
@@ -610,8 +808,8 @@ def analyze_dataset(
610
808
 
611
809
  if collect_data_for_blowup_supervised_learning:
612
810
  datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
613
- raw_joint_adjusted_bread_matrix,
614
- joint_adjusted_bread_cond,
811
+ raw_joint_bread_matrix,
812
+ joint_bread_cond,
615
813
  avg_estimating_function_stack,
616
814
  per_subject_estimating_function_stacks,
617
815
  all_post_update_betas,
@@ -756,12 +954,16 @@ def single_subject_weighted_estimating_function_stacker(
756
954
  policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
757
955
  action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
758
956
  beta_index_by_policy_num: dict[int | float, int],
759
- ) -> tuple[
760
- jnp.ndarray[jnp.float32],
761
- jnp.ndarray[jnp.float32],
762
- jnp.ndarray[jnp.float32],
763
- jnp.ndarray[jnp.float32],
764
- ]:
957
+ include_auxiliary_outputs: bool = True,
958
+ ) -> (
959
+ tuple[
960
+ jnp.ndarray[jnp.float32],
961
+ jnp.ndarray[jnp.float32],
962
+ jnp.ndarray[jnp.float32],
963
+ jnp.ndarray[jnp.float32],
964
+ ]
965
+ | jnp.ndarray[jnp.float32]
966
+ ):
765
967
  """
766
968
  Computes a weighted estimating function stack for a given algorithm estimating function
767
969
  and arguments, inference estimating functio and arguments, and action probability function and
@@ -825,12 +1027,23 @@ def single_subject_weighted_estimating_function_stacker(
825
1027
  A dictionary mapping policy numbers to the index of the corresponding beta in
826
1028
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
827
1029
 
1030
+ include_auxiliary_outputs (bool):
1031
+ If True, returns the adjusted meat, classical meat, and classical bread contributions in
1032
+ a second returned tuple. If False, only returns the weighted estimating function stack.
1033
+
828
1034
  Returns:
829
1035
  jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
830
1036
  stack.
831
1037
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
832
1038
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
833
1039
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
1040
+
1041
+ or
1042
+
1043
+ jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
1044
+ stack.
1045
+
1046
+ depending on the value of include_auxiliary_outputs.
834
1047
  """
835
1048
 
836
1049
  logger.info(
@@ -1020,14 +1233,18 @@ def single_subject_weighted_estimating_function_stacker(
1020
1233
  # c. The third output is averaged across subjects to obtain the classical meat matrix.
1021
1234
  # d. The fourth output is averaged across subjects to obtain the inverse classical bread
1022
1235
  # matrix.
1023
- return (
1024
- weighted_stack,
1025
- jnp.outer(weighted_stack, weighted_stack),
1026
- jnp.outer(inference_component, inference_component),
1027
- jax.jacrev(inference_estimating_func, argnums=inference_func_args_theta_index)(
1028
- *threaded_inference_func_args
1029
- ),
1030
- )
1236
+ if include_auxiliary_outputs:
1237
+ return (
1238
+ weighted_stack,
1239
+ jnp.outer(weighted_stack, weighted_stack),
1240
+ jnp.outer(inference_component, inference_component),
1241
+ jax.jacrev(
1242
+ inference_estimating_func, argnums=inference_func_args_theta_index
1243
+ )(*threaded_inference_func_args),
1244
+ )
1245
+
1246
+ else:
1247
+ return weighted_stack
1031
1248
 
1032
1249
 
1033
1250
  def get_avg_weighted_estimating_function_stacks_and_aux_values(
@@ -1067,6 +1284,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1067
1284
  ],
1068
1285
  suppress_all_data_checks: bool,
1069
1286
  suppress_interactive_data_checks: bool,
1287
+ include_auxiliary_outputs: bool = True,
1070
1288
  ) -> tuple[
1071
1289
  jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
1072
1290
  ]:
@@ -1141,10 +1359,14 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1141
1359
  If True, suppresses interactive data checks that would otherwise be performed to ensure
1142
1360
  the correctness of the threaded arguments. The checks are still performed, but
1143
1361
  any interactive prompts are suppressed.
1362
+ include_auxiliary_outputs (bool):
1363
+ If True, returns the adjusted meat, classical meat, and classical bread contributions in addition to the average weighted estimating function stack.
1364
+ If False, returns only the average weighted estimating function stack.
1144
1365
 
1145
1366
  Returns:
1146
1367
  jnp.ndarray:
1147
1368
  A 2D JAX NumPy array holding the average weighted estimating function stack.
1369
+
1148
1370
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1149
1371
  A tuple containing
1150
1372
  1. the average weighted estimating function stack
@@ -1153,6 +1375,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1153
1375
  4. the subject-level inverse classical bread matrix contributions
1154
1376
  5. raw per-subject weighted estimating function
1155
1377
  stacks.
1378
+ or jnp.ndarray:
1379
+ A 1-D JAX NumPy array representing the subject's weighted estimating function
1380
+ stack.
1381
+ depending on the value of include_auxiliary_outputs.
1156
1382
  """
1157
1383
 
1158
1384
  # 1. Collect estimating functions by differentiating the loss functions if needed.
@@ -1275,6 +1501,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1275
1501
  ]
1276
1502
 
1277
1503
  stacks = jnp.array([result[0] for result in results])
1504
+
1505
+ if not include_auxiliary_outputs:
1506
+ return jnp.mean(stacks, axis=0)
1507
+
1278
1508
  outer_products = jnp.array([result[1] for result in results])
1279
1509
  inference_only_outer_products = jnp.array([result[2] for result in results])
1280
1510
  inference_hessians = jnp.array([result[3] for result in results])
@@ -1475,7 +1705,7 @@ def construct_classical_and_adjusted_sandwiches(
1475
1705
  theta_dim = theta_est.shape[0]
1476
1706
  beta_dim = all_post_update_betas.shape[1]
1477
1707
  # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
1478
- raw_joint_adjusted_bread_matrix, (
1708
+ raw_joint_bread_matrix, (
1479
1709
  avg_estimating_function_stack,
1480
1710
  per_subject_joint_adjusted_meat_contributions,
1481
1711
  per_subject_classical_meat_contributions,
@@ -1533,21 +1763,21 @@ def construct_classical_and_adjusted_sandwiches(
1533
1763
 
1534
1764
  # Increase diagonal block dominance possibly improve conditioning of diagonal
1535
1765
  # blocks as necessary, to ensure mathematical stability of joint bread
1536
- stabilized_joint_adjusted_bread_matrix = (
1766
+ stabilized_joint_bread_matrix = (
1537
1767
  (
1538
1768
  stabilize_joint_bread_if_necessary(
1539
- raw_joint_adjusted_bread_matrix,
1769
+ raw_joint_bread_matrix,
1540
1770
  beta_dim,
1541
1771
  theta_dim,
1542
1772
  )
1543
1773
  )
1544
1774
  if stabilize_joint_bread
1545
- else raw_joint_adjusted_bread_matrix
1775
+ else raw_joint_bread_matrix
1546
1776
  )
1547
1777
 
1548
1778
  # Now stably (no explicit inversion) form our sandwiches.
1549
- joint_adjusted_sandwich = form_sandwich_from_bread_and_meat(
1550
- stabilized_joint_adjusted_bread_matrix,
1779
+ joint_sandwich = form_sandwich_from_bread_and_meat(
1780
+ stabilized_joint_bread_matrix,
1551
1781
  joint_adjusted_meat_matrix,
1552
1782
  num_subjects,
1553
1783
  method=SandwichFormationMethods.BREAD_T_QR,
@@ -1568,7 +1798,7 @@ def construct_classical_and_adjusted_sandwiches(
1568
1798
  form_adjusted_meat_adjustments_directly(
1569
1799
  theta_dim,
1570
1800
  all_post_update_betas.shape[1],
1571
- stabilized_joint_adjusted_bread_matrix,
1801
+ stabilized_joint_bread_matrix,
1572
1802
  per_subject_estimating_function_stacks,
1573
1803
  analysis_df,
1574
1804
  active_col_name,
@@ -1610,7 +1840,7 @@ def construct_classical_and_adjusted_sandwiches(
1610
1840
  method=SandwichFormationMethods.BREAD_T_QR,
1611
1841
  )
1612
1842
  )
1613
- theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
1843
+ theta_only_adjusted_sandwich = joint_sandwich[-theta_dim:, -theta_dim:]
1614
1844
 
1615
1845
  if not np.allclose(
1616
1846
  theta_only_adjusted_sandwich,
@@ -1624,10 +1854,10 @@ def construct_classical_and_adjusted_sandwiches(
1624
1854
  # Stack the joint bread pieces together horizontally and return the auxiliary
1625
1855
  # values too. The joint bread should always be block lower triangular.
1626
1856
  return (
1627
- raw_joint_adjusted_bread_matrix,
1628
- stabilized_joint_adjusted_bread_matrix,
1857
+ raw_joint_bread_matrix,
1858
+ stabilized_joint_bread_matrix,
1629
1859
  joint_adjusted_meat_matrix,
1630
- joint_adjusted_sandwich,
1860
+ joint_sandwich,
1631
1861
  classical_bread_matrix,
1632
1862
  classical_meat_matrix,
1633
1863
  classical_sandwich,
@@ -1643,7 +1873,7 @@ def construct_classical_and_adjusted_sandwiches(
1643
1873
  # important for the subject to know if this is happening. Even if enabled, it is important
1644
1874
  # that the subject know it actually kicks in.
1645
1875
  def stabilize_joint_bread_if_necessary(
1646
- joint_adjusted_bread_matrix: jnp.ndarray,
1876
+ joint_bread_matrix: jnp.ndarray,
1647
1877
  beta_dim: int,
1648
1878
  theta_dim: int,
1649
1879
  ) -> jnp.ndarray:
@@ -1652,7 +1882,7 @@ def stabilize_joint_bread_if_necessary(
1652
1882
  dominance and/or adding a small ridge penalty to the diagonal blocks.
1653
1883
 
1654
1884
  Args:
1655
- joint_adjusted_bread_matrix (jnp.ndarray):
1885
+ joint_bread_matrix (jnp.ndarray):
1656
1886
  A 2-D JAX NumPy array representing the joint bread matrix.
1657
1887
  beta_dim (int):
1658
1888
  The dimension of each beta parameter.
@@ -1673,7 +1903,7 @@ def stabilize_joint_bread_if_necessary(
1673
1903
 
1674
1904
  # Grab just the RL block and convert numpy array for easier manipulation.
1675
1905
  RL_stack_beta_derivatives_block = np.array(
1676
- joint_adjusted_bread_matrix[:-theta_dim, :-theta_dim]
1906
+ joint_bread_matrix[:-theta_dim, :-theta_dim]
1677
1907
  )
1678
1908
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
1679
1909
  for i in range(1, num_updates + 1):
@@ -1792,11 +2022,11 @@ def stabilize_joint_bread_if_necessary(
1792
2022
  [
1793
2023
  [
1794
2024
  RL_stack_beta_derivatives_block,
1795
- joint_adjusted_bread_matrix[:-theta_dim, -theta_dim:],
2025
+ joint_bread_matrix[:-theta_dim, -theta_dim:],
1796
2026
  ],
1797
2027
  [
1798
- joint_adjusted_bread_matrix[-theta_dim:, :-theta_dim],
1799
- joint_adjusted_bread_matrix[-theta_dim:, -theta_dim:],
2028
+ joint_bread_matrix[-theta_dim:, :-theta_dim],
2029
+ joint_bread_matrix[-theta_dim:, -theta_dim:],
1800
2030
  ],
1801
2031
  ]
1802
2032
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lifejacket
3
- Version: 1.0.2
3
+ Version: 1.1.0
4
4
  Summary: Consistent standard errors for longitudinal data collected under pooling online decision policies.
5
5
  Author-email: Nowell Closser <nowellclosser@gmail.com>
6
6
  Requires-Python: >=3.10
@@ -7,11 +7,11 @@ lifejacket/form_adjusted_meat_adjustments_directly.py,sha256=AVlGOuw_FgVDcVnhQs1
7
7
  lifejacket/get_datum_for_blowup_supervised_learning.py,sha256=sCH-PlrFlLJgCYpTmdeasiHwHYSEy9wxspkOTDuDPuY,58594
8
8
  lifejacket/helper_functions.py,sha256=SdAbUwXNx-3JFsyTfLyliQ7kUOm0eABaiNgoYLR8NG0,16967
9
9
  lifejacket/input_checks.py,sha256=q7HFZq5n18edQU8X5laONsBgWSMidLRy6Nhqdw5FpOw,47084
10
- lifejacket/post_deployment_analysis.py,sha256=XdKObve0hOXVwPWSDD2lEEfrWAdcZK-c-uh53HIrKLM,82664
10
+ lifejacket/post_deployment_analysis.py,sha256=M6-qQ1vynJ7-iHi-vgTdS71Vp-27gioB1b6qt0sCdek,92470
11
11
  lifejacket/small_sample_corrections.py,sha256=aB6qi-r3ANoBMgf2Oo5-lCXCy_L4H3FlBffGwPcfXkg,5610
12
12
  lifejacket/vmap_helpers.py,sha256=pZqYN3p9Ty9DPOeeY9TKbRJXR2AV__HBwwDFOvdOQ84,2688
13
- lifejacket-1.0.2.dist-info/METADATA,sha256=L48u8IMsEXMTwtBJlFwh4mJ-pZUD8NcQEuqgjV0zkjo,1773
14
- lifejacket-1.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
- lifejacket-1.0.2.dist-info/entry_points.txt,sha256=CZ9AUPN0xfnpYqwtGTr6n9l5mpyEOddsXX8fnxKRB6U,71
16
- lifejacket-1.0.2.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
17
- lifejacket-1.0.2.dist-info/RECORD,,
13
+ lifejacket-1.1.0.dist-info/METADATA,sha256=RPJwaJyvfQGCKxgNvg7tx-QPu_Wa89pIPv34fFzqjzA,1773
14
+ lifejacket-1.1.0.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
15
+ lifejacket-1.1.0.dist-info/entry_points.txt,sha256=CZ9AUPN0xfnpYqwtGTr6n9l5mpyEOddsXX8fnxKRB6U,71
16
+ lifejacket-1.1.0.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
17
+ lifejacket-1.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5