lifejacket 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/calculate_derivatives.py +0 -2
- lifejacket/constants.py +4 -16
- lifejacket/deployment_conditioning_monitor.py +19 -12
- lifejacket/form_adjusted_meat_adjustments_directly.py +25 -27
- lifejacket/get_datum_for_blowup_supervised_learning.py +71 -77
- lifejacket/helper_functions.py +15 -148
- lifejacket/input_checks.py +49 -50
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +127 -124
- lifejacket/small_sample_corrections.py +11 -13
- {lifejacket-1.0.0.dist-info → lifejacket-1.0.2.dist-info}/METADATA +1 -1
- lifejacket-1.0.2.dist-info/RECORD +17 -0
- lifejacket-1.0.2.dist-info/entry_points.txt +2 -0
- lifejacket-1.0.0.dist-info/RECORD +0 -17
- lifejacket-1.0.0.dist-info/entry_points.txt +0 -2
- {lifejacket-1.0.0.dist-info → lifejacket-1.0.2.dist-info}/WHEEL +0 -0
- {lifejacket-1.0.0.dist-info → lifejacket-1.0.2.dist-info}/top_level.txt +0 -0
|
@@ -217,9 +217,9 @@ def cli():
|
|
|
217
217
|
type=click.Choice(
|
|
218
218
|
[
|
|
219
219
|
SmallSampleCorrections.NONE,
|
|
220
|
-
SmallSampleCorrections.
|
|
221
|
-
SmallSampleCorrections.
|
|
222
|
-
SmallSampleCorrections.
|
|
220
|
+
SmallSampleCorrections.Z1theta,
|
|
221
|
+
SmallSampleCorrections.Z2theta,
|
|
222
|
+
SmallSampleCorrections.Z3theta,
|
|
223
223
|
]
|
|
224
224
|
),
|
|
225
225
|
default=SmallSampleCorrections.NONE,
|
|
@@ -235,13 +235,13 @@ def cli():
|
|
|
235
235
|
"--form_adjusted_meat_adjustments_explicitly",
|
|
236
236
|
type=bool,
|
|
237
237
|
default=False,
|
|
238
|
-
help="If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
238
|
+
help="If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted sandwich from the classical sandwich. This is for diagnostic purposes, as the adjusted sandwich is formed without doing this.",
|
|
239
239
|
)
|
|
240
240
|
@click.option(
|
|
241
|
-
"--
|
|
241
|
+
"--stabilize_joint_bread",
|
|
242
242
|
type=bool,
|
|
243
243
|
default=True,
|
|
244
|
-
help="If True, stabilizes the joint
|
|
244
|
+
help="If True, stabilizes the joint bread matrix if it does not meet conditioning thresholds.",
|
|
245
245
|
)
|
|
246
246
|
def analyze_dataset_wrapper(**kwargs):
|
|
247
247
|
"""
|
|
@@ -324,15 +324,15 @@ def analyze_dataset(
|
|
|
324
324
|
small_sample_correction: str,
|
|
325
325
|
collect_data_for_blowup_supervised_learning: bool,
|
|
326
326
|
form_adjusted_meat_adjustments_explicitly: bool,
|
|
327
|
-
|
|
327
|
+
stabilize_joint_bread: bool,
|
|
328
328
|
) -> None:
|
|
329
329
|
"""
|
|
330
|
-
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using
|
|
330
|
+
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using and classical sandwich estimators.
|
|
331
331
|
|
|
332
332
|
There are two modes of use for this function.
|
|
333
333
|
|
|
334
334
|
First, it may be called indirectly from the command line by passing through
|
|
335
|
-
|
|
335
|
+
analyze_dataset_wrapper.
|
|
336
336
|
|
|
337
337
|
Second, it may be called directly from Python code with in-memory objects.
|
|
338
338
|
|
|
@@ -388,17 +388,17 @@ def analyze_dataset(
|
|
|
388
388
|
small_sample_correction (str):
|
|
389
389
|
Type of small sample correction to apply.
|
|
390
390
|
collect_data_for_blowup_supervised_learning (bool):
|
|
391
|
-
Whether to collect data for doing supervised learning about
|
|
391
|
+
Whether to collect data for doing supervised learning about adjusted sandwich blowup.
|
|
392
392
|
form_adjusted_meat_adjustments_explicitly (bool):
|
|
393
|
-
If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
393
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
394
394
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
If True, stabilizes the joint
|
|
395
|
+
adjusted sandwich is formed without doing this.
|
|
396
|
+
stabilize_joint_bread (bool):
|
|
397
|
+
If True, stabilizes the joint bread matrix if it does not meet conditioning
|
|
398
398
|
thresholds.
|
|
399
399
|
|
|
400
400
|
Returns:
|
|
401
|
-
dict: A dictionary containing the theta estimate,
|
|
401
|
+
dict: A dictionary containing the theta estimate, adjusted sandwich variance estimate, and
|
|
402
402
|
classical sandwich variance estimate.
|
|
403
403
|
"""
|
|
404
404
|
|
|
@@ -438,7 +438,6 @@ def analyze_dataset(
|
|
|
438
438
|
)
|
|
439
439
|
|
|
440
440
|
### Begin collecting data structures that will be used to compute the joint bread matrix.
|
|
441
|
-
|
|
442
441
|
beta_index_by_policy_num, initial_policy_num = (
|
|
443
442
|
construct_beta_index_by_policy_num_map(
|
|
444
443
|
analysis_df, policy_num_col_name, active_col_name
|
|
@@ -475,20 +474,20 @@ def analyze_dataset(
|
|
|
475
474
|
active_col_name,
|
|
476
475
|
)
|
|
477
476
|
|
|
478
|
-
# Use a per-subject weighted estimating function stacking
|
|
479
|
-
#
|
|
477
|
+
# Use a per-subject weighted estimating function stacking function to derive classical and joint
|
|
478
|
+
# meat and bread matrices. This is facilitated because the *value* of the
|
|
480
479
|
# weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
|
|
481
480
|
logger.info(
|
|
482
|
-
"Constructing joint
|
|
481
|
+
"Constructing joint bread matrix, joint meat matrix, the classical analogs, and the avg estimating function stack across subjects."
|
|
483
482
|
)
|
|
484
483
|
|
|
485
484
|
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
486
485
|
(
|
|
487
|
-
|
|
488
|
-
|
|
486
|
+
stabilized_joint_adjusted_bread_matrix,
|
|
487
|
+
raw_joint_adjusted_bread_matrix,
|
|
489
488
|
joint_adjusted_meat_matrix,
|
|
490
489
|
joint_adjusted_sandwich_matrix,
|
|
491
|
-
|
|
490
|
+
classical_bread_matrix,
|
|
492
491
|
classical_meat_matrix,
|
|
493
492
|
classical_sandwich_var_estimate,
|
|
494
493
|
avg_estimating_function_stack,
|
|
@@ -524,7 +523,7 @@ def analyze_dataset(
|
|
|
524
523
|
suppress_interactive_data_checks,
|
|
525
524
|
small_sample_correction,
|
|
526
525
|
form_adjusted_meat_adjustments_explicitly,
|
|
527
|
-
|
|
526
|
+
stabilize_joint_bread,
|
|
528
527
|
analysis_df,
|
|
529
528
|
active_col_name,
|
|
530
529
|
action_col_name,
|
|
@@ -550,17 +549,16 @@ def analyze_dataset(
|
|
|
550
549
|
]
|
|
551
550
|
|
|
552
551
|
# Check for negative diagonal elements and set them to zero if found
|
|
553
|
-
|
|
554
|
-
if np.any(
|
|
552
|
+
adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
|
|
553
|
+
if np.any(adjusted_diagonal < 0):
|
|
555
554
|
logger.warning(
|
|
556
|
-
"Found negative diagonal elements in
|
|
555
|
+
"Found negative diagonal elements in adjusted sandwich variance estimate. Setting them to zero."
|
|
557
556
|
)
|
|
558
557
|
np.fill_diagonal(
|
|
559
|
-
adjusted_sandwich_var_estimate, np.maximum(
|
|
558
|
+
adjusted_sandwich_var_estimate, np.maximum(adjusted_diagonal, 0)
|
|
560
559
|
)
|
|
561
560
|
|
|
562
561
|
logger.info("Writing results to file...")
|
|
563
|
-
# Write analysis results to same directory that input files are in
|
|
564
562
|
output_folder_abs_path = pathlib.Path(output_dir).resolve()
|
|
565
563
|
|
|
566
564
|
analysis_dict = {
|
|
@@ -574,25 +572,31 @@ def analyze_dataset(
|
|
|
574
572
|
f,
|
|
575
573
|
)
|
|
576
574
|
|
|
577
|
-
|
|
578
|
-
|
|
575
|
+
joint_adjusted_bread_cond = jnp.linalg.cond(raw_joint_adjusted_bread_matrix)
|
|
576
|
+
logger.info(
|
|
577
|
+
"Joint adjusted bread condition number: %f",
|
|
578
|
+
joint_adjusted_bread_cond,
|
|
579
579
|
)
|
|
580
|
+
|
|
581
|
+
# calculate the max eigenvalue of the joint adjusted sandwich
|
|
582
|
+
max_eigenvalue = scipy.linalg.eigvalsh(joint_adjusted_sandwich_matrix).max()
|
|
580
583
|
logger.info(
|
|
581
|
-
"
|
|
582
|
-
|
|
584
|
+
"Max eigenvalue of joint adjusted sandwich matrix: %f",
|
|
585
|
+
max_eigenvalue,
|
|
583
586
|
)
|
|
584
587
|
|
|
585
588
|
debug_pieces_dict = {
|
|
586
589
|
"theta_est": theta_est,
|
|
587
590
|
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
588
591
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
589
|
-
"
|
|
590
|
-
"
|
|
592
|
+
"raw_joint_bread_matrix": raw_joint_adjusted_bread_matrix,
|
|
593
|
+
"stabilized_joint_bread_matrix": stabilized_joint_adjusted_bread_matrix,
|
|
591
594
|
"joint_meat_matrix": joint_adjusted_meat_matrix,
|
|
592
|
-
"
|
|
595
|
+
"classical_bread_matrix": classical_bread_matrix,
|
|
593
596
|
"classical_meat_matrix": classical_meat_matrix,
|
|
594
597
|
"all_estimating_function_stacks": per_subject_estimating_function_stacks,
|
|
595
|
-
"
|
|
598
|
+
"joint_bread_condition_number": joint_adjusted_bread_cond,
|
|
599
|
+
"max_eigenvalue_joint_adjusted_sandwich": max_eigenvalue,
|
|
596
600
|
"all_post_update_betas": all_post_update_betas,
|
|
597
601
|
"per_subject_adjusted_corrections": per_subject_adjusted_corrections,
|
|
598
602
|
"per_subject_classical_corrections": per_subject_classical_corrections,
|
|
@@ -606,8 +610,8 @@ def analyze_dataset(
|
|
|
606
610
|
|
|
607
611
|
if collect_data_for_blowup_supervised_learning:
|
|
608
612
|
datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
|
|
609
|
-
|
|
610
|
-
|
|
613
|
+
raw_joint_adjusted_bread_matrix,
|
|
614
|
+
joint_adjusted_bread_cond,
|
|
611
615
|
avg_estimating_function_stack,
|
|
612
616
|
per_subject_estimating_function_stacks,
|
|
613
617
|
all_post_update_betas,
|
|
@@ -824,7 +828,7 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
824
828
|
Returns:
|
|
825
829
|
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
826
830
|
stack.
|
|
827
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's
|
|
831
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
|
|
828
832
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
829
833
|
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
830
834
|
"""
|
|
@@ -1008,10 +1012,10 @@ def single_subject_weighted_estimating_function_stacker(
|
|
|
1008
1012
|
|
|
1009
1013
|
# 6. Return the following outputs:
|
|
1010
1014
|
# a. The first is simply the weighted estimating function stack for this subject. The average
|
|
1011
|
-
# of these is what we differentiate with respect to theta to form the
|
|
1015
|
+
# of these is what we differentiate with respect to theta to form the joint
|
|
1012
1016
|
# bread matrix, and we also compare that average to zero to check the estimating functions'
|
|
1013
1017
|
# fidelity.
|
|
1014
|
-
# b. The average outer product of these per-subject stacks across subjects is the
|
|
1018
|
+
# b. The average outer product of these per-subject stacks across subjects is the adjusted joint meat
|
|
1015
1019
|
# matrix, hence the second output.
|
|
1016
1020
|
# c. The third output is averaged across subjects to obtain the classical meat matrix.
|
|
1017
1021
|
# d. The fourth output is averaged across subjects to obtain the inverse classical bread
|
|
@@ -1068,7 +1072,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1068
1072
|
]:
|
|
1069
1073
|
"""
|
|
1070
1074
|
Computes the average weighted estimating function stack across all subjects, along with
|
|
1071
|
-
auxiliary values used to construct the
|
|
1075
|
+
auxiliary values used to construct the adjusted and classical sandwich variances.
|
|
1072
1076
|
|
|
1073
1077
|
Args:
|
|
1074
1078
|
flattened_betas_and_theta (jnp.ndarray):
|
|
@@ -1144,7 +1148,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1144
1148
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1145
1149
|
A tuple containing
|
|
1146
1150
|
1. the average weighted estimating function stack
|
|
1147
|
-
2. the subject-level
|
|
1151
|
+
2. the subject-level adjusted meat matrix contributions
|
|
1148
1152
|
3. the subject-level classical meat matrix contributions
|
|
1149
1153
|
4. the subject-level inverse classical bread matrix contributions
|
|
1150
1154
|
5. raw per-subject weighted estimating function
|
|
@@ -1248,7 +1252,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1248
1252
|
)
|
|
1249
1253
|
|
|
1250
1254
|
# 5. Now we can compute the weighted estimating function stacks for all subjects
|
|
1251
|
-
# as well as collect related values used to construct the
|
|
1255
|
+
# as well as collect related values used to construct the adjusted and classical
|
|
1252
1256
|
# sandwich variances.
|
|
1253
1257
|
results = [
|
|
1254
1258
|
single_subject_weighted_estimating_function_stacker(
|
|
@@ -1277,10 +1281,11 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1277
1281
|
|
|
1278
1282
|
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1279
1283
|
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1280
|
-
# for the
|
|
1281
|
-
#
|
|
1284
|
+
# for the estimating functions sum check, per_subject_classical_bread_contributions, and
|
|
1285
|
+
# classical meat and inverse read matrices. The raw per-subject stacks are also returned for
|
|
1286
|
+
# debugging purposes.
|
|
1282
1287
|
|
|
1283
|
-
# Note that returning the raw stacks here as the first
|
|
1288
|
+
# Note that returning the raw stacks here as the first argument is potentially
|
|
1284
1289
|
# memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
|
|
1285
1290
|
# inverse contributions are needed for something like CR2/CR3 small-sample corrections.
|
|
1286
1291
|
return jnp.mean(stacks, axis=0), (
|
|
@@ -1330,7 +1335,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1330
1335
|
suppress_interactive_data_checks: bool,
|
|
1331
1336
|
small_sample_correction: str,
|
|
1332
1337
|
form_adjusted_meat_adjustments_explicitly: bool,
|
|
1333
|
-
|
|
1338
|
+
stabilize_joint_bread: bool,
|
|
1334
1339
|
analysis_df: pd.DataFrame | None,
|
|
1335
1340
|
active_col_name: str | None,
|
|
1336
1341
|
action_col_name: str | None,
|
|
@@ -1352,11 +1357,11 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1352
1357
|
jnp.ndarray[jnp.float32],
|
|
1353
1358
|
]:
|
|
1354
1359
|
"""
|
|
1355
|
-
Constructs the classical and
|
|
1360
|
+
Constructs the classical and adjusted sandwich matrices, as well as various
|
|
1356
1361
|
intermediate pieces in their consruction.
|
|
1357
1362
|
|
|
1358
1363
|
This is done by computing and differentiating the average weighted estimating function stack
|
|
1359
|
-
with respect to the betas and theta, using the resulting Jacobian to compute the
|
|
1364
|
+
with respect to the betas and theta, using the resulting Jacobian to compute the bread
|
|
1360
1365
|
and meat matrices, and then stably computing sandwiches.
|
|
1361
1366
|
|
|
1362
1367
|
Args:
|
|
@@ -1426,13 +1431,13 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1426
1431
|
The type of small sample correction to apply. See SmallSampleCorrections class for
|
|
1427
1432
|
options.
|
|
1428
1433
|
form_adjusted_meat_adjustments_explicitly (bool):
|
|
1429
|
-
If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
1434
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
|
|
1430
1435
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
If True, will apply various techniques to stabilize the joint
|
|
1436
|
+
adjusted sandwich is formed without doing this.
|
|
1437
|
+
stabilize_joint_bread (bool):
|
|
1438
|
+
If True, will apply various techniques to stabilize the joint bread if necessary.
|
|
1434
1439
|
analysis_df (pd.DataFrame):
|
|
1435
|
-
The full analysis dataframe, needed if forming the
|
|
1440
|
+
The full analysis dataframe, needed if forming the adjusted meat adjustments explicitly.
|
|
1436
1441
|
active_col_name (str):
|
|
1437
1442
|
The name of the column in analysis_df indicating whether a subject is active at a given decision time.
|
|
1438
1443
|
action_col_name (str):
|
|
@@ -1443,25 +1448,25 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1443
1448
|
The name of the column in analysis_df indicating the subject ID.
|
|
1444
1449
|
action_prob_func_args (tuple):
|
|
1445
1450
|
The arguments to be passed to the action probability function, needed if forming the
|
|
1446
|
-
|
|
1451
|
+
adjusted meat adjustments explicitly.
|
|
1447
1452
|
action_prob_col_name (str):
|
|
1448
1453
|
The name of the column in analysis_df indicating the action probability of the action taken,
|
|
1449
|
-
needed if forming the
|
|
1454
|
+
needed if forming the adjusted meat adjustments explicitly.
|
|
1450
1455
|
Returns:
|
|
1451
1456
|
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
1452
1457
|
A tuple containing:
|
|
1453
|
-
- The raw joint
|
|
1454
|
-
- The (possibly) stabilized joint
|
|
1455
|
-
- The joint
|
|
1456
|
-
- The joint
|
|
1457
|
-
- The classical
|
|
1458
|
+
- The raw joint bread matrix.
|
|
1459
|
+
- The (possibly) stabilized joint bread matrix.
|
|
1460
|
+
- The joint meat matrix.
|
|
1461
|
+
- The joint sandwich matrix.
|
|
1462
|
+
- The classical bread matrix.
|
|
1458
1463
|
- The classical meat matrix.
|
|
1459
1464
|
- The classical sandwich matrix.
|
|
1460
1465
|
- The average weighted estimating function stack.
|
|
1461
1466
|
- All per-subject weighted estimating function stacks.
|
|
1462
|
-
- The per-subject
|
|
1467
|
+
- The per-subject adjusted meat small-sample corrections.
|
|
1463
1468
|
- The per-subject classical meat small-sample corrections.
|
|
1464
|
-
- The per-subject
|
|
1469
|
+
- The per-subject adjusted meat adjustments, if form_adjusted_meat_adjustments_explicitly
|
|
1465
1470
|
is True, otherwise an array of NaNs.
|
|
1466
1471
|
"""
|
|
1467
1472
|
logger.info(
|
|
@@ -1470,11 +1475,11 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1470
1475
|
theta_dim = theta_est.shape[0]
|
|
1471
1476
|
beta_dim = all_post_update_betas.shape[1]
|
|
1472
1477
|
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
1473
|
-
|
|
1478
|
+
raw_joint_adjusted_bread_matrix, (
|
|
1474
1479
|
avg_estimating_function_stack,
|
|
1475
1480
|
per_subject_joint_adjusted_meat_contributions,
|
|
1476
1481
|
per_subject_classical_meat_contributions,
|
|
1477
|
-
|
|
1482
|
+
per_subject_classical_bread_contributions,
|
|
1478
1483
|
per_subject_estimating_function_stacks,
|
|
1479
1484
|
) = jax.jacrev(
|
|
1480
1485
|
get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
|
|
@@ -1521,40 +1526,38 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1521
1526
|
small_sample_correction,
|
|
1522
1527
|
per_subject_joint_adjusted_meat_contributions,
|
|
1523
1528
|
per_subject_classical_meat_contributions,
|
|
1524
|
-
|
|
1529
|
+
per_subject_classical_bread_contributions,
|
|
1525
1530
|
num_subjects,
|
|
1526
1531
|
theta_dim,
|
|
1527
1532
|
)
|
|
1528
1533
|
|
|
1529
1534
|
# Increase diagonal block dominance possibly improve conditioning of diagonal
|
|
1530
|
-
# blocks as necessary, to ensure mathematical stability of joint bread
|
|
1531
|
-
|
|
1535
|
+
# blocks as necessary, to ensure mathematical stability of joint bread
|
|
1536
|
+
stabilized_joint_adjusted_bread_matrix = (
|
|
1532
1537
|
(
|
|
1533
|
-
|
|
1534
|
-
|
|
1538
|
+
stabilize_joint_bread_if_necessary(
|
|
1539
|
+
raw_joint_adjusted_bread_matrix,
|
|
1535
1540
|
beta_dim,
|
|
1536
1541
|
theta_dim,
|
|
1537
1542
|
)
|
|
1538
1543
|
)
|
|
1539
|
-
if
|
|
1540
|
-
else
|
|
1544
|
+
if stabilize_joint_bread
|
|
1545
|
+
else raw_joint_adjusted_bread_matrix
|
|
1541
1546
|
)
|
|
1542
1547
|
|
|
1543
1548
|
# Now stably (no explicit inversion) form our sandwiches.
|
|
1544
|
-
joint_adjusted_sandwich =
|
|
1545
|
-
|
|
1549
|
+
joint_adjusted_sandwich = form_sandwich_from_bread_and_meat(
|
|
1550
|
+
stabilized_joint_adjusted_bread_matrix,
|
|
1546
1551
|
joint_adjusted_meat_matrix,
|
|
1547
1552
|
num_subjects,
|
|
1548
|
-
method=SandwichFormationMethods.
|
|
1549
|
-
)
|
|
1550
|
-
classical_bread_inverse_matrix = jnp.mean(
|
|
1551
|
-
per_subject_classical_bread_inverse_contributions, axis=0
|
|
1553
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1552
1554
|
)
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
+
classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
|
|
1556
|
+
classical_sandwich = form_sandwich_from_bread_and_meat(
|
|
1557
|
+
classical_bread_matrix,
|
|
1555
1558
|
classical_meat_matrix,
|
|
1556
1559
|
num_subjects,
|
|
1557
|
-
method=SandwichFormationMethods.
|
|
1560
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1558
1561
|
)
|
|
1559
1562
|
|
|
1560
1563
|
per_subject_adjusted_meat_adjustments = jnp.full(
|
|
@@ -1565,7 +1568,7 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1565
1568
|
form_adjusted_meat_adjustments_directly(
|
|
1566
1569
|
theta_dim,
|
|
1567
1570
|
all_post_update_betas.shape[1],
|
|
1568
|
-
|
|
1571
|
+
stabilized_joint_adjusted_bread_matrix,
|
|
1569
1572
|
per_subject_estimating_function_stacks,
|
|
1570
1573
|
analysis_df,
|
|
1571
1574
|
active_col_name,
|
|
@@ -1582,9 +1585,9 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1582
1585
|
action_prob_col_name,
|
|
1583
1586
|
)
|
|
1584
1587
|
)
|
|
1585
|
-
# Validate that the
|
|
1586
|
-
# the theta-only
|
|
1587
|
-
# we get by taking a subset of the joint
|
|
1588
|
+
# Validate that the adjusted meat adjustments we just formed are accurate by constructing
|
|
1589
|
+
# the theta-only adjusted sandwich from them and checking that it matches the standard result
|
|
1590
|
+
# we get by taking a subset of the joint sandwich.
|
|
1588
1591
|
# First just apply any small-sample correction for parity.
|
|
1589
1592
|
(
|
|
1590
1593
|
_,
|
|
@@ -1595,16 +1598,16 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1595
1598
|
small_sample_correction,
|
|
1596
1599
|
per_subject_joint_adjusted_meat_contributions,
|
|
1597
1600
|
per_subject_adjusted_classical_meat_contributions,
|
|
1598
|
-
|
|
1601
|
+
per_subject_classical_bread_contributions,
|
|
1599
1602
|
num_subjects,
|
|
1600
1603
|
theta_dim,
|
|
1601
1604
|
)
|
|
1602
1605
|
theta_only_adjusted_sandwich_from_adjustments = (
|
|
1603
|
-
|
|
1604
|
-
|
|
1606
|
+
form_sandwich_from_bread_and_meat(
|
|
1607
|
+
classical_bread_matrix,
|
|
1605
1608
|
theta_only_adjusted_meat_matrix_v2,
|
|
1606
1609
|
num_subjects,
|
|
1607
|
-
method=SandwichFormationMethods.
|
|
1610
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1608
1611
|
)
|
|
1609
1612
|
)
|
|
1610
1613
|
theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
|
|
@@ -1615,17 +1618,17 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1615
1618
|
rtol=3e-2,
|
|
1616
1619
|
):
|
|
1617
1620
|
logger.warning(
|
|
1618
|
-
"There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only
|
|
1621
|
+
"There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adjusted sandwich two different ways and they do not match sufficiently."
|
|
1619
1622
|
)
|
|
1620
1623
|
|
|
1621
|
-
# Stack the joint
|
|
1622
|
-
# values too. The joint
|
|
1624
|
+
# Stack the joint bread pieces together horizontally and return the auxiliary
|
|
1625
|
+
# values too. The joint bread should always be block lower triangular.
|
|
1623
1626
|
return (
|
|
1624
|
-
|
|
1625
|
-
|
|
1627
|
+
raw_joint_adjusted_bread_matrix,
|
|
1628
|
+
stabilized_joint_adjusted_bread_matrix,
|
|
1626
1629
|
joint_adjusted_meat_matrix,
|
|
1627
1630
|
joint_adjusted_sandwich,
|
|
1628
|
-
|
|
1631
|
+
classical_bread_matrix,
|
|
1629
1632
|
classical_meat_matrix,
|
|
1630
1633
|
classical_sandwich,
|
|
1631
1634
|
avg_estimating_function_stack,
|
|
@@ -1639,25 +1642,25 @@ def construct_classical_and_adjusted_sandwiches(
|
|
|
1639
1642
|
# TODO: I think there should be interaction to confirm stabilization. It is
|
|
1640
1643
|
# important for the subject to know if this is happening. Even if enabled, it is important
|
|
1641
1644
|
# that the subject know it actually kicks in.
|
|
1642
|
-
def
|
|
1643
|
-
|
|
1645
|
+
def stabilize_joint_bread_if_necessary(
|
|
1646
|
+
joint_adjusted_bread_matrix: jnp.ndarray,
|
|
1644
1647
|
beta_dim: int,
|
|
1645
1648
|
theta_dim: int,
|
|
1646
1649
|
) -> jnp.ndarray:
|
|
1647
1650
|
"""
|
|
1648
|
-
Stabilizes the joint
|
|
1651
|
+
Stabilizes the joint bread matrix if necessary by increasing diagonal block
|
|
1649
1652
|
dominance and/or adding a small ridge penalty to the diagonal blocks.
|
|
1650
1653
|
|
|
1651
1654
|
Args:
|
|
1652
|
-
|
|
1653
|
-
A 2-D JAX NumPy array representing the joint
|
|
1655
|
+
joint_adjusted_bread_matrix (jnp.ndarray):
|
|
1656
|
+
A 2-D JAX NumPy array representing the joint bread matrix.
|
|
1654
1657
|
beta_dim (int):
|
|
1655
1658
|
The dimension of each beta parameter.
|
|
1656
1659
|
theta_dim (int):
|
|
1657
1660
|
The dimension of the theta parameter.
|
|
1658
1661
|
Returns:
|
|
1659
1662
|
jnp.ndarray:
|
|
1660
|
-
A 2-D NumPy array representing the stabilized joint
|
|
1663
|
+
A 2-D NumPy array representing the stabilized joint bread matrix.
|
|
1661
1664
|
"""
|
|
1662
1665
|
|
|
1663
1666
|
# TODO: come up with more sophisticated settings here. These are maybe a little loose,
|
|
@@ -1670,7 +1673,7 @@ def stabilize_joint_adjusted_bread_inverse_if_necessary(
|
|
|
1670
1673
|
|
|
1671
1674
|
# Grab just the RL block and convert numpy array for easier manipulation.
|
|
1672
1675
|
RL_stack_beta_derivatives_block = np.array(
|
|
1673
|
-
|
|
1676
|
+
joint_adjusted_bread_matrix[:-theta_dim, :-theta_dim]
|
|
1674
1677
|
)
|
|
1675
1678
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
1676
1679
|
for i in range(1, num_updates + 1):
|
|
@@ -1789,31 +1792,31 @@ def stabilize_joint_adjusted_bread_inverse_if_necessary(
|
|
|
1789
1792
|
[
|
|
1790
1793
|
[
|
|
1791
1794
|
RL_stack_beta_derivatives_block,
|
|
1792
|
-
|
|
1795
|
+
joint_adjusted_bread_matrix[:-theta_dim, -theta_dim:],
|
|
1793
1796
|
],
|
|
1794
1797
|
[
|
|
1795
|
-
|
|
1796
|
-
|
|
1798
|
+
joint_adjusted_bread_matrix[-theta_dim:, :-theta_dim],
|
|
1799
|
+
joint_adjusted_bread_matrix[-theta_dim:, -theta_dim:],
|
|
1797
1800
|
],
|
|
1798
1801
|
]
|
|
1799
1802
|
)
|
|
1800
1803
|
|
|
1801
1804
|
|
|
1802
|
-
def
|
|
1803
|
-
|
|
1805
|
+
def form_sandwich_from_bread_and_meat(
|
|
1806
|
+
bread: jnp.ndarray,
|
|
1804
1807
|
meat: jnp.ndarray,
|
|
1805
1808
|
num_subjects: int,
|
|
1806
|
-
method: str = SandwichFormationMethods.
|
|
1809
|
+
method: str = SandwichFormationMethods.BREAD_T_QR,
|
|
1807
1810
|
) -> jnp.ndarray:
|
|
1808
1811
|
"""
|
|
1809
|
-
Forms a sandwich variance matrix from the provided bread
|
|
1812
|
+
Forms a sandwich variance matrix from the provided bread and meat matrices.
|
|
1810
1813
|
|
|
1811
|
-
Attempts to do so STABLY without ever forming the bread matrix itself
|
|
1814
|
+
Attempts to do so STABLY without ever forming the bread inverse matrix itself
|
|
1812
1815
|
(except with naive option).
|
|
1813
1816
|
|
|
1814
1817
|
Args:
|
|
1815
|
-
|
|
1816
|
-
A 2-D JAX NumPy array representing the bread
|
|
1818
|
+
bread (jnp.ndarray):
|
|
1819
|
+
A 2-D JAX NumPy array representing the bread matrix.
|
|
1817
1820
|
meat (jnp.ndarray):
|
|
1818
1821
|
A 2-D JAX NumPy array representing the meat matrix.
|
|
1819
1822
|
num_subjects (int):
|
|
@@ -1821,12 +1824,12 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1821
1824
|
method (str):
|
|
1822
1825
|
The method to use for forming the sandwich.
|
|
1823
1826
|
|
|
1824
|
-
SandwichFormationMethods.
|
|
1825
|
-
of the bread
|
|
1827
|
+
SandwichFormationMethods.BREAD_T_QR uses the QR decomposition of the transpose
|
|
1828
|
+
of the bread matrix.
|
|
1826
1829
|
|
|
1827
1830
|
SandwichFormationMethods.MEAT_SVD_SOLVE uses a decomposition of the meat matrix.
|
|
1828
1831
|
|
|
1829
|
-
SandwichFormationMethods.NAIVE simply inverts the bread
|
|
1832
|
+
SandwichFormationMethods.NAIVE simply inverts the bread and forms the sandwich.
|
|
1830
1833
|
|
|
1831
1834
|
|
|
1832
1835
|
Returns:
|
|
@@ -1834,9 +1837,9 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1834
1837
|
A 2-D JAX NumPy array representing the sandwich variance matrix.
|
|
1835
1838
|
"""
|
|
1836
1839
|
|
|
1837
|
-
if method == SandwichFormationMethods.
|
|
1840
|
+
if method == SandwichFormationMethods.BREAD_T_QR:
|
|
1838
1841
|
# QR of B^T → Q orthogonal, R upper triangular; L = R^T lower triangular
|
|
1839
|
-
Q, R = np.linalg.qr(
|
|
1842
|
+
Q, R = np.linalg.qr(bread.T, mode="reduced")
|
|
1840
1843
|
L = R.T
|
|
1841
1844
|
|
|
1842
1845
|
new_meat = scipy.linalg.solve_triangular(
|
|
@@ -1854,21 +1857,21 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1854
1857
|
C_right = Vh.T * np.sqrt(s)
|
|
1855
1858
|
|
|
1856
1859
|
# Solve B W_left = C_left and B W_right = C_right (no explicit inverses).
|
|
1857
|
-
W_left = scipy.linalg.solve(
|
|
1858
|
-
W_right = scipy.linalg.solve(
|
|
1860
|
+
W_left = scipy.linalg.solve(bread, C_left)
|
|
1861
|
+
W_right = scipy.linalg.solve(bread, C_right)
|
|
1859
1862
|
|
|
1860
1863
|
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
|
|
1861
1864
|
return W_left @ W_right.T / num_subjects
|
|
1862
1865
|
|
|
1863
1866
|
elif method == SandwichFormationMethods.NAIVE:
|
|
1864
|
-
# Simply invert the bread
|
|
1867
|
+
# Simply invert the bread and form the sandwich directly.
|
|
1865
1868
|
# This is NOT numerically stable and is only included for comparison purposes.
|
|
1866
|
-
|
|
1867
|
-
return
|
|
1869
|
+
bread_inverse = np.linalg.inv(bread)
|
|
1870
|
+
return bread_inverse @ meat @ bread_inverse.T / num_subjects
|
|
1868
1871
|
|
|
1869
1872
|
else:
|
|
1870
1873
|
raise ValueError(
|
|
1871
|
-
f"Unknown sandwich method: {method}. Please use '
|
|
1874
|
+
f"Unknown sandwich method: {method}. Please use 'bread_t_qr' or 'meat_decomposition_solve'."
|
|
1872
1875
|
)
|
|
1873
1876
|
|
|
1874
1877
|
|