lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,7 +16,7 @@ from .helper_functions import (
16
16
  calculate_beta_dim,
17
17
  collect_all_post_update_betas,
18
18
  construct_beta_index_by_policy_num_map,
19
- extract_action_and_policy_by_decision_time_by_user_id,
19
+ extract_action_and_policy_by_decision_time_by_subject_id,
20
20
  flatten_params,
21
21
  get_min_time_by_policy_num,
22
22
  get_radon_nikodym_weight,
@@ -34,7 +34,13 @@ logging.basicConfig(
34
34
  )
35
35
 
36
36
 
37
- class TrialConditioningMonitor:
37
+ class DeploymentConditioningMonitor:
38
+ """
39
+ Experimental feature. Monitors the conditioning of the RL portion of the bread matrix.
40
+ Repeats more logic from post_deployment_analysis.py than is ideal, but this is for experimental use only.
41
+ Unit tests should be unskipped and expanded if this is to be used more broadly.
42
+ """
43
+
38
44
  whole_RL_block_conditioning_threshold = None
39
45
  diagonal_RL_block_conditioning_threshold = None
40
46
 
@@ -54,7 +60,7 @@ class TrialConditioningMonitor:
54
60
  def assess_update(
55
61
  self,
56
62
  proposed_policy_num: int | float,
57
- study_df: pd.DataFrame,
63
+ analysis_df: pd.DataFrame,
58
64
  action_prob_func: callable,
59
65
  action_prob_func_args: dict,
60
66
  action_prob_func_args_beta_index: int,
@@ -64,23 +70,24 @@ class TrialConditioningMonitor:
64
70
  alg_update_func_args_beta_index: int,
65
71
  alg_update_func_args_action_prob_index: int,
66
72
  alg_update_func_args_action_prob_times_index: int,
67
- in_study_col_name: str,
73
+ alg_update_func_args_previous_betas_index: int,
74
+ active_col_name: str,
68
75
  action_col_name: str,
69
76
  policy_num_col_name: str,
70
77
  calendar_t_col_name: str,
71
- user_id_col_name: str,
78
+ subject_id_col_name: str,
72
79
  action_prob_col_name: str,
73
80
  suppress_interactive_data_checks: bool,
74
81
  suppress_all_data_checks: bool,
75
82
  incremental: bool = True,
76
83
  ) -> None:
77
84
  """
78
- Analyzes a dataset to estimate parameters and variance using adaptive and classical sandwich estimators.
85
+ Analyzes a dataset to estimate parameters and variance using adjusted and classical sandwich estimators.
79
86
 
80
87
  Parameters:
81
88
  proposed_policy_num (int | float):
82
89
  The policy number of the proposed update.
83
- study_df (pd.DataFrame):
90
+ analysis_df (pd.DataFrame):
84
91
  DataFrame containing the study data.
85
92
  action_prob_func (str):
86
93
  Action probability function.
@@ -100,16 +107,18 @@ class TrialConditioningMonitor:
100
107
  Index for action probability in algorithm update function arguments.
101
108
  alg_update_func_args_action_prob_times_index (int):
102
109
  Index for action probability times in algorithm update function arguments.
103
- in_study_col_name (str):
104
- Column name indicating if a user is in the study in the study dataframe.
110
+ alg_update_func_args_previous_betas_index (int):
111
+ Index for previous betas in algorithm update function arguments.
112
+ active_col_name (str):
113
+ Column name indicating if a subject is in the study in the study dataframe.
105
114
  action_col_name (str):
106
115
  Column name for actions in the study dataframe.
107
116
  policy_num_col_name (str):
108
117
  Column name for policy numbers in the study dataframe.
109
118
  calendar_t_col_name (str):
110
119
  Column name for calendar time in the study dataframe.
111
- user_id_col_name (str):
112
- Column name for user IDs in the study dataframe.
120
+ subject_id_col_name (str):
121
+ Column name for subject IDs in the study dataframe.
113
122
  action_prob_col_name (str):
114
123
  Column name for action probabilities in the study dataframe.
115
124
  reward_col_name (str):
@@ -121,13 +130,13 @@ class TrialConditioningMonitor:
121
130
  small_sample_correction (str):
122
131
  Type of small sample correction to apply.
123
132
  collect_data_for_blowup_supervised_learning (bool):
124
- Whether to collect data for doing supervised learning about adaptive sandwich blowup.
125
- form_adaptive_meat_adjustments_explicitly (bool):
126
- If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
133
+ Whether to collect data for doing supervised learning about adjusted sandwich blowup.
134
+ form_adjusted_meat_adjustments_explicitly (bool):
135
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
127
136
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
128
- adaptive sandwich is formed without doing this.
129
- stabilize_joint_adaptive_bread_inverse (bool):
130
- If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
137
+ adjusted sandwich is formed without doing this.
138
+ stabilize_joint_bread (bool):
139
+ If True, stabilizes the joint bread matrix if it does not meet conditioning
131
140
  thresholds.
132
141
 
133
142
  Returns:
@@ -141,11 +150,11 @@ class TrialConditioningMonitor:
141
150
 
142
151
  if not suppress_all_data_checks:
143
152
  input_checks.perform_alg_only_input_checks(
144
- study_df,
145
- in_study_col_name,
153
+ analysis_df,
154
+ active_col_name,
146
155
  policy_num_col_name,
147
156
  calendar_t_col_name,
148
- user_id_col_name,
157
+ subject_id_col_name,
149
158
  action_prob_col_name,
150
159
  action_prob_func,
151
160
  action_prob_func_args,
@@ -154,12 +163,13 @@ class TrialConditioningMonitor:
154
163
  alg_update_func_args_beta_index,
155
164
  alg_update_func_args_action_prob_index,
156
165
  alg_update_func_args_action_prob_times_index,
166
+ alg_update_func_args_previous_betas_index,
157
167
  suppress_interactive_data_checks,
158
168
  )
159
169
 
160
170
  beta_index_by_policy_num, initial_policy_num = (
161
171
  construct_beta_index_by_policy_num_map(
162
- study_df, policy_num_col_name, in_study_col_name
172
+ analysis_df, policy_num_col_name, active_col_name
163
173
  )
164
174
  )
165
175
  # We augment the produced map to include the proposed policy num.
@@ -174,22 +184,23 @@ class TrialConditioningMonitor:
174
184
  alg_update_func_args_beta_index,
175
185
  )
176
186
 
177
- action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id = (
178
- extract_action_and_policy_by_decision_time_by_user_id(
179
- study_df,
180
- user_id_col_name,
181
- in_study_col_name,
182
- calendar_t_col_name,
183
- action_col_name,
184
- policy_num_col_name,
185
- )
187
+ (
188
+ action_by_decision_time_by_subject_id,
189
+ policy_num_by_decision_time_by_subject_id,
190
+ ) = extract_action_and_policy_by_decision_time_by_subject_id(
191
+ analysis_df,
192
+ subject_id_col_name,
193
+ active_col_name,
194
+ calendar_t_col_name,
195
+ action_col_name,
196
+ policy_num_col_name,
186
197
  )
187
198
 
188
- user_ids = jnp.array(study_df[user_id_col_name].unique())
199
+ subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
189
200
 
190
201
  phi_dot_bar, avg_estimating_function_stack = self.construct_phi_dot_bar_so_far(
191
202
  all_post_update_betas,
192
- user_ids,
203
+ subject_ids,
193
204
  action_prob_func,
194
205
  action_prob_func_args_beta_index,
195
206
  alg_update_func,
@@ -197,12 +208,13 @@ class TrialConditioningMonitor:
197
208
  alg_update_func_args_beta_index,
198
209
  alg_update_func_args_action_prob_index,
199
210
  alg_update_func_args_action_prob_times_index,
211
+ alg_update_func_args_previous_betas_index,
200
212
  action_prob_func_args,
201
- policy_num_by_decision_time_by_user_id,
213
+ policy_num_by_decision_time_by_subject_id,
202
214
  initial_policy_num,
203
215
  beta_index_by_policy_num,
204
216
  alg_update_func_args,
205
- action_by_decision_time_by_user_id,
217
+ action_by_decision_time_by_subject_id,
206
218
  suppress_all_data_checks,
207
219
  suppress_interactive_data_checks,
208
220
  incremental=incremental,
@@ -225,7 +237,7 @@ class TrialConditioningMonitor:
225
237
 
226
238
  if whole_RL_block_condition_number > self.whole_RL_block_conditioning_threshold:
227
239
  logger.warning(
228
- "The RL portion of the bread inverse up to this point exceeds the threshold set (condition number: %s, threshold: %s). Consider an alternative update strategy which produces less dependence on previous RL parameters (via the data they produced) and/or improves the conditioning of each update itself. Regularization may help with both of these.",
240
+ "The RL portion of the bread up to this point exceeds the threshold set (condition number: %s, threshold: %s). Consider an alternative update strategy which produces less dependence on previous RL parameters (via the data they produced) and/or improves the conditioning of each update itself. Regularization may help with both of these.",
229
241
  whole_RL_block_condition_number,
230
242
  self.whole_RL_block_conditioning_threshold,
231
243
  )
@@ -236,7 +248,7 @@ class TrialConditioningMonitor:
236
248
  > self.diagonal_RL_block_conditioning_threshold
237
249
  ):
238
250
  logger.warning(
239
- "The diagonal RL block of the bread inverse up to this point exceeds the threshold set (condition number: %s, threshold: %s). This may illustrate a fundamental problem with the conditioning of the RL update procedure.",
251
+ "The diagonal RL block of the bread up to this point exceeds the threshold set (condition number: %s, threshold: %s). This may illustrate a fundamental problem with the conditioning of the RL update procedure.",
240
252
  new_diagonal_RL_block_condition_number,
241
253
  self.diagonal_RL_block_conditioning_threshold,
242
254
  )
@@ -259,7 +271,7 @@ class TrialConditioningMonitor:
259
271
  def construct_phi_dot_bar_so_far(
260
272
  self,
261
273
  all_post_update_betas: jnp.ndarray,
262
- user_ids: jnp.ndarray,
274
+ subject_ids: jnp.ndarray,
263
275
  action_prob_func: callable,
264
276
  action_prob_func_args_beta_index: int,
265
277
  alg_update_func: callable,
@@ -267,18 +279,19 @@ class TrialConditioningMonitor:
267
279
  alg_update_func_args_beta_index: int,
268
280
  alg_update_func_args_action_prob_index: int,
269
281
  alg_update_func_args_action_prob_times_index: int,
270
- action_prob_func_args_by_user_id_by_decision_time: dict[
282
+ alg_update_func_args_previous_betas_index: int,
283
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
271
284
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
272
285
  ],
273
- policy_num_by_decision_time_by_user_id: dict[
286
+ policy_num_by_decision_time_by_subject_id: dict[
274
287
  collections.abc.Hashable, dict[int, int | float]
275
288
  ],
276
289
  initial_policy_num: int | float,
277
290
  beta_index_by_policy_num: dict[int | float, int],
278
- update_func_args_by_by_user_id_by_policy_num: dict[
291
+ update_func_args_by_by_subject_id_by_policy_num: dict[
279
292
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
280
293
  ],
281
- action_by_decision_time_by_user_id: dict[
294
+ action_by_decision_time_by_subject_id: dict[
282
295
  collections.abc.Hashable, dict[int, int]
283
296
  ],
284
297
  suppress_all_data_checks: bool,
@@ -289,18 +302,18 @@ class TrialConditioningMonitor:
289
302
  jnp.ndarray[jnp.float32],
290
303
  ]:
291
304
  """
292
- Constructs the classical and adaptive inverse bread and meat matrices, as well as the average
305
+ Constructs the classical and bread and meat matrices, as well as the average
293
306
  estimating function stack and some other intermediate pieces.
294
307
 
295
308
  This is done by computing and differentiating the average weighted estimating function stack
296
- with respect to the betas and theta, using the resulting Jacobian to compute the inverse bread
309
+ with respect to the betas and theta, using the resulting Jacobian to compute the bread
297
310
  and meat matrices, and then stably computing sandwiches.
298
311
 
299
312
  Args:
300
313
  all_post_update_betas (jnp.ndarray):
301
314
  A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
302
- user_ids (jnp.ndarray):
303
- A 1-D JAX NumPy array holding all user IDs in the study.
315
+ subject_ids (jnp.ndarray):
316
+ A 1-D JAX NumPy array holding all subject IDs in the study.
304
317
  action_prob_func (callable):
305
318
  The action probability function.
306
319
  action_prob_func_args_beta_index (int):
@@ -317,22 +330,24 @@ class TrialConditioningMonitor:
317
330
  alg_update_func_args_action_prob_times_index (int):
318
331
  The index in the update function arguments tuple where an array of times for which the
319
332
  given action probabilities apply is provided, if applicable. -1 otherwise.
320
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
321
- A dictionary mapping decision times to maps of user ids to the function arguments
322
- required to compute action probabilities for this user.
323
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
324
- A map of user ids to dictionaries mapping decision times to the policy number in use.
333
+ alg_update_func_args_previous_betas_index (int):
334
+ The index in the update function arguments tuple where the previous betas are provided, if applicable. -1 otherwise.
335
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
336
+ A dictionary mapping decision times to maps of subject ids to the function arguments
337
+ required to compute action probabilities for this subject.
338
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
339
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
325
340
  Only applies to in-study decision times!
326
341
  initial_policy_num (int | float):
327
342
  The policy number of the initial policy before any updates.
328
343
  beta_index_by_policy_num (dict[int | float, int]):
329
344
  A dictionary mapping policy numbers to the index of the corresponding beta in
330
345
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
331
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
332
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
346
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
347
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
333
348
  to their respective update function arguments.
334
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
335
- A dictionary mapping user IDs to their respective actions taken at each decision time.
349
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
350
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
336
351
  Only applies to in-study decision times!
337
352
  suppress_all_data_checks (bool):
338
353
  If True, suppresses carrying out any data checks at all.
@@ -362,7 +377,7 @@ class TrialConditioningMonitor:
362
377
  # here to improve performance. We can simply unflatten them inside the function.
363
378
  flatten_params(all_post_update_betas, jnp.array([])),
364
379
  beta_dim,
365
- user_ids,
380
+ subject_ids,
366
381
  action_prob_func,
367
382
  action_prob_func_args_beta_index,
368
383
  alg_update_func,
@@ -370,12 +385,13 @@ class TrialConditioningMonitor:
370
385
  alg_update_func_args_beta_index,
371
386
  alg_update_func_args_action_prob_index,
372
387
  alg_update_func_args_action_prob_times_index,
373
- action_prob_func_args_by_user_id_by_decision_time,
374
- policy_num_by_decision_time_by_user_id,
388
+ alg_update_func_args_previous_betas_index,
389
+ action_prob_func_args_by_subject_id_by_decision_time,
390
+ policy_num_by_decision_time_by_subject_id,
375
391
  initial_policy_num,
376
392
  beta_index_by_policy_num,
377
- update_func_args_by_by_user_id_by_policy_num,
378
- action_by_decision_time_by_user_id,
393
+ update_func_args_by_by_subject_id_by_policy_num,
394
+ action_by_decision_time_by_subject_id,
379
395
  suppress_all_data_checks,
380
396
  suppress_interactive_data_checks,
381
397
  only_latest_block=True,
@@ -404,7 +420,7 @@ class TrialConditioningMonitor:
404
420
  # here to improve performance. We can simply unflatten them inside the function.
405
421
  flatten_params(all_post_update_betas, jnp.array([])),
406
422
  beta_dim,
407
- user_ids,
423
+ subject_ids,
408
424
  action_prob_func,
409
425
  action_prob_func_args_beta_index,
410
426
  alg_update_func,
@@ -412,12 +428,13 @@ class TrialConditioningMonitor:
412
428
  alg_update_func_args_beta_index,
413
429
  alg_update_func_args_action_prob_index,
414
430
  alg_update_func_args_action_prob_times_index,
415
- action_prob_func_args_by_user_id_by_decision_time,
416
- policy_num_by_decision_time_by_user_id,
431
+ alg_update_func_args_previous_betas_index,
432
+ action_prob_func_args_by_subject_id_by_decision_time,
433
+ policy_num_by_decision_time_by_subject_id,
417
434
  initial_policy_num,
418
435
  beta_index_by_policy_num,
419
- update_func_args_by_by_user_id_by_policy_num,
420
- action_by_decision_time_by_user_id,
436
+ update_func_args_by_by_subject_id_by_policy_num,
437
+ action_by_decision_time_by_subject_id,
421
438
  suppress_all_data_checks,
422
439
  suppress_interactive_data_checks,
423
440
  )
