lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,8 +11,6 @@ import numpy as np
11
11
  import jax.numpy as jnp
12
12
  import pandas as pd
13
13
 
14
- from .constants import InverseStabilizationMethods
15
-
16
14
  logger = logging.getLogger(__name__)
17
15
  logging.basicConfig(
18
16
  format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
@@ -27,11 +25,7 @@ def conditional_x_or_one_minus_x(x, condition):
27
25
 
28
26
  def invert_matrix_and_check_conditioning(
29
27
  matrix: np.ndarray,
30
- inverse_stabilization_method: str = InverseStabilizationMethods.NONE,
31
28
  condition_num_threshold: float = 10**4,
32
- ridge_median_singular_value_fraction: str = 0.01,
33
- beta_dim: int = None,
34
- theta_dim: int = None,
35
29
  ):
36
30
  """
37
31
  Check a matrix's condition number and invert it. If the condition number is
@@ -39,139 +33,15 @@ def invert_matrix_and_check_conditioning(
39
33
  Parameters
40
34
  """
41
35
  inverse = None
42
- pre_inversion_condition_number = np.linalg.cond(matrix)
43
- if pre_inversion_condition_number > condition_num_threshold:
36
+ condition_number = np.linalg.cond(matrix)
37
+ if condition_number > condition_num_threshold:
44
38
  logger.warning(
45
- "You are inverting a matrix with a large condition number: %s",
46
- pre_inversion_condition_number,
39
+ "You are inverting a matrix with a potentially large condition number: %s",
40
+ condition_number,
47
41
  )
48
- if (
49
- inverse_stabilization_method
50
- == InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES
51
- ):
52
- logger.info("Trimming small singular values to improve conditioning.")
53
- u, s, vT = np.linalg.svd(matrix, full_matrices=False)
54
- logger.info(
55
- " Sorted singular values: %s",
56
- s,
57
- )
58
- sing_values_above_threshold_cond = s > s.max() / condition_num_threshold
59
- if not np.any(sing_values_above_threshold_cond):
60
- raise RuntimeError(
61
- f"All singular values are below the threshold of {s.max() / condition_num_threshold}. Singular value trimming will not work.",
62
- )
63
- trimmed_pseudoinverse = (
64
- vT.T[:, sing_values_above_threshold_cond]
65
- / s[sing_values_above_threshold_cond]
66
- ) @ u[:, sing_values_above_threshold_cond].T
67
- inverse = trimmed_pseudoinverse
68
- pre_inversion_condition_number = (
69
- s[sing_values_above_threshold_cond].max()
70
- / s[sing_values_above_threshold_cond].min()
71
- )
72
-
73
- logger.info(
74
- "Kept %s out of %s singular values. Condition number of resulting lower-rank-approximation before inversion: %s",
75
- sum(sing_values_above_threshold_cond),
76
- len(s),
77
- pre_inversion_condition_number,
78
- )
79
- elif (
80
- inverse_stabilization_method
81
- == InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER
82
- ):
83
- logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
84
- _, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
85
- logger.info(
86
- "Using fixed condition number threshold of %s to determine lambda.",
87
- condition_num_threshold,
88
- )
89
- lambda_ = (
90
- singular_values.max() / condition_num_threshold - singular_values.min()
91
- )
92
- logger.info("Lambda for ridge regularization: %s", lambda_)
93
- new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
94
- pre_inversion_condition_number = np.linalg.cond(new_matrix)
95
- logger.info(
96
- "Condition number of matrix after ridge regularization: %s",
97
- pre_inversion_condition_number,
98
- )
99
- inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
100
- elif (
101
- inverse_stabilization_method
102
- == InverseStabilizationMethods.ADD_RIDGE_MEDIAN_SINGULAR_VALUE_FRACTION
103
- ):
104
- logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
105
- _, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
106
- logger.info(
107
- "Using median singular value times %s as lambda.",
108
- ridge_median_singular_value_fraction,
109
- )
110
- lambda_ = ridge_median_singular_value_fraction * np.median(singular_values)
111
- logger.info("Lambda for ridge regularization: %s", lambda_)
112
- new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
113
- pre_inversion_condition_number = np.linalg.cond(new_matrix)
114
- logger.info(
115
- "Condition number of matrix after ridge regularization: %s",
116
- pre_inversion_condition_number,
117
- )
118
- inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
119
- elif (
120
- inverse_stabilization_method
121
- == InverseStabilizationMethods.INVERSE_BREAD_STRUCTURE_AWARE_INVERSION
122
- ):
123
- if not beta_dim or not theta_dim:
124
- raise ValueError(
125
- "When using structure-aware inversion, beta_dim and theta_dim must be provided."
126
- )
127
- logger.info(
128
- "Using inverse bread's block lower triangular structure to invert only diagonal blocks."
129
- )
130
- pre_inversion_condition_number = np.linalg.cond(matrix)
131
- inverse = invert_inverse_bread_matrix(
132
- matrix,
133
- beta_dim,
134
- theta_dim,
135
- InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER,
136
- )
137
- elif (
138
- inverse_stabilization_method
139
- == InverseStabilizationMethods.ZERO_OUT_SMALL_OFF_DIAGONALS
140
- ):
141
- if not beta_dim or not theta_dim:
142
- raise ValueError(
143
- "When zeroing out small off diagonals, beta_dim and theta_dim must be provided."
144
- )
145
- logger.info(
146
- "Zeroing out small off-diagonal blocks to improve conditioning."
147
- )
148
- zeroed_matrix = zero_small_off_diagonal_blocks(
149
- matrix,
150
- ([beta_dim] * (matrix.shape[0] // beta_dim)) + [theta_dim],
151
- )
152
- pre_inversion_condition_number = np.linalg.cond(zeroed_matrix)
153
- logger.info(
154
- "Condition number of matrix after zeroing out small off-diagonal blocks: %s",
155
- pre_inversion_condition_number,
156
- )
157
- inverse = np.linalg.solve(zeroed_matrix, np.eye(zeroed_matrix.shape[0]))
158
- elif (
159
- inverse_stabilization_method
160
- == InverseStabilizationMethods.ALL_METHODS_COMPETITION
161
- ):
162
- # TODO: Choose right metric for competition... identity diff might not be it.
163
- raise NotImplementedError(
164
- "All methods competition is not implemented yet. Please choose a specific method."
165
- )
166
- elif inverse_stabilization_method == InverseStabilizationMethods.NONE:
167
- logger.info("No inverse stabilization method applied. Inverting directly.")
168
- else:
169
- raise ValueError(
170
- f"Unknown inverse stabilization method: {inverse_stabilization_method}"
171
- )
172
42
  if inverse is None:
173
43
  inverse = np.linalg.solve(matrix, np.eye(matrix.shape[0]))
174
- return inverse, pre_inversion_condition_number
44
+ return inverse, condition_number
175
45
 
176
46
 
177
47
  def zero_small_off_diagonal_blocks(
@@ -183,7 +53,7 @@ def zero_small_off_diagonal_blocks(
183
53
  Zero off-diagonal blocks whose Frobenius norm is < frobenius_norm_threshold_fraction x
184
54
  Frobenius norm of the diagonal block in the same ROW. One could compare to
185
55
  the same column or both the row and column, but we choose row here since
186
- rows correspond to a single RL update or inference step in the adaptive bread
56
+ rows correspond to a single RL update or inference step in the bread
187
57
  inverse matrices this method is designed for.
188
58
 
189
59
  Args:
@@ -237,18 +107,17 @@ def zero_small_off_diagonal_blocks(
237
107
  return J_trim
238
108
 
239
109
 
240
- def invert_inverse_bread_matrix(
241
- inverse_bread,
110
+ def invert_bread_matrix(
111
+ bread,
242
112
  beta_dim,
243
113
  theta_dim,
244
- diag_inverse_stabilization_method=InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES,
245
114
  ):
246
115
  """
247
- Invert the inverse bread matrix to get the bread matrix. This is a special
116
+ Invert the bread matrix to get the inverse bread matrix. This is a special
248
117
  function in order to take advantage of the block lower triangular structure.
249
118
 
250
119
  The procedure is as follows:
251
- 1. Initialize the inverse matrix B = A^{-1} as a block lower triangular matrix
120
+ 1. Initialize the matrix B = A^{-1} as a block lower triangular matrix
252
121
  with the same block structure as A.
253
122
 
254
123
  2. Compute the diagonal blocks B_{ii}:
@@ -260,24 +129,23 @@ def invert_inverse_bread_matrix(
260
129
  B_{ij} = -A_{ii}^{-1} * sum(A_{ik} * B_{kj} for k in range(j, i))
261
130
  """
262
131
  blocks = []
263
- num_beta_block_rows = (inverse_bread.shape[0] - theta_dim) // beta_dim
132
+ num_beta_block_rows = (bread.shape[0] - theta_dim) // beta_dim
264
133
 
265
134
  # Create upper rows of block of bread (just the beta portion)
266
135
  for i in range(0, num_beta_block_rows):
267
136
  beta_block_row = []
268
137
  beta_diag_inverse = invert_matrix_and_check_conditioning(
269
- inverse_bread[
138
+ bread[
270
139
  beta_dim * i : beta_dim * (i + 1),
271
140
  beta_dim * i : beta_dim * (i + 1),
272
141
  ],
273
- diag_inverse_stabilization_method,
274
142
  )[0]
275
143
  for j in range(0, num_beta_block_rows):
276
144
  if i > j:
277
145
  beta_block_row.append(
278
146
  -beta_diag_inverse
279
147
  @ sum(
280
- inverse_bread[
148
+ bread[
281
149
  beta_dim * i : beta_dim * (i + 1),
282
150
  beta_dim * k : beta_dim * (k + 1),
283
151
  ]
@@ -299,17 +167,16 @@ def invert_inverse_bread_matrix(
299
167
  # Create the bottom block row of bread (the theta portion)
300
168
  theta_block_row = []
301
169
  theta_diag_inverse = invert_matrix_and_check_conditioning(
302
- inverse_bread[
170
+ bread[
303
171
  -theta_dim:,
304
172
  -theta_dim:,
305
173
  ],
306
- diag_inverse_stabilization_method,
307
174
  )[0]
308
175
  for k in range(0, num_beta_block_rows):
309
176
  theta_block_row.append(
310
177
  -theta_diag_inverse
311
178
  @ sum(
312
- inverse_bread[
179
+ bread[
313
180
  -theta_dim:,
314
181
  beta_dim * h : beta_dim * (h + 1),
315
182
  ]
@@ -378,9 +245,9 @@ def confirm_input_check_result(message, suppress_interaction, error=None):
378
245
  print("\nPlease enter 'y' or 'n'.\n")
379
246
 
380
247
 
381
- def get_in_study_df_column(study_df, col_name, in_study_col_name):
248
+ def get_active_df_column(analysis_df, col_name, active_col_name):
382
249
  return jnp.array(
383
- study_df.loc[study_df[in_study_col_name] == 1, col_name]
250
+ analysis_df.loc[analysis_df[active_col_name] == 1, col_name]
384
251
  .to_numpy()
385
252
  .reshape(-1, 1)
386
253
  )
@@ -408,7 +275,7 @@ def get_radon_nikodym_weight(
408
275
  action_prob_func: callable,
409
276
  action_prob_func_args_beta_index: int,
410
277
  action: int,
411
- *action_prob_func_args_single_user: tuple[Any, ...],
278
+ *action_prob_func_args_single_subject: tuple[Any, ...],
412
279
  ):
413
280
  """
414
281
  Computes a ratio of action probabilities under two sets of algorithm parameters:
@@ -426,13 +293,13 @@ def get_radon_nikodym_weight(
426
293
  The beta value to use in the denominator. NOT involved in differentation!
427
294
  action_prob_func (callable):
428
295
  The function used to compute the probability of action 1 at a given decision time for
429
- a particular user given their state and the algorithm parameters.
296
+ a particular subject given their state and the algorithm parameters.
430
297
  action_prob_func_args_beta_index (int):
431
298
  The index of the beta argument in the action probability function's arguments.
432
299
  action (int):
433
300
  The actual taken action at the relevant decision time.
434
- *action_prob_func_args_single_user (tuple[Any, ...]):
435
- The arguments to the action probability function for the relevant user at this time.
301
+ *action_prob_func_args_single_subject (tuple[Any, ...]):
302
+ The arguments to the action probability function for the relevant subject at this time.
436
303
 
437
304
  Returns:
438
305
  jnp.float32: The Radon-Nikodym weight.
@@ -440,15 +307,17 @@ def get_radon_nikodym_weight(
440
307
  """
441
308
 
442
309
  # numerator
443
- pi_beta = action_prob_func(*action_prob_func_args_single_user)
310
+ pi_beta = action_prob_func(*action_prob_func_args_single_subject)
444
311
 
445
312
  # denominator, where we thread in beta_target so that differentiation with respect to the
446
313
  # original beta in the arguments leaves this alone.
447
- beta_target_action_prob_func_args_single_user = [*action_prob_func_args_single_user]
448
- beta_target_action_prob_func_args_single_user[action_prob_func_args_beta_index] = (
449
- beta_target
450
- )
451
- pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_user)
314
+ beta_target_action_prob_func_args_single_subject = [
315
+ *action_prob_func_args_single_subject
316
+ ]
317
+ beta_target_action_prob_func_args_single_subject[
318
+ action_prob_func_args_beta_index
319
+ ] = beta_target
320
+ pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_subject)
452
321
 
453
322
  return conditional_x_or_one_minus_x(pi_beta, action) / conditional_x_or_one_minus_x(
454
323
  pi_beta_target, action
@@ -456,7 +325,7 @@ def get_radon_nikodym_weight(
456
325
 
457
326
 
458
327
  def get_min_time_by_policy_num(
459
- single_user_policy_num_by_decision_time, beta_index_by_policy_num
328
+ single_subject_policy_num_by_decision_time, beta_index_by_policy_num
460
329
  ):
461
330
  """
462
331
  Returns a dictionary mapping each policy number to the first time it was applicable,
@@ -464,12 +333,12 @@ def get_min_time_by_policy_num(
464
333
  """
465
334
  min_time_by_policy_num = {}
466
335
  first_time_after_first_update = None
467
- for decision_time, policy_num in single_user_policy_num_by_decision_time.items():
336
+ for decision_time, policy_num in single_subject_policy_num_by_decision_time.items():
468
337
  if policy_num not in min_time_by_policy_num:
469
338
  min_time_by_policy_num[policy_num] = decision_time
470
339
 
471
340
  # Grab the first time where a non-initial, non-fallback policy is used.
472
- # Assumes single_user_policy_num_by_decision_time is sorted.
341
+ # Assumes single_subject_policy_num_by_decision_time is sorted.
473
342
  if (
474
343
  policy_num in beta_index_by_policy_num
475
344
  and first_time_after_first_update is None
@@ -494,10 +363,10 @@ def calculate_beta_dim(
494
363
  int: The dimension of the beta vector.
495
364
  """
496
365
  for decision_time in action_prob_func_args:
497
- for user_id in action_prob_func_args[decision_time]:
498
- if action_prob_func_args[decision_time][user_id]:
366
+ for subject_id in action_prob_func_args[decision_time]:
367
+ if action_prob_func_args[decision_time][subject_id]:
499
368
  return len(
500
- action_prob_func_args[decision_time][user_id][
369
+ action_prob_func_args[decision_time][subject_id][
501
370
  action_prob_func_args_beta_index
502
371
  ]
503
372
  )
@@ -507,7 +376,7 @@ def calculate_beta_dim(
507
376
 
508
377
 
509
378
  def construct_beta_index_by_policy_num_map(
510
- study_df: pd.DataFrame, policy_num_col_name: str, in_study_col_name: str
379
+ analysis_df: pd.DataFrame, policy_num_col_name: str, active_col_name: str
511
380
  ) -> tuple[dict[int | float, int], int | float]:
512
381
  """
513
382
  Constructs a mapping from non-initial, non-fallback policy numbers to the index of the
@@ -524,8 +393,9 @@ def construct_beta_index_by_policy_num_map(
524
393
  """
525
394
 
526
395
  unique_sorted_non_fallback_policy_nums = sorted(
527
- study_df[
528
- (study_df[policy_num_col_name] >= 0) & (study_df[in_study_col_name] == 1)
396
+ analysis_df[
397
+ (analysis_df[policy_num_col_name] >= 0)
398
+ & (analysis_df[active_col_name] == 1)
529
399
  ][policy_num_col_name]
530
400
  .unique()
531
401
  .tolist()
@@ -550,10 +420,10 @@ def collect_all_post_update_betas(
550
420
  """
551
421
  all_post_update_betas = []
552
422
  for policy_num in sorted(beta_index_by_policy_num.keys()):
553
- for user_id in alg_update_func_args[policy_num]:
554
- if alg_update_func_args[policy_num][user_id]:
423
+ for subject_id in alg_update_func_args[policy_num]:
424
+ if alg_update_func_args[policy_num][subject_id]:
555
425
  all_post_update_betas.append(
556
- alg_update_func_args[policy_num][user_id][
426
+ alg_update_func_args[policy_num][subject_id][
557
427
  alg_update_func_args_beta_index
558
428
  ]
559
429
  )
@@ -561,27 +431,31 @@ def collect_all_post_update_betas(
561
431
  return jnp.array(all_post_update_betas)
562
432
 
563
433
 
564
- def extract_action_and_policy_by_decision_time_by_user_id(
565
- study_df,
566
- user_id_col_name,
567
- in_study_col_name,
434
+ def extract_action_and_policy_by_decision_time_by_subject_id(
435
+ analysis_df,
436
+ subject_id_col_name,
437
+ active_col_name,
568
438
  calendar_t_col_name,
569
439
  action_col_name,
570
440
  policy_num_col_name,
571
441
  ):
572
- action_by_decision_time_by_user_id = {}
573
- policy_num_by_decision_time_by_user_id = {}
574
- for user_id, user_df in study_df.groupby(user_id_col_name):
575
- in_study_user_df = user_df[user_df[in_study_col_name] == 1]
576
- action_by_decision_time_by_user_id[user_id] = dict(
442
+ action_by_decision_time_by_subject_id = {}
443
+ policy_num_by_decision_time_by_subject_id = {}
444
+ for subject_id, subject_df in analysis_df.groupby(subject_id_col_name):
445
+ active_subject_df = subject_df[subject_df[active_col_name] == 1]
446
+ action_by_decision_time_by_subject_id[subject_id] = dict(
577
447
  zip(
578
- in_study_user_df[calendar_t_col_name], in_study_user_df[action_col_name]
448
+ active_subject_df[calendar_t_col_name],
449
+ active_subject_df[action_col_name],
579
450
  )
580
451
  )
581
- policy_num_by_decision_time_by_user_id[user_id] = dict(
452
+ policy_num_by_decision_time_by_subject_id[subject_id] = dict(
582
453
  zip(
583
- in_study_user_df[calendar_t_col_name],
584
- in_study_user_df[policy_num_col_name],
454
+ active_subject_df[calendar_t_col_name],
455
+ active_subject_df[policy_num_col_name],
585
456
  )
586
457
  )
587
- return action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id
458
+ return (
459
+ action_by_decision_time_by_subject_id,
460
+ policy_num_by_decision_time_by_subject_id,
461
+ )