lifejacket 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/calculate_derivatives.py +0 -2
- lifejacket/constants.py +4 -16
- lifejacket/deployment_conditioning_monitor.py +19 -12
- lifejacket/form_adjusted_meat_adjustments_directly.py +25 -27
- lifejacket/get_datum_for_blowup_supervised_learning.py +71 -77
- lifejacket/helper_functions.py +15 -148
- lifejacket/input_checks.py +49 -50
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +377 -144
- lifejacket/small_sample_corrections.py +11 -13
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/METADATA +1 -1
- lifejacket-1.1.0.dist-info/RECORD +17 -0
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/WHEEL +1 -1
- lifejacket-1.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-1.0.0.dist-info/RECORD +0 -17
- lifejacket-1.0.0.dist-info/entry_points.txt +0 -2
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/top_level.txt +0 -0
lifejacket/input_checks.py
CHANGED
|
@@ -8,7 +8,7 @@ from jax import numpy as jnp
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
import plotext as plt
|
|
10
10
|
|
|
11
|
-
from .constants import
|
|
11
|
+
from .constants import SmallSampleCorrections
|
|
12
12
|
from .helper_functions import (
|
|
13
13
|
confirm_input_check_result,
|
|
14
14
|
)
|
|
@@ -64,7 +64,9 @@ def perform_first_wave_input_checks(
|
|
|
64
64
|
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
65
65
|
)
|
|
66
66
|
confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
67
|
-
alg_update_func_args_action_prob_index,
|
|
67
|
+
alg_update_func_args_action_prob_index,
|
|
68
|
+
alg_update_func_args_previous_betas_index,
|
|
69
|
+
suppress_interactive_data_checks,
|
|
68
70
|
)
|
|
69
71
|
require_action_prob_times_given_if_index_supplied(
|
|
70
72
|
alg_update_func_args_action_prob_index,
|
|
@@ -193,6 +195,7 @@ def perform_alg_only_input_checks(
|
|
|
193
195
|
alg_update_func_args_beta_index,
|
|
194
196
|
alg_update_func_args_action_prob_index,
|
|
195
197
|
alg_update_func_args_action_prob_times_index,
|
|
198
|
+
alg_update_func_args_previous_betas_index,
|
|
196
199
|
suppress_interactive_data_checks,
|
|
197
200
|
):
|
|
198
201
|
### Validate algorithm loss/estimating function and args
|
|
@@ -206,7 +209,9 @@ def perform_alg_only_input_checks(
|
|
|
206
209
|
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
207
210
|
)
|
|
208
211
|
confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
209
|
-
alg_update_func_args_action_prob_index,
|
|
212
|
+
alg_update_func_args_action_prob_index,
|
|
213
|
+
alg_update_func_args_previous_betas_index,
|
|
214
|
+
suppress_interactive_data_checks,
|
|
210
215
|
)
|
|
211
216
|
require_action_prob_times_given_if_index_supplied(
|
|
212
217
|
alg_update_func_args_action_prob_index,
|
|
@@ -278,7 +283,7 @@ def require_action_probabilities_in_analysis_df_can_be_reconstructed(
|
|
|
278
283
|
action_prob_func,
|
|
279
284
|
):
|
|
280
285
|
"""
|
|
281
|
-
Check that the action probabilities in the
|
|
286
|
+
Check that the action probabilities in the analysis DataFrame can be reconstructed from the supplied
|
|
282
287
|
action probability function and its arguments.
|
|
283
288
|
|
|
284
289
|
NOTE THAT THIS IS A HARD FAILURE IF THE RECONSTRUCTION DOESN'T PASS.
|
|
@@ -317,7 +322,7 @@ def require_all_subjects_have_all_times_in_analysis_df(
|
|
|
317
322
|
# Check if all subjects have the same set of unique calendar times
|
|
318
323
|
if not subject_calendar_times.apply(lambda x: x == unique_calendar_times).all():
|
|
319
324
|
raise AssertionError(
|
|
320
|
-
"Not all subjects have all calendar times in the
|
|
325
|
+
"Not all subjects have all calendar times in the analysis DataFrame. Please see the contract for details."
|
|
321
326
|
)
|
|
322
327
|
|
|
323
328
|
|
|
@@ -345,7 +350,7 @@ def require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
|
|
|
345
350
|
):
|
|
346
351
|
logger.info(
|
|
347
352
|
"Checking that the action probabilities supplied in the algorithm update function args, if"
|
|
348
|
-
" any, correspond to those in the
|
|
353
|
+
" any, correspond to those in the analysis DataFrame for the corresponding subjects and decision"
|
|
349
354
|
" times."
|
|
350
355
|
)
|
|
351
356
|
if alg_update_func_args_action_prob_index < 0:
|
|
@@ -377,7 +382,7 @@ def require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
|
|
|
377
382
|
), (
|
|
378
383
|
f"There is a mismatch for subject {subject_id} between the action probabilities supplied"
|
|
379
384
|
f" in the args to the algorithm update function at policy {policy_num} and those in"
|
|
380
|
-
" the
|
|
385
|
+
" the analysis DataFrame for the supplied times. Please see the contract for details."
|
|
381
386
|
)
|
|
382
387
|
|
|
383
388
|
|
|
@@ -418,7 +423,7 @@ def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times
|
|
|
418
423
|
):
|
|
419
424
|
logger.info(
|
|
420
425
|
"Checking that action probability function args are blank for exactly the times each subject"
|
|
421
|
-
"is not in the study according to the
|
|
426
|
+
"is not in the study according to the analysis DataFrame."
|
|
422
427
|
)
|
|
423
428
|
inactive_df = analysis_df[analysis_df[active_col_name] == 0]
|
|
424
429
|
inactive_times_by_subject_according_to_analysis_df = (
|
|
@@ -441,7 +446,7 @@ def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times
|
|
|
441
446
|
inactive_times_by_subject_according_to_analysis_df
|
|
442
447
|
== inactive_times_by_subject_according_to_action_prob_func_args
|
|
443
448
|
), (
|
|
444
|
-
"
|
|
449
|
+
"Inactive decision times according to the analysis DataFrame do not match up with the"
|
|
445
450
|
" times for which action probability arguments are blank for all subjects. Please see the"
|
|
446
451
|
" contract for details."
|
|
447
452
|
)
|
|
@@ -456,21 +461,27 @@ def require_all_named_columns_present_in_analysis_df(
|
|
|
456
461
|
subject_id_col_name,
|
|
457
462
|
action_prob_col_name,
|
|
458
463
|
):
|
|
459
|
-
logger.info(
|
|
460
|
-
|
|
461
|
-
|
|
464
|
+
logger.info(
|
|
465
|
+
"Checking that all named columns are present in the analysis DataFrame."
|
|
466
|
+
)
|
|
467
|
+
assert (
|
|
468
|
+
active_col_name in analysis_df.columns
|
|
469
|
+
), f"{active_col_name} not in analysis DataFrame."
|
|
470
|
+
assert (
|
|
471
|
+
action_col_name in analysis_df.columns
|
|
472
|
+
), f"{action_col_name} not in analysis DataFrame."
|
|
462
473
|
assert (
|
|
463
474
|
policy_num_col_name in analysis_df.columns
|
|
464
|
-
), f"{policy_num_col_name} not in
|
|
475
|
+
), f"{policy_num_col_name} not in analysis DataFrame."
|
|
465
476
|
assert (
|
|
466
477
|
calendar_t_col_name in analysis_df.columns
|
|
467
|
-
), f"{calendar_t_col_name} not in
|
|
478
|
+
), f"{calendar_t_col_name} not in analysis DataFrame."
|
|
468
479
|
assert (
|
|
469
480
|
subject_id_col_name in analysis_df.columns
|
|
470
|
-
), f"{subject_id_col_name} not in
|
|
481
|
+
), f"{subject_id_col_name} not in analysis DataFrame."
|
|
471
482
|
assert (
|
|
472
483
|
action_prob_col_name in analysis_df.columns
|
|
473
|
-
), f"{action_prob_col_name} not in
|
|
484
|
+
), f"{action_prob_col_name} not in analysis DataFrame."
|
|
474
485
|
|
|
475
486
|
|
|
476
487
|
def require_all_named_columns_not_object_type_in_analysis_df(
|
|
@@ -493,7 +504,7 @@ def require_all_named_columns_not_object_type_in_analysis_df(
|
|
|
493
504
|
):
|
|
494
505
|
assert (
|
|
495
506
|
analysis_df[colname].dtype != "object"
|
|
496
|
-
), f"At least {colname} is of object type in
|
|
507
|
+
), f"At least {colname} is of object type in analysis DataFrame."
|
|
497
508
|
|
|
498
509
|
|
|
499
510
|
def require_binary_actions(analysis_df, active_col_name, action_col_name):
|
|
@@ -574,12 +585,12 @@ def require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
|
|
|
574
585
|
analysis_df, policy_num_col_name, alg_update_func_args
|
|
575
586
|
):
|
|
576
587
|
logger.info(
|
|
577
|
-
"Checking that policy numbers in algorithm update function args are present in the
|
|
588
|
+
"Checking that policy numbers in algorithm update function args are present in the analysis DataFrame."
|
|
578
589
|
)
|
|
579
590
|
alg_update_policy_nums = sorted(alg_update_func_args.keys())
|
|
580
591
|
analysis_df_policy_nums = sorted(analysis_df[policy_num_col_name].unique())
|
|
581
592
|
assert set(alg_update_policy_nums).issubset(set(analysis_df_policy_nums)), (
|
|
582
|
-
f"There are policy numbers present in algorithm update function args but not in the
|
|
593
|
+
f"There are policy numbers present in algorithm update function args but not in the analysis DataFrame. "
|
|
583
594
|
f"\nalg_update_func_args policy numbers: {alg_update_policy_nums}"
|
|
584
595
|
f"\nanalysis_df policy numbers: {analysis_df_policy_nums}.\nPlease see the contract for details."
|
|
585
596
|
)
|
|
@@ -589,7 +600,7 @@ def require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallba
|
|
|
589
600
|
analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
|
|
590
601
|
):
|
|
591
602
|
logger.info(
|
|
592
|
-
"Checking that all policy numbers in the
|
|
603
|
+
"Checking that all policy numbers in the analysis DataFrame are present in the algorithm update function args."
|
|
593
604
|
)
|
|
594
605
|
active_df = analysis_df[analysis_df[active_col_name] == 1]
|
|
595
606
|
# Get the number of the initial policy. 0 is recommended but not required.
|
|
@@ -602,19 +613,23 @@ def require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallba
|
|
|
602
613
|
].unique()
|
|
603
614
|
).issubset(
|
|
604
615
|
alg_update_func_args.keys()
|
|
605
|
-
), f"There are non-fallback, non-initial policy numbers in the
|
|
616
|
+
), f"There are non-fallback, non-initial policy numbers in the analysis DataFrame that are not in the update function args: {set(active_df[active_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
|
|
606
617
|
|
|
607
618
|
|
|
608
619
|
def confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
|
|
609
620
|
alg_update_func_args_action_prob_index,
|
|
621
|
+
alg_update_func_args_previous_betas_index,
|
|
610
622
|
suppress_interactive_data_checks,
|
|
611
623
|
):
|
|
612
624
|
logger.info(
|
|
613
625
|
"Confirming that action probabilities are not in algorithm update function args IF their index is not specified"
|
|
614
626
|
)
|
|
615
|
-
if
|
|
627
|
+
if (
|
|
628
|
+
alg_update_func_args_action_prob_index < 0
|
|
629
|
+
and alg_update_func_args_previous_betas_index < 0
|
|
630
|
+
):
|
|
616
631
|
confirm_input_check_result(
|
|
617
|
-
"\nYou specified that the algorithm update function supplied does not have action probabilities
|
|
632
|
+
"\nYou specified that the algorithm update function supplied does not have action probabilities or previous betas in its arguments. Please verify this is correct.\n\nContinue? (y/n)\n",
|
|
618
633
|
suppress_interactive_data_checks,
|
|
619
634
|
)
|
|
620
635
|
|
|
@@ -633,20 +648,6 @@ def confirm_no_small_sample_correction_desired_if_not_requested(
|
|
|
633
648
|
)
|
|
634
649
|
|
|
635
650
|
|
|
636
|
-
def confirm_no_adaptive_bread_inverse_stabilization_method_desired_if_not_requested(
|
|
637
|
-
adaptive_bread_inverse_stabilization_method,
|
|
638
|
-
suppress_interactive_data_checks,
|
|
639
|
-
):
|
|
640
|
-
logger.info(
|
|
641
|
-
"Confirming that no adaptive bread inverse stabilization method is desired if it's not requested."
|
|
642
|
-
)
|
|
643
|
-
if adaptive_bread_inverse_stabilization_method == InverseStabilizationMethods.NONE:
|
|
644
|
-
confirm_input_check_result(
|
|
645
|
-
"\nYou specified that you would not like to perform any inverse stabilization while forming the adaptive variance. This is not usually recommended. Please verify that it is correct or select one of the available options.\n\nContinue? (y/n)\n",
|
|
646
|
-
suppress_interactive_data_checks,
|
|
647
|
-
)
|
|
648
|
-
|
|
649
|
-
|
|
650
651
|
def require_action_prob_times_given_if_index_supplied(
|
|
651
652
|
alg_update_func_args_action_prob_index,
|
|
652
653
|
alg_update_func_args_action_prob_times_index,
|
|
@@ -809,7 +810,7 @@ def verify_analysis_df_summary_satisfactory(
|
|
|
809
810
|
avg_reward_trajectory_plot = plt.build()
|
|
810
811
|
|
|
811
812
|
confirm_input_check_result(
|
|
812
|
-
f"\nYou provided
|
|
813
|
+
f"\nYou provided an analysis DataFrame reflecting a study with"
|
|
813
814
|
f"\n* {num_subjects} subjects"
|
|
814
815
|
f"\n* {num_non_initial_or_fallback_policies} policy updates"
|
|
815
816
|
f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_subject}"
|
|
@@ -924,7 +925,7 @@ def require_valid_action_prob_times_given_if_index_supplied(
|
|
|
924
925
|
), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for subject {subject_id} and policy {policy_idx}. Please see the contract for details."
|
|
925
926
|
assert (
|
|
926
927
|
times[0] >= min_time and times[-1] <= max_time
|
|
927
|
-
), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the
|
|
928
|
+
), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the analysis DataFrame are {min_time} and {max_time}, while subject {subject_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
|
|
928
929
|
|
|
929
930
|
|
|
930
931
|
def require_estimating_functions_sum_to_zero(
|
|
@@ -1089,20 +1090,18 @@ def require_RL_estimating_functions_sum_to_zero(
|
|
|
1089
1090
|
)
|
|
1090
1091
|
|
|
1091
1092
|
|
|
1092
|
-
def
|
|
1093
|
-
|
|
1094
|
-
|
|
1093
|
+
def require_joint_bread_inverse_is_true_inverse(
|
|
1094
|
+
joint_bread_inverse_matrix,
|
|
1095
|
+
joint_bread_matrix,
|
|
1095
1096
|
suppress_interactive_data_checks,
|
|
1096
1097
|
):
|
|
1097
1098
|
"""
|
|
1098
|
-
Check that the product of the joint
|
|
1099
|
+
Check that the product of the joint bread matrix and its inverse is
|
|
1099
1100
|
sufficiently close to the identity matrix. This is a direct check that the
|
|
1100
|
-
|
|
1101
|
+
joint_bread_matrix we create is "well-conditioned".
|
|
1101
1102
|
"""
|
|
1102
|
-
should_be_identity =
|
|
1103
|
-
|
|
1104
|
-
)
|
|
1105
|
-
identity = np.eye(joint_adaptive_bread_matrix.shape[0])
|
|
1103
|
+
should_be_identity = joint_bread_inverse_matrix @ joint_bread_matrix
|
|
1104
|
+
identity = np.eye(joint_bread_matrix.shape[0])
|
|
1106
1105
|
try:
|
|
1107
1106
|
np.testing.assert_allclose(
|
|
1108
1107
|
should_be_identity,
|
|
@@ -1112,7 +1111,7 @@ def require_adaptive_bread_inverse_is_true_inverse(
|
|
|
1112
1111
|
)
|
|
1113
1112
|
except AssertionError as e:
|
|
1114
1113
|
confirm_input_check_result(
|
|
1115
|
-
f"\nJoint
|
|
1114
|
+
f"\nJoint bread inverse is not exact inverse of the constructed matrix that was inverted to form it. This likely illustrates poor conditioning:\n{str(e)}\n\nContinue? (y/n)\n",
|
|
1116
1115
|
suppress_interactive_data_checks,
|
|
1117
1116
|
e,
|
|
1118
1117
|
)
|
|
@@ -1120,7 +1119,7 @@ def require_adaptive_bread_inverse_is_true_inverse(
|
|
|
1120
1119
|
# If we haven't already errored out, return some measures of how far off we are from identity
|
|
1121
1120
|
diff = should_be_identity - identity
|
|
1122
1121
|
logger.debug(
|
|
1123
|
-
"Difference between should-be-identity produced by multiplying joint
|
|
1122
|
+
"Difference between should-be-identity produced by multiplying joint bread and its computed inverse and actual identity:\n%s",
|
|
1124
1123
|
diff,
|
|
1125
1124
|
)
|
|
1126
1125
|
|