lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1145 @@
1
+ import collections
2
+ import logging
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import jax
7
+ from jax import numpy as jnp
8
+ import pandas as pd
9
+ import plotext as plt
10
+
11
+ from .constants import InverseStabilizationMethods, SmallSampleCorrections
12
+ from .helper_functions import (
13
+ confirm_input_check_result,
14
+ )
15
+
16
+ # When we print out objects for debugging, show the whole thing.
17
+ np.set_printoptions(threshold=np.inf)
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logging.basicConfig(
21
+ format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
22
+ datefmt="%Y-%m-%d:%H:%M:%S",
23
+ level=logging.INFO,
24
+ )
25
+
26
+
27
+ # TODO: any checks needed here about alg update function type?
28
+ def perform_first_wave_input_checks(
29
+ study_df,
30
+ in_study_col_name,
31
+ action_col_name,
32
+ policy_num_col_name,
33
+ calendar_t_col_name,
34
+ user_id_col_name,
35
+ action_prob_col_name,
36
+ reward_col_name,
37
+ action_prob_func,
38
+ action_prob_func_args,
39
+ action_prob_func_args_beta_index,
40
+ alg_update_func_args,
41
+ alg_update_func_args_beta_index,
42
+ alg_update_func_args_action_prob_index,
43
+ alg_update_func_args_action_prob_times_index,
44
+ theta_est,
45
+ beta_dim,
46
+ suppress_interactive_data_checks,
47
+ small_sample_correction,
48
+ ):
49
+ ### Validate algorithm loss/estimating function and args
50
+ require_alg_update_args_given_for_all_users_at_each_update(
51
+ study_df, user_id_col_name, alg_update_func_args
52
+ )
53
+ require_no_policy_numbers_present_in_alg_update_args_but_not_study_df(
54
+ study_df, policy_num_col_name, alg_update_func_args
55
+ )
56
+ require_beta_is_1D_array_in_alg_update_args(
57
+ alg_update_func_args, alg_update_func_args_beta_index
58
+ )
59
+ require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
60
+ study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
61
+ )
62
+ confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
63
+ alg_update_func_args_action_prob_index, suppress_interactive_data_checks
64
+ )
65
+ require_action_prob_times_given_if_index_supplied(
66
+ alg_update_func_args_action_prob_index,
67
+ alg_update_func_args_action_prob_times_index,
68
+ )
69
+ require_action_prob_index_given_if_times_supplied(
70
+ alg_update_func_args_action_prob_index,
71
+ alg_update_func_args_action_prob_times_index,
72
+ )
73
+ require_betas_match_in_alg_update_args_each_update(
74
+ alg_update_func_args, alg_update_func_args_beta_index
75
+ )
76
+ require_action_prob_args_in_alg_update_func_correspond_to_study_df(
77
+ study_df,
78
+ action_prob_col_name,
79
+ calendar_t_col_name,
80
+ user_id_col_name,
81
+ alg_update_func_args,
82
+ alg_update_func_args_action_prob_index,
83
+ alg_update_func_args_action_prob_times_index,
84
+ )
85
+ require_valid_action_prob_times_given_if_index_supplied(
86
+ study_df,
87
+ calendar_t_col_name,
88
+ alg_update_func_args,
89
+ alg_update_func_args_action_prob_times_index,
90
+ )
91
+
92
+ confirm_no_small_sample_correction_desired_if_not_requested(
93
+ small_sample_correction, suppress_interactive_data_checks
94
+ )
95
+
96
+ ### Validate action prob function and args
97
+ require_action_prob_func_args_given_for_all_users_at_each_decision(
98
+ study_df, user_id_col_name, action_prob_func_args
99
+ )
100
+ require_action_prob_func_args_given_for_all_decision_times(
101
+ study_df, calendar_t_col_name, action_prob_func_args
102
+ )
103
+ require_action_probabilities_in_study_df_can_be_reconstructed(
104
+ study_df,
105
+ action_prob_col_name,
106
+ calendar_t_col_name,
107
+ user_id_col_name,
108
+ in_study_col_name,
109
+ action_prob_func_args,
110
+ action_prob_func,
111
+ )
112
+
113
+ require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
114
+ study_df,
115
+ calendar_t_col_name,
116
+ action_prob_func_args,
117
+ in_study_col_name,
118
+ user_id_col_name,
119
+ )
120
+ require_beta_is_1D_array_in_action_prob_args(
121
+ action_prob_func_args, action_prob_func_args_beta_index
122
+ )
123
+ require_betas_match_in_action_prob_func_args_each_decision(
124
+ action_prob_func_args, action_prob_func_args_beta_index
125
+ )
126
+
127
+ ### Validate study_df
128
+ verify_study_df_summary_satisfactory(
129
+ study_df,
130
+ user_id_col_name,
131
+ policy_num_col_name,
132
+ calendar_t_col_name,
133
+ in_study_col_name,
134
+ action_prob_col_name,
135
+ reward_col_name,
136
+ beta_dim,
137
+ len(theta_est),
138
+ suppress_interactive_data_checks,
139
+ )
140
+
141
+ require_all_users_have_all_times_in_study_df(
142
+ study_df, calendar_t_col_name, user_id_col_name
143
+ )
144
+ require_all_named_columns_present_in_study_df(
145
+ study_df,
146
+ in_study_col_name,
147
+ action_col_name,
148
+ policy_num_col_name,
149
+ calendar_t_col_name,
150
+ user_id_col_name,
151
+ action_prob_col_name,
152
+ )
153
+ require_all_named_columns_not_object_type_in_study_df(
154
+ study_df,
155
+ in_study_col_name,
156
+ action_col_name,
157
+ policy_num_col_name,
158
+ calendar_t_col_name,
159
+ user_id_col_name,
160
+ action_prob_col_name,
161
+ )
162
+ require_binary_actions(study_df, in_study_col_name, action_col_name)
163
+ require_binary_in_study_indicators(study_df, in_study_col_name)
164
+ require_consecutive_integer_policy_numbers(
165
+ study_df, in_study_col_name, policy_num_col_name
166
+ )
167
+ require_consecutive_integer_calendar_times(study_df, calendar_t_col_name)
168
+ require_hashable_user_ids(study_df, in_study_col_name, user_id_col_name)
169
+ require_action_probabilities_in_range_0_to_1(study_df, action_prob_col_name)
170
+
171
+ ### Validate theta estimation
172
+ require_theta_is_1D_array(theta_est)
173
+
174
+
175
+ def perform_alg_only_input_checks(
176
+ study_df,
177
+ in_study_col_name,
178
+ policy_num_col_name,
179
+ calendar_t_col_name,
180
+ user_id_col_name,
181
+ action_prob_col_name,
182
+ action_prob_func,
183
+ action_prob_func_args,
184
+ action_prob_func_args_beta_index,
185
+ alg_update_func_args,
186
+ alg_update_func_args_beta_index,
187
+ alg_update_func_args_action_prob_index,
188
+ alg_update_func_args_action_prob_times_index,
189
+ suppress_interactive_data_checks,
190
+ ):
191
+ ### Validate algorithm loss/estimating function and args
192
+ require_alg_update_args_given_for_all_users_at_each_update(
193
+ study_df, user_id_col_name, alg_update_func_args
194
+ )
195
+ require_beta_is_1D_array_in_alg_update_args(
196
+ alg_update_func_args, alg_update_func_args_beta_index
197
+ )
198
+ require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
199
+ study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
200
+ )
201
+ confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
202
+ alg_update_func_args_action_prob_index, suppress_interactive_data_checks
203
+ )
204
+ require_action_prob_times_given_if_index_supplied(
205
+ alg_update_func_args_action_prob_index,
206
+ alg_update_func_args_action_prob_times_index,
207
+ )
208
+ require_action_prob_index_given_if_times_supplied(
209
+ alg_update_func_args_action_prob_index,
210
+ alg_update_func_args_action_prob_times_index,
211
+ )
212
+ require_betas_match_in_alg_update_args_each_update(
213
+ alg_update_func_args, alg_update_func_args_beta_index
214
+ )
215
+ require_action_prob_args_in_alg_update_func_correspond_to_study_df(
216
+ study_df,
217
+ action_prob_col_name,
218
+ calendar_t_col_name,
219
+ user_id_col_name,
220
+ alg_update_func_args,
221
+ alg_update_func_args_action_prob_index,
222
+ alg_update_func_args_action_prob_times_index,
223
+ )
224
+ require_valid_action_prob_times_given_if_index_supplied(
225
+ study_df,
226
+ calendar_t_col_name,
227
+ alg_update_func_args,
228
+ alg_update_func_args_action_prob_times_index,
229
+ )
230
+
231
+ ### Validate action prob function and args
232
+ require_action_prob_func_args_given_for_all_users_at_each_decision(
233
+ study_df, user_id_col_name, action_prob_func_args
234
+ )
235
+ require_action_prob_func_args_given_for_all_decision_times(
236
+ study_df, calendar_t_col_name, action_prob_func_args
237
+ )
238
+ require_action_probabilities_in_study_df_can_be_reconstructed(
239
+ study_df,
240
+ action_prob_col_name,
241
+ calendar_t_col_name,
242
+ user_id_col_name,
243
+ in_study_col_name,
244
+ action_prob_func_args,
245
+ action_prob_func=action_prob_func,
246
+ )
247
+
248
+ require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
249
+ study_df,
250
+ calendar_t_col_name,
251
+ action_prob_func_args,
252
+ in_study_col_name,
253
+ user_id_col_name,
254
+ )
255
+ require_beta_is_1D_array_in_action_prob_args(
256
+ action_prob_func_args, action_prob_func_args_beta_index
257
+ )
258
+ require_betas_match_in_action_prob_func_args_each_decision(
259
+ action_prob_func_args, action_prob_func_args_beta_index
260
+ )
261
+
262
+
263
+ # TODO: Give a hard-to-use option to loosen this check somehow
264
+ def require_action_probabilities_in_study_df_can_be_reconstructed(
265
+ study_df,
266
+ action_prob_col_name,
267
+ calendar_t_col_name,
268
+ user_id_col_name,
269
+ in_study_col_name,
270
+ action_prob_func_args,
271
+ action_prob_func,
272
+ ):
273
+ """
274
+ Check that the action probabilities in the study dataframe can be reconstructed from the supplied
275
+ action probability function and its arguments.
276
+
277
+ NOTE THAT THIS IS A HARD FAILURE IF THE RECONSTRUCTION DOESN'T PASS.
278
+ """
279
+ logger.info("Reconstructing action probabilities from function and arguments.")
280
+
281
+ in_study_df = study_df[study_df[in_study_col_name] == 1]
282
+ reconstructed_action_probs = in_study_df.apply(
283
+ lambda row: action_prob_func(
284
+ *action_prob_func_args[row[calendar_t_col_name]][row[user_id_col_name]]
285
+ ),
286
+ axis=1,
287
+ )
288
+
289
+ np.testing.assert_allclose(
290
+ in_study_df[action_prob_col_name].to_numpy(dtype="float64"),
291
+ reconstructed_action_probs.to_numpy(dtype="float64"),
292
+ atol=1e-6,
293
+ )
294
+
295
+
296
+ def require_all_users_have_all_times_in_study_df(
297
+ study_df, calendar_t_col_name, user_id_col_name
298
+ ):
299
+ logger.info("Checking that all users have the same set of unique calendar times.")
300
+ # Get the unique calendar times
301
+ unique_calendar_times = set(study_df[calendar_t_col_name].unique())
302
+
303
+ # Group by user ID and aggregate the unique calendar times for each user
304
+ user_calendar_times = study_df.groupby(user_id_col_name)[calendar_t_col_name].apply(
305
+ set
306
+ )
307
+
308
+ # Check if all users have the same set of unique calendar times
309
+ if not user_calendar_times.apply(lambda x: x == unique_calendar_times).all():
310
+ raise AssertionError(
311
+ "Not all users have all calendar times in the study dataframe. Please see the contract for details."
312
+ )
313
+
314
+
315
+ def require_alg_update_args_given_for_all_users_at_each_update(
316
+ study_df, user_id_col_name, alg_update_func_args
317
+ ):
318
+ logger.info(
319
+ "Checking that algorithm update function args are given for all users at each update."
320
+ )
321
+ all_user_ids = set(study_df[user_id_col_name].unique())
322
+ for policy_num in alg_update_func_args:
323
+ assert (
324
+ set(alg_update_func_args[policy_num].keys()) == all_user_ids
325
+ ), f"Not all users present in algorithm update function args for policy number {policy_num}. Please see the contract for details."
326
+
327
+
328
+ def require_action_prob_args_in_alg_update_func_correspond_to_study_df(
329
+ study_df,
330
+ action_prob_col_name,
331
+ calendar_t_col_name,
332
+ user_id_col_name,
333
+ alg_update_func_args,
334
+ alg_update_func_args_action_prob_index,
335
+ alg_update_func_args_action_prob_times_index,
336
+ ):
337
+ logger.info(
338
+ "Checking that the action probabilities supplied in the algorithm update function args, if"
339
+ " any, correspond to those in the study dataframe for the corresponding users and decision"
340
+ " times."
341
+ )
342
+ if alg_update_func_args_action_prob_index < 0:
343
+ return
344
+
345
+ # Precompute a lookup dictionary for faster access
346
+ study_df_lookup = study_df.set_index([calendar_t_col_name, user_id_col_name])[
347
+ action_prob_col_name
348
+ ].to_dict()
349
+
350
+ for policy_num, user_args in alg_update_func_args.items():
351
+ for user_id, args in user_args.items():
352
+ if not args:
353
+ continue
354
+ arg_action_probs = args[alg_update_func_args_action_prob_index]
355
+ action_prob_times = args[
356
+ alg_update_func_args_action_prob_times_index
357
+ ].flatten()
358
+
359
+ # Use the precomputed lookup dictionary
360
+ study_df_action_probs = [
361
+ study_df_lookup[(decision_time.item(), user_id)]
362
+ for decision_time in action_prob_times
363
+ ]
364
+
365
+ assert np.allclose(
366
+ arg_action_probs.flatten(),
367
+ study_df_action_probs,
368
+ ), (
369
+ f"There is a mismatch for user {user_id} between the action probabilities supplied"
370
+ f" in the args to the algorithm update function at policy {policy_num} and those in"
371
+ " the study dataframe for the supplied times. Please see the contract for details."
372
+ )
373
+
374
+
375
+ def require_action_prob_func_args_given_for_all_users_at_each_decision(
376
+ study_df,
377
+ user_id_col_name,
378
+ action_prob_func_args,
379
+ ):
380
+ logger.info(
381
+ "Checking that action prob function args are given for all users at each decision time."
382
+ )
383
+ all_user_ids = set(study_df[user_id_col_name].unique())
384
+ for decision_time in action_prob_func_args:
385
+ assert (
386
+ set(action_prob_func_args[decision_time].keys()) == all_user_ids
387
+ ), f"Not all users present in algorithm update function args for decision time {decision_time}. Please see the contract for details."
388
+
389
+
390
+ def require_action_prob_func_args_given_for_all_decision_times(
391
+ study_df, calendar_t_col_name, action_prob_func_args
392
+ ):
393
+ logger.info(
394
+ "Checking that action prob function args are given for all decision times."
395
+ )
396
+ all_times = set(study_df[calendar_t_col_name].unique())
397
+
398
+ assert (
399
+ set(action_prob_func_args.keys()) == all_times
400
+ ), "Not all decision times present in action prob function args. Please see the contract for details."
401
+
402
+
403
+ def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times(
404
+ study_df: pd.DataFrame,
405
+ calendar_t_col_name: str,
406
+ action_prob_func_args: dict[str, dict[str, tuple[Any, ...]]],
407
+ in_study_col_name,
408
+ user_id_col_name,
409
+ ):
410
+ logger.info(
411
+ "Checking that action probability function args are blank for exactly the times each user"
412
+ " is not in the study according to the study dataframe."
413
+ )
414
+ out_of_study_df = study_df[study_df[in_study_col_name] == 0]
415
+ out_of_study_times_by_user_according_to_study_df = (
416
+ out_of_study_df.groupby(user_id_col_name)[calendar_t_col_name]
417
+ .apply(set)
418
+ .to_dict()
419
+ )
420
+
421
+ out_of_study_times_by_user_according_to_action_prob_func_args = (
422
+ collections.defaultdict(set)
423
+ )
424
+ for decision_time, action_prob_args_by_user in action_prob_func_args.items():
425
+ for user_id, action_prob_args in action_prob_args_by_user.items():
426
+ if not action_prob_args:
427
+ out_of_study_times_by_user_according_to_action_prob_func_args[
428
+ user_id
429
+ ].add(decision_time)
430
+
431
+ assert (
432
+ out_of_study_times_by_user_according_to_study_df
433
+ == out_of_study_times_by_user_according_to_action_prob_func_args
434
+ ), (
435
+ "Out-of-study decision times according to the study dataframe do not match up with the"
436
+ " times for which action probability arguments are blank for all users. Please see the"
437
+ " contract for details."
438
+ )
439
+
440
+
441
+ def require_all_named_columns_present_in_study_df(
442
+ study_df,
443
+ in_study_col_name,
444
+ action_col_name,
445
+ policy_num_col_name,
446
+ calendar_t_col_name,
447
+ user_id_col_name,
448
+ action_prob_col_name,
449
+ ):
450
+ logger.info("Checking that all named columns are present in the study dataframe.")
451
+ assert (
452
+ in_study_col_name in study_df.columns
453
+ ), f"{in_study_col_name} not in study df."
454
+ assert action_col_name in study_df.columns, f"{action_col_name} not in study df."
455
+ assert (
456
+ policy_num_col_name in study_df.columns
457
+ ), f"{policy_num_col_name} not in study df."
458
+ assert (
459
+ calendar_t_col_name in study_df.columns
460
+ ), f"{calendar_t_col_name} not in study df."
461
+ assert user_id_col_name in study_df.columns, f"{user_id_col_name} not in study df."
462
+ assert (
463
+ action_prob_col_name in study_df.columns
464
+ ), f"{action_prob_col_name} not in study df."
465
+
466
+
467
+ def require_all_named_columns_not_object_type_in_study_df(
468
+ study_df,
469
+ in_study_col_name,
470
+ action_col_name,
471
+ policy_num_col_name,
472
+ calendar_t_col_name,
473
+ user_id_col_name,
474
+ action_prob_col_name,
475
+ ):
476
+ logger.info("Checking that all named columns are not type object.")
477
+ for colname in (
478
+ in_study_col_name,
479
+ action_col_name,
480
+ policy_num_col_name,
481
+ calendar_t_col_name,
482
+ user_id_col_name,
483
+ action_prob_col_name,
484
+ ):
485
+ assert (
486
+ study_df[colname].dtype != "object"
487
+ ), f"At least {colname} is of object type in study df."
488
+
489
+
490
+ def require_binary_actions(study_df, in_study_col_name, action_col_name):
491
+ logger.info("Checking that actions are binary.")
492
+ assert (
493
+ study_df[study_df[in_study_col_name] == 1][action_col_name]
494
+ .astype("int64")
495
+ .isin([0, 1])
496
+ .all()
497
+ ), "Actions are not binary."
498
+
499
+
500
+ def require_binary_in_study_indicators(study_df, in_study_col_name):
501
+ logger.info("Checking that in-study indicators are binary.")
502
+ assert (
503
+ study_df[study_df[in_study_col_name] == 1][in_study_col_name]
504
+ .astype("int64")
505
+ .isin([0, 1])
506
+ .all()
507
+ ), "In-study indicators are not binary."
508
+
509
+
510
+ def require_consecutive_integer_policy_numbers(
511
+ study_df, in_study_col_name, policy_num_col_name
512
+ ):
513
+ # TODO: This is a somewhat rough check of this, could also check nondecreasing temporally
514
+
515
+ logger.info(
516
+ "Checking that in-study, non-fallback policy numbers are consecutive integers."
517
+ )
518
+
519
+ in_study_df = study_df[study_df[in_study_col_name] == 1]
520
+ nonnegative_policy_df = in_study_df[in_study_df[policy_num_col_name] >= 0]
521
+ # Ideally we actually have integers, but for legacy reasons we will support
522
+ # floats as well.
523
+ if nonnegative_policy_df[policy_num_col_name].dtype == "float64":
524
+ nonnegative_policy_df[policy_num_col_name] = nonnegative_policy_df[
525
+ policy_num_col_name
526
+ ].astype("int64")
527
+ assert np.array_equal(
528
+ nonnegative_policy_df[policy_num_col_name].unique(),
529
+ range(
530
+ nonnegative_policy_df[policy_num_col_name].min(),
531
+ nonnegative_policy_df[policy_num_col_name].max() + 1,
532
+ ),
533
+ ), "Policy numbers are not consecutive integers."
534
+
535
+
536
+ def require_consecutive_integer_calendar_times(study_df, calendar_t_col_name):
537
+ # This is a somewhat rough check of this, more like checking there are no
538
+ # gaps in the integers covered. But we have other checks that all users
539
+ # have same times, etc.
540
+ # Note these times should be well-formed even when the user is not in the study.
541
+ logger.info("Checking that calendar times are consecutive integers.")
542
+ assert np.array_equal(
543
+ study_df[calendar_t_col_name].unique(),
544
+ range(
545
+ study_df[calendar_t_col_name].min(), study_df[calendar_t_col_name].max() + 1
546
+ ),
547
+ ), "Calendar times are not consecutive integers."
548
+
549
+
550
+ def require_hashable_user_ids(study_df, in_study_col_name, user_id_col_name):
551
+ logger.info("Checking that user IDs are hashable.")
552
+ isinstance(
553
+ study_df[study_df[in_study_col_name] == 1][user_id_col_name][0],
554
+ collections.abc.Hashable,
555
+ )
556
+
557
+
558
+ def require_action_probabilities_in_range_0_to_1(study_df, action_prob_col_name):
559
+ logger.info("Checking that action probabilities are in the interval (0, 1).")
560
+ study_df[action_prob_col_name].between(0, 1, inclusive="neither").all()
561
+
562
+
563
+ def require_no_policy_numbers_present_in_alg_update_args_but_not_study_df(
564
+ study_df, policy_num_col_name, alg_update_func_args
565
+ ):
566
+ logger.info(
567
+ "Checking that policy numbers in algorithm update function args are present in the study dataframe."
568
+ )
569
+ alg_update_policy_nums = sorted(alg_update_func_args.keys())
570
+ study_df_policy_nums = sorted(study_df[policy_num_col_name].unique())
571
+ assert set(alg_update_policy_nums).issubset(set(study_df_policy_nums)), (
572
+ f"There are policy numbers present in algorithm update function args but not in the study dataframe. "
573
+ f"\nalg_update_func_args policy numbers: {alg_update_policy_nums}"
574
+ f"\nstudy_df policy numbers: {study_df_policy_nums}.\nPlease see the contract for details."
575
+ )
576
+
577
+
578
+ def require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
579
+ study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
580
+ ):
581
+ logger.info(
582
+ "Checking that all policy numbers in the study dataframe are present in the algorithm update function args."
583
+ )
584
+ in_study_df = study_df[study_df[in_study_col_name] == 1]
585
+ # Get the number of the initial policy. 0 is recommended but not required.
586
+ min_nonnegative_policy_number = in_study_df[in_study_df[policy_num_col_name] >= 0][
587
+ policy_num_col_name
588
+ ]
589
+ assert set(
590
+ in_study_df[in_study_df[policy_num_col_name] > min_nonnegative_policy_number][
591
+ policy_num_col_name
592
+ ].unique()
593
+ ).issubset(
594
+ alg_update_func_args.keys()
595
+ ), f"There are non-fallback, non-initial policy numbers in the study dataframe that are not in the update function args: {set(in_study_df[in_study_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
596
+
597
+
598
+ def confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
599
+ alg_update_func_args_action_prob_index,
600
+ suppress_interactive_data_checks,
601
+ ):
602
+ logger.info(
603
+ "Confirming that action probabilities are not in algorithm update function args IF their index is not specified"
604
+ )
605
+ if alg_update_func_args_action_prob_index < 0:
606
+ confirm_input_check_result(
607
+ "\nYou specified that the algorithm update function supplied does not have action probabilities as one of its arguments. Please verify this is correct.\n\nContinue? (y/n)\n",
608
+ suppress_interactive_data_checks,
609
+ )
610
+
611
+
612
+ def confirm_no_small_sample_correction_desired_if_not_requested(
613
+ small_sample_correction,
614
+ suppress_interactive_data_checks,
615
+ ):
616
+ logger.info(
617
+ "Confirming that no small sample correction is desired if it's not requested."
618
+ )
619
+ if small_sample_correction == SmallSampleCorrections.NONE:
620
+ confirm_input_check_result(
621
+ "\nYou specified that you would not like to perform any small-sample corrections. Please verify that this is correct.\n\nContinue? (y/n)\n",
622
+ suppress_interactive_data_checks,
623
+ )
624
+
625
+
626
+ def confirm_no_adaptive_bread_inverse_stabilization_method_desired_if_not_requested(
627
+ adaptive_bread_inverse_stabilization_method,
628
+ suppress_interactive_data_checks,
629
+ ):
630
+ logger.info(
631
+ "Confirming that no adaptive bread inverse stabilization method is desired if it's not requested."
632
+ )
633
+ if adaptive_bread_inverse_stabilization_method == InverseStabilizationMethods.NONE:
634
+ confirm_input_check_result(
635
+ "\nYou specified that you would not like to perform any inverse stabilization while forming the adaptive variance. This is not usually recommended. Please verify that it is correct or select one of the available options.\n\nContinue? (y/n)\n",
636
+ suppress_interactive_data_checks,
637
+ )
638
+
639
+
640
+ def require_action_prob_times_given_if_index_supplied(
641
+ alg_update_func_args_action_prob_index,
642
+ alg_update_func_args_action_prob_times_index,
643
+ ):
644
+ logger.info("Checking that action prob times are given if index is supplied.")
645
+ if alg_update_func_args_action_prob_index >= 0:
646
+ assert alg_update_func_args_action_prob_times_index >= 0 and (
647
+ alg_update_func_args_action_prob_times_index
648
+ != alg_update_func_args_action_prob_index
649
+ )
650
+
651
+
652
+ def require_action_prob_index_given_if_times_supplied(
653
+ alg_update_func_args_action_prob_index,
654
+ alg_update_func_args_action_prob_times_index,
655
+ ):
656
+ logger.info("Checking that action prob index is given if times are supplied.")
657
+ if alg_update_func_args_action_prob_times_index >= 0:
658
+ assert alg_update_func_args_action_prob_index >= 0 and (
659
+ alg_update_func_args_action_prob_times_index
660
+ != alg_update_func_args_action_prob_index
661
+ )
662
+
663
+
664
+ def require_beta_is_1D_array_in_alg_update_args(
665
+ alg_update_func_args, alg_update_func_args_beta_index
666
+ ):
667
+ for policy_num in alg_update_func_args:
668
+ for user_id in alg_update_func_args[policy_num]:
669
+ if not alg_update_func_args[policy_num][user_id]:
670
+ continue
671
+ assert (
672
+ alg_update_func_args[policy_num][user_id][
673
+ alg_update_func_args_beta_index
674
+ ].ndim
675
+ == 1
676
+ ), "Beta is not a 1D array in the algorithm update function args."
677
+
678
+
679
+ def require_beta_is_1D_array_in_action_prob_args(
680
+ action_prob_func_args, action_prob_func_args_beta_index
681
+ ):
682
+ for decision_time in action_prob_func_args:
683
+ for user_id in action_prob_func_args[decision_time]:
684
+ if not action_prob_func_args[decision_time][user_id]:
685
+ continue
686
+ assert (
687
+ action_prob_func_args[decision_time][user_id][
688
+ action_prob_func_args_beta_index
689
+ ].ndim
690
+ == 1
691
+ ), "Beta is not a 1D array in the action probability function args."
692
+
693
+
694
+ def require_theta_is_1D_array(theta_est):
695
+ assert theta_est.ndim == 1, "Theta is not a 1D array."
696
+
697
+
698
+ def verify_study_df_summary_satisfactory(
699
+ study_df,
700
+ user_id_col_name,
701
+ policy_num_col_name,
702
+ calendar_t_col_name,
703
+ in_study_col_name,
704
+ action_prob_col_name,
705
+ reward_col_name,
706
+ beta_dim,
707
+ theta_dim,
708
+ suppress_interactive_data_checks,
709
+ ):
710
+
711
+ in_study_df = study_df[study_df[in_study_col_name] == 1]
712
+ num_users = in_study_df[user_id_col_name].nunique()
713
+ num_non_initial_or_fallback_policies = in_study_df[
714
+ in_study_df[policy_num_col_name] > 0
715
+ ][policy_num_col_name].nunique()
716
+ num_decision_times_with_fallback_policies = len(
717
+ in_study_df[in_study_df[policy_num_col_name] < 0]
718
+ )
719
+ num_decision_times = in_study_df[calendar_t_col_name].nunique()
720
+ avg_decisions_per_user = len(in_study_df) / num_users
721
+ num_decision_times_with_multiple_policies = (
722
+ in_study_df[in_study_df[policy_num_col_name] >= 0]
723
+ .groupby(calendar_t_col_name)[policy_num_col_name]
724
+ .nunique()
725
+ > 1
726
+ ).sum()
727
+ min_action_prob = in_study_df[action_prob_col_name].min()
728
+ max_action_prob = in_study_df[action_prob_col_name].max()
729
+ min_non_fallback_policy_num = in_study_df[in_study_df[policy_num_col_name] >= 0][
730
+ policy_num_col_name
731
+ ].min()
732
+ num_data_points_before_first_update = len(
733
+ in_study_df[in_study_df[policy_num_col_name] == min_non_fallback_policy_num]
734
+ )
735
+
736
+ median_action_probabilities = (
737
+ in_study_df.groupby(calendar_t_col_name)[action_prob_col_name]
738
+ .median()
739
+ .to_numpy()
740
+ )
741
+ quartiles = in_study_df.groupby(calendar_t_col_name)[action_prob_col_name].quantile(
742
+ [0.25, 0.75]
743
+ )
744
+ q25_action_probabilities = quartiles.xs(0.25, level=1).to_numpy()
745
+ q75_action_probabilities = quartiles.xs(0.75, level=1).to_numpy()
746
+
747
+ avg_rewards = in_study_df.groupby(calendar_t_col_name)[reward_col_name].mean()
748
+
749
+ # Plot action probability quartile trajectories
750
+ plt.clear_figure()
751
+ plt.title("Action 1 Probability 25/50/75 Quantile Trajectories")
752
+ plt.xlabel("Decision Time")
753
+ plt.ylabel("Action 1 Probability Quantiles")
754
+ plt.error(
755
+ median_action_probabilities,
756
+ yerr=q75_action_probabilities - q25_action_probabilities,
757
+ color="blue+",
758
+ )
759
+ plt.grid(True)
760
+ plt.xticks(
761
+ range(
762
+ 0,
763
+ len(median_action_probabilities),
764
+ max(1, len(median_action_probabilities) // 10),
765
+ )
766
+ )
767
+ action_prob_trajectory_plot = plt.build()
768
+
769
+ # Plot avg reward trajectory
770
+ plt.clear_figure()
771
+ plt.title("Avg Reward Trajectory")
772
+ plt.xlabel("Decision Time")
773
+ plt.ylabel("Avg Reward")
774
+ plt.scatter(avg_rewards, color="blue+", marker="*")
775
+ plt.grid(True)
776
+ plt.xticks(
777
+ range(
778
+ 0,
779
+ len(avg_rewards),
780
+ max(1, len(avg_rewards) // 10),
781
+ )
782
+ )
783
+ avg_reward_trajectory_plot = plt.build()
784
+
785
+ confirm_input_check_result(
786
+ f"\nYou provided a study dataframe reflecting a study with"
787
+ f"\n* {num_users} users"
788
+ f"\n* {num_non_initial_or_fallback_policies} policy updates"
789
+ f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_user}"
790
+ f" decisions per user"
791
+ f"\n* RL parameters of dimension {beta_dim} per update"
792
+ f"\n* Inferential target of dimension {theta_dim}"
793
+ f"\n* {num_data_points_before_first_update} data points before the first update"
794
+ f"\n* {num_decision_times_with_fallback_policies} decision times"
795
+ f" ({num_decision_times_with_fallback_policies * 100 / num_decision_times}%) for which"
796
+ f" fallback policies were used"
797
+ f"\n* {num_decision_times_with_multiple_policies} decision times"
798
+ f" ({num_decision_times_with_multiple_policies * 100 / num_decision_times}%)"
799
+ f" for which multiple non-fallback policies were used"
800
+ f"\n* Minimum action probability {min_action_prob}"
801
+ f"\n* Maximum action probability {max_action_prob}"
802
+ f"\n* The following trajectories of action probability quartiles over time:\n {action_prob_trajectory_plot}"
803
+ f"\n* The following average reward trajectory over time:\n {avg_reward_trajectory_plot}"
804
+ f" \n\nDoes this meet expectations? (y/n)\n",
805
+ suppress_interactive_data_checks,
806
+ )
807
+
808
+
809
+ def require_betas_match_in_alg_update_args_each_update(
810
+ alg_update_func_args, alg_update_func_args_beta_index
811
+ ):
812
+ logger.info(
813
+ "Checking that betas match across users for each update in the algorithm update function args."
814
+ )
815
+ for policy_num in alg_update_func_args:
816
+ first_beta = None
817
+ for user_id in alg_update_func_args[policy_num]:
818
+ if not alg_update_func_args[policy_num][user_id]:
819
+ continue
820
+ beta = alg_update_func_args[policy_num][user_id][
821
+ alg_update_func_args_beta_index
822
+ ]
823
+ if first_beta is None:
824
+ first_beta = beta
825
+ else:
826
+ assert np.array_equal(
827
+ beta, first_beta
828
+ ), f"Betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
829
+
830
+
831
+ def require_betas_match_in_action_prob_func_args_each_decision(
832
+ action_prob_func_args, action_prob_func_args_beta_index
833
+ ):
834
+ logger.info(
835
+ "Checking that betas match across users for each decision time in the action prob args."
836
+ )
837
+ for decision_time in action_prob_func_args:
838
+ first_beta = None
839
+ for user_id in action_prob_func_args[decision_time]:
840
+ if not action_prob_func_args[decision_time][user_id]:
841
+ continue
842
+ beta = action_prob_func_args[decision_time][user_id][
843
+ action_prob_func_args_beta_index
844
+ ]
845
+ if first_beta is None:
846
+ first_beta = beta
847
+ else:
848
+ assert np.array_equal(
849
+ beta, first_beta
850
+ ), f"Betas do not match across users in the action prob args for decision_time {decision_time}. Please see the contract for details."
851
+
852
+
853
+ def require_valid_action_prob_times_given_if_index_supplied(
854
+ study_df,
855
+ calendar_t_col_name,
856
+ alg_update_func_args,
857
+ alg_update_func_args_action_prob_times_index,
858
+ ):
859
+ logger.info("Checking that action prob times are valid if index is supplied.")
860
+
861
+ if alg_update_func_args_action_prob_times_index < 0:
862
+ return
863
+
864
+ min_time = study_df[calendar_t_col_name].min()
865
+ max_time = study_df[calendar_t_col_name].max()
866
+ for policy_idx, args_by_user in alg_update_func_args.items():
867
+ for user_id, args in args_by_user.items():
868
+ if not args:
869
+ continue
870
+ times = args[alg_update_func_args_action_prob_times_index]
871
+ assert (
872
+ times[i] > times[i - 1] for i in range(1, len(times))
873
+ ), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for user {user_id} and policy {policy_idx}. Please see the contract for details."
874
+ assert (
875
+ times[0] >= min_time and times[-1] <= max_time
876
+ ), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the study dataframe are {min_time} and {max_time}, while user {user_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
877
+
878
+
879
+ def require_estimating_functions_sum_to_zero(
880
+ mean_estimating_function_stack: jnp.ndarray,
881
+ beta_dim: int,
882
+ theta_dim: int,
883
+ suppress_interactive_data_checks: bool,
884
+ ):
885
+ """
886
+ This is a test that the correct loss/estimating functions have
887
+ been given for both the algorithm updates and inference. If that is true, then the
888
+ loss/estimating functions when evaluated should sum to approximately zero across users. These
889
+ values have been stacked and averaged across users in mean_estimating_function_stack, which
890
+ we simply compare to the zero vector. We can isolate components for each update and inference
891
+ by considering the dimensions of the beta vectors and the theta vector.
892
+
893
+ Inputs:
894
+ mean_estimating_function_stack:
895
+ The mean of the estimating function stack (a component for each algorithm update and
896
+ inference) across users. This should be a 1D array.
897
+ beta_dim:
898
+ The dimension of the beta vectors that parameterize the algorithm.
899
+ theta_dim:
900
+ The dimension of the theta vector that we estimate during after-study analysis.
901
+
902
+ Returns:
903
+ None
904
+ """
905
+
906
+ logger.info("Checking that estimating functions average to zero across users")
907
+
908
+ # Have a looser hard failure cutoff before the typical interactive check
909
+ try:
910
+ np.testing.assert_allclose(
911
+ mean_estimating_function_stack,
912
+ jnp.zeros(mean_estimating_function_stack.size),
913
+ atol=1e-2,
914
+ )
915
+ except AssertionError as e:
916
+ logger.info(
917
+ "Estimating function stacks do not average to within loose tolerance of zero across users. Drilling in to specific updates and inference component."
918
+ )
919
+ # If this is not true there is an internal problem in the package.
920
+ assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
921
+ num_updates = (mean_estimating_function_stack.size - theta_dim) // beta_dim
922
+ for i in range(num_updates):
923
+ logger.info(
924
+ "Mean estimating function contribution for update %s:\n%s",
925
+ i + 1,
926
+ mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
927
+ )
928
+ logger.info(
929
+ "Mean estimating function contribution for inference:\n%s",
930
+ mean_estimating_function_stack[-theta_dim:],
931
+ )
932
+
933
+ raise e
934
+
935
+ logger.info(
936
+ "Estimating functions pass loose tolerance check, proceeding to tighter check."
937
+ )
938
+ try:
939
+ np.testing.assert_allclose(
940
+ mean_estimating_function_stack,
941
+ jnp.zeros(mean_estimating_function_stack.size),
942
+ atol=1e-5,
943
+ )
944
+ except AssertionError as e:
945
+ logger.info(
946
+ "Estimating function stacks do not average to within specified tolerance of zero across users. Drilling in to specific updates and inference component."
947
+ )
948
+ # If this is not true there is an internal problem in the package.
949
+ assert (mean_estimating_function_stack.size - theta_dim) % beta_dim == 0
950
+ num_updates = (mean_estimating_function_stack.size - theta_dim) // beta_dim
951
+ for i in range(num_updates):
952
+ logger.info(
953
+ "Mean estimating function contribution for update %s:\n%s",
954
+ i + 1,
955
+ mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
956
+ )
957
+ logger.info(
958
+ "Mean estimating function contribution for inference:\n%s",
959
+ mean_estimating_function_stack[-theta_dim:],
960
+ )
961
+ confirm_input_check_result(
962
+ f"\nEstimating functions do not average to within default tolerance of zero vector. Please decide if the following is a reasonable result, taking into account the above breakdown by update number and inference. If not, there are several possible reasons for failure mentioned in the contract. Results:\n{str(e)}\n\nContinue? (y/n)\n",
963
+ suppress_interactive_data_checks,
964
+ e,
965
+ )
966
+
967
+
968
+ def require_RL_estimating_functions_sum_to_zero(
969
+ mean_estimating_function_stack: jnp.ndarray,
970
+ beta_dim: int,
971
+ suppress_interactive_data_checks: bool,
972
+ ):
973
+ """
974
+ This is a test that the correct loss/estimating functions have
975
+ been given for both the algorithm updates and inference. If that is true, then the
976
+ loss/estimating functions when evaluated should sum to approximately zero across users. These
977
+ values have been stacked and averaged across users in mean_estimating_function_stack, which
978
+ we simply compare to the zero vector. We can isolate components for each update and inference
979
+ by considering the dimensions of the beta vectors and the theta vector.
980
+
981
+ Inputs:
982
+ mean_estimating_function_stack:
983
+ The mean of the estimating function stack (a component for each algorithm update and
984
+ inference) across users. This should be a 1D array.
985
+ beta_dim:
986
+ The dimension of the beta vectors that parameterize the algorithm.
987
+ theta_dim:
988
+ The dimension of the theta vector that we estimate during after-study analysis.
989
+
990
+ Returns:
991
+ None
992
+ """
993
+
994
+ logger.info("Checking that RL estimating functions average to zero across users")
995
+
996
+ # Have a looser hard failure cutoff before the typical interactive check
997
+ try:
998
+ np.testing.assert_allclose(
999
+ mean_estimating_function_stack,
1000
+ jnp.zeros(mean_estimating_function_stack.size),
1001
+ atol=1e-2,
1002
+ )
1003
+ except AssertionError as e:
1004
+ logger.info(
1005
+ "RL estimating function stacks do not average to zero across users. Drilling in to specific updates and inference component."
1006
+ )
1007
+ num_updates = (mean_estimating_function_stack.size) // beta_dim
1008
+ for i in range(num_updates):
1009
+ logger.info(
1010
+ "Mean estimating function contribution for update %s:\n%s",
1011
+ i + 1,
1012
+ mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
1013
+ )
1014
+ # TODO: We may need to email instead of failing here for monitoring algorithm.
1015
+ raise e
1016
+
1017
+ try:
1018
+ np.testing.assert_allclose(
1019
+ mean_estimating_function_stack,
1020
+ jnp.zeros(mean_estimating_function_stack.size),
1021
+ atol=1e-5,
1022
+ )
1023
+ except AssertionError as e:
1024
+ logger.info(
1025
+ "RL estimating function stacks do not average to zero across users. Drilling in to specific updates and inference component."
1026
+ )
1027
+ num_updates = (mean_estimating_function_stack.size) // beta_dim
1028
+ for i in range(num_updates):
1029
+ logger.info(
1030
+ "Mean estimating function contribution for update %s:\n%s",
1031
+ i + 1,
1032
+ mean_estimating_function_stack[i * beta_dim : (i + 1) * beta_dim],
1033
+ )
1034
+ # TODO: Email instead of requiring user input for monitoring alg.
1035
+ confirm_input_check_result(
1036
+ f"\nEstimating functions do not average to within default tolerance of zero vector. Please decide if the following is a reasonable result, taking into account the above breakdown by update number and inference. If not, there are several possible reasons for failure mentioned in the contract. Results:\n{str(e)}\n\nContinue? (y/n)\n",
1037
+ suppress_interactive_data_checks,
1038
+ e,
1039
+ )
1040
+
1041
+
1042
+ def require_adaptive_bread_inverse_is_true_inverse(
1043
+ joint_adaptive_bread_matrix,
1044
+ joint_adaptive_bread_inverse_matrix,
1045
+ suppress_interactive_data_checks,
1046
+ ):
1047
+ """
1048
+ Check that the product of the joint adaptive bread matrix and its inverse is
1049
+ sufficiently close to the identity matrix. This is a direct check that the
1050
+ joint_adaptive_bread_inverse_matrix we create is "well-conditioned".
1051
+ """
1052
+ should_be_identity = (
1053
+ joint_adaptive_bread_matrix @ joint_adaptive_bread_inverse_matrix
1054
+ )
1055
+ identity = np.eye(joint_adaptive_bread_matrix.shape[0])
1056
+ try:
1057
+ np.testing.assert_allclose(
1058
+ should_be_identity,
1059
+ identity,
1060
+ rtol=1e-5,
1061
+ atol=1e-5,
1062
+ )
1063
+ except AssertionError as e:
1064
+ confirm_input_check_result(
1065
+ f"\nJoint adaptive bread is not exact inverse of the constructed matrix that was inverted to form it. This likely illustrates poor conditioning:\n{str(e)}\n\nContinue? (y/n)\n",
1066
+ suppress_interactive_data_checks,
1067
+ e,
1068
+ )
1069
+
1070
+ # If we haven't already errored out, return some measures of how far off we are from identity
1071
+ diff = should_be_identity - identity
1072
+ logger.debug(
1073
+ "Difference between should-be-identity produced by multiplying joint adaptive bread inverse and its computed inverse and actual identity:\n%s",
1074
+ diff,
1075
+ )
1076
+
1077
+ diff_abs_max = np.max(np.abs(diff))
1078
+ diff_frobenius_norm = np.linalg.norm(diff, "fro")
1079
+
1080
+ logger.info("Maximum abs element of difference: %s", diff_abs_max)
1081
+ logger.info("Frobenius norm of difference: %s", diff_frobenius_norm)
1082
+
1083
+ return diff_abs_max, diff_frobenius_norm
1084
+
1085
+
1086
+ def require_threaded_algorithm_estimating_function_args_equivalent(
1087
+ algorithm_estimating_func,
1088
+ update_func_args_by_by_user_id_by_policy_num,
1089
+ threaded_update_func_args_by_policy_num_by_user_id,
1090
+ suppress_interactive_data_checks,
1091
+ ):
1092
+ """
1093
+ Check that the algorithm estimating function returns the same values
1094
+ when called with the original arguments and when called with the
1095
+ reconstructed action probabilities substituted in.
1096
+ """
1097
+ for (
1098
+ policy_num,
1099
+ update_func_args_by_user_id,
1100
+ ) in update_func_args_by_by_user_id_by_policy_num.items():
1101
+ for (
1102
+ user_id,
1103
+ unthreaded_args,
1104
+ ) in update_func_args_by_user_id.items():
1105
+ if not unthreaded_args:
1106
+ continue
1107
+ np.testing.assert_allclose(
1108
+ algorithm_estimating_func(*unthreaded_args),
1109
+ # Need to stop gradient here because we can't convert a traced value to np array
1110
+ jax.lax.stop_gradient(
1111
+ algorithm_estimating_func(
1112
+ *threaded_update_func_args_by_policy_num_by_user_id[user_id][
1113
+ policy_num
1114
+ ]
1115
+ )
1116
+ ),
1117
+ atol=1e-7,
1118
+ rtol=1e-3,
1119
+ )
1120
+
1121
+
1122
+ def require_threaded_inference_estimating_function_args_equivalent(
1123
+ inference_estimating_func,
1124
+ inference_func_args_by_user_id,
1125
+ threaded_inference_func_args_by_user_id,
1126
+ suppress_interactive_data_checks,
1127
+ ):
1128
+ """
1129
+ Check that the inference estimating function returns the same values
1130
+ when called with the original arguments and when called with the
1131
+ reconstructed action probabilities substituted in.
1132
+ """
1133
+ for user_id, unthreaded_args in inference_func_args_by_user_id.items():
1134
+ if not unthreaded_args:
1135
+ continue
1136
+ np.testing.assert_allclose(
1137
+ inference_estimating_func(*unthreaded_args),
1138
+ # Need to stop gradient here because we can't convert a traced value to np array
1139
+ jax.lax.stop_gradient(
1140
+ inference_estimating_func(
1141
+ *threaded_inference_func_args_by_user_id[user_id]
1142
+ )
1143
+ ),
1144
+ rtol=1e-2,
1145
+ )