@@ -429,7 +446,7 @@ class TrialConditioningMonitor:
429
446
  self,
430
447
  flattened_betas_and_theta: jnp.ndarray,
431
448
  beta_dim: int,
432
- user_ids: jnp.ndarray,
449
+ subject_ids: jnp.ndarray,
433
450
  action_prob_func: callable,
434
451
  action_prob_func_args_beta_index: int,
435
452
  alg_update_func: callable,
@@ -437,18 +454,19 @@ class TrialConditioningMonitor:
437
454
  alg_update_func_args_beta_index: int,
438
455
  alg_update_func_args_action_prob_index: int,
439
456
  alg_update_func_args_action_prob_times_index: int,
440
- action_prob_func_args_by_user_id_by_decision_time: dict[
457
+ alg_update_func_args_previous_betas_index: int,
458
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
441
459
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
442
460
  ],
443
- policy_num_by_decision_time_by_user_id: dict[
461
+ policy_num_by_decision_time_by_subject_id: dict[
444
462
  collections.abc.Hashable, dict[int, int | float]
445
463
  ],
446
464
  initial_policy_num: int | float,
447
465
  beta_index_by_policy_num: dict[int | float, int],
448
- update_func_args_by_by_user_id_by_policy_num: dict[
466
+ update_func_args_by_by_subject_id_by_policy_num: dict[
449
467
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
450
468
  ],
451
- action_by_decision_time_by_user_id: dict[
469
+ action_by_decision_time_by_subject_id: dict[
452
470
  collections.abc.Hashable, dict[int, int]
453
471
  ],
454
472
  suppress_all_data_checks: bool,
@@ -459,8 +477,8 @@ class TrialConditioningMonitor:
459
477
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
460
478
  ]:
461
479
  """
462
- Computes the average weighted estimating function stack across all users, along with
463
- auxiliary values used to construct the adaptive and classical sandwich variances.
480
+ Computes the average weighted estimating function stack across all subjects, along with
481
+ auxiliary values used to construct the adjusted and classical sandwich variances.
464
482
 
465
483
  If only_latest_block is True, only uses data from the most recent update.
466
484
 
@@ -471,8 +489,8 @@ class TrialConditioningMonitor:
471
489
  We simply extract the betas and theta from this array below.
472
490
  beta_dim (int):
473
491
  The dimension of each of the beta parameters.
474
- user_ids (jnp.ndarray):
475
- A 1D JAX NumPy array of user IDs.
492
+ subject_ids (jnp.ndarray):
493
+ A 1D JAX NumPy array of subject IDs.
476
494
  action_prob_func (callable):
477
495
  The action probability function.
478
496
  action_prob_func_args_beta_index (int):
@@ -489,22 +507,24 @@ class TrialConditioningMonitor:
489
507
  alg_update_func_args_action_prob_times_index (int):
490
508
  The index in the update function arguments tuple where an array of times for which the
491
509
  given action probabilities apply is provided, if applicable. -1 otherwise.
492
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
493
- A dictionary mapping decision times to maps of user ids to the function arguments
494
- required to compute action probabilities for this user.
495
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
496
- A map of user ids to dictionaries mapping decision times to the policy number in use.
510
+ alg_update_func_args_previous_betas_index (int):
511
+ The index in the update function arguments tuple where the previous betas are provided, if applicable. -1 otherwise.
512
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
513
+ A dictionary mapping decision times to maps of subject ids to the function arguments
514
+ required to compute action probabilities for this subject.
515
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
516
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
497
517
  Only applies to in-study decision times!
498
518
  initial_policy_num (int | float):
499
519
  The policy number of the initial policy before any updates.
500
520
  beta_index_by_policy_num (dict[int | float, int]):
501
521
  A dictionary mapping policy numbers to the index of the corresponding beta in
502
522
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
503
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
504
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
523
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
524
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
505
525
  to their respective update function arguments.
506
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
507
- A dictionary mapping user IDs to their respective actions taken at each decision time.
526
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
527
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
508
528
  Only applies to in-study decision times!
509
529
  suppress_all_data_checks (bool):
510
530
  If True, suppresses carrying out any data checks at all.
@@ -536,15 +556,15 @@ class TrialConditioningMonitor:
536
556
  # 1. If only_latest_block is True, we need to filter all the arguments to only
537
557
  # include those relevant to the latest update. We still need action probabilities
538
558
  # from the beginning for the weights, but the update function args can be trimmed
539
- # to the max policy so that the loop single_user_weighted_RL_estimating_function_stacker
559
+ # to the max policy so that the loop single_subject_weighted_RL_estimating_function_stacker
540
560
  # is only over one policy.
541
561
  if only_latest_block:
542
562
  logger.info(
543
563
  "Filtering algorithm update function arguments to only include those relevant to the latest update."
544
564
  )
545
565
  max_policy_num = max(beta_index_by_policy_num)
546
- update_func_args_by_by_user_id_by_policy_num = {
547
- max_policy_num: update_func_args_by_by_user_id_by_policy_num[
566
+ update_func_args_by_by_subject_id_by_policy_num = {
567
+ max_policy_num: update_func_args_by_by_subject_id_by_policy_num[
548
568
  max_policy_num
549
569
  ]
550
570
  }
@@ -553,15 +573,17 @@ class TrialConditioningMonitor:
553
573
  # supplied for the above functions, so that differentiation works correctly. The existing
554
574
  # values should be the same, but not connected to the parameter we are differentiating
555
575
  # with respect to. Note we will also find it useful below to have the action probability args
556
- # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
576
+ # nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
557
577
 
558
- logger.info("Threading in betas to action probability arguments for all users.")
578
+ logger.info(
579
+ "Threading in betas to action probability arguments for all subjects."
580
+ )
559
581
  (
560
- threaded_action_prob_func_args_by_decision_time_by_user_id,
561
- action_prob_func_args_by_decision_time_by_user_id,
582
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
583
+ action_prob_func_args_by_decision_time_by_subject_id,
562
584
  ) = thread_action_prob_func_args(
563
- action_prob_func_args_by_user_id_by_decision_time,
564
- policy_num_by_decision_time_by_user_id,
585
+ action_prob_func_args_by_subject_id_by_decision_time,
586
+ policy_num_by_decision_time_by_subject_id,
565
587
  initial_policy_num,
566
588
  betas,
567
589
  beta_index_by_policy_num,
@@ -573,16 +595,17 @@ class TrialConditioningMonitor:
573
595
  # arguments with the central betas introduced.
574
596
  logger.info(
575
597
  "Threading in betas and beta-dependent action probabilities to algorithm update "
576
- "function args for all users"
598
+ "function args for all subjects"
577
599
  )
578
- threaded_update_func_args_by_policy_num_by_user_id = thread_update_func_args(
579
- update_func_args_by_by_user_id_by_policy_num,
600
+ threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
601
+ update_func_args_by_by_subject_id_by_policy_num,
580
602
  betas,
581
603
  beta_index_by_policy_num,
582
604
  alg_update_func_args_beta_index,
583
605
  alg_update_func_args_action_prob_index,
584
606
  alg_update_func_args_action_prob_times_index,
585
- threaded_action_prob_func_args_by_decision_time_by_user_id,
607
+ alg_update_func_args_previous_betas_index,
608
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
586
609
  action_prob_func,
587
610
  )
588
611
 
@@ -592,42 +615,44 @@ class TrialConditioningMonitor:
592
615
  if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
593
616
  input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
594
617
  algorithm_estimating_func,
595
- update_func_args_by_by_user_id_by_policy_num,
596
- threaded_update_func_args_by_policy_num_by_user_id,
618
+ update_func_args_by_by_subject_id_by_policy_num,
619
+ threaded_update_func_args_by_policy_num_by_subject_id,
597
620
  suppress_interactive_data_checks,
598
621
  )
599
622
 
600
- # 5. Now we can compute the weighted estimating function stacks for all users
601
- # as well as collect related values used to construct the adaptive and classical
623
+ # 5. Now we can compute the weighted estimating function stacks for all subjects
624
+ # as well as collect related values used to construct the adjusted and classical
602
625
  # sandwich variances.
603
626
  RL_stacks = jnp.array(
604
627
  [
605
- self.single_user_weighted_RL_estimating_function_stacker(
628
+ self.single_subject_weighted_RL_estimating_function_stacker(
606
629
  beta_dim,
607
- user_id,
630
+ subject_id,
608
631
  action_prob_func,
609
632
  algorithm_estimating_func,
610
633
  action_prob_func_args_beta_index,
611
- action_prob_func_args_by_decision_time_by_user_id[user_id],
612
- threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
613
- threaded_update_func_args_by_policy_num_by_user_id[user_id],
614
- policy_num_by_decision_time_by_user_id[user_id],
615
- action_by_decision_time_by_user_id[user_id],
634
+ action_prob_func_args_by_decision_time_by_subject_id[subject_id],
635
+ threaded_action_prob_func_args_by_decision_time_by_subject_id[
636
+ subject_id
637
+ ],
638
+ threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
639
+ policy_num_by_decision_time_by_subject_id[subject_id],
640
+ action_by_decision_time_by_subject_id[subject_id],
616
641
  beta_index_by_policy_num,
617
642
  )
618
- for user_id in user_ids.tolist()
643
+ for subject_id in subject_ids.tolist()
619
644
  ]
620
645
  )
621
646
 
622
647
  # 6. We will differentiate the first output, while the second will be used
623
648
  # for an estimating function sum check.
624
- mean_stack_across_users = jnp.mean(RL_stacks, axis=0)
625
- return mean_stack_across_users, mean_stack_across_users
649
+ mean_stack_across_subjects = jnp.mean(RL_stacks, axis=0)
650
+ return mean_stack_across_subjects, mean_stack_across_subjects
626
651
 
627
- def single_user_weighted_RL_estimating_function_stacker(
652
+ def single_subject_weighted_RL_estimating_function_stacker(
628
653
  self,
629
654
  beta_dim: int,
630
- user_id: collections.abc.Hashable,
655
+ subject_id: collections.abc.Hashable,
631
656
  action_prob_func: callable,
632
657
  algorithm_estimating_func: callable,
633
658
  action_prob_func_args_beta_index: int,
@@ -660,12 +685,12 @@ class TrialConditioningMonitor:
660
685
  beta_dim (list[jnp.ndarray]):
661
686
  A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
662
687
 
663
- user_id (collections.abc.Hashable):
664
- The user ID for which to compute the weighted estimating function stack.
688
+ subject_id (collections.abc.Hashable):
689
+ The subject ID for which to compute the weighted estimating function stack.
665
690
 
666
691
  action_prob_func (callable):
667
692
  The function used to compute the probability of action 1 at a given decision time for
668
- a particular user given their state and the algorithm parameters.
693
+ a particular subject given their state and the algorithm parameters.
669
694
 
670
695
  algorithm_estimating_func (callable):
671
696
  The estimating function that corresponds to algorithm updates.
@@ -674,7 +699,7 @@ class TrialConditioningMonitor:
674
699
  The index of the beta argument in the action probability function's arguments.
675
700
 
676
701
  action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
677
- A map from decision times to tuples of arguments for this user for the action
702
+ A map from decision times to tuples of arguments for this subject for the action
678
703
  probability function. This is for all decision times (args are an empty
679
704
  tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
680
705
  ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
@@ -687,13 +712,13 @@ class TrialConditioningMonitor:
687
712
 
688
713
  threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
689
714
  A map from policy numbers to tuples containing the arguments for
690
- the corresponding estimating functions for this user, with the shared betas threaded in
715
+ the corresponding estimating functions for this subject, with the shared betas threaded in
691
716
  for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
692
717
  should be sorted.
693
718
 
694
719
  policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
695
720
  A dictionary mapping decision times to the policy number in use. This may be
696
- user-specific. Should be sorted by decision time. Only applies to in-study decision
721
+ subject-specific. Should be sorted by decision time. Only applies to in-study decision
697
722
  times!
698
723
 
699
724
  action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
@@ -705,18 +730,18 @@ class TrialConditioningMonitor:
705
730
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
706
731
 
707
732
  Returns:
708
- jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the user's weighted
733
+ jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the subject's weighted
709
734
  estimating function stack.
710
735
  """
711
736
 
712
737
  logger.info(
713
- "Computing weighted estimating function stack for user %s.", user_id
738
+ "Computing weighted estimating function stack for subject %s.", subject_id
714
739
  )
715
740
 
716
741
  # First, reformat the supplied data into more convenient structures.
717
742
 
718
743
  # 1. Form a dictionary mapping policy numbers to the first time they were
719
- # applicable (for this user). Note that this includes ALL policies, initial
744
+ # applicable (for this subject). Note that this includes ALL policies, initial
720
745
  # fallbacks included.
721
746
  # Collect the first time after the first update separately for convenience.
722
747
  # These are both used to form the Radon-Nikodym weights for the right times.
@@ -727,17 +752,17 @@ class TrialConditioningMonitor:
727
752
  )
728
753
  )
