lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -378,9 +378,9 @@ def confirm_input_check_result(message, suppress_interaction, error=None):
378
378
  print("\nPlease enter 'y' or 'n'.\n")
379
379
 
380
380
 
381
- def get_in_study_df_column(study_df, col_name, in_study_col_name):
381
+ def get_active_df_column(analysis_df, col_name, active_col_name):
382
382
  return jnp.array(
383
- study_df.loc[study_df[in_study_col_name] == 1, col_name]
383
+ analysis_df.loc[analysis_df[active_col_name] == 1, col_name]
384
384
  .to_numpy()
385
385
  .reshape(-1, 1)
386
386
  )
@@ -408,7 +408,7 @@ def get_radon_nikodym_weight(
408
408
  action_prob_func: callable,
409
409
  action_prob_func_args_beta_index: int,
410
410
  action: int,
411
- *action_prob_func_args_single_user: tuple[Any, ...],
411
+ *action_prob_func_args_single_subject: tuple[Any, ...],
412
412
  ):
413
413
  """
414
414
  Computes a ratio of action probabilities under two sets of algorithm parameters:
@@ -426,13 +426,13 @@ def get_radon_nikodym_weight(
426
426
  The beta value to use in the denominator. NOT involved in differentation!
427
427
  action_prob_func (callable):
428
428
  The function used to compute the probability of action 1 at a given decision time for
429
- a particular user given their state and the algorithm parameters.
429
+ a particular subject given their state and the algorithm parameters.
430
430
  action_prob_func_args_beta_index (int):
431
431
  The index of the beta argument in the action probability function's arguments.
432
432
  action (int):
433
433
  The actual taken action at the relevant decision time.
434
- *action_prob_func_args_single_user (tuple[Any, ...]):
435
- The arguments to the action probability function for the relevant user at this time.
434
+ *action_prob_func_args_single_subject (tuple[Any, ...]):
435
+ The arguments to the action probability function for the relevant subject at this time.
436
436
 
437
437
  Returns:
438
438
  jnp.float32: The Radon-Nikodym weight.
@@ -440,15 +440,17 @@ def get_radon_nikodym_weight(
440
440
  """
441
441
 
442
442
  # numerator
443
- pi_beta = action_prob_func(*action_prob_func_args_single_user)
443
+ pi_beta = action_prob_func(*action_prob_func_args_single_subject)
444
444
 
445
445
  # denominator, where we thread in beta_target so that differentiation with respect to the
446
446
  # original beta in the arguments leaves this alone.
447
- beta_target_action_prob_func_args_single_user = [*action_prob_func_args_single_user]
448
- beta_target_action_prob_func_args_single_user[action_prob_func_args_beta_index] = (
449
- beta_target
450
- )
451
- pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_user)
447
+ beta_target_action_prob_func_args_single_subject = [
448
+ *action_prob_func_args_single_subject
449
+ ]
450
+ beta_target_action_prob_func_args_single_subject[
451
+ action_prob_func_args_beta_index
452
+ ] = beta_target
453
+ pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_subject)
452
454
 
