lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/after_study_analysis.py +397 -387
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -21
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +146 -128
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +7 -7
- lifejacket/get_datum_for_blowup_supervised_learning.py +315 -307
- lifejacket/helper_functions.py +45 -38
- lifejacket/input_checks.py +263 -261
- lifejacket/small_sample_corrections.py +42 -40
- lifejacket-1.0.0.dist-info/METADATA +56 -0
- lifejacket-1.0.0.dist-info/RECORD +17 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/top_level.txt +0 -0
lifejacket/input_checks.py
CHANGED
|
@@ -26,12 +26,12 @@ logging.basicConfig(
|
|
|
26
26
|
|
|
27
27
|
# TODO: any checks needed here about alg update function type?
|
|
28
28
|
def perform_first_wave_input_checks(
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
analysis_df,
|
|
30
|
+
active_col_name,
|
|
31
31
|
action_col_name,
|
|
32
32
|
policy_num_col_name,
|
|
33
33
|
calendar_t_col_name,
|
|
34
|
-
|
|
34
|
+
subject_id_col_name,
|
|
35
35
|
action_prob_col_name,
|
|
36
36
|
reward_col_name,
|
|
37
37
|
action_prob_func,
|
|
@@ -48,11 +48,11 @@ def perform_first_wave_input_checks(
|
|
|
48
48
|
small_sample_correction,
|
|
49
49
|
):
|
|
50
50
|
### Validate algorithm loss/estimating function and args
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
require_alg_update_args_given_for_all_subjects_at_each_update(
|
|
52
|
+
analysis_df, subject_id_col_name, alg_update_func_args
|
|
53
53
|
)
|
|
54
|
-
|
|
55
|
-
|
|
54
|
+
require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
|
|
55
|
+
analysis_df, policy_num_col_name, alg_update_func_args
|
|
56
56
|
)
|
|
57
57
|
require_beta_is_1D_array_in_alg_update_args(
|
|
58
58
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
@@ -60,8 +60,8 @@ def perform_first_wave_input_checks(
|
|
|
60
60
|
require_previous_betas_is_2D_array_in_alg_update_args(
|
|
61
61
|
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
62
62
|
)
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
64
|
+
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
65
65
|
)
|
|
66
66
|
confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
67
67
|
alg_update_func_args_action_prob_index, suppress_interactive_data_checks
|
|
@@ -80,17 +80,17 @@ def perform_first_wave_input_checks(
|
|
|
80
80
|
require_previous_betas_match_in_alg_update_args_each_update(
|
|
81
81
|
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
82
82
|
)
|
|
83
|
-
|
|
84
|
-
|
|
83
|
+
require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
|
|
84
|
+
analysis_df,
|
|
85
85
|
action_prob_col_name,
|
|
86
86
|
calendar_t_col_name,
|
|
87
|
-
|
|
87
|
+
subject_id_col_name,
|
|
88
88
|
alg_update_func_args,
|
|
89
89
|
alg_update_func_args_action_prob_index,
|
|
90
90
|
alg_update_func_args_action_prob_times_index,
|
|
91
91
|
)
|
|
92
92
|
require_valid_action_prob_times_given_if_index_supplied(
|
|
93
|
-
|
|
93
|
+
analysis_df,
|
|
94
94
|
calendar_t_col_name,
|
|
95
95
|
alg_update_func_args,
|
|
96
96
|
alg_update_func_args_action_prob_times_index,
|
|
@@ -101,28 +101,28 @@ def perform_first_wave_input_checks(
|
|
|
101
101
|
)
|
|
102
102
|
|
|
103
103
|
### Validate action prob function and args
|
|
104
|
-
|
|
105
|
-
|
|
104
|
+
require_action_prob_func_args_given_for_all_subjects_at_each_decision(
|
|
105
|
+
analysis_df, subject_id_col_name, action_prob_func_args
|
|
106
106
|
)
|
|
107
107
|
require_action_prob_func_args_given_for_all_decision_times(
|
|
108
|
-
|
|
108
|
+
analysis_df, calendar_t_col_name, action_prob_func_args
|
|
109
109
|
)
|
|
110
|
-
|
|
111
|
-
|
|
110
|
+
require_action_probabilities_in_analysis_df_can_be_reconstructed(
|
|
111
|
+
analysis_df,
|
|
112
112
|
action_prob_col_name,
|
|
113
113
|
calendar_t_col_name,
|
|
114
|
-
|
|
115
|
-
|
|
114
|
+
subject_id_col_name,
|
|
115
|
+
active_col_name,
|
|
116
116
|
action_prob_func_args,
|
|
117
117
|
action_prob_func,
|
|
118
118
|
)
|
|
119
119
|
|
|
120
120
|
require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
121
|
-
|
|
121
|
+
analysis_df,
|
|
122
122
|
calendar_t_col_name,
|
|
123
123
|
action_prob_func_args,
|
|
124
|
-
|
|
125
|
-
|
|
124
|
+
active_col_name,
|
|
125
|
+
subject_id_col_name,
|
|
126
126
|
)
|
|
127
127
|
require_beta_is_1D_array_in_action_prob_args(
|
|
128
128
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
@@ -131,13 +131,13 @@ def perform_first_wave_input_checks(
|
|
|
131
131
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
132
132
|
)
|
|
133
133
|
|
|
134
|
-
### Validate
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
134
|
+
### Validate analysis_df
|
|
135
|
+
verify_analysis_df_summary_satisfactory(
|
|
136
|
+
analysis_df,
|
|
137
|
+
subject_id_col_name,
|
|
138
138
|
policy_num_col_name,
|
|
139
139
|
calendar_t_col_name,
|
|
140
|
-
|
|
140
|
+
active_col_name,
|
|
141
141
|
action_prob_col_name,
|
|
142
142
|
reward_col_name,
|
|
143
143
|
beta_dim,
|
|
@@ -145,46 +145,46 @@ def perform_first_wave_input_checks(
|
|
|
145
145
|
suppress_interactive_data_checks,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
-
|
|
149
|
-
|
|
148
|
+
require_all_subjects_have_all_times_in_analysis_df(
|
|
149
|
+
analysis_df, calendar_t_col_name, subject_id_col_name
|
|
150
150
|
)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
151
|
+
require_all_named_columns_present_in_analysis_df(
|
|
152
|
+
analysis_df,
|
|
153
|
+
active_col_name,
|
|
154
154
|
action_col_name,
|
|
155
155
|
policy_num_col_name,
|
|
156
156
|
calendar_t_col_name,
|
|
157
|
-
|
|
157
|
+
subject_id_col_name,
|
|
158
158
|
action_prob_col_name,
|
|
159
159
|
)
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
160
|
+
require_all_named_columns_not_object_type_in_analysis_df(
|
|
161
|
+
analysis_df,
|
|
162
|
+
active_col_name,
|
|
163
163
|
action_col_name,
|
|
164
164
|
policy_num_col_name,
|
|
165
165
|
calendar_t_col_name,
|
|
166
|
-
|
|
166
|
+
subject_id_col_name,
|
|
167
167
|
action_prob_col_name,
|
|
168
168
|
)
|
|
169
|
-
require_binary_actions(
|
|
170
|
-
|
|
169
|
+
require_binary_actions(analysis_df, active_col_name, action_col_name)
|
|
170
|
+
require_binary_active_indicators(analysis_df, active_col_name)
|
|
171
171
|
require_consecutive_integer_policy_numbers(
|
|
172
|
-
|
|
172
|
+
analysis_df, active_col_name, policy_num_col_name
|
|
173
173
|
)
|
|
174
|
-
require_consecutive_integer_calendar_times(
|
|
175
|
-
|
|
176
|
-
require_action_probabilities_in_range_0_to_1(
|
|
174
|
+
require_consecutive_integer_calendar_times(analysis_df, calendar_t_col_name)
|
|
175
|
+
require_hashable_subject_ids(analysis_df, active_col_name, subject_id_col_name)
|
|
176
|
+
require_action_probabilities_in_range_0_to_1(analysis_df, action_prob_col_name)
|
|
177
177
|
|
|
178
178
|
### Validate theta estimation
|
|
179
179
|
require_theta_is_1D_array(theta_est)
|
|
180
180
|
|
|
181
181
|
|
|
182
182
|
def perform_alg_only_input_checks(
|
|
183
|
-
|
|
184
|
-
|
|
183
|
+
analysis_df,
|
|
184
|
+
active_col_name,
|
|
185
185
|
policy_num_col_name,
|
|
186
186
|
calendar_t_col_name,
|
|
187
|
-
|
|
187
|
+
subject_id_col_name,
|
|
188
188
|
action_prob_col_name,
|
|
189
189
|
action_prob_func,
|
|
190
190
|
action_prob_func_args,
|
|
@@ -196,14 +196,14 @@ def perform_alg_only_input_checks(
|
|
|
196
196
|
suppress_interactive_data_checks,
|
|
197
197
|
):
|
|
198
198
|
### Validate algorithm loss/estimating function and args
|
|
199
|
-
|
|
200
|
-
|
|
199
|
+
require_alg_update_args_given_for_all_subjects_at_each_update(
|
|
200
|
+
analysis_df, subject_id_col_name, alg_update_func_args
|
|
201
201
|
)
|
|
202
202
|
require_beta_is_1D_array_in_alg_update_args(
|
|
203
203
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
204
204
|
)
|
|
205
|
-
|
|
206
|
-
|
|
205
|
+
require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
206
|
+
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
207
207
|
)
|
|
208
208
|
confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
209
209
|
alg_update_func_args_action_prob_index, suppress_interactive_data_checks
|
|
@@ -219,45 +219,45 @@ def perform_alg_only_input_checks(
|
|
|
219
219
|
require_betas_match_in_alg_update_args_each_update(
|
|
220
220
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
221
221
|
)
|
|
222
|
-
|
|
223
|
-
|
|
222
|
+
require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
|
|
223
|
+
analysis_df,
|
|
224
224
|
action_prob_col_name,
|
|
225
225
|
calendar_t_col_name,
|
|
226
|
-
|
|
226
|
+
subject_id_col_name,
|
|
227
227
|
alg_update_func_args,
|
|
228
228
|
alg_update_func_args_action_prob_index,
|
|
229
229
|
alg_update_func_args_action_prob_times_index,
|
|
230
230
|
)
|
|
231
231
|
require_valid_action_prob_times_given_if_index_supplied(
|
|
232
|
-
|
|
232
|
+
analysis_df,
|
|
233
233
|
calendar_t_col_name,
|
|
234
234
|
alg_update_func_args,
|
|
235
235
|
alg_update_func_args_action_prob_times_index,
|
|
236
236
|
)
|
|
237
237
|
|
|
238
238
|
### Validate action prob function and args
|
|
239
|
-
|
|
240
|
-
|
|
239
|
+
require_action_prob_func_args_given_for_all_subjects_at_each_decision(
|
|
240
|
+
analysis_df, subject_id_col_name, action_prob_func_args
|
|
241
241
|
)
|
|
242
242
|
require_action_prob_func_args_given_for_all_decision_times(
|
|
243
|
-
|
|
243
|
+
analysis_df, calendar_t_col_name, action_prob_func_args
|
|
244
244
|
)
|
|
245
|
-
|
|
246
|
-
|
|
245
|
+
require_action_probabilities_in_analysis_df_can_be_reconstructed(
|
|
246
|
+
analysis_df,
|
|
247
247
|
action_prob_col_name,
|
|
248
248
|
calendar_t_col_name,
|
|
249
|
-
|
|
250
|
-
|
|
249
|
+
subject_id_col_name,
|
|
250
|
+
active_col_name,
|
|
251
251
|
action_prob_func_args,
|
|
252
252
|
action_prob_func=action_prob_func,
|
|
253
253
|
)
|
|
254
254
|
|
|
255
255
|
require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
256
|
-
|
|
256
|
+
analysis_df,
|
|
257
257
|
calendar_t_col_name,
|
|
258
258
|
action_prob_func_args,
|
|
259
|
-
|
|
260
|
-
|
|
259
|
+
active_col_name,
|
|
260
|
+
subject_id_col_name,
|
|
261
261
|
)
|
|
262
262
|
require_beta_is_1D_array_in_action_prob_args(
|
|
263
263
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
@@ -268,12 +268,12 @@ def perform_alg_only_input_checks(
|
|
|
268
268
|
|
|
269
269
|
|
|
270
270
|
# TODO: Give a hard-to-use option to loosen this check somehow
|
|
271
|
-
def
|
|
272
|
-
|
|
271
|
+
def require_action_probabilities_in_analysis_df_can_be_reconstructed(
|
|
272
|
+
analysis_df,
|
|
273
273
|
action_prob_col_name,
|
|
274
274
|
calendar_t_col_name,
|
|
275
|
-
|
|
276
|
-
|
|
275
|
+
subject_id_col_name,
|
|
276
|
+
active_col_name,
|
|
277
277
|
action_prob_func_args,
|
|
278
278
|
action_prob_func,
|
|
279
279
|
):
|
|
@@ -285,77 +285,79 @@ def require_action_probabilities_in_study_df_can_be_reconstructed(
|
|
|
285
285
|
"""
|
|
286
286
|
logger.info("Reconstructing action probabilities from function and arguments.")
|
|
287
287
|
|
|
288
|
-
|
|
289
|
-
reconstructed_action_probs =
|
|
288
|
+
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
289
|
+
reconstructed_action_probs = active_df.apply(
|
|
290
290
|
lambda row: action_prob_func(
|
|
291
|
-
*action_prob_func_args[row[calendar_t_col_name]][row[
|
|
291
|
+
*action_prob_func_args[row[calendar_t_col_name]][row[subject_id_col_name]]
|
|
292
292
|
),
|
|
293
293
|
axis=1,
|
|
294
294
|
)
|
|
295
295
|
|
|
296
296
|
np.testing.assert_allclose(
|
|
297
|
-
|
|
297
|
+
active_df[action_prob_col_name].to_numpy(dtype="float64"),
|
|
298
298
|
reconstructed_action_probs.to_numpy(dtype="float64"),
|
|
299
299
|
atol=1e-6,
|
|
300
300
|
)
|
|
301
301
|
|
|
302
302
|
|
|
303
|
-
def
|
|
304
|
-
|
|
303
|
+
def require_all_subjects_have_all_times_in_analysis_df(
|
|
304
|
+
analysis_df, calendar_t_col_name, subject_id_col_name
|
|
305
305
|
):
|
|
306
|
-
logger.info(
|
|
306
|
+
logger.info(
|
|
307
|
+
"Checking that all subjects have the same set of unique calendar times."
|
|
308
|
+
)
|
|
307
309
|
# Get the unique calendar times
|
|
308
|
-
unique_calendar_times = set(
|
|
310
|
+
unique_calendar_times = set(analysis_df[calendar_t_col_name].unique())
|
|
309
311
|
|
|
310
|
-
# Group by
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
)
|
|
312
|
+
# Group by subject ID and aggregate the unique calendar times for each subject
|
|
313
|
+
subject_calendar_times = analysis_df.groupby(subject_id_col_name)[
|
|
314
|
+
calendar_t_col_name
|
|
315
|
+
].apply(set)
|
|
314
316
|
|
|
315
|
-
# Check if all
|
|
316
|
-
if not
|
|
317
|
+
# Check if all subjects have the same set of unique calendar times
|
|
318
|
+
if not subject_calendar_times.apply(lambda x: x == unique_calendar_times).all():
|
|
317
319
|
raise AssertionError(
|
|
318
|
-
"Not all
|
|
320
|
+
"Not all subjects have all calendar times in the study dataframe. Please see the contract for details."
|
|
319
321
|
)
|
|
320
322
|
|
|
321
323
|
|
|
322
|
-
def
|
|
323
|
-
|
|
324
|
+
def require_alg_update_args_given_for_all_subjects_at_each_update(
|
|
325
|
+
analysis_df, subject_id_col_name, alg_update_func_args
|
|
324
326
|
):
|
|
325
327
|
logger.info(
|
|
326
|
-
"Checking that algorithm update function args are given for all
|
|
328
|
+
"Checking that algorithm update function args are given for all subjects at each update."
|
|
327
329
|
)
|
|
328
|
-
|
|
330
|
+
all_subject_ids = set(analysis_df[subject_id_col_name].unique())
|
|
329
331
|
for policy_num in alg_update_func_args:
|
|
330
332
|
assert (
|
|
331
|
-
set(alg_update_func_args[policy_num].keys()) ==
|
|
332
|
-
), f"Not all
|
|
333
|
+
set(alg_update_func_args[policy_num].keys()) == all_subject_ids
|
|
334
|
+
), f"Not all subjects present in algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
333
335
|
|
|
334
336
|
|
|
335
|
-
def
|
|
336
|
-
|
|
337
|
+
def require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
|
|
338
|
+
analysis_df,
|
|
337
339
|
action_prob_col_name,
|
|
338
340
|
calendar_t_col_name,
|
|
339
|
-
|
|
341
|
+
subject_id_col_name,
|
|
340
342
|
alg_update_func_args,
|
|
341
343
|
alg_update_func_args_action_prob_index,
|
|
342
344
|
alg_update_func_args_action_prob_times_index,
|
|
343
345
|
):
|
|
344
346
|
logger.info(
|
|
345
347
|
"Checking that the action probabilities supplied in the algorithm update function args, if"
|
|
346
|
-
" any, correspond to those in the study dataframe for the corresponding
|
|
348
|
+
" any, correspond to those in the study dataframe for the corresponding subjects and decision"
|
|
347
349
|
" times."
|
|
348
350
|
)
|
|
349
351
|
if alg_update_func_args_action_prob_index < 0:
|
|
350
352
|
return
|
|
351
353
|
|
|
352
354
|
# Precompute a lookup dictionary for faster access
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
].to_dict()
|
|
355
|
+
analysis_df_lookup = analysis_df.set_index(
|
|
356
|
+
[calendar_t_col_name, subject_id_col_name]
|
|
357
|
+
)[action_prob_col_name].to_dict()
|
|
356
358
|
|
|
357
|
-
for policy_num,
|
|
358
|
-
for
|
|
359
|
+
for policy_num, subject_args in alg_update_func_args.items():
|
|
360
|
+
for subject_id, args in subject_args.items():
|
|
359
361
|
if not args:
|
|
360
362
|
continue
|
|
361
363
|
arg_action_probs = args[alg_update_func_args_action_prob_index]
|
|
@@ -364,43 +366,43 @@ def require_action_prob_args_in_alg_update_func_correspond_to_study_df(
|
|
|
364
366
|
].flatten()
|
|
365
367
|
|
|
366
368
|
# Use the precomputed lookup dictionary
|
|
367
|
-
|
|
368
|
-
|
|
369
|
+
analysis_df_action_probs = [
|
|
370
|
+
analysis_df_lookup[(decision_time.item(), subject_id)]
|
|
369
371
|
for decision_time in action_prob_times
|
|
370
372
|
]
|
|
371
373
|
|
|
372
374
|
assert np.allclose(
|
|
373
375
|
arg_action_probs.flatten(),
|
|
374
|
-
|
|
376
|
+
analysis_df_action_probs,
|
|
375
377
|
), (
|
|
376
|
-
f"There is a mismatch for
|
|
378
|
+
f"There is a mismatch for subject {subject_id} between the action probabilities supplied"
|
|
377
379
|
f" in the args to the algorithm update function at policy {policy_num} and those in"
|
|
378
380
|
" the study dataframe for the supplied times. Please see the contract for details."
|
|
379
381
|
)
|
|
380
382
|
|
|
381
383
|
|
|
382
|
-
def
|
|
383
|
-
|
|
384
|
-
|
|
384
|
+
def require_action_prob_func_args_given_for_all_subjects_at_each_decision(
|
|
385
|
+
analysis_df,
|
|
386
|
+
subject_id_col_name,
|
|
385
387
|
action_prob_func_args,
|
|
386
388
|
):
|
|
387
389
|
logger.info(
|
|
388
|
-
"Checking that action prob function args are given for all
|
|
390
|
+
"Checking that action prob function args are given for all subjects at each decision time."
|
|
389
391
|
)
|
|
390
|
-
|
|
392
|
+
all_subject_ids = set(analysis_df[subject_id_col_name].unique())
|
|
391
393
|
for decision_time in action_prob_func_args:
|
|
392
394
|
assert (
|
|
393
|
-
set(action_prob_func_args[decision_time].keys()) ==
|
|
394
|
-
), f"Not all
|
|
395
|
+
set(action_prob_func_args[decision_time].keys()) == all_subject_ids
|
|
396
|
+
), f"Not all subjects present in algorithm update function args for decision time {decision_time}. Please see the contract for details."
|
|
395
397
|
|
|
396
398
|
|
|
397
399
|
def require_action_prob_func_args_given_for_all_decision_times(
|
|
398
|
-
|
|
400
|
+
analysis_df, calendar_t_col_name, action_prob_func_args
|
|
399
401
|
):
|
|
400
402
|
logger.info(
|
|
401
403
|
"Checking that action prob function args are given for all decision times."
|
|
402
404
|
)
|
|
403
|
-
all_times = set(
|
|
405
|
+
all_times = set(analysis_df[calendar_t_col_name].unique())
|
|
404
406
|
|
|
405
407
|
assert (
|
|
406
408
|
set(action_prob_func_args.keys()) == all_times
|
|
@@ -408,106 +410,106 @@ def require_action_prob_func_args_given_for_all_decision_times(
|
|
|
408
410
|
|
|
409
411
|
|
|
410
412
|
def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
411
|
-
|
|
413
|
+
analysis_df: pd.DataFrame,
|
|
412
414
|
calendar_t_col_name: str,
|
|
413
415
|
action_prob_func_args: dict[str, dict[str, tuple[Any, ...]]],
|
|
414
|
-
|
|
415
|
-
|
|
416
|
+
active_col_name,
|
|
417
|
+
subject_id_col_name,
|
|
416
418
|
):
|
|
417
419
|
logger.info(
|
|
418
|
-
"Checking that action probability function args are blank for exactly the times each
|
|
419
|
-
"
|
|
420
|
+
"Checking that action probability function args are blank for exactly the times each subject"
|
|
421
|
+
"is not in the study according to the study dataframe."
|
|
420
422
|
)
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
423
|
+
inactive_df = analysis_df[analysis_df[active_col_name] == 0]
|
|
424
|
+
inactive_times_by_subject_according_to_analysis_df = (
|
|
425
|
+
inactive_df.groupby(subject_id_col_name)[calendar_t_col_name]
|
|
424
426
|
.apply(set)
|
|
425
427
|
.to_dict()
|
|
426
428
|
)
|
|
427
429
|
|
|
428
|
-
|
|
430
|
+
inactive_times_by_subject_according_to_action_prob_func_args = (
|
|
429
431
|
collections.defaultdict(set)
|
|
430
432
|
)
|
|
431
|
-
for decision_time,
|
|
432
|
-
for
|
|
433
|
+
for decision_time, action_prob_args_by_subject in action_prob_func_args.items():
|
|
434
|
+
for subject_id, action_prob_args in action_prob_args_by_subject.items():
|
|
433
435
|
if not action_prob_args:
|
|
434
|
-
|
|
435
|
-
|
|
436
|
+
inactive_times_by_subject_according_to_action_prob_func_args[
|
|
437
|
+
subject_id
|
|
436
438
|
].add(decision_time)
|
|
437
439
|
|
|
438
440
|
assert (
|
|
439
|
-
|
|
440
|
-
==
|
|
441
|
+
inactive_times_by_subject_according_to_analysis_df
|
|
442
|
+
== inactive_times_by_subject_according_to_action_prob_func_args
|
|
441
443
|
), (
|
|
442
444
|
"Out-of-study decision times according to the study dataframe do not match up with the"
|
|
443
|
-
" times for which action probability arguments are blank for all
|
|
445
|
+
" times for which action probability arguments are blank for all subjects. Please see the"
|
|
444
446
|
" contract for details."
|
|
445
447
|
)
|
|
446
448
|
|
|
447
449
|
|
|
448
|
-
def
|
|
449
|
-
|
|
450
|
-
|
|
450
|
+
def require_all_named_columns_present_in_analysis_df(
|
|
451
|
+
analysis_df,
|
|
452
|
+
active_col_name,
|
|
451
453
|
action_col_name,
|
|
452
454
|
policy_num_col_name,
|
|
453
455
|
calendar_t_col_name,
|
|
454
|
-
|
|
456
|
+
subject_id_col_name,
|
|
455
457
|
action_prob_col_name,
|
|
456
458
|
):
|
|
457
459
|
logger.info("Checking that all named columns are present in the study dataframe.")
|
|
460
|
+
assert active_col_name in analysis_df.columns, f"{active_col_name} not in study df."
|
|
461
|
+
assert action_col_name in analysis_df.columns, f"{action_col_name} not in study df."
|
|
458
462
|
assert (
|
|
459
|
-
|
|
460
|
-
), f"{in_study_col_name} not in study df."
|
|
461
|
-
assert action_col_name in study_df.columns, f"{action_col_name} not in study df."
|
|
462
|
-
assert (
|
|
463
|
-
policy_num_col_name in study_df.columns
|
|
463
|
+
policy_num_col_name in analysis_df.columns
|
|
464
464
|
), f"{policy_num_col_name} not in study df."
|
|
465
465
|
assert (
|
|
466
|
-
calendar_t_col_name in
|
|
466
|
+
calendar_t_col_name in analysis_df.columns
|
|
467
467
|
), f"{calendar_t_col_name} not in study df."
|
|
468
|
-
assert user_id_col_name in study_df.columns, f"{user_id_col_name} not in study df."
|
|
469
468
|
assert (
|
|
470
|
-
|
|
469
|
+
subject_id_col_name in analysis_df.columns
|
|
470
|
+
), f"{subject_id_col_name} not in study df."
|
|
471
|
+
assert (
|
|
472
|
+
action_prob_col_name in analysis_df.columns
|
|
471
473
|
), f"{action_prob_col_name} not in study df."
|
|
472
474
|
|
|
473
475
|
|
|
474
|
-
def
|
|
475
|
-
|
|
476
|
-
|
|
476
|
+
def require_all_named_columns_not_object_type_in_analysis_df(
|
|
477
|
+
analysis_df,
|
|
478
|
+
active_col_name,
|
|
477
479
|
action_col_name,
|
|
478
480
|
policy_num_col_name,
|
|
479
481
|
calendar_t_col_name,
|
|
480
|
-
|
|
482
|
+
subject_id_col_name,
|
|
481
483
|
action_prob_col_name,
|
|
482
484
|
):
|
|
483
485
|
logger.info("Checking that all named columns are not type object.")
|
|
484
486
|
for colname in (
|
|
485
|
-
|
|
487
|
+
active_col_name,
|
|
486
488
|
action_col_name,
|
|
487
489
|
policy_num_col_name,
|
|
488
490
|
calendar_t_col_name,
|
|
489
|
-
|
|
491
|
+
subject_id_col_name,
|
|
490
492
|
action_prob_col_name,
|
|
491
493
|
):
|
|
492
494
|
assert (
|
|
493
|
-
|
|
495
|
+
analysis_df[colname].dtype != "object"
|
|
494
496
|
), f"At least {colname} is of object type in study df."
|
|
495
497
|
|
|
496
498
|
|
|
497
|
-
def require_binary_actions(
|
|
499
|
+
def require_binary_actions(analysis_df, active_col_name, action_col_name):
|
|
498
500
|
logger.info("Checking that actions are binary.")
|
|
499
501
|
assert (
|
|
500
|
-
|
|
502
|
+
analysis_df[analysis_df[active_col_name] == 1][action_col_name]
|
|
501
503
|
.astype("int64")
|
|
502
504
|
.isin([0, 1])
|
|
503
505
|
.all()
|
|
504
506
|
), "Actions are not binary."
|
|
505
507
|
|
|
506
508
|
|
|
507
|
-
def
|
|
508
|
-
logger.info("Checking that
|
|
509
|
+
def require_binary_active_indicators(analysis_df, active_col_name):
|
|
510
|
+
logger.info("Checking that active indicators are binary.")
|
|
509
511
|
assert (
|
|
510
|
-
|
|
512
|
+
analysis_df[analysis_df[active_col_name] == 1][active_col_name]
|
|
511
513
|
.astype("int64")
|
|
512
514
|
.isin([0, 1])
|
|
513
515
|
.all()
|
|
@@ -515,7 +517,7 @@ def require_binary_in_study_indicators(study_df, in_study_col_name):
|
|
|
515
517
|
|
|
516
518
|
|
|
517
519
|
def require_consecutive_integer_policy_numbers(
|
|
518
|
-
|
|
520
|
+
analysis_df, active_col_name, policy_num_col_name
|
|
519
521
|
):
|
|
520
522
|
# TODO: This is a somewhat rough check of this, could also check nondecreasing temporally
|
|
521
523
|
|
|
@@ -523,8 +525,8 @@ def require_consecutive_integer_policy_numbers(
|
|
|
523
525
|
"Checking that in-study, non-fallback policy numbers are consecutive integers."
|
|
524
526
|
)
|
|
525
527
|
|
|
526
|
-
|
|
527
|
-
nonnegative_policy_df =
|
|
528
|
+
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
529
|
+
nonnegative_policy_df = active_df[active_df[policy_num_col_name] >= 0]
|
|
528
530
|
# Ideally we actually have integers, but for legacy reasons we will support
|
|
529
531
|
# floats as well.
|
|
530
532
|
if nonnegative_policy_df[policy_num_col_name].dtype == "float64":
|
|
@@ -540,66 +542,67 @@ def require_consecutive_integer_policy_numbers(
|
|
|
540
542
|
), "Policy numbers are not consecutive integers."
|
|
541
543
|
|
|
542
544
|
|
|
543
|
-
def require_consecutive_integer_calendar_times(
|
|
545
|
+
def require_consecutive_integer_calendar_times(analysis_df, calendar_t_col_name):
|
|
544
546
|
# This is a somewhat rough check of this, more like checking there are no
|
|
545
|
-
# gaps in the integers covered. But we have other checks that all
|
|
547
|
+
# gaps in the integers covered. But we have other checks that all subjects
|
|
546
548
|
# have same times, etc.
|
|
547
|
-
# Note these times should be well-formed even when the
|
|
549
|
+
# Note these times should be well-formed even when the subject is not in the study.
|
|
548
550
|
logger.info("Checking that calendar times are consecutive integers.")
|
|
549
551
|
assert np.array_equal(
|
|
550
|
-
|
|
552
|
+
analysis_df[calendar_t_col_name].unique(),
|
|
551
553
|
range(
|
|
552
|
-
|
|
554
|
+
analysis_df[calendar_t_col_name].min(),
|
|
555
|
+
analysis_df[calendar_t_col_name].max() + 1,
|
|
553
556
|
),
|
|
554
557
|
), "Calendar times are not consecutive integers."
|
|
555
558
|
|
|
556
559
|
|
|
557
|
-
def
|
|
558
|
-
logger.info("Checking that
|
|
560
|
+
def require_hashable_subject_ids(analysis_df, active_col_name, subject_id_col_name):
|
|
561
|
+
logger.info("Checking that subject IDs are hashable.")
|
|
559
562
|
isinstance(
|
|
560
|
-
|
|
563
|
+
analysis_df[analysis_df[active_col_name] == 1][subject_id_col_name][0],
|
|
561
564
|
collections.abc.Hashable,
|
|
562
565
|
)
|
|
563
566
|
|
|
564
567
|
|
|
565
|
-
def require_action_probabilities_in_range_0_to_1(
|
|
568
|
+
def require_action_probabilities_in_range_0_to_1(analysis_df, action_prob_col_name):
|
|
566
569
|
logger.info("Checking that action probabilities are in the interval (0, 1).")
|
|
567
|
-
|
|
570
|
+
analysis_df[action_prob_col_name].between(0, 1, inclusive="neither").all()
|
|
568
571
|
|
|
569
572
|
|
|
570
|
-
def
|
|
571
|
-
|
|
573
|
+
def require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
|
|
574
|
+
analysis_df, policy_num_col_name, alg_update_func_args
|
|
572
575
|
):
|
|
573
576
|
logger.info(
|
|
574
577
|
"Checking that policy numbers in algorithm update function args are present in the study dataframe."
|
|
575
578
|
)
|
|
576
579
|
alg_update_policy_nums = sorted(alg_update_func_args.keys())
|
|
577
|
-
|
|
578
|
-
assert set(alg_update_policy_nums).issubset(set(
|
|
580
|
+
analysis_df_policy_nums = sorted(analysis_df[policy_num_col_name].unique())
|
|
581
|
+
assert set(alg_update_policy_nums).issubset(set(analysis_df_policy_nums)), (
|
|
579
582
|
f"There are policy numbers present in algorithm update function args but not in the study dataframe. "
|
|
580
583
|
f"\nalg_update_func_args policy numbers: {alg_update_policy_nums}"
|
|
581
|
-
f"\
|
|
584
|
+
f"\nanalysis_df policy numbers: {analysis_df_policy_nums}.\nPlease see the contract for details."
|
|
582
585
|
)
|
|
583
586
|
|
|
584
587
|
|
|
585
|
-
def
|
|
586
|
-
|
|
588
|
+
def require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
589
|
+
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
587
590
|
):
|
|
588
591
|
logger.info(
|
|
589
592
|
"Checking that all policy numbers in the study dataframe are present in the algorithm update function args."
|
|
590
593
|
)
|
|
591
|
-
|
|
594
|
+
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
592
595
|
# Get the number of the initial policy. 0 is recommended but not required.
|
|
593
|
-
min_nonnegative_policy_number =
|
|
596
|
+
min_nonnegative_policy_number = active_df[active_df[policy_num_col_name] >= 0][
|
|
594
597
|
policy_num_col_name
|
|
595
598
|
]
|
|
596
599
|
assert set(
|
|
597
|
-
|
|
600
|
+
active_df[active_df[policy_num_col_name] > min_nonnegative_policy_number][
|
|
598
601
|
policy_num_col_name
|
|
599
602
|
].unique()
|
|
600
603
|
).issubset(
|
|
601
604
|
alg_update_func_args.keys()
|
|
602
|
-
), f"There are non-fallback, non-initial policy numbers in the study dataframe that are not in the update function args: {set(
|
|
605
|
+
), f"There are non-fallback, non-initial policy numbers in the study dataframe that are not in the update function args: {set(active_df[active_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
|
|
603
606
|
|
|
604
607
|
|
|
605
608
|
def confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
@@ -672,28 +675,29 @@ def require_beta_is_1D_array_in_alg_update_args(
|
|
|
672
675
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
673
676
|
):
|
|
674
677
|
for policy_num in alg_update_func_args:
|
|
675
|
-
for
|
|
676
|
-
if not alg_update_func_args[policy_num][
|
|
678
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
679
|
+
if not alg_update_func_args[policy_num][subject_id]:
|
|
677
680
|
continue
|
|
678
681
|
assert (
|
|
679
|
-
alg_update_func_args[policy_num][
|
|
682
|
+
alg_update_func_args[policy_num][subject_id][
|
|
680
683
|
alg_update_func_args_beta_index
|
|
681
684
|
].ndim
|
|
682
685
|
== 1
|
|
683
686
|
), "Beta is not a 1D array in the algorithm update function args."
|
|
684
687
|
|
|
688
|
+
|
|
685
689
|
def require_previous_betas_is_2D_array_in_alg_update_args(
|
|
686
690
|
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
687
691
|
):
|
|
688
692
|
if alg_update_func_args_previous_betas_index < 0:
|
|
689
693
|
return
|
|
690
|
-
|
|
694
|
+
|
|
691
695
|
for policy_num in alg_update_func_args:
|
|
692
|
-
for
|
|
693
|
-
if not alg_update_func_args[policy_num][
|
|
696
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
697
|
+
if not alg_update_func_args[policy_num][subject_id]:
|
|
694
698
|
continue
|
|
695
699
|
assert (
|
|
696
|
-
alg_update_func_args[policy_num][
|
|
700
|
+
alg_update_func_args[policy_num][subject_id][
|
|
697
701
|
alg_update_func_args_previous_betas_index
|
|
698
702
|
].ndim
|
|
699
703
|
== 2
|
|
@@ -704,11 +708,11 @@ def require_beta_is_1D_array_in_action_prob_args(
|
|
|
704
708
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
705
709
|
):
|
|
706
710
|
for decision_time in action_prob_func_args:
|
|
707
|
-
for
|
|
708
|
-
if not action_prob_func_args[decision_time][
|
|
711
|
+
for subject_id in action_prob_func_args[decision_time]:
|
|
712
|
+
if not action_prob_func_args[decision_time][subject_id]:
|
|
709
713
|
continue
|
|
710
714
|
assert (
|
|
711
|
-
action_prob_func_args[decision_time][
|
|
715
|
+
action_prob_func_args[decision_time][subject_id][
|
|
712
716
|
action_prob_func_args_beta_index
|
|
713
717
|
].ndim
|
|
714
718
|
== 1
|
|
@@ -719,12 +723,12 @@ def require_theta_is_1D_array(theta_est):
|
|
|
719
723
|
assert theta_est.ndim == 1, "Theta is not a 1D array."
|
|
720
724
|
|
|
721
725
|
|
|
722
|
-
def
|
|
723
|
-
|
|
724
|
-
|
|
726
|
+
def verify_analysis_df_summary_satisfactory(
|
|
727
|
+
analysis_df,
|
|
728
|
+
subject_id_col_name,
|
|
725
729
|
policy_num_col_name,
|
|
726
730
|
calendar_t_col_name,
|
|
727
|
-
|
|
731
|
+
active_col_name,
|
|
728
732
|
action_prob_col_name,
|
|
729
733
|
reward_col_name,
|
|
730
734
|
beta_dim,
|
|
@@ -732,43 +736,41 @@ def verify_study_df_summary_satisfactory(
|
|
|
732
736
|
suppress_interactive_data_checks,
|
|
733
737
|
):
|
|
734
738
|
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
num_non_initial_or_fallback_policies =
|
|
738
|
-
|
|
739
|
+
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
740
|
+
num_subjects = active_df[subject_id_col_name].nunique()
|
|
741
|
+
num_non_initial_or_fallback_policies = active_df[
|
|
742
|
+
active_df[policy_num_col_name] > 0
|
|
739
743
|
][policy_num_col_name].nunique()
|
|
740
744
|
num_decision_times_with_fallback_policies = len(
|
|
741
|
-
|
|
745
|
+
active_df[active_df[policy_num_col_name] < 0]
|
|
742
746
|
)
|
|
743
|
-
num_decision_times =
|
|
744
|
-
|
|
747
|
+
num_decision_times = active_df[calendar_t_col_name].nunique()
|
|
748
|
+
avg_decisions_per_subject = len(active_df) / num_subjects
|
|
745
749
|
num_decision_times_with_multiple_policies = (
|
|
746
|
-
|
|
750
|
+
active_df[active_df[policy_num_col_name] >= 0]
|
|
747
751
|
.groupby(calendar_t_col_name)[policy_num_col_name]
|
|
748
752
|
.nunique()
|
|
749
753
|
> 1
|
|
750
754
|
).sum()
|
|
751
|
-
min_action_prob =
|
|
752
|
-
max_action_prob =
|
|
753
|
-
min_non_fallback_policy_num =
|
|
755
|
+
min_action_prob = active_df[action_prob_col_name].min()
|
|
756
|
+
max_action_prob = active_df[action_prob_col_name].max()
|
|
757
|
+
min_non_fallback_policy_num = active_df[active_df[policy_num_col_name] >= 0][
|
|
754
758
|
policy_num_col_name
|
|
755
759
|
].min()
|
|
756
760
|
num_data_points_before_first_update = len(
|
|
757
|
-
|
|
761
|
+
active_df[active_df[policy_num_col_name] == min_non_fallback_policy_num]
|
|
758
762
|
)
|
|
759
763
|
|
|
760
764
|
median_action_probabilities = (
|
|
761
|
-
|
|
762
|
-
.median()
|
|
763
|
-
.to_numpy()
|
|
765
|
+
active_df.groupby(calendar_t_col_name)[action_prob_col_name].median().to_numpy()
|
|
764
766
|
)
|
|
765
|
-
quartiles =
|
|
767
|
+
quartiles = active_df.groupby(calendar_t_col_name)[action_prob_col_name].quantile(
|
|
766
768
|
[0.25, 0.75]
|
|
767
769
|
)
|
|
768
770
|
q25_action_probabilities = quartiles.xs(0.25, level=1).to_numpy()
|
|
769
771
|
q75_action_probabilities = quartiles.xs(0.75, level=1).to_numpy()
|
|
770
772
|
|
|
771
|
-
avg_rewards =
|
|
773
|
+
avg_rewards = active_df.groupby(calendar_t_col_name)[reward_col_name].mean()
|
|
772
774
|
|
|
773
775
|
# Plot action probability quartile trajectories
|
|
774
776
|
plt.clear_figure()
|
|
@@ -808,10 +810,10 @@ def verify_study_df_summary_satisfactory(
|
|
|
808
810
|
|
|
809
811
|
confirm_input_check_result(
|
|
810
812
|
f"\nYou provided a study dataframe reflecting a study with"
|
|
811
|
-
f"\n* {
|
|
813
|
+
f"\n* {num_subjects} subjects"
|
|
812
814
|
f"\n* {num_non_initial_or_fallback_policies} policy updates"
|
|
813
|
-
f"\n* {num_decision_times} decision times, for an average of {
|
|
814
|
-
f" decisions per
|
|
815
|
+
f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_subject}"
|
|
816
|
+
f" decisions per subject"
|
|
815
817
|
f"\n* RL parameters of dimension {beta_dim} per update"
|
|
816
818
|
f"\n* Inferential target of dimension {theta_dim}"
|
|
817
819
|
f"\n* {num_data_points_before_first_update} data points before the first update"
|
|
@@ -834,14 +836,14 @@ def require_betas_match_in_alg_update_args_each_update(
|
|
|
834
836
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
835
837
|
):
|
|
836
838
|
logger.info(
|
|
837
|
-
"Checking that betas match across
|
|
839
|
+
"Checking that betas match across subjects for each update in the algorithm update function args."
|
|
838
840
|
)
|
|
839
841
|
for policy_num in alg_update_func_args:
|
|
840
842
|
first_beta = None
|
|
841
|
-
for
|
|
842
|
-
if not alg_update_func_args[policy_num][
|
|
843
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
844
|
+
if not alg_update_func_args[policy_num][subject_id]:
|
|
843
845
|
continue
|
|
844
|
-
beta = alg_update_func_args[policy_num][
|
|
846
|
+
beta = alg_update_func_args[policy_num][subject_id][
|
|
845
847
|
alg_update_func_args_beta_index
|
|
846
848
|
]
|
|
847
849
|
if first_beta is None:
|
|
@@ -849,23 +851,24 @@ def require_betas_match_in_alg_update_args_each_update(
|
|
|
849
851
|
else:
|
|
850
852
|
assert np.array_equal(
|
|
851
853
|
beta, first_beta
|
|
852
|
-
), f"Betas do not match across
|
|
854
|
+
), f"Betas do not match across subjects in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
855
|
+
|
|
853
856
|
|
|
854
857
|
def require_previous_betas_match_in_alg_update_args_each_update(
|
|
855
858
|
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
856
859
|
):
|
|
857
860
|
logger.info(
|
|
858
|
-
"Checking that previous betas match across
|
|
861
|
+
"Checking that previous betas match across subjects for each update in the algorithm update function args."
|
|
859
862
|
)
|
|
860
863
|
if alg_update_func_args_previous_betas_index < 0:
|
|
861
|
-
return
|
|
862
|
-
|
|
864
|
+
return
|
|
865
|
+
|
|
863
866
|
for policy_num in alg_update_func_args:
|
|
864
867
|
first_previous_betas = None
|
|
865
|
-
for
|
|
866
|
-
if not alg_update_func_args[policy_num][
|
|
868
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
869
|
+
if not alg_update_func_args[policy_num][subject_id]:
|
|
867
870
|
continue
|
|
868
|
-
previous_betas = alg_update_func_args[policy_num][
|
|
871
|
+
previous_betas = alg_update_func_args[policy_num][subject_id][
|
|
869
872
|
alg_update_func_args_previous_betas_index
|
|
870
873
|
]
|
|
871
874
|
if first_previous_betas is None:
|
|
@@ -873,21 +876,21 @@ def require_previous_betas_match_in_alg_update_args_each_update(
|
|
|
873
876
|
else:
|
|
874
877
|
assert np.array_equal(
|
|
875
878
|
previous_betas, first_previous_betas
|
|
876
|
-
), f"Previous betas do not match across
|
|
879
|
+
), f"Previous betas do not match across subjects in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
877
880
|
|
|
878
881
|
|
|
879
882
|
def require_betas_match_in_action_prob_func_args_each_decision(
|
|
880
883
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
881
884
|
):
|
|
882
885
|
logger.info(
|
|
883
|
-
"Checking that betas match across
|
|
886
|
+
"Checking that betas match across subjects for each decision time in the action prob args."
|
|
884
887
|
)
|
|
885
888
|
for decision_time in action_prob_func_args:
|
|
886
889
|
first_beta = None
|
|
887
|
-
for
|
|
888
|
-
if not action_prob_func_args[decision_time][
|
|
890
|
+
for subject_id in action_prob_func_args[decision_time]:
|
|
891
|
+
if not action_prob_func_args[decision_time][subject_id]:
|
|
889
892
|
continue
|
|
890
|
-
beta = action_prob_func_args[decision_time][
|
|
893
|
+
beta = action_prob_func_args[decision_time][subject_id][
|
|
891
894
|
action_prob_func_args_beta_index
|
|
892
895
|
]
|
|
893
896
|
if first_beta is None:
|
|
@@ -895,11 +898,11 @@ def require_betas_match_in_action_prob_func_args_each_decision(
|
|
|
895
898
|
else:
|
|
896
899
|
assert np.array_equal(
|
|
897
900
|
beta, first_beta
|
|
898
|
-
), f"Betas do not match across
|
|
901
|
+
), f"Betas do not match across subjects in the action prob args for decision_time {decision_time}. Please see the contract for details."
|
|
899
902
|
|
|
900
903
|
|
|
901
904
|
def require_valid_action_prob_times_given_if_index_supplied(
|
|
902
|
-
|
|
905
|
+
analysis_df,
|
|
903
906
|
calendar_t_col_name,
|
|
904
907
|
alg_update_func_args,
|
|
905
908
|
alg_update_func_args_action_prob_times_index,
|
|
@@ -909,19 +912,19 @@ def require_valid_action_prob_times_given_if_index_supplied(
|
|
|
909
912
|
if alg_update_func_args_action_prob_times_index < 0:
|
|
910
913
|
return
|
|
911
914
|
|
|
912
|
-
min_time =
|
|
913
|
-
max_time =
|
|
914
|
-
for policy_idx,
|
|
915
|
-
for
|
|
915
|
+
min_time = analysis_df[calendar_t_col_name].min()
|
|
916
|
+
max_time = analysis_df[calendar_t_col_name].max()
|
|
917
|
+
for policy_idx, args_by_subject in alg_update_func_args.items():
|
|
918
|
+
for subject_id, args in args_by_subject.items():
|
|
916
919
|
if not args:
|
|
917
920
|
continue
|
|
918
921
|
times = args[alg_update_func_args_action_prob_times_index]
|
|
919
922
|
assert (
|
|
920
923
|
times[i] > times[i - 1] for i in range(1, len(times))
|
|
921
|
-
), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for
|
|
924
|
+
), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for subject {subject_id} and policy {policy_idx}. Please see the contract for details."
|
|
922
925
|
assert (
|
|
923
926
|
times[0] >= min_time and times[-1] <= max_time
|
|
924
|
-
), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the study dataframe are {min_time} and {max_time}, while
|
|
927
|
+
), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the study dataframe are {min_time} and {max_time}, while subject {subject_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
|
|
925
928
|
|
|
926
929
|
|
|
927
930
|
def require_estimating_functions_sum_to_zero(
|
|
@@ -933,15 +936,15 @@ def require_estimating_functions_sum_to_zero(
|
|
|
933
936
|
"""
|
|
934
937
|
This is a test that the correct loss/estimating functions have
|
|
935
938
|
been given for both the algorithm updates and inference. If that is true, then the
|
|
936
|
-
loss/estimating functions when evaluated should sum to approximately zero across
|
|
937
|
-
values have been stacked and averaged across
|
|
939
|
+
loss/estimating functions when evaluated should sum to approximately zero across subjects. These
|
|
940
|
+
values have been stacked and averaged across subjects in mean_estimating_function_stack, which
|
|
938
941
|
we simply compare to the zero vector. We can isolate components for each update and inference
|
|
939
942
|
by considering the dimensions of the beta vectors and the theta vector.
|
|
940
943
|
|
|
941
944
|
Inputs:
|
|
942
945
|
mean_estimating_function_stack:
|
|
943
946
|
The mean of the estimating function stack (a component for each algorithm update and
|
|
944
|
-
inference) across
|
|
947
|
+
inference) across subjects. This should be a 1D array.
|
|
945
948
|
beta_dim:
|
|
946
949
|
The dimension of the beta vectors that parameterize the algorithm.
|
|
947
950
|
theta_dim:
|
|
@@ -951,7 +954,7 @@ def require_estimating_functions_sum_to_zero(
|
|
|
951
954
|
None
|
|
952
955
|
"""
|
|
953
956
|
|
|
954
|
-
logger.info("Checking that estimating functions average to zero across
|
|
957
|
+
logger.info("Checking that estimating functions average to zero across subjects")
|
|
955
958
|
|
|
956
959
|
# Have a looser hard failure cutoff before the typical interactive check
|
|
957
960
|
try:
|
|
@@ -962,7 +965,7 @@ def require_estimating_functions_sum_to_zero(
|
|
|
962
965
|
)
|
|
963
966
|
except AssertionError as e:
|
|
964
967
|
logger.info(
|
|
965
|
-
"Estimating function stacks do not average to within loose tolerance of zero across
|
|
968
|
+
"Estimating function stacks do not average to within loose tolerance of zero across subjects. Drilling in to specific updates and inference component."
|
|
966
969
|
)
|
|
967
970
|
# If this is not true there is an internal problem in the package.
|
|
968
971
|
assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
|
|
@@ -987,11 +990,11 @@ def require_estimating_functions_sum_to_zero(
|
|
|
987
990
|
np.testing.assert_allclose(
|
|
988
991
|
mean_estimating_function_stack,
|
|
989
992
|
jnp.zeros(mean_estimating_function_stack.size),
|
|
990
|
-
atol=
|
|
993
|
+
atol=5e-4,
|
|
991
994
|
)
|
|
992
995
|
except AssertionError as e:
|
|
993
996
|
logger.info(
|
|
994
|
-
"Estimating function stacks do not average to within specified tolerance of zero across
|
|
997
|
+
"Estimating function stacks do not average to within specified tolerance of zero across subjects. Drilling in to specific updates and inference component."
|
|
995
998
|
)
|
|
996
999
|
# If this is not true there is an internal problem in the package.
|
|
997
1000
|
assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
|
|
@@ -1021,15 +1024,15 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1021
1024
|
"""
|
|
1022
1025
|
This is a test that the correct loss/estimating functions have
|
|
1023
1026
|
been given for both the algorithm updates and inference. If that is true, then the
|
|
1024
|
-
loss/estimating functions when evaluated should sum to approximately zero across
|
|
1025
|
-
values have been stacked and averaged across
|
|
1027
|
+
loss/estimating functions when evaluated should sum to approximately zero across subjects. These
|
|
1028
|
+
values have been stacked and averaged across subjects in mean_estimating_function_stack, which
|
|
1026
1029
|
we simply compare to the zero vector. We can isolate components for each update and inference
|
|
1027
1030
|
by considering the dimensions of the beta vectors and the theta vector.
|
|
1028
1031
|
|
|
1029
1032
|
Inputs:
|
|
1030
1033
|
mean_estimating_function_stack:
|
|
1031
1034
|
The mean of the estimating function stack (a component for each algorithm update and
|
|
1032
|
-
inference) across
|
|
1035
|
+
inference) across subjects. This should be a 1D array.
|
|
1033
1036
|
beta_dim:
|
|
1034
1037
|
The dimension of the beta vectors that parameterize the algorithm.
|
|
1035
1038
|
theta_dim:
|
|
@@ -1039,7 +1042,7 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1039
1042
|
None
|
|
1040
1043
|
"""
|
|
1041
1044
|
|
|
1042
|
-
logger.info("Checking that RL estimating functions average to zero across
|
|
1045
|
+
logger.info("Checking that RL estimating functions average to zero across subjects")
|
|
1043
1046
|
|
|
1044
1047
|
# Have a looser hard failure cutoff before the typical interactive check
|
|
1045
1048
|
try:
|
|
@@ -1050,7 +1053,7 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1050
1053
|
)
|
|
1051
1054
|
except AssertionError as e:
|
|
1052
1055
|
logger.info(
|
|
1053
|
-
"RL estimating function stacks do not average to zero across
|
|
1056
|
+
"RL estimating function stacks do not average to zero across subjects. Drilling in to specific updates and inference component."
|
|
1054
1057
|
)
|
|
1055
1058
|
num_updates = (mean_estimating_function_stack.size) // beta_dim
|
|
1056
1059
|
for i in range(num_updates):
|
|
@@ -1070,7 +1073,7 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1070
1073
|
)
|
|
1071
1074
|
except AssertionError as e:
|
|
1072
1075
|
logger.info(
|
|
1073
|
-
"RL estimating function stacks do not average to zero across
|
|
1076
|
+
"RL estimating function stacks do not average to zero across subjects. Drilling in to specific updates and inference component."
|
|
1074
1077
|
)
|
|
1075
1078
|
num_updates = (mean_estimating_function_stack.size) // beta_dim
|
|
1076
1079
|
for i in range(num_updates):
|
|
@@ -1079,7 +1082,6 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1079
1082
|
i + 1,
|
|
1080
1083
|
mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
|
|
1081
1084
|
)
|
|
1082
|
-
# TODO: Email instead of requiring user input for monitoring alg.
|
|
1083
1085
|
confirm_input_check_result(
|
|
1084
1086
|
f"\nEstimating functions do not average to within default tolerance of zero vector. Please decide if the following is a reasonable result, taking into account the above breakdown by update number and inference. If not, there are several possible reasons for failure mentioned in the contract. Results:\n{str(e)}\n\nContinue? (y/n)\n",
|
|
1085
1087
|
suppress_interactive_data_checks,
|
|
@@ -1133,8 +1135,8 @@ def require_adaptive_bread_inverse_is_true_inverse(
|
|
|
1133
1135
|
|
|
1134
1136
|
def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
1135
1137
|
algorithm_estimating_func,
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1139
|
+
threaded_update_func_args_by_policy_num_by_subject_id,
|
|
1138
1140
|
suppress_interactive_data_checks,
|
|
1139
1141
|
):
|
|
1140
1142
|
"""
|
|
@@ -1144,12 +1146,12 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
|
1144
1146
|
"""
|
|
1145
1147
|
for (
|
|
1146
1148
|
policy_num,
|
|
1147
|
-
|
|
1148
|
-
) in
|
|
1149
|
+
update_func_args_by_subject_id,
|
|
1150
|
+
) in update_func_args_by_by_subject_id_by_policy_num.items():
|
|
1149
1151
|
for (
|
|
1150
|
-
|
|
1152
|
+
subject_id,
|
|
1151
1153
|
unthreaded_args,
|
|
1152
|
-
) in
|
|
1154
|
+
) in update_func_args_by_subject_id.items():
|
|
1153
1155
|
if not unthreaded_args:
|
|
1154
1156
|
continue
|
|
1155
1157
|
np.testing.assert_allclose(
|
|
@@ -1157,9 +1159,9 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
|
1157
1159
|
# Need to stop gradient here because we can't convert a traced value to np array
|
|
1158
1160
|
jax.lax.stop_gradient(
|
|
1159
1161
|
algorithm_estimating_func(
|
|
1160
|
-
*
|
|
1161
|
-
|
|
1162
|
-
]
|
|
1162
|
+
*threaded_update_func_args_by_policy_num_by_subject_id[
|
|
1163
|
+
subject_id
|
|
1164
|
+
][policy_num]
|
|
1163
1165
|
)
|
|
1164
1166
|
),
|
|
1165
1167
|
atol=1e-7,
|
|
@@ -1169,8 +1171,8 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
|
1169
1171
|
|
|
1170
1172
|
def require_threaded_inference_estimating_function_args_equivalent(
|
|
1171
1173
|
inference_estimating_func,
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
+
inference_func_args_by_subject_id,
|
|
1175
|
+
threaded_inference_func_args_by_subject_id,
|
|
1174
1176
|
suppress_interactive_data_checks,
|
|
1175
1177
|
):
|
|
1176
1178
|
"""
|
|
@@ -1178,7 +1180,7 @@ def require_threaded_inference_estimating_function_args_equivalent(
|
|
|
1178
1180
|
when called with the original arguments and when called with the
|
|
1179
1181
|
reconstructed action probabilities substituted in.
|
|
1180
1182
|
"""
|
|
1181
|
-
for
|
|
1183
|
+
for subject_id, unthreaded_args in inference_func_args_by_subject_id.items():
|
|
1182
1184
|
if not unthreaded_args:
|
|
1183
1185
|
continue
|
|
1184
1186
|
np.testing.assert_allclose(
|
|
@@ -1186,7 +1188,7 @@ def require_threaded_inference_estimating_function_args_equivalent(
|
|
|
1186
1188
|
# Need to stop gradient here because we can't convert a traced value to np array
|
|
1187
1189
|
jax.lax.stop_gradient(
|
|
1188
1190
|
inference_estimating_func(
|
|
1189
|
-
*
|
|
1191
|
+
*threaded_inference_func_args_by_subject_id[subject_id]
|
|
1190
1192
|
)
|
|
1191
1193
|
),
|
|
1192
1194
|
rtol=1e-2,
|