lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,12 +26,12 @@ logging.basicConfig(
26
26
 
27
27
  # TODO: any checks needed here about alg update function type?
28
28
  def perform_first_wave_input_checks(
29
- study_df,
30
- in_study_col_name,
29
+ analysis_df,
30
+ active_col_name,
31
31
  action_col_name,
32
32
  policy_num_col_name,
33
33
  calendar_t_col_name,
34
- user_id_col_name,
34
+ subject_id_col_name,
35
35
  action_prob_col_name,
36
36
  reward_col_name,
37
37
  action_prob_func,
@@ -48,11 +48,11 @@ def perform_first_wave_input_checks(
48
48
  small_sample_correction,
49
49
  ):
50
50
  ### Validate algorithm loss/estimating function and args
51
- require_alg_update_args_given_for_all_users_at_each_update(
52
- study_df, user_id_col_name, alg_update_func_args
51
+ require_alg_update_args_given_for_all_subjects_at_each_update(
52
+ analysis_df, subject_id_col_name, alg_update_func_args
53
53
  )
54
- require_no_policy_numbers_present_in_alg_update_args_but_not_study_df(
55
- study_df, policy_num_col_name, alg_update_func_args
54
+ require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
55
+ analysis_df, policy_num_col_name, alg_update_func_args
56
56
  )
57
57
  require_beta_is_1D_array_in_alg_update_args(
58
58
  alg_update_func_args, alg_update_func_args_beta_index
@@ -60,8 +60,8 @@ def perform_first_wave_input_checks(
60
60
  require_previous_betas_is_2D_array_in_alg_update_args(
61
61
  alg_update_func_args, alg_update_func_args_previous_betas_index
62
62
  )
63
- require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
64
- study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
63
+ require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
64
+ analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
65
65
  )
66
66
  confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
67
67
  alg_update_func_args_action_prob_index, suppress_interactive_data_checks
@@ -80,17 +80,17 @@ def perform_first_wave_input_checks(
80
80
  require_previous_betas_match_in_alg_update_args_each_update(
81
81
  alg_update_func_args, alg_update_func_args_previous_betas_index
82
82
  )
83
- require_action_prob_args_in_alg_update_func_correspond_to_study_df(
84
- study_df,
83
+ require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
84
+ analysis_df,
85
85
  action_prob_col_name,
86
86
  calendar_t_col_name,
87
- user_id_col_name,
87
+ subject_id_col_name,
88
88
  alg_update_func_args,
89
89
  alg_update_func_args_action_prob_index,
90
90
  alg_update_func_args_action_prob_times_index,
91
91
  )
92
92
  require_valid_action_prob_times_given_if_index_supplied(
93
- study_df,
93
+ analysis_df,
94
94
  calendar_t_col_name,
95
95
  alg_update_func_args,
96
96
  alg_update_func_args_action_prob_times_index,
@@ -101,28 +101,28 @@ def perform_first_wave_input_checks(
101
101
  )
102
102
 
103
103
  ### Validate action prob function and args
104
- require_action_prob_func_args_given_for_all_users_at_each_decision(
105
- study_df, user_id_col_name, action_prob_func_args
104
+ require_action_prob_func_args_given_for_all_subjects_at_each_decision(
105
+ analysis_df, subject_id_col_name, action_prob_func_args
106
106
  )
107
107
  require_action_prob_func_args_given_for_all_decision_times(
108
- study_df, calendar_t_col_name, action_prob_func_args
108
+ analysis_df, calendar_t_col_name, action_prob_func_args
109
109
  )
110
- require_action_probabilities_in_study_df_can_be_reconstructed(
111
- study_df,
110
+ require_action_probabilities_in_analysis_df_can_be_reconstructed(
111
+ analysis_df,
112
112
  action_prob_col_name,
113
113
  calendar_t_col_name,
114
- user_id_col_name,
115
- in_study_col_name,
114
+ subject_id_col_name,
115
+ active_col_name,
116
116
  action_prob_func_args,
117
117
  action_prob_func,
118
118
  )
119
119
 
120
120
  require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
121
- study_df,
121
+ analysis_df,
122
122
  calendar_t_col_name,
123
123
  action_prob_func_args,
124
- in_study_col_name,
125
- user_id_col_name,
124
+ active_col_name,
125
+ subject_id_col_name,
126
126
  )
127
127
  require_beta_is_1D_array_in_action_prob_args(
128
128
  action_prob_func_args, action_prob_func_args_beta_index
@@ -131,13 +131,13 @@ def perform_first_wave_input_checks(
131
131
  action_prob_func_args, action_prob_func_args_beta_index
132
132
  )
133
133
 
134
- ### Validate study_df
135
- verify_study_df_summary_satisfactory(
136
- study_df,
137
- user_id_col_name,
134
+ ### Validate analysis_df
135
+ verify_analysis_df_summary_satisfactory(
136
+ analysis_df,
137
+ subject_id_col_name,
138
138
  policy_num_col_name,
139
139
  calendar_t_col_name,
140
- in_study_col_name,
140
+ active_col_name,
141
141
  action_prob_col_name,
142
142
  reward_col_name,
143
143
  beta_dim,
@@ -145,46 +145,46 @@ def perform_first_wave_input_checks(
145
145
  suppress_interactive_data_checks,
146
146
  )
147
147
 
148
- require_all_users_have_all_times_in_study_df(
149
- study_df, calendar_t_col_name, user_id_col_name
148
+ require_all_subjects_have_all_times_in_analysis_df(
149
+ analysis_df, calendar_t_col_name, subject_id_col_name
150
150
  )
151
- require_all_named_columns_present_in_study_df(
152
- study_df,
153
- in_study_col_name,
151
+ require_all_named_columns_present_in_analysis_df(
152
+ analysis_df,
153
+ active_col_name,
154
154
  action_col_name,
155
155
  policy_num_col_name,
156
156
  calendar_t_col_name,
157
- user_id_col_name,
157
+ subject_id_col_name,
158
158
  action_prob_col_name,
159
159
  )
160
- require_all_named_columns_not_object_type_in_study_df(
161
- study_df,
162
- in_study_col_name,
160
+ require_all_named_columns_not_object_type_in_analysis_df(
161
+ analysis_df,
162
+ active_col_name,
163
163
  action_col_name,
164
164
  policy_num_col_name,
165
165
  calendar_t_col_name,
166
- user_id_col_name,
166
+ subject_id_col_name,
167
167
  action_prob_col_name,
168
168
  )
169
- require_binary_actions(study_df, in_study_col_name, action_col_name)
170
- require_binary_in_study_indicators(study_df, in_study_col_name)
169
+ require_binary_actions(analysis_df, active_col_name, action_col_name)
170
+ require_binary_active_indicators(analysis_df, active_col_name)
171
171
  require_consecutive_integer_policy_numbers(
172
- study_df, in_study_col_name, policy_num_col_name
172
+ analysis_df, active_col_name, policy_num_col_name
173
173
  )
174
- require_consecutive_integer_calendar_times(study_df, calendar_t_col_name)
175
- require_hashable_user_ids(study_df, in_study_col_name, user_id_col_name)
176
- require_action_probabilities_in_range_0_to_1(study_df, action_prob_col_name)
174
+ require_consecutive_integer_calendar_times(analysis_df, calendar_t_col_name)
175
+ require_hashable_subject_ids(analysis_df, active_col_name, subject_id_col_name)
176
+ require_action_probabilities_in_range_0_to_1(analysis_df, action_prob_col_name)
177
177
 
178
178
  ### Validate theta estimation
179
179
  require_theta_is_1D_array(theta_est)
180
180
 
181
181
 
182
182
  def perform_alg_only_input_checks(
183
- study_df,
184
- in_study_col_name,
183
+ analysis_df,
184
+ active_col_name,
185
185
  policy_num_col_name,
186
186
  calendar_t_col_name,
187
- user_id_col_name,
187
+ subject_id_col_name,
188
188
  action_prob_col_name,
189
189
  action_prob_func,
190
190
  action_prob_func_args,
@@ -196,14 +196,14 @@ def perform_alg_only_input_checks(
196
196
  suppress_interactive_data_checks,
197
197
  ):
198
198
  ### Validate algorithm loss/estimating function and args
199
- require_alg_update_args_given_for_all_users_at_each_update(
200
- study_df, user_id_col_name, alg_update_func_args
199
+ require_alg_update_args_given_for_all_subjects_at_each_update(
200
+ analysis_df, subject_id_col_name, alg_update_func_args
201
201
  )
202
202
  require_beta_is_1D_array_in_alg_update_args(
203
203
  alg_update_func_args, alg_update_func_args_beta_index
204
204
  )
205
- require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
206
- study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
205
+ require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
206
+ analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
207
207
  )
208
208
  confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
209
209
  alg_update_func_args_action_prob_index, suppress_interactive_data_checks
@@ -219,45 +219,45 @@ def perform_alg_only_input_checks(
219
219
  require_betas_match_in_alg_update_args_each_update(
220
220
  alg_update_func_args, alg_update_func_args_beta_index
221
221
  )
222
- require_action_prob_args_in_alg_update_func_correspond_to_study_df(
223
- study_df,
222
+ require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
223
+ analysis_df,
224
224
  action_prob_col_name,
225
225
  calendar_t_col_name,
226
- user_id_col_name,
226
+ subject_id_col_name,
227
227
  alg_update_func_args,
228
228
  alg_update_func_args_action_prob_index,
229
229
  alg_update_func_args_action_prob_times_index,
230
230
  )
231
231
  require_valid_action_prob_times_given_if_index_supplied(
232
- study_df,
232
+ analysis_df,
233
233
  calendar_t_col_name,
234
234
  alg_update_func_args,
235
235
  alg_update_func_args_action_prob_times_index,
236
236
  )
237
237
 
238
238
  ### Validate action prob function and args
239
- require_action_prob_func_args_given_for_all_users_at_each_decision(
240
- study_df, user_id_col_name, action_prob_func_args
239
+ require_action_prob_func_args_given_for_all_subjects_at_each_decision(
240
+ analysis_df, subject_id_col_name, action_prob_func_args
241
241
  )
242
242
  require_action_prob_func_args_given_for_all_decision_times(
243
- study_df, calendar_t_col_name, action_prob_func_args
243
+ analysis_df, calendar_t_col_name, action_prob_func_args
244
244
  )
245
- require_action_probabilities_in_study_df_can_be_reconstructed(
246
- study_df,
245
+ require_action_probabilities_in_analysis_df_can_be_reconstructed(
246
+ analysis_df,
247
247
  action_prob_col_name,
248
248
  calendar_t_col_name,
249
- user_id_col_name,
250
- in_study_col_name,
249
+ subject_id_col_name,
250
+ active_col_name,
251
251
  action_prob_func_args,
252
252
  action_prob_func=action_prob_func,
253
253
  )
254
254
 
255
255
  require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
256
- study_df,
256
+ analysis_df,
257
257
  calendar_t_col_name,
258
258
  action_prob_func_args,
259
- in_study_col_name,
260
- user_id_col_name,
259
+ active_col_name,
260
+ subject_id_col_name,
261
261
  )
262
262
  require_beta_is_1D_array_in_action_prob_args(
263
263
  action_prob_func_args, action_prob_func_args_beta_index
@@ -268,12 +268,12 @@ def perform_alg_only_input_checks(
268
268
 
269
269
 
270
270
  # TODO: Give a hard-to-use option to loosen this check somehow
271
- def require_action_probabilities_in_study_df_can_be_reconstructed(
272
- study_df,
271
+ def require_action_probabilities_in_analysis_df_can_be_reconstructed(
272
+ analysis_df,
273
273
  action_prob_col_name,
274
274
  calendar_t_col_name,
275
- user_id_col_name,
276
- in_study_col_name,
275
+ subject_id_col_name,
276
+ active_col_name,
277
277
  action_prob_func_args,
278
278
  action_prob_func,
279
279
  ):
@@ -285,77 +285,79 @@ def require_action_probabilities_in_study_df_can_be_reconstructed(
285
285
  """
286
286
  logger.info("Reconstructing action probabilities from function and arguments.")
287
287
 
288
- in_study_df = study_df[study_df[in_study_col_name] == 1]
289
- reconstructed_action_probs = in_study_df.apply(
288
+ active_df = analysis_df[analysis_df[active_col_name] == 1]
289
+ reconstructed_action_probs = active_df.apply(
290
290
  lambda row: action_prob_func(
291
- *action_prob_func_args[row[calendar_t_col_name]][row[user_id_col_name]]
291
+ *action_prob_func_args[row[calendar_t_col_name]][row[subject_id_col_name]]
292
292
  ),
293
293
  axis=1,
294
294
  )
295
295
 
296
296
  np.testing.assert_allclose(
297
- in_study_df[action_prob_col_name].to_numpy(dtype="float64"),
297
+ active_df[action_prob_col_name].to_numpy(dtype="float64"),
298
298
  reconstructed_action_probs.to_numpy(dtype="float64"),
299
299
  atol=1e-6,
300
300
  )
301
301
 
302
302
 
303
- def require_all_users_have_all_times_in_study_df(
304
- study_df, calendar_t_col_name, user_id_col_name
303
+ def require_all_subjects_have_all_times_in_analysis_df(
304
+ analysis_df, calendar_t_col_name, subject_id_col_name
305
305
  ):
306
- logger.info("Checking that all users have the same set of unique calendar times.")
306
+ logger.info(
307
+ "Checking that all subjects have the same set of unique calendar times."
308
+ )
307
309
  # Get the unique calendar times
308
- unique_calendar_times = set(study_df[calendar_t_col_name].unique())
310
+ unique_calendar_times = set(analysis_df[calendar_t_col_name].unique())
309
311
 
310
- # Group by user ID and aggregate the unique calendar times for each user
311
- user_calendar_times = study_df.groupby(user_id_col_name)[calendar_t_col_name].apply(
312
- set
313
- )
312
+ # Group by subject ID and aggregate the unique calendar times for each subject
313
+ subject_calendar_times = analysis_df.groupby(subject_id_col_name)[
314
+ calendar_t_col_name
315
+ ].apply(set)
314
316
 
315
- # Check if all users have the same set of unique calendar times
316
- if not user_calendar_times.apply(lambda x: x == unique_calendar_times).all():
317
+ # Check if all subjects have the same set of unique calendar times
318
+ if not subject_calendar_times.apply(lambda x: x == unique_calendar_times).all():
317
319
  raise AssertionError(
318
- "Not all users have all calendar times in the study dataframe. Please see the contract for details."
320
+ "Not all subjects have all calendar times in the study dataframe. Please see the contract for details."
319
321
  )
320
322
 
321
323
 
322
- def require_alg_update_args_given_for_all_users_at_each_update(
323
- study_df, user_id_col_name, alg_update_func_args
324
+ def require_alg_update_args_given_for_all_subjects_at_each_update(
325
+ analysis_df, subject_id_col_name, alg_update_func_args
324
326
  ):
325
327
  logger.info(
326
- "Checking that algorithm update function args are given for all users at each update."
328
+ "Checking that algorithm update function args are given for all subjects at each update."
327
329
  )
328
- all_user_ids = set(study_df[user_id_col_name].unique())
330
+ all_subject_ids = set(analysis_df[subject_id_col_name].unique())
329
331
  for policy_num in alg_update_func_args:
330
332
  assert (
331
- set(alg_update_func_args[policy_num].keys()) == all_user_ids
332
- ), f"Not all users present in algorithm update function args for policy number {policy_num}. Please see the contract for details."
333
+ set(alg_update_func_args[policy_num].keys()) == all_subject_ids
334
+ ), f"Not all subjects present in algorithm update function args for policy number {policy_num}. Please see the contract for details."
333
335
 
334
336
 
335
- def require_action_prob_args_in_alg_update_func_correspond_to_study_df(
336
- study_df,
337
+ def require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
338
+ analysis_df,
337
339
  action_prob_col_name,
338
340
  calendar_t_col_name,
339
- user_id_col_name,
341
+ subject_id_col_name,
340
342
  alg_update_func_args,
341
343
  alg_update_func_args_action_prob_index,
342
344
  alg_update_func_args_action_prob_times_index,
343
345
  ):
344
346
  logger.info(
345
347
  "Checking that the action probabilities supplied in the algorithm update function args, if"
346
- " any, correspond to those in the study dataframe for the corresponding users and decision"
348
+ " any, correspond to those in the study dataframe for the corresponding subjects and decision"
347
349
  " times."
348
350
  )
349
351
  if alg_update_func_args_action_prob_index < 0:
350
352
  return
351
353
 
352
354
  # Precompute a lookup dictionary for faster access
353
- study_df_lookup = study_df.set_index([calendar_t_col_name, user_id_col_name])[
354
- action_prob_col_name
355
- ].to_dict()
355
+ analysis_df_lookup = analysis_df.set_index(
356
+ [calendar_t_col_name, subject_id_col_name]
357
+ )[action_prob_col_name].to_dict()
356
358
 
357
- for policy_num, user_args in alg_update_func_args.items():
358
- for user_id, args in user_args.items():
359
+ for policy_num, subject_args in alg_update_func_args.items():
360
+ for subject_id, args in subject_args.items():
359
361
  if not args:
360
362
  continue
361
363
  arg_action_probs = args[alg_update_func_args_action_prob_index]
@@ -364,43 +366,43 @@ def require_action_prob_args_in_alg_update_func_correspond_to_study_df(
364
366
  ].flatten()
365
367
 
366
368
  # Use the precomputed lookup dictionary
367
- study_df_action_probs = [
368
- study_df_lookup[(decision_time.item(), user_id)]
369
+ analysis_df_action_probs = [
370
+ analysis_df_lookup[(decision_time.item(), subject_id)]
369
371
  for decision_time in action_prob_times
370
372
  ]
371
373
 
372
374
  assert np.allclose(
373
375
  arg_action_probs.flatten(),
374
- study_df_action_probs,
376
+ analysis_df_action_probs,
375
377
  ), (
376
- f"There is a mismatch for user {user_id} between the action probabilities supplied"
378
+ f"There is a mismatch for subject {subject_id} between the action probabilities supplied"
377
379
  f" in the args to the algorithm update function at policy {policy_num} and those in"
378
380
  " the study dataframe for the supplied times. Please see the contract for details."
379
381
  )
380
382
 
381
383
 
382
- def require_action_prob_func_args_given_for_all_users_at_each_decision(
383
- study_df,
384
- user_id_col_name,
384
+ def require_action_prob_func_args_given_for_all_subjects_at_each_decision(
385
+ analysis_df,
386
+ subject_id_col_name,
385
387
  action_prob_func_args,
386
388
  ):
387
389
  logger.info(
388
- "Checking that action prob function args are given for all users at each decision time."
390
+ "Checking that action prob function args are given for all subjects at each decision time."
389
391
  )
390
- all_user_ids = set(study_df[user_id_col_name].unique())
392
+ all_subject_ids = set(analysis_df[subject_id_col_name].unique())
391
393
  for decision_time in action_prob_func_args:
392
394
  assert (
393
- set(action_prob_func_args[decision_time].keys()) == all_user_ids
394
- ), f"Not all users present in algorithm update function args for decision time {decision_time}. Please see the contract for details."
395
+ set(action_prob_func_args[decision_time].keys()) == all_subject_ids
396
+ ), f"Not all subjects present in algorithm update function args for decision time {decision_time}. Please see the contract for details."
395
397
 
396
398
 
397
399
  def require_action_prob_func_args_given_for_all_decision_times(
398
- study_df, calendar_t_col_name, action_prob_func_args
400
+ analysis_df, calendar_t_col_name, action_prob_func_args
399
401
  ):
400
402
  logger.info(
401
403
  "Checking that action prob function args are given for all decision times."
402
404
  )
403
- all_times = set(study_df[calendar_t_col_name].unique())
405
+ all_times = set(analysis_df[calendar_t_col_name].unique())
404
406
 
405
407
  assert (
406
408
  set(action_prob_func_args.keys()) == all_times
@@ -408,106 +410,106 @@ def require_action_prob_func_args_given_for_all_decision_times(
408
410
 
409
411
 
410
412
  def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
411
- study_df: pd.DataFrame,
413
+ analysis_df: pd.DataFrame,
412
414
  calendar_t_col_name: str,
413
415
  action_prob_func_args: dict[str, dict[str, tuple[Any, ...]]],
414
- in_study_col_name,
415
- user_id_col_name,
416
+ active_col_name,
417
+ subject_id_col_name,
416
418
  ):
417
419
  logger.info(
418
- "Checking that action probability function args are blank for exactly the times each user"
419
- " is not in the study according to the study dataframe."
420
+ "Checking that action probability function args are blank for exactly the times each subject"
421
+ "is not in the study according to the study dataframe."
420
422
  )
421
- out_of_study_df = study_df[study_df[in_study_col_name] == 0]
422
- out_of_study_times_by_user_according_to_study_df = (
423
- out_of_study_df.groupby(user_id_col_name)[calendar_t_col_name]
423
+ inactive_df = analysis_df[analysis_df[active_col_name] == 0]
424
+ inactive_times_by_subject_according_to_analysis_df = (
425
+ inactive_df.groupby(subject_id_col_name)[calendar_t_col_name]
424
426
  .apply(set)
425
427
  .to_dict()
426
428
  )
427
429
 
428
- out_of_study_times_by_user_according_to_action_prob_func_args = (
430
+ inactive_times_by_subject_according_to_action_prob_func_args = (
429
431
  collections.defaultdict(set)
430
432
  )
431
- for decision_time, action_prob_args_by_user in action_prob_func_args.items():
432
- for user_id, action_prob_args in action_prob_args_by_user.items():
433
+ for decision_time, action_prob_args_by_subject in action_prob_func_args.items():
434
+ for subject_id, action_prob_args in action_prob_args_by_subject.items():
433
435
  if not action_prob_args:
434
- out_of_study_times_by_user_according_to_action_prob_func_args[
435
- user_id
436
+ inactive_times_by_subject_according_to_action_prob_func_args[
437
+ subject_id
436
438
  ].add(decision_time)
437
439
 
438
440
  assert (
439
- out_of_study_times_by_user_according_to_study_df
440
- == out_of_study_times_by_user_according_to_action_prob_func_args
441
+ inactive_times_by_subject_according_to_analysis_df
442
+ == inactive_times_by_subject_according_to_action_prob_func_args
441
443
  ), (
442
444
  "Out-of-study decision times according to the study dataframe do not match up with the"
443
- " times for which action probability arguments are blank for all users. Please see the"
445
+ " times for which action probability arguments are blank for all subjects. Please see the"
444
446
  " contract for details."
445
447
  )
446
448
 
447
449
 
448
- def require_all_named_columns_present_in_study_df(
449
- study_df,
450
- in_study_col_name,
450
+ def require_all_named_columns_present_in_analysis_df(
451
+ analysis_df,
452
+ active_col_name,
451
453
  action_col_name,
452
454
  policy_num_col_name,
453
455
  calendar_t_col_name,
454
- user_id_col_name,
456
+ subject_id_col_name,
455
457
  action_prob_col_name,
456
458
  ):
457
459
  logger.info("Checking that all named columns are present in the study dataframe.")
460
+ assert active_col_name in analysis_df.columns, f"{active_col_name} not in study df."
461
+ assert action_col_name in analysis_df.columns, f"{action_col_name} not in study df."
458
462
  assert (
459
- in_study_col_name in study_df.columns
460
- ), f"{in_study_col_name} not in study df."
461
- assert action_col_name in study_df.columns, f"{action_col_name} not in study df."
462
- assert (
463
- policy_num_col_name in study_df.columns
463
+ policy_num_col_name in analysis_df.columns
464
464
  ), f"{policy_num_col_name} not in study df."
465
465
  assert (
466
- calendar_t_col_name in study_df.columns
466
+ calendar_t_col_name in analysis_df.columns
467
467
  ), f"{calendar_t_col_name} not in study df."
468
- assert user_id_col_name in study_df.columns, f"{user_id_col_name} not in study df."
469
468
  assert (
470
- action_prob_col_name in study_df.columns
469
+ subject_id_col_name in analysis_df.columns
470
+ ), f"{subject_id_col_name} not in study df."
471
+ assert (
472
+ action_prob_col_name in analysis_df.columns
471
473
  ), f"{action_prob_col_name} not in study df."
472
474
 
473
475
 
474
- def require_all_named_columns_not_object_type_in_study_df(
475
- study_df,
476
- in_study_col_name,
476
+ def require_all_named_columns_not_object_type_in_analysis_df(
477
+ analysis_df,
478
+ active_col_name,
477
479
  action_col_name,
478
480
  policy_num_col_name,
479
481
  calendar_t_col_name,
480
- user_id_col_name,
482
+ subject_id_col_name,
481
483
  action_prob_col_name,
482
484
  ):
483
485
  logger.info("Checking that all named columns are not type object.")
484
486
  for colname in (
485
- in_study_col_name,
487
+ active_col_name,
486
488
  action_col_name,
487
489
  policy_num_col_name,
488
490
  calendar_t_col_name,
489
- user_id_col_name,
491
+ subject_id_col_name,
490
492
  action_prob_col_name,
491
493
  ):
492
494
  assert (
493
- study_df[colname].dtype != "object"
495
+ analysis_df[colname].dtype != "object"
494
496
  ), f"At least {colname} is of object type in study df."
495
497
 
496
498
 
497
- def require_binary_actions(study_df, in_study_col_name, action_col_name):
499
+ def require_binary_actions(analysis_df, active_col_name, action_col_name):
498
500
  logger.info("Checking that actions are binary.")
499
501
  assert (
500
- study_df[study_df[in_study_col_name] == 1][action_col_name]
502
+ analysis_df[analysis_df[active_col_name] == 1][action_col_name]
501
503
  .astype("int64")
502
504
  .isin([0, 1])
503
505
  .all()
504
506
  ), "Actions are not binary."
505
507
 
506
508
 
507
- def require_binary_in_study_indicators(study_df, in_study_col_name):
508
- logger.info("Checking that in-study indicators are binary.")
509
+ def require_binary_active_indicators(analysis_df, active_col_name):
510
+ logger.info("Checking that active indicators are binary.")
509
511
  assert (
510
- study_df[study_df[in_study_col_name] == 1][in_study_col_name]
512
+ analysis_df[analysis_df[active_col_name] == 1][active_col_name]
511
513
  .astype("int64")
512
514
  .isin([0, 1])
513
515
  .all()
@@ -515,7 +517,7 @@ def require_binary_in_study_indicators(study_df, in_study_col_name):
515
517
 
516
518
 
517
519
  def require_consecutive_integer_policy_numbers(
518
- study_df, in_study_col_name, policy_num_col_name
520
+ analysis_df, active_col_name, policy_num_col_name
519
521
  ):
520
522
  # TODO: This is a somewhat rough check of this, could also check nondecreasing temporally
521
523
 
@@ -523,8 +525,8 @@ def require_consecutive_integer_policy_numbers(
523
525
  "Checking that in-study, non-fallback policy numbers are consecutive integers."
524
526
  )
525
527
 
526
- in_study_df = study_df[study_df[in_study_col_name] == 1]
527
- nonnegative_policy_df = in_study_df[in_study_df[policy_num_col_name] >= 0]
528
+ active_df = analysis_df[analysis_df[active_col_name] == 1]
529
+ nonnegative_policy_df = active_df[active_df[policy_num_col_name] >= 0]
528
530
  # Ideally we actually have integers, but for legacy reasons we will support
529
531
  # floats as well.
530
532
  if nonnegative_policy_df[policy_num_col_name].dtype == "float64":
@@ -540,66 +542,67 @@ def require_consecutive_integer_policy_numbers(
540
542
  ), "Policy numbers are not consecutive integers."
541
543
 
542
544
 
543
- def require_consecutive_integer_calendar_times(study_df, calendar_t_col_name):
545
+ def require_consecutive_integer_calendar_times(analysis_df, calendar_t_col_name):
544
546
  # This is a somewhat rough check of this, more like checking there are no
545
- # gaps in the integers covered. But we have other checks that all users
547
+ # gaps in the integers covered. But we have other checks that all subjects
546
548
  # have same times, etc.
547
- # Note these times should be well-formed even when the user is not in the study.
549
+ # Note these times should be well-formed even when the subject is not in the study.
548
550
  logger.info("Checking that calendar times are consecutive integers.")
549
551
  assert np.array_equal(
550
- study_df[calendar_t_col_name].unique(),
552
+ analysis_df[calendar_t_col_name].unique(),
551
553
  range(
552
- study_df[calendar_t_col_name].min(), study_df[calendar_t_col_name].max() + 1
554
+ analysis_df[calendar_t_col_name].min(),
555
+ analysis_df[calendar_t_col_name].max() + 1,
553
556
  ),
554
557
  ), "Calendar times are not consecutive integers."
555
558
 
556
559
 
557
- def require_hashable_user_ids(study_df, in_study_col_name, user_id_col_name):
558
- logger.info("Checking that user IDs are hashable.")
560
+ def require_hashable_subject_ids(analysis_df, active_col_name, subject_id_col_name):
561
+ logger.info("Checking that subject IDs are hashable.")
559
562
  isinstance(
560
- study_df[study_df[in_study_col_name] == 1][user_id_col_name][0],
563
+ analysis_df[analysis_df[active_col_name] == 1][subject_id_col_name][0],
561
564
  collections.abc.Hashable,
562
565
  )
563
566
 
564
567
 
565
- def require_action_probabilities_in_range_0_to_1(study_df, action_prob_col_name):
568
+ def require_action_probabilities_in_range_0_to_1(analysis_df, action_prob_col_name):
566
569
  logger.info("Checking that action probabilities are in the interval (0, 1).")
567
- study_df[action_prob_col_name].between(0, 1, inclusive="neither").all()
570
+ analysis_df[action_prob_col_name].between(0, 1, inclusive="neither").all()
568
571
 
569
572
 
570
- def require_no_policy_numbers_present_in_alg_update_args_but_not_study_df(
571
- study_df, policy_num_col_name, alg_update_func_args
573
+ def require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
574
+ analysis_df, policy_num_col_name, alg_update_func_args
572
575
  ):
573
576
  logger.info(
574
577
  "Checking that policy numbers in algorithm update function args are present in the study dataframe."
575
578
  )
576
579
  alg_update_policy_nums = sorted(alg_update_func_args.keys())
577
- study_df_policy_nums = sorted(study_df[policy_num_col_name].unique())
578
- assert set(alg_update_policy_nums).issubset(set(study_df_policy_nums)), (
580
+ analysis_df_policy_nums = sorted(analysis_df[policy_num_col_name].unique())
581
+ assert set(alg_update_policy_nums).issubset(set(analysis_df_policy_nums)), (
579
582
  f"There are policy numbers present in algorithm update function args but not in the study dataframe. "
580
583
  f"\nalg_update_func_args policy numbers: {alg_update_policy_nums}"
581
- f"\nstudy_df policy numbers: {study_df_policy_nums}.\nPlease see the contract for details."
584
+ f"\nanalysis_df policy numbers: {analysis_df_policy_nums}.\nPlease see the contract for details."
582
585
  )
583
586
 
584
587
 
585
- def require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
586
- study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
588
+ def require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
589
+ analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
587
590
  ):
588
591
  logger.info(
589
592
  "Checking that all policy numbers in the study dataframe are present in the algorithm update function args."
590
593
  )
591
- in_study_df = study_df[study_df[in_study_col_name] == 1]
594
+ active_df = analysis_df[analysis_df[active_col_name] == 1]
592
595
  # Get the number of the initial policy. 0 is recommended but not required.
593
- min_nonnegative_policy_number = in_study_df[in_study_df[policy_num_col_name] >= 0][
596
+ min_nonnegative_policy_number = active_df[active_df[policy_num_col_name] >= 0][
594
597
  policy_num_col_name
595
598
  ]
596
599
  assert set(
597
- in_study_df[in_study_df[policy_num_col_name] > min_nonnegative_policy_number][
600
+ active_df[active_df[policy_num_col_name] > min_nonnegative_policy_number][
598
601
  policy_num_col_name
599
602
  ].unique()
600
603
  ).issubset(
601
604
  alg_update_func_args.keys()
602
- ), f"There are non-fallback, non-initial policy numbers in the study dataframe that are not in the update function args: {set(in_study_df[in_study_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
605
+ ), f"There are non-fallback, non-initial policy numbers in the study dataframe that are not in the update function args: {set(active_df[active_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
603
606
 
604
607
 
605
608
  def confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
@@ -672,28 +675,29 @@ def require_beta_is_1D_array_in_alg_update_args(
672
675
  alg_update_func_args, alg_update_func_args_beta_index
673
676
  ):
674
677
  for policy_num in alg_update_func_args:
675
- for user_id in alg_update_func_args[policy_num]:
676
- if not alg_update_func_args[policy_num][user_id]:
678
+ for subject_id in alg_update_func_args[policy_num]:
679
+ if not alg_update_func_args[policy_num][subject_id]:
677
680
  continue
678
681
  assert (
679
- alg_update_func_args[policy_num][user_id][
682
+ alg_update_func_args[policy_num][subject_id][
680
683
  alg_update_func_args_beta_index
681
684
  ].ndim
682
685
  == 1
683
686
  ), "Beta is not a 1D array in the algorithm update function args."
684
687
 
688
+
685
689
  def require_previous_betas_is_2D_array_in_alg_update_args(
686
690
  alg_update_func_args, alg_update_func_args_previous_betas_index
687
691
  ):
688
692
  if alg_update_func_args_previous_betas_index < 0:
689
693
  return
690
-
694
+
691
695
  for policy_num in alg_update_func_args:
692
- for user_id in alg_update_func_args[policy_num]:
693
- if not alg_update_func_args[policy_num][user_id]:
696
+ for subject_id in alg_update_func_args[policy_num]:
697
+ if not alg_update_func_args[policy_num][subject_id]:
694
698
  continue
695
699
  assert (
696
- alg_update_func_args[policy_num][user_id][
700
+ alg_update_func_args[policy_num][subject_id][
697
701
  alg_update_func_args_previous_betas_index
698
702
  ].ndim
699
703
  == 2
@@ -704,11 +708,11 @@ def require_beta_is_1D_array_in_action_prob_args(
704
708
  action_prob_func_args, action_prob_func_args_beta_index
705
709
  ):
706
710
  for decision_time in action_prob_func_args:
707
- for user_id in action_prob_func_args[decision_time]:
708
- if not action_prob_func_args[decision_time][user_id]:
711
+ for subject_id in action_prob_func_args[decision_time]:
712
+ if not action_prob_func_args[decision_time][subject_id]:
709
713
  continue
710
714
  assert (
711
- action_prob_func_args[decision_time][user_id][
715
+ action_prob_func_args[decision_time][subject_id][
712
716
  action_prob_func_args_beta_index
713
717
  ].ndim
714
718
  == 1
@@ -719,12 +723,12 @@ def require_theta_is_1D_array(theta_est):
719
723
  assert theta_est.ndim == 1, "Theta is not a 1D array."
720
724
 
721
725
 
722
- def verify_study_df_summary_satisfactory(
723
- study_df,
724
- user_id_col_name,
726
+ def verify_analysis_df_summary_satisfactory(
727
+ analysis_df,
728
+ subject_id_col_name,
725
729
  policy_num_col_name,
726
730
  calendar_t_col_name,
727
- in_study_col_name,
731
+ active_col_name,
728
732
  action_prob_col_name,
729
733
  reward_col_name,
730
734
  beta_dim,
@@ -732,43 +736,41 @@ def verify_study_df_summary_satisfactory(
732
736
  suppress_interactive_data_checks,
733
737
  ):
734
738
 
735
- in_study_df = study_df[study_df[in_study_col_name] == 1]
736
- num_users = in_study_df[user_id_col_name].nunique()
737
- num_non_initial_or_fallback_policies = in_study_df[
738
- in_study_df[policy_num_col_name] > 0
739
+ active_df = analysis_df[analysis_df[active_col_name] == 1]
740
+ num_subjects = active_df[subject_id_col_name].nunique()
741
+ num_non_initial_or_fallback_policies = active_df[
742
+ active_df[policy_num_col_name] > 0
739
743
  ][policy_num_col_name].nunique()
740
744
  num_decision_times_with_fallback_policies = len(
741
- in_study_df[in_study_df[policy_num_col_name] < 0]
745
+ active_df[active_df[policy_num_col_name] < 0]
742
746
  )
743
- num_decision_times = in_study_df[calendar_t_col_name].nunique()
744
- avg_decisions_per_user = len(in_study_df) / num_users
747
+ num_decision_times = active_df[calendar_t_col_name].nunique()
748
+ avg_decisions_per_subject = len(active_df) / num_subjects
745
749
  num_decision_times_with_multiple_policies = (
746
- in_study_df[in_study_df[policy_num_col_name] >= 0]
750
+ active_df[active_df[policy_num_col_name] >= 0]
747
751
  .groupby(calendar_t_col_name)[policy_num_col_name]
748
752
  .nunique()
749
753
  > 1
750
754
  ).sum()
751
- min_action_prob = in_study_df[action_prob_col_name].min()
752
- max_action_prob = in_study_df[action_prob_col_name].max()
753
- min_non_fallback_policy_num = in_study_df[in_study_df[policy_num_col_name] >= 0][
755
+ min_action_prob = active_df[action_prob_col_name].min()
756
+ max_action_prob = active_df[action_prob_col_name].max()
757
+ min_non_fallback_policy_num = active_df[active_df[policy_num_col_name] >= 0][
754
758
  policy_num_col_name
755
759
  ].min()
756
760
  num_data_points_before_first_update = len(
757
- in_study_df[in_study_df[policy_num_col_name] == min_non_fallback_policy_num]
761
+ active_df[active_df[policy_num_col_name] == min_non_fallback_policy_num]
758
762
  )
759
763
 
760
764
  median_action_probabilities = (
761
- in_study_df.groupby(calendar_t_col_name)[action_prob_col_name]
762
- .median()
763
- .to_numpy()
765
+ active_df.groupby(calendar_t_col_name)[action_prob_col_name].median().to_numpy()
764
766
  )
765
- quartiles = in_study_df.groupby(calendar_t_col_name)[action_prob_col_name].quantile(
767
+ quartiles = active_df.groupby(calendar_t_col_name)[action_prob_col_name].quantile(
766
768
  [0.25, 0.75]
767
769
  )
768
770
  q25_action_probabilities = quartiles.xs(0.25, level=1).to_numpy()
769
771
  q75_action_probabilities = quartiles.xs(0.75, level=1).to_numpy()
770
772
 
771
- avg_rewards = in_study_df.groupby(calendar_t_col_name)[reward_col_name].mean()
773
+ avg_rewards = active_df.groupby(calendar_t_col_name)[reward_col_name].mean()
772
774
 
773
775
  # Plot action probability quartile trajectories
774
776
  plt.clear_figure()
@@ -808,10 +810,10 @@ def verify_study_df_summary_satisfactory(
808
810
 
809
811
  confirm_input_check_result(
810
812
  f"\nYou provided a study dataframe reflecting a study with"
811
- f"\n* {num_users} users"
813
+ f"\n* {num_subjects} subjects"
812
814
  f"\n* {num_non_initial_or_fallback_policies} policy updates"
813
- f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_user}"
814
- f" decisions per user"
815
+ f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_subject}"
816
+ f" decisions per subject"
815
817
  f"\n* RL parameters of dimension {beta_dim} per update"
816
818
  f"\n* Inferential target of dimension {theta_dim}"
817
819
  f"\n* {num_data_points_before_first_update} data points before the first update"
@@ -834,14 +836,14 @@ def require_betas_match_in_alg_update_args_each_update(
834
836
  alg_update_func_args, alg_update_func_args_beta_index
835
837
  ):
836
838
  logger.info(
837
- "Checking that betas match across users for each update in the algorithm update function args."
839
+ "Checking that betas match across subjects for each update in the algorithm update function args."
838
840
  )
839
841
  for policy_num in alg_update_func_args:
840
842
  first_beta = None
841
- for user_id in alg_update_func_args[policy_num]:
842
- if not alg_update_func_args[policy_num][user_id]:
843
+ for subject_id in alg_update_func_args[policy_num]:
844
+ if not alg_update_func_args[policy_num][subject_id]:
843
845
  continue
844
- beta = alg_update_func_args[policy_num][user_id][
846
+ beta = alg_update_func_args[policy_num][subject_id][
845
847
  alg_update_func_args_beta_index
846
848
  ]
847
849
  if first_beta is None:
@@ -849,23 +851,24 @@ def require_betas_match_in_alg_update_args_each_update(
849
851
  else:
850
852
  assert np.array_equal(
851
853
  beta, first_beta
852
- ), f"Betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
854
+ ), f"Betas do not match across subjects in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
855
+
853
856
 
854
857
  def require_previous_betas_match_in_alg_update_args_each_update(
855
858
  alg_update_func_args, alg_update_func_args_previous_betas_index
856
859
  ):
857
860
  logger.info(
858
- "Checking that previous betas match across users for each update in the algorithm update function args."
861
+ "Checking that previous betas match across subjects for each update in the algorithm update function args."
859
862
  )
860
863
  if alg_update_func_args_previous_betas_index < 0:
861
- return
862
-
864
+ return
865
+
863
866
  for policy_num in alg_update_func_args:
864
867
  first_previous_betas = None
865
- for user_id in alg_update_func_args[policy_num]:
866
- if not alg_update_func_args[policy_num][user_id]:
868
+ for subject_id in alg_update_func_args[policy_num]:
869
+ if not alg_update_func_args[policy_num][subject_id]:
867
870
  continue
868
- previous_betas = alg_update_func_args[policy_num][user_id][
871
+ previous_betas = alg_update_func_args[policy_num][subject_id][
869
872
  alg_update_func_args_previous_betas_index
870
873
  ]
871
874
  if first_previous_betas is None:
@@ -873,21 +876,21 @@ def require_previous_betas_match_in_alg_update_args_each_update(
873
876
  else:
874
877
  assert np.array_equal(
875
878
  previous_betas, first_previous_betas
876
- ), f"Previous betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
879
+ ), f"Previous betas do not match across subjects in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
877
880
 
878
881
 
879
882
  def require_betas_match_in_action_prob_func_args_each_decision(
880
883
  action_prob_func_args, action_prob_func_args_beta_index
881
884
  ):
882
885
  logger.info(
883
- "Checking that betas match across users for each decision time in the action prob args."
886
+ "Checking that betas match across subjects for each decision time in the action prob args."
884
887
  )
885
888
  for decision_time in action_prob_func_args:
886
889
  first_beta = None
887
- for user_id in action_prob_func_args[decision_time]:
888
- if not action_prob_func_args[decision_time][user_id]:
890
+ for subject_id in action_prob_func_args[decision_time]:
891
+ if not action_prob_func_args[decision_time][subject_id]:
889
892
  continue
890
- beta = action_prob_func_args[decision_time][user_id][
893
+ beta = action_prob_func_args[decision_time][subject_id][
891
894
  action_prob_func_args_beta_index
892
895
  ]
893
896
  if first_beta is None:
@@ -895,11 +898,11 @@ def require_betas_match_in_action_prob_func_args_each_decision(
895
898
  else:
896
899
  assert np.array_equal(
897
900
  beta, first_beta
898
- ), f"Betas do not match across users in the action prob args for decision_time {decision_time}. Please see the contract for details."
901
+ ), f"Betas do not match across subjects in the action prob args for decision_time {decision_time}. Please see the contract for details."
899
902
 
900
903
 
901
904
  def require_valid_action_prob_times_given_if_index_supplied(
902
- study_df,
905
+ analysis_df,
903
906
  calendar_t_col_name,
904
907
  alg_update_func_args,
905
908
  alg_update_func_args_action_prob_times_index,
@@ -909,19 +912,19 @@ def require_valid_action_prob_times_given_if_index_supplied(
909
912
  if alg_update_func_args_action_prob_times_index < 0:
910
913
  return
911
914
 
912
- min_time = study_df[calendar_t_col_name].min()
913
- max_time = study_df[calendar_t_col_name].max()
914
- for policy_idx, args_by_user in alg_update_func_args.items():
915
- for user_id, args in args_by_user.items():
915
+ min_time = analysis_df[calendar_t_col_name].min()
916
+ max_time = analysis_df[calendar_t_col_name].max()
917
+ for policy_idx, args_by_subject in alg_update_func_args.items():
918
+ for subject_id, args in args_by_subject.items():
916
919
  if not args:
917
920
  continue
918
921
  times = args[alg_update_func_args_action_prob_times_index]
919
922
  assert (
920
923
  times[i] > times[i - 1] for i in range(1, len(times))
921
- ), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for user {user_id} and policy {policy_idx}. Please see the contract for details."
924
+ ), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for subject {subject_id} and policy {policy_idx}. Please see the contract for details."
922
925
  assert (
923
926
  times[0] >= min_time and times[-1] <= max_time
924
- ), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the study dataframe are {min_time} and {max_time}, while user {user_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
927
+ ), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the study dataframe are {min_time} and {max_time}, while subject {subject_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
925
928
 
926
929
 
927
930
  def require_estimating_functions_sum_to_zero(
@@ -933,15 +936,15 @@ def require_estimating_functions_sum_to_zero(
933
936
  """
934
937
  This is a test that the correct loss/estimating functions have
935
938
  been given for both the algorithm updates and inference. If that is true, then the
936
- loss/estimating functions when evaluated should sum to approximately zero across users. These
937
- values have been stacked and averaged across users in mean_estimating_function_stack, which
939
+ loss/estimating functions when evaluated should sum to approximately zero across subjects. These
940
+ values have been stacked and averaged across subjects in mean_estimating_function_stack, which
938
941
  we simply compare to the zero vector. We can isolate components for each update and inference
939
942
  by considering the dimensions of the beta vectors and the theta vector.
940
943
 
941
944
  Inputs:
942
945
  mean_estimating_function_stack:
943
946
  The mean of the estimating function stack (a component for each algorithm update and
944
- inference) across users. This should be a 1D array.
947
+ inference) across subjects. This should be a 1D array.
945
948
  beta_dim:
946
949
  The dimension of the beta vectors that parameterize the algorithm.
947
950
  theta_dim:
@@ -951,7 +954,7 @@ def require_estimating_functions_sum_to_zero(
951
954
  None
952
955
  """
953
956
 
954
- logger.info("Checking that estimating functions average to zero across users")
957
+ logger.info("Checking that estimating functions average to zero across subjects")
955
958
 
956
959
  # Have a looser hard failure cutoff before the typical interactive check
957
960
  try:
@@ -962,7 +965,7 @@ def require_estimating_functions_sum_to_zero(
962
965
  )
963
966
  except AssertionError as e:
964
967
  logger.info(
965
- "Estimating function stacks do not average to within loose tolerance of zero across users. Drilling in to specific updates and inference component."
968
+ "Estimating function stacks do not average to within loose tolerance of zero across subjects. Drilling in to specific updates and inference component."
966
969
  )
967
970
  # If this is not true there is an internal problem in the package.
968
971
  assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
@@ -987,11 +990,11 @@ def require_estimating_functions_sum_to_zero(
987
990
  np.testing.assert_allclose(
988
991
  mean_estimating_function_stack,
989
992
  jnp.zeros(mean_estimating_function_stack.size),
990
- atol=1e-5,
993
+ atol=5e-4,
991
994
  )
992
995
  except AssertionError as e:
993
996
  logger.info(
994
- "Estimating function stacks do not average to within specified tolerance of zero across users. Drilling in to specific updates and inference component."
997
+ "Estimating function stacks do not average to within specified tolerance of zero across subjects. Drilling in to specific updates and inference component."
995
998
  )
996
999
  # If this is not true there is an internal problem in the package.
997
1000
  assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
@@ -1021,15 +1024,15 @@ def require_RL_estimating_functions_sum_to_zero(
1021
1024
  """
1022
1025
  This is a test that the correct loss/estimating functions have
1023
1026
  been given for both the algorithm updates and inference. If that is true, then the
1024
- loss/estimating functions when evaluated should sum to approximately zero across users. These
1025
- values have been stacked and averaged across users in mean_estimating_function_stack, which
1027
+ loss/estimating functions when evaluated should sum to approximately zero across subjects. These
1028
+ values have been stacked and averaged across subjects in mean_estimating_function_stack, which
1026
1029
  we simply compare to the zero vector. We can isolate components for each update and inference
1027
1030
  by considering the dimensions of the beta vectors and the theta vector.
1028
1031
 
1029
1032
  Inputs:
1030
1033
  mean_estimating_function_stack:
1031
1034
  The mean of the estimating function stack (a component for each algorithm update and
1032
- inference) across users. This should be a 1D array.
1035
+ inference) across subjects. This should be a 1D array.
1033
1036
  beta_dim:
1034
1037
  The dimension of the beta vectors that parameterize the algorithm.
1035
1038
  theta_dim:
@@ -1039,7 +1042,7 @@ def require_RL_estimating_functions_sum_to_zero(
1039
1042
  None
1040
1043
  """
1041
1044
 
1042
- logger.info("Checking that RL estimating functions average to zero across users")
1045
+ logger.info("Checking that RL estimating functions average to zero across subjects")
1043
1046
 
1044
1047
  # Have a looser hard failure cutoff before the typical interactive check
1045
1048
  try:
@@ -1050,7 +1053,7 @@ def require_RL_estimating_functions_sum_to_zero(
1050
1053
  )
1051
1054
  except AssertionError as e:
1052
1055
  logger.info(
1053
- "RL estimating function stacks do not average to zero across users. Drilling in to specific updates and inference component."
1056
+ "RL estimating function stacks do not average to zero across subjects. Drilling in to specific updates and inference component."
1054
1057
  )
1055
1058
  num_updates = (mean_estimating_function_stack.size) // beta_dim
1056
1059
  for i in range(num_updates):
@@ -1070,7 +1073,7 @@ def require_RL_estimating_functions_sum_to_zero(
1070
1073
  )
1071
1074
  except AssertionError as e:
1072
1075
  logger.info(
1073
- "RL estimating function stacks do not average to zero across users. Drilling in to specific updates and inference component."
1076
+ "RL estimating function stacks do not average to zero across subjects. Drilling in to specific updates and inference component."
1074
1077
  )
1075
1078
  num_updates = (mean_estimating_function_stack.size) // beta_dim
1076
1079
  for i in range(num_updates):
@@ -1079,7 +1082,6 @@ def require_RL_estimating_functions_sum_to_zero(
1079
1082
  i + 1,
1080
1083
  mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
1081
1084
  )
1082
- # TODO: Email instead of requiring user input for monitoring alg.
1083
1085
  confirm_input_check_result(
1084
1086
  f"\nEstimating functions do not average to within default tolerance of zero vector. Please decide if the following is a reasonable result, taking into account the above breakdown by update number and inference. If not, there are several possible reasons for failure mentioned in the contract. Results:\n{str(e)}\n\nContinue? (y/n)\n",
1085
1087
  suppress_interactive_data_checks,
@@ -1133,8 +1135,8 @@ def require_adaptive_bread_inverse_is_true_inverse(
1133
1135
 
1134
1136
  def require_threaded_algorithm_estimating_function_args_equivalent(
1135
1137
  algorithm_estimating_func,
1136
- update_func_args_by_by_user_id_by_policy_num,
1137
- threaded_update_func_args_by_policy_num_by_user_id,
1138
+ update_func_args_by_by_subject_id_by_policy_num,
1139
+ threaded_update_func_args_by_policy_num_by_subject_id,
1138
1140
  suppress_interactive_data_checks,
1139
1141
  ):
1140
1142
  """
@@ -1144,12 +1146,12 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
1144
1146
  """
1145
1147
  for (
1146
1148
  policy_num,
1147
- update_func_args_by_user_id,
1148
- ) in update_func_args_by_by_user_id_by_policy_num.items():
1149
+ update_func_args_by_subject_id,
1150
+ ) in update_func_args_by_by_subject_id_by_policy_num.items():
1149
1151
  for (
1150
- user_id,
1152
+ subject_id,
1151
1153
  unthreaded_args,
1152
- ) in update_func_args_by_user_id.items():
1154
+ ) in update_func_args_by_subject_id.items():
1153
1155
  if not unthreaded_args:
1154
1156
  continue
1155
1157
  np.testing.assert_allclose(
@@ -1157,9 +1159,9 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
1157
1159
  # Need to stop gradient here because we can't convert a traced value to np array
1158
1160
  jax.lax.stop_gradient(
1159
1161
  algorithm_estimating_func(
1160
- *threaded_update_func_args_by_policy_num_by_user_id[user_id][
1161
- policy_num
1162
- ]
1162
+ *threaded_update_func_args_by_policy_num_by_subject_id[
1163
+ subject_id
1164
+ ][policy_num]
1163
1165
  )
1164
1166
  ),
1165
1167
  atol=1e-7,
@@ -1169,8 +1171,8 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
1169
1171
 
1170
1172
  def require_threaded_inference_estimating_function_args_equivalent(
1171
1173
  inference_estimating_func,
1172
- inference_func_args_by_user_id,
1173
- threaded_inference_func_args_by_user_id,
1174
+ inference_func_args_by_subject_id,
1175
+ threaded_inference_func_args_by_subject_id,
1174
1176
  suppress_interactive_data_checks,
1175
1177
  ):
1176
1178
  """
@@ -1178,7 +1180,7 @@ def require_threaded_inference_estimating_function_args_equivalent(
1178
1180
  when called with the original arguments and when called with the
1179
1181
  reconstructed action probabilities substituted in.
1180
1182
  """
1181
- for user_id, unthreaded_args in inference_func_args_by_user_id.items():
1183
+ for subject_id, unthreaded_args in inference_func_args_by_subject_id.items():
1182
1184
  if not unthreaded_args:
1183
1185
  continue
1184
1186
  np.testing.assert_allclose(
@@ -1186,7 +1188,7 @@ def require_threaded_inference_estimating_function_args_equivalent(
1186
1188
  # Need to stop gradient here because we can't convert a traced value to np array
1187
1189
  jax.lax.stop_gradient(
1188
1190
  inference_estimating_func(
1189
- *threaded_inference_func_args_by_user_id[user_id]
1191
+ *threaded_inference_func_args_by_subject_id[subject_id]
1190
1192
  )
1191
1193
  ),
1192
1194
  rtol=1e-2,