lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -23
- lifejacket/constants.py +4 -16
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +163 -138
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +32 -34
- lifejacket/get_datum_for_blowup_supervised_learning.py +341 -339
- lifejacket/helper_functions.py +60 -186
- lifejacket/input_checks.py +303 -302
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +470 -457
- lifejacket/small_sample_corrections.py +49 -49
- lifejacket-1.0.2.dist-info/METADATA +56 -0
- lifejacket-1.0.2.dist-info/RECORD +17 -0
- lifejacket-1.0.2.dist-info/entry_points.txt +2 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- lifejacket-0.2.1.dist-info/entry_points.txt +0 -2
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/top_level.txt +0 -0
|
@@ -24,8 +24,8 @@ from .constants import (
|
|
|
24
24
|
SandwichFormationMethods,
|
|
25
25
|
SmallSampleCorrections,
|
|
26
26
|
)
|
|
27
|
-
from .
|
|
28
|
-
|
|
27
|
+
from .form_adjusted_meat_adjustments_directly import (
|
|
28
|
+
form_adjusted_meat_adjustments_directly,
|
|
29
29
|
)
|
|
30
30
|
from . import input_checks
|
|
31
31
|
from . import get_datum_for_blowup_supervised_learning
|
|
@@ -37,9 +37,9 @@ from .helper_functions import (
|
|
|
37
37
|
calculate_beta_dim,
|
|
38
38
|
collect_all_post_update_betas,
|
|
39
39
|
construct_beta_index_by_policy_num_map,
|
|
40
|
-
|
|
40
|
+
extract_action_and_policy_by_decision_time_by_subject_id,
|
|
41
41
|
flatten_params,
|
|
42
|
-
|
|
42
|
+
get_active_df_column,
|
|
43
43
|
get_min_time_by_policy_num,
|
|
44
44
|
get_radon_nikodym_weight,
|
|
45
45
|
load_function_from_same_named_file,
|
|
@@ -61,7 +61,7 @@ def cli():
|
|
|
61
61
|
|
|
62
62
|
# TODO: Check all help strings for accuracy.
|
|
63
63
|
# TODO: Deal with NA, -1, etc policy numbers
|
|
64
|
-
# TODO: Make sure in
|
|
64
|
+
# TODO: Make sure in deployment is never on for more than one stretch EDIT: unclear if
|
|
65
65
|
# this will remain an invariant as we deal with more complicated data missingness
|
|
66
66
|
# TODO: I think I'm agnostic to indexing of calendar times but should check because
|
|
67
67
|
# otherwise need to add a check here to verify required format.
|
|
@@ -69,7 +69,7 @@ def cli():
|
|
|
69
69
|
# Higher dimensional objects not supported. Not entirely sure what kind of "scalars" apply.
|
|
70
70
|
@cli.command(name="analyze")
|
|
71
71
|
@click.option(
|
|
72
|
-
"--
|
|
72
|
+
"--analysis_df_pickle",
|
|
73
73
|
type=click.File("rb"),
|
|
74
74
|
help="Pickled pandas dataframe in correct format (see contract/readme).",
|
|
75
75
|
required=True,
|
|
@@ -83,7 +83,7 @@ def cli():
|
|
|
83
83
|
@click.option(
|
|
84
84
|
"--action_prob_func_args_pickle",
|
|
85
85
|
type=click.File("rb"),
|
|
86
|
-
help="Pickled dictionary that contains the action probability function arguments for all decision times for all
|
|
86
|
+
help="Pickled dictionary that contains the action probability function arguments for all decision times for all subjects.",
|
|
87
87
|
required=True,
|
|
88
88
|
)
|
|
89
89
|
@click.option(
|
|
@@ -95,7 +95,7 @@ def cli():
|
|
|
95
95
|
@click.option(
|
|
96
96
|
"--alg_update_func_filename",
|
|
97
97
|
type=click.Path(exists=True),
|
|
98
|
-
help="File that contains the per-
|
|
98
|
+
help="File that contains the per-subject update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
|
|
99
99
|
required=True,
|
|
100
100
|
)
|
|
101
101
|
@click.option(
|
|
@@ -107,7 +107,7 @@ def cli():
|
|
|
107
107
|
@click.option(
|
|
108
108
|
"--alg_update_func_args_pickle",
|
|
109
109
|
type=click.File("rb"),
|
|
110
|
-
help="Pickled dictionary that contains the algorithm update function arguments for all update times for all
|
|
110
|
+
help="Pickled dictionary that contains the algorithm update function arguments for all update times for all subjects.",
|
|
111
111
|
required=True,
|
|
112
112
|
)
|
|
113
113
|
@click.option(
|
|
@@ -137,7 +137,7 @@ def cli():
|
|
|
137
137
|
@click.option(
|
|
138
138
|
"--inference_func_filename",
|
|
139
139
|
type=click.Path(exists=True),
|
|
140
|
-
help="File that contains the per-
|
|
140
|
+
help="File that contains the per-subject loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
|
|
141
141
|
required=True,
|
|
142
142
|
)
|
|
143
143
|
@click.option(
|
|
@@ -155,56 +155,56 @@ def cli():
|
|
|
155
155
|
@click.option(
|
|
156
156
|
"--theta_calculation_func_filename",
|
|
157
157
|
type=click.Path(exists=True),
|
|
158
|
-
help="Path to file that allows one to actually calculate a theta estimate given the
|
|
158
|
+
help="Path to file that allows one to actually calculate a theta estimate given the analysis dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
|
|
159
159
|
required=True,
|
|
160
160
|
)
|
|
161
161
|
@click.option(
|
|
162
|
-
"--
|
|
162
|
+
"--active_col_name",
|
|
163
163
|
type=str,
|
|
164
164
|
required=True,
|
|
165
|
-
help="Name of the binary column in the
|
|
165
|
+
help="Name of the binary column in the analysis dataframe that indicates whether a subject is in the deployment.",
|
|
166
166
|
)
|
|
167
167
|
@click.option(
|
|
168
168
|
"--action_col_name",
|
|
169
169
|
type=str,
|
|
170
170
|
required=True,
|
|
171
|
-
help="Name of the binary column in the
|
|
171
|
+
help="Name of the binary column in the analysis dataframe that indicates which action was taken.",
|
|
172
172
|
)
|
|
173
173
|
@click.option(
|
|
174
174
|
"--policy_num_col_name",
|
|
175
175
|
type=str,
|
|
176
176
|
required=True,
|
|
177
|
-
help="Name of the column in the
|
|
177
|
+
help="Name of the column in the analysis dataframe that indicates the policy number in use.",
|
|
178
178
|
)
|
|
179
179
|
@click.option(
|
|
180
180
|
"--calendar_t_col_name",
|
|
181
181
|
type=str,
|
|
182
182
|
required=True,
|
|
183
|
-
help="Name of the column in the
|
|
183
|
+
help="Name of the column in the analysis dataframe that indicates calendar time (shared integer index across subjects).",
|
|
184
184
|
)
|
|
185
185
|
@click.option(
|
|
186
|
-
"--
|
|
186
|
+
"--subject_id_col_name",
|
|
187
187
|
type=str,
|
|
188
188
|
required=True,
|
|
189
|
-
help="Name of the column in the
|
|
189
|
+
help="Name of the column in the analysis dataframe that indicates subject id.",
|
|
190
190
|
)
|
|
191
191
|
@click.option(
|
|
192
192
|
"--action_prob_col_name",
|
|
193
193
|
type=str,
|
|
194
194
|
required=True,
|
|
195
|
-
help="Name of the column in the
|
|
195
|
+
help="Name of the column in the analysis dataframe that gives action one probabilities.",
|
|
196
196
|
)
|
|
197
197
|
@click.option(
|
|
198
198
|
"--reward_col_name",
|
|
199
199
|
type=str,
|
|
200
200
|
required=True,
|
|
201
|
-
help="Name of the column in the
|
|
201
|
+
help="Name of the column in the analysis dataframe that gives rewards.",
|
|
202
202
|
)
|
|
203
203
|
@click.option(
|
|
204
204
|
"--suppress_interactive_data_checks",
|
|
205
205
|
type=bool,
|
|
206
206
|
default=False,
|
|
207
|
-
help="Flag to suppress any data checks that require
|
|
207
|
+
help="Flag to suppress any data checks that require subject input. This is suitable for tests and large simulations",
|
|
208
208
|
)
|
|
209
209
|
@click.option(
|
|
210
210
|
"--suppress_all_data_checks",
|
|
@@ -217,9 +217,9 @@ def cli():
|
|
|
217
217
|
type=click.Choice(
|
|
218
218
|
[
|
|
219
219
|
SmallSampleCorrections.NONE,
|
|
220
|
-
SmallSampleCorrections.
|
|
221
|
-
SmallSampleCorrections.
|
|
222
|
-
SmallSampleCorrections.
|
|
220
|
+
SmallSampleCorrections.Z1theta,
|
|
221
|
+
SmallSampleCorrections.Z2theta,
|
|
222
|
+
SmallSampleCorrections.Z3theta,
|
|
223
223
|
]
|
|
224
224
|
),
|
|
225
225
|
default=SmallSampleCorrections.NONE,
|
|
@@ -232,23 +232,23 @@ def cli():
|
|
|
232
232
|
help="Flag to collect data for supervised learning blowup detection. This will write a single datum and label to a file in the same directory as the input files.",
|
|
233
233
|
)
|
|
234
234
|
@click.option(
|
|
235
|
-
"--
|
|
235
|
+
"--form_adjusted_meat_adjustments_explicitly",
|
|
236
236
|
type=bool,
|
|
237
237
|
default=False,
|
|
238
|
-
help="If True, explicitly forms the per-
|
|
238
|
+
help="If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted sandwich from the classical sandwich. This is for diagnostic purposes, as the adjusted sandwich is formed without doing this.",
|
|
239
239
|
)
|
|
240
240
|
@click.option(
|
|
241
|
-
"--
|
|
241
|
+
"--stabilize_joint_bread",
|
|
242
242
|
type=bool,
|
|
243
243
|
default=True,
|
|
244
|
-
help="If True, stabilizes the joint
|
|
244
|
+
help="If True, stabilizes the joint bread matrix if it does not meet conditioning thresholds.",
|
|
245
245
|
)
|
|
246
246
|
def analyze_dataset_wrapper(**kwargs):
|
|
247
247
|
"""
|
|
248
248
|
This function is a wrapper around analyze_dataset to facilitate command line use.
|
|
249
249
|
|
|
250
250
|
From the command line, we will take pickles and filenames for Python objects.
|
|
251
|
-
|
|
251
|
+
We unpickle/load files here for passing to the implementation function, which
|
|
252
252
|
may also be called in its own right with in-memory objects.
|
|
253
253
|
|
|
254
254
|
See analyze_dataset for the underlying details.
|
|
@@ -256,18 +256,20 @@ def analyze_dataset_wrapper(**kwargs):
|
|
|
256
256
|
Returns: None
|
|
257
257
|
"""
|
|
258
258
|
|
|
259
|
-
# Pass along the folder the
|
|
260
|
-
# Do it now because we will be removing the
|
|
261
|
-
kwargs["output_dir"] = pathlib.Path(
|
|
259
|
+
# Pass along the folder the analysis dataframe is in as the output folder.
|
|
260
|
+
# Do it now because we will be removing the analysis dataframe pickle from kwargs.
|
|
261
|
+
kwargs["output_dir"] = pathlib.Path(
|
|
262
|
+
kwargs["analysis_df_pickle"].name
|
|
263
|
+
).parent.resolve()
|
|
262
264
|
|
|
263
265
|
# Unpickle pickles and replace those args in kwargs
|
|
264
|
-
kwargs["
|
|
266
|
+
kwargs["analysis_df"] = pickle.load(kwargs["analysis_df_pickle"])
|
|
265
267
|
kwargs["action_prob_func_args"] = pickle.load(
|
|
266
268
|
kwargs["action_prob_func_args_pickle"]
|
|
267
269
|
)
|
|
268
270
|
kwargs["alg_update_func_args"] = pickle.load(kwargs["alg_update_func_args_pickle"])
|
|
269
271
|
|
|
270
|
-
kwargs.pop("
|
|
272
|
+
kwargs.pop("analysis_df_pickle")
|
|
271
273
|
kwargs.pop("action_prob_func_args_pickle")
|
|
272
274
|
kwargs.pop("alg_update_func_args_pickle")
|
|
273
275
|
|
|
@@ -295,7 +297,7 @@ def analyze_dataset_wrapper(**kwargs):
|
|
|
295
297
|
|
|
296
298
|
def analyze_dataset(
|
|
297
299
|
output_dir: pathlib.Path | str,
|
|
298
|
-
|
|
300
|
+
analysis_df: pd.DataFrame,
|
|
299
301
|
action_prob_func: Callable,
|
|
300
302
|
action_prob_func_args: dict[int, Any],
|
|
301
303
|
action_prob_func_args_beta_index: int,
|
|
@@ -310,35 +312,35 @@ def analyze_dataset(
|
|
|
310
312
|
inference_func_type: str,
|
|
311
313
|
inference_func_args_theta_index: int,
|
|
312
314
|
theta_calculation_func: Callable[[pd.DataFrame], jnp.ndarray],
|
|
313
|
-
|
|
315
|
+
active_col_name: str,
|
|
314
316
|
action_col_name: str,
|
|
315
317
|
policy_num_col_name: str,
|
|
316
318
|
calendar_t_col_name: str,
|
|
317
|
-
|
|
319
|
+
subject_id_col_name: str,
|
|
318
320
|
action_prob_col_name: str,
|
|
319
321
|
reward_col_name: str,
|
|
320
322
|
suppress_interactive_data_checks: bool,
|
|
321
323
|
suppress_all_data_checks: bool,
|
|
322
324
|
small_sample_correction: str,
|
|
323
325
|
collect_data_for_blowup_supervised_learning: bool,
|
|
324
|
-
|
|
325
|
-
|
|
326
|
+
form_adjusted_meat_adjustments_explicitly: bool,
|
|
327
|
+
stabilize_joint_bread: bool,
|
|
326
328
|
) -> None:
|
|
327
329
|
"""
|
|
328
|
-
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using
|
|
330
|
+
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using and classical sandwich estimators.
|
|
329
331
|
|
|
330
332
|
There are two modes of use for this function.
|
|
331
333
|
|
|
332
334
|
First, it may be called indirectly from the command line by passing through
|
|
333
|
-
|
|
335
|
+
analyze_dataset_wrapper.
|
|
334
336
|
|
|
335
337
|
Second, it may be called directly from Python code with in-memory objects.
|
|
336
338
|
|
|
337
339
|
Parameters:
|
|
338
340
|
output_dir (pathlib.Path | str):
|
|
339
341
|
Directory in which to save output files.
|
|
340
|
-
|
|
341
|
-
DataFrame containing the
|
|
342
|
+
analysis_df (pd.DataFrame):
|
|
343
|
+
DataFrame containing the deployment data.
|
|
342
344
|
action_prob_func (callable):
|
|
343
345
|
Action probability function.
|
|
344
346
|
action_prob_func_args (dict[int, Any]):
|
|
@@ -364,21 +366,21 @@ def analyze_dataset(
|
|
|
364
366
|
inference_func_args_theta_index (int):
|
|
365
367
|
Index for theta in inference function arguments.
|
|
366
368
|
theta_calculation_func (callable):
|
|
367
|
-
Function to estimate theta from the
|
|
368
|
-
|
|
369
|
-
Column name indicating if a
|
|
369
|
+
Function to estimate theta from the analysis dataframe.
|
|
370
|
+
active_col_name (str):
|
|
371
|
+
Column name indicating if a subject is active in the analysis dataframe.
|
|
370
372
|
action_col_name (str):
|
|
371
|
-
Column name for actions in the
|
|
373
|
+
Column name for actions in the analysis dataframe.
|
|
372
374
|
policy_num_col_name (str):
|
|
373
|
-
Column name for policy numbers in the
|
|
375
|
+
Column name for policy numbers in the analysis dataframe.
|
|
374
376
|
calendar_t_col_name (str):
|
|
375
|
-
Column name for calendar time in the
|
|
376
|
-
|
|
377
|
-
Column name for
|
|
377
|
+
Column name for calendar time in the analysis dataframe.
|
|
378
|
+
subject_id_col_name (str):
|
|
379
|
+
Column name for subject IDs in the analysis dataframe.
|
|
378
380
|
action_prob_col_name (str):
|
|
379
|
-
Column name for action probabilities in the
|
|
381
|
+
Column name for action probabilities in the analysis dataframe.
|
|
380
382
|
reward_col_name (str):
|
|
381
|
-
Column name for rewards in the
|
|
383
|
+
Column name for rewards in the analysis dataframe.
|
|
382
384
|
suppress_interactive_data_checks (bool):
|
|
383
385
|
Whether to suppress interactive data checks. This should be used in simulations, for example.
|
|
384
386
|
suppress_all_data_checks (bool):
|
|
@@ -386,17 +388,17 @@ def analyze_dataset(
|
|
|
386
388
|
small_sample_correction (str):
|
|
387
389
|
Type of small sample correction to apply.
|
|
388
390
|
collect_data_for_blowup_supervised_learning (bool):
|
|
389
|
-
Whether to collect data for doing supervised learning about
|
|
390
|
-
|
|
391
|
-
If True, explicitly forms the per-
|
|
391
|
+
Whether to collect data for doing supervised learning about adjusted sandwich blowup.
|
|
392
|
+
form_adjusted_meat_adjustments_explicitly (bool):
|
|
393
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
392
394
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
If True, stabilizes the joint
|
|
395
|
+
adjusted sandwich is formed without doing this.
|
|
396
|
+
stabilize_joint_bread (bool):
|
|
397
|
+
If True, stabilizes the joint bread matrix if it does not meet conditioning
|
|
396
398
|
thresholds.
|
|
397
399
|
|
|
398
400
|
Returns:
|
|
399
|
-
dict: A dictionary containing the theta estimate,
|
|
401
|
+
dict: A dictionary containing the theta estimate, adjusted sandwich variance estimate, and
|
|
400
402
|
classical sandwich variance estimate.
|
|
401
403
|
"""
|
|
402
404
|
|
|
@@ -406,19 +408,19 @@ def analyze_dataset(
|
|
|
406
408
|
level=logging.INFO,
|
|
407
409
|
)
|
|
408
410
|
|
|
409
|
-
theta_est = jnp.array(theta_calculation_func(
|
|
411
|
+
theta_est = jnp.array(theta_calculation_func(analysis_df))
|
|
410
412
|
|
|
411
413
|
beta_dim = calculate_beta_dim(
|
|
412
414
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
413
415
|
)
|
|
414
416
|
if not suppress_all_data_checks:
|
|
415
417
|
input_checks.perform_first_wave_input_checks(
|
|
416
|
-
|
|
417
|
-
|
|
418
|
+
analysis_df,
|
|
419
|
+
active_col_name,
|
|
418
420
|
action_col_name,
|
|
419
421
|
policy_num_col_name,
|
|
420
422
|
calendar_t_col_name,
|
|
421
|
-
|
|
423
|
+
subject_id_col_name,
|
|
422
424
|
action_prob_col_name,
|
|
423
425
|
reward_col_name,
|
|
424
426
|
action_prob_func,
|
|
@@ -436,10 +438,9 @@ def analyze_dataset(
|
|
|
436
438
|
)
|
|
437
439
|
|
|
438
440
|
### Begin collecting data structures that will be used to compute the joint bread matrix.
|
|
439
|
-
|
|
440
441
|
beta_index_by_policy_num, initial_policy_num = (
|
|
441
442
|
construct_beta_index_by_policy_num_map(
|
|
442
|
-
|
|
443
|
+
analysis_df, policy_num_col_name, active_col_name
|
|
443
444
|
)
|
|
444
445
|
)
|
|
445
446
|
|
|
@@ -447,11 +448,11 @@ def analyze_dataset(
|
|
|
447
448
|
beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
|
|
448
449
|
)
|
|
449
450
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
451
|
+
action_by_decision_time_by_subject_id, policy_num_by_decision_time_by_subject_id = (
|
|
452
|
+
extract_action_and_policy_by_decision_time_by_subject_id(
|
|
453
|
+
analysis_df,
|
|
454
|
+
subject_id_col_name,
|
|
455
|
+
active_col_name,
|
|
455
456
|
calendar_t_col_name,
|
|
456
457
|
action_col_name,
|
|
457
458
|
policy_num_col_name,
|
|
@@ -459,45 +460,45 @@ def analyze_dataset(
|
|
|
459
460
|
)
|
|
460
461
|
|
|
461
462
|
(
|
|
462
|
-
|
|
463
|
+
inference_func_args_by_subject_id,
|
|
463
464
|
inference_func_args_action_prob_index,
|
|
464
|
-
|
|
465
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
465
466
|
) = process_inference_func_args(
|
|
466
467
|
inference_func,
|
|
467
468
|
inference_func_args_theta_index,
|
|
468
|
-
|
|
469
|
+
analysis_df,
|
|
469
470
|
theta_est,
|
|
470
471
|
action_prob_col_name,
|
|
471
472
|
calendar_t_col_name,
|
|
472
|
-
|
|
473
|
-
|
|
473
|
+
subject_id_col_name,
|
|
474
|
+
active_col_name,
|
|
474
475
|
)
|
|
475
476
|
|
|
476
|
-
# Use a per-
|
|
477
|
-
#
|
|
477
|
+
# Use a per-subject weighted estimating function stacking function to derive classical and joint
|
|
478
|
+
# meat and bread matrices. This is facilitated because the *value* of the
|
|
478
479
|
# weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
|
|
479
480
|
logger.info(
|
|
480
|
-
"Constructing joint
|
|
481
|
+
"Constructing joint bread matrix, joint meat matrix, the classical analogs, and the avg estimating function stack across subjects."
|
|
481
482
|
)
|
|
482
483
|
|
|
483
|
-
|
|
484
|
+
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
484
485
|
(
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
486
|
+
stabilized_joint_adjusted_bread_matrix,
|
|
487
|
+
raw_joint_adjusted_bread_matrix,
|
|
488
|
+
joint_adjusted_meat_matrix,
|
|
489
|
+
joint_adjusted_sandwich_matrix,
|
|
490
|
+
classical_bread_matrix,
|
|
490
491
|
classical_meat_matrix,
|
|
491
492
|
classical_sandwich_var_estimate,
|
|
492
493
|
avg_estimating_function_stack,
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
) =
|
|
494
|
+
per_subject_estimating_function_stacks,
|
|
495
|
+
per_subject_adjusted_corrections,
|
|
496
|
+
per_subject_classical_corrections,
|
|
497
|
+
per_subject_adjusted_meat_adjustments,
|
|
498
|
+
) = construct_classical_and_adjusted_sandwiches(
|
|
498
499
|
theta_est,
|
|
499
500
|
all_post_update_betas,
|
|
500
|
-
|
|
501
|
+
subject_ids,
|
|
501
502
|
action_prob_func,
|
|
502
503
|
action_prob_func_args_beta_index,
|
|
503
504
|
alg_update_func,
|
|
@@ -511,23 +512,23 @@ def analyze_dataset(
|
|
|
511
512
|
inference_func_args_theta_index,
|
|
512
513
|
inference_func_args_action_prob_index,
|
|
513
514
|
action_prob_func_args,
|
|
514
|
-
|
|
515
|
+
policy_num_by_decision_time_by_subject_id,
|
|
515
516
|
initial_policy_num,
|
|
516
517
|
beta_index_by_policy_num,
|
|
517
|
-
|
|
518
|
-
|
|
518
|
+
inference_func_args_by_subject_id,
|
|
519
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
519
520
|
alg_update_func_args,
|
|
520
|
-
|
|
521
|
+
action_by_decision_time_by_subject_id,
|
|
521
522
|
suppress_all_data_checks,
|
|
522
523
|
suppress_interactive_data_checks,
|
|
523
524
|
small_sample_correction,
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
525
|
+
form_adjusted_meat_adjustments_explicitly,
|
|
526
|
+
stabilize_joint_bread,
|
|
527
|
+
analysis_df,
|
|
528
|
+
active_col_name,
|
|
528
529
|
action_col_name,
|
|
529
530
|
calendar_t_col_name,
|
|
530
|
-
|
|
531
|
+
subject_id_col_name,
|
|
531
532
|
action_prob_func_args,
|
|
532
533
|
action_prob_col_name,
|
|
533
534
|
)
|
|
@@ -543,27 +544,26 @@ def analyze_dataset(
|
|
|
543
544
|
|
|
544
545
|
# This bottom right corner of the joint (betas and theta) variance matrix is the portion
|
|
545
546
|
# corresponding to just theta.
|
|
546
|
-
|
|
547
|
+
adjusted_sandwich_var_estimate = joint_adjusted_sandwich_matrix[
|
|
547
548
|
-theta_dim:, -theta_dim:
|
|
548
549
|
]
|
|
549
550
|
|
|
550
551
|
# Check for negative diagonal elements and set them to zero if found
|
|
551
|
-
|
|
552
|
-
if np.any(
|
|
552
|
+
adjusted_diagonal = np.diag(adjusted_sandwich_var_estimate)
|
|
553
|
+
if np.any(adjusted_diagonal < 0):
|
|
553
554
|
logger.warning(
|
|
554
|
-
"Found negative diagonal elements in
|
|
555
|
+
"Found negative diagonal elements in adjusted sandwich variance estimate. Setting them to zero."
|
|
555
556
|
)
|
|
556
557
|
np.fill_diagonal(
|
|
557
|
-
|
|
558
|
+
adjusted_sandwich_var_estimate, np.maximum(adjusted_diagonal, 0)
|
|
558
559
|
)
|
|
559
560
|
|
|
560
561
|
logger.info("Writing results to file...")
|
|
561
|
-
# Write analysis results to same directory that input files are in
|
|
562
562
|
output_folder_abs_path = pathlib.Path(output_dir).resolve()
|
|
563
563
|
|
|
564
564
|
analysis_dict = {
|
|
565
565
|
"theta_est": theta_est,
|
|
566
|
-
"
|
|
566
|
+
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
567
567
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
568
568
|
}
|
|
569
569
|
with open(output_folder_abs_path / "analysis.pkl", "wb") as f:
|
|
@@ -572,29 +572,35 @@ def analyze_dataset(
|
|
|
572
572
|
f,
|
|
573
573
|
)
|
|
574
574
|
|
|
575
|
-
|
|
576
|
-
|
|
575
|
+
joint_adjusted_bread_cond = jnp.linalg.cond(raw_joint_adjusted_bread_matrix)
|
|
576
|
+
logger.info(
|
|
577
|
+
"Joint adjusted bread condition number: %f",
|
|
578
|
+
joint_adjusted_bread_cond,
|
|
577
579
|
)
|
|
580
|
+
|
|
581
|
+
# calculate the max eigenvalue of the joint adjusted sandwich
|
|
582
|
+
max_eigenvalue = scipy.linalg.eigvalsh(joint_adjusted_sandwich_matrix).max()
|
|
578
583
|
logger.info(
|
|
579
|
-
"
|
|
580
|
-
|
|
584
|
+
"Max eigenvalue of joint adjusted sandwich matrix: %f",
|
|
585
|
+
max_eigenvalue,
|
|
581
586
|
)
|
|
582
587
|
|
|
583
588
|
debug_pieces_dict = {
|
|
584
589
|
"theta_est": theta_est,
|
|
585
|
-
"
|
|
590
|
+
"adjusted_sandwich_var_estimate": adjusted_sandwich_var_estimate,
|
|
586
591
|
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
587
|
-
"
|
|
588
|
-
"
|
|
589
|
-
"joint_meat_matrix":
|
|
590
|
-
"
|
|
592
|
+
"raw_joint_bread_matrix": raw_joint_adjusted_bread_matrix,
|
|
593
|
+
"stabilized_joint_bread_matrix": stabilized_joint_adjusted_bread_matrix,
|
|
594
|
+
"joint_meat_matrix": joint_adjusted_meat_matrix,
|
|
595
|
+
"classical_bread_matrix": classical_bread_matrix,
|
|
591
596
|
"classical_meat_matrix": classical_meat_matrix,
|
|
592
|
-
"all_estimating_function_stacks":
|
|
593
|
-
"
|
|
597
|
+
"all_estimating_function_stacks": per_subject_estimating_function_stacks,
|
|
598
|
+
"joint_bread_condition_number": joint_adjusted_bread_cond,
|
|
599
|
+
"max_eigenvalue_joint_adjusted_sandwich": max_eigenvalue,
|
|
594
600
|
"all_post_update_betas": all_post_update_betas,
|
|
595
|
-
"
|
|
596
|
-
"
|
|
597
|
-
"
|
|
601
|
+
"per_subject_adjusted_corrections": per_subject_adjusted_corrections,
|
|
602
|
+
"per_subject_classical_corrections": per_subject_classical_corrections,
|
|
603
|
+
"per_subject_adjusted_meat_adjustments": per_subject_adjusted_meat_adjustments,
|
|
598
604
|
}
|
|
599
605
|
with open(output_folder_abs_path / "debug_pieces.pkl", "wb") as f:
|
|
600
606
|
pickle.dump(
|
|
@@ -604,25 +610,25 @@ def analyze_dataset(
|
|
|
604
610
|
|
|
605
611
|
if collect_data_for_blowup_supervised_learning:
|
|
606
612
|
datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
|
|
607
|
-
|
|
608
|
-
|
|
613
|
+
raw_joint_adjusted_bread_matrix,
|
|
614
|
+
joint_adjusted_bread_cond,
|
|
609
615
|
avg_estimating_function_stack,
|
|
610
|
-
|
|
616
|
+
per_subject_estimating_function_stacks,
|
|
611
617
|
all_post_update_betas,
|
|
612
|
-
|
|
613
|
-
|
|
618
|
+
analysis_df,
|
|
619
|
+
active_col_name,
|
|
614
620
|
calendar_t_col_name,
|
|
615
621
|
action_prob_col_name,
|
|
616
|
-
|
|
622
|
+
subject_id_col_name,
|
|
617
623
|
reward_col_name,
|
|
618
624
|
theta_est,
|
|
619
|
-
|
|
620
|
-
|
|
625
|
+
adjusted_sandwich_var_estimate,
|
|
626
|
+
subject_ids,
|
|
621
627
|
beta_dim,
|
|
622
628
|
theta_dim,
|
|
623
629
|
initial_policy_num,
|
|
624
630
|
beta_index_by_policy_num,
|
|
625
|
-
|
|
631
|
+
policy_num_by_decision_time_by_subject_id,
|
|
626
632
|
theta_calculation_func,
|
|
627
633
|
action_prob_func,
|
|
628
634
|
action_prob_func_args_beta_index,
|
|
@@ -630,16 +636,16 @@ def analyze_dataset(
|
|
|
630
636
|
inference_func_type,
|
|
631
637
|
inference_func_args_theta_index,
|
|
632
638
|
inference_func_args_action_prob_index,
|
|
633
|
-
|
|
639
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
634
640
|
action_prob_func_args,
|
|
635
|
-
|
|
641
|
+
action_by_decision_time_by_subject_id,
|
|
636
642
|
)
|
|
637
643
|
|
|
638
644
|
with open(output_folder_abs_path / "supervised_learning_datum.pkl", "wb") as f:
|
|
639
645
|
pickle.dump(datum_and_label_dict, f)
|
|
640
646
|
|
|
641
647
|
print(f"\nParameter estimate:\n {theta_est}")
|
|
642
|
-
print(f"\
|
|
648
|
+
print(f"\nAdjusted sandwich variance estimate:\n {adjusted_sandwich_var_estimate}")
|
|
643
649
|
print(
|
|
644
650
|
f"\nClassical sandwich variance estimate:\n {classical_sandwich_var_estimate}\n"
|
|
645
651
|
)
|
|
@@ -650,15 +656,15 @@ def analyze_dataset(
|
|
|
650
656
|
def process_inference_func_args(
|
|
651
657
|
inference_func: callable,
|
|
652
658
|
inference_func_args_theta_index: int,
|
|
653
|
-
|
|
659
|
+
analysis_df: pd.DataFrame,
|
|
654
660
|
theta_est: jnp.ndarray,
|
|
655
661
|
action_prob_col_name: str,
|
|
656
662
|
calendar_t_col_name: str,
|
|
657
|
-
|
|
658
|
-
|
|
663
|
+
subject_id_col_name: str,
|
|
664
|
+
active_col_name: str,
|
|
659
665
|
) -> tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int]:
|
|
660
666
|
"""
|
|
661
|
-
Collects the inference function arguments for each
|
|
667
|
+
Collects the inference function arguments for each subject from the analysis DataFrame.
|
|
662
668
|
|
|
663
669
|
Note that theta and action probabilities, if present, will be replaced later
|
|
664
670
|
so that the function can be differentiated with respect to shared versions
|
|
@@ -669,32 +675,32 @@ def process_inference_func_args(
|
|
|
669
675
|
The inference function to be used.
|
|
670
676
|
inference_func_args_theta_index (int):
|
|
671
677
|
The index of the theta parameter in the inference function's arguments.
|
|
672
|
-
|
|
673
|
-
The
|
|
678
|
+
analysis_df (pandas.DataFrame):
|
|
679
|
+
The analysis DataFrame.
|
|
674
680
|
theta_est (jnp.ndarray):
|
|
675
681
|
The estimate of the parameter vector.
|
|
676
682
|
action_prob_col_name (str):
|
|
677
|
-
The name of the column in the
|
|
683
|
+
The name of the column in the analysis DataFrame that gives action probabilities.
|
|
678
684
|
calendar_t_col_name (str):
|
|
679
|
-
The name of the column in the
|
|
680
|
-
|
|
681
|
-
The name of the column in the
|
|
682
|
-
|
|
683
|
-
The name of the binary column in the
|
|
685
|
+
The name of the column in the analysis DataFrame that indicates calendar time.
|
|
686
|
+
subject_id_col_name (str):
|
|
687
|
+
The name of the column in the analysis DataFrame that indicates subject ID.
|
|
688
|
+
active_col_name (str):
|
|
689
|
+
The name of the binary column in the analysis DataFrame that indicates whether a subject is in the deployment.
|
|
684
690
|
Returns:
|
|
685
691
|
tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int, dict[collections.abc.Hashable, jnp.ndarray[int]]]:
|
|
686
692
|
A tuple containing
|
|
687
|
-
- the inference function arguments dictionary for each
|
|
693
|
+
- the inference function arguments dictionary for each subject
|
|
688
694
|
- the index of the action probabilities argument
|
|
689
|
-
- a dictionary mapping
|
|
695
|
+
- a dictionary mapping subject IDs to the decision times to which action probabilities correspond
|
|
690
696
|
"""
|
|
691
697
|
|
|
692
698
|
num_args = inference_func.__code__.co_argcount
|
|
693
699
|
inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
|
|
694
|
-
|
|
700
|
+
inference_func_args_by_subject_id = {}
|
|
695
701
|
|
|
696
702
|
inference_func_args_action_prob_index = -1
|
|
697
|
-
|
|
703
|
+
inference_action_prob_decision_times_by_subject_id = {}
|
|
698
704
|
|
|
699
705
|
using_action_probs = action_prob_col_name in inference_func_arg_names
|
|
700
706
|
if using_action_probs:
|
|
@@ -702,34 +708,36 @@ def process_inference_func_args(
|
|
|
702
708
|
action_prob_col_name
|
|
703
709
|
)
|
|
704
710
|
|
|
705
|
-
for
|
|
706
|
-
|
|
707
|
-
|
|
711
|
+
for subject_id in analysis_df[subject_id_col_name].unique():
|
|
712
|
+
subject_args_list = []
|
|
713
|
+
filtered_subject_data = analysis_df.loc[
|
|
714
|
+
analysis_df[subject_id_col_name] == subject_id
|
|
715
|
+
]
|
|
708
716
|
for idx, col_name in enumerate(inference_func_arg_names):
|
|
709
717
|
if idx == inference_func_args_theta_index:
|
|
710
|
-
|
|
718
|
+
subject_args_list.append(theta_est)
|
|
711
719
|
continue
|
|
712
|
-
|
|
713
|
-
|
|
720
|
+
subject_args_list.append(
|
|
721
|
+
get_active_df_column(filtered_subject_data, col_name, active_col_name)
|
|
714
722
|
)
|
|
715
|
-
|
|
723
|
+
inference_func_args_by_subject_id[subject_id] = tuple(subject_args_list)
|
|
716
724
|
if using_action_probs:
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
725
|
+
inference_action_prob_decision_times_by_subject_id[subject_id] = (
|
|
726
|
+
get_active_df_column(
|
|
727
|
+
filtered_subject_data, calendar_t_col_name, active_col_name
|
|
720
728
|
)
|
|
721
729
|
)
|
|
722
730
|
|
|
723
731
|
return (
|
|
724
|
-
|
|
732
|
+
inference_func_args_by_subject_id,
|
|
725
733
|
inference_func_args_action_prob_index,
|
|
726
|
-
|
|
734
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
727
735
|
)
|
|
728
736
|
|
|
729
737
|
|
|
730
|
-
def
|
|
738
|
+
def single_subject_weighted_estimating_function_stacker(
|
|
731
739
|
beta_dim: int,
|
|
732
|
-
|
|
740
|
+
subject_id: collections.abc.Hashable,
|
|
733
741
|
action_prob_func: callable,
|
|
734
742
|
algorithm_estimating_func: callable,
|
|
735
743
|
inference_estimating_func: callable,
|
|
@@ -763,12 +771,12 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
763
771
|
beta_dim (list[jnp.ndarray]):
|
|
764
772
|
A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
|
|
765
773
|
|
|
766
|
-
|
|
767
|
-
The
|
|
774
|
+
subject_id (collections.abc.Hashable):
|
|
775
|
+
The subject ID for which to compute the weighted estimating function stack.
|
|
768
776
|
|
|
769
777
|
action_prob_func (callable):
|
|
770
778
|
The function used to compute the probability of action 1 at a given decision time for
|
|
771
|
-
a particular
|
|
779
|
+
a particular subject given their state and the algorithm parameters.
|
|
772
780
|
|
|
773
781
|
algorithm_estimating_func (callable):
|
|
774
782
|
The estimating function that corresponds to algorithm updates.
|
|
@@ -783,9 +791,9 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
783
791
|
The index of the theta parameter in the inference loss or estimating function arguments.
|
|
784
792
|
|
|
785
793
|
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
786
|
-
A map from decision times to tuples of arguments for this
|
|
794
|
+
A map from decision times to tuples of arguments for this subject for the action
|
|
787
795
|
probability function. This is for all decision times (args are an empty
|
|
788
|
-
tuple if they are not in the
|
|
796
|
+
tuple if they are not in the deployment). Should be sorted by decision time. NOTE THAT THESE
|
|
789
797
|
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
790
798
|
will occur.
|
|
791
799
|
|
|
@@ -796,21 +804,21 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
796
804
|
|
|
797
805
|
threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
798
806
|
A map from policy numbers to tuples containing the arguments for
|
|
799
|
-
the corresponding estimating functions for this
|
|
807
|
+
the corresponding estimating functions for this subject, with the shared betas threaded in
|
|
800
808
|
for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
|
|
801
809
|
should be sorted.
|
|
802
810
|
|
|
803
811
|
threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
804
812
|
A tuple containing the arguments for the inference
|
|
805
|
-
estimating function for this
|
|
813
|
+
estimating function for this subject, with the shared betas threaded in for differentiation.
|
|
806
814
|
|
|
807
815
|
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
808
816
|
A dictionary mapping decision times to the policy number in use. This may be
|
|
809
|
-
|
|
817
|
+
subject-specific. Should be sorted by decision time. Only applies to active decision
|
|
810
818
|
times!
|
|
811
819
|
|
|
812
820
|
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
813
|
-
A dictionary mapping decision times to actions taken. Only applies to
|
|
821
|
+
A dictionary mapping decision times to actions taken. Only applies to active decision
|
|
814
822
|
times!
|
|
815
823
|
|
|
816
824
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
@@ -818,19 +826,21 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
818
826
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
819
827
|
|
|
820
828
|
Returns:
|
|
821
|
-
jnp.ndarray: A 1-D JAX NumPy array representing the
|
|
829
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the subject's weighted estimating function
|
|
822
830
|
stack.
|
|
823
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
824
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
825
|
-
jnp.ndarray: A 2-D JAX NumPy matrix representing the
|
|
831
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's adjusted meat contribution.
|
|
832
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical meat contribution.
|
|
833
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the subject's classical bread contribution.
|
|
826
834
|
"""
|
|
827
835
|
|
|
828
|
-
logger.info(
|
|
836
|
+
logger.info(
|
|
837
|
+
"Computing weighted estimating function stack for subject %s.", subject_id
|
|
838
|
+
)
|
|
829
839
|
|
|
830
840
|
# First, reformat the supplied data into more convenient structures.
|
|
831
841
|
|
|
832
842
|
# 1. Form a dictionary mapping policy numbers to the first time they were
|
|
833
|
-
# applicable (for this
|
|
843
|
+
# applicable (for this subject). Note that this includes ALL policies, initial
|
|
834
844
|
# fallbacks included.
|
|
835
845
|
# Collect the first time after the first update separately for convenience.
|
|
836
846
|
# These are both used to form the Radon-Nikodym weights for the right times.
|
|
@@ -839,38 +849,38 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
839
849
|
beta_index_by_policy_num,
|
|
840
850
|
)
|
|
841
851
|
|
|
842
|
-
# 2. Get the start and end times for this
|
|
843
|
-
|
|
844
|
-
|
|
852
|
+
# 2. Get the start and end times for this subject.
|
|
853
|
+
subject_start_time = math.inf
|
|
854
|
+
subject_end_time = -math.inf
|
|
845
855
|
for decision_time in action_by_decision_time:
|
|
846
|
-
|
|
847
|
-
|
|
856
|
+
subject_start_time = min(subject_start_time, decision_time)
|
|
857
|
+
subject_end_time = max(subject_end_time, decision_time)
|
|
848
858
|
|
|
849
859
|
# 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
|
|
850
860
|
logger.info(
|
|
851
|
-
"Computing the algorithm component of the weighted estimating function stack for
|
|
852
|
-
|
|
861
|
+
"Computing the algorithm component of the weighted estimating function stack for subject %s.",
|
|
862
|
+
subject_id,
|
|
853
863
|
)
|
|
854
864
|
|
|
855
|
-
|
|
865
|
+
active_action_prob_func_args = [
|
|
856
866
|
args for args in action_prob_func_args_by_decision_time.values() if args
|
|
857
867
|
]
|
|
858
|
-
|
|
868
|
+
active_betas_list_by_decision_time_index = jnp.array(
|
|
859
869
|
[
|
|
860
870
|
action_prob_func_args[action_prob_func_args_beta_index]
|
|
861
|
-
for action_prob_func_args in
|
|
871
|
+
for action_prob_func_args in active_action_prob_func_args
|
|
862
872
|
]
|
|
863
873
|
)
|
|
864
|
-
|
|
874
|
+
active_actions_list_by_decision_time_index = jnp.array(
|
|
865
875
|
list(action_by_decision_time.values())
|
|
866
876
|
)
|
|
867
877
|
|
|
868
878
|
# Sort the threaded args by decision time to be cautious. We check if the
|
|
869
|
-
#
|
|
870
|
-
# subset of the
|
|
879
|
+
# subject id is present in the subject args dict because we may call this on a
|
|
880
|
+
# subset of the subject arg dict when we are batching arguments by shape
|
|
871
881
|
sorted_threaded_action_prob_args_by_decision_time = {
|
|
872
882
|
decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
|
|
873
|
-
for decision_time in range(
|
|
883
|
+
for decision_time in range(subject_start_time, subject_end_time + 1)
|
|
874
884
|
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
875
885
|
}
|
|
876
886
|
|
|
@@ -901,19 +911,19 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
901
911
|
# Just grab the original beta from the update function arguments. This is the same
|
|
902
912
|
# value, but impervious to differentiation with respect to all_post_update_betas. The
|
|
903
913
|
# args, on the other hand, are a function of all_post_update_betas.
|
|
904
|
-
|
|
914
|
+
active_weights = jax.vmap(
|
|
905
915
|
fun=get_radon_nikodym_weight,
|
|
906
916
|
in_axes=[0, None, None, 0] + batch_axes,
|
|
907
917
|
out_axes=0,
|
|
908
918
|
)(
|
|
909
|
-
|
|
919
|
+
active_betas_list_by_decision_time_index,
|
|
910
920
|
action_prob_func,
|
|
911
921
|
action_prob_func_args_beta_index,
|
|
912
|
-
|
|
922
|
+
active_actions_list_by_decision_time_index,
|
|
913
923
|
*batched_threaded_arg_tensors,
|
|
914
924
|
)
|
|
915
925
|
|
|
916
|
-
|
|
926
|
+
active_index = 0
|
|
917
927
|
decision_time_to_all_weights_index_offset = min(
|
|
918
928
|
sorted_threaded_action_prob_args_by_decision_time
|
|
919
929
|
)
|
|
@@ -922,35 +932,35 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
922
932
|
decision_time,
|
|
923
933
|
args,
|
|
924
934
|
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
925
|
-
all_weights_raw.append(
|
|
926
|
-
|
|
935
|
+
all_weights_raw.append(active_weights[active_index] if args else 1.0)
|
|
936
|
+
active_index += 1
|
|
927
937
|
all_weights = jnp.array(all_weights_raw)
|
|
928
938
|
|
|
929
939
|
algorithm_component = jnp.concatenate(
|
|
930
940
|
[
|
|
931
941
|
# Here we compute a product of Radon-Nikodym weights
|
|
932
942
|
# for all decision times after the first update and before the update
|
|
933
|
-
# update under consideration took effect, for which the
|
|
943
|
+
# update under consideration took effect, for which the subject was in the deployment.
|
|
934
944
|
(
|
|
935
945
|
jnp.prod(
|
|
936
946
|
all_weights[
|
|
937
|
-
# The earliest time after the first update where the
|
|
938
|
-
# the
|
|
947
|
+
# The earliest time after the first update where the subject was in
|
|
948
|
+
# the deployment
|
|
939
949
|
max(
|
|
940
950
|
first_time_after_first_update,
|
|
941
|
-
|
|
951
|
+
subject_start_time,
|
|
942
952
|
)
|
|
943
953
|
- decision_time_to_all_weights_index_offset :
|
|
944
|
-
# One more than the latest time the
|
|
954
|
+
# One more than the latest time the subject was in the deployment before the time
|
|
945
955
|
# the update under consideration first applied. Note the + 1 because range
|
|
946
956
|
# does not include the right endpoint.
|
|
947
957
|
min(
|
|
948
958
|
min_time_by_policy_num.get(policy_num, math.inf),
|
|
949
|
-
|
|
959
|
+
subject_end_time + 1,
|
|
950
960
|
)
|
|
951
961
|
- decision_time_to_all_weights_index_offset,
|
|
952
962
|
]
|
|
953
|
-
# If the
|
|
963
|
+
# If the subject exited the deployment before there were any updates,
|
|
954
964
|
# this variable will be None and the above code to grab a weight would
|
|
955
965
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
956
966
|
# if they have data to contribute to the update.
|
|
@@ -958,8 +968,8 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
958
968
|
else 1
|
|
959
969
|
) # Now use the above to weight the alg estimating function for this update
|
|
960
970
|
* algorithm_estimating_func(*update_args)
|
|
961
|
-
# If there are no arguments for the update function, the
|
|
962
|
-
#
|
|
971
|
+
# If there are no arguments for the update function, the subject is not yet in the
|
|
972
|
+
# deployment, so we just add a zero vector contribution to the sum across subjects.
|
|
963
973
|
# Note that after they exit, they still contribute all their data to later
|
|
964
974
|
# updates.
|
|
965
975
|
if update_args
|
|
@@ -978,17 +988,17 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
978
988
|
)
|
|
979
989
|
# 4. Form the weighted inference estimating equation.
|
|
980
990
|
logger.info(
|
|
981
|
-
"Computing the inference component of the weighted estimating function stack for
|
|
982
|
-
|
|
991
|
+
"Computing the inference component of the weighted estimating function stack for subject %s.",
|
|
992
|
+
subject_id,
|
|
983
993
|
)
|
|
984
994
|
inference_component = jnp.prod(
|
|
985
995
|
all_weights[
|
|
986
|
-
max(first_time_after_first_update,
|
|
987
|
-
- decision_time_to_all_weights_index_offset :
|
|
996
|
+
max(first_time_after_first_update, subject_start_time)
|
|
997
|
+
- decision_time_to_all_weights_index_offset : subject_end_time
|
|
988
998
|
+ 1
|
|
989
999
|
- decision_time_to_all_weights_index_offset,
|
|
990
1000
|
]
|
|
991
|
-
# If the
|
|
1001
|
+
# If the subject exited the deployment before there were any updates,
|
|
992
1002
|
# this variable will be None and the above code to grab a weight would
|
|
993
1003
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
994
1004
|
# if they have data to contribute here (pretty sure everyone should?)
|
|
@@ -997,18 +1007,18 @@ def single_user_weighted_estimating_function_stacker(
|
|
|
997
1007
|
) * inference_estimating_func(*threaded_inference_func_args)
|
|
998
1008
|
|
|
999
1009
|
# 5. Concatenate the two components to form the weighted estimating function stack for this
|
|
1000
|
-
#
|
|
1010
|
+
# subject.
|
|
1001
1011
|
weighted_stack = jnp.concatenate([algorithm_component, inference_component])
|
|
1002
1012
|
|
|
1003
1013
|
# 6. Return the following outputs:
|
|
1004
|
-
# a. The first is simply the weighted estimating function stack for this
|
|
1005
|
-
# of these is what we differentiate with respect to theta to form the
|
|
1014
|
+
# a. The first is simply the weighted estimating function stack for this subject. The average
|
|
1015
|
+
# of these is what we differentiate with respect to theta to form the joint
|
|
1006
1016
|
# bread matrix, and we also compare that average to zero to check the estimating functions'
|
|
1007
1017
|
# fidelity.
|
|
1008
|
-
# b. The average outer product of these per-
|
|
1018
|
+
# b. The average outer product of these per-subject stacks across subjects is the adjusted joint meat
|
|
1009
1019
|
# matrix, hence the second output.
|
|
1010
|
-
# c. The third output is averaged across
|
|
1011
|
-
# d. The fourth output is averaged across
|
|
1020
|
+
# c. The third output is averaged across subjects to obtain the classical meat matrix.
|
|
1021
|
+
# d. The fourth output is averaged across subjects to obtain the inverse classical bread
|
|
1012
1022
|
# matrix.
|
|
1013
1023
|
return (
|
|
1014
1024
|
weighted_stack,
|
|
@@ -1024,7 +1034,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1024
1034
|
flattened_betas_and_theta: jnp.ndarray,
|
|
1025
1035
|
beta_dim: int,
|
|
1026
1036
|
theta_dim: int,
|
|
1027
|
-
|
|
1037
|
+
subject_ids: jnp.ndarray,
|
|
1028
1038
|
action_prob_func: callable,
|
|
1029
1039
|
action_prob_func_args_beta_index: int,
|
|
1030
1040
|
alg_update_func: callable,
|
|
@@ -1037,30 +1047,32 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1037
1047
|
inference_func_type: str,
|
|
1038
1048
|
inference_func_args_theta_index: int,
|
|
1039
1049
|
inference_func_args_action_prob_index: int,
|
|
1040
|
-
|
|
1050
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
1041
1051
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1042
1052
|
],
|
|
1043
|
-
|
|
1053
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
1044
1054
|
collections.abc.Hashable, dict[int, int | float]
|
|
1045
1055
|
],
|
|
1046
1056
|
initial_policy_num: int | float,
|
|
1047
1057
|
beta_index_by_policy_num: dict[int | float, int],
|
|
1048
|
-
|
|
1049
|
-
|
|
1058
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1059
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
1050
1060
|
collections.abc.Hashable, list[int]
|
|
1051
1061
|
],
|
|
1052
|
-
|
|
1062
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
1053
1063
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
1054
1064
|
],
|
|
1055
|
-
|
|
1065
|
+
action_by_decision_time_by_subject_id: dict[
|
|
1066
|
+
collections.abc.Hashable, dict[int, int]
|
|
1067
|
+
],
|
|
1056
1068
|
suppress_all_data_checks: bool,
|
|
1057
1069
|
suppress_interactive_data_checks: bool,
|
|
1058
1070
|
) -> tuple[
|
|
1059
1071
|
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
1060
1072
|
]:
|
|
1061
1073
|
"""
|
|
1062
|
-
Computes the average weighted estimating function stack across all
|
|
1063
|
-
auxiliary values used to construct the
|
|
1074
|
+
Computes the average weighted estimating function stack across all subjects, along with
|
|
1075
|
+
auxiliary values used to construct the adjusted and classical sandwich variances.
|
|
1064
1076
|
|
|
1065
1077
|
Args:
|
|
1066
1078
|
flattened_betas_and_theta (jnp.ndarray):
|
|
@@ -1071,8 +1083,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1071
1083
|
The dimension of each of the beta parameters.
|
|
1072
1084
|
theta_dim (int):
|
|
1073
1085
|
The dimension of the theta parameter.
|
|
1074
|
-
|
|
1075
|
-
A 1D JAX NumPy array of
|
|
1086
|
+
subject_ids (jnp.ndarray):
|
|
1087
|
+
A 1D JAX NumPy array of subject IDs.
|
|
1076
1088
|
action_prob_func (callable):
|
|
1077
1089
|
The action probability function.
|
|
1078
1090
|
action_prob_func_args_beta_index (int):
|
|
@@ -1100,29 +1112,29 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1100
1112
|
inference_func_args_action_prob_index (int):
|
|
1101
1113
|
The index of action probabilities in the inference function arguments tuple, if
|
|
1102
1114
|
applicable. -1 otherwise.
|
|
1103
|
-
|
|
1104
|
-
A dictionary mapping decision times to maps of
|
|
1105
|
-
required to compute action probabilities for this
|
|
1106
|
-
|
|
1107
|
-
A map of
|
|
1108
|
-
Only applies to
|
|
1115
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1116
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
1117
|
+
required to compute action probabilities for this subject.
|
|
1118
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1119
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
1120
|
+
Only applies to active decision times!
|
|
1109
1121
|
initial_policy_num (int | float):
|
|
1110
1122
|
The policy number of the initial policy before any updates.
|
|
1111
1123
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
1112
1124
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1113
1125
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1114
|
-
|
|
1115
|
-
A dictionary mapping
|
|
1116
|
-
|
|
1117
|
-
For each
|
|
1118
|
-
provided. Typically just
|
|
1126
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1127
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
1128
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
1129
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
1130
|
+
provided. Typically just active times if action probabilites are used in the inference
|
|
1119
1131
|
loss or estimating function.
|
|
1120
|
-
|
|
1121
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
1132
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
1133
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
1122
1134
|
to their respective update function arguments.
|
|
1123
|
-
|
|
1124
|
-
A dictionary mapping
|
|
1125
|
-
Only applies to
|
|
1135
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1136
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
1137
|
+
Only applies to active decision times!
|
|
1126
1138
|
suppress_all_data_checks (bool):
|
|
1127
1139
|
If True, suppresses carrying out any data checks at all.
|
|
1128
1140
|
suppress_interactive_data_checks (bool):
|
|
@@ -1136,10 +1148,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1136
1148
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1137
1149
|
A tuple containing
|
|
1138
1150
|
1. the average weighted estimating function stack
|
|
1139
|
-
2. the
|
|
1140
|
-
3. the
|
|
1141
|
-
4. the
|
|
1142
|
-
5. raw per-
|
|
1151
|
+
2. the subject-level adjusted meat matrix contributions
|
|
1152
|
+
3. the subject-level classical meat matrix contributions
|
|
1153
|
+
4. the subject-level inverse classical bread matrix contributions
|
|
1154
|
+
5. raw per-subject weighted estimating function
|
|
1143
1155
|
stacks.
|
|
1144
1156
|
"""
|
|
1145
1157
|
|
|
@@ -1166,15 +1178,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1166
1178
|
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
1167
1179
|
# values should be the same, but not connected to the parameter we are differentiating
|
|
1168
1180
|
# with respect to. Note we will also find it useful below to have the action probability args
|
|
1169
|
-
# nested dict structure flipped to be
|
|
1181
|
+
# nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
|
|
1170
1182
|
|
|
1171
|
-
logger.info("Threading in betas to action probability arguments for all
|
|
1183
|
+
logger.info("Threading in betas to action probability arguments for all subjects.")
|
|
1172
1184
|
(
|
|
1173
|
-
|
|
1174
|
-
|
|
1185
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1186
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
1175
1187
|
) = thread_action_prob_func_args(
|
|
1176
|
-
|
|
1177
|
-
|
|
1188
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1189
|
+
policy_num_by_decision_time_by_subject_id,
|
|
1178
1190
|
initial_policy_num,
|
|
1179
1191
|
betas,
|
|
1180
1192
|
beta_index_by_policy_num,
|
|
@@ -1186,17 +1198,17 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1186
1198
|
# arguments with the central betas introduced.
|
|
1187
1199
|
logger.info(
|
|
1188
1200
|
"Threading in betas and beta-dependent action probabilities to algorithm update "
|
|
1189
|
-
"function args for all
|
|
1201
|
+
"function args for all subjects"
|
|
1190
1202
|
)
|
|
1191
|
-
|
|
1192
|
-
|
|
1203
|
+
threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
|
|
1204
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1193
1205
|
betas,
|
|
1194
1206
|
beta_index_by_policy_num,
|
|
1195
1207
|
alg_update_func_args_beta_index,
|
|
1196
1208
|
alg_update_func_args_action_prob_index,
|
|
1197
1209
|
alg_update_func_args_action_prob_times_index,
|
|
1198
1210
|
alg_update_func_args_previous_betas_index,
|
|
1199
|
-
|
|
1211
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1200
1212
|
action_prob_func,
|
|
1201
1213
|
)
|
|
1202
1214
|
|
|
@@ -1206,8 +1218,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1206
1218
|
if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
|
|
1207
1219
|
input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
|
|
1208
1220
|
algorithm_estimating_func,
|
|
1209
|
-
|
|
1210
|
-
|
|
1221
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1222
|
+
threaded_update_func_args_by_policy_num_by_subject_id,
|
|
1211
1223
|
suppress_interactive_data_checks,
|
|
1212
1224
|
)
|
|
1213
1225
|
|
|
@@ -1216,15 +1228,15 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1216
1228
|
# arguments with the central betas introduced.
|
|
1217
1229
|
logger.info(
|
|
1218
1230
|
"Threading in theta and beta-dependent action probabilities to inference update "
|
|
1219
|
-
"function args for all
|
|
1231
|
+
"function args for all subjects"
|
|
1220
1232
|
)
|
|
1221
|
-
|
|
1222
|
-
|
|
1233
|
+
threaded_inference_func_args_by_subject_id = thread_inference_func_args(
|
|
1234
|
+
inference_func_args_by_subject_id,
|
|
1223
1235
|
inference_func_args_theta_index,
|
|
1224
1236
|
theta,
|
|
1225
1237
|
inference_func_args_action_prob_index,
|
|
1226
|
-
|
|
1227
|
-
|
|
1238
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1239
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
1228
1240
|
action_prob_func,
|
|
1229
1241
|
)
|
|
1230
1242
|
|
|
@@ -1234,32 +1246,32 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1234
1246
|
if not suppress_all_data_checks and inference_func_args_action_prob_index >= 0:
|
|
1235
1247
|
input_checks.require_threaded_inference_estimating_function_args_equivalent(
|
|
1236
1248
|
inference_estimating_func,
|
|
1237
|
-
|
|
1238
|
-
|
|
1249
|
+
inference_func_args_by_subject_id,
|
|
1250
|
+
threaded_inference_func_args_by_subject_id,
|
|
1239
1251
|
suppress_interactive_data_checks,
|
|
1240
1252
|
)
|
|
1241
1253
|
|
|
1242
|
-
# 5. Now we can compute the weighted estimating function stacks for all
|
|
1243
|
-
# as well as collect related values used to construct the
|
|
1254
|
+
# 5. Now we can compute the weighted estimating function stacks for all subjects
|
|
1255
|
+
# as well as collect related values used to construct the adjusted and classical
|
|
1244
1256
|
# sandwich variances.
|
|
1245
1257
|
results = [
|
|
1246
|
-
|
|
1258
|
+
single_subject_weighted_estimating_function_stacker(
|
|
1247
1259
|
beta_dim,
|
|
1248
|
-
|
|
1260
|
+
subject_id,
|
|
1249
1261
|
action_prob_func,
|
|
1250
1262
|
algorithm_estimating_func,
|
|
1251
1263
|
inference_estimating_func,
|
|
1252
1264
|
action_prob_func_args_beta_index,
|
|
1253
1265
|
inference_func_args_theta_index,
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1266
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1267
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
1268
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
|
|
1269
|
+
threaded_inference_func_args_by_subject_id[subject_id],
|
|
1270
|
+
policy_num_by_decision_time_by_subject_id[subject_id],
|
|
1271
|
+
action_by_decision_time_by_subject_id[subject_id],
|
|
1260
1272
|
beta_index_by_policy_num,
|
|
1261
1273
|
)
|
|
1262
|
-
for
|
|
1274
|
+
for subject_id in subject_ids.tolist()
|
|
1263
1275
|
]
|
|
1264
1276
|
|
|
1265
1277
|
stacks = jnp.array([result[0] for result in results])
|
|
@@ -1269,11 +1281,12 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1269
1281
|
|
|
1270
1282
|
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1271
1283
|
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1272
|
-
# for the
|
|
1273
|
-
#
|
|
1284
|
+
# for the estimating functions sum check, per_subject_classical_bread_contributions, and
|
|
1285
|
+
# classical meat and inverse read matrices. The raw per-subject stacks are also returned for
|
|
1286
|
+
# debugging purposes.
|
|
1274
1287
|
|
|
1275
|
-
# Note that returning the raw stacks here as the first
|
|
1276
|
-
# memory-intensive when combined with differentiation. Keep this in mind if the per-
|
|
1288
|
+
# Note that returning the raw stacks here as the first argument is potentially
|
|
1289
|
+
# memory-intensive when combined with differentiation. Keep this in mind if the per-subject bread
|
|
1277
1290
|
# inverse contributions are needed for something like CR2/CR3 small-sample corrections.
|
|
1278
1291
|
return jnp.mean(stacks, axis=0), (
|
|
1279
1292
|
jnp.mean(stacks, axis=0),
|
|
@@ -1284,10 +1297,10 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1284
1297
|
)
|
|
1285
1298
|
|
|
1286
1299
|
|
|
1287
|
-
def
|
|
1300
|
+
def construct_classical_and_adjusted_sandwiches(
|
|
1288
1301
|
theta_est: jnp.ndarray,
|
|
1289
1302
|
all_post_update_betas: jnp.ndarray,
|
|
1290
|
-
|
|
1303
|
+
subject_ids: jnp.ndarray,
|
|
1291
1304
|
action_prob_func: callable,
|
|
1292
1305
|
action_prob_func_args_beta_index: int,
|
|
1293
1306
|
alg_update_func: callable,
|
|
@@ -1300,32 +1313,34 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1300
1313
|
inference_func_type: str,
|
|
1301
1314
|
inference_func_args_theta_index: int,
|
|
1302
1315
|
inference_func_args_action_prob_index: int,
|
|
1303
|
-
|
|
1316
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
1304
1317
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1305
1318
|
],
|
|
1306
|
-
|
|
1319
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
1307
1320
|
collections.abc.Hashable, dict[int, int | float]
|
|
1308
1321
|
],
|
|
1309
1322
|
initial_policy_num: int | float,
|
|
1310
1323
|
beta_index_by_policy_num: dict[int | float, int],
|
|
1311
|
-
|
|
1312
|
-
|
|
1324
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1325
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
1313
1326
|
collections.abc.Hashable, list[int]
|
|
1314
1327
|
],
|
|
1315
|
-
|
|
1328
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
1316
1329
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
1317
1330
|
],
|
|
1318
|
-
|
|
1331
|
+
action_by_decision_time_by_subject_id: dict[
|
|
1332
|
+
collections.abc.Hashable, dict[int, int]
|
|
1333
|
+
],
|
|
1319
1334
|
suppress_all_data_checks: bool,
|
|
1320
1335
|
suppress_interactive_data_checks: bool,
|
|
1321
1336
|
small_sample_correction: str,
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1337
|
+
form_adjusted_meat_adjustments_explicitly: bool,
|
|
1338
|
+
stabilize_joint_bread: bool,
|
|
1339
|
+
analysis_df: pd.DataFrame | None,
|
|
1340
|
+
active_col_name: str | None,
|
|
1326
1341
|
action_col_name: str | None,
|
|
1327
1342
|
calendar_t_col_name: str | None,
|
|
1328
|
-
|
|
1343
|
+
subject_id_col_name: str | None,
|
|
1329
1344
|
action_prob_func_args: tuple | None,
|
|
1330
1345
|
action_prob_col_name: str | None,
|
|
1331
1346
|
) -> tuple[
|
|
@@ -1342,11 +1357,11 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1342
1357
|
jnp.ndarray[jnp.float32],
|
|
1343
1358
|
]:
|
|
1344
1359
|
"""
|
|
1345
|
-
Constructs the classical and
|
|
1360
|
+
Constructs the classical and adjusted sandwich matrices, as well as various
|
|
1346
1361
|
intermediate pieces in their consruction.
|
|
1347
1362
|
|
|
1348
1363
|
This is done by computing and differentiating the average weighted estimating function stack
|
|
1349
|
-
with respect to the betas and theta, using the resulting Jacobian to compute the
|
|
1364
|
+
with respect to the betas and theta, using the resulting Jacobian to compute the bread
|
|
1350
1365
|
and meat matrices, and then stably computing sandwiches.
|
|
1351
1366
|
|
|
1352
1367
|
Args:
|
|
@@ -1354,8 +1369,8 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1354
1369
|
A 1-D JAX NumPy array representing the parameter estimate for inference.
|
|
1355
1370
|
all_post_update_betas (jnp.ndarray):
|
|
1356
1371
|
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
1357
|
-
|
|
1358
|
-
A 1-D JAX NumPy array holding all
|
|
1372
|
+
subject_ids (jnp.ndarray):
|
|
1373
|
+
A 1-D JAX NumPy array holding all subject IDs in the deployment.
|
|
1359
1374
|
action_prob_func (callable):
|
|
1360
1375
|
The action probability function.
|
|
1361
1376
|
action_prob_func_args_beta_index (int):
|
|
@@ -1383,29 +1398,29 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1383
1398
|
inference_func_args_action_prob_index (int):
|
|
1384
1399
|
The index of action probabilities in the inference function arguments tuple, if
|
|
1385
1400
|
applicable. -1 otherwise.
|
|
1386
|
-
|
|
1387
|
-
A dictionary mapping decision times to maps of
|
|
1388
|
-
required to compute action probabilities for this
|
|
1389
|
-
|
|
1390
|
-
A map of
|
|
1391
|
-
Only applies to
|
|
1401
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1402
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
1403
|
+
required to compute action probabilities for this subject.
|
|
1404
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1405
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
1406
|
+
Only applies to active decision times!
|
|
1392
1407
|
initial_policy_num (int | float):
|
|
1393
1408
|
The policy number of the initial policy before any updates.
|
|
1394
1409
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
1395
1410
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1396
1411
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1397
|
-
|
|
1398
|
-
A dictionary mapping
|
|
1399
|
-
|
|
1400
|
-
For each
|
|
1401
|
-
provided. Typically just
|
|
1412
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1413
|
+
A dictionary mapping subject IDs to their respective inference function arguments.
|
|
1414
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
1415
|
+
For each subject, a list of decision times to which action probabilities correspond if
|
|
1416
|
+
provided. Typically just active times if action probabilites are used in the inference
|
|
1402
1417
|
loss or estimating function.
|
|
1403
|
-
|
|
1404
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
1418
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
1419
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
1405
1420
|
to their respective update function arguments.
|
|
1406
|
-
|
|
1407
|
-
A dictionary mapping
|
|
1408
|
-
Only applies to
|
|
1421
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1422
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
1423
|
+
Only applies to active decision times!
|
|
1409
1424
|
suppress_all_data_checks (bool):
|
|
1410
1425
|
If True, suppresses carrying out any data checks at all.
|
|
1411
1426
|
suppress_interactive_data_checks (bool):
|
|
@@ -1415,43 +1430,43 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1415
1430
|
small_sample_correction (str):
|
|
1416
1431
|
The type of small sample correction to apply. See SmallSampleCorrections class for
|
|
1417
1432
|
options.
|
|
1418
|
-
|
|
1419
|
-
If True, explicitly forms the per-
|
|
1433
|
+
form_adjusted_meat_adjustments_explicitly (bool):
|
|
1434
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
|
|
1420
1435
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
If True, will apply various techniques to stabilize the joint
|
|
1424
|
-
|
|
1425
|
-
The full
|
|
1426
|
-
|
|
1427
|
-
The name of the column in
|
|
1436
|
+
adjusted sandwich is formed without doing this.
|
|
1437
|
+
stabilize_joint_bread (bool):
|
|
1438
|
+
If True, will apply various techniques to stabilize the joint bread if necessary.
|
|
1439
|
+
analysis_df (pd.DataFrame):
|
|
1440
|
+
The full analysis dataframe, needed if forming the adjusted meat adjustments explicitly.
|
|
1441
|
+
active_col_name (str):
|
|
1442
|
+
The name of the column in analysis_df indicating whether a subject is active at a given decision time.
|
|
1428
1443
|
action_col_name (str):
|
|
1429
|
-
The name of the column in
|
|
1444
|
+
The name of the column in analysis_df indicating the action taken at a given decision time.
|
|
1430
1445
|
calendar_t_col_name (str):
|
|
1431
|
-
The name of the column in
|
|
1432
|
-
|
|
1433
|
-
The name of the column in
|
|
1446
|
+
The name of the column in analysis_df indicating the calendar time of a given decision time.
|
|
1447
|
+
subject_id_col_name (str):
|
|
1448
|
+
The name of the column in analysis_df indicating the subject ID.
|
|
1434
1449
|
action_prob_func_args (tuple):
|
|
1435
1450
|
The arguments to be passed to the action probability function, needed if forming the
|
|
1436
|
-
|
|
1451
|
+
adjusted meat adjustments explicitly.
|
|
1437
1452
|
action_prob_col_name (str):
|
|
1438
|
-
The name of the column in
|
|
1439
|
-
needed if forming the
|
|
1453
|
+
The name of the column in analysis_df indicating the action probability of the action taken,
|
|
1454
|
+
needed if forming the adjusted meat adjustments explicitly.
|
|
1440
1455
|
Returns:
|
|
1441
1456
|
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
1442
1457
|
A tuple containing:
|
|
1443
|
-
- The raw joint
|
|
1444
|
-
- The (possibly) stabilized joint
|
|
1445
|
-
- The joint
|
|
1446
|
-
- The joint
|
|
1447
|
-
- The classical
|
|
1458
|
+
- The raw joint bread matrix.
|
|
1459
|
+
- The (possibly) stabilized joint bread matrix.
|
|
1460
|
+
- The joint meat matrix.
|
|
1461
|
+
- The joint sandwich matrix.
|
|
1462
|
+
- The classical bread matrix.
|
|
1448
1463
|
- The classical meat matrix.
|
|
1449
1464
|
- The classical sandwich matrix.
|
|
1450
1465
|
- The average weighted estimating function stack.
|
|
1451
|
-
- All per-
|
|
1452
|
-
- The per-
|
|
1453
|
-
- The per-
|
|
1454
|
-
- The per-
|
|
1466
|
+
- All per-subject weighted estimating function stacks.
|
|
1467
|
+
- The per-subject adjusted meat small-sample corrections.
|
|
1468
|
+
- The per-subject classical meat small-sample corrections.
|
|
1469
|
+
- The per-subject adjusted meat adjustments, if form_adjusted_meat_adjustments_explicitly
|
|
1455
1470
|
is True, otherwise an array of NaNs.
|
|
1456
1471
|
"""
|
|
1457
1472
|
logger.info(
|
|
@@ -1459,13 +1474,13 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1459
1474
|
)
|
|
1460
1475
|
theta_dim = theta_est.shape[0]
|
|
1461
1476
|
beta_dim = all_post_update_betas.shape[1]
|
|
1462
|
-
# Note that these "contributions" are per-
|
|
1463
|
-
|
|
1477
|
+
# Note that these "contributions" are per-subject Jacobians of the weighted estimating function stack.
|
|
1478
|
+
raw_joint_adjusted_bread_matrix, (
|
|
1464
1479
|
avg_estimating_function_stack,
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1480
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1481
|
+
per_subject_classical_meat_contributions,
|
|
1482
|
+
per_subject_classical_bread_contributions,
|
|
1483
|
+
per_subject_estimating_function_stacks,
|
|
1469
1484
|
) = jax.jacrev(
|
|
1470
1485
|
get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
|
|
1471
1486
|
)(
|
|
@@ -1475,7 +1490,7 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1475
1490
|
flatten_params(all_post_update_betas, theta_est),
|
|
1476
1491
|
beta_dim,
|
|
1477
1492
|
theta_dim,
|
|
1478
|
-
|
|
1493
|
+
subject_ids,
|
|
1479
1494
|
action_prob_func,
|
|
1480
1495
|
action_prob_func_args_beta_index,
|
|
1481
1496
|
alg_update_func,
|
|
@@ -1488,166 +1503,164 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1488
1503
|
inference_func_type,
|
|
1489
1504
|
inference_func_args_theta_index,
|
|
1490
1505
|
inference_func_args_action_prob_index,
|
|
1491
|
-
|
|
1492
|
-
|
|
1506
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1507
|
+
policy_num_by_decision_time_by_subject_id,
|
|
1493
1508
|
initial_policy_num,
|
|
1494
1509
|
beta_index_by_policy_num,
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1510
|
+
inference_func_args_by_subject_id,
|
|
1511
|
+
inference_action_prob_decision_times_by_subject_id,
|
|
1512
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1513
|
+
action_by_decision_time_by_subject_id,
|
|
1499
1514
|
suppress_all_data_checks,
|
|
1500
1515
|
suppress_interactive_data_checks,
|
|
1501
1516
|
)
|
|
1502
1517
|
|
|
1503
|
-
|
|
1518
|
+
num_subjects = len(subject_ids)
|
|
1504
1519
|
|
|
1505
1520
|
(
|
|
1506
|
-
|
|
1521
|
+
joint_adjusted_meat_matrix,
|
|
1507
1522
|
classical_meat_matrix,
|
|
1508
|
-
|
|
1509
|
-
|
|
1523
|
+
per_subject_adjusted_corrections,
|
|
1524
|
+
per_subject_classical_corrections,
|
|
1510
1525
|
) = perform_desired_small_sample_correction(
|
|
1511
1526
|
small_sample_correction,
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1527
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1528
|
+
per_subject_classical_meat_contributions,
|
|
1529
|
+
per_subject_classical_bread_contributions,
|
|
1530
|
+
num_subjects,
|
|
1516
1531
|
theta_dim,
|
|
1517
1532
|
)
|
|
1518
1533
|
|
|
1519
1534
|
# Increase diagonal block dominance possibly improve conditioning of diagonal
|
|
1520
|
-
# blocks as necessary, to ensure mathematical stability of joint bread
|
|
1521
|
-
|
|
1535
|
+
# blocks as necessary, to ensure mathematical stability of joint bread
|
|
1536
|
+
stabilized_joint_adjusted_bread_matrix = (
|
|
1522
1537
|
(
|
|
1523
|
-
|
|
1524
|
-
|
|
1538
|
+
stabilize_joint_bread_if_necessary(
|
|
1539
|
+
raw_joint_adjusted_bread_matrix,
|
|
1525
1540
|
beta_dim,
|
|
1526
1541
|
theta_dim,
|
|
1527
1542
|
)
|
|
1528
1543
|
)
|
|
1529
|
-
if
|
|
1530
|
-
else
|
|
1544
|
+
if stabilize_joint_bread
|
|
1545
|
+
else raw_joint_adjusted_bread_matrix
|
|
1531
1546
|
)
|
|
1532
1547
|
|
|
1533
1548
|
# Now stably (no explicit inversion) form our sandwiches.
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
method=SandwichFormationMethods.
|
|
1539
|
-
)
|
|
1540
|
-
classical_bread_inverse_matrix = jnp.mean(
|
|
1541
|
-
per_user_classical_bread_inverse_contributions, axis=0
|
|
1549
|
+
joint_adjusted_sandwich = form_sandwich_from_bread_and_meat(
|
|
1550
|
+
stabilized_joint_adjusted_bread_matrix,
|
|
1551
|
+
joint_adjusted_meat_matrix,
|
|
1552
|
+
num_subjects,
|
|
1553
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1542
1554
|
)
|
|
1543
|
-
|
|
1544
|
-
|
|
1555
|
+
classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
|
|
1556
|
+
classical_sandwich = form_sandwich_from_bread_and_meat(
|
|
1557
|
+
classical_bread_matrix,
|
|
1545
1558
|
classical_meat_matrix,
|
|
1546
|
-
|
|
1547
|
-
method=SandwichFormationMethods.
|
|
1559
|
+
num_subjects,
|
|
1560
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1548
1561
|
)
|
|
1549
1562
|
|
|
1550
|
-
|
|
1551
|
-
(len(
|
|
1563
|
+
per_subject_adjusted_meat_adjustments = jnp.full(
|
|
1564
|
+
(len(subject_ids), theta_dim, theta_dim), jnp.nan
|
|
1552
1565
|
)
|
|
1553
|
-
if
|
|
1554
|
-
|
|
1555
|
-
|
|
1566
|
+
if form_adjusted_meat_adjustments_explicitly:
|
|
1567
|
+
per_subject_adjusted_classical_meat_contributions = (
|
|
1568
|
+
form_adjusted_meat_adjustments_directly(
|
|
1556
1569
|
theta_dim,
|
|
1557
1570
|
all_post_update_betas.shape[1],
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1571
|
+
stabilized_joint_adjusted_bread_matrix,
|
|
1572
|
+
per_subject_estimating_function_stacks,
|
|
1573
|
+
analysis_df,
|
|
1574
|
+
active_col_name,
|
|
1562
1575
|
action_col_name,
|
|
1563
1576
|
calendar_t_col_name,
|
|
1564
|
-
|
|
1577
|
+
subject_id_col_name,
|
|
1565
1578
|
action_prob_func,
|
|
1566
1579
|
action_prob_func_args,
|
|
1567
1580
|
action_prob_func_args_beta_index,
|
|
1568
1581
|
theta_est,
|
|
1569
1582
|
inference_func,
|
|
1570
1583
|
inference_func_args_theta_index,
|
|
1571
|
-
|
|
1584
|
+
subject_ids,
|
|
1572
1585
|
action_prob_col_name,
|
|
1573
1586
|
)
|
|
1574
1587
|
)
|
|
1575
|
-
# Validate that the
|
|
1576
|
-
# the theta-only
|
|
1577
|
-
# we get by taking a subset of the joint
|
|
1588
|
+
# Validate that the adjusted meat adjustments we just formed are accurate by constructing
|
|
1589
|
+
# the theta-only adjusted sandwich from them and checking that it matches the standard result
|
|
1590
|
+
# we get by taking a subset of the joint sandwich.
|
|
1578
1591
|
# First just apply any small-sample correction for parity.
|
|
1579
1592
|
(
|
|
1580
1593
|
_,
|
|
1581
|
-
|
|
1594
|
+
theta_only_adjusted_meat_matrix_v2,
|
|
1582
1595
|
_,
|
|
1583
1596
|
_,
|
|
1584
1597
|
) = perform_desired_small_sample_correction(
|
|
1585
1598
|
small_sample_correction,
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1599
|
+
per_subject_joint_adjusted_meat_contributions,
|
|
1600
|
+
per_subject_adjusted_classical_meat_contributions,
|
|
1601
|
+
per_subject_classical_bread_contributions,
|
|
1602
|
+
num_subjects,
|
|
1590
1603
|
theta_dim,
|
|
1591
1604
|
)
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
method=SandwichFormationMethods.
|
|
1605
|
+
theta_only_adjusted_sandwich_from_adjustments = (
|
|
1606
|
+
form_sandwich_from_bread_and_meat(
|
|
1607
|
+
classical_bread_matrix,
|
|
1608
|
+
theta_only_adjusted_meat_matrix_v2,
|
|
1609
|
+
num_subjects,
|
|
1610
|
+
method=SandwichFormationMethods.BREAD_T_QR,
|
|
1598
1611
|
)
|
|
1599
1612
|
)
|
|
1600
|
-
|
|
1613
|
+
theta_only_adjusted_sandwich = joint_adjusted_sandwich[-theta_dim:, -theta_dim:]
|
|
1601
1614
|
|
|
1602
1615
|
if not np.allclose(
|
|
1603
|
-
|
|
1604
|
-
|
|
1616
|
+
theta_only_adjusted_sandwich,
|
|
1617
|
+
theta_only_adjusted_sandwich_from_adjustments,
|
|
1605
1618
|
rtol=3e-2,
|
|
1606
1619
|
):
|
|
1607
1620
|
logger.warning(
|
|
1608
|
-
"There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only
|
|
1621
|
+
"There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adjusted sandwich two different ways and they do not match sufficiently."
|
|
1609
1622
|
)
|
|
1610
1623
|
|
|
1611
|
-
# Stack the joint
|
|
1612
|
-
# values too. The joint
|
|
1624
|
+
# Stack the joint bread pieces together horizontally and return the auxiliary
|
|
1625
|
+
# values too. The joint bread should always be block lower triangular.
|
|
1613
1626
|
return (
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
|
|
1627
|
+
raw_joint_adjusted_bread_matrix,
|
|
1628
|
+
stabilized_joint_adjusted_bread_matrix,
|
|
1629
|
+
joint_adjusted_meat_matrix,
|
|
1630
|
+
joint_adjusted_sandwich,
|
|
1631
|
+
classical_bread_matrix,
|
|
1619
1632
|
classical_meat_matrix,
|
|
1620
1633
|
classical_sandwich,
|
|
1621
1634
|
avg_estimating_function_stack,
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1635
|
+
per_subject_estimating_function_stacks,
|
|
1636
|
+
per_subject_adjusted_corrections,
|
|
1637
|
+
per_subject_classical_corrections,
|
|
1638
|
+
per_subject_adjusted_meat_adjustments,
|
|
1626
1639
|
)
|
|
1627
1640
|
|
|
1628
1641
|
|
|
1629
1642
|
# TODO: I think there should be interaction to confirm stabilization. It is
|
|
1630
|
-
# important for the
|
|
1631
|
-
# that the
|
|
1632
|
-
def
|
|
1633
|
-
|
|
1643
|
+
# important for the subject to know if this is happening. Even if enabled, it is important
|
|
1644
|
+
# that the subject know it actually kicks in.
|
|
1645
|
+
def stabilize_joint_bread_if_necessary(
|
|
1646
|
+
joint_adjusted_bread_matrix: jnp.ndarray,
|
|
1634
1647
|
beta_dim: int,
|
|
1635
1648
|
theta_dim: int,
|
|
1636
1649
|
) -> jnp.ndarray:
|
|
1637
1650
|
"""
|
|
1638
|
-
Stabilizes the joint
|
|
1651
|
+
Stabilizes the joint bread matrix if necessary by increasing diagonal block
|
|
1639
1652
|
dominance and/or adding a small ridge penalty to the diagonal blocks.
|
|
1640
1653
|
|
|
1641
1654
|
Args:
|
|
1642
|
-
|
|
1643
|
-
A 2-D JAX NumPy array representing the joint
|
|
1655
|
+
joint_adjusted_bread_matrix (jnp.ndarray):
|
|
1656
|
+
A 2-D JAX NumPy array representing the joint bread matrix.
|
|
1644
1657
|
beta_dim (int):
|
|
1645
1658
|
The dimension of each beta parameter.
|
|
1646
1659
|
theta_dim (int):
|
|
1647
1660
|
The dimension of the theta parameter.
|
|
1648
1661
|
Returns:
|
|
1649
1662
|
jnp.ndarray:
|
|
1650
|
-
A 2-D NumPy array representing the stabilized joint
|
|
1663
|
+
A 2-D NumPy array representing the stabilized joint bread matrix.
|
|
1651
1664
|
"""
|
|
1652
1665
|
|
|
1653
1666
|
# TODO: come up with more sophisticated settings here. These are maybe a little loose,
|
|
@@ -1660,7 +1673,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1660
1673
|
|
|
1661
1674
|
# Grab just the RL block and convert numpy array for easier manipulation.
|
|
1662
1675
|
RL_stack_beta_derivatives_block = np.array(
|
|
1663
|
-
|
|
1676
|
+
joint_adjusted_bread_matrix[:-theta_dim, :-theta_dim]
|
|
1664
1677
|
)
|
|
1665
1678
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
1666
1679
|
for i in range(1, num_updates + 1):
|
|
@@ -1688,7 +1701,7 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1688
1701
|
RL_stack_beta_derivatives_block[
|
|
1689
1702
|
diagonal_block_slice, diagonal_block_slice
|
|
1690
1703
|
] = diagonal_block + ridge_penalty * np.eye(beta_dim)
|
|
1691
|
-
# TODO: Require
|
|
1704
|
+
# TODO: Require subject input here in interactive settings?
|
|
1692
1705
|
logger.info(
|
|
1693
1706
|
"Added ridge penalty of %s to diagonal block for update %s to improve conditioning from %s to %s",
|
|
1694
1707
|
ridge_penalty,
|
|
@@ -1779,44 +1792,44 @@ def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
|
1779
1792
|
[
|
|
1780
1793
|
[
|
|
1781
1794
|
RL_stack_beta_derivatives_block,
|
|
1782
|
-
|
|
1795
|
+
joint_adjusted_bread_matrix[:-theta_dim, -theta_dim:],
|
|
1783
1796
|
],
|
|
1784
1797
|
[
|
|
1785
|
-
|
|
1786
|
-
|
|
1798
|
+
joint_adjusted_bread_matrix[-theta_dim:, :-theta_dim],
|
|
1799
|
+
joint_adjusted_bread_matrix[-theta_dim:, -theta_dim:],
|
|
1787
1800
|
],
|
|
1788
1801
|
]
|
|
1789
1802
|
)
|
|
1790
1803
|
|
|
1791
1804
|
|
|
1792
|
-
def
|
|
1793
|
-
|
|
1805
|
+
def form_sandwich_from_bread_and_meat(
|
|
1806
|
+
bread: jnp.ndarray,
|
|
1794
1807
|
meat: jnp.ndarray,
|
|
1795
|
-
|
|
1796
|
-
method: str = SandwichFormationMethods.
|
|
1808
|
+
num_subjects: int,
|
|
1809
|
+
method: str = SandwichFormationMethods.BREAD_T_QR,
|
|
1797
1810
|
) -> jnp.ndarray:
|
|
1798
1811
|
"""
|
|
1799
|
-
Forms a sandwich variance matrix from the provided bread
|
|
1812
|
+
Forms a sandwich variance matrix from the provided bread and meat matrices.
|
|
1800
1813
|
|
|
1801
|
-
Attempts to do so STABLY without ever forming the bread matrix itself
|
|
1814
|
+
Attempts to do so STABLY without ever forming the bread inverse matrix itself
|
|
1802
1815
|
(except with naive option).
|
|
1803
1816
|
|
|
1804
1817
|
Args:
|
|
1805
|
-
|
|
1806
|
-
A 2-D JAX NumPy array representing the bread
|
|
1818
|
+
bread (jnp.ndarray):
|
|
1819
|
+
A 2-D JAX NumPy array representing the bread matrix.
|
|
1807
1820
|
meat (jnp.ndarray):
|
|
1808
1821
|
A 2-D JAX NumPy array representing the meat matrix.
|
|
1809
|
-
|
|
1810
|
-
The number of
|
|
1822
|
+
num_subjects (int):
|
|
1823
|
+
The number of subjects in the deployment, used to scale the sandwich appropriately.
|
|
1811
1824
|
method (str):
|
|
1812
1825
|
The method to use for forming the sandwich.
|
|
1813
1826
|
|
|
1814
|
-
SandwichFormationMethods.
|
|
1815
|
-
of the bread
|
|
1827
|
+
SandwichFormationMethods.BREAD_T_QR uses the QR decomposition of the transpose
|
|
1828
|
+
of the bread matrix.
|
|
1816
1829
|
|
|
1817
1830
|
SandwichFormationMethods.MEAT_SVD_SOLVE uses a decomposition of the meat matrix.
|
|
1818
1831
|
|
|
1819
|
-
SandwichFormationMethods.NAIVE simply inverts the bread
|
|
1832
|
+
SandwichFormationMethods.NAIVE simply inverts the bread and forms the sandwich.
|
|
1820
1833
|
|
|
1821
1834
|
|
|
1822
1835
|
Returns:
|
|
@@ -1824,16 +1837,16 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1824
1837
|
A 2-D JAX NumPy array representing the sandwich variance matrix.
|
|
1825
1838
|
"""
|
|
1826
1839
|
|
|
1827
|
-
if method == SandwichFormationMethods.
|
|
1840
|
+
if method == SandwichFormationMethods.BREAD_T_QR:
|
|
1828
1841
|
# QR of B^T → Q orthogonal, R upper triangular; L = R^T lower triangular
|
|
1829
|
-
Q, R = np.linalg.qr(
|
|
1842
|
+
Q, R = np.linalg.qr(bread.T, mode="reduced")
|
|
1830
1843
|
L = R.T
|
|
1831
1844
|
|
|
1832
1845
|
new_meat = scipy.linalg.solve_triangular(
|
|
1833
1846
|
L, scipy.linalg.solve_triangular(L, meat.T, lower=True).T, lower=True
|
|
1834
1847
|
)
|
|
1835
1848
|
|
|
1836
|
-
return Q @ new_meat @ Q.T /
|
|
1849
|
+
return Q @ new_meat @ Q.T / num_subjects
|
|
1837
1850
|
elif method == SandwichFormationMethods.MEAT_SVD_SOLVE:
|
|
1838
1851
|
# Factor the meat via SVD without any symmetrization or truncation.
|
|
1839
1852
|
# For general (possibly slightly nonsymmetric) M, SVD gives M = U @ diag(s) @ Vh.
|
|
@@ -1844,21 +1857,21 @@ def form_sandwich_from_bread_inverse_and_meat(
|
|
|
1844
1857
|
C_right = Vh.T * np.sqrt(s)
|
|
1845
1858
|
|
|
1846
1859
|
# Solve B W_left = C_left and B W_right = C_right (no explicit inverses).
|
|
1847
|
-
W_left = scipy.linalg.solve(
|
|
1848
|
-
W_right = scipy.linalg.solve(
|
|
1860
|
+
W_left = scipy.linalg.solve(bread, C_left)
|
|
1861
|
+
W_right = scipy.linalg.solve(bread, C_right)
|
|
1849
1862
|
|
|
1850
|
-
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T /
|
|
1851
|
-
return W_left @ W_right.T /
|
|
1863
|
+
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_subjects
|
|
1864
|
+
return W_left @ W_right.T / num_subjects
|
|
1852
1865
|
|
|
1853
1866
|
elif method == SandwichFormationMethods.NAIVE:
|
|
1854
|
-
# Simply invert the bread
|
|
1867
|
+
# Simply invert the bread and form the sandwich directly.
|
|
1855
1868
|
# This is NOT numerically stable and is only included for comparison purposes.
|
|
1856
|
-
|
|
1857
|
-
return
|
|
1869
|
+
bread_inverse = np.linalg.inv(bread)
|
|
1870
|
+
return bread_inverse @ meat @ bread_inverse.T / num_subjects
|
|
1858
1871
|
|
|
1859
1872
|
else:
|
|
1860
1873
|
raise ValueError(
|
|
1861
|
-
f"Unknown sandwich method: {method}. Please use '
|
|
1874
|
+
f"Unknown sandwich method: {method}. Please use 'bread_t_qr' or 'meat_decomposition_solve'."
|
|
1862
1875
|
)
|
|
1863
1876
|
|
|
1864
1877
|
|