lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ from jax import numpy as jnp
8
8
  import pandas as pd
9
9
  import plotext as plt
10
10
 
11
- from .constants import InverseStabilizationMethods, SmallSampleCorrections
11
+ from .constants import SmallSampleCorrections
12
12
  from .helper_functions import (
13
13
  confirm_input_check_result,
14
14
  )
@@ -26,12 +26,12 @@ logging.basicConfig(
26
26
 
27
27
  # TODO: any checks needed here about alg update function type?
28
28
  def perform_first_wave_input_checks(
29
- study_df,
30
- in_study_col_name,
29
+ analysis_df,
30
+ active_col_name,
31
31
  action_col_name,
32
32
  policy_num_col_name,
33
33
  calendar_t_col_name,
34
- user_id_col_name,
34
+ subject_id_col_name,
35
35
  action_prob_col_name,
36
36
  reward_col_name,
37
37
  action_prob_func,
@@ -48,11 +48,11 @@ def perform_first_wave_input_checks(
48
48
  small_sample_correction,
49
49
  ):
50
50
  ### Validate algorithm loss/estimating function and args
51
- require_alg_update_args_given_for_all_users_at_each_update(
52
- study_df, user_id_col_name, alg_update_func_args
51
+ require_alg_update_args_given_for_all_subjects_at_each_update(
52
+ analysis_df, subject_id_col_name, alg_update_func_args
53
53
  )
54
- require_no_policy_numbers_present_in_alg_update_args_but_not_study_df(
55
- study_df, policy_num_col_name, alg_update_func_args
54
+ require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
55
+ analysis_df, policy_num_col_name, alg_update_func_args
56
56
  )
57
57
  require_beta_is_1D_array_in_alg_update_args(
58
58
  alg_update_func_args, alg_update_func_args_beta_index
@@ -60,11 +60,13 @@ def perform_first_wave_input_checks(
60
60
  require_previous_betas_is_2D_array_in_alg_update_args(
61
61
  alg_update_func_args, alg_update_func_args_previous_betas_index
62
62
  )
63
- require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
64
- study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
63
+ require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
64
+ analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
65
65
  )
66
66
  confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
67
- alg_update_func_args_action_prob_index, suppress_interactive_data_checks
67
+ alg_update_func_args_action_prob_index,
68
+ alg_update_func_args_previous_betas_index,
69
+ suppress_interactive_data_checks,
68
70
  )
69
71
  require_action_prob_times_given_if_index_supplied(
70
72
  alg_update_func_args_action_prob_index,
@@ -80,17 +82,17 @@ def perform_first_wave_input_checks(
80
82
  require_previous_betas_match_in_alg_update_args_each_update(
81
83
  alg_update_func_args, alg_update_func_args_previous_betas_index
82
84
  )
83
- require_action_prob_args_in_alg_update_func_correspond_to_study_df(
84
- study_df,
85
+ require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
86
+ analysis_df,
85
87
  action_prob_col_name,
86
88
  calendar_t_col_name,
87
- user_id_col_name,
89
+ subject_id_col_name,
88
90
  alg_update_func_args,
89
91
  alg_update_func_args_action_prob_index,
90
92
  alg_update_func_args_action_prob_times_index,
91
93
  )
92
94
  require_valid_action_prob_times_given_if_index_supplied(
93
- study_df,
95
+ analysis_df,
94
96
  calendar_t_col_name,
95
97
  alg_update_func_args,
96
98
  alg_update_func_args_action_prob_times_index,
@@ -101,28 +103,28 @@ def perform_first_wave_input_checks(
101
103
  )
102
104
 
103
105
  ### Validate action prob function and args
104
- require_action_prob_func_args_given_for_all_users_at_each_decision(
105
- study_df, user_id_col_name, action_prob_func_args
106
+ require_action_prob_func_args_given_for_all_subjects_at_each_decision(
107
+ analysis_df, subject_id_col_name, action_prob_func_args
106
108
  )
107
109
  require_action_prob_func_args_given_for_all_decision_times(
108
- study_df, calendar_t_col_name, action_prob_func_args
110
+ analysis_df, calendar_t_col_name, action_prob_func_args
109
111
  )
110
- require_action_probabilities_in_study_df_can_be_reconstructed(
111
- study_df,
112
+ require_action_probabilities_in_analysis_df_can_be_reconstructed(
113
+ analysis_df,
112
114
  action_prob_col_name,
113
115
  calendar_t_col_name,
114
- user_id_col_name,
115
- in_study_col_name,
116
+ subject_id_col_name,
117
+ active_col_name,
116
118
  action_prob_func_args,
117
119
  action_prob_func,
118
120
  )
119
121
 
120
122
  require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
121
- study_df,
123
+ analysis_df,
122
124
  calendar_t_col_name,
123
125
  action_prob_func_args,
124
- in_study_col_name,
125
- user_id_col_name,
126
+ active_col_name,
127
+ subject_id_col_name,
126
128
  )
127
129
  require_beta_is_1D_array_in_action_prob_args(
128
130
  action_prob_func_args, action_prob_func_args_beta_index
@@ -131,13 +133,13 @@ def perform_first_wave_input_checks(
131
133
  action_prob_func_args, action_prob_func_args_beta_index
132
134
  )
133
135
 
134
- ### Validate study_df
135
- verify_study_df_summary_satisfactory(
136
- study_df,
137
- user_id_col_name,
136
+ ### Validate analysis_df
137
+ verify_analysis_df_summary_satisfactory(
138
+ analysis_df,
139
+ subject_id_col_name,
138
140
  policy_num_col_name,
139
141
  calendar_t_col_name,
140
- in_study_col_name,
142
+ active_col_name,
141
143
  action_prob_col_name,
142
144
  reward_col_name,
143
145
  beta_dim,
@@ -145,46 +147,46 @@ def perform_first_wave_input_checks(
145
147
  suppress_interactive_data_checks,
146
148
  )
147
149
 
148
- require_all_users_have_all_times_in_study_df(
149
- study_df, calendar_t_col_name, user_id_col_name
150
+ require_all_subjects_have_all_times_in_analysis_df(
151
+ analysis_df, calendar_t_col_name, subject_id_col_name
150
152
  )
151
- require_all_named_columns_present_in_study_df(
152
- study_df,
153
- in_study_col_name,
153
+ require_all_named_columns_present_in_analysis_df(
154
+ analysis_df,
155
+ active_col_name,
154
156
  action_col_name,
155
157
  policy_num_col_name,
156
158
  calendar_t_col_name,
157
- user_id_col_name,
159
+ subject_id_col_name,
158
160
  action_prob_col_name,
159
161
  )
160
- require_all_named_columns_not_object_type_in_study_df(
161
- study_df,
162
- in_study_col_name,
162
+ require_all_named_columns_not_object_type_in_analysis_df(
163
+ analysis_df,
164
+ active_col_name,
163
165
  action_col_name,
164
166
  policy_num_col_name,
165
167
  calendar_t_col_name,
166
- user_id_col_name,
168
+ subject_id_col_name,
167
169
  action_prob_col_name,
168
170
  )
169
- require_binary_actions(study_df, in_study_col_name, action_col_name)
170
- require_binary_in_study_indicators(study_df, in_study_col_name)
171
+ require_binary_actions(analysis_df, active_col_name, action_col_name)
172
+ require_binary_active_indicators(analysis_df, active_col_name)
171
173
  require_consecutive_integer_policy_numbers(
172
- study_df, in_study_col_name, policy_num_col_name
174
+ analysis_df, active_col_name, policy_num_col_name
173
175
  )
174
- require_consecutive_integer_calendar_times(study_df, calendar_t_col_name)
175
- require_hashable_user_ids(study_df, in_study_col_name, user_id_col_name)
176
- require_action_probabilities_in_range_0_to_1(study_df, action_prob_col_name)
176
+ require_consecutive_integer_calendar_times(analysis_df, calendar_t_col_name)
177
+ require_hashable_subject_ids(analysis_df, active_col_name, subject_id_col_name)
178
+ require_action_probabilities_in_range_0_to_1(analysis_df, action_prob_col_name)
177
179
 
178
180
  ### Validate theta estimation
179
181
  require_theta_is_1D_array(theta_est)
180
182
 
181
183
 
182
184
  def perform_alg_only_input_checks(
183
- study_df,
184
- in_study_col_name,
185
+ analysis_df,
186
+ active_col_name,
185
187
  policy_num_col_name,
186
188
  calendar_t_col_name,
187
- user_id_col_name,
189
+ subject_id_col_name,
188
190
  action_prob_col_name,
189
191
  action_prob_func,
190
192
  action_prob_func_args,
@@ -193,20 +195,23 @@ def perform_alg_only_input_checks(
193
195
  alg_update_func_args_beta_index,
194
196
  alg_update_func_args_action_prob_index,
195
197
  alg_update_func_args_action_prob_times_index,
198
+ alg_update_func_args_previous_betas_index,
196
199
  suppress_interactive_data_checks,
197
200
  ):
198
201
  ### Validate algorithm loss/estimating function and args
199
- require_alg_update_args_given_for_all_users_at_each_update(
200
- study_df, user_id_col_name, alg_update_func_args
202
+ require_alg_update_args_given_for_all_subjects_at_each_update(
203
+ analysis_df, subject_id_col_name, alg_update_func_args
201
204
  )
202
205
  require_beta_is_1D_array_in_alg_update_args(
203
206
  alg_update_func_args, alg_update_func_args_beta_index
204
207
  )
205
- require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
206
- study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
208
+ require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
209
+ analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
207
210
  )
208
211
  confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
209
- alg_update_func_args_action_prob_index, suppress_interactive_data_checks
212
+ alg_update_func_args_action_prob_index,
213
+ alg_update_func_args_previous_betas_index,
214
+ suppress_interactive_data_checks,
210
215
  )
211
216
  require_action_prob_times_given_if_index_supplied(
212
217
  alg_update_func_args_action_prob_index,
@@ -219,45 +224,45 @@ def perform_alg_only_input_checks(
219
224
  require_betas_match_in_alg_update_args_each_update(
220
225
  alg_update_func_args, alg_update_func_args_beta_index
221
226
  )
222
- require_action_prob_args_in_alg_update_func_correspond_to_study_df(
223
- study_df,
227
+ require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
228
+ analysis_df,
224
229
  action_prob_col_name,
225
230
  calendar_t_col_name,
226
- user_id_col_name,
231
+ subject_id_col_name,
227
232
  alg_update_func_args,
228
233
  alg_update_func_args_action_prob_index,
229
234
  alg_update_func_args_action_prob_times_index,
230
235
  )
231
236
  require_valid_action_prob_times_given_if_index_supplied(
232
- study_df,
237
+ analysis_df,
233
238
  calendar_t_col_name,
234
239
  alg_update_func_args,
235
240
  alg_update_func_args_action_prob_times_index,
236
241
  )
237
242
 
238
243
  ### Validate action prob function and args
239
- require_action_prob_func_args_given_for_all_users_at_each_decision(
240
- study_df, user_id_col_name, action_prob_func_args
244
+ require_action_prob_func_args_given_for_all_subjects_at_each_decision(
245
+ analysis_df, subject_id_col_name, action_prob_func_args
241
246
  )
242
247
  require_action_prob_func_args_given_for_all_decision_times(
243
- study_df, calendar_t_col_name, action_prob_func_args
248
+ analysis_df, calendar_t_col_name, action_prob_func_args
244
249
  )
245
- require_action_probabilities_in_study_df_can_be_reconstructed(
246
- study_df,
250
+ require_action_probabilities_in_analysis_df_can_be_reconstructed(
251
+ analysis_df,
247
252
  action_prob_col_name,
248
253
  calendar_t_col_name,
249
- user_id_col_name,
250
- in_study_col_name,
254
+ subject_id_col_name,
255
+ active_col_name,
251
256
  action_prob_func_args,
252
257
  action_prob_func=action_prob_func,
253
258
  )
254
259
 
255
260
  require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
256
- study_df,
261
+ analysis_df,
257
262
  calendar_t_col_name,
258
263
  action_prob_func_args,
259
- in_study_col_name,
260
- user_id_col_name,
264
+ active_col_name,
265
+ subject_id_col_name,
261
266
  )
262
267
  require_beta_is_1D_array_in_action_prob_args(
263
268
  action_prob_func_args, action_prob_func_args_beta_index
@@ -268,94 +273,96 @@ def perform_alg_only_input_checks(
268
273
 
269
274
 
270
275
  # TODO: Give a hard-to-use option to loosen this check somehow
271
- def require_action_probabilities_in_study_df_can_be_reconstructed(
272
- study_df,
276
+ def require_action_probabilities_in_analysis_df_can_be_reconstructed(
277
+ analysis_df,
273
278
  action_prob_col_name,
274
279
  calendar_t_col_name,
275
- user_id_col_name,
276
- in_study_col_name,
280
+ subject_id_col_name,
281
+ active_col_name,
277
282
  action_prob_func_args,
278
283
  action_prob_func,
279
284
  ):
280
285
  """
281
- Check that the action probabilities in the study dataframe can be reconstructed from the supplied
286
+ Check that the action probabilities in the analysis DataFrame can be reconstructed from the supplied
282
287
  action probability function and its arguments.
283
288
 
284
289
  NOTE THAT THIS IS A HARD FAILURE IF THE RECONSTRUCTION DOESN'T PASS.
285
290
  """
286
291
  logger.info("Reconstructing action probabilities from function and arguments.")
287
292
 
288
- in_study_df = study_df[study_df[in_study_col_name] == 1]
289
- reconstructed_action_probs = in_study_df.apply(
293
+ active_df = analysis_df[analysis_df[active_col_name] == 1]
294
+ reconstructed_action_probs = active_df.apply(
290
295
  lambda row: action_prob_func(
291
- *action_prob_func_args[row[calendar_t_col_name]][row[user_id_col_name]]
296
+ *action_prob_func_args[row[calendar_t_col_name]][row[subject_id_col_name]]
292
297
  ),
293
298
  axis=1,
294
299
  )
295
300
 
296
301
  np.testing.assert_allclose(
297
- in_study_df[action_prob_col_name].to_numpy(dtype="float64"),
302
+ active_df[action_prob_col_name].to_numpy(dtype="float64"),
298
303
  reconstructed_action_probs.to_numpy(dtype="float64"),
299
304
  atol=1e-6,
300
305
  )
301
306
 
302
307
 
303
- def require_all_users_have_all_times_in_study_df(
304
- study_df, calendar_t_col_name, user_id_col_name
308
+ def require_all_subjects_have_all_times_in_analysis_df(
309
+ analysis_df, calendar_t_col_name, subject_id_col_name
305
310
  ):
306
- logger.info("Checking that all users have the same set of unique calendar times.")
311
+ logger.info(
312
+ "Checking that all subjects have the same set of unique calendar times."
313
+ )
307
314
  # Get the unique calendar times
308
- unique_calendar_times = set(study_df[calendar_t_col_name].unique())
315
+ unique_calendar_times = set(analysis_df[calendar_t_col_name].unique())
309
316
 
310
- # Group by user ID and aggregate the unique calendar times for each user
311
- user_calendar_times = study_df.groupby(user_id_col_name)[calendar_t_col_name].apply(
312
- set
313
- )
317
+ # Group by subject ID and aggregate the unique calendar times for each subject
318
+ subject_calendar_times = analysis_df.groupby(subject_id_col_name)[
319
+ calendar_t_col_name
320
+ ].apply(set)
314
321
 
315
- # Check if all users have the same set of unique calendar times
316
- if not user_calendar_times.apply(lambda x: x == unique_calendar_times).all():
322
+ # Check if all subjects have the same set of unique calendar times
323
+ if not subject_calendar_times.apply(lambda x: x == unique_calendar_times).all():
317
324
  raise AssertionError(
318
- "Not all users have all calendar times in the study dataframe. Please see the contract for details."
325
+ "Not all subjects have all calendar times in the analysis DataFrame. Please see the contract for details."
319
326
  )
320
327
 
321
328
 
322
- def require_alg_update_args_given_for_all_users_at_each_update(
323
- study_df, user_id_col_name, alg_update_func_args
329
+ def require_alg_update_args_given_for_all_subjects_at_each_update(
330
+ analysis_df, subject_id_col_name, alg_update_func_args
324
331
  ):
325
332
  logger.info(
326
- "Checking that algorithm update function args are given for all users at each update."
333
+ "Checking that algorithm update function args are given for all subjects at each update."
327
334
  )
328
- all_user_ids = set(study_df[user_id_col_name].unique())
335
+ all_subject_ids = set(analysis_df[subject_id_col_name].unique())
329
336
  for policy_num in alg_update_func_args:
330
337
  assert (
331
- set(alg_update_func_args[policy_num].keys()) == all_user_ids
332
- ), f"Not all users present in algorithm update function args for policy number {policy_num}. Please see the contract for details."
338
+ set(alg_update_func_args[policy_num].keys()) == all_subject_ids
339
+ ), f"Not all subjects present in algorithm update function args for policy number {policy_num}. Please see the contract for details."
333
340
 
334
341
 
335
- def require_action_prob_args_in_alg_update_func_correspond_to_study_df(
336
- study_df,
342
+ def require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
343
+ analysis_df,
337
344
  action_prob_col_name,
338
345
  calendar_t_col_name,
339
- user_id_col_name,
346
+ subject_id_col_name,
340
347
  alg_update_func_args,
341
348
  alg_update_func_args_action_prob_index,
342
349
  alg_update_func_args_action_prob_times_index,
343
350
  ):
344
351
  logger.info(
345
352
  "Checking that the action probabilities supplied in the algorithm update function args, if"
346
- " any, correspond to those in the study dataframe for the corresponding users and decision"
353
+ " any, correspond to those in the analysis DataFrame for the corresponding subjects and decision"
347
354
  " times."
348
355
  )
349
356
  if alg_update_func_args_action_prob_index < 0:
350
357
  return
351
358
 
352
359
  # Precompute a lookup dictionary for faster access
353
- study_df_lookup = study_df.set_index([calendar_t_col_name, user_id_col_name])[
354
- action_prob_col_name
355
- ].to_dict()
360
+ analysis_df_lookup = analysis_df.set_index(
361
+ [calendar_t_col_name, subject_id_col_name]
362
+ )[action_prob_col_name].to_dict()
356
363
 
357
- for policy_num, user_args in alg_update_func_args.items():
358
- for user_id, args in user_args.items():
364
+ for policy_num, subject_args in alg_update_func_args.items():
365
+ for subject_id, args in subject_args.items():
359
366
  if not args:
360
367
  continue
361
368
  arg_action_probs = args[alg_update_func_args_action_prob_index]
@@ -364,43 +371,43 @@ def require_action_prob_args_in_alg_update_func_correspond_to_study_df(
364
371
  ].flatten()
365
372
 
366
373
  # Use the precomputed lookup dictionary
367
- study_df_action_probs = [
368
- study_df_lookup[(decision_time.item(), user_id)]
374
+ analysis_df_action_probs = [
375
+ analysis_df_lookup[(decision_time.item(), subject_id)]
369
376
  for decision_time in action_prob_times
370
377
  ]
371
378
 
372
379
  assert np.allclose(
373
380
  arg_action_probs.flatten(),
374
- study_df_action_probs,
381
+ analysis_df_action_probs,
375
382
  ), (
376
- f"There is a mismatch for user {user_id} between the action probabilities supplied"
383
+ f"There is a mismatch for subject {subject_id} between the action probabilities supplied"
377
384
  f" in the args to the algorithm update function at policy {policy_num} and those in"
378
- " the study dataframe for the supplied times. Please see the contract for details."
385
+ " the analysis DataFrame for the supplied times. Please see the contract for details."
379
386
  )
380
387
 
381
388
 
382
- def require_action_prob_func_args_given_for_all_users_at_each_decision(
383
- study_df,
384
- user_id_col_name,
389
+ def require_action_prob_func_args_given_for_all_subjects_at_each_decision(
390
+ analysis_df,
391
+ subject_id_col_name,
385
392
  action_prob_func_args,
386
393
  ):
387
394
  logger.info(
388
- "Checking that action prob function args are given for all users at each decision time."
395
+ "Checking that action prob function args are given for all subjects at each decision time."
389
396
  )
390
- all_user_ids = set(study_df[user_id_col_name].unique())
397
+ all_subject_ids = set(analysis_df[subject_id_col_name].unique())
391
398
  for decision_time in action_prob_func_args:
392
399
  assert (
393
- set(action_prob_func_args[decision_time].keys()) == all_user_ids
394
- ), f"Not all users present in algorithm update function args for decision time {decision_time}. Please see the contract for details."
400
+ set(action_prob_func_args[decision_time].keys()) == all_subject_ids
401
+ ), f"Not all subjects present in algorithm update function args for decision time {decision_time}. Please see the contract for details."
395
402
 
396
403
 
397
404
  def require_action_prob_func_args_given_for_all_decision_times(
398
- study_df, calendar_t_col_name, action_prob_func_args
405
+ analysis_df, calendar_t_col_name, action_prob_func_args
399
406
  ):
400
407
  logger.info(
401
408
  "Checking that action prob function args are given for all decision times."
402
409
  )
403
- all_times = set(study_df[calendar_t_col_name].unique())
410
+ all_times = set(analysis_df[calendar_t_col_name].unique())
404
411
 
405
412
  assert (
406
413
  set(action_prob_func_args.keys()) == all_times
@@ -408,106 +415,112 @@ def require_action_prob_func_args_given_for_all_decision_times(
408
415
 
409
416
 
410
417
  def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
411
- study_df: pd.DataFrame,
418
+ analysis_df: pd.DataFrame,
412
419
  calendar_t_col_name: str,
413
420
  action_prob_func_args: dict[str, dict[str, tuple[Any, ...]]],
414
- in_study_col_name,
415
- user_id_col_name,
421
+ active_col_name,
422
+ subject_id_col_name,
416
423
  ):
417
424
  logger.info(
418
- "Checking that action probability function args are blank for exactly the times each user"
419
- " is not in the study according to the study dataframe."
425
+ "Checking that action probability function args are blank for exactly the times each subject"
426
+ "is not in the study according to the analysis DataFrame."
420
427
  )
421
- out_of_study_df = study_df[study_df[in_study_col_name] == 0]
422
- out_of_study_times_by_user_according_to_study_df = (
423
- out_of_study_df.groupby(user_id_col_name)[calendar_t_col_name]
428
+ inactive_df = analysis_df[analysis_df[active_col_name] == 0]
429
+ inactive_times_by_subject_according_to_analysis_df = (
430
+ inactive_df.groupby(subject_id_col_name)[calendar_t_col_name]
424
431
  .apply(set)
425
432
  .to_dict()
426
433
  )
427
434
 
428
- out_of_study_times_by_user_according_to_action_prob_func_args = (
435
+ inactive_times_by_subject_according_to_action_prob_func_args = (
429
436
  collections.defaultdict(set)
430
437
  )
431
- for decision_time, action_prob_args_by_user in action_prob_func_args.items():
432
- for user_id, action_prob_args in action_prob_args_by_user.items():
438
+ for decision_time, action_prob_args_by_subject in action_prob_func_args.items():
439
+ for subject_id, action_prob_args in action_prob_args_by_subject.items():
433
440
  if not action_prob_args:
434
- out_of_study_times_by_user_according_to_action_prob_func_args[
435
- user_id
441
+ inactive_times_by_subject_according_to_action_prob_func_args[
442
+ subject_id
436
443
  ].add(decision_time)
437
444
 
438
445
  assert (
439
- out_of_study_times_by_user_according_to_study_df
440
- == out_of_study_times_by_user_according_to_action_prob_func_args
446
+ inactive_times_by_subject_according_to_analysis_df
447
+ == inactive_times_by_subject_according_to_action_prob_func_args
441
448
  ), (
442
- "Out-of-study decision times according to the study dataframe do not match up with the"
443
- " times for which action probability arguments are blank for all users. Please see the"
449
+ "Inactive decision times according to the analysis DataFrame do not match up with the"
450
+ " times for which action probability arguments are blank for all subjects. Please see the"
444
451
  " contract for details."
445
452
  )
446
453
 
447
454
 
448
- def require_all_named_columns_present_in_study_df(
449
- study_df,
450
- in_study_col_name,
455
+ def require_all_named_columns_present_in_analysis_df(
456
+ analysis_df,
457
+ active_col_name,
451
458
  action_col_name,
452
459
  policy_num_col_name,
453
460
  calendar_t_col_name,
454
- user_id_col_name,
461
+ subject_id_col_name,
455
462
  action_prob_col_name,
456
463
  ):
457
- logger.info("Checking that all named columns are present in the study dataframe.")
464
+ logger.info(
465
+ "Checking that all named columns are present in the analysis DataFrame."
466
+ )
467
+ assert (
468
+ active_col_name in analysis_df.columns
469
+ ), f"{active_col_name} not in analysis DataFrame."
470
+ assert (
471
+ action_col_name in analysis_df.columns
472
+ ), f"{action_col_name} not in analysis DataFrame."
458
473
  assert (
459
- in_study_col_name in study_df.columns
460
- ), f"{in_study_col_name} not in study df."
461
- assert action_col_name in study_df.columns, f"{action_col_name} not in study df."
474
+ policy_num_col_name in analysis_df.columns
475
+ ), f"{policy_num_col_name} not in analysis DataFrame."
462
476
  assert (
463
- policy_num_col_name in study_df.columns
464
- ), f"{policy_num_col_name} not in study df."
477
+ calendar_t_col_name in analysis_df.columns
478
+ ), f"{calendar_t_col_name} not in analysis DataFrame."
465
479
  assert (
466
- calendar_t_col_name in study_df.columns
467
- ), f"{calendar_t_col_name} not in study df."
468
- assert user_id_col_name in study_df.columns, f"{user_id_col_name} not in study df."
480
+ subject_id_col_name in analysis_df.columns
481
+ ), f"{subject_id_col_name} not in analysis DataFrame."
469
482
  assert (
470
- action_prob_col_name in study_df.columns
471
- ), f"{action_prob_col_name} not in study df."
483
+ action_prob_col_name in analysis_df.columns
484
+ ), f"{action_prob_col_name} not in analysis DataFrame."
472
485
 
473
486
 
474
- def require_all_named_columns_not_object_type_in_study_df(
475
- study_df,
476
- in_study_col_name,
487
+ def require_all_named_columns_not_object_type_in_analysis_df(
488
+ analysis_df,
489
+ active_col_name,
477
490
  action_col_name,
478
491
  policy_num_col_name,
479
492
  calendar_t_col_name,
480
- user_id_col_name,
493
+ subject_id_col_name,
481
494
  action_prob_col_name,
482
495
  ):
483
496
  logger.info("Checking that all named columns are not type object.")
484
497
  for colname in (
485
- in_study_col_name,
498
+ active_col_name,
486
499
  action_col_name,
487
500
  policy_num_col_name,
488
501
  calendar_t_col_name,
489
- user_id_col_name,
502
+ subject_id_col_name,
490
503
  action_prob_col_name,
491
504
  ):
492
505
  assert (
493
- study_df[colname].dtype != "object"
494
- ), f"At least {colname} is of object type in study df."
506
+ analysis_df[colname].dtype != "object"
507
+ ), f"At least {colname} is of object type in analysis DataFrame."
495
508
 
496
509
 
497
- def require_binary_actions(study_df, in_study_col_name, action_col_name):
510
+ def require_binary_actions(analysis_df, active_col_name, action_col_name):
498
511
  logger.info("Checking that actions are binary.")
499
512
  assert (
500
- study_df[study_df[in_study_col_name] == 1][action_col_name]
513
+ analysis_df[analysis_df[active_col_name] == 1][action_col_name]
501
514
  .astype("int64")
502
515
  .isin([0, 1])
503
516
  .all()
504
517
  ), "Actions are not binary."
505
518
 
506
519
 
507
- def require_binary_in_study_indicators(study_df, in_study_col_name):
508
- logger.info("Checking that in-study indicators are binary.")
520
+ def require_binary_active_indicators(analysis_df, active_col_name):
521
+ logger.info("Checking that active indicators are binary.")
509
522
  assert (
510
- study_df[study_df[in_study_col_name] == 1][in_study_col_name]
523
+ analysis_df[analysis_df[active_col_name] == 1][active_col_name]
511
524
  .astype("int64")
512
525
  .isin([0, 1])
513
526
  .all()
@@ -515,7 +528,7 @@ def require_binary_in_study_indicators(study_df, in_study_col_name):
515
528
 
516
529
 
517
530
  def require_consecutive_integer_policy_numbers(
518
- study_df, in_study_col_name, policy_num_col_name
531
+ analysis_df, active_col_name, policy_num_col_name
519
532
  ):
520
533
  # TODO: This is a somewhat rough check of this, could also check nondecreasing temporally
521
534
 
@@ -523,8 +536,8 @@ def require_consecutive_integer_policy_numbers(
523
536
  "Checking that in-study, non-fallback policy numbers are consecutive integers."
524
537
  )
525
538
 
526
- in_study_df = study_df[study_df[in_study_col_name] == 1]
527
- nonnegative_policy_df = in_study_df[in_study_df[policy_num_col_name] >= 0]
539
+ active_df = analysis_df[analysis_df[active_col_name] == 1]
540
+ nonnegative_policy_df = active_df[active_df[policy_num_col_name] >= 0]
528
541
  # Ideally we actually have integers, but for legacy reasons we will support
529
542
  # floats as well.
530
543
  if nonnegative_policy_df[policy_num_col_name].dtype == "float64":
@@ -540,78 +553,83 @@ def require_consecutive_integer_policy_numbers(
540
553
  ), "Policy numbers are not consecutive integers."
541
554
 
542
555
 
543
- def require_consecutive_integer_calendar_times(study_df, calendar_t_col_name):
556
+ def require_consecutive_integer_calendar_times(analysis_df, calendar_t_col_name):
544
557
  # This is a somewhat rough check of this, more like checking there are no
545
- # gaps in the integers covered. But we have other checks that all users
558
+ # gaps in the integers covered. But we have other checks that all subjects
546
559
  # have same times, etc.
547
- # Note these times should be well-formed even when the user is not in the study.
560
+ # Note these times should be well-formed even when the subject is not in the study.
548
561
  logger.info("Checking that calendar times are consecutive integers.")
549
562
  assert np.array_equal(
550
- study_df[calendar_t_col_name].unique(),
563
+ analysis_df[calendar_t_col_name].unique(),
551
564
  range(
552
- study_df[calendar_t_col_name].min(), study_df[calendar_t_col_name].max() + 1
565
+ analysis_df[calendar_t_col_name].min(),
566
+ analysis_df[calendar_t_col_name].max() + 1,
553
567
  ),
554
568
  ), "Calendar times are not consecutive integers."
555
569
 
556
570
 
557
- def require_hashable_user_ids(study_df, in_study_col_name, user_id_col_name):
558
- logger.info("Checking that user IDs are hashable.")
571
+ def require_hashable_subject_ids(analysis_df, active_col_name, subject_id_col_name):
572
+ logger.info("Checking that subject IDs are hashable.")
559
573
  isinstance(
560
- study_df[study_df[in_study_col_name] == 1][user_id_col_name][0],
574
+ analysis_df[analysis_df[active_col_name] == 1][subject_id_col_name][0],
561
575
  collections.abc.Hashable,
562
576
  )
563
577
 
564
578
 
565
- def require_action_probabilities_in_range_0_to_1(study_df, action_prob_col_name):
579
+ def require_action_probabilities_in_range_0_to_1(analysis_df, action_prob_col_name):
566
580
  logger.info("Checking that action probabilities are in the interval (0, 1).")
567
- study_df[action_prob_col_name].between(0, 1, inclusive="neither").all()
581
+ analysis_df[action_prob_col_name].between(0, 1, inclusive="neither").all()
568
582
 
569
583
 
570
- def require_no_policy_numbers_present_in_alg_update_args_but_not_study_df(
571
- study_df, policy_num_col_name, alg_update_func_args
584
+ def require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
585
+ analysis_df, policy_num_col_name, alg_update_func_args
572
586
  ):
573
587
  logger.info(
574
- "Checking that policy numbers in algorithm update function args are present in the study dataframe."
588
+ "Checking that policy numbers in algorithm update function args are present in the analysis DataFrame."
575
589
  )
576
590
  alg_update_policy_nums = sorted(alg_update_func_args.keys())
577
- study_df_policy_nums = sorted(study_df[policy_num_col_name].unique())
578
- assert set(alg_update_policy_nums).issubset(set(study_df_policy_nums)), (
579
- f"There are policy numbers present in algorithm update function args but not in the study dataframe. "
591
+ analysis_df_policy_nums = sorted(analysis_df[policy_num_col_name].unique())
592
+ assert set(alg_update_policy_nums).issubset(set(analysis_df_policy_nums)), (
593
+ f"There are policy numbers present in algorithm update function args but not in the analysis DataFrame. "
580
594
  f"\nalg_update_func_args policy numbers: {alg_update_policy_nums}"
581
- f"\nstudy_df policy numbers: {study_df_policy_nums}.\nPlease see the contract for details."
595
+ f"\nanalysis_df policy numbers: {analysis_df_policy_nums}.\nPlease see the contract for details."
582
596
  )
583
597
 
584
598
 
585
- def require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
586
- study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
599
+ def require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
600
+ analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
587
601
  ):
588
602
  logger.info(
589
- "Checking that all policy numbers in the study dataframe are present in the algorithm update function args."
603
+ "Checking that all policy numbers in the analysis DataFrame are present in the algorithm update function args."
590
604
  )
591
- in_study_df = study_df[study_df[in_study_col_name] == 1]
605
+ active_df = analysis_df[analysis_df[active_col_name] == 1]
592
606
  # Get the number of the initial policy. 0 is recommended but not required.
593
- min_nonnegative_policy_number = in_study_df[in_study_df[policy_num_col_name] >= 0][
607
+ min_nonnegative_policy_number = active_df[active_df[policy_num_col_name] >= 0][
594
608
  policy_num_col_name
595
609
  ]
596
610
  assert set(
597
- in_study_df[in_study_df[policy_num_col_name] > min_nonnegative_policy_number][
611
+ active_df[active_df[policy_num_col_name] > min_nonnegative_policy_number][
598
612
  policy_num_col_name
599
613
  ].unique()
600
614
  ).issubset(
601
615
  alg_update_func_args.keys()
602
- ), f"There are non-fallback, non-initial policy numbers in the study dataframe that are not in the update function args: {set(in_study_df[in_study_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
616
+ ), f"There are non-fallback, non-initial policy numbers in the analysis DataFrame that are not in the update function args: {set(active_df[active_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
603
617
 
604
618
 
605
619
  def confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
606
620
  alg_update_func_args_action_prob_index,
621
+ alg_update_func_args_previous_betas_index,
607
622
  suppress_interactive_data_checks,
608
623
  ):
609
624
  logger.info(
610
625
  "Confirming that action probabilities are not in algorithm update function args IF their index is not specified"
611
626
  )
612
- if alg_update_func_args_action_prob_index < 0:
627
+ if (
628
+ alg_update_func_args_action_prob_index < 0
629
+ and alg_update_func_args_previous_betas_index < 0
630
+ ):
613
631
  confirm_input_check_result(
614
- "\nYou specified that the algorithm update function supplied does not have action probabilities as one of its arguments. Please verify this is correct.\n\nContinue? (y/n)\n",
632
+ "\nYou specified that the algorithm update function supplied does not have action probabilities or previous betas in its arguments. Please verify this is correct.\n\nContinue? (y/n)\n",
615
633
  suppress_interactive_data_checks,
616
634
  )
617
635
 
@@ -630,20 +648,6 @@ def confirm_no_small_sample_correction_desired_if_not_requested(
630
648
  )
631
649
 
632
650
 
633
- def confirm_no_adaptive_bread_inverse_stabilization_method_desired_if_not_requested(
634
- adaptive_bread_inverse_stabilization_method,
635
- suppress_interactive_data_checks,
636
- ):
637
- logger.info(
638
- "Confirming that no adaptive bread inverse stabilization method is desired if it's not requested."
639
- )
640
- if adaptive_bread_inverse_stabilization_method == InverseStabilizationMethods.NONE:
641
- confirm_input_check_result(
642
- "\nYou specified that you would not like to perform any inverse stabilization while forming the adaptive variance. This is not usually recommended. Please verify that it is correct or select one of the available options.\n\nContinue? (y/n)\n",
643
- suppress_interactive_data_checks,
644
- )
645
-
646
-
647
651
  def require_action_prob_times_given_if_index_supplied(
648
652
  alg_update_func_args_action_prob_index,
649
653
  alg_update_func_args_action_prob_times_index,
@@ -672,28 +676,29 @@ def require_beta_is_1D_array_in_alg_update_args(
672
676
  alg_update_func_args, alg_update_func_args_beta_index
673
677
  ):
674
678
  for policy_num in alg_update_func_args:
675
- for user_id in alg_update_func_args[policy_num]:
676
- if not alg_update_func_args[policy_num][user_id]:
679
+ for subject_id in alg_update_func_args[policy_num]:
680
+ if not alg_update_func_args[policy_num][subject_id]:
677
681
  continue
678
682
  assert (
679
- alg_update_func_args[policy_num][user_id][
683
+ alg_update_func_args[policy_num][subject_id][
680
684
  alg_update_func_args_beta_index
681
685
  ].ndim
682
686
  == 1
683
687
  ), "Beta is not a 1D array in the algorithm update function args."
684
688
 
689
+
685
690
  def require_previous_betas_is_2D_array_in_alg_update_args(
686
691
  alg_update_func_args, alg_update_func_args_previous_betas_index
687
692
  ):
688
693
  if alg_update_func_args_previous_betas_index < 0:
689
694
  return
690
-
695
+
691
696
  for policy_num in alg_update_func_args:
692
- for user_id in alg_update_func_args[policy_num]:
693
- if not alg_update_func_args[policy_num][user_id]:
697
+ for subject_id in alg_update_func_args[policy_num]:
698
+ if not alg_update_func_args[policy_num][subject_id]:
694
699
  continue
695
700
  assert (
696
- alg_update_func_args[policy_num][user_id][
701
+ alg_update_func_args[policy_num][subject_id][
697
702
  alg_update_func_args_previous_betas_index
698
703
  ].ndim
699
704
  == 2
@@ -704,11 +709,11 @@ def require_beta_is_1D_array_in_action_prob_args(
704
709
  action_prob_func_args, action_prob_func_args_beta_index
705
710
  ):
706
711
  for decision_time in action_prob_func_args:
707
- for user_id in action_prob_func_args[decision_time]:
708
- if not action_prob_func_args[decision_time][user_id]:
712
+ for subject_id in action_prob_func_args[decision_time]:
713
+ if not action_prob_func_args[decision_time][subject_id]:
709
714
  continue
710
715
  assert (
711
- action_prob_func_args[decision_time][user_id][
716
+ action_prob_func_args[decision_time][subject_id][
712
717
  action_prob_func_args_beta_index
713
718
  ].ndim
714
719
  == 1
@@ -719,12 +724,12 @@ def require_theta_is_1D_array(theta_est):
719
724
  assert theta_est.ndim == 1, "Theta is not a 1D array."
720
725
 
721
726
 
722
- def verify_study_df_summary_satisfactory(
723
- study_df,
724
- user_id_col_name,
727
+ def verify_analysis_df_summary_satisfactory(
728
+ analysis_df,
729
+ subject_id_col_name,
725
730
  policy_num_col_name,
726
731
  calendar_t_col_name,
727
- in_study_col_name,
732
+ active_col_name,
728
733
  action_prob_col_name,
729
734
  reward_col_name,
730
735
  beta_dim,
@@ -732,43 +737,41 @@ def verify_study_df_summary_satisfactory(
732
737
  suppress_interactive_data_checks,
733
738
  ):
734
739
 
735
- in_study_df = study_df[study_df[in_study_col_name] == 1]
736
- num_users = in_study_df[user_id_col_name].nunique()
737
- num_non_initial_or_fallback_policies = in_study_df[
738
- in_study_df[policy_num_col_name] > 0
740
+ active_df = analysis_df[analysis_df[active_col_name] == 1]
741
+ num_subjects = active_df[subject_id_col_name].nunique()
742
+ num_non_initial_or_fallback_policies = active_df[
743
+ active_df[policy_num_col_name] > 0
739
744
  ][policy_num_col_name].nunique()
740
745
  num_decision_times_with_fallback_policies = len(
741
- in_study_df[in_study_df[policy_num_col_name] < 0]
746
+ active_df[active_df[policy_num_col_name] < 0]
742
747
  )
743
- num_decision_times = in_study_df[calendar_t_col_name].nunique()
744
- avg_decisions_per_user = len(in_study_df) / num_users
748
+ num_decision_times = active_df[calendar_t_col_name].nunique()
749
+ avg_decisions_per_subject = len(active_df) / num_subjects
745
750
  num_decision_times_with_multiple_policies = (
746
- in_study_df[in_study_df[policy_num_col_name] >= 0]
751
+ active_df[active_df[policy_num_col_name] >= 0]
747
752
  .groupby(calendar_t_col_name)[policy_num_col_name]
748
753
  .nunique()
749
754
  > 1
750
755
  ).sum()
751
- min_action_prob = in_study_df[action_prob_col_name].min()
752
- max_action_prob = in_study_df[action_prob_col_name].max()
753
- min_non_fallback_policy_num = in_study_df[in_study_df[policy_num_col_name] >= 0][
756
+ min_action_prob = active_df[action_prob_col_name].min()
757
+ max_action_prob = active_df[action_prob_col_name].max()
758
+ min_non_fallback_policy_num = active_df[active_df[policy_num_col_name] >= 0][
754
759
  policy_num_col_name
755
760
  ].min()
756
761
  num_data_points_before_first_update = len(
757
- in_study_df[in_study_df[policy_num_col_name] == min_non_fallback_policy_num]
762
+ active_df[active_df[policy_num_col_name] == min_non_fallback_policy_num]
758
763
  )
759
764
 
760
765
  median_action_probabilities = (
761
- in_study_df.groupby(calendar_t_col_name)[action_prob_col_name]
762
- .median()
763
- .to_numpy()
766
+ active_df.groupby(calendar_t_col_name)[action_prob_col_name].median().to_numpy()
764
767
  )
765
- quartiles = in_study_df.groupby(calendar_t_col_name)[action_prob_col_name].quantile(
768
+ quartiles = active_df.groupby(calendar_t_col_name)[action_prob_col_name].quantile(
766
769
  [0.25, 0.75]
767
770
  )
768
771
  q25_action_probabilities = quartiles.xs(0.25, level=1).to_numpy()
769
772
  q75_action_probabilities = quartiles.xs(0.75, level=1).to_numpy()
770
773
 
771
- avg_rewards = in_study_df.groupby(calendar_t_col_name)[reward_col_name].mean()
774
+ avg_rewards = active_df.groupby(calendar_t_col_name)[reward_col_name].mean()
772
775
 
773
776
  # Plot action probability quartile trajectories
774
777
  plt.clear_figure()
@@ -807,11 +810,11 @@ def verify_study_df_summary_satisfactory(
807
810
  avg_reward_trajectory_plot = plt.build()
808
811
 
809
812
  confirm_input_check_result(
810
- f"\nYou provided a study dataframe reflecting a study with"
811
- f"\n* {num_users} users"
813
+ f"\nYou provided an analysis DataFrame reflecting a study with"
814
+ f"\n* {num_subjects} subjects"
812
815
  f"\n* {num_non_initial_or_fallback_policies} policy updates"
813
- f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_user}"
814
- f" decisions per user"
816
+ f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_subject}"
817
+ f" decisions per subject"
815
818
  f"\n* RL parameters of dimension {beta_dim} per update"
816
819
  f"\n* Inferential target of dimension {theta_dim}"
817
820
  f"\n* {num_data_points_before_first_update} data points before the first update"
@@ -834,14 +837,14 @@ def require_betas_match_in_alg_update_args_each_update(
834
837
  alg_update_func_args, alg_update_func_args_beta_index
835
838
  ):
836
839
  logger.info(
837
- "Checking that betas match across users for each update in the algorithm update function args."
840
+ "Checking that betas match across subjects for each update in the algorithm update function args."
838
841
  )
839
842
  for policy_num in alg_update_func_args:
840
843
  first_beta = None
841
- for user_id in alg_update_func_args[policy_num]:
842
- if not alg_update_func_args[policy_num][user_id]:
844
+ for subject_id in alg_update_func_args[policy_num]:
845
+ if not alg_update_func_args[policy_num][subject_id]:
843
846
  continue
844
- beta = alg_update_func_args[policy_num][user_id][
847
+ beta = alg_update_func_args[policy_num][subject_id][
845
848
  alg_update_func_args_beta_index
846
849
  ]
847
850
  if first_beta is None:
@@ -849,23 +852,24 @@ def require_betas_match_in_alg_update_args_each_update(
849
852
  else:
850
853
  assert np.array_equal(
851
854
  beta, first_beta
852
- ), f"Betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
855
+ ), f"Betas do not match across subjects in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
856
+
853
857
 
854
858
  def require_previous_betas_match_in_alg_update_args_each_update(
855
859
  alg_update_func_args, alg_update_func_args_previous_betas_index
856
860
  ):
857
861
  logger.info(
858
- "Checking that previous betas match across users for each update in the algorithm update function args."
862
+ "Checking that previous betas match across subjects for each update in the algorithm update function args."
859
863
  )
860
864
  if alg_update_func_args_previous_betas_index < 0:
861
- return
862
-
865
+ return
866
+
863
867
  for policy_num in alg_update_func_args:
864
868
  first_previous_betas = None
865
- for user_id in alg_update_func_args[policy_num]:
866
- if not alg_update_func_args[policy_num][user_id]:
869
+ for subject_id in alg_update_func_args[policy_num]:
870
+ if not alg_update_func_args[policy_num][subject_id]:
867
871
  continue
868
- previous_betas = alg_update_func_args[policy_num][user_id][
872
+ previous_betas = alg_update_func_args[policy_num][subject_id][
869
873
  alg_update_func_args_previous_betas_index
870
874
  ]
871
875
  if first_previous_betas is None:
@@ -873,21 +877,21 @@ def require_previous_betas_match_in_alg_update_args_each_update(
873
877
  else:
874
878
  assert np.array_equal(
875
879
  previous_betas, first_previous_betas
876
- ), f"Previous betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
880
+ ), f"Previous betas do not match across subjects in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
877
881
 
878
882
 
879
883
  def require_betas_match_in_action_prob_func_args_each_decision(
880
884
  action_prob_func_args, action_prob_func_args_beta_index
881
885
  ):
882
886
  logger.info(
883
- "Checking that betas match across users for each decision time in the action prob args."
887
+ "Checking that betas match across subjects for each decision time in the action prob args."
884
888
  )
885
889
  for decision_time in action_prob_func_args:
886
890
  first_beta = None
887
- for user_id in action_prob_func_args[decision_time]:
888
- if not action_prob_func_args[decision_time][user_id]:
891
+ for subject_id in action_prob_func_args[decision_time]:
892
+ if not action_prob_func_args[decision_time][subject_id]:
889
893
  continue
890
- beta = action_prob_func_args[decision_time][user_id][
894
+ beta = action_prob_func_args[decision_time][subject_id][
891
895
  action_prob_func_args_beta_index
892
896
  ]
893
897
  if first_beta is None:
@@ -895,11 +899,11 @@ def require_betas_match_in_action_prob_func_args_each_decision(
895
899
  else:
896
900
  assert np.array_equal(
897
901
  beta, first_beta
898
- ), f"Betas do not match across users in the action prob args for decision_time {decision_time}. Please see the contract for details."
902
+ ), f"Betas do not match across subjects in the action prob args for decision_time {decision_time}. Please see the contract for details."
899
903
 
900
904
 
901
905
  def require_valid_action_prob_times_given_if_index_supplied(
902
- study_df,
906
+ analysis_df,
903
907
  calendar_t_col_name,
904
908
  alg_update_func_args,
905
909
  alg_update_func_args_action_prob_times_index,
@@ -909,19 +913,19 @@ def require_valid_action_prob_times_given_if_index_supplied(
909
913
  if alg_update_func_args_action_prob_times_index < 0:
910
914
  return
911
915
 
912
- min_time = study_df[calendar_t_col_name].min()
913
- max_time = study_df[calendar_t_col_name].max()
914
- for policy_idx, args_by_user in alg_update_func_args.items():
915
- for user_id, args in args_by_user.items():
916
+ min_time = analysis_df[calendar_t_col_name].min()
917
+ max_time = analysis_df[calendar_t_col_name].max()
918
+ for policy_idx, args_by_subject in alg_update_func_args.items():
919
+ for subject_id, args in args_by_subject.items():
916
920
  if not args:
917
921
  continue
918
922
  times = args[alg_update_func_args_action_prob_times_index]
919
923
  assert (
920
924
  times[i] > times[i - 1] for i in range(1, len(times))
921
- ), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for user {user_id} and policy {policy_idx}. Please see the contract for details."
925
+ ), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for subject {subject_id} and policy {policy_idx}. Please see the contract for details."
922
926
  assert (
923
927
  times[0] >= min_time and times[-1] <= max_time
924
- ), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the study dataframe are {min_time} and {max_time}, while user {user_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
928
+ ), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the analysis DataFrame are {min_time} and {max_time}, while subject {subject_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
925
929
 
926
930
 
927
931
  def require_estimating_functions_sum_to_zero(
@@ -933,15 +937,15 @@ def require_estimating_functions_sum_to_zero(
933
937
  """
934
938
  This is a test that the correct loss/estimating functions have
935
939
  been given for both the algorithm updates and inference. If that is true, then the
936
- loss/estimating functions when evaluated should sum to approximately zero across users. These
937
- values have been stacked and averaged across users in mean_estimating_function_stack, which
940
+ loss/estimating functions when evaluated should sum to approximately zero across subjects. These
941
+ values have been stacked and averaged across subjects in mean_estimating_function_stack, which
938
942
  we simply compare to the zero vector. We can isolate components for each update and inference
939
943
  by considering the dimensions of the beta vectors and the theta vector.
940
944
 
941
945
  Inputs:
942
946
  mean_estimating_function_stack:
943
947
  The mean of the estimating function stack (a component for each algorithm update and
944
- inference) across users. This should be a 1D array.
948
+ inference) across subjects. This should be a 1D array.
945
949
  beta_dim:
946
950
  The dimension of the beta vectors that parameterize the algorithm.
947
951
  theta_dim:
@@ -951,7 +955,7 @@ def require_estimating_functions_sum_to_zero(
951
955
  None
952
956
  """
953
957
 
954
- logger.info("Checking that estimating functions average to zero across users")
958
+ logger.info("Checking that estimating functions average to zero across subjects")
955
959
 
956
960
  # Have a looser hard failure cutoff before the typical interactive check
957
961
  try:
@@ -962,7 +966,7 @@ def require_estimating_functions_sum_to_zero(
962
966
  )
963
967
  except AssertionError as e:
964
968
  logger.info(
965
- "Estimating function stacks do not average to within loose tolerance of zero across users. Drilling in to specific updates and inference component."
969
+ "Estimating function stacks do not average to within loose tolerance of zero across subjects. Drilling in to specific updates and inference component."
966
970
  )
967
971
  # If this is not true there is an internal problem in the package.
968
972
  assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
@@ -987,11 +991,11 @@ def require_estimating_functions_sum_to_zero(
987
991
  np.testing.assert_allclose(
988
992
  mean_estimating_function_stack,
989
993
  jnp.zeros(mean_estimating_function_stack.size),
990
- atol=1e-5,
994
+ atol=5e-4,
991
995
  )
992
996
  except AssertionError as e:
993
997
  logger.info(
994
- "Estimating function stacks do not average to within specified tolerance of zero across users. Drilling in to specific updates and inference component."
998
+ "Estimating function stacks do not average to within specified tolerance of zero across subjects. Drilling in to specific updates and inference component."
995
999
  )
996
1000
  # If this is not true there is an internal problem in the package.
997
1001
  assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
@@ -1021,15 +1025,15 @@ def require_RL_estimating_functions_sum_to_zero(
1021
1025
  """
1022
1026
  This is a test that the correct loss/estimating functions have
1023
1027
  been given for both the algorithm updates and inference. If that is true, then the
1024
- loss/estimating functions when evaluated should sum to approximately zero across users. These
1025
- values have been stacked and averaged across users in mean_estimating_function_stack, which
1028
+ loss/estimating functions when evaluated should sum to approximately zero across subjects. These
1029
+ values have been stacked and averaged across subjects in mean_estimating_function_stack, which
1026
1030
  we simply compare to the zero vector. We can isolate components for each update and inference
1027
1031
  by considering the dimensions of the beta vectors and the theta vector.
1028
1032
 
1029
1033
  Inputs:
1030
1034
  mean_estimating_function_stack:
1031
1035
  The mean of the estimating function stack (a component for each algorithm update and
1032
- inference) across users. This should be a 1D array.
1036
+ inference) across subjects. This should be a 1D array.
1033
1037
  beta_dim:
1034
1038
  The dimension of the beta vectors that parameterize the algorithm.
1035
1039
  theta_dim:
@@ -1039,7 +1043,7 @@ def require_RL_estimating_functions_sum_to_zero(
1039
1043
  None
1040
1044
  """
1041
1045
 
1042
- logger.info("Checking that RL estimating functions average to zero across users")
1046
+ logger.info("Checking that RL estimating functions average to zero across subjects")
1043
1047
 
1044
1048
  # Have a looser hard failure cutoff before the typical interactive check
1045
1049
  try:
@@ -1050,7 +1054,7 @@ def require_RL_estimating_functions_sum_to_zero(
1050
1054
  )
1051
1055
  except AssertionError as e:
1052
1056
  logger.info(
1053
- "RL estimating function stacks do not average to zero across users. Drilling in to specific updates and inference component."
1057
+ "RL estimating function stacks do not average to zero across subjects. Drilling in to specific updates and inference component."
1054
1058
  )
1055
1059
  num_updates = (mean_estimating_function_stack.size) // beta_dim
1056
1060
  for i in range(num_updates):
@@ -1070,7 +1074,7 @@ def require_RL_estimating_functions_sum_to_zero(
1070
1074
  )
1071
1075
  except AssertionError as e:
1072
1076
  logger.info(
1073
- "RL estimating function stacks do not average to zero across users. Drilling in to specific updates and inference component."
1077
+ "RL estimating function stacks do not average to zero across subjects. Drilling in to specific updates and inference component."
1074
1078
  )
1075
1079
  num_updates = (mean_estimating_function_stack.size) // beta_dim
1076
1080
  for i in range(num_updates):
@@ -1079,7 +1083,6 @@ def require_RL_estimating_functions_sum_to_zero(
1079
1083
  i + 1,
1080
1084
  mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
1081
1085
  )
1082
- # TODO: Email instead of requiring user input for monitoring alg.
1083
1086
  confirm_input_check_result(
1084
1087
  f"\nEstimating functions do not average to within default tolerance of zero vector. Please decide if the following is a reasonable result, taking into account the above breakdown by update number and inference. If not, there are several possible reasons for failure mentioned in the contract. Results:\n{str(e)}\n\nContinue? (y/n)\n",
1085
1088
  suppress_interactive_data_checks,
@@ -1087,20 +1090,18 @@ def require_RL_estimating_functions_sum_to_zero(
1087
1090
  )
1088
1091
 
1089
1092
 
1090
- def require_adaptive_bread_inverse_is_true_inverse(
1091
- joint_adaptive_bread_matrix,
1092
- joint_adaptive_bread_inverse_matrix,
1093
+ def require_joint_bread_inverse_is_true_inverse(
1094
+ joint_bread_inverse_matrix,
1095
+ joint_bread_matrix,
1093
1096
  suppress_interactive_data_checks,
1094
1097
  ):
1095
1098
  """
1096
- Check that the product of the joint adaptive bread matrix and its inverse is
1099
+ Check that the product of the joint bread matrix and its inverse is
1097
1100
  sufficiently close to the identity matrix. This is a direct check that the
1098
- joint_adaptive_bread_inverse_matrix we create is "well-conditioned".
1101
+ joint_bread_matrix we create is "well-conditioned".
1099
1102
  """
1100
- should_be_identity = (
1101
- joint_adaptive_bread_matrix @ joint_adaptive_bread_inverse_matrix
1102
- )
1103
- identity = np.eye(joint_adaptive_bread_matrix.shape[0])
1103
+ should_be_identity = joint_bread_inverse_matrix @ joint_bread_matrix
1104
+ identity = np.eye(joint_bread_matrix.shape[0])
1104
1105
  try:
1105
1106
  np.testing.assert_allclose(
1106
1107
  should_be_identity,
@@ -1110,7 +1111,7 @@ def require_adaptive_bread_inverse_is_true_inverse(
1110
1111
  )
1111
1112
  except AssertionError as e:
1112
1113
  confirm_input_check_result(
1113
- f"\nJoint adaptive bread is not exact inverse of the constructed matrix that was inverted to form it. This likely illustrates poor conditioning:\n{str(e)}\n\nContinue? (y/n)\n",
1114
+ f"\nJoint bread inverse is not exact inverse of the constructed matrix that was inverted to form it. This likely illustrates poor conditioning:\n{str(e)}\n\nContinue? (y/n)\n",
1114
1115
  suppress_interactive_data_checks,
1115
1116
  e,
1116
1117
  )
@@ -1118,7 +1119,7 @@ def require_adaptive_bread_inverse_is_true_inverse(
1118
1119
  # If we haven't already errored out, return some measures of how far off we are from identity
1119
1120
  diff = should_be_identity - identity
1120
1121
  logger.debug(
1121
- "Difference between should-be-identity produced by multiplying joint adaptive bread inverse and its computed inverse and actual identity:\n%s",
1122
+ "Difference between should-be-identity produced by multiplying joint bread and its computed inverse and actual identity:\n%s",
1122
1123
  diff,
1123
1124
  )
1124
1125
 
@@ -1133,8 +1134,8 @@ def require_adaptive_bread_inverse_is_true_inverse(
1133
1134
 
1134
1135
  def require_threaded_algorithm_estimating_function_args_equivalent(
1135
1136
  algorithm_estimating_func,
1136
- update_func_args_by_by_user_id_by_policy_num,
1137
- threaded_update_func_args_by_policy_num_by_user_id,
1137
+ update_func_args_by_by_subject_id_by_policy_num,
1138
+ threaded_update_func_args_by_policy_num_by_subject_id,
1138
1139
  suppress_interactive_data_checks,
1139
1140
  ):
1140
1141
  """
@@ -1144,12 +1145,12 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
1144
1145
  """
1145
1146
  for (
1146
1147
  policy_num,
1147
- update_func_args_by_user_id,
1148
- ) in update_func_args_by_by_user_id_by_policy_num.items():
1148
+ update_func_args_by_subject_id,
1149
+ ) in update_func_args_by_by_subject_id_by_policy_num.items():
1149
1150
  for (
1150
- user_id,
1151
+ subject_id,
1151
1152
  unthreaded_args,
1152
- ) in update_func_args_by_user_id.items():
1153
+ ) in update_func_args_by_subject_id.items():
1153
1154
  if not unthreaded_args:
1154
1155
  continue
1155
1156
  np.testing.assert_allclose(
@@ -1157,9 +1158,9 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
1157
1158
  # Need to stop gradient here because we can't convert a traced value to np array
1158
1159
  jax.lax.stop_gradient(
1159
1160
  algorithm_estimating_func(
1160
- *threaded_update_func_args_by_policy_num_by_user_id[user_id][
1161
- policy_num
1162
- ]
1161
+ *threaded_update_func_args_by_policy_num_by_subject_id[
1162
+ subject_id
1163
+ ][policy_num]
1163
1164
  )
1164
1165
  ),
1165
1166
  atol=1e-7,
@@ -1169,8 +1170,8 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
1169
1170
 
1170
1171
  def require_threaded_inference_estimating_function_args_equivalent(
1171
1172
  inference_estimating_func,
1172
- inference_func_args_by_user_id,
1173
- threaded_inference_func_args_by_user_id,
1173
+ inference_func_args_by_subject_id,
1174
+ threaded_inference_func_args_by_subject_id,
1174
1175
  suppress_interactive_data_checks,
1175
1176
  ):
1176
1177
  """
@@ -1178,7 +1179,7 @@ def require_threaded_inference_estimating_function_args_equivalent(
1178
1179
  when called with the original arguments and when called with the
1179
1180
  reconstructed action probabilities substituted in.
1180
1181
  """
1181
- for user_id, unthreaded_args in inference_func_args_by_user_id.items():
1182
+ for subject_id, unthreaded_args in inference_func_args_by_subject_id.items():
1182
1183
  if not unthreaded_args:
1183
1184
  continue
1184
1185
  np.testing.assert_allclose(
@@ -1186,7 +1187,7 @@ def require_threaded_inference_estimating_function_args_equivalent(
1186
1187
  # Need to stop gradient here because we can't convert a traced value to np array
1187
1188
  jax.lax.stop_gradient(
1188
1189
  inference_estimating_func(
1189
- *threaded_inference_func_args_by_user_id[user_id]
1190
+ *threaded_inference_func_args_by_subject_id[subject_id]
1190
1191
  )
1191
1192
  ),
1192
1193
  rtol=1e-2,