lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1312 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ from typing import Any
6
+ import collections
7
+
8
+ import numpy as np
9
+ from scipy.special import logit
10
+ import plotext as plt
11
+ import jax
12
+ from jax import numpy as jnp
13
+ import pandas as pd
14
+
15
+ from . import after_study_analysis
16
+ from .constants import FunctionTypes
17
+ from .vmap_helpers import stack_batched_arg_lists_into_tensors
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logging.basicConfig(
21
+ format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
22
+ datefmt="%Y-%m-%d:%H:%M:%S",
23
+ level=logging.INFO,
24
+ )
25
+
26
+
27
+ def get_datum_for_blowup_supervised_learning(
28
+ joint_adaptive_bread_inverse_matrix,
29
+ joint_adaptive_bread_inverse_cond,
30
+ avg_estimating_function_stack,
31
+ per_user_estimating_function_stacks,
32
+ all_post_update_betas,
33
+ study_df,
34
+ in_study_col_name,
35
+ calendar_t_col_name,
36
+ action_prob_col_name,
37
+ user_id_col_name,
38
+ reward_col_name,
39
+ theta_est,
40
+ adaptive_sandwich_var_estimate,
41
+ user_ids,
42
+ beta_dim,
43
+ theta_dim,
44
+ initial_policy_num,
45
+ beta_index_by_policy_num,
46
+ policy_num_by_decision_time_by_user_id,
47
+ theta_calculation_func,
48
+ action_prob_func,
49
+ action_prob_func_args_beta_index,
50
+ inference_func,
51
+ inference_func_type,
52
+ inference_func_args_theta_index,
53
+ inference_func_args_action_prob_index,
54
+ inference_action_prob_decision_times_by_user_id,
55
+ action_prob_func_args,
56
+ action_by_decision_time_by_user_id,
57
+ ) -> dict[str, Any]:
58
+ """
59
+ Collects a datum for supervised learning about adaptive sandwich blowup.
60
+
61
+ The datum consists of features and the raw adaptive sandwich variance estimate as a label.
62
+
63
+ A few plots are produced along the way to help visualize the data.
64
+
65
+ Args:
66
+ joint_adaptive_bread_inverse_matrix (jnp.ndarray):
67
+ The joint adaptive bread inverse matrix.
68
+ joint_adaptive_bread_inverse_cond (float):
69
+ The condition number of the joint adaptive bread inverse matrix.
70
+ avg_estimating_function_stack (jnp.ndarray):
71
+ The average estimating function stack across users.
72
+ per_user_estimating_function_stacks (jnp.ndarray):
73
+ The estimating function stacks for each user.
74
+ all_post_update_betas (jnp.ndarray):
75
+ All post-update beta parameters.
76
+ study_df (pd.DataFrame):
77
+ The study DataFrame.
78
+ in_study_col_name (str):
79
+ Column name indicating if a user is in the study in the study dataframe.
80
+ calendar_t_col_name (str):
81
+ Column name for calendar time in the study dataframe.
82
+ action_prob_col_name (str):
83
+ Column name for action probabilities in the study dataframe.
84
+ user_id_col_name (str):
85
+ Column name for user IDs in the study dataframe
86
+ reward_col_name (str):
87
+ Column name for rewards in the study dataframe.
88
+ theta_est (jnp.ndarray):
89
+ The estimate of the parameter vector theta.
90
+ adaptive_sandwich_var_estimate (jnp.ndarray):
91
+ The adaptive sandwich variance estimate for theta.
92
+ user_ids (jnp.ndarray):
93
+ Array of unique user IDs.
94
+ beta_dim (int):
95
+ Dimension of the beta parameter vector.
96
+ theta_dim (int):
97
+ Dimension of the theta parameter vector.
98
+ initial_policy_num (int | float):
99
+ The initial policy number used in the study.
100
+ beta_index_by_policy_num (dict[int | float, int]):
101
+ Mapping from policy numbers to indices in all_post_update_betas.
102
+ policy_num_by_decision_time_by_user_id (dict):
103
+ Mapping from user IDs to their policy numbers by decision time.
104
+ theta_calculation_func (callable):
105
+ The theta calculation function.
106
+ action_prob_func (callable):
107
+ The action probability function.
108
+ action_prob_func_args_beta_index (int):
109
+ Index for beta in action probability function arguments.
110
+ inference_func (callable):
111
+ The inference function.
112
+ inference_func_type (str):
113
+ Type of the inference function.
114
+ inference_func_args_theta_index (int):
115
+ Index for theta in inference function arguments.
116
+ inference_func_args_action_prob_index (int):
117
+ Index for action probability in inference function arguments.
118
+ inference_action_prob_decision_times_by_user_id (dict):
119
+ Mapping from user IDs to decision times for action probabilities used in inference.
120
+ action_prob_func_args (dict):
121
+ Arguments for the action probability function.
122
+ action_by_decision_time_by_user_id (dict):
123
+ Mapping from user IDs to their actions by decision time.
124
+ Returns:
125
+ dict[str, Any]: A dictionary containing features and the label for supervised learning.
126
+ """
127
+ num_diagonal_blocks = (
128
+ (joint_adaptive_bread_inverse_matrix.shape[0] - theta_dim) // beta_dim
129
+ ) + 1
130
+ diagonal_block_sizes = ([beta_dim] * (num_diagonal_blocks - 1)) + [theta_dim]
131
+
132
+ block_bounds = np.cumsum([0] + list(diagonal_block_sizes))
133
+ num_block_rows_cols = len(diagonal_block_sizes)
134
+
135
+ # collect diagonal and sub-diagonal block norms and diagonal condition numbers
136
+ off_diag_block_norms = {}
137
+ diag_norms = []
138
+ diag_conds = []
139
+ off_diag_row_norms = np.zeros(num_block_rows_cols)
140
+ off_diag_col_norms = np.zeros(num_block_rows_cols)
141
+ for i in range(num_block_rows_cols):
142
+ for j in range(num_block_rows_cols):
143
+ if i > j: # below-diagonal blocks
144
+ row_slice = slice(block_bounds[i], block_bounds[i + 1])
145
+ col_slice = slice(block_bounds[j], block_bounds[j + 1])
146
+ block_norm = np.linalg.norm(
147
+ joint_adaptive_bread_inverse_matrix[row_slice, col_slice],
148
+ ord="fro",
149
+ )
150
+ # We will sum here and take the square root later
151
+ off_diag_row_norms[i] += block_norm
152
+ off_diag_col_norms[j] += block_norm
153
+ off_diag_block_norms[(i, j)] = block_norm
154
+
155
+ # handle diagonal blocks
156
+ sl = slice(block_bounds[i], block_bounds[i + 1])
157
+ diag_norms.append(
158
+ np.linalg.norm(joint_adaptive_bread_inverse_matrix[sl, sl], ord="fro")
159
+ )
160
+ diag_conds.append(np.linalg.cond(joint_adaptive_bread_inverse_matrix[sl, sl]))
161
+
162
+ # Sqrt each row/col sum to truly get row/column norms.
163
+ # Perhaps not necessary for learning, but more natural
164
+ off_diag_row_norms = np.sqrt(off_diag_row_norms)
165
+ off_diag_col_norms = np.sqrt(off_diag_col_norms)
166
+
167
+ # Get the per-person estimating function stack norms
168
+ estimating_function_stack_norms = np.linalg.norm(
169
+ per_user_estimating_function_stacks, axis=1
170
+ )
171
+
172
+ # Get the average estimating function stack norms by update/inference
173
+ # Use the bounds variable from above to split the estimating function stacks
174
+ # into blocks corresponding to the updates and inference.
175
+ avg_estimating_function_stack_norms_per_segment = [
176
+ np.mean(
177
+ np.linalg.norm(
178
+ per_user_estimating_function_stacks[
179
+ :, block_bounds[i] : block_bounds[i + 1]
180
+ ],
181
+ axis=1,
182
+ )
183
+ )
184
+ for i in range(len(block_bounds) - 1)
185
+ ]
186
+
187
+ # Compute the norms of each successive difference in all_post_update_betas.
188
+ successive_beta_diffs = np.diff(np.array(all_post_update_betas), axis=0)
189
+ successive_beta_diff_norms = np.linalg.norm(successive_beta_diffs, axis=1)
190
+ max_successive_beta_diff_norm = np.max(successive_beta_diff_norms)
191
+ std_successive_beta_diff_norm = np.std(successive_beta_diff_norms)
192
+
193
+ # Add a column with logits of the action probabilities
194
+ # Compute the average and standard deviation of the logits of the action probabilities at each decision time using study_df
195
+ # action_prob_logit_means and action_prob_logit_stds are numpy arrays of mean and stddev at each decision time
196
+ # Only compute logits for rows where user is in the study; set others to NaN
197
+ in_study_mask = study_df[in_study_col_name] == 1
198
+ study_df["action_prob_logit"] = np.where(
199
+ in_study_mask,
200
+ logit(study_df[action_prob_col_name]),
201
+ np.nan,
202
+ )
203
+ grouped_action_prob_logit = study_df.loc[in_study_mask].groupby(
204
+ calendar_t_col_name
205
+ )["action_prob_logit"]
206
+ action_prob_logit_means_by_t = grouped_action_prob_logit.mean().values
207
+ action_prob_logit_stds_by_t = grouped_action_prob_logit.std().values
208
+
209
+ # Compute the average and standard deviation of the rewards at each decision time using study_df
210
+ # reward_means and reward_stds are numpy arrays of mean and stddev at each decision time
211
+ grouped_reward = study_df.loc[in_study_mask].groupby(calendar_t_col_name)[
212
+ reward_col_name
213
+ ]
214
+ reward_means_by_t = grouped_reward.mean().values
215
+ reward_stds_by_t = grouped_reward.std().values
216
+
217
+ joint_bread_inverse_min_singular_value = np.linalg.svd(
218
+ joint_adaptive_bread_inverse_matrix, compute_uv=False
219
+ )[-1]
220
+
221
+ max_reward = study_df.loc[in_study_mask][reward_col_name].max()
222
+
223
+ norm_avg_estimating_function_stack = np.linalg.norm(avg_estimating_function_stack)
224
+ max_estimating_function_stack_norm = np.max(estimating_function_stack_norms)
225
+
226
+ (
227
+ premature_thetas,
228
+ premature_adaptive_sandwiches,
229
+ premature_classical_sandwiches,
230
+ premature_joint_adaptive_bread_inverse_condition_numbers,
231
+ premature_avg_inference_estimating_functions,
232
+ ) = calculate_sequence_of_premature_adaptive_estimates(
233
+ study_df,
234
+ initial_policy_num,
235
+ beta_index_by_policy_num,
236
+ policy_num_by_decision_time_by_user_id,
237
+ theta_calculation_func,
238
+ calendar_t_col_name,
239
+ action_prob_col_name,
240
+ user_id_col_name,
241
+ in_study_col_name,
242
+ all_post_update_betas,
243
+ user_ids,
244
+ action_prob_func,
245
+ action_prob_func_args_beta_index,
246
+ inference_func,
247
+ inference_func_type,
248
+ inference_func_args_theta_index,
249
+ inference_func_args_action_prob_index,
250
+ inference_action_prob_decision_times_by_user_id,
251
+ action_prob_func_args,
252
+ action_by_decision_time_by_user_id,
253
+ joint_adaptive_bread_inverse_matrix,
254
+ per_user_estimating_function_stacks,
255
+ beta_dim,
256
+ )
257
+
258
+ np.testing.assert_allclose(
259
+ np.zeros_like(premature_avg_inference_estimating_functions),
260
+ premature_avg_inference_estimating_functions,
261
+ atol=1e-3,
262
+ )
263
+
264
+ # Plot premature joint adaptive bread inverse log condition numbers
265
+ plt.clear_figure()
266
+ plt.title("Premature Joint Adaptive Bread Inverse Log Condition Numbers")
267
+ plt.xlabel("Premature Update Index")
268
+ plt.ylabel("Log Condition Number")
269
+ plt.scatter(
270
+ np.log(premature_joint_adaptive_bread_inverse_condition_numbers),
271
+ color="blue+",
272
+ )
273
+ plt.grid(True)
274
+ plt.xticks(
275
+ range(
276
+ 0,
277
+ len(premature_joint_adaptive_bread_inverse_condition_numbers),
278
+ max(
279
+ 1,
280
+ len(premature_joint_adaptive_bread_inverse_condition_numbers) // 10,
281
+ ),
282
+ )
283
+ )
284
+ plt.show()
285
+
286
+ # Plot each diagonal element of premature adaptive sandwiches
287
+ num_diag = premature_adaptive_sandwiches.shape[-1]
288
+ for i in range(num_diag):
289
+ plt.clear_figure()
290
+ plt.title(f"Premature Adaptive Sandwich Diagonal Element {i}")
291
+ plt.xlabel("Premature Update Index")
292
+ plt.ylabel(f"Variance (Diagonal {i})")
293
+ plt.scatter(np.array(premature_adaptive_sandwiches[:, i, i]), color="blue+")
294
+ plt.grid(True)
295
+ plt.xticks(
296
+ range(
297
+ 0,
298
+ int(premature_adaptive_sandwiches.shape[0]),
299
+ max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
300
+ )
301
+ )
302
+ plt.show()
303
+
304
+ plt.clear_figure()
305
+ plt.title(
306
+ f"Premature Adaptive Sandwich Diagonal Element {i} Ratio to Classical"
307
+ )
308
+ plt.xlabel("Premature Update Index")
309
+ plt.ylabel(f"Variance (Diagonal {i})")
310
+ plt.scatter(
311
+ np.array(premature_adaptive_sandwiches[:, i, i])
312
+ / np.array(premature_classical_sandwiches[:, i, i]),
313
+ color="red+",
314
+ )
315
+ plt.grid(True)
316
+ plt.xticks(
317
+ range(
318
+ 0,
319
+ int(premature_adaptive_sandwiches.shape[0]),
320
+ max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
321
+ )
322
+ )
323
+ plt.show()
324
+
325
+ plt.clear_figure()
326
+ plt.title(f"Premature Theta Estimates At Index {i}")
327
+ plt.xlabel("Premature Update Index")
328
+ plt.ylabel(f"Theta element {i}")
329
+ plt.scatter(np.array(premature_thetas[:, i]), color="green+")
330
+ plt.grid(True)
331
+ plt.xticks(
332
+ range(
333
+ 0,
334
+ int(premature_adaptive_sandwiches.shape[0]),
335
+ max(1, int(premature_adaptive_sandwiches.shape[0]) // 10),
336
+ )
337
+ )
338
+ plt.show()
339
+
340
+ # Grab predictors related to premature Phi-dot-bars
341
+ RL_stack_beta_derivatives_block = joint_adaptive_bread_inverse_matrix[
342
+ :-theta_dim, :-theta_dim
343
+ ]
344
+ num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
345
+ premature_RL_block_condition_numbers = []
346
+ premature_RL_block_inverse_norms = []
347
+ diagonal_RL_block_condition_numbers = []
348
+ off_diagonal_RL_scaled_block_norm_sums = []
349
+ for i in range(1, num_updates + 1):
350
+ whole_block_size = i * beta_dim
351
+ whole_block = RL_stack_beta_derivatives_block[
352
+ :whole_block_size, :whole_block_size
353
+ ]
354
+ whole_RL_block_cond_number = np.linalg.cond(whole_block)
355
+ premature_RL_block_condition_numbers.append(whole_RL_block_cond_number)
356
+ logger.info(
357
+ "Condition number of whole RL_stack_beta_derivatives_block (after update %s): %s",
358
+ i,
359
+ whole_RL_block_cond_number,
360
+ )
361
+ diagonal_block = RL_stack_beta_derivatives_block[
362
+ (i - 1) * beta_dim : i * beta_dim, (i - 1) * beta_dim : i * beta_dim
363
+ ]
364
+ diagonal_RL_block_cond_number = np.linalg.cond(diagonal_block)
365
+ diagonal_RL_block_condition_numbers.append(diagonal_RL_block_cond_number)
366
+ logger.info(
367
+ "Condition number of just RL_stack_beta_derivatives_block *diagonal block* for update %s: %s",
368
+ i,
369
+ diagonal_RL_block_cond_number,
370
+ )
371
+
372
+ premature_RL_block_inverse_norms.append(
373
+ np.linalg.norm(np.linalg.inv(whole_block))
374
+ )
375
+ logger.info(
376
+ "Norm of inverse of whole RL_stack_beta_derivatives_block (after update %s): %s",
377
+ i,
378
+ premature_RL_block_inverse_norms[-1],
379
+ )
380
+
381
+ off_diagonal_RL_scaled_block_norm_sum = 0
382
+ for j in range(1, i):
383
+ off_diagonal_block = RL_stack_beta_derivatives_block[
384
+ (i - 1) * beta_dim : i * beta_dim, (j - 1) * beta_dim : j * beta_dim
385
+ ]
386
+ off_diagonal_scaled_block_norm = np.linalg.norm(
387
+ np.linalg.solve(diagonal_block, off_diagonal_block)
388
+ )
389
+ off_diagonal_RL_scaled_block_norm_sum += off_diagonal_scaled_block_norm
390
+ off_diagonal_RL_scaled_block_norm_sums.append(
391
+ off_diagonal_RL_scaled_block_norm_sum
392
+ )
393
+ logger.info(
394
+ "Sum of norms of off-diagonal blocks in row %s scaled by inverse of diagonal block: %s",
395
+ i,
396
+ off_diagonal_RL_scaled_block_norm_sum,
397
+ )
398
+ return {
399
+ **{
400
+ "joint_bread_inverse_condition_number": joint_adaptive_bread_inverse_cond,
401
+ "joint_bread_inverse_min_singular_value": joint_bread_inverse_min_singular_value,
402
+ "max_reward": max_reward,
403
+ "norm_avg_estimating_function_stack": norm_avg_estimating_function_stack,
404
+ "max_estimating_function_stack_norm": max_estimating_function_stack_norm,
405
+ "max_successive_beta_diff_norm": max_successive_beta_diff_norm,
406
+ "std_successive_beta_diff_norm": std_successive_beta_diff_norm,
407
+ "label": adaptive_sandwich_var_estimate,
408
+ },
409
+ **{
410
+ f"off_diag_block_{i}_{j}_norm": off_diag_block_norms[(i, j)]
411
+ for i in range(num_block_rows_cols)
412
+ for j in range(i)
413
+ },
414
+ **{f"diag_block_{i}_norm": diag_norms[i] for i in range(num_block_rows_cols)},
415
+ **{f"diag_block_{i}_cond": diag_conds[i] for i in range(num_block_rows_cols)},
416
+ **{
417
+ f"off_diag_row_{i}_norm": off_diag_row_norms[i]
418
+ for i in range(num_block_rows_cols)
419
+ },
420
+ **{
421
+ f"off_diag_col_{i}_norm": off_diag_col_norms[i]
422
+ for i in range(num_block_rows_cols)
423
+ },
424
+ **{
425
+ f"estimating_function_stack_norm_user_{user_id}": estimating_function_stack_norms[
426
+ i
427
+ ]
428
+ for i, user_id in enumerate(user_ids)
429
+ },
430
+ **{
431
+ f"avg_estimating_function_stack_norm_segment_{i}": avg_estimating_function_stack_norms_per_segment[
432
+ i
433
+ ]
434
+ for i in range(len(avg_estimating_function_stack_norms_per_segment))
435
+ },
436
+ **{
437
+ f"successive_beta_diff_norm_{i}": successive_beta_diff_norms[i]
438
+ for i in range(len(successive_beta_diff_norms))
439
+ },
440
+ **{
441
+ f"action_prob_logit_mean_t_{t}": action_prob_logit_means_by_t[t]
442
+ for t in range(len(action_prob_logit_means_by_t))
443
+ },
444
+ **{
445
+ f"action_prob_logit_std_t_{t}": action_prob_logit_stds_by_t[t]
446
+ for t in range(len(action_prob_logit_stds_by_t))
447
+ },
448
+ **{
449
+ f"reward_mean_t_{t}": reward_means_by_t[t]
450
+ for t in range(len(reward_means_by_t))
451
+ },
452
+ **{
453
+ f"reward_std_t_{t}": reward_stds_by_t[t]
454
+ for t in range(len(reward_stds_by_t))
455
+ },
456
+ **{f"theta_est_{i}": theta_est[i].item() for i in range(len(theta_est))},
457
+ **{
458
+ f"premature_joint_adaptive_bread_inverse_condition_number_{i}": premature_joint_adaptive_bread_inverse_condition_numbers[
459
+ i
460
+ ]
461
+ for i in range(
462
+ len(premature_joint_adaptive_bread_inverse_condition_numbers)
463
+ )
464
+ },
465
+ **{
466
+ f"premature_adaptive_sandwich_update_{i}_diag_position_{j}": premature_adaptive_sandwich[
467
+ j, j
468
+ ]
469
+ for premature_adaptive_sandwich in premature_adaptive_sandwiches
470
+ for j in range(theta_dim)
471
+ },
472
+ **{
473
+ f"premature_classical_sandwich_update_{i}_diag_position_{j}": premature_classical_sandwich[
474
+ j, j
475
+ ]
476
+ for premature_classical_sandwich in premature_classical_sandwiches
477
+ for j in range(theta_dim)
478
+ },
479
+ **{
480
+ f"off_diagonal_RL_scaled_block_norm_sum_for_update_{i}": off_diagonal_RL_scaled_block_norm_sums[
481
+ i
482
+ ]
483
+ for i in range(len(off_diagonal_RL_scaled_block_norm_sums))
484
+ },
485
+ **{
486
+ f"premature_RL_block_condition_number_after_update_{i}": premature_RL_block_condition_numbers[
487
+ i
488
+ ]
489
+ for i in range(len(premature_RL_block_condition_numbers))
490
+ },
491
+ **{
492
+ f"premature_RL_block_inverse_norm_after_update_{i}": premature_RL_block_inverse_norms[
493
+ i
494
+ ]
495
+ for i in range(len(premature_RL_block_inverse_norms))
496
+ },
497
+ }
498
+
499
+
500
+ def calculate_sequence_of_premature_adaptive_estimates(
501
+ study_df: pd.DataFrame,
502
+ initial_policy_num: int | float,
503
+ beta_index_by_policy_num: dict[int | float, int],
504
+ policy_num_by_decision_time_by_user_id: dict[
505
+ collections.abc.Hashable, dict[int, int | float]
506
+ ],
507
+ theta_calculation_func: str,
508
+ calendar_t_col_name: str,
509
+ action_prob_col_name: str,
510
+ user_id_col_name: str,
511
+ in_study_col_name: str,
512
+ all_post_update_betas: jnp.ndarray,
513
+ user_ids: jnp.ndarray,
514
+ action_prob_func: str,
515
+ action_prob_func_args_beta_index: int,
516
+ inference_func: str,
517
+ inference_func_type: str,
518
+ inference_func_args_theta_index: int,
519
+ inference_func_args_action_prob_index: int,
520
+ inference_action_prob_decision_times_by_user_id: dict[
521
+ collections.abc.Hashable, list[int]
522
+ ],
523
+ action_prob_func_args_by_user_id_by_decision_time: dict[
524
+ int, dict[collections.abc.Hashable, tuple[Any, ...]]
525
+ ],
526
+ action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
527
+ full_joint_adaptive_bread_inverse_matrix: jnp.ndarray,
528
+ per_user_estimating_function_stacks: jnp.ndarray,
529
+ beta_dim: int,
530
+ ) -> jnp.ndarray:
531
+ """
532
+ Calculates a sequence of premature adaptive estimates for the given study DataFrame, where we
533
+ pretend the study ended after each update in sequence. The behavior of this sequence may provide
534
+ insight into the stability of the final adaptive estimate.
535
+
536
+ Args:
537
+ study_df (pandas.DataFrame):
538
+ The DataFrame containing the study data.
539
+ initial_policy_num (int | float): The policy number of the initial policy before any updates.
540
+ initial_policy_num (int | float):
541
+ The policy number of the initial policy before any updates.
542
+ beta_index_by_policy_num (dict[int | float, int]):
543
+ A dictionary mapping policy numbers to the index of the corresponding beta in
544
+ all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
545
+ policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
546
+ A map of user ids to dictionaries mapping decision times to the policy number in use.
547
+ Only applies to in-study decision times!
548
+ theta_calculation_func (callable):
549
+ The filename for the theta calculation function.
550
+ calendar_t_col_name (str):
551
+ The name of the column in study_df representing calendar time.
552
+ action_prob_col_name (str):
553
+ The name of the column in study_df representing action probabilities.
554
+ user_id_col_name (str):
555
+ The name of the column in study_df representing user IDs.
556
+ in_study_col_name (str):
557
+ The name of the column in study_df indicating whether the user is in the study at that time.
558
+ all_post_update_betas (jnp.ndarray):
559
+ A NumPy array containing all post-update beta values.
560
+ user_ids (jnp.ndarray):
561
+ A NumPy array containing all user IDs in the study.
562
+ action_prob_func (callable):
563
+ The action probability function.
564
+ action_prob_func_args_beta_index (int):
565
+ The index of beta in the action probability function arguments tuples.
566
+ inference_func (callable):
567
+ The inference function.
568
+ inference_func_type (str):
569
+ The type of the inference function (loss or estimating).
570
+ inference_func_args_theta_index (int):
571
+ The index of the theta parameter in the inference function arguments tuples.
572
+ inference_func_args_action_prob_index (int):
573
+ The index of action probabilities in the inference function arguments tuple, if
574
+ applicable. -1 otherwise.
575
+ inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
576
+ For each user, a list of decision times to which action probabilities correspond if
577
+ provided. Typically just in-study times if action probabilites are used in the inference
578
+ loss or estimating function.
579
+ action_prob_func_args_by_user_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
580
+ A dictionary mapping decision times to maps of user ids to the function arguments
581
+ required to compute action probabilities for this user.
582
+ action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
583
+ A dictionary mapping user IDs to their respective actions taken at each decision time.
584
+ Only applies to in-study decision times!
585
+ full_joint_adaptive_bread_inverse_matrix (jnp.ndarray):
586
+ The full joint adaptive bread inverse matrix as a NumPy array.
587
+ per_user_estimating_function_stacks (jnp.ndarray):
588
+ A NumPy array containing all per-user (weighted) estimating function stacks.
589
+ beta_dim (int):
590
+ The dimension of the beta parameters.
591
+ Returns:
592
+ tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: A NumPy array containing the sequence of premature adaptive estimates.
593
+ """
594
+
595
+ # Loop through the non-initial (ie not before an update has occurred), non-final policy numbers in sorted order, forming adaptive and classical
596
+ # variance estimates pretending that each was the final policy.
597
+ premature_adaptive_sandwiches = []
598
+ premature_thetas = []
599
+ premature_joint_adaptive_bread_inverse_condition_numbers = []
600
+ premature_avg_inference_estimating_functions = []
601
+ premature_classical_sandwiches = []
602
+ logger.info(
603
+ "Calculating sequence of premature adaptive estimates by pretending the study ended after each update in sequence."
604
+ )
605
+ for policy_num in sorted(beta_index_by_policy_num):
606
+ logger.info(
607
+ "Calculating premature adaptive estimate assuming policy %s is the final one.",
608
+ policy_num,
609
+ )
610
+ pretend_max_policy = policy_num
611
+
612
+ truncated_joint_adaptive_bread_inverse_matrix = (
613
+ full_joint_adaptive_bread_inverse_matrix[
614
+ : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
615
+ : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
616
+ ]
617
+ )
618
+
619
+ max_decision_time = study_df[study_df["policy_num"] == pretend_max_policy][
620
+ calendar_t_col_name
621
+ ].max()
622
+
623
+ truncated_study_df = study_df[
624
+ study_df[calendar_t_col_name] <= max_decision_time
625
+ ].copy()
626
+
627
+ truncated_beta_index_by_policy_num = {
628
+ k: v for k, v in beta_index_by_policy_num.items() if k <= pretend_max_policy
629
+ }
630
+
631
+ max_beta_index = max(truncated_beta_index_by_policy_num.values())
632
+
633
+ truncated_all_post_update_betas = all_post_update_betas[: max_beta_index + 1, :]
634
+
635
+ premature_theta = jnp.array(theta_calculation_func(truncated_study_df))
636
+
637
+ truncated_action_prob_func_args_by_user_id_by_decision_time = {
638
+ decision_time: args_by_user_id
639
+ for decision_time, args_by_user_id in action_prob_func_args_by_user_id_by_decision_time.items()
640
+ if decision_time <= max_decision_time
641
+ }
642
+
643
+ truncated_inference_func_args_by_user_id, _, _ = (
644
+ after_study_analysis.process_inference_func_args(
645
+ inference_func,
646
+ inference_func_args_theta_index,
647
+ truncated_study_df,
648
+ premature_theta,
649
+ action_prob_col_name,
650
+ calendar_t_col_name,
651
+ user_id_col_name,
652
+ in_study_col_name,
653
+ )
654
+ )
655
+
656
+ truncated_inference_action_prob_decision_times_by_user_id = {
657
+ user_id: [
658
+ decision_time
659
+ for decision_time in inference_action_prob_decision_times_by_user_id[
660
+ user_id
661
+ ]
662
+ if decision_time <= max_decision_time
663
+ ]
664
+ # writing this way is important, handles empty dicts correctly
665
+ for user_id in inference_action_prob_decision_times_by_user_id
666
+ }
667
+
668
+ truncated_action_by_decision_time_by_user_id = {
669
+ user_id: {
670
+ decision_time: action
671
+ for decision_time, action in action_by_decision_time_by_user_id[
672
+ user_id
673
+ ].items()
674
+ if decision_time <= max_decision_time
675
+ }
676
+ for user_id in action_by_decision_time_by_user_id
677
+ }
678
+
679
+ truncated_per_user_estimating_function_stacks = (
680
+ per_user_estimating_function_stacks[
681
+ :,
682
+ : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
683
+ ]
684
+ )
685
+
686
+ (
687
+ premature_adaptive_sandwich,
688
+ premature_classical_sandwich,
689
+ premature_avg_inference_estimating_function,
690
+ ) = construct_premature_classical_and_adaptive_sandwiches(
691
+ truncated_joint_adaptive_bread_inverse_matrix,
692
+ truncated_per_user_estimating_function_stacks,
693
+ premature_theta,
694
+ truncated_all_post_update_betas,
695
+ user_ids,
696
+ action_prob_func,
697
+ action_prob_func_args_beta_index,
698
+ inference_func,
699
+ inference_func_type,
700
+ inference_func_args_theta_index,
701
+ inference_func_args_action_prob_index,
702
+ truncated_action_prob_func_args_by_user_id_by_decision_time,
703
+ policy_num_by_decision_time_by_user_id,
704
+ initial_policy_num,
705
+ truncated_beta_index_by_policy_num,
706
+ truncated_inference_func_args_by_user_id,
707
+ truncated_inference_action_prob_decision_times_by_user_id,
708
+ truncated_action_by_decision_time_by_user_id,
709
+ )
710
+
711
+ premature_adaptive_sandwiches.append(premature_adaptive_sandwich)
712
+ premature_classical_sandwiches.append(premature_classical_sandwich)
713
+ premature_thetas.append(premature_theta)
714
+ premature_avg_inference_estimating_functions.append(
715
+ premature_avg_inference_estimating_function
716
+ )
717
+ return (
718
+ jnp.array(premature_thetas),
719
+ jnp.array(premature_adaptive_sandwiches),
720
+ jnp.array(premature_classical_sandwiches),
721
+ jnp.array(premature_joint_adaptive_bread_inverse_condition_numbers),
722
+ jnp.array(premature_avg_inference_estimating_functions),
723
+ )
724
+
725
+
726
+ def construct_premature_classical_and_adaptive_sandwiches(
727
+ truncated_joint_adaptive_bread_inverse_matrix: jnp.ndarray,
728
+ per_user_truncated_estimating_function_stacks: jnp.ndarray,
729
+ theta: jnp.ndarray,
730
+ all_post_update_betas: jnp.ndarray,
731
+ user_ids: jnp.ndarray,
732
+ action_prob_func: str,
733
+ action_prob_func_args_beta_index: int,
734
+ inference_func: str,
735
+ inference_func_type: str,
736
+ inference_func_args_theta_index: int,
737
+ inference_func_args_action_prob_index: int,
738
+ action_prob_func_args_by_user_id_by_decision_time: dict[
739
+ collections.abc.Hashable, dict[int, tuple[Any, ...]]
740
+ ],
741
+ policy_num_by_decision_time_by_user_id: dict[
742
+ collections.abc.Hashable, dict[int, int | float]
743
+ ],
744
+ initial_policy_num: int | float,
745
+ beta_index_by_policy_num: dict[int | float, int],
746
+ inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
747
+ inference_action_prob_decision_times_by_user_id: dict[
748
+ collections.abc.Hashable, list[int]
749
+ ],
750
+ action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
751
+ ) -> tuple[
752
+ jnp.ndarray[jnp.float32],
753
+ jnp.ndarray[jnp.float32],
754
+ jnp.ndarray[jnp.float32],
755
+ jnp.ndarray[jnp.float32],
756
+ jnp.ndarray[jnp.float32],
757
+ jnp.ndarray[jnp.float32],
758
+ jnp.ndarray[jnp.float32],
759
+ jnp.ndarray[jnp.float32],
760
+ ]:
761
+ """
762
+ Constructs the classical bread and meat matrices, as well as the adaptive bread matrix
763
+ and the average weighted inference estimating function for the premature variance estimation
764
+ procedure.
765
+
766
+ This is done by computing and differentiating the new average inference estimating function
767
+ with respect to the betas and theta, and stitching this together with the existing
768
+ adaptive bread inverse matrix portion (corresponding to the updates still under consideration)
769
+ to form the new premature joint adaptive bread inverse matrix.
770
+
771
+ Args:
772
+ truncated_joint_adaptive_bread_inverse_matrix (jnp.ndarray):
773
+ A 2-D JAX NumPy array holding the existing joint adaptive bread inverse but
774
+ with rows corresponding to updates not under consideration and inference dropped.
775
+ We will stitch this together with the newly computed inference portion to form
776
+ our "premature" joint adaptive bread inverse matrix.
777
+ per_user_truncated_estimating_function_stacks (jnp.ndarray):
778
+ A 2-D JAX NumPy array holding the existing per-user weighted estimating function
779
+ stacks but with rows corresponding to updates not under consideration dropped.
780
+ We will stitch this together with the newly computed inference estimating functions
781
+ to form our "premature" joint adaptive estimating function stacks from which the new
782
+ adaptive meat matrix can be computed.
783
+ theta (jnp.ndarray):
784
+ A 1-D JAX NumPy array representing the parameter estimate for inference.
785
+ all_post_update_betas (jnp.ndarray):
786
+ A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
787
+ user_ids (jnp.ndarray):
788
+ A 1-D JAX NumPy array holding all user IDs in the study.
789
+ action_prob_func (callable):
790
+ The action probability function.
791
+ action_prob_func_args_beta_index (int):
792
+ The index of beta in the action probability function arguments tuples.
793
+ inference_func (callable):
794
+ The inference loss or estimating function.
795
+ inference_func_type (str):
796
+ The type of the inference function (loss or estimating).
797
+ inference_func_args_theta_index (int):
798
+ The index of the theta parameter in the inference function arguments tuples.
799
+ inference_func_args_action_prob_index (int):
800
+ The index of action probabilities in the inference function arguments tuple, if
801
+ applicable. -1 otherwise.
802
+ action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
803
+ A dictionary mapping decision times to maps of user ids to the function arguments
804
+ required to compute action probabilities for this user.
805
+ policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
806
+ A map of user ids to dictionaries mapping decision times to the policy number in use.
807
+ Only applies to in-study decision times!
808
+ initial_policy_num (int | float):
809
+ The policy number of the initial policy before any updates.
810
+ beta_index_by_policy_num (dict[int | float, int]):
811
+ A dictionary mapping policy numbers to the index of the corresponding beta in
812
+ all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
813
+ inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
814
+ A dictionary mapping user IDs to their respective inference function arguments.
815
+ inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
816
+ For each user, a list of decision times to which action probabilities correspond if
817
+ provided. Typically just in-study times if action probabilites are used in the inference
818
+ loss or estimating function.
819
+ action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
820
+ A dictionary mapping user IDs to their respective actions taken at each decision time.
821
+ Only applies to in-study decision times!
822
+ Returns:
823
+ tuple[jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
824
+ jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
825
+ jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
826
+ A tuple containing:
827
+ - The joint adaptive inverse bread matrix.
828
+ - The joint adaptive bread matrix.
829
+ - The joint adaptive meat matrix.
830
+ - The classical inverse bread matrix.
831
+ - The classical bread matrix.
832
+ - The classical meat matrix.
833
+ - The average (weighted) inference estimating function.
834
+ - The joint adaptive inverse bread matrix condition number.
835
+ """
836
+ logger.info(
837
+ "Differentiating average weighted inference estimating function stack and collecting auxiliary values."
838
+ )
839
+ # jax.jacobian may perform worse here--seemed to hang indefinitely while jacrev is merely very
840
+ # slow.
841
+ # Note that these "contributions" are per-user Jacobians of the weighted estimating function stack.
842
+ new_inference_block_row, (
843
+ per_user_inference_estimating_functions,
844
+ avg_inference_estimating_function,
845
+ per_user_classical_meat_contributions,
846
+ per_user_classical_bread_inverse_contributions,
847
+ ) = jax.jacrev(get_weighted_inference_estimating_functions_only, has_aux=True)(
848
+ # While JAX can technically differentiate with respect to a list of JAX arrays,
849
+ # it is more efficient to flatten them into a single array. This is done
850
+ # here to improve performance. We can simply unflatten them inside the function.
851
+ after_study_analysis.flatten_params(all_post_update_betas, theta),
852
+ all_post_update_betas.shape[1],
853
+ theta.shape[0],
854
+ user_ids,
855
+ action_prob_func,
856
+ action_prob_func_args_beta_index,
857
+ inference_func,
858
+ inference_func_type,
859
+ inference_func_args_theta_index,
860
+ inference_func_args_action_prob_index,
861
+ action_prob_func_args_by_user_id_by_decision_time,
862
+ policy_num_by_decision_time_by_user_id,
863
+ initial_policy_num,
864
+ beta_index_by_policy_num,
865
+ inference_func_args_by_user_id,
866
+ inference_action_prob_decision_times_by_user_id,
867
+ action_by_decision_time_by_user_id,
868
+ )
869
+
870
+ joint_adaptive_bread_inverse_matrix = jnp.block(
871
+ [
872
+ [
873
+ truncated_joint_adaptive_bread_inverse_matrix,
874
+ np.zeros(
875
+ (
876
+ truncated_joint_adaptive_bread_inverse_matrix.shape[0],
877
+ new_inference_block_row.shape[0],
878
+ )
879
+ ),
880
+ ],
881
+ [new_inference_block_row],
882
+ ]
883
+ )
884
+ per_user_estimating_function_stacks = jnp.concatenate(
885
+ [
886
+ per_user_truncated_estimating_function_stacks,
887
+ per_user_inference_estimating_functions,
888
+ ],
889
+ axis=1,
890
+ )
891
+ per_user_adaptive_meat_contributions = jnp.einsum(
892
+ "ni,nj->nij",
893
+ per_user_estimating_function_stacks,
894
+ per_user_estimating_function_stacks,
895
+ )
896
+
897
+ joint_adaptive_meat_matrix = jnp.mean(per_user_adaptive_meat_contributions, axis=0)
898
+
899
+ classical_bread_inverse_matrix = jnp.mean(
900
+ per_user_classical_bread_inverse_contributions, axis=0
901
+ )
902
+ classical_meat_matrix = jnp.mean(per_user_classical_meat_contributions, axis=0)
903
+
904
+ num_users = user_ids.shape[0]
905
+ joint_adaptive_sandwich = (
906
+ after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
907
+ joint_adaptive_bread_inverse_matrix,
908
+ joint_adaptive_meat_matrix,
909
+ num_users,
910
+ method="bread_inverse_T_qr",
911
+ )
912
+ )
913
+ adaptive_sandwich = joint_adaptive_sandwich[-theta.shape[0] :, -theta.shape[0] :]
914
+
915
+ classical_bread_inverse_matrix = jnp.mean(
916
+ per_user_classical_bread_inverse_contributions, axis=0
917
+ )
918
+ classical_sandwich = after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
919
+ classical_bread_inverse_matrix,
920
+ classical_meat_matrix,
921
+ num_users,
922
+ method="bread_inverse_T_qr",
923
+ )
924
+
925
+ # Stack the joint adaptive inverse bread pieces together horizontally and return the auxiliary
926
+ # values too. The joint adaptive bread inverse should always be block lower triangular.
927
+ return (
928
+ adaptive_sandwich,
929
+ classical_sandwich,
930
+ avg_inference_estimating_function,
931
+ )
932
+
933
+
934
+ def get_weighted_inference_estimating_functions_only(
935
+ flattened_betas_and_theta: jnp.ndarray,
936
+ beta_dim: int,
937
+ theta_dim: int,
938
+ user_ids: jnp.ndarray,
939
+ action_prob_func: callable,
940
+ action_prob_func_args_beta_index: int,
941
+ inference_func: callable,
942
+ inference_func_type: str,
943
+ inference_func_args_theta_index: int,
944
+ inference_func_args_action_prob_index: int,
945
+ action_prob_func_args_by_user_id_by_decision_time: dict[
946
+ collections.abc.Hashable, dict[int, tuple[Any, ...]]
947
+ ],
948
+ policy_num_by_decision_time_by_user_id: dict[
949
+ collections.abc.Hashable, dict[int, int | float]
950
+ ],
951
+ initial_policy_num: int | float,
952
+ beta_index_by_policy_num: dict[int | float, int],
953
+ inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
954
+ inference_action_prob_decision_times_by_user_id: dict[
955
+ collections.abc.Hashable, list[int]
956
+ ],
957
+ action_by_decision_time_by_user_id: dict[collections.abc.Hashable, dict[int, int]],
958
+ ) -> tuple[
959
+ jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]
960
+ ]:
961
+ """
962
+ Computes the average weighted inference estimating function across users, along with
963
+ auxiliary values used to construct the adaptive and classical sandwich variances.
964
+
965
+ Note that input data should have been adjusted to only correspond to updates/decision times
966
+ that are being considered for the current "premature" variance estimation procedure.
967
+
968
+ Args:
969
+ flattened_betas_and_theta (jnp.ndarray):
970
+ A list of JAX NumPy arrays representing the betas produced by all updates and the
971
+ theta value, in that order. Important that this is a 1D array for efficiency reasons.
972
+ We simply extract the betas and theta from this array below.
973
+ beta_dim (int):
974
+ The dimension of each of the beta parameters.
975
+ theta_dim (int):
976
+ The dimension of the theta parameter.
977
+ user_ids (jnp.ndarray):
978
+ A 1D JAX NumPy array of user IDs.
979
+ action_prob_func (str):
980
+ The action probability function.
981
+ action_prob_func_args_beta_index (int):
982
+ The index of beta in the action probability function arguments tuples.
983
+ inference_func (str):
984
+ The inference loss or estimating function.
985
+ inference_func_type (str):
986
+ The type of the inference function (loss or estimating).
987
+ inference_func_args_theta_index (int):
988
+ The index of the theta parameter in the inference function arguments tuples.
989
+ inference_func_args_action_prob_index (int):
990
+ The index of action probabilities in the inference function arguments tuple, if
991
+ applicable. -1 otherwise.
992
+ action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
993
+ A dictionary mapping decision times to maps of user ids to the function arguments
994
+ required to compute action probabilities for this user.
995
+ policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
996
+ A map of user ids to dictionaries mapping decision times to the policy number in use.
997
+ Only applies to in-study decision times!
998
+ initial_policy_num (int | float):
999
+ The policy number of the initial policy before any updates.
1000
+ beta_index_by_policy_num (dict[int | float, int]):
1001
+ A dictionary mapping policy numbers to the index of the corresponding beta in
1002
+ all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1003
+ inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
1004
+ A dictionary mapping user IDs to their respective inference function arguments.
1005
+ inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
1006
+ For each user, a list of decision times to which action probabilities correspond if
1007
+ provided. Typically just in-study times if action probabilites are used in the inference
1008
+ loss or estimating function.
1009
+ action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
1010
+ A dictionary mapping user IDs to their respective actions taken at each decision time.
1011
+ Only applies to in-study decision times!
1012
+
1013
+ Returns:
1014
+ jnp.ndarray:
1015
+ A 2D JAX NumPy array holding the average weighted inference estimating function.
1016
+ tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1017
+ A tuple containing
1018
+ 1. the per-user weighted inference estimating function stacks
1019
+ 2. the average weighted inference estimating function
1020
+ 3. the user-level classical meat matrix contributions
1021
+ 4. the user-level inverse classical bread matrix contributions
1022
+ stacks.
1023
+ """
1024
+
1025
+ inference_estimating_func = (
1026
+ jax.grad(inference_func, argnums=inference_func_args_theta_index)
1027
+ if (inference_func_type == FunctionTypes.LOSS)
1028
+ else inference_func
1029
+ )
1030
+
1031
+ betas, theta = after_study_analysis.unflatten_params(
1032
+ flattened_betas_and_theta,
1033
+ beta_dim,
1034
+ theta_dim,
1035
+ )
1036
+
1037
+ # 2. Thread in the betas and theta in all_post_update_betas_and_theta into the arguments
1038
+ # supplied for the above functions, so that differentiation works correctly. The existing
1039
+ # values should be the same, but not connected to the parameter we are differentiating
1040
+ # with respect to. Note we will also find it useful below to have the action probability args
1041
+ # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
1042
+
1043
+ logger.info("Threading in betas to action probability arguments for all users.")
1044
+ (
1045
+ threaded_action_prob_func_args_by_decision_time_by_user_id,
1046
+ action_prob_func_args_by_decision_time_by_user_id,
1047
+ ) = after_study_analysis.thread_action_prob_func_args(
1048
+ action_prob_func_args_by_user_id_by_decision_time,
1049
+ policy_num_by_decision_time_by_user_id,
1050
+ initial_policy_num,
1051
+ betas,
1052
+ beta_index_by_policy_num,
1053
+ action_prob_func_args_beta_index,
1054
+ )
1055
+
1056
+ # 4. Thread the central theta into the inference function arguments
1057
+ # and replace any action probabilities with reconstructed ones from the above
1058
+ # arguments with the central betas introduced.
1059
+ logger.info(
1060
+ "Threading in theta and beta-dependent action probabilities to inference update "
1061
+ "function args for all users"
1062
+ )
1063
+ threaded_inference_func_args_by_user_id = (
1064
+ after_study_analysis.thread_inference_func_args(
1065
+ inference_func_args_by_user_id,
1066
+ inference_func_args_theta_index,
1067
+ theta,
1068
+ inference_func_args_action_prob_index,
1069
+ threaded_action_prob_func_args_by_decision_time_by_user_id,
1070
+ inference_action_prob_decision_times_by_user_id,
1071
+ action_prob_func,
1072
+ )
1073
+ )
1074
+
1075
+ # 5. Now we can compute the the weighted inference estimating functions for all users
1076
+ # as well as collect related values used to construct the adaptive and classical
1077
+ # sandwich variances.
1078
+ results = [
1079
+ single_user_weighted_inference_estimating_function(
1080
+ user_id,
1081
+ action_prob_func,
1082
+ inference_estimating_func,
1083
+ action_prob_func_args_beta_index,
1084
+ inference_func_args_theta_index,
1085
+ action_prob_func_args_by_decision_time_by_user_id[user_id],
1086
+ threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
1087
+ threaded_inference_func_args_by_user_id[user_id],
1088
+ policy_num_by_decision_time_by_user_id[user_id],
1089
+ action_by_decision_time_by_user_id[user_id],
1090
+ beta_index_by_policy_num,
1091
+ )
1092
+ for user_id in user_ids.tolist()
1093
+ ]
1094
+
1095
+ weighted_inference_estimating_functions = jnp.array(
1096
+ [result[0] for result in results]
1097
+ )
1098
+ inference_only_outer_products = jnp.array([result[1] for result in results])
1099
+ inference_hessians = jnp.array([result[2] for result in results])
1100
+
1101
+ # 6. Note this strange return structure! We will differentiate the first output,
1102
+ # but the second tuple will be passed along without modification via has_aux=True and then used
1103
+ # for the adaptive meat matrix, estimating functions sum check, and classical meat and inverse
1104
+ # bread matrices. The raw per-user estimating functions are also returned again for debugging
1105
+ # purposes.
1106
+ return jnp.mean(weighted_inference_estimating_functions, axis=0), (
1107
+ weighted_inference_estimating_functions,
1108
+ jnp.mean(weighted_inference_estimating_functions, axis=0),
1109
+ inference_only_outer_products,
1110
+ inference_hessians,
1111
+ )
1112
+
1113
+
1114
+ def single_user_weighted_inference_estimating_function(
1115
+ user_id: collections.abc.Hashable,
1116
+ action_prob_func: callable,
1117
+ inference_estimating_func: callable,
1118
+ action_prob_func_args_beta_index: int,
1119
+ inference_func_args_theta_index: int,
1120
+ action_prob_func_args_by_decision_time: dict[
1121
+ int, dict[collections.abc.Hashable, tuple[Any, ...]]
1122
+ ],
1123
+ threaded_action_prob_func_args_by_decision_time: dict[
1124
+ collections.abc.Hashable, dict[int, tuple[Any, ...]]
1125
+ ],
1126
+ threaded_inference_func_args: dict[collections.abc.Hashable, tuple[Any, ...]],
1127
+ policy_num_by_decision_time: dict[collections.abc.Hashable, dict[int, int | float]],
1128
+ action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
1129
+ beta_index_by_policy_num: dict[int | float, int],
1130
+ ) -> tuple[
1131
+ jnp.ndarray[jnp.float32],
1132
+ jnp.ndarray[jnp.float32],
1133
+ jnp.ndarray[jnp.float32],
1134
+ ]:
1135
+ """
1136
+ Computes a weighted inference estimating function for a given inference estimating function and arguments
1137
+ and action probability function and arguments if applicable.
1138
+
1139
+ Args:
1140
+ user_id (collections.abc.Hashable):
1141
+ The user ID for which to compute the weighted estimating function stack.
1142
+
1143
+ action_prob_func (callable):
1144
+ The function used to compute the probability of action 1 at a given decision time for
1145
+ a particular user given their state and the algorithm parameters.
1146
+
1147
+ inference_estimating_func (callable):
1148
+ The estimating function that corresponds to inference.
1149
+
1150
+ action_prob_func_args_beta_index (int):
1151
+ The index of the beta argument in the action probability function's arguments.
1152
+
1153
+ inference_func_args_theta_index (int):
1154
+ The index of the theta parameter in the inference loss or estimating function arguments.
1155
+
1156
+ action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
1157
+ A map from decision times to tuples of arguments for this user for the action
1158
+ probability function. This is for all decision times (args are an empty
1159
+ tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
1160
+ ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
1161
+ will occur.
1162
+
1163
+ threaded_action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
1164
+ A map from decision times to tuples of arguments for the action
1165
+ probability function, with the shared betas threaded in for differentation. Decision
1166
+ times should be sorted.
1167
+
1168
+ threaded_inference_func_args (dict[collections.abc.Hashable, tuple[Any, ...]]):
1169
+ A tuple containing the arguments for the inference
1170
+ estimating function for this user, with the shared betas threaded in for differentiation.
1171
+
1172
+ policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
1173
+ A dictionary mapping decision times to the policy number in use. This may be
1174
+ user-specific. Should be sorted by decision time. Only applies to in-study decision
1175
+ times!
1176
+
1177
+ action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
1178
+ A dictionary mapping decision times to actions taken. Only applies to in-study decision
1179
+ times!
1180
+
1181
+ beta_index_by_policy_num (dict[int | float, int]):
1182
+ A dictionary mapping policy numbers to the index of the corresponding beta in
1183
+ all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
1184
+
1185
+ Returns:
1186
+ jnp.ndarray: A 1-D JAX NumPy array representing the user's weighted inference estimating function.
1187
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical meat contribution.
1188
+ jnp.ndarray: A 2-D JAX NumPy matrix representing the user's classical bread contribution.
1189
+ """
1190
+
1191
+ logger.info(
1192
+ "Computing only weighted inference estimating function stack for user %s.",
1193
+ user_id,
1194
+ )
1195
+
1196
+ # First, reformat the supplied data into more convenient structures.
1197
+
1198
+ # 1. Get the first time after the first update for convenience.
1199
+ # This is used to form the Radon-Nikodym weights for the right times.
1200
+ _, first_time_after_first_update = after_study_analysis.get_min_time_by_policy_num(
1201
+ policy_num_by_decision_time,
1202
+ beta_index_by_policy_num,
1203
+ )
1204
+
1205
+ # 2. Get the start and end times for this user.
1206
+ user_start_time = math.inf
1207
+ user_end_time = -math.inf
1208
+ for decision_time in action_by_decision_time:
1209
+ user_start_time = min(user_start_time, decision_time)
1210
+ user_end_time = max(user_end_time, decision_time)
1211
+
1212
+ # 3. Calculate the Radon-Nikodym weights for the inference estimating function.
1213
+ in_study_action_prob_func_args = [
1214
+ args for args in action_prob_func_args_by_decision_time.values() if args
1215
+ ]
1216
+ in_study_betas_list_by_decision_time_index = jnp.array(
1217
+ [
1218
+ action_prob_func_args[action_prob_func_args_beta_index]
1219
+ for action_prob_func_args in in_study_action_prob_func_args
1220
+ ]
1221
+ )
1222
+ in_study_actions_list_by_decision_time_index = jnp.array(
1223
+ list(action_by_decision_time.values())
1224
+ )
1225
+
1226
+ # Sort the threaded args by decision time to be cautious. We check if the
1227
+ # user id is present in the user args dict because we may call this on a
1228
+ # subset of the user arg dict when we are batching arguments by shape
1229
+ sorted_threaded_action_prob_args_by_decision_time = {
1230
+ decision_time: threaded_action_prob_func_args_by_decision_time[decision_time]
1231
+ for decision_time in range(user_start_time, user_end_time + 1)
1232
+ if decision_time in threaded_action_prob_func_args_by_decision_time
1233
+ }
1234
+
1235
+ num_args = None
1236
+ for args in sorted_threaded_action_prob_args_by_decision_time.values():
1237
+ if args:
1238
+ num_args = len(args)
1239
+ break
1240
+
1241
+ # NOTE: Cannot do [[]] * num_args here! Then all lists point
1242
+ # same object...
1243
+ batched_threaded_arg_lists = [[] for _ in range(num_args)]
1244
+ for (
1245
+ decision_time,
1246
+ args,
1247
+ ) in sorted_threaded_action_prob_args_by_decision_time.items():
1248
+ if not args:
1249
+ continue
1250
+ for idx, arg in enumerate(args):
1251
+ batched_threaded_arg_lists[idx].append(arg)
1252
+
1253
+ batched_threaded_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensors(
1254
+ batched_threaded_arg_lists
1255
+ )
1256
+
1257
+ # Note that we do NOT use the shared betas in the first arg to the weight function,
1258
+ # since we don't want differentiation to happen with respect to them.
1259
+ # Just grab the original beta from the update function arguments. This is the same
1260
+ # value, but impervious to differentiation with respect to all_post_update_betas. The
1261
+ # args, on the other hand, are a function of all_post_update_betas.
1262
+ in_study_weights = jax.vmap(
1263
+ fun=after_study_analysis.get_radon_nikodym_weight,
1264
+ in_axes=[0, None, None, 0] + batch_axes,
1265
+ out_axes=0,
1266
+ )(
1267
+ in_study_betas_list_by_decision_time_index,
1268
+ action_prob_func,
1269
+ action_prob_func_args_beta_index,
1270
+ in_study_actions_list_by_decision_time_index,
1271
+ *batched_threaded_arg_tensors,
1272
+ )
1273
+
1274
+ in_study_index = 0
1275
+ decision_time_to_all_weights_index_offset = min(
1276
+ sorted_threaded_action_prob_args_by_decision_time
1277
+ )
1278
+ all_weights_raw = []
1279
+ for (
1280
+ decision_time,
1281
+ args,
1282
+ ) in sorted_threaded_action_prob_args_by_decision_time.items():
1283
+ all_weights_raw.append(in_study_weights[in_study_index] if args else 1.0)
1284
+ in_study_index += 1
1285
+ all_weights = jnp.array(all_weights_raw)
1286
+
1287
+ # 4. Form the weighted inference estimating equation.
1288
+ weighted_inference_estimating_function = jnp.prod(
1289
+ all_weights[
1290
+ max(first_time_after_first_update, user_start_time)
1291
+ - decision_time_to_all_weights_index_offset : user_end_time
1292
+ + 1
1293
+ - decision_time_to_all_weights_index_offset,
1294
+ ]
1295
+ # If the user exited the study before there were any updates,
1296
+ # this variable will be None and the above code to grab a weight would
1297
+ # throw an error. Just use 1 to include the unweighted estimating function
1298
+ # if they have data to contribute here (pretty sure everyone should?)
1299
+ if first_time_after_first_update is not None
1300
+ else 1
1301
+ ) * inference_estimating_func(*threaded_inference_func_args)
1302
+
1303
+ return (
1304
+ weighted_inference_estimating_function,
1305
+ jnp.outer(
1306
+ weighted_inference_estimating_function,
1307
+ weighted_inference_estimating_function,
1308
+ ),
1309
+ jax.jacrev(inference_estimating_func, argnums=inference_func_args_theta_index)(
1310
+ *threaded_inference_func_args
1311
+ ),
1312
+ )