lifejacket 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ from jax import numpy as jnp
8
8
  import pandas as pd
9
9
  import plotext as plt
10
10
 
11
- from .constants import InverseStabilizationMethods, SmallSampleCorrections
11
+ from .constants import SmallSampleCorrections
12
12
  from .helper_functions import (
13
13
  confirm_input_check_result,
14
14
  )
@@ -64,7 +64,9 @@ def perform_first_wave_input_checks(
64
64
  analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
65
65
  )
66
66
  confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
67
- alg_update_func_args_action_prob_index, suppress_interactive_data_checks
67
+ alg_update_func_args_action_prob_index,
68
+ alg_update_func_args_previous_betas_index,
69
+ suppress_interactive_data_checks,
68
70
  )
69
71
  require_action_prob_times_given_if_index_supplied(
70
72
  alg_update_func_args_action_prob_index,
@@ -193,6 +195,7 @@ def perform_alg_only_input_checks(
193
195
  alg_update_func_args_beta_index,
194
196
  alg_update_func_args_action_prob_index,
195
197
  alg_update_func_args_action_prob_times_index,
198
+ alg_update_func_args_previous_betas_index,
196
199
  suppress_interactive_data_checks,
197
200
  ):
198
201
  ### Validate algorithm loss/estimating function and args
@@ -206,7 +209,9 @@ def perform_alg_only_input_checks(
206
209
  analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
207
210
  )
208
211
  confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
209
- alg_update_func_args_action_prob_index, suppress_interactive_data_checks
212
+ alg_update_func_args_action_prob_index,
213
+ alg_update_func_args_previous_betas_index,
214
+ suppress_interactive_data_checks,
210
215
  )
