lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/after_study_analysis.py +397 -387
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -21
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +146 -128
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +7 -7
- lifejacket/get_datum_for_blowup_supervised_learning.py +315 -307
- lifejacket/helper_functions.py +45 -38
- lifejacket/input_checks.py +263 -261
- lifejacket/small_sample_corrections.py +42 -40
- lifejacket-1.0.0.dist-info/METADATA +56 -0
- lifejacket-1.0.0.dist-info/RECORD +17 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -24,8 +24,8 @@ from .constants import (
|
|
|
24
24
|
SandwichFormationMethods,
|
|
25
25
|
SmallSampleCorrections,
|
|
26
26
|
)
|
|
27
|
-
from .
|
|
28
|
-
|
|
27
|
+
from .form_adjusted_meat_adjustments_directly import (
|
|
28
|
+
form_adjusted_meat_adjustments_directly,
|
|
29
29
|
)
|
|
30
30
|
from . import input_checks
|
|
31
31
|
from . import get_datum_for_blowup_supervised_learning
|
|
@@ -37,9 +37,9 @@ from .helper_functions import (
|
|
|
37
37
|
calculate_beta_dim,
|
|
38
38
|
collect_all_post_update_betas,
|
|
39
39
|
construct_beta_index_by_policy_num_map,
|
|
40
|
-
|
|
40
|
+
extract_action_and_policy_by_decision_time_by_subject_id,
|
|
41
41
|
flatten_params,
|
|
42
|
-
|
|
42
|
+
get_active_df_column,
|
|
43
43
|
get_min_time_by_policy_num,
|
|
44
44
|
get_radon_nikodym_weight,
|
|
45
45
|
load_function_from_same_named_file,
|
|
@@ -61,7 +61,7 @@ def cli():
|
|
|
61
61
|
|
|
62
62
|
# TODO: Check all help strings for accuracy.
|
|
63
63
|
# TODO: Deal with NA, -1, etc policy numbers
|
|
64
|
-
# TODO: Make sure in
|
|
64
|
+
# TODO: Make sure in deployment is never on for more than one stretch EDIT: unclear if
|
|
65
65
|
# this will remain an invariant as we deal with more complicated data missingness
|
|
66
66
|
# TODO: I think I'm agnostic to indexing of calendar times but should check because
|
|
67
67
|
# otherwise need to add a check here to verify required format.
|
|
@@ -69,7 +69,7 @@ def cli():
|
|
|
69
69
|
# Higher dimensional objects not supported. Not entirely sure what kind of "scalars" apply.
|
|
70
70
|
@cli.command(name="analyze")
|
|
71
71
|
@click.option(
|
|
72
|
-
"--
|
|
72
|
+
"--analysis_df_pickle",
|
|
73
73
|
type=click.File("rb"),
|
|
74
74
|
help="Pickled pandas dataframe in correct format (see contract/readme).",
|
|
75
75
|
required=True,
|
|
@@ -83,7 +83,7 @@ def cli():
|
|
|
83
83
|
@click.option(
|
|
84
84
|
"--action_prob_func_args_pickle",
|
|
85
85
|
type=click.File("rb"),
|
|
86
|
-
help="Pickled dictionary that contains the action probability function arguments for all decision times for all
|
|
86
|
+
help="Pickled dictionary that contains the action probability function arguments for all decision times for all subjects.",
|
|
87
87
|
required=True,
|
|
88
88
|
)
|
|
89
89
|
@click.option(
|
|
@@ -95,7 +95,7 @@ def cli():
|
|
|
95
95
|
@click.option(
|
|
96
96
|
"--alg_update_func_filename",
|
|
97
97
|
type=click.Path(exists=True),
|
|
98
|
-
help="File that contains the per-
|
|
98
|
+
help="File that contains the per-subject update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
|
|
99
99
|
required=True,
|
|
100
100
|
)
|
|
101
101
|
@click.option(
|
|
@@ -107,7 +107,7 @@ def cli():
|
|
|
107
107
|
@click.option(
|
|
108
108
|
"--alg_update_func_args_pickle",
|
|
109
109
|
type=click.File("rb"),
|
|
110
|
-
help="Pickled dictionary that contains the algorithm update function arguments for all update times for all
|
|
110
|
+
help="Pickled dictionary that contains the algorithm update function arguments for all update times for all subjects.",
|
|
111
111
|
required=True,
|
|
112
112
|
)
|
|
113
113
|
@click.option(
|
|
@@ -137,7 +137,7 @@ def cli():
|
|
|
137
137
|
@click.option(
|
|
138
138
|
"--inference_func_filename",
|
|
139
139
|
type=click.Path(exists=True),
|
|
140
|
-
help="File that contains the per-
|
|
140
|
+
help="File that contains the per-subject loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
|
|
141
141
|
required=True,
|
|
142
142
|
)
|
|
143
143
|
@click.option(
|
|
@@ -155,56 +155,56 @@ def cli():
|
|
|
155
155
|
@click.option(
|
|
156
156
|
"--theta_calculation_func_filename",
|
|
157
157
|
type=click.Path(exists=True),
|
|
158
|
-
help="Path to file that allows one to actually calculate a theta estimate given the
|
|
158
|
+
help="Path to file that allows one to actually calculate a theta estimate given the analysis dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
|
|
159
159
|
required=True,
|
|
160
160
|
)
|
|
161
161
|
@click.option(
|
|
162
|
-
"--
|
|
162
|
+
"--active_col_name",
|
|
163
163
|
type=str,
|
|
164
164
|
required=True,
|
|
165
|
-
help="Name of the binary column in the
|
|
165
|
+
help="Name of the binary column in the analysis dataframe that indicates whether a subject is in the deployment.",
|
|
166
166
|
)
|
|
167
167
|
@click.option(
|
|
168
168
|
"--action_col_name",
|
|
169
169
|
type=str,
|
|
170
170
|
required=True,
|
|
171
|
-
help="Name of the binary column in the
|
|
171
|
+
help="Name of the binary column in the analysis dataframe that indicates which action was taken.",
|
|
172
172
|
)
|
|
173
173
|
@click.option(
|
|
174
174
|
"--policy_num_col_name",
|
|
175
175
|
type=str,
|
|
176
176
|
required=True,
|
|
177
|
-
help="Name of the column in the
|
|
177
|
+
help="Name of the column in the analysis dataframe that indicates the policy number in use.",
|
|
178
178
|
)
|
|
179
179
|
@click.option(
|
|
180
180
|
"--calendar_t_col_name",
|
|
181
181
|
type=str,
|
|
182
182
|
required=True,
|
|
183
|
-
help="Name of the column in the
|
|
183
|
+
help="Name of the column in the analysis dataframe that indicates calendar time (shared integer index across subjects).",
|
|
184
184
|
)
|
|
185
185
|
@click.option(
|
|
186
|
-
"--
|
|
186
|
+
"--subject_id_col_name",
|
|
187
187
|
type=str,
|
|
188
188
|
required=True,
|
|
189
|
-
help="Name of the column in the
|
|
189
|
+
help="Name of the column in the analysis dataframe that indicates subject id.",
|
|
190
190
|
)
|
|
191
191
|
@click.option(
|
|
192
192
|
"--action_prob_col_name",
|
|
193
193
|
type=str,
|
|
194
194
|
required=True,
|
|
195
|
-
help="Name of the column in the
|
|
195
|
+
help="Name of the column in the analysis dataframe that gives action one probabilities.",
|
|
196
196
|
)
|
|
197
197
|
@click.option(
|
|
198
198
|
"--reward_col_name",
|
|
199
199
|
type=str,
|
|
200
200
|
required=True,
|
|
201
|
-
help="Name of the column in the
|
|
201
|
+
help="Name of the column in the analysis dataframe that gives rewards.",
|
|
202
202
|
)
|
|
203
203
|
@click.option(
|
|
204
204
|
"--suppress_interactive_data_checks",
|
|
205
205
|
type=bool,
|
|
206
206
|
default=False,
|
|
207
|
-
help="Flag to suppress any data checks that require
|
|
207
|
+
help="Flag to suppress any data checks that require subject input. This is suitable for tests and large simulations",
|
|
208
208
|
)
|
|
209
209
|
@click.option(
|
|
210
210
|
"--suppress_all_data_checks",
|
|
@@ -232,13 +232,13 @@ def cli():
|
|
|
232
232
|
help="Flag to collect data for supervised learning blowup detection. This will write a single datum and label to a file in the same directory as the input files.",
|
|
233
233
|
)
|
|
234
234
|
@click.option(
|
|
235
|
-
"--
|
|
235
|
+
"--form_adjusted_meat_adjustments_explicitly",
|
|
236
236
|
type=bool,
|
|
237
237
|
default=False,
|
|
238
|
-
help="If True, explicitly forms the per-
|
|
238
|
+
help="If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
|
|
239
239
|
)
|
|
240
240
|
@click.option(
|
|
241
|
-
"--
|
|
241
|
+
"--stabilize_joint_adjusted_bread_inverse",
|
|
242
242
|
type=bool,
|
|
243
243
|
default=True,
|
|
244
244
|
help="If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning thresholds.",
|
|
@@ -248,7 +248,7 @@ def analyze_dataset_wrapper(**kwargs):
|
|
|
248
248
|
This function is a wrapper around analyze_dataset to facilitate command line use.
|
|
249
249
|
|
|
250
250
|
From the command line, we will take pickles and filenames for Python objects.
|
|
251
|
-
|
|
251
|
+
We unpickle/load files here for passing to the implementation function, which
|
|
252
252
|
may also be called in its own right with in-memory objects.
|
|
253
253
|
|
|
254
254
|
See analyze_dataset for the underlying details.
|
|
@@ -256,18 +256,20 @@ def analyze_dataset_wrapper(**kwargs):
|
|
|
256
256
|
Returns: None
|
|
257
257
|
"""
|
|
258
258
|
|
|
259
|
-
# Pass along the folder the
|
|
260
|
-
# Do it now because we will be removing the
|
|
261
|
-
kwargs["output_dir"] = pathlib.Path(
|
|
259
|
+
# Pass along the folder the analysis dataframe is in as the output folder.
|
|
260
|
+
# Do it now because we will be removing the analysis dataframe pickle from kwargs.
|
|
261
|
+
kwargs["output_dir"] = pathlib.Path(
|
|
262
|
+
kwargs["analysis_df_pickle"].name
|
|
263
|
+
).parent.resolve()
|
|
262
264
|
|
|
263
265
|
# Unpickle pickles and replace those args in kwargs
|
|
264
|
-
kwargs["
|
|
266
|
+
kwargs["analysis_df"] = pickle.load(kwargs["analysis_df_pickle"])
|
|
265
267
|
kwargs["action_prob_func_args"] = pickle.load(
|
|
266
268
|
kwargs["action_prob_func_args_pickle"]
|
|
267
269
|
)
|
|
268
270
|
kwargs["alg_update_func_args"] = pickle.load(kwargs["alg_update_func_args_pickle"])
|
|
269
271
|
|
|
270
|
-
kwargs.pop("
|
|
272
|
+
kwargs.pop("analysis_df_pickle")
|
|
271
273
|
kwargs.pop("action_prob_func_args_pickle")
|
|
272
274
|
kwargs.pop("alg_update_func_args_pickle")
|
|
273
275
|
|
|
@@ -295,7 +297,7 @@ def analyze_dataset_wrapper(**kwargs):
|
|
|
295
297
|
|
|
296
298
|
def analyze_dataset(
|
|
297
299
|
output_dir: pathlib.Path | str,
|
|
298
|
-
|
|
300
|
+
analysis_df: pd.DataFrame,
|
|
299
301
|
action_prob_func: Callable,
|
|
300
302
|
action_prob_func_args: dict[int, Any],
|
|
301
303
|
action_prob_func_args_beta_index: int,
|
|
@@ -310,19 +312,19 @@ def analyze_dataset(
|
|
|
310
312
|
inference_func_type: str,
|
|
311
313
|
inference_func_args_theta_index: int,
|
|
312
314
|
theta_calculation_func: Callable[[pd.DataFrame], jnp.ndarray],
|
|
313
|
-
|
|
315
|
+
active_col_name: str,
|
|
314
316
|
action_col_name: str,
|
|
315
317
|
policy_num_col_name: str,
|
|
316
318
|
calendar_t_col_name: str,
|
|
317
|
-
|
|
319
|
+
subject_id_col_name: str,
|
|
318
320
|
action_prob_col_name: str,
|
|
319
321
|
reward_col_name: str,
|
|
320
322
|
suppress_interactive_data_checks: bool,
|
|
321
323
|
suppress_all_data_checks: bool,
|
|
322
324
|
small_sample_correction: str,
|
|
323
325
|
collect_data_for_blowup_supervised_learning: bool,
|
|
324
|
-
|
|
325
|
-
|
|
326
|
+
form_adjusted_meat_adjustments_explicitly: bool,
|
|
327
|
+
stabilize_joint_adjusted_bread_inverse: bool,
|
|
326
328
|
) -> None:
|
|
327
329
|
"""
|
|
328
330
|
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using adaptive and classical sandwich estimators.
|
|
@@ -337,8 +339,8 @@ def analyze_dataset(
|
|
|
337
339
|
Parameters:
|
|
338
340
|
output_dir (pathlib.Path | str):
|
|
339
341
|
Directory in which to save output files.
|
|
340
|
-
|
|
341
|
-
DataFrame containing the
|
|
342
|
+
analysis_df (pd.DataFrame):
|
|
343
|
+
DataFrame containing the deployment data.
|
|
342
344
|
action_prob_func (callable):
|
|
343
345
|
Action probability function.
|
|
344
346
|
action_prob_func_args (dict[int, Any]):
|
|
@@ -364,21 +366,21 @@ def analyze_dataset(
|
|
|
364
366
|
inference_func_args_theta_index (int):
|
|
365
367
|
Index for theta in inference function arguments.
|
|
366
368
|
theta_calculation_func (callable):
|
|
367
|
-
Function to estimate theta from the
|
|
368
|
-
|
|
369
|
-
Column name indicating if a
|
|
369
|
+
Function to estimate theta from the analysis dataframe.
|
|
370
|
+
active_col_name (str):
|
|
371
|
+
Column name indicating if a subject is active in the analysis dataframe.
|
|
370
372
|
action_col_name (str):
|
|
371
|
-
Column name for actions in the
|
|
373
|
+
Column name for actions in the analysis dataframe.
|
|
372
374
|
policy_num_col_name (str):
|
|
373
|
-
Column name for policy numbers in the
|
|
375
|
+
Column name for policy numbers in the analysis dataframe.
|
|
374
376
|
calendar_t_col_name (str):
|
|
375
|
-
Column name for calendar time in the
|
|
376
|
-
|
|
377
|
-
Column name for
|
|
377
|
+
Column name for calendar time in the analysis dataframe.
|
|
378
|
+
subject_id_col_name (str):
|
|
379
|
+
Column name for subject IDs in the analysis dataframe.
|
|
378
380
|
action_prob_col_name (str):
|
|
379
|
-
Column name for action probabilities in the
|
|
381
|
+
Column name for action probabilities in the analysis dataframe.
|
|
380
382
|
reward_col_name (str):
|
|
381
|
-
Column name for rewards in the
|
|
383
|
+
Column name for rewards in the analysis dataframe.
|
|
382
384
|
suppress_interactive_data_checks (bool):
|
|
383
385
|
Whether to suppress interactive data checks. This should be used in simulations, for example.
|
|
384
386
|
suppress_all_data_checks (bool):
|
|
@@ -387,11 +389,11 @@ def analyze_dataset(
|
|
|
387
389
|
Type of small sample correction to apply.
|
|
388
390
|
collect_data_for_blowup_supervised_learning (bool):
|
|
389
391
|
Whether to collect data for doing supervised learning about adaptive sandwich blowup.
|
|
390
|
-
|
|
391
|
-
If True, explicitly forms the per-
|
|
392
|
+
form_adjusted_meat_adjustments_explicitly (bool):
|
|
393
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
|
|
392
394
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
393
395
|
adaptive sandwich is formed without doing this.
|
|
394
|
-
|
|
396
|
+
stabilize_joint_adjusted_bread_inverse (bool):
|
|
395
397
|
If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
|
|
396
398
|
thresholds.
|
|
397
399
|
|
|
@@ -406,19 +408,19 @@ def analyze_dataset(
|
|
|
406
408
|
level=logging.INFO,
|
|
407
409
|
)
|
|
408
410
|
|
|
409
|
-
theta_est = jnp.array(theta_calculation_func(
|
|
411
|
+
theta_est = jnp.array(theta_calculation_func(analysis_df))
|
|
410
412
|
|
|
411
413
|
beta_dim = calculate_beta_dim(
|
|
412
414
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
413
415
|
)
|
|
414
416
|
if not suppress_all_data_checks:
|
|
415
417
|
input_checks.perform_first_wave_input_checks(
|
|
416
|
-
|
|
417
|
-
|
|
418
|
+
analysis_df,
|
|
419
|
+
active_col_name,
|
|
418
420
|
action_col_name,
|
|
419
421
|
policy_num_col_name,
|
|
420
422
|
calendar_t_col_name,
|
|
421
|
-
|
|
423
|
+
subject_id_col_name,
|
|
422
424
|
action_prob_col_name,
|
|
423
425
|
reward_col_name,
|
|
424
426
|
action_prob_func,
|
|
@@ -439,7 +441,7 @@ def analyze_dataset(
|
|
|
439
441
|
|
|
440
442
|
beta_index_by_policy_num, initial_policy_num = (
|
|
441
443
|
construct_beta_index_by_policy_num_map(
|
|
442
|
-
|
|
444
|
+
analysis_df, policy_num_col_name, active_col_name
|
|
443
445
|
)
|
|
444
446
|
)
|
|
445
447
|
|
|
@@ -447,11 +449,11 @@ def analyze_dataset(
|
|
|
447
449
|
beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
|
|
448
450
|
)
|
|
449
451
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
452
|
+
action_by_decision_time_by_subject_id, policy_num_by_decision_time_by_subject_id = (
|
|
453
|
+
extract_action_and_policy_by_decision_time_by_subject_id(
|
|
454
|
+
analysis_df,
|
|
455
|
+
subject_id_col_name,
|
|
456
|
+
active_col_name,
|
|
455
457
|
calendar_t_col_name,
|
|
456
458
|
action_col_name,
|
|
457
459
|
policy_num_col_name,
|
|
@@ -459,45 +461,45 @@ def analyze_dataset(
|
|
|
459
461
|
)
|
|
460
462
|
|
|
461
463
|
(
|
|
462
|
-
|
|
464
|
+
inference_func_args_by_subject_id,
|
|
463
465
|
inference_func_args_action_prob_index,
|
|
464
|
-
|
|
466
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
465
467
|
) = process_inference_func_args(
|
|
466
468
|
inference_func,
|
|
467
469
|
inference_func_args_theta_index,
|
|
468
|
-
|
|
470
|
+
analysis_df,
|
|
469
471
|
theta_est,
|
|
470
472
|
action_prob_col_name,
|
|
471
473
|
calendar_t_col_name,
|
|
472
|
-
|
|
473
|
-
|
|
474
|
+
subject_id_col_name,
|
|
475
|
+
active_col_name,
|
|
474
476
|
)
|
|
475
477
|
|
|
476
|
-
# Use a per-
|
|
478
|
+
# Use a per-subject weighted estimating function stacking functino to derive classical and joint
|
|
477
479
|
# adaptive meat and inverse bread matrices. This is facilitated because the *value* of the
|
|
478
480
|
# weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
|
|
479
481
|
logger.info(
|
|
480
|
-
"Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across
|
|
482
|
+
"Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across subjects."
|
|
481
483
|
)
|
|
482
484
|
|
|
483
|
-
|
|
485
|
+
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
484
486
|
(
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
487
|
+
stabilized_joint_adjusted_bread_inverse_matrix,
|
|
488
|
+
raw_joint_adjusted_bread_inverse_matrix,
|
|
489
|
+
joint_adjusted_meat_matrix,
|
|
490
|
+
joint_adjusted_sandwich_matrix,
|
|
489
491
|
classical_bread_inverse_matrix,
|
|
490
492
|
classical_meat_matrix,
|
|
491
493
|
classical_sandwich_var_estimate,
|
|
492
494
|
avg_estimating_function_stack,
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
) =
|
|
495
|
+
per_subject_estimating_function_stacks,
|
|
496
|
+
per_subject_adjusted_corrections,
|
|
497
|
+
per_subject_classical_corrections,
|
|
498
|
+
per_subject_adjusted_meat_adjustments,
|
|
499
|
+
) = construct_classical_and_adjusted_sandwiches(
|
|
498
500
|
theta_est,
|
|
499
501
|
all_post_update_betas,
|
|
500
|
-
|
|
502
|
+
subject_ids,
|
|
501
503
|
action_prob_func,
|
|
502
504
|
action_prob_func_args_beta_index,
|
|
503
505
|
alg_update_func,
|
|
@@ -511,23 +513,23 @@ def analyze_dataset(
|
|
|
511
513
|
inference_func_args_theta_index,
|
|
512
514
|
inference_func_args_action_prob_index,
|
|
513
515
|
action_prob_func_args,
|
|
514
|
-
|
|
516
|
+
policy_num_by_decision_time_by_subject_id,
|
|
515
517
|
initial_policy_num,
|
|
516
518
|
beta_index_by_policy_num,
|
|
517
|
-
|
|
518
|
-
|
|
519
|
+
inference_func_args_by_subject_id,
|
|
520
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
519
521
|
alg_update_func_args,
|
|
520
|
-
|
|
522
|
+
action_by_decision_time_by_subject_id,
|
|
521
523
|
suppress_all_data_checks,
|
|
522
524
|
suppress_interactive_data_checks,
|
|
523
525
|
small_sample_correction,
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
526
|
+
form_adjusted_meat_adjustments_explicitly,
|
|
527
|
+
stabilize_joint_adjusted_bread_inverse,
|
|
528
|
+
analysis_df,
|
|
529
|
+
active_col_name,
|
|
528
530
|
action_col_name,
|
|
529
531
|
calendar_t_col_name,
|
|
530
|
-
|
|
532
|
+
subject_id_col_name,
|
|
531
533
|
action_prob_func_args,
|
|
532
534
|
action_prob_col_name,
|
|
533
535
|
)
|
|
@@ -543,18 +545,18 @@ def analyze_dataset(
|
|
|
543
545
|
|
|
544
546
|
# This bottom right corner of the joint (betas and theta) variance matrix is the portion
|
|
545
547
|
# corresponding to just theta.
|
|
546
|
-
|
|
548
|
+
adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
|
|
547
549
|
-theta_dim:, -theta_dim:
|
|
548
550
|
]
|
|
549
551
|
|
|
550
552
|
# Check for negative diagonal elements and set them to zero if found
|
|
551
|
-
adaptive_diagonal = np.diag(
|
|
553
|
+
adaptive_diagonal = np.diag(adjusted_sandwich_var_estimate)
|
|
552
554
|
if np.any(adaptive_diagonal < 0):
|
|
553
555
|
logger.warning(
|
|
554
556
|
"Found negative diagonal elements in adaptive sandwich variance estimate. Setting them to zero."
|
|
555
557
|
)
|
|
556
558
|
np.fill_diagonal(
|
|
557
|
-
|
|
559
|
+
adjusted_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
|
|
558
560
|
)
|
|
559
561
|
|
|
560
562
|
logger.info("Writing results to file...")
|
|
@@ -563,7 +565,7 @@ def analyze_dataset(
|
|
|
563
565
|
|
|
564
566
|
analysis_dict = {
|
|
565
567
|
"theta_est": theta_est,
|
|
566
|
-
"
|
|
568
|
+
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
567
569
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
568
570
|
}
|
|
569
571
|
with open(output_folder_abs_path / "analysis.pkl", "wb") as f:
|
|
@@ -572,29 +574,29 @@ def analyze_dataset(
|
|
|
572
574
|
f,
|
|
573
575
|
)
|
|
574
576
|
|
|
575
|
-
|
|
576
|
-
|
|
577
|
+
joint_adjusted_bread_inverse_cond = jnp.linalg.cond(
|
|
578
|
+
raw_joint_adjusted_bread_inverse_matrix
|
|
577
579
|
)
|
|
578
580
|
logger.info(
|
|
579
|
-
"Joint
|
|
580
|
-
|
|
581
|
+
"Joint adjusted bread inverse condition number: %f",
|
|
582
|
+
joint_adjusted_bread_inverse_cond,
|
|
581
583
|
)
|
|
582
584
|
|
|
583
585
|
debug_pieces_dict = {
|
|
584
586
|
"theta_est": theta_est,
|
|
585
|
-
"
|
|
587
|
+
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
586
588
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
587
|
-
"raw_joint_bread_inverse_matrix":
|
|
588
|
-
"stabilized_joint_bread_inverse_matrix":
|
|
589
|
-
"joint_meat_matrix":
|
|
589
|
+
"raw_joint_bread_inverse_matrix": raw_joint_adjusted_bread_inverse_matrix,
|
|
590
|
+
"stabilized_joint_bread_inverse_matrix": stabilized_joint_adjusted_bread_inverse_matrix,
|
|
591
|
+
"joint_meat_matrix": joint_adjusted_meat_matrix,
|
|
590
592
|
"classical_bread_inverse_matrix": classical_bread_inverse_matrix,
|
|
591
593
|
"classical_meat_matrix": classical_meat_matrix,
|
|
592
|
-
"all_estimating_function_stacks":
|
|
593
|
-
"joint_bread_inverse_condition_number":
|
|
594
|
+
"all_estimating_function_stacks": per_subject_estimating_function_stacks,
|
|
595
|
+
"joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
|
|
594
596
|
"all_post_update_betas": all_post_update_betas,
|
|
595
|
-
"
|
|
596
|
-
"
|
|
597
|
-
"
|
|
597
|
+
"per_subject_adjusted_corrections": per_subject_adjusted_corrections,
|
|
598
|
+
"per_subject_classical_corrections": per_subject_classical_corrections,
|
|
599
|
+
"per_subject_adjusted_meat_adjustments": per_subject_adjusted_meat_adjustments,
|
|
598
600
|
}
|
|
599
601
|
with open(output_folder_abs_path / "debug_pieces.pkl", "wb") as f:
|
|
600
602
|
pickle.dump(
|
|
@@ -604,25 +606,25 @@ def analyze_dataset(
|
|
|
604
606
|
|
|
605
607
|
if collect_data_for_blowup_supervised_learning:
|
|
606
608
|
datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
|
|
607
|
-
|
|
608
|
-
|
|
609
|
+
raw_joint_adjusted_bread_inverse_matrix,
|
|
610
|
+
joint_adjusted_bread_inverse_cond,
|
|
609
611
|
avg_estimating_function_stack,
|
|
610
|
-
|
|
612
|
+
per_subject_estimating_function_stacks,
|
|
611
613
|
all_post_update_betas,
|
|
612
|
-
|
|
613
|
-
|
|
614
|
+
analysis_df,
|
|
615
|
+
active_col_name,
|
|
614
616
|
calendar_t_col_name,
|
|
615
617
|
action_prob_col_name,
|
|
616
|
-
|
|
618
|
+
subject_id_col_name,
|
|
617
619
|
reward_col_name,
|
|
618
620
|
theta_est,
|
|
619
|
-
|
|
620
|
-
|
|
621
|
+
adjusted_sandwich_var_estimate,
|
|
622
|
+
subject_ids,
|
|
621
623
|
beta_dim,
|
|
622
624
|
theta_dim,
|
|
623
625
|
initial_policy_num,
|
|
624
626
|
beta_index_by_policy_num,
|
|
625
|
-
|
|
627
|
+
policy_num_by_decision_time_by_subject_id,
|
|
626
628
|
theta_calculation_func,
|
|
627
629
|
action_prob_func,
|
|
628
630
|
action_prob_func_args_beta_index,
|
|
@@ -630,16 +632,16 @@ def analyze_dataset(
|
|
|
630
632
|
inference_func_type,
|
|
631
633
|
inference_func_args_theta_index,
|
|
632
634
|
inference_func_args_action_prob_index,
|
|
633
|
-
|
|
635
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
634
636
|
action_prob_func_args,
|
|
635
|
-
|
|
637
|
+
action_by_decision_time_by_subject_id,
|
|
636
638
|
)
|
|
637
639
|
|
|
638
640
|
with open(output_folder_abs_path / "supervised_learning_datum.pkl", "wb") as f:
|
|
639
641
|
pickle.dump(datum_and_label_dict, f)
|
|
640
642
|
|
|
641
643
|
print(f"\nParameter estimate:\n {theta_est}")
|
|
642
|
-
print(f"\
|
|
644
|
+
print(f"\nAdjusted sandwich variance estimate:\n {adjusted_sandwich_var_estimate}")
|
|
643
645
|
print(
|
|
644
646
|
f"\nClassical sandwich variance estimate:\n {classical_sandwich_var_estimate}\n"
|
|
645
647
|
)
|
|
@@ -650,15 +652,15 @@ def analyze_dataset(
|
|
|
650
652
|
def process_inference_func_args(
|
|
651
653
|
inference_func: callable,
|
|
652
654
|
inference_func_args_theta_index: int,
|
|
653
|
-
|
|
655
|
+
analysis_df: pd.DataFrame,
|
|
654
656
|
theta_est: jnp.ndarray,
|
|
655
657
|
action_prob_col_name: str,
|
|
656
658
|
calendar_t_col_name: str,
|
|
657
|
-
|
|
658
|
-
|
|
659
|
+
subject_id_col_name: str,
|
|
660
|
+
active_col_name: str,
|
|
659
661
|
) -> tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int]:
|
|
660
662
|
"""
|
|
661
|
-
Collects the inference function arguments for each
|
|
663
|
+
Collects the inference function arguments for each subject from the analysis DataFrame.
|
|
662
664
|
|
|
663
665
|
Note that theta and action probabilities, if present, will be replaced later
|
|
664
666
|
so that the function can be differentiated with respect to shared versions
|
|
@@ -669,32 +671,32 @@ def process_inference_func_args(
|
|
|
669
671
|
The inference function to be used.
|
|
670
672
|
inference_func_args_theta_index (int):
|
|
671
673
|
The index of the theta parameter in the inference function's arguments.
|
|
672
|
-
|
|
673
|
-
The
|
|
674
|
+
analysis_df (pandas.DataFrame):
|
|
675
|
+
The analysis DataFrame.
|
|
674
676
|
theta_est (jnp.ndarray):
|
|
675
677
|
The estimate of the parameter vector.
|
|
676
678
|
action_prob_col_name (str):
|
|
677
|
-
The name of the column in the
|
|
679
|
+
The name of the column in the analysis DataFrame that gives action probabilities.
|
|
678
680
|
calendar_t_col_name (str):
|
|
679
|
-
The name of the column in the
|
|
680
|
-
|
|
681
|
-
The name of the column in the
|
|
682
|
-
|
|
683
|
-
The name of the binary column in the
|
|
681
|
+
The name of the column in the analysis DataFrame that indicates calendar time.
|
|
682
|
+
subject_id_col_name (str):
|
|
683
|
+
The name of the column in the analysis DataFrame that indicates subject ID.
|
|
684
|
+
active_col_name (str):
|
|
685
|
+
The name of the binary column in the analysis DataFrame that indicates whether a subject is in the deployment.
|
|
684
686
|
Returns:
|
|
685
687
|
tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int, dict[collections.abc.Hashable, jnp.ndarray[int]]]:
|
|
686
688
|
A tuple containing
|
|
687
|
-
- the inference function arguments dictionary for each
|
|
689
|
+
- the inference function arguments dictionary for each subject
|
|
688
690
|
- the index of the action probabilities argument
|
|
689
|
-
- a dictionary mapping
|
|
691
|
+
- a dictionary mapping subject IDs to the decision times to which action probabilities correspond
|
|
690
692
|
"""
|
|
691
693
|
|
|
692
694
|
num_args = inference_func.__code__.co_argcount
|
|
693
695
|
inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
|
|
694
|
-
|
|
696
|
+
inference_func_args_by_subject_id = {}
|
|
695
697
|
|
|
696
698
|
inference_func_args_action_prob_index = -1
|
|
697
|
-
|
|
699
|
+
inference_action_prob_decision_times_by_subject_id = {}
|
|
698
700
|
|
|
699
701
|
using_action_probs = action_prob_col_name in inference_func_arg_names
|
|
700
702
|
if using_action_probs:
|
|
@@ -702,34 +704,36 @@ def process_inference_func_args(
|
|
|
702
704
|
action_prob_col_name
|
|
703
705
|
)
|
|
704
706
|
|
|
705
|
-
for
|
|
706
|
-
|
|
707
|
-
|
|
707
|
+
for subject_id in analysis_df[subject_id_col_name].unique():
|
|
708
|
+
subject_args_list = []
|
|
709
|
+
filtered_subject_data = analysis_df.loc[
|
|
710
|
+
analysis_df[subject_id_col_name] == subject_id
|
|
711
|
+
]
|
|
708
712
|
for idx, col_name in enumerate(inference_func_arg_names):
|
|
709
713
|
if idx == inference_func_args_theta_index:
|
|
710
|
-
|
|
714
|
+
subject_args_list.append(theta_est)
|
|
711
715
|
continue
|
|
712
|
-
|
|
713
|
-
|
|
716
|
+
subject_args_list.append(
|
|
717
|
+
get_active_df_column(filtered_subject_data, col_name, active_col_name)
|
|
714
718
|
)
|
|
715
|
-
|
|
719
|
+
inference_func_args_by_subject_id[subject_id] = tuple(subject_args_list)
|
|
716
720
|
if using_action_probs:
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
721
|
+
inference_action_prob_decision_times_by_subject_id[subject_id] = (
|
|
722
|
+
get_active_df_column(
|
|
723
|
+
filtered_subject_data, calendar_t_col_name, active_col_name
|
|
720
724
|
)
|
|
721
725
|
)
|
|
722
726
|
|
|
723
727
|
return (
|
|
724
|
-
|
|
728
|
+
inference_func_args_by_subject_id,
|
|
725
729
|
inference_func_args_action_prob_index,
|
|
726
|
-
|
|
730
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
727
731
|
)
|
|
728
732
|
|
|
729
733
|
|
|
730
|
-
def
|
|
734
|
+
def single_subject_weighted_estimating_function_stacker(
|
|
731
735
|
beta_dim: int,
|
|
732
|
-
|
|
736
|
+
subject_id: collections.abc.Hashable,
|
|
733
737
|
action_prob_func: callable,
|
|
734
738
|
algorithm_estimating_func: callable,
|
|
735
739
|
inference_estimating_func: callable,
|
|
@@ -763,12 +767,12 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
763
767
|
beta_dim (list[jnp.ndarray]):
|
|
764
768
|
A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
|
|
765
769
|
|
|
766
|
-
|
|
767
|
-
The
|
|
770
|
+
subject_id (collections.abc.Hashable):
|
|
771
|
+
The subject ID for which to compute the weighted estimating function stack.
|
|
768
772
|
|
|
769
773
|
action_prob_func (callable):
|
|
770
774
|
The function used to compute the probability of action 1 at a given decision time for
|
|
771
|
-
a particular
|
|
775
|
+
a particular subject given their state and the algorithm parameters.
|
|
772
776
|
|
|
773
777
|
algorithm_estimating_func (callable):
|
|
774
778
|
The estimating function that corresponds to algorithm updates.
|
|
@@ -783,9 +787,9 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
783
787
|
The index of the theta parameter in the inference loss or estimating function arguments.
|
|
784
788
|
|
|
785
789
|
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
786
|
-
A map from decision times to tuples of arguments for this
|
|
790
|
+
A map from decision times to tuples of arguments for this subject for the action
|
|
787
791
|
probability function. This is for all decision times (args are an empty
|
|
788
|
-
tuple if they are not in the
|
|
792
|
+
tuple if they are not in the deployment). Should be sorted by decision time. NOTE THAT THESE
|
|
789
793
|
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
790
794
|
will occur.
|
|
791
795
|
|
|
@@ -796,21 +800,21 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
796
800
|
|
|
797
801
|
threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
798
802
|
A map from policy numbers to tuples containing the arguments for
|
|
799
|
-
the corresponding estimating functions for this
|
|
803
|
+
the corresponding estimating functions for this subject, with the shared betas threaded in
|
|
800
804
|
for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
|
|
801
805
|
should be sorted.
|
|
802
806
|
|
|
803
807
|
threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
804
808
|
A tuple containing the arguments for the inference
|
|
805
|
-
estimating function for this
|
|
809
|
+
estimating function for this subject, with the shared betas threaded in for differentiation.
|
|
806
810
|
|
|
807
811
|
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
808
812
|
A dictionary mapping decision times to the policy number in use. This may be
|
|
809
|
-
|
|
813
|
+
subject-specific. Should be sorted by decision time. Only applies to active decision
|
|
810
814
|
times!
|
|
811
815
|
|
|
812
816
|
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
813
|
-
A dictionary mapping decision times to actions taken. Only applies to
|
|
817
|
+
A dictionary mapping decision times to actions taken. Only applies to active decision
|
|
814
818
|
times!
|
|
815
819
|
|
|
816
820
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
@@ -818,19 +822,21 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
818
822
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
819
823
|
|
|
820
824
|
Returns:
|
|
821
|
-
jnp.ndarray: A 1-D JAX NumPy array representing the
|
|
825
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
822
826
|
stack.
|
|
823
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
824
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
825
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
827
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adaptive meat contribution.
|
|
828
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
829
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
826
830
|
"""
|
|
827
831
|
|
|
828
|
-
logger.info(
|
|
832
|
+
logger.info(
|
|
833
|
+
"Computing weighted estimating function stack for subject %s.", subject_id
|
|
834
|
+
)
|
|
829
835
|
|
|
830
836
|
# First, reformat the supplied data into more convenient structures.
|
|
831
837
|
|
|
832
838
|
# 1. Form a dictionary mapping policy numbers to the first time they were
|
|
833
|
-
# applicable (for this
|
|
839
|
+
# applicable (for this subject). Note that this includes ALL policies, initial
|
|
834
840
|
# fallbacks included.
|
|
835
841
|
# Collect the first time after the first update separately for convenience.
|
|
836
842
|
# These are both used to form the Radon-Nikodym weights for the right times.
|
|
@@ -839,38 +845,38 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
839
845
|
beta_index_by_policy_num,
|
|
840
846
|
)
|
|
841
847
|
|
|
842
|
-
# 2. Get the start and end times for this
|
|
843
|
-
|
|
844
|
-
|
|
848
|
+
# 2. Get the start and end times for this subject.
|
|
849
|
+
subject_start_time = math.inf
|
|
850
|
+
subject_end_time = -math.inf
|
|
845
851
|
for decision_time in action_by_decision_time:
|
|
846
|
-
|
|
847
|
-
|
|
852
|
+
subject_start_time = min(subject_start_time, decision_time)
|
|
853
|
+
subject_end_time = max(subject_end_time, decision_time)
|
|
848
854
|
|
|
849
855
|
# 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
|
|
850
856
|
logger.info(
|
|
851
|
-
"Computing the algorithm component of the weighted estimating function stack for
|
|
852
|
-
|
|
857
|
+
"Computing the algorithm component of the weighted estimating function stack for subject %s.",
|
|
858
|
+
subject_id,
|
|
853
859
|
)
|
|
854
860
|
|
|
855
|
-
|
|
861
|
+
active_action_prob_func_args = [
|
|
856
862
|
args for args in action_prob_func_args_by_decision_time.values() if args
|
|
857
863
|
]
|
|
858
|
-
|
|
864
|
+
active_betas_list_by_decision_time_index = jnp.array(
|
|
859
865
|
[
|
|
860
866
|
action_prob_func_args[action_prob_func_args_beta_index]
|
|
861
|
-
for action_prob_func_args in
|
|
867
|
+
for action_prob_func_args in active_action_prob_func_args
|
|
862
868
|
]
|
|
863
869
|
)
|
|
864
|
-
|
|
870
|
+
active_actions_list_by_decision_time_index = jnp.array(
|
|
865
871
|
list(action_by_decision_time.values())
|
|
866
872
|
)
|
|
867
873
|
|
|
868
874
|
# Sort the threaded args by decision time to be cautious. We check if the
|
|
869
|
-
#
|
|
870
|
-
# subset of the
|
|
875
|
+
# subject id is present in the subject args dict because we may call this on a
|
|
876
|
+
# subset of the subject arg dict when we are batching arguments by shape
|
|
871
877
|
sorted_threaded_action_prob_args_by_decision_time = {
|
|
872
878
|
decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
|
|
873
|
-
for decision_time in range(
|
|
879
|
+
for decision_time in range(subject_start_time, subject_end_time + 1)
|
|
874
880
|
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
875
881
|
}
|
|
876
882
|
|
|
@@ -901,19 +907,19 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
901
907
|
# Just grab the original beta from the update function arguments. This is the same
|
|
902
908
|
# value, but impervious to differentiation with respect to all_post_update_betas. The
|
|
903
909
|
# args, on the other hand, are a function of all_post_update_betas.
|
|
904
|
-
|
|
910
|
+
active_weights = jax.vmap(
|
|
905
911
|
fun=get_radon_nikodym_weight,
|
|
906
912
|
in_axes=[0, None, None, 0] + batch_axes,
|
|
907
913
|
out_axes=0,
|
|
908
914
|
)(
|
|
909
|
-
|
|
915
|
+
active_betas_list_by_decision_time_index,
|
|
910
916
|
action_prob_func,
|
|
911
917
|
action_prob_func_args_beta_index,
|
|
912
|
-
|
|
918
|
+
active_actions_list_by_decision_time_index,
|
|
913
919
|
*batched_threaded_arg_tensors,
|
|
914
920
|
)
|
|
915
921
|
|
|
916
|
-
|
|
922
|
+
active_index = 0
|
|
917
923
|
decision_time_to_all_weights_index_offset = min(
|
|
918
924
|
sorted_threaded_action_prob_args_by_decision_time
|
|
919
925
|
)
|
|
@@ -922,35 +928,35 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
922
928
|
decision_time,
|
|
923
929
|
args,
|
|
924
930
|
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
925
|
-
all_weights_raw.append(
|
|
926
|
-
|
|
931
|
+
all_weights_raw.append(active_weights[active_index] if args else 1.0)
|
|
932
|
+
active_index += 1
|
|
927
933
|
all_weights = jnp.array(all_weights_raw)
|
|
928
934
|
|
|
929
935
|
algorithm_component = jnp.concatenate(
|
|
930
936
|
[
|
|
931
937
|
# Here we compute a product of Radon-Nikodym weights
|
|
932
938
|
# for all decision times after the first update and before the update
|
|
933
|
-
# update under consideration took effect, for which the
|
|
939
|
+
# update under consideration took effect, for which the subject was in the deployment.
|
|
934
940
|
(
|
|
935
941
|
jnp.prod(
|
|
936
942
|
all_weights[
|
|
937
|
-
# The earliest time after the first update where the
|
|
938
|
-
# the
|
|
943
|
+
# The earliest time after the first update where the subject was in
|
|
944
|
+
# the deployment
|
|
939
945
|
max(
|
|
940
946
|
first_time_after_first_update,
|
|
941
|
-
|
|
947
|
+
subject_start_time,
|
|
942
948
|
)
|
|
943
949
|
- decision_time_to_all_weights_index_offset :
|
|
944
|
-
# One more than the latest time the
|
|
950
|
+
# One more than the latest time the subject was in the deployment before the time
|
|
945
951
|
# the update under consideration first applied. Note the + 1 because range
|
|
946
952
|
# does not include the right endpoint.
|
|
947
953
|
min(
|
|
948
954
|
min_time_by_policy_num.get(policy_num, math.inf),
|
|
949
|
-
|
|
955
|
+
subject_end_time + 1,
|
|
950
956
|
)
|
|
951
957
|
- decision_time_to_all_weights_index_offset,
|
|
952
958
|
]
|
|
953
|
-
# If the
|
|
959
|
+
# If the subject exited the deployment before there were any updates,
|
|
954
960
|
# this variable will be None and the above code to grab a weight would
|
|
955
961
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
956
962
|
# if they have data to contribute to the update.
|
|
@@ -958,8 +964,8 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
958
964
|
else 1
|
|
959
965
|
) # Now use the above to weight the alg estimating function for this update
|
|
960
966
|
* algorithm_estimating_func(*update_args)
|
|
961
|
-
# If there are no arguments for the update function, the
|
|
962
|
-
#
|
|
967
|
+
# If there are no arguments for the update function, the subject is not yet in the
|
|
968
|
+
# deployment, so we just add a zero vector contribution to the sum across subjects.
|
|
963
969
|
# Note that after they exit, they still contribute all their data to later
|
|
964
970
|
# updates.
|
|
965
971
|
if update_args
|
|
@@ -978,17 +984,17 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
978
984
|
)
|
|
979
985
|
# 4. Form the weighted inference estimating equation.
|
|
980
986
|
logger.info(
|
|
981
|
-
"Computing the inference component of the weighted estimating function stack for
|
|
982
|
-
|
|
987
|
+
"Computing the inference component of the weighted estimating function stack for subject %s.",
|
|
988
|
+
subject_id,
|
|
983
989
|
)
|
|
984
990
|
inference_component = jnp.prod(
|
|
985
991
|
all_weights[
|
|
986
|
-
max(first_time_after_first_update,
|
|
987
|
-
- decision_time_to_all_weights_index_offset :
|
|
992
|
+
max(first_time_after_first_update, subject_start_time)
|
|
993
|
+
- decision_time_to_all_weights_index_offset : subject_end_time
|
|
988
994
|
+ 1
|
|
989
995
|
- decision_time_to_all_weights_index_offset,
|
|
990
996
|
]
|
|
991
|
-
# If the
|
|
997
|
+
# If the subject exited the deployment before there were any updates,
|
|
992
998
|
# this variable will be None and the above code to grab a weight would
|
|
993
999
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
994
1000
|
# if they have data to contribute here (pretty sure everyone should?)
|
|
@@ -997,18 +1003,18 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
997
1003
|
) * inference_estimating_func(*threaded_inference_func_args)
|
|
998
1004
|
|
|
999
1005
|
# 5. Concatenate the two components to form the weighted estimating function stack for this
|
|
1000
|
-
#
|
|
1006
|
+
# subject.
|
|
1001
1007
|
weighted_stack = jnp.concatenate([algorithm_component, inference_component])
|
|
1002
1008
|
|
|
1003
1009
|
# 6. Return the following outputs:
|
|
1004
|
-
# a. The first is simply the weighted estimating function stack for this
|
|
1010
|
+
# a. The first is simply the weighted estimating function stack for this subject. The average
|
|
1005
1011
|
# of these is what we differentiate with respect to theta to form the inverse adaptive joint
|
|
1006
1012
|
# bread matrix, and we also compare that average to zero to check the estimating functions'
|
|
1007
1013
|
# fidelity.
|
|
1008
|
-
# b. The average outer product of these per-
|
|
1014
|
+
# b. The average outer product of these per-subject stacks across subjects is the adaptive joint meat
|
|
1009
1015
|
# matrix, hence the second output.
|
|
1010
|
-
# c. The third output is averaged across
|
|
1011
|
-
# d. The fourth output is averaged across
|
|
1016
|
+
# c. The third output is averaged across subjects to obtain the classical meat matrix.
|
|
1017
|
+
# d. The fourth output is averaged across subjects to obtain the inverse classical bread
|
|
1012
1018
|
# matrix.
|
|
1013
1019
|
return (
|
|
1014
1020
|
weighted_stack,
|
|
@@ -1024,7 +1030,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1024
1030
|
flattened_betas_and_theta: jnp.ndarray,
|
|
1025
1031
|
beta_dim: int,
|
|
1026
1032
|
theta_dim: int,
|
|
1027
|
-
|
|
1033
|
+
subject_ids: jnp.ndarray,
|
|
1028
1034
|
action_prob_func: callable,
|
|
1029
1035
|
action_prob_func_args_beta_index: int,
|
|
1030
1036
|
alg_update_func: callable,
|
|
@@ -1037,29 +1043,31 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1037
1043
|
inference_func_type: str,
|
|
1038
1044
|
inference_func_args_theta_index: int,
|
|
1039
1045
|
inference_func_args_action_prob_index: int,
|
|
1040
|
-
|
|
1046
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
1041
1047
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1042
1048
|
],
|
|
1043
|
-
|
|
1049
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
1044
1050
|
collections.abc.Hashable, dict[int, int | float]
|
|
1045
1051
|
],
|
|
1046
1052
|
initial_policy_num: int | float,
|
|
1047
1053
|
beta_index_by_policy_num: dict[int | float, int],
|
|
1048
|
-
|
|
1049
|
-
|
|
1054
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1055
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
1050
1056
|
collections.abc.Hashable, list[int]
|
|
1051
1057
|
],
|
|
1052
|
-
|
|
1058
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
1053
1059
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
1054
1060
|
],
|
|
1055
|
-
|
|
1061
|
+
action_by_decision_time_by_subject_id: dict[
|
|
1062
|
+
collections.abc.Hashable, dict[int, int]
|
|
1063
|
+
],
|
|
1056
1064
|
suppress_all_data_checks: bool,
|
|
1057
1065
|
suppress_interactive_data_checks: bool,
|
|
1058
1066
|
) -> tuple[
|
|
1059
1067
|
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
1060
1068
|
]:
|
|
1061
1069
|
"""
|
|
1062
|
-
Computes the average weighted estimating function stack across all
|
|
1070
|
+
Computes the average weighted estimating function stack across all subjects, along with
|
|
1063
1071
|
auxiliary values used to construct the adaptive and classical sandwich variances.
|
|
1064
1072
|
|
|
1065
1073
|
Args:
|
|
@@ -1071,8 +1079,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1071
1079
|
The dimension of each of the beta parameters.
|
|
1072
1080
|
theta_dim (int):
|
|
1073
1081
|
The dimension of the theta parameter.
|
|
1074
|
-
|
|
1075
|
-
A 1D JAX NumPy array of
|
|
1082
|
+
subject_ids (jnp.ndarray):
|
|
1083
|
+
A 1D JAX NumPy array of subject IDs.
|
|
1076
1084
|
action_prob_func (callable):
|
|
1077
1085
|
The action probability function.
|
|
1078
1086
|
action_prob_func_args_beta_index (int):
|
|
@@ -1100,29 +1108,29 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1100
1108
|
inference_func_args_action_prob_index (int):
|
|
1101
1109
|
The index of action probabilities in the inference function arguments tuple, if
|
|
1102
1110
|
applicable. -1 otherwise.
|
|
1103
|
-
|
|
1104
|
-
A dictionary mapping decision times to maps of
|
|
1105
|
-
required to compute action probabilities for this
|
|
1106
|
-
|
|
1107
|
-
A map of
|
|
1108
|
-
Only applies to
|
|
1111
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1112
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
1113
|
+
required to compute action probabilities for this subject.
|
|
1114
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1115
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
1116
|
+
Only applies to active decision times!
|
|
1109
1117
|
initial_policy_num (int | float):
|
|
1110
1118
|
The policy number of the initial policy before any updates.
|
|
1111
1119
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
1112
1120
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1113
1121
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1114
|
-
|
|
1115
|
-
A dictionary mapping
|
|
1116
|
-
|
|
1117
|
-
For each
|
|
1118
|
-
provided. Typically just
|
|
1122
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1123
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
1124
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
1125
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
1126
|
+
provided. Typically just active times if action probabilites are used in the inference
|
|
1119
1127
|
loss or estimating function.
|
|
1120
|
-
|
|
1121
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
1128
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
1129
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
1122
1130
|
to their respective update function arguments.
|
|
1123
|
-
|
|
1124
|
-
A dictionary mapping
|
|
1125
|
-
Only applies to
|
|
1131
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1132
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
1133
|
+
Only applies to active decision times!
|
|
1126
1134
|
suppress_all_data_checks (bool):
|
|
1127
1135
|
If True, suppresses carrying out any data checks at all.
|
|
1128
1136
|
suppress_interactive_data_checks (bool):
|
|
@@ -1136,10 +1144,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1136
1144
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1137
1145
|
A tuple containing
|
|
1138
1146
|
1. the average weighted estimating function stack
|
|
1139
|
-
2. the
|
|
1140
|
-
3. the
|
|
1141
|
-
4. the
|
|
1142
|
-
5. raw per-
|
|
1147
|
+
2. the subject-level adaptive meat matrix contributions
|
|
1148
|
+
3. the subject-level classical meat matrix contributions
|
|
1149
|
+
4. the subject-level inverse classical bread matrix contributions
|
|
1150
|
+
5. raw per-subject weighted estimating function
|
|
1143
1151
|
stacks.
|
|
1144
1152
|
"""
|
|
1145
1153
|
|
|
@@ -1166,15 +1174,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1166
1174
|
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
1167
1175
|
# values should be the same, but not connected to the parameter we are differentiating
|
|
1168
1176
|
# with respect to. Note we will also find it useful below to have the action probability args
|
|
1169
|
-
# nested dict structure flipped to be
|
|
1177
|
+
# nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
|
|
1170
1178
|
|
|
1171
|
-
logger.info("Threading in betas to action probability arguments for all
|
|
1179
|
+
logger.info("Threading in betas to action probability arguments for all subjects.")
|
|
1172
1180
|
(
|
|
1173
|
-
|
|
1174
|
-
|
|
1181
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1182
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
1175
1183
|
) = thread_action_prob_func_args(
|
|
1176
|
-
|
|
1177
|
-
|
|
1184
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1185
|
+
policy_num_by_decision_time_by_subject_id,
|
|
1178
1186
|
initial_policy_num,
|
|
1179
1187
|
betas,
|
|
1180
1188
|
beta_index_by_policy_num,
|
|
@@ -1186,17 +1194,17 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1186
1194
|
# arguments with the central betas introduced.
|
|
1187
1195
|
logger.info(
|
|
1188
1196
|
"Threading in betas and beta-dependent action probabilities to algorithm update "
|
|
1189
|
-
"function args for all
|
|
1197
|
+
"function args for all subjects"
|
|
1190
1198
|
)
|
|
1191
|
-
|
|
1192
|
-
|
|
1199
|
+
threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
|
|
1200
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1193
1201
|
betas,
|
|
1194
1202
|
beta_index_by_policy_num,
|
|
1195
1203
|
alg_update_func_args_beta_index,
|
|
1196
1204
|
alg_update_func_args_action_prob_index,
|
|
1197
1205
|
alg_update_func_args_action_prob_times_index,
|
|
1198
1206
|
alg_update_func_args_previous_betas_index,
|
|
1199
|
-
|
|
1207
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1200
1208
|
action_prob_func,
|
|
1201
1209
|
)
|
|
1202
1210
|
|
|
@@ -1206,8 +1214,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1206
1214
|
if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
|
|
1207
1215
|
input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
|
|
1208
1216
|
algorithm_estimating_func,
|
|
1209
|
-
|
|
1210
|
-
|
|
1217
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1218
|
+
threaded_update_func_args_by_policy_num_by_subject_id,
|
|
1211
1219
|
suppress_interactive_data_checks,
|
|
1212
1220
|
)
|
|
1213
1221
|
|
|
@@ -1216,15 +1224,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1216
1224
|
# arguments with the central betas introduced.
|
|
1217
1225
|
logger.info(
|
|
1218
1226
|
"Threading in theta and beta-dependent action probabilities to inference update "
|
|
1219
|
-
"function args for all
|
|
1227
|
+
"function args for all subjects"
|
|
1220
1228
|
)
|
|
1221
|
-
|
|
1222
|
-
|
|
1229
|
+
threaded_inference_func_args_by_subject_id = thread_inference_func_args(
|
|
1230
|
+
inference_func_args_by_subject_id,
|
|
1223
1231
|
inference_func_args_theta_index,
|
|
1224
1232
|
theta,
|
|
1225
1233
|
inference_func_args_action_prob_index,
|
|
1226
|
-
|
|
1227
|
-
|
|
1234
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1235
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
1228
1236
|
action_prob_func,
|
|
1229
1237
|
)
|
|
1230
1238
|
|
|
@@ -1234,32 +1242,32 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1234
1242
|
if not suppress_all_data_checks and inference_func_args_action_prob_index >= 0:
|
|
1235
1243
|
input_checks.require_threaded_inference_estimating_function_args_equivalent(
|
|
1236
1244
|
inference_estimating_func,
|
|
1237
|
-
|
|
1238
|
-
|
|
1245
|
+
inference_func_args_by_subject_id,
|
|
1246
|
+
threaded_inference_func_args_by_subject_id,
|
|
1239
1247
|
suppress_interactive_data_checks,
|
|
1240
1248
|
)
|
|
1241
1249
|
|
|
1242
|
-
# 5. Now we can compute the weighted estimating function stacks for all
|
|
1250
|
+
# 5. Now we can compute the weighted estimating function stacks for all subjects
|
|
1243
1251
|
# as well as collect related values used to construct the adaptive and classical
|
|
1244
1252
|
# sandwich variances.
|
|
1245
1253
|
results = [
|
|
1246
|
-
|
|
1254
|
+
single_subject_weighted_estimating_function_stacker(
|
|
1247
1255
|
beta_dim,
|
|
1248
|
-
|
|
1256
|
+
subject_id,
|
|
1249
1257
|
action_prob_func,
|
|
1250
1258
|
algorithm_estimating_func,
|
|
1251
1259
|
inference_estimating_func,
|
|
1252
1260
|
action_prob_func_args_beta_index,
|
|
1253
1261
|
inference_func_args_theta_index,
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1262
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1263
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1264
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
|
|
1265
|
+
threaded_inference_func_args_by_subject_id[subject_id],
|
|
1266
|
+
policy_num_by_decision_time_by_subject_id[subject_id],
|
|
1267
|
+
action_by_decision_time_by_subject_id[subject_id],
|
|
1260
1268
|
beta_index_by_policy_num,
|
|
1261
1269
|
)
|
|
1262
|
-
for
|
|
1270
|
+
for subject_id in subject_ids.tolist()
|
|
1263
1271
|
]
|
|
1264
1272
|
|
|
1265
1273
|
stacks = jnp.array([result[0] for result in results])
|
|
@@ -1270,10 +1278,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1270
1278
|
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1271
1279
|
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1272
1280
|
# for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
|
|
1273
|
-
# bread matrices. The raw per-
|
|
1281
|
+
# bread matrices. The raw per-subject stacks are also returned for debugging purposes.
|
|
1274
1282
|
|
|
1275
1283
|
# Note that returning the raw stacks here as the first arguments is potentially
|
|
1276
|
-
# memory-intensive when combined with differentiation. Keep this in mind if the per-
|
|
1284
|
+
# memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
|
|
1277
1285
|
# inverse contributions are needed for something like CR2/CR3 small-sample corrections.
|
|
1278
1286
|
return jnp.mean(stacks, axis=0), (
|
|
1279
1287
|
jnp.mean(stacks, axis=0),
|
|
@@ -1284,10 +1292,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1284
1292
|
)
|
|
1285
1293
|
|
|
1286
1294
|
|
|
1287
|
-
def
|
|
1295
|
+
def construct_classical_and_adjusted_sandwiches(
|
|
1288
1296
|
theta_est: jnp.ndarray,
|
|
1289
1297
|
all_post_update_betas: jnp.ndarray,
|
|
1290
|
-
|
|
1298
|
+
subject_ids: jnp.ndarray,
|
|
1291
1299
|
action_prob_func: callable,
|
|
1292
1300
|
action_prob_func_args_beta_index: int,
|
|
1293
1301
|
alg_update_func: callable,
|
|
@@ -1300,32 +1308,34 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1300
1308
|
inference_func_type: str,
|
|
1301
1309
|
inference_func_args_theta_index: int,
|
|
1302
1310
|
inference_func_args_action_prob_index: int,
|
|
1303
|
-
|
|
1311
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
1304
1312
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1305
1313
|
],
|
|
1306
|
-
|
|
1314
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
1307
1315
|
collections.abc.Hashable, dict[int, int | float]
|
|
1308
1316
|
],
|
|
1309
1317
|
initial_policy_num: int | float,
|
|
1310
1318
|
beta_index_by_policy_num: dict[int | float, int],
|
|
1311
|
-
|
|
1312
|
-
|
|
1319
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1320
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
1313
1321
|
collections.abc.Hashable, list[int]
|
|
1314
1322
|
],
|
|
1315
|
-
|
|
1323
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
1316
1324
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
1317
1325
|
],
|
|
1318
|
-
|
|
1326
|
+
action_by_decision_time_by_subject_id: dict[
|
|
1327
|
+
collections.abc.Hashable, dict[int, int]
|
|
1328
|
+
],
|
|
1319
1329
|
suppress_all_data_checks: bool,
|
|
1320
1330
|
suppress_interactive_data_checks: bool,
|
|
1321
1331
|
small_sample_correction: str,
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1332
|
+
form_adjusted_meat_adjustments_explicitly: bool,
|
|
1333
|
+
stabilize_joint_adjusted_bread_inverse: bool,
|
|
1334
|
+
analysis_df: pd.DataFrame | None,
|
|
1335
|
+
active_col_name: str | None,
|
|
1326
1336
|
action_col_name: str | None,
|
|
1327
1337
|
calendar_t_col_name: str | None,
|
|
1328
|
-
|
|
1338
|
+
subject_id_col_name: str | None,
|
|
1329
1339
|
action_prob_func_args: tuple | None,
|
|
1330
1340
|
action_prob_col_name: str | None,
|
|
1331
1341
|
) -> tuple[
|
|
@@ -1354,8 +1364,8 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1354
1364
|
A 1-D JAX NumPy array representing the parameter estimate for inference.
|
|
1355
1365
|
all_post_update_betas (jnp.ndarray):
|
|
1356
1366
|
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
1357
|
-
|
|
1358
|
-
A 1-D JAX NumPy array holding all
|
|
1367
|
+
subject_ids (jnp.ndarray):
|
|
1368
|
+
A 1-D JAX NumPy array holding all subject IDs in the deployment.
|
|
1359
1369
|
action_prob_func (callable):
|
|
1360
1370
|
The action probability function.
|
|
1361
1371
|
action_prob_func_args_beta_index (int):
|
|
@@ -1383,29 +1393,29 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1383
1393
|
inference_func_args_action_prob_index (int):
|
|
1384
1394
|
The index of action probabilities in the inference function arguments tuple, if
|
|
1385
1395
|
applicable. -1 otherwise.
|
|
1386
|
-
|
|
1387
|
-
A dictionary mapping decision times to maps of
|
|
1388
|
-
required to compute action probabilities for this
|
|
1389
|
-
|
|
1390
|
-
A map of
|
|
1391
|
-
Only applies to
|
|
1396
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1397
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
1398
|
+
required to compute action probabilities for this subject.
|
|
1399
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1400
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
1401
|
+
Only applies to active decision times!
|
|
1392
1402
|
initial_policy_num (int | float):
|
|
1393
1403
|
The policy number of the initial policy before any updates.
|
|
1394
1404
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
1395
1405
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1396
1406
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1397
|
-
|
|
1398
|
-
A dictionary mapping
|
|
1399
|
-
|
|
1400
|
-
For each
|
|
1401
|
-
provided. Typically just
|
|
1407
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1408
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
1409
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
1410
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
1411
|
+
provided. Typically just active times if action probabilites are used in the inference
|
|
1402
1412
|
loss or estimating function.
|
|
1403
|
-
|
|
1404
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
1413
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
1414
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
1405
1415
|
to their respective update function arguments.
|
|
1406
|
-
|
|
1407
|
-
A dictionary mapping
|
|
1408
|
-
Only applies to
|
|
1416
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1417
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
1418
|
+
Only applies to active decision times!
|
|
1409
1419
|
suppress_all_data_checks (bool):
|
|
1410
1420
|
If True, suppresses carrying out any data checks at all.
|
|
1411
1421
|
suppress_interactive_data_checks (bool):
|
|
@@ -1415,27 +1425,27 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1415
1425
|
small_sample_correction (str):
|
|
1416
1426
|
The type of small sample correction to apply. See SmallSampleCorrections class for
|
|
1417
1427
|
options.
|
|
1418
|
-
|
|
1419
|
-
If True, explicitly forms the per-
|
|
1428
|
+
form_adjusted_meat_adjustments_explicitly (bool):
|
|
1429
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
|
|
1420
1430
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
1421
1431
|
adaptive sandwich is formed without doing this.
|
|
1422
|
-
|
|
1432
|
+
stabilize_joint_adjusted_bread_inverse (bool):
|
|
1423
1433
|
If True, will apply various techniques to stabilize the joint adaptive bread inverse if necessary.
|
|
1424
|
-
|
|
1425
|
-
The full
|
|
1426
|
-
|
|
1427
|
-
The name of the column in
|
|
1434
|
+
analysis_df (pd.DataFrame):
|
|
1435
|
+
The full analysis dataframe, needed if forming the adaptive meat adjustments explicitly.
|
|
1436
|
+
active_col_name (str):
|
|
1437
|
+
The name of the column in analysis_df indicating whether a subject is active at a given decision time.
|
|
1428
1438
|
action_col_name (str):
|
|
1429
|
-
The name of the column in
|
|
1439
|
+
The name of the column in analysis_df indicating the action taken at a given decision time.
|
|
1430
1440
|
calendar_t_col_name (str):
|
|
1431
|
-
The name of the column in
|
|
1432
|
-
|
|
1433
|
-
The name of the column in
|
|
1441
|
+
The name of the column in analysis_df indicating the calendar time of a given decision time.
|
|
1442
|
+
subject_id_col_name (str):
|
|
1443
|
+
The name of the column in analysis_df indicating the subject ID.
|
|
1434
1444
|
action_prob_func_args (tuple):
|
|
1435
1445
|
The arguments to be passed to the action probability function, needed if forming the
|
|
1436
1446
|
adaptive meat adjustments explicitly.
|
|
1437
1447
|
action_prob_col_name (str):
|
|
1438
|
-
The name of the column in
|
|
1448
|
+
The name of the column in analysis_df indicating the action probability of the action taken,
|
|
1439
1449
|
needed if forming the adaptive meat adjustments explicitly.
|
|
1440
1450
|
Returns:
|
|
1441
1451
|
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
@@ -1448,10 +1458,10 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1448
1458
|
- The classical meat matrix.
|
|
1449
1459
|
- The classical sandwich matrix.
|
|
1450
1460
|
- The average weighted estimating function stack.
|
|
1451
|
-
- All per-
|
|
1452
|
-
- The per-
|
|
1453
|
-
- The per-
|
|
1454
|
-
- The per-
|
|
1461
|
+
- All per-subject weighted estimating function stacks.
|
|
1462
|
+
- The per-subject adaptive meat small-sample corrections.
|
|
1463
|
+
- The per-subject classical meat small-sample corrections.
|
|
1464
|
+
- The per-subject adaptive meat adjustments, if form_adjusted_meat_adjustments_explicitly
|
|
1455
1465
|
is True, otherwise an array of NaNs.
|
|
1456
1466
|
"""
|
|
1457
1467
|
logger.info(
|
|
@@ -1459,13 +1469,13 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1459
1469
|
)
|
|
1460
1470
|
theta_dim = theta_est.shape[0]
|
|
1461
1471
|
beta_dim = all_post_update_betas.shape[1]
|
|
1462
|
-
# Note that these "contributions" are per-
|
|
1463
|
-
|
|
1472
|
+
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
1473
|
+
raw_joint_adjusted_bread_inverse_matrix, (
|
|
1464
1474
|
avg_estimating_function_stack,
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1475
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1476
|
+
per_subject_classical_meat_contributions,
|
|
1477
|
+
per_subject_classical_bread_inverse_contributions,
|
|
1478
|
+
per_subject_estimating_function_stacks,
|
|
1469
1479
|
) = jax.jacrev(
|
|
1470
1480
|
get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
|
|
1471
1481
|
)(
|
|
@@ -1475,7 +1485,7 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1475
1485
|
flatten_params(all_post_update_betas, theta_est),
|
|
1476
1486
|
beta_dim,
|
|
1477
1487
|
theta_dim,
|
|
1478
|
-
|
|
1488
|
+
subject_ids,
|
|
1479
1489
|
action_prob_func,
|
|
1480
1490
|
action_prob_func_args_beta_index,
|
|
1481
1491
|
alg_update_func,
|
|
@@ -1488,87 +1498,87 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1488
1498
|
inference_func_type,
|
|
1489
1499
|
inference_func_args_theta_index,
|
|
1490
1500
|
inference_func_args_action_prob_index,
|
|
1491
|
-
|
|
1492
|
-
|
|
1501
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1502
|
+
policy_num_by_decision_time_by_subject_id,
|
|
1493
1503
|
initial_policy_num,
|
|
1494
1504
|
beta_index_by_policy_num,
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1505
|
+
inference_func_args_by_subject_id,
|
|
1506
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
1507
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1508
|
+
action_by_decision_time_by_subject_id,
|
|
1499
1509
|
suppress_all_data_checks,
|
|
1500
1510
|
suppress_interactive_data_checks,
|
|
1501
1511
|
)
|
|
1502
1512
|
|
|
1503
|
-
|
|
1513
|
+
num_subjects = len(subject_ids)
|
|
1504
1514
|
|
|
1505
1515
|
(
|
|
1506
|
-
|
|
1516
|
+
joint_adjusted_meat_matrix,
|
|
1507
1517
|
classical_meat_matrix,
|
|
1508
|
-
|
|
1509
|
-
|
|
1518
|
+
per_subject_adjusted_corrections,
|
|
1519
|
+
per_subject_classical_corrections,
|
|
1510
1520
|
) = perform_desired_small_sample_correction(
|
|
1511
1521
|
small_sample_correction,
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1522
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1523
|
+
per_subject_classical_meat_contributions,
|
|
1524
|
+
per_subject_classical_bread_inverse_contributions,
|
|
1525
|
+
num_subjects,
|
|
1516
1526
|
theta_dim,
|
|
1517
1527
|
)
|
|
1518
1528
|
|
|
1519
1529
|
# Increase diagonal block dominance possibly improve conditioning of diagonal
|
|
1520
1530
|
# blocks as necessary, to ensure mathematical stability of joint bread inverse
|
|
1521
|
-
|
|
1531
|
+
stabilized_joint_adjusted_bread_inverse_matrix = (
|
|
1522
1532
|
(
|
|
1523
|
-
|
|
1524
|
-
|
|
1533
|
+
stabilize_joint_adjusted_bread_inverse_if_necessary(
|
|
1534
|
+
raw_joint_adjusted_bread_inverse_matrix,
|
|
1525
1535
|
beta_dim,
|
|
1526
1536
|
theta_dim,
|
|
1527
1537
|
)
|
|
1528
1538
|
)
|
|
1529
|
-
if
|
|
1530
|
-
else
|
|
1539
|
+
if stabilize_joint_adjusted_bread_inverse
|
|
1540
|
+
else raw_joint_adjusted_bread_inverse_matrix
|
|
1531
1541
|
)
|
|
1532
1542
|
|
|
1533
1543
|
# Now stably (no explicit inversion) form our sandwiches.
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1544
|
+
joint_adjusted_sandwich = form_sandwich_from_bread_inverse_and_meat(
|
|
1545
|
+
stabilized_joint_adjusted_bread_inverse_matrix,
|
|
1546
|
+
joint_adjusted_meat_matrix,
|
|
1547
|
+
num_subjects,
|
|
1538
1548
|
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1539
1549
|
)
|
|
1540
1550
|
classical_bread_inverse_matrix = jnp.mean(
|
|
1541
|
-
|
|
1551
|
+
per_subject_classical_bread_inverse_contributions, axis=0
|
|
1542
1552
|
)
|
|
1543
1553
|
classical_sandwich = form_sandwich_from_bread_inverse_and_meat(
|
|
1544
1554
|
classical_bread_inverse_matrix,
|
|
1545
1555
|
classical_meat_matrix,
|
|
1546
|
-
|
|
1556
|
+
num_subjects,
|
|
1547
1557
|
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1548
1558
|
)
|
|
1549
1559
|
|
|
1550
|
-
|
|
1551
|
-
(len(
|
|
1560
|
+
per_subject_adjusted_meat_adjustments = jnp.full(
|
|
1561
|
+
(len(subject_ids), theta_dim, theta_dim), jnp.nan
|
|
1552
1562
|
)
|
|
1553
|
-
if
|
|
1554
|
-
|
|
1555
|
-
|
|
1563
|
+
if form_adjusted_meat_adjustments_explicitly:
|
|
1564
|
+
per_subject_adjusted_classical_meat_contributions = (
|
|
1565
|
+
form_adjusted_meat_adjustments_directly(
|
|
1556
1566
|
theta_dim,
|
|
1557
1567
|
all_post_update_betas.shape[1],
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1568
|
+
stabilized_joint_adjusted_bread_inverse_matrix,
|
|
1569
|
+
per_subject_estimating_function_stacks,
|
|
1570
|
+
analysis_df,
|
|
1571
|
+
active_col_name,
|
|
1562
1572
|
action_col_name,
|
|
1563
1573
|
calendar_t_col_name,
|
|
1564
|
-
|
|
1574
|
+
subject_id_col_name,
|
|
1565
1575
|
action_prob_func,
|
|
1566
1576
|
action_prob_func_args,
|
|
1567
1577
|
action_prob_func_args_beta_index,
|
|
1568
1578
|
theta_est,
|
|
1569
1579
|
inference_func,
|
|
1570
1580
|
inference_func_args_theta_index,
|
|
1571
|
-
|
|
1581
|
+
subject_ids,
|
|
1572
1582
|
action_prob_col_name,
|
|
1573
1583
|
)
|
|
1574
1584
|
)
|
|
@@ -1578,30 +1588,30 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1578
1588
|
# First just apply any small-sample correction for parity.
|
|
1579
1589
|
(
|
|
1580
1590
|
_,
|
|
1581
|
-
|
|
1591
|
+
theta_only_adjusted_meat_matrix_v2,
|
|
1582
1592
|
_,
|
|
1583
1593
|
_,
|
|
1584
1594
|
) = perform_desired_small_sample_correction(
|
|
1585
1595
|
small_sample_correction,
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1596
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1597
|
+
per_subject_adjusted_classical_meat_contributions,
|
|
1598
|
+
per_subject_classical_bread_inverse_contributions,
|
|
1599
|
+
num_subjects,
|
|
1590
1600
|
theta_dim,
|
|
1591
1601
|
)
|
|
1592
|
-
|
|
1602
|
+
theta_only_adjusted_sandwich_from_adjustments = (
|
|
1593
1603
|
form_sandwich_from_bread_inverse_and_meat(
|
|
1594
1604
|
classical_bread_inverse_matrix,
|
|
1595
|
-
|
|
1596
|
-
|
|
1605
|
+
theta_only_adjusted_meat_matrix_v2,
|
|
1606
|
+
num_subjects,
|
|
1597
1607
|
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1598
1608
|
)
|
|
1599
1609
|
)
|
|
1600
|
-
|
|
1610
|
+
theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
|
|
1601
1611
|
|
|
1602
1612
|
if not np.allclose(
|
|
1603
|
-
|
|
1604
|
-
|
|
1613
|
+
theta_only_adjusted_sandwich,
|
|
1614
|
+
theta_only_adjusted_sandwich_from_adjustments,
|
|
1605
1615
|
rtol=3e-2,
|
|
1606
1616
|
):
|
|
1607
1617
|
logger.warning(
|
|
@@ -1611,26 +1621,26 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1611
1621
|
# Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
|
|
1612
1622
|
# values too. The joint adaptive bread inverse should always be block lower triangular.
|
|
1613
1623
|
return (
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1624
|
+
raw_joint_adjusted_bread_inverse_matrix,
|
|
1625
|
+
stabilized_joint_adjusted_bread_inverse_matrix,
|
|
1626
|
+
joint_adjusted_meat_matrix,
|
|
1627
|
+
joint_adjusted_sandwich,
|
|
1618
1628
|
classical_bread_inverse_matrix,
|
|
1619
1629
|
classical_meat_matrix,
|
|
1620
1630
|
classical_sandwich,
|
|
1621
1631
|
avg_estimating_function_stack,
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1632
|
+
per_subject_estimating_function_stacks,
|
|
1633
|
+
per_subject_adjusted_corrections,
|
|
1634
|
+
per_subject_classical_corrections,
|
|
1635
|
+
per_subject_adjusted_meat_adjustments,
|
|
1626
1636
|
)
|
|
1627
1637
|
|
|
1628
1638
|
|
|
1629
1639
|
# TODO: I think there should be interaction to confirm stabilization. It is
|
|
1630
|
-
# important for the
|
|
1631
|
-
# that the
|
|
1632
|
-
def
|
|
1633
|
-
|
|
1640
|
+
# important for the subject to know if this is happening. Even if enabled, it is important
|
|
1641
|
+
# that the subject know it actually kicks in.
|
|
1642
|
+
def stabilize_joint_adjusted_bread_inverse_if_necessary(
|
|
1643
|
+
joint_adjusted_bread_inverse_matrix: jnp.ndarray,
|
|
1634
1644
|
beta_dim: int,
|
|
1635
1645
|
theta_dim: int,
|
|
1636
1646
|
) -> jnp.ndarray:
|
|
@@ -1639,7 +1649,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1639
1649
|
dominance and/or adding a small ridge penalty to the diagonal blocks.
|
|
1640
1650
|
|
|
1641
1651
|
Args:
|
|
1642
|
-
|
|
1652
|
+
joint_adjusted_bread_inverse_matrix (jnp.ndarray):
|
|
1643
1653
|
A 2-D JAX NumPy array representing the joint adaptive bread inverse matrix.
|
|
1644
1654
|
beta_dim (int):
|
|
1645
1655
|
The dimension of each beta parameter.
|
|
@@ -1660,7 +1670,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1660
1670
|
|
|
1661
1671
|
# Grab just the RL block and convert numpy array for easier manipulation.
|
|
1662
1672
|
RL_stack_beta_derivatives_block = np.array(
|
|
1663
|
-
|
|
1673
|
+
joint_adjusted_bread_inverse_matrix[:-theta_dim, :-theta_dim]
|
|
1664
1674
|
)
|
|
1665
1675
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
1666
1676
|
for i in range(1, num_updates + 1):
|
|
@@ -1688,7 +1698,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1688
1698
|
RL_stack_beta_derivatives_block[
|
|
1689
1699
|
diagonal_block_slice, diagonal_block_slice
|
|
1690
1700
|
] = diagonal_block + ridge_penalty * np.eye(beta_dim)
|
|
1691
|
-
# TODO: Require
|
|
1701
|
+
# TODO: Require subject input here in interactive settings?
|
|
1692
1702
|
logger.info(
|
|
1693
1703
|
"Added ridge penalty of %s to diagonal block for update %s to improve conditioning from %s to %s",
|
|
1694
1704
|
ridge_penalty,
|
|
@@ -1779,11 +1789,11 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1779
1789
|
[
|
|
1780
1790
|
[
|
|
1781
1791
|
RL_stack_beta_derivatives_block,
|
|
1782
|
-
|
|
1792
|
+
joint_adjusted_bread_inverse_matrix[:-theta_dim, -theta_dim:],
|
|
1783
1793
|
],
|
|
1784
1794
|
[
|
|
1785
|
-
|
|
1786
|
-
|
|
1795
|
+
joint_adjusted_bread_inverse_matrix[-theta_dim:, :-theta_dim],
|
|
1796
|
+
joint_adjusted_bread_inverse_matrix[-theta_dim:, -theta_dim:],
|
|
1787
1797
|
],
|
|
1788
1798
|
]
|
|
1789
1799
|
)
|
|
@@ -1792,7 +1802,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1792
1802
|
def form_sandwich_from_bread_inverse_and_meat(
|
|
1793
1803
|
bread_inverse: jnp.ndarray,
|
|
1794
1804
|
meat: jnp.ndarray,
|
|
1795
|
-
|
|
1805
|
+
num_subjects: int,
|
|
1796
1806
|
method: str = SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1797
1807
|
) -> jnp.ndarray:
|
|
1798
1808
|
"""
|
|
@@ -1806,8 +1816,8 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1806
1816
|
A 2-D JAX NumPy array representing the bread inverse matrix.
|
|
1807
1817
|
meat (jnp.ndarray):
|
|
1808
1818
|
A 2-D JAX NumPy array representing the meat matrix.
|
|
1809
|
-
|
|
1810
|
-
The number of
|
|
1819
|
+
num_subjects (int):
|
|
1820
|
+
The number of subjects in the deployment, used to scale the sandwich appropriately.
|
|
1811
1821
|
method (str):
|
|
1812
1822
|
The method to use for forming the sandwich.
|
|
1813
1823
|
|
|
@@ -1833,7 +1843,7 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1833
1843
|
L, scipy.linalg.solve_triangular(L, meat.T, lower=True).T, lower=True
|
|
1834
1844
|
)
|
|
1835
1845
|
|
|
1836
|
-
return Q @ new_meat @ Q.T /
|
|
1846
|
+
return Q @ new_meat @ Q.T / num_subjects
|
|
1837
1847
|
elif method == SandwichFormationMethods.MEAT_SVD_SOLVE:
|
|
1838
1848
|
# Factor the meat via SVD without any symmetrization or truncation.
|
|
1839
1849
|
# For general (possibly slightly nonsymmetric) M, SVD gives M = U @ diag(s) @ Vh.
|
|
@@ -1847,14 +1857,14 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1847
1857
|
W_left = scipy.linalg.solve(bread_inverse, C_left)
|
|
1848
1858
|
W_right = scipy.linalg.solve(bread_inverse, C_right)
|
|
1849
1859
|
|
|
1850
|
-
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T /
|
|
1851
|
-
return W_left @ W_right.T /
|
|
1860
|
+
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
|
|
1861
|
+
return W_left @ W_right.T / num_subjects
|
|
1852
1862
|
|
|
1853
1863
|
elif method == SandwichFormationMethods.NAIVE:
|
|
1854
1864
|
# Simply invert the bread inverse and form the sandwich directly.
|
|
1855
1865
|
# This is NOT numerically stable and is only included for comparison purposes.
|
|
1856
1866
|
bread = np.linalg.inv(bread_inverse)
|
|
1857
|
-
return bread @ meat @ meat.T /
|
|
1867
|
+
return bread @ meat @ meat.T / num_subjects
|
|
1858
1868
|
|
|
1859
1869
|
else:
|
|
1860
1870
|
raise ValueError(
|