729
754
 
730
- # 2. Get the start and end times for this user.
731
- user_start_time = math.inf
732
- user_end_time = -math.inf
755
+ # 2. Get the start and end times for this subject.
756
+ subject_start_time = math.inf
757
+ subject_end_time = -math.inf
733
758
  for decision_time in action_by_decision_time:
734
- user_start_time = min(user_start_time, decision_time)
735
- user_end_time = max(user_end_time, decision_time)
759
+ subject_start_time = min(subject_start_time, decision_time)
760
+ subject_end_time = max(subject_end_time, decision_time)
736
761
 
737
762
  # 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
738
763
  logger.info(
739
- "Computing the algorithm component of the weighted estimating function stack for user %s.",
740
- user_id,
764
+ "Computing the algorithm component of the weighted estimating function stack for subject %s.",
765
+ subject_id,
741
766
  )
742
767
 
743
768
  in_study_action_prob_func_args = [
@@ -754,13 +779,13 @@ class TrialConditioningMonitor:
754
779
  )
755
780
 
756
781
  # Sort the threaded args by decision time to be cautious. We check if the
757
- # user id is present in the user args dict because we may call this on a
758
- # subset of the user arg dict when we are batching arguments by shape
782
+ # subject id is present in the subject args dict because we may call this on a
783
+ # subset of the subject arg dict when we are batching arguments by shape
759
784
  sorted_threaded_action_prob_args_by_decision_time = {
760
785
  decision_time: threaded_action_prob_func_args_by_decision_time[
761
786
  decision_time
762
787
  ]
763
- for decision_time in range(user_start_time, user_end_time + 1)
788
+ for decision_time in range(subject_start_time, subject_end_time + 1)
764
789
  if decision_time in threaded_action_prob_func_args_by_decision_time
765
790
  }
