lifejacket 1.0.2__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {lifejacket-1.0.2 → lifejacket-1.2.0}/PKG-INFO +1 -1
  2. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/post_deployment_analysis.py +277 -49
  3. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/PKG-INFO +1 -1
  4. {lifejacket-1.0.2 → lifejacket-1.2.0}/pyproject.toml +1 -1
  5. {lifejacket-1.0.2 → lifejacket-1.2.0}/README.md +0 -0
  6. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/__init__.py +0 -0
  7. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/arg_threading_helpers.py +0 -0
  8. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/calculate_derivatives.py +0 -0
  9. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/constants.py +0 -0
  10. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/deployment_conditioning_monitor.py +0 -0
  11. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/form_adjusted_meat_adjustments_directly.py +0 -0
  12. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/get_datum_for_blowup_supervised_learning.py +0 -0
  13. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/helper_functions.py +0 -0
  14. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/input_checks.py +0 -0
  15. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/small_sample_corrections.py +0 -0
  16. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/vmap_helpers.py +0 -0
  17. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/SOURCES.txt +0 -0
  18. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/dependency_links.txt +0 -0
  19. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/entry_points.txt +0 -0
  20. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/requires.txt +0 -0
  21. {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/top_level.txt +0 -0
  22. {lifejacket-1.0.2 → lifejacket-1.2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lifejacket
3
- Version: 1.0.2
3
+ Version: 1.2.0
4
4
  Summary: Consistent standard errors for longitudinal data collected under pooling online decision policies.
5
5
  Author-email: Nowell Closser <nowellclosser@gmail.com>
6
6
  Requires-Python: >=3.10
@@ -483,10 +483,10 @@ def analyze_dataset(
483
483
 
484
484
  subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
485
485
  (
486
- stabilized_joint_adjusted_bread_matrix,
487
- raw_joint_adjusted_bread_matrix,
486
+ stabilized_joint_bread_matrix,
487
+ raw_joint_bread_matrix,
488
488
  joint_adjusted_meat_matrix,
489
- joint_adjusted_sandwich_matrix,
489
+ joint_sandwich_matrix,
490
490
  classical_bread_matrix,
491
491
  classical_meat_matrix,
492
492
  classical_sandwich_var_estimate,
@@ -544,9 +544,7 @@ def analyze_dataset(
544
544
 
545
545
  # This bottom right corner of the joint (betas and theta) variance matrix is the portion
546
546
  # corresponding to just theta.
547
- adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
548
- -theta_dim:, -theta_dim:
549
- ]
547
+ adjusted_sandwich_var_estimate = joint_sandwich_matrix[-theta_dim:, -theta_dim:]
550
548
 
551
549
  # Check for negative diagonal elements and set them to zero if found
552
550
  adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
@@ -572,31 +570,229 @@ def analyze_dataset(
572
570
  f,
573
571
  )
574
572
 
575
- joint_adjusted_bread_cond = jnp.linalg.cond(raw_joint_adjusted_bread_matrix)
573
+ joint_bread_cond = jnp.linalg.cond(raw_joint_bread_matrix)
574
+ logger.info(
575
+ "Joint bread condition number: %f",
576
+ joint_bread_cond,
577
+ )
578
+
579
+ # calculate the max eigenvalue of the theta-only adjusted sandwich
580
+ max_eigenvalue_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
581
+ adjusted_sandwich_var_estimate
582
+ ).max()
576
583
  logger.info(
577
- "Joint adjusted bread condition number: %f",
578
- joint_adjusted_bread_cond,
584
+ "Max eigenvalue of theta-only adjusted sandwich matrix: %f",
585
+ max_eigenvalue_theta_only_adjusted_sandwich,
579
586
  )
580
587
 
581
- # calculate the max eigenvalue of the joint adjusted sandwich
582
- max_eigenvalue = scipy.linalg.eigvalsh(joint_adjusted_sandwich_matrix).max()
588
+ # Compute ratios: max eigenvalue / median eigenvalue among those >= 1e-8 * max.
589
+ eigvals_joint_sandwich = scipy.linalg.eigvalsh(joint_sandwich_matrix)
590
+ max_eig_joint = float(eigvals_joint_sandwich.max())
583
591
  logger.info(
584
592
  "Max eigenvalue of joint adjusted sandwich matrix: %f",
585
- max_eigenvalue,
593
+ max_eig_joint,
594
+ )
595
+
596
+ joint_keep = eigvals_joint_sandwich >= (1e-8 * max_eig_joint)
597
+ joint_median_kept = (
598
+ float(np.median(eigvals_joint_sandwich[joint_keep]))
599
+ if np.any(joint_keep)
600
+ else math.nan
601
+ )
602
+ max_to_median_ratio_joint_sandwich = (
603
+ (max_eig_joint / joint_median_kept)
604
+ if (not math.isnan(joint_median_kept) and joint_median_kept > 0)
605
+ else (
606
+ math.inf
607
+ if (not math.isnan(joint_median_kept) and joint_median_kept == 0)
608
+ else math.nan
609
+ )
610
+ )
611
+ logger.info(
612
+ "Max/median eigenvalue ratio (joint sandwich; median over eigvals >= 1e-8*max): %f",
613
+ max_to_median_ratio_joint_sandwich,
614
+ )
615
+
616
+ eigvals_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
617
+ adjusted_sandwich_var_estimate
618
+ )
619
+ max_eig_theta = float(eigvals_theta_only_adjusted_sandwich.max())
620
+ theta_keep = eigvals_theta_only_adjusted_sandwich >= (1e-8 * max_eig_theta)
621
+ theta_median_kept = (
622
+ float(np.median(eigvals_theta_only_adjusted_sandwich[theta_keep]))
623
+ if np.any(theta_keep)
624
+ else math.nan
625
+ )
626
+ max_to_median_ratio_theta_only_adjusted_sandwich = (
627
+ (max_eig_theta / theta_median_kept)
628
+ if (not math.isnan(theta_median_kept) and theta_median_kept > 0)
629
+ else (
630
+ math.inf
631
+ if (not math.isnan(theta_median_kept) and theta_median_kept == 0)
632
+ else math.nan
633
+ )
634
+ )
635
+ logger.info(
636
+ "Max/median eigenvalue ratio (theta-only adjusted sandwich; median over eigvals >= 1e-8*max): %f",
637
+ max_to_median_ratio_theta_only_adjusted_sandwich,
586
638
  )
587
639
 
640
+ # --- Local linearization validity diagnostic (single-run) ---
641
+ # We compare the nonlinear Taylor remainder of the joint estimating-function map to the
642
+ # retained linear term, at perturbations on the O(1/sqrt(n)) scale.
643
+ #
644
+ # Define r(delta) = || g(eta+delta) - g(eta) - B delta ||_2 / || B delta ||_2,
645
+ # where g(eta) is the avg per-subject weighted estimating-function stack and B is the
646
+ # stabilized joint bread (Jacobian of g w.r.t. flattened betas+theta).
647
+ #
648
+ # This ratio is dimensionless and can be used as a necessary/sanity diagnostic that the
649
+ # first-order linearization is locally accurate at the estimation scale.
650
+
651
+ def _compute_local_linearization_error_ratio() -> tuple[float, float]:
652
+ # Ensure float64 for diagnostics even if upstream ran in float32.
653
+ joint_bread_float64 = jnp.asarray(
654
+ stabilized_joint_bread_matrix, dtype=jnp.float64
655
+ )
656
+ g_hat = jnp.asarray(avg_estimating_function_stack, dtype=jnp.float64)
657
+ stacks_float64 = jnp.asarray(
658
+ per_subject_estimating_function_stacks, dtype=jnp.float64
659
+ )
660
+
661
+ num_subjects = stacks_float64.shape[0]
662
+
663
+ def _eval_avg_stack_jit(flattened_betas_and_theta: jnp.ndarray) -> jnp.ndarray:
664
+ return jnp.asarray(
665
+ get_avg_weighted_estimating_function_stacks_and_aux_values(
666
+ flattened_betas_and_theta,
667
+ beta_dim,
668
+ theta_dim,
669
+ subject_ids,
670
+ action_prob_func,
671
+ action_prob_func_args_beta_index,
672
+ alg_update_func,
673
+ alg_update_func_type,
674
+ alg_update_func_args_beta_index,
675
+ alg_update_func_args_action_prob_index,
676
+ alg_update_func_args_action_prob_times_index,
677
+ alg_update_func_args_previous_betas_index,
678
+ inference_func,
679
+ inference_func_type,
680
+ inference_func_args_theta_index,
681
+ inference_func_args_action_prob_index,
682
+ action_prob_func_args,
683
+ policy_num_by_decision_time_by_subject_id,
684
+ initial_policy_num,
685
+ beta_index_by_policy_num,
686
+ inference_func_args_by_subject_id,
687
+ inference_action_prob_decision_times_by_subject_id,
688
+ alg_update_func_args,
689
+ action_by_decision_time_by_subject_id,
690
+ True, # suppress_all_data_checks
691
+ True, # suppress_interactive_data_checks
692
+ False, # include_auxiliary_outputs
693
+ ),
694
+ dtype=jnp.float64,
695
+ )
696
+
697
+ # Evaluate at the final estimate.
698
+ eta_hat = jnp.asarray(
699
+ flatten_params(all_post_update_betas, theta_est), dtype=jnp.float64
700
+ )
701
+
702
+ # Draw perturbations delta_j on the O(1/sqrt(n)) scale, aligned with the empirical
703
+ # joint estimating function stack covariance, without forming a d_joint x d_joint matrix
704
+ # square-root. If G is the (n x d) matrix of per-subject stacks, then (1/n) G^T G is the
705
+ # empirical covariance in joint estimating function stack space. Sampling u = (G^T w)/sqrt(n) with w~N(0, I_n) gives
706
+ # u ~ N(0, empirical joint estimating function stack covariance G^T G/n ) in joint estimating function stack space.
707
+ key = jax.random.PRNGKey(0)
708
+
709
+ # The number of perturbations we will probe
710
+ J = 15
711
+ # Each requires num_subjects standard normal draws, which we will then transform
712
+ # into joint estimating function space perturbations in U
713
+ W = jax.random.normal(key, shape=(J, num_subjects), dtype=jnp.float64)
714
+
715
+ # Joint estimating function space perturbations: u_j in R^{d_joint}
716
+ # U = (1/sqrt(n)) * W G, where rows of G are g_i^T
717
+ U = (W @ stacks_float64) / jnp.sqrt(num_subjects)
718
+
719
+ # Parameter perturbations: delta = (c/sqrt(n)) * B^{-1} u
720
+ # Use solve rather than explicit inverse.
721
+ c = 1.0
722
+ delta = (c / jnp.sqrt(num_subjects)) * jnp.linalg.solve(
723
+ joint_bread_float64, U.T
724
+ ).T
725
+
726
+ # Compute ratios r_j.
727
+ # NOTE: We use the Euclidean norm in score space; this is dimensionless and avoids
728
+ # forming/pseudoinverting a potentially rank-deficient matrix.
729
+ B_delta = (joint_bread_float64 @ delta.T).T
730
+ g_plus = jax.vmap(lambda d: _eval_avg_stack_jit(eta_hat + d))(delta)
731
+ remainder = g_plus - g_hat - B_delta
732
+
733
+ denom = jnp.linalg.norm(B_delta, axis=1)
734
+ numer = jnp.linalg.norm(remainder, axis=1)
735
+
736
+ # Avoid division by zero (should not happen unless delta collapses numerically).
737
+ ratios = jnp.where(denom > 0, numer / denom, jnp.inf)
738
+
739
+ local_error_ratio_median = float(jnp.median(ratios))
740
+ local_error_ratio_p90 = float(jnp.quantile(ratios, 0.9))
741
+ local_error_ratio_max = float(jnp.max(ratios))
742
+
743
+ logger.info(
744
+ "Local linearization error ratio (median over %d draws): %.6f",
745
+ J,
746
+ local_error_ratio_median,
747
+ )
748
+ logger.info(
749
+ "Local linearization error ratio (90th pct over %d draws): %.6f",
750
+ J,
751
+ local_error_ratio_p90,
752
+ )
753
+
754
+ logger.info(
755
+ "Local linearization error ratio (max over %d draws): %.6f",
756
+ J,
757
+ local_error_ratio_max,
758
+ )
759
+
760
+ return local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max
761
+
762
+ try:
763
+ local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max = (
764
+ _compute_local_linearization_error_ratio()
765
+ )
766
+ except Exception as e:
767
+ # This diagnostic is best-effort; failure should not break analysis.
768
+ logger.warning(
769
+ "Failed to compute local linearization error ratio diagnostic: %s",
770
+ str(e),
771
+ )
772
+ local_error_ratio_median = math.nan
773
+ local_error_ratio_p90 = math.nan
774
+ local_error_ratio_max = math.nan
775
+
588
776
  debug_pieces_dict = {
589
777
  "theta_est": theta_est,
590
778
  "adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
591
779
  "classical_sandwich_var_estimate": classical_sandwich_var_estimate,
592
- "raw_joint_bread_matrix": raw_joint_adjusted_bread_matrix,
593
- "stabilized_joint_bread_matrix": stabilized_joint_adjusted_bread_matrix,
780
+ "raw_joint_bread_matrix": raw_joint_bread_matrix,
781
+ "stabilized_joint_bread_matrix": stabilized_joint_bread_matrix,
594
782
  "joint_meat_matrix": joint_adjusted_meat_matrix,
595
783
  "classical_bread_matrix": classical_bread_matrix,
596
784
  "classical_meat_matrix": classical_meat_matrix,
597
785
  "all_estimating_function_stacks": per_subject_estimating_function_stacks,
598
- "joint_bread_condition_number": joint_adjusted_bread_cond,
599
- "max_eigenvalue_joint_adjusted_sandwich": max_eigenvalue,
786
+ "joint_bread_condition_number": joint_bread_cond,
787
+ "max_eigenvalue_joint_sandwich": max_eig_joint,
788
+ "all_eigenvalues_joint_sandwich": eigvals_joint_sandwich,
789
+ "max_to_median_ratio_joint_sandwich": max_to_median_ratio_joint_sandwich,
790
+ "max_eigenvalue_theta_only_adjusted_sandwich": max_eig_theta,
791
+ "all_eigenvalues_theta_only_adjusted_sandwich": eigvals_theta_only_adjusted_sandwich,
792
+ "max_to_median_ratio_theta_only_adjusted_sandwich": max_to_median_ratio_theta_only_adjusted_sandwich,
793
+ "local_linearization_error_ratio_median": local_error_ratio_median,
794
+ "local_linearization_error_ratio_p90": local_error_ratio_p90,
795
+ "local_linearization_error_ratio_max": local_error_ratio_max,
600
796
  "all_post_update_betas": all_post_update_betas,
601
797
  "per_subject_adjusted_corrections": per_subject_adjusted_corrections,
602
798
  "per_subject_classical_corrections": per_subject_classical_corrections,
@@ -610,8 +806,8 @@ def analyze_dataset(
610
806
 
611
807
  if collect_data_for_blowup_supervised_learning:
612
808
  datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
613
- raw_joint_adjusted_bread_matrix,
614
- joint_adjusted_bread_cond,
809
+ raw_joint_bread_matrix,
810
+ joint_bread_cond,
615
811
  avg_estimating_function_stack,
616
812
  per_subject_estimating_function_stacks,
617
813
  all_post_update_betas,
@@ -756,12 +952,16 @@ def single_subject_weighted_estimating_function_stacker(
756
952
  policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
757
953
  action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
758
954
  beta_index_by_policy_num: dict[int | float, int],
759
- ) -> tuple[
760
- jnp.ndarray[jnp.float32],
761
- jnp.ndarray[jnp.float32],
762
- jnp.ndarray[jnp.float32],
763
- jnp.ndarray[jnp.float32],
764
- ]:
955
+ include_auxiliary_outputs: bool = True,
956
+ ) -> (
957
+ tuple[
958
+ jnp.ndarray[jnp.float32],
959
+ jnp.ndarray[jnp.float32],
960
+ jnp.ndarray[jnp.float32],
961
+ jnp.ndarray[jnp.float32],
962
+ ]
963
+ | jnp.ndarray[jnp.float32]
964
+ ):
765
965
  """
766
966
  Computes a weighted estimating function stack for a given algorithm estimating function
767
967
  and arguments, inference estimating functio and arguments, and action probability function and
@@ -825,12 +1025,23 @@ def single_subject_weighted_estimating_function_stacker(
825
1025
  A dictionary mapping policy numbers to the index of the corresponding beta in
826
1026
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
827
1027
 
1028
+ include_auxiliary_outputs (bool):
1029
+ If True, returns the adjusted meat, classical meat, and classical bread contributions in
1030
+ a second returned tuple. If False, only returns the weighted estimating function stack.
1031
+
828
1032
  Returns:
829
1033
  jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
830
1034
  stack.
831
1035
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
832
1036
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
833
1037
  jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
1038
+
1039
+ or
1040
+
1041
+ jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
1042
+ stack.
1043
+
1044
+ depending on the value of include_auxiliary_outputs.
834
1045
  """
835
1046
 
836
1047
  logger.info(
@@ -1020,14 +1231,18 @@ def single_subject_weighted_estimating_function_stacker(
1020
1231
  # c. The third output is averaged across subjects to obtain the classical meat matrix.
1021
1232
  # d. The fourth output is averaged across subjects to obtain the inverse classical bread
1022
1233
  # matrix.
1023
- return (
1024
- weighted_stack,
1025
- jnp.outer(weighted_stack, weighted_stack),
1026
- jnp.outer(inference_component, inference_component),
1027
- jax.jacrev(inference_estimating_func, argnums=inference_func_args_theta_index)(
1028
- *threaded_inference_func_args
1029
- ),
1030
- )
1234
+ if include_auxiliary_outputs:
1235
+ return (
1236
+ weighted_stack,
1237
+ jnp.outer(weighted_stack, weighted_stack),
1238
+ jnp.outer(inference_component, inference_component),
1239
+ jax.jacrev(
1240
+ inference_estimating_func, argnums=inference_func_args_theta_index
1241
+ )(*threaded_inference_func_args),
1242
+ )
1243
+
1244
+ else:
1245
+ return weighted_stack
1031
1246
 
1032
1247
 
1033
1248
  def get_avg_weighted_estimating_function_stacks_and_aux_values(
@@ -1067,6 +1282,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1067
1282
  ],
1068
1283
  suppress_all_data_checks: bool,
1069
1284
  suppress_interactive_data_checks: bool,
1285
+ include_auxiliary_outputs: bool = True,
1070
1286
  ) -> tuple[
1071
1287
  jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
1072
1288
  ]:
@@ -1141,10 +1357,14 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1141
1357
  If True, suppresses interactive data checks that would otherwise be performed to ensure
1142
1358
  the correctness of the threaded arguments. The checks are still performed, but
1143
1359
  any interactive prompts are suppressed.
1360
+ include_auxiliary_outputs (bool):
1361
+ If True, returns the adjusted meat, classical meat, and classical bread contributions in addition to the average weighted estimating function stack.
1362
+ If False, returns only the average weighted estimating function stack.
1144
1363
 
1145
1364
  Returns:
1146
1365
  jnp.ndarray:
1147
1366
  A 2D JAX NumPy array holding the average weighted estimating function stack.
1367
+
1148
1368
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1149
1369
  A tuple containing
1150
1370
  1. the average weighted estimating function stack
@@ -1153,6 +1373,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1153
1373
  4. the subject-level inverse classical bread matrix contributions
1154
1374
  5. raw per-subject weighted estimating function
1155
1375
  stacks.
1376
+ or jnp.ndarray:
1377
+ A 1-D JAX NumPy array representing the subject's weighted estimating function
1378
+ stack.
1379
+ depending on the value of include_auxiliary_outputs.
1156
1380
  """
1157
1381
 
1158
1382
  # 1. Collect estimating functions by differentiating the loss functions if needed.
@@ -1275,6 +1499,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1275
1499
  ]
1276
1500
 
1277
1501
  stacks = jnp.array([result[0] for result in results])
1502
+
1503
+ if not include_auxiliary_outputs:
1504
+ return jnp.mean(stacks, axis=0)
1505
+
1278
1506
  outer_products = jnp.array([result[1] for result in results])
1279
1507
  inference_only_outer_products = jnp.array([result[2] for result in results])
1280
1508
  inference_hessians = jnp.array([result[3] for result in results])
@@ -1475,7 +1703,7 @@ def construct_classical_and_adjusted_sandwiches(
1475
1703
  theta_dim = theta_est.shape[0]
1476
1704
  beta_dim = all_post_update_betas.shape[1]
1477
1705
  # Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
1478
- raw_joint_adjusted_bread_matrix, (
1706
+ raw_joint_bread_matrix, (
1479
1707
  avg_estimating_function_stack,
1480
1708
  per_subject_joint_adjusted_meat_contributions,
1481
1709
  per_subject_classical_meat_contributions,
@@ -1533,21 +1761,21 @@ def construct_classical_and_adjusted_sandwiches(
1533
1761
 
1534
1762
  # Increase diagonal block dominance possibly improve conditioning of diagonal
1535
1763
  # blocks as necessary, to ensure mathematical stability of joint bread
1536
- stabilized_joint_adjusted_bread_matrix = (
1764
+ stabilized_joint_bread_matrix = (
1537
1765
  (
1538
1766
  stabilize_joint_bread_if_necessary(
1539
- raw_joint_adjusted_bread_matrix,
1767
+ raw_joint_bread_matrix,
1540
1768
  beta_dim,
1541
1769
  theta_dim,
1542
1770
  )
1543
1771
  )
1544
1772
  if stabilize_joint_bread
1545
- else raw_joint_adjusted_bread_matrix
1773
+ else raw_joint_bread_matrix
1546
1774
  )
1547
1775
 
1548
1776
  # Now stably (no explicit inversion) form our sandwiches.
1549
- joint_adjusted_sandwich = form_sandwich_from_bread_and_meat(
1550
- stabilized_joint_adjusted_bread_matrix,
1777
+ joint_sandwich = form_sandwich_from_bread_and_meat(
1778
+ stabilized_joint_bread_matrix,
1551
1779
  joint_adjusted_meat_matrix,
1552
1780
  num_subjects,
1553
1781
  method=SandwichFormationMethods.BREAD_T_QR,
@@ -1568,7 +1796,7 @@ def construct_classical_and_adjusted_sandwiches(
1568
1796
  form_adjusted_meat_adjustments_directly(
1569
1797
  theta_dim,
1570
1798
  all_post_update_betas.shape[1],
1571
- stabilized_joint_adjusted_bread_matrix,
1799
+ stabilized_joint_bread_matrix,
1572
1800
  per_subject_estimating_function_stacks,
1573
1801
  analysis_df,
1574
1802
  active_col_name,
@@ -1610,7 +1838,7 @@ def construct_classical_and_adjusted_sandwiches(
1610
1838
  method=SandwichFormationMethods.BREAD_T_QR,
1611
1839
  )
1612
1840
  )
1613
- theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
1841
+ theta_only_adjusted_sandwich = joint_sandwich[-theta_dim:, -theta_dim:]
1614
1842
 
1615
1843
  if not np.allclose(
1616
1844
  theta_only_adjusted_sandwich,
@@ -1624,10 +1852,10 @@ def construct_classical_and_adjusted_sandwiches(
1624
1852
  # Stack the joint bread pieces together horizontally and return the auxiliary
1625
1853
  # values too. The joint bread should always be block lower triangular.
1626
1854
  return (
1627
- raw_joint_adjusted_bread_matrix,
1628
- stabilized_joint_adjusted_bread_matrix,
1855
+ raw_joint_bread_matrix,
1856
+ stabilized_joint_bread_matrix,
1629
1857
  joint_adjusted_meat_matrix,
1630
- joint_adjusted_sandwich,
1858
+ joint_sandwich,
1631
1859
  classical_bread_matrix,
1632
1860
  classical_meat_matrix,
1633
1861
  classical_sandwich,
@@ -1643,7 +1871,7 @@ def construct_classical_and_adjusted_sandwiches(
1643
1871
  # important for the subject to know if this is happening. Even if enabled, it is important
1644
1872
  # that the subject know it actually kicks in.
1645
1873
  def stabilize_joint_bread_if_necessary(
1646
- joint_adjusted_bread_matrix: jnp.ndarray,
1874
+ joint_bread_matrix: jnp.ndarray,
1647
1875
  beta_dim: int,
1648
1876
  theta_dim: int,
1649
1877
  ) -> jnp.ndarray:
@@ -1652,7 +1880,7 @@ def stabilize_joint_bread_if_necessary(
1652
1880
  dominance and/or adding a small ridge penalty to the diagonal blocks.
1653
1881
 
1654
1882
  Args:
1655
- joint_adjusted_bread_matrix (jnp.ndarray):
1883
+ joint_bread_matrix (jnp.ndarray):
1656
1884
  A 2-D JAX NumPy array representing the joint bread matrix.
1657
1885
  beta_dim (int):
1658
1886
  The dimension of each beta parameter.
@@ -1673,7 +1901,7 @@ def stabilize_joint_bread_if_necessary(
1673
1901
 
1674
1902
  # Grab just the RL block and convert numpy array for easier manipulation.
1675
1903
  RL_stack_beta_derivatives_block = np.array(
1676
- joint_adjusted_bread_matrix[:-theta_dim, :-theta_dim]
1904
+ joint_bread_matrix[:-theta_dim, :-theta_dim]
1677
1905
  )
1678
1906
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
1679
1907
  for i in range(1, num_updates + 1):
@@ -1792,11 +2020,11 @@ def stabilize_joint_bread_if_necessary(
1792
2020
  [
1793
2021
  [
1794
2022
  RL_stack_beta_derivatives_block,
1795
- joint_adjusted_bread_matrix[:-theta_dim, -theta_dim:],
2023
+ joint_bread_matrix[:-theta_dim, -theta_dim:],
1796
2024
  ],
1797
2025
  [
1798
- joint_adjusted_bread_matrix[-theta_dim:, :-theta_dim],
1799
- joint_adjusted_bread_matrix[-theta_dim:, -theta_dim:],
2026
+ joint_bread_matrix[-theta_dim:, :-theta_dim],
2027
+ joint_bread_matrix[-theta_dim:, -theta_dim:],
1800
2028
  ],
1801
2029
  ]
1802
2030
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lifejacket
3
- Version: 1.0.2
3
+ Version: 1.2.0
4
4
  Summary: Consistent standard errors for longitudinal data collected under pooling online decision policies.
5
5
  Author-email: Nowell Closser <nowellclosser@gmail.com>
6
6
  Requires-Python: >=3.10
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "lifejacket"
7
- version = "1.0.2"
7
+ version = "1.2.0"
8
8
  description = "Consistent standard errors for longitudinal data collected under pooling online decision policies."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.10"
File without changes
File without changes