lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -23
- lifejacket/constants.py +4 -16
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +163 -138
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +32 -34
- lifejacket/get_datum_for_blowup_supervised_learning.py +341 -339
- lifejacket/helper_functions.py +60 -186
- lifejacket/input_checks.py +303 -302
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +470 -457
- lifejacket/small_sample_corrections.py +49 -49
- lifejacket-1.0.2.dist-info/METADATA +56 -0
- lifejacket-1.0.2.dist-info/RECORD +17 -0
- lifejacket-1.0.2.dist-info/entry_points.txt +2 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- lifejacket-0.2.1.dist-info/entry_points.txt +0 -2
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/top_level.txt +0 -0
lifejacket/input_checks.py
CHANGED
|
@@ -8,7 +8,7 @@ from jax import numpy as jnp
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import plotext as plt
|
|
10
10
|
|
|
11
|
-
from .constants import
|
|
11
|
+
from .constants import SmallSampleCorrections
|
|
12
12
|
from .helper_functions import (
|
|
13
13
|
confirm_input_check_result,
|
|
14
14
|
)
|
|
@@ -26,12 +26,12 @@ logging.basicConfig(
|
|
|
26
26
|
|
|
27
27
|
# TODO: any checks needed here about alg update function type?
|
|
28
28
|
def perform_first_wave_input_checks(
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
analysis_df,
|
|
30
|
+
active_col_name,
|
|
31
31
|
action_col_name,
|
|
32
32
|
policy_num_col_name,
|
|
33
33
|
calendar_t_col_name,
|
|
34
|
-
|
|
34
|
+
subject_id_col_name,
|
|
35
35
|
action_prob_col_name,
|
|
36
36
|
reward_col_name,
|
|
37
37
|
action_prob_func,
|
|
@@ -48,11 +48,11 @@ def perform_first_wave_input_checks(
|
|
|
48
48
|
small_sample_correction,
|
|
49
49
|
):
|
|
50
50
|
### Validate algorithm loss/estimating function and args
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
require_alg_update_args_given_for_all_subjects_at_each_update(
|
|
52
|
+
analysis_df, subject_id_col_name, alg_update_func_args
|
|
53
53
|
)
|
|
54
|
-
|
|
55
|
-
|
|
54
|
+
require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
|
|
55
|
+
analysis_df, policy_num_col_name, alg_update_func_args
|
|
56
56
|
)
|
|
57
57
|
require_beta_is_1D_array_in_alg_update_args(
|
|
58
58
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
@@ -60,11 +60,13 @@ def perform_first_wave_input_checks(
|
|
|
60
60
|
require_previous_betas_is_2D_array_in_alg_update_args(
|
|
61
61
|
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
62
62
|
)
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
64
|
+
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
65
65
|
)
|
|
66
66
|
confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
67
|
-
alg_update_func_args_action_prob_index,
|
|
67
|
+
alg_update_func_args_action_prob_index,
|
|
68
|
+
alg_update_func_args_previous_betas_index,
|
|
69
|
+
suppress_interactive_data_checks,
|
|
68
70
|
)
|
|
69
71
|
require_action_prob_times_given_if_index_supplied(
|
|
70
72
|
alg_update_func_args_action_prob_index,
|
|
@@ -80,17 +82,17 @@ def perform_first_wave_input_checks(
|
|
|
80
82
|
require_previous_betas_match_in_alg_update_args_each_update(
|
|
81
83
|
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
82
84
|
)
|
|
83
|
-
|
|
84
|
-
|
|
85
|
+
require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
|
|
86
|
+
analysis_df,
|
|
85
87
|
action_prob_col_name,
|
|
86
88
|
calendar_t_col_name,
|
|
87
|
-
|
|
89
|
+
subject_id_col_name,
|
|
88
90
|
alg_update_func_args,
|
|
89
91
|
alg_update_func_args_action_prob_index,
|
|
90
92
|
alg_update_func_args_action_prob_times_index,
|
|
91
93
|
)
|
|
92
94
|
require_valid_action_prob_times_given_if_index_supplied(
|
|
93
|
-
|
|
95
|
+
analysis_df,
|
|
94
96
|
calendar_t_col_name,
|
|
95
97
|
alg_update_func_args,
|
|
96
98
|
alg_update_func_args_action_prob_times_index,
|
|
@@ -101,28 +103,28 @@ def perform_first_wave_input_checks(
|
|
|
101
103
|
)
|
|
102
104
|
|
|
103
105
|
### Validate action prob function and args
|
|
104
|
-
|
|
105
|
-
|
|
106
|
+
require_action_prob_func_args_given_for_all_subjects_at_each_decision(
|
|
107
|
+
analysis_df, subject_id_col_name, action_prob_func_args
|
|
106
108
|
)
|
|
107
109
|
require_action_prob_func_args_given_for_all_decision_times(
|
|
108
|
-
|
|
110
|
+
analysis_df, calendar_t_col_name, action_prob_func_args
|
|
109
111
|
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
+
require_action_probabilities_in_analysis_df_can_be_reconstructed(
|
|
113
|
+
analysis_df,
|
|
112
114
|
action_prob_col_name,
|
|
113
115
|
calendar_t_col_name,
|
|
114
|
-
|
|
115
|
-
|
|
116
|
+
subject_id_col_name,
|
|
117
|
+
active_col_name,
|
|
116
118
|
action_prob_func_args,
|
|
117
119
|
action_prob_func,
|
|
118
120
|
)
|
|
119
121
|
|
|
120
122
|
require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
121
|
-
|
|
123
|
+
analysis_df,
|
|
122
124
|
calendar_t_col_name,
|
|
123
125
|
action_prob_func_args,
|
|
124
|
-
|
|
125
|
-
|
|
126
|
+
active_col_name,
|
|
127
|
+
subject_id_col_name,
|
|
126
128
|
)
|
|
127
129
|
require_beta_is_1D_array_in_action_prob_args(
|
|
128
130
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
@@ -131,13 +133,13 @@ def perform_first_wave_input_checks(
|
|
|
131
133
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
132
134
|
)
|
|
133
135
|
|
|
134
|
-
### Validate
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
136
|
+
### Validate analysis_df
|
|
137
|
+
verify_analysis_df_summary_satisfactory(
|
|
138
|
+
analysis_df,
|
|
139
|
+
subject_id_col_name,
|
|
138
140
|
policy_num_col_name,
|
|
139
141
|
calendar_t_col_name,
|
|
140
|
-
|
|
142
|
+
active_col_name,
|
|
141
143
|
action_prob_col_name,
|
|
142
144
|
reward_col_name,
|
|
143
145
|
beta_dim,
|
|
@@ -145,46 +147,46 @@ def perform_first_wave_input_checks(
|
|
|
145
147
|
suppress_interactive_data_checks,
|
|
146
148
|
)
|
|
147
149
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
+
require_all_subjects_have_all_times_in_analysis_df(
|
|
151
|
+
analysis_df, calendar_t_col_name, subject_id_col_name
|
|
150
152
|
)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
153
|
+
require_all_named_columns_present_in_analysis_df(
|
|
154
|
+
analysis_df,
|
|
155
|
+
active_col_name,
|
|
154
156
|
action_col_name,
|
|
155
157
|
policy_num_col_name,
|
|
156
158
|
calendar_t_col_name,
|
|
157
|
-
|
|
159
|
+
subject_id_col_name,
|
|
158
160
|
action_prob_col_name,
|
|
159
161
|
)
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
162
|
+
require_all_named_columns_not_object_type_in_analysis_df(
|
|
163
|
+
analysis_df,
|
|
164
|
+
active_col_name,
|
|
163
165
|
action_col_name,
|
|
164
166
|
policy_num_col_name,
|
|
165
167
|
calendar_t_col_name,
|
|
166
|
-
|
|
168
|
+
subject_id_col_name,
|
|
167
169
|
action_prob_col_name,
|
|
168
170
|
)
|
|
169
|
-
require_binary_actions(
|
|
170
|
-
|
|
171
|
+
require_binary_actions(analysis_df, active_col_name, action_col_name)
|
|
172
|
+
require_binary_active_indicators(analysis_df, active_col_name)
|
|
171
173
|
require_consecutive_integer_policy_numbers(
|
|
172
|
-
|
|
174
|
+
analysis_df, active_col_name, policy_num_col_name
|
|
173
175
|
)
|
|
174
|
-
require_consecutive_integer_calendar_times(
|
|
175
|
-
|
|
176
|
-
require_action_probabilities_in_range_0_to_1(
|
|
176
|
+
require_consecutive_integer_calendar_times(analysis_df, calendar_t_col_name)
|
|
177
|
+
require_hashable_subject_ids(analysis_df, active_col_name, subject_id_col_name)
|
|
178
|
+
require_action_probabilities_in_range_0_to_1(analysis_df, action_prob_col_name)
|
|
177
179
|
|
|
178
180
|
### Validate theta estimation
|
|
179
181
|
require_theta_is_1D_array(theta_est)
|
|
180
182
|
|
|
181
183
|
|
|
182
184
|
def perform_alg_only_input_checks(
|
|
183
|
-
|
|
184
|
-
|
|
185
|
+
analysis_df,
|
|
186
|
+
active_col_name,
|
|
185
187
|
policy_num_col_name,
|
|
186
188
|
calendar_t_col_name,
|
|
187
|
-
|
|
189
|
+
subject_id_col_name,
|
|
188
190
|
action_prob_col_name,
|
|
189
191
|
action_prob_func,
|
|
190
192
|
action_prob_func_args,
|
|
@@ -193,20 +195,23 @@ def perform_alg_only_input_checks(
|
|
|
193
195
|
alg_update_func_args_beta_index,
|
|
194
196
|
alg_update_func_args_action_prob_index,
|
|
195
197
|
alg_update_func_args_action_prob_times_index,
|
|
198
|
+
alg_update_func_args_previous_betas_index,
|
|
196
199
|
suppress_interactive_data_checks,
|
|
197
200
|
):
|
|
198
201
|
### Validate algorithm loss/estimating function and args
|
|
199
|
-
|
|
200
|
-
|
|
202
|
+
require_alg_update_args_given_for_all_subjects_at_each_update(
|
|
203
|
+
analysis_df, subject_id_col_name, alg_update_func_args
|
|
201
204
|
)
|
|
202
205
|
require_beta_is_1D_array_in_alg_update_args(
|
|
203
206
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
204
207
|
)
|
|
205
|
-
|
|
206
|
-
|
|
208
|
+
require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
209
|
+
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
207
210
|
)
|
|
208
211
|
confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
209
|
-
alg_update_func_args_action_prob_index,
|
|
212
|
+
alg_update_func_args_action_prob_index,
|
|
213
|
+
alg_update_func_args_previous_betas_index,
|
|
214
|
+
suppress_interactive_data_checks,
|
|
210
215
|
)
|
|
211
216
|
require_action_prob_times_given_if_index_supplied(
|
|
212
217
|
alg_update_func_args_action_prob_index,
|
|
@@ -219,45 +224,45 @@ def perform_alg_only_input_checks(
|
|
|
219
224
|
require_betas_match_in_alg_update_args_each_update(
|
|
220
225
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
221
226
|
)
|
|
222
|
-
|
|
223
|
-
|
|
227
|
+
require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
|
|
228
|
+
analysis_df,
|
|
224
229
|
action_prob_col_name,
|
|
225
230
|
calendar_t_col_name,
|
|
226
|
-
|
|
231
|
+
subject_id_col_name,
|
|
227
232
|
alg_update_func_args,
|
|
228
233
|
alg_update_func_args_action_prob_index,
|
|
229
234
|
alg_update_func_args_action_prob_times_index,
|
|
230
235
|
)
|
|
231
236
|
require_valid_action_prob_times_given_if_index_supplied(
|
|
232
|
-
|
|
237
|
+
analysis_df,
|
|
233
238
|
calendar_t_col_name,
|
|
234
239
|
alg_update_func_args,
|
|
235
240
|
alg_update_func_args_action_prob_times_index,
|
|
236
241
|
)
|
|
237
242
|
|
|
238
243
|
### Validate action prob function and args
|
|
239
|
-
|
|
240
|
-
|
|
244
|
+
require_action_prob_func_args_given_for_all_subjects_at_each_decision(
|
|
245
|
+
analysis_df, subject_id_col_name, action_prob_func_args
|
|
241
246
|
)
|
|
242
247
|
require_action_prob_func_args_given_for_all_decision_times(
|
|
243
|
-
|
|
248
|
+
analysis_df, calendar_t_col_name, action_prob_func_args
|
|
244
249
|
)
|
|
245
|
-
|
|
246
|
-
|
|
250
|
+
require_action_probabilities_in_analysis_df_can_be_reconstructed(
|
|
251
|
+
analysis_df,
|
|
247
252
|
action_prob_col_name,
|
|
248
253
|
calendar_t_col_name,
|
|
249
|
-
|
|
250
|
-
|
|
254
|
+
subject_id_col_name,
|
|
255
|
+
active_col_name,
|
|
251
256
|
action_prob_func_args,
|
|
252
257
|
action_prob_func=action_prob_func,
|
|
253
258
|
)
|
|
254
259
|
|
|
255
260
|
require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
256
|
-
|
|
261
|
+
analysis_df,
|
|
257
262
|
calendar_t_col_name,
|
|
258
263
|
action_prob_func_args,
|
|
259
|
-
|
|
260
|
-
|
|
264
|
+
active_col_name,
|
|
265
|
+
subject_id_col_name,
|
|
261
266
|
)
|
|
262
267
|
require_beta_is_1D_array_in_action_prob_args(
|
|
263
268
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
@@ -268,94 +273,96 @@ def perform_alg_only_input_checks(
|
|
|
268
273
|
|
|
269
274
|
|
|
270
275
|
# TODO: Give a hard-to-use option to loosen this check somehow
|
|
271
|
-
def
|
|
272
|
-
|
|
276
|
+
def require_action_probabilities_in_analysis_df_can_be_reconstructed(
|
|
277
|
+
analysis_df,
|
|
273
278
|
action_prob_col_name,
|
|
274
279
|
calendar_t_col_name,
|
|
275
|
-
|
|
276
|
-
|
|
280
|
+
subject_id_col_name,
|
|
281
|
+
active_col_name,
|
|
277
282
|
action_prob_func_args,
|
|
278
283
|
action_prob_func,
|
|
279
284
|
):
|
|
280
285
|
"""
|
|
281
|
-
Check that the action probabilities in the
|
|
286
|
+
Check that the action probabilities in the analysis DataFrame can be reconstructed from the supplied
|
|
282
287
|
action probability function and its arguments.
|
|
283
288
|
|
|
284
289
|
NOTE THAT THIS IS A HARD FAILURE IF THE RECONSTRUCTION DOESN'T PASS.
|
|
285
290
|
"""
|
|
286
291
|
logger.info("Reconstructing action probabilities from function and arguments.")
|
|
287
292
|
|
|
288
|
-
|
|
289
|
-
reconstructed_action_probs =
|
|
293
|
+
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
294
|
+
reconstructed_action_probs = active_df.apply(
|
|
290
295
|
lambda row: action_prob_func(
|
|
291
|
-
*action_prob_func_args[row[calendar_t_col_name]][row[
|
|
296
|
+
*action_prob_func_args[row[calendar_t_col_name]][row[subject_id_col_name]]
|
|
292
297
|
),
|
|
293
298
|
axis=1,
|
|
294
299
|
)
|
|
295
300
|
|
|
296
301
|
np.testing.assert_allclose(
|
|
297
|
-
|
|
302
|
+
active_df[action_prob_col_name].to_numpy(dtype="float64"),
|
|
298
303
|
reconstructed_action_probs.to_numpy(dtype="float64"),
|
|
299
304
|
atol=1e-6,
|
|
300
305
|
)
|
|
301
306
|
|
|
302
307
|
|
|
303
|
-
def
|
|
304
|
-
|
|
308
|
+
def require_all_subjects_have_all_times_in_analysis_df(
|
|
309
|
+
analysis_df, calendar_t_col_name, subject_id_col_name
|
|
305
310
|
):
|
|
306
|
-
logger.info(
|
|
311
|
+
logger.info(
|
|
312
|
+
"Checking that all subjects have the same set of unique calendar times."
|
|
313
|
+
)
|
|
307
314
|
# Get the unique calendar times
|
|
308
|
-
unique_calendar_times = set(
|
|
315
|
+
unique_calendar_times = set(analysis_df[calendar_t_col_name].unique())
|
|
309
316
|
|
|
310
|
-
# Group by
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
)
|
|
317
|
+
# Group by subject ID and aggregate the unique calendar times for each subject
|
|
318
|
+
subject_calendar_times = analysis_df.groupby(subject_id_col_name)[
|
|
319
|
+
calendar_t_col_name
|
|
320
|
+
].apply(set)
|
|
314
321
|
|
|
315
|
-
# Check if all
|
|
316
|
-
if not
|
|
322
|
+
# Check if all subjects have the same set of unique calendar times
|
|
323
|
+
if not subject_calendar_times.apply(lambda x: x == unique_calendar_times).all():
|
|
317
324
|
raise AssertionError(
|
|
318
|
-
"Not all
|
|
325
|
+
"Not all subjects have all calendar times in the analysis DataFrame. Please see the contract for details."
|
|
319
326
|
)
|
|
320
327
|
|
|
321
328
|
|
|
322
|
-
def
|
|
323
|
-
|
|
329
|
+
def require_alg_update_args_given_for_all_subjects_at_each_update(
|
|
330
|
+
analysis_df, subject_id_col_name, alg_update_func_args
|
|
324
331
|
):
|
|
325
332
|
logger.info(
|
|
326
|
-
"Checking that algorithm update function args are given for all
|
|
333
|
+
"Checking that algorithm update function args are given for all subjects at each update."
|
|
327
334
|
)
|
|
328
|
-
|
|
335
|
+
all_subject_ids = set(analysis_df[subject_id_col_name].unique())
|
|
329
336
|
for policy_num in alg_update_func_args:
|
|
330
337
|
assert (
|
|
331
|
-
set(alg_update_func_args[policy_num].keys()) ==
|
|
332
|
-
), f"Not all
|
|
338
|
+
set(alg_update_func_args[policy_num].keys()) == all_subject_ids
|
|
339
|
+
), f"Not all subjects present in algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
333
340
|
|
|
334
341
|
|
|
335
|
-
def
|
|
336
|
-
|
|
342
|
+
def require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
|
|
343
|
+
analysis_df,
|
|
337
344
|
action_prob_col_name,
|
|
338
345
|
calendar_t_col_name,
|
|
339
|
-
|
|
346
|
+
subject_id_col_name,
|
|
340
347
|
alg_update_func_args,
|
|
341
348
|
alg_update_func_args_action_prob_index,
|
|
342
349
|
alg_update_func_args_action_prob_times_index,
|
|
343
350
|
):
|
|
344
351
|
logger.info(
|
|
345
352
|
"Checking that the action probabilities supplied in the algorithm update function args, if"
|
|
346
|
-
" any, correspond to those in the
|
|
353
|
+
" any, correspond to those in the analysis DataFrame for the corresponding subjects and decision"
|
|
347
354
|
" times."
|
|
348
355
|
)
|
|
349
356
|
if alg_update_func_args_action_prob_index < 0:
|
|
350
357
|
return
|
|
351
358
|
|
|
352
359
|
# Precompute a lookup dictionary for faster access
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
].to_dict()
|
|
360
|
+
analysis_df_lookup = analysis_df.set_index(
|
|
361
|
+
[calendar_t_col_name, subject_id_col_name]
|
|
362
|
+
)[action_prob_col_name].to_dict()
|
|
356
363
|
|
|
357
|
-
for policy_num,
|
|
358
|
-
for
|
|
364
|
+
for policy_num, subject_args in alg_update_func_args.items():
|
|
365
|
+
for subject_id, args in subject_args.items():
|
|
359
366
|
if not args:
|
|
360
367
|
continue
|
|
361
368
|
arg_action_probs = args[alg_update_func_args_action_prob_index]
|
|
@@ -364,43 +371,43 @@ def require_action_prob_args_in_alg_update_func_correspond_to_study_df(
|
|
|
364
371
|
].flatten()
|
|
365
372
|
|
|
366
373
|
# Use the precomputed lookup dictionary
|
|
367
|
-
|
|
368
|
-
|
|
374
|
+
analysis_df_action_probs = [
|
|
375
|
+
analysis_df_lookup[(decision_time.item(), subject_id)]
|
|
369
376
|
for decision_time in action_prob_times
|
|
370
377
|
]
|
|
371
378
|
|
|
372
379
|
assert np.allclose(
|
|
373
380
|
arg_action_probs.flatten(),
|
|
374
|
-
|
|
381
|
+
analysis_df_action_probs,
|
|
375
382
|
), (
|
|
376
|
-
f"There is a mismatch for
|
|
383
|
+
f"There is a mismatch for subject {subject_id} between the action probabilities supplied"
|
|
377
384
|
f" in the args to the algorithm update function at policy {policy_num} and those in"
|
|
378
|
-
" the
|
|
385
|
+
" the analysis DataFrame for the supplied times. Please see the contract for details."
|
|
379
386
|
)
|
|
380
387
|
|
|
381
388
|
|
|
382
|
-
def
|
|
383
|
-
|
|
384
|
-
|
|
389
|
+
def require_action_prob_func_args_given_for_all_subjects_at_each_decision(
|
|
390
|
+
analysis_df,
|
|
391
|
+
subject_id_col_name,
|
|
385
392
|
action_prob_func_args,
|
|
386
393
|
):
|
|
387
394
|
logger.info(
|
|
388
|
-
"Checking that action prob function args are given for all
|
|
395
|
+
"Checking that action prob function args are given for all subjects at each decision time."
|
|
389
396
|
)
|
|
390
|
-
|
|
397
|
+
all_subject_ids = set(analysis_df[subject_id_col_name].unique())
|
|
391
398
|
for decision_time in action_prob_func_args:
|
|
392
399
|
assert (
|
|
393
|
-
set(action_prob_func_args[decision_time].keys()) ==
|
|
394
|
-
), f"Not all
|
|
400
|
+
set(action_prob_func_args[decision_time].keys()) == all_subject_ids
|
|
401
|
+
), f"Not all subjects present in algorithm update function args for decision time {decision_time}. Please see the contract for details."
|
|
395
402
|
|
|
396
403
|
|
|
397
404
|
def require_action_prob_func_args_given_for_all_decision_times(
|
|
398
|
-
|
|
405
|
+
analysis_df, calendar_t_col_name, action_prob_func_args
|
|
399
406
|
):
|
|
400
407
|
logger.info(
|
|
401
408
|
"Checking that action prob function args are given for all decision times."
|
|
402
409
|
)
|
|
403
|
-
all_times = set(
|
|
410
|
+
all_times = set(analysis_df[calendar_t_col_name].unique())
|
|
404
411
|
|
|
405
412
|
assert (
|
|
406
413
|
set(action_prob_func_args.keys()) == all_times
|
|
@@ -408,106 +415,112 @@ def require_action_prob_func_args_given_for_all_decision_times(
|
|
|
408
415
|
|
|
409
416
|
|
|
410
417
|
def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
411
|
-
|
|
418
|
+
analysis_df: pd.DataFrame,
|
|
412
419
|
calendar_t_col_name: str,
|
|
413
420
|
action_prob_func_args: dict[str, dict[str, tuple[Any, ...]]],
|
|
414
|
-
|
|
415
|
-
|
|
421
|
+
active_col_name,
|
|
422
|
+
subject_id_col_name,
|
|
416
423
|
):
|
|
417
424
|
logger.info(
|
|
418
|
-
"Checking that action probability function args are blank for exactly the times each
|
|
419
|
-
"
|
|
425
|
+
"Checking that action probability function args are blank for exactly the times each subject"
|
|
426
|
+
"is not in the study according to the analysis DataFrame."
|
|
420
427
|
)
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
428
|
+
inactive_df = analysis_df[analysis_df[active_col_name] == 0]
|
|
429
|
+
inactive_times_by_subject_according_to_analysis_df = (
|
|
430
|
+
inactive_df.groupby(subject_id_col_name)[calendar_t_col_name]
|
|
424
431
|
.apply(set)
|
|
425
432
|
.to_dict()
|
|
426
433
|
)
|
|
427
434
|
|
|
428
|
-
|
|
435
|
+
inactive_times_by_subject_according_to_action_prob_func_args = (
|
|
429
436
|
collections.defaultdict(set)
|
|
430
437
|
)
|
|
431
|
-
for decision_time,
|
|
432
|
-
for
|
|
438
|
+
for decision_time, action_prob_args_by_subject in action_prob_func_args.items():
|
|
439
|
+
for subject_id, action_prob_args in action_prob_args_by_subject.items():
|
|
433
440
|
if not action_prob_args:
|
|
434
|
-
|
|
435
|
-
|
|
441
|
+
inactive_times_by_subject_according_to_action_prob_func_args[
|
|
442
|
+
subject_id
|
|
436
443
|
].add(decision_time)
|
|
437
444
|
|
|
438
445
|
assert (
|
|
439
|
-
|
|
440
|
-
==
|
|
446
|
+
inactive_times_by_subject_according_to_analysis_df
|
|
447
|
+
== inactive_times_by_subject_according_to_action_prob_func_args
|
|
441
448
|
), (
|
|
442
|
-
"
|
|
443
|
-
" times for which action probability arguments are blank for all
|
|
449
|
+
"Inactive decision times according to the analysis DataFrame do not match up with the"
|
|
450
|
+
" times for which action probability arguments are blank for all subjects. Please see the"
|
|
444
451
|
" contract for details."
|
|
445
452
|
)
|
|
446
453
|
|
|
447
454
|
|
|
448
|
-
def
|
|
449
|
-
|
|
450
|
-
|
|
455
|
+
def require_all_named_columns_present_in_analysis_df(
|
|
456
|
+
analysis_df,
|
|
457
|
+
active_col_name,
|
|
451
458
|
action_col_name,
|
|
452
459
|
policy_num_col_name,
|
|
453
460
|
calendar_t_col_name,
|
|
454
|
-
|
|
461
|
+
subject_id_col_name,
|
|
455
462
|
action_prob_col_name,
|
|
456
463
|
):
|
|
457
|
-
logger.info(
|
|
464
|
+
logger.info(
|
|
465
|
+
"Checking that all named columns are present in the analysis DataFrame."
|
|
466
|
+
)
|
|
467
|
+
assert (
|
|
468
|
+
active_col_name in analysis_df.columns
|
|
469
|
+
), f"{active_col_name} not in analysis DataFrame."
|
|
470
|
+
assert (
|
|
471
|
+
action_col_name in analysis_df.columns
|
|
472
|
+
), f"{action_col_name} not in analysis DataFrame."
|
|
458
473
|
assert (
|
|
459
|
-
|
|
460
|
-
), f"{
|
|
461
|
-
assert action_col_name in study_df.columns, f"{action_col_name} not in study df."
|
|
474
|
+
policy_num_col_name in analysis_df.columns
|
|
475
|
+
), f"{policy_num_col_name} not in analysis DataFrame."
|
|
462
476
|
assert (
|
|
463
|
-
|
|
464
|
-
), f"{
|
|
477
|
+
calendar_t_col_name in analysis_df.columns
|
|
478
|
+
), f"{calendar_t_col_name} not in analysis DataFrame."
|
|
465
479
|
assert (
|
|
466
|
-
|
|
467
|
-
), f"{
|
|
468
|
-
assert user_id_col_name in study_df.columns, f"{user_id_col_name} not in study df."
|
|
480
|
+
subject_id_col_name in analysis_df.columns
|
|
481
|
+
), f"{subject_id_col_name} not in analysis DataFrame."
|
|
469
482
|
assert (
|
|
470
|
-
action_prob_col_name in
|
|
471
|
-
), f"{action_prob_col_name} not in
|
|
483
|
+
action_prob_col_name in analysis_df.columns
|
|
484
|
+
), f"{action_prob_col_name} not in analysis DataFrame."
|
|
472
485
|
|
|
473
486
|
|
|
474
|
-
def
|
|
475
|
-
|
|
476
|
-
|
|
487
|
+
def require_all_named_columns_not_object_type_in_analysis_df(
|
|
488
|
+
analysis_df,
|
|
489
|
+
active_col_name,
|
|
477
490
|
action_col_name,
|
|
478
491
|
policy_num_col_name,
|
|
479
492
|
calendar_t_col_name,
|
|
480
|
-
|
|
493
|
+
subject_id_col_name,
|
|
481
494
|
action_prob_col_name,
|
|
482
495
|
):
|
|
483
496
|
logger.info("Checking that all named columns are not type object.")
|
|
484
497
|
for colname in (
|
|
485
|
-
|
|
498
|
+
active_col_name,
|
|
486
499
|
action_col_name,
|
|
487
500
|
policy_num_col_name,
|
|
488
501
|
calendar_t_col_name,
|
|
489
|
-
|
|
502
|
+
subject_id_col_name,
|
|
490
503
|
action_prob_col_name,
|
|
491
504
|
):
|
|
492
505
|
assert (
|
|
493
|
-
|
|
494
|
-
), f"At least {colname} is of object type in
|
|
506
|
+
analysis_df[colname].dtype != "object"
|
|
507
|
+
), f"At least {colname} is of object type in analysis DataFrame."
|
|
495
508
|
|
|
496
509
|
|
|
497
|
-
def require_binary_actions(
|
|
510
|
+
def require_binary_actions(analysis_df, active_col_name, action_col_name):
|
|
498
511
|
logger.info("Checking that actions are binary.")
|
|
499
512
|
assert (
|
|
500
|
-
|
|
513
|
+
analysis_df[analysis_df[active_col_name] == 1][action_col_name]
|
|
501
514
|
.astype("int64")
|
|
502
515
|
.isin([0, 1])
|
|
503
516
|
.all()
|
|
504
517
|
), "Actions are not binary."
|
|
505
518
|
|
|
506
519
|
|
|
507
|
-
def
|
|
508
|
-
logger.info("Checking that
|
|
520
|
+
def require_binary_active_indicators(analysis_df, active_col_name):
|
|
521
|
+
logger.info("Checking that active indicators are binary.")
|
|
509
522
|
assert (
|
|
510
|
-
|
|
523
|
+
analysis_df[analysis_df[active_col_name] == 1][active_col_name]
|
|
511
524
|
.astype("int64")
|
|
512
525
|
.isin([0, 1])
|
|
513
526
|
.all()
|
|
@@ -515,7 +528,7 @@ def require_binary_in_study_indicators(study_df, in_study_col_name):
|
|
|
515
528
|
|
|
516
529
|
|
|
517
530
|
def require_consecutive_integer_policy_numbers(
|
|
518
|
-
|
|
531
|
+
analysis_df, active_col_name, policy_num_col_name
|
|
519
532
|
):
|
|
520
533
|
# TODO: This is a somewhat rough check of this, could also check nondecreasing temporally
|
|
521
534
|
|
|
@@ -523,8 +536,8 @@ def require_consecutive_integer_policy_numbers(
|
|
|
523
536
|
"Checking that in-study, non-fallback policy numbers are consecutive integers."
|
|
524
537
|
)
|
|
525
538
|
|
|
526
|
-
|
|
527
|
-
nonnegative_policy_df =
|
|
539
|
+
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
540
|
+
nonnegative_policy_df = active_df[active_df[policy_num_col_name] >= 0]
|
|
528
541
|
# Ideally we actually have integers, but for legacy reasons we will support
|
|
529
542
|
# floats as well.
|
|
530
543
|
if nonnegative_policy_df[policy_num_col_name].dtype == "float64":
|
|
@@ -540,78 +553,83 @@ def require_consecutive_integer_policy_numbers(
|
|
|
540
553
|
), "Policy numbers are not consecutive integers."
|
|
541
554
|
|
|
542
555
|
|
|
543
|
-
def require_consecutive_integer_calendar_times(
|
|
556
|
+
def require_consecutive_integer_calendar_times(analysis_df, calendar_t_col_name):
|
|
544
557
|
# This is a somewhat rough check of this, more like checking there are no
|
|
545
|
-
# gaps in the integers covered. But we have other checks that all
|
|
558
|
+
# gaps in the integers covered. But we have other checks that all subjects
|
|
546
559
|
# have same times, etc.
|
|
547
|
-
# Note these times should be well-formed even when the
|
|
560
|
+
# Note these times should be well-formed even when the subject is not in the study.
|
|
548
561
|
logger.info("Checking that calendar times are consecutive integers.")
|
|
549
562
|
assert np.array_equal(
|
|
550
|
-
|
|
563
|
+
analysis_df[calendar_t_col_name].unique(),
|
|
551
564
|
range(
|
|
552
|
-
|
|
565
|
+
analysis_df[calendar_t_col_name].min(),
|
|
566
|
+
analysis_df[calendar_t_col_name].max() + 1,
|
|
553
567
|
),
|
|
554
568
|
), "Calendar times are not consecutive integers."
|
|
555
569
|
|
|
556
570
|
|
|
557
|
-
def
|
|
558
|
-
logger.info("Checking that
|
|
571
|
+
def require_hashable_subject_ids(analysis_df, active_col_name, subject_id_col_name):
|
|
572
|
+
logger.info("Checking that subject IDs are hashable.")
|
|
559
573
|
isinstance(
|
|
560
|
-
|
|
574
|
+
analysis_df[analysis_df[active_col_name] == 1][subject_id_col_name][0],
|
|
561
575
|
collections.abc.Hashable,
|
|
562
576
|
)
|
|
563
577
|
|
|
564
578
|
|
|
565
|
-
def require_action_probabilities_in_range_0_to_1(
|
|
579
|
+
def require_action_probabilities_in_range_0_to_1(analysis_df, action_prob_col_name):
|
|
566
580
|
logger.info("Checking that action probabilities are in the interval (0, 1).")
|
|
567
|
-
|
|
581
|
+
analysis_df[action_prob_col_name].between(0, 1, inclusive="neither").all()
|
|
568
582
|
|
|
569
583
|
|
|
570
|
-
def
|
|
571
|
-
|
|
584
|
+
def require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
|
|
585
|
+
analysis_df, policy_num_col_name, alg_update_func_args
|
|
572
586
|
):
|
|
573
587
|
logger.info(
|
|
574
|
-
"Checking that policy numbers in algorithm update function args are present in the
|
|
588
|
+
"Checking that policy numbers in algorithm update function args are present in the analysis DataFrame."
|
|
575
589
|
)
|
|
576
590
|
alg_update_policy_nums = sorted(alg_update_func_args.keys())
|
|
577
|
-
|
|
578
|
-
assert set(alg_update_policy_nums).issubset(set(
|
|
579
|
-
f"There are policy numbers present in algorithm update function args but not in the
|
|
591
|
+
analysis_df_policy_nums = sorted(analysis_df[policy_num_col_name].unique())
|
|
592
|
+
assert set(alg_update_policy_nums).issubset(set(analysis_df_policy_nums)), (
|
|
593
|
+
f"There are policy numbers present in algorithm update function args but not in the analysis DataFrame. "
|
|
580
594
|
f"\nalg_update_func_args policy numbers: {alg_update_policy_nums}"
|
|
581
|
-
f"\
|
|
595
|
+
f"\nanalysis_df policy numbers: {analysis_df_policy_nums}.\nPlease see the contract for details."
|
|
582
596
|
)
|
|
583
597
|
|
|
584
598
|
|
|
585
|
-
def
|
|
586
|
-
|
|
599
|
+
def require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
600
|
+
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
587
601
|
):
|
|
588
602
|
logger.info(
|
|
589
|
-
"Checking that all policy numbers in the
|
|
603
|
+
"Checking that all policy numbers in the analysis DataFrame are present in the algorithm update function args."
|
|
590
604
|
)
|
|
591
|
-
|
|
605
|
+
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
592
606
|
# Get the number of the initial policy. 0 is recommended but not required.
|
|
593
|
-
min_nonnegative_policy_number =
|
|
607
|
+
min_nonnegative_policy_number = active_df[active_df[policy_num_col_name] >= 0][
|
|
594
608
|
policy_num_col_name
|
|
595
609
|
]
|
|
596
610
|
assert set(
|
|
597
|
-
|
|
611
|
+
active_df[active_df[policy_num_col_name] > min_nonnegative_policy_number][
|
|
598
612
|
policy_num_col_name
|
|
599
613
|
].unique()
|
|
600
614
|
).issubset(
|
|
601
615
|
alg_update_func_args.keys()
|
|
602
|
-
), f"There are non-fallback, non-initial policy numbers in the
|
|
616
|
+
), f"There are non-fallback, non-initial policy numbers in the analysis DataFrame that are not in the update function args: {set(active_df[active_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
|
|
603
617
|
|
|
604
618
|
|
|
605
619
|
def confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
606
620
|
alg_update_func_args_action_prob_index,
|
|
621
|
+
alg_update_func_args_previous_betas_index,
|
|
607
622
|
suppress_interactive_data_checks,
|
|
608
623
|
):
|
|
609
624
|
logger.info(
|
|
610
625
|
"Confirming that action probabilities are not in algorithm update function args IF their index is not specified"
|
|
611
626
|
)
|
|
612
|
-
if
|
|
627
|
+
if (
|
|
628
|
+
alg_update_func_args_action_prob_index < 0
|
|
629
|
+
and alg_update_func_args_previous_betas_index < 0
|
|
630
|
+
):
|
|
613
631
|
confirm_input_check_result(
|
|
614
|
-
"\nYou specified that the algorithm update function supplied does not have action probabilities
|
|
632
|
+
"\nYou specified that the algorithm update function supplied does not have action probabilities or previous betas in its arguments. Please verify this is correct.\n\nContinue? (y/n)\n",
|
|
615
633
|
suppress_interactive_data_checks,
|
|
616
634
|
)
|
|
617
635
|
|
|
@@ -630,20 +648,6 @@ def confirm_no_small_sample_correction_desired_if_not_requested(
|
|
|
630
648
|
)
|
|
631
649
|
|
|
632
650
|
|
|
633
|
-
def confirm_no_adaptive_bread_inverse_stabilization_method_desired_if_not_requested(
|
|
634
|
-
adaptive_bread_inverse_stabilization_method,
|
|
635
|
-
suppress_interactive_data_checks,
|
|
636
|
-
):
|
|
637
|
-
logger.info(
|
|
638
|
-
"Confirming that no adaptive bread inverse stabilization method is desired if it's not requested."
|
|
639
|
-
)
|
|
640
|
-
if adaptive_bread_inverse_stabilization_method == InverseStabilizationMethods.NONE:
|
|
641
|
-
confirm_input_check_result(
|
|
642
|
-
"\nYou specified that you would not like to perform any inverse stabilization while forming the adaptive variance. This is not usually recommended. Please verify that it is correct or select one of the available options.\n\nContinue? (y/n)\n",
|
|
643
|
-
suppress_interactive_data_checks,
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
|
|
647
651
|
def require_action_prob_times_given_if_index_supplied(
|
|
648
652
|
alg_update_func_args_action_prob_index,
|
|
649
653
|
alg_update_func_args_action_prob_times_index,
|
|
@@ -672,28 +676,29 @@ def require_beta_is_1D_array_in_alg_update_args(
|
|
|
672
676
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
673
677
|
):
|
|
674
678
|
for policy_num in alg_update_func_args:
|
|
675
|
-
for
|
|
676
|
-
if not alg_update_func_args[policy_num][
|
|
679
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
680
|
+
if not alg_update_func_args[policy_num][subject_id]:
|
|
677
681
|
continue
|
|
678
682
|
assert (
|
|
679
|
-
alg_update_func_args[policy_num][
|
|
683
|
+
alg_update_func_args[policy_num][subject_id][
|
|
680
684
|
alg_update_func_args_beta_index
|
|
681
685
|
].ndim
|
|
682
686
|
== 1
|
|
683
687
|
), "Beta is not a 1D array in the algorithm update function args."
|
|
684
688
|
|
|
689
|
+
|
|
685
690
|
def require_previous_betas_is_2D_array_in_alg_update_args(
|
|
686
691
|
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
687
692
|
):
|
|
688
693
|
if alg_update_func_args_previous_betas_index < 0:
|
|
689
694
|
return
|
|
690
|
-
|
|
695
|
+
|
|
691
696
|
for policy_num in alg_update_func_args:
|
|
692
|
-
for
|
|
693
|
-
if not alg_update_func_args[policy_num][
|
|
697
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
698
|
+
if not alg_update_func_args[policy_num][subject_id]:
|
|
694
699
|
continue
|
|
695
700
|
assert (
|
|
696
|
-
alg_update_func_args[policy_num][
|
|
701
|
+
alg_update_func_args[policy_num][subject_id][
|
|
697
702
|
alg_update_func_args_previous_betas_index
|
|
698
703
|
].ndim
|
|
699
704
|
== 2
|
|
@@ -704,11 +709,11 @@ def require_beta_is_1D_array_in_action_prob_args(
|
|
|
704
709
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
705
710
|
):
|
|
706
711
|
for decision_time in action_prob_func_args:
|
|
707
|
-
for
|
|
708
|
-
if not action_prob_func_args[decision_time][
|
|
712
|
+
for subject_id in action_prob_func_args[decision_time]:
|
|
713
|
+
if not action_prob_func_args[decision_time][subject_id]:
|
|
709
714
|
continue
|
|
710
715
|
assert (
|
|
711
|
-
action_prob_func_args[decision_time][
|
|
716
|
+
action_prob_func_args[decision_time][subject_id][
|
|
712
717
|
action_prob_func_args_beta_index
|
|
713
718
|
].ndim
|
|
714
719
|
== 1
|
|
@@ -719,12 +724,12 @@ def require_theta_is_1D_array(theta_est):
|
|
|
719
724
|
assert theta_est.ndim == 1, "Theta is not a 1D array."
|
|
720
725
|
|
|
721
726
|
|
|
722
|
-
def
|
|
723
|
-
|
|
724
|
-
|
|
727
|
+
def verify_analysis_df_summary_satisfactory(
|
|
728
|
+
analysis_df,
|
|
729
|
+
subject_id_col_name,
|
|
725
730
|
policy_num_col_name,
|
|
726
731
|
calendar_t_col_name,
|
|
727
|
-
|
|
732
|
+
active_col_name,
|
|
728
733
|
action_prob_col_name,
|
|
729
734
|
reward_col_name,
|
|
730
735
|
beta_dim,
|
|
@@ -732,43 +737,41 @@ def verify_study_df_summary_satisfactory(
|
|
|
732
737
|
suppress_interactive_data_checks,
|
|
733
738
|
):
|
|
734
739
|
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
num_non_initial_or_fallback_policies =
|
|
738
|
-
|
|
740
|
+
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
741
|
+
num_subjects = active_df[subject_id_col_name].nunique()
|
|
742
|
+
num_non_initial_or_fallback_policies = active_df[
|
|
743
|
+
active_df[policy_num_col_name] > 0
|
|
739
744
|
][policy_num_col_name].nunique()
|
|
740
745
|
num_decision_times_with_fallback_policies = len(
|
|
741
|
-
|
|
746
|
+
active_df[active_df[policy_num_col_name] < 0]
|
|
742
747
|
)
|
|
743
|
-
num_decision_times =
|
|
744
|
-
|
|
748
|
+
num_decision_times = active_df[calendar_t_col_name].nunique()
|
|
749
|
+
avg_decisions_per_subject = len(active_df) / num_subjects
|
|
745
750
|
num_decision_times_with_multiple_policies = (
|
|
746
|
-
|
|
751
|
+
active_df[active_df[policy_num_col_name] >= 0]
|
|
747
752
|
.groupby(calendar_t_col_name)[policy_num_col_name]
|
|
748
753
|
.nunique()
|
|
749
754
|
> 1
|
|
750
755
|
).sum()
|
|
751
|
-
min_action_prob =
|
|
752
|
-
max_action_prob =
|
|
753
|
-
min_non_fallback_policy_num =
|
|
756
|
+
min_action_prob = active_df[action_prob_col_name].min()
|
|
757
|
+
max_action_prob = active_df[action_prob_col_name].max()
|
|
758
|
+
min_non_fallback_policy_num = active_df[active_df[policy_num_col_name] >= 0][
|
|
754
759
|
policy_num_col_name
|
|
755
760
|
].min()
|
|
756
761
|
num_data_points_before_first_update = len(
|
|
757
|
-
|
|
762
|
+
active_df[active_df[policy_num_col_name] == min_non_fallback_policy_num]
|
|
758
763
|
)
|
|
759
764
|
|
|
760
765
|
median_action_probabilities = (
|
|
761
|
-
|
|
762
|
-
.median()
|
|
763
|
-
.to_numpy()
|
|
766
|
+
active_df.groupby(calendar_t_col_name)[action_prob_col_name].median().to_numpy()
|
|
764
767
|
)
|
|
765
|
-
quartiles =
|
|
768
|
+
quartiles = active_df.groupby(calendar_t_col_name)[action_prob_col_name].quantile(
|
|
766
769
|
[0.25, 0.75]
|
|
767
770
|
)
|
|
768
771
|
q25_action_probabilities = quartiles.xs(0.25, level=1).to_numpy()
|
|
769
772
|
q75_action_probabilities = quartiles.xs(0.75, level=1).to_numpy()
|
|
770
773
|
|
|
771
|
-
avg_rewards =
|
|
774
|
+
avg_rewards = active_df.groupby(calendar_t_col_name)[reward_col_name].mean()
|
|
772
775
|
|
|
773
776
|
# Plot action probability quartile trajectories
|
|
774
777
|
plt.clear_figure()
|
|
@@ -807,11 +810,11 @@ def verify_study_df_summary_satisfactory(
|
|
|
807
810
|
avg_reward_trajectory_plot = plt.build()
|
|
808
811
|
|
|
809
812
|
confirm_input_check_result(
|
|
810
|
-
f"\nYou provided
|
|
811
|
-
f"\n* {
|
|
813
|
+
f"\nYou provided an analysis DataFrame reflecting a study with"
|
|
814
|
+
f"\n* {num_subjects} subjects"
|
|
812
815
|
f"\n* {num_non_initial_or_fallback_policies} policy updates"
|
|
813
|
-
f"\n* {num_decision_times} decision times, for an average of {
|
|
814
|
-
f" decisions per
|
|
816
|
+
f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_subject}"
|
|
817
|
+
f" decisions per subject"
|
|
815
818
|
f"\n* RL parameters of dimension {beta_dim} per update"
|
|
816
819
|
f"\n* Inferential target of dimension {theta_dim}"
|
|
817
820
|
f"\n* {num_data_points_before_first_update} data points before the first update"
|
|
@@ -834,14 +837,14 @@ def require_betas_match_in_alg_update_args_each_update(
|
|
|
834
837
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
835
838
|
):
|
|
836
839
|
logger.info(
|
|
837
|
-
"Checking that betas match across
|
|
840
|
+
"Checking that betas match across subjects for each update in the algorithm update function args."
|
|
838
841
|
)
|
|
839
842
|
for policy_num in alg_update_func_args:
|
|
840
843
|
first_beta = None
|
|
841
|
-
for
|
|
842
|
-
if not alg_update_func_args[policy_num][
|
|
844
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
845
|
+
if not alg_update_func_args[policy_num][subject_id]:
|
|
843
846
|
continue
|
|
844
|
-
beta = alg_update_func_args[policy_num][
|
|
847
|
+
beta = alg_update_func_args[policy_num][subject_id][
|
|
845
848
|
alg_update_func_args_beta_index
|
|
846
849
|
]
|
|
847
850
|
if first_beta is None:
|
|
@@ -849,23 +852,24 @@ def require_betas_match_in_alg_update_args_each_update(
|
|
|
849
852
|
else:
|
|
850
853
|
assert np.array_equal(
|
|
851
854
|
beta, first_beta
|
|
852
|
-
), f"Betas do not match across
|
|
855
|
+
), f"Betas do not match across subjects in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
856
|
+
|
|
853
857
|
|
|
854
858
|
def require_previous_betas_match_in_alg_update_args_each_update(
|
|
855
859
|
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
856
860
|
):
|
|
857
861
|
logger.info(
|
|
858
|
-
"Checking that previous betas match across
|
|
862
|
+
"Checking that previous betas match across subjects for each update in the algorithm update function args."
|
|
859
863
|
)
|
|
860
864
|
if alg_update_func_args_previous_betas_index < 0:
|
|
861
|
-
return
|
|
862
|
-
|
|
865
|
+
return
|
|
866
|
+
|
|
863
867
|
for policy_num in alg_update_func_args:
|
|
864
868
|
first_previous_betas = None
|
|
865
|
-
for
|
|
866
|
-
if not alg_update_func_args[policy_num][
|
|
869
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
870
|
+
if not alg_update_func_args[policy_num][subject_id]:
|
|
867
871
|
continue
|
|
868
|
-
previous_betas = alg_update_func_args[policy_num][
|
|
872
|
+
previous_betas = alg_update_func_args[policy_num][subject_id][
|
|
869
873
|
alg_update_func_args_previous_betas_index
|
|
870
874
|
]
|
|
871
875
|
if first_previous_betas is None:
|
|
@@ -873,21 +877,21 @@ def require_previous_betas_match_in_alg_update_args_each_update(
|
|
|
873
877
|
else:
|
|
874
878
|
assert np.array_equal(
|
|
875
879
|
previous_betas, first_previous_betas
|
|
876
|
-
), f"Previous betas do not match across
|
|
880
|
+
), f"Previous betas do not match across subjects in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
877
881
|
|
|
878
882
|
|
|
879
883
|
def require_betas_match_in_action_prob_func_args_each_decision(
|
|
880
884
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
881
885
|
):
|
|
882
886
|
logger.info(
|
|
883
|
-
"Checking that betas match across
|
|
887
|
+
"Checking that betas match across subjects for each decision time in the action prob args."
|
|
884
888
|
)
|
|
885
889
|
for decision_time in action_prob_func_args:
|
|
886
890
|
first_beta = None
|
|
887
|
-
for
|
|
888
|
-
if not action_prob_func_args[decision_time][
|
|
891
|
+
for subject_id in action_prob_func_args[decision_time]:
|
|
892
|
+
if not action_prob_func_args[decision_time][subject_id]:
|
|
889
893
|
continue
|
|
890
|
-
beta = action_prob_func_args[decision_time][
|
|
894
|
+
beta = action_prob_func_args[decision_time][subject_id][
|
|
891
895
|
action_prob_func_args_beta_index
|
|
892
896
|
]
|
|
893
897
|
if first_beta is None:
|
|
@@ -895,11 +899,11 @@ def require_betas_match_in_action_prob_func_args_each_decision(
|
|
|
895
899
|
else:
|
|
896
900
|
assert np.array_equal(
|
|
897
901
|
beta, first_beta
|
|
898
|
-
), f"Betas do not match across
|
|
902
|
+
), f"Betas do not match across subjects in the action prob args for decision_time {decision_time}. Please see the contract for details."
|
|
899
903
|
|
|
900
904
|
|
|
901
905
|
def require_valid_action_prob_times_given_if_index_supplied(
|
|
902
|
-
|
|
906
|
+
analysis_df,
|
|
903
907
|
calendar_t_col_name,
|
|
904
908
|
alg_update_func_args,
|
|
905
909
|
alg_update_func_args_action_prob_times_index,
|
|
@@ -909,19 +913,19 @@ def require_valid_action_prob_times_given_if_index_supplied(
|
|
|
909
913
|
if alg_update_func_args_action_prob_times_index < 0:
|
|
910
914
|
return
|
|
911
915
|
|
|
912
|
-
min_time =
|
|
913
|
-
max_time =
|
|
914
|
-
for policy_idx,
|
|
915
|
-
for
|
|
916
|
+
min_time = analysis_df[calendar_t_col_name].min()
|
|
917
|
+
max_time = analysis_df[calendar_t_col_name].max()
|
|
918
|
+
for policy_idx, args_by_subject in alg_update_func_args.items():
|
|
919
|
+
for subject_id, args in args_by_subject.items():
|
|
916
920
|
if not args:
|
|
917
921
|
continue
|
|
918
922
|
times = args[alg_update_func_args_action_prob_times_index]
|
|
919
923
|
assert (
|
|
920
924
|
times[i] > times[i - 1] for i in range(1, len(times))
|
|
921
|
-
), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for
|
|
925
|
+
), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for subject {subject_id} and policy {policy_idx}. Please see the contract for details."
|
|
922
926
|
assert (
|
|
923
927
|
times[0] >= min_time and times[-1] <= max_time
|
|
924
|
-
), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the
|
|
928
|
+
), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the analysis DataFrame are {min_time} and {max_time}, while subject {subject_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
|
|
925
929
|
|
|
926
930
|
|
|
927
931
|
def require_estimating_functions_sum_to_zero(
|
|
@@ -933,15 +937,15 @@ def require_estimating_functions_sum_to_zero(
|
|
|
933
937
|
"""
|
|
934
938
|
This is a test that the correct loss/estimating functions have
|
|
935
939
|
been given for both the algorithm updates and inference. If that is true, then the
|
|
936
|
-
loss/estimating functions when evaluated should sum to approximately zero across
|
|
937
|
-
values have been stacked and averaged across
|
|
940
|
+
loss/estimating functions when evaluated should sum to approximately zero across subjects. These
|
|
941
|
+
values have been stacked and averaged across subjects in mean_estimating_function_stack, which
|
|
938
942
|
we simply compare to the zero vector. We can isolate components for each update and inference
|
|
939
943
|
by considering the dimensions of the beta vectors and the theta vector.
|
|
940
944
|
|
|
941
945
|
Inputs:
|
|
942
946
|
mean_estimating_function_stack:
|
|
943
947
|
The mean of the estimating function stack (a component for each algorithm update and
|
|
944
|
-
inference) across
|
|
948
|
+
inference) across subjects. This should be a 1D array.
|
|
945
949
|
beta_dim:
|
|
946
950
|
The dimension of the beta vectors that parameterize the algorithm.
|
|
947
951
|
theta_dim:
|
|
@@ -951,7 +955,7 @@ def require_estimating_functions_sum_to_zero(
|
|
|
951
955
|
None
|
|
952
956
|
"""
|
|
953
957
|
|
|
954
|
-
logger.info("Checking that estimating functions average to zero across
|
|
958
|
+
logger.info("Checking that estimating functions average to zero across subjects")
|
|
955
959
|
|
|
956
960
|
# Have a looser hard failure cutoff before the typical interactive check
|
|
957
961
|
try:
|
|
@@ -962,7 +966,7 @@ def require_estimating_functions_sum_to_zero(
|
|
|
962
966
|
)
|
|
963
967
|
except AssertionError as e:
|
|
964
968
|
logger.info(
|
|
965
|
-
"Estimating function stacks do not average to within loose tolerance of zero across
|
|
969
|
+
"Estimating function stacks do not average to within loose tolerance of zero across subjects. Drilling in to specific updates and inference component."
|
|
966
970
|
)
|
|
967
971
|
# If this is not true there is an internal problem in the package.
|
|
968
972
|
assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
|
|
@@ -987,11 +991,11 @@ def require_estimating_functions_sum_to_zero(
|
|
|
987
991
|
np.testing.assert_allclose(
|
|
988
992
|
mean_estimating_function_stack,
|
|
989
993
|
jnp.zeros(mean_estimating_function_stack.size),
|
|
990
|
-
atol=
|
|
994
|
+
atol=5e-4,
|
|
991
995
|
)
|
|
992
996
|
except AssertionError as e:
|
|
993
997
|
logger.info(
|
|
994
|
-
"Estimating function stacks do not average to within specified tolerance of zero across
|
|
998
|
+
"Estimating function stacks do not average to within specified tolerance of zero across subjects. Drilling in to specific updates and inference component."
|
|
995
999
|
)
|
|
996
1000
|
# If this is not true there is an internal problem in the package.
|
|
997
1001
|
assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
|
|
@@ -1021,15 +1025,15 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1021
1025
|
"""
|
|
1022
1026
|
This is a test that the correct loss/estimating functions have
|
|
1023
1027
|
been given for both the algorithm updates and inference. If that is true, then the
|
|
1024
|
-
loss/estimating functions when evaluated should sum to approximately zero across
|
|
1025
|
-
values have been stacked and averaged across
|
|
1028
|
+
loss/estimating functions when evaluated should sum to approximately zero across subjects. These
|
|
1029
|
+
values have been stacked and averaged across subjects in mean_estimating_function_stack, which
|
|
1026
1030
|
we simply compare to the zero vector. We can isolate components for each update and inference
|
|
1027
1031
|
by considering the dimensions of the beta vectors and the theta vector.
|
|
1028
1032
|
|
|
1029
1033
|
Inputs:
|
|
1030
1034
|
mean_estimating_function_stack:
|
|
1031
1035
|
The mean of the estimating function stack (a component for each algorithm update and
|
|
1032
|
-
inference) across
|
|
1036
|
+
inference) across subjects. This should be a 1D array.
|
|
1033
1037
|
beta_dim:
|
|
1034
1038
|
The dimension of the beta vectors that parameterize the algorithm.
|
|
1035
1039
|
theta_dim:
|
|
@@ -1039,7 +1043,7 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1039
1043
|
None
|
|
1040
1044
|
"""
|
|
1041
1045
|
|
|
1042
|
-
logger.info("Checking that RL estimating functions average to zero across
|
|
1046
|
+
logger.info("Checking that RL estimating functions average to zero across subjects")
|
|
1043
1047
|
|
|
1044
1048
|
# Have a looser hard failure cutoff before the typical interactive check
|
|
1045
1049
|
try:
|
|
@@ -1050,7 +1054,7 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1050
1054
|
)
|
|
1051
1055
|
except AssertionError as e:
|
|
1052
1056
|
logger.info(
|
|
1053
|
-
"RL estimating function stacks do not average to zero across
|
|
1057
|
+
"RL estimating function stacks do not average to zero across subjects. Drilling in to specific updates and inference component."
|
|
1054
1058
|
)
|
|
1055
1059
|
num_updates = (mean_estimating_function_stack.size) // beta_dim
|
|
1056
1060
|
for i in range(num_updates):
|
|
@@ -1070,7 +1074,7 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1070
1074
|
)
|
|
1071
1075
|
except AssertionError as e:
|
|
1072
1076
|
logger.info(
|
|
1073
|
-
"RL estimating function stacks do not average to zero across
|
|
1077
|
+
"RL estimating function stacks do not average to zero across subjects. Drilling in to specific updates and inference component."
|
|
1074
1078
|
)
|
|
1075
1079
|
num_updates = (mean_estimating_function_stack.size) // beta_dim
|
|
1076
1080
|
for i in range(num_updates):
|
|
@@ -1079,7 +1083,6 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1079
1083
|
i + 1,
|
|
1080
1084
|
mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
|
|
1081
1085
|
)
|
|
1082
|
-
# TODO: Email instead of requiring user input for monitoring alg.
|
|
1083
1086
|
confirm_input_check_result(
|
|
1084
1087
|
f"\nEstimating functions do not average to within default tolerance of zero vector. Please decide if the following is a reasonable result, taking into account the above breakdown by update number and inference. If not, there are several possible reasons for failure mentioned in the contract. Results:\n{str(e)}\n\nContinue? (y/n)\n",
|
|
1085
1088
|
suppress_interactive_data_checks,
|
|
@@ -1087,20 +1090,18 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1087
1090
|
)
|
|
1088
1091
|
|
|
1089
1092
|
|
|
1090
|
-
def
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
+
def require_joint_bread_inverse_is_true_inverse(
|
|
1094
|
+
joint_bread_inverse_matrix,
|
|
1095
|
+
joint_bread_matrix,
|
|
1093
1096
|
suppress_interactive_data_checks,
|
|
1094
1097
|
):
|
|
1095
1098
|
"""
|
|
1096
|
-
Check that the product of the joint
|
|
1099
|
+
Check that the product of the joint bread matrix and its inverse is
|
|
1097
1100
|
sufficiently close to the identity matrix. This is a direct check that the
|
|
1098
|
-
|
|
1101
|
+
joint_bread_matrix we create is "well-conditioned".
|
|
1099
1102
|
"""
|
|
1100
|
-
should_be_identity =
|
|
1101
|
-
|
|
1102
|
-
)
|
|
1103
|
-
identity = np.eye(joint_adaptive_bread_matrix.shape[0])
|
|
1103
|
+
should_be_identity = joint_bread_inverse_matrix @ joint_bread_matrix
|
|
1104
|
+
identity = np.eye(joint_bread_matrix.shape[0])
|
|
1104
1105
|
try:
|
|
1105
1106
|
np.testing.assert_allclose(
|
|
1106
1107
|
should_be_identity,
|
|
@@ -1110,7 +1111,7 @@ def require_adaptive_bread_inverse_is_true_inverse(
|
|
|
1110
1111
|
)
|
|
1111
1112
|
except AssertionError as e:
|
|
1112
1113
|
confirm_input_check_result(
|
|
1113
|
-
f"\nJoint
|
|
1114
|
+
f"\nJoint bread inverse is not exact inverse of the constructed matrix that was inverted to form it. This likely illustrates poor conditioning:\n{str(e)}\n\nContinue? (y/n)\n",
|
|
1114
1115
|
suppress_interactive_data_checks,
|
|
1115
1116
|
e,
|
|
1116
1117
|
)
|
|
@@ -1118,7 +1119,7 @@ def require_adaptive_bread_inverse_is_true_inverse(
|
|
|
1118
1119
|
# If we haven't already errored out, return some measures of how far off we are from identity
|
|
1119
1120
|
diff = should_be_identity - identity
|
|
1120
1121
|
logger.debug(
|
|
1121
|
-
"Difference between should-be-identity produced by multiplying joint
|
|
1122
|
+
"Difference between should-be-identity produced by multiplying joint bread and its computed inverse and actual identity:\n%s",
|
|
1122
1123
|
diff,
|
|
1123
1124
|
)
|
|
1124
1125
|
|
|
@@ -1133,8 +1134,8 @@ def require_adaptive_bread_inverse_is_true_inverse(
|
|
|
1133
1134
|
|
|
1134
1135
|
def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
1135
1136
|
algorithm_estimating_func,
|
|
1136
|
-
|
|
1137
|
-
|
|
1137
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
1138
|
+
threaded_update_func_args_by_policy_num_by_subject_id,
|
|
1138
1139
|
suppress_interactive_data_checks,
|
|
1139
1140
|
):
|
|
1140
1141
|
"""
|
|
@@ -1144,12 +1145,12 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
|
1144
1145
|
"""
|
|
1145
1146
|
for (
|
|
1146
1147
|
policy_num,
|
|
1147
|
-
|
|
1148
|
-
) in
|
|
1148
|
+
update_func_args_by_subject_id,
|
|
1149
|
+
) in update_func_args_by_by_subject_id_by_policy_num.items():
|
|
1149
1150
|
for (
|
|
1150
|
-
|
|
1151
|
+
subject_id,
|
|
1151
1152
|
unthreaded_args,
|
|
1152
|
-
) in
|
|
1153
|
+
) in update_func_args_by_subject_id.items():
|
|
1153
1154
|
if not unthreaded_args:
|
|
1154
1155
|
continue
|
|
1155
1156
|
np.testing.assert_allclose(
|
|
@@ -1157,9 +1158,9 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
|
1157
1158
|
# Need to stop gradient here because we can't convert a traced value to np array
|
|
1158
1159
|
jax.lax.stop_gradient(
|
|
1159
1160
|
algorithm_estimating_func(
|
|
1160
|
-
*
|
|
1161
|
-
|
|
1162
|
-
]
|
|
1161
|
+
*threaded_update_func_args_by_policy_num_by_subject_id[
|
|
1162
|
+
subject_id
|
|
1163
|
+
][policy_num]
|
|
1163
1164
|
)
|
|
1164
1165
|
),
|
|
1165
1166
|
atol=1e-7,
|
|
@@ -1169,8 +1170,8 @@ def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
|
1169
1170
|
|
|
1170
1171
|
def require_threaded_inference_estimating_function_args_equivalent(
|
|
1171
1172
|
inference_estimating_func,
|
|
1172
|
-
|
|
1173
|
-
|
|
1173
|
+
inference_func_args_by_subject_id,
|
|
1174
|
+
threaded_inference_func_args_by_subject_id,
|
|
1174
1175
|
suppress_interactive_data_checks,
|
|
1175
1176
|
):
|
|
1176
1177
|
"""
|
|
@@ -1178,7 +1179,7 @@ def require_threaded_inference_estimating_function_args_equivalent(
|
|
|
1178
1179
|
when called with the original arguments and when called with the
|
|
1179
1180
|
reconstructed action probabilities substituted in.
|
|
1180
1181
|
"""
|
|
1181
|
-
for
|
|
1182
|
+
for subject_id, unthreaded_args in inference_func_args_by_subject_id.items():
|
|
1182
1183
|
if not unthreaded_args:
|
|
1183
1184
|
continue
|
|
1184
1185
|
np.testing.assert_allclose(
|
|
@@ -1186,7 +1187,7 @@ def require_threaded_inference_estimating_function_args_equivalent(
|
|
|
1186
1187
|
# Need to stop gradient here because we can't convert a traced value to np array
|
|
1187
1188
|
jax.lax.stop_gradient(
|
|
1188
1189
|
inference_estimating_func(
|
|
1189
|
-
*
|
|
1190
|
+
*threaded_inference_func_args_by_subject_id[subject_id]
|
|
1190
1191
|
)
|
|
1191
1192
|
),
|
|
1192
1193
|
rtol=1e-2,
|