lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,870 @@
1
+ from __future__ import annotations
2
+
3
+ import collections.abc
4
+ import logging
5
+ import math
6
+ from typing import Any
7
+
8
+ from jax import numpy as jnp
9
+ import jax
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from .arg_threading_helpers import thread_action_prob_func_args, thread_update_func_args
14
+ from .constants import FunctionTypes
15
+ from .helper_functions import (
16
+ calculate_beta_dim,
17
+ collect_all_post_update_betas,
18
+ construct_beta_index_by_policy_num_map,
19
+ extract_action_and_policy_by_decision_time_by_user_id,
20
+ flatten_params,
21
+ get_min_time_by_policy_num,
22
+ get_radon_nikodym_weight,
23
+ unflatten_params,
24
+ )
25
+ from . import input_checks
26
+ from .vmap_helpers import stack_batched_arg_lists_into_tensors
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+ logging.basicConfig(
31
+ format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
32
+ datefmt="%Y-%m-%d:%H:%M:%S",
33
+ level=logging.INFO,
34
+ )
35
+
36
+
37
+ class TrialConditioningMonitor:
38
+ whole_RL_block_conditioning_threshold = None
39
+ diagonal_RL_block_conditioning_threshold = None
40
+
41
+ def __init__(
42
+ self,
43
+ whole_RL_block_conditioning_threshold: int = 1000,
44
+ diagonal_RL_block_conditioning_threshold: int = 100,
45
+ ):
46
+ self.whole_RL_block_conditioning_threshold = (
47
+ whole_RL_block_conditioning_threshold
48
+ )
49
+ self.diagonal_RL_block_conditioning_threshold = (
50
+ diagonal_RL_block_conditioning_threshold
51
+ )
52
+ self.latest_phi_dot_bar = None
53
+
54
+ def assess_update(
55
+ self,
56
+ proposed_policy_num: int | float,
57
+ study_df: pd.DataFrame,
58
+ action_prob_func: callable,
59
+ action_prob_func_args: dict,
60
+ action_prob_func_args_beta_index: int,
61
+ alg_update_func: callable,
62
+ alg_update_func_type: str,
63
+ alg_update_func_args: dict,
64
+ alg_update_func_args_beta_index: int,
65
+ alg_update_func_args_action_prob_index: int,
66
+ alg_update_func_args_action_prob_times_index: int,
67
+ in_study_col_name: str,
68
+ action_col_name: str,
69
+ policy_num_col_name: str,
70
+ calendar_t_col_name: str,
71
+ user_id_col_name: str,
72
+ action_prob_col_name: str,
73
+ suppress_interactive_data_checks: bool,
74
+ suppress_all_data_checks: bool,
75
+ incremental: bool = True,
76
+ ) -> None:
77
+ """
78
+ Analyzes a dataset to estimate parameters and variance using adaptive and classical sandwich estimators.
79
+
80
+ Parameters:
81
+ proposed_policy_num (int | float):
82
+ The policy number of the proposed update.
83
+ study_df (pd.DataFrame):
84
+ DataFrame containing the study data.
85
+ action_prob_func (str):
86
+ Action probability function.
87
+ action_prob_func_args (dict):
88
+ Arguments for the action probability function.
89
+ action_prob_func_args_beta_index (int):
90
+ Index for beta in action probability function arguments.
91
+ alg_update_func (str):
92
+ Algorithm update function.
93
+ alg_update_func_type (str):
94
+ Type of the algorithm update function.
95
+ alg_update_func_args (dict):
96
+ Arguments for the algorithm update function.
97
+ alg_update_func_args_beta_index (int):
98
+ Index for beta in algorithm update function arguments.
99
+ alg_update_func_args_action_prob_index (int):
100
+ Index for action probability in algorithm update function arguments.
101
+ alg_update_func_args_action_prob_times_index (int):
102
+ Index for action probability times in algorithm update function arguments.
103
+ in_study_col_name (str):
104
+ Column name indicating if a user is in the study in the study dataframe.
105
+ action_col_name (str):
106
+ Column name for actions in the study dataframe.
107
+ policy_num_col_name (str):
108
+ Column name for policy numbers in the study dataframe.
109
+ calendar_t_col_name (str):
110
+ Column name for calendar time in the study dataframe.
111
+ user_id_col_name (str):
112
+ Column name for user IDs in the study dataframe.
113
+ action_prob_col_name (str):
114
+ Column name for action probabilities in the study dataframe.
115
+ reward_col_name (str):
116
+ Column name for rewards in the study dataframe.
117
+ suppress_interactive_data_checks (bool):
118
+ Whether to suppress interactive data checks. This should be used in simulations, for example.
119
+ suppress_all_data_checks (bool):
120
+ Whether to suppress all data checks. Not recommended.
121
+ small_sample_correction (str):
122
+ Type of small sample correction to apply.
123
+ collect_data_for_blowup_supervised_learning (bool):
124
+ Whether to collect data for doing supervised learning about adaptive sandwich blowup.
125
+ form_adaptive_meat_adjustments_explicitly (bool):
126
+ If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
127
+ sandwich from the classical sandwich. This is for diagnostic purposes, as the
128
+ adaptive sandwich is formed without doing this.
129
+ stabilize_joint_adaptive_bread_inverse (bool):
130
+ If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
131
+ thresholds.
132
+
133
+ Returns:
134
+ None: The function writes analysis results and debug pieces to files in the same directory as
135
+ the input files.
136
+ """
137
+
138
+ beta_dim = calculate_beta_dim(
139
+ action_prob_func_args, action_prob_func_args_beta_index
140
+ )
141
+
142
+ if not suppress_all_data_checks:
143
+ input_checks.perform_alg_only_input_checks(
144
+ study_df,
145
+ in_study_col_name,
146
+ policy_num_col_name,
147
+ calendar_t_col_name,
148
+ user_id_col_name,
149
+ action_prob_col_name,
150
+ action_prob_func,
151
+ action_prob_func_args,
152
+ action_prob_func_args_beta_index,
153
+ alg_update_func_args,
154
+ alg_update_func_args_beta_index,
155
+ alg_update_func_args_action_prob_index,
156
+ alg_update_func_args_action_prob_times_index,
157
+ suppress_interactive_data_checks,
158
+ )
159
+
160
+ beta_index_by_policy_num, initial_policy_num = (
161
+ construct_beta_index_by_policy_num_map(
162
+ study_df, policy_num_col_name, in_study_col_name
163
+ )
164
+ )
165
+ # We augment the produced map to include the proposed policy num.
166
+ # This is necessary because the logic above assumes all policies are present in the
167
+ # study df, whereas for us we are only passing the data *used* for the current update,
168
+ # i.e. up to the previous policy.
169
+ beta_index_by_policy_num[proposed_policy_num] = len(beta_index_by_policy_num)
170
+
171
+ all_post_update_betas = collect_all_post_update_betas(
172
+ beta_index_by_policy_num,
173
+ alg_update_func_args,
174
+ alg_update_func_args_beta_index,
175
+ )
176
+
177
+ action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id = (
178
+ extract_action_and_policy_by_decision_time_by_user_id(
179
+ study_df,
180
+ user_id_col_name,
181
+ in_study_col_name,
182
+ calendar_t_col_name,
183
+ action_col_name,
184
+ policy_num_col_name,
185
+ )
186
+ )
187
+
188
+ user_ids = jnp.array(study_df[user_id_col_name].unique())
189
+
190
+ phi_dot_bar, avg_estimating_function_stack = self.construct_phi_dot_bar_so_far(
191
+ all_post_update_betas,
192
+ user_ids,
193
+ action_prob_func,
194
+ action_prob_func_args_beta_index,
195
+ alg_update_func,
196
+ alg_update_func_type,
197
+ alg_update_func_args_beta_index,
198
+ alg_update_func_args_action_prob_index,
199
+ alg_update_func_args_action_prob_times_index,
200
+ action_prob_func_args,
201
+ policy_num_by_decision_time_by_user_id,
202
+ initial_policy_num,
203
+ beta_index_by_policy_num,
204
+ alg_update_func_args,
205
+ action_by_decision_time_by_user_id,
206
+ suppress_all_data_checks,
207
+ suppress_interactive_data_checks,
208
+ incremental=incremental,
209
+ )
210
+
211
+ if not suppress_all_data_checks:
212
+ input_checks.require_RL_estimating_functions_sum_to_zero(
213
+ avg_estimating_function_stack,
214
+ beta_dim,
215
+ suppress_interactive_data_checks,
216
+ )
217
+
218
+ # Decide whether to accept or reject the update based on conditioning
219
+ update_rejected = False
220
+ rejection_reason = ""
221
+ whole_RL_block_condition_number = np.linalg.cond(phi_dot_bar)
222
+ new_diagonal_RL_block_condition_number = np.linalg.cond(
223
+ phi_dot_bar[-beta_dim:, -beta_dim:]
224
+ )
225
+
226
+ if whole_RL_block_condition_number > self.whole_RL_block_conditioning_threshold:
227
+ logger.warning(
228
+ "The RL portion of the bread inverse up to this point exceeds the threshold set (condition number: %s, threshold: %s). Consider an alternative update strategy which produces less dependence on previous RL parameters (via the data they produced) and/or improves the conditioning of each update itself. Regularization may help with both of these.",
229
+ whole_RL_block_condition_number,
230
+ self.whole_RL_block_conditioning_threshold,
231
+ )
232
+ update_rejected = True
233
+ rejection_reason = "whole_block_poor_conditioning"
234
+ elif (
235
+ new_diagonal_RL_block_condition_number
236
+ > self.diagonal_RL_block_conditioning_threshold
237
+ ):
238
+ logger.warning(
239
+ "The diagonal RL block of the bread inverse up to this point exceeds the threshold set (condition number: %s, threshold: %s). This may illustrate a fundamental problem with the conditioning of the RL update procedure.",
240
+ new_diagonal_RL_block_condition_number,
241
+ self.diagonal_RL_block_conditioning_threshold,
242
+ )
243
+ update_rejected = True
244
+ rejection_reason = "diagonal_block_poor_conditioning"
245
+
246
+ # TODO: Regression -> prediction over going over threshold? Take in estimated num updates if so.
247
+
248
+ ans = {
249
+ "update_rejected": update_rejected,
250
+ "rejection_reason": rejection_reason,
251
+ "whole_RL_block_condition_number": whole_RL_block_condition_number,
252
+ "whole_RL_block_conditioning_threshold": self.whole_RL_block_conditioning_threshold,
253
+ "new_diagonal_RL_block_condition_number": new_diagonal_RL_block_condition_number,
254
+ "diagonal_RL_block_conditioning_threshold": self.diagonal_RL_block_conditioning_threshold,
255
+ }
256
+ logger.info("Update assessment results: %s", ans)
257
+ return ans
258
+
259
+ def construct_phi_dot_bar_so_far(
260
+ self,
261
+ all_post_update_betas: jnp.ndarray,
262
+ user_ids: jnp.ndarray,
263
+ action_prob_func: callable,
264
+ action_prob_func_args_beta_index: int,
265
+ alg_update_func: callable,
266
+ alg_update_func_type: str,
267
+ alg_update_func_args_beta_index: int,
268
+ alg_update_func_args_action_prob_index: int,
269
+ alg_update_func_args_action_prob_times_index: int,
270
+ action_prob_func_args_by_user_id_by_decision_time: dict[
271
+ collections.abc.Hashable, dict[int, tuple[Any, ...]]
272
+ ],
273
+ policy_num_by_decision_time_by_user_id: dict[
274
+ collections.abc.Hashable, dict[int, int | float]
275
+ ],
276
+ initial_policy_num: int | float,
277
+ beta_index_by_policy_num: dict[int | float, int],
278
+ update_func_args_by_by_user_id_by_policy_num: dict[
279
+ collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
280
+ ],
281
+ action_by_decision_time_by_user_id: dict[
282
+ collections.abc.Hashable, dict[int, int]
283
+ ],
284
+ suppress_all_data_checks: bool,
285
+ suppress_interactive_data_checks: bool,
286
+ incremental: bool,
287
+ ) -> tuple[
288
+ jnp.ndarray[jnp.float32],
289
+ jnp.ndarray[jnp.float32],
290
+ ]:
291
+ """
292
+ Constructs the classical and adaptive inverse bread and meat matrices, as well as the average
293
+ estimating function stack and some other intermediate pieces.
294
+
295
+ This is done by computing and differentiating the average weighted estimating function stack
296
+ with respect to the betas and theta, using the resulting Jacobian to compute the inverse bread
297
+ and meat matrices, and then stably computing sandwiches.
298
+
299
+ Args:
300
+ all_post_update_betas (jnp.ndarray):
301
+ A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
302
+ user_ids (jnp.ndarray):
303
+ A 1-D JAX NumPy array holding all user IDs in the study.
304
+ action_prob_func (callable):
305
+ The action probability function.
306
+ action_prob_func_args_beta_index (int):
307
+ The index of beta in the action probability function arguments tuples.
308
+ alg_update_func (callable):
309
+ The algorithm update function.
310
+ alg_update_func_type (str):
311
+ The type of the algorithm update function (loss or estimating).
312
+ alg_update_func_args_beta_index (int):
313
+ The index of beta in the update function arguments tuples.
314
+ alg_update_func_args_action_prob_index (int):
315
+ The index of action probabilities in the update function arguments tuple, if
316
+ applicable. -1 otherwise.
317
+ alg_update_func_args_action_prob_times_index (int):
318
+ The index in the update function arguments tuple where an array of times for which the
319
+ given action probabilities apply is provided, if applicable. -1 otherwise.
320
+ action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
321
+ A dictionary mapping decision times to maps of user ids to the function arguments
322
+ required to compute action probabilities for this user.
323
+ policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
324
+ A map of user ids to dictionaries mapping decision times to the policy number in use.
325
+ Only applies to in-study decision times!
326
+ initial_policy_num (int | float):
327
+ The policy number of the initial policy before any updates.
328
+ beta_index_by_policy_num (dict[int | float, int]):
329
+ A dictionary mapping policy numbers to the index of the corresponding beta in
330
+ all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
331
+ update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
332
+ A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
333
+ to their respective update function arguments.
334
+ action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
335
+ A dictionary mapping user IDs to their respective actions taken at each decision time.
336
+ Only applies to in-study decision times!
337
+ suppress_all_data_checks (bool):
338
+ If True, suppresses carrying out any data checks at all.
339
+ suppress_interactive_data_checks (bool):
340
+ If True, suppresses interactive data checks that would otherwise be performed to ensure
341
+ the correctness of the threaded arguments. The checks are still performed, but
342
+ any interactive prompts are suppressed.
343
+ incremental (bool): Whether to form the whole phi-dot-bar so far, or just form the latest
344
+ block row and add to the cached previous phi-dot-bar.
345
+ """
346
+ logger.info(
347
+ "Differentiating average weighted estimating function stack and collecting auxiliary values."
348
+ )
349
+ beta_dim = all_post_update_betas.shape[1]
350
+
351
+ if incremental and self.latest_phi_dot_bar is not None:
352
+ # We only need to compute the latest block row of the Jacobian.
353
+ (
354
+ phi_dot_bar_latest_block,
355
+ avg_RL_estimating_function_stack,
356
+ ) = jax.jacrev(
357
+ self.get_avg_weighted_RL_estimating_function_stacks_and_aux_values,
358
+ has_aux=True,
359
+ )(
360
+ # While JAX can technically differentiate with respect to a list of JAX arrays,
361
+ # it is apparently more efficient to flatten them into a single array. This is done
362
+ # here to improve performance. We can simply unflatten them inside the function.
363
+ flatten_params(all_post_update_betas, jnp.array([])),
364
+ beta_dim,
365
+ user_ids,
366
+ action_prob_func,
367
+ action_prob_func_args_beta_index,
368
+ alg_update_func,
369
+ alg_update_func_type,
370
+ alg_update_func_args_beta_index,
371
+ alg_update_func_args_action_prob_index,
372
+ alg_update_func_args_action_prob_times_index,
373
+ action_prob_func_args_by_user_id_by_decision_time,
374
+ policy_num_by_decision_time_by_user_id,
375
+ initial_policy_num,
376
+ beta_index_by_policy_num,
377
+ update_func_args_by_by_user_id_by_policy_num,
378
+ action_by_decision_time_by_user_id,
379
+ suppress_all_data_checks,
380
+ suppress_interactive_data_checks,
381
+ only_latest_block=True,
382
+ )
383
+
384
+ # Now we can just augment the cached previous phi-dot-bar with zeros
385
+ # and the latest block row.
386
+ phi_dot_bar = jnp.block(
387
+ [
388
+ self.latest_phi_dot_bar,
389
+ jnp.zeros((beta_dim, beta_dim)),
390
+ phi_dot_bar_latest_block[-beta_dim:, :],
391
+ ]
392
+ )
393
+ else:
394
+
395
+ (
396
+ phi_dot_bar,
397
+ avg_RL_estimating_function_stack,
398
+ ) = jax.jacrev(
399
+ self.get_avg_weighted_RL_estimating_function_stacks_and_aux_values,
400
+ has_aux=True,
401
+ )(
402
+ # While JAX can technically differentiate with respect to a list of JAX arrays,
403
+ # it is apparently more efficient to flatten them into a single array. This is done
404
+ # here to improve performance. We can simply unflatten them inside the function.
405
+ flatten_params(all_post_update_betas, jnp.array([])),
406
+ beta_dim,
407
+ user_ids,
408
+ action_prob_func,
409
+ action_prob_func_args_beta_index,
410
+ alg_update_func,
411
+ alg_update_func_type,
412
+ alg_update_func_args_beta_index,
413
+ alg_update_func_args_action_prob_index,
414
+ alg_update_func_args_action_prob_times_index,
415
+ action_prob_func_args_by_user_id_by_decision_time,
416
+ policy_num_by_decision_time_by_user_id,
417
+ initial_policy_num,
418
+ beta_index_by_policy_num,
419
+ update_func_args_by_by_user_id_by_policy_num,
420
+ action_by_decision_time_by_user_id,
421
+ suppress_all_data_checks,
422
+ suppress_interactive_data_checks,
423
+ )
424
+
425
+ self.latest_phi_dot_bar = phi_dot_bar
426
+ return phi_dot_bar, avg_RL_estimating_function_stack
427
+
428
+ def get_avg_weighted_RL_estimating_function_stacks_and_aux_values(
429
+ self,
430
+ flattened_betas_and_theta: jnp.ndarray,
431
+ beta_dim: int,
432
+ user_ids: jnp.ndarray,
433
+ action_prob_func: callable,
434
+ action_prob_func_args_beta_index: int,
435
+ alg_update_func: callable,
436
+ alg_update_func_type: str,
437
+ alg_update_func_args_beta_index: int,
438
+ alg_update_func_args_action_prob_index: int,
439
+ alg_update_func_args_action_prob_times_index: int,
440
+ action_prob_func_args_by_user_id_by_decision_time: dict[
441
+ collections.abc.Hashable, dict[int, tuple[Any, ...]]
442
+ ],
443
+ policy_num_by_decision_time_by_user_id: dict[
444
+ collections.abc.Hashable, dict[int, int | float]
445
+ ],
446
+ initial_policy_num: int | float,
447
+ beta_index_by_policy_num: dict[int | float, int],
448
+ update_func_args_by_by_user_id_by_policy_num: dict[
449
+ collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
450
+ ],
451
+ action_by_decision_time_by_user_id: dict[
452
+ collections.abc.Hashable, dict[int, int]
453
+ ],
454
+ suppress_all_data_checks: bool,
455
+ suppress_interactive_data_checks: bool,
456
+ only_latest_block: bool = False,
457
+ ) -> tuple[
458
+ jnp.ndarray,
459
+ tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
460
+ ]:
461
+ """
462
+ Computes the average weighted estimating function stack across all users, along with
463
+ auxiliary values used to construct the adaptive and classical sandwich variances.
464
+
465
+ If only_latest_block is True, only uses data from the most recent update.
466
+
467
+ Args:
468
+ flattened_betas_and_theta (jnp.ndarray):
469
+ A list of JAX NumPy arrays representing the betas produced by all updates and the
470
+ theta value, in that order. Important that this is a 1D array for efficiency reasons.
471
+ We simply extract the betas and theta from this array below.
472
+ beta_dim (int):
473
+ The dimension of each of the beta parameters.
474
+ user_ids (jnp.ndarray):
475
+ A 1D JAX NumPy array of user IDs.
476
+ action_prob_func (callable):
477
+ The action probability function.
478
+ action_prob_func_args_beta_index (int):
479
+ The index of beta in the action probability function arguments tuples.
480
+ alg_update_func (callable):
481
+ The algorithm update function.
482
+ alg_update_func_type (str):
483
+ The type of the algorithm update function (loss or estimating).
484
+ alg_update_func_args_beta_index (int):
485
+ The index of beta in the update function arguments tuples.
486
+ alg_update_func_args_action_prob_index (int):
487
+ The index of action probabilities in the update function arguments tuple, if
488
+ applicable. -1 otherwise.
489
+ alg_update_func_args_action_prob_times_index (int):
490
+ The index in the update function arguments tuple where an array of times for which the
491
+ given action probabilities apply is provided, if applicable. -1 otherwise.
492
+ action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
493
+ A dictionary mapping decision times to maps of user ids to the function arguments
494
+ required to compute action probabilities for this user.
495
+ policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
496
+ A map of user ids to dictionaries mapping decision times to the policy number in use.
497
+ Only applies to in-study decision times!
498
+ initial_policy_num (int | float):
499
+ The policy number of the initial policy before any updates.
500
+ beta_index_by_policy_num (dict[int | float, int]):
501
+ A dictionary mapping policy numbers to the index of the corresponding beta in
502
+ all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
503
+ update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
504
+ A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
505
+ to their respective update function arguments.
506
+ action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
507
+ A dictionary mapping user IDs to their respective actions taken at each decision time.
508
+ Only applies to in-study decision times!
509
+ suppress_all_data_checks (bool):
510
+ If True, suppresses carrying out any data checks at all.
511
+ suppress_interactive_data_checks (bool):
512
+ If True, suppresses interactive data checks that would otherwise be performed to ensure
513
+ the correctness of the threaded arguments. The checks are still performed, but
514
+ any interactive prompts are suppressed.
515
+ only_latest_block (bool):
516
+ If True, only uses data from the most recent update.
517
+
518
+ Returns:
519
+ jnp.ndarray:
520
+ A 2D JAX NumPy array holding the average weighted estimating function stack.
521
+ jnp.ndarray:
522
+ The same, again. We will differentiate the first output.
523
+ """
524
+
525
+ algorithm_estimating_func = (
526
+ jax.grad(alg_update_func, argnums=alg_update_func_args_beta_index)
527
+ if (alg_update_func_type == FunctionTypes.LOSS)
528
+ else alg_update_func
529
+ )
530
+
531
+ betas, _ = unflatten_params(
532
+ flattened_betas_and_theta,
533
+ beta_dim,
534
+ 0,
535
+ )
536
+ # 1. If only_latest_block is True, we need to filter all the arguments to only
537
+ # include those relevant to the latest update. We still need action probabilities
538
+ # from the beginning for the weights, but the update function args can be trimmed
539
+ # to the max policy so that the loop single_user_weighted_RL_estimating_function_stacker
540
+ # is only over one policy.
541
+ if only_latest_block:
542
+ logger.info(
543
+ "Filtering algorithm update function arguments to only include those relevant to the latest update."
544
+ )
545
+ max_policy_num = max(beta_index_by_policy_num)
546
+ update_func_args_by_by_user_id_by_policy_num = {
547
+ max_policy_num: update_func_args_by_by_user_id_by_policy_num[
548
+ max_policy_num
549
+ ]
550
+ }
551
+
552
+ # 2. Thread in the betas and theta in all_post_update_betas_and_theta into the arguments
553
+ # supplied for the above functions, so that differentiation works correctly. The existing
554
+ # values should be the same, but not connected to the parameter we are differentiating
555
+ # with respect to. Note we will also find it useful below to have the action probability args
556
+ # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
557
+
558
+ logger.info("Threading in betas to action probability arguments for all users.")
559
+ (
560
+ threaded_action_prob_func_args_by_decision_time_by_user_id,
561
+ action_prob_func_args_by_decision_time_by_user_id,
562
+ ) = thread_action_prob_func_args(
563
+ action_prob_func_args_by_user_id_by_decision_time,
564
+ policy_num_by_decision_time_by_user_id,
565
+ initial_policy_num,
566
+ betas,
567
+ beta_index_by_policy_num,
568
+ action_prob_func_args_beta_index,
569
+ )
570
+
571
+ # 3. Thread the central betas into the algorithm update function arguments
572
+ # and replace any action probabilities with reconstructed ones from the above
573
+ # arguments with the central betas introduced.
574
+ logger.info(
575
+ "Threading in betas and beta-dependent action probabilities to algorithm update "
576
+ "function args for all users"
577
+ )
578
+ threaded_update_func_args_by_policy_num_by_user_id = thread_update_func_args(
579
+ update_func_args_by_by_user_id_by_policy_num,
580
+ betas,
581
+ beta_index_by_policy_num,
582
+ alg_update_func_args_beta_index,
583
+ alg_update_func_args_action_prob_index,
584
+ alg_update_func_args_action_prob_times_index,
585
+ threaded_action_prob_func_args_by_decision_time_by_user_id,
586
+ action_prob_func,
587
+ )
588
+
589
+ # If action probabilities are used in the algorithm estimating function, make
590
+ # sure that substituting in the reconstructed action probabilities is approximately
591
+ # equivalent to using the original action probabilities.
592
+ if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
593
+ input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
594
+ algorithm_estimating_func,
595
+ update_func_args_by_by_user_id_by_policy_num,
596
+ threaded_update_func_args_by_policy_num_by_user_id,
597
+ suppress_interactive_data_checks,
598
+ )
599
+
600
+ # 5. Now we can compute the weighted estimating function stacks for all users
601
+ # as well as collect related values used to construct the adaptive and classical
602
+ # sandwich variances.
603
+ RL_stacks = jnp.array(
604
+ [
605
+ self.single_user_weighted_RL_estimating_function_stacker(
606
+ beta_dim,
607
+ user_id,
608
+ action_prob_func,
609
+ algorithm_estimating_func,
610
+ action_prob_func_args_beta_index,
611
+ action_prob_func_args_by_decision_time_by_user_id[user_id],
612
+ threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
613
+ threaded_update_func_args_by_policy_num_by_user_id[user_id],
614
+ policy_num_by_decision_time_by_user_id[user_id],
615
+ action_by_decision_time_by_user_id[user_id],
616
+ beta_index_by_policy_num,
617
+ )
618
+ for user_id in user_ids.tolist()
619
+ ]
620
+ )
621
+
622
+ # 6. We will differentiate the first output, while the second will be used
623
+ # for an estimating function sum check.
624
+ mean_stack_across_users = jnp.mean(RL_stacks, axis=0)
625
+ return mean_stack_across_users, mean_stack_across_users
626
+
627
+ def single_user_weighted_RL_estimating_function_stacker(
628
+ self,
629
+ beta_dim: int,
630
+ user_id: collections.abc.Hashable,
631
+ action_prob_func: callable,
632
+ algorithm_estimating_func: callable,
633
+ action_prob_func_args_beta_index: int,
634
+ action_prob_func_args_by_decision_time: dict[
635
+ int, dict[collections.abc.Hashable, tuple[Any, ...]]
636
+ ],
637
+ threaded_action_prob_func_args_by_decision_time: dict[
638
+ collections.abc.Hashable, dict[int, tuple[Any, ...]]
639
+ ],
640
+ threaded_update_func_args_by_policy_num: dict[
641
+ collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
642
+ ],
643
+ policy_num_by_decision_time: dict[
644
+ collections.abc.Hashable, dict[int, int | float]
645
+ ],
646
+ action_by_decision_time: dict[collections.abc.Hashable, dict[int, int]],
647
+ beta_index_by_policy_num: dict[int | float, int],
648
+ ) -> tuple[
649
+ jnp.ndarray[jnp.float32],
650
+ jnp.ndarray[jnp.float32],
651
+ jnp.ndarray[jnp.float32],
652
+ jnp.ndarray[jnp.float32],
653
+ ]:
654
+ """
655
+ Computes a weighted estimating function stack for a given algorithm estimating function
656
+ and arguments, inference estimating functio and arguments, and action probability function and
657
+ arguments.
658
+
659
+ Args:
660
+ beta_dim (list[jnp.ndarray]):
661
+ A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
662
+
663
+ user_id (collections.abc.Hashable):
664
+ The user ID for which to compute the weighted estimating function stack.
665
+
666
+ action_prob_func (callable):
667
+ The function used to compute the probability of action 1 at a given decision time for
668
+ a particular user given their state and the algorithm parameters.
669
+
670
+ algorithm_estimating_func (callable):
671
+ The estimating function that corresponds to algorithm updates.
672
+
673
+ action_prob_func_args_beta_index (int):
674
+ The index of the beta argument in the action probability function's arguments.
675
+
676
+ action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
677
+ A map from decision times to tuples of arguments for this user for the action
678
+ probability function. This is for all decision times (args are an empty
679
+ tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
680
+ ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
681
+ will occur.
682
+
683
+ threaded_action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
684
+ A map from decision times to tuples of arguments for the action
685
+ probability function, with the shared betas threaded in for differentation. Decision
686
+ times should be sorted.
687
+
688
+ threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
689
+ A map from policy numbers to tuples containing the arguments for
690
+ the corresponding estimating functions for this user, with the shared betas threaded in
691
+ for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
692
+ should be sorted.
693
+
694
+ policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
695
+ A dictionary mapping decision times to the policy number in use. This may be
696
+ user-specific. Should be sorted by decision time. Only applies to in-study decision
697
+ times!
698
+
699
+ action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
700
+ A dictionary mapping decision times to actions taken. Only applies to in-study decision
701
+ times!
702
+
703
+ beta_index_by_policy_num (dict[int | float, int]):
704
+ A dictionary mapping policy numbers to the index of the corresponding beta in
705
+ all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
706
+
707
+ Returns:
708
+ jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the user's weighted
709
+ estimating function stack.
710
+ """
711
+
712
+ logger.info(
713
+ "Computing weighted estimating function stack for user %s.", user_id
714
+ )
715
+
716
+ # First, reformat the supplied data into more convenient structures.
717
+
718
+ # 1. Form a dictionary mapping policy numbers to the first time they were
719
+ # applicable (for this user). Note that this includes ALL policies, initial
720
+ # fallbacks included.
721
+ # Collect the first time after the first update separately for convenience.
722
+ # These are both used to form the Radon-Nikodym weights for the right times.
723
+ min_time_by_policy_num, first_time_after_first_update = (
724
+ get_min_time_by_policy_num(
725
+ policy_num_by_decision_time,
726
+ beta_index_by_policy_num,
727
+ )
728
+ )
729
+
730
+ # 2. Get the start and end times for this user.
731
+ user_start_time = math.inf
732
+ user_end_time = -math.inf
733
+ for decision_time in action_by_decision_time:
734
+ user_start_time = min(user_start_time, decision_time)
735
+ user_end_time = max(user_end_time, decision_time)
736
+
737
+ # 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
738
+ logger.info(
739
+ "Computing the algorithm component of the weighted estimating function stack for user %s.",
740
+ user_id,
741
+ )
742
+
743
+ in_study_action_prob_func_args = [
744
+ args for args in action_prob_func_args_by_decision_time.values() if args
745
+ ]
746
+ in_study_betas_list_by_decision_time_index = jnp.array(
747
+ [
748
+ action_prob_func_args[action_prob_func_args_beta_index]
749
+ for action_prob_func_args in in_study_action_prob_func_args
750
+ ]
751
+ )
752
+ in_study_actions_list_by_decision_time_index = jnp.array(
753
+ list(action_by_decision_time.values())
754
+ )
755
+
756
+ # Sort the threaded args by decision time to be cautious. We check if the
757
+ # user id is present in the user args dict because we may call this on a
758
+ # subset of the user arg dict when we are batching arguments by shape
759
+ sorted_threaded_action_prob_args_by_decision_time = {
760
+ decision_time: threaded_action_prob_func_args_by_decision_time[
761
+ decision_time
762
+ ]
763
+ for decision_time in range(user_start_time, user_end_time + 1)
764
+ if decision_time in threaded_action_prob_func_args_by_decision_time
765
+ }
766
+
767
+ num_args = None
768
+ for args in sorted_threaded_action_prob_args_by_decision_time.values():
769
+ if args:
770
+ num_args = len(args)
771
+ break
772
+
773
+ # NOTE: Cannot do [[]] * num_args here! Then all lists point
774
+ # same object...
775
+ batched_threaded_arg_lists = [[] for _ in range(num_args)]
776
+ for (
777
+ decision_time,
778
+ args,
779
+ ) in sorted_threaded_action_prob_args_by_decision_time.items():
780
+ if not args:
781
+ continue
782
+ for idx, arg in enumerate(args):
783
+ batched_threaded_arg_lists[idx].append(arg)
784
+
785
+ batched_threaded_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensors(
786
+ batched_threaded_arg_lists
787
+ )
788
+
789
+ # Note that we do NOT use the shared betas in the first arg to the weight function,
790
+ # since we don't want differentiation to happen with respect to them.
791
+ # Just grab the original beta from the update function arguments. This is the same
792
+ # value, but impervious to differentiation with respect to all_post_update_betas. The
793
+ # args, on the other hand, are a function of all_post_update_betas.
794
+ in_study_weights = jax.vmap(
795
+ fun=get_radon_nikodym_weight,
796
+ in_axes=[0, None, None, 0] + batch_axes,
797
+ out_axes=0,
798
+ )(
799
+ in_study_betas_list_by_decision_time_index,
800
+ action_prob_func,
801
+ action_prob_func_args_beta_index,
802
+ in_study_actions_list_by_decision_time_index,
803
+ *batched_threaded_arg_tensors,
804
+ )
805
+
806
+ in_study_index = 0
807
+ decision_time_to_all_weights_index_offset = min(
808
+ sorted_threaded_action_prob_args_by_decision_time
809
+ )
810
+ all_weights_raw = []
811
+ for (
812
+ decision_time,
813
+ args,
814
+ ) in sorted_threaded_action_prob_args_by_decision_time.items():
815
+ all_weights_raw.append(in_study_weights[in_study_index] if args else 1.0)
816
+ in_study_index += 1
817
+ all_weights = jnp.array(all_weights_raw)
818
+
819
+ algorithm_component = jnp.concatenate(
820
+ [
821
+ # Here we compute a product of Radon-Nikodym weights
822
+ # for all decision times after the first update and before the update
823
+ # update under consideration took effect, for which the user was in the study.
824
+ (
825
+ jnp.prod(
826
+ all_weights[
827
+ # The earliest time after the first update where the user was in
828
+ # the study
829
+ max(
830
+ first_time_after_first_update,
831
+ user_start_time,
832
+ )
833
+ - decision_time_to_all_weights_index_offset :
834
+ # One more than the latest time the user was in the study before the time
835
+ # the update under consideration first applied. Note the + 1 because range
836
+ # does not include the right endpoint.
837
+ min(
838
+ min_time_by_policy_num.get(policy_num, math.inf),
839
+ user_end_time + 1,
840
+ )
841
+ - decision_time_to_all_weights_index_offset,
842
+ ]
843
+ # If the user exited the study before there were any updates,
844
+ # this variable will be None and the above code to grab a weight would
845
+ # throw an error. Just use 1 to include the unweighted estimating function
846
+ # if they have data to contribute to the update.
847
+ if first_time_after_first_update is not None
848
+ else 1
849
+ ) # Now use the above to weight the alg estimating function for this update
850
+ * algorithm_estimating_func(*update_args)
851
+ # If there are no arguments for the update function, the user is not yet in the
852
+ # study, so we just add a zero vector contribution to the sum across users.
853
+ # Note that after they exit, they still contribute all their data to later
854
+ # updates.
855
+ if update_args
856
+ else jnp.zeros(beta_dim)
857
+ )
858
+ # vmapping over this would be tricky due to different shapes across updates
859
+ for policy_num, update_args in threaded_update_func_args_by_policy_num.items()
860
+ ]
861
+ )
862
+
863
+ if algorithm_component.size % beta_dim != 0:
864
+ raise ValueError(
865
+ "The algorithm component of the weighted estimating function stack does not have a "
866
+ "size that is a multiple of the beta dimension. This likely means that the "
867
+ "algorithm estimating function is not returning a vector of the correct size."
868
+ )
869
+
870
+ return algorithm_component