766
791
 
@@ -820,27 +845,27 @@ class TrialConditioningMonitor:
820
845
  [
821
846
  # Here we compute a product of Radon-Nikodym weights
822
847
  # for all decision times after the first update and before the update
823
- # update under consideration took effect, for which the user was in the study.
848
+ # update under consideration took effect, for which the subject was in the study.
824
849
  (
825
850
  jnp.prod(
826
851
  all_weights[
827
- # The earliest time after the first update where the user was in
852
+ # The earliest time after the first update where the subject was in
828
853
  # the study
829
854
  max(
830
855
  first_time_after_first_update,
831
- user_start_time,
856
+ subject_start_time,
832
857
  )
833
858
  - decision_time_to_all_weights_index_offset :
834
- # One more than the latest time the user was in the study before the time
859
+ # One more than the latest time the subject was in the study before the time
835
860
  # the update under consideration first applied. Note the + 1 because range
836
861
  # does not include the right endpoint.
837
862
  min(
838
863
  min_time_by_policy_num.get(policy_num, math.inf),
839
- user_end_time + 1,
864
+ subject_end_time + 1,
840
865
  )
841
866
  - decision_time_to_all_weights_index_offset,
842
867
  ]
843
- # If the user exited the study before there were any updates,
868
+ # If the subject exited the study before there were any updates,
844
869
  # this variable will be None and the above code to grab a weight would
845
870
  # throw an error. Just use 1 to include the unweighted estimating function
846
871
  # if they have data to contribute to the update.
@@ -848,8 +873,8 @@ class TrialConditioningMonitor:
848
873
  else 1
849
874
  ) # Now use the above to weight the alg estimating function for this update
850
875
  * algorithm_estimating_func(*update_args)
851
- # If there are no arguments for the update function, the user is not yet in the
852
- # study, so we just add a zero vector contribution to the sum across users.
876
+ # If there are no arguments for the update function, the subject is not yet in the
877
+ # study, so we just add a zero vector contribution to the sum across subjects.
853
878
  # Note that after they exit, they still contribute all their data to later
854
879
  # updates.
855
880
  if update_args