lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1845 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
import pathlib
|
|
5
|
+
import pickle
|
|
6
|
+
import logging
|
|
7
|
+
import math
|
|
8
|
+
from typing import Any, Callable
|
|
9
|
+
|
|
10
|
+
import click
|
|
11
|
+
import jax
|
|
12
|
+
import numpy as np
|
|
13
|
+
from jax import numpy as jnp
|
|
14
|
+
import scipy
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from .arg_threading_helpers import (
|
|
18
|
+
thread_action_prob_func_args,
|
|
19
|
+
thread_inference_func_args,
|
|
20
|
+
thread_update_func_args,
|
|
21
|
+
)
|
|
22
|
+
from .constants import (
|
|
23
|
+
FunctionTypes,
|
|
24
|
+
SandwichFormationMethods,
|
|
25
|
+
SmallSampleCorrections,
|
|
26
|
+
)
|
|
27
|
+
from .form_adaptive_meat_adjustments_directly import (
|
|
28
|
+
form_adaptive_meat_adjustments_directly,
|
|
29
|
+
)
|
|
30
|
+
from . import input_checks
|
|
31
|
+
from . import get_datum_for_blowup_supervised_learning
|
|
32
|
+
from .small_sample_corrections import perform_desired_small_sample_correction
|
|
33
|
+
from .vmap_helpers import stack_batched_arg_lists_into_tensors
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
from .helper_functions import (
|
|
37
|
+
calculate_beta_dim,
|
|
38
|
+
collect_all_post_update_betas,
|
|
39
|
+
construct_beta_index_by_policy_num_map,
|
|
40
|
+
extract_action_and_policy_by_decision_time_by_user_id,
|
|
41
|
+
flatten_params,
|
|
42
|
+
get_in_study_df_column,
|
|
43
|
+
get_min_time_by_policy_num,
|
|
44
|
+
get_radon_nikodym_weight,
|
|
45
|
+
load_function_from_same_named_file,
|
|
46
|
+
unflatten_params,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
logger = logging.getLogger(__name__)
|
|
50
|
+
logging.basicConfig(
|
|
51
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
52
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
53
|
+
level=logging.INFO,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@click.group()
|
|
58
|
+
def cli():
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# TODO: Check all help strings for accuracy.
|
|
63
|
+
# TODO: Deal with NA, -1, etc policy numbers
|
|
64
|
+
# TODO: Make sure in study is never on for more than one stretch EDIT: unclear if
|
|
65
|
+
# this will remain an invariant as we deal with more complicated data missingness
|
|
66
|
+
# TODO: I think I'm agnostic to indexing of calendar times but should check because
|
|
67
|
+
# otherwise need to add a check here to verify required format.
|
|
68
|
+
# TODO: Currently assuming function args can be placed in a numpy array. Must be scalar, 1d or 2d array.
|
|
69
|
+
# Higher dimensional objects not supported. Not entirely sure what kind of "scalars" apply.
|
|
70
|
+
@cli.command(name="analyze")
|
|
71
|
+
@click.option(
|
|
72
|
+
"--study_df_pickle",
|
|
73
|
+
type=click.File("rb"),
|
|
74
|
+
help="Pickled pandas dataframe in correct format (see contract/readme).",
|
|
75
|
+
required=True,
|
|
76
|
+
)
|
|
77
|
+
@click.option(
|
|
78
|
+
"--action_prob_func_filename",
|
|
79
|
+
type=click.Path(exists=True),
|
|
80
|
+
help="File that contains the action probability function and relevant imports. The filename without its extension will be assumed to match the function name.",
|
|
81
|
+
required=True,
|
|
82
|
+
)
|
|
83
|
+
@click.option(
|
|
84
|
+
"--action_prob_func_args_pickle",
|
|
85
|
+
type=click.File("rb"),
|
|
86
|
+
help="Pickled dictionary that contains the action probability function arguments for all decision times for all users.",
|
|
87
|
+
required=True,
|
|
88
|
+
)
|
|
89
|
+
@click.option(
|
|
90
|
+
"--action_prob_func_args_beta_index",
|
|
91
|
+
type=int,
|
|
92
|
+
required=True,
|
|
93
|
+
help="Index of the algorithm parameter vector beta in the tuple of action probability func args.",
|
|
94
|
+
)
|
|
95
|
+
@click.option(
|
|
96
|
+
"--alg_update_func_filename",
|
|
97
|
+
type=click.Path(exists=True),
|
|
98
|
+
help="File that contains the per-user update function used to determine the algorithm parameters at each update and relevant imports. May be a loss or estimating function, specified in a separate argument. The filename without its extension will be assumed to match the function name.",
|
|
99
|
+
required=True,
|
|
100
|
+
)
|
|
101
|
+
@click.option(
|
|
102
|
+
"--alg_update_func_type",
|
|
103
|
+
type=click.Choice([FunctionTypes.LOSS, FunctionTypes.ESTIMATING]),
|
|
104
|
+
help="Type of function used to summarize the algorithm updates. If loss, an update should correspond to choosing parameters to minimize it. If estimating, an update should correspond to setting the function equal to zero and solving for the parameters.",
|
|
105
|
+
required=True,
|
|
106
|
+
)
|
|
107
|
+
@click.option(
|
|
108
|
+
"--alg_update_func_args_pickle",
|
|
109
|
+
type=click.File("rb"),
|
|
110
|
+
help="Pickled dictionary that contains the algorithm update function arguments for all update times for all users.",
|
|
111
|
+
required=True,
|
|
112
|
+
)
|
|
113
|
+
@click.option(
|
|
114
|
+
"--alg_update_func_args_beta_index",
|
|
115
|
+
type=int,
|
|
116
|
+
required=True,
|
|
117
|
+
help="Index of the algorithm parameter vector beta in the tuple of algorithm update func args.",
|
|
118
|
+
)
|
|
119
|
+
@click.option(
|
|
120
|
+
"--alg_update_func_args_action_prob_index",
|
|
121
|
+
type=int,
|
|
122
|
+
default=-1000,
|
|
123
|
+
help="Index of the action probability in the tuple of algorithm update func args, if applicable.",
|
|
124
|
+
)
|
|
125
|
+
@click.option(
|
|
126
|
+
"--alg_update_func_args_action_prob_times_index",
|
|
127
|
+
type=int,
|
|
128
|
+
default=-1000,
|
|
129
|
+
help="Index of the argument holding the decision times the action probabilities correspond to in the tuple of algorithm update func args, if applicable.",
|
|
130
|
+
)
|
|
131
|
+
@click.option(
|
|
132
|
+
"--inference_func_filename",
|
|
133
|
+
type=click.Path(exists=True),
|
|
134
|
+
help="File that contains the per-user loss/estimating function used to determine the inference estimate and relevant imports. The filename without its extension will be assumed to match the function name.",
|
|
135
|
+
required=True,
|
|
136
|
+
)
|
|
137
|
+
@click.option(
|
|
138
|
+
"--inference_func_type",
|
|
139
|
+
type=click.Choice([FunctionTypes.LOSS, FunctionTypes.ESTIMATING]),
|
|
140
|
+
help="Type of function used to summarize inference. If loss, inference should correspond to choosing parameters to minimize it. If estimating, inference should correspond to setting the function equal to zero and solving for the parameters.",
|
|
141
|
+
required=True,
|
|
142
|
+
)
|
|
143
|
+
@click.option(
|
|
144
|
+
"--inference_func_args_theta_index",
|
|
145
|
+
type=int,
|
|
146
|
+
required=True,
|
|
147
|
+
help="Index of the algorithm parameter vector beta in the tuple of inference loss/estimating func args.",
|
|
148
|
+
)
|
|
149
|
+
@click.option(
|
|
150
|
+
"--theta_calculation_func_filename",
|
|
151
|
+
type=click.Path(exists=True),
|
|
152
|
+
help="Path to file that allows one to actually calculate a theta estimate given the study dataframe only. One must supply either this or a precomputed theta estimate. The filename without its extension will be assumed to match the function name.",
|
|
153
|
+
required=True,
|
|
154
|
+
)
|
|
155
|
+
@click.option(
|
|
156
|
+
"--in_study_col_name",
|
|
157
|
+
type=str,
|
|
158
|
+
required=True,
|
|
159
|
+
help="Name of the binary column in the study dataframe that indicates whether a user is in the study.",
|
|
160
|
+
)
|
|
161
|
+
@click.option(
|
|
162
|
+
"--action_col_name",
|
|
163
|
+
type=str,
|
|
164
|
+
required=True,
|
|
165
|
+
help="Name of the binary column in the study dataframe that indicates which action was taken.",
|
|
166
|
+
)
|
|
167
|
+
@click.option(
|
|
168
|
+
"--policy_num_col_name",
|
|
169
|
+
type=str,
|
|
170
|
+
required=True,
|
|
171
|
+
help="Name of the column in the study dataframe that indicates the policy number in use.",
|
|
172
|
+
)
|
|
173
|
+
@click.option(
|
|
174
|
+
"--calendar_t_col_name",
|
|
175
|
+
type=str,
|
|
176
|
+
required=True,
|
|
177
|
+
help="Name of the column in the study dataframe that indicates calendar time (shared integer index across users).",
|
|
178
|
+
)
|
|
179
|
+
@click.option(
|
|
180
|
+
"--user_id_col_name",
|
|
181
|
+
type=str,
|
|
182
|
+
required=True,
|
|
183
|
+
help="Name of the column in the study dataframe that indicates user id.",
|
|
184
|
+
)
|
|
185
|
+
@click.option(
|
|
186
|
+
"--action_prob_col_name",
|
|
187
|
+
type=str,
|
|
188
|
+
required=True,
|
|
189
|
+
help="Name of the column in the study dataframe that gives action one probabilities.",
|
|
190
|
+
)
|
|
191
|
+
@click.option(
|
|
192
|
+
"--reward_col_name",
|
|
193
|
+
type=str,
|
|
194
|
+
required=True,
|
|
195
|
+
help="Name of the column in the study dataframe that gives rewards.",
|
|
196
|
+
)
|
|
197
|
+
@click.option(
|
|
198
|
+
"--suppress_interactive_data_checks",
|
|
199
|
+
type=bool,
|
|
200
|
+
default=False,
|
|
201
|
+
help="Flag to suppress any data checks that require user input. This is suitable for tests and large simulations",
|
|
202
|
+
)
|
|
203
|
+
@click.option(
|
|
204
|
+
"--suppress_all_data_checks",
|
|
205
|
+
type=bool,
|
|
206
|
+
default=False,
|
|
207
|
+
help="Flag to suppress all data checks. Not usually recommended, as suppressing only interactive checks suffices to keep tests/simulations running and is safer.",
|
|
208
|
+
)
|
|
209
|
+
@click.option(
|
|
210
|
+
"--small_sample_correction",
|
|
211
|
+
type=click.Choice(
|
|
212
|
+
[
|
|
213
|
+
SmallSampleCorrections.NONE,
|
|
214
|
+
SmallSampleCorrections.HC1theta,
|
|
215
|
+
SmallSampleCorrections.HC2theta,
|
|
216
|
+
SmallSampleCorrections.HC3theta,
|
|
217
|
+
]
|
|
218
|
+
),
|
|
219
|
+
default=SmallSampleCorrections.NONE,
|
|
220
|
+
help="Type of small sample correction to apply to the variance estimate",
|
|
221
|
+
)
|
|
222
|
+
@click.option(
|
|
223
|
+
"--collect_data_for_blowup_supervised_learning",
|
|
224
|
+
type=bool,
|
|
225
|
+
default=False,
|
|
226
|
+
help="Flag to collect data for supervised learning blowup detection. This will write a single datum and label to a file in the same directory as the input files.",
|
|
227
|
+
)
|
|
228
|
+
@click.option(
|
|
229
|
+
"--form_adaptive_meat_adjustments_explicitly",
|
|
230
|
+
type=bool,
|
|
231
|
+
default=False,
|
|
232
|
+
help="If True, explicitly forms the per-user meat adjustments that differentiate the adaptive sandwich from the classical sandwich. This is for diagnostic purposes, as the adaptive sandwich is formed without doing this.",
|
|
233
|
+
)
|
|
234
|
+
@click.option(
|
|
235
|
+
"--stabilize_joint_adaptive_bread_inverse",
|
|
236
|
+
type=bool,
|
|
237
|
+
default=True,
|
|
238
|
+
help="If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning thresholds.",
|
|
239
|
+
)
|
|
240
|
+
def analyze_dataset_wrapper(**kwargs):
|
|
241
|
+
"""
|
|
242
|
+
This function is a wrapper around analyze_dataset to facilitate command line use.
|
|
243
|
+
|
|
244
|
+
From the command line, we will take pickles and filenames for Python objects.
|
|
245
|
+
Unpickle/load files here for passing to the implementation function, which
|
|
246
|
+
may also be called in its own right with in-memory objects.
|
|
247
|
+
|
|
248
|
+
See analyze_dataset for the underlying details.
|
|
249
|
+
|
|
250
|
+
Returns: None
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
# Pass along the folder the study dataframe is in as the output folder.
|
|
254
|
+
# Do it now because we will be removing the study dataframe pickle from kwargs.
|
|
255
|
+
kwargs["output_dir"] = pathlib.Path(kwargs["study_df_pickle"].name).parent.resolve()
|
|
256
|
+
|
|
257
|
+
# Unpickle pickles and replace those args in kwargs
|
|
258
|
+
kwargs["study_df"] = pickle.load(kwargs["study_df_pickle"])
|
|
259
|
+
kwargs["action_prob_func_args"] = pickle.load(
|
|
260
|
+
kwargs["action_prob_func_args_pickle"]
|
|
261
|
+
)
|
|
262
|
+
kwargs["alg_update_func_args"] = pickle.load(kwargs["alg_update_func_args_pickle"])
|
|
263
|
+
|
|
264
|
+
kwargs.pop("study_df_pickle")
|
|
265
|
+
kwargs.pop("action_prob_func_args_pickle")
|
|
266
|
+
kwargs.pop("alg_update_func_args_pickle")
|
|
267
|
+
|
|
268
|
+
# Load functions from filenames and replace those args in kwargs
|
|
269
|
+
kwargs["action_prob_func"] = load_function_from_same_named_file(
|
|
270
|
+
kwargs["action_prob_func_filename"]
|
|
271
|
+
)
|
|
272
|
+
kwargs["alg_update_func"] = load_function_from_same_named_file(
|
|
273
|
+
kwargs["alg_update_func_filename"]
|
|
274
|
+
)
|
|
275
|
+
kwargs["inference_func"] = load_function_from_same_named_file(
|
|
276
|
+
kwargs["inference_func_filename"]
|
|
277
|
+
)
|
|
278
|
+
kwargs["theta_calculation_func"] = load_function_from_same_named_file(
|
|
279
|
+
kwargs["theta_calculation_func_filename"]
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
kwargs.pop("action_prob_func_filename")
|
|
283
|
+
kwargs.pop("alg_update_func_filename")
|
|
284
|
+
kwargs.pop("inference_func_filename")
|
|
285
|
+
kwargs.pop("theta_calculation_func_filename")
|
|
286
|
+
|
|
287
|
+
analyze_dataset(**kwargs)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def analyze_dataset(
|
|
291
|
+
output_dir: pathlib.Path | str,
|
|
292
|
+
study_df: pd.DataFrame,
|
|
293
|
+
action_prob_func: Callable,
|
|
294
|
+
action_prob_func_args: dict[int, Any],
|
|
295
|
+
action_prob_func_args_beta_index: int,
|
|
296
|
+
alg_update_func: Callable,
|
|
297
|
+
alg_update_func_type: str,
|
|
298
|
+
alg_update_func_args: dict[int, Any],
|
|
299
|
+
alg_update_func_args_beta_index: int,
|
|
300
|
+
alg_update_func_args_action_prob_index: int,
|
|
301
|
+
alg_update_func_args_action_prob_times_index: int,
|
|
302
|
+
inference_func: Callable,
|
|
303
|
+
inference_func_type: str,
|
|
304
|
+
inference_func_args_theta_index: int,
|
|
305
|
+
theta_calculation_func: Callable[[pd.DataFrame], jnp.ndarray],
|
|
306
|
+
in_study_col_name: str,
|
|
307
|
+
action_col_name: str,
|
|
308
|
+
policy_num_col_name: str,
|
|
309
|
+
calendar_t_col_name: str,
|
|
310
|
+
user_id_col_name: str,
|
|
311
|
+
action_prob_col_name: str,
|
|
312
|
+
reward_col_name: str,
|
|
313
|
+
suppress_interactive_data_checks: bool,
|
|
314
|
+
suppress_all_data_checks: bool,
|
|
315
|
+
small_sample_correction: str,
|
|
316
|
+
collect_data_for_blowup_supervised_learning: bool,
|
|
317
|
+
form_adaptive_meat_adjustments_explicitly: bool,
|
|
318
|
+
stabilize_joint_adaptive_bread_inverse: bool,
|
|
319
|
+
) -> None:
|
|
320
|
+
"""
|
|
321
|
+
Analyzes a dataset to provide a parameter estimate and an estimate of its variance using adaptive and classical sandwich estimators.
|
|
322
|
+
|
|
323
|
+
There are two modes of use for this function.
|
|
324
|
+
|
|
325
|
+
First, it may be called indirectly from the command line by passing through
|
|
326
|
+
analyze_dataset.
|
|
327
|
+
|
|
328
|
+
Second, it may be called directly from Python code with in-memory objects.
|
|
329
|
+
|
|
330
|
+
Parameters:
|
|
331
|
+
output_dir (pathlib.Path | str):
|
|
332
|
+
Directory in which to save output files.
|
|
333
|
+
study_df (pd.DataFrame):
|
|
334
|
+
DataFrame containing the study data.
|
|
335
|
+
action_prob_func (callable):
|
|
336
|
+
Action probability function.
|
|
337
|
+
action_prob_func_args (dict[int, Any]):
|
|
338
|
+
Arguments for the action probability function.
|
|
339
|
+
action_prob_func_args_beta_index (int):
|
|
340
|
+
Index for beta in action probability function arguments.
|
|
341
|
+
alg_update_func (callable):
|
|
342
|
+
Algorithm update function.
|
|
343
|
+
alg_update_func_type (str):
|
|
344
|
+
Type of the algorithm update function.
|
|
345
|
+
alg_update_func_args (dict[int, Any]):
|
|
346
|
+
Arguments for the algorithm update function.
|
|
347
|
+
alg_update_func_args_beta_index (int):
|
|
348
|
+
Index for beta in algorithm update function arguments.
|
|
349
|
+
alg_update_func_args_action_prob_index (int):
|
|
350
|
+
Index for action probability in algorithm update function arguments.
|
|
351
|
+
alg_update_func_args_action_prob_times_index (int):
|
|
352
|
+
Index for action probability times in algorithm update function arguments.
|
|
353
|
+
inference_func (callable):
|
|
354
|
+
Inference loss or estimating function.
|
|
355
|
+
inference_func_type (str):
|
|
356
|
+
Type of the inference function.
|
|
357
|
+
inference_func_args_theta_index (int):
|
|
358
|
+
Index for theta in inference function arguments.
|
|
359
|
+
theta_calculation_func (callable):
|
|
360
|
+
Function to estimate theta from the study DataFrame.
|
|
361
|
+
in_study_col_name (str):
|
|
362
|
+
Column name indicating if a user is in the study in the study dataframe.
|
|
363
|
+
action_col_name (str):
|
|
364
|
+
Column name for actions in the study dataframe.
|
|
365
|
+
policy_num_col_name (str):
|
|
366
|
+
Column name for policy numbers in the study dataframe.
|
|
367
|
+
calendar_t_col_name (str):
|
|
368
|
+
Column name for calendar time in the study dataframe.
|
|
369
|
+
user_id_col_name (str):
|
|
370
|
+
Column name for user IDs in the study dataframe.
|
|
371
|
+
action_prob_col_name (str):
|
|
372
|
+
Column name for action probabilities in the study dataframe.
|
|
373
|
+
reward_col_name (str):
|
|
374
|
+
Column name for rewards in the study dataframe.
|
|
375
|
+
suppress_interactive_data_checks (bool):
|
|
376
|
+
Whether to suppress interactive data checks. This should be used in simulations, for example.
|
|
377
|
+
suppress_all_data_checks (bool):
|
|
378
|
+
Whether to suppress all data checks. Not recommended.
|
|
379
|
+
small_sample_correction (str):
|
|
380
|
+
Type of small sample correction to apply.
|
|
381
|
+
collect_data_for_blowup_supervised_learning (bool):
|
|
382
|
+
Whether to collect data for doing supervised learning about adaptive sandwich blowup.
|
|
383
|
+
form_adaptive_meat_adjustments_explicitly (bool):
|
|
384
|
+
If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
|
|
385
|
+
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
386
|
+
adaptive sandwich is formed without doing this.
|
|
387
|
+
stabilize_joint_adaptive_bread_inverse (bool):
|
|
388
|
+
If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
|
|
389
|
+
thresholds.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
dict: A dictionary containing the theta estimate, adaptive sandwich variance estimate, and
|
|
393
|
+
classical sandwich variance estimate.
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
logging.basicConfig(
|
|
397
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
398
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
399
|
+
level=logging.INFO,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
theta_est = jnp.array(theta_calculation_func(study_df))
|
|
403
|
+
|
|
404
|
+
beta_dim = calculate_beta_dim(
|
|
405
|
+
action_prob_func_args, action_prob_func_args_beta_index
|
|
406
|
+
)
|
|
407
|
+
if not suppress_all_data_checks:
|
|
408
|
+
input_checks.perform_first_wave_input_checks(
|
|
409
|
+
study_df,
|
|
410
|
+
in_study_col_name,
|
|
411
|
+
action_col_name,
|
|
412
|
+
policy_num_col_name,
|
|
413
|
+
calendar_t_col_name,
|
|
414
|
+
user_id_col_name,
|
|
415
|
+
action_prob_col_name,
|
|
416
|
+
reward_col_name,
|
|
417
|
+
action_prob_func,
|
|
418
|
+
action_prob_func_args,
|
|
419
|
+
action_prob_func_args_beta_index,
|
|
420
|
+
alg_update_func_args,
|
|
421
|
+
alg_update_func_args_beta_index,
|
|
422
|
+
alg_update_func_args_action_prob_index,
|
|
423
|
+
alg_update_func_args_action_prob_times_index,
|
|
424
|
+
theta_est,
|
|
425
|
+
beta_dim,
|
|
426
|
+
suppress_interactive_data_checks,
|
|
427
|
+
small_sample_correction,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
### Begin collecting data structures that will be used to compute the joint bread matrix.
|
|
431
|
+
|
|
432
|
+
beta_index_by_policy_num, initial_policy_num = (
|
|
433
|
+
construct_beta_index_by_policy_num_map(
|
|
434
|
+
study_df, policy_num_col_name, in_study_col_name
|
|
435
|
+
)
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
all_post_update_betas = collect_all_post_update_betas(
|
|
439
|
+
beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id = (
|
|
443
|
+
extract_action_and_policy_by_decision_time_by_user_id(
|
|
444
|
+
study_df,
|
|
445
|
+
user_id_col_name,
|
|
446
|
+
in_study_col_name,
|
|
447
|
+
calendar_t_col_name,
|
|
448
|
+
action_col_name,
|
|
449
|
+
policy_num_col_name,
|
|
450
|
+
)
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
(
|
|
454
|
+
inference_func_args_by_user_id,
|
|
455
|
+
inference_func_args_action_prob_index,
|
|
456
|
+
inference_action_prob_decision_times_by_user_id,
|
|
457
|
+
) = process_inference_func_args(
|
|
458
|
+
inference_func,
|
|
459
|
+
inference_func_args_theta_index,
|
|
460
|
+
study_df,
|
|
461
|
+
theta_est,
|
|
462
|
+
action_prob_col_name,
|
|
463
|
+
calendar_t_col_name,
|
|
464
|
+
user_id_col_name,
|
|
465
|
+
in_study_col_name,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# Use a per-user weighted estimating function stacking functino to derive classical and joint
|
|
469
|
+
# adaptive meat and inverse bread matrices. This is facilitated because the *value* of the
|
|
470
|
+
# weighted and unweighted stacks are the same, as the weights evaluate to 1 pre-differentiation.
|
|
471
|
+
logger.info(
|
|
472
|
+
"Constructing joint adaptive bread inverse matrix, joint adaptive meat matrix, the classical analogs, and the avg estimating function stack across users."
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
user_ids = jnp.array(study_df[user_id_col_name].unique())
|
|
476
|
+
(
|
|
477
|
+
stabilized_joint_adaptive_bread_inverse_matrix,
|
|
478
|
+
raw_joint_adaptive_bread_inverse_matrix,
|
|
479
|
+
joint_adaptive_meat_matrix,
|
|
480
|
+
joint_adaptive_sandwich_matrix,
|
|
481
|
+
classical_bread_inverse_matrix,
|
|
482
|
+
classical_meat_matrix,
|
|
483
|
+
classical_sandwich_var_estimate,
|
|
484
|
+
avg_estimating_function_stack,
|
|
485
|
+
per_user_estimating_function_stacks,
|
|
486
|
+
per_user_adaptive_corrections,
|
|
487
|
+
per_user_classical_corrections,
|
|
488
|
+
per_user_adaptive_meat_adjustments,
|
|
489
|
+
) = construct_classical_and_adaptive_sandwiches(
|
|
490
|
+
theta_est,
|
|
491
|
+
all_post_update_betas,
|
|
492
|
+
user_ids,
|
|
493
|
+
action_prob_func,
|
|
494
|
+
action_prob_func_args_beta_index,
|
|
495
|
+
alg_update_func,
|
|
496
|
+
alg_update_func_type,
|
|
497
|
+
alg_update_func_args_beta_index,
|
|
498
|
+
alg_update_func_args_action_prob_index,
|
|
499
|
+
alg_update_func_args_action_prob_times_index,
|
|
500
|
+
inference_func,
|
|
501
|
+
inference_func_type,
|
|
502
|
+
inference_func_args_theta_index,
|
|
503
|
+
inference_func_args_action_prob_index,
|
|
504
|
+
action_prob_func_args,
|
|
505
|
+
policy_num_by_decision_time_by_user_id,
|
|
506
|
+
initial_policy_num,
|
|
507
|
+
beta_index_by_policy_num,
|
|
508
|
+
inference_func_args_by_user_id,
|
|
509
|
+
inference_action_prob_decision_times_by_user_id,
|
|
510
|
+
alg_update_func_args,
|
|
511
|
+
action_by_decision_time_by_user_id,
|
|
512
|
+
suppress_all_data_checks,
|
|
513
|
+
suppress_interactive_data_checks,
|
|
514
|
+
small_sample_correction,
|
|
515
|
+
form_adaptive_meat_adjustments_explicitly,
|
|
516
|
+
stabilize_joint_adaptive_bread_inverse,
|
|
517
|
+
study_df,
|
|
518
|
+
in_study_col_name,
|
|
519
|
+
action_col_name,
|
|
520
|
+
calendar_t_col_name,
|
|
521
|
+
user_id_col_name,
|
|
522
|
+
action_prob_func_args,
|
|
523
|
+
action_prob_col_name,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
joint_adaptive_bread_inverse_cond = jnp.linalg.cond(
|
|
527
|
+
stabilized_joint_adaptive_bread_inverse_matrix
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
theta_dim = len(theta_est)
|
|
531
|
+
if not suppress_all_data_checks:
|
|
532
|
+
input_checks.require_estimating_functions_sum_to_zero(
|
|
533
|
+
avg_estimating_function_stack,
|
|
534
|
+
beta_dim,
|
|
535
|
+
theta_dim,
|
|
536
|
+
suppress_interactive_data_checks,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# This bottom right corner of the joint (betas and theta) variance matrix is the portion
|
|
540
|
+
# corresponding to just theta.
|
|
541
|
+
adaptive_sandwich_var_estimate = joint_adaptive_sandwich_matrix[
|
|
542
|
+
-theta_dim:, -theta_dim:
|
|
543
|
+
]
|
|
544
|
+
|
|
545
|
+
# Check for negative diagonal elements and set them to zero if found
|
|
546
|
+
adaptive_diagonal = np.diag(adaptive_sandwich_var_estimate)
|
|
547
|
+
if np.any(adaptive_diagonal < 0):
|
|
548
|
+
logger.warning(
|
|
549
|
+
"Found negative diagonal elements in adaptive sandwich variance estimate. Setting them to zero."
|
|
550
|
+
)
|
|
551
|
+
np.fill_diagonal(
|
|
552
|
+
adaptive_sandwich_var_estimate, np.maximum(adaptive_diagonal, 0)
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
logger.info("Writing results to file...")
|
|
556
|
+
# Write analysis results to same directory that input files are in
|
|
557
|
+
output_folder_abs_path = pathlib.Path(output_dir).resolve()
|
|
558
|
+
|
|
559
|
+
analysis_dict = {
|
|
560
|
+
"theta_est": theta_est,
|
|
561
|
+
"adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
|
|
562
|
+
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
563
|
+
}
|
|
564
|
+
with open(output_folder_abs_path / "analysis.pkl", "wb") as f:
|
|
565
|
+
pickle.dump(
|
|
566
|
+
analysis_dict,
|
|
567
|
+
f,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
debug_pieces_dict = {
|
|
571
|
+
"theta_est": theta_est,
|
|
572
|
+
"adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
|
|
573
|
+
"classical_sandwich_var_estimate": classical_sandwich_var_estimate,
|
|
574
|
+
"raw_joint_bread_inverse_matrix": raw_joint_adaptive_bread_inverse_matrix,
|
|
575
|
+
"stabilized_joint_bread_inverse_matrix": stabilized_joint_adaptive_bread_inverse_matrix,
|
|
576
|
+
"joint_meat_matrix": joint_adaptive_meat_matrix,
|
|
577
|
+
"classical_bread_inverse_matrix": classical_bread_inverse_matrix,
|
|
578
|
+
"classical_meat_matrix": classical_meat_matrix,
|
|
579
|
+
"all_estimating_function_stacks": per_user_estimating_function_stacks,
|
|
580
|
+
"joint_bread_inverse_condition_number": joint_adaptive_bread_inverse_cond,
|
|
581
|
+
"all_post_update_betas": all_post_update_betas,
|
|
582
|
+
"per_user_adaptive_corrections": per_user_adaptive_corrections,
|
|
583
|
+
"per_user_classical_corrections": per_user_classical_corrections,
|
|
584
|
+
"per_user_adaptive_meat_adjustments": per_user_adaptive_meat_adjustments,
|
|
585
|
+
}
|
|
586
|
+
with open(output_folder_abs_path / "debug_pieces.pkl", "wb") as f:
|
|
587
|
+
pickle.dump(
|
|
588
|
+
debug_pieces_dict,
|
|
589
|
+
f,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
if collect_data_for_blowup_supervised_learning:
|
|
593
|
+
datum_and_label_dict = get_datum_for_blowup_supervised_learning.get_datum_for_blowup_supervised_learning(
|
|
594
|
+
raw_joint_adaptive_bread_inverse_matrix,
|
|
595
|
+
joint_adaptive_bread_inverse_cond,
|
|
596
|
+
avg_estimating_function_stack,
|
|
597
|
+
per_user_estimating_function_stacks,
|
|
598
|
+
all_post_update_betas,
|
|
599
|
+
study_df,
|
|
600
|
+
in_study_col_name,
|
|
601
|
+
calendar_t_col_name,
|
|
602
|
+
action_prob_col_name,
|
|
603
|
+
user_id_col_name,
|
|
604
|
+
reward_col_name,
|
|
605
|
+
theta_est,
|
|
606
|
+
adaptive_sandwich_var_estimate,
|
|
607
|
+
user_ids,
|
|
608
|
+
beta_dim,
|
|
609
|
+
theta_dim,
|
|
610
|
+
initial_policy_num,
|
|
611
|
+
beta_index_by_policy_num,
|
|
612
|
+
policy_num_by_decision_time_by_user_id,
|
|
613
|
+
theta_calculation_func,
|
|
614
|
+
action_prob_func,
|
|
615
|
+
action_prob_func_args_beta_index,
|
|
616
|
+
inference_func,
|
|
617
|
+
inference_func_type,
|
|
618
|
+
inference_func_args_theta_index,
|
|
619
|
+
inference_func_args_action_prob_index,
|
|
620
|
+
inference_action_prob_decision_times_by_user_id,
|
|
621
|
+
action_prob_func_args,
|
|
622
|
+
action_by_decision_time_by_user_id,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
with open(output_folder_abs_path / "supervised_learning_datum.pkl", "wb") as f:
|
|
626
|
+
pickle.dump(datum_and_label_dict, f)
|
|
627
|
+
|
|
628
|
+
print(f"\nParameter estimate:\n {theta_est}")
|
|
629
|
+
print(f"\nAdaptive sandwich variance estimate:\n {adaptive_sandwich_var_estimate}")
|
|
630
|
+
print(
|
|
631
|
+
f"\nClassical sandwich variance estimate:\n {classical_sandwich_var_estimate}\n"
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
return analysis_dict
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def process_inference_func_args(
|
|
638
|
+
inference_func: callable,
|
|
639
|
+
inference_func_args_theta_index: int,
|
|
640
|
+
study_df: pd.DataFrame,
|
|
641
|
+
theta_est: jnp.ndarray,
|
|
642
|
+
action_prob_col_name: str,
|
|
643
|
+
calendar_t_col_name: str,
|
|
644
|
+
user_id_col_name: str,
|
|
645
|
+
in_study_col_name: str,
|
|
646
|
+
) -> tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int]:
|
|
647
|
+
"""
|
|
648
|
+
Collects the inference function arguments for each user from the study DataFrame.
|
|
649
|
+
|
|
650
|
+
Note that theta and action probabilities, if present, will be replaced later
|
|
651
|
+
so that the function can be differentiated with respect to shared versions
|
|
652
|
+
of them.
|
|
653
|
+
|
|
654
|
+
Args:
|
|
655
|
+
inference_func (callable):
|
|
656
|
+
The inference function to be used.
|
|
657
|
+
inference_func_args_theta_index (int):
|
|
658
|
+
The index of the theta parameter in the inference function's arguments.
|
|
659
|
+
study_df (pandas.DataFrame):
|
|
660
|
+
The study DataFrame.
|
|
661
|
+
theta_est (jnp.ndarray):
|
|
662
|
+
The estimate of the parameter vector.
|
|
663
|
+
action_prob_col_name (str):
|
|
664
|
+
The name of the column in the study DataFrame that gives action probabilities.
|
|
665
|
+
calendar_t_col_name (str):
|
|
666
|
+
The name of the column in the study DataFrame that indicates calendar time.
|
|
667
|
+
user_id_col_name (str):
|
|
668
|
+
The name of the column in the study DataFrame that indicates user ID.
|
|
669
|
+
in_study_col_name (str):
|
|
670
|
+
The name of the binary column in the study DataFrame that indicates whether a user is in the study.
|
|
671
|
+
Returns:
|
|
672
|
+
tuple[dict[collections.abc.Hashable, tuple[Any, ...]], int, dict[collections.abc.Hashable, jnp.ndarray[int]]]:
|
|
673
|
+
A tuple containing
|
|
674
|
+
- the inference function arguments dictionary for each user
|
|
675
|
+
- the index of the action probabilities argument
|
|
676
|
+
- a dictionary mapping user IDs to the decision times to which action probabilities correspond
|
|
677
|
+
"""
|
|
678
|
+
|
|
679
|
+
num_args = inference_func.__code__.co_argcount
|
|
680
|
+
inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
|
|
681
|
+
inference_func_args_by_user_id = {}
|
|
682
|
+
|
|
683
|
+
inference_func_args_action_prob_index = -1
|
|
684
|
+
inference_action_prob_decision_times_by_user_id = {}
|
|
685
|
+
|
|
686
|
+
using_action_probs = action_prob_col_name in inference_func_arg_names
|
|
687
|
+
if using_action_probs:
|
|
688
|
+
inference_func_args_action_prob_index = inference_func_arg_names.index(
|
|
689
|
+
action_prob_col_name
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
for user_id in study_df[user_id_col_name].unique():
|
|
693
|
+
user_args_list = []
|
|
694
|
+
filtered_user_data = study_df.loc[study_df[user_id_col_name] == user_id]
|
|
695
|
+
for idx, col_name in enumerate(inference_func_arg_names):
|
|
696
|
+
if idx == inference_func_args_theta_index:
|
|
697
|
+
user_args_list.append(theta_est)
|
|
698
|
+
continue
|
|
699
|
+
user_args_list.append(
|
|
700
|
+
get_in_study_df_column(filtered_user_data, col_name, in_study_col_name)
|
|
701
|
+
)
|
|
702
|
+
inference_func_args_by_user_id[user_id] = tuple(user_args_list)
|
|
703
|
+
if using_action_probs:
|
|
704
|
+
inference_action_prob_decision_times_by_user_id[user_id] = (
|
|
705
|
+
get_in_study_df_column(
|
|
706
|
+
filtered_user_data, calendar_t_col_name, in_study_col_name
|
|
707
|
+
)
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
return (
|
|
711
|
+
inference_func_args_by_user_id,
|
|
712
|
+
inference_func_args_action_prob_index,
|
|
713
|
+
inference_action_prob_decision_times_by_user_id,
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def single_user_weighted_estimating_function_stacker(
|
|
718
|
+
beta_dim: int,
|
|
719
|
+
user_id: collections.abc.Hashable,
|
|
720
|
+
action_prob_func: callable,
|
|
721
|
+
algorithm_estimating_func: callable,
|
|
722
|
+
inference_estimating_func: callable,
|
|
723
|
+
action_prob_func_args_beta_index: int,
|
|
724
|
+
inference_func_args_theta_index: int,
|
|
725
|
+
action_prob_func_args_by_decision_time: dict[
|
|
726
|
+
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
727
|
+
],
|
|
728
|
+
threaded_action_prob_func_args_by_decision_time: dict[
|
|
729
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
730
|
+
],
|
|
731
|
+
threaded_update_func_args_by_policy_num: dict[
|
|
732
|
+
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
733
|
+
],
|
|
734
|
+
threaded_inference_func_args: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
735
|
+
policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
|
|
736
|
+
action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
|
|
737
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
738
|
+
) -> tuple[
|
|
739
|
+
jnp.ndarray[jnp.float32],
|
|
740
|
+
jnp.ndarray[jnp.float32],
|
|
741
|
+
jnp.ndarray[jnp.float32],
|
|
742
|
+
jnp.ndarray[jnp.float32],
|
|
743
|
+
]:
|
|
744
|
+
"""
|
|
745
|
+
Computes a weighted estimating function stack for a given algorithm estimating function
|
|
746
|
+
and arguments, inference estimating functio and arguments, and action probability function and
|
|
747
|
+
arguments.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
beta_dim (list[jnp.ndarray]):
|
|
751
|
+
A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
|
|
752
|
+
|
|
753
|
+
user_id (collections.abc.Hashable):
|
|
754
|
+
The user ID for which to compute the weighted estimating function stack.
|
|
755
|
+
|
|
756
|
+
action_prob_func (callable):
|
|
757
|
+
The function used to compute the probability of action 1 at a given decision time for
|
|
758
|
+
a particular user given their state and the algorithm parameters.
|
|
759
|
+
|
|
760
|
+
algorithm_estimating_func (callable):
|
|
761
|
+
The estimating function that corresponds to algorithm updates.
|
|
762
|
+
|
|
763
|
+
inference_estimating_func (callable):
|
|
764
|
+
The estimating function that corresponds to inference.
|
|
765
|
+
|
|
766
|
+
action_prob_func_args_beta_index (int):
|
|
767
|
+
The index of the beta argument in the action probability function's arguments.
|
|
768
|
+
|
|
769
|
+
inference_func_args_theta_index (int):
|
|
770
|
+
The index of the theta parameter in the inference loss or estimating function arguments.
|
|
771
|
+
|
|
772
|
+
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
773
|
+
A map from decision times to tuples of arguments for this user for the action
|
|
774
|
+
probability function. This is for all decision times (args are an empty
|
|
775
|
+
tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
|
|
776
|
+
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
777
|
+
will occur.
|
|
778
|
+
|
|
779
|
+
threaded_action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
780
|
+
A map from decision times to tuples of arguments for the action
|
|
781
|
+
probability function, with the shared betas threaded in for differentation. Decision
|
|
782
|
+
times should be sorted.
|
|
783
|
+
|
|
784
|
+
threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
785
|
+
A map from policy numbers to tuples containing the arguments for
|
|
786
|
+
the corresponding estimating functions for this user, with the shared betas threaded in
|
|
787
|
+
for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
|
|
788
|
+
should be sorted.
|
|
789
|
+
|
|
790
|
+
threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
791
|
+
A tuple containing the arguments for the inference
|
|
792
|
+
estimating function for this user, with the shared betas threaded in for differentiation.
|
|
793
|
+
|
|
794
|
+
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
795
|
+
A dictionary mapping decision times to the policy number in use. This may be
|
|
796
|
+
user-specific. Should be sorted by decision time. Only applies to in-study decision
|
|
797
|
+
times!
|
|
798
|
+
|
|
799
|
+
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
800
|
+
A dictionary mapping decision times to actions taken. Only applies to in-study decision
|
|
801
|
+
times!
|
|
802
|
+
|
|
803
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
804
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
805
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
806
|
+
|
|
807
|
+
Returns:
|
|
808
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the user's weighted estimating function
|
|
809
|
+
stack.
|
|
810
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the user's adaptive meat contribution.
|
|
811
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical meat contribution.
|
|
812
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical bread contribution.
|
|
813
|
+
"""
|
|
814
|
+
|
|
815
|
+
logger.info("Computing weighted estimating function stack for user %s.", user_id)
|
|
816
|
+
|
|
817
|
+
# First, reformat the supplied data into more convenient structures.
|
|
818
|
+
|
|
819
|
+
# 1. Form a dictionary mapping policy numbers to the first time they were
|
|
820
|
+
# applicable (for this user). Note that this includes ALL policies, initial
|
|
821
|
+
# fallbacks included.
|
|
822
|
+
# Collect the first time after the first update separately for convenience.
|
|
823
|
+
# These are both used to form the Radon-Nikodym weights for the right times.
|
|
824
|
+
min_time_by_policy_num, first_time_after_first_update = get_min_time_by_policy_num(
|
|
825
|
+
policy_num_by_decision_time,
|
|
826
|
+
beta_index_by_policy_num,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# 2. Get the start and end times for this user.
|
|
830
|
+
user_start_time = math.inf
|
|
831
|
+
user_end_time = -math.inf
|
|
832
|
+
for decision_time in action_by_decision_time:
|
|
833
|
+
user_start_time = min(user_start_time, decision_time)
|
|
834
|
+
user_end_time = max(user_end_time, decision_time)
|
|
835
|
+
|
|
836
|
+
# 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
|
|
837
|
+
logger.info(
|
|
838
|
+
"Computing the algorithm component of the weighted estimating function stack for user %s.",
|
|
839
|
+
user_id,
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
in_study_action_prob_func_args = [
|
|
843
|
+
args for args in action_prob_func_args_by_decision_time.values() if args
|
|
844
|
+
]
|
|
845
|
+
in_study_betas_list_by_decision_time_index = jnp.array(
|
|
846
|
+
[
|
|
847
|
+
action_prob_func_args[action_prob_func_args_beta_index]
|
|
848
|
+
for action_prob_func_args in in_study_action_prob_func_args
|
|
849
|
+
]
|
|
850
|
+
)
|
|
851
|
+
in_study_actions_list_by_decision_time_index = jnp.array(
|
|
852
|
+
list(action_by_decision_time.values())
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
# Sort the threaded args by decision time to be cautious. We check if the
|
|
856
|
+
# user id is present in the user args dict because we may call this on a
|
|
857
|
+
# subset of the user arg dict when we are batching arguments by shape
|
|
858
|
+
sorted_threaded_action_prob_args_by_decision_time = {
|
|
859
|
+
decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
|
|
860
|
+
for decision_time in range(user_start_time, user_end_time + 1)
|
|
861
|
+
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
862
|
+
}
|
|
863
|
+
|
|
864
|
+
num_args = None
|
|
865
|
+
for args in sorted_threaded_action_prob_args_by_decision_time.values():
|
|
866
|
+
if args:
|
|
867
|
+
num_args = len(args)
|
|
868
|
+
break
|
|
869
|
+
|
|
870
|
+
# NOTE: Cannot do [[]] * num_args here! Then all lists point
|
|
871
|
+
# same object...
|
|
872
|
+
batched_threaded_arg_lists = [[] for _ in range(num_args)]
|
|
873
|
+
for (
|
|
874
|
+
decision_time,
|
|
875
|
+
args,
|
|
876
|
+
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
877
|
+
if not args:
|
|
878
|
+
continue
|
|
879
|
+
for idx, arg in enumerate(args):
|
|
880
|
+
batched_threaded_arg_lists[idx].append(arg)
|
|
881
|
+
|
|
882
|
+
batched_threaded_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensors(
|
|
883
|
+
batched_threaded_arg_lists
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
# Note that we do NOT use the shared betas in the first arg to the weight function,
|
|
887
|
+
# since we don't want differentiation to happen with respect to them.
|
|
888
|
+
# Just grab the original beta from the update function arguments. This is the same
|
|
889
|
+
# value, but impervious to differentiation with respect to all_post_update_betas. The
|
|
890
|
+
# args, on the other hand, are a function of all_post_update_betas.
|
|
891
|
+
in_study_weights = jax.vmap(
|
|
892
|
+
fun=get_radon_nikodym_weight,
|
|
893
|
+
in_axes=[0, None, None, 0] + batch_axes,
|
|
894
|
+
out_axes=0,
|
|
895
|
+
)(
|
|
896
|
+
in_study_betas_list_by_decision_time_index,
|
|
897
|
+
action_prob_func,
|
|
898
|
+
action_prob_func_args_beta_index,
|
|
899
|
+
in_study_actions_list_by_decision_time_index,
|
|
900
|
+
*batched_threaded_arg_tensors,
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
in_study_index = 0
|
|
904
|
+
decision_time_to_all_weights_index_offset = min(
|
|
905
|
+
sorted_threaded_action_prob_args_by_decision_time
|
|
906
|
+
)
|
|
907
|
+
all_weights_raw = []
|
|
908
|
+
for (
|
|
909
|
+
decision_time,
|
|
910
|
+
args,
|
|
911
|
+
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
912
|
+
all_weights_raw.append(in_study_weights[in_study_index] if args else 1.0)
|
|
913
|
+
in_study_index += 1
|
|
914
|
+
all_weights = jnp.array(all_weights_raw)
|
|
915
|
+
|
|
916
|
+
algorithm_component = jnp.concatenate(
|
|
917
|
+
[
|
|
918
|
+
# Here we compute a product of Radon-Nikodym weights
|
|
919
|
+
# for all decision times after the first update and before the update
|
|
920
|
+
# update under consideration took effect, for which the user was in the study.
|
|
921
|
+
(
|
|
922
|
+
jnp.prod(
|
|
923
|
+
all_weights[
|
|
924
|
+
# The earliest time after the first update where the user was in
|
|
925
|
+
# the study
|
|
926
|
+
max(
|
|
927
|
+
first_time_after_first_update,
|
|
928
|
+
user_start_time,
|
|
929
|
+
)
|
|
930
|
+
- decision_time_to_all_weights_index_offset :
|
|
931
|
+
# One more than the latest time the user was in the study before the time
|
|
932
|
+
# the update under consideration first applied. Note the + 1 because range
|
|
933
|
+
# does not include the right endpoint.
|
|
934
|
+
min(
|
|
935
|
+
min_time_by_policy_num.get(policy_num, math.inf),
|
|
936
|
+
user_end_time + 1,
|
|
937
|
+
)
|
|
938
|
+
- decision_time_to_all_weights_index_offset,
|
|
939
|
+
]
|
|
940
|
+
# If the user exited the study before there were any updates,
|
|
941
|
+
# this variable will be None and the above code to grab a weight would
|
|
942
|
+
# throw an error. Just use 1 to include the unweighted estimating function
|
|
943
|
+
# if they have data to contribute to the update.
|
|
944
|
+
if first_time_after_first_update is not None
|
|
945
|
+
else 1
|
|
946
|
+
) # Now use the above to weight the alg estimating function for this update
|
|
947
|
+
* algorithm_estimating_func(*update_args)
|
|
948
|
+
# If there are no arguments for the update function, the user is not yet in the
|
|
949
|
+
# study, so we just add a zero vector contribution to the sum across users.
|
|
950
|
+
# Note that after they exit, they still contribute all their data to later
|
|
951
|
+
# updates.
|
|
952
|
+
if update_args
|
|
953
|
+
else jnp.zeros(beta_dim)
|
|
954
|
+
)
|
|
955
|
+
# vmapping over this would be tricky due to different shapes across updates
|
|
956
|
+
for policy_num, update_args in threaded_update_func_args_by_policy_num.items()
|
|
957
|
+
]
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
if algorithm_component.size % beta_dim != 0:
|
|
961
|
+
raise ValueError(
|
|
962
|
+
"The algorithm component of the weighted estimating function stack does not have a "
|
|
963
|
+
"size that is a multiple of the beta dimension. This likely means that the "
|
|
964
|
+
"algorithm estimating function is not returning a vector of the correct size."
|
|
965
|
+
)
|
|
966
|
+
# 4. Form the weighted inference estimating equation.
|
|
967
|
+
logger.info(
|
|
968
|
+
"Computing the inference component of the weighted estimating function stack for user %s.",
|
|
969
|
+
user_id,
|
|
970
|
+
)
|
|
971
|
+
inference_component = jnp.prod(
|
|
972
|
+
all_weights[
|
|
973
|
+
max(first_time_after_first_update, user_start_time)
|
|
974
|
+
- decision_time_to_all_weights_index_offset : user_end_time
|
|
975
|
+
+ 1
|
|
976
|
+
- decision_time_to_all_weights_index_offset,
|
|
977
|
+
]
|
|
978
|
+
# If the user exited the study before there were any updates,
|
|
979
|
+
# this variable will be None and the above code to grab a weight would
|
|
980
|
+
# throw an error. Just use 1 to include the unweighted estimating function
|
|
981
|
+
# if they have data to contribute here (pretty sure everyone should?)
|
|
982
|
+
if first_time_after_first_update is not None
|
|
983
|
+
else 1
|
|
984
|
+
) * inference_estimating_func(*threaded_inference_func_args)
|
|
985
|
+
|
|
986
|
+
# 5. Concatenate the two components to form the weighted estimating function stack for this
|
|
987
|
+
# user.
|
|
988
|
+
weighted_stack = jnp.concatenate([algorithm_component, inference_component])
|
|
989
|
+
|
|
990
|
+
# 6. Return the following outputs:
|
|
991
|
+
# a. The first is simply the weighted estimating function stack for this user. The average
|
|
992
|
+
# of these is what we differentiate with respect to theta to form the inverse adaptive joint
|
|
993
|
+
# bread matrix, and we also compare that average to zero to check the estimating functions'
|
|
994
|
+
# fidelity.
|
|
995
|
+
# b. The average outer product of these per-user stacks across users is the adaptive joint meat
|
|
996
|
+
# matrix, hence the second output.
|
|
997
|
+
# c. The third output is averaged across users to obtain the classical meat matrix.
|
|
998
|
+
# d. The fourth output is averaged across users to obtain the inverse classical bread
|
|
999
|
+
# matrix.
|
|
1000
|
+
return (
|
|
1001
|
+
weighted_stack,
|
|
1002
|
+
jnp.outer(weighted_stack, weighted_stack),
|
|
1003
|
+
jnp.outer(inference_component, inference_component),
|
|
1004
|
+
jax.jacrev(inference_estimating_func, argnums=inference_func_args_theta_index)(
|
|
1005
|
+
*threaded_inference_func_args
|
|
1006
|
+
),
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
1011
|
+
flattened_betas_and_theta: jnp.ndarray,
|
|
1012
|
+
beta_dim: int,
|
|
1013
|
+
theta_dim: int,
|
|
1014
|
+
user_ids: jnp.ndarray,
|
|
1015
|
+
action_prob_func: callable,
|
|
1016
|
+
action_prob_func_args_beta_index: int,
|
|
1017
|
+
alg_update_func: callable,
|
|
1018
|
+
alg_update_func_type: str,
|
|
1019
|
+
alg_update_func_args_beta_index: int,
|
|
1020
|
+
alg_update_func_args_action_prob_index: int,
|
|
1021
|
+
alg_update_func_args_action_prob_times_index: int,
|
|
1022
|
+
inference_func: callable,
|
|
1023
|
+
inference_func_type: str,
|
|
1024
|
+
inference_func_args_theta_index: int,
|
|
1025
|
+
inference_func_args_action_prob_index: int,
|
|
1026
|
+
action_prob_func_args_by_user_id_by_decision_time: dict[
|
|
1027
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1028
|
+
],
|
|
1029
|
+
policy_num_by_decision_time_by_user_id: dict[
|
|
1030
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
1031
|
+
],
|
|
1032
|
+
initial_policy_num: int | float,
|
|
1033
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
1034
|
+
inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1035
|
+
inference_action_prob_decision_times_by_user_id: dict[
|
|
1036
|
+
collections.abc.Hashable, list[int]
|
|
1037
|
+
],
|
|
1038
|
+
update_func_args_by_by_user_id_by_policy_num: dict[
|
|
1039
|
+
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
1040
|
+
],
|
|
1041
|
+
action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
|
|
1042
|
+
suppress_all_data_checks: bool,
|
|
1043
|
+
suppress_interactive_data_checks: bool,
|
|
1044
|
+
) -> tuple[
|
|
1045
|
+
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
1046
|
+
]:
|
|
1047
|
+
"""
|
|
1048
|
+
Computes the average weighted estimating function stack across all users, along with
|
|
1049
|
+
auxiliary values used to construct the adaptive and classical sandwich variances.
|
|
1050
|
+
|
|
1051
|
+
Args:
|
|
1052
|
+
flattened_betas_and_theta (jnp.ndarray):
|
|
1053
|
+
A list of JAX NumPy arrays representing the betas produced by all updates and the
|
|
1054
|
+
theta value, in that order. Important that this is a 1D array for efficiency reasons.
|
|
1055
|
+
We simply extract the betas and theta from this array below.
|
|
1056
|
+
beta_dim (int):
|
|
1057
|
+
The dimension of each of the beta parameters.
|
|
1058
|
+
theta_dim (int):
|
|
1059
|
+
The dimension of the theta parameter.
|
|
1060
|
+
user_ids (jnp.ndarray):
|
|
1061
|
+
A 1D JAX NumPy array of user IDs.
|
|
1062
|
+
action_prob_func (callable):
|
|
1063
|
+
The action probability function.
|
|
1064
|
+
action_prob_func_args_beta_index (int):
|
|
1065
|
+
The index of beta in the action probability function arguments tuples.
|
|
1066
|
+
alg_update_func (callable):
|
|
1067
|
+
The algorithm update estimating or loss function.
|
|
1068
|
+
alg_update_func_type (str):
|
|
1069
|
+
The type of the algorithm update function (loss or estimating).
|
|
1070
|
+
alg_update_func_args_beta_index (int):
|
|
1071
|
+
The index of beta in the update function arguments tuples.
|
|
1072
|
+
alg_update_func_args_action_prob_index (int):
|
|
1073
|
+
The index of action probabilities in the update function arguments tuple, if
|
|
1074
|
+
applicable. -1 otherwise.
|
|
1075
|
+
alg_update_func_args_action_prob_times_index (int):
|
|
1076
|
+
The index in the update function arguments tuple where an array of times for which the
|
|
1077
|
+
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
1078
|
+
inference_func (callable):
|
|
1079
|
+
The inference loss or estimating function.
|
|
1080
|
+
inference_func_type (str):
|
|
1081
|
+
The type of the inference function (loss or estimating).
|
|
1082
|
+
inference_func_args_theta_index (int):
|
|
1083
|
+
The index of the theta parameter in the inference function arguments tuples.
|
|
1084
|
+
inference_func_args_action_prob_index (int):
|
|
1085
|
+
The index of action probabilities in the inference function arguments tuple, if
|
|
1086
|
+
applicable. -1 otherwise.
|
|
1087
|
+
action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1088
|
+
A dictionary mapping decision times to maps of user ids to the function arguments
|
|
1089
|
+
required to compute action probabilities for this user.
|
|
1090
|
+
policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1091
|
+
A map of user ids to dictionaries mapping decision times to the policy number in use.
|
|
1092
|
+
Only applies to in-study decision times!
|
|
1093
|
+
initial_policy_num (int | float):
|
|
1094
|
+
The policy number of the initial policy before any updates.
|
|
1095
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
1096
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1097
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1098
|
+
inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1099
|
+
A dictionary mapping user IDs to their respective inference function arguments.
|
|
1100
|
+
inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
|
|
1101
|
+
For each user, a list of decision times to which action probabilities correspond if
|
|
1102
|
+
provided. Typically just in-study times if action probabilites are used in the inference
|
|
1103
|
+
loss or estimating function.
|
|
1104
|
+
update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
1105
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
|
|
1106
|
+
to their respective update function arguments.
|
|
1107
|
+
action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1108
|
+
A dictionary mapping user IDs to their respective actions taken at each decision time.
|
|
1109
|
+
Only applies to in-study decision times!
|
|
1110
|
+
suppress_all_data_checks (bool):
|
|
1111
|
+
If True, suppresses carrying out any data checks at all.
|
|
1112
|
+
suppress_interactive_data_checks (bool):
|
|
1113
|
+
If True, suppresses interactive data checks that would otherwise be performed to ensure
|
|
1114
|
+
the correctness of the threaded arguments. The checks are still performed, but
|
|
1115
|
+
any interactive prompts are suppressed.
|
|
1116
|
+
|
|
1117
|
+
Returns:
|
|
1118
|
+
jnp.ndarray:
|
|
1119
|
+
A 2D JAX NumPy array holding the average weighted estimating function stack.
|
|
1120
|
+
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1121
|
+
A tuple containing
|
|
1122
|
+
1. the average weighted estimating function stack
|
|
1123
|
+
2. the user-level adaptive meat matrix contributions
|
|
1124
|
+
3. the user-level classical meat matrix contributions
|
|
1125
|
+
4. the user-level inverse classical bread matrix contributions
|
|
1126
|
+
5. raw per-user weighted estimating function
|
|
1127
|
+
stacks.
|
|
1128
|
+
"""
|
|
1129
|
+
|
|
1130
|
+
# 1. Collect estimating functions by differentiating the loss functions if needed.
|
|
1131
|
+
algorithm_estimating_func = (
|
|
1132
|
+
jax.grad(alg_update_func, argnums=alg_update_func_args_beta_index)
|
|
1133
|
+
if (alg_update_func_type == FunctionTypes.LOSS)
|
|
1134
|
+
else alg_update_func
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
inference_estimating_func = (
|
|
1138
|
+
jax.grad(inference_func, argnums=inference_func_args_theta_index)
|
|
1139
|
+
if (inference_func_type == FunctionTypes.LOSS)
|
|
1140
|
+
else inference_func
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
betas, theta = unflatten_params(
|
|
1144
|
+
flattened_betas_and_theta,
|
|
1145
|
+
beta_dim,
|
|
1146
|
+
theta_dim,
|
|
1147
|
+
)
|
|
1148
|
+
|
|
1149
|
+
# 2. Thread in the betas and theta in all_post_update_betas_and_theta into the arguments
|
|
1150
|
+
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
1151
|
+
# values should be the same, but not connected to the parameter we are differentiating
|
|
1152
|
+
# with respect to. Note we will also find it useful below to have the action probability args
|
|
1153
|
+
# nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
|
|
1154
|
+
|
|
1155
|
+
logger.info("Threading in betas to action probability arguments for all users.")
|
|
1156
|
+
(
|
|
1157
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
1158
|
+
action_prob_func_args_by_decision_time_by_user_id,
|
|
1159
|
+
) = thread_action_prob_func_args(
|
|
1160
|
+
action_prob_func_args_by_user_id_by_decision_time,
|
|
1161
|
+
policy_num_by_decision_time_by_user_id,
|
|
1162
|
+
initial_policy_num,
|
|
1163
|
+
betas,
|
|
1164
|
+
beta_index_by_policy_num,
|
|
1165
|
+
action_prob_func_args_beta_index,
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
# 3. Thread the central betas into the algorithm update function arguments
|
|
1169
|
+
# and replace any action probabilities with reconstructed ones from the above
|
|
1170
|
+
# arguments with the central betas introduced.
|
|
1171
|
+
logger.info(
|
|
1172
|
+
"Threading in betas and beta-dependent action probabilities to algorithm update "
|
|
1173
|
+
"function args for all users"
|
|
1174
|
+
)
|
|
1175
|
+
threaded_update_func_args_by_policy_num_by_user_id = thread_update_func_args(
|
|
1176
|
+
update_func_args_by_by_user_id_by_policy_num,
|
|
1177
|
+
betas,
|
|
1178
|
+
beta_index_by_policy_num,
|
|
1179
|
+
alg_update_func_args_beta_index,
|
|
1180
|
+
alg_update_func_args_action_prob_index,
|
|
1181
|
+
alg_update_func_args_action_prob_times_index,
|
|
1182
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
1183
|
+
action_prob_func,
|
|
1184
|
+
)
|
|
1185
|
+
|
|
1186
|
+
# If action probabilites are used in the algorithm estimating function, make
|
|
1187
|
+
# sure that substituting in the reconstructed action probabilities is approximately
|
|
1188
|
+
# equivalent to using the original action probabilities.
|
|
1189
|
+
if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
|
|
1190
|
+
input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
|
|
1191
|
+
algorithm_estimating_func,
|
|
1192
|
+
update_func_args_by_by_user_id_by_policy_num,
|
|
1193
|
+
threaded_update_func_args_by_policy_num_by_user_id,
|
|
1194
|
+
suppress_interactive_data_checks,
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
# 4. Thread the central theta into the inference function arguments
|
|
1198
|
+
# and replace any action probabilities with reconstructed ones from the above
|
|
1199
|
+
# arguments with the central betas introduced.
|
|
1200
|
+
logger.info(
|
|
1201
|
+
"Threading in theta and beta-dependent action probabilities to inference update "
|
|
1202
|
+
"function args for all users"
|
|
1203
|
+
)
|
|
1204
|
+
threaded_inference_func_args_by_user_id = thread_inference_func_args(
|
|
1205
|
+
inference_func_args_by_user_id,
|
|
1206
|
+
inference_func_args_theta_index,
|
|
1207
|
+
theta,
|
|
1208
|
+
inference_func_args_action_prob_index,
|
|
1209
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
1210
|
+
inference_action_prob_decision_times_by_user_id,
|
|
1211
|
+
action_prob_func,
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1214
|
+
# If action probabilites are used in the inference estimating function, make
|
|
1215
|
+
# sure that substituting in the reconstructed action probabilities is approximately
|
|
1216
|
+
# equivalent to using the original action probabilities.
|
|
1217
|
+
if not suppress_all_data_checks and inference_func_args_action_prob_index >= 0:
|
|
1218
|
+
input_checks.require_threaded_inference_estimating_function_args_equivalent(
|
|
1219
|
+
inference_estimating_func,
|
|
1220
|
+
inference_func_args_by_user_id,
|
|
1221
|
+
threaded_inference_func_args_by_user_id,
|
|
1222
|
+
suppress_interactive_data_checks,
|
|
1223
|
+
)
|
|
1224
|
+
|
|
1225
|
+
# 5. Now we can compute the weighted estimating function stacks for all users
|
|
1226
|
+
# as well as collect related values used to construct the adaptive and classical
|
|
1227
|
+
# sandwich variances.
|
|
1228
|
+
results = [
|
|
1229
|
+
single_user_weighted_estimating_function_stacker(
|
|
1230
|
+
beta_dim,
|
|
1231
|
+
user_id,
|
|
1232
|
+
action_prob_func,
|
|
1233
|
+
algorithm_estimating_func,
|
|
1234
|
+
inference_estimating_func,
|
|
1235
|
+
action_prob_func_args_beta_index,
|
|
1236
|
+
inference_func_args_theta_index,
|
|
1237
|
+
action_prob_func_args_by_decision_time_by_user_id[user_id],
|
|
1238
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
|
|
1239
|
+
threaded_update_func_args_by_policy_num_by_user_id[user_id],
|
|
1240
|
+
threaded_inference_func_args_by_user_id[user_id],
|
|
1241
|
+
policy_num_by_decision_time_by_user_id[user_id],
|
|
1242
|
+
action_by_decision_time_by_user_id[user_id],
|
|
1243
|
+
beta_index_by_policy_num,
|
|
1244
|
+
)
|
|
1245
|
+
for user_id in user_ids.tolist()
|
|
1246
|
+
]
|
|
1247
|
+
|
|
1248
|
+
stacks = jnp.array([result[0] for result in results])
|
|
1249
|
+
outer_products = jnp.array([result[1] for result in results])
|
|
1250
|
+
inference_only_outer_products = jnp.array([result[2] for result in results])
|
|
1251
|
+
inference_hessians = jnp.array([result[3] for result in results])
|
|
1252
|
+
|
|
1253
|
+
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1254
|
+
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1255
|
+
# for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
|
|
1256
|
+
# bread matrices. The raw per-user stacks are also returned for debugging purposes.
|
|
1257
|
+
|
|
1258
|
+
# Note that returning the raw stacks here as the first arguments is potentially
|
|
1259
|
+
# memory-intensive when combined with differentiation. Keep this in mind if the per-user bread
|
|
1260
|
+
# inverse contributions are needed for something like CR2/CR3 small-sample corrections.
|
|
1261
|
+
return jnp.mean(stacks, axis=0), (
|
|
1262
|
+
jnp.mean(stacks, axis=0),
|
|
1263
|
+
outer_products,
|
|
1264
|
+
inference_only_outer_products,
|
|
1265
|
+
inference_hessians,
|
|
1266
|
+
stacks,
|
|
1267
|
+
)
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
def construct_classical_and_adaptive_sandwiches(
|
|
1271
|
+
theta_est: jnp.ndarray,
|
|
1272
|
+
all_post_update_betas: jnp.ndarray,
|
|
1273
|
+
user_ids: jnp.ndarray,
|
|
1274
|
+
action_prob_func: callable,
|
|
1275
|
+
action_prob_func_args_beta_index: int,
|
|
1276
|
+
alg_update_func: callable,
|
|
1277
|
+
alg_update_func_type: str,
|
|
1278
|
+
alg_update_func_args_beta_index: int,
|
|
1279
|
+
alg_update_func_args_action_prob_index: int,
|
|
1280
|
+
alg_update_func_args_action_prob_times_index: int,
|
|
1281
|
+
inference_func: callable,
|
|
1282
|
+
inference_func_type: str,
|
|
1283
|
+
inference_func_args_theta_index: int,
|
|
1284
|
+
inference_func_args_action_prob_index: int,
|
|
1285
|
+
action_prob_func_args_by_user_id_by_decision_time: dict[
|
|
1286
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1287
|
+
],
|
|
1288
|
+
policy_num_by_decision_time_by_user_id: dict[
|
|
1289
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
1290
|
+
],
|
|
1291
|
+
initial_policy_num: int | float,
|
|
1292
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
1293
|
+
inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1294
|
+
inference_action_prob_decision_times_by_user_id: dict[
|
|
1295
|
+
collections.abc.Hashable, list[int]
|
|
1296
|
+
],
|
|
1297
|
+
update_func_args_by_by_user_id_by_policy_num: dict[
|
|
1298
|
+
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
1299
|
+
],
|
|
1300
|
+
action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
|
|
1301
|
+
suppress_all_data_checks: bool,
|
|
1302
|
+
suppress_interactive_data_checks: bool,
|
|
1303
|
+
small_sample_correction: str,
|
|
1304
|
+
form_adaptive_meat_adjustments_explicitly: bool,
|
|
1305
|
+
stabilize_joint_adaptive_bread_inverse: bool,
|
|
1306
|
+
study_df: pd.DataFrame | None,
|
|
1307
|
+
in_study_col_name: str | None,
|
|
1308
|
+
action_col_name: str | None,
|
|
1309
|
+
calendar_t_col_name: str | None,
|
|
1310
|
+
user_id_col_name: str | None,
|
|
1311
|
+
action_prob_func_args: tuple | None,
|
|
1312
|
+
action_prob_col_name: str | None,
|
|
1313
|
+
) -> tuple[
|
|
1314
|
+
jnp.ndarray[jnp.float32],
|
|
1315
|
+
jnp.ndarray[jnp.float32],
|
|
1316
|
+
jnp.ndarray[jnp.float32],
|
|
1317
|
+
jnp.ndarray[jnp.float32],
|
|
1318
|
+
jnp.ndarray[jnp.float32],
|
|
1319
|
+
jnp.ndarray[jnp.float32],
|
|
1320
|
+
jnp.ndarray[jnp.float32],
|
|
1321
|
+
jnp.ndarray[jnp.float32],
|
|
1322
|
+
jnp.ndarray[jnp.float32],
|
|
1323
|
+
jnp.ndarray[jnp.float32],
|
|
1324
|
+
jnp.ndarray[jnp.float32],
|
|
1325
|
+
]:
|
|
1326
|
+
"""
|
|
1327
|
+
Constructs the classical and adaptive sandwich matrices, as well as various
|
|
1328
|
+
intermediate pieces in their consruction.
|
|
1329
|
+
|
|
1330
|
+
This is done by computing and differentiating the average weighted estimating function stack
|
|
1331
|
+
with respect to the betas and theta, using the resulting Jacobian to compute the inverse bread
|
|
1332
|
+
and meat matrices, and then stably computing sandwiches.
|
|
1333
|
+
|
|
1334
|
+
Args:
|
|
1335
|
+
theta_est (jnp.ndarray):
|
|
1336
|
+
A 1-D JAX NumPy array representing the parameter estimate for inference.
|
|
1337
|
+
all_post_update_betas (jnp.ndarray):
|
|
1338
|
+
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
1339
|
+
user_ids (jnp.ndarray):
|
|
1340
|
+
A 1-D JAX NumPy array holding all user IDs in the study.
|
|
1341
|
+
action_prob_func (callable):
|
|
1342
|
+
The action probability function.
|
|
1343
|
+
action_prob_func_args_beta_index (int):
|
|
1344
|
+
The index of beta in the action probability function arguments tuples.
|
|
1345
|
+
alg_update_func (callable):
|
|
1346
|
+
The algorithm update loss/estimating function.
|
|
1347
|
+
alg_update_func_type (str):
|
|
1348
|
+
The type of the algorithm update function (loss or estimating).
|
|
1349
|
+
alg_update_func_args_beta_index (int):
|
|
1350
|
+
The index of beta in the update function arguments tuples.
|
|
1351
|
+
alg_update_func_args_action_prob_index (int):
|
|
1352
|
+
The index of action probabilities in the update function arguments tuple, if
|
|
1353
|
+
applicable. -1 otherwise.
|
|
1354
|
+
alg_update_func_args_action_prob_times_index (int):
|
|
1355
|
+
The index in the update function arguments tuple where an array of times for which the
|
|
1356
|
+
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
1357
|
+
inference_func (callable):
|
|
1358
|
+
The inference loss or estimating function.
|
|
1359
|
+
inference_func_type (str):
|
|
1360
|
+
The type of the inference function (loss or estimating).
|
|
1361
|
+
inference_func_args_theta_index (int):
|
|
1362
|
+
The index of the theta parameter in the inference function arguments tuples.
|
|
1363
|
+
inference_func_args_action_prob_index (int):
|
|
1364
|
+
The index of action probabilities in the inference function arguments tuple, if
|
|
1365
|
+
applicable. -1 otherwise.
|
|
1366
|
+
action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
1367
|
+
A dictionary mapping decision times to maps of user ids to the function arguments
|
|
1368
|
+
required to compute action probabilities for this user.
|
|
1369
|
+
policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1370
|
+
A map of user ids to dictionaries mapping decision times to the policy number in use.
|
|
1371
|
+
Only applies to in-study decision times!
|
|
1372
|
+
initial_policy_num (int | float):
|
|
1373
|
+
The policy number of the initial policy before any updates.
|
|
1374
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
1375
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1376
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1377
|
+
inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1378
|
+
A dictionary mapping user IDs to their respective inference function arguments.
|
|
1379
|
+
inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
|
|
1380
|
+
For each user, a list of decision times to which action probabilities correspond if
|
|
1381
|
+
provided. Typically just in-study times if action probabilites are used in the inference
|
|
1382
|
+
loss or estimating function.
|
|
1383
|
+
update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
1384
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
|
|
1385
|
+
to their respective update function arguments.
|
|
1386
|
+
action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1387
|
+
A dictionary mapping user IDs to their respective actions taken at each decision time.
|
|
1388
|
+
Only applies to in-study decision times!
|
|
1389
|
+
suppress_all_data_checks (bool):
|
|
1390
|
+
If True, suppresses carrying out any data checks at all.
|
|
1391
|
+
suppress_interactive_data_checks (bool):
|
|
1392
|
+
If True, suppresses interactive data checks that would otherwise be performed to ensure
|
|
1393
|
+
the correctness of the threaded arguments. The checks are still performed, but
|
|
1394
|
+
any interactive prompts are suppressed.
|
|
1395
|
+
small_sample_correction (str):
|
|
1396
|
+
The type of small sample correction to apply. See SmallSampleCorrections class for
|
|
1397
|
+
options.
|
|
1398
|
+
form_adaptive_meat_adjustments_explicitly (bool):
|
|
1399
|
+
If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
|
|
1400
|
+
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
1401
|
+
adaptive sandwich is formed without doing this.
|
|
1402
|
+
stabilize_joint_adaptive_bread_inverse (bool):
|
|
1403
|
+
If True, will apply various techniques to stabilize the joint adaptive bread inverse if necessary.
|
|
1404
|
+
study_df (pd.DataFrame):
|
|
1405
|
+
The full study dataframe, needed if forming the adaptive meat adjustments explicitly.
|
|
1406
|
+
in_study_col_name (str):
|
|
1407
|
+
The name of the column in study_df indicating whether a user is in-study at a given decision time.
|
|
1408
|
+
action_col_name (str):
|
|
1409
|
+
The name of the column in study_df indicating the action taken at a given decision time.
|
|
1410
|
+
calendar_t_col_name (str):
|
|
1411
|
+
The name of the column in study_df indicating the calendar time of a given decision time.
|
|
1412
|
+
user_id_col_name (str):
|
|
1413
|
+
The name of the column in study_df indicating the user ID.
|
|
1414
|
+
action_prob_func_args (tuple):
|
|
1415
|
+
The arguments to be passed to the action probability function, needed if forming the
|
|
1416
|
+
adaptive meat adjustments explicitly.
|
|
1417
|
+
action_prob_col_name (str):
|
|
1418
|
+
The name of the column in study_df indicating the action probability of the action taken,
|
|
1419
|
+
needed if forming the adaptive meat adjustments explicitly.
|
|
1420
|
+
Returns:
|
|
1421
|
+
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
1422
|
+
A tuple containing:
|
|
1423
|
+
- The raw joint adaptive inverse bread matrix.
|
|
1424
|
+
- The (possibly) stabilized joint adaptive inverse bread matrix.
|
|
1425
|
+
- The joint adaptive meat matrix.
|
|
1426
|
+
- The joint adaptive sandwich matrix.
|
|
1427
|
+
- The classical inverse bread matrix.
|
|
1428
|
+
- The classical meat matrix.
|
|
1429
|
+
- The classical sandwich matrix.
|
|
1430
|
+
- The average weighted estimating function stack.
|
|
1431
|
+
- All per-user weighted estimating function stacks.
|
|
1432
|
+
- The per-user adaptive meat small-sample corrections.
|
|
1433
|
+
- The per-user classical meat small-sample corrections.
|
|
1434
|
+
- The per-user adaptive meat adjustments, if form_adaptive_meat_adjustments_explicitly
|
|
1435
|
+
is True, otherwise an array of NaNs.
|
|
1436
|
+
"""
|
|
1437
|
+
logger.info(
|
|
1438
|
+
"Differentiating average weighted estimating function stack and collecting auxiliary values."
|
|
1439
|
+
)
|
|
1440
|
+
theta_dim = theta_est.shape[0]
|
|
1441
|
+
beta_dim = all_post_update_betas.shape[1]
|
|
1442
|
+
# Note that these "contributions" are per-user Jacobians of the weighted estimating function stack.
|
|
1443
|
+
raw_joint_adaptive_bread_inverse_matrix, (
|
|
1444
|
+
avg_estimating_function_stack,
|
|
1445
|
+
per_user_joint_adaptive_meat_contributions,
|
|
1446
|
+
per_user_classical_meat_contributions,
|
|
1447
|
+
per_user_classical_bread_inverse_contributions,
|
|
1448
|
+
per_user_estimating_function_stacks,
|
|
1449
|
+
) = jax.jacrev(
|
|
1450
|
+
get_avg_weighted_estimating_function_stacks_and_aux_values, has_aux=True
|
|
1451
|
+
)(
|
|
1452
|
+
# While JAX can technically differentiate with respect to a list of JAX arrays,
|
|
1453
|
+
# it is apparently more efficient to flatten them into a single array. This is done
|
|
1454
|
+
# here to improve performance. We can simply unflatten them inside the function.
|
|
1455
|
+
flatten_params(all_post_update_betas, theta_est),
|
|
1456
|
+
beta_dim,
|
|
1457
|
+
theta_dim,
|
|
1458
|
+
user_ids,
|
|
1459
|
+
action_prob_func,
|
|
1460
|
+
action_prob_func_args_beta_index,
|
|
1461
|
+
alg_update_func,
|
|
1462
|
+
alg_update_func_type,
|
|
1463
|
+
alg_update_func_args_beta_index,
|
|
1464
|
+
alg_update_func_args_action_prob_index,
|
|
1465
|
+
alg_update_func_args_action_prob_times_index,
|
|
1466
|
+
inference_func,
|
|
1467
|
+
inference_func_type,
|
|
1468
|
+
inference_func_args_theta_index,
|
|
1469
|
+
inference_func_args_action_prob_index,
|
|
1470
|
+
action_prob_func_args_by_user_id_by_decision_time,
|
|
1471
|
+
policy_num_by_decision_time_by_user_id,
|
|
1472
|
+
initial_policy_num,
|
|
1473
|
+
beta_index_by_policy_num,
|
|
1474
|
+
inference_func_args_by_user_id,
|
|
1475
|
+
inference_action_prob_decision_times_by_user_id,
|
|
1476
|
+
update_func_args_by_by_user_id_by_policy_num,
|
|
1477
|
+
action_by_decision_time_by_user_id,
|
|
1478
|
+
suppress_all_data_checks,
|
|
1479
|
+
suppress_interactive_data_checks,
|
|
1480
|
+
)
|
|
1481
|
+
|
|
1482
|
+
num_users = len(user_ids)
|
|
1483
|
+
|
|
1484
|
+
(
|
|
1485
|
+
joint_adaptive_meat_matrix,
|
|
1486
|
+
classical_meat_matrix,
|
|
1487
|
+
per_user_adaptive_corrections,
|
|
1488
|
+
per_user_classical_corrections,
|
|
1489
|
+
) = perform_desired_small_sample_correction(
|
|
1490
|
+
small_sample_correction,
|
|
1491
|
+
per_user_joint_adaptive_meat_contributions,
|
|
1492
|
+
per_user_classical_meat_contributions,
|
|
1493
|
+
per_user_classical_bread_inverse_contributions,
|
|
1494
|
+
num_users,
|
|
1495
|
+
theta_dim,
|
|
1496
|
+
)
|
|
1497
|
+
|
|
1498
|
+
# Increase diagonal block dominance possibly improve conditioning of diagonal
|
|
1499
|
+
# blocks as necessary, to ensure mathematical stability of joint bread inverse
|
|
1500
|
+
stabilized_joint_adaptive_bread_inverse_matrix = (
|
|
1501
|
+
(
|
|
1502
|
+
stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
1503
|
+
raw_joint_adaptive_bread_inverse_matrix,
|
|
1504
|
+
beta_dim,
|
|
1505
|
+
theta_dim,
|
|
1506
|
+
)
|
|
1507
|
+
)
|
|
1508
|
+
if stabilize_joint_adaptive_bread_inverse
|
|
1509
|
+
else raw_joint_adaptive_bread_inverse_matrix
|
|
1510
|
+
)
|
|
1511
|
+
|
|
1512
|
+
# Now stably (no explicit inversion) form our sandwiches.
|
|
1513
|
+
joint_adaptive_sandwich = form_sandwich_from_bread_inverse_and_meat(
|
|
1514
|
+
stabilized_joint_adaptive_bread_inverse_matrix,
|
|
1515
|
+
joint_adaptive_meat_matrix,
|
|
1516
|
+
num_users,
|
|
1517
|
+
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1518
|
+
)
|
|
1519
|
+
classical_bread_inverse_matrix = jnp.mean(
|
|
1520
|
+
per_user_classical_bread_inverse_contributions, axis=0
|
|
1521
|
+
)
|
|
1522
|
+
classical_sandwich = form_sandwich_from_bread_inverse_and_meat(
|
|
1523
|
+
classical_bread_inverse_matrix,
|
|
1524
|
+
classical_meat_matrix,
|
|
1525
|
+
num_users,
|
|
1526
|
+
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1527
|
+
)
|
|
1528
|
+
|
|
1529
|
+
per_user_adaptive_meat_adjustments = jnp.full(
|
|
1530
|
+
(len(user_ids), theta_dim, theta_dim), jnp.nan
|
|
1531
|
+
)
|
|
1532
|
+
if form_adaptive_meat_adjustments_explicitly:
|
|
1533
|
+
per_user_adjusted_classical_meat_contributions = (
|
|
1534
|
+
form_adaptive_meat_adjustments_directly(
|
|
1535
|
+
theta_dim,
|
|
1536
|
+
all_post_update_betas.shape[1],
|
|
1537
|
+
stabilized_joint_adaptive_bread_inverse_matrix,
|
|
1538
|
+
per_user_estimating_function_stacks,
|
|
1539
|
+
study_df,
|
|
1540
|
+
in_study_col_name,
|
|
1541
|
+
action_col_name,
|
|
1542
|
+
calendar_t_col_name,
|
|
1543
|
+
user_id_col_name,
|
|
1544
|
+
action_prob_func,
|
|
1545
|
+
action_prob_func_args,
|
|
1546
|
+
action_prob_func_args_beta_index,
|
|
1547
|
+
theta_est,
|
|
1548
|
+
inference_func,
|
|
1549
|
+
inference_func_args_theta_index,
|
|
1550
|
+
user_ids,
|
|
1551
|
+
action_prob_col_name,
|
|
1552
|
+
)
|
|
1553
|
+
)
|
|
1554
|
+
# Validate that the adaptive meat adjustments we just formed are accurate by constructing
|
|
1555
|
+
# the theta-only adaptive sandwich from them and checking that it matches the standard result
|
|
1556
|
+
# we get by taking a subset of the joint adaptive sandwich.
|
|
1557
|
+
# First just apply any small-sample correction for parity.
|
|
1558
|
+
(
|
|
1559
|
+
_,
|
|
1560
|
+
theta_only_adaptive_meat_matrix_v2,
|
|
1561
|
+
_,
|
|
1562
|
+
_,
|
|
1563
|
+
) = perform_desired_small_sample_correction(
|
|
1564
|
+
small_sample_correction,
|
|
1565
|
+
per_user_joint_adaptive_meat_contributions,
|
|
1566
|
+
per_user_adjusted_classical_meat_contributions,
|
|
1567
|
+
per_user_classical_bread_inverse_contributions,
|
|
1568
|
+
num_users,
|
|
1569
|
+
theta_dim,
|
|
1570
|
+
)
|
|
1571
|
+
theta_only_adaptive_sandwich_from_adjustments = (
|
|
1572
|
+
form_sandwich_from_bread_inverse_and_meat(
|
|
1573
|
+
classical_bread_inverse_matrix,
|
|
1574
|
+
theta_only_adaptive_meat_matrix_v2,
|
|
1575
|
+
num_users,
|
|
1576
|
+
method=SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1577
|
+
)
|
|
1578
|
+
)
|
|
1579
|
+
theta_only_adaptive_sandwich = joint_adaptive_sandwich[-theta_dim:, -theta_dim:]
|
|
1580
|
+
|
|
1581
|
+
if not np.allclose(
|
|
1582
|
+
theta_only_adaptive_sandwich,
|
|
1583
|
+
theta_only_adaptive_sandwich_from_adjustments,
|
|
1584
|
+
rtol=3e-2,
|
|
1585
|
+
):
|
|
1586
|
+
logger.warning(
|
|
1587
|
+
"There may be a bug in the explicit meat adjustment calculation (this doesn't affect the actual calculation, just diagnostics). We've calculated the theta-only adaptive sandwich two different ways and they do not match sufficiently."
|
|
1588
|
+
)
|
|
1589
|
+
|
|
1590
|
+
# Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
|
|
1591
|
+
# values too. The joint adaptive bread inverse should always be block lower triangular.
|
|
1592
|
+
return (
|
|
1593
|
+
raw_joint_adaptive_bread_inverse_matrix,
|
|
1594
|
+
stabilized_joint_adaptive_bread_inverse_matrix,
|
|
1595
|
+
joint_adaptive_meat_matrix,
|
|
1596
|
+
joint_adaptive_sandwich,
|
|
1597
|
+
classical_bread_inverse_matrix,
|
|
1598
|
+
classical_meat_matrix,
|
|
1599
|
+
classical_sandwich,
|
|
1600
|
+
avg_estimating_function_stack,
|
|
1601
|
+
per_user_estimating_function_stacks,
|
|
1602
|
+
per_user_adaptive_corrections,
|
|
1603
|
+
per_user_classical_corrections,
|
|
1604
|
+
per_user_adaptive_meat_adjustments,
|
|
1605
|
+
)
|
|
1606
|
+
|
|
1607
|
+
|
|
1608
|
+
# TODO: I think there should be interaction to confirm stabilization. It is
|
|
1609
|
+
# important for the user to know if this is happening. Even if enabled, it is important
|
|
1610
|
+
# that the user know it actually kicks in.
|
|
1611
|
+
def stabilize_joint_adaptive_bread_inverse_if_necessary(
|
|
1612
|
+
joint_adaptive_bread_inverse_matrix: jnp.ndarray,
|
|
1613
|
+
beta_dim: int,
|
|
1614
|
+
theta_dim: int,
|
|
1615
|
+
) -> jnp.ndarray:
|
|
1616
|
+
"""
|
|
1617
|
+
Stabilizes the joint adaptive bread inverse matrix if necessary by increasing diagonal block
|
|
1618
|
+
dominance and/or adding a small ridge penalty to the diagonal blocks.
|
|
1619
|
+
|
|
1620
|
+
Args:
|
|
1621
|
+
joint_adaptive_bread_inverse_matrix (jnp.ndarray):
|
|
1622
|
+
A 2-D JAX NumPy array representing the joint adaptive bread inverse matrix.
|
|
1623
|
+
beta_dim (int):
|
|
1624
|
+
The dimension of each beta parameter.
|
|
1625
|
+
theta_dim (int):
|
|
1626
|
+
The dimension of the theta parameter.
|
|
1627
|
+
Returns:
|
|
1628
|
+
jnp.ndarray:
|
|
1629
|
+
A 2-D NumPy array representing the stabilized joint adaptive bread inverse matrix.
|
|
1630
|
+
"""
|
|
1631
|
+
|
|
1632
|
+
# TODO: come up with more sophisticated settings here. These are maybe a little loose,
|
|
1633
|
+
# but I especially want to avoid adding ridge penalties if possible.
|
|
1634
|
+
# Would be interested in dividing each by 10, though.
|
|
1635
|
+
|
|
1636
|
+
# Set thresholds to guide stabilization.
|
|
1637
|
+
diagonal_block_cond_threshold = 2e2
|
|
1638
|
+
whole_RL_block_cond_threshold = 1e4
|
|
1639
|
+
|
|
1640
|
+
# Grab just the RL block and convert numpy array for easier manipulation.
|
|
1641
|
+
RL_stack_beta_derivatives_block = np.array(
|
|
1642
|
+
joint_adaptive_bread_inverse_matrix[:-theta_dim, :-theta_dim]
|
|
1643
|
+
)
|
|
1644
|
+
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
1645
|
+
for i in range(1, num_updates + 1):
|
|
1646
|
+
|
|
1647
|
+
# Add ridge penalty to diagonal block to control its condition number if necessary.
|
|
1648
|
+
# Define the slice for the current diagonal block
|
|
1649
|
+
diagonal_block_slice = slice((i - 1) * beta_dim, i * beta_dim)
|
|
1650
|
+
diagonal_block = RL_stack_beta_derivatives_block[
|
|
1651
|
+
diagonal_block_slice, diagonal_block_slice
|
|
1652
|
+
]
|
|
1653
|
+
diagonal_block_cond_number = np.linalg.cond(diagonal_block)
|
|
1654
|
+
svs = np.linalg.svd(diagonal_block, compute_uv=False)
|
|
1655
|
+
max_sv = svs[0]
|
|
1656
|
+
min_sv = svs[-1]
|
|
1657
|
+
|
|
1658
|
+
ridge_penalty = max(
|
|
1659
|
+
0,
|
|
1660
|
+
(max_sv - diagonal_block_cond_threshold * min_sv)
|
|
1661
|
+
/ (diagonal_block_cond_threshold + 1),
|
|
1662
|
+
)
|
|
1663
|
+
|
|
1664
|
+
if ridge_penalty:
|
|
1665
|
+
new_block = diagonal_block + ridge_penalty * np.eye(beta_dim)
|
|
1666
|
+
new_diagonal_block_cond_number = np.linalg.cond(new_block)
|
|
1667
|
+
RL_stack_beta_derivatives_block[
|
|
1668
|
+
diagonal_block_slice, diagonal_block_slice
|
|
1669
|
+
] = diagonal_block + ridge_penalty * np.eye(beta_dim)
|
|
1670
|
+
# TODO: Require user input here in interactive settings?
|
|
1671
|
+
logger.info(
|
|
1672
|
+
"Added ridge penalty of %s to diagonal block for update %s to improve conditioning from %s to %s",
|
|
1673
|
+
ridge_penalty,
|
|
1674
|
+
i,
|
|
1675
|
+
diagonal_block_cond_number,
|
|
1676
|
+
new_diagonal_block_cond_number,
|
|
1677
|
+
)
|
|
1678
|
+
|
|
1679
|
+
# Damp off-diagonal blocks to improve conditioning of whole RL block if necessary.
|
|
1680
|
+
off_diagonal_block_row_slices = (
|
|
1681
|
+
slice((i - 1) * beta_dim, i * beta_dim),
|
|
1682
|
+
slice((i - 1) * beta_dim),
|
|
1683
|
+
)
|
|
1684
|
+
whole_block_cur_update_size = i * beta_dim
|
|
1685
|
+
initial_whole_block_cond_number = None
|
|
1686
|
+
incremental_damping_factor = 0.9
|
|
1687
|
+
max_iterations = 50
|
|
1688
|
+
damping_applied = 1
|
|
1689
|
+
|
|
1690
|
+
for _ in range(max_iterations):
|
|
1691
|
+
whole_block_cur_update = RL_stack_beta_derivatives_block[
|
|
1692
|
+
:whole_block_cur_update_size, :whole_block_cur_update_size
|
|
1693
|
+
]
|
|
1694
|
+
whole_block_cur_update_cond_number = np.linalg.cond(whole_block_cur_update)
|
|
1695
|
+
if initial_whole_block_cond_number is None:
|
|
1696
|
+
initial_whole_block_cond_number = whole_block_cur_update_cond_number
|
|
1697
|
+
|
|
1698
|
+
if whole_block_cur_update_cond_number <= whole_RL_block_cond_threshold:
|
|
1699
|
+
break
|
|
1700
|
+
|
|
1701
|
+
damping_applied *= incremental_damping_factor
|
|
1702
|
+
RL_stack_beta_derivatives_block[
|
|
1703
|
+
off_diagonal_block_row_slices
|
|
1704
|
+
] *= incremental_damping_factor
|
|
1705
|
+
else:
|
|
1706
|
+
damping_applied = 0
|
|
1707
|
+
RL_stack_beta_derivatives_block[off_diagonal_block_row_slices] *= 0
|
|
1708
|
+
|
|
1709
|
+
# TODO: Maybe in this case, roll back through previous rows and damp off diagonals
|
|
1710
|
+
# instead of adding ridge? Feels a little safer because if we zeroed everything
|
|
1711
|
+
# off-diagonal and didnt touch diagonal, we'd get classical.
|
|
1712
|
+
if whole_block_cur_update_cond_number > whole_RL_block_cond_threshold:
|
|
1713
|
+
logger.warning(
|
|
1714
|
+
"Off-diagonal blocks were zeroed for update %s, but conditioning is still poor: %s > %s. Adding extra ridge penalty to entire RL block so far.",
|
|
1715
|
+
i,
|
|
1716
|
+
whole_block_cur_update_cond_number,
|
|
1717
|
+
whole_RL_block_cond_threshold,
|
|
1718
|
+
)
|
|
1719
|
+
|
|
1720
|
+
svs = np.linalg.svd(whole_block_cur_update, compute_uv=False)
|
|
1721
|
+
max_sv = svs[0]
|
|
1722
|
+
min_sv = svs[-1]
|
|
1723
|
+
|
|
1724
|
+
ridge_penalty = max(
|
|
1725
|
+
0,
|
|
1726
|
+
(max_sv - whole_RL_block_cond_threshold * min_sv)
|
|
1727
|
+
/ (whole_RL_block_cond_threshold + 1),
|
|
1728
|
+
)
|
|
1729
|
+
|
|
1730
|
+
# TODO: This is highly questionable, potentially modifying the matrix very significantly.
|
|
1731
|
+
new_block = whole_block_cur_update + ridge_penalty * np.eye(
|
|
1732
|
+
whole_block_cur_update_size
|
|
1733
|
+
)
|
|
1734
|
+
new_whole_block_cond_number = np.linalg.cond(new_block)
|
|
1735
|
+
RL_stack_beta_derivatives_block[
|
|
1736
|
+
:whole_block_cur_update_size, :whole_block_cur_update_size
|
|
1737
|
+
] += ridge_penalty * np.eye(whole_block_cur_update_size)
|
|
1738
|
+
logger.info(
|
|
1739
|
+
"Added ridge penalty of %s to entire RL block up to update %s to improve conditioning from %s to %s",
|
|
1740
|
+
ridge_penalty,
|
|
1741
|
+
i,
|
|
1742
|
+
whole_block_cur_update_cond_number,
|
|
1743
|
+
new_whole_block_cond_number,
|
|
1744
|
+
)
|
|
1745
|
+
|
|
1746
|
+
# Add ridge penalty to off-diagonal blocks if necessary.
|
|
1747
|
+
|
|
1748
|
+
if damping_applied < 1:
|
|
1749
|
+
logger.info(
|
|
1750
|
+
"Applied damping factor of %s to off-diagonal blocks for update %s to improve conditioning of whole RL block up to that update from %s to %s",
|
|
1751
|
+
damping_applied,
|
|
1752
|
+
i,
|
|
1753
|
+
initial_whole_block_cond_number,
|
|
1754
|
+
whole_block_cur_update_cond_number,
|
|
1755
|
+
)
|
|
1756
|
+
|
|
1757
|
+
return np.block(
|
|
1758
|
+
[
|
|
1759
|
+
[
|
|
1760
|
+
RL_stack_beta_derivatives_block,
|
|
1761
|
+
joint_adaptive_bread_inverse_matrix[:-theta_dim, -theta_dim:],
|
|
1762
|
+
],
|
|
1763
|
+
[
|
|
1764
|
+
joint_adaptive_bread_inverse_matrix[-theta_dim:, :-theta_dim],
|
|
1765
|
+
joint_adaptive_bread_inverse_matrix[-theta_dim:, -theta_dim:],
|
|
1766
|
+
],
|
|
1767
|
+
]
|
|
1768
|
+
)
|
|
1769
|
+
|
|
1770
|
+
|
|
1771
|
+
def form_sandwich_from_bread_inverse_and_meat(
|
|
1772
|
+
bread_inverse: jnp.ndarray,
|
|
1773
|
+
meat: jnp.ndarray,
|
|
1774
|
+
num_users: int,
|
|
1775
|
+
method: str = SandwichFormationMethods.BREAD_INVERSE_T_QR,
|
|
1776
|
+
) -> jnp.ndarray:
|
|
1777
|
+
"""
|
|
1778
|
+
Forms a sandwich variance matrix from the provided bread inverse and meat matrices.
|
|
1779
|
+
|
|
1780
|
+
Attempts to do so STABLY without ever forming the bread matrix itself
|
|
1781
|
+
(except with naive option).
|
|
1782
|
+
|
|
1783
|
+
Args:
|
|
1784
|
+
bread_inverse (jnp.ndarray):
|
|
1785
|
+
A 2-D JAX NumPy array representing the bread inverse matrix.
|
|
1786
|
+
meat (jnp.ndarray):
|
|
1787
|
+
A 2-D JAX NumPy array representing the meat matrix.
|
|
1788
|
+
num_users (int):
|
|
1789
|
+
The number of users in the study, used to scale the sandwich appropriately.
|
|
1790
|
+
method (str):
|
|
1791
|
+
The method to use for forming the sandwich.
|
|
1792
|
+
|
|
1793
|
+
SandwichFormationMethods.BREAD_INVERSE_T_QR uses the QR decomposition of the transpose
|
|
1794
|
+
of the bread inverse matrix.
|
|
1795
|
+
|
|
1796
|
+
SandwichFormationMethods.MEAT_SVD_SOLVE uses a decomposition of the meat matrix.
|
|
1797
|
+
|
|
1798
|
+
SandwichFormationMethods.NAIVE simply inverts the bread inverse and forms the sandwich.
|
|
1799
|
+
|
|
1800
|
+
|
|
1801
|
+
Returns:
|
|
1802
|
+
jnp.ndarray:
|
|
1803
|
+
A 2-D JAX NumPy array representing the sandwich variance matrix.
|
|
1804
|
+
"""
|
|
1805
|
+
|
|
1806
|
+
if method == SandwichFormationMethods.BREAD_INVERSE_T_QR:
|
|
1807
|
+
# QR of B^T → Q orthogonal, R upper triangular; L = R^T lower triangular
|
|
1808
|
+
Q, R = np.linalg.qr(bread_inverse.T, mode="reduced")
|
|
1809
|
+
L = R.T
|
|
1810
|
+
|
|
1811
|
+
new_meat = scipy.linalg.solve_triangular(
|
|
1812
|
+
L, scipy.linalg.solve_triangular(L, meat.T, lower=True).T, lower=True
|
|
1813
|
+
)
|
|
1814
|
+
|
|
1815
|
+
return Q @ new_meat @ Q.T / num_users
|
|
1816
|
+
elif method == SandwichFormationMethods.MEAT_SVD_SOLVE:
|
|
1817
|
+
# Factor the meat via SVD without any symmetrization or truncation.
|
|
1818
|
+
# For general (possibly slightly nonsymmetric) M, SVD gives M = U @ diag(s) @ Vh.
|
|
1819
|
+
# We construct two square-root factors C_left = U * sqrt(s) and C_right = V * sqrt(s)
|
|
1820
|
+
# so that M = C_left @ C_right.T exactly, then solve once per factor.
|
|
1821
|
+
U, s, Vh = scipy.linalg.svd(meat, full_matrices=False)
|
|
1822
|
+
C_left = U * np.sqrt(s)
|
|
1823
|
+
C_right = Vh.T * np.sqrt(s)
|
|
1824
|
+
|
|
1825
|
+
# Solve B W_left = C_left and B W_right = C_right (no explicit inverses).
|
|
1826
|
+
W_left = scipy.linalg.solve(bread_inverse, C_left)
|
|
1827
|
+
W_right = scipy.linalg.solve(bread_inverse, C_right)
|
|
1828
|
+
|
|
1829
|
+
# Return the exact sandwich: V = (B^{-1} C_left) (B^{-1} C_right)^T / num_users
|
|
1830
|
+
return W_left @ W_right.T / num_users
|
|
1831
|
+
|
|
1832
|
+
elif method == SandwichFormationMethods.NAIVE:
|
|
1833
|
+
# Simply invert the bread inverse and form the sandwich directly.
|
|
1834
|
+
# This is NOT numerically stable and is only included for comparison purposes.
|
|
1835
|
+
bread = np.linalg.inv(bread_inverse)
|
|
1836
|
+
return bread @ meat @ meat.T / num_users
|
|
1837
|
+
|
|
1838
|
+
else:
|
|
1839
|
+
raise ValueError(
|
|
1840
|
+
f"Unknown sandwich method: {method}. Please use 'bread_inverse_t_qr' or 'meat_decomposition_solve'."
|
|
1841
|
+
)
|
|
1842
|
+
|
|
1843
|
+
|
|
1844
|
+
if __name__ == "__main__":
|
|
1845
|
+
cli()
|