lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1312 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
from typing import Any
|
|
6
|
+
import collections
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.special import logit
|
|
10
|
+
import plotext as plt
|
|
11
|
+
import jax
|
|
12
|
+
from jax import numpy as jnp
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from . import after_study_analysis
|
|
16
|
+
from .constants import FunctionTypes
|
|
17
|
+
from .vmap_helpers import stack_batched_arg_lists_into_tensors
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
logging.basicConfig(
|
|
21
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
22
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
23
|
+
level=logging.INFO,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_datum_for_blowup_supervised_learning(
|
|
28
|
+
joint_adaptive_bread_inverse_matrix,
|
|
29
|
+
joint_adaptive_bread_inverse_cond,
|
|
30
|
+
avg_estimating_function_stack,
|
|
31
|
+
per_user_estimating_function_stacks,
|
|
32
|
+
all_post_update_betas,
|
|
33
|
+
study_df,
|
|
34
|
+
in_study_col_name,
|
|
35
|
+
calendar_t_col_name,
|
|
36
|
+
action_prob_col_name,
|
|
37
|
+
user_id_col_name,
|
|
38
|
+
reward_col_name,
|
|
39
|
+
theta_est,
|
|
40
|
+
adaptive_sandwich_var_estimate,
|
|
41
|
+
user_ids,
|
|
42
|
+
beta_dim,
|
|
43
|
+
theta_dim,
|
|
44
|
+
initial_policy_num,
|
|
45
|
+
beta_index_by_policy_num,
|
|
46
|
+
policy_num_by_decision_time_by_user_id,
|
|
47
|
+
theta_calculation_func,
|
|
48
|
+
action_prob_func,
|
|
49
|
+
action_prob_func_args_beta_index,
|
|
50
|
+
inference_func,
|
|
51
|
+
inference_func_type,
|
|
52
|
+
inference_func_args_theta_index,
|
|
53
|
+
inference_func_args_action_prob_index,
|
|
54
|
+
inference_action_prob_decision_times_by_user_id,
|
|
55
|
+
action_prob_func_args,
|
|
56
|
+
action_by_decision_time_by_user_id,
|
|
57
|
+
) -> dict[str, Any]:
|
|
58
|
+
"""
|
|
59
|
+
Collects a datum for supervised learning about adaptive sandwich blowup.
|
|
60
|
+
|
|
61
|
+
The datum consists of features and the raw adaptive sandwich variance estimate as a label.
|
|
62
|
+
|
|
63
|
+
A few plots are produced along the way to help visualize the data.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
joint_adaptive_bread_inverse_matrix (jnp.ndarray):
|
|
67
|
+
The joint adaptive bread inverse matrix.
|
|
68
|
+
joint_adaptive_bread_inverse_cond (float):
|
|
69
|
+
The condition number of the joint adaptive bread inverse matrix.
|
|
70
|
+
avg_estimating_function_stack (jnp.ndarray):
|
|
71
|
+
The average estimating function stack across users.
|
|
72
|
+
per_user_estimating_function_stacks (jnp.ndarray):
|
|
73
|
+
The estimating function stacks for each user.
|
|
74
|
+
all_post_update_betas (jnp.ndarray):
|
|
75
|
+
All post-update beta parameters.
|
|
76
|
+
study_df (pd.DataFrame):
|
|
77
|
+
The study DataFrame.
|
|
78
|
+
in_study_col_name (str):
|
|
79
|
+
Column name indicating if a user is in the study in the study dataframe.
|
|
80
|
+
calendar_t_col_name (str):
|
|
81
|
+
Column name for calendar time in the study dataframe.
|
|
82
|
+
action_prob_col_name (str):
|
|
83
|
+
Column name for action probabilities in the study dataframe.
|
|
84
|
+
user_id_col_name (str):
|
|
85
|
+
Column name for user IDs in the study dataframe
|
|
86
|
+
reward_col_name (str):
|
|
87
|
+
Column name for rewards in the study dataframe.
|
|
88
|
+
theta_est (jnp.ndarray):
|
|
89
|
+
The estimate of the parameter vector theta.
|
|
90
|
+
adaptive_sandwich_var_estimate (jnp.ndarray):
|
|
91
|
+
The adaptive sandwich variance estimate for theta.
|
|
92
|
+
user_ids (jnp.ndarray):
|
|
93
|
+
Array of unique user IDs.
|
|
94
|
+
beta_dim (int):
|
|
95
|
+
Dimension of the beta parameter vector.
|
|
96
|
+
theta_dim (int):
|
|
97
|
+
Dimension of the theta parameter vector.
|
|
98
|
+
initial_policy_num (int | float):
|
|
99
|
+
The initial policy number used in the study.
|
|
100
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
101
|
+
Mapping from policy numbers to indices in all_post_update_betas.
|
|
102
|
+
policy_num_by_decision_time_by_user_id (dict):
|
|
103
|
+
Mapping from user IDs to their policy numbers by decision time.
|
|
104
|
+
theta_calculation_func (callable):
|
|
105
|
+
The theta calculation function.
|
|
106
|
+
action_prob_func (callable):
|
|
107
|
+
The action probability function.
|
|
108
|
+
action_prob_func_args_beta_index (int):
|
|
109
|
+
Index for beta in action probability function arguments.
|
|
110
|
+
inference_func (callable):
|
|
111
|
+
The inference function.
|
|
112
|
+
inference_func_type (str):
|
|
113
|
+
Type of the inference function.
|
|
114
|
+
inference_func_args_theta_index (int):
|
|
115
|
+
Index for theta in inference function arguments.
|
|
116
|
+
inference_func_args_action_prob_index (int):
|
|
117
|
+
Index for action probability in inference function arguments.
|
|
118
|
+
inference_action_prob_decision_times_by_user_id (dict):
|
|
119
|
+
Mapping from user IDs to decision times for action probabilities used in inference.
|
|
120
|
+
action_prob_func_args (dict):
|
|
121
|
+
Arguments for the action probability function.
|
|
122
|
+
action_by_decision_time_by_user_id (dict):
|
|
123
|
+
Mapping from user IDs to their actions by decision time.
|
|
124
|
+
Returns:
|
|
125
|
+
dict[str, Any]: A dictionary containing features and the label for supervised learning.
|
|
126
|
+
"""
|
|
127
|
+
num_diagonal_blocks = (
|
|
128
|
+
(joint_adaptive_bread_inverse_matrix.shape[0] - theta_dim) // beta_dim
|
|
129
|
+
) + 1
|
|
130
|
+
diagonal_block_sizes = ([beta_dim] * (num_diagonal_blocks - 1)) + [theta_dim]
|
|
131
|
+
|
|
132
|
+
block_bounds = np.cumsum([0] + list(diagonal_block_sizes))
|
|
133
|
+
num_block_rows_cols = len(diagonal_block_sizes)
|
|
134
|
+
|
|
135
|
+
# collect diagonal and sub-diagonal block norms and diagonal condition numbers
|
|
136
|
+
off_diag_block_norms = {}
|
|
137
|
+
diag_norms = []
|
|
138
|
+
diag_conds = []
|
|
139
|
+
off_diag_row_norms = np.zeros(num_block_rows_cols)
|
|
140
|
+
off_diag_col_norms = np.zeros(num_block_rows_cols)
|
|
141
|
+
for i in range(num_block_rows_cols):
|
|
142
|
+
for j in range(num_block_rows_cols):
|
|
143
|
+
if i > j: # below-diagonal blocks
|
|
144
|
+
row_slice = slice(block_bounds[i], block_bounds[i + 1])
|
|
145
|
+
col_slice = slice(block_bounds[j], block_bounds[j + 1])
|
|
146
|
+
block_norm = np.linalg.norm(
|
|
147
|
+
joint_adaptive_bread_inverse_matrix[row_slice, col_slice],
|
|
148
|
+
ord="fro",
|
|
149
|
+
)
|
|
150
|
+
# We will sum here and take the square root later
|
|
151
|
+
off_diag_row_norms[i] += block_norm
|
|
152
|
+
off_diag_col_norms[j] += block_norm
|
|
153
|
+
off_diag_block_norms[(i, j)] = block_norm
|
|
154
|
+
|
|
155
|
+
# handle diagonal blocks
|
|
156
|
+
sl = slice(block_bounds[i], block_bounds[i + 1])
|
|
157
|
+
diag_norms.append(
|
|
158
|
+
np.linalg.norm(joint_adaptive_bread_inverse_matrix[sl, sl], ord="fro")
|
|
159
|
+
)
|
|
160
|
+
diag_conds.append(np.linalg.cond(joint_adaptive_bread_inverse_matrix[sl, sl]))
|
|
161
|
+
|
|
162
|
+
# Sqrt each row/col sum to truly get row/column norms.
|
|
163
|
+
# Perhaps not necessary for learning, but more natural
|
|
164
|
+
off_diag_row_norms = np.sqrt(off_diag_row_norms)
|
|
165
|
+
off_diag_col_norms = np.sqrt(off_diag_col_norms)
|
|
166
|
+
|
|
167
|
+
# Get the per-person estimating function stack norms
|
|
168
|
+
estimating_function_stack_norms = np.linalg.norm(
|
|
169
|
+
per_user_estimating_function_stacks, axis=1
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Get the average estimating function stack norms by update/inference
|
|
173
|
+
# Use the bounds variable from above to split the estimating function stacks
|
|
174
|
+
# into blocks corresponding to the updates and inference.
|
|
175
|
+
avg_estimating_function_stack_norms_per_segment = [
|
|
176
|
+
np.mean(
|
|
177
|
+
np.linalg.norm(
|
|
178
|
+
per_user_estimating_function_stacks[
|
|
179
|
+
:, block_bounds[i] : block_bounds[i + 1]
|
|
180
|
+
],
|
|
181
|
+
axis=1,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
for i in range(len(block_bounds) - 1)
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
# Compute the norms of each successive difference in all_post_update_betas.
|
|
188
|
+
successive_beta_diffs = np.diff(np.array(all_post_update_betas), axis=0)
|
|
189
|
+
successive_beta_diff_norms = np.linalg.norm(successive_beta_diffs, axis=1)
|
|
190
|
+
max_successive_beta_diff_norm = np.max(successive_beta_diff_norms)
|
|
191
|
+
std_successive_beta_diff_norm = np.std(successive_beta_diff_norms)
|
|
192
|
+
|
|
193
|
+
# Add a column with logits of the action probabilities
|
|
194
|
+
# Compute the average and standard deviation of the logits of the action probabilities at each decision time using study_df
|
|
195
|
+
# action_prob_logit_means and action_prob_logit_stds are numpy arrays of mean and stddev at each decision time
|
|
196
|
+
# Only compute logits for rows where user is in the study; set others to NaN
|
|
197
|
+
in_study_mask = study_df[in_study_col_name] == 1
|
|
198
|
+
study_df["action_prob_logit"] = np.where(
|
|
199
|
+
in_study_mask,
|
|
200
|
+
logit(study_df[action_prob_col_name]),
|
|
201
|
+
np.nan,
|
|
202
|
+
)
|
|
203
|
+
grouped_action_prob_logit = study_df.loc[in_study_mask].groupby(
|
|
204
|
+
calendar_t_col_name
|
|
205
|
+
)["action_prob_logit"]
|
|
206
|
+
action_prob_logit_means_by_t = grouped_action_prob_logit.mean().values
|
|
207
|
+
action_prob_logit_stds_by_t = grouped_action_prob_logit.std().values
|
|
208
|
+
|
|
209
|
+
# Compute the average and standard deviation of the rewards at each decision time using study_df
|
|
210
|
+
# reward_means and reward_stds are numpy arrays of mean and stddev at each decision time
|
|
211
|
+
grouped_reward = study_df.loc[in_study_mask].groupby(calendar_t_col_name)[
|
|
212
|
+
reward_col_name
|
|
213
|
+
]
|
|
214
|
+
reward_means_by_t = grouped_reward.mean().values
|
|
215
|
+
reward_stds_by_t = grouped_reward.std().values
|
|
216
|
+
|
|
217
|
+
joint_bread_inverse_min_singular_value = np.linalg.svd(
|
|
218
|
+
joint_adaptive_bread_inverse_matrix, compute_uv=False
|
|
219
|
+
)[-1]
|
|
220
|
+
|
|
221
|
+
max_reward = study_df.loc[in_study_mask][reward_col_name].max()
|
|
222
|
+
|
|
223
|
+
norm_avg_estimating_function_stack = np.linalg.norm(avg_estimating_function_stack)
|
|
224
|
+
max_estimating_function_stack_norm = np.max(estimating_function_stack_norms)
|
|
225
|
+
|
|
226
|
+
(
|
|
227
|
+
premature_thetas,
|
|
228
|
+
premature_adaptive_sandwiches,
|
|
229
|
+
premature_classical_sandwiches,
|
|
230
|
+
premature_joint_adaptive_bread_inverse_condition_numbers,
|
|
231
|
+
premature_avg_inference_estimating_functions,
|
|
232
|
+
) = calculate_sequence_of_premature_adaptive_estimates(
|
|
233
|
+
study_df,
|
|
234
|
+
initial_policy_num,
|
|
235
|
+
beta_index_by_policy_num,
|
|
236
|
+
policy_num_by_decision_time_by_user_id,
|
|
237
|
+
theta_calculation_func,
|
|
238
|
+
calendar_t_col_name,
|
|
239
|
+
action_prob_col_name,
|
|
240
|
+
user_id_col_name,
|
|
241
|
+
in_study_col_name,
|
|
242
|
+
all_post_update_betas,
|
|
243
|
+
user_ids,
|
|
244
|
+
action_prob_func,
|
|
245
|
+
action_prob_func_args_beta_index,
|
|
246
|
+
inference_func,
|
|
247
|
+
inference_func_type,
|
|
248
|
+
inference_func_args_theta_index,
|
|
249
|
+
inference_func_args_action_prob_index,
|
|
250
|
+
inference_action_prob_decision_times_by_user_id,
|
|
251
|
+
action_prob_func_args,
|
|
252
|
+
action_by_decision_time_by_user_id,
|
|
253
|
+
joint_adaptive_bread_inverse_matrix,
|
|
254
|
+
per_user_estimating_function_stacks,
|
|
255
|
+
beta_dim,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
np.testing.assert_allclose(
|
|
259
|
+
np.zeros_like(premature_avg_inference_estimating_functions),
|
|
260
|
+
premature_avg_inference_estimating_functions,
|
|
261
|
+
atol=1e-3,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Plot premature joint adaptive bread inverse log condition numbers
|
|
265
|
+
plt.clear_figure()
|
|
266
|
+
plt.title("Premature Joint Adaptive Bread Inverse Log Condition Numbers")
|
|
267
|
+
plt.xlabel("Premature Update Index")
|
|
268
|
+
plt.ylabel("Log Condition Number")
|
|
269
|
+
plt.scatter(
|
|
270
|
+
np.log(premature_joint_adaptive_bread_inverse_condition_numbers),
|
|
271
|
+
color="blue+",
|
|
272
|
+
)
|
|
273
|
+
plt.grid(True)
|
|
274
|
+
plt.xticks(
|
|
275
|
+
range(
|
|
276
|
+
0,
|
|
277
|
+
len(premature_joint_adaptive_bread_inverse_condition_numbers),
|
|
278
|
+
max(
|
|
279
|
+
1,
|
|
280
|
+
len(premature_joint_adaptive_bread_inverse_condition_numbers) // 10,
|
|
281
|
+
),
|
|
282
|
+
)
|
|
283
|
+
)
|
|
284
|
+
plt.show()
|
|
285
|
+
|
|
286
|
+
# Plot each diagonal element of premature adaptive sandwiches
|
|
287
|
+
num_diag = premature_adaptive_sandwiches.shape[-1]
|
|
288
|
+
for i in range(num_diag):
|
|
289
|
+
plt.clear_figure()
|
|
290
|
+
plt.title(f"Premature Adaptive Sandwich Diagonal Element {i}")
|
|
291
|
+
plt.xlabel("Premature Update Index")
|
|
292
|
+
plt.ylabel(f"Variance (Diagonal {i})")
|
|
293
|
+
plt.scatter(np.array(premature_adaptive_sandwiches[:, i, i]), color="blue+")
|
|
294
|
+
plt.grid(True)
|
|
295
|
+
plt.xticks(
|
|
296
|
+
range(
|
|
297
|
+
0,
|
|
298
|
+
int(premature_adaptive_sandwiches.shape[0]),
|
|
299
|
+
max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
|
|
300
|
+
)
|
|
301
|
+
)
|
|
302
|
+
plt.show()
|
|
303
|
+
|
|
304
|
+
plt.clear_figure()
|
|
305
|
+
plt.title(
|
|
306
|
+
f"Premature Adaptive Sandwich Diagonal Element {i} Ratio to Classical"
|
|
307
|
+
)
|
|
308
|
+
plt.xlabel("Premature Update Index")
|
|
309
|
+
plt.ylabel(f"Variance (Diagonal {i})")
|
|
310
|
+
plt.scatter(
|
|
311
|
+
np.array(premature_adaptive_sandwiches[:, i, i])
|
|
312
|
+
/ np.array(premature_classical_sandwiches[:, i, i]),
|
|
313
|
+
color="red+",
|
|
314
|
+
)
|
|
315
|
+
plt.grid(True)
|
|
316
|
+
plt.xticks(
|
|
317
|
+
range(
|
|
318
|
+
0,
|
|
319
|
+
int(premature_adaptive_sandwiches.shape[0]),
|
|
320
|
+
max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
plt.show()
|
|
324
|
+
|
|
325
|
+
plt.clear_figure()
|
|
326
|
+
plt.title(f"Premature Theta Estimates At Index {i}")
|
|
327
|
+
plt.xlabel("Premature Update Index")
|
|
328
|
+
plt.ylabel(f"Theta element {i}")
|
|
329
|
+
plt.scatter(np.array(premature_thetas[:, i]), color="green+")
|
|
330
|
+
plt.grid(True)
|
|
331
|
+
plt.xticks(
|
|
332
|
+
range(
|
|
333
|
+
0,
|
|
334
|
+
int(premature_adaptive_sandwiches.shape[0]),
|
|
335
|
+
max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
|
|
336
|
+
)
|
|
337
|
+
)
|
|
338
|
+
plt.show()
|
|
339
|
+
|
|
340
|
+
# Grab predictors related to premature Phi-dot-bars
|
|
341
|
+
RL_stack_beta_derivatives_block = joint_adaptive_bread_inverse_matrix[
|
|
342
|
+
:-theta_dim, :-theta_dim
|
|
343
|
+
]
|
|
344
|
+
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
345
|
+
premature_RL_block_condition_numbers = []
|
|
346
|
+
premature_RL_block_inverse_norms = []
|
|
347
|
+
diagonal_RL_block_condition_numbers = []
|
|
348
|
+
off_diagonal_RL_scaled_block_norm_sums = []
|
|
349
|
+
for i in range(1, num_updates + 1):
|
|
350
|
+
whole_block_size = i * beta_dim
|
|
351
|
+
whole_block = RL_stack_beta_derivatives_block[
|
|
352
|
+
:whole_block_size, :whole_block_size
|
|
353
|
+
]
|
|
354
|
+
whole_RL_block_cond_number = np.linalg.cond(whole_block)
|
|
355
|
+
premature_RL_block_condition_numbers.append(whole_RL_block_cond_number)
|
|
356
|
+
logger.info(
|
|
357
|
+
"Condition number of whole RL_stack_beta_derivatives_block (after update %s): %s",
|
|
358
|
+
i,
|
|
359
|
+
whole_RL_block_cond_number,
|
|
360
|
+
)
|
|
361
|
+
diagonal_block = RL_stack_beta_derivatives_block[
|
|
362
|
+
(i - 1) * beta_dim : i * beta_dim, (i - 1) * beta_dim : i * beta_dim
|
|
363
|
+
]
|
|
364
|
+
diagonal_RL_block_cond_number = np.linalg.cond(diagonal_block)
|
|
365
|
+
diagonal_RL_block_condition_numbers.append(diagonal_RL_block_cond_number)
|
|
366
|
+
logger.info(
|
|
367
|
+
"Condition number of just RL_stack_beta_derivatives_block *diagonal block* for update %s: %s",
|
|
368
|
+
i,
|
|
369
|
+
diagonal_RL_block_cond_number,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
premature_RL_block_inverse_norms.append(
|
|
373
|
+
np.linalg.norm(np.linalg.inv(whole_block))
|
|
374
|
+
)
|
|
375
|
+
logger.info(
|
|
376
|
+
"Norm of inverse of whole RL_stack_beta_derivatives_block (after update %s): %s",
|
|
377
|
+
i,
|
|
378
|
+
premature_RL_block_inverse_norms[-1],
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
off_diagonal_RL_scaled_block_norm_sum = 0
|
|
382
|
+
for j in range(1, i):
|
|
383
|
+
off_diagonal_block = RL_stack_beta_derivatives_block[
|
|
384
|
+
(i - 1) * beta_dim : i * beta_dim, (j - 1) * beta_dim : j * beta_dim
|
|
385
|
+
]
|
|
386
|
+
off_diagonal_scaled_block_norm = np.linalg.norm(
|
|
387
|
+
np.linalg.solve(diagonal_block, off_diagonal_block)
|
|
388
|
+
)
|
|
389
|
+
off_diagonal_RL_scaled_block_norm_sum += off_diagonal_scaled_block_norm
|
|
390
|
+
off_diagonal_RL_scaled_block_norm_sums.append(
|
|
391
|
+
off_diagonal_RL_scaled_block_norm_sum
|
|
392
|
+
)
|
|
393
|
+
logger.info(
|
|
394
|
+
"Sum of norms of off-diagonal blocks in row %s scaled by inverse of diagonal block: %s",
|
|
395
|
+
i,
|
|
396
|
+
off_diagonal_RL_scaled_block_norm_sum,
|
|
397
|
+
)
|
|
398
|
+
return {
|
|
399
|
+
**{
|
|
400
|
+
"joint_bread_inverse_condition_number": joint_adaptive_bread_inverse_cond,
|
|
401
|
+
"joint_bread_inverse_min_singular_value": joint_bread_inverse_min_singular_value,
|
|
402
|
+
"max_reward": max_reward,
|
|
403
|
+
"norm_avg_estimating_function_stack": norm_avg_estimating_function_stack,
|
|
404
|
+
"max_estimating_function_stack_norm": max_estimating_function_stack_norm,
|
|
405
|
+
"max_successive_beta_diff_norm": max_successive_beta_diff_norm,
|
|
406
|
+
"std_successive_beta_diff_norm": std_successive_beta_diff_norm,
|
|
407
|
+
"label": adaptive_sandwich_var_estimate,
|
|
408
|
+
},
|
|
409
|
+
**{
|
|
410
|
+
f"off_diag_block_{i}_{j}_norm": off_diag_block_norms[(i, j)]
|
|
411
|
+
for i in range(num_block_rows_cols)
|
|
412
|
+
for j in range(i)
|
|
413
|
+
},
|
|
414
|
+
**{f"diag_block_{i}_norm": diag_norms[i] for i in range(num_block_rows_cols)},
|
|
415
|
+
**{f"diag_block_{i}_cond": diag_conds[i] for i in range(num_block_rows_cols)},
|
|
416
|
+
**{
|
|
417
|
+
f"off_diag_row_{i}_norm": off_diag_row_norms[i]
|
|
418
|
+
for i in range(num_block_rows_cols)
|
|
419
|
+
},
|
|
420
|
+
**{
|
|
421
|
+
f"off_diag_col_{i}_norm": off_diag_col_norms[i]
|
|
422
|
+
for i in range(num_block_rows_cols)
|
|
423
|
+
},
|
|
424
|
+
**{
|
|
425
|
+
f"estimating_function_stack_norm_user_{user_id}": estimating_function_stack_norms[
|
|
426
|
+
i
|
|
427
|
+
]
|
|
428
|
+
for i, user_id in enumerate(user_ids)
|
|
429
|
+
},
|
|
430
|
+
**{
|
|
431
|
+
f"avg_estimating_function_stack_norm_segment_{i}": avg_estimating_function_stack_norms_per_segment[
|
|
432
|
+
i
|
|
433
|
+
]
|
|
434
|
+
for i in range(len(avg_estimating_function_stack_norms_per_segment))
|
|
435
|
+
},
|
|
436
|
+
**{
|
|
437
|
+
f"successive_beta_diff_norm_{i}": successive_beta_diff_norms[i]
|
|
438
|
+
for i in range(len(successive_beta_diff_norms))
|
|
439
|
+
},
|
|
440
|
+
**{
|
|
441
|
+
f"action_prob_logit_mean_t_{t}": action_prob_logit_means_by_t[t]
|
|
442
|
+
for t in range(len(action_prob_logit_means_by_t))
|
|
443
|
+
},
|
|
444
|
+
**{
|
|
445
|
+
f"action_prob_logit_std_t_{t}": action_prob_logit_stds_by_t[t]
|
|
446
|
+
for t in range(len(action_prob_logit_stds_by_t))
|
|
447
|
+
},
|
|
448
|
+
**{
|
|
449
|
+
f"reward_mean_t_{t}": reward_means_by_t[t]
|
|
450
|
+
for t in range(len(reward_means_by_t))
|
|
451
|
+
},
|
|
452
|
+
**{
|
|
453
|
+
f"reward_std_t_{t}": reward_stds_by_t[t]
|
|
454
|
+
for t in range(len(reward_stds_by_t))
|
|
455
|
+
},
|
|
456
|
+
**{f"theta_est_{i}": theta_est[i].item() for i in range(len(theta_est))},
|
|
457
|
+
**{
|
|
458
|
+
f"premature_joint_adaptive_bread_inverse_condition_number_{i}": premature_joint_adaptive_bread_inverse_condition_numbers[
|
|
459
|
+
i
|
|
460
|
+
]
|
|
461
|
+
for i in range(
|
|
462
|
+
len(premature_joint_adaptive_bread_inverse_condition_numbers)
|
|
463
|
+
)
|
|
464
|
+
},
|
|
465
|
+
**{
|
|
466
|
+
f"premature_adaptive_sandwich_update_{i}_diag_position_{j}": premature_adaptive_sandwich[
|
|
467
|
+
j, j
|
|
468
|
+
]
|
|
469
|
+
for premature_adaptive_sandwich in premature_adaptive_sandwiches
|
|
470
|
+
for j in range(theta_dim)
|
|
471
|
+
},
|
|
472
|
+
**{
|
|
473
|
+
f"premature_classical_sandwich_update_{i}_diag_position_{j}": premature_classical_sandwich[
|
|
474
|
+
j, j
|
|
475
|
+
]
|
|
476
|
+
for premature_classical_sandwich in premature_classical_sandwiches
|
|
477
|
+
for j in range(theta_dim)
|
|
478
|
+
},
|
|
479
|
+
**{
|
|
480
|
+
f"off_diagonal_RL_scaled_block_norm_sum_for_update_{i}": off_diagonal_RL_scaled_block_norm_sums[
|
|
481
|
+
i
|
|
482
|
+
]
|
|
483
|
+
for i in range(len(off_diagonal_RL_scaled_block_norm_sums))
|
|
484
|
+
},
|
|
485
|
+
**{
|
|
486
|
+
f"premature_RL_block_condition_number_after_update_{i}": premature_RL_block_condition_numbers[
|
|
487
|
+
i
|
|
488
|
+
]
|
|
489
|
+
for i in range(len(premature_RL_block_condition_numbers))
|
|
490
|
+
},
|
|
491
|
+
**{
|
|
492
|
+
f"premature_RL_block_inverse_norm_after_update_{i}": premature_RL_block_inverse_norms[
|
|
493
|
+
i
|
|
494
|
+
]
|
|
495
|
+
for i in range(len(premature_RL_block_inverse_norms))
|
|
496
|
+
},
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def calculate_sequence_of_premature_adaptive_estimates(
|
|
501
|
+
study_df: pd.DataFrame,
|
|
502
|
+
initial_policy_num: int | float,
|
|
503
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
504
|
+
policy_num_by_decision_time_by_user_id: dict[
|
|
505
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
506
|
+
],
|
|
507
|
+
theta_calculation_func: str,
|
|
508
|
+
calendar_t_col_name: str,
|
|
509
|
+
action_prob_col_name: str,
|
|
510
|
+
user_id_col_name: str,
|
|
511
|
+
in_study_col_name: str,
|
|
512
|
+
all_post_update_betas: jnp.ndarray,
|
|
513
|
+
user_ids: jnp.ndarray,
|
|
514
|
+
action_prob_func: str,
|
|
515
|
+
action_prob_func_args_beta_index: int,
|
|
516
|
+
inference_func: str,
|
|
517
|
+
inference_func_type: str,
|
|
518
|
+
inference_func_args_theta_index: int,
|
|
519
|
+
inference_func_args_action_prob_index: int,
|
|
520
|
+
inference_action_prob_decision_times_by_user_id: dict[
|
|
521
|
+
collections.abc.Hashable, list[int]
|
|
522
|
+
],
|
|
523
|
+
action_prob_func_args_by_user_id_by_decision_time: dict[
|
|
524
|
+
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
525
|
+
],
|
|
526
|
+
action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
|
|
527
|
+
full_joint_adaptive_bread_inverse_matrix: jnp.ndarray,
|
|
528
|
+
per_user_estimating_function_stacks: jnp.ndarray,
|
|
529
|
+
beta_dim: int,
|
|
530
|
+
) -> jnp.ndarray:
|
|
531
|
+
"""
|
|
532
|
+
Calculates a sequence of premature adaptive estimates for the given study DataFrame, where we
|
|
533
|
+
pretend the study ended after each update in sequence. The behavior of this sequence may provide
|
|
534
|
+
insight into the stability of the final adaptive estimate.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
study_df (pandas.DataFrame):
|
|
538
|
+
The DataFrame containing the study data.
|
|
539
|
+
initial_policy_num (int | float): The policy number of the initial policy before any updates.
|
|
540
|
+
initial_policy_num (int | float):
|
|
541
|
+
The policy number of the initial policy before any updates.
|
|
542
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
543
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
544
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
545
|
+
policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
546
|
+
A map of user ids to dictionaries mapping decision times to the policy number in use.
|
|
547
|
+
Only applies to in-study decision times!
|
|
548
|
+
theta_calculation_func (callable):
|
|
549
|
+
The filename for the theta calculation function.
|
|
550
|
+
calendar_t_col_name (str):
|
|
551
|
+
The name of the column in study_df representing calendar time.
|
|
552
|
+
action_prob_col_name (str):
|
|
553
|
+
The name of the column in study_df representing action probabilities.
|
|
554
|
+
user_id_col_name (str):
|
|
555
|
+
The name of the column in study_df representing user IDs.
|
|
556
|
+
in_study_col_name (str):
|
|
557
|
+
The name of the column in study_df indicating whether the user is in the study at that time.
|
|
558
|
+
all_post_update_betas (jnp.ndarray):
|
|
559
|
+
A NumPy array containing all post-update beta values.
|
|
560
|
+
user_ids (jnp.ndarray):
|
|
561
|
+
A NumPy array containing all user IDs in the study.
|
|
562
|
+
action_prob_func (callable):
|
|
563
|
+
The action probability function.
|
|
564
|
+
action_prob_func_args_beta_index (int):
|
|
565
|
+
The index of beta in the action probability function arguments tuples.
|
|
566
|
+
inference_func (callable):
|
|
567
|
+
The inference function.
|
|
568
|
+
inference_func_type (str):
|
|
569
|
+
The type of the inference function (loss or estimating).
|
|
570
|
+
inference_func_args_theta_index (int):
|
|
571
|
+
The index of the theta parameter in the inference function arguments tuples.
|
|
572
|
+
inference_func_args_action_prob_index (int):
|
|
573
|
+
The index of action probabilities in the inference function arguments tuple, if
|
|
574
|
+
applicable. -1 otherwise.
|
|
575
|
+
inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
|
|
576
|
+
For each user, a list of decision times to which action probabilities correspond if
|
|
577
|
+
provided. Typically just in-study times if action probabilites are used in the inference
|
|
578
|
+
loss or estimating function.
|
|
579
|
+
action_prob_func_args_by_user_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
580
|
+
A dictionary mapping decision times to maps of user ids to the function arguments
|
|
581
|
+
required to compute action probabilities for this user.
|
|
582
|
+
action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
583
|
+
A dictionary mapping user IDs to their respective actions taken at each decision time.
|
|
584
|
+
Only applies to in-study decision times!
|
|
585
|
+
full_joint_adaptive_bread_inverse_matrix (jnp.ndarray):
|
|
586
|
+
The full joint adaptive bread inverse matrix as a NumPy array.
|
|
587
|
+
per_user_estimating_function_stacks (jnp.ndarray):
|
|
588
|
+
A NumPy array containing all per-user (weighted) estimating function stacks.
|
|
589
|
+
beta_dim (int):
|
|
590
|
+
The dimension of the beta parameters.
|
|
591
|
+
Returns:
|
|
592
|
+
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature adaptive estimates.
|
|
593
|
+
"""
|
|
594
|
+
|
|
595
|
+
# Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming adaptive and classical
|
|
596
|
+
# variance estimates pretending that each was the final policy.
|
|
597
|
+
premature_adaptive_sandwiches = []
|
|
598
|
+
premature_thetas = []
|
|
599
|
+
premature_joint_adaptive_bread_inverse_condition_numbers = []
|
|
600
|
+
premature_avg_inference_estimating_functions = []
|
|
601
|
+
premature_classical_sandwiches = []
|
|
602
|
+
logger.info(
|
|
603
|
+
"Calculating sequence of premature adaptive estimates by pretending the study ended after each update in sequence."
|
|
604
|
+
)
|
|
605
|
+
for policy_num in sorted(beta_index_by_policy_num):
|
|
606
|
+
logger.info(
|
|
607
|
+
"Calculating premature adaptive estimate assuming policy %s is the final one.",
|
|
608
|
+
policy_num,
|
|
609
|
+
)
|
|
610
|
+
pretend_max_policy = policy_num
|
|
611
|
+
|
|
612
|
+
truncated_joint_adaptive_bread_inverse_matrix = (
|
|
613
|
+
full_joint_adaptive_bread_inverse_matrix[
|
|
614
|
+
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
615
|
+
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
616
|
+
]
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
max_decision_time = study_df[study_df["policy_num"] == pretend_max_policy][
|
|
620
|
+
calendar_t_col_name
|
|
621
|
+
].max()
|
|
622
|
+
|
|
623
|
+
truncated_study_df = study_df[
|
|
624
|
+
study_df[calendar_t_col_name] <= max_decision_time
|
|
625
|
+
].copy()
|
|
626
|
+
|
|
627
|
+
truncated_beta_index_by_policy_num = {
|
|
628
|
+
k: v for k, v in beta_index_by_policy_num.items() if k <= pretend_max_policy
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
max_beta_index = max(truncated_beta_index_by_policy_num.values())
|
|
632
|
+
|
|
633
|
+
truncated_all_post_update_betas = all_post_update_betas[: max_beta_index + 1, :]
|
|
634
|
+
|
|
635
|
+
premature_theta = jnp.array(theta_calculation_func(truncated_study_df))
|
|
636
|
+
|
|
637
|
+
truncated_action_prob_func_args_by_user_id_by_decision_time = {
|
|
638
|
+
decision_time: args_by_user_id
|
|
639
|
+
for decision_time, args_by_user_id in action_prob_func_args_by_user_id_by_decision_time.items()
|
|
640
|
+
if decision_time <= max_decision_time
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
truncated_inference_func_args_by_user_id, _, _ = (
|
|
644
|
+
after_study_analysis.process_inference_func_args(
|
|
645
|
+
inference_func,
|
|
646
|
+
inference_func_args_theta_index,
|
|
647
|
+
truncated_study_df,
|
|
648
|
+
premature_theta,
|
|
649
|
+
action_prob_col_name,
|
|
650
|
+
calendar_t_col_name,
|
|
651
|
+
user_id_col_name,
|
|
652
|
+
in_study_col_name,
|
|
653
|
+
)
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
truncated_inference_action_prob_decision_times_by_user_id = {
|
|
657
|
+
user_id: [
|
|
658
|
+
decision_time
|
|
659
|
+
for decision_time in inference_action_prob_decision_times_by_user_id[
|
|
660
|
+
user_id
|
|
661
|
+
]
|
|
662
|
+
if decision_time <= max_decision_time
|
|
663
|
+
]
|
|
664
|
+
# writing this way is important, handles empty dicts correctly
|
|
665
|
+
for user_id in inference_action_prob_decision_times_by_user_id
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
truncated_action_by_decision_time_by_user_id = {
|
|
669
|
+
user_id: {
|
|
670
|
+
decision_time: action
|
|
671
|
+
for decision_time, action in action_by_decision_time_by_user_id[
|
|
672
|
+
user_id
|
|
673
|
+
].items()
|
|
674
|
+
if decision_time <= max_decision_time
|
|
675
|
+
}
|
|
676
|
+
for user_id in action_by_decision_time_by_user_id
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
truncated_per_user_estimating_function_stacks = (
|
|
680
|
+
per_user_estimating_function_stacks[
|
|
681
|
+
:,
|
|
682
|
+
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
683
|
+
]
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
(
|
|
687
|
+
premature_adaptive_sandwich,
|
|
688
|
+
premature_classical_sandwich,
|
|
689
|
+
premature_avg_inference_estimating_function,
|
|
690
|
+
) = construct_premature_classical_and_adaptive_sandwiches(
|
|
691
|
+
truncated_joint_adaptive_bread_inverse_matrix,
|
|
692
|
+
truncated_per_user_estimating_function_stacks,
|
|
693
|
+
premature_theta,
|
|
694
|
+
truncated_all_post_update_betas,
|
|
695
|
+
user_ids,
|
|
696
|
+
action_prob_func,
|
|
697
|
+
action_prob_func_args_beta_index,
|
|
698
|
+
inference_func,
|
|
699
|
+
inference_func_type,
|
|
700
|
+
inference_func_args_theta_index,
|
|
701
|
+
inference_func_args_action_prob_index,
|
|
702
|
+
truncated_action_prob_func_args_by_user_id_by_decision_time,
|
|
703
|
+
policy_num_by_decision_time_by_user_id,
|
|
704
|
+
initial_policy_num,
|
|
705
|
+
truncated_beta_index_by_policy_num,
|
|
706
|
+
truncated_inference_func_args_by_user_id,
|
|
707
|
+
truncated_inference_action_prob_decision_times_by_user_id,
|
|
708
|
+
truncated_action_by_decision_time_by_user_id,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
premature_adaptive_sandwiches.append(premature_adaptive_sandwich)
|
|
712
|
+
premature_classical_sandwiches.append(premature_classical_sandwich)
|
|
713
|
+
premature_thetas.append(premature_theta)
|
|
714
|
+
premature_avg_inference_estimating_functions.append(
|
|
715
|
+
premature_avg_inference_estimating_function
|
|
716
|
+
)
|
|
717
|
+
return (
|
|
718
|
+
jnp.array(premature_thetas),
|
|
719
|
+
jnp.array(premature_adaptive_sandwiches),
|
|
720
|
+
jnp.array(premature_classical_sandwiches),
|
|
721
|
+
jnp.array(premature_joint_adaptive_bread_inverse_condition_numbers),
|
|
722
|
+
jnp.array(premature_avg_inference_estimating_functions),
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def construct_premature_classical_and_adaptive_sandwiches(
|
|
727
|
+
truncated_joint_adaptive_bread_inverse_matrix: jnp.ndarray,
|
|
728
|
+
per_user_truncated_estimating_function_stacks: jnp.ndarray,
|
|
729
|
+
theta: jnp.ndarray,
|
|
730
|
+
all_post_update_betas: jnp.ndarray,
|
|
731
|
+
user_ids: jnp.ndarray,
|
|
732
|
+
action_prob_func: str,
|
|
733
|
+
action_prob_func_args_beta_index: int,
|
|
734
|
+
inference_func: str,
|
|
735
|
+
inference_func_type: str,
|
|
736
|
+
inference_func_args_theta_index: int,
|
|
737
|
+
inference_func_args_action_prob_index: int,
|
|
738
|
+
action_prob_func_args_by_user_id_by_decision_time: dict[
|
|
739
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
740
|
+
],
|
|
741
|
+
policy_num_by_decision_time_by_user_id: dict[
|
|
742
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
743
|
+
],
|
|
744
|
+
initial_policy_num: int | float,
|
|
745
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
746
|
+
inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
747
|
+
inference_action_prob_decision_times_by_user_id: dict[
|
|
748
|
+
collections.abc.Hashable, list[int]
|
|
749
|
+
],
|
|
750
|
+
action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
|
|
751
|
+
) -> tuple[
|
|
752
|
+
jnp.ndarray[jnp.float32],
|
|
753
|
+
jnp.ndarray[jnp.float32],
|
|
754
|
+
jnp.ndarray[jnp.float32],
|
|
755
|
+
jnp.ndarray[jnp.float32],
|
|
756
|
+
jnp.ndarray[jnp.float32],
|
|
757
|
+
jnp.ndarray[jnp.float32],
|
|
758
|
+
jnp.ndarray[jnp.float32],
|
|
759
|
+
jnp.ndarray[jnp.float32],
|
|
760
|
+
]:
|
|
761
|
+
"""
|
|
762
|
+
Constructs the classical bread and meat matrices, as well as the adaptive bread matrix
|
|
763
|
+
and the average weighted inference estimating function for the premature variance estimation
|
|
764
|
+
procedure.
|
|
765
|
+
|
|
766
|
+
This is done by computing and differentiating the new average inference estimating function
|
|
767
|
+
with respect to the betas and theta, and stitching this together with the existing
|
|
768
|
+
adaptive bread inverse matrix portion (corresponding to the updates still under consideration)
|
|
769
|
+
to form the new premature joint adaptive bread inverse matrix.
|
|
770
|
+
|
|
771
|
+
Args:
|
|
772
|
+
truncated_joint_adaptive_bread_inverse_matrix (jnp.ndarray):
|
|
773
|
+
A 2-D JAX NumPy array holding the existing joint adaptive bread inverse but
|
|
774
|
+
with rows corresponding to updates not under consideration and inference dropped.
|
|
775
|
+
We will stitch this together with the newly computed inference portion to form
|
|
776
|
+
our "premature" joint adaptive bread inverse matrix.
|
|
777
|
+
per_user_truncated_estimating_function_stacks (jnp.ndarray):
|
|
778
|
+
A 2-D JAX NumPy array holding the existing per-user weighted estimating function
|
|
779
|
+
stacks but with rows corresponding to updates not under consideration dropped.
|
|
780
|
+
We will stitch this together with the newly computed inference estimating functions
|
|
781
|
+
to form our "premature" joint adaptive estimating function stacks from which the new
|
|
782
|
+
adaptive meat matrix can be computed.
|
|
783
|
+
theta (jnp.ndarray):
|
|
784
|
+
A 1-D JAX NumPy array representing the parameter estimate for inference.
|
|
785
|
+
all_post_update_betas (jnp.ndarray):
|
|
786
|
+
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
787
|
+
user_ids (jnp.ndarray):
|
|
788
|
+
A 1-D JAX NumPy array holding all user IDs in the study.
|
|
789
|
+
action_prob_func (callable):
|
|
790
|
+
The action probability function.
|
|
791
|
+
action_prob_func_args_beta_index (int):
|
|
792
|
+
The index of beta in the action probability function arguments tuples.
|
|
793
|
+
inference_func (callable):
|
|
794
|
+
The inference loss or estimating function.
|
|
795
|
+
inference_func_type (str):
|
|
796
|
+
The type of the inference function (loss or estimating).
|
|
797
|
+
inference_func_args_theta_index (int):
|
|
798
|
+
The index of the theta parameter in the inference function arguments tuples.
|
|
799
|
+
inference_func_args_action_prob_index (int):
|
|
800
|
+
The index of action probabilities in the inference function arguments tuple, if
|
|
801
|
+
applicable. -1 otherwise.
|
|
802
|
+
action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
803
|
+
A dictionary mapping decision times to maps of user ids to the function arguments
|
|
804
|
+
required to compute action probabilities for this user.
|
|
805
|
+
policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
806
|
+
A map of user ids to dictionaries mapping decision times to the policy number in use.
|
|
807
|
+
Only applies to in-study decision times!
|
|
808
|
+
initial_policy_num (int | float):
|
|
809
|
+
The policy number of the initial policy before any updates.
|
|
810
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
811
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
812
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
813
|
+
inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
814
|
+
A dictionary mapping user IDs to their respective inference function arguments.
|
|
815
|
+
inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
|
|
816
|
+
For each user, a list of decision times to which action probabilities correspond if
|
|
817
|
+
provided. Typically just in-study times if action probabilites are used in the inference
|
|
818
|
+
loss or estimating function.
|
|
819
|
+
action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
820
|
+
A dictionary mapping user IDs to their respective actions taken at each decision time.
|
|
821
|
+
Only applies to in-study decision times!
|
|
822
|
+
Returns:
|
|
823
|
+
tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
|
|
824
|
+
jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
|
|
825
|
+
jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
826
|
+
A tuple containing:
|
|
827
|
+
- The joint adaptive inverse bread matrix.
|
|
828
|
+
- The joint adaptive bread matrix.
|
|
829
|
+
- The joint adaptive meat matrix.
|
|
830
|
+
- The classical inverse bread matrix.
|
|
831
|
+
- The classical bread matrix.
|
|
832
|
+
- The classical meat matrix.
|
|
833
|
+
- The average (weighted) inference estimating function.
|
|
834
|
+
- The joint adaptive inverse bread matrix condition number.
|
|
835
|
+
"""
|
|
836
|
+
logger.info(
|
|
837
|
+
"Differentiating average weighted inference estimating function stack and collecting auxiliary values."
|
|
838
|
+
)
|
|
839
|
+
# jax.jacobian may perform worse here--seemed to hang indefinitely while jacrev is merely very
|
|
840
|
+
# slow.
|
|
841
|
+
# Note that these "contributions" are per-user Jacobians of the weighted estimating function stack.
|
|
842
|
+
new_inference_block_row, (
|
|
843
|
+
per_user_inference_estimating_functions,
|
|
844
|
+
avg_inference_estimating_function,
|
|
845
|
+
per_user_classical_meat_contributions,
|
|
846
|
+
per_user_classical_bread_inverse_contributions,
|
|
847
|
+
) = jax.jacrev(get_weighted_inference_estimating_functions_only, has_aux=True)(
|
|
848
|
+
# While JAX can technically differentiate with respect to a list of JAX arrays,
|
|
849
|
+
# it is more efficient to flatten them into a single array. This is done
|
|
850
|
+
# here to improve performance. We can simply unflatten them inside the function.
|
|
851
|
+
after_study_analysis.flatten_params(all_post_update_betas, theta),
|
|
852
|
+
all_post_update_betas.shape[1],
|
|
853
|
+
theta.shape[0],
|
|
854
|
+
user_ids,
|
|
855
|
+
action_prob_func,
|
|
856
|
+
action_prob_func_args_beta_index,
|
|
857
|
+
inference_func,
|
|
858
|
+
inference_func_type,
|
|
859
|
+
inference_func_args_theta_index,
|
|
860
|
+
inference_func_args_action_prob_index,
|
|
861
|
+
action_prob_func_args_by_user_id_by_decision_time,
|
|
862
|
+
policy_num_by_decision_time_by_user_id,
|
|
863
|
+
initial_policy_num,
|
|
864
|
+
beta_index_by_policy_num,
|
|
865
|
+
inference_func_args_by_user_id,
|
|
866
|
+
inference_action_prob_decision_times_by_user_id,
|
|
867
|
+
action_by_decision_time_by_user_id,
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
joint_adaptive_bread_inverse_matrix = jnp.block(
|
|
871
|
+
[
|
|
872
|
+
[
|
|
873
|
+
truncated_joint_adaptive_bread_inverse_matrix,
|
|
874
|
+
np.zeros(
|
|
875
|
+
(
|
|
876
|
+
truncated_joint_adaptive_bread_inverse_matrix.shape[0],
|
|
877
|
+
new_inference_block_row.shape[0],
|
|
878
|
+
)
|
|
879
|
+
),
|
|
880
|
+
],
|
|
881
|
+
[new_inference_block_row],
|
|
882
|
+
]
|
|
883
|
+
)
|
|
884
|
+
per_user_estimating_function_stacks = jnp.concatenate(
|
|
885
|
+
[
|
|
886
|
+
per_user_truncated_estimating_function_stacks,
|
|
887
|
+
per_user_inference_estimating_functions,
|
|
888
|
+
],
|
|
889
|
+
axis=1,
|
|
890
|
+
)
|
|
891
|
+
per_user_adaptive_meat_contributions = jnp.einsum(
|
|
892
|
+
"ni,nj->nij",
|
|
893
|
+
per_user_estimating_function_stacks,
|
|
894
|
+
per_user_estimating_function_stacks,
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
joint_adaptive_meat_matrix = jnp.mean(per_user_adaptive_meat_contributions, axis=0)
|
|
898
|
+
|
|
899
|
+
classical_bread_inverse_matrix = jnp.mean(
|
|
900
|
+
per_user_classical_bread_inverse_contributions, axis=0
|
|
901
|
+
)
|
|
902
|
+
classical_meat_matrix = jnp.mean(per_user_classical_meat_contributions, axis=0)
|
|
903
|
+
|
|
904
|
+
num_users = user_ids.shape[0]
|
|
905
|
+
joint_adaptive_sandwich = (
|
|
906
|
+
after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
|
|
907
|
+
joint_adaptive_bread_inverse_matrix,
|
|
908
|
+
joint_adaptive_meat_matrix,
|
|
909
|
+
num_users,
|
|
910
|
+
method="bread_inverse_T_qr",
|
|
911
|
+
)
|
|
912
|
+
)
|
|
913
|
+
adaptive_sandwich = joint_adaptive_sandwich[-theta.shape[0] :, -theta.shape[0] :]
|
|
914
|
+
|
|
915
|
+
classical_bread_inverse_matrix = jnp.mean(
|
|
916
|
+
per_user_classical_bread_inverse_contributions, axis=0
|
|
917
|
+
)
|
|
918
|
+
classical_sandwich = after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
|
|
919
|
+
classical_bread_inverse_matrix,
|
|
920
|
+
classical_meat_matrix,
|
|
921
|
+
num_users,
|
|
922
|
+
method="bread_inverse_T_qr",
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
# Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
|
|
926
|
+
# values too. The joint adaptive bread inverse should always be block lower triangular.
|
|
927
|
+
return (
|
|
928
|
+
adaptive_sandwich,
|
|
929
|
+
classical_sandwich,
|
|
930
|
+
avg_inference_estimating_function,
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def get_weighted_inference_estimating_functions_only(
|
|
935
|
+
flattened_betas_and_theta: jnp.ndarray,
|
|
936
|
+
beta_dim: int,
|
|
937
|
+
theta_dim: int,
|
|
938
|
+
user_ids: jnp.ndarray,
|
|
939
|
+
action_prob_func: callable,
|
|
940
|
+
action_prob_func_args_beta_index: int,
|
|
941
|
+
inference_func: callable,
|
|
942
|
+
inference_func_type: str,
|
|
943
|
+
inference_func_args_theta_index: int,
|
|
944
|
+
inference_func_args_action_prob_index: int,
|
|
945
|
+
action_prob_func_args_by_user_id_by_decision_time: dict[
|
|
946
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
947
|
+
],
|
|
948
|
+
policy_num_by_decision_time_by_user_id: dict[
|
|
949
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
950
|
+
],
|
|
951
|
+
initial_policy_num: int | float,
|
|
952
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
953
|
+
inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
954
|
+
inference_action_prob_decision_times_by_user_id: dict[
|
|
955
|
+
collections.abc.Hashable, list[int]
|
|
956
|
+
],
|
|
957
|
+
action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
|
|
958
|
+
) -> tuple[
|
|
959
|
+
jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
|
|
960
|
+
]:
|
|
961
|
+
"""
|
|
962
|
+
Computes the average weighted inference estimating function across users, along with
|
|
963
|
+
auxiliary values used to construct the adaptive and classical sandwich variances.
|
|
964
|
+
|
|
965
|
+
Note that input data should have been adjusted to only correspond to updates/decision times
|
|
966
|
+
that are being considered for the current "premature" variance estimation procedure.
|
|
967
|
+
|
|
968
|
+
Args:
|
|
969
|
+
flattened_betas_and_theta (jnp.ndarray):
|
|
970
|
+
A list of JAX NumPy arrays representing the betas produced by all updates and the
|
|
971
|
+
theta value, in that order. Important that this is a 1D array for efficiency reasons.
|
|
972
|
+
We simply extract the betas and theta from this array below.
|
|
973
|
+
beta_dim (int):
|
|
974
|
+
The dimension of each of the beta parameters.
|
|
975
|
+
theta_dim (int):
|
|
976
|
+
The dimension of the theta parameter.
|
|
977
|
+
user_ids (jnp.ndarray):
|
|
978
|
+
A 1D JAX NumPy array of user IDs.
|
|
979
|
+
action_prob_func (str):
|
|
980
|
+
The action probability function.
|
|
981
|
+
action_prob_func_args_beta_index (int):
|
|
982
|
+
The index of beta in the action probability function arguments tuples.
|
|
983
|
+
inference_func (str):
|
|
984
|
+
The inference loss or estimating function.
|
|
985
|
+
inference_func_type (str):
|
|
986
|
+
The type of the inference function (loss or estimating).
|
|
987
|
+
inference_func_args_theta_index (int):
|
|
988
|
+
The index of the theta parameter in the inference function arguments tuples.
|
|
989
|
+
inference_func_args_action_prob_index (int):
|
|
990
|
+
The index of action probabilities in the inference function arguments tuple, if
|
|
991
|
+
applicable. -1 otherwise.
|
|
992
|
+
action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
993
|
+
A dictionary mapping decision times to maps of user ids to the function arguments
|
|
994
|
+
required to compute action probabilities for this user.
|
|
995
|
+
policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
996
|
+
A map of user ids to dictionaries mapping decision times to the policy number in use.
|
|
997
|
+
Only applies to in-study decision times!
|
|
998
|
+
initial_policy_num (int | float):
|
|
999
|
+
The policy number of the initial policy before any updates.
|
|
1000
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
1001
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1002
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1003
|
+
inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1004
|
+
A dictionary mapping user IDs to their respective inference function arguments.
|
|
1005
|
+
inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
|
|
1006
|
+
For each user, a list of decision times to which action probabilities correspond if
|
|
1007
|
+
provided. Typically just in-study times if action probabilites are used in the inference
|
|
1008
|
+
loss or estimating function.
|
|
1009
|
+
action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1010
|
+
A dictionary mapping user IDs to their respective actions taken at each decision time.
|
|
1011
|
+
Only applies to in-study decision times!
|
|
1012
|
+
|
|
1013
|
+
Returns:
|
|
1014
|
+
jnp.ndarray:
|
|
1015
|
+
A 2D JAX NumPy array holding the average weighted inference estimating function.
|
|
1016
|
+
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
1017
|
+
A tuple containing
|
|
1018
|
+
1. the per-user weighted inference estimating function stacks
|
|
1019
|
+
2. the average weighted inference estimating function
|
|
1020
|
+
3. the user-level classical meat matrix contributions
|
|
1021
|
+
4. the user-level inverse classical bread matrix contributions
|
|
1022
|
+
stacks.
|
|
1023
|
+
"""
|
|
1024
|
+
|
|
1025
|
+
inference_estimating_func = (
|
|
1026
|
+
jax.grad(inference_func, argnums=inference_func_args_theta_index)
|
|
1027
|
+
if (inference_func_type == FunctionTypes.LOSS)
|
|
1028
|
+
else inference_func
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
betas, theta = after_study_analysis.unflatten_params(
|
|
1032
|
+
flattened_betas_and_theta,
|
|
1033
|
+
beta_dim,
|
|
1034
|
+
theta_dim,
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
# 2. Thread in the betas and theta in all_post_update_betas_and_theta into the arguments
|
|
1038
|
+
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
1039
|
+
# values should be the same, but not connected to the parameter we are differentiating
|
|
1040
|
+
# with respect to. Note we will also find it useful below to have the action probability args
|
|
1041
|
+
# nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
|
|
1042
|
+
|
|
1043
|
+
logger.info("Threading in betas to action probability arguments for all users.")
|
|
1044
|
+
(
|
|
1045
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
1046
|
+
action_prob_func_args_by_decision_time_by_user_id,
|
|
1047
|
+
) = after_study_analysis.thread_action_prob_func_args(
|
|
1048
|
+
action_prob_func_args_by_user_id_by_decision_time,
|
|
1049
|
+
policy_num_by_decision_time_by_user_id,
|
|
1050
|
+
initial_policy_num,
|
|
1051
|
+
betas,
|
|
1052
|
+
beta_index_by_policy_num,
|
|
1053
|
+
action_prob_func_args_beta_index,
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
# 4. Thread the central theta into the inference function arguments
|
|
1057
|
+
# and replace any action probabilities with reconstructed ones from the above
|
|
1058
|
+
# arguments with the central betas introduced.
|
|
1059
|
+
logger.info(
|
|
1060
|
+
"Threading in theta and beta-dependent action probabilities to inference update "
|
|
1061
|
+
"function args for all users"
|
|
1062
|
+
)
|
|
1063
|
+
threaded_inference_func_args_by_user_id = (
|
|
1064
|
+
after_study_analysis.thread_inference_func_args(
|
|
1065
|
+
inference_func_args_by_user_id,
|
|
1066
|
+
inference_func_args_theta_index,
|
|
1067
|
+
theta,
|
|
1068
|
+
inference_func_args_action_prob_index,
|
|
1069
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
1070
|
+
inference_action_prob_decision_times_by_user_id,
|
|
1071
|
+
action_prob_func,
|
|
1072
|
+
)
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
# 5. Now we can compute the the weighted inference estimating functions for all users
|
|
1076
|
+
# as well as collect related values used to construct the adaptive and classical
|
|
1077
|
+
# sandwich variances.
|
|
1078
|
+
results = [
|
|
1079
|
+
single_user_weighted_inference_estimating_function(
|
|
1080
|
+
user_id,
|
|
1081
|
+
action_prob_func,
|
|
1082
|
+
inference_estimating_func,
|
|
1083
|
+
action_prob_func_args_beta_index,
|
|
1084
|
+
inference_func_args_theta_index,
|
|
1085
|
+
action_prob_func_args_by_decision_time_by_user_id[user_id],
|
|
1086
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
|
|
1087
|
+
threaded_inference_func_args_by_user_id[user_id],
|
|
1088
|
+
policy_num_by_decision_time_by_user_id[user_id],
|
|
1089
|
+
action_by_decision_time_by_user_id[user_id],
|
|
1090
|
+
beta_index_by_policy_num,
|
|
1091
|
+
)
|
|
1092
|
+
for user_id in user_ids.tolist()
|
|
1093
|
+
]
|
|
1094
|
+
|
|
1095
|
+
weighted_inference_estimating_functions = jnp.array(
|
|
1096
|
+
[result[0] for result in results]
|
|
1097
|
+
)
|
|
1098
|
+
inference_only_outer_products = jnp.array([result[1] for result in results])
|
|
1099
|
+
inference_hessians = jnp.array([result[2] for result in results])
|
|
1100
|
+
|
|
1101
|
+
# 6. Note this strange return structure! We will differentiate the first output,
|
|
1102
|
+
# but the second tuple will be passed along without modification via has_aux=True and then used
|
|
1103
|
+
# for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
|
|
1104
|
+
# bread matrices. The raw per-user estimating functions are also returned again for debugging
|
|
1105
|
+
# purposes.
|
|
1106
|
+
return jnp.mean(weighted_inference_estimating_functions, axis=0), (
|
|
1107
|
+
weighted_inference_estimating_functions,
|
|
1108
|
+
jnp.mean(weighted_inference_estimating_functions, axis=0),
|
|
1109
|
+
inference_only_outer_products,
|
|
1110
|
+
inference_hessians,
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
|
|
1114
|
+
def single_user_weighted_inference_estimating_function(
|
|
1115
|
+
user_id: collections.abc.Hashable,
|
|
1116
|
+
action_prob_func: callable,
|
|
1117
|
+
inference_estimating_func: callable,
|
|
1118
|
+
action_prob_func_args_beta_index: int,
|
|
1119
|
+
inference_func_args_theta_index: int,
|
|
1120
|
+
action_prob_func_args_by_decision_time: dict[
|
|
1121
|
+
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
1122
|
+
],
|
|
1123
|
+
threaded_action_prob_func_args_by_decision_time: dict[
|
|
1124
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
1125
|
+
],
|
|
1126
|
+
threaded_inference_func_args: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
1127
|
+
policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
|
|
1128
|
+
action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
|
|
1129
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
1130
|
+
) -> tuple[
|
|
1131
|
+
jnp.ndarray[jnp.float32],
|
|
1132
|
+
jnp.ndarray[jnp.float32],
|
|
1133
|
+
jnp.ndarray[jnp.float32],
|
|
1134
|
+
]:
|
|
1135
|
+
"""
|
|
1136
|
+
Computes a weighted inference estimating function for a given inference estimating function and arguments
|
|
1137
|
+
and action probability function and arguments if applicable.
|
|
1138
|
+
|
|
1139
|
+
Args:
|
|
1140
|
+
user_id (collections.abc.Hashable):
|
|
1141
|
+
The user ID for which to compute the weighted estimating function stack.
|
|
1142
|
+
|
|
1143
|
+
action_prob_func (callable):
|
|
1144
|
+
The function used to compute the probability of action 1 at a given decision time for
|
|
1145
|
+
a particular user given their state and the algorithm parameters.
|
|
1146
|
+
|
|
1147
|
+
inference_estimating_func (callable):
|
|
1148
|
+
The estimating function that corresponds to inference.
|
|
1149
|
+
|
|
1150
|
+
action_prob_func_args_beta_index (int):
|
|
1151
|
+
The index of the beta argument in the action probability function's arguments.
|
|
1152
|
+
|
|
1153
|
+
inference_func_args_theta_index (int):
|
|
1154
|
+
The index of the theta parameter in the inference loss or estimating function arguments.
|
|
1155
|
+
|
|
1156
|
+
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
1157
|
+
A map from decision times to tuples of arguments for this user for the action
|
|
1158
|
+
probability function. This is for all decision times (args are an empty
|
|
1159
|
+
tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
|
|
1160
|
+
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
1161
|
+
will occur.
|
|
1162
|
+
|
|
1163
|
+
threaded_action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
1164
|
+
A map from decision times to tuples of arguments for the action
|
|
1165
|
+
probability function, with the shared betas threaded in for differentation. Decision
|
|
1166
|
+
times should be sorted.
|
|
1167
|
+
|
|
1168
|
+
threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
1169
|
+
A tuple containing the arguments for the inference
|
|
1170
|
+
estimating function for this user, with the shared betas threaded in for differentiation.
|
|
1171
|
+
|
|
1172
|
+
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
1173
|
+
A dictionary mapping decision times to the policy number in use. This may be
|
|
1174
|
+
user-specific. Should be sorted by decision time. Only applies to in-study decision
|
|
1175
|
+
times!
|
|
1176
|
+
|
|
1177
|
+
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
1178
|
+
A dictionary mapping decision times to actions taken. Only applies to in-study decision
|
|
1179
|
+
times!
|
|
1180
|
+
|
|
1181
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
1182
|
+
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
1183
|
+
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
1184
|
+
|
|
1185
|
+
Returns:
|
|
1186
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the user's weighted inference estimating function.
|
|
1187
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical meat contribution.
|
|
1188
|
+
jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical bread contribution.
|
|
1189
|
+
"""
|
|
1190
|
+
|
|
1191
|
+
logger.info(
|
|
1192
|
+
"Computing only weighted inference estimating function stack for user %s.",
|
|
1193
|
+
user_id,
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
# First, reformat the supplied data into more convenient structures.
|
|
1197
|
+
|
|
1198
|
+
# 1. Get the first time after the first update for convenience.
|
|
1199
|
+
# This is used to form the Radon-Nikodym weights for the right times.
|
|
1200
|
+
_, first_time_after_first_update = after_study_analysis.get_min_time_by_policy_num(
|
|
1201
|
+
policy_num_by_decision_time,
|
|
1202
|
+
beta_index_by_policy_num,
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
# 2. Get the start and end times for this user.
|
|
1206
|
+
user_start_time = math.inf
|
|
1207
|
+
user_end_time = -math.inf
|
|
1208
|
+
for decision_time in action_by_decision_time:
|
|
1209
|
+
user_start_time = min(user_start_time, decision_time)
|
|
1210
|
+
user_end_time = max(user_end_time, decision_time)
|
|
1211
|
+
|
|
1212
|
+
# 3. Calculate the Radon-Nikodym weights for the inference estimating function.
|
|
1213
|
+
in_study_action_prob_func_args = [
|
|
1214
|
+
args for args in action_prob_func_args_by_decision_time.values() if args
|
|
1215
|
+
]
|
|
1216
|
+
in_study_betas_list_by_decision_time_index = jnp.array(
|
|
1217
|
+
[
|
|
1218
|
+
action_prob_func_args[action_prob_func_args_beta_index]
|
|
1219
|
+
for action_prob_func_args in in_study_action_prob_func_args
|
|
1220
|
+
]
|
|
1221
|
+
)
|
|
1222
|
+
in_study_actions_list_by_decision_time_index = jnp.array(
|
|
1223
|
+
list(action_by_decision_time.values())
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
# Sort the threaded args by decision time to be cautious. We check if the
|
|
1227
|
+
# user id is present in the user args dict because we may call this on a
|
|
1228
|
+
# subset of the user arg dict when we are batching arguments by shape
|
|
1229
|
+
sorted_threaded_action_prob_args_by_decision_time = {
|
|
1230
|
+
decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
|
|
1231
|
+
for decision_time in range(user_start_time, user_end_time + 1)
|
|
1232
|
+
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
1233
|
+
}
|
|
1234
|
+
|
|
1235
|
+
num_args = None
|
|
1236
|
+
for args in sorted_threaded_action_prob_args_by_decision_time.values():
|
|
1237
|
+
if args:
|
|
1238
|
+
num_args = len(args)
|
|
1239
|
+
break
|
|
1240
|
+
|
|
1241
|
+
# NOTE: Cannot do [[]] * num_args here! Then all lists point
|
|
1242
|
+
# same object...
|
|
1243
|
+
batched_threaded_arg_lists = [[] for _ in range(num_args)]
|
|
1244
|
+
for (
|
|
1245
|
+
decision_time,
|
|
1246
|
+
args,
|
|
1247
|
+
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
1248
|
+
if not args:
|
|
1249
|
+
continue
|
|
1250
|
+
for idx, arg in enumerate(args):
|
|
1251
|
+
batched_threaded_arg_lists[idx].append(arg)
|
|
1252
|
+
|
|
1253
|
+
batched_threaded_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensors(
|
|
1254
|
+
batched_threaded_arg_lists
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
# Note that we do NOT use the shared betas in the first arg to the weight function,
|
|
1258
|
+
# since we don't want differentiation to happen with respect to them.
|
|
1259
|
+
# Just grab the original beta from the update function arguments. This is the same
|
|
1260
|
+
# value, but impervious to differentiation with respect to all_post_update_betas. The
|
|
1261
|
+
# args, on the other hand, are a function of all_post_update_betas.
|
|
1262
|
+
in_study_weights = jax.vmap(
|
|
1263
|
+
fun=after_study_analysis.get_radon_nikodym_weight,
|
|
1264
|
+
in_axes=[0, None, None, 0] + batch_axes,
|
|
1265
|
+
out_axes=0,
|
|
1266
|
+
)(
|
|
1267
|
+
in_study_betas_list_by_decision_time_index,
|
|
1268
|
+
action_prob_func,
|
|
1269
|
+
action_prob_func_args_beta_index,
|
|
1270
|
+
in_study_actions_list_by_decision_time_index,
|
|
1271
|
+
*batched_threaded_arg_tensors,
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
in_study_index = 0
|
|
1275
|
+
decision_time_to_all_weights_index_offset = min(
|
|
1276
|
+
sorted_threaded_action_prob_args_by_decision_time
|
|
1277
|
+
)
|
|
1278
|
+
all_weights_raw = []
|
|
1279
|
+
for (
|
|
1280
|
+
decision_time,
|
|
1281
|
+
args,
|
|
1282
|
+
) in sorted_threaded_action_prob_args_by_decision_time.items():
|
|
1283
|
+
all_weights_raw.append(in_study_weights[in_study_index] if args else 1.0)
|
|
1284
|
+
in_study_index += 1
|
|
1285
|
+
all_weights = jnp.array(all_weights_raw)
|
|
1286
|
+
|
|
1287
|
+
# 4. Form the weighted inference estimating equation.
|
|
1288
|
+
weighted_inference_estimating_function = jnp.prod(
|
|
1289
|
+
all_weights[
|
|
1290
|
+
max(first_time_after_first_update, user_start_time)
|
|
1291
|
+
- decision_time_to_all_weights_index_offset : user_end_time
|
|
1292
|
+
+ 1
|
|
1293
|
+
- decision_time_to_all_weights_index_offset,
|
|
1294
|
+
]
|
|
1295
|
+
# If the user exited the study before there were any updates,
|
|
1296
|
+
# this variable will be None and the above code to grab a weight would
|
|
1297
|
+
# throw an error. Just use 1 to include the unweighted estimating function
|
|
1298
|
+
# if they have data to contribute here (pretty sure everyone should?)
|
|
1299
|
+
if first_time_after_first_update is not None
|
|
1300
|
+
else 1
|
|
1301
|
+
) * inference_estimating_func(*threaded_inference_func_args)
|
|
1302
|
+
|
|
1303
|
+
return (
|
|
1304
|
+
weighted_inference_estimating_function,
|
|
1305
|
+
jnp.outer(
|
|
1306
|
+
weighted_inference_estimating_function,
|
|
1307
|
+
weighted_inference_estimating_function,
|
|
1308
|
+
),
|
|
1309
|
+
jax.jacrev(inference_estimating_func, argnums=inference_func_args_theta_index)(
|
|
1310
|
+
*threaded_inference_func_args
|
|
1311
|
+
),
|
|
1312
|
+
)
|