453
455
  return conditional_x_or_one_minus_x(pi_beta, action) / conditional_x_or_one_minus_x(
454
456
  pi_beta_target, action
@@ -456,7 +458,7 @@ def get_radon_nikodym_weight(
456
458
 
457
459
 
458
460
  def get_min_time_by_policy_num(
459
- single_user_policy_num_by_decision_time, beta_index_by_policy_num
461
+ single_subject_policy_num_by_decision_time, beta_index_by_policy_num
460
462
  ):
461
463
  """
462
464
  Returns a dictionary mapping each policy number to the first time it was applicable,
@@ -464,12 +466,12 @@ def get_min_time_by_policy_num(
464
466
  """
465
467
  min_time_by_policy_num = {}
466
468
  first_time_after_first_update = None
467
- for decision_time, policy_num in single_user_policy_num_by_decision_time.items():
469
+ for decision_time, policy_num in single_subject_policy_num_by_decision_time.items():
468
470
  if policy_num not in min_time_by_policy_num:
469
471
  min_time_by_policy_num[policy_num] = decision_time
470
472
 
471
473
  # Grab the first time where a non-initial, non-fallback policy is used.
472
- # Assumes single_user_policy_num_by_decision_time is sorted.
474
+ # Assumes single_subject_policy_num_by_decision_time is sorted.
473
475
  if (
474
476
  policy_num in beta_index_by_policy_num
475
477
  and first_time_after_first_update is None
@@ -494,10 +496,10 @@ def calculate_beta_dim(
494
496
  int: The dimension of the beta vector.
495
497
  """
496
498
  for decision_time in action_prob_func_args:
497
- for user_id in action_prob_func_args[decision_time]:
498
- if action_prob_func_args[decision_time][user_id]:
499
+ for subject_id in action_prob_func_args[decision_time]:
500
+ if action_prob_func_args[decision_time][subject_id]:
499
501
  return len(
500
- action_prob_func_args[decision_time][user_id][
502
+ action_prob_func_args[decision_time][subject_id][
501
503
  action_prob_func_args_beta_index
502
504
  ]
503
505
  )
@@ -507,7 +509,7 @@ def calculate_beta_dim(
507
509
 
508
510
 
509
511
  def construct_beta_index_by_policy_num_map(
510
- study_df: pd.DataFrame, policy_num_col_name: str, in_study_col_name: str
512
+ analysis_df: pd.DataFrame, policy_num_col_name: str, active_col_name: str
511
513
  ) -> tuple[dict[int | float, int], int | float]:
512
514
  """
513
515
  Constructs a mapping from non-initial, non-fallback policy numbers to the index of the
@@ -524,8 +526,9 @@ def construct_beta_index_by_policy_num_map(
524
526
  """
525
527
 
526
528
  unique_sorted_non_fallback_policy_nums = sorted(
527
- study_df[
528
- (study_df[policy_num_col_name] >= 0) & (study_df[in_study_col_name] == 1)
529
+ analysis_df[
530
+ (analysis_df[policy_num_col_name] >= 0)
531
+ & (analysis_df[active_col_name] == 1)
529
532
  ][policy_num_col_name]
530
533
  .unique()
531
534
  .tolist()
@@ -550,10 +553,10 @@ def collect_all_post_update_betas(
550
553
  """
551
554
  all_post_update_betas = []
552
555
  for policy_num in sorted(beta_index_by_policy_num.keys()):
553
- for user_id in alg_update_func_args[policy_num]:
554
- if alg_update_func_args[policy_num][user_id]:
556
+ for subject_id in alg_update_func_args[policy_num]:
557
+ if alg_update_func_args[policy_num][subject_id]:
555
558
  all_post_update_betas.append(
556
- alg_update_func_args[policy_num][user_id][
559
+ alg_update_func_args[policy_num][subject_id][
557
560
  alg_update_func_args_beta_index
558
561
  ]
559
562
  )
@@ -561,27 +564,31 @@ def collect_all_post_update_betas(
561
564
  return jnp.array(all_post_update_betas)
562
565
 
563
566
 
564
- def extract_action_and_policy_by_decision_time_by_user_id(
565
- study_df,
566
- user_id_col_name,
567
- in_study_col_name,
567
+ def extract_action_and_policy_by_decision_time_by_subject_id(
568
+ analysis_df,
569
+ subject_id_col_name,
570
+ active_col_name,
568
571
  calendar_t_col_name,
569
572
  action_col_name,
570
573
  policy_num_col_name,
571
574
  ):
572
- action_by_decision_time_by_user_id = {}
573
- policy_num_by_decision_time_by_user_id = {}
574
- for user_id, user_df in study_df.groupby(user_id_col_name):
575
- in_study_user_df = user_df[user_df[in_study_col_name] == 1]
576
- action_by_decision_time_by_user_id[user_id] = dict(
575
+ action_by_decision_time_by_subject_id = {}
576
+ policy_num_by_decision_time_by_subject_id = {}
577
+ for subject_id, subject_df in analysis_df.groupby(subject_id_col_name):
578
+ active_subject_df = subject_df[subject_df[active_col_name] == 1]
579
+ action_by_decision_time_by_subject_id[subject_id] = dict(
577
580
  zip(
578
- in_study_user_df[calendar_t_col_name], in_study_user_df[action_col_name]
581
+ active_subject_df[calendar_t_col_name],
582
+ active_subject_df[action_col_name],
579
583
  )
580
584
  )
581
- policy_num_by_decision_time_by_user_id[user_id] = dict(
585
+ policy_num_by_decision_time_by_subject_id[subject_id] = dict(
582
586
  zip(
583
- in_study_user_df[calendar_t_col_name],
584
- in_study_user_df[policy_num_col_name],
587
+ active_subject_df[calendar_t_col_name],
588
+ active_subject_df[policy_num_col_name],
585
589
  )
586
590
  )
587
- return action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id
591
+ return (
592
+ action_by_decision_time_by_subject_id,
593
+ policy_num_by_decision_time_by_subject_id,
594
+ )