lifejacket 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/after_study_analysis.py +401 -387
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -21
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +146 -128
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +7 -7
- lifejacket/get_datum_for_blowup_supervised_learning.py +315 -307
- lifejacket/helper_functions.py +45 -38
- lifejacket/input_checks.py +263 -261
- lifejacket/small_sample_corrections.py +42 -40
- lifejacket-1.0.0.dist-info/METADATA +56 -0
- lifejacket-1.0.0.dist-info/RECORD +17 -0
- lifejacket-0.2.0.dist-info/METADATA +0 -100
- lifejacket-0.2.0.dist-info/RECORD +0 -17
- {lifejacket-0.2.0.dist-info → lifejacket-1.0.0.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.0.dist-info → lifejacket-1.0.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-0.2.0.dist-info → lifejacket-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -24,8 +24,8 @@ from .constants import (
|
|
|
24
24
|
SandwichFormationMethods,
|
|
25
25
|
SmallSampleCorrections,
|
|
26
26
|
)
|
|
27
|
-
from .
|
|
28
|
-
|
|
27
|
+
from .form_adjusted_meat_adjustments_directly import (
|
|
28
|
+
form_adjusted_meat_adjustments_directly,
|
|
29
29
|
)
|
|
30
30
|
from . import input_checks
|
|
31
31
|
from . import get_datum_for_blowup_supervised_learning
|
|
@@ -37,9 +37,9 @@ from .helper_functions import (
|
|
|
37
37
|
calculate_beta_dim,
|
|
38
38
|
collect_all_post_update_betas,
|
|
39
39
|
construct_beta_index_by_policy_num_map,
|
|
40
|
-
|
|
40
|
+
extract_action_and_policy_by_decision_time_by_subject_id,
|
|
41
41
|
flatten_params,
|
|
42
|
-
|
|
42
|
+
get_active_df_column,
|
|
43
43
|
get_min_time_by_policy_num,
|
|
44
44
|
get_radon_nikodym_weight,
|
|
45
45
|
load_function_from_same_named_file,
|
|
@@ -61,7 +61,7 @@ def cli():
|
|
|
61
61
|
|
|
62
62
|
# TODO: Check all help strings for accuracy.
|
|
63
63
|
# TODO: Deal with NA, -1, etc policy numbers
|
|
64
|
-
# TODO: Make sure in
|
|
64
|
+
# TODO: Make sure in deployment is never on for more than one stretch EDIT: unclear if
|
|
65
65
|
# this will remain an invariant as we deal with more complicated data missingness
|
|
66
66
|
# TODO: I think I'm agnostic to indexing of calendar times but should check because
|
|
67
67
|
# otherwise need to add a check here to verify required format.
|
|
@@ -69,7 +69,7 @@ def cli():
|
|
|
69
69
|
# Higher dimensional objects not supported. Not entirely sure what kind of "scalars" apply.
|
|
70
70
|
@cli.command(name="analyze")
|
|
71
71
|
@click.option(
|
|
72
|
-
"--
|
|
72
|
+
"--analysis_df_pickle",
|
|
73
73
|
type=click.File("rb"),
|
|
74
74
|
help="Pickled pandas dataframe in correct format (see contract/readme).",
|
|
75
75
|
required=True,
|
|
@@ -83,7 +83,7 @@ def cli():
|
|
|
83
83
|
@click.option(
|
|
84
84
|
"--action_prob_func_args_pickle",
|
|
85
85
|
type=click.File("rb"),
|
|
86
|
-
help="Pickled dictionary that contains the action probability function arguments for all decision times for all
|
|
86
|
+
help="Pickled dictionary that contains the action probability function arguments for all decision times for all subjects.",
|
|
87
87
|
required=True,
|
|
88
88
|
)
|
|
89
89
|
@click.option(
|
|
@@ -95,7 +95,7 @@ def cli():
|
|
|
95
95
|
@click.option(
|
|
96
96
|
"--alg_update_func_filename",
|
|
97
97
|
type=click.Path(exists=True),
|
|
98
|
-
help="File that contains the per-
|
|
98
|
+
help="File that contains the per-subject update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
|
|
99
99
|
required=True,
|
|
100
100
|
)
|
|
101
101
|
@click.option(
|
|
@@ -107,7 +107,7 @@ def cli():
|
|
|
107
107
|
@click.option(
|
|
108
108
|
"--alg_update_func_args_pickle",
|
|
109
109
|
type=click.File("rb"),
|
|
110
|
-
help="Pickled dictionary that contains the algorithm update function arguments for all update times for all
|
|
110
|
+
help="Pickled dictionary that contains the algorithm update function arguments for all update times for all subjects.",
|
|
111
111
|
required=True,
|
|
112
112
|
)
|
|
113
113
|
@click.option(
|
|
@@ -137,7 +137,7 @@ def cli():
|
|
|
137
137
|
@click.option(
|
|
138
138
|
"--inference_func_filename",
|
|
139
139
|
type=click.Path(exists=True),
|
|
140
|
-
help="File that contains the per-
|
|
140
|
+
help="File that contains the per-subject loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
|
|
141
141
|
required=True,
|
|
142
142
|
)
|
|
143
143
|
@click.option(
|
|
@@ -155,56 +155,56 @@ def cli():
|
|
|
155
155
|
@click.option(
|
|
156
156
|
"--theta_calculation_func_filename",
|
|
157
157
|
type=click.Path(exists=True),
|
|
158
|
-
help="Path to file that allows one to actually calculate a theta estimate given the
|
|
158
|
+
help="Path to file that allows one to actually calculate a theta estimate given the analysis dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
|
|
159
159
|
required=True,
|
|
160
160
|
)
|
|
161
161
|
@click.option(
|
|
162
|
-
"--
|
|
162
|
+
"--active_col_name",
|
|
163
163
|
type=str,
|
|
164
164
|
required=True,
|
|
165
|
-
help="Name of the binary column in the
|
|
165
|
+
help="Name of the binary column in the analysis dataframe that indicates whether a subject is in the deployment.",
|
|
166
166
|
)
|
|
167
167
|
@click.option(
|
|
168
168
|
"--action_col_name",
|
|
169
169
|
type=str,
|
|
170
170
|
required=True,
|
|
171
|
-
help="Name of the binary column in the
|
|
171
|
+
help="Name of the binary column in the analysis dataframe that indicates which action was taken.",
|
|
172
172
|
)
|
|
173
173
|
@click.option(
|
|
174
174
|
"--policy_num_col_name",
|
|
175
175
|
type=str,
|
|
176
176
|
required=True,
|
|
177
|
-
help="Name of the column in the
|
|
177
|
+
help="Name of the column in the analysis dataframe that indicates the policy number in use.",
|
|
178
178
|
)
|
|
179
179
|
@click.option(
|
|
180
180
|
"--calendar_t_col_name",
|
|
181
181
|
type=str,
|
|
182
182
|
required=True,
|
|
183
|
-
help="Name of the column in the
|
|
183
|
+
help="Name of the column in the analysis dataframe that indicates calendar time (shared integer index across subjects).",
|
|
184
184
|
)
|
|
185
185
|
@click.option(
|
|
186
|
-
"--
|
|
186
|
+
"--subject_id_col_name",
|
|
187
187
|
type=str,
|
|
188
188
|
required=True,
|
|
189
|
-
help="Name of the column in the
|
|
189
|
+
help="Name of the column in the analysis dataframe that indicates subject id.",
|
|
190
190
|
)
|
|
191
191
|
@click.option(
|
|
192
192
|
"--action_prob_col_name",
|
|
193
193
|
type=str,
|
|
194
194
|
required=True,
|
|
195
|
-
help="Name of the column in the
|
|
195
|
+
help="Name of the column in the analysis dataframe that gives action one probabilities.",
|
|
196
196
|
)
|
|
197
197
|
@click.option(
|
|
198
198
|
"--reward_col_name",
|
|
199
199
|
type=str,
|
|
200
200
|
required=True,
|
|
201
|
-
help="Name of the column in the
|
|
201
|
+
help="Name of the column in the analysis dataframe that gives rewards.",
|
|
202
202
|
)
|
|
203
203
|
@click.option(
|
|
204
204
|
"--suppress_interactive_data_checks",
|
|
205
205
|
type=bool,
|
|
206
206
|
default=False,
|
|
207
|
-
help="Flag to suppress any data checks that require
|
|
207
|
+
help="Flag to suppress any data checks that require subject input. This is suitable for tests and large simulations",
|
|
208
208
|
)
|
|
209
209
|
@click.option(
|
|
210
210
|
"--suppress_all_data_checks",
|
|
@@ -232,13 +232,13 @@ def cli():
|
|
|
232
232
|
help="Flag to collect data for supervised learning blowup detection. This will write a single datum and label to a file in the same directory as the input files.",
|
|
233
233
|
)
|
|
234
234
|
@click.option(
|
|
235
|
-
"--
|
|
235
|
+
"--form_adjusted_meat_adjustments_explicitly",
|
|
236
236
|
type=bool,
|
|
237
237
|
default=False,
|
|
238
|
-
help="If True, explicitly forms the per-
|
|
238
|
+
help="If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
|
|
239
239
|
)
|
|
240
240
|
@click.option(
|
|
241
|
-
"--
|
|
241
|
+
"--stabilize_joint_adjusted_bread_inverse",
|
|
242
242
|
type=bool,
|
|
243
243
|
default=True,
|
|
244
244
|
help="If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning thresholds.",
|
|
@@ -248,7 +248,7 @@ def analyze_dataset_wrapper(**kwargs):
|
|
|
248
248
|
This function is a wrapper around analyze_dataset to facilitate command line use.
|
|
249
249
|
|
|
250
250
|
From the command line, we will take pickles and filenames for Python objects.
|
|
251
|
-
|
|
251
|
+
We unpickle/load files here for passing to the implementation function, which
|
|
252
252
|
may also be called in its own right with in-memory objects.
|
|
253
253
|
|
|
254
254
|
See analyze_dataset for the underlying details.
|
|
@@ -256,18 +256,20 @@ def analyze_dataset_wrapper(**kwargs):
|
|
|
256
256
|
Returns: None
|
|
257
257
|
"""
|
|
258
258
|
|
|
259
|
-
# Pass along the folder the
|
|
260
|
-
# Do it now because we will be removing the
|
|
261
|
-
kwargs["output_dir"] = pathlib.Path(
|
|
259
|
+
# Pass along the folder the analysis dataframe is in as the output folder.
|
|
260
|
+
# Do it now because we will be removing the analysis dataframe pickle from kwargs.
|
|
261
|
+
kwargs["output_dir"] = pathlib.Path(
|
|
262
|
+
kwargs["analysis_df_pickle"].name
|
|
263
|
+
).parent.resolve()
|
|
262
264
|
|
|
263
265
|
# Unpickle pickles and replace those args in kwargs
|
|
264
|
-
kwargs["
|
|
266
|
+
kwargs["analysis_df"] = pickle.load(kwargs["analysis_df_pickle"])
|
|
265
267
|
kwargs["action_prob_func_args"] = pickle.load(
|
|
266
268
|
kwargs["action_prob_func_args_pickle"]
|
|
267
269
|
)
|
|
268
270
|
kwargs["alg_update_func_args"] = pickle.load(kwargs["alg_update_func_args_pickle"])
|
|
269
271
|
|
|
270
|
-
kwargs.pop("
|
|
272
|
+
kwargs.pop("analysis_df_pickle")
|
|
271
273
|
kwargs.pop("action_prob_func_args_pickle")
|
|
272
274
|
kwargs.pop("alg_update_func_args_pickle")
|
|
273
275
|
|
|
@@ -295,7 +297,7 @@ def analyze_dataset_wrapper(**kwargs):
|
|
|
295
297
|
|
|
296
298
|
def analyze_dataset(
|
|
297
299
|
output_dir: pathlib.Path | str,
|
|
298
|
-
|
|
300
|
+
analysis_df: pd.DataFrame,
|
|
299
301
|
action_prob_func: Callable,
|
|
300
302
|
action_prob_func_args: dict[int, Any],
|
|
301
303
|
action_prob_func_args_beta_index: int,
|
|
@@ -310,19 +312,19 @@ def analyze_dataset(
|
|
|
310
312
|
inference_func_type: str,
|
|
311
313
|
inference_func_args_theta_index: int,
|
|
312
314
|
theta_calculation_func: Callable[[pd.DataFrame], jnp.ndarray],
|
|
313
|
-
|
|
315
|
+
active_col_name: str,
|
|
314
316
|
action_col_name: str,
|
|
315
317
|
policy_num_col_name: str,
|
|
316
318
|
calendar_t_col_name: str,
|
|
317
|
-
|
|
319
|
+
subject_id_col_name: str,
|
|
318
320
|
action_prob_col_name: str,
|
|
319
321
|
reward_col_name: str,
|
|
320
322
|
suppress_interactive_data_checks: bool,
|
|
321
323
|
suppress_all_data_checks: bool,
|
|
322
324
|
small_sample_correction: str,
|
|
323
325
|
collect_data_for_blowup_supervised_learning: bool,
|
|
324
|
-
|
|
325
|
-
|
|
326
|
+
form_adjusted_meat_adjustments_explicitly: bool,
|
|
327
|
+
stabilize_joint_adjusted_bread_inverse: bool,
|
|
326
328
|
) -> None:
|
|
327
329
|
"""
|
|
328
330
|
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using adaptive and classical sandwich estimators.
|
|
@@ -337,8 +339,8 @@ def analyze_dataset(
|
|
|
337
339
|
Parameters:
|
|
338
340
|
output_dir (pathlib.Path | str):
|
|
339
341
|
Directory in which to save output files.
|
|
340
|
-
|
|
341
|
-
DataFrame containing the
|
|
342
|
+
analysis_df (pd.DataFrame):
|
|
343
|
+
DataFrame containing the deployment data.
|
|
342
344
|
action_prob_func (callable):
|
|
343
345
|
Action probability function.
|
|
344
346
|
action_prob_func_args (dict[int, Any]):
|
|
@@ -364,21 +366,21 @@ def analyze_dataset(
|
|
|
364
366
|
inference_func_args_theta_index (int):
|
|
365
367
|
Index for theta in inference function arguments.
|
|
366
368
|
theta_calculation_func (callable):
|
|
367
|
-
Function to estimate theta from the
|
|
368
|
-
|
|
369
|
-
Column name indicating if a
|
|
369
|
+
Function to estimate theta from the analysis dataframe.
|
|
370
|
+
active_col_name (str):
|
|
371
|
+
Column name indicating if a subject is active in the analysis dataframe.
|
|
370
372
|
action_col_name (str):
|
|
371
|
-
Column name for actions in the
|
|
373
|
+
Column name for actions in the analysis dataframe.
|
|
372
374
|
policy_num_col_name (str):
|
|
373
|
-
Column name for policy numbers in the
|
|
375
|
+
Column name for policy numbers in the analysis dataframe.
|
|
374
376
|
calendar_t_col_name (str):
|
|
375
|
-
Column name for calendar time in the
|
|
376
|
-
|
|
377
|
-
Column name for
|
|
377
|
+
Column name for calendar time in the analysis dataframe.
|
|
378
|
+
subject_id_col_name (str):
|
|
379
|
+
Column name for subject IDs in the analysis dataframe.
|
|
378
380
|
action_prob_col_name (str):
|
|
379
|
-
Column name for action probabilities in the
|
|
381
|
+
Column name for action probabilities in the analysis dataframe.
|
|
380
382
|
reward_col_name (str):
|
|
381
|
-
Column name for rewards in the
|
|
383
|
+
Column name for rewards in the analysis dataframe.
|
|
382
384
|
suppress_interactive_data_checks (bool):
|
|
383
385
|
Whether to suppress interactive data checks. This should be used in simulations, for example.
|
|
384
386
|
suppress_all_data_checks (bool):
|
|
@@ -387,11 +389,11 @@ def analyze_dataset(
|
|
|
387
389
|
Type of small sample correction to apply.
|
|
388
390
|
collect_data_for_blowup_supervised_learning (bool):
|
|
389
391
|
Whether to collect data for doing supervised learning about adaptive sandwich blowup.
|
|
390
|
-
|
|
391
|
-
If True, explicitly forms the per-
|
|
392
|
+
form_adjusted_meat_adjustments_explicitly (bool):
|
|
393
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
|
|
392
394
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
393
395
|
adaptive sandwich is formed without doing this.
|
|
394
|
-
|
|
396
|
+
stabilize_joint_adjusted_bread_inverse (bool):
|
|
395
397
|
If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
|
|
396
398
|
thresholds.
|
|
397
399
|
|
|
@@ -406,19 +408,19 @@ def analyze_dataset(
|
|
|
406
408
|
level=logging.INFO,
|
|
407
409
|
)
|
|
408
410
|
|
|
409
|
-
theta_est = jnp.array(theta_calculation_func(
|
|
411
|
+
theta_est = jnp.array(theta_calculation_func(analysis_df))
|
|
410
412
|
|
|
411
413
|
beta_dim = calculate_beta_dim(
|
|
412
414
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
413
415
|
)
|
|
414
416
|
if not suppress_all_data_checks:
|
|
415
417
|
input_checks.perform_first_wave_input_checks(
|
|
416
|
-
|
|
417
|
-
|
|
418
|
+
analysis_df,
|
|
419
|
+
active_col_name,
|
|
418
420
|
action_col_name,
|
|
419
421
|
policy_num_col_name,
|
|
420
422
|
calendar_t_col_name,
|
|
421
|
-
|
|
423
|
+
subject_id_col_name,
|
|
422
424
|
action_prob_col_name,
|
|
423
425
|
reward_col_name,
|
|
424
426
|
action_prob_func,
|
|
@@ -439,7 +441,7 @@ def analyze_dataset(
|
|
|
439
441
|
|
|
440
442
|
beta_index_by_policy_num, initial_policy_num = (
|
|
441
443
|
construct_beta_index_by_policy_num_map(
|
|
442
|
-
|
|
444
|
+
analysis_df, policy_num_col_name, active_col_name
|
|
443
445
|
)
|
|
444
446
|
)
|
|
445
447
|
|
|
@@ -447,11 +449,11 @@ def analyze_dataset(
|
|
|
447
449
|
beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
|
|
448
450
|
)
|
|
449
451
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
452
|
+
action_by_decision_time_by_subject_id, policy_num_by_decision_time_by_subject_id = (
|
|
453
|
+
extract_action_and_policy_by_decision_time_by_subject_id(
|
|
454
|
+
analysis_df,
|
|
455
|
+
subject_id_col_name,
|
|
456
|
+
active_col_name,
|
|
455
457
|
calendar_t_col_name,
|
|
456
458
|
action_col_name,
|
|
457
459
|
policy_num_col_name,
|
|
@@ -459,45 +461,45 @@ def analyze_dataset(
|
|
|
459
461
|
)
|
|
460
462
|
|
|
461
463
|
(
|
|
462
|
-
|
|
464
|
+
inference_func_args_by_subject_id,
|
|
463
465
|
inference_func_args_action_prob_index,
|
|
464
|
-
|
|
466
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
465
467
|
) = process_inference_func_args(
|
|
466
468
|
inference_func,
|
|
467
469
|
inference_func_args_theta_index,
|
|
468
|
-
|
|
470
|
+
analysis_df,
|
|
469
471
|
theta_est,
|
|
470
472
|
action_prob_col_name,
|
|
471
473
|
calendar_t_col_name,
|
|
472
|
-
|
|
473
|
-
|
|
474
|
+
subject_id_col_name,
|
|
475
|
+
active_col_name,
|
|
474
476
|
)
|
|
475
477
|
|
|
476
|
-
# Use a per-
|
|
478
|
+
# Use a per-subject weighted estimating function stacking functino to derive classical and joint
|
|
477
479
|
# adaptive meat and inverse bread matrices. This is facilitated because the *value* of the
|
|
478
480
|
# weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
|
|
479
481
|
logger.info(
|
|
480
|
-
"Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across
|
|
482
|
+
"Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across subjects."
|
|
481
483
|
)
|
|
482
484
|
|
|
483
|
-
|
|
485
|
+
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
484
486
|
(
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
487
|
+
stabilized_joint_adjusted_bread_inverse_matrix,
|
|
488
|
+
raw_joint_adjusted_bread_inverse_matrix,
|
|
489
|
+
joint_adjusted_meat_matrix,
|
|
490
|
+
joint_adjusted_sandwich_matrix,
|
|
489
491
|
classical_bread_inverse_matrix,
|
|
490
492
|
classical_meat_matrix,
|
|
491
493
|
classical_sandwich_var_estimate,
|
|
492
494
|
avg_estimating_function_stack,
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
) =
|
|
495
|
+
per_subject_estimating_function_stacks,
|
|
496
|
+
per_subject_adjusted_corrections,
|
|
497
|
+
per_subject_classical_corrections,
|
|
498
|
+
per_subject_adjusted_meat_adjustments,
|
|
499
|
+
) = construct_classical_and_adjusted_sandwiches(
|
|
498
500
|
theta_est,
|
|
499
501
|
all_post_update_betas,
|
|
500
|
-
|
|
502
|
+
subject_ids,
|
|
501
503
|
action_prob_func,
|
|
502
504
|
action_prob_func_args_beta_index,
|
|
503
505
|
alg_update_func,
|
|
@@ -511,31 +513,27 @@ def analyze_dataset(
|
|
|
511
513
|
inference_func_args_theta_index,
|
|
512
514
|
inference_func_args_action_prob_index,
|
|
513
515
|
action_prob_func_args,
|
|
514
|
-
|
|
516
|
+
policy_num_by_decision_time_by_subject_id,
|
|
515
517
|
initial_policy_num,
|
|
516
518
|
beta_index_by_policy_num,
|
|
517
|
-
|
|
518
|
-
|
|
519
|
+
inference_func_args_by_subject_id,
|
|
520
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
519
521
|
alg_update_func_args,
|
|
520
|
-
|
|
522
|
+
action_by_decision_time_by_subject_id,
|
|
521
523
|
suppress_all_data_checks,
|
|
522
524
|
suppress_interactive_data_checks,
|
|
523
525
|
small_sample_correction,
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
526
|
+
form_adjusted_meat_adjustments_explicitly,
|
|
527
|
+
stabilize_joint_adjusted_bread_inverse,
|
|
528
|
+
analysis_df,
|
|
529
|
+
active_col_name,
|
|
528
530
|
action_col_name,
|
|
529
531
|
calendar_t_col_name,
|
|
530
|
-
|
|
532
|
+
subject_id_col_name,
|
|
531
533
|
action_prob_func_args,
|
|
532
534
|
action_prob_col_name,
|
|
533
535
|
)
|
|
534
536
|
|
|
535
|
-
joint_adaptive_bread_inverse_cond = jnp.linalg.cond(
|
|
536
|
-
stabilized_joint_adaptive_bread_inverse_matrix
|
|
537
|
-
)
|
|
538
|
-
|
|
539
537
|
theta_dim = len(theta_est)
|
|
540
538
|
if not suppress_all_data_checks:
|
|
541
539
|
input_checks.require_estimating_functions_sum_to_zero(
|
|
@@ -547,18 +545,18 @@ def analyze_dataset(
|
|
|
547
545
|
|
|
548
546
|
# This bottom right corner of the joint (betas and theta) variance matrix is the portion
|
|
549
547
|
# corresponding to just theta.
|
|
550
|
-
|
|
548
|
+
adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
|
|
551
549
|
-theta_dim:, -theta_dim:
|
|
552
550
|
]
|
|
553
551
|
|
|
554
552
|
# Check for negative diagonal elements and set them to zero if found
|
|
555
|
-
adaptive_diagonal = np.diag(
|
|
553
|
+
adaptive_diagonal = np.diag(adjusted_sandwich_var_estimate)
|
|
556
554
|
if np.any(adaptive_diagonal < 0):
|
|
557
555
|
logger.warning(
|
|
558
556
|
"Found negative diagonal elements in adaptive sandwich variance estimate. Setting them to zero."
|
|
559
557
|
)
|
|
560
558
|
np.fill_diagonal(
|
|
561
|
-
|
|
559
|
+
adjusted_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
|
|
562
560
|
)
|
|
563
561
|
|
|
564
562
|
logger.info("Writing results to file...")
|
|
@@ -567,7 +565,7 @@ def analyze_dataset(
|
|
|
567
565
|
|
|
568
566
|
analysis_dict = {
|
|
569
567
|
"theta_est": theta_est,
|
|
570
|
-
"
|
|
568
|
+
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
571
569
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
572
570
|
}
|
|
573
571
|
with open(output_folder_abs_path / "analysis.pkl", "wb") as f:
|
|
@@ -576,21 +574,29 @@ def analyze_dataset(
|
|
|
576
574
|
f,
|
|
577
575
|
)
|
|
578
576
|
|
|
577
|
+
joint_adjusted_bread_inverse_cond = jnp.linalg.cond(
|
|
578
|
+
raw_joint_adjusted_bread_inverse_matrix
|
|
579
|
+
)
|
|
580
|
+
logger.info(
|
|
581
|
+
"Joint adjusted bread inverse condition number: %f",
|
|
582
|
+
joint_adjusted_bread_inverse_cond,
|
|
583
|
+
)
|
|
584
|
+
|
|
579
585
|
debug_pieces_dict = {
|
|
580
586
|
"theta_est": theta_est,
|
|
581
|
-
"
|
|
587
|
+
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
582
588
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
583
|
-
"raw_joint_bread_inverse_matrix":
|
|
584
|
-
"stabilized_joint_bread_inverse_matrix":
|
|
585
|
-
"joint_meat_matrix":
|
|
589
|
+
"raw_joint_bread_inverse_matrix": raw_joint_adjusted_bread_inverse_matrix,
|
|
590
|
+
"stabilized_joint_bread_inverse_matrix": stabilized_joint_adjusted_bread_inverse_matrix,
|
|
591
|
+
"joint_meat_matrix": joint_adjusted_meat_matrix,
|
|
586
592
|
"classical_bread_inverse_matrix": classical_bread_inverse_matrix,
|
|
587
593
|
"classical_meat_matrix": classical_meat_matrix,
|
|
588
|
-
"all_estimating_function_stacks":
|
|
589
|
-
"joint_bread_inverse_condition_number":
|
|
594
|
+
"all_estimating_function_stacks": per_subject_estimating_function_stacks,
|
|
595
|
+
"joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
|
|
590
596
|
"all_post_update_betas": all_post_update_betas,
|
|
591
|
-
"
|
|
592
|
-
"
|
|
593
|
-
"
|
|
597
|
+
"per_subject_adjusted_corrections": per_subject_adjusted_corrections,
|
|
598
|
+
"per_subject_classical_corrections": per_subject_classical_corrections,
|
|
599
|
+
"per_subject_adjusted_meat_adjustments": per_subject_adjusted_meat_adjustments,
|
|
594
600
|
}
|
|
595
601
|
with open(output_folder_abs_path / "debug_pieces.pkl", "wb") as f:
|
|
596
602
|
pickle.dump(
|
|
@@ -600,25 +606,25 @@ def analyze_dataset(
|
|
|
600
606
|
|
|
601
607
|
if collect_data_for_blowup_supervised_learning:
|
|
602
608
|
datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
|
|
603
|
-
|
|
604
|
-
|
|
609
|
+
raw_joint_adjusted_bread_inverse_matrix,
|
|
610
|
+
joint_adjusted_bread_inverse_cond,
|
|
605
611
|
avg_estimating_function_stack,
|
|
606
|
-
|
|
612
|
+
per_subject_estimating_function_stacks,
|
|
607
613
|
all_post_update_betas,
|
|
608
|
-
|
|
609
|
-
|
|
614
|
+
analysis_df,
|
|
615
|
+
active_col_name,
|
|
610
616
|
calendar_t_col_name,
|
|
611
617
|
action_prob_col_name,
|
|
612
|
-
|
|
618
|
+
subject_id_col_name,
|
|
613
619
|
reward_col_name,
|
|
614
620
|
theta_est,
|
|
615
|
-
|
|
616
|
-
|
|
621
|
+
adjusted_sandwich_var_estimate,
|
|
622
|
+
subject_ids,
|
|
617
623
|
beta_dim,
|
|
618
624
|
theta_dim,
|
|
619
625
|
initial_policy_num,
|
|
620
626
|
beta_index_by_policy_num,
|
|
621
|
-
|
|
627
|
+
policy_num_by_decision_time_by_subject_id,
|
|
622
628
|
theta_calculation_func,
|
|
623
629
|
action_prob_func,
|
|
624
630
|
action_prob_func_args_beta_index,
|
|
@@ -626,16 +632,16 @@ def analyze_dataset(
|
|
|
626
632
|
inference_func_type,
|
|
627
633
|
inference_func_args_theta_index,
|
|
628
634
|
inference_func_args_action_prob_index,
|
|
629
|
-
|
|
635
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
630
636
|
action_prob_func_args,
|
|
631
|
-
|
|
637
|
+
action_by_decision_time_by_subject_id,
|
|
632
638
|
)
|
|
633
639
|
|
|
634
640
|
with open(output_folder_abs_path / "supervised_learning_datum.pkl", "wb") as f:
|
|
635
641
|
pickle.dump(datum_and_label_dict, f)
|
|
636
642
|
|
|
637
643
|
print(f"\nParameter estimate:\n {theta_est}")
|
|
638
|
-
print(f"\
|
|
644
|
+
print(f"\nAdjusted sandwich variance estimate:\n {adjusted_sandwich_var_estimate}")
|
|
639
645
|
print(
|
|
640
646
|
f"\nClassical sandwich variance estimate:\n {classical_sandwich_var_estimate}\n"
|
|
641
647
|
)
|
|
@@ -646,15 +652,15 @@ def analyze_dataset(
|
|
|
646
652
|
def process_inference_func_args(
|
|
647
653
|
inference_func: callable,
|
|
648
654
|
inference_func_args_theta_index: int,
|
|
649
|
-
|
|
655
|
+
analysis_df: pd.DataFrame,
|
|
650
656
|
theta_est: jnp.ndarray,
|
|
651
657
|
action_prob_col_name: str,
|
|
652
658
|
calendar_t_col_name: str,
|
|
653
|
-
|
|
654
|
-
|
|
659
|
+
subject_id_col_name: str,
|
|
660
|
+
active_col_name: str,
|
|
655
661
|
) -> tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int]:
|
|
656
662
|
"""
|
|
657
|
-
Collects the inference function arguments for each
|
|
663
|
+
Collects the inference function arguments for each subject from the analysis DataFrame.
|
|
658
664
|
|
|
659
665
|
Note that theta and action probabilities, if present, will be replaced later
|
|
660
666
|
so that the function can be differentiated with respect to shared versions
|
|
@@ -665,32 +671,32 @@ def process_inference_func_args(
|
|
|
665
671
|
The inference function to be used.
|
|
666
672
|
inference_func_args_theta_index (int):
|
|
667
673
|
The index of the theta parameter in the inference function's arguments.
|
|
668
|
-
|
|
669
|
-
The
|
|
674
|
+
analysis_df (pandas.DataFrame):
|
|
675
|
+
The analysis DataFrame.
|
|
670
676
|
theta_est (jnp.ndarray):
|
|
671
677
|
The estimate of the parameter vector.
|
|
672
678
|
action_prob_col_name (str):
|
|
673
|
-
The name of the column in the
|
|
679
|
+
The name of the column in the analysis DataFrame that gives action probabilities.
|
|
674
680
|
calendar_t_col_name (str):
|
|
675
|
-
The name of the column in the
|
|
676
|
-
|
|
677
|
-
The name of the column in the
|
|
678
|
-
|
|
679
|
-
The name of the binary column in the
|
|
681
|
+
The name of the column in the analysis DataFrame that indicates calendar time.
|
|
682
|
+
subject_id_col_name (str):
|
|
683
|
+
The name of the column in the analysis DataFrame that indicates subject ID.
|
|
684
|
+
active_col_name (str):
|
|
685
|
+
The name of the binary column in the analysis DataFrame that indicates whether a subject is in the deployment.
|
|
680
686
|
Returns:
|
|
681
687
|
tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int, dict[collections.abc.Hashable, jnp.ndarray[int]]]:
|
|
682
688
|
A tuple containing
|
|
683
|
-
- the inference function arguments dictionary for each
|
|
689
|
+
- the inference function arguments dictionary for each subject
|
|
684
690
|
- the index of the action probabilities argument
|
|
685
|
-
- a dictionary mapping
|
|
691
|
+
- a dictionary mapping subject IDs to the decision times to which action probabilities correspond
|
|
686
692
|
"""
|
|
687
693
|
|
|
688
694
|
num_args = inference_func.__code__.co_argcount
|
|
689
695
|
inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
|
|
690
|
-
|
|
696
|
+
inference_func_args_by_subject_id = {}
|
|
691
697
|
|
|
692
698
|
inference_func_args_action_prob_index = -1
|
|
693
|
-
|
|
699
|
+
inference_action_prob_decision_times_by_subject_id = {}
|
|
694
700
|
|
|
695
701
|
using_action_probs = action_prob_col_name in inference_func_arg_names
|
|
696
702
|
if using_action_probs:
|
|
@@ -698,34 +704,36 @@ def process_inference_func_args(
|
|
|
698
704
|
action_prob_col_name
|
|
699
705
|
)
|
|
700
706
|
|
|
701
|
-
for
|
|
702
|
-
|
|
703
|
-
|
|
707
|
+
for subject_id in analysis_df[subject_id_col_name].unique():
|
|
708
|
+
subject_args_list = []
|
|
709
|
+
filtered_subject_data = analysis_df.loc[
|
|
710
|
+
analysis_df[subject_id_col_name] == subject_id
|
|
711
|
+
]
|
|
704
712
|
for idx, col_name in enumerate(inference_func_arg_names):
|
|
705
713
|
if idx == inference_func_args_theta_index:
|
|
706
|
-
|
|
714
|
+
subject_args_list.append(theta_est)
|
|
707
715
|
continue
|
|
708
|
-
|
|
709
|
-
|
|
716
|
+
subject_args_list.append(
|
|
717
|
+
get_active_df_column(filtered_subject_data, col_name, active_col_name)
|
|
710
718
|
)
|
|
711
|
-
|
|
719
|
+
inference_func_args_by_subject_id[subject_id] = tuple(subject_args_list)
|
|
712
720
|
if using_action_probs:
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
721
|
+
inference_action_prob_decision_times_by_subject_id[subject_id] = (
|
|
722
|
+
get_active_df_column(
|
|
723
|
+
filtered_subject_data, calendar_t_col_name, active_col_name
|
|
716
724
|
)
|
|
717
725
|
)
|
|
718
726
|
|
|
719
727
|
return (
|
|
720
|
-
|
|
728
|
+
inference_func_args_by_subject_id,
|
|
721
729
|
inference_func_args_action_prob_index,
|
|
722
|
-
|
|
730
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
723
731
|
)
|
|
724
732
|
|
|
725
733
|
|
|
726
|
-
def
|
|
734
|
+
def single_subject_weighted_estimating_function_stacker(
|
|
727
735
|
beta_dim: int,
|
|
728
|
-
|
|
736
|
+
subject_id: collections.abc.Hashable,
|
|
729
737
|
action_prob_func: callable,
|
|
730
738
|
algorithm_estimating_func: callable,
|
|
731
739
|
inference_estimating_func: callable,
|
|
@@ -759,12 +767,12 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
759
767
|
beta_dim (list[jnp.ndarray]):
|
|
760
768
|
A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
|
|
761
769
|
|
|
762
|
-
|
|
763
|
-
The
|
|
770
|
+
subject_id (collections.abc.Hashable):
|
|
771
|
+
The subject ID for which to compute the weighted estimating function stack.
|
|
764
772
|
|
|
765
773
|
action_prob_func (callable):
|
|
766
774
|
The function used to compute the probability of action 1 at a given decision time for
|
|
767
|
-
a particular
|
|
775
|
+
a particular subject given their state and the algorithm parameters.
|
|
768
776
|
|
|
769
777
|
algorithm_estimating_func (callable):
|
|
770
778
|
The estimating function that corresponds to algorithm updates.
|
|
@@ -779,9 +787,9 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
779
787
|
The index of the theta parameter in the inference loss or estimating function arguments.
|
|
780
788
|
|
|
781
789
|
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
782
|
-
A map from decision times to tuples of arguments for this
|
|
790
|
+
A map from decision times to tuples of arguments for this subject for the action
|
|
783
791
|
probability function. This is for all decision times (args are an empty
|
|
784
|
-
tuple if they are not in the
|
|
792
|
+
tuple if they are not in the deployment). Should be sorted by decision time. NOTE THAT THESE
|
|
785
793
|
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
786
794
|
will occur.
|
|
787
795
|
|
|
@@ -792,21 +800,21 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
792
800
|
|
|
793
801
|
threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
794
802
|
A map from policy numbers to tuples containing the arguments for
|
|
795
|
-
the corresponding estimating functions for this
|
|
803
|
+
the corresponding estimating functions for this subject, with the shared betas threaded in
|
|
796
804
|
for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
|
|
797
805
|
should be sorted.
|
|
798
806
|
|
|
799
807
|
threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
800
808
|
A tuple containing the arguments for the inference
|
|
801
|
-
estimating function for this
|
|
809
|
+
estimating function for this subject, with the shared betas threaded in for differentiation.
|
|
802
810
|
|
|
803
811
|
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
804
812
|
A dictionary mapping decision times to the policy number in use. This may be
|
|
805
|
-
|
|
813
|
+
subject-specific. Should be sorted by decision time. Only applies to active decision
|
|
806
814
|
times!
|
|
807
815
|
|
|
808
816
|
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
809
|
-
A dictionary mapping decision times to actions taken. Only applies to
|
|
817
|
+
A dictionary mapping decision times to actions taken. Only applies to active decision
|
|
810
818
|
times!
|
|
811
819
|
|
|
812
820
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
@@ -814,19 +822,21 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
814
822
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
815
823
|
|
|
816
824
|
Returns:
|
|
817
|
-
jnp.ndarray: A 1-D JAX NumPy array representing the
|
|
825
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
818
826
|
stack.
|
|
819
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
820
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
821
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
827
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adaptive meat contribution.
|
|
828
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
829
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
822
830
|
"""
|
|
823
831
|
|
|
824
|
-
logger.info(
|
|
832
|
+
logger.info(
|
|
833
|
+
"Computing weighted estimating function stack for subject %s.", subject_id
|
|
834
|
+
)
|
|
825
835
|
|
|
826
836
|
# First, reformat the supplied data into more convenient structures.
|
|
827
837
|
|
|
828
838
|
# 1. Form a dictionary mapping policy numbers to the first time they were
|
|
829
|
-
# applicable (for this
|
|
839
|
+
# applicable (for this subject). Note that this includes ALL policies, initial
|
|
830
840
|
# fallbacks included.
|
|
831
841
|
# Collect the first time after the first update separately for convenience.
|
|
832
842
|
# These are both used to form the Radon-Nikodym weights for the right times.
|
|
@@ -835,38 +845,38 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
835
845
|
beta_index_by_policy_num,
|
|
836
846
|
)
|
|
837
847
|
|
|
838
|
-
# 2. Get the start and end times for this
|
|
839
|
-
|
|
840
|
-
|
|
848
|
+
# 2. Get the start and end times for this subject.
|
|
849
|
+
subject_start_time = math.inf
|
|
850
|
+
subject_end_time = -math.inf
|
|
841
851
|
for decision_time in action_by_decision_time:
|
|
842
|
-
|
|
843
|
-
|
|
852
|
+
subject_start_time = min(subject_start_time, decision_time)
|
|
853
|
+
subject_end_time = max(subject_end_time, decision_time)
|
|
844
854
|
|
|
845
855
|
# 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
|
|
846
856
|
logger.info(
|
|
847
|
-
"Computing the algorithm component of the weighted estimating function stack for
|
|
848
|
-
|
|
857
|
+
"Computing the algorithm component of the weighted estimating function stack for subject %s.",
|
|
858
|
+
subject_id,
|
|
849
859
|
)
|
|
850
860
|
|
|
851
|
-
|
|
861
|
+
active_action_prob_func_args = [
|
|
852
862
|
args for args in action_prob_func_args_by_decision_time.values() if args
|
|
853
863
|
]
|
|
854
|
-
|
|
864
|
+
active_betas_list_by_decision_time_index = jnp.array(
|
|
855
865
|
[
|
|
856
866
|
action_prob_func_args[action_prob_func_args_beta_index]
|
|
857
|
-
for action_prob_func_args in
|
|
867
|
+
for action_prob_func_args in active_action_prob_func_args
|
|
858
868
|
]
|
|
859
869
|
)
|
|
860
|
-
|
|
870
|
+
active_actions_list_by_decision_time_index = jnp.array(
|
|
861
871
|
list(action_by_decision_time.values())
|
|
862
872
|
)
|
|
863
873
|
|
|
864
874
|
# Sort the threaded args by decision time to be cautious. We check if the
|
|
865
|
-
#
|
|
866
|
-
# subset of the
|
|
875
|
+
# subject id is present in the subject args dict because we may call this on a
|
|
876
|
+
# subset of the subject arg dict when we are batching arguments by shape
|
|
867
877
|
sorted_threaded_action_prob_args_by_decision_time = {
|
|
868
878
|
decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
|
|
869
|
-
for decision_time in range(
|
|
879
|
+
for decision_time in range(subject_start_time, subject_end_time + 1)
|
|
870
880
|
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
871
881
|
}
|
|
872
882
|
|
|
@@ -897,19 +907,19 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
897
907
|
# Just grab the original beta from the update function arguments. This is the same
|
|
898
908
|
# value, but impervious to differentiation with respect to all_post_update_betas. The
|
|
899
909
|
# args, on the other hand, are a function of all_post_update_betas.
|
|
900
|
-
|
|
910
|
+
active_weights = jax.vmap(
|
|
901
911
|
fun=get_radon_nikodym_weight,
|
|
902
912
|
in_axes=[0, None, None, 0] + batch_axes,
|
|
903
913
|
out_axes=0,
|
|
904
914
|
)(
|
|
905
|
-
|
|
915
|
+
active_betas_list_by_decision_time_index,
|
|
906
916
|
action_prob_func,
|
|
907
917
|
action_prob_func_args_beta_index,
|
|
908
|
-
|
|
918
|
+
active_actions_list_by_decision_time_index,
|
|
909
919
|
*batched_threaded_arg_tensors,
|
|
910
920
|
)
|
|
911
921
|
|
|
912
|
-
|
|
922
|
+
active_index = 0
|
|
913
923
|
decision_time_to_all_weights_index_offset = min(
|
|
914
924
|
sorted_threaded_action_prob_args_by_decision_time
|
|
915
925
|
)
|
|
@@ -918,35 +928,35 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
918
928
|
decision_time,
|
|
919
929
|
args,
|
|
920
930
|
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
921
|
-
all_weights_raw.append(
|
|
922
|
-
|
|
931
|
+
all_weights_raw.append(active_weights[active_index] if args else 1.0)
|
|
932
|
+
active_index += 1
|
|
923
933
|
all_weights = jnp.array(all_weights_raw)
|
|
924
934
|
|
|
925
935
|
algorithm_component = jnp.concatenate(
|
|
926
936
|
[
|
|
927
937
|
# Here we compute a product of Radon-Nikodym weights
|
|
928
938
|
# for all decision times after the first update and before the update
|
|
929
|
-
# update under consideration took effect, for which the
|
|
939
|
+
# update under consideration took effect, for which the subject was in the deployment.
|
|
930
940
|
(
|
|
931
941
|
jnp.prod(
|
|
932
942
|
all_weights[
|
|
933
|
-
# The earliest time after the first update where the
|
|
934
|
-
# the
|
|
943
|
+
# The earliest time after the first update where the subject was in
|
|
944
|
+
# the deployment
|
|
935
945
|
max(
|
|
936
946
|
first_time_after_first_update,
|
|
937
|
-
|
|
947
|
+
subject_start_time,
|
|
938
948
|
)
|
|
939
949
|
- decision_time_to_all_weights_index_offset :
|
|
940
|
-
# One more than the latest time the
|
|
950
|
+
# One more than the latest time the subject was in the deployment before the time
|
|
941
951
|
# the update under consideration first applied. Note the + 1 because range
|
|
942
952
|
# does not include the right endpoint.
|
|
943
953
|
min(
|
|
944
954
|
min_time_by_policy_num.get(policy_num, math.inf),
|
|
945
|
-
|
|
955
|
+
subject_end_time + 1,
|
|
946
956
|
)
|
|
947
957
|
- decision_time_to_all_weights_index_offset,
|
|
948
958
|
]
|
|
949
|
-
# If the
|
|
959
|
+
# If the subject exited the deployment before there were any updates,
|
|
950
960
|
# this variable will be None and the above code to grab a weight would
|
|
951
961
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
952
962
|
# if they have data to contribute to the update.
|
|
@@ -954,8 +964,8 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
954
964
|
else 1
|
|
955
965
|
) # Now use the above to weight the alg estimating function for this update
|
|
956
966
|
* algorithm_estimating_func(*update_args)
|
|
957
|
-
# If there are no arguments for the update function, the
|
|
958
|
-
#
|
|
967
|
+
# If there are no arguments for the update function, the subject is not yet in the
|
|
968
|
+
# deployment, so we just add a zero vector contribution to the sum across subjects.
|
|
959
969
|
# Note that after they exit, they still contribute all their data to later
|
|
960
970
|
# updates.
|
|
961
971
|
if update_args
|
|
@@ -974,17 +984,17 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
974
984
|
)
|
|
975
985
|
# 4. Form the weighted inference estimating equation.
|
|
976
986
|
logger.info(
|
|
977
|
-
"Computing the inference component of the weighted estimating function stack for
|
|
978
|
-
|
|
987
|
+
"Computing the inference component of the weighted estimating function stack for subject %s.",
|
|
988
|
+
subject_id,
|
|
979
989
|
)
|
|
980
990
|
inference_component = jnp.prod(
|
|
981
991
|
all_weights[
|
|
982
|
-
max(first_time_after_first_update,
|
|
983
|
-
- decision_time_to_all_weights_index_offset :
|
|
992
|
+
max(first_time_after_first_update, subject_start_time)
|
|
993
|
+
- decision_time_to_all_weights_index_offset : subject_end_time
|
|
984
994
|
+ 1
|
|
985
995
|
- decision_time_to_all_weights_index_offset,
|
|
986
996
|
]
|
|
987
|
-
# If the
|
|
997
|
+
# If the subject exited the deployment before there were any updates,
|
|
988
998
|
# this variable will be None and the above code to grab a weight would
|
|
989
999
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
990
1000
|
# if they have data to contribute here (pretty sure everyone should?)
|
|
@@ -993,18 +1003,18 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
993
1003
|
) * inference_estimating_func(*threaded_inference_func_args)
|
|
994
1004
|
|
|
995
1005
|
# 5. Concatenate the two components to form the weighted estimating function stack for this
|
|
996
|
-
#
|
|
1006
|
+
# subject.
|
|
997
1007
|
weighted_stack = jnp.concatenate([algorithm_component, inference_component])
|
|
998
1008
|
|
|
999
1009
|
# 6. Return the following outputs:
|
|
1000
|
-
# a. The first is simply the weighted estimating function stack for this
|
|
1010
|
+
# a. The first is simply the weighted estimating function stack for this subject. The average
|
|
1001
1011
|
# of these is what we differentiate with respect to theta to form the inverse adaptive joint
|
|
1002
1012
|
# bread matrix, and we also compare that average to zero to check the estimating functions'
|
|
1003
1013
|
# fidelity.
|
|
1004
|
-
# b. The average outer product of these per-
|
|
1014
|
+
# b. The average outer product of these per-subject stacks across subjects is the adaptive joint meat
|
|
1005
1015
|
# matrix, hence the second output.
|
|
1006
|
-
# c. The third output is averaged across
|
|
1007
|
-
# d. The fourth output is averaged across
|
|
1016
|
+
# c. The third output is averaged across subjects to obtain the classical meat matrix.
|
|
1017
|
+
# d. The fourth output is averaged across subjects to obtain the inverse classical bread
|
|
1008
1018
|
# matrix.
|
|
1009
1019
|
return (
|
|
1010
1020
|
weighted_stack,
|
|
@@ -1020,7 +1030,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1020
1030
|
flattened_betas_and_theta: jnp.ndarray,
|
|
1021
1031
|
beta_dim: int,
|
|
1022
1032
|
theta_dim: int,
|
|
1023
|
-
|
|
1033
|
+
subject_ids: jnp.ndarray,
|
|
1024
1034
|
action_prob_func: callable,
|
|
1025
1035
|
action_prob_func_args_beta_index: int,
|
|
1026
1036
|
alg_update_func: callable,
|
|
@@ -1033,29 +1043,31 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1033
1043
|
inference_func_type: str,
|
|
1034
1044
|
inference_func_args_theta_index: int,
|
|
1035
1045
|
inference_func_args_action_prob_index: int,
|
|
1036
|
-
|
|
1046
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
1037
1047
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1038
1048
|
],
|
|
1039
|
-
|
|
1049
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
1040
1050
|
collections.abc.Hashable, dict[int, int | float]
|
|
1041
1051
|
],
|
|
1042
1052
|
initial_policy_num: int | float,
|
|
1043
1053
|
beta_index_by_policy_num: dict[int | float, int],
|
|
1044
|
-
|
|
1045
|
-
|
|
1054
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1055
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
1046
1056
|
collections.abc.Hashable, list[int]
|
|
1047
1057
|
],
|
|
1048
|
-
|
|
1058
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
1049
1059
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
1050
1060
|
],
|
|
1051
|
-
|
|
1061
|
+
action_by_decision_time_by_subject_id: dict[
|
|
1062
|
+
collections.abc.Hashable, dict[int, int]
|
|
1063
|
+
],
|
|
1052
1064
|
suppress_all_data_checks: bool,
|
|
1053
1065
|
suppress_interactive_data_checks: bool,
|
|
1054
1066
|
) -> tuple[
|
|
1055
1067
|
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
1056
1068
|
]:
|
|
1057
1069
|
"""
|
|
1058
|
-
Computes the average weighted estimating function stack across all
|
|
1070
|
+
Computes the average weighted estimating function stack across all subjects, along with
|
|
1059
1071
|
auxiliary values used to construct the adaptive and classical sandwich variances.
|
|
1060
1072
|
|
|
1061
1073
|
Args:
|
|
@@ -1067,8 +1079,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1067
1079
|
The dimension of each of the beta parameters.
|
|
1068
1080
|
theta_dim (int):
|
|
1069
1081
|
The dimension of the theta parameter.
|
|
1070
|
-
|
|
1071
|
-
A 1D JAX NumPy array of
|
|
1082
|
+
subject_ids (jnp.ndarray):
|
|
1083
|
+
A 1D JAX NumPy array of subject IDs.
|
|
1072
1084
|
action_prob_func (callable):
|
|
1073
1085
|
The action probability function.
|
|
1074
1086
|
action_prob_func_args_beta_index (int):
|
|
@@ -1096,29 +1108,29 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1096
1108
|
inference_func_args_action_prob_index (int):
|
|
1097
1109
|
The index of action probabilities in the inference function arguments tuple, if
|
|
1098
1110
|
applicable. -1 otherwise.
|
|
1099
|
-
|
|
1100
|
-
A dictionary mapping decision times to maps of
|
|
1101
|
-
required to compute action probabilities for this
|
|
1102
|
-
|
|
1103
|
-
A map of
|
|
1104
|
-
Only applies to
|
|
1111
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1112
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
1113
|
+
required to compute action probabilities for this subject.
|
|
1114
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1115
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
1116
|
+
Only applies to active decision times!
|
|
1105
1117
|
initial_policy_num (int | float):
|
|
1106
1118
|
The policy number of the initial policy before any updates.
|
|
1107
1119
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
1108
1120
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1109
1121
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1110
|
-
|
|
1111
|
-
A dictionary mapping
|
|
1112
|
-
|
|
1113
|
-
For each
|
|
1114
|
-
provided. Typically just
|
|
1122
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1123
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
1124
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
1125
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
1126
|
+
provided. Typically just active times if action probabilites are used in the inference
|
|
1115
1127
|
loss or estimating function.
|
|
1116
|
-
|
|
1117
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
1128
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
1129
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
1118
1130
|
to their respective update function arguments.
|
|
1119
|
-
|
|
1120
|
-
A dictionary mapping
|
|
1121
|
-
Only applies to
|
|
1131
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1132
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
1133
|
+
Only applies to active decision times!
|
|
1122
1134
|
suppress_all_data_checks (bool):
|
|
1123
1135
|
If True, suppresses carrying out any data checks at all.
|
|
1124
1136
|
suppress_interactive_data_checks (bool):
|
|
@@ -1132,10 +1144,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1132
1144
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1133
1145
|
A tuple containing
|
|
1134
1146
|
1. the average weighted estimating function stack
|
|
1135
|
-
2. the
|
|
1136
|
-
3. the
|
|
1137
|
-
4. the
|
|
1138
|
-
5. raw per-
|
|
1147
|
+
2. the subject-level adaptive meat matrix contributions
|
|
1148
|
+
3. the subject-level classical meat matrix contributions
|
|
1149
|
+
4. the subject-level inverse classical bread matrix contributions
|
|
1150
|
+
5. raw per-subject weighted estimating function
|
|
1139
1151
|
stacks.
|
|
1140
1152
|
"""
|
|
1141
1153
|
|
|
@@ -1162,15 +1174,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1162
1174
|
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
1163
1175
|
# values should be the same, but not connected to the parameter we are differentiating
|
|
1164
1176
|
# with respect to. Note we will also find it useful below to have the action probability args
|
|
1165
|
-
# nested dict structure flipped to be
|
|
1177
|
+
# nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
|
|
1166
1178
|
|
|
1167
|
-
logger.info("Threading in betas to action probability arguments for all
|
|
1179
|
+
logger.info("Threading in betas to action probability arguments for all subjects.")
|
|
1168
1180
|
(
|
|
1169
|
-
|
|
1170
|
-
|
|
1181
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1182
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
1171
1183
|
) = thread_action_prob_func_args(
|
|
1172
|
-
|
|
1173
|
-
|
|
1184
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1185
|
+
policy_num_by_decision_time_by_subject_id,
|
|
1174
1186
|
initial_policy_num,
|
|
1175
1187
|
betas,
|
|
1176
1188
|
beta_index_by_policy_num,
|
|
@@ -1182,17 +1194,17 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1182
1194
|
# arguments with the central betas introduced.
|
|
1183
1195
|
logger.info(
|
|
1184
1196
|
"Threading in betas and beta-dependent action probabilities to algorithm update "
|
|
1185
|
-
"function args for all
|
|
1197
|
+
"function args for all subjects"
|
|
1186
1198
|
)
|
|
1187
|
-
|
|
1188
|
-
|
|
1199
|
+
threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
|
|
1200
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1189
1201
|
betas,
|
|
1190
1202
|
beta_index_by_policy_num,
|
|
1191
1203
|
alg_update_func_args_beta_index,
|
|
1192
1204
|
alg_update_func_args_action_prob_index,
|
|
1193
1205
|
alg_update_func_args_action_prob_times_index,
|
|
1194
1206
|
alg_update_func_args_previous_betas_index,
|
|
1195
|
-
|
|
1207
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1196
1208
|
action_prob_func,
|
|
1197
1209
|
)
|
|
1198
1210
|
|
|
@@ -1202,8 +1214,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1202
1214
|
if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
|
|
1203
1215
|
input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
|
|
1204
1216
|
algorithm_estimating_func,
|
|
1205
|
-
|
|
1206
|
-
|
|
1217
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1218
|
+
threaded_update_func_args_by_policy_num_by_subject_id,
|
|
1207
1219
|
suppress_interactive_data_checks,
|
|
1208
1220
|
)
|
|
1209
1221
|
|
|
@@ -1212,15 +1224,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1212
1224
|
# arguments with the central betas introduced.
|
|
1213
1225
|
logger.info(
|
|
1214
1226
|
"Threading in theta and beta-dependent action probabilities to inference update "
|
|
1215
|
-
"function args for all
|
|
1227
|
+
"function args for all subjects"
|
|
1216
1228
|
)
|
|
1217
|
-
|
|
1218
|
-
|
|
1229
|
+
threaded_inference_func_args_by_subject_id = thread_inference_func_args(
|
|
1230
|
+
inference_func_args_by_subject_id,
|
|
1219
1231
|
inference_func_args_theta_index,
|
|
1220
1232
|
theta,
|
|
1221
1233
|
inference_func_args_action_prob_index,
|
|
1222
|
-
|
|
1223
|
-
|
|
1234
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1235
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
1224
1236
|
action_prob_func,
|
|
1225
1237
|
)
|
|
1226
1238
|
|
|
@@ -1230,32 +1242,32 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1230
1242
|
if not suppress_all_data_checks and inference_func_args_action_prob_index >= 0:
|
|
1231
1243
|
input_checks.require_threaded_inference_estimating_function_args_equivalent(
|
|
1232
1244
|
inference_estimating_func,
|
|
1233
|
-
|
|
1234
|
-
|
|
1245
|
+
inference_func_args_by_subject_id,
|
|
1246
|
+
threaded_inference_func_args_by_subject_id,
|
|
1235
1247
|
suppress_interactive_data_checks,
|
|
1236
1248
|
)
|
|
1237
1249
|
|
|
1238
|
-
# 5. Now we can compute the weighted estimating function stacks for all
|
|
1250
|
+
# 5. Now we can compute the weighted estimating function stacks for all subjects
|
|
1239
1251
|
# as well as collect related values used to construct the adaptive and classical
|
|
1240
1252
|
# sandwich variances.
|
|
1241
1253
|
results = [
|
|
1242
|
-
|
|
1254
|
+
single_subject_weighted_estimating_function_stacker(
|
|
1243
1255
|
beta_dim,
|
|
1244
|
-
|
|
1256
|
+
subject_id,
|
|
1245
1257
|
action_prob_func,
|
|
1246
1258
|
algorithm_estimating_func,
|
|
1247
1259
|
inference_estimating_func,
|
|
1248
1260
|
action_prob_func_args_beta_index,
|
|
1249
1261
|
inference_func_args_theta_index,
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1262
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1263
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1264
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
|
|
1265
|
+
threaded_inference_func_args_by_subject_id[subject_id],
|
|
1266
|
+
policy_num_by_decision_time_by_subject_id[subject_id],
|
|
1267
|
+
action_by_decision_time_by_subject_id[subject_id],
|
|
1256
1268
|
beta_index_by_policy_num,
|
|
1257
1269
|
)
|
|
1258
|
-
for
|
|
1270
|
+
for subject_id in subject_ids.tolist()
|
|
1259
1271
|
]
|
|
1260
1272
|
|
|
1261
1273
|
stacks = jnp.array([result[0] for result in results])
|
|
@@ -1266,10 +1278,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1266
1278
|
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1267
1279
|
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1268
1280
|
# for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
|
|
1269
|
-
# bread matrices. The raw per-
|
|
1281
|
+
# bread matrices. The raw per-subject stacks are also returned for debugging purposes.
|
|
1270
1282
|
|
|
1271
1283
|
# Note that returning the raw stacks here as the first arguments is potentially
|
|
1272
|
-
# memory-intensive when combined with differentiation. Keep this in mind if the per-
|
|
1284
|
+
# memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
|
|
1273
1285
|
# inverse contributions are needed for something like CR2/CR3 small-sample corrections.
|
|
1274
1286
|
return jnp.mean(stacks, axis=0), (
|
|
1275
1287
|
jnp.mean(stacks, axis=0),
|
|
@@ -1280,10 +1292,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1280
1292
|
)
|
|
1281
1293
|
|
|
1282
1294
|
|
|
1283
|
-
def
|
|
1295
|
+
def construct_classical_and_adjusted_sandwiches(
|
|
1284
1296
|
theta_est: jnp.ndarray,
|
|
1285
1297
|
all_post_update_betas: jnp.ndarray,
|
|
1286
|
-
|
|
1298
|
+
subject_ids: jnp.ndarray,
|
|
1287
1299
|
action_prob_func: callable,
|
|
1288
1300
|
action_prob_func_args_beta_index: int,
|
|
1289
1301
|
alg_update_func: callable,
|
|
@@ -1296,32 +1308,34 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1296
1308
|
inference_func_type: str,
|
|
1297
1309
|
inference_func_args_theta_index: int,
|
|
1298
1310
|
inference_func_args_action_prob_index: int,
|
|
1299
|
-
|
|
1311
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
1300
1312
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1301
1313
|
],
|
|
1302
|
-
|
|
1314
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
1303
1315
|
collections.abc.Hashable, dict[int, int | float]
|
|
1304
1316
|
],
|
|
1305
1317
|
initial_policy_num: int | float,
|
|
1306
1318
|
beta_index_by_policy_num: dict[int | float, int],
|
|
1307
|
-
|
|
1308
|
-
|
|
1319
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1320
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
1309
1321
|
collections.abc.Hashable, list[int]
|
|
1310
1322
|
],
|
|
1311
|
-
|
|
1323
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
1312
1324
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
1313
1325
|
],
|
|
1314
|
-
|
|
1326
|
+
action_by_decision_time_by_subject_id: dict[
|
|
1327
|
+
collections.abc.Hashable, dict[int, int]
|
|
1328
|
+
],
|
|
1315
1329
|
suppress_all_data_checks: bool,
|
|
1316
1330
|
suppress_interactive_data_checks: bool,
|
|
1317
1331
|
small_sample_correction: str,
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1332
|
+
form_adjusted_meat_adjustments_explicitly: bool,
|
|
1333
|
+
stabilize_joint_adjusted_bread_inverse: bool,
|
|
1334
|
+
analysis_df: pd.DataFrame | None,
|
|
1335
|
+
active_col_name: str | None,
|
|
1322
1336
|
action_col_name: str | None,
|
|
1323
1337
|
calendar_t_col_name: str | None,
|
|
1324
|
-
|
|
1338
|
+
subject_id_col_name: str | None,
|
|
1325
1339
|
action_prob_func_args: tuple | None,
|
|
1326
1340
|
action_prob_col_name: str | None,
|
|
1327
1341
|
) -> tuple[
|
|
@@ -1350,8 +1364,8 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1350
1364
|
A 1-D JAX NumPy array representing the parameter estimate for inference.
|
|
1351
1365
|
all_post_update_betas (jnp.ndarray):
|
|
1352
1366
|
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
1353
|
-
|
|
1354
|
-
A 1-D JAX NumPy array holding all
|
|
1367
|
+
subject_ids (jnp.ndarray):
|
|
1368
|
+
A 1-D JAX NumPy array holding all subject IDs in the deployment.
|
|
1355
1369
|
action_prob_func (callable):
|
|
1356
1370
|
The action probability function.
|
|
1357
1371
|
action_prob_func_args_beta_index (int):
|
|
@@ -1379,29 +1393,29 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1379
1393
|
inference_func_args_action_prob_index (int):
|
|
1380
1394
|
The index of action probabilities in the inference function arguments tuple, if
|
|
1381
1395
|
applicable. -1 otherwise.
|
|
1382
|
-
|
|
1383
|
-
A dictionary mapping decision times to maps of
|
|
1384
|
-
required to compute action probabilities for this
|
|
1385
|
-
|
|
1386
|
-
A map of
|
|
1387
|
-
Only applies to
|
|
1396
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1397
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
1398
|
+
required to compute action probabilities for this subject.
|
|
1399
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1400
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
1401
|
+
Only applies to active decision times!
|
|
1388
1402
|
initial_policy_num (int | float):
|
|
1389
1403
|
The policy number of the initial policy before any updates.
|
|
1390
1404
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
1391
1405
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1392
1406
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1393
|
-
|
|
1394
|
-
A dictionary mapping
|
|
1395
|
-
|
|
1396
|
-
For each
|
|
1397
|
-
provided. Typically just
|
|
1407
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1408
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
1409
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
1410
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
1411
|
+
provided. Typically just active times if action probabilites are used in the inference
|
|
1398
1412
|
loss or estimating function.
|
|
1399
|
-
|
|
1400
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
1413
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
1414
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
1401
1415
|
to their respective update function arguments.
|
|
1402
|
-
|
|
1403
|
-
A dictionary mapping
|
|
1404
|
-
Only applies to
|
|
1416
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1417
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
1418
|
+
Only applies to active decision times!
|
|
1405
1419
|
suppress_all_data_checks (bool):
|
|
1406
1420
|
If True, suppresses carrying out any data checks at all.
|
|
1407
1421
|
suppress_interactive_data_checks (bool):
|
|
@@ -1411,27 +1425,27 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1411
1425
|
small_sample_correction (str):
|
|
1412
1426
|
The type of small sample correction to apply. See SmallSampleCorrections class for
|
|
1413
1427
|
options.
|
|
1414
|
-
|
|
1415
|
-
If True, explicitly forms the per-
|
|
1428
|
+
form_adjusted_meat_adjustments_explicitly (bool):
|
|
1429
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
|
|
1416
1430
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
1417
1431
|
adaptive sandwich is formed without doing this.
|
|
1418
|
-
|
|
1432
|
+
stabilize_joint_adjusted_bread_inverse (bool):
|
|
1419
1433
|
If True, will apply various techniques to stabilize the joint adaptive bread inverse if necessary.
|
|
1420
|
-
|
|
1421
|
-
The full
|
|
1422
|
-
|
|
1423
|
-
The name of the column in
|
|
1434
|
+
analysis_df (pd.DataFrame):
|
|
1435
|
+
The full analysis dataframe, needed if forming the adaptive meat adjustments explicitly.
|
|
1436
|
+
active_col_name (str):
|
|
1437
|
+
The name of the column in analysis_df indicating whether a subject is active at a given decision time.
|
|
1424
1438
|
action_col_name (str):
|
|
1425
|
-
The name of the column in
|
|
1439
|
+
The name of the column in analysis_df indicating the action taken at a given decision time.
|
|
1426
1440
|
calendar_t_col_name (str):
|
|
1427
|
-
The name of the column in
|
|
1428
|
-
|
|
1429
|
-
The name of the column in
|
|
1441
|
+
The name of the column in analysis_df indicating the calendar time of a given decision time.
|
|
1442
|
+
subject_id_col_name (str):
|
|
1443
|
+
The name of the column in analysis_df indicating the subject ID.
|
|
1430
1444
|
action_prob_func_args (tuple):
|
|
1431
1445
|
The arguments to be passed to the action probability function, needed if forming the
|
|
1432
1446
|
adaptive meat adjustments explicitly.
|
|
1433
1447
|
action_prob_col_name (str):
|
|
1434
|
-
The name of the column in
|
|
1448
|
+
The name of the column in analysis_df indicating the action probability of the action taken,
|
|
1435
1449
|
needed if forming the adaptive meat adjustments explicitly.
|
|
1436
1450
|
Returns:
|
|
1437
1451
|
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
@@ -1444,10 +1458,10 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1444
1458
|
- The classical meat matrix.
|
|
1445
1459
|
- The classical sandwich matrix.
|
|
1446
1460
|
- The average weighted estimating function stack.
|
|
1447
|
-
- All per-
|
|
1448
|
-
- The per-
|
|
1449
|
-
- The per-
|
|
1450
|
-
- The per-
|
|
1461
|
+
- All per-subject weighted estimating function stacks.
|
|
1462
|
+
- The per-subject adaptive meat small-sample corrections.
|
|
1463
|
+
- The per-subject classical meat small-sample corrections.
|
|
1464
|
+
- The per-subject adaptive meat adjustments, if form_adjusted_meat_adjustments_explicitly
|
|
1451
1465
|
is True, otherwise an array of NaNs.
|
|
1452
1466
|
"""
|
|
1453
1467
|
logger.info(
|
|
@@ -1455,13 +1469,13 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1455
1469
|
)
|
|
1456
1470
|
theta_dim = theta_est.shape[0]
|
|
1457
1471
|
beta_dim = all_post_update_betas.shape[1]
|
|
1458
|
-
# Note that these "contributions" are per-
|
|
1459
|
-
|
|
1472
|
+
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
1473
|
+
raw_joint_adjusted_bread_inverse_matrix, (
|
|
1460
1474
|
avg_estimating_function_stack,
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1475
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1476
|
+
per_subject_classical_meat_contributions,
|
|
1477
|
+
per_subject_classical_bread_inverse_contributions,
|
|
1478
|
+
per_subject_estimating_function_stacks,
|
|
1465
1479
|
) = jax.jacrev(
|
|
1466
1480
|
get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
|
|
1467
1481
|
)(
|
|
@@ -1471,7 +1485,7 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1471
1485
|
flatten_params(all_post_update_betas, theta_est),
|
|
1472
1486
|
beta_dim,
|
|
1473
1487
|
theta_dim,
|
|
1474
|
-
|
|
1488
|
+
subject_ids,
|
|
1475
1489
|
action_prob_func,
|
|
1476
1490
|
action_prob_func_args_beta_index,
|
|
1477
1491
|
alg_update_func,
|
|
@@ -1484,87 +1498,87 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1484
1498
|
inference_func_type,
|
|
1485
1499
|
inference_func_args_theta_index,
|
|
1486
1500
|
inference_func_args_action_prob_index,
|
|
1487
|
-
|
|
1488
|
-
|
|
1501
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1502
|
+
policy_num_by_decision_time_by_subject_id,
|
|
1489
1503
|
initial_policy_num,
|
|
1490
1504
|
beta_index_by_policy_num,
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1505
|
+
inference_func_args_by_subject_id,
|
|
1506
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
1507
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1508
|
+
action_by_decision_time_by_subject_id,
|
|
1495
1509
|
suppress_all_data_checks,
|
|
1496
1510
|
suppress_interactive_data_checks,
|
|
1497
1511
|
)
|
|
1498
1512
|
|
|
1499
|
-
|
|
1513
|
+
num_subjects = len(subject_ids)
|
|
1500
1514
|
|
|
1501
1515
|
(
|
|
1502
|
-
|
|
1516
|
+
joint_adjusted_meat_matrix,
|
|
1503
1517
|
classical_meat_matrix,
|
|
1504
|
-
|
|
1505
|
-
|
|
1518
|
+
per_subject_adjusted_corrections,
|
|
1519
|
+
per_subject_classical_corrections,
|
|
1506
1520
|
) = perform_desired_small_sample_correction(
|
|
1507
1521
|
small_sample_correction,
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1522
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1523
|
+
per_subject_classical_meat_contributions,
|
|
1524
|
+
per_subject_classical_bread_inverse_contributions,
|
|
1525
|
+
num_subjects,
|
|
1512
1526
|
theta_dim,
|
|
1513
1527
|
)
|
|
1514
1528
|
|
|
1515
1529
|
# Increase diagonal block dominance possibly improve conditioning of diagonal
|
|
1516
1530
|
# blocks as necessary, to ensure mathematical stability of joint bread inverse
|
|
1517
|
-
|
|
1531
|
+
stabilized_joint_adjusted_bread_inverse_matrix = (
|
|
1518
1532
|
(
|
|
1519
|
-
|
|
1520
|
-
|
|
1533
|
+
stabilize_joint_adjusted_bread_inverse_if_necessary(
|
|
1534
|
+
raw_joint_adjusted_bread_inverse_matrix,
|
|
1521
1535
|
beta_dim,
|
|
1522
1536
|
theta_dim,
|
|
1523
1537
|
)
|
|
1524
1538
|
)
|
|
1525
|
-
if
|
|
1526
|
-
else
|
|
1539
|
+
if stabilize_joint_adjusted_bread_inverse
|
|
1540
|
+
else raw_joint_adjusted_bread_inverse_matrix
|
|
1527
1541
|
)
|
|
1528
1542
|
|
|
1529
1543
|
# Now stably (no explicit inversion) form our sandwiches.
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1544
|
+
joint_adjusted_sandwich = form_sandwich_from_bread_inverse_and_meat(
|
|
1545
|
+
stabilized_joint_adjusted_bread_inverse_matrix,
|
|
1546
|
+
joint_adjusted_meat_matrix,
|
|
1547
|
+
num_subjects,
|
|
1534
1548
|
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1535
1549
|
)
|
|
1536
1550
|
classical_bread_inverse_matrix = jnp.mean(
|
|
1537
|
-
|
|
1551
|
+
per_subject_classical_bread_inverse_contributions, axis=0
|
|
1538
1552
|
)
|
|
1539
1553
|
classical_sandwich = form_sandwich_from_bread_inverse_and_meat(
|
|
1540
1554
|
classical_bread_inverse_matrix,
|
|
1541
1555
|
classical_meat_matrix,
|
|
1542
|
-
|
|
1556
|
+
num_subjects,
|
|
1543
1557
|
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1544
1558
|
)
|
|
1545
1559
|
|
|
1546
|
-
|
|
1547
|
-
(len(
|
|
1560
|
+
per_subject_adjusted_meat_adjustments = jnp.full(
|
|
1561
|
+
(len(subject_ids), theta_dim, theta_dim), jnp.nan
|
|
1548
1562
|
)
|
|
1549
|
-
if
|
|
1550
|
-
|
|
1551
|
-
|
|
1563
|
+
if form_adjusted_meat_adjustments_explicitly:
|
|
1564
|
+
per_subject_adjusted_classical_meat_contributions = (
|
|
1565
|
+
form_adjusted_meat_adjustments_directly(
|
|
1552
1566
|
theta_dim,
|
|
1553
1567
|
all_post_update_betas.shape[1],
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1568
|
+
stabilized_joint_adjusted_bread_inverse_matrix,
|
|
1569
|
+
per_subject_estimating_function_stacks,
|
|
1570
|
+
analysis_df,
|
|
1571
|
+
active_col_name,
|
|
1558
1572
|
action_col_name,
|
|
1559
1573
|
calendar_t_col_name,
|
|
1560
|
-
|
|
1574
|
+
subject_id_col_name,
|
|
1561
1575
|
action_prob_func,
|
|
1562
1576
|
action_prob_func_args,
|
|
1563
1577
|
action_prob_func_args_beta_index,
|
|
1564
1578
|
theta_est,
|
|
1565
1579
|
inference_func,
|
|
1566
1580
|
inference_func_args_theta_index,
|
|
1567
|
-
|
|
1581
|
+
subject_ids,
|
|
1568
1582
|
action_prob_col_name,
|
|
1569
1583
|
)
|
|
1570
1584
|
)
|
|
@@ -1574,30 +1588,30 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1574
1588
|
# First just apply any small-sample correction for parity.
|
|
1575
1589
|
(
|
|
1576
1590
|
_,
|
|
1577
|
-
|
|
1591
|
+
theta_only_adjusted_meat_matrix_v2,
|
|
1578
1592
|
_,
|
|
1579
1593
|
_,
|
|
1580
1594
|
) = perform_desired_small_sample_correction(
|
|
1581
1595
|
small_sample_correction,
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1596
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1597
|
+
per_subject_adjusted_classical_meat_contributions,
|
|
1598
|
+
per_subject_classical_bread_inverse_contributions,
|
|
1599
|
+
num_subjects,
|
|
1586
1600
|
theta_dim,
|
|
1587
1601
|
)
|
|
1588
|
-
|
|
1602
|
+
theta_only_adjusted_sandwich_from_adjustments = (
|
|
1589
1603
|
form_sandwich_from_bread_inverse_and_meat(
|
|
1590
1604
|
classical_bread_inverse_matrix,
|
|
1591
|
-
|
|
1592
|
-
|
|
1605
|
+
theta_only_adjusted_meat_matrix_v2,
|
|
1606
|
+
num_subjects,
|
|
1593
1607
|
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1594
1608
|
)
|
|
1595
1609
|
)
|
|
1596
|
-
|
|
1610
|
+
theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
|
|
1597
1611
|
|
|
1598
1612
|
if not np.allclose(
|
|
1599
|
-
|
|
1600
|
-
|
|
1613
|
+
theta_only_adjusted_sandwich,
|
|
1614
|
+
theta_only_adjusted_sandwich_from_adjustments,
|
|
1601
1615
|
rtol=3e-2,
|
|
1602
1616
|
):
|
|
1603
1617
|
logger.warning(
|
|
@@ -1607,26 +1621,26 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1607
1621
|
# Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
|
|
1608
1622
|
# values too. The joint adaptive bread inverse should always be block lower triangular.
|
|
1609
1623
|
return (
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1624
|
+
raw_joint_adjusted_bread_inverse_matrix,
|
|
1625
|
+
stabilized_joint_adjusted_bread_inverse_matrix,
|
|
1626
|
+
joint_adjusted_meat_matrix,
|
|
1627
|
+
joint_adjusted_sandwich,
|
|
1614
1628
|
classical_bread_inverse_matrix,
|
|
1615
1629
|
classical_meat_matrix,
|
|
1616
1630
|
classical_sandwich,
|
|
1617
1631
|
avg_estimating_function_stack,
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
|
|
1632
|
+
per_subject_estimating_function_stacks,
|
|
1633
|
+
per_subject_adjusted_corrections,
|
|
1634
|
+
per_subject_classical_corrections,
|
|
1635
|
+
per_subject_adjusted_meat_adjustments,
|
|
1622
1636
|
)
|
|
1623
1637
|
|
|
1624
1638
|
|
|
1625
1639
|
# TODO: I think there should be interaction to confirm stabilization. It is
|
|
1626
|
-
# important for the
|
|
1627
|
-
# that the
|
|
1628
|
-
def
|
|
1629
|
-
|
|
1640
|
+
# important for the subject to know if this is happening. Even if enabled, it is important
|
|
1641
|
+
# that the subject know it actually kicks in.
|
|
1642
|
+
def stabilize_joint_adjusted_bread_inverse_if_necessary(
|
|
1643
|
+
joint_adjusted_bread_inverse_matrix: jnp.ndarray,
|
|
1630
1644
|
beta_dim: int,
|
|
1631
1645
|
theta_dim: int,
|
|
1632
1646
|
) -> jnp.ndarray:
|
|
@@ -1635,7 +1649,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1635
1649
|
dominance and/or adding a small ridge penalty to the diagonal blocks.
|
|
1636
1650
|
|
|
1637
1651
|
Args:
|
|
1638
|
-
|
|
1652
|
+
joint_adjusted_bread_inverse_matrix (jnp.ndarray):
|
|
1639
1653
|
A 2-D JAX NumPy array representing the joint adaptive bread inverse matrix.
|
|
1640
1654
|
beta_dim (int):
|
|
1641
1655
|
The dimension of each beta parameter.
|
|
@@ -1656,7 +1670,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1656
1670
|
|
|
1657
1671
|
# Grab just the RL block and convert numpy array for easier manipulation.
|
|
1658
1672
|
RL_stack_beta_derivatives_block = np.array(
|
|
1659
|
-
|
|
1673
|
+
joint_adjusted_bread_inverse_matrix[:-theta_dim, :-theta_dim]
|
|
1660
1674
|
)
|
|
1661
1675
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
1662
1676
|
for i in range(1, num_updates + 1):
|
|
@@ -1684,7 +1698,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1684
1698
|
RL_stack_beta_derivatives_block[
|
|
1685
1699
|
diagonal_block_slice, diagonal_block_slice
|
|
1686
1700
|
] = diagonal_block + ridge_penalty * np.eye(beta_dim)
|
|
1687
|
-
# TODO: Require
|
|
1701
|
+
# TODO: Require subject input here in interactive settings?
|
|
1688
1702
|
logger.info(
|
|
1689
1703
|
"Added ridge penalty of %s to diagonal block for update %s to improve conditioning from %s to %s",
|
|
1690
1704
|
ridge_penalty,
|
|
@@ -1775,11 +1789,11 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1775
1789
|
[
|
|
1776
1790
|
[
|
|
1777
1791
|
RL_stack_beta_derivatives_block,
|
|
1778
|
-
|
|
1792
|
+
joint_adjusted_bread_inverse_matrix[:-theta_dim, -theta_dim:],
|
|
1779
1793
|
],
|
|
1780
1794
|
[
|
|
1781
|
-
|
|
1782
|
-
|
|
1795
|
+
joint_adjusted_bread_inverse_matrix[-theta_dim:, :-theta_dim],
|
|
1796
|
+
joint_adjusted_bread_inverse_matrix[-theta_dim:, -theta_dim:],
|
|
1783
1797
|
],
|
|
1784
1798
|
]
|
|
1785
1799
|
)
|
|
@@ -1788,7 +1802,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1788
1802
|
def form_sandwich_from_bread_inverse_and_meat(
|
|
1789
1803
|
bread_inverse: jnp.ndarray,
|
|
1790
1804
|
meat: jnp.ndarray,
|
|
1791
|
-
|
|
1805
|
+
num_subjects: int,
|
|
1792
1806
|
method: str = SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1793
1807
|
) -> jnp.ndarray:
|
|
1794
1808
|
"""
|
|
@@ -1802,8 +1816,8 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1802
1816
|
A 2-D JAX NumPy array representing the bread inverse matrix.
|
|
1803
1817
|
meat (jnp.ndarray):
|
|
1804
1818
|
A 2-D JAX NumPy array representing the meat matrix.
|
|
1805
|
-
|
|
1806
|
-
The number of
|
|
1819
|
+
num_subjects (int):
|
|
1820
|
+
The number of subjects in the deployment, used to scale the sandwich appropriately.
|
|
1807
1821
|
method (str):
|
|
1808
1822
|
The method to use for forming the sandwich.
|
|
1809
1823
|
|
|
@@ -1829,7 +1843,7 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1829
1843
|
L, scipy.linalg.solve_triangular(L, meat.T, lower=True).T, lower=True
|
|
1830
1844
|
)
|
|
1831
1845
|
|
|
1832
|
-
return Q @ new_meat @ Q.T /
|
|
1846
|
+
return Q @ new_meat @ Q.T / num_subjects
|
|
1833
1847
|
elif method == SandwichFormationMethods.MEAT_SVD_SOLVE:
|
|
1834
1848
|
# Factor the meat via SVD without any symmetrization or truncation.
|
|
1835
1849
|
# For general (possibly slightly nonsymmetric) M, SVD gives M = U @ diag(s) @ Vh.
|
|
@@ -1843,14 +1857,14 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1843
1857
|
W_left = scipy.linalg.solve(bread_inverse, C_left)
|
|
1844
1858
|
W_right = scipy.linalg.solve(bread_inverse, C_right)
|
|
1845
1859
|
|
|
1846
|
-
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T /
|
|
1847
|
-
return W_left @ W_right.T /
|
|
1860
|
+
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
|
|
1861
|
+
return W_left @ W_right.T / num_subjects
|
|
1848
1862
|
|
|
1849
1863
|
elif method == SandwichFormationMethods.NAIVE:
|
|
1850
1864
|
# Simply invert the bread inverse and form the sandwich directly.
|
|
1851
1865
|
# This is NOT numerically stable and is only included for comparison purposes.
|
|
1852
1866
|
bread = np.linalg.inv(bread_inverse)
|
|
1853
|
-
return bread @ meat @ meat.T /
|
|
1867
|
+
return bread @ meat @ meat.T / num_subjects
|
|
1854
1868
|
|
|
1855
1869
|
else:
|
|
1856
1870
|
raise ValueError(
|