lifejacket 1.0.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@ import jax
12
12
  from jax import numpy as jnp
13
13
  import pandas as pd
14
14
 
15
- from . import after_study_analysis
15
+ from . import post_deployment_analysis
16
16
  from .constants import FunctionTypes
17
17
  from .vmap_helpers import stack_batched_arg_lists_into_tensors
18
18
 
@@ -25,8 +25,8 @@ logging.basicConfig(
25
25
 
26
26
 
27
27
  def get_datum_for_blowup_supervised_learning(
28
- joint_adjusted_bread_inverse_matrix,
29
- joint_adjusted_bread_inverse_cond,
28
+ joint_adjusted_bread_matrix,
29
+ joint_adjusted_bread_cond,
30
30
  avg_estimating_function_stack,
31
31
  per_subject_estimating_function_stacks,
32
32
  all_post_update_betas,
@@ -63,10 +63,10 @@ def get_datum_for_blowup_supervised_learning(
63
63
  A few plots are produced along the way to help visualize the data.
64
64
 
65
65
  Args:
66
- joint_adjusted_bread_inverse_matrix (jnp.ndarray):
67
- The joint adjusted bread inverse matrix.
68
- joint_adjusted_bread_inverse_cond (float):
69
- The condition number of the joint adjusted bread inverse matrix.
66
+ joint_adjusted_bread_matrix (jnp.ndarray):
67
+ The joint adjusted bread matrix.
68
+ joint_adjusted_bread_cond (float):
69
+ The condition number of the joint adjusted bread matrix.
70
70
  avg_estimating_function_stack (jnp.ndarray):
71
71
  The average estimating function stack across subjects.
72
72
  per_subject_estimating_function_stacks (jnp.ndarray):
@@ -125,7 +125,7 @@ def get_datum_for_blowup_supervised_learning(
125
125
  dict[str, Any]: A dictionary containing features and the label for supervised learning.
126
126
  """
127
127
  num_diagonal_blocks = (
128
- (joint_adjusted_bread_inverse_matrix.shape[0] - theta_dim) // beta_dim
128
+ (joint_adjusted_bread_matrix.shape[0] - theta_dim) // beta_dim
129
129
  ) + 1
130
130
  diagonal_block_sizes = ([beta_dim] * (num_diagonal_blocks - 1)) + [theta_dim]
131
131
 
@@ -144,7 +144,7 @@ def get_datum_for_blowup_supervised_learning(
144
144
  row_slice = slice(block_bounds[i], block_bounds[i + 1])
145
145
  col_slice = slice(block_bounds[j], block_bounds[j + 1])
146
146
  block_norm = np.linalg.norm(
147
- joint_adjusted_bread_inverse_matrix[row_slice, col_slice],
147
+ joint_adjusted_bread_matrix[row_slice, col_slice],
148
148
  ord="fro",
149
149
  )
150
150
  # We will sum here and take the square root later
@@ -155,9 +155,9 @@ def get_datum_for_blowup_supervised_learning(
155
155
  # handle diagonal blocks
156
156
  sl = slice(block_bounds[i], block_bounds[i + 1])
157
157
  diag_norms.append(
158
- np.linalg.norm(joint_adjusted_bread_inverse_matrix[sl, sl], ord="fro")
158
+ np.linalg.norm(joint_adjusted_bread_matrix[sl, sl], ord="fro")
159
159
  )
160
- diag_conds.append(np.linalg.cond(joint_adjusted_bread_inverse_matrix[sl, sl]))
160
+ diag_conds.append(np.linalg.cond(joint_adjusted_bread_matrix[sl, sl]))
161
161
 
162
162
  # Sqrt each row/col sum to truly get row/column norms.
163
163
  # Perhaps not necessary for learning, but more natural
@@ -214,8 +214,8 @@ def get_datum_for_blowup_supervised_learning(
214
214
  reward_means_by_t = grouped_reward.mean().values
215
215
  reward_stds_by_t = grouped_reward.std().values
216
216
 
217
- joint_bread_inverse_min_singular_value = np.linalg.svd(
218
- joint_adjusted_bread_inverse_matrix, compute_uv=False
217
+ joint_bread_min_singular_value = np.linalg.svd(
218
+ joint_adjusted_bread_matrix, compute_uv=False
219
219
  )[-1]
220
220
 
221
221
  max_reward = analysis_df.loc[in_study_mask][reward_col_name].max()
@@ -227,7 +227,7 @@ def get_datum_for_blowup_supervised_learning(
227
227
  premature_thetas,
228
228
  premature_adjusted_sandwiches,
229
229
  premature_classical_sandwiches,
230
- premature_joint_adjusted_bread_inverse_condition_numbers,
230
+ premature_joint_adjusted_bread_condition_numbers,
231
231
  premature_avg_inference_estimating_functions,
232
232
  ) = calculate_sequence_of_premature_adjusted_estimates(
233
233
  analysis_df,
@@ -250,7 +250,7 @@ def get_datum_for_blowup_supervised_learning(
250
250
  inference_action_prob_decision_times_by_subject_id,
251
251
  action_prob_func_args,
252
252
  action_by_decision_time_by_subject_id,
253
- joint_adjusted_bread_inverse_matrix,
253
+ joint_adjusted_bread_matrix,
254
254
  per_subject_estimating_function_stacks,
255
255
  beta_dim,
256
256
  )
@@ -261,23 +261,23 @@ def get_datum_for_blowup_supervised_learning(
261
261
  atol=1e-3,
262
262
  )
263
263
 
264
- # Plot premature joint adjusted bread inverse log condition numbers
264
+ # Plot premature joint adjusted bread log condition numbers
265
265
  plt.clear_figure()
266
- plt.title("Premature Joint Adaptive Bread Inverse Log Condition Numbers")
266
+ plt.title("Premature Joint Adjusted Bread Inverse Log Condition Numbers")
267
267
  plt.xlabel("Premature Update Index")
268
268
  plt.ylabel("Log Condition Number")
269
269
  plt.scatter(
270
- np.log(premature_joint_adjusted_bread_inverse_condition_numbers),
270
+ np.log(premature_joint_adjusted_bread_condition_numbers),
271
271
  color="blue+",
272
272
  )
273
273
  plt.grid(True)
274
274
  plt.xticks(
275
275
  range(
276
276
  0,
277
- len(premature_joint_adjusted_bread_inverse_condition_numbers),
277
+ len(premature_joint_adjusted_bread_condition_numbers),
278
278
  max(
279
279
  1,
280
- len(premature_joint_adjusted_bread_inverse_condition_numbers) // 10,
280
+ len(premature_joint_adjusted_bread_condition_numbers) // 10,
281
281
  ),
282
282
  )
283
283
  )
@@ -287,7 +287,7 @@ def get_datum_for_blowup_supervised_learning(
287
287
  num_diag = premature_adjusted_sandwiches.shape[-1]
288
288
  for i in range(num_diag):
289
289
  plt.clear_figure()
290
- plt.title(f"Premature Adaptive Sandwich Diagonal Element {i}")
290
+ plt.title(f"Premature Adjusted Sandwich Diagonal Element {i}")
291
291
  plt.xlabel("Premature Update Index")
292
292
  plt.ylabel(f"Variance (Diagonal {i})")
293
293
  plt.scatter(np.array(premature_adjusted_sandwiches[:, i, i]), color="blue+")
@@ -303,7 +303,7 @@ def get_datum_for_blowup_supervised_learning(
303
303
 
304
304
  plt.clear_figure()
305
305
  plt.title(
306
- f"Premature Adaptive Sandwich Diagonal Element {i} Ratio to Classical"
306
+ f"Premature Adjusted Sandwich Diagonal Element {i} Ratio to Classical"
307
307
  )
308
308
  plt.xlabel("Premature Update Index")
309
309
  plt.ylabel(f"Variance (Diagonal {i})")
@@ -338,7 +338,7 @@ def get_datum_for_blowup_supervised_learning(
338
338
  plt.show()
339
339
 
340
340
  # Grab predictors related to premature Phi-dot-bars
341
- RL_stack_beta_derivatives_block = joint_adjusted_bread_inverse_matrix[
341
+ RL_stack_beta_derivatives_block = joint_adjusted_bread_matrix[
342
342
  :-theta_dim, :-theta_dim
343
343
  ]
344
344
  num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
@@ -397,8 +397,8 @@ def get_datum_for_blowup_supervised_learning(
397
397
  )
398
398
  return {
399
399
  **{
400
- "joint_bread_inverse_condition_number": joint_adjusted_bread_inverse_cond,
401
- "joint_bread_inverse_min_singular_value": joint_bread_inverse_min_singular_value,
400
+ "joint_bread_condition_number": joint_adjusted_bread_cond,
401
+ "joint_bread_min_singular_value": joint_bread_min_singular_value,
402
402
  "max_reward": max_reward,
403
403
  "norm_avg_estimating_function_stack": norm_avg_estimating_function_stack,
404
404
  "max_estimating_function_stack_norm": max_estimating_function_stack_norm,
@@ -455,12 +455,10 @@ def get_datum_for_blowup_supervised_learning(
455
455
  },
456
456
  **{f"theta_est_{i}": theta_est[i].item() for i in range(len(theta_est))},
457
457
  **{
458
- f"premature_joint_adjusted_bread_inverse_condition_number_{i}": premature_joint_adjusted_bread_inverse_condition_numbers[
458
+ f"premature_joint_adjusted_bread_condition_number_{i}": premature_joint_adjusted_bread_condition_numbers[
459
459
  i
460
460
  ]
461
- for i in range(
462
- len(premature_joint_adjusted_bread_inverse_condition_numbers)
463
- )
461
+ for i in range(len(premature_joint_adjusted_bread_condition_numbers))
464
462
  },
465
463
  **{
466
464
  f"premature_adjusted_sandwich_update_{i}_diag_position_{j}": premature_adjusted_sandwich[
@@ -526,7 +524,7 @@ def calculate_sequence_of_premature_adjusted_estimates(
526
524
  action_by_decision_time_by_subject_id: dict[
527
525
  collections.abc.Hashable, dict[int, int]
528
526
  ],
529
- full_joint_adjusted_bread_inverse_matrix: jnp.ndarray,
527
+ full_joint_adjusted_bread_matrix: jnp.ndarray,
530
528
  per_subject_estimating_function_stacks: jnp.ndarray,
531
529
  beta_dim: int,
532
530
  ) -> jnp.ndarray:
@@ -584,8 +582,8 @@ def calculate_sequence_of_premature_adjusted_estimates(
584
582
  action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
585
583
  A dictionary mapping subject IDs to their respective actions taken at each decision time.
586
584
  Only applies to in-study decision times!
587
- full_joint_adjusted_bread_inverse_matrix (jnp.ndarray):
588
- The full joint adjusted bread inverse matrix as a NumPy array.
585
+ full_joint_adjusted_bread_matrix (jnp.ndarray):
586
+ The full joint adjusted bread matrix as a NumPy array.
589
587
  per_subject_estimating_function_stacks (jnp.ndarray):
590
588
  A NumPy array containing all per-subject (weighted) estimating function stacks.
591
589
  beta_dim (int):
@@ -598,7 +596,7 @@ def calculate_sequence_of_premature_adjusted_estimates(
598
596
  # variance estimates pretending that each was the final policy.
599
597
  premature_adjusted_sandwiches = []
600
598
  premature_thetas = []
601
- premature_joint_adjusted_bread_inverse_condition_numbers = []
599
+ premature_joint_adjusted_bread_condition_numbers = []
602
600
  premature_avg_inference_estimating_functions = []
603
601
  premature_classical_sandwiches = []
604
602
  logger.info(
@@ -611,12 +609,10 @@ def calculate_sequence_of_premature_adjusted_estimates(
611
609
  )
612
610
  pretend_max_policy = policy_num
613
611
 
614
- truncated_joint_adjusted_bread_inverse_matrix = (
615
- full_joint_adjusted_bread_inverse_matrix[
616
- : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
617
- : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
618
- ]
619
- )
612
+ truncated_joint_adjusted_bread_matrix = full_joint_adjusted_bread_matrix[
613
+ : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
614
+ : (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
615
+ ]
620
616
 
621
617
  max_decision_time = analysis_df[
622
618
  analysis_df["policy_num"] == pretend_max_policy
@@ -643,7 +639,7 @@ def calculate_sequence_of_premature_adjusted_estimates(
643
639
  }
644
640
 
645
641
  truncated_inference_func_args_by_subject_id, _, _ = (
646
- after_study_analysis.process_inference_func_args(
642
+ post_deployment_analysis.process_inference_func_args(
647
643
  inference_func,
648
644
  inference_func_args_theta_index,
649
645
  truncated_analysis_df,
@@ -690,7 +686,7 @@ def calculate_sequence_of_premature_adjusted_estimates(
690
686
  premature_classical_sandwich,
691
687
  premature_avg_inference_estimating_function,
692
688
  ) = construct_premature_classical_and_adjusted_sandwiches(
693
- truncated_joint_adjusted_bread_inverse_matrix,
689
+ truncated_joint_adjusted_bread_matrix,
694
690
  truncated_per_subject_estimating_function_stacks,
695
691
  premature_theta,
696
692
  truncated_all_post_update_betas,
@@ -720,13 +716,13 @@ def calculate_sequence_of_premature_adjusted_estimates(
720
716
  jnp.array(premature_thetas),
721
717
  jnp.array(premature_adjusted_sandwiches),
722
718
  jnp.array(premature_classical_sandwiches),
723
- jnp.array(premature_joint_adjusted_bread_inverse_condition_numbers),
719
+ jnp.array(premature_joint_adjusted_bread_condition_numbers),
724
720
  jnp.array(premature_avg_inference_estimating_functions),
725
721
  )
726
722
 
727
723
 
728
724
  def construct_premature_classical_and_adjusted_sandwiches(
729
- truncated_joint_adjusted_bread_inverse_matrix: jnp.ndarray,
725
+ truncated_joint_adjusted_bread_matrix: jnp.ndarray,
730
726
  per_subject_truncated_estimating_function_stacks: jnp.ndarray,
731
727
  theta: jnp.ndarray,
732
728
  all_post_update_betas: jnp.ndarray,
@@ -769,15 +765,15 @@ def construct_premature_classical_and_adjusted_sandwiches(
769
765
 
770
766
  This is done by computing and differentiating the new average inference estimating function
771
767
  with respect to the betas and theta, and stitching this together with the existing
772
- adjusted bread inverse matrix portion (corresponding to the updates still under consideration)
773
- to form the new premature joint adjusted bread inverse matrix.
768
+ adjusted bread matrix portion (corresponding to the updates still under consideration)
769
+ to form the new premature joint adjusted bread matrix.
774
770
 
775
771
  Args:
776
- truncated_joint_adjusted_bread_inverse_matrix (jnp.ndarray):
777
- A 2-D JAX NumPy array holding the existing joint adjusted bread inverse but
772
+ truncated_joint_adjusted_bread_matrix (jnp.ndarray):
773
+ A 2-D JAX NumPy array holding the existing joint adjusted bread but
778
774
  with rows corresponding to updates not under consideration and inference dropped.
779
775
  We will stitch this together with the newly computed inference portion to form
780
- our "premature" joint adjusted bread inverse matrix.
776
+ our "premature" joint adjusted bread matrix.
781
777
  per_subject_truncated_estimating_function_stacks (jnp.ndarray):
782
778
  A 2-D JAX NumPy array holding the existing per-subject weighted estimating function
783
779
  stacks but with rows corresponding to updates not under consideration dropped.
@@ -828,14 +824,14 @@ def construct_premature_classical_and_adjusted_sandwiches(
828
824
  jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
829
825
  jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
830
826
  A tuple containing:
831
- - The joint adjusted inverse bread matrix.
827
+ - The joint adjusted bread matrix.
832
828
  - The joint adjusted bread matrix.
833
829
  - The joint adjusted meat matrix.
834
- - The classical inverse bread matrix.
830
+ - The classical bread matrix.
835
831
  - The classical bread matrix.
836
832
  - The classical meat matrix.
837
833
  - The average (weighted) inference estimating function.
838
- - The joint adjusted inverse bread matrix condition number.
834
+ - The joint adjusted bread matrix condition number.
839
835
  """
840
836
  logger.info(
841
837
  "Differentiating average weighted inference estimating function stack and collecting auxiliary values."
@@ -847,12 +843,12 @@ def construct_premature_classical_and_adjusted_sandwiches(
847
843
  per_subject_inference_estimating_functions,
848
844
  avg_inference_estimating_function,
849
845
  per_subject_classical_meat_contributions,
850
- per_subject_classical_bread_inverse_contributions,
846
+ per_subject_classical_bread_contributions,
851
847
  ) = jax.jacrev(get_weighted_inference_estimating_functions_only, has_aux=True)(
852
848
  # While JAX can technically differentiate with respect to a list of JAX arrays,
853
849
  # it is more efficient to flatten them into a single array. This is done
854
850
  # here to improve performance. We can simply unflatten them inside the function.
855
- after_study_analysis.flatten_params(all_post_update_betas, theta),
851
+ post_deployment_analysis.flatten_params(all_post_update_betas, theta),
856
852
  all_post_update_betas.shape[1],
857
853
  theta.shape[0],
858
854
  subject_ids,
@@ -871,13 +867,13 @@ def construct_premature_classical_and_adjusted_sandwiches(
871
867
  action_by_decision_time_by_subject_id,
872
868
  )
873
869
 
874
- joint_adjusted_bread_inverse_matrix = jnp.block(
870
+ joint_adjusted_bread_matrix = jnp.block(
875
871
  [
876
872
  [
877
- truncated_joint_adjusted_bread_inverse_matrix,
873
+ truncated_joint_adjusted_bread_matrix,
878
874
  np.zeros(
879
875
  (
880
- truncated_joint_adjusted_bread_inverse_matrix.shape[0],
876
+ truncated_joint_adjusted_bread_matrix.shape[0],
881
877
  new_inference_block_row.shape[0],
882
878
  )
883
879
  ),
@@ -902,34 +898,30 @@ def construct_premature_classical_and_adjusted_sandwiches(
902
898
  per_subject_adjusted_meat_contributions, axis=0
903
899
  )
904
900
 
905
- classical_bread_inverse_matrix = jnp.mean(
906
- per_subject_classical_bread_inverse_contributions, axis=0
907
- )
901
+ classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
908
902
  classical_meat_matrix = jnp.mean(per_subject_classical_meat_contributions, axis=0)
909
903
 
910
904
  num_subjects = subject_ids.shape[0]
911
905
  joint_adjusted_sandwich = (
912
- after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
913
- joint_adjusted_bread_inverse_matrix,
906
+ post_deployment_analysis.form_sandwich_from_bread_and_meat(
907
+ joint_adjusted_bread_matrix,
914
908
  joint_adjusted_meat_matrix,
915
909
  num_subjects,
916
- method="bread_inverse_T_qr",
910
+ method="bread_T_qr",
917
911
  )
918
912
  )
919
913
  adjusted_sandwich = joint_adjusted_sandwich[-theta.shape[0] :, -theta.shape[0] :]
920
914
 
921
- classical_bread_inverse_matrix = jnp.mean(
922
- per_subject_classical_bread_inverse_contributions, axis=0
923
- )
924
- classical_sandwich = after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
925
- classical_bread_inverse_matrix,
915
+ classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
916
+ classical_sandwich = post_deployment_analysis.form_sandwich_from_bread_and_meat(
917
+ classical_bread_matrix,
926
918
  classical_meat_matrix,
927
919
  num_subjects,
928
- method="bread_inverse_T_qr",
920
+ method="bread_T_qr",
929
921
  )
930
922
 
931
- # Stack the joint adjusted inverse bread pieces together horizontally and return the auxiliary
932
- # values too. The joint adjusted bread inverse should always be block lower triangular.
923
+ # Stack the joint adjusted bread pieces together horizontally and return the auxiliary
924
+ # values too. The joint adjusted bread should always be block lower triangular.
933
925
  return (
934
926
  adjusted_sandwich,
935
927
  classical_sandwich,
@@ -1036,7 +1028,7 @@ def get_weighted_inference_estimating_functions_only(
1036
1028
  else inference_func
1037
1029
  )
1038
1030
 
1039
- betas, theta = after_study_analysis.unflatten_params(
1031
+ betas, theta = post_deployment_analysis.unflatten_params(
1040
1032
  flattened_betas_and_theta,
1041
1033
  beta_dim,
1042
1034
  theta_dim,
@@ -1052,7 +1044,7 @@ def get_weighted_inference_estimating_functions_only(
1052
1044
  (
1053
1045
  threaded_action_prob_func_args_by_decision_time_by_subject_id,
1054
1046
  action_prob_func_args_by_decision_time_by_subject_id,
1055
- ) = after_study_analysis.thread_action_prob_func_args(
1047
+ ) = post_deployment_analysis.thread_action_prob_func_args(
1056
1048
  action_prob_func_args_by_subject_id_by_decision_time,
1057
1049
  policy_num_by_decision_time_by_subject_id,
1058
1050
  initial_policy_num,
@@ -1069,7 +1061,7 @@ def get_weighted_inference_estimating_functions_only(
1069
1061
  "function args for all subjects"
1070
1062
  )
1071
1063
  threaded_inference_func_args_by_subject_id = (
1072
- after_study_analysis.thread_inference_func_args(
1064
+ post_deployment_analysis.thread_inference_func_args(
1073
1065
  inference_func_args_by_subject_id,
1074
1066
  inference_func_args_theta_index,
1075
1067
  theta,
@@ -1205,9 +1197,11 @@ def single_subject_weighted_inference_estimating_function(
1205
1197
 
1206
1198
  # 1. Get the first time after the first update for convenience.
1207
1199
  # This is used to form the Radon-Nikodym weights for the right times.
1208
- _, first_time_after_first_update = after_study_analysis.get_min_time_by_policy_num(
1209
- policy_num_by_decision_time,
1210
- beta_index_by_policy_num,
1200
+ _, first_time_after_first_update = (
1201
+ post_deployment_analysis.get_min_time_by_policy_num(
1202
+ policy_num_by_decision_time,
1203
+ beta_index_by_policy_num,
1204
+ )
1211
1205
  )
1212
1206
 
1213
1207
  # 2. Get the start and end times for this subject.
@@ -1268,7 +1262,7 @@ def single_subject_weighted_inference_estimating_function(
1268
1262
  # value, but impervious to differentiation with respect to all_post_update_betas. The
1269
1263
  # args, on the other hand, are a function of all_post_update_betas.
1270
1264
  in_study_weights = jax.vmap(
1271
- fun=after_study_analysis.get_radon_nikodym_weight,
1265
+ fun=post_deployment_analysis.get_radon_nikodym_weight,
1272
1266
  in_axes=[0, None, None, 0] + batch_axes,
1273
1267
  out_axes=0,
1274
1268
  )(
@@ -11,8 +11,6 @@ import numpy as np
11
11
  import jax.numpy as jnp
12
12
  import pandas as pd
13
13
 
14
- from .constants import InverseStabilizationMethods
15
-
16
14
  logger = logging.getLogger(__name__)
17
15
  logging.basicConfig(
18
16
  format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
@@ -27,11 +25,7 @@ def conditional_x_or_one_minus_x(x, condition):
27
25
 
28
26
  def invert_matrix_and_check_conditioning(
29
27
  matrix: np.ndarray,
30
- inverse_stabilization_method: str = InverseStabilizationMethods.NONE,
31
28
  condition_num_threshold: float = 10**4,
32
- ridge_median_singular_value_fraction: str = 0.01,
33
- beta_dim: int = None,
34
- theta_dim: int = None,
35
29
  ):
36
30
  """
37
31
  Check a matrix's condition number and invert it. If the condition number is
@@ -39,139 +33,15 @@ def invert_matrix_and_check_conditioning(
39
33
  Parameters
40
34
  """
41
35
  inverse = None
42
- pre_inversion_condition_number = np.linalg.cond(matrix)
43
- if pre_inversion_condition_number > condition_num_threshold:
36
+ condition_number = np.linalg.cond(matrix)
37
+ if condition_number > condition_num_threshold:
44
38
  logger.warning(
45
- "You are inverting a matrix with a large condition number: %s",
46
- pre_inversion_condition_number,
39
+ "You are inverting a matrix with a potentially large condition number: %s",
40
+ condition_number,
47
41
  )
48
- if (
49
- inverse_stabilization_method
50
- == InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES
51
- ):
52
- logger.info("Trimming small singular values to improve conditioning.")
53
- u, s, vT = np.linalg.svd(matrix, full_matrices=False)
54
- logger.info(
55
- " Sorted singular values: %s",
56
- s,
57
- )
58
- sing_values_above_threshold_cond = s > s.max() / condition_num_threshold
59
- if not np.any(sing_values_above_threshold_cond):
60
- raise RuntimeError(
61
- f"All singular values are below the threshold of {s.max() / condition_num_threshold}. Singular value trimming will not work.",
62
- )
63
- trimmed_pseudoinverse = (
64
- vT.T[:, sing_values_above_threshold_cond]
65
- / s[sing_values_above_threshold_cond]
66
- ) @ u[:, sing_values_above_threshold_cond].T
67
- inverse = trimmed_pseudoinverse
68
- pre_inversion_condition_number = (
69
- s[sing_values_above_threshold_cond].max()
70
- / s[sing_values_above_threshold_cond].min()
71
- )
72
-
73
- logger.info(
74
- "Kept %s out of %s singular values. Condition number of resulting lower-rank-approximation before inversion: %s",
75
- sum(sing_values_above_threshold_cond),
76
- len(s),
77
- pre_inversion_condition_number,
78
- )
79
- elif (
80
- inverse_stabilization_method
81
- == InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER
82
- ):
83
- logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
84
- _, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
85
- logger.info(
86
- "Using fixed condition number threshold of %s to determine lambda.",
87
- condition_num_threshold,
88
- )
89
- lambda_ = (
90
- singular_values.max() / condition_num_threshold - singular_values.min()
91
- )
92
- logger.info("Lambda for ridge regularization: %s", lambda_)
93
- new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
94
- pre_inversion_condition_number = np.linalg.cond(new_matrix)
95
- logger.info(
96
- "Condition number of matrix after ridge regularization: %s",
97
- pre_inversion_condition_number,
98
- )
99
- inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
100
- elif (
101
- inverse_stabilization_method
102
- == InverseStabilizationMethods.ADD_RIDGE_MEDIAN_SINGULAR_VALUE_FRACTION
103
- ):
104
- logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
105
- _, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
106
- logger.info(
107
- "Using median singular value times %s as lambda.",
108
- ridge_median_singular_value_fraction,
109
- )
110
- lambda_ = ridge_median_singular_value_fraction * np.median(singular_values)
111
- logger.info("Lambda for ridge regularization: %s", lambda_)
112
- new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
113
- pre_inversion_condition_number = np.linalg.cond(new_matrix)
114
- logger.info(
115
- "Condition number of matrix after ridge regularization: %s",
116
- pre_inversion_condition_number,
117
- )
118
- inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
119
- elif (
120
- inverse_stabilization_method
121
- == InverseStabilizationMethods.INVERSE_BREAD_STRUCTURE_AWARE_INVERSION
122
- ):
123
- if not beta_dim or not theta_dim:
124
- raise ValueError(
125
- "When using structure-aware inversion, beta_dim and theta_dim must be provided."
126
- )
127
- logger.info(
128
- "Using inverse bread's block lower triangular structure to invert only diagonal blocks."
129
- )
130
- pre_inversion_condition_number = np.linalg.cond(matrix)
131
- inverse = invert_inverse_bread_matrix(
132
- matrix,
133
- beta_dim,
134
- theta_dim,
135
- InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER,
136
- )
137
- elif (
138
- inverse_stabilization_method
139
- == InverseStabilizationMethods.ZERO_OUT_SMALL_OFF_DIAGONALS
140
- ):
141
- if not beta_dim or not theta_dim:
142
- raise ValueError(
143
- "When zeroing out small off diagonals, beta_dim and theta_dim must be provided."
144
- )
145
- logger.info(
146
- "Zeroing out small off-diagonal blocks to improve conditioning."
147
- )
148
- zeroed_matrix = zero_small_off_diagonal_blocks(
149
- matrix,
150
- ([beta_dim] * (matrix.shape[0] // beta_dim)) + [theta_dim],
151
- )
152
- pre_inversion_condition_number = np.linalg.cond(zeroed_matrix)
153
- logger.info(
154
- "Condition number of matrix after zeroing out small off-diagonal blocks: %s",
155
- pre_inversion_condition_number,
156
- )
157
- inverse = np.linalg.solve(zeroed_matrix, np.eye(zeroed_matrix.shape[0]))
158
- elif (
159
- inverse_stabilization_method
160
- == InverseStabilizationMethods.ALL_METHODS_COMPETITION
161
- ):
162
- # TODO: Choose right metric for competition... identity diff might not be it.
163
- raise NotImplementedError(
164
- "All methods competition is not implemented yet. Please choose a specific method."
165
- )
166
- elif inverse_stabilization_method == InverseStabilizationMethods.NONE:
167
- logger.info("No inverse stabilization method applied. Inverting directly.")
168
- else:
169
- raise ValueError(
170
- f"Unknown inverse stabilization method: {inverse_stabilization_method}"
171
- )
172
42
  if inverse is None:
173
43
  inverse = np.linalg.solve(matrix, np.eye(matrix.shape[0]))
174
- return inverse, pre_inversion_condition_number
44
+ return inverse, condition_number
175
45
 
176
46
 
177
47
  def zero_small_off_diagonal_blocks(
@@ -183,7 +53,7 @@ def zero_small_off_diagonal_blocks(
183
53
  Zero off-diagonal blocks whose Frobenius norm is < frobenius_norm_threshold_fraction x
184
54
  Frobenius norm of the diagonal block in the same ROW. One could compare to
185
55
  the same column or both the row and column, but we choose row here since
186
- rows correspond to a single RL update or inference step in the adaptive bread
56
+ rows correspond to a single RL update or inference step in the bread
187
57
  inverse matrices this method is designed for.
188
58
 
189
59
  Args:
@@ -237,18 +107,17 @@ def zero_small_off_diagonal_blocks(
237
107
  return J_trim
238
108
 
239
109
 
240
- def invert_inverse_bread_matrix(
241
- inverse_bread,
110
+ def invert_bread_matrix(
111
+ bread,
242
112
  beta_dim,
243
113
  theta_dim,
244
- diag_inverse_stabilization_method=InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES,
245
114
  ):
246
115
  """
247
- Invert the inverse bread matrix to get the bread matrix. This is a special
116
+ Invert the bread matrix to get the inverse bread matrix. This is a special
248
117
  function in order to take advantage of the block lower triangular structure.
249
118
 
250
119
  The procedure is as follows:
251
- 1. Initialize the inverse matrix B = A^{-1} as a block lower triangular matrix
120
+ 1. Initialize the matrix B = A^{-1} as a block lower triangular matrix
252
121
  with the same block structure as A.
253
122
 
254
123
  2. Compute the diagonal blocks B_{ii}:
@@ -260,24 +129,23 @@ def invert_inverse_bread_matrix(
260
129
  B_{ij} = -A_{ii}^{-1} * sum(A_{ik} * B_{kj} for k in range(j, i))
261
130
  """
262
131
  blocks = []
263
- num_beta_block_rows = (inverse_bread.shape[0] - theta_dim) // beta_dim
132
+ num_beta_block_rows = (bread.shape[0] - theta_dim) // beta_dim
264
133
 
265
134
  # Create upper rows of block of bread (just the beta portion)
266
135
  for i in range(0, num_beta_block_rows):
267
136
  beta_block_row = []
268
137
  beta_diag_inverse = invert_matrix_and_check_conditioning(
269
- inverse_bread[
138
+ bread[
270
139
  beta_dim * i : beta_dim * (i + 1),
271
140
  beta_dim * i : beta_dim * (i + 1),
272
141
  ],
273
- diag_inverse_stabilization_method,
274
142
  )[0]
275
143
  for j in range(0, num_beta_block_rows):
276
144
  if i > j:
277
145
  beta_block_row.append(
278
146
  -beta_diag_inverse
279
147
  @ sum(
280
- inverse_bread[
148
+ bread[
281
149
  beta_dim * i : beta_dim * (i + 1),
282
150
  beta_dim * k : beta_dim * (k + 1),
283
151
  ]
@@ -299,17 +167,16 @@ def invert_inverse_bread_matrix(
299
167
  # Create the bottom block row of bread (the theta portion)
300
168
  theta_block_row = []
301
169
  theta_diag_inverse = invert_matrix_and_check_conditioning(
302
- inverse_bread[
170
+ bread[
303
171
  -theta_dim:,
304
172
  -theta_dim:,
305
173
  ],
306
- diag_inverse_stabilization_method,
307
174
  )[0]
308
175
  for k in range(0, num_beta_block_rows):
309
176
  theta_block_row.append(
310
177
  -theta_diag_inverse
311
178
  @ sum(
312
- inverse_bread[
179
+ bread[
313
180
  -theta_dim:,
314
181
  beta_dim * h : beta_dim * (h + 1),
315
182
  ]