lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/after_study_analysis.py +397 -387
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -21
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +146 -128
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +7 -7
- lifejacket/get_datum_for_blowup_supervised_learning.py +315 -307
- lifejacket/helper_functions.py +45 -38
- lifejacket/input_checks.py +263 -261
- lifejacket/small_sample_corrections.py +42 -40
- lifejacket-1.0.0.dist-info/METADATA +56 -0
- lifejacket-1.0.0.dist-info/RECORD +17 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/top_level.txt +0 -0
lifejacket/helper_functions.py
CHANGED
|
@@ -378,9 +378,9 @@ def confirm_input_check_result(message, suppress_interaction, error=None):
|
|
|
378
378
|
print("\nPlease enter 'y' or 'n'.\n")
|
|
379
379
|
|
|
380
380
|
|
|
381
|
-
def
|
|
381
|
+
def get_active_df_column(analysis_df, col_name, active_col_name):
|
|
382
382
|
return jnp.array(
|
|
383
|
-
|
|
383
|
+
analysis_df.loc[analysis_df[active_col_name] == 1, col_name]
|
|
384
384
|
.to_numpy()
|
|
385
385
|
.reshape(-1, 1)
|
|
386
386
|
)
|
|
@@ -408,7 +408,7 @@ def get_radon_nikodym_weight(
|
|
|
408
408
|
action_prob_func: callable,
|
|
409
409
|
action_prob_func_args_beta_index: int,
|
|
410
410
|
action: int,
|
|
411
|
-
*
|
|
411
|
+
*action_prob_func_args_single_subject: tuple[Any, ...],
|
|
412
412
|
):
|
|
413
413
|
"""
|
|
414
414
|
Computes a ratio of action probabilities under two sets of algorithm parameters:
|
|
@@ -426,13 +426,13 @@ def get_radon_nikodym_weight(
|
|
|
426
426
|
The beta value to use in the denominator. NOT involved in differentation!
|
|
427
427
|
action_prob_func (callable):
|
|
428
428
|
The function used to compute the probability of action 1 at a given decision time for
|
|
429
|
-
a particular
|
|
429
|
+
a particular subject given their state and the algorithm parameters.
|
|
430
430
|
action_prob_func_args_beta_index (int):
|
|
431
431
|
The index of the beta argument in the action probability function's arguments.
|
|
432
432
|
action (int):
|
|
433
433
|
The actual taken action at the relevant decision time.
|
|
434
|
-
*
|
|
435
|
-
The arguments to the action probability function for the relevant
|
|
434
|
+
*action_prob_func_args_single_subject (tuple[Any, ...]):
|
|
435
|
+
The arguments to the action probability function for the relevant subject at this time.
|
|
436
436
|
|
|
437
437
|
Returns:
|
|
438
438
|
jnp.float32: The Radon-Nikodym weight.
|
|
@@ -440,15 +440,17 @@ def get_radon_nikodym_weight(
|
|
|
440
440
|
"""
|
|
441
441
|
|
|
442
442
|
# numerator
|
|
443
|
-
pi_beta = action_prob_func(*
|
|
443
|
+
pi_beta = action_prob_func(*action_prob_func_args_single_subject)
|
|
444
444
|
|
|
445
445
|
# denominator, where we thread in beta_target so that differentiation with respect to the
|
|
446
446
|
# original beta in the arguments leaves this alone.
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
447
|
+
beta_target_action_prob_func_args_single_subject = [
|
|
448
|
+
*action_prob_func_args_single_subject
|
|
449
|
+
]
|
|
450
|
+
beta_target_action_prob_func_args_single_subject[
|
|
451
|
+
action_prob_func_args_beta_index
|
|
452
|
+
] = beta_target
|
|
453
|
+
pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_subject)
|
|
452
454
|
|
|
453
455
|
return conditional_x_or_one_minus_x(pi_beta, action) / conditional_x_or_one_minus_x(
|
|
454
456
|
pi_beta_target, action
|
|
@@ -456,7 +458,7 @@ def get_radon_nikodym_weight(
|
|
|
456
458
|
|
|
457
459
|
|
|
458
460
|
def get_min_time_by_policy_num(
|
|
459
|
-
|
|
461
|
+
single_subject_policy_num_by_decision_time, beta_index_by_policy_num
|
|
460
462
|
):
|
|
461
463
|
"""
|
|
462
464
|
Returns a dictionary mapping each policy number to the first time it was applicable,
|
|
@@ -464,12 +466,12 @@ def get_min_time_by_policy_num(
|
|
|
464
466
|
"""
|
|
465
467
|
min_time_by_policy_num = {}
|
|
466
468
|
first_time_after_first_update = None
|
|
467
|
-
for decision_time, policy_num in
|
|
469
|
+
for decision_time, policy_num in single_subject_policy_num_by_decision_time.items():
|
|
468
470
|
if policy_num not in min_time_by_policy_num:
|
|
469
471
|
min_time_by_policy_num[policy_num] = decision_time
|
|
470
472
|
|
|
471
473
|
# Grab the first time where a non-initial, non-fallback policy is used.
|
|
472
|
-
# Assumes
|
|
474
|
+
# Assumes single_subject_policy_num_by_decision_time is sorted.
|
|
473
475
|
if (
|
|
474
476
|
policy_num in beta_index_by_policy_num
|
|
475
477
|
and first_time_after_first_update is None
|
|
@@ -494,10 +496,10 @@ def calculate_beta_dim(
|
|
|
494
496
|
int: The dimension of the beta vector.
|
|
495
497
|
"""
|
|
496
498
|
for decision_time in action_prob_func_args:
|
|
497
|
-
for
|
|
498
|
-
if action_prob_func_args[decision_time][
|
|
499
|
+
for subject_id in action_prob_func_args[decision_time]:
|
|
500
|
+
if action_prob_func_args[decision_time][subject_id]:
|
|
499
501
|
return len(
|
|
500
|
-
action_prob_func_args[decision_time][
|
|
502
|
+
action_prob_func_args[decision_time][subject_id][
|
|
501
503
|
action_prob_func_args_beta_index
|
|
502
504
|
]
|
|
503
505
|
)
|
|
@@ -507,7 +509,7 @@ def calculate_beta_dim(
|
|
|
507
509
|
|
|
508
510
|
|
|
509
511
|
def construct_beta_index_by_policy_num_map(
|
|
510
|
-
|
|
512
|
+
analysis_df: pd.DataFrame, policy_num_col_name: str, active_col_name: str
|
|
511
513
|
) -> tuple[dict[int | float, int], int | float]:
|
|
512
514
|
"""
|
|
513
515
|
Constructs a mapping from non-initial, non-fallback policy numbers to the index of the
|
|
@@ -524,8 +526,9 @@ def construct_beta_index_by_policy_num_map(
|
|
|
524
526
|
"""
|
|
525
527
|
|
|
526
528
|
unique_sorted_non_fallback_policy_nums = sorted(
|
|
527
|
-
|
|
528
|
-
(
|
|
529
|
+
analysis_df[
|
|
530
|
+
(analysis_df[policy_num_col_name] >= 0)
|
|
531
|
+
& (analysis_df[active_col_name] == 1)
|
|
529
532
|
][policy_num_col_name]
|
|
530
533
|
.unique()
|
|
531
534
|
.tolist()
|
|
@@ -550,10 +553,10 @@ def collect_all_post_update_betas(
|
|
|
550
553
|
"""
|
|
551
554
|
all_post_update_betas = []
|
|
552
555
|
for policy_num in sorted(beta_index_by_policy_num.keys()):
|
|
553
|
-
for
|
|
554
|
-
if alg_update_func_args[policy_num][
|
|
556
|
+
for subject_id in alg_update_func_args[policy_num]:
|
|
557
|
+
if alg_update_func_args[policy_num][subject_id]:
|
|
555
558
|
all_post_update_betas.append(
|
|
556
|
-
alg_update_func_args[policy_num][
|
|
559
|
+
alg_update_func_args[policy_num][subject_id][
|
|
557
560
|
alg_update_func_args_beta_index
|
|
558
561
|
]
|
|
559
562
|
)
|
|
@@ -561,27 +564,31 @@ def collect_all_post_update_betas(
|
|
|
561
564
|
return jnp.array(all_post_update_betas)
|
|
562
565
|
|
|
563
566
|
|
|
564
|
-
def
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
567
|
+
def extract_action_and_policy_by_decision_time_by_subject_id(
|
|
568
|
+
analysis_df,
|
|
569
|
+
subject_id_col_name,
|
|
570
|
+
active_col_name,
|
|
568
571
|
calendar_t_col_name,
|
|
569
572
|
action_col_name,
|
|
570
573
|
policy_num_col_name,
|
|
571
574
|
):
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
for
|
|
575
|
-
|
|
576
|
-
|
|
575
|
+
action_by_decision_time_by_subject_id = {}
|
|
576
|
+
policy_num_by_decision_time_by_subject_id = {}
|
|
577
|
+
for subject_id, subject_df in analysis_df.groupby(subject_id_col_name):
|
|
578
|
+
active_subject_df = subject_df[subject_df[active_col_name] == 1]
|
|
579
|
+
action_by_decision_time_by_subject_id[subject_id] = dict(
|
|
577
580
|
zip(
|
|
578
|
-
|
|
581
|
+
active_subject_df[calendar_t_col_name],
|
|
582
|
+
active_subject_df[action_col_name],
|
|
579
583
|
)
|
|
580
584
|
)
|
|
581
|
-
|
|
585
|
+
policy_num_by_decision_time_by_subject_id[subject_id] = dict(
|
|
582
586
|
zip(
|
|
583
|
-
|
|
584
|
-
|
|
587
|
+
active_subject_df[calendar_t_col_name],
|
|
588
|
+
active_subject_df[policy_num_col_name],
|
|
585
589
|
)
|
|
586
590
|
)
|
|
587
|
-
return
|
|
591
|
+
return (
|
|
592
|
+
action_by_decision_time_by_subject_id,
|
|
593
|
+
policy_num_by_decision_time_by_subject_id,
|
|
594
|
+
)
|