211
216
  require_action_prob_times_given_if_index_supplied(
212
217
  alg_update_func_args_action_prob_index,
@@ -278,7 +283,7 @@ def require_action_probabilities_in_analysis_df_can_be_reconstructed(
278
283
  action_prob_func,
279
284
  ):
280
285
  """
281
- Check that the action probabilities in the study dataframe can be reconstructed from the supplied
286
+ Check that the action probabilities in the analysis DataFrame can be reconstructed from the supplied
282
287
  action probability function and its arguments.
283
288
 
284
289
  NOTE THAT THIS IS A HARD FAILURE IF THE RECONSTRUCTION DOESN'T PASS.
@@ -317,7 +322,7 @@ def require_all_subjects_have_all_times_in_analysis_df(
317
322
  # Check if all subjects have the same set of unique calendar times
318
323
  if not subject_calendar_times.apply(lambda x: x == unique_calendar_times).all():
319
324
  raise AssertionError(
320
- "Not all subjects have all calendar times in the study dataframe. Please see the contract for details."
325
+ "Not all subjects have all calendar times in the analysis DataFrame. Please see the contract for details."
321
326
  )
322
327
 
323
328
 
@@ -345,7 +350,7 @@ def require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
345
350
  ):
346
351
  logger.info(
347
352
  "Checking that the action probabilities supplied in the algorithm update function args, if"
348
- " any, correspond to those in the study dataframe for the corresponding subjects and decision"
353
+ " any, correspond to those in the analysis DataFrame for the corresponding subjects and decision"
349
354
  " times."
350
355
  )
351
356
  if alg_update_func_args_action_prob_index < 0:
@@ -377,7 +382,7 @@ def require_action_prob_args_in_alg_update_func_correspond_to_analysis_df(
377
382
  ), (
378
383
  f"There is a mismatch for subject {subject_id} between the action probabilities supplied"
379
384
  f" in the args to the algorithm update function at policy {policy_num} and those in"
380
- " the study dataframe for the supplied times. Please see the contract for details."
385
+ " the analysis DataFrame for the supplied times. Please see the contract for details."
381
386
  )
382
387
 
383
388
 
@@ -418,7 +423,7 @@ def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times
418
423
  ):
419
424
  logger.info(
420
425
  "Checking that action probability function args are blank for exactly the times each subject"
421
- "is not in the study according to the study dataframe."
426
+ "is not in the study according to the analysis DataFrame."
422
427
  )
423
428
  inactive_df = analysis_df[analysis_df[active_col_name] == 0]
424
429
  inactive_times_by_subject_according_to_analysis_df = (
@@ -441,7 +446,7 @@ def require_out_of_study_decision_times_are_exactly_blank_action_prob_args_times
441
446
  inactive_times_by_subject_according_to_analysis_df
442
447
  == inactive_times_by_subject_according_to_action_prob_func_args
443
448
  ), (
444
- "Out-of-study decision times according to the study dataframe do not match up with the"
449
+ "Inactive decision times according to the analysis DataFrame do not match up with the"
445
450
  " times for which action probability arguments are blank for all subjects. Please see the"
446
451
  " contract for details."
447
452
  )
@@ -456,21 +461,27 @@ def require_all_named_columns_present_in_analysis_df(
456
461
  subject_id_col_name,
457
462
  action_prob_col_name,
458
463
  ):
459
- logger.info("Checking that all named columns are present in the study dataframe.")
460
- assert active_col_name in analysis_df.columns, f"{active_col_name} not in study df."
461
- assert action_col_name in analysis_df.columns, f"{action_col_name} not in study df."
464
+ logger.info(
465
+ "Checking that all named columns are present in the analysis DataFrame."
466
+ )
467
+ assert (
468
+ active_col_name in analysis_df.columns
469
+ ), f"{active_col_name} not in analysis DataFrame."
470
+ assert (
471
+ action_col_name in analysis_df.columns
472
+ ), f"{action_col_name} not in analysis DataFrame."
462
473
  assert (
463
474
  policy_num_col_name in analysis_df.columns
464
- ), f"{policy_num_col_name} not in study df."
475
+ ), f"{policy_num_col_name} not in analysis DataFrame."
465
476
  assert (
466
477
  calendar_t_col_name in analysis_df.columns
467
- ), f"{calendar_t_col_name} not in study df."
478
+ ), f"{calendar_t_col_name} not in analysis DataFrame."
468
479
  assert (
469
480
  subject_id_col_name in analysis_df.columns
470
- ), f"{subject_id_col_name} not in study df."
481
+ ), f"{subject_id_col_name} not in analysis DataFrame."
471
482
  assert (
472
483
  action_prob_col_name in analysis_df.columns
473
- ), f"{action_prob_col_name} not in study df."
484
+ ), f"{action_prob_col_name} not in analysis DataFrame."
474
485
 
475
486
 
476
487
  def require_all_named_columns_not_object_type_in_analysis_df(
@@ -493,7 +504,7 @@ def require_all_named_columns_not_object_type_in_analysis_df(
493
504
  ):
494
505
  assert (
495
506
  analysis_df[colname].dtype != "object"
496
- ), f"At least {colname} is of object type in study df."
507
+ ), f"At least {colname} is of object type in analysis DataFrame."
497
508
 
498
509
 
499
510
  def require_binary_actions(analysis_df, active_col_name, action_col_name):
@@ -574,12 +585,12 @@ def require_no_policy_numbers_present_in_alg_update_args_but_not_analysis_df(
574
585
  analysis_df, policy_num_col_name, alg_update_func_args
575
586
  ):
576
587
  logger.info(
577
- "Checking that policy numbers in algorithm update function args are present in the study dataframe."
588
+ "Checking that policy numbers in algorithm update function args are present in the analysis DataFrame."
578
589
  )
579
590
  alg_update_policy_nums = sorted(alg_update_func_args.keys())
580
591
  analysis_df_policy_nums = sorted(analysis_df[policy_num_col_name].unique())
581
592
  assert set(alg_update_policy_nums).issubset(set(analysis_df_policy_nums)), (
582
- f"There are policy numbers present in algorithm update function args but not in the study dataframe. "
593
+ f"There are policy numbers present in algorithm update function args but not in the analysis DataFrame. "
583
594
  f"\nalg_update_func_args policy numbers: {alg_update_policy_nums}"
584
595
  f"\nanalysis_df policy numbers: {analysis_df_policy_nums}.\nPlease see the contract for details."
585
596
  )
@@ -589,7 +600,7 @@ def require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallba
589
600
  analysis_df, active_col_name, policy_num_col_name, alg_update_func_args
590
601
  ):
591
602
  logger.info(
592
- "Checking that all policy numbers in the study dataframe are present in the algorithm update function args."
603
+ "Checking that all policy numbers in the analysis DataFrame are present in the algorithm update function args."
593
604
  )
594
605
  active_df = analysis_df[analysis_df[active_col_name] == 1]
595
606
  # Get the number of the initial policy. 0 is recommended but not required.
@@ -602,19 +613,23 @@ def require_all_policy_numbers_in_analysis_df_except_possibly_initial_and_fallba
602
613
  ].unique()
603
614
  ).issubset(
604
615
  alg_update_func_args.keys()
605
- ), f"There are non-fallback, non-initial policy numbers in the study dataframe that are not in the update function args: {set(active_df[active_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
616
+ ), f"There are non-fallback, non-initial policy numbers in the analysis DataFrame that are not in the update function args: {set(active_df[active_df[policy_num_col_name] > 0][policy_num_col_name].unique()) - set(alg_update_func_args.keys())}. Please see the contract for details."
606
617
 
607
618
 
608
619
  def confirm_action_probabilities_not_in_alg_update_args_if_index_not_supplied(
609
620
  alg_update_func_args_action_prob_index,
621
+ alg_update_func_args_previous_betas_index,
610
622
  suppress_interactive_data_checks,
611
623
  ):
612
624
  logger.info(
613
625
  "Confirming that action probabilities are not in algorithm update function args IF their index is not specified"
614
626
  )
615
- if alg_update_func_args_action_prob_index < 0:
627
+ if (
628
+ alg_update_func_args_action_prob_index < 0
629
+ and alg_update_func_args_previous_betas_index < 0
630
+ ):
616
631
  confirm_input_check_result(
617
- "\nYou specified that the algorithm update function supplied does not have action probabilities as one of its arguments. Please verify this is correct.\n\nContinue? (y/n)\n",
632
+ "\nYou specified that the algorithm update function supplied does not have action probabilities or previous betas in its arguments. Please verify this is correct.\n\nContinue? (y/n)\n",
618
633
  suppress_interactive_data_checks,
619
634
  )
620
635
 
@@ -633,20 +648,6 @@ def confirm_no_small_sample_correction_desired_if_not_requested(
633
648
  )
634
649
 
635
650
 
636
- def confirm_no_adaptive_bread_inverse_stabilization_method_desired_if_not_requested(
637
- adaptive_bread_inverse_stabilization_method,
638
- suppress_interactive_data_checks,
639
- ):
640
- logger.info(
641
- "Confirming that no adaptive bread inverse stabilization method is desired if it's not requested."
642
- )
643
- if adaptive_bread_inverse_stabilization_method == InverseStabilizationMethods.NONE:
644
- confirm_input_check_result(
645
- "\nYou specified that you would not like to perform any inverse stabilization while forming the adaptive variance. This is not usually recommended. Please verify that it is correct or select one of the available options.\n\nContinue? (y/n)\n",
646
- suppress_interactive_data_checks,
647
- )
648
-
649
-
650
651
  def require_action_prob_times_given_if_index_supplied(
651
652
  alg_update_func_args_action_prob_index,
652
653
  alg_update_func_args_action_prob_times_index,
@@ -809,7 +810,7 @@ def verify_analysis_df_summary_satisfactory(
809
810
  avg_reward_trajectory_plot = plt.build()
810
811
 
811
812
  confirm_input_check_result(
812
- f"\nYou provided a study dataframe reflecting a study with"
813
+ f"\nYou provided an analysis DataFrame reflecting a study with"
813
814
  f"\n* {num_subjects} subjects"
814
815
  f"\n* {num_non_initial_or_fallback_policies} policy updates"
815
816
  f"\n* {num_decision_times} decision times, for an average of {avg_decisions_per_subject}"
@@ -924,7 +925,7 @@ def require_valid_action_prob_times_given_if_index_supplied(
924
925
  ), f"Non-strictly-increasing times were given for action probabilities in the algorithm update function args for subject {subject_id} and policy {policy_idx}. Please see the contract for details."
925
926
  assert (
926
927
  times[0] >= min_time and times[-1] <= max_time
927
- ), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the study dataframe are {min_time} and {max_time}, while subject {subject_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
928
+ ), f"Times not present in the study were given for action probabilities in the algorithm update function args. The min and max times in the analysis DataFrame are {min_time} and {max_time}, while subject {subject_id} has times {times} supplied for policy {policy_idx}. Please see the contract for details."
928
929
 
929
930
 
930
931
  def require_estimating_functions_sum_to_zero(
@@ -1089,20 +1090,18 @@ def require_RL_estimating_functions_sum_to_zero(
1089
1090
  )
1090
1091
 
1091
1092
 
1092
- def require_adaptive_bread_inverse_is_true_inverse(
1093
- joint_adaptive_bread_matrix,
1094
- joint_adaptive_bread_inverse_matrix,
1093
+ def require_joint_bread_inverse_is_true_inverse(
1094
+ joint_bread_inverse_matrix,
1095
+ joint_bread_matrix,
1095
1096
  suppress_interactive_data_checks,
1096
1097
  ):
1097
1098
  """
1098
- Check that the product of the joint adaptive bread matrix and its inverse is
1099
+ Check that the product of the joint bread matrix and its inverse is
1099
1100
  sufficiently close to the identity matrix. This is a direct check that the
1100
- joint_adaptive_bread_inverse_matrix we create is "well-conditioned".
1101
+ joint_bread_matrix we create is "well-conditioned".
1101
1102
  """
1102
- should_be_identity = (
1103
- joint_adaptive_bread_matrix @ joint_adaptive_bread_inverse_matrix
1104
- )
1105
- identity = np.eye(joint_adaptive_bread_matrix.shape[0])
1103
+ should_be_identity = joint_bread_inverse_matrix @ joint_bread_matrix
1104
+ identity = np.eye(joint_bread_matrix.shape[0])
1106
1105
  try:
1107
1106
  np.testing.assert_allclose(
1108
1107
  should_be_identity,
@@ -1112,7 +1111,7 @@ def require_adaptive_bread_inverse_is_true_inverse(
1112
1111
  )
1113
1112
  except AssertionError as e:
1114
1113
  confirm_input_check_result(
1115
- f"\nJoint adaptive bread is not exact inverse of the constructed matrix that was inverted to form it. This likely illustrates poor conditioning:\n{str(e)}\n\nContinue? (y/n)\n",
1114
+ f"\nJoint bread inverse is not exact inverse of the constructed matrix that was inverted to form it. This likely illustrates poor conditioning:\n{str(e)}\n\nContinue? (y/n)\n",
1116
1115
  suppress_interactive_data_checks,
1117
1116
  e,
1118
1117
  )
@@ -1120,7 +1119,7 @@ def require_adaptive_bread_inverse_is_true_inverse(
1120
1119
  # If we haven't already errored out, return some measures of how far off we are from identity
1121
1120
  diff = should_be_identity - identity
1122
1121
  logger.debug(
1123
- "Difference between should-be-identity produced by multiplying joint adaptive bread inverse and its computed inverse and actual identity:\n%s",
1122
+ "Difference between should-be-identity produced by multiplying joint bread and its computed inverse and actual identity:\n%s",
1124
1123
  diff,
1125
1124
  )
1126
1125