lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,870 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import collections.abc
|
|
4
|
+
import logging
|
|
5
|
+
import math
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from jax import numpy as jnp
|
|
9
|
+
import jax
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
from .arg_threading_helpers import thread_action_prob_func_args, thread_update_func_args
|
|
14
|
+
from .constants import FunctionTypes
|
|
15
|
+
from .helper_functions import (
|
|
16
|
+
calculate_beta_dim,
|
|
17
|
+
collect_all_post_update_betas,
|
|
18
|
+
construct_beta_index_by_policy_num_map,
|
|
19
|
+
extract_action_and_policy_by_decision_time_by_user_id,
|
|
20
|
+
flatten_params,
|
|
21
|
+
get_min_time_by_policy_num,
|
|
22
|
+
get_radon_nikodym_weight,
|
|
23
|
+
unflatten_params,
|
|
24
|
+
)
|
|
25
|
+
from . import input_checks
|
|
26
|
+
from .vmap_helpers import stack_batched_arg_lists_into_tensors
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
logging.basicConfig(
|
|
31
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
32
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
33
|
+
level=logging.INFO,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TrialConditioningMonitor:
|
|
38
|
+
whole_RL_block_conditioning_threshold = None
|
|
39
|
+
diagonal_RL_block_conditioning_threshold = None
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
whole_RL_block_conditioning_threshold: int = 1000,
|
|
44
|
+
diagonal_RL_block_conditioning_threshold: int = 100,
|
|
45
|
+
):
|
|
46
|
+
self.whole_RL_block_conditioning_threshold = (
|
|
47
|
+
whole_RL_block_conditioning_threshold
|
|
48
|
+
)
|
|
49
|
+
self.diagonal_RL_block_conditioning_threshold = (
|
|
50
|
+
diagonal_RL_block_conditioning_threshold
|
|
51
|
+
)
|
|
52
|
+
self.latest_phi_dot_bar = None
|
|
53
|
+
|
|
54
|
+
def assess_update(
|
|
55
|
+
self,
|
|
56
|
+
proposed_policy_num: int | float,
|
|
57
|
+
study_df: pd.DataFrame,
|
|
58
|
+
action_prob_func: callable,
|
|
59
|
+
action_prob_func_args: dict,
|
|
60
|
+
action_prob_func_args_beta_index: int,
|
|
61
|
+
alg_update_func: callable,
|
|
62
|
+
alg_update_func_type: str,
|
|
63
|
+
alg_update_func_args: dict,
|
|
64
|
+
alg_update_func_args_beta_index: int,
|
|
65
|
+
alg_update_func_args_action_prob_index: int,
|
|
66
|
+
alg_update_func_args_action_prob_times_index: int,
|
|
67
|
+
in_study_col_name: str,
|
|
68
|
+
action_col_name: str,
|
|
69
|
+
policy_num_col_name: str,
|
|
70
|
+
calendar_t_col_name: str,
|
|
71
|
+
user_id_col_name: str,
|
|
72
|
+
action_prob_col_name: str,
|
|
73
|
+
suppress_interactive_data_checks: bool,
|
|
74
|
+
suppress_all_data_checks: bool,
|
|
75
|
+
incremental: bool = True,
|
|
76
|
+
) -> None:
|
|
77
|
+
"""
|
|
78
|
+
Analyzes a dataset to estimate parameters and variance using adaptive and classical sandwich estimators.
|
|
79
|
+
|
|
80
|
+
Parameters:
|
|
81
|
+
proposed_policy_num (int | float):
|
|
82
|
+
The policy number of the proposed update.
|
|
83
|
+
study_df (pd.DataFrame):
|
|
84
|
+
DataFrame containing the study data.
|
|
85
|
+
action_prob_func (str):
|
|
86
|
+
Action probability function.
|
|
87
|
+
action_prob_func_args (dict):
|
|
88
|
+
Arguments for the action probability function.
|
|
89
|
+
action_prob_func_args_beta_index (int):
|
|
90
|
+
Index for beta in action probability function arguments.
|
|
91
|
+
alg_update_func (str):
|
|
92
|
+
Algorithm update function.
|
|
93
|
+
alg_update_func_type (str):
|
|
94
|
+
Type of the algorithm update function.
|
|
95
|
+
alg_update_func_args (dict):
|
|
96
|
+
Arguments for the algorithm update function.
|
|
97
|
+
alg_update_func_args_beta_index (int):
|
|
98
|
+
Index for beta in algorithm update function arguments.
|
|
99
|
+
alg_update_func_args_action_prob_index (int):
|
|
100
|
+
Index for action probability in algorithm update function arguments.
|
|
101
|
+
alg_update_func_args_action_prob_times_index (int):
|
|
102
|
+
Index for action probability times in algorithm update function arguments.
|
|
103
|
+
in_study_col_name (str):
|
|
104
|
+
Column name indicating if a user is in the study in the study dataframe.
|
|
105
|
+
action_col_name (str):
|
|
106
|
+
Column name for actions in the study dataframe.
|
|
107
|
+
policy_num_col_name (str):
|
|
108
|
+
Column name for policy numbers in the study dataframe.
|
|
109
|
+
calendar_t_col_name (str):
|
|
110
|
+
Column name for calendar time in the study dataframe.
|
|
111
|
+
user_id_col_name (str):
|
|
112
|
+
Column name for user IDs in the study dataframe.
|
|
113
|
+
action_prob_col_name (str):
|
|
114
|
+
Column name for action probabilities in the study dataframe.
|
|
115
|
+
reward_col_name (str):
|
|
116
|
+
Column name for rewards in the study dataframe.
|
|
117
|
+
suppress_interactive_data_checks (bool):
|
|
118
|
+
Whether to suppress interactive data checks. This should be used in simulations, for example.
|
|
119
|
+
suppress_all_data_checks (bool):
|
|
120
|
+
Whether to suppress all data checks. Not recommended.
|
|
121
|
+
small_sample_correction (str):
|
|
122
|
+
Type of small sample correction to apply.
|
|
123
|
+
collect_data_for_blowup_supervised_learning (bool):
|
|
124
|
+
Whether to collect data for doing supervised learning about adaptive sandwich blowup.
|
|
125
|
+
form_adaptive_meat_adjustments_explicitly (bool):
|
|
126
|
+
If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
|
|
127
|
+
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
128
|
+
adaptive sandwich is formed without doing this.
|
|
129
|
+
stabilize_joint_adaptive_bread_inverse (bool):
|
|
130
|
+
If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
|
|
131
|
+
thresholds.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
None: The function writes analysis results and debug pieces to files in the same directory as
|
|
135
|
+
the input files.
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
beta_dim = calculate_beta_dim(
|
|
139
|
+
action_prob_func_args, action_prob_func_args_beta_index
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if not suppress_all_data_checks:
|
|
143
|
+
input_checks.perform_alg_only_input_checks(
|
|
144
|
+
study_df,
|
|
145
|
+
in_study_col_name,
|
|
146
|
+
policy_num_col_name,
|
|
147
|
+
calendar_t_col_name,
|
|
148
|
+
user_id_col_name,
|
|
149
|
+
action_prob_col_name,
|
|
150
|
+
action_prob_func,
|
|
151
|
+
action_prob_func_args,
|
|
152
|
+
action_prob_func_args_beta_index,
|
|
153
|
+
alg_update_func_args,
|
|
154
|
+
alg_update_func_args_beta_index,
|
|
155
|
+
alg_update_func_args_action_prob_index,
|
|
156
|
+
alg_update_func_args_action_prob_times_index,
|
|
157
|
+
suppress_interactive_data_checks,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
beta_index_by_policy_num, initial_policy_num = (
|
|
161
|
+
construct_beta_index_by_policy_num_map(
|
|
162
|
+
study_df, policy_num_col_name, in_study_col_name
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
# We augment the produced map to include the proposed policy num.
|
|
166
|
+
# This is necessary because the logic above assumes all policies are present in the
|
|
167
|
+
# study df, whereas for us we are only passing the data *used* for the current update,
|
|
168
|
+
# i.e. up to the previous policy.
|
|
169
|
+
beta_index_by_policy_num[proposed_policy_num] = len(beta_index_by_policy_num)
|
|
170
|
+
|
|
171
|
+
all_post_update_betas = collect_all_post_update_betas(
|
|
172
|
+
beta_index_by_policy_num,
|
|
173
|
+
alg_update_func_args,
|
|
174
|
+
alg_update_func_args_beta_index,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id = (
|
|
178
|
+
extract_action_and_policy_by_decision_time_by_user_id(
|
|
179
|
+
study_df,
|
|
180
|
+
user_id_col_name,
|
|
181
|
+
in_study_col_name,
|
|
182
|
+
calendar_t_col_name,
|
|
183
|
+
action_col_name,
|
|
184
|
+
policy_num_col_name,
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
user_ids = jnp.array(study_df[user_id_col_name].unique())
|
|
189
|
+
|
|
190
|
+
phi_dot_bar, avg_estimating_function_stack = self.construct_phi_dot_bar_so_far(
|
|
191
|
+
all_post_update_betas,
|
|
192
|
+
user_ids,
|
|
193
|
+
action_prob_func,
|
|
194
|
+
action_prob_func_args_beta_index,
|
|
195
|
+
alg_update_func,
|
|
196
|
+
alg_update_func_type,
|
|
197
|
+
alg_update_func_args_beta_index,
|
|
198
|
+
alg_update_func_args_action_prob_index,
|
|
199
|
+
alg_update_func_args_action_prob_times_index,
|
|
200
|
+
action_prob_func_args,
|
|
201
|
+
policy_num_by_decision_time_by_user_id,
|
|
202
|
+
initial_policy_num,
|
|
203
|
+
beta_index_by_policy_num,
|
|
204
|
+
alg_update_func_args,
|
|
205
|
+
action_by_decision_time_by_user_id,
|
|
206
|
+
suppress_all_data_checks,
|
|
207
|
+
suppress_interactive_data_checks,
|
|
208
|
+
incremental=incremental,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if not suppress_all_data_checks:
|
|
212
|
+
input_checks.require_RL_estimating_functions_sum_to_zero(
|
|
213
|
+
avg_estimating_function_stack,
|
|
214
|
+
beta_dim,
|
|
215
|
+
suppress_interactive_data_checks,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Decide whether to accept or reject the update based on conditioning
|
|
219
|
+
update_rejected = False
|
|
220
|
+
rejection_reason = ""
|
|
221
|
+
whole_RL_block_condition_number = np.linalg.cond(phi_dot_bar)
|
|
222
|
+
new_diagonal_RL_block_condition_number = np.linalg.cond(
|
|
223
|
+
phi_dot_bar[-beta_dim:, -beta_dim:]
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if whole_RL_block_condition_number > self.whole_RL_block_conditioning_threshold:
|
|
227
|
+
logger.warning(
|
|
228
|
+
"The RL portion of the bread inverse up to this point exceeds the threshold set (condition number: %s, threshold: %s). Consider an alternative update strategy which produces less dependence on previous RL parameters (via the data they produced) and/or improves the conditioning of each update itself. Regularization may help with both of these.",
|
|
229
|
+
whole_RL_block_condition_number,
|
|
230
|
+
self.whole_RL_block_conditioning_threshold,
|
|
231
|
+
)
|
|
232
|
+
update_rejected = True
|
|
233
|
+
rejection_reason = "whole_block_poor_conditioning"
|
|
234
|
+
elif (
|
|
235
|
+
new_diagonal_RL_block_condition_number
|
|
236
|
+
> self.diagonal_RL_block_conditioning_threshold
|
|
237
|
+
):
|
|
238
|
+
logger.warning(
|
|
239
|
+
"The diagonal RL block of the bread inverse up to this point exceeds the threshold set (condition number: %s, threshold: %s). This may illustrate a fundamental problem with the conditioning of the RL update procedure.",
|
|
240
|
+
new_diagonal_RL_block_condition_number,
|
|
241
|
+
self.diagonal_RL_block_conditioning_threshold,
|
|
242
|
+
)
|
|
243
|
+
update_rejected = True
|
|
244
|
+
rejection_reason = "diagonal_block_poor_conditioning"
|
|
245
|
+
|
|
246
|
+
# TODO: Regression -> prediction over going over threshold? Take in estimated num updates if so.
|
|
247
|
+
|
|
248
|
+
ans = {
|
|
249
|
+
"update_rejected": update_rejected,
|
|
250
|
+
"rejection_reason": rejection_reason,
|
|
251
|
+
"whole_RL_block_condition_number": whole_RL_block_condition_number,
|
|
252
|
+
"whole_RL_block_conditioning_threshold": self.whole_RL_block_conditioning_threshold,
|
|
253
|
+
"new_diagonal_RL_block_condition_number": new_diagonal_RL_block_condition_number,
|
|
254
|
+
"diagonal_RL_block_conditioning_threshold": self.diagonal_RL_block_conditioning_threshold,
|
|
255
|
+
}
|
|
256
|
+
logger.info("Update assessment results: %s", ans)
|
|
257
|
+
return ans
|
|
258
|
+
|
|
259
|
+
def construct_phi_dot_bar_so_far(
|
|
260
|
+
self,
|
|
261
|
+
all_post_update_betas: jnp.ndarray,
|
|
262
|
+
user_ids: jnp.ndarray,
|
|
263
|
+
action_prob_func: callable,
|
|
264
|
+
action_prob_func_args_beta_index: int,
|
|
265
|
+
alg_update_func: callable,
|
|
266
|
+
alg_update_func_type: str,
|
|
267
|
+
alg_update_func_args_beta_index: int,
|
|
268
|
+
alg_update_func_args_action_prob_index: int,
|
|
269
|
+
alg_update_func_args_action_prob_times_index: int,
|
|
270
|
+
action_prob_func_args_by_user_id_by_decision_time: dict[
|
|
271
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
272
|
+
],
|
|
273
|
+
policy_num_by_decision_time_by_user_id: dict[
|
|
274
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
275
|
+
],
|
|
276
|
+
initial_policy_num: int | float,
|
|
277
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
278
|
+
update_func_args_by_by_user_id_by_policy_num: dict[
|
|
279
|
+
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
280
|
+
],
|
|
281
|
+
action_by_decision_time_by_user_id: dict[
|
|
282
|
+
collections.abc.Hashable, dict[int, int]
|
|
283
|
+
],
|
|
284
|
+
suppress_all_data_checks: bool,
|
|
285
|
+
suppress_interactive_data_checks: bool,
|
|
286
|
+
incremental: bool,
|
|
287
|
+
) -> tuple[
|
|
288
|
+
jnp.ndarray[jnp.float32],
|
|
289
|
+
jnp.ndarray[jnp.float32],
|
|
290
|
+
]:
|
|
291
|
+
"""
|
|
292
|
+
Constructs the classical and adaptive inverse bread and meat matrices, as well as the average
|
|
293
|
+
estimating function stack and some other intermediate pieces.
|
|
294
|
+
|
|
295
|
+
This is done by computing and differentiating the average weighted estimating function stack
|
|
296
|
+
with respect to the betas and theta, using the resulting Jacobian to compute the inverse bread
|
|
297
|
+
and meat matrices, and then stably computing sandwiches.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
all_post_update_betas (jnp.ndarray):
|
|
301
|
+
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
302
|
+
user_ids (jnp.ndarray):
|
|
303
|
+
A 1-D JAX NumPy array holding all user IDs in the study.
|
|
304
|
+
action_prob_func (callable):
|
|
305
|
+
The action probability function.
|
|
306
|
+
action_prob_func_args_beta_index (int):
|
|
307
|
+
The index of beta in the action probability function arguments tuples.
|
|
308
|
+
alg_update_func (callable):
|
|
309
|
+
The algorithm update function.
|
|
310
|
+
alg_update_func_type (str):
|
|
311
|
+
The type of the algorithm update function (loss or estimating).
|
|
312
|
+
alg_update_func_args_beta_index (int):
|
|
313
|
+
The index of beta in the update function arguments tuples.
|
|
314
|
+
alg_update_func_args_action_prob_index (int):
|
|
315
|
+
The index of action probabilities in the update function arguments tuple, if
|
|
316
|
+
applicable. -1 otherwise.
|
|
317
|
+
alg_update_func_args_action_prob_times_index (int):
|
|
318
|
+
The index in the update function arguments tuple where an array of times for which the
|
|
319
|
+
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
320
|
+
action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
321
|
+
A dictionary mapping decision times to maps of user ids to the function arguments
|
|
322
|
+
required to compute action probabilities for this user.
|
|
323
|
+
policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
324
|
+
A map of user ids to dictionaries mapping decision times to the policy number in use.
|
|
325
|
+
Only applies to in-study decision times!
|
|
326
|
+
initial_policy_num (int | float):
|
|
327
|
+
The policy number of the initial policy before any updates.
|
|
328
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
329
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
330
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
331
|
+
update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
332
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
|
|
333
|
+
to their respective update function arguments.
|
|
334
|
+
action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
335
|
+
A dictionary mapping user IDs to their respective actions taken at each decision time.
|
|
336
|
+
Only applies to in-study decision times!
|
|
337
|
+
suppress_all_data_checks (bool):
|
|
338
|
+
If True, suppresses carrying out any data checks at all.
|
|
339
|
+
suppress_interactive_data_checks (bool):
|
|
340
|
+
If True, suppresses interactive data checks that would otherwise be performed to ensure
|
|
341
|
+
the correctness of the threaded arguments. The checks are still performed, but
|
|
342
|
+
any interactive prompts are suppressed.
|
|
343
|
+
incremental (bool): Whether to form the whole phi-dot-bar so far, or just form the latest
|
|
344
|
+
block row and add to the cached previous phi-dot-bar.
|
|
345
|
+
"""
|
|
346
|
+
logger.info(
|
|
347
|
+
"Differentiating average weighted estimating function stack and collecting auxiliary values."
|
|
348
|
+
)
|
|
349
|
+
beta_dim = all_post_update_betas.shape[1]
|
|
350
|
+
|
|
351
|
+
if incremental and self.latest_phi_dot_bar is not None:
|
|
352
|
+
# We only need to compute the latest block row of the Jacobian.
|
|
353
|
+
(
|
|
354
|
+
phi_dot_bar_latest_block,
|
|
355
|
+
avg_RL_estimating_function_stack,
|
|
356
|
+
) = jax.jacrev(
|
|
357
|
+
self.get_avg_weighted_RL_estimating_function_stacks_and_aux_values,
|
|
358
|
+
has_aux=True,
|
|
359
|
+
)(
|
|
360
|
+
# While JAX can technically differentiate with respect to a list of JAX arrays,
|
|
361
|
+
# it is apparently more efficient to flatten them into a single array. This is done
|
|
362
|
+
# here to improve performance. We can simply unflatten them inside the function.
|
|
363
|
+
flatten_params(all_post_update_betas, jnp.array([])),
|
|
364
|
+
beta_dim,
|
|
365
|
+
user_ids,
|
|
366
|
+
action_prob_func,
|
|
367
|
+
action_prob_func_args_beta_index,
|
|
368
|
+
alg_update_func,
|
|
369
|
+
alg_update_func_type,
|
|
370
|
+
alg_update_func_args_beta_index,
|
|
371
|
+
alg_update_func_args_action_prob_index,
|
|
372
|
+
alg_update_func_args_action_prob_times_index,
|
|
373
|
+
action_prob_func_args_by_user_id_by_decision_time,
|
|
374
|
+
policy_num_by_decision_time_by_user_id,
|
|
375
|
+
initial_policy_num,
|
|
376
|
+
beta_index_by_policy_num,
|
|
377
|
+
update_func_args_by_by_user_id_by_policy_num,
|
|
378
|
+
action_by_decision_time_by_user_id,
|
|
379
|
+
suppress_all_data_checks,
|
|
380
|
+
suppress_interactive_data_checks,
|
|
381
|
+
only_latest_block=True,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Now we can just augment the cached previous phi-dot-bar with zeros
|
|
385
|
+
# and the latest block row.
|
|
386
|
+
phi_dot_bar = jnp.block(
|
|
387
|
+
[
|
|
388
|
+
self.latest_phi_dot_bar,
|
|
389
|
+
jnp.zeros((beta_dim, beta_dim)),
|
|
390
|
+
phi_dot_bar_latest_block[-beta_dim:, :],
|
|
391
|
+
]
|
|
392
|
+
)
|
|
393
|
+
else:
|
|
394
|
+
|
|
395
|
+
(
|
|
396
|
+
phi_dot_bar,
|
|
397
|
+
avg_RL_estimating_function_stack,
|
|
398
|
+
) = jax.jacrev(
|
|
399
|
+
self.get_avg_weighted_RL_estimating_function_stacks_and_aux_values,
|
|
400
|
+
has_aux=True,
|
|
401
|
+
)(
|
|
402
|
+
# While JAX can technically differentiate with respect to a list of JAX arrays,
|
|
403
|
+
# it is apparently more efficient to flatten them into a single array. This is done
|
|
404
|
+
# here to improve performance. We can simply unflatten them inside the function.
|
|
405
|
+
flatten_params(all_post_update_betas, jnp.array([])),
|
|
406
|
+
beta_dim,
|
|
407
|
+
user_ids,
|
|
408
|
+
action_prob_func,
|
|
409
|
+
action_prob_func_args_beta_index,
|
|
410
|
+
alg_update_func,
|
|
411
|
+
alg_update_func_type,
|
|
412
|
+
alg_update_func_args_beta_index,
|
|
413
|
+
alg_update_func_args_action_prob_index,
|
|
414
|
+
alg_update_func_args_action_prob_times_index,
|
|
415
|
+
action_prob_func_args_by_user_id_by_decision_time,
|
|
416
|
+
policy_num_by_decision_time_by_user_id,
|
|
417
|
+
initial_policy_num,
|
|
418
|
+
beta_index_by_policy_num,
|
|
419
|
+
update_func_args_by_by_user_id_by_policy_num,
|
|
420
|
+
action_by_decision_time_by_user_id,
|
|
421
|
+
suppress_all_data_checks,
|
|
422
|
+
suppress_interactive_data_checks,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
self.latest_phi_dot_bar = phi_dot_bar
|
|
426
|
+
return phi_dot_bar, avg_RL_estimating_function_stack
|
|
427
|
+
|
|
428
|
+
def get_avg_weighted_RL_estimating_function_stacks_and_aux_values(
|
|
429
|
+
self,
|
|
430
|
+
flattened_betas_and_theta: jnp.ndarray,
|
|
431
|
+
beta_dim: int,
|
|
432
|
+
user_ids: jnp.ndarray,
|
|
433
|
+
action_prob_func: callable,
|
|
434
|
+
action_prob_func_args_beta_index: int,
|
|
435
|
+
alg_update_func: callable,
|
|
436
|
+
alg_update_func_type: str,
|
|
437
|
+
alg_update_func_args_beta_index: int,
|
|
438
|
+
alg_update_func_args_action_prob_index: int,
|
|
439
|
+
alg_update_func_args_action_prob_times_index: int,
|
|
440
|
+
action_prob_func_args_by_user_id_by_decision_time: dict[
|
|
441
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
442
|
+
],
|
|
443
|
+
policy_num_by_decision_time_by_user_id: dict[
|
|
444
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
445
|
+
],
|
|
446
|
+
initial_policy_num: int | float,
|
|
447
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
448
|
+
update_func_args_by_by_user_id_by_policy_num: dict[
|
|
449
|
+
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
450
|
+
],
|
|
451
|
+
action_by_decision_time_by_user_id: dict[
|
|
452
|
+
collections.abc.Hashable, dict[int, int]
|
|
453
|
+
],
|
|
454
|
+
suppress_all_data_checks: bool,
|
|
455
|
+
suppress_interactive_data_checks: bool,
|
|
456
|
+
only_latest_block: bool = False,
|
|
457
|
+
) -> tuple[
|
|
458
|
+
jnp.ndarray,
|
|
459
|
+
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
|
|
460
|
+
]:
|
|
461
|
+
"""
|
|
462
|
+
Computes the average weighted estimating function stack across all users, along with
|
|
463
|
+
auxiliary values used to construct the adaptive and classical sandwich variances.
|
|
464
|
+
|
|
465
|
+
If only_latest_block is True, only uses data from the most recent update.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
flattened_betas_and_theta (jnp.ndarray):
|
|
469
|
+
A list of JAX NumPy arrays representing the betas produced by all updates and the
|
|
470
|
+
theta value, in that order. Important that this is a 1D array for efficiency reasons.
|
|
471
|
+
We simply extract the betas and theta from this array below.
|
|
472
|
+
beta_dim (int):
|
|
473
|
+
The dimension of each of the beta parameters.
|
|
474
|
+
user_ids (jnp.ndarray):
|
|
475
|
+
A 1D JAX NumPy array of user IDs.
|
|
476
|
+
action_prob_func (callable):
|
|
477
|
+
The action probability function.
|
|
478
|
+
action_prob_func_args_beta_index (int):
|
|
479
|
+
The index of beta in the action probability function arguments tuples.
|
|
480
|
+
alg_update_func (callable):
|
|
481
|
+
The algorithm update function.
|
|
482
|
+
alg_update_func_type (str):
|
|
483
|
+
The type of the algorithm update function (loss or estimating).
|
|
484
|
+
alg_update_func_args_beta_index (int):
|
|
485
|
+
The index of beta in the update function arguments tuples.
|
|
486
|
+
alg_update_func_args_action_prob_index (int):
|
|
487
|
+
The index of action probabilities in the update function arguments tuple, if
|
|
488
|
+
applicable. -1 otherwise.
|
|
489
|
+
alg_update_func_args_action_prob_times_index (int):
|
|
490
|
+
The index in the update function arguments tuple where an array of times for which the
|
|
491
|
+
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
492
|
+
action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
493
|
+
A dictionary mapping decision times to maps of user ids to the function arguments
|
|
494
|
+
required to compute action probabilities for this user.
|
|
495
|
+
policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
496
|
+
A map of user ids to dictionaries mapping decision times to the policy number in use.
|
|
497
|
+
Only applies to in-study decision times!
|
|
498
|
+
initial_policy_num (int | float):
|
|
499
|
+
The policy number of the initial policy before any updates.
|
|
500
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
501
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
502
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
503
|
+
update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
504
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
|
|
505
|
+
to their respective update function arguments.
|
|
506
|
+
action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
507
|
+
A dictionary mapping user IDs to their respective actions taken at each decision time.
|
|
508
|
+
Only applies to in-study decision times!
|
|
509
|
+
suppress_all_data_checks (bool):
|
|
510
|
+
If True, suppresses carrying out any data checks at all.
|
|
511
|
+
suppress_interactive_data_checks (bool):
|
|
512
|
+
If True, suppresses interactive data checks that would otherwise be performed to ensure
|
|
513
|
+
the correctness of the threaded arguments. The checks are still performed, but
|
|
514
|
+
any interactive prompts are suppressed.
|
|
515
|
+
only_latest_block (bool):
|
|
516
|
+
If True, only uses data from the most recent update.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
jnp.ndarray:
|
|
520
|
+
A 2D JAX NumPy array holding the average weighted estimating function stack.
|
|
521
|
+
jnp.ndarray:
|
|
522
|
+
The same, again. We will differentiate the first output.
|
|
523
|
+
"""
|
|
524
|
+
|
|
525
|
+
algorithm_estimating_func = (
|
|
526
|
+
jax.grad(alg_update_func, argnums=alg_update_func_args_beta_index)
|
|
527
|
+
if (alg_update_func_type == FunctionTypes.LOSS)
|
|
528
|
+
else alg_update_func
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
betas, _ = unflatten_params(
|
|
532
|
+
flattened_betas_and_theta,
|
|
533
|
+
beta_dim,
|
|
534
|
+
0,
|
|
535
|
+
)
|
|
536
|
+
# 1. If only_latest_block is True, we need to filter all the arguments to only
|
|
537
|
+
# include those relevant to the latest update. We still need action probabilities
|
|
538
|
+
# from the beginning for the weights, but the update function args can be trimmed
|
|
539
|
+
# to the max policy so that the loop single_user_weighted_RL_estimating_function_stacker
|
|
540
|
+
# is only over one policy.
|
|
541
|
+
if only_latest_block:
|
|
542
|
+
logger.info(
|
|
543
|
+
"Filtering algorithm update function arguments to only include those relevant to the latest update."
|
|
544
|
+
)
|
|
545
|
+
max_policy_num = max(beta_index_by_policy_num)
|
|
546
|
+
update_func_args_by_by_user_id_by_policy_num = {
|
|
547
|
+
max_policy_num: update_func_args_by_by_user_id_by_policy_num[
|
|
548
|
+
max_policy_num
|
|
549
|
+
]
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
# 2. Thread in the betas and theta in all_post_update_betas_and_theta into the arguments
|
|
553
|
+
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
554
|
+
# values should be the same, but not connected to the parameter we are differentiating
|
|
555
|
+
# with respect to. Note we will also find it useful below to have the action probability args
|
|
556
|
+
# nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
|
|
557
|
+
|
|
558
|
+
logger.info("Threading in betas to action probability arguments for all users.")
|
|
559
|
+
(
|
|
560
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
561
|
+
action_prob_func_args_by_decision_time_by_user_id,
|
|
562
|
+
) = thread_action_prob_func_args(
|
|
563
|
+
action_prob_func_args_by_user_id_by_decision_time,
|
|
564
|
+
policy_num_by_decision_time_by_user_id,
|
|
565
|
+
initial_policy_num,
|
|
566
|
+
betas,
|
|
567
|
+
beta_index_by_policy_num,
|
|
568
|
+
action_prob_func_args_beta_index,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# 3. Thread the central betas into the algorithm update function arguments
|
|
572
|
+
# and replace any action probabilities with reconstructed ones from the above
|
|
573
|
+
# arguments with the central betas introduced.
|
|
574
|
+
logger.info(
|
|
575
|
+
"Threading in betas and beta-dependent action probabilities to algorithm update "
|
|
576
|
+
"function args for all users"
|
|
577
|
+
)
|
|
578
|
+
threaded_update_func_args_by_policy_num_by_user_id = thread_update_func_args(
|
|
579
|
+
update_func_args_by_by_user_id_by_policy_num,
|
|
580
|
+
betas,
|
|
581
|
+
beta_index_by_policy_num,
|
|
582
|
+
alg_update_func_args_beta_index,
|
|
583
|
+
alg_update_func_args_action_prob_index,
|
|
584
|
+
alg_update_func_args_action_prob_times_index,
|
|
585
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
586
|
+
action_prob_func,
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# If action probabilities are used in the algorithm estimating function, make
|
|
590
|
+
# sure that substituting in the reconstructed action probabilities is approximately
|
|
591
|
+
# equivalent to using the original action probabilities.
|
|
592
|
+
if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
|
|
593
|
+
input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
|
|
594
|
+
algorithm_estimating_func,
|
|
595
|
+
update_func_args_by_by_user_id_by_policy_num,
|
|
596
|
+
threaded_update_func_args_by_policy_num_by_user_id,
|
|
597
|
+
suppress_interactive_data_checks,
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
# 5. Now we can compute the weighted estimating function stacks for all users
|
|
601
|
+
# as well as collect related values used to construct the adaptive and classical
|
|
602
|
+
# sandwich variances.
|
|
603
|
+
RL_stacks = jnp.array(
|
|
604
|
+
[
|
|
605
|
+
self.single_user_weighted_RL_estimating_function_stacker(
|
|
606
|
+
beta_dim,
|
|
607
|
+
user_id,
|
|
608
|
+
action_prob_func,
|
|
609
|
+
algorithm_estimating_func,
|
|
610
|
+
action_prob_func_args_beta_index,
|
|
611
|
+
action_prob_func_args_by_decision_time_by_user_id[user_id],
|
|
612
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
|
|
613
|
+
threaded_update_func_args_by_policy_num_by_user_id[user_id],
|
|
614
|
+
policy_num_by_decision_time_by_user_id[user_id],
|
|
615
|
+
action_by_decision_time_by_user_id[user_id],
|
|
616
|
+
beta_index_by_policy_num,
|
|
617
|
+
)
|
|
618
|
+
for user_id in user_ids.tolist()
|
|
619
|
+
]
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
# 6. We will differentiate the first output, while the second will be used
|
|
623
|
+
# for an estimating function sum check.
|
|
624
|
+
mean_stack_across_users = jnp.mean(RL_stacks, axis=0)
|
|
625
|
+
return mean_stack_across_users, mean_stack_across_users
|
|
626
|
+
|
|
627
|
+
def single_user_weighted_RL_estimating_function_stacker(
|
|
628
|
+
self,
|
|
629
|
+
beta_dim: int,
|
|
630
|
+
user_id: collections.abc.Hashable,
|
|
631
|
+
action_prob_func: callable,
|
|
632
|
+
algorithm_estimating_func: callable,
|
|
633
|
+
action_prob_func_args_beta_index: int,
|
|
634
|
+
action_prob_func_args_by_decision_time: dict[
|
|
635
|
+
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
636
|
+
],
|
|
637
|
+
threaded_action_prob_func_args_by_decision_time: dict[
|
|
638
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
639
|
+
],
|
|
640
|
+
threaded_update_func_args_by_policy_num: dict[
|
|
641
|
+
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
642
|
+
],
|
|
643
|
+
policy_num_by_decision_time: dict[
|
|
644
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
645
|
+
],
|
|
646
|
+
action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
|
|
647
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
648
|
+
) -> tuple[
|
|
649
|
+
jnp.ndarray[jnp.float32],
|
|
650
|
+
jnp.ndarray[jnp.float32],
|
|
651
|
+
jnp.ndarray[jnp.float32],
|
|
652
|
+
jnp.ndarray[jnp.float32],
|
|
653
|
+
]:
|
|
654
|
+
"""
|
|
655
|
+
Computes a weighted estimating function stack for a given algorithm estimating function
|
|
656
|
+
and arguments, inference estimating functio and arguments, and action probability function and
|
|
657
|
+
arguments.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
beta_dim (list[jnp.ndarray]):
|
|
661
|
+
A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
|
|
662
|
+
|
|
663
|
+
user_id (collections.abc.Hashable):
|
|
664
|
+
The user ID for which to compute the weighted estimating function stack.
|
|
665
|
+
|
|
666
|
+
action_prob_func (callable):
|
|
667
|
+
The function used to compute the probability of action 1 at a given decision time for
|
|
668
|
+
a particular user given their state and the algorithm parameters.
|
|
669
|
+
|
|
670
|
+
algorithm_estimating_func (callable):
|
|
671
|
+
The estimating function that corresponds to algorithm updates.
|
|
672
|
+
|
|
673
|
+
action_prob_func_args_beta_index (int):
|
|
674
|
+
The index of the beta argument in the action probability function's arguments.
|
|
675
|
+
|
|
676
|
+
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
677
|
+
A map from decision times to tuples of arguments for this user for the action
|
|
678
|
+
probability function. This is for all decision times (args are an empty
|
|
679
|
+
tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
|
|
680
|
+
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
681
|
+
will occur.
|
|
682
|
+
|
|
683
|
+
threaded_action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
684
|
+
A map from decision times to tuples of arguments for the action
|
|
685
|
+
probability function, with the shared betas threaded in for differentation. Decision
|
|
686
|
+
times should be sorted.
|
|
687
|
+
|
|
688
|
+
threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
689
|
+
A map from policy numbers to tuples containing the arguments for
|
|
690
|
+
the corresponding estimating functions for this user, with the shared betas threaded in
|
|
691
|
+
for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
|
|
692
|
+
should be sorted.
|
|
693
|
+
|
|
694
|
+
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
695
|
+
A dictionary mapping decision times to the policy number in use. This may be
|
|
696
|
+
user-specific. Should be sorted by decision time. Only applies to in-study decision
|
|
697
|
+
times!
|
|
698
|
+
|
|
699
|
+
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
700
|
+
A dictionary mapping decision times to actions taken. Only applies to in-study decision
|
|
701
|
+
times!
|
|
702
|
+
|
|
703
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
704
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
705
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the user's weighted
|
|
709
|
+
estimating function stack.
|
|
710
|
+
"""
|
|
711
|
+
|
|
712
|
+
logger.info(
|
|
713
|
+
"Computing weighted estimating function stack for user %s.", user_id
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
# First, reformat the supplied data into more convenient structures.
|
|
717
|
+
|
|
718
|
+
# 1. Form a dictionary mapping policy numbers to the first time they were
|
|
719
|
+
# applicable (for this user). Note that this includes ALL policies, initial
|
|
720
|
+
# fallbacks included.
|
|
721
|
+
# Collect the first time after the first update separately for convenience.
|
|
722
|
+
# These are both used to form the Radon-Nikodym weights for the right times.
|
|
723
|
+
min_time_by_policy_num, first_time_after_first_update = (
|
|
724
|
+
get_min_time_by_policy_num(
|
|
725
|
+
policy_num_by_decision_time,
|
|
726
|
+
beta_index_by_policy_num,
|
|
727
|
+
)
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
# 2. Get the start and end times for this user.
|
|
731
|
+
user_start_time = math.inf
|
|
732
|
+
user_end_time = -math.inf
|
|
733
|
+
for decision_time in action_by_decision_time:
|
|
734
|
+
user_start_time = min(user_start_time, decision_time)
|
|
735
|
+
user_end_time = max(user_end_time, decision_time)
|
|
736
|
+
|
|
737
|
+
# 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
|
|
738
|
+
logger.info(
|
|
739
|
+
"Computing the algorithm component of the weighted estimating function stack for user %s.",
|
|
740
|
+
user_id,
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
in_study_action_prob_func_args = [
|
|
744
|
+
args for args in action_prob_func_args_by_decision_time.values() if args
|
|
745
|
+
]
|
|
746
|
+
in_study_betas_list_by_decision_time_index = jnp.array(
|
|
747
|
+
[
|
|
748
|
+
action_prob_func_args[action_prob_func_args_beta_index]
|
|
749
|
+
for action_prob_func_args in in_study_action_prob_func_args
|
|
750
|
+
]
|
|
751
|
+
)
|
|
752
|
+
in_study_actions_list_by_decision_time_index = jnp.array(
|
|
753
|
+
list(action_by_decision_time.values())
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
# Sort the threaded args by decision time to be cautious. We check if the
|
|
757
|
+
# user id is present in the user args dict because we may call this on a
|
|
758
|
+
# subset of the user arg dict when we are batching arguments by shape
|
|
759
|
+
sorted_threaded_action_prob_args_by_decision_time = {
|
|
760
|
+
decision_time: threaded_action_prob_func_args_by_decision_time[
|
|
761
|
+
decision_time
|
|
762
|
+
]
|
|
763
|
+
for decision_time in range(user_start_time, user_end_time + 1)
|
|
764
|
+
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
num_args = None
|
|
768
|
+
for args in sorted_threaded_action_prob_args_by_decision_time.values():
|
|
769
|
+
if args:
|
|
770
|
+
num_args = len(args)
|
|
771
|
+
break
|
|
772
|
+
|
|
773
|
+
# NOTE: Cannot do [[]] * num_args here! Then all lists point
|
|
774
|
+
# same object...
|
|
775
|
+
batched_threaded_arg_lists = [[] for _ in range(num_args)]
|
|
776
|
+
for (
|
|
777
|
+
decision_time,
|
|
778
|
+
args,
|
|
779
|
+
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
780
|
+
if not args:
|
|
781
|
+
continue
|
|
782
|
+
for idx, arg in enumerate(args):
|
|
783
|
+
batched_threaded_arg_lists[idx].append(arg)
|
|
784
|
+
|
|
785
|
+
batched_threaded_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensors(
|
|
786
|
+
batched_threaded_arg_lists
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
# Note that we do NOT use the shared betas in the first arg to the weight function,
|
|
790
|
+
# since we don't want differentiation to happen with respect to them.
|
|
791
|
+
# Just grab the original beta from the update function arguments. This is the same
|
|
792
|
+
# value, but impervious to differentiation with respect to all_post_update_betas. The
|
|
793
|
+
# args, on the other hand, are a function of all_post_update_betas.
|
|
794
|
+
in_study_weights = jax.vmap(
|
|
795
|
+
fun=get_radon_nikodym_weight,
|
|
796
|
+
in_axes=[0, None, None, 0] + batch_axes,
|
|
797
|
+
out_axes=0,
|
|
798
|
+
)(
|
|
799
|
+
in_study_betas_list_by_decision_time_index,
|
|
800
|
+
action_prob_func,
|
|
801
|
+
action_prob_func_args_beta_index,
|
|
802
|
+
in_study_actions_list_by_decision_time_index,
|
|
803
|
+
*batched_threaded_arg_tensors,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
in_study_index = 0
|
|
807
|
+
decision_time_to_all_weights_index_offset = min(
|
|
808
|
+
sorted_threaded_action_prob_args_by_decision_time
|
|
809
|
+
)
|
|
810
|
+
all_weights_raw = []
|
|
811
|
+
for (
|
|
812
|
+
decision_time,
|
|
813
|
+
args,
|
|
814
|
+
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
815
|
+
all_weights_raw.append(in_study_weights[in_study_index] if args else 1.0)
|
|
816
|
+
in_study_index += 1
|
|
817
|
+
all_weights = jnp.array(all_weights_raw)
|
|
818
|
+
|
|
819
|
+
algorithm_component = jnp.concatenate(
|
|
820
|
+
[
|
|
821
|
+
# Here we compute a product of Radon-Nikodym weights
|
|
822
|
+
# for all decision times after the first update and before the update
|
|
823
|
+
# update under consideration took effect, for which the user was in the study.
|
|
824
|
+
(
|
|
825
|
+
jnp.prod(
|
|
826
|
+
all_weights[
|
|
827
|
+
# The earliest time after the first update where the user was in
|
|
828
|
+
# the study
|
|
829
|
+
max(
|
|
830
|
+
first_time_after_first_update,
|
|
831
|
+
user_start_time,
|
|
832
|
+
)
|
|
833
|
+
- decision_time_to_all_weights_index_offset :
|
|
834
|
+
# One more than the latest time the user was in the study before the time
|
|
835
|
+
# the update under consideration first applied. Note the + 1 because range
|
|
836
|
+
# does not include the right endpoint.
|
|
837
|
+
min(
|
|
838
|
+
min_time_by_policy_num.get(policy_num, math.inf),
|
|
839
|
+
user_end_time + 1,
|
|
840
|
+
)
|
|
841
|
+
- decision_time_to_all_weights_index_offset,
|
|
842
|
+
]
|
|
843
|
+
# If the user exited the study before there were any updates,
|
|
844
|
+
# this variable will be None and the above code to grab a weight would
|
|
845
|
+
# throw an error. Just use 1 to include the unweighted estimating function
|
|
846
|
+
# if they have data to contribute to the update.
|
|
847
|
+
if first_time_after_first_update is not None
|
|
848
|
+
else 1
|
|
849
|
+
) # Now use the above to weight the alg estimating function for this update
|
|
850
|
+
* algorithm_estimating_func(*update_args)
|
|
851
|
+
# If there are no arguments for the update function, the user is not yet in the
|
|
852
|
+
# study, so we just add a zero vector contribution to the sum across users.
|
|
853
|
+
# Note that after they exit, they still contribute all their data to later
|
|
854
|
+
# updates.
|
|
855
|
+
if update_args
|
|
856
|
+
else jnp.zeros(beta_dim)
|
|
857
|
+
)
|
|
858
|
+
# vmapping over this would be tricky due to different shapes across updates
|
|
859
|
+
for policy_num, update_args in threaded_update_func_args_by_policy_num.items()
|
|
860
|
+
]
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
if algorithm_component.size % beta_dim != 0:
|
|
864
|
+
raise ValueError(
|
|
865
|
+
"The algorithm component of the weighted estimating function stack does not have a "
|
|
866
|
+
"size that is a multiple of the beta dimension. This likely means that the "
|
|
867
|
+
"algorithm estimating function is not returning a vector of the correct size."
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
return algorithm_component
|