lifejacket 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/calculate_derivatives.py +0 -2
- lifejacket/constants.py +4 -16
- lifejacket/deployment_conditioning_monitor.py +19 -12
- lifejacket/form_adjusted_meat_adjustments_directly.py +25 -27
- lifejacket/get_datum_for_blowup_supervised_learning.py +71 -77
- lifejacket/helper_functions.py +15 -148
- lifejacket/input_checks.py +49 -50
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +377 -144
- lifejacket/small_sample_corrections.py +11 -13
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/METADATA +1 -1
- lifejacket-1.1.0.dist-info/RECORD +17 -0
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/WHEEL +1 -1
- lifejacket-1.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-1.0.0.dist-info/RECORD +0 -17
- lifejacket-1.0.0.dist-info/entry_points.txt +0 -2
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -53,6 +53,8 @@ logging.basicConfig(
|
|
|
53
53
|
level=logging.INFO,
|
|
54
54
|
)
|
|
55
55
|
|
|
56
|
+
jax.config.update("jax_enable_x64", True)
|
|
57
|
+
|
|
56
58
|
|
|
57
59
|
@click.group()
|
|
58
60
|
def cli():
|
|
@@ -217,9 +219,9 @@ def cli():
|
|
|
217
219
|
type=click.Choice(
|
|
218
220
|
[
|
|
219
221
|
SmallSampleCorrections.NONE,
|
|
220
|
-
SmallSampleCorrections.
|
|
221
|
-
SmallSampleCorrections.
|
|
222
|
-
SmallSampleCorrections.
|
|
222
|
+
SmallSampleCorrections.Z1theta,
|
|
223
|
+
SmallSampleCorrections.Z2theta,
|
|
224
|
+
SmallSampleCorrections.Z3theta,
|
|
223
225
|
]
|
|
224
226
|
),
|
|
225
227
|
default=SmallSampleCorrections.NONE,
|
|
@@ -235,13 +237,13 @@ def cli():
|
|
|
235
237
|
"--form_adjusted_meat_adjustments_explicitly",
|
|
236
238
|
type=bool,
|
|
237
239
|
default=False,
|
|
238
|
-
help="If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
240
|
+
help="If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted sandwich from the classical sandwich. This is for diagnostic purposes, as the adjusted sandwich is formed without doing this.",
|
|
239
241
|
)
|
|
240
242
|
@click.option(
|
|
241
|
-
"--
|
|
243
|
+
"--stabilize_joint_bread",
|
|
242
244
|
type=bool,
|
|
243
245
|
default=True,
|
|
244
|
-
help="If True, stabilizes the joint
|
|
246
|
+
help="If True, stabilizes the joint bread matrix if it does not meet conditioning thresholds.",
|
|
245
247
|
)
|
|
246
248
|
def analyze_dataset_wrapper(**kwargs):
|
|
247
249
|
"""
|
|
@@ -324,15 +326,15 @@ def analyze_dataset(
|
|
|
324
326
|
small_sample_correction: str,
|
|
325
327
|
collect_data_for_blowup_supervised_learning: bool,
|
|
326
328
|
form_adjusted_meat_adjustments_explicitly: bool,
|
|
327
|
-
|
|
329
|
+
stabilize_joint_bread: bool,
|
|
328
330
|
) -> None:
|
|
329
331
|
"""
|
|
330
|
-
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using
|
|
332
|
+
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using and classical sandwich estimators.
|
|
331
333
|
|
|
332
334
|
There are two modes of use for this function.
|
|
333
335
|
|
|
334
336
|
First, it may be called indirectly from the command line by passing through
|
|
335
|
-
|
|
337
|
+
analyze_dataset_wrapper.
|
|
336
338
|
|
|
337
339
|
Second, it may be called directly from Python code with in-memory objects.
|
|
338
340
|
|
|
@@ -388,17 +390,17 @@ def analyze_dataset(
|
|
|
388
390
|
small_sample_correction (str):
|
|
389
391
|
Type of small sample correction to apply.
|
|
390
392
|
collect_data_for_blowup_supervised_learning (bool):
|
|
391
|
-
Whether to collect data for doing supervised learning about
|
|
393
|
+
Whether to collect data for doing supervised learning about adjusted sandwich blowup.
|
|
392
394
|
form_adjusted_meat_adjustments_explicitly (bool):
|
|
393
|
-
If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
395
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
394
396
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
If True, stabilizes the joint
|
|
397
|
+
adjusted sandwich is formed without doing this.
|
|
398
|
+
stabilize_joint_bread (bool):
|
|
399
|
+
If True, stabilizes the joint bread matrix if it does not meet conditioning
|
|
398
400
|
thresholds.
|
|
399
401
|
|
|
400
402
|
Returns:
|
|
401
|
-
dict: A dictionary containing the theta estimate,
|
|
403
|
+
dict: A dictionary containing the theta estimate, adjusted sandwich variance estimate, and
|
|
402
404
|
classical sandwich variance estimate.
|
|
403
405
|
"""
|
|
404
406
|
|
|
@@ -438,7 +440,6 @@ def analyze_dataset(
|
|
|
438
440
|
)
|
|
439
441
|
|
|
440
442
|
### Begin collecting data structures that will be used to compute the joint bread matrix.
|
|
441
|
-
|
|
442
443
|
beta_index_by_policy_num, initial_policy_num = (
|
|
443
444
|
construct_beta_index_by_policy_num_map(
|
|
444
445
|
analysis_df, policy_num_col_name, active_col_name
|
|
@@ -475,20 +476,20 @@ def analyze_dataset(
|
|
|
475
476
|
active_col_name,
|
|
476
477
|
)
|
|
477
478
|
|
|
478
|
-
# Use a per-subject weighted estimating function stacking
|
|
479
|
-
#
|
|
479
|
+
# Use a per-subject weighted estimating function stacking function to derive classical and joint
|
|
480
|
+
# meat and bread matrices. This is facilitated because the *value* of the
|
|
480
481
|
# weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
|
|
481
482
|
logger.info(
|
|
482
|
-
"Constructing joint
|
|
483
|
+
"Constructing joint bread matrix, joint meat matrix, the classical analogs, and the avg estimating function stack across subjects."
|
|
483
484
|
)
|
|
484
485
|
|
|
485
486
|
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
486
487
|
(
|
|
487
|
-
|
|
488
|
-
|
|
488
|
+
stabilized_joint_bread_matrix,
|
|
489
|
+
raw_joint_bread_matrix,
|
|
489
490
|
joint_adjusted_meat_matrix,
|
|
490
|
-
|
|
491
|
-
|
|
491
|
+
joint_sandwich_matrix,
|
|
492
|
+
classical_bread_matrix,
|
|
492
493
|
classical_meat_matrix,
|
|
493
494
|
classical_sandwich_var_estimate,
|
|
494
495
|
avg_estimating_function_stack,
|
|
@@ -524,7 +525,7 @@ def analyze_dataset(
|
|
|
524
525
|
suppress_interactive_data_checks,
|
|
525
526
|
small_sample_correction,
|
|
526
527
|
form_adjusted_meat_adjustments_explicitly,
|
|
527
|
-
|
|
528
|
+
stabilize_joint_bread,
|
|
528
529
|
analysis_df,
|
|
529
530
|
active_col_name,
|
|
530
531
|
action_col_name,
|
|
@@ -545,22 +546,19 @@ def analyze_dataset(
|
|
|
545
546
|
|
|
546
547
|
# This bottom right corner of the joint (betas and theta) variance matrix is the portion
|
|
547
548
|
# corresponding to just theta.
|
|
548
|
-
adjusted_sandwich_var_estimate =
|
|
549
|
-
-theta_dim:, -theta_dim:
|
|
550
|
-
]
|
|
549
|
+
adjusted_sandwich_var_estimate = joint_sandwich_matrix[-theta_dim:, -theta_dim:]
|
|
551
550
|
|
|
552
551
|
# Check for negative diagonal elements and set them to zero if found
|
|
553
|
-
|
|
554
|
-
if np.any(
|
|
552
|
+
adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
|
|
553
|
+
if np.any(adjusted_diagonal < 0):
|
|
555
554
|
logger.warning(
|
|
556
|
-
"Found negative diagonal elements in
|
|
555
|
+
"Found negative diagonal elements in adjusted sandwich variance estimate. Setting them to zero."
|
|
557
556
|
)
|
|
558
557
|
np.fill_diagonal(
|
|
559
|
-
adjusted_sandwich_var_estimate, np.maximum(
|
|
558
|
+
adjusted_sandwich_var_estimate, np.maximum(adjusted_diagonal, 0)
|
|
560
559
|
)
|
|
561
560
|
|
|
562
561
|
logger.info("Writing results to file...")
|
|
563
|
-
# Write analysis results to same directory that input files are in
|
|
564
562
|
output_folder_abs_path = pathlib.Path(output_dir).resolve()
|
|
565
563
|
|
|
566
564
|
analysis_dict = {
|
|
@@ -574,25 +572,229 @@ def analyze_dataset(
|
|
|
574
572
|
f,
|
|
575
573
|
)
|
|
576
574
|
|
|
577
|
-
|
|
578
|
-
|
|
575
|
+
joint_bread_cond = jnp.linalg.cond(raw_joint_bread_matrix)
|
|
576
|
+
logger.info(
|
|
577
|
+
"Joint bread condition number: %f",
|
|
578
|
+
joint_bread_cond,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# calculate the max eigenvalue of the theta-only adjusted sandwich
|
|
582
|
+
max_eigenvalue_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
|
|
583
|
+
adjusted_sandwich_var_estimate
|
|
584
|
+
).max()
|
|
585
|
+
logger.info(
|
|
586
|
+
"Max eigenvalue of theta-only adjusted sandwich matrix: %f",
|
|
587
|
+
max_eigenvalue_theta_only_adjusted_sandwich,
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Compute ratios: max eigenvalue / median eigenvalue among those >= 1e-8 * max.
|
|
591
|
+
eigvals_joint_sandwich = scipy.linalg.eigvalsh(joint_sandwich_matrix)
|
|
592
|
+
max_eig_joint = float(eigvals_joint_sandwich.max())
|
|
593
|
+
logger.info(
|
|
594
|
+
"Max eigenvalue of joint adjusted sandwich matrix: %f",
|
|
595
|
+
max_eig_joint,
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
joint_keep = eigvals_joint_sandwich >= (1e-8 * max_eig_joint)
|
|
599
|
+
joint_median_kept = (
|
|
600
|
+
float(np.median(eigvals_joint_sandwich[joint_keep]))
|
|
601
|
+
if np.any(joint_keep)
|
|
602
|
+
else math.nan
|
|
603
|
+
)
|
|
604
|
+
max_to_median_ratio_joint_sandwich = (
|
|
605
|
+
(max_eig_joint / joint_median_kept)
|
|
606
|
+
if (not math.isnan(joint_median_kept) and joint_median_kept > 0)
|
|
607
|
+
else (
|
|
608
|
+
math.inf
|
|
609
|
+
if (not math.isnan(joint_median_kept) and joint_median_kept == 0)
|
|
610
|
+
else math.nan
|
|
611
|
+
)
|
|
579
612
|
)
|
|
580
613
|
logger.info(
|
|
581
|
-
"
|
|
582
|
-
|
|
614
|
+
"Max/median eigenvalue ratio (joint sandwich; median over eigvals >= 1e-8*max): %f",
|
|
615
|
+
max_to_median_ratio_joint_sandwich,
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
eigvals_theta_only_adjusted_sandwich = scipy.linalg.eigvalsh(
|
|
619
|
+
adjusted_sandwich_var_estimate
|
|
583
620
|
)
|
|
621
|
+
max_eig_theta = float(eigvals_theta_only_adjusted_sandwich.max())
|
|
622
|
+
theta_keep = eigvals_theta_only_adjusted_sandwich >= (1e-8 * max_eig_theta)
|
|
623
|
+
theta_median_kept = (
|
|
624
|
+
float(np.median(eigvals_theta_only_adjusted_sandwich[theta_keep]))
|
|
625
|
+
if np.any(theta_keep)
|
|
626
|
+
else math.nan
|
|
627
|
+
)
|
|
628
|
+
max_to_median_ratio_theta_only_adjusted_sandwich = (
|
|
629
|
+
(max_eig_theta / theta_median_kept)
|
|
630
|
+
if (not math.isnan(theta_median_kept) and theta_median_kept > 0)
|
|
631
|
+
else (
|
|
632
|
+
math.inf
|
|
633
|
+
if (not math.isnan(theta_median_kept) and theta_median_kept == 0)
|
|
634
|
+
else math.nan
|
|
635
|
+
)
|
|
636
|
+
)
|
|
637
|
+
logger.info(
|
|
638
|
+
"Max/median eigenvalue ratio (theta-only adjusted sandwich; median over eigvals >= 1e-8*max): %f",
|
|
639
|
+
max_to_median_ratio_theta_only_adjusted_sandwich,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
# --- Local linearization validity diagnostic (single-run) ---
|
|
643
|
+
# We compare the nonlinear Taylor remainder of the joint estimating-function map to the
|
|
644
|
+
# retained linear term, at perturbations on the O(1/sqrt(n)) scale.
|
|
645
|
+
#
|
|
646
|
+
# Define r(delta) = || g(eta+delta) - g(eta) - B delta ||_2 / || B delta ||_2,
|
|
647
|
+
# where g(eta) is the avg per-subject weighted estimating-function stack and B is the
|
|
648
|
+
# stabilized joint bread (Jacobian of g w.r.t. flattened betas+theta).
|
|
649
|
+
#
|
|
650
|
+
# This ratio is dimensionless and can be used as a necessary/sanity diagnostic that the
|
|
651
|
+
# first-order linearization is locally accurate at the estimation scale.
|
|
652
|
+
|
|
653
|
+
def _compute_local_linearization_error_ratio() -> tuple[float, float]:
|
|
654
|
+
# Ensure float64 for diagnostics even if upstream ran in float32.
|
|
655
|
+
joint_bread_float64 = jnp.asarray(
|
|
656
|
+
stabilized_joint_bread_matrix, dtype=jnp.float64
|
|
657
|
+
)
|
|
658
|
+
g_hat = jnp.asarray(avg_estimating_function_stack, dtype=jnp.float64)
|
|
659
|
+
stacks_float64 = jnp.asarray(
|
|
660
|
+
per_subject_estimating_function_stacks, dtype=jnp.float64
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
num_subjects = stacks_float64.shape[0]
|
|
664
|
+
|
|
665
|
+
def _eval_avg_stack_jit(flattened_betas_and_theta: jnp.ndarray) -> jnp.ndarray:
|
|
666
|
+
return jnp.asarray(
|
|
667
|
+
get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
668
|
+
flattened_betas_and_theta,
|
|
669
|
+
beta_dim,
|
|
670
|
+
theta_dim,
|
|
671
|
+
subject_ids,
|
|
672
|
+
action_prob_func,
|
|
673
|
+
action_prob_func_args_beta_index,
|
|
674
|
+
alg_update_func,
|
|
675
|
+
alg_update_func_type,
|
|
676
|
+
alg_update_func_args_beta_index,
|
|
677
|
+
alg_update_func_args_action_prob_index,
|
|
678
|
+
alg_update_func_args_action_prob_times_index,
|
|
679
|
+
alg_update_func_args_previous_betas_index,
|
|
680
|
+
inference_func,
|
|
681
|
+
inference_func_type,
|
|
682
|
+
inference_func_args_theta_index,
|
|
683
|
+
inference_func_args_action_prob_index,
|
|
684
|
+
action_prob_func_args,
|
|
685
|
+
policy_num_by_decision_time_by_subject_id,
|
|
686
|
+
initial_policy_num,
|
|
687
|
+
beta_index_by_policy_num,
|
|
688
|
+
inference_func_args_by_subject_id,
|
|
689
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
690
|
+
alg_update_func_args,
|
|
691
|
+
action_by_decision_time_by_subject_id,
|
|
692
|
+
True, # suppress_all_data_checks
|
|
693
|
+
True, # suppress_interactive_data_checks
|
|
694
|
+
False, # include_auxiliary_outputs
|
|
695
|
+
),
|
|
696
|
+
dtype=jnp.float64,
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
# Evaluate at the final estimate.
|
|
700
|
+
eta_hat = jnp.asarray(
|
|
701
|
+
flatten_params(all_post_update_betas, theta_est), dtype=jnp.float64
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Draw perturbations delta_j on the O(1/sqrt(n)) scale, aligned with the empirical
|
|
705
|
+
# joint estimating function stack covariance, without forming a d_joint x d_joint matrix
|
|
706
|
+
# square-root. If G is the (n x d) matrix of per-subject stacks, then (1/n) G^T G is the
|
|
707
|
+
# empirical covariance in joint estimating function stack space. Sampling u = (G^T w)/sqrt(n) with w~N(0, I_n) gives
|
|
708
|
+
# u ~ N(0, empirical joint estimating function stack covariance G^T G/n ) in joint estimating function stack space.
|
|
709
|
+
key = jax.random.PRNGKey(0)
|
|
710
|
+
|
|
711
|
+
# The number of perturbations we will probe
|
|
712
|
+
J = 15
|
|
713
|
+
# Each requires num_subjects standard normal draws, which we will then transform
|
|
714
|
+
# into joint estimating function space perturbations in U
|
|
715
|
+
W = jax.random.normal(key, shape=(J, num_subjects), dtype=jnp.float64)
|
|
716
|
+
|
|
717
|
+
# Joint estimating function space perturbations: u_j in R^{d_joint}
|
|
718
|
+
# U = (1/sqrt(n)) * W G, where rows of G are g_i^T
|
|
719
|
+
U = (W @ stacks_float64) / jnp.sqrt(num_subjects)
|
|
720
|
+
|
|
721
|
+
# Parameter perturbations: delta = (c/sqrt(n)) * B^{-1} u
|
|
722
|
+
# Use solve rather than explicit inverse.
|
|
723
|
+
c = 1.0
|
|
724
|
+
delta = (c / jnp.sqrt(num_subjects)) * jnp.linalg.solve(
|
|
725
|
+
joint_bread_float64, U.T
|
|
726
|
+
).T
|
|
727
|
+
|
|
728
|
+
# Compute ratios r_j.
|
|
729
|
+
# NOTE: We use the Euclidean norm in score space; this is dimensionless and avoids
|
|
730
|
+
# forming/pseudoinverting a potentially rank-deficient matrix.
|
|
731
|
+
B_delta = (joint_bread_float64 @ delta.T).T
|
|
732
|
+
g_plus = jax.vmap(lambda d: _eval_avg_stack_jit(eta_hat + d))(delta)
|
|
733
|
+
remainder = g_plus - g_hat - B_delta
|
|
734
|
+
|
|
735
|
+
denom = jnp.linalg.norm(B_delta, axis=1)
|
|
736
|
+
numer = jnp.linalg.norm(remainder, axis=1)
|
|
737
|
+
|
|
738
|
+
# Avoid division by zero (should not happen unless delta collapses numerically).
|
|
739
|
+
ratios = jnp.where(denom > 0, numer / denom, jnp.inf)
|
|
740
|
+
|
|
741
|
+
local_error_ratio_median = float(jnp.median(ratios))
|
|
742
|
+
local_error_ratio_p90 = float(jnp.quantile(ratios, 0.9))
|
|
743
|
+
local_error_ratio_max = float(jnp.max(ratios))
|
|
744
|
+
|
|
745
|
+
logger.info(
|
|
746
|
+
"Local linearization error ratio (median over %d draws): %.6f",
|
|
747
|
+
J,
|
|
748
|
+
local_error_ratio_median,
|
|
749
|
+
)
|
|
750
|
+
logger.info(
|
|
751
|
+
"Local linearization error ratio (90th pct over %d draws): %.6f",
|
|
752
|
+
J,
|
|
753
|
+
local_error_ratio_p90,
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
logger.info(
|
|
757
|
+
"Local linearization error ratio (max over %d draws): %.6f",
|
|
758
|
+
J,
|
|
759
|
+
local_error_ratio_max,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
return local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max
|
|
763
|
+
|
|
764
|
+
try:
|
|
765
|
+
local_error_ratio_median, local_error_ratio_p90, local_error_ratio_max = (
|
|
766
|
+
_compute_local_linearization_error_ratio()
|
|
767
|
+
)
|
|
768
|
+
except Exception as e:
|
|
769
|
+
# This diagnostic is best-effort; failure should not break analysis.
|
|
770
|
+
logger.warning(
|
|
771
|
+
"Failed to compute local linearization error ratio diagnostic: %s",
|
|
772
|
+
str(e),
|
|
773
|
+
)
|
|
774
|
+
local_error_ratio_median = math.nan
|
|
775
|
+
local_error_ratio_p90 = math.nan
|
|
776
|
+
local_error_ratio_max = math.nan
|
|
584
777
|
|
|
585
778
|
debug_pieces_dict = {
|
|
586
779
|
"theta_est": theta_est,
|
|
587
780
|
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
588
781
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
589
|
-
"
|
|
590
|
-
"
|
|
782
|
+
"raw_joint_bread_matrix": raw_joint_bread_matrix,
|
|
783
|
+
"stabilized_joint_bread_matrix": stabilized_joint_bread_matrix,
|
|
591
784
|
"joint_meat_matrix": joint_adjusted_meat_matrix,
|
|
592
|
-
"
|
|
785
|
+
"classical_bread_matrix": classical_bread_matrix,
|
|
593
786
|
"classical_meat_matrix": classical_meat_matrix,
|
|
594
787
|
"all_estimating_function_stacks": per_subject_estimating_function_stacks,
|
|
595
|
-
"
|
|
788
|
+
"joint_bread_condition_number": joint_bread_cond,
|
|
789
|
+
"max_eigenvalue_joint_sandwich": max_eig_joint,
|
|
790
|
+
"all_eigenvalues_joint_sandwich": eigvals_joint_sandwich,
|
|
791
|
+
"max_to_median_ratio_joint_sandwich": max_to_median_ratio_joint_sandwich,
|
|
792
|
+
"max_eigenvalue_theta_only_adjusted_sandwich": max_eig_theta,
|
|
793
|
+
"all_eigenvalues_theta_only_adjusted_sandwich": eigvals_theta_only_adjusted_sandwich,
|
|
794
|
+
"max_to_median_ratio_theta_only_adjusted_sandwich": max_to_median_ratio_theta_only_adjusted_sandwich,
|
|
795
|
+
"local_linearization_error_ratio_median": local_error_ratio_median,
|
|
796
|
+
"local_linearization_error_ratio_p90": local_error_ratio_p90,
|
|
797
|
+
"local_linearization_error_ratio_max": local_error_ratio_max,
|
|
596
798
|
"all_post_update_betas": all_post_update_betas,
|
|
597
799
|
"per_subject_adjusted_corrections": per_subject_adjusted_corrections,
|
|
598
800
|
"per_subject_classical_corrections": per_subject_classical_corrections,
|
|
@@ -606,8 +808,8 @@ def analyze_dataset(
|
|
|
606
808
|
|
|
607
809
|
if collect_data_for_blowup_supervised_learning:
|
|
608
810
|
datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
|
|
609
|
-
|
|
610
|
-
|
|
811
|
+
raw_joint_bread_matrix,
|
|
812
|
+
joint_bread_cond,
|
|
611
813
|
avg_estimating_function_stack,
|
|
612
814
|
per_subject_estimating_function_stacks,
|
|
613
815
|
all_post_update_betas,
|
|
@@ -752,12 +954,16 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
752
954
|
policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
|
|
753
955
|
action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
|
|
754
956
|
beta_index_by_policy_num: dict[int | float, int],
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
]
|
|
957
|
+
include_auxiliary_outputs: bool = True,
|
|
958
|
+
) -> (
|
|
959
|
+
tuple[
|
|
960
|
+
jnp.ndarray[jnp.float32],
|
|
961
|
+
jnp.ndarray[jnp.float32],
|
|
962
|
+
jnp.ndarray[jnp.float32],
|
|
963
|
+
jnp.ndarray[jnp.float32],
|
|
964
|
+
]
|
|
965
|
+
| jnp.ndarray[jnp.float32]
|
|
966
|
+
):
|
|
761
967
|
"""
|
|
762
968
|
Computes a weighted estimating function stack for a given algorithm estimating function
|
|
763
969
|
and arguments, inference estimating functio and arguments, and action probability function and
|
|
@@ -821,12 +1027,23 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
821
1027
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
822
1028
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
823
1029
|
|
|
1030
|
+
include_auxiliary_outputs (bool):
|
|
1031
|
+
If True, returns the adjusted meat, classical meat, and classical bread contributions in
|
|
1032
|
+
a second returned tuple. If False, only returns the weighted estimating function stack.
|
|
1033
|
+
|
|
824
1034
|
Returns:
|
|
825
1035
|
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
826
1036
|
stack.
|
|
827
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's
|
|
1037
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
|
|
828
1038
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
829
1039
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
1040
|
+
|
|
1041
|
+
or
|
|
1042
|
+
|
|
1043
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
1044
|
+
stack.
|
|
1045
|
+
|
|
1046
|
+
depending on the value of include_auxiliary_outputs.
|
|
830
1047
|
"""
|
|
831
1048
|
|
|
832
1049
|
logger.info(
|
|
@@ -1008,22 +1225,26 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
1008
1225
|
|
|
1009
1226
|
# 6. Return the following outputs:
|
|
1010
1227
|
# a. The first is simply the weighted estimating function stack for this subject. The average
|
|
1011
|
-
# of these is what we differentiate with respect to theta to form the
|
|
1228
|
+
# of these is what we differentiate with respect to theta to form the joint
|
|
1012
1229
|
# bread matrix, and we also compare that average to zero to check the estimating functions'
|
|
1013
1230
|
# fidelity.
|
|
1014
|
-
# b. The average outer product of these per-subject stacks across subjects is the
|
|
1231
|
+
# b. The average outer product of these per-subject stacks across subjects is the adjusted joint meat
|
|
1015
1232
|
# matrix, hence the second output.
|
|
1016
1233
|
# c. The third output is averaged across subjects to obtain the classical meat matrix.
|
|
1017
1234
|
# d. The fourth output is averaged across subjects to obtain the inverse classical bread
|
|
1018
1235
|
# matrix.
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1236
|
+
if include_auxiliary_outputs:
|
|
1237
|
+
return (
|
|
1238
|
+
weighted_stack,
|
|
1239
|
+
jnp.outer(weighted_stack, weighted_stack),
|
|
1240
|
+
jnp.outer(inference_component, inference_component),
|
|
1241
|
+
jax.jacrev(
|
|
1242
|
+
inference_estimating_func, argnums=inference_func_args_theta_index
|
|
1243
|
+
)(*threaded_inference_func_args),
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
else:
|
|
1247
|
+
return weighted_stack
|
|
1027
1248
|
|
|
1028
1249
|
|
|
1029
1250
|
def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
@@ -1063,12 +1284,13 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1063
1284
|
],
|
|
1064
1285
|
suppress_all_data_checks: bool,
|
|
1065
1286
|
suppress_interactive_data_checks: bool,
|
|
1287
|
+
include_auxiliary_outputs: bool = True,
|
|
1066
1288
|
) -> tuple[
|
|
1067
1289
|
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
1068
1290
|
]:
|
|
1069
1291
|
"""
|
|
1070
1292
|
Computes the average weighted estimating function stack across all subjects, along with
|
|
1071
|
-
auxiliary values used to construct the
|
|
1293
|
+
auxiliary values used to construct the adjusted and classical sandwich variances.
|
|
1072
1294
|
|
|
1073
1295
|
Args:
|
|
1074
1296
|
flattened_betas_and_theta (jnp.ndarray):
|
|
@@ -1137,18 +1359,26 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1137
1359
|
If True, suppresses interactive data checks that would otherwise be performed to ensure
|
|
1138
1360
|
the correctness of the threaded arguments. The checks are still performed, but
|
|
1139
1361
|
any interactive prompts are suppressed.
|
|
1362
|
+
include_auxiliary_outputs (bool):
|
|
1363
|
+
If True, returns the adjusted meat, classical meat, and classical bread contributions in addition to the average weighted estimating function stack.
|
|
1364
|
+
If False, returns only the average weighted estimating function stack.
|
|
1140
1365
|
|
|
1141
1366
|
Returns:
|
|
1142
1367
|
jnp.ndarray:
|
|
1143
1368
|
A 2D JAX NumPy array holding the average weighted estimating function stack.
|
|
1369
|
+
|
|
1144
1370
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1145
1371
|
A tuple containing
|
|
1146
1372
|
1. the average weighted estimating function stack
|
|
1147
|
-
2. the subject-level
|
|
1373
|
+
2. the subject-level adjusted meat matrix contributions
|
|
1148
1374
|
3. the subject-level classical meat matrix contributions
|
|
1149
1375
|
4. the subject-level inverse classical bread matrix contributions
|
|
1150
1376
|
5. raw per-subject weighted estimating function
|
|
1151
1377
|
stacks.
|
|
1378
|
+
or jnp.ndarray:
|
|
1379
|
+
A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
1380
|
+
stack.
|
|
1381
|
+
depending on the value of include_auxiliary_outputs.
|
|
1152
1382
|
"""
|
|
1153
1383
|
|
|
1154
1384
|
# 1. Collect estimating functions by differentiating the loss functions if needed.
|
|
@@ -1248,7 +1478,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1248
1478
|
)
|
|
1249
1479
|
|
|
1250
1480
|
# 5. Now we can compute the weighted estimating function stacks for all subjects
|
|
1251
|
-
# as well as collect related values used to construct the
|
|
1481
|
+
# as well as collect related values used to construct the adjusted and classical
|
|
1252
1482
|
# sandwich variances.
|
|
1253
1483
|
results = [
|
|
1254
1484
|
single_subject_weighted_estimating_function_stacker(
|
|
@@ -1271,16 +1501,21 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1271
1501
|
]
|
|
1272
1502
|
|
|
1273
1503
|
stacks = jnp.array([result[0] for result in results])
|
|
1504
|
+
|
|
1505
|
+
if not include_auxiliary_outputs:
|
|
1506
|
+
return jnp.mean(stacks, axis=0)
|
|
1507
|
+
|
|
1274
1508
|
outer_products = jnp.array([result[1] for result in results])
|
|
1275
1509
|
inference_only_outer_products = jnp.array([result[2] for result in results])
|
|
1276
1510
|
inference_hessians = jnp.array([result[3] for result in results])
|
|
1277
1511
|
|
|
1278
1512
|
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1279
1513
|
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1280
|
-
# for the
|
|
1281
|
-
#
|
|
1514
|
+
# for the estimating functions sum check, per_subject_classical_bread_contributions, and
|
|
1515
|
+
# classical meat and inverse read matrices. The raw per-subject stacks are also returned for
|
|
1516
|
+
# debugging purposes.
|
|
1282
1517
|
|
|
1283
|
-
# Note that returning the raw stacks here as the first
|
|
1518
|
+
# Note that returning the raw stacks here as the first argument is potentially
|
|
1284
1519
|
# memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
|
|
1285
1520
|
# inverse contributions are needed for something like CR2/CR3 small-sample corrections.
|
|
1286
1521
|
return jnp.mean(stacks, axis=0), (
|
|
@@ -1330,7 +1565,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1330
1565
|
suppress_interactive_data_checks: bool,
|
|
1331
1566
|
small_sample_correction: str,
|
|
1332
1567
|
form_adjusted_meat_adjustments_explicitly: bool,
|
|
1333
|
-
|
|
1568
|
+
stabilize_joint_bread: bool,
|
|
1334
1569
|
analysis_df: pd.DataFrame | None,
|
|
1335
1570
|
active_col_name: str | None,
|
|
1336
1571
|
action_col_name: str | None,
|
|
@@ -1352,11 +1587,11 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1352
1587
|
jnp.ndarray[jnp.float32],
|
|
1353
1588
|
]:
|
|
1354
1589
|
"""
|
|
1355
|
-
Constructs the classical and
|
|
1590
|
+
Constructs the classical and adjusted sandwich matrices, as well as various
|
|
1356
1591
|
intermediate pieces in their consruction.
|
|
1357
1592
|
|
|
1358
1593
|
This is done by computing and differentiating the average weighted estimating function stack
|
|
1359
|
-
with respect to the betas and theta, using the resulting Jacobian to compute the
|
|
1594
|
+
with respect to the betas and theta, using the resulting Jacobian to compute the bread
|
|
1360
1595
|
and meat matrices, and then stably computing sandwiches.
|
|
1361
1596
|
|
|
1362
1597
|
Args:
|
|
@@ -1426,13 +1661,13 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1426
1661
|
The type of small sample correction to apply. See SmallSampleCorrections class for
|
|
1427
1662
|
options.
|
|
1428
1663
|
form_adjusted_meat_adjustments_explicitly (bool):
|
|
1429
|
-
If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
1664
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
|
|
1430
1665
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
If True, will apply various techniques to stabilize the joint
|
|
1666
|
+
adjusted sandwich is formed without doing this.
|
|
1667
|
+
stabilize_joint_bread (bool):
|
|
1668
|
+
If True, will apply various techniques to stabilize the joint bread if necessary.
|
|
1434
1669
|
analysis_df (pd.DataFrame):
|
|
1435
|
-
The full analysis dataframe, needed if forming the
|
|
1670
|
+
The full analysis dataframe, needed if forming the adjusted meat adjustments explicitly.
|
|
1436
1671
|
active_col_name (str):
|
|
1437
1672
|
The name of the column in analysis_df indicating whether a subject is active at a given decision time.
|
|
1438
1673
|
action_col_name (str):
|
|
@@ -1443,25 +1678,25 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1443
1678
|
The name of the column in analysis_df indicating the subject ID.
|
|
1444
1679
|
action_prob_func_args (tuple):
|
|
1445
1680
|
The arguments to be passed to the action probability function, needed if forming the
|
|
1446
|
-
|
|
1681
|
+
adjusted meat adjustments explicitly.
|
|
1447
1682
|
action_prob_col_name (str):
|
|
1448
1683
|
The name of the column in analysis_df indicating the action probability of the action taken,
|
|
1449
|
-
needed if forming the
|
|
1684
|
+
needed if forming the adjusted meat adjustments explicitly.
|
|
1450
1685
|
Returns:
|
|
1451
1686
|
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
1452
1687
|
A tuple containing:
|
|
1453
|
-
- The raw joint
|
|
1454
|
-
- The (possibly) stabilized joint
|
|
1455
|
-
- The joint
|
|
1456
|
-
- The joint
|
|
1457
|
-
- The classical
|
|
1688
|
+
- The raw joint bread matrix.
|
|
1689
|
+
- The (possibly) stabilized joint bread matrix.
|
|
1690
|
+
- The joint meat matrix.
|
|
1691
|
+
- The joint sandwich matrix.
|
|
1692
|
+
- The classical bread matrix.
|
|
1458
1693
|
- The classical meat matrix.
|
|
1459
1694
|
- The classical sandwich matrix.
|
|
1460
1695
|
- The average weighted estimating function stack.
|
|
1461
1696
|
- All per-subject weighted estimating function stacks.
|
|
1462
|
-
- The per-subject
|
|
1697
|
+
- The per-subject adjusted meat small-sample corrections.
|
|
1463
1698
|
- The per-subject classical meat small-sample corrections.
|
|
1464
|
-
- The per-subject
|
|
1699
|
+
- The per-subject adjusted meat adjustments, if form_adjusted_meat_adjustments_explicitly
|
|
1465
1700
|
is True, otherwise an array of NaNs.
|
|
1466
1701
|
"""
|
|
1467
1702
|
logger.info(
|
|
@@ -1470,11 +1705,11 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1470
1705
|
theta_dim = theta_est.shape[0]
|
|
1471
1706
|
beta_dim = all_post_update_betas.shape[1]
|
|
1472
1707
|
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
1473
|
-
|
|
1708
|
+
raw_joint_bread_matrix, (
|
|
1474
1709
|
avg_estimating_function_stack,
|
|
1475
1710
|
per_subject_joint_adjusted_meat_contributions,
|
|
1476
1711
|
per_subject_classical_meat_contributions,
|
|
1477
|
-
|
|
1712
|
+
per_subject_classical_bread_contributions,
|
|
1478
1713
|
per_subject_estimating_function_stacks,
|
|
1479
1714
|
) = jax.jacrev(
|
|
1480
1715
|
get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
|
|
@@ -1521,40 +1756,38 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1521
1756
|
small_sample_correction,
|
|
1522
1757
|
per_subject_joint_adjusted_meat_contributions,
|
|
1523
1758
|
per_subject_classical_meat_contributions,
|
|
1524
|
-
|
|
1759
|
+
per_subject_classical_bread_contributions,
|
|
1525
1760
|
num_subjects,
|
|
1526
1761
|
theta_dim,
|
|
1527
1762
|
)
|
|
1528
1763
|
|
|
1529
1764
|
# Increase diagonal block dominance possibly improve conditioning of diagonal
|
|
1530
|
-
# blocks as necessary, to ensure mathematical stability of joint bread
|
|
1531
|
-
|
|
1765
|
+
# blocks as necessary, to ensure mathematical stability of joint bread
|
|
1766
|
+
stabilized_joint_bread_matrix = (
|
|
1532
1767
|
(
|
|
1533
|
-
|
|
1534
|
-
|
|
1768
|
+
stabilize_joint_bread_if_necessary(
|
|
1769
|
+
raw_joint_bread_matrix,
|
|
1535
1770
|
beta_dim,
|
|
1536
1771
|
theta_dim,
|
|
1537
1772
|
)
|
|
1538
1773
|
)
|
|
1539
|
-
if
|
|
1540
|
-
else
|
|
1774
|
+
if stabilize_joint_bread
|
|
1775
|
+
else raw_joint_bread_matrix
|
|
1541
1776
|
)
|
|
1542
1777
|
|
|
1543
1778
|
# Now stably (no explicit inversion) form our sandwiches.
|
|
1544
|
-
|
|
1545
|
-
|
|
1779
|
+
joint_sandwich = form_sandwich_from_bread_and_meat(
|
|
1780
|
+
stabilized_joint_bread_matrix,
|
|
1546
1781
|
joint_adjusted_meat_matrix,
|
|
1547
1782
|
num_subjects,
|
|
1548
|
-
method=SandwichFormationMethods.
|
|
1549
|
-
)
|
|
1550
|
-
classical_bread_inverse_matrix = jnp.mean(
|
|
1551
|
-
per_subject_classical_bread_inverse_contributions, axis=0
|
|
1783
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1552
1784
|
)
|
|
1553
|
-
|
|
1554
|
-
|
|
1785
|
+
classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
|
|
1786
|
+
classical_sandwich = form_sandwich_from_bread_and_meat(
|
|
1787
|
+
classical_bread_matrix,
|
|
1555
1788
|
classical_meat_matrix,
|
|
1556
1789
|
num_subjects,
|
|
1557
|
-
method=SandwichFormationMethods.
|
|
1790
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1558
1791
|
)
|
|
1559
1792
|
|
|
1560
1793
|
per_subject_adjusted_meat_adjustments = jnp.full(
|
|
@@ -1565,7 +1798,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1565
1798
|
form_adjusted_meat_adjustments_directly(
|
|
1566
1799
|
theta_dim,
|
|
1567
1800
|
all_post_update_betas.shape[1],
|
|
1568
|
-
|
|
1801
|
+
stabilized_joint_bread_matrix,
|
|
1569
1802
|
per_subject_estimating_function_stacks,
|
|
1570
1803
|
analysis_df,
|
|
1571
1804
|
active_col_name,
|
|
@@ -1582,9 +1815,9 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1582
1815
|
action_prob_col_name,
|
|
1583
1816
|
)
|
|
1584
1817
|
)
|
|
1585
|
-
# Validate that the
|
|
1586
|
-
# the theta-only
|
|
1587
|
-
# we get by taking a subset of the joint
|
|
1818
|
+
# Validate that the adjusted meat adjustments we just formed are accurate by constructing
|
|
1819
|
+
# the theta-only adjusted sandwich from them and checking that it matches the standard result
|
|
1820
|
+
# we get by taking a subset of the joint sandwich.
|
|
1588
1821
|
# First just apply any small-sample correction for parity.
|
|
1589
1822
|
(
|
|
1590
1823
|
_,
|
|
@@ -1595,19 +1828,19 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1595
1828
|
small_sample_correction,
|
|
1596
1829
|
per_subject_joint_adjusted_meat_contributions,
|
|
1597
1830
|
per_subject_adjusted_classical_meat_contributions,
|
|
1598
|
-
|
|
1831
|
+
per_subject_classical_bread_contributions,
|
|
1599
1832
|
num_subjects,
|
|
1600
1833
|
theta_dim,
|
|
1601
1834
|
)
|
|
1602
1835
|
theta_only_adjusted_sandwich_from_adjustments = (
|
|
1603
|
-
|
|
1604
|
-
|
|
1836
|
+
form_sandwich_from_bread_and_meat(
|
|
1837
|
+
classical_bread_matrix,
|
|
1605
1838
|
theta_only_adjusted_meat_matrix_v2,
|
|
1606
1839
|
num_subjects,
|
|
1607
|
-
method=SandwichFormationMethods.
|
|
1840
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1608
1841
|
)
|
|
1609
1842
|
)
|
|
1610
|
-
theta_only_adjusted_sandwich =
|
|
1843
|
+
theta_only_adjusted_sandwich = joint_sandwich[-theta_dim:, -theta_dim:]
|
|
1611
1844
|
|
|
1612
1845
|
if not np.allclose(
|
|
1613
1846
|
theta_only_adjusted_sandwich,
|
|
@@ -1615,17 +1848,17 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1615
1848
|
rtol=3e-2,
|
|
1616
1849
|
):
|
|
1617
1850
|
logger.warning(
|
|
1618
|
-
"There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only
|
|
1851
|
+
"There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adjusted sandwich two different ways and they do not match sufficiently."
|
|
1619
1852
|
)
|
|
1620
1853
|
|
|
1621
|
-
# Stack the joint
|
|
1622
|
-
# values too. The joint
|
|
1854
|
+
# Stack the joint bread pieces together horizontally and return the auxiliary
|
|
1855
|
+
# values too. The joint bread should always be block lower triangular.
|
|
1623
1856
|
return (
|
|
1624
|
-
|
|
1625
|
-
|
|
1857
|
+
raw_joint_bread_matrix,
|
|
1858
|
+
stabilized_joint_bread_matrix,
|
|
1626
1859
|
joint_adjusted_meat_matrix,
|
|
1627
|
-
|
|
1628
|
-
|
|
1860
|
+
joint_sandwich,
|
|
1861
|
+
classical_bread_matrix,
|
|
1629
1862
|
classical_meat_matrix,
|
|
1630
1863
|
classical_sandwich,
|
|
1631
1864
|
avg_estimating_function_stack,
|
|
@@ -1639,25 +1872,25 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1639
1872
|
# TODO: I think there should be interaction to confirm stabilization. It is
|
|
1640
1873
|
# important for the subject to know if this is happening. Even if enabled, it is important
|
|
1641
1874
|
# that the subject know it actually kicks in.
|
|
1642
|
-
def
|
|
1643
|
-
|
|
1875
|
+
def stabilize_joint_bread_if_necessary(
|
|
1876
|
+
joint_bread_matrix: jnp.ndarray,
|
|
1644
1877
|
beta_dim: int,
|
|
1645
1878
|
theta_dim: int,
|
|
1646
1879
|
) -> jnp.ndarray:
|
|
1647
1880
|
"""
|
|
1648
|
-
Stabilizes the joint
|
|
1881
|
+
Stabilizes the joint bread matrix if necessary by increasing diagonal block
|
|
1649
1882
|
dominance and/or adding a small ridge penalty to the diagonal blocks.
|
|
1650
1883
|
|
|
1651
1884
|
Args:
|
|
1652
|
-
|
|
1653
|
-
A 2-D JAX NumPy array representing the joint
|
|
1885
|
+
joint_bread_matrix (jnp.ndarray):
|
|
1886
|
+
A 2-D JAX NumPy array representing the joint bread matrix.
|
|
1654
1887
|
beta_dim (int):
|
|
1655
1888
|
The dimension of each beta parameter.
|
|
1656
1889
|
theta_dim (int):
|
|
1657
1890
|
The dimension of the theta parameter.
|
|
1658
1891
|
Returns:
|
|
1659
1892
|
jnp.ndarray:
|
|
1660
|
-
A 2-D NumPy array representing the stabilized joint
|
|
1893
|
+
A 2-D NumPy array representing the stabilized joint bread matrix.
|
|
1661
1894
|
"""
|
|
1662
1895
|
|
|
1663
1896
|
# TODO: come up with more sophisticated settings here. These are maybe a little loose,
|
|
@@ -1670,7 +1903,7 @@ def stabilize_joint_adjusted_bread_inverse_if_necessary(
|
|
|
1670
1903
|
|
|
1671
1904
|
# Grab just the RL block and convert numpy array for easier manipulation.
|
|
1672
1905
|
RL_stack_beta_derivatives_block = np.array(
|
|
1673
|
-
|
|
1906
|
+
joint_bread_matrix[:-theta_dim, :-theta_dim]
|
|
1674
1907
|
)
|
|
1675
1908
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
1676
1909
|
for i in range(1, num_updates + 1):
|
|
@@ -1789,31 +2022,31 @@ def stabilize_joint_adjusted_bread_inverse_if_necessary(
|
|
|
1789
2022
|
[
|
|
1790
2023
|
[
|
|
1791
2024
|
RL_stack_beta_derivatives_block,
|
|
1792
|
-
|
|
2025
|
+
joint_bread_matrix[:-theta_dim, -theta_dim:],
|
|
1793
2026
|
],
|
|
1794
2027
|
[
|
|
1795
|
-
|
|
1796
|
-
|
|
2028
|
+
joint_bread_matrix[-theta_dim:, :-theta_dim],
|
|
2029
|
+
joint_bread_matrix[-theta_dim:, -theta_dim:],
|
|
1797
2030
|
],
|
|
1798
2031
|
]
|
|
1799
2032
|
)
|
|
1800
2033
|
|
|
1801
2034
|
|
|
1802
|
-
def
|
|
1803
|
-
|
|
2035
|
+
def form_sandwich_from_bread_and_meat(
|
|
2036
|
+
bread: jnp.ndarray,
|
|
1804
2037
|
meat: jnp.ndarray,
|
|
1805
2038
|
num_subjects: int,
|
|
1806
|
-
method: str = SandwichFormationMethods.
|
|
2039
|
+
method: str = SandwichFormationMethods.BREAD_T_QR,
|
|
1807
2040
|
) -> jnp.ndarray:
|
|
1808
2041
|
"""
|
|
1809
|
-
Forms a sandwich variance matrix from the provided bread
|
|
2042
|
+
Forms a sandwich variance matrix from the provided bread and meat matrices.
|
|
1810
2043
|
|
|
1811
|
-
Attempts to do so STABLY without ever forming the bread matrix itself
|
|
2044
|
+
Attempts to do so STABLY without ever forming the bread inverse matrix itself
|
|
1812
2045
|
(except with naive option).
|
|
1813
2046
|
|
|
1814
2047
|
Args:
|
|
1815
|
-
|
|
1816
|
-
A 2-D JAX NumPy array representing the bread
|
|
2048
|
+
bread (jnp.ndarray):
|
|
2049
|
+
A 2-D JAX NumPy array representing the bread matrix.
|
|
1817
2050
|
meat (jnp.ndarray):
|
|
1818
2051
|
A 2-D JAX NumPy array representing the meat matrix.
|
|
1819
2052
|
num_subjects (int):
|
|
@@ -1821,12 +2054,12 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1821
2054
|
method (str):
|
|
1822
2055
|
The method to use for forming the sandwich.
|
|
1823
2056
|
|
|
1824
|
-
SandwichFormationMethods.
|
|
1825
|
-
of the bread
|
|
2057
|
+
SandwichFormationMethods.BREAD_T_QR uses the QR decomposition of the transpose
|
|
2058
|
+
of the bread matrix.
|
|
1826
2059
|
|
|
1827
2060
|
SandwichFormationMethods.MEAT_SVD_SOLVE uses a decomposition of the meat matrix.
|
|
1828
2061
|
|
|
1829
|
-
SandwichFormationMethods.NAIVE simply inverts the bread
|
|
2062
|
+
SandwichFormationMethods.NAIVE simply inverts the bread and forms the sandwich.
|
|
1830
2063
|
|
|
1831
2064
|
|
|
1832
2065
|
Returns:
|
|
@@ -1834,9 +2067,9 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1834
2067
|
A 2-D JAX NumPy array representing the sandwich variance matrix.
|
|
1835
2068
|
"""
|
|
1836
2069
|
|
|
1837
|
-
if method == SandwichFormationMethods.
|
|
2070
|
+
if method == SandwichFormationMethods.BREAD_T_QR:
|
|
1838
2071
|
# QR of B^T → Q orthogonal, R upper triangular; L = R^T lower triangular
|
|
1839
|
-
Q, R = np.linalg.qr(
|
|
2072
|
+
Q, R = np.linalg.qr(bread.T, mode="reduced")
|
|
1840
2073
|
L = R.T
|
|
1841
2074
|
|
|
1842
2075
|
new_meat = scipy.linalg.solve_triangular(
|
|
@@ -1854,21 +2087,21 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1854
2087
|
C_right = Vh.T * np.sqrt(s)
|
|
1855
2088
|
|
|
1856
2089
|
# Solve B W_left = C_left and B W_right = C_right (no explicit inverses).
|
|
1857
|
-
W_left = scipy.linalg.solve(
|
|
1858
|
-
W_right = scipy.linalg.solve(
|
|
2090
|
+
W_left = scipy.linalg.solve(bread, C_left)
|
|
2091
|
+
W_right = scipy.linalg.solve(bread, C_right)
|
|
1859
2092
|
|
|
1860
2093
|
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
|
|
1861
2094
|
return W_left @ W_right.T / num_subjects
|
|
1862
2095
|
|
|
1863
2096
|
elif method == SandwichFormationMethods.NAIVE:
|
|
1864
|
-
# Simply invert the bread
|
|
2097
|
+
# Simply invert the bread and form the sandwich directly.
|
|
1865
2098
|
# This is NOT numerically stable and is only included for comparison purposes.
|
|
1866
|
-
|
|
1867
|
-
return
|
|
2099
|
+
bread_inverse = np.linalg.inv(bread)
|
|
2100
|
+
return bread_inverse @ meat @ bread_inverse.T / num_subjects
|
|
1868
2101
|
|
|
1869
2102
|
else:
|
|
1870
2103
|
raise ValueError(
|
|
1871
|
-
f"Unknown sandwich method: {method}. Please use '
|
|
2104
|
+
f"Unknown sandwich method: {method}. Please use 'bread_t_qr' or 'meat_decomposition_solve'."
|
|
1872
2105
|
)
|
|
1873
2106
|
|
|
1874
2107
|
|