lifejacket 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,7 +16,7 @@ from .helper_functions import (
16
16
  calculate_beta_dim,
17
17
  collect_all_post_update_betas,
18
18
  construct_beta_index_by_policy_num_map,
19
- extract_action_and_policy_by_decision_time_by_user_id,
19
+ extract_action_and_policy_by_decision_time_by_subject_id,
20
20
  flatten_params,
21
21
  get_min_time_by_policy_num,
22
22
  get_radon_nikodym_weight,
@@ -34,7 +34,7 @@ logging.basicConfig(
34
34
  )
35
35
 
36
36
 
37
- class TrialConditioningMonitor:
37
+ class DeploymentConditioningMonitor:
38
38
  whole_RL_block_conditioning_threshold = None
39
39
  diagonal_RL_block_conditioning_threshold = None
40
40
 
@@ -54,7 +54,7 @@ class TrialConditioningMonitor:
54
54
  def assess_update(
55
55
  self,
56
56
  proposed_policy_num: int | float,
57
- study_df: pd.DataFrame,
57
+ analysis_df: pd.DataFrame,
58
58
  action_prob_func: callable,
59
59
  action_prob_func_args: dict,
60
60
  action_prob_func_args_beta_index: int,
@@ -64,11 +64,12 @@ class TrialConditioningMonitor:
64
64
  alg_update_func_args_beta_index: int,
65
65
  alg_update_func_args_action_prob_index: int,
66
66
  alg_update_func_args_action_prob_times_index: int,
67
- in_study_col_name: str,
67
+ alg_update_func_args_previous_betas_index: int,
68
+ active_col_name: str,
68
69
  action_col_name: str,
69
70
  policy_num_col_name: str,
70
71
  calendar_t_col_name: str,
71
- user_id_col_name: str,
72
+ subject_id_col_name: str,
72
73
  action_prob_col_name: str,
73
74
  suppress_interactive_data_checks: bool,
74
75
  suppress_all_data_checks: bool,
@@ -80,7 +81,7 @@ class TrialConditioningMonitor:
80
81
  Parameters:
81
82
  proposed_policy_num (int | float):
82
83
  The policy number of the proposed update.
83
- study_df (pd.DataFrame):
84
+ analysis_df (pd.DataFrame):
84
85
  DataFrame containing the study data.
85
86
  action_prob_func (str):
86
87
  Action probability function.
@@ -100,16 +101,18 @@ class TrialConditioningMonitor:
100
101
  Index for action probability in algorithm update function arguments.
101
102
  alg_update_func_args_action_prob_times_index (int):
102
103
  Index for action probability times in algorithm update function arguments.
103
- in_study_col_name (str):
104
- Column name indicating if a user is in the study in the study dataframe.
104
+ alg_update_func_args_previous_betas_index (int):
105
+ Index for previous betas in algorithm update function arguments.
106
+ active_col_name (str):
107
+ Column name indicating if a subject is in the study in the study dataframe.
105
108
  action_col_name (str):
106
109
  Column name for actions in the study dataframe.
107
110
  policy_num_col_name (str):
108
111
  Column name for policy numbers in the study dataframe.
109
112
  calendar_t_col_name (str):
110
113
  Column name for calendar time in the study dataframe.
111
- user_id_col_name (str):
112
- Column name for user IDs in the study dataframe.
114
+ subject_id_col_name (str):
115
+ Column name for subject IDs in the study dataframe.
113
116
  action_prob_col_name (str):
114
117
  Column name for action probabilities in the study dataframe.
115
118
  reward_col_name (str):
@@ -122,11 +125,11 @@ class TrialConditioningMonitor:
122
125
  Type of small sample correction to apply.
123
126
  collect_data_for_blowup_supervised_learning (bool):
124
127
  Whether to collect data for doing supervised learning about adaptive sandwich blowup.
125
- form_adaptive_meat_adjustments_explicitly (bool):
126
- If True, explicitly forms the per-user meat adjustments that differentiate the adaptive
128
+ form_adjusted_meat_adjustments_explicitly (bool):
129
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
127
130
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
128
131
  adaptive sandwich is formed without doing this.
129
- stabilize_joint_adaptive_bread_inverse (bool):
132
+ stabilize_joint_adjusted_bread_inverse (bool):
130
133
  If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
131
134
  thresholds.
132
135
 
@@ -141,11 +144,11 @@ class TrialConditioningMonitor:
141
144
 
142
145
  if not suppress_all_data_checks:
143
146
  input_checks.perform_alg_only_input_checks(
144
- study_df,
145
- in_study_col_name,
147
+ analysis_df,
148
+ active_col_name,
146
149
  policy_num_col_name,
147
150
  calendar_t_col_name,
148
- user_id_col_name,
151
+ subject_id_col_name,
149
152
  action_prob_col_name,
150
153
  action_prob_func,
151
154
  action_prob_func_args,
@@ -159,7 +162,7 @@ class TrialConditioningMonitor:
159
162
 
160
163
  beta_index_by_policy_num, initial_policy_num = (
161
164
  construct_beta_index_by_policy_num_map(
162
- study_df, policy_num_col_name, in_study_col_name
165
+ analysis_df, policy_num_col_name, active_col_name
163
166
  )
164
167
  )
165
168
  # We augment the produced map to include the proposed policy num.
@@ -174,22 +177,23 @@ class TrialConditioningMonitor:
174
177
  alg_update_func_args_beta_index,
175
178
  )
176
179
 
177
- action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id = (
178
- extract_action_and_policy_by_decision_time_by_user_id(
179
- study_df,
180
- user_id_col_name,
181
- in_study_col_name,
182
- calendar_t_col_name,
183
- action_col_name,
184
- policy_num_col_name,
185
- )
180
+ (
181
+ action_by_decision_time_by_subject_id,
182
+ policy_num_by_decision_time_by_subject_id,
183
+ ) = extract_action_and_policy_by_decision_time_by_subject_id(
184
+ analysis_df,
185
+ subject_id_col_name,
186
+ active_col_name,
187
+ calendar_t_col_name,
188
+ action_col_name,
189
+ policy_num_col_name,
186
190
  )
187
191
 
188
- user_ids = jnp.array(study_df[user_id_col_name].unique())
192
+ subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
189
193
 
190
194
  phi_dot_bar, avg_estimating_function_stack = self.construct_phi_dot_bar_so_far(
191
195
  all_post_update_betas,
192
- user_ids,
196
+ subject_ids,
193
197
  action_prob_func,
194
198
  action_prob_func_args_beta_index,
195
199
  alg_update_func,
@@ -197,12 +201,13 @@ class TrialConditioningMonitor:
197
201
  alg_update_func_args_beta_index,
198
202
  alg_update_func_args_action_prob_index,
199
203
  alg_update_func_args_action_prob_times_index,
204
+ alg_update_func_args_previous_betas_index,
200
205
  action_prob_func_args,
201
- policy_num_by_decision_time_by_user_id,
206
+ policy_num_by_decision_time_by_subject_id,
202
207
  initial_policy_num,
203
208
  beta_index_by_policy_num,
204
209
  alg_update_func_args,
205
- action_by_decision_time_by_user_id,
210
+ action_by_decision_time_by_subject_id,
206
211
  suppress_all_data_checks,
207
212
  suppress_interactive_data_checks,
208
213
  incremental=incremental,
@@ -259,7 +264,7 @@ class TrialConditioningMonitor:
259
264
  def construct_phi_dot_bar_so_far(
260
265
  self,
261
266
  all_post_update_betas: jnp.ndarray,
262
- user_ids: jnp.ndarray,
267
+ subject_ids: jnp.ndarray,
263
268
  action_prob_func: callable,
264
269
  action_prob_func_args_beta_index: int,
265
270
  alg_update_func: callable,
@@ -267,18 +272,19 @@ class TrialConditioningMonitor:
267
272
  alg_update_func_args_beta_index: int,
268
273
  alg_update_func_args_action_prob_index: int,
269
274
  alg_update_func_args_action_prob_times_index: int,
270
- action_prob_func_args_by_user_id_by_decision_time: dict[
275
+ alg_update_func_args_previous_betas_index: int,
276
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
271
277
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
272
278
  ],
273
- policy_num_by_decision_time_by_user_id: dict[
279
+ policy_num_by_decision_time_by_subject_id: dict[
274
280
  collections.abc.Hashable, dict[int, int | float]
275
281
  ],
276
282
  initial_policy_num: int | float,
277
283
  beta_index_by_policy_num: dict[int | float, int],
278
- update_func_args_by_by_user_id_by_policy_num: dict[
284
+ update_func_args_by_by_subject_id_by_policy_num: dict[
279
285
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
280
286
  ],
281
- action_by_decision_time_by_user_id: dict[
287
+ action_by_decision_time_by_subject_id: dict[
282
288
  collections.abc.Hashable, dict[int, int]
283
289
  ],
284
290
  suppress_all_data_checks: bool,
@@ -299,8 +305,8 @@ class TrialConditioningMonitor:
299
305
  Args:
300
306
  all_post_update_betas (jnp.ndarray):
301
307
  A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
302
- user_ids (jnp.ndarray):
303
- A 1-D JAX NumPy array holding all user IDs in the study.
308
+ subject_ids (jnp.ndarray):
309
+ A 1-D JAX NumPy array holding all subject IDs in the study.
304
310
  action_prob_func (callable):
305
311
  The action probability function.
306
312
  action_prob_func_args_beta_index (int):
@@ -317,22 +323,24 @@ class TrialConditioningMonitor:
317
323
  alg_update_func_args_action_prob_times_index (int):
318
324
  The index in the update function arguments tuple where an array of times for which the
319
325
  given action probabilities apply is provided, if applicable. -1 otherwise.
320
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
321
- A dictionary mapping decision times to maps of user ids to the function arguments
322
- required to compute action probabilities for this user.
323
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
324
- A map of user ids to dictionaries mapping decision times to the policy number in use.
326
+ alg_update_func_args_previous_betas_index (int):
327
+ The index in the update function arguments tuple where the previous betas are provided, if applicable. -1 otherwise.
328
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
329
+ A dictionary mapping decision times to maps of subject ids to the function arguments
330
+ required to compute action probabilities for this subject.
331
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
332
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
325
333
  Only applies to in-study decision times!
326
334
  initial_policy_num (int | float):
327
335
  The policy number of the initial policy before any updates.
328
336
  beta_index_by_policy_num (dict[int | float, int]):
329
337
  A dictionary mapping policy numbers to the index of the corresponding beta in
330
338
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
331
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
332
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
339
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
340
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
333
341
  to their respective update function arguments.
334
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
335
- A dictionary mapping user IDs to their respective actions taken at each decision time.
342
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
343
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
336
344
  Only applies to in-study decision times!
337
345
  suppress_all_data_checks (bool):
338
346
  If True, suppresses carrying out any data checks at all.
@@ -362,7 +370,7 @@ class TrialConditioningMonitor:
362
370
  # here to improve performance. We can simply unflatten them inside the function.
363
371
  flatten_params(all_post_update_betas, jnp.array([])),
364
372
  beta_dim,
365
- user_ids,
373
+ subject_ids,
366
374
  action_prob_func,
367
375
  action_prob_func_args_beta_index,
368
376
  alg_update_func,
@@ -370,12 +378,13 @@ class TrialConditioningMonitor:
370
378
  alg_update_func_args_beta_index,
371
379
  alg_update_func_args_action_prob_index,
372
380
  alg_update_func_args_action_prob_times_index,
373
- action_prob_func_args_by_user_id_by_decision_time,
374
- policy_num_by_decision_time_by_user_id,
381
+ alg_update_func_args_previous_betas_index,
382
+ action_prob_func_args_by_subject_id_by_decision_time,
383
+ policy_num_by_decision_time_by_subject_id,
375
384
  initial_policy_num,
376
385
  beta_index_by_policy_num,
377
- update_func_args_by_by_user_id_by_policy_num,
378
- action_by_decision_time_by_user_id,
386
+ update_func_args_by_by_subject_id_by_policy_num,
387
+ action_by_decision_time_by_subject_id,
379
388
  suppress_all_data_checks,
380
389
  suppress_interactive_data_checks,
381
390
  only_latest_block=True,
@@ -404,7 +413,7 @@ class TrialConditioningMonitor:
404
413
  # here to improve performance. We can simply unflatten them inside the function.
405
414
  flatten_params(all_post_update_betas, jnp.array([])),
406
415
  beta_dim,
407
- user_ids,
416
+ subject_ids,
408
417
  action_prob_func,
409
418
  action_prob_func_args_beta_index,
410
419
  alg_update_func,
@@ -412,12 +421,13 @@ class TrialConditioningMonitor:
412
421
  alg_update_func_args_beta_index,
413
422
  alg_update_func_args_action_prob_index,
414
423
  alg_update_func_args_action_prob_times_index,
415
- action_prob_func_args_by_user_id_by_decision_time,
416
- policy_num_by_decision_time_by_user_id,
424
+ alg_update_func_args_previous_betas_index,
425
+ action_prob_func_args_by_subject_id_by_decision_time,
426
+ policy_num_by_decision_time_by_subject_id,
417
427
  initial_policy_num,
418
428
  beta_index_by_policy_num,
419
- update_func_args_by_by_user_id_by_policy_num,
420
- action_by_decision_time_by_user_id,
429
+ update_func_args_by_by_subject_id_by_policy_num,
430
+ action_by_decision_time_by_subject_id,
421
431
  suppress_all_data_checks,
422
432
  suppress_interactive_data_checks,
423
433
  )
@@ -429,7 +439,7 @@ class TrialConditioningMonitor:
429
439
  self,
430
440
  flattened_betas_and_theta: jnp.ndarray,
431
441
  beta_dim: int,
432
- user_ids: jnp.ndarray,
442
+ subject_ids: jnp.ndarray,
433
443
  action_prob_func: callable,
434
444
  action_prob_func_args_beta_index: int,
435
445
  alg_update_func: callable,
@@ -437,18 +447,19 @@ class TrialConditioningMonitor:
437
447
  alg_update_func_args_beta_index: int,
438
448
  alg_update_func_args_action_prob_index: int,
439
449
  alg_update_func_args_action_prob_times_index: int,
440
- action_prob_func_args_by_user_id_by_decision_time: dict[
450
+ alg_update_func_args_previous_betas_index: int,
451
+ action_prob_func_args_by_subject_id_by_decision_time: dict[
441
452
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
442
453
  ],
443
- policy_num_by_decision_time_by_user_id: dict[
454
+ policy_num_by_decision_time_by_subject_id: dict[
444
455
  collections.abc.Hashable, dict[int, int | float]
445
456
  ],
446
457
  initial_policy_num: int | float,
447
458
  beta_index_by_policy_num: dict[int | float, int],
448
- update_func_args_by_by_user_id_by_policy_num: dict[
459
+ update_func_args_by_by_subject_id_by_policy_num: dict[
449
460
  collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
450
461
  ],
451
- action_by_decision_time_by_user_id: dict[
462
+ action_by_decision_time_by_subject_id: dict[
452
463
  collections.abc.Hashable, dict[int, int]
453
464
  ],
454
465
  suppress_all_data_checks: bool,
@@ -459,7 +470,7 @@ class TrialConditioningMonitor:
459
470
  tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
460
471
  ]:
461
472
  """
462
- Computes the average weighted estimating function stack across all users, along with
473
+ Computes the average weighted estimating function stack across all subjects, along with
463
474
  auxiliary values used to construct the adaptive and classical sandwich variances.
464
475
 
465
476
  If only_latest_block is True, only uses data from the most recent update.
@@ -471,8 +482,8 @@ class TrialConditioningMonitor:
471
482
  We simply extract the betas and theta from this array below.
472
483
  beta_dim (int):
473
484
  The dimension of each of the beta parameters.
474
- user_ids (jnp.ndarray):
475
- A 1D JAX NumPy array of user IDs.
485
+ subject_ids (jnp.ndarray):
486
+ A 1D JAX NumPy array of subject IDs.
476
487
  action_prob_func (callable):
477
488
  The action probability function.
478
489
  action_prob_func_args_beta_index (int):
@@ -489,22 +500,24 @@ class TrialConditioningMonitor:
489
500
  alg_update_func_args_action_prob_times_index (int):
490
501
  The index in the update function arguments tuple where an array of times for which the
491
502
  given action probabilities apply is provided, if applicable. -1 otherwise.
492
- action_prob_func_args_by_user_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
493
- A dictionary mapping decision times to maps of user ids to the function arguments
494
- required to compute action probabilities for this user.
495
- policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
496
- A map of user ids to dictionaries mapping decision times to the policy number in use.
503
+ alg_update_func_args_previous_betas_index (int):
504
+ The index in the update function arguments tuple where the previous betas are provided, if applicable. -1 otherwise.
505
+ action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
506
+ A dictionary mapping decision times to maps of subject ids to the function arguments
507
+ required to compute action probabilities for this subject.
508
+ policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
509
+ A map of subject ids to dictionaries mapping decision times to the policy number in use.
497
510
  Only applies to in-study decision times!
498
511
  initial_policy_num (int | float):
499
512
  The policy number of the initial policy before any updates.
500
513
  beta_index_by_policy_num (dict[int | float, int]):
501
514
  A dictionary mapping policy numbers to the index of the corresponding beta in
502
515
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
503
- update_func_args_by_by_user_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
504
- A dictionary where keys are policy numbers and values are dictionaries mapping user IDs
516
+ update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
517
+ A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
505
518
  to their respective update function arguments.
506
- action_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int]]):
507
- A dictionary mapping user IDs to their respective actions taken at each decision time.
519
+ action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
520
+ A dictionary mapping subject IDs to their respective actions taken at each decision time.
508
521
  Only applies to in-study decision times!
509
522
  suppress_all_data_checks (bool):
510
523
  If True, suppresses carrying out any data checks at all.
@@ -536,15 +549,15 @@ class TrialConditioningMonitor:
536
549
  # 1. If only_latest_block is True, we need to filter all the arguments to only
537
550
  # include those relevant to the latest update. We still need action probabilities
538
551
  # from the beginning for the weights, but the update function args can be trimmed
539
- # to the max policy so that the loop single_user_weighted_RL_estimating_function_stacker
552
+ # to the max policy so that the loop single_subject_weighted_RL_estimating_function_stacker
540
553
  # is only over one policy.
541
554
  if only_latest_block:
542
555
  logger.info(
543
556
  "Filtering algorithm update function arguments to only include those relevant to the latest update."
544
557
  )
545
558
  max_policy_num = max(beta_index_by_policy_num)
546
- update_func_args_by_by_user_id_by_policy_num = {
547
- max_policy_num: update_func_args_by_by_user_id_by_policy_num[
559
+ update_func_args_by_by_subject_id_by_policy_num = {
560
+ max_policy_num: update_func_args_by_by_subject_id_by_policy_num[
548
561
  max_policy_num
549
562
  ]
550
563
  }
@@ -553,15 +566,17 @@ class TrialConditioningMonitor:
553
566
  # supplied for the above functions, so that differentiation works correctly. The existing
554
567
  # values should be the same, but not connected to the parameter we are differentiating
555
568
  # with respect to. Note we will also find it useful below to have the action probability args
556
- # nested dict structure flipped to be user_id -> decision_time -> args, so we do that here too.
569
+ # nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
557
570
 
558
- logger.info("Threading in betas to action probability arguments for all users.")
571
+ logger.info(
572
+ "Threading in betas to action probability arguments for all subjects."
573
+ )
559
574
  (
560
- threaded_action_prob_func_args_by_decision_time_by_user_id,
561
- action_prob_func_args_by_decision_time_by_user_id,
575
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
576
+ action_prob_func_args_by_decision_time_by_subject_id,
562
577
  ) = thread_action_prob_func_args(
563
- action_prob_func_args_by_user_id_by_decision_time,
564
- policy_num_by_decision_time_by_user_id,
578
+ action_prob_func_args_by_subject_id_by_decision_time,
579
+ policy_num_by_decision_time_by_subject_id,
565
580
  initial_policy_num,
566
581
  betas,
567
582
  beta_index_by_policy_num,
@@ -573,16 +588,17 @@ class TrialConditioningMonitor:
573
588
  # arguments with the central betas introduced.
574
589
  logger.info(
575
590
  "Threading in betas and beta-dependent action probabilities to algorithm update "
576
- "function args for all users"
591
+ "function args for all subjects"
577
592
  )
578
- threaded_update_func_args_by_policy_num_by_user_id = thread_update_func_args(
579
- update_func_args_by_by_user_id_by_policy_num,
593
+ threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
594
+ update_func_args_by_by_subject_id_by_policy_num,
580
595
  betas,
581
596
  beta_index_by_policy_num,
582
597
  alg_update_func_args_beta_index,
583
598
  alg_update_func_args_action_prob_index,
584
599
  alg_update_func_args_action_prob_times_index,
585
- threaded_action_prob_func_args_by_decision_time_by_user_id,
600
+ alg_update_func_args_previous_betas_index,
601
+ threaded_action_prob_func_args_by_decision_time_by_subject_id,
586
602
  action_prob_func,
587
603
  )
588
604
 
@@ -592,42 +608,44 @@ class TrialConditioningMonitor:
592
608
  if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
593
609
  input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
594
610
  algorithm_estimating_func,
595
- update_func_args_by_by_user_id_by_policy_num,
596
- threaded_update_func_args_by_policy_num_by_user_id,
611
+ update_func_args_by_by_subject_id_by_policy_num,
612
+ threaded_update_func_args_by_policy_num_by_subject_id,
597
613
  suppress_interactive_data_checks,
598
614
  )
599
615
 
600
- # 5. Now we can compute the weighted estimating function stacks for all users
616
+ # 5. Now we can compute the weighted estimating function stacks for all subjects
601
617
  # as well as collect related values used to construct the adaptive and classical
602
618
  # sandwich variances.
603
619
  RL_stacks = jnp.array(
604
620
  [
605
- self.single_user_weighted_RL_estimating_function_stacker(
621
+ self.single_subject_weighted_RL_estimating_function_stacker(
606
622
  beta_dim,
607
- user_id,
623
+ subject_id,
608
624
  action_prob_func,
609
625
  algorithm_estimating_func,
610
626
  action_prob_func_args_beta_index,
611
- action_prob_func_args_by_decision_time_by_user_id[user_id],
612
- threaded_action_prob_func_args_by_decision_time_by_user_id[user_id],
613
- threaded_update_func_args_by_policy_num_by_user_id[user_id],
614
- policy_num_by_decision_time_by_user_id[user_id],
615
- action_by_decision_time_by_user_id[user_id],
627
+ action_prob_func_args_by_decision_time_by_subject_id[subject_id],
628
+ threaded_action_prob_func_args_by_decision_time_by_subject_id[
629
+ subject_id
630
+ ],
631
+ threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
632
+ policy_num_by_decision_time_by_subject_id[subject_id],
633
+ action_by_decision_time_by_subject_id[subject_id],
616
634
  beta_index_by_policy_num,
617
635
  )
618
- for user_id in user_ids.tolist()
636
+ for subject_id in subject_ids.tolist()
619
637
  ]
620
638
  )
621
639
 
622
640
  # 6. We will differentiate the first output, while the second will be used
623
641
  # for an estimating function sum check.
624
- mean_stack_across_users = jnp.mean(RL_stacks, axis=0)
625
- return mean_stack_across_users, mean_stack_across_users
642
+ mean_stack_across_subjects = jnp.mean(RL_stacks, axis=0)
643
+ return mean_stack_across_subjects, mean_stack_across_subjects
626
644
 
627
- def single_user_weighted_RL_estimating_function_stacker(
645
+ def single_subject_weighted_RL_estimating_function_stacker(
628
646
  self,
629
647
  beta_dim: int,
630
- user_id: collections.abc.Hashable,
648
+ subject_id: collections.abc.Hashable,
631
649
  action_prob_func: callable,
632
650
  algorithm_estimating_func: callable,
633
651
  action_prob_func_args_beta_index: int,
@@ -660,12 +678,12 @@ class TrialConditioningMonitor:
660
678
  beta_dim (list[jnp.ndarray]):
661
679
  A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
662
680
 
663
- user_id (collections.abc.Hashable):
664
- The user ID for which to compute the weighted estimating function stack.
681
+ subject_id (collections.abc.Hashable):
682
+ The subject ID for which to compute the weighted estimating function stack.
665
683
 
666
684
  action_prob_func (callable):
667
685
  The function used to compute the probability of action 1 at a given decision time for
668
- a particular user given their state and the algorithm parameters.
686
+ a particular subject given their state and the algorithm parameters.
669
687
 
670
688
  algorithm_estimating_func (callable):
671
689
  The estimating function that corresponds to algorithm updates.
@@ -674,7 +692,7 @@ class TrialConditioningMonitor:
674
692
  The index of the beta argument in the action probability function's arguments.
675
693
 
676
694
  action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
677
- A map from decision times to tuples of arguments for this user for the action
695
+ A map from decision times to tuples of arguments for this subject for the action
678
696
  probability function. This is for all decision times (args are an empty
679
697
  tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
680
698
  ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
@@ -687,13 +705,13 @@ class TrialConditioningMonitor:
687
705
 
688
706
  threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
689
707
  A map from policy numbers to tuples containing the arguments for
690
- the corresponding estimating functions for this user, with the shared betas threaded in
708
+ the corresponding estimating functions for this subject, with the shared betas threaded in
691
709
  for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
692
710
  should be sorted.
693
711
 
694
712
  policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
695
713
  A dictionary mapping decision times to the policy number in use. This may be
696
- user-specific. Should be sorted by decision time. Only applies to in-study decision
714
+ subject-specific. Should be sorted by decision time. Only applies to in-study decision
697
715
  times!
698
716
 
699
717
  action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
@@ -705,18 +723,18 @@ class TrialConditioningMonitor:
705
723
  all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
706
724
 
707
725
  Returns:
708
- jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the user's weighted
726
+ jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the subject's weighted
709
727
  estimating function stack.
710
728
  """
711
729
 
712
730
  logger.info(
713
- "Computing weighted estimating function stack for user %s.", user_id
731
+ "Computing weighted estimating function stack for subject %s.", subject_id
714
732
  )
715
733
 
716
734
  # First, reformat the supplied data into more convenient structures.
717
735
 
718
736
  # 1. Form a dictionary mapping policy numbers to the first time they were
719
- # applicable (for this user). Note that this includes ALL policies, initial
737
+ # applicable (for this subject). Note that this includes ALL policies, initial
720
738
  # fallbacks included.
721
739
  # Collect the first time after the first update separately for convenience.
722
740
  # These are both used to form the Radon-Nikodym weights for the right times.
@@ -727,17 +745,17 @@ class TrialConditioningMonitor:
727
745
  )
728
746
  )
729
747
 
730
- # 2. Get the start and end times for this user.
731
- user_start_time = math.inf
732
- user_end_time = -math.inf
748
+ # 2. Get the start and end times for this subject.
749
+ subject_start_time = math.inf
750
+ subject_end_time = -math.inf
733
751
  for decision_time in action_by_decision_time:
734
- user_start_time = min(user_start_time, decision_time)
735
- user_end_time = max(user_end_time, decision_time)
752
+ subject_start_time = min(subject_start_time, decision_time)
753
+ subject_end_time = max(subject_end_time, decision_time)
736
754
 
737
755
  # 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
738
756
  logger.info(
739
- "Computing the algorithm component of the weighted estimating function stack for user %s.",
740
- user_id,
757
+ "Computing the algorithm component of the weighted estimating function stack for subject %s.",
758
+ subject_id,
741
759
  )
742
760
 
743
761
  in_study_action_prob_func_args = [
@@ -754,13 +772,13 @@ class TrialConditioningMonitor:
754
772
  )
755
773
 
756
774
  # Sort the threaded args by decision time to be cautious. We check if the
757
- # user id is present in the user args dict because we may call this on a
758
- # subset of the user arg dict when we are batching arguments by shape
775
+ # subject id is present in the subject args dict because we may call this on a
776
+ # subset of the subject arg dict when we are batching arguments by shape
759
777
  sorted_threaded_action_prob_args_by_decision_time = {
760
778
  decision_time: threaded_action_prob_func_args_by_decision_time[
761
779
  decision_time
762
780
  ]
763
- for decision_time in range(user_start_time, user_end_time + 1)
781
+ for decision_time in range(subject_start_time, subject_end_time + 1)
764
782
  if decision_time in threaded_action_prob_func_args_by_decision_time
765
783
  }
766
784
 
@@ -820,27 +838,27 @@ class TrialConditioningMonitor:
820
838
  [
821
839
  # Here we compute a product of Radon-Nikodym weights
822
840
  # for all decision times after the first update and before the update
823
- # update under consideration took effect, for which the user was in the study.
841
+ # update under consideration took effect, for which the subject was in the study.
824
842
  (
825
843
  jnp.prod(
826
844
  all_weights[
827
- # The earliest time after the first update where the user was in
845
+ # The earliest time after the first update where the subject was in
828
846
  # the study
829
847
  max(
830
848
  first_time_after_first_update,
831
- user_start_time,
849
+ subject_start_time,
832
850
  )
833
851
  - decision_time_to_all_weights_index_offset :
834
- # One more than the latest time the user was in the study before the time
852
+ # One more than the latest time the subject was in the study before the time
835
853
  # the update under consideration first applied. Note the + 1 because range
836
854
  # does not include the right endpoint.
837
855
  min(
838
856
  min_time_by_policy_num.get(policy_num, math.inf),
839
- user_end_time + 1,
857
+ subject_end_time + 1,
840
858
  )
841
859
  - decision_time_to_all_weights_index_offset,
842
860
  ]
843
- # If the user exited the study before there were any updates,
861
+ # If the subject exited the study before there were any updates,
844
862
  # this variable will be None and the above code to grab a weight would
845
863
  # throw an error. Just use 1 to include the unweighted estimating function
846
864
  # if they have data to contribute to the update.
@@ -848,8 +866,8 @@ class TrialConditioningMonitor:
848
866
  else 1
849
867
  ) # Now use the above to weight the alg estimating function for this update
850
868
  * algorithm_estimating_func(*update_args)
851
- # If there are no arguments for the update function, the user is not yet in the
852
- # study, so we just add a zero vector contribution to the sum across users.
869
+ # If there are no arguments for the update function, the subject is not yet in the
870
+ # study, so we just add a zero vector contribution to the sum across subjects.
853
871
  # Note that after they exit, they still contribute all their data to later
854
872
  # updates.
855
873
  if update_args