lifejacket 1.0.2__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/post_deployment_analysis.py +279 -49
- {lifejacket-1.0.2.dist-info → lifejacket-1.1.0.dist-info}/METADATA +1 -1
- {lifejacket-1.0.2.dist-info → lifejacket-1.1.0.dist-info}/RECORD +6 -6
- {lifejacket-1.0.2.dist-info → lifejacket-1.1.0.dist-info}/WHEEL +1 -1
- {lifejacket-1.0.2.dist-info → lifejacket-1.1.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-1.0.2.dist-info → lifejacket-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -53,6 +53,8 @@ logging.basicConfig(
|
|
|
53
53
|
level=logging.INFO,
|
|
54
54
|
)
|
|
55
55
|
|
|
56
|
+
jax.config.update("jax_enable_x64", True)
|
|
57
|
+
|
|
56
58
|
|
|
57
59
|
@click.group()
|
|
58
60
|
def cli():
|
|
@@ -483,10 +485,10 @@ def analyze_dataset(
|
|
|
483
485
|
|
|
484
486
|
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
485
487
|
(
|
|
486
|
-
|
|
487
|
-
|
|
488
|
+
stabilized_joint_bread_matrix,
|
|
489
|
+
raw_joint_bread_matrix,
|
|
488
490
|
joint_adjusted_meat_matrix,
|
|
489
|
-
|
|
491
|
+
joint_sandwich_matrix,
|
|
490
492
|
classical_bread_matrix,
|
|
491
493
|
classical_meat_matrix,
|
|
492
494
|
classical_sandwich_var_estimate,
|
|
@@ -544,9 +546,7 @@ def analyze_dataset(
|
|
|
544
546
|
|
|
545
547
|
# This bottom right corner of the joint (betas and theta) variance matrix is the portion
|
|
546
548
|
# corresponding to just theta.
|
|
547
|
-
adjusted_sandwich_var_estimate =
|
|
548
|
-
-theta_dim:, -theta_dim:
|
|
549
|
-
]
|
|
549
|
+
adjusted_sandwich_var_estimate = joint_sandwich_matrix[-theta_dim:, -theta_dim:]
|
|
550
550
|
|
|
551
551
|
# Check for negative diagonal elements and set them to zero if found
|
|
552
552
|
adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
|
|
@@ -572,31 +572,229 @@ def analyze_dataset(
|
|
|
572
572
|
f,
|
|
573
573
|
)
|
|
574
574
|
|
|
575
|
-
|
|
575
|
+
joint_bread_cond = jnp.linalg.cond(raw_joint_bread_matrix)
|
|
576
|
+
logger.info(
|
|
577
|
+
"Joint bread condition number: %f",
|
|
578
|
+
joint_bread_cond,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# calculate the max eigenvalue of the theta-only adjusted sandwich
|
|
582
|
+
max_eigenvalue_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
|
|
583
|
+
adjusted_sandwich_var_estimate
|
|
584
|
+
).max()
|
|
576
585
|
logger.info(
|
|
577
|
-
"
|
|
578
|
-
|
|
586
|
+
"Max eigenvalue of theta-only adjusted sandwich matrix: %f",
|
|
587
|
+
max_eigenvalue_theta_only_adjusted_sandwich,
|
|
579
588
|
)
|
|
580
589
|
|
|
581
|
-
#
|
|
582
|
-
|
|
590
|
+
# Compute ratios: max eigenvalue / median eigenvalue among those >= 1e-8 * max.
|
|
591
|
+
eigvals_joint_sandwich = scipy.linalg.eigvalsh(joint_sandwich_matrix)
|
|
592
|
+
max_eig_joint = float(eigvals_joint_sandwich.max())
|
|
583
593
|
logger.info(
|
|
584
594
|
"Max eigenvalue of joint adjusted sandwich matrix: %f",
|
|
585
|
-
|
|
595
|
+
max_eig_joint,
|
|
586
596
|
)
|
|
587
597
|
|
|
598
|
+
joint_keep = eigvals_joint_sandwich >= (1e-8 * max_eig_joint)
|
|
599
|
+
joint_median_kept = (
|
|
600
|
+
float(np.median(eigvals_joint_sandwich[joint_keep]))
|
|
601
|
+
if np.any(joint_keep)
|
|
602
|
+
else math.nan
|
|
603
|
+
)
|
|
604
|
+
max_to_median_ratio_joint_sandwich = (
|
|
605
|
+
(max_eig_joint / joint_median_kept)
|
|
606
|
+
if (not math.isnan(joint_median_kept) and joint_median_kept > 0)
|
|
607
|
+
else (
|
|
608
|
+
math.inf
|
|
609
|
+
if (not math.isnan(joint_median_kept) and joint_median_kept == 0)
|
|
610
|
+
else math.nan
|
|
611
|
+
)
|
|
612
|
+
)
|
|
613
|
+
logger.info(
|
|
614
|
+
"Max/median eigenvalue ratio (joint sandwich; median over eigvals >= 1e-8*max): %f",
|
|
615
|
+
max_to_median_ratio_joint_sandwich,
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
eigvals_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
|
|
619
|
+
adjusted_sandwich_var_estimate
|
|
620
|
+
)
|
|
621
|
+
max_eig_theta = float(eigvals_theta_only_adjusted_sandwich.max())
|
|
622
|
+
theta_keep = eigvals_theta_only_adjusted_sandwich >= (1e-8 * max_eig_theta)
|
|
623
|
+
theta_median_kept = (
|
|
624
|
+
float(np.median(eigvals_theta_only_adjusted_sandwich[theta_keep]))
|
|
625
|
+
if np.any(theta_keep)
|
|
626
|
+
else math.nan
|
|
627
|
+
)
|
|
628
|
+
max_to_median_ratio_theta_only_adjusted_sandwich = (
|
|
629
|
+
(max_eig_theta / theta_median_kept)
|
|
630
|
+
if (not math.isnan(theta_median_kept) and theta_median_kept > 0)
|
|
631
|
+
else (
|
|
632
|
+
math.inf
|
|
633
|
+
if (not math.isnan(theta_median_kept) and theta_median_kept == 0)
|
|
634
|
+
else math.nan
|
|
635
|
+
)
|
|
636
|
+
)
|
|
637
|
+
logger.info(
|
|
638
|
+
"Max/median eigenvalue ratio (theta-only adjusted sandwich; median over eigvals >= 1e-8*max): %f",
|
|
639
|
+
max_to_median_ratio_theta_only_adjusted_sandwich,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
# --- Local linearization validity diagnostic (single-run) ---
|
|
643
|
+
# We compare the nonlinear Taylor remainder of the joint estimating-function map to the
|
|
644
|
+
# retained linear term, at perturbations on the O(1/sqrt(n)) scale.
|
|
645
|
+
#
|
|
646
|
+
# Define r(delta) = || g(eta+delta) - g(eta) - B delta ||_2 / || B delta ||_2,
|
|
647
|
+
# where g(eta) is the avg per-subject weighted estimating-function stack and B is the
|
|
648
|
+
# stabilized joint bread (Jacobian of g w.r.t. flattened betas+theta).
|
|
649
|
+
#
|
|
650
|
+
# This ratio is dimensionless and can be used as a necessary/sanity diagnostic that the
|
|
651
|
+
# first-order linearization is locally accurate at the estimation scale.
|
|
652
|
+
|
|
653
|
+
def _compute_local_linearization_error_ratio() -> tuple[float, float]:
|
|
654
|
+
# Ensure float64 for diagnostics even if upstream ran in float32.
|
|
655
|
+
joint_bread_float64 = jnp.asarray(
|
|
656
|
+
stabilized_joint_bread_matrix, dtype=jnp.float64
|
|
657
|
+
)
|
|
658
|
+
g_hat = jnp.asarray(avg_estimating_function_stack, dtype=jnp.float64)
|
|
659
|
+
stacks_float64 = jnp.asarray(
|
|
660
|
+
per_subject_estimating_function_stacks, dtype=jnp.float64
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
num_subjects = stacks_float64.shape[0]
|
|
664
|
+
|
|
665
|
+
def _eval_avg_stack_jit(flattened_betas_and_theta: jnp.ndarray) -> jnp.ndarray:
|
|
666
|
+
return jnp.asarray(
|
|
667
|
+
get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
668
|
+
flattened_betas_and_theta,
|
|
669
|
+
beta_dim,
|
|
670
|
+
theta_dim,
|
|
671
|
+
subject_ids,
|
|
672
|
+
action_prob_func,
|
|
673
|
+
action_prob_func_args_beta_index,
|
|
674
|
+
alg_update_func,
|
|
675
|
+
alg_update_func_type,
|
|
676
|
+
alg_update_func_args_beta_index,
|
|
677
|
+
alg_update_func_args_action_prob_index,
|
|
678
|
+
alg_update_func_args_action_prob_times_index,
|
|
679
|
+
alg_update_func_args_previous_betas_index,
|
|
680
|
+
inference_func,
|
|
681
|
+
inference_func_type,
|
|
682
|
+
inference_func_args_theta_index,
|
|
683
|
+
inference_func_args_action_prob_index,
|
|
684
|
+
action_prob_func_args,
|
|
685
|
+
policy_num_by_decision_time_by_subject_id,
|
|
686
|
+
initial_policy_num,
|
|
687
|
+
beta_index_by_policy_num,
|
|
688
|
+
inference_func_args_by_subject_id,
|
|
689
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
690
|
+
alg_update_func_args,
|
|
691
|
+
action_by_decision_time_by_subject_id,
|
|
692
|
+
True, # suppress_all_data_checks
|
|
693
|
+
True, # suppress_interactive_data_checks
|
|
694
|
+
False, # include_auxiliary_outputs
|
|
695
|
+
),
|
|
696
|
+
dtype=jnp.float64,
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
# Evaluate at the final estimate.
|
|
700
|
+
eta_hat = jnp.asarray(
|
|
701
|
+
flatten_params(all_post_update_betas, theta_est), dtype=jnp.float64
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Draw perturbations delta_j on the O(1/sqrt(n)) scale, aligned with the empirical
|
|
705
|
+
# joint estimating function stack covariance, without forming a d_joint x d_joint matrix
|
|
706
|
+
# square-root. If G is the (n x d) matrix of per-subject stacks, then (1/n) G^T G is the
|
|
707
|
+
# empirical covariance in joint estimating function stack space. Sampling u = (G^T w)/sqrt(n) with w~N(0, I_n) gives
|
|
708
|
+
# u ~ N(0, empirical joint estimating function stack covariance G^T G/n ) in joint estimating function stack space.
|
|
709
|
+
key = jax.random.PRNGKey(0)
|
|
710
|
+
|
|
711
|
+
# The number of perturbations we will probe
|
|
712
|
+
J = 15
|
|
713
|
+
# Each requires num_subjects standard normal draws, which we will then transform
|
|
714
|
+
# into joint estimating function space perturbations in U
|
|
715
|
+
W = jax.random.normal(key, shape=(J, num_subjects), dtype=jnp.float64)
|
|
716
|
+
|
|
717
|
+
# Joint estimating function space perturbations: u_j in R^{d_joint}
|
|
718
|
+
# U = (1/sqrt(n)) * W G, where rows of G are g_i^T
|
|
719
|
+
U = (W @ stacks_float64) / jnp.sqrt(num_subjects)
|
|
720
|
+
|
|
721
|
+
# Parameter perturbations: delta = (c/sqrt(n)) * B^{-1} u
|
|
722
|
+
# Use solve rather than explicit inverse.
|
|
723
|
+
c = 1.0
|
|
724
|
+
delta = (c / jnp.sqrt(num_subjects)) * jnp.linalg.solve(
|
|
725
|
+
joint_bread_float64, U.T
|
|
726
|
+
).T
|
|
727
|
+
|
|
728
|
+
# Compute ratios r_j.
|
|
729
|
+
# NOTE: We use the Euclidean norm in score space; this is dimensionless and avoids
|
|
730
|
+
# forming/pseudoinverting a potentially rank-deficient matrix.
|
|
731
|
+
B_delta = (joint_bread_float64 @ delta.T).T
|
|
732
|
+
g_plus = jax.vmap(lambda d: _eval_avg_stack_jit(eta_hat + d))(delta)
|
|
733
|
+
remainder = g_plus - g_hat - B_delta
|
|
734
|
+
|
|
735
|
+
denom = jnp.linalg.norm(B_delta, axis=1)
|
|
736
|
+
numer = jnp.linalg.norm(remainder, axis=1)
|
|
737
|
+
|
|
738
|
+
# Avoid division by zero (should not happen unless delta collapses numerically).
|
|
739
|
+
ratios = jnp.where(denom > 0, numer / denom, jnp.inf)
|
|
740
|
+
|
|
741
|
+
local_error_ratio_median = float(jnp.median(ratios))
|
|
742
|
+
local_error_ratio_p90 = float(jnp.quantile(ratios, 0.9))
|
|
743
|
+
local_error_ratio_max = float(jnp.max(ratios))
|
|
744
|
+
|
|
745
|
+
logger.info(
|
|
746
|
+
"Local linearization error ratio (median over %d draws): %.6f",
|
|
747
|
+
J,
|
|
748
|
+
local_error_ratio_median,
|
|
749
|
+
)
|
|
750
|
+
logger.info(
|
|
751
|
+
"Local linearization error ratio (90th pct over %d draws): %.6f",
|
|
752
|
+
J,
|
|
753
|
+
local_error_ratio_p90,
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
logger.info(
|
|
757
|
+
"Local linearization error ratio (max over %d draws): %.6f",
|
|
758
|
+
J,
|
|
759
|
+
local_error_ratio_max,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
return local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max
|
|
763
|
+
|
|
764
|
+
try:
|
|
765
|
+
local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max = (
|
|
766
|
+
_compute_local_linearization_error_ratio()
|
|
767
|
+
)
|
|
768
|
+
except Exception as e:
|
|
769
|
+
# This diagnostic is best-effort; failure should not break analysis.
|
|
770
|
+
logger.warning(
|
|
771
|
+
"Failed to compute local linearization error ratio diagnostic: %s",
|
|
772
|
+
str(e),
|
|
773
|
+
)
|
|
774
|
+
local_error_ratio_median = math.nan
|
|
775
|
+
local_error_ratio_p90 = math.nan
|
|
776
|
+
local_error_ratio_max = math.nan
|
|
777
|
+
|
|
588
778
|
debug_pieces_dict = {
|
|
589
779
|
"theta_est": theta_est,
|
|
590
780
|
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
591
781
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
592
|
-
"raw_joint_bread_matrix":
|
|
593
|
-
"stabilized_joint_bread_matrix":
|
|
782
|
+
"raw_joint_bread_matrix": raw_joint_bread_matrix,
|
|
783
|
+
"stabilized_joint_bread_matrix": stabilized_joint_bread_matrix,
|
|
594
784
|
"joint_meat_matrix": joint_adjusted_meat_matrix,
|
|
595
785
|
"classical_bread_matrix": classical_bread_matrix,
|
|
596
786
|
"classical_meat_matrix": classical_meat_matrix,
|
|
597
787
|
"all_estimating_function_stacks": per_subject_estimating_function_stacks,
|
|
598
|
-
"joint_bread_condition_number":
|
|
599
|
-
"
|
|
788
|
+
"joint_bread_condition_number": joint_bread_cond,
|
|
789
|
+
"max_eigenvalue_joint_sandwich": max_eig_joint,
|
|
790
|
+
"all_eigenvalues_joint_sandwich": eigvals_joint_sandwich,
|
|
791
|
+
"max_to_median_ratio_joint_sandwich": max_to_median_ratio_joint_sandwich,
|
|
792
|
+
"max_eigenvalue_theta_only_adjusted_sandwich": max_eig_theta,
|
|
793
|
+
"all_eigenvalues_theta_only_adjusted_sandwich": eigvals_theta_only_adjusted_sandwich,
|
|
794
|
+
"max_to_median_ratio_theta_only_adjusted_sandwich": max_to_median_ratio_theta_only_adjusted_sandwich,
|
|
795
|
+
"local_linearization_error_ratio_median": local_error_ratio_median,
|
|
796
|
+
"local_linearization_error_ratio_p90": local_error_ratio_p90,
|
|
797
|
+
"local_linearization_error_ratio_max": local_error_ratio_max,
|
|
600
798
|
"all_post_update_betas": all_post_update_betas,
|
|
601
799
|
"per_subject_adjusted_corrections": per_subject_adjusted_corrections,
|
|
602
800
|
"per_subject_classical_corrections": per_subject_classical_corrections,
|
|
@@ -610,8 +808,8 @@ def analyze_dataset(
|
|
|
610
808
|
|
|
611
809
|
if collect_data_for_blowup_supervised_learning:
|
|
612
810
|
datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
|
|
613
|
-
|
|
614
|
-
|
|
811
|
+
raw_joint_bread_matrix,
|
|
812
|
+
joint_bread_cond,
|
|
615
813
|
avg_estimating_function_stack,
|
|
616
814
|
per_subject_estimating_function_stacks,
|
|
617
815
|
all_post_update_betas,
|
|
@@ -756,12 +954,16 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
756
954
|
policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
|
|
757
955
|
action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
|
|
758
956
|
beta_index_by_policy_num: dict[int | float, int],
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
]
|
|
957
|
+
include_auxiliary_outputs: bool = True,
|
|
958
|
+
) -> (
|
|
959
|
+
tuple[
|
|
960
|
+
jnp.ndarray[jnp.float32],
|
|
961
|
+
jnp.ndarray[jnp.float32],
|
|
962
|
+
jnp.ndarray[jnp.float32],
|
|
963
|
+
jnp.ndarray[jnp.float32],
|
|
964
|
+
]
|
|
965
|
+
| jnp.ndarray[jnp.float32]
|
|
966
|
+
):
|
|
765
967
|
"""
|
|
766
968
|
Computes a weighted estimating function stack for a given algorithm estimating function
|
|
767
969
|
and arguments, inference estimating functio and arguments, and action probability function and
|
|
@@ -825,12 +1027,23 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
825
1027
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
826
1028
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
827
1029
|
|
|
1030
|
+
include_auxiliary_outputs (bool):
|
|
1031
|
+
If True, returns the adjusted meat, classical meat, and classical bread contributions in
|
|
1032
|
+
a second returned tuple. If False, only returns the weighted estimating function stack.
|
|
1033
|
+
|
|
828
1034
|
Returns:
|
|
829
1035
|
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
830
1036
|
stack.
|
|
831
1037
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
|
|
832
1038
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
833
1039
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
1040
|
+
|
|
1041
|
+
or
|
|
1042
|
+
|
|
1043
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
1044
|
+
stack.
|
|
1045
|
+
|
|
1046
|
+
depending on the value of include_auxiliary_outputs.
|
|
834
1047
|
"""
|
|
835
1048
|
|
|
836
1049
|
logger.info(
|
|
@@ -1020,14 +1233,18 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
1020
1233
|
# c. The third output is averaged across subjects to obtain the classical meat matrix.
|
|
1021
1234
|
# d. The fourth output is averaged across subjects to obtain the inverse classical bread
|
|
1022
1235
|
# matrix.
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1236
|
+
if include_auxiliary_outputs:
|
|
1237
|
+
return (
|
|
1238
|
+
weighted_stack,
|
|
1239
|
+
jnp.outer(weighted_stack, weighted_stack),
|
|
1240
|
+
jnp.outer(inference_component, inference_component),
|
|
1241
|
+
jax.jacrev(
|
|
1242
|
+
inference_estimating_func, argnums=inference_func_args_theta_index
|
|
1243
|
+
)(*threaded_inference_func_args),
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
else:
|
|
1247
|
+
return weighted_stack
|
|
1031
1248
|
|
|
1032
1249
|
|
|
1033
1250
|
def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
@@ -1067,6 +1284,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1067
1284
|
],
|
|
1068
1285
|
suppress_all_data_checks: bool,
|
|
1069
1286
|
suppress_interactive_data_checks: bool,
|
|
1287
|
+
include_auxiliary_outputs: bool = True,
|
|
1070
1288
|
) -> tuple[
|
|
1071
1289
|
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
1072
1290
|
]:
|
|
@@ -1141,10 +1359,14 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1141
1359
|
If True, suppresses interactive data checks that would otherwise be performed to ensure
|
|
1142
1360
|
the correctness of the threaded arguments. The checks are still performed, but
|
|
1143
1361
|
any interactive prompts are suppressed.
|
|
1362
|
+
include_auxiliary_outputs (bool):
|
|
1363
|
+
If True, returns the adjusted meat, classical meat, and classical bread contributions in addition to the average weighted estimating function stack.
|
|
1364
|
+
If False, returns only the average weighted estimating function stack.
|
|
1144
1365
|
|
|
1145
1366
|
Returns:
|
|
1146
1367
|
jnp.ndarray:
|
|
1147
1368
|
A 2D JAX NumPy array holding the average weighted estimating function stack.
|
|
1369
|
+
|
|
1148
1370
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1149
1371
|
A tuple containing
|
|
1150
1372
|
1. the average weighted estimating function stack
|
|
@@ -1153,6 +1375,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1153
1375
|
4. the subject-level inverse classical bread matrix contributions
|
|
1154
1376
|
5. raw per-subject weighted estimating function
|
|
1155
1377
|
stacks.
|
|
1378
|
+
or jnp.ndarray:
|
|
1379
|
+
A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
1380
|
+
stack.
|
|
1381
|
+
depending on the value of include_auxiliary_outputs.
|
|
1156
1382
|
"""
|
|
1157
1383
|
|
|
1158
1384
|
# 1. Collect estimating functions by differentiating the loss functions if needed.
|
|
@@ -1275,6 +1501,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1275
1501
|
]
|
|
1276
1502
|
|
|
1277
1503
|
stacks = jnp.array([result[0] for result in results])
|
|
1504
|
+
|
|
1505
|
+
if not include_auxiliary_outputs:
|
|
1506
|
+
return jnp.mean(stacks, axis=0)
|
|
1507
|
+
|
|
1278
1508
|
outer_products = jnp.array([result[1] for result in results])
|
|
1279
1509
|
inference_only_outer_products = jnp.array([result[2] for result in results])
|
|
1280
1510
|
inference_hessians = jnp.array([result[3] for result in results])
|
|
@@ -1475,7 +1705,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1475
1705
|
theta_dim = theta_est.shape[0]
|
|
1476
1706
|
beta_dim = all_post_update_betas.shape[1]
|
|
1477
1707
|
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
1478
|
-
|
|
1708
|
+
raw_joint_bread_matrix, (
|
|
1479
1709
|
avg_estimating_function_stack,
|
|
1480
1710
|
per_subject_joint_adjusted_meat_contributions,
|
|
1481
1711
|
per_subject_classical_meat_contributions,
|
|
@@ -1533,21 +1763,21 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1533
1763
|
|
|
1534
1764
|
# Increase diagonal block dominance possibly improve conditioning of diagonal
|
|
1535
1765
|
# blocks as necessary, to ensure mathematical stability of joint bread
|
|
1536
|
-
|
|
1766
|
+
stabilized_joint_bread_matrix = (
|
|
1537
1767
|
(
|
|
1538
1768
|
stabilize_joint_bread_if_necessary(
|
|
1539
|
-
|
|
1769
|
+
raw_joint_bread_matrix,
|
|
1540
1770
|
beta_dim,
|
|
1541
1771
|
theta_dim,
|
|
1542
1772
|
)
|
|
1543
1773
|
)
|
|
1544
1774
|
if stabilize_joint_bread
|
|
1545
|
-
else
|
|
1775
|
+
else raw_joint_bread_matrix
|
|
1546
1776
|
)
|
|
1547
1777
|
|
|
1548
1778
|
# Now stably (no explicit inversion) form our sandwiches.
|
|
1549
|
-
|
|
1550
|
-
|
|
1779
|
+
joint_sandwich = form_sandwich_from_bread_and_meat(
|
|
1780
|
+
stabilized_joint_bread_matrix,
|
|
1551
1781
|
joint_adjusted_meat_matrix,
|
|
1552
1782
|
num_subjects,
|
|
1553
1783
|
method=SandwichFormationMethods.BREAD_T_QR,
|
|
@@ -1568,7 +1798,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1568
1798
|
form_adjusted_meat_adjustments_directly(
|
|
1569
1799
|
theta_dim,
|
|
1570
1800
|
all_post_update_betas.shape[1],
|
|
1571
|
-
|
|
1801
|
+
stabilized_joint_bread_matrix,
|
|
1572
1802
|
per_subject_estimating_function_stacks,
|
|
1573
1803
|
analysis_df,
|
|
1574
1804
|
active_col_name,
|
|
@@ -1610,7 +1840,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1610
1840
|
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1611
1841
|
)
|
|
1612
1842
|
)
|
|
1613
|
-
theta_only_adjusted_sandwich =
|
|
1843
|
+
theta_only_adjusted_sandwich = joint_sandwich[-theta_dim:, -theta_dim:]
|
|
1614
1844
|
|
|
1615
1845
|
if not np.allclose(
|
|
1616
1846
|
theta_only_adjusted_sandwich,
|
|
@@ -1624,10 +1854,10 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1624
1854
|
# Stack the joint bread pieces together horizontally and return the auxiliary
|
|
1625
1855
|
# values too. The joint bread should always be block lower triangular.
|
|
1626
1856
|
return (
|
|
1627
|
-
|
|
1628
|
-
|
|
1857
|
+
raw_joint_bread_matrix,
|
|
1858
|
+
stabilized_joint_bread_matrix,
|
|
1629
1859
|
joint_adjusted_meat_matrix,
|
|
1630
|
-
|
|
1860
|
+
joint_sandwich,
|
|
1631
1861
|
classical_bread_matrix,
|
|
1632
1862
|
classical_meat_matrix,
|
|
1633
1863
|
classical_sandwich,
|
|
@@ -1643,7 +1873,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1643
1873
|
# important for the subject to know if this is happening. Even if enabled, it is important
|
|
1644
1874
|
# that the subject know it actually kicks in.
|
|
1645
1875
|
def stabilize_joint_bread_if_necessary(
|
|
1646
|
-
|
|
1876
|
+
joint_bread_matrix: jnp.ndarray,
|
|
1647
1877
|
beta_dim: int,
|
|
1648
1878
|
theta_dim: int,
|
|
1649
1879
|
) -> jnp.ndarray:
|
|
@@ -1652,7 +1882,7 @@ def stabilize_joint_bread_if_necessary(
|
|
|
1652
1882
|
dominance and/or adding a small ridge penalty to the diagonal blocks.
|
|
1653
1883
|
|
|
1654
1884
|
Args:
|
|
1655
|
-
|
|
1885
|
+
joint_bread_matrix (jnp.ndarray):
|
|
1656
1886
|
A 2-D JAX NumPy array representing the joint bread matrix.
|
|
1657
1887
|
beta_dim (int):
|
|
1658
1888
|
The dimension of each beta parameter.
|
|
@@ -1673,7 +1903,7 @@ def stabilize_joint_bread_if_necessary(
|
|
|
1673
1903
|
|
|
1674
1904
|
# Grab just the RL block and convert numpy array for easier manipulation.
|
|
1675
1905
|
RL_stack_beta_derivatives_block = np.array(
|
|
1676
|
-
|
|
1906
|
+
joint_bread_matrix[:-theta_dim, :-theta_dim]
|
|
1677
1907
|
)
|
|
1678
1908
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
1679
1909
|
for i in range(1, num_updates + 1):
|
|
@@ -1792,11 +2022,11 @@ def stabilize_joint_bread_if_necessary(
|
|
|
1792
2022
|
[
|
|
1793
2023
|
[
|
|
1794
2024
|
RL_stack_beta_derivatives_block,
|
|
1795
|
-
|
|
2025
|
+
joint_bread_matrix[:-theta_dim, -theta_dim:],
|
|
1796
2026
|
],
|
|
1797
2027
|
[
|
|
1798
|
-
|
|
1799
|
-
|
|
2028
|
+
joint_bread_matrix[-theta_dim:, :-theta_dim],
|
|
2029
|
+
joint_bread_matrix[-theta_dim:, -theta_dim:],
|
|
1800
2030
|
],
|
|
1801
2031
|
]
|
|
1802
2032
|
)
|
|
@@ -7,11 +7,11 @@ lifejacket/form_adjusted_meat_adjustments_directly.py,sha256=AVlGOuw_FgVDcVnhQs1
|
|
|
7
7
|
lifejacket/get_datum_for_blowup_supervised_learning.py,sha256=sCH-PlrFlLJgCYpTmdeasiHwHYSEy9wxspkOTDuDPuY,58594
|
|
8
8
|
lifejacket/helper_functions.py,sha256=SdAbUwXNx-3JFsyTfLyliQ7kUOm0eABaiNgoYLR8NG0,16967
|
|
9
9
|
lifejacket/input_checks.py,sha256=q7HFZq5n18edQU8X5laONsBgWSMidLRy6Nhqdw5FpOw,47084
|
|
10
|
-
lifejacket/post_deployment_analysis.py,sha256=
|
|
10
|
+
lifejacket/post_deployment_analysis.py,sha256=M6-qQ1vynJ7-iHi-vgTdS71Vp-27gioB1b6qt0sCdek,92470
|
|
11
11
|
lifejacket/small_sample_corrections.py,sha256=aB6qi-r3ANoBMgf2Oo5-lCXCy_L4H3FlBffGwPcfXkg,5610
|
|
12
12
|
lifejacket/vmap_helpers.py,sha256=pZqYN3p9Ty9DPOeeY9TKbRJXR2AV__HBwwDFOvdOQ84,2688
|
|
13
|
-
lifejacket-1.0.
|
|
14
|
-
lifejacket-1.0.
|
|
15
|
-
lifejacket-1.0.
|
|
16
|
-
lifejacket-1.0.
|
|
17
|
-
lifejacket-1.0.
|
|
13
|
+
lifejacket-1.1.0.dist-info/METADATA,sha256=RPJwaJyvfQGCKxgNvg7tx-QPu_Wa89pIPv34fFzqjzA,1773
|
|
14
|
+
lifejacket-1.1.0.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
15
|
+
lifejacket-1.1.0.dist-info/entry_points.txt,sha256=CZ9AUPN0xfnpYqwtGTr6n9l5mpyEOddsXX8fnxKRB6U,71
|
|
16
|
+
lifejacket-1.1.0.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
|
|
17
|
+
lifejacket-1.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|