lifejacket 1.0.2__tar.gz → 1.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lifejacket-1.0.2 → lifejacket-1.2.0}/PKG-INFO +1 -1
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/post_deployment_analysis.py +277 -49
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/PKG-INFO +1 -1
- {lifejacket-1.0.2 → lifejacket-1.2.0}/pyproject.toml +1 -1
- {lifejacket-1.0.2 → lifejacket-1.2.0}/README.md +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/__init__.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/arg_threading_helpers.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/calculate_derivatives.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/constants.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/deployment_conditioning_monitor.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/form_adjusted_meat_adjustments_directly.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/get_datum_for_blowup_supervised_learning.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/helper_functions.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/input_checks.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/small_sample_corrections.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/vmap_helpers.py +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/SOURCES.txt +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/dependency_links.txt +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/entry_points.txt +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/requires.txt +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket.egg-info/top_level.txt +0 -0
- {lifejacket-1.0.2 → lifejacket-1.2.0}/setup.cfg +0 -0
|
@@ -483,10 +483,10 @@ def analyze_dataset(
|
|
|
483
483
|
|
|
484
484
|
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
485
485
|
(
|
|
486
|
-
|
|
487
|
-
|
|
486
|
+
stabilized_joint_bread_matrix,
|
|
487
|
+
raw_joint_bread_matrix,
|
|
488
488
|
joint_adjusted_meat_matrix,
|
|
489
|
-
|
|
489
|
+
joint_sandwich_matrix,
|
|
490
490
|
classical_bread_matrix,
|
|
491
491
|
classical_meat_matrix,
|
|
492
492
|
classical_sandwich_var_estimate,
|
|
@@ -544,9 +544,7 @@ def analyze_dataset(
|
|
|
544
544
|
|
|
545
545
|
# This bottom right corner of the joint (betas and theta) variance matrix is the portion
|
|
546
546
|
# corresponding to just theta.
|
|
547
|
-
adjusted_sandwich_var_estimate =
|
|
548
|
-
-theta_dim:, -theta_dim:
|
|
549
|
-
]
|
|
547
|
+
adjusted_sandwich_var_estimate = joint_sandwich_matrix[-theta_dim:, -theta_dim:]
|
|
550
548
|
|
|
551
549
|
# Check for negative diagonal elements and set them to zero if found
|
|
552
550
|
adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
|
|
@@ -572,31 +570,229 @@ def analyze_dataset(
|
|
|
572
570
|
f,
|
|
573
571
|
)
|
|
574
572
|
|
|
575
|
-
|
|
573
|
+
joint_bread_cond = jnp.linalg.cond(raw_joint_bread_matrix)
|
|
574
|
+
logger.info(
|
|
575
|
+
"Joint bread condition number: %f",
|
|
576
|
+
joint_bread_cond,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
# calculate the max eigenvalue of the theta-only adjusted sandwich
|
|
580
|
+
max_eigenvalue_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
|
|
581
|
+
adjusted_sandwich_var_estimate
|
|
582
|
+
).max()
|
|
576
583
|
logger.info(
|
|
577
|
-
"
|
|
578
|
-
|
|
584
|
+
"Max eigenvalue of theta-only adjusted sandwich matrix: %f",
|
|
585
|
+
max_eigenvalue_theta_only_adjusted_sandwich,
|
|
579
586
|
)
|
|
580
587
|
|
|
581
|
-
#
|
|
582
|
-
|
|
588
|
+
# Compute ratios: max eigenvalue / median eigenvalue among those >= 1e-8 * max.
|
|
589
|
+
eigvals_joint_sandwich = scipy.linalg.eigvalsh(joint_sandwich_matrix)
|
|
590
|
+
max_eig_joint = float(eigvals_joint_sandwich.max())
|
|
583
591
|
logger.info(
|
|
584
592
|
"Max eigenvalue of joint adjusted sandwich matrix: %f",
|
|
585
|
-
|
|
593
|
+
max_eig_joint,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
joint_keep = eigvals_joint_sandwich >= (1e-8 * max_eig_joint)
|
|
597
|
+
joint_median_kept = (
|
|
598
|
+
float(np.median(eigvals_joint_sandwich[joint_keep]))
|
|
599
|
+
if np.any(joint_keep)
|
|
600
|
+
else math.nan
|
|
601
|
+
)
|
|
602
|
+
max_to_median_ratio_joint_sandwich = (
|
|
603
|
+
(max_eig_joint / joint_median_kept)
|
|
604
|
+
if (not math.isnan(joint_median_kept) and joint_median_kept > 0)
|
|
605
|
+
else (
|
|
606
|
+
math.inf
|
|
607
|
+
if (not math.isnan(joint_median_kept) and joint_median_kept == 0)
|
|
608
|
+
else math.nan
|
|
609
|
+
)
|
|
610
|
+
)
|
|
611
|
+
logger.info(
|
|
612
|
+
"Max/median eigenvalue ratio (joint sandwich; median over eigvals >= 1e-8*max): %f",
|
|
613
|
+
max_to_median_ratio_joint_sandwich,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
eigvals_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
|
|
617
|
+
adjusted_sandwich_var_estimate
|
|
618
|
+
)
|
|
619
|
+
max_eig_theta = float(eigvals_theta_only_adjusted_sandwich.max())
|
|
620
|
+
theta_keep = eigvals_theta_only_adjusted_sandwich >= (1e-8 * max_eig_theta)
|
|
621
|
+
theta_median_kept = (
|
|
622
|
+
float(np.median(eigvals_theta_only_adjusted_sandwich[theta_keep]))
|
|
623
|
+
if np.any(theta_keep)
|
|
624
|
+
else math.nan
|
|
625
|
+
)
|
|
626
|
+
max_to_median_ratio_theta_only_adjusted_sandwich = (
|
|
627
|
+
(max_eig_theta / theta_median_kept)
|
|
628
|
+
if (not math.isnan(theta_median_kept) and theta_median_kept > 0)
|
|
629
|
+
else (
|
|
630
|
+
math.inf
|
|
631
|
+
if (not math.isnan(theta_median_kept) and theta_median_kept == 0)
|
|
632
|
+
else math.nan
|
|
633
|
+
)
|
|
634
|
+
)
|
|
635
|
+
logger.info(
|
|
636
|
+
"Max/median eigenvalue ratio (theta-only adjusted sandwich; median over eigvals >= 1e-8*max): %f",
|
|
637
|
+
max_to_median_ratio_theta_only_adjusted_sandwich,
|
|
586
638
|
)
|
|
587
639
|
|
|
640
|
+
# --- Local linearization validity diagnostic (single-run) ---
|
|
641
|
+
# We compare the nonlinear Taylor remainder of the joint estimating-function map to the
|
|
642
|
+
# retained linear term, at perturbations on the O(1/sqrt(n)) scale.
|
|
643
|
+
#
|
|
644
|
+
# Define r(delta) = || g(eta+delta) - g(eta) - B delta ||_2 / || B delta ||_2,
|
|
645
|
+
# where g(eta) is the avg per-subject weighted estimating-function stack and B is the
|
|
646
|
+
# stabilized joint bread (Jacobian of g w.r.t. flattened betas+theta).
|
|
647
|
+
#
|
|
648
|
+
# This ratio is dimensionless and can be used as a necessary/sanity diagnostic that the
|
|
649
|
+
# first-order linearization is locally accurate at the estimation scale.
|
|
650
|
+
|
|
651
|
+
def _compute_local_linearization_error_ratio() -> tuple[float, float]:
|
|
652
|
+
# Ensure float64 for diagnostics even if upstream ran in float32.
|
|
653
|
+
joint_bread_float64 = jnp.asarray(
|
|
654
|
+
stabilized_joint_bread_matrix, dtype=jnp.float64
|
|
655
|
+
)
|
|
656
|
+
g_hat = jnp.asarray(avg_estimating_function_stack, dtype=jnp.float64)
|
|
657
|
+
stacks_float64 = jnp.asarray(
|
|
658
|
+
per_subject_estimating_function_stacks, dtype=jnp.float64
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
num_subjects = stacks_float64.shape[0]
|
|
662
|
+
|
|
663
|
+
def _eval_avg_stack_jit(flattened_betas_and_theta: jnp.ndarray) -> jnp.ndarray:
|
|
664
|
+
return jnp.asarray(
|
|
665
|
+
get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
666
|
+
flattened_betas_and_theta,
|
|
667
|
+
beta_dim,
|
|
668
|
+
theta_dim,
|
|
669
|
+
subject_ids,
|
|
670
|
+
action_prob_func,
|
|
671
|
+
action_prob_func_args_beta_index,
|
|
672
|
+
alg_update_func,
|
|
673
|
+
alg_update_func_type,
|
|
674
|
+
alg_update_func_args_beta_index,
|
|
675
|
+
alg_update_func_args_action_prob_index,
|
|
676
|
+
alg_update_func_args_action_prob_times_index,
|
|
677
|
+
alg_update_func_args_previous_betas_index,
|
|
678
|
+
inference_func,
|
|
679
|
+
inference_func_type,
|
|
680
|
+
inference_func_args_theta_index,
|
|
681
|
+
inference_func_args_action_prob_index,
|
|
682
|
+
action_prob_func_args,
|
|
683
|
+
policy_num_by_decision_time_by_subject_id,
|
|
684
|
+
initial_policy_num,
|
|
685
|
+
beta_index_by_policy_num,
|
|
686
|
+
inference_func_args_by_subject_id,
|
|
687
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
688
|
+
alg_update_func_args,
|
|
689
|
+
action_by_decision_time_by_subject_id,
|
|
690
|
+
True, # suppress_all_data_checks
|
|
691
|
+
True, # suppress_interactive_data_checks
|
|
692
|
+
False, # include_auxiliary_outputs
|
|
693
|
+
),
|
|
694
|
+
dtype=jnp.float64,
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
# Evaluate at the final estimate.
|
|
698
|
+
eta_hat = jnp.asarray(
|
|
699
|
+
flatten_params(all_post_update_betas, theta_est), dtype=jnp.float64
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
# Draw perturbations delta_j on the O(1/sqrt(n)) scale, aligned with the empirical
|
|
703
|
+
# joint estimating function stack covariance, without forming a d_joint x d_joint matrix
|
|
704
|
+
# square-root. If G is the (n x d) matrix of per-subject stacks, then (1/n) G^T G is the
|
|
705
|
+
# empirical covariance in joint estimating function stack space. Sampling u = (G^T w)/sqrt(n) with w~N(0, I_n) gives
|
|
706
|
+
# u ~ N(0, empirical joint estimating function stack covariance G^T G/n ) in joint estimating function stack space.
|
|
707
|
+
key = jax.random.PRNGKey(0)
|
|
708
|
+
|
|
709
|
+
# The number of perturbations we will probe
|
|
710
|
+
J = 15
|
|
711
|
+
# Each requires num_subjects standard normal draws, which we will then transform
|
|
712
|
+
# into joint estimating function space perturbations in U
|
|
713
|
+
W = jax.random.normal(key, shape=(J, num_subjects), dtype=jnp.float64)
|
|
714
|
+
|
|
715
|
+
# Joint estimating function space perturbations: u_j in R^{d_joint}
|
|
716
|
+
# U = (1/sqrt(n)) * W G, where rows of G are g_i^T
|
|
717
|
+
U = (W @ stacks_float64) / jnp.sqrt(num_subjects)
|
|
718
|
+
|
|
719
|
+
# Parameter perturbations: delta = (c/sqrt(n)) * B^{-1} u
|
|
720
|
+
# Use solve rather than explicit inverse.
|
|
721
|
+
c = 1.0
|
|
722
|
+
delta = (c / jnp.sqrt(num_subjects)) * jnp.linalg.solve(
|
|
723
|
+
joint_bread_float64, U.T
|
|
724
|
+
).T
|
|
725
|
+
|
|
726
|
+
# Compute ratios r_j.
|
|
727
|
+
# NOTE: We use the Euclidean norm in score space; this is dimensionless and avoids
|
|
728
|
+
# forming/pseudoinverting a potentially rank-deficient matrix.
|
|
729
|
+
B_delta = (joint_bread_float64 @ delta.T).T
|
|
730
|
+
g_plus = jax.vmap(lambda d: _eval_avg_stack_jit(eta_hat + d))(delta)
|
|
731
|
+
remainder = g_plus - g_hat - B_delta
|
|
732
|
+
|
|
733
|
+
denom = jnp.linalg.norm(B_delta, axis=1)
|
|
734
|
+
numer = jnp.linalg.norm(remainder, axis=1)
|
|
735
|
+
|
|
736
|
+
# Avoid division by zero (should not happen unless delta collapses numerically).
|
|
737
|
+
ratios = jnp.where(denom > 0, numer / denom, jnp.inf)
|
|
738
|
+
|
|
739
|
+
local_error_ratio_median = float(jnp.median(ratios))
|
|
740
|
+
local_error_ratio_p90 = float(jnp.quantile(ratios, 0.9))
|
|
741
|
+
local_error_ratio_max = float(jnp.max(ratios))
|
|
742
|
+
|
|
743
|
+
logger.info(
|
|
744
|
+
"Local linearization error ratio (median over %d draws): %.6f",
|
|
745
|
+
J,
|
|
746
|
+
local_error_ratio_median,
|
|
747
|
+
)
|
|
748
|
+
logger.info(
|
|
749
|
+
"Local linearization error ratio (90th pct over %d draws): %.6f",
|
|
750
|
+
J,
|
|
751
|
+
local_error_ratio_p90,
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
logger.info(
|
|
755
|
+
"Local linearization error ratio (max over %d draws): %.6f",
|
|
756
|
+
J,
|
|
757
|
+
local_error_ratio_max,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
return local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max
|
|
761
|
+
|
|
762
|
+
try:
|
|
763
|
+
local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max = (
|
|
764
|
+
_compute_local_linearization_error_ratio()
|
|
765
|
+
)
|
|
766
|
+
except Exception as e:
|
|
767
|
+
# This diagnostic is best-effort; failure should not break analysis.
|
|
768
|
+
logger.warning(
|
|
769
|
+
"Failed to compute local linearization error ratio diagnostic: %s",
|
|
770
|
+
str(e),
|
|
771
|
+
)
|
|
772
|
+
local_error_ratio_median = math.nan
|
|
773
|
+
local_error_ratio_p90 = math.nan
|
|
774
|
+
local_error_ratio_max = math.nan
|
|
775
|
+
|
|
588
776
|
debug_pieces_dict = {
|
|
589
777
|
"theta_est": theta_est,
|
|
590
778
|
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
591
779
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
592
|
-
"raw_joint_bread_matrix":
|
|
593
|
-
"stabilized_joint_bread_matrix":
|
|
780
|
+
"raw_joint_bread_matrix": raw_joint_bread_matrix,
|
|
781
|
+
"stabilized_joint_bread_matrix": stabilized_joint_bread_matrix,
|
|
594
782
|
"joint_meat_matrix": joint_adjusted_meat_matrix,
|
|
595
783
|
"classical_bread_matrix": classical_bread_matrix,
|
|
596
784
|
"classical_meat_matrix": classical_meat_matrix,
|
|
597
785
|
"all_estimating_function_stacks": per_subject_estimating_function_stacks,
|
|
598
|
-
"joint_bread_condition_number":
|
|
599
|
-
"
|
|
786
|
+
"joint_bread_condition_number": joint_bread_cond,
|
|
787
|
+
"max_eigenvalue_joint_sandwich": max_eig_joint,
|
|
788
|
+
"all_eigenvalues_joint_sandwich": eigvals_joint_sandwich,
|
|
789
|
+
"max_to_median_ratio_joint_sandwich": max_to_median_ratio_joint_sandwich,
|
|
790
|
+
"max_eigenvalue_theta_only_adjusted_sandwich": max_eig_theta,
|
|
791
|
+
"all_eigenvalues_theta_only_adjusted_sandwich": eigvals_theta_only_adjusted_sandwich,
|
|
792
|
+
"max_to_median_ratio_theta_only_adjusted_sandwich": max_to_median_ratio_theta_only_adjusted_sandwich,
|
|
793
|
+
"local_linearization_error_ratio_median": local_error_ratio_median,
|
|
794
|
+
"local_linearization_error_ratio_p90": local_error_ratio_p90,
|
|
795
|
+
"local_linearization_error_ratio_max": local_error_ratio_max,
|
|
600
796
|
"all_post_update_betas": all_post_update_betas,
|
|
601
797
|
"per_subject_adjusted_corrections": per_subject_adjusted_corrections,
|
|
602
798
|
"per_subject_classical_corrections": per_subject_classical_corrections,
|
|
@@ -610,8 +806,8 @@ def analyze_dataset(
|
|
|
610
806
|
|
|
611
807
|
if collect_data_for_blowup_supervised_learning:
|
|
612
808
|
datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
|
|
613
|
-
|
|
614
|
-
|
|
809
|
+
raw_joint_bread_matrix,
|
|
810
|
+
joint_bread_cond,
|
|
615
811
|
avg_estimating_function_stack,
|
|
616
812
|
per_subject_estimating_function_stacks,
|
|
617
813
|
all_post_update_betas,
|
|
@@ -756,12 +952,16 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
756
952
|
policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
|
|
757
953
|
action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
|
|
758
954
|
beta_index_by_policy_num: dict[int | float, int],
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
]
|
|
955
|
+
include_auxiliary_outputs: bool = True,
|
|
956
|
+
) -> (
|
|
957
|
+
tuple[
|
|
958
|
+
jnp.ndarray[jnp.float32],
|
|
959
|
+
jnp.ndarray[jnp.float32],
|
|
960
|
+
jnp.ndarray[jnp.float32],
|
|
961
|
+
jnp.ndarray[jnp.float32],
|
|
962
|
+
]
|
|
963
|
+
| jnp.ndarray[jnp.float32]
|
|
964
|
+
):
|
|
765
965
|
"""
|
|
766
966
|
Computes a weighted estimating function stack for a given algorithm estimating function
|
|
767
967
|
and arguments, inference estimating functio and arguments, and action probability function and
|
|
@@ -825,12 +1025,23 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
825
1025
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
826
1026
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
827
1027
|
|
|
1028
|
+
include_auxiliary_outputs (bool):
|
|
1029
|
+
If True, returns the adjusted meat, classical meat, and classical bread contributions in
|
|
1030
|
+
a second returned tuple. If False, only returns the weighted estimating function stack.
|
|
1031
|
+
|
|
828
1032
|
Returns:
|
|
829
1033
|
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
830
1034
|
stack.
|
|
831
1035
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
|
|
832
1036
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
833
1037
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
1038
|
+
|
|
1039
|
+
or
|
|
1040
|
+
|
|
1041
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
1042
|
+
stack.
|
|
1043
|
+
|
|
1044
|
+
depending on the value of include_auxiliary_outputs.
|
|
834
1045
|
"""
|
|
835
1046
|
|
|
836
1047
|
logger.info(
|
|
@@ -1020,14 +1231,18 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
1020
1231
|
# c. The third output is averaged across subjects to obtain the classical meat matrix.
|
|
1021
1232
|
# d. The fourth output is averaged across subjects to obtain the inverse classical bread
|
|
1022
1233
|
# matrix.
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1234
|
+
if include_auxiliary_outputs:
|
|
1235
|
+
return (
|
|
1236
|
+
weighted_stack,
|
|
1237
|
+
jnp.outer(weighted_stack, weighted_stack),
|
|
1238
|
+
jnp.outer(inference_component, inference_component),
|
|
1239
|
+
jax.jacrev(
|
|
1240
|
+
inference_estimating_func, argnums=inference_func_args_theta_index
|
|
1241
|
+
)(*threaded_inference_func_args),
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
else:
|
|
1245
|
+
return weighted_stack
|
|
1031
1246
|
|
|
1032
1247
|
|
|
1033
1248
|
def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
@@ -1067,6 +1282,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1067
1282
|
],
|
|
1068
1283
|
suppress_all_data_checks: bool,
|
|
1069
1284
|
suppress_interactive_data_checks: bool,
|
|
1285
|
+
include_auxiliary_outputs: bool = True,
|
|
1070
1286
|
) -> tuple[
|
|
1071
1287
|
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
1072
1288
|
]:
|
|
@@ -1141,10 +1357,14 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1141
1357
|
If True, suppresses interactive data checks that would otherwise be performed to ensure
|
|
1142
1358
|
the correctness of the threaded arguments. The checks are still performed, but
|
|
1143
1359
|
any interactive prompts are suppressed.
|
|
1360
|
+
include_auxiliary_outputs (bool):
|
|
1361
|
+
If True, returns the adjusted meat, classical meat, and classical bread contributions in addition to the average weighted estimating function stack.
|
|
1362
|
+
If False, returns only the average weighted estimating function stack.
|
|
1144
1363
|
|
|
1145
1364
|
Returns:
|
|
1146
1365
|
jnp.ndarray:
|
|
1147
1366
|
A 2D JAX NumPy array holding the average weighted estimating function stack.
|
|
1367
|
+
|
|
1148
1368
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1149
1369
|
A tuple containing
|
|
1150
1370
|
1. the average weighted estimating function stack
|
|
@@ -1153,6 +1373,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1153
1373
|
4. the subject-level inverse classical bread matrix contributions
|
|
1154
1374
|
5. raw per-subject weighted estimating function
|
|
1155
1375
|
stacks.
|
|
1376
|
+
or jnp.ndarray:
|
|
1377
|
+
A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
1378
|
+
stack.
|
|
1379
|
+
depending on the value of include_auxiliary_outputs.
|
|
1156
1380
|
"""
|
|
1157
1381
|
|
|
1158
1382
|
# 1. Collect estimating functions by differentiating the loss functions if needed.
|
|
@@ -1275,6 +1499,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1275
1499
|
]
|
|
1276
1500
|
|
|
1277
1501
|
stacks = jnp.array([result[0] for result in results])
|
|
1502
|
+
|
|
1503
|
+
if not include_auxiliary_outputs:
|
|
1504
|
+
return jnp.mean(stacks, axis=0)
|
|
1505
|
+
|
|
1278
1506
|
outer_products = jnp.array([result[1] for result in results])
|
|
1279
1507
|
inference_only_outer_products = jnp.array([result[2] for result in results])
|
|
1280
1508
|
inference_hessians = jnp.array([result[3] for result in results])
|
|
@@ -1475,7 +1703,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1475
1703
|
theta_dim = theta_est.shape[0]
|
|
1476
1704
|
beta_dim = all_post_update_betas.shape[1]
|
|
1477
1705
|
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
1478
|
-
|
|
1706
|
+
raw_joint_bread_matrix, (
|
|
1479
1707
|
avg_estimating_function_stack,
|
|
1480
1708
|
per_subject_joint_adjusted_meat_contributions,
|
|
1481
1709
|
per_subject_classical_meat_contributions,
|
|
@@ -1533,21 +1761,21 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1533
1761
|
|
|
1534
1762
|
# Increase diagonal block dominance possibly improve conditioning of diagonal
|
|
1535
1763
|
# blocks as necessary, to ensure mathematical stability of joint bread
|
|
1536
|
-
|
|
1764
|
+
stabilized_joint_bread_matrix = (
|
|
1537
1765
|
(
|
|
1538
1766
|
stabilize_joint_bread_if_necessary(
|
|
1539
|
-
|
|
1767
|
+
raw_joint_bread_matrix,
|
|
1540
1768
|
beta_dim,
|
|
1541
1769
|
theta_dim,
|
|
1542
1770
|
)
|
|
1543
1771
|
)
|
|
1544
1772
|
if stabilize_joint_bread
|
|
1545
|
-
else
|
|
1773
|
+
else raw_joint_bread_matrix
|
|
1546
1774
|
)
|
|
1547
1775
|
|
|
1548
1776
|
# Now stably (no explicit inversion) form our sandwiches.
|
|
1549
|
-
|
|
1550
|
-
|
|
1777
|
+
joint_sandwich = form_sandwich_from_bread_and_meat(
|
|
1778
|
+
stabilized_joint_bread_matrix,
|
|
1551
1779
|
joint_adjusted_meat_matrix,
|
|
1552
1780
|
num_subjects,
|
|
1553
1781
|
method=SandwichFormationMethods.BREAD_T_QR,
|
|
@@ -1568,7 +1796,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1568
1796
|
form_adjusted_meat_adjustments_directly(
|
|
1569
1797
|
theta_dim,
|
|
1570
1798
|
all_post_update_betas.shape[1],
|
|
1571
|
-
|
|
1799
|
+
stabilized_joint_bread_matrix,
|
|
1572
1800
|
per_subject_estimating_function_stacks,
|
|
1573
1801
|
analysis_df,
|
|
1574
1802
|
active_col_name,
|
|
@@ -1610,7 +1838,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1610
1838
|
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1611
1839
|
)
|
|
1612
1840
|
)
|
|
1613
|
-
theta_only_adjusted_sandwich =
|
|
1841
|
+
theta_only_adjusted_sandwich = joint_sandwich[-theta_dim:, -theta_dim:]
|
|
1614
1842
|
|
|
1615
1843
|
if not np.allclose(
|
|
1616
1844
|
theta_only_adjusted_sandwich,
|
|
@@ -1624,10 +1852,10 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1624
1852
|
# Stack the joint bread pieces together horizontally and return the auxiliary
|
|
1625
1853
|
# values too. The joint bread should always be block lower triangular.
|
|
1626
1854
|
return (
|
|
1627
|
-
|
|
1628
|
-
|
|
1855
|
+
raw_joint_bread_matrix,
|
|
1856
|
+
stabilized_joint_bread_matrix,
|
|
1629
1857
|
joint_adjusted_meat_matrix,
|
|
1630
|
-
|
|
1858
|
+
joint_sandwich,
|
|
1631
1859
|
classical_bread_matrix,
|
|
1632
1860
|
classical_meat_matrix,
|
|
1633
1861
|
classical_sandwich,
|
|
@@ -1643,7 +1871,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1643
1871
|
# important for the subject to know if this is happening. Even if enabled, it is important
|
|
1644
1872
|
# that the subject know it actually kicks in.
|
|
1645
1873
|
def stabilize_joint_bread_if_necessary(
|
|
1646
|
-
|
|
1874
|
+
joint_bread_matrix: jnp.ndarray,
|
|
1647
1875
|
beta_dim: int,
|
|
1648
1876
|
theta_dim: int,
|
|
1649
1877
|
) -> jnp.ndarray:
|
|
@@ -1652,7 +1880,7 @@ def stabilize_joint_bread_if_necessary(
|
|
|
1652
1880
|
dominance and/or adding a small ridge penalty to the diagonal blocks.
|
|
1653
1881
|
|
|
1654
1882
|
Args:
|
|
1655
|
-
|
|
1883
|
+
joint_bread_matrix (jnp.ndarray):
|
|
1656
1884
|
A 2-D JAX NumPy array representing the joint bread matrix.
|
|
1657
1885
|
beta_dim (int):
|
|
1658
1886
|
The dimension of each beta parameter.
|
|
@@ -1673,7 +1901,7 @@ def stabilize_joint_bread_if_necessary(
|
|
|
1673
1901
|
|
|
1674
1902
|
# Grab just the RL block and convert numpy array for easier manipulation.
|
|
1675
1903
|
RL_stack_beta_derivatives_block = np.array(
|
|
1676
|
-
|
|
1904
|
+
joint_bread_matrix[:-theta_dim, :-theta_dim]
|
|
1677
1905
|
)
|
|
1678
1906
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
1679
1907
|
for i in range(1, num_updates + 1):
|
|
@@ -1792,11 +2020,11 @@ def stabilize_joint_bread_if_necessary(
|
|
|
1792
2020
|
[
|
|
1793
2021
|
[
|
|
1794
2022
|
RL_stack_beta_derivatives_block,
|
|
1795
|
-
|
|
2023
|
+
joint_bread_matrix[:-theta_dim, -theta_dim:],
|
|
1796
2024
|
],
|
|
1797
2025
|
[
|
|
1798
|
-
|
|
1799
|
-
|
|
2026
|
+
joint_bread_matrix[-theta_dim:, :-theta_dim],
|
|
2027
|
+
joint_bread_matrix[-theta_dim:, -theta_dim:],
|
|
1800
2028
|
],
|
|
1801
2029
|
]
|
|
1802
2030
|
)
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "lifejacket"
|
|
7
|
-
version = "1.0
|
|
7
|
+
version = "1.2.0"
|
|
8
8
|
description = "Consistent standard errors for longitudinal data collected under pooling online decision policies."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.10"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{lifejacket-1.0.2 → lifejacket-1.2.0}/lifejacket/get_datum_for_blowup_supervised_learning.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|