lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1145 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import jax
|
|
7
|
+
from jax import numpy as jnp
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import plotext as plt
|
|
10
|
+
|
|
11
|
+
from .constants import InverseStabilizationMethods, SmallSampleCorrections
|
|
12
|
+
from .helper_functions import (
|
|
13
|
+
confirm_input_check_result,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# When we print out objects for debugging, show the whole thing.
|
|
17
|
+
np.set_printoptions(threshold=np.inf)
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
logging.basicConfig(
|
|
21
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
22
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
23
|
+
level=logging.INFO,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# TODO: any checks needed here about alg update function type?
|
|
28
|
+
def perform_first_wave_input_checks(
|
|
29
|
+
study_df,
|
|
30
|
+
in_study_col_name,
|
|
31
|
+
action_col_name,
|
|
32
|
+
policy_num_col_name,
|
|
33
|
+
calendar_t_col_name,
|
|
34
|
+
user_id_col_name,
|
|
35
|
+
action_prob_col_name,
|
|
36
|
+
reward_col_name,
|
|
37
|
+
action_prob_func,
|
|
38
|
+
action_prob_func_args,
|
|
39
|
+
action_prob_func_args_beta_index,
|
|
40
|
+
alg_update_func_args,
|
|
41
|
+
alg_update_func_args_beta_index,
|
|
42
|
+
alg_update_func_args_action_prob_index,
|
|
43
|
+
alg_update_func_args_action_prob_times_index,
|
|
44
|
+
theta_est,
|
|
45
|
+
beta_dim,
|
|
46
|
+
suppress_interactive_data_checks,
|
|
47
|
+
small_sample_correction,
|
|
48
|
+
):
|
|
49
|
+
### Validate algorithm loss/estimating function and args
|
|
50
|
+
require_alg_update_args_given_for_all_users_at_each_update(
|
|
51
|
+
study_df, user_id_col_name, alg_update_func_args
|
|
52
|
+
)
|
|
53
|
+
require_no_policy_numbers_present_in_alg_update_args_but_not_study_df(
|
|
54
|
+
study_df, policy_num_col_name, alg_update_func_args
|
|
55
|
+
)
|
|
56
|
+
require_beta_is_1D_array_in_alg_update_args(
|
|
57
|
+
alg_update_func_args, alg_update_func_args_beta_index
|
|
58
|
+
)
|
|
59
|
+
require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
60
|
+
study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
|
|
61
|
+
)
|
|
62
|
+
confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
63
|
+
alg_update_func_args_action_prob_index, suppress_interactive_data_checks
|
|
64
|
+
)
|
|
65
|
+
require_action_prob_times_given_if_index_supplied(
|
|
66
|
+
alg_update_func_args_action_prob_index,
|
|
67
|
+
alg_update_func_args_action_prob_times_index,
|
|
68
|
+
)
|
|
69
|
+
require_action_prob_index_given_if_times_supplied(
|
|
70
|
+
alg_update_func_args_action_prob_index,
|
|
71
|
+
alg_update_func_args_action_prob_times_index,
|
|
72
|
+
)
|
|
73
|
+
require_betas_match_in_alg_update_args_each_update(
|
|
74
|
+
alg_update_func_args, alg_update_func_args_beta_index
|
|
75
|
+
)
|
|
76
|
+
require_action_prob_args_in_alg_update_func_correspond_to_study_df(
|
|
77
|
+
study_df,
|
|
78
|
+
action_prob_col_name,
|
|
79
|
+
calendar_t_col_name,
|
|
80
|
+
user_id_col_name,
|
|
81
|
+
alg_update_func_args,
|
|
82
|
+
alg_update_func_args_action_prob_index,
|
|
83
|
+
alg_update_func_args_action_prob_times_index,
|
|
84
|
+
)
|
|
85
|
+
require_valid_action_prob_times_given_if_index_supplied(
|
|
86
|
+
study_df,
|
|
87
|
+
calendar_t_col_name,
|
|
88
|
+
alg_update_func_args,
|
|
89
|
+
alg_update_func_args_action_prob_times_index,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
confirm_no_small_sample_correction_desired_if_not_requested(
|
|
93
|
+
small_sample_correction, suppress_interactive_data_checks
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
### Validate action prob function and args
|
|
97
|
+
require_action_prob_func_args_given_for_all_users_at_each_decision(
|
|
98
|
+
study_df, user_id_col_name, action_prob_func_args
|
|
99
|
+
)
|
|
100
|
+
require_action_prob_func_args_given_for_all_decision_times(
|
|
101
|
+
study_df, calendar_t_col_name, action_prob_func_args
|
|
102
|
+
)
|
|
103
|
+
require_action_probabilities_in_study_df_can_be_reconstructed(
|
|
104
|
+
study_df,
|
|
105
|
+
action_prob_col_name,
|
|
106
|
+
calendar_t_col_name,
|
|
107
|
+
user_id_col_name,
|
|
108
|
+
in_study_col_name,
|
|
109
|
+
action_prob_func_args,
|
|
110
|
+
action_prob_func,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
114
|
+
study_df,
|
|
115
|
+
calendar_t_col_name,
|
|
116
|
+
action_prob_func_args,
|
|
117
|
+
in_study_col_name,
|
|
118
|
+
user_id_col_name,
|
|
119
|
+
)
|
|
120
|
+
require_beta_is_1D_array_in_action_prob_args(
|
|
121
|
+
action_prob_func_args, action_prob_func_args_beta_index
|
|
122
|
+
)
|
|
123
|
+
require_betas_match_in_action_prob_func_args_each_decision(
|
|
124
|
+
action_prob_func_args, action_prob_func_args_beta_index
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
### Validate study_df
|
|
128
|
+
verify_study_df_summary_satisfactory(
|
|
129
|
+
study_df,
|
|
130
|
+
user_id_col_name,
|
|
131
|
+
policy_num_col_name,
|
|
132
|
+
calendar_t_col_name,
|
|
133
|
+
in_study_col_name,
|
|
134
|
+
action_prob_col_name,
|
|
135
|
+
reward_col_name,
|
|
136
|
+
beta_dim,
|
|
137
|
+
len(theta_est),
|
|
138
|
+
suppress_interactive_data_checks,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
require_all_users_have_all_times_in_study_df(
|
|
142
|
+
study_df, calendar_t_col_name, user_id_col_name
|
|
143
|
+
)
|
|
144
|
+
require_all_named_columns_present_in_study_df(
|
|
145
|
+
study_df,
|
|
146
|
+
in_study_col_name,
|
|
147
|
+
action_col_name,
|
|
148
|
+
policy_num_col_name,
|
|
149
|
+
calendar_t_col_name,
|
|
150
|
+
user_id_col_name,
|
|
151
|
+
action_prob_col_name,
|
|
152
|
+
)
|
|
153
|
+
require_all_named_columns_not_object_type_in_study_df(
|
|
154
|
+
study_df,
|
|
155
|
+
in_study_col_name,
|
|
156
|
+
action_col_name,
|
|
157
|
+
policy_num_col_name,
|
|
158
|
+
calendar_t_col_name,
|
|
159
|
+
user_id_col_name,
|
|
160
|
+
action_prob_col_name,
|
|
161
|
+
)
|
|
162
|
+
require_binary_actions(study_df, in_study_col_name, action_col_name)
|
|
163
|
+
require_binary_in_study_indicators(study_df, in_study_col_name)
|
|
164
|
+
require_consecutive_integer_policy_numbers(
|
|
165
|
+
study_df, in_study_col_name, policy_num_col_name
|
|
166
|
+
)
|
|
167
|
+
require_consecutive_integer_calendar_times(study_df, calendar_t_col_name)
|
|
168
|
+
require_hashable_user_ids(study_df, in_study_col_name, user_id_col_name)
|
|
169
|
+
require_action_probabilities_in_range_0_to_1(study_df, action_prob_col_name)
|
|
170
|
+
|
|
171
|
+
### Validate theta estimation
|
|
172
|
+
require_theta_is_1D_array(theta_est)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def perform_alg_only_input_checks(
|
|
176
|
+
study_df,
|
|
177
|
+
in_study_col_name,
|
|
178
|
+
policy_num_col_name,
|
|
179
|
+
calendar_t_col_name,
|
|
180
|
+
user_id_col_name,
|
|
181
|
+
action_prob_col_name,
|
|
182
|
+
action_prob_func,
|
|
183
|
+
action_prob_func_args,
|
|
184
|
+
action_prob_func_args_beta_index,
|
|
185
|
+
alg_update_func_args,
|
|
186
|
+
alg_update_func_args_beta_index,
|
|
187
|
+
alg_update_func_args_action_prob_index,
|
|
188
|
+
alg_update_func_args_action_prob_times_index,
|
|
189
|
+
suppress_interactive_data_checks,
|
|
190
|
+
):
|
|
191
|
+
### Validate algorithm loss/estimating function and args
|
|
192
|
+
require_alg_update_args_given_for_all_users_at_each_update(
|
|
193
|
+
study_df, user_id_col_name, alg_update_func_args
|
|
194
|
+
)
|
|
195
|
+
require_beta_is_1D_array_in_alg_update_args(
|
|
196
|
+
alg_update_func_args, alg_update_func_args_beta_index
|
|
197
|
+
)
|
|
198
|
+
require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
199
|
+
study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
|
|
200
|
+
)
|
|
201
|
+
confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
202
|
+
alg_update_func_args_action_prob_index, suppress_interactive_data_checks
|
|
203
|
+
)
|
|
204
|
+
require_action_prob_times_given_if_index_supplied(
|
|
205
|
+
alg_update_func_args_action_prob_index,
|
|
206
|
+
alg_update_func_args_action_prob_times_index,
|
|
207
|
+
)
|
|
208
|
+
require_action_prob_index_given_if_times_supplied(
|
|
209
|
+
alg_update_func_args_action_prob_index,
|
|
210
|
+
alg_update_func_args_action_prob_times_index,
|
|
211
|
+
)
|
|
212
|
+
require_betas_match_in_alg_update_args_each_update(
|
|
213
|
+
alg_update_func_args, alg_update_func_args_beta_index
|
|
214
|
+
)
|
|
215
|
+
require_action_prob_args_in_alg_update_func_correspond_to_study_df(
|
|
216
|
+
study_df,
|
|
217
|
+
action_prob_col_name,
|
|
218
|
+
calendar_t_col_name,
|
|
219
|
+
user_id_col_name,
|
|
220
|
+
alg_update_func_args,
|
|
221
|
+
alg_update_func_args_action_prob_index,
|
|
222
|
+
alg_update_func_args_action_prob_times_index,
|
|
223
|
+
)
|
|
224
|
+
require_valid_action_prob_times_given_if_index_supplied(
|
|
225
|
+
study_df,
|
|
226
|
+
calendar_t_col_name,
|
|
227
|
+
alg_update_func_args,
|
|
228
|
+
alg_update_func_args_action_prob_times_index,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
### Validate action prob function and args
|
|
232
|
+
require_action_prob_func_args_given_for_all_users_at_each_decision(
|
|
233
|
+
study_df, user_id_col_name, action_prob_func_args
|
|
234
|
+
)
|
|
235
|
+
require_action_prob_func_args_given_for_all_decision_times(
|
|
236
|
+
study_df, calendar_t_col_name, action_prob_func_args
|
|
237
|
+
)
|
|
238
|
+
require_action_probabilities_in_study_df_can_be_reconstructed(
|
|
239
|
+
study_df,
|
|
240
|
+
action_prob_col_name,
|
|
241
|
+
calendar_t_col_name,
|
|
242
|
+
user_id_col_name,
|
|
243
|
+
in_study_col_name,
|
|
244
|
+
action_prob_func_args,
|
|
245
|
+
action_prob_func=action_prob_func,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
249
|
+
study_df,
|
|
250
|
+
calendar_t_col_name,
|
|
251
|
+
action_prob_func_args,
|
|
252
|
+
in_study_col_name,
|
|
253
|
+
user_id_col_name,
|
|
254
|
+
)
|
|
255
|
+
require_beta_is_1D_array_in_action_prob_args(
|
|
256
|
+
action_prob_func_args, action_prob_func_args_beta_index
|
|
257
|
+
)
|
|
258
|
+
require_betas_match_in_action_prob_func_args_each_decision(
|
|
259
|
+
action_prob_func_args, action_prob_func_args_beta_index
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
# TODO: Give a hard-to-use option to loosen this check somehow
|
|
264
|
+
def require_action_probabilities_in_study_df_can_be_reconstructed(
|
|
265
|
+
study_df,
|
|
266
|
+
action_prob_col_name,
|
|
267
|
+
calendar_t_col_name,
|
|
268
|
+
user_id_col_name,
|
|
269
|
+
in_study_col_name,
|
|
270
|
+
action_prob_func_args,
|
|
271
|
+
action_prob_func,
|
|
272
|
+
):
|
|
273
|
+
"""
|
|
274
|
+
Check that the action probabilities in the study dataframe can be reconstructed from the supplied
|
|
275
|
+
action probability function and its arguments.
|
|
276
|
+
|
|
277
|
+
NOTE THAT THIS IS A HARD FAILURE IF THE RECONSTRUCTION DOESN'T PASS.
|
|
278
|
+
"""
|
|
279
|
+
logger.info("Reconstructing action probabilities from function and arguments.")
|
|
280
|
+
|
|
281
|
+
in_study_df = study_df[study_df[in_study_col_name] == 1]
|
|
282
|
+
reconstructed_action_probs = in_study_df.apply(
|
|
283
|
+
lambda row: action_prob_func(
|
|
284
|
+
*action_prob_func_args[row[calendar_t_col_name]][row[user_id_col_name]]
|
|
285
|
+
),
|
|
286
|
+
axis=1,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
np.testing.assert_allclose(
|
|
290
|
+
in_study_df[action_prob_col_name].to_numpy(dtype="float64"),
|
|
291
|
+
reconstructed_action_probs.to_numpy(dtype="float64"),
|
|
292
|
+
atol=1e-6,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def require_all_users_have_all_times_in_study_df(
|
|
297
|
+
study_df, calendar_t_col_name, user_id_col_name
|
|
298
|
+
):
|
|
299
|
+
logger.info("Checking that all users have the same set of unique calendar times.")
|
|
300
|
+
# Get the unique calendar times
|
|
301
|
+
unique_calendar_times = set(study_df[calendar_t_col_name].unique())
|
|
302
|
+
|
|
303
|
+
# Group by user ID and aggregate the unique calendar times for each user
|
|
304
|
+
user_calendar_times = study_df.groupby(user_id_col_name)[calendar_t_col_name].apply(
|
|
305
|
+
set
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Check if all users have the same set of unique calendar times
|
|
309
|
+
if not user_calendar_times.apply(lambda x: x == unique_calendar_times).all():
|
|
310
|
+
raise AssertionError(
|
|
311
|
+
"Not all users have all calendar times in the study dataframe. Please see the contract for details."
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def require_alg_update_args_given_for_all_users_at_each_update(
|
|
316
|
+
study_df, user_id_col_name, alg_update_func_args
|
|
317
|
+
):
|
|
318
|
+
logger.info(
|
|
319
|
+
"Checking that algorithm update function args are given for all users at each update."
|
|
320
|
+
)
|
|
321
|
+
all_user_ids = set(study_df[user_id_col_name].unique())
|
|
322
|
+
for policy_num in alg_update_func_args:
|
|
323
|
+
assert (
|
|
324
|
+
set(alg_update_func_args[policy_num].keys()) == all_user_ids
|
|
325
|
+
), f"Not all users present in algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def require_action_prob_args_in_alg_update_func_correspond_to_study_df(
|
|
329
|
+
study_df,
|
|
330
|
+
action_prob_col_name,
|
|
331
|
+
calendar_t_col_name,
|
|
332
|
+
user_id_col_name,
|
|
333
|
+
alg_update_func_args,
|
|
334
|
+
alg_update_func_args_action_prob_index,
|
|
335
|
+
alg_update_func_args_action_prob_times_index,
|
|
336
|
+
):
|
|
337
|
+
logger.info(
|
|
338
|
+
"Checking that the action probabilities supplied in the algorithm update function args, if"
|
|
339
|
+
" any, correspond to those in the study dataframe for the corresponding users and decision"
|
|
340
|
+
" times."
|
|
341
|
+
)
|
|
342
|
+
if alg_update_func_args_action_prob_index < 0:
|
|
343
|
+
return
|
|
344
|
+
|
|
345
|
+
# Precompute a lookup dictionary for faster access
|
|
346
|
+
study_df_lookup = study_df.set_index([calendar_t_col_name, user_id_col_name])[
|
|
347
|
+
action_prob_col_name
|
|
348
|
+
].to_dict()
|
|
349
|
+
|
|
350
|
+
for policy_num, user_args in alg_update_func_args.items():
|
|
351
|
+
for user_id, args in user_args.items():
|
|
352
|
+
if not args:
|
|
353
|
+
continue
|
|
354
|
+
arg_action_probs = args[alg_update_func_args_action_prob_index]
|
|
355
|
+
action_prob_times = args[
|
|
356
|
+
alg_update_func_args_action_prob_times_index
|
|
357
|
+
].flatten()
|
|
358
|
+
|
|
359
|
+
# Use the precomputed lookup dictionary
|
|
360
|
+
study_df_action_probs = [
|
|
361
|
+
study_df_lookup[(decision_time.item(), user_id)]
|
|
362
|
+
for decision_time in action_prob_times
|
|
363
|
+
]
|
|
364
|
+
|
|
365
|
+
assert np.allclose(
|
|
366
|
+
arg_action_probs.flatten(),
|
|
367
|
+
study_df_action_probs,
|
|
368
|
+
), (
|
|
369
|
+
f"There is a mismatch for user {user_id} between the action probabilities supplied"
|
|
370
|
+
f" in the args to the algorithm update function at policy {policy_num} and those in"
|
|
371
|
+
" the study dataframe for the supplied times. Please see the contract for details."
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def require_action_prob_func_args_given_for_all_users_at_each_decision(
|
|
376
|
+
study_df,
|
|
377
|
+
user_id_col_name,
|
|
378
|
+
action_prob_func_args,
|
|
379
|
+
):
|
|
380
|
+
logger.info(
|
|
381
|
+
"Checking that action prob function args are given for all users at each decision time."
|
|
382
|
+
)
|
|
383
|
+
all_user_ids = set(study_df[user_id_col_name].unique())
|
|
384
|
+
for decision_time in action_prob_func_args:
|
|
385
|
+
assert (
|
|
386
|
+
set(action_prob_func_args[decision_time].keys()) == all_user_ids
|
|
387
|
+
), f"Not all users present in algorithm update function args for decision time {decision_time}. Please see the contract for details."
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def require_action_prob_func_args_given_for_all_decision_times(
|
|
391
|
+
study_df, calendar_t_col_name, action_prob_func_args
|
|
392
|
+
):
|
|
393
|
+
logger.info(
|
|
394
|
+
"Checking that action prob function args are given for all decision times."
|
|
395
|
+
)
|
|
396
|
+
all_times = set(study_df[calendar_t_col_name].unique())
|
|
397
|
+
|
|
398
|
+
assert (
|
|
399
|
+
set(action_prob_func_args.keys()) == all_times
|
|
400
|
+
), "Not all decision times present in action prob function args. Please see the contract for details."
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
|
|
404
|
+
study_df: pd.DataFrame,
|
|
405
|
+
calendar_t_col_name: str,
|
|
406
|
+
action_prob_func_args: dict[str, dict[str, tuple[Any, ...]]],
|
|
407
|
+
in_study_col_name,
|
|
408
|
+
user_id_col_name,
|
|
409
|
+
):
|
|
410
|
+
logger.info(
|
|
411
|
+
"Checking that action probability function args are blank for exactly the times each user"
|
|
412
|
+
" is not in the study according to the study dataframe."
|
|
413
|
+
)
|
|
414
|
+
out_of_study_df = study_df[study_df[in_study_col_name] == 0]
|
|
415
|
+
out_of_study_times_by_user_according_to_study_df = (
|
|
416
|
+
out_of_study_df.groupby(user_id_col_name)[calendar_t_col_name]
|
|
417
|
+
.apply(set)
|
|
418
|
+
.to_dict()
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
out_of_study_times_by_user_according_to_action_prob_func_args = (
|
|
422
|
+
collections.defaultdict(set)
|
|
423
|
+
)
|
|
424
|
+
for decision_time, action_prob_args_by_user in action_prob_func_args.items():
|
|
425
|
+
for user_id, action_prob_args in action_prob_args_by_user.items():
|
|
426
|
+
if not action_prob_args:
|
|
427
|
+
out_of_study_times_by_user_according_to_action_prob_func_args[
|
|
428
|
+
user_id
|
|
429
|
+
].add(decision_time)
|
|
430
|
+
|
|
431
|
+
assert (
|
|
432
|
+
out_of_study_times_by_user_according_to_study_df
|
|
433
|
+
== out_of_study_times_by_user_according_to_action_prob_func_args
|
|
434
|
+
), (
|
|
435
|
+
"Out-of-study decision times according to the study dataframe do not match up with the"
|
|
436
|
+
" times for which action probability arguments are blank for all users. Please see the"
|
|
437
|
+
" contract for details."
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def require_all_named_columns_present_in_study_df(
|
|
442
|
+
study_df,
|
|
443
|
+
in_study_col_name,
|
|
444
|
+
action_col_name,
|
|
445
|
+
policy_num_col_name,
|
|
446
|
+
calendar_t_col_name,
|
|
447
|
+
user_id_col_name,
|
|
448
|
+
action_prob_col_name,
|
|
449
|
+
):
|
|
450
|
+
logger.info("Checking that all named columns are present in the study dataframe.")
|
|
451
|
+
assert (
|
|
452
|
+
in_study_col_name in study_df.columns
|
|
453
|
+
), f"{in_study_col_name} not in study df."
|
|
454
|
+
assert action_col_name in study_df.columns, f"{action_col_name} not in study df."
|
|
455
|
+
assert (
|
|
456
|
+
policy_num_col_name in study_df.columns
|
|
457
|
+
), f"{policy_num_col_name} not in study df."
|
|
458
|
+
assert (
|
|
459
|
+
calendar_t_col_name in study_df.columns
|
|
460
|
+
), f"{calendar_t_col_name} not in study df."
|
|
461
|
+
assert user_id_col_name in study_df.columns, f"{user_id_col_name} not in study df."
|
|
462
|
+
assert (
|
|
463
|
+
action_prob_col_name in study_df.columns
|
|
464
|
+
), f"{action_prob_col_name} not in study df."
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def require_all_named_columns_not_object_type_in_study_df(
|
|
468
|
+
study_df,
|
|
469
|
+
in_study_col_name,
|
|
470
|
+
action_col_name,
|
|
471
|
+
policy_num_col_name,
|
|
472
|
+
calendar_t_col_name,
|
|
473
|
+
user_id_col_name,
|
|
474
|
+
action_prob_col_name,
|
|
475
|
+
):
|
|
476
|
+
logger.info("Checking that all named columns are not type object.")
|
|
477
|
+
for colname in (
|
|
478
|
+
in_study_col_name,
|
|
479
|
+
action_col_name,
|
|
480
|
+
policy_num_col_name,
|
|
481
|
+
calendar_t_col_name,
|
|
482
|
+
user_id_col_name,
|
|
483
|
+
action_prob_col_name,
|
|
484
|
+
):
|
|
485
|
+
assert (
|
|
486
|
+
study_df[colname].dtype != "object"
|
|
487
|
+
), f"At least {colname} is of object type in study df."
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def require_binary_actions(study_df, in_study_col_name, action_col_name):
|
|
491
|
+
logger.info("Checking that actions are binary.")
|
|
492
|
+
assert (
|
|
493
|
+
study_df[study_df[in_study_col_name] == 1][action_col_name]
|
|
494
|
+
.astype("int64")
|
|
495
|
+
.isin([0, 1])
|
|
496
|
+
.all()
|
|
497
|
+
), "Actions are not binary."
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def require_binary_in_study_indicators(study_df, in_study_col_name):
|
|
501
|
+
logger.info("Checking that in-study indicators are binary.")
|
|
502
|
+
assert (
|
|
503
|
+
study_df[study_df[in_study_col_name] == 1][in_study_col_name]
|
|
504
|
+
.astype("int64")
|
|
505
|
+
.isin([0, 1])
|
|
506
|
+
.all()
|
|
507
|
+
), "In-study indicators are not binary."
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def require_consecutive_integer_policy_numbers(
|
|
511
|
+
study_df, in_study_col_name, policy_num_col_name
|
|
512
|
+
):
|
|
513
|
+
# TODO: This is a somewhat rough check of this, could also check nondecreasing temporally
|
|
514
|
+
|
|
515
|
+
logger.info(
|
|
516
|
+
"Checking that in-study, non-fallback policy numbers are consecutive integers."
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
in_study_df = study_df[study_df[in_study_col_name] == 1]
|
|
520
|
+
nonnegative_policy_df = in_study_df[in_study_df[policy_num_col_name] >= 0]
|
|
521
|
+
# Ideally we actually have integers, but for legacy reasons we will support
|
|
522
|
+
# floats as well.
|
|
523
|
+
if nonnegative_policy_df[policy_num_col_name].dtype == "float64":
|
|
524
|
+
nonnegative_policy_df[policy_num_col_name] = nonnegative_policy_df[
|
|
525
|
+
policy_num_col_name
|
|
526
|
+
].astype("int64")
|
|
527
|
+
assert np.array_equal(
|
|
528
|
+
nonnegative_policy_df[policy_num_col_name].unique(),
|
|
529
|
+
range(
|
|
530
|
+
nonnegative_policy_df[policy_num_col_name].min(),
|
|
531
|
+
nonnegative_policy_df[policy_num_col_name].max() + 1,
|
|
532
|
+
),
|
|
533
|
+
), "Policy numbers are not consecutive integers."
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def require_consecutive_integer_calendar_times(study_df, calendar_t_col_name):
|
|
537
|
+
# This is a somewhat rough check of this, more like checking there are no
|
|
538
|
+
# gaps in the integers covered. But we have other checks that all users
|
|
539
|
+
# have same times, etc.
|
|
540
|
+
# Note these times should be well-formed even when the user is not in the study.
|
|
541
|
+
logger.info("Checking that calendar times are consecutive integers.")
|
|
542
|
+
assert np.array_equal(
|
|
543
|
+
study_df[calendar_t_col_name].unique(),
|
|
544
|
+
range(
|
|
545
|
+
study_df[calendar_t_col_name].min(), study_df[calendar_t_col_name].max() + 1
|
|
546
|
+
),
|
|
547
|
+
), "Calendar times are not consecutive integers."
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def require_hashable_user_ids(study_df, in_study_col_name, user_id_col_name):
|
|
551
|
+
logger.info("Checking that user IDs are hashable.")
|
|
552
|
+
isinstance(
|
|
553
|
+
study_df[study_df[in_study_col_name] == 1][user_id_col_name][0],
|
|
554
|
+
collections.abc.Hashable,
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def require_action_probabilities_in_range_0_to_1(study_df, action_prob_col_name):
|
|
559
|
+
logger.info("Checking that action probabilities are in the interval (0, 1).")
|
|
560
|
+
study_df[action_prob_col_name].between(0, 1, inclusive="neither").all()
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def require_no_policy_numbers_present_in_alg_update_args_but_not_study_df(
|
|
564
|
+
study_df, policy_num_col_name, alg_update_func_args
|
|
565
|
+
):
|
|
566
|
+
logger.info(
|
|
567
|
+
"Checking that policy numbers in algorithm update function args are present in the study dataframe."
|
|
568
|
+
)
|
|
569
|
+
alg_update_policy_nums = sorted(alg_update_func_args.keys())
|
|
570
|
+
study_df_policy_nums = sorted(study_df[policy_num_col_name].unique())
|
|
571
|
+
assert set(alg_update_policy_nums).issubset(set(study_df_policy_nums)), (
|
|
572
|
+
f"There are policy numbers present in algorithm update function args but not in the study dataframe. "
|
|
573
|
+
f"\nalg_update_func_args policy numbers: {alg_update_policy_nums}"
|
|
574
|
+
f"\nstudy_df policy numbers: {study_df_policy_nums}.\nPlease see the contract for details."
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
579
|
+
study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
|
|
580
|
+
):
|
|
581
|
+
logger.info(
|
|
582
|
+
"Checking that all policy numbers in the study dataframe are present in the algorithm update function args."
|
|
583
|
+
)
|
|
584
|
+
in_study_df = study_df[study_df[in_study_col_name] == 1]
|
|
585
|
+
# Get the number of the initial policy. 0 is recommended but not required.
|
|
586
|
+
min_nonnegative_policy_number = in_study_df[in_study_df[policy_num_col_name] >= 0][
|
|
587
|
+
policy_num_col_name
|
|
588
|
+
]
|
|
589
|
+
assert set(
|
|
590
|
+
in_study_df[in_study_df[policy_num_col_name] > min_nonnegative_policy_number][
|
|
591
|
+
policy_num_col_name
|
|
592
|
+
].unique()
|
|
593
|
+
).issubset(
|
|
594
|
+
alg_update_func_args.keys()
|
|
595
|
+
), f"There are non-fallback, non-initial policy numbers in the study dataframe that are not in the update function args: {set(in_study_df[in_study_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
599
|
+
alg_update_func_args_action_prob_index,
|
|
600
|
+
suppress_interactive_data_checks,
|
|
601
|
+
):
|
|
602
|
+
logger.info(
|
|
603
|
+
"Confirming that action probabilities are not in algorithm update function args IF their index is not specified"
|
|
604
|
+
)
|
|
605
|
+
if alg_update_func_args_action_prob_index < 0:
|
|
606
|
+
confirm_input_check_result(
|
|
607
|
+
"\nYou specified that the algorithm update function supplied does not have action probabilities as one of its arguments. Please verify this is correct.\n\nContinue? (y/n)\n",
|
|
608
|
+
suppress_interactive_data_checks,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def confirm_no_small_sample_correction_desired_if_not_requested(
|
|
613
|
+
small_sample_correction,
|
|
614
|
+
suppress_interactive_data_checks,
|
|
615
|
+
):
|
|
616
|
+
logger.info(
|
|
617
|
+
"Confirming that no small sample correction is desired if it's not requested."
|
|
618
|
+
)
|
|
619
|
+
if small_sample_correction == SmallSampleCorrections.NONE:
|
|
620
|
+
confirm_input_check_result(
|
|
621
|
+
"\nYou specified that you would not like to perform any small-sample corrections. Please verify that this is correct.\n\nContinue? (y/n)\n",
|
|
622
|
+
suppress_interactive_data_checks,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def confirm_no_adaptive_bread_inverse_stabilization_method_desired_if_not_requested(
|
|
627
|
+
adaptive_bread_inverse_stabilization_method,
|
|
628
|
+
suppress_interactive_data_checks,
|
|
629
|
+
):
|
|
630
|
+
logger.info(
|
|
631
|
+
"Confirming that no adaptive bread inverse stabilization method is desired if it's not requested."
|
|
632
|
+
)
|
|
633
|
+
if adaptive_bread_inverse_stabilization_method == InverseStabilizationMethods.NONE:
|
|
634
|
+
confirm_input_check_result(
|
|
635
|
+
"\nYou specified that you would not like to perform any inverse stabilization while forming the adaptive variance. This is not usually recommended. Please verify that it is correct or select one of the available options.\n\nContinue? (y/n)\n",
|
|
636
|
+
suppress_interactive_data_checks,
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def require_action_prob_times_given_if_index_supplied(
|
|
641
|
+
alg_update_func_args_action_prob_index,
|
|
642
|
+
alg_update_func_args_action_prob_times_index,
|
|
643
|
+
):
|
|
644
|
+
logger.info("Checking that action prob times are given if index is supplied.")
|
|
645
|
+
if alg_update_func_args_action_prob_index >= 0:
|
|
646
|
+
assert alg_update_func_args_action_prob_times_index >= 0 and (
|
|
647
|
+
alg_update_func_args_action_prob_times_index
|
|
648
|
+
!= alg_update_func_args_action_prob_index
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def require_action_prob_index_given_if_times_supplied(
|
|
653
|
+
alg_update_func_args_action_prob_index,
|
|
654
|
+
alg_update_func_args_action_prob_times_index,
|
|
655
|
+
):
|
|
656
|
+
logger.info("Checking that action prob index is given if times are supplied.")
|
|
657
|
+
if alg_update_func_args_action_prob_times_index >= 0:
|
|
658
|
+
assert alg_update_func_args_action_prob_index >= 0 and (
|
|
659
|
+
alg_update_func_args_action_prob_times_index
|
|
660
|
+
!= alg_update_func_args_action_prob_index
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def require_beta_is_1D_array_in_alg_update_args(
|
|
665
|
+
alg_update_func_args, alg_update_func_args_beta_index
|
|
666
|
+
):
|
|
667
|
+
for policy_num in alg_update_func_args:
|
|
668
|
+
for user_id in alg_update_func_args[policy_num]:
|
|
669
|
+
if not alg_update_func_args[policy_num][user_id]:
|
|
670
|
+
continue
|
|
671
|
+
assert (
|
|
672
|
+
alg_update_func_args[policy_num][user_id][
|
|
673
|
+
alg_update_func_args_beta_index
|
|
674
|
+
].ndim
|
|
675
|
+
== 1
|
|
676
|
+
), "Beta is not a 1D array in the algorithm update function args."
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def require_beta_is_1D_array_in_action_prob_args(
|
|
680
|
+
action_prob_func_args, action_prob_func_args_beta_index
|
|
681
|
+
):
|
|
682
|
+
for decision_time in action_prob_func_args:
|
|
683
|
+
for user_id in action_prob_func_args[decision_time]:
|
|
684
|
+
if not action_prob_func_args[decision_time][user_id]:
|
|
685
|
+
continue
|
|
686
|
+
assert (
|
|
687
|
+
action_prob_func_args[decision_time][user_id][
|
|
688
|
+
action_prob_func_args_beta_index
|
|
689
|
+
].ndim
|
|
690
|
+
== 1
|
|
691
|
+
), "Beta is not a 1D array in the action probability function args."
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def require_theta_is_1D_array(theta_est):
|
|
695
|
+
assert theta_est.ndim == 1, "Theta is not a 1D array."
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def verify_study_df_summary_satisfactory(
|
|
699
|
+
study_df,
|
|
700
|
+
user_id_col_name,
|
|
701
|
+
policy_num_col_name,
|
|
702
|
+
calendar_t_col_name,
|
|
703
|
+
in_study_col_name,
|
|
704
|
+
action_prob_col_name,
|
|
705
|
+
reward_col_name,
|
|
706
|
+
beta_dim,
|
|
707
|
+
theta_dim,
|
|
708
|
+
suppress_interactive_data_checks,
|
|
709
|
+
):
|
|
710
|
+
|
|
711
|
+
in_study_df = study_df[study_df[in_study_col_name] == 1]
|
|
712
|
+
num_users = in_study_df[user_id_col_name].nunique()
|
|
713
|
+
num_non_initial_or_fallback_policies = in_study_df[
|
|
714
|
+
in_study_df[policy_num_col_name] > 0
|
|
715
|
+
][policy_num_col_name].nunique()
|
|
716
|
+
num_decision_times_with_fallback_policies = len(
|
|
717
|
+
in_study_df[in_study_df[policy_num_col_name] < 0]
|
|
718
|
+
)
|
|
719
|
+
num_decision_times = in_study_df[calendar_t_col_name].nunique()
|
|
720
|
+
avg_decisions_per_user = len(in_study_df) / num_users
|
|
721
|
+
num_decision_times_with_multiple_policies = (
|
|
722
|
+
in_study_df[in_study_df[policy_num_col_name] >= 0]
|
|
723
|
+
.groupby(calendar_t_col_name)[policy_num_col_name]
|
|
724
|
+
.nunique()
|
|
725
|
+
> 1
|
|
726
|
+
).sum()
|
|
727
|
+
min_action_prob = in_study_df[action_prob_col_name].min()
|
|
728
|
+
max_action_prob = in_study_df[action_prob_col_name].max()
|
|
729
|
+
min_non_fallback_policy_num = in_study_df[in_study_df[policy_num_col_name] >= 0][
|
|
730
|
+
policy_num_col_name
|
|
731
|
+
].min()
|
|
732
|
+
num_data_points_before_first_update = len(
|
|
733
|
+
in_study_df[in_study_df[policy_num_col_name] == min_non_fallback_policy_num]
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
median_action_probabilities = (
|
|
737
|
+
in_study_df.groupby(calendar_t_col_name)[action_prob_col_name]
|
|
738
|
+
.median()
|
|
739
|
+
.to_numpy()
|
|
740
|
+
)
|
|
741
|
+
quartiles = in_study_df.groupby(calendar_t_col_name)[action_prob_col_name].quantile(
|
|
742
|
+
[0.25, 0.75]
|
|
743
|
+
)
|
|
744
|
+
q25_action_probabilities = quartiles.xs(0.25, level=1).to_numpy()
|
|
745
|
+
q75_action_probabilities = quartiles.xs(0.75, level=1).to_numpy()
|
|
746
|
+
|
|
747
|
+
avg_rewards = in_study_df.groupby(calendar_t_col_name)[reward_col_name].mean()
|
|
748
|
+
|
|
749
|
+
# Plot action probability quartile trajectories
|
|
750
|
+
plt.clear_figure()
|
|
751
|
+
plt.title("Action 1 Probability 25/50/75 Quantile Trajectories")
|
|
752
|
+
plt.xlabel("Decision Time")
|
|
753
|
+
plt.ylabel("Action 1 Probability Quantiles")
|
|
754
|
+
plt.error(
|
|
755
|
+
median_action_probabilities,
|
|
756
|
+
yerr=q75_action_probabilities - q25_action_probabilities,
|
|
757
|
+
color="blue+",
|
|
758
|
+
)
|
|
759
|
+
plt.grid(True)
|
|
760
|
+
plt.xticks(
|
|
761
|
+
range(
|
|
762
|
+
0,
|
|
763
|
+
len(median_action_probabilities),
|
|
764
|
+
max(1, len(median_action_probabilities) // 10),
|
|
765
|
+
)
|
|
766
|
+
)
|
|
767
|
+
action_prob_trajectory_plot = plt.build()
|
|
768
|
+
|
|
769
|
+
# Plot avg reward trajectory
|
|
770
|
+
plt.clear_figure()
|
|
771
|
+
plt.title("Avg Reward Trajectory")
|
|
772
|
+
plt.xlabel("Decision Time")
|
|
773
|
+
plt.ylabel("Avg Reward")
|
|
774
|
+
plt.scatter(avg_rewards, color="blue+", marker="*")
|
|
775
|
+
plt.grid(True)
|
|
776
|
+
plt.xticks(
|
|
777
|
+
range(
|
|
778
|
+
0,
|
|
779
|
+
len(avg_rewards),
|
|
780
|
+
max(1, len(avg_rewards) // 10),
|
|
781
|
+
)
|
|
782
|
+
)
|
|
783
|
+
avg_reward_trajectory_plot = plt.build()
|
|
784
|
+
|
|
785
|
+
confirm_input_check_result(
|
|
786
|
+
f"\nYou provided a study dataframe reflecting a study with"
|
|
787
|
+
f"\n* {num_users} users"
|
|
788
|
+
f"\n* {num_non_initial_or_fallback_policies} policy updates"
|
|
789
|
+
f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_user}"
|
|
790
|
+
f" decisions per user"
|
|
791
|
+
f"\n* RL parameters of dimension {beta_dim} per update"
|
|
792
|
+
f"\n* Inferential target of dimension {theta_dim}"
|
|
793
|
+
f"\n* {num_data_points_before_first_update} data points before the first update"
|
|
794
|
+
f"\n* {num_decision_times_with_fallback_policies} decision times"
|
|
795
|
+
f" ({num_decision_times_with_fallback_policies * 100 / num_decision_times}%) for which"
|
|
796
|
+
f" fallback policies were used"
|
|
797
|
+
f"\n* {num_decision_times_with_multiple_policies} decision times"
|
|
798
|
+
f" ({num_decision_times_with_multiple_policies * 100 / num_decision_times}%)"
|
|
799
|
+
f" for which multiple non-fallback policies were used"
|
|
800
|
+
f"\n* Minimum action probability {min_action_prob}"
|
|
801
|
+
f"\n* Maximum action probability {max_action_prob}"
|
|
802
|
+
f"\n* The following trajectories of action probability quartiles over time:\n {action_prob_trajectory_plot}"
|
|
803
|
+
f"\n* The following average reward trajectory over time:\n {avg_reward_trajectory_plot}"
|
|
804
|
+
f" \n\nDoes this meet expectations? (y/n)\n",
|
|
805
|
+
suppress_interactive_data_checks,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
def require_betas_match_in_alg_update_args_each_update(
|
|
810
|
+
alg_update_func_args, alg_update_func_args_beta_index
|
|
811
|
+
):
|
|
812
|
+
logger.info(
|
|
813
|
+
"Checking that betas match across users for each update in the algorithm update function args."
|
|
814
|
+
)
|
|
815
|
+
for policy_num in alg_update_func_args:
|
|
816
|
+
first_beta = None
|
|
817
|
+
for user_id in alg_update_func_args[policy_num]:
|
|
818
|
+
if not alg_update_func_args[policy_num][user_id]:
|
|
819
|
+
continue
|
|
820
|
+
beta = alg_update_func_args[policy_num][user_id][
|
|
821
|
+
alg_update_func_args_beta_index
|
|
822
|
+
]
|
|
823
|
+
if first_beta is None:
|
|
824
|
+
first_beta = beta
|
|
825
|
+
else:
|
|
826
|
+
assert np.array_equal(
|
|
827
|
+
beta, first_beta
|
|
828
|
+
), f"Betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
def require_betas_match_in_action_prob_func_args_each_decision(
|
|
832
|
+
action_prob_func_args, action_prob_func_args_beta_index
|
|
833
|
+
):
|
|
834
|
+
logger.info(
|
|
835
|
+
"Checking that betas match across users for each decision time in the action prob args."
|
|
836
|
+
)
|
|
837
|
+
for decision_time in action_prob_func_args:
|
|
838
|
+
first_beta = None
|
|
839
|
+
for user_id in action_prob_func_args[decision_time]:
|
|
840
|
+
if not action_prob_func_args[decision_time][user_id]:
|
|
841
|
+
continue
|
|
842
|
+
beta = action_prob_func_args[decision_time][user_id][
|
|
843
|
+
action_prob_func_args_beta_index
|
|
844
|
+
]
|
|
845
|
+
if first_beta is None:
|
|
846
|
+
first_beta = beta
|
|
847
|
+
else:
|
|
848
|
+
assert np.array_equal(
|
|
849
|
+
beta, first_beta
|
|
850
|
+
), f"Betas do not match across users in the action prob args for decision_time {decision_time}. Please see the contract for details."
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
def require_valid_action_prob_times_given_if_index_supplied(
|
|
854
|
+
study_df,
|
|
855
|
+
calendar_t_col_name,
|
|
856
|
+
alg_update_func_args,
|
|
857
|
+
alg_update_func_args_action_prob_times_index,
|
|
858
|
+
):
|
|
859
|
+
logger.info("Checking that action prob times are valid if index is supplied.")
|
|
860
|
+
|
|
861
|
+
if alg_update_func_args_action_prob_times_index < 0:
|
|
862
|
+
return
|
|
863
|
+
|
|
864
|
+
min_time = study_df[calendar_t_col_name].min()
|
|
865
|
+
max_time = study_df[calendar_t_col_name].max()
|
|
866
|
+
for policy_idx, args_by_user in alg_update_func_args.items():
|
|
867
|
+
for user_id, args in args_by_user.items():
|
|
868
|
+
if not args:
|
|
869
|
+
continue
|
|
870
|
+
times = args[alg_update_func_args_action_prob_times_index]
|
|
871
|
+
assert (
|
|
872
|
+
times[i] > times[i - 1] for i in range(1, len(times))
|
|
873
|
+
), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for user {user_id} and policy {policy_idx}. Please see the contract for details."
|
|
874
|
+
assert (
|
|
875
|
+
times[0] >= min_time and times[-1] <= max_time
|
|
876
|
+
), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the study dataframe are {min_time} and {max_time}, while user {user_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def require_estimating_functions_sum_to_zero(
|
|
880
|
+
mean_estimating_function_stack: jnp.ndarray,
|
|
881
|
+
beta_dim: int,
|
|
882
|
+
theta_dim: int,
|
|
883
|
+
suppress_interactive_data_checks: bool,
|
|
884
|
+
):
|
|
885
|
+
"""
|
|
886
|
+
This is a test that the correct loss/estimating functions have
|
|
887
|
+
been given for both the algorithm updates and inference. If that is true, then the
|
|
888
|
+
loss/estimating functions when evaluated should sum to approximately zero across users. These
|
|
889
|
+
values have been stacked and averaged across users in mean_estimating_function_stack, which
|
|
890
|
+
we simply compare to the zero vector. We can isolate components for each update and inference
|
|
891
|
+
by considering the dimensions of the beta vectors and the theta vector.
|
|
892
|
+
|
|
893
|
+
Inputs:
|
|
894
|
+
mean_estimating_function_stack:
|
|
895
|
+
The mean of the estimating function stack (a component for each algorithm update and
|
|
896
|
+
inference) across users. This should be a 1D array.
|
|
897
|
+
beta_dim:
|
|
898
|
+
The dimension of the beta vectors that parameterize the algorithm.
|
|
899
|
+
theta_dim:
|
|
900
|
+
The dimension of the theta vector that we estimate during after-study analysis.
|
|
901
|
+
|
|
902
|
+
Returns:
|
|
903
|
+
None
|
|
904
|
+
"""
|
|
905
|
+
|
|
906
|
+
logger.info("Checking that estimating functions average to zero across users")
|
|
907
|
+
|
|
908
|
+
# Have a looser hard failure cutoff before the typical interactive check
|
|
909
|
+
try:
|
|
910
|
+
np.testing.assert_allclose(
|
|
911
|
+
mean_estimating_function_stack,
|
|
912
|
+
jnp.zeros(mean_estimating_function_stack.size),
|
|
913
|
+
atol=1e-2,
|
|
914
|
+
)
|
|
915
|
+
except AssertionError as e:
|
|
916
|
+
logger.info(
|
|
917
|
+
"Estimating function stacks do not average to within loose tolerance of zero across users. Drilling in to specific updates and inference component."
|
|
918
|
+
)
|
|
919
|
+
# If this is not true there is an internal problem in the package.
|
|
920
|
+
assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
|
|
921
|
+
num_updates = (mean_estimating_function_stack.size - theta_dim) // beta_dim
|
|
922
|
+
for i in range(num_updates):
|
|
923
|
+
logger.info(
|
|
924
|
+
"Mean estimating function contribution for update %s:\n%s",
|
|
925
|
+
i + 1,
|
|
926
|
+
mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
|
|
927
|
+
)
|
|
928
|
+
logger.info(
|
|
929
|
+
"Mean estimating function contribution for inference:\n%s",
|
|
930
|
+
mean_estimating_function_stack[-theta_dim:],
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
raise e
|
|
934
|
+
|
|
935
|
+
logger.info(
|
|
936
|
+
"Estimating functions pass loose tolerance check, proceeding to tighter check."
|
|
937
|
+
)
|
|
938
|
+
try:
|
|
939
|
+
np.testing.assert_allclose(
|
|
940
|
+
mean_estimating_function_stack,
|
|
941
|
+
jnp.zeros(mean_estimating_function_stack.size),
|
|
942
|
+
atol=1e-5,
|
|
943
|
+
)
|
|
944
|
+
except AssertionError as e:
|
|
945
|
+
logger.info(
|
|
946
|
+
"Estimating function stacks do not average to within specified tolerance of zero across users. Drilling in to specific updates and inference component."
|
|
947
|
+
)
|
|
948
|
+
# If this is not true there is an internal problem in the package.
|
|
949
|
+
assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
|
|
950
|
+
num_updates = (mean_estimating_function_stack.size - theta_dim) // beta_dim
|
|
951
|
+
for i in range(num_updates):
|
|
952
|
+
logger.info(
|
|
953
|
+
"Mean estimating function contribution for update %s:\n%s",
|
|
954
|
+
i + 1,
|
|
955
|
+
mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
|
|
956
|
+
)
|
|
957
|
+
logger.info(
|
|
958
|
+
"Mean estimating function contribution for inference:\n%s",
|
|
959
|
+
mean_estimating_function_stack[-theta_dim:],
|
|
960
|
+
)
|
|
961
|
+
confirm_input_check_result(
|
|
962
|
+
f"\nEstimating functions do not average to within default tolerance of zero vector. Please decide if the following is a reasonable result, taking into account the above breakdown by update number and inference. If not, there are several possible reasons for failure mentioned in the contract. Results:\n{str(e)}\n\nContinue? (y/n)\n",
|
|
963
|
+
suppress_interactive_data_checks,
|
|
964
|
+
e,
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
|
|
968
|
+
def require_RL_estimating_functions_sum_to_zero(
|
|
969
|
+
mean_estimating_function_stack: jnp.ndarray,
|
|
970
|
+
beta_dim: int,
|
|
971
|
+
suppress_interactive_data_checks: bool,
|
|
972
|
+
):
|
|
973
|
+
"""
|
|
974
|
+
This is a test that the correct loss/estimating functions have
|
|
975
|
+
been given for both the algorithm updates and inference. If that is true, then the
|
|
976
|
+
loss/estimating functions when evaluated should sum to approximately zero across users. These
|
|
977
|
+
values have been stacked and averaged across users in mean_estimating_function_stack, which
|
|
978
|
+
we simply compare to the zero vector. We can isolate components for each update and inference
|
|
979
|
+
by considering the dimensions of the beta vectors and the theta vector.
|
|
980
|
+
|
|
981
|
+
Inputs:
|
|
982
|
+
mean_estimating_function_stack:
|
|
983
|
+
The mean of the estimating function stack (a component for each algorithm update and
|
|
984
|
+
inference) across users. This should be a 1D array.
|
|
985
|
+
beta_dim:
|
|
986
|
+
The dimension of the beta vectors that parameterize the algorithm.
|
|
987
|
+
theta_dim:
|
|
988
|
+
The dimension of the theta vector that we estimate during after-study analysis.
|
|
989
|
+
|
|
990
|
+
Returns:
|
|
991
|
+
None
|
|
992
|
+
"""
|
|
993
|
+
|
|
994
|
+
logger.info("Checking that RL estimating functions average to zero across users")
|
|
995
|
+
|
|
996
|
+
# Have a looser hard failure cutoff before the typical interactive check
|
|
997
|
+
try:
|
|
998
|
+
np.testing.assert_allclose(
|
|
999
|
+
mean_estimating_function_stack,
|
|
1000
|
+
jnp.zeros(mean_estimating_function_stack.size),
|
|
1001
|
+
atol=1e-2,
|
|
1002
|
+
)
|
|
1003
|
+
except AssertionError as e:
|
|
1004
|
+
logger.info(
|
|
1005
|
+
"RL estimating function stacks do not average to zero across users. Drilling in to specific updates and inference component."
|
|
1006
|
+
)
|
|
1007
|
+
num_updates = (mean_estimating_function_stack.size) // beta_dim
|
|
1008
|
+
for i in range(num_updates):
|
|
1009
|
+
logger.info(
|
|
1010
|
+
"Mean estimating function contribution for update %s:\n%s",
|
|
1011
|
+
i + 1,
|
|
1012
|
+
mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
|
|
1013
|
+
)
|
|
1014
|
+
# TODO: We may need to email instead of failing here for monitoring algorithm.
|
|
1015
|
+
raise e
|
|
1016
|
+
|
|
1017
|
+
try:
|
|
1018
|
+
np.testing.assert_allclose(
|
|
1019
|
+
mean_estimating_function_stack,
|
|
1020
|
+
jnp.zeros(mean_estimating_function_stack.size),
|
|
1021
|
+
atol=1e-5,
|
|
1022
|
+
)
|
|
1023
|
+
except AssertionError as e:
|
|
1024
|
+
logger.info(
|
|
1025
|
+
"RL estimating function stacks do not average to zero across users. Drilling in to specific updates and inference component."
|
|
1026
|
+
)
|
|
1027
|
+
num_updates = (mean_estimating_function_stack.size) // beta_dim
|
|
1028
|
+
for i in range(num_updates):
|
|
1029
|
+
logger.info(
|
|
1030
|
+
"Mean estimating function contribution for update %s:\n%s",
|
|
1031
|
+
i + 1,
|
|
1032
|
+
mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
|
|
1033
|
+
)
|
|
1034
|
+
# TODO: Email instead of requiring user input for monitoring alg.
|
|
1035
|
+
confirm_input_check_result(
|
|
1036
|
+
f"\nEstimating functions do not average to within default tolerance of zero vector. Please decide if the following is a reasonable result, taking into account the above breakdown by update number and inference. If not, there are several possible reasons for failure mentioned in the contract. Results:\n{str(e)}\n\nContinue? (y/n)\n",
|
|
1037
|
+
suppress_interactive_data_checks,
|
|
1038
|
+
e,
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
def require_adaptive_bread_inverse_is_true_inverse(
|
|
1043
|
+
joint_adaptive_bread_matrix,
|
|
1044
|
+
joint_adaptive_bread_inverse_matrix,
|
|
1045
|
+
suppress_interactive_data_checks,
|
|
1046
|
+
):
|
|
1047
|
+
"""
|
|
1048
|
+
Check that the product of the joint adaptive bread matrix and its inverse is
|
|
1049
|
+
sufficiently close to the identity matrix. This is a direct check that the
|
|
1050
|
+
joint_adaptive_bread_inverse_matrix we create is "well-conditioned".
|
|
1051
|
+
"""
|
|
1052
|
+
should_be_identity = (
|
|
1053
|
+
joint_adaptive_bread_matrix @ joint_adaptive_bread_inverse_matrix
|
|
1054
|
+
)
|
|
1055
|
+
identity = np.eye(joint_adaptive_bread_matrix.shape[0])
|
|
1056
|
+
try:
|
|
1057
|
+
np.testing.assert_allclose(
|
|
1058
|
+
should_be_identity,
|
|
1059
|
+
identity,
|
|
1060
|
+
rtol=1e-5,
|
|
1061
|
+
atol=1e-5,
|
|
1062
|
+
)
|
|
1063
|
+
except AssertionError as e:
|
|
1064
|
+
confirm_input_check_result(
|
|
1065
|
+
f"\nJoint adaptive bread is not exact inverse of the constructed matrix that was inverted to form it. This likely illustrates poor conditioning:\n{str(e)}\n\nContinue? (y/n)\n",
|
|
1066
|
+
suppress_interactive_data_checks,
|
|
1067
|
+
e,
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
# If we haven't already errored out, return some measures of how far off we are from identity
|
|
1071
|
+
diff = should_be_identity - identity
|
|
1072
|
+
logger.debug(
|
|
1073
|
+
"Difference between should-be-identity produced by multiplying joint adaptive bread inverse and its computed inverse and actual identity:\n%s",
|
|
1074
|
+
diff,
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
diff_abs_max = np.max(np.abs(diff))
|
|
1078
|
+
diff_frobenius_norm = np.linalg.norm(diff, "fro")
|
|
1079
|
+
|
|
1080
|
+
logger.info("Maximum abs element of difference: %s", diff_abs_max)
|
|
1081
|
+
logger.info("Frobenius norm of difference: %s", diff_frobenius_norm)
|
|
1082
|
+
|
|
1083
|
+
return diff_abs_max, diff_frobenius_norm
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
def require_threaded_algorithm_estimating_function_args_equivalent(
|
|
1087
|
+
algorithm_estimating_func,
|
|
1088
|
+
update_func_args_by_by_user_id_by_policy_num,
|
|
1089
|
+
threaded_update_func_args_by_policy_num_by_user_id,
|
|
1090
|
+
suppress_interactive_data_checks,
|
|
1091
|
+
):
|
|
1092
|
+
"""
|
|
1093
|
+
Check that the algorithm estimating function returns the same values
|
|
1094
|
+
when called with the original arguments and when called with the
|
|
1095
|
+
reconstructed action probabilities substituted in.
|
|
1096
|
+
"""
|
|
1097
|
+
for (
|
|
1098
|
+
policy_num,
|
|
1099
|
+
update_func_args_by_user_id,
|
|
1100
|
+
) in update_func_args_by_by_user_id_by_policy_num.items():
|
|
1101
|
+
for (
|
|
1102
|
+
user_id,
|
|
1103
|
+
unthreaded_args,
|
|
1104
|
+
) in update_func_args_by_user_id.items():
|
|
1105
|
+
if not unthreaded_args:
|
|
1106
|
+
continue
|
|
1107
|
+
np.testing.assert_allclose(
|
|
1108
|
+
algorithm_estimating_func(*unthreaded_args),
|
|
1109
|
+
# Need to stop gradient here because we can't convert a traced value to np array
|
|
1110
|
+
jax.lax.stop_gradient(
|
|
1111
|
+
algorithm_estimating_func(
|
|
1112
|
+
*threaded_update_func_args_by_policy_num_by_user_id[user_id][
|
|
1113
|
+
policy_num
|
|
1114
|
+
]
|
|
1115
|
+
)
|
|
1116
|
+
),
|
|
1117
|
+
atol=1e-7,
|
|
1118
|
+
rtol=1e-3,
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def require_threaded_inference_estimating_function_args_equivalent(
|
|
1123
|
+
inference_estimating_func,
|
|
1124
|
+
inference_func_args_by_user_id,
|
|
1125
|
+
threaded_inference_func_args_by_user_id,
|
|
1126
|
+
suppress_interactive_data_checks,
|
|
1127
|
+
):
|
|
1128
|
+
"""
|
|
1129
|
+
Check that the inference estimating function returns the same values
|
|
1130
|
+
when called with the original arguments and when called with the
|
|
1131
|
+
reconstructed action probabilities substituted in.
|
|
1132
|
+
"""
|
|
1133
|
+
for user_id, unthreaded_args in inference_func_args_by_user_id.items():
|
|
1134
|
+
if not unthreaded_args:
|
|
1135
|
+
continue
|
|
1136
|
+
np.testing.assert_allclose(
|
|
1137
|
+
inference_estimating_func(*unthreaded_args),
|
|
1138
|
+
# Need to stop gradient here because we can't convert a traced value to np array
|
|
1139
|
+
jax.lax.stop_gradient(
|
|
1140
|
+
inference_estimating_func(
|
|
1141
|
+
*threaded_inference_func_args_by_user_id[user_id]
|
|
1142
|
+
)
|
|
1143
|
+
),
|
|
1144
|
+
rtol=1e-2,
|
|
1145
|
+
)
|