lifejacket 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/calculate_derivatives.py +0 -2
- lifejacket/constants.py +4 -16
- lifejacket/deployment_conditioning_monitor.py +19 -12
- lifejacket/form_adjusted_meat_adjustments_directly.py +25 -27
- lifejacket/get_datum_for_blowup_supervised_learning.py +71 -77
- lifejacket/helper_functions.py +15 -148
- lifejacket/input_checks.py +49 -50
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +377 -144
- lifejacket/small_sample_corrections.py +11 -13
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/METADATA +1 -1
- lifejacket-1.1.0.dist-info/RECORD +17 -0
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/WHEEL +1 -1
- lifejacket-1.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-1.0.0.dist-info/RECORD +0 -17
- lifejacket-1.0.0.dist-info/entry_points.txt +0 -2
- {lifejacket-1.0.0.dist-info → lifejacket-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@ import jax
|
|
|
12
12
|
from jax import numpy as jnp
|
|
13
13
|
import pandas as pd
|
|
14
14
|
|
|
15
|
-
from . import
|
|
15
|
+
from . import post_deployment_analysis
|
|
16
16
|
from .constants import FunctionTypes
|
|
17
17
|
from .vmap_helpers import stack_batched_arg_lists_into_tensors
|
|
18
18
|
|
|
@@ -25,8 +25,8 @@ logging.basicConfig(
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def get_datum_for_blowup_supervised_learning(
|
|
28
|
-
|
|
29
|
-
|
|
28
|
+
joint_adjusted_bread_matrix,
|
|
29
|
+
joint_adjusted_bread_cond,
|
|
30
30
|
avg_estimating_function_stack,
|
|
31
31
|
per_subject_estimating_function_stacks,
|
|
32
32
|
all_post_update_betas,
|
|
@@ -63,10 +63,10 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
63
63
|
A few plots are produced along the way to help visualize the data.
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
|
-
|
|
67
|
-
The joint adjusted bread
|
|
68
|
-
|
|
69
|
-
The condition number of the joint adjusted bread
|
|
66
|
+
joint_adjusted_bread_matrix (jnp.ndarray):
|
|
67
|
+
The joint adjusted bread matrix.
|
|
68
|
+
joint_adjusted_bread_cond (float):
|
|
69
|
+
The condition number of the joint adjusted bread matrix.
|
|
70
70
|
avg_estimating_function_stack (jnp.ndarray):
|
|
71
71
|
The average estimating function stack across subjects.
|
|
72
72
|
per_subject_estimating_function_stacks (jnp.ndarray):
|
|
@@ -125,7 +125,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
125
125
|
dict[str, Any]: A dictionary containing features and the label for supervised learning.
|
|
126
126
|
"""
|
|
127
127
|
num_diagonal_blocks = (
|
|
128
|
-
(
|
|
128
|
+
(joint_adjusted_bread_matrix.shape[0] - theta_dim) // beta_dim
|
|
129
129
|
) + 1
|
|
130
130
|
diagonal_block_sizes = ([beta_dim] * (num_diagonal_blocks - 1)) + [theta_dim]
|
|
131
131
|
|
|
@@ -144,7 +144,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
144
144
|
row_slice = slice(block_bounds[i], block_bounds[i + 1])
|
|
145
145
|
col_slice = slice(block_bounds[j], block_bounds[j + 1])
|
|
146
146
|
block_norm = np.linalg.norm(
|
|
147
|
-
|
|
147
|
+
joint_adjusted_bread_matrix[row_slice, col_slice],
|
|
148
148
|
ord="fro",
|
|
149
149
|
)
|
|
150
150
|
# We will sum here and take the square root later
|
|
@@ -155,9 +155,9 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
155
155
|
# handle diagonal blocks
|
|
156
156
|
sl = slice(block_bounds[i], block_bounds[i + 1])
|
|
157
157
|
diag_norms.append(
|
|
158
|
-
np.linalg.norm(
|
|
158
|
+
np.linalg.norm(joint_adjusted_bread_matrix[sl, sl], ord="fro")
|
|
159
159
|
)
|
|
160
|
-
diag_conds.append(np.linalg.cond(
|
|
160
|
+
diag_conds.append(np.linalg.cond(joint_adjusted_bread_matrix[sl, sl]))
|
|
161
161
|
|
|
162
162
|
# Sqrt each row/col sum to truly get row/column norms.
|
|
163
163
|
# Perhaps not necessary for learning, but more natural
|
|
@@ -214,8 +214,8 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
214
214
|
reward_means_by_t = grouped_reward.mean().values
|
|
215
215
|
reward_stds_by_t = grouped_reward.std().values
|
|
216
216
|
|
|
217
|
-
|
|
218
|
-
|
|
217
|
+
joint_bread_min_singular_value = np.linalg.svd(
|
|
218
|
+
joint_adjusted_bread_matrix, compute_uv=False
|
|
219
219
|
)[-1]
|
|
220
220
|
|
|
221
221
|
max_reward = analysis_df.loc[in_study_mask][reward_col_name].max()
|
|
@@ -227,7 +227,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
227
227
|
premature_thetas,
|
|
228
228
|
premature_adjusted_sandwiches,
|
|
229
229
|
premature_classical_sandwiches,
|
|
230
|
-
|
|
230
|
+
premature_joint_adjusted_bread_condition_numbers,
|
|
231
231
|
premature_avg_inference_estimating_functions,
|
|
232
232
|
) = calculate_sequence_of_premature_adjusted_estimates(
|
|
233
233
|
analysis_df,
|
|
@@ -250,7 +250,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
250
250
|
inference_action_prob_decision_times_by_subject_id,
|
|
251
251
|
action_prob_func_args,
|
|
252
252
|
action_by_decision_time_by_subject_id,
|
|
253
|
-
|
|
253
|
+
joint_adjusted_bread_matrix,
|
|
254
254
|
per_subject_estimating_function_stacks,
|
|
255
255
|
beta_dim,
|
|
256
256
|
)
|
|
@@ -261,23 +261,23 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
261
261
|
atol=1e-3,
|
|
262
262
|
)
|
|
263
263
|
|
|
264
|
-
# Plot premature joint adjusted bread
|
|
264
|
+
# Plot premature joint adjusted bread log condition numbers
|
|
265
265
|
plt.clear_figure()
|
|
266
|
-
plt.title("Premature Joint
|
|
266
|
+
plt.title("Premature Joint Adjusted Bread Inverse Log Condition Numbers")
|
|
267
267
|
plt.xlabel("Premature Update Index")
|
|
268
268
|
plt.ylabel("Log Condition Number")
|
|
269
269
|
plt.scatter(
|
|
270
|
-
np.log(
|
|
270
|
+
np.log(premature_joint_adjusted_bread_condition_numbers),
|
|
271
271
|
color="blue+",
|
|
272
272
|
)
|
|
273
273
|
plt.grid(True)
|
|
274
274
|
plt.xticks(
|
|
275
275
|
range(
|
|
276
276
|
0,
|
|
277
|
-
len(
|
|
277
|
+
len(premature_joint_adjusted_bread_condition_numbers),
|
|
278
278
|
max(
|
|
279
279
|
1,
|
|
280
|
-
len(
|
|
280
|
+
len(premature_joint_adjusted_bread_condition_numbers) // 10,
|
|
281
281
|
),
|
|
282
282
|
)
|
|
283
283
|
)
|
|
@@ -287,7 +287,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
287
287
|
num_diag = premature_adjusted_sandwiches.shape[-1]
|
|
288
288
|
for i in range(num_diag):
|
|
289
289
|
plt.clear_figure()
|
|
290
|
-
plt.title(f"Premature
|
|
290
|
+
plt.title(f"Premature Adjusted Sandwich Diagonal Element {i}")
|
|
291
291
|
plt.xlabel("Premature Update Index")
|
|
292
292
|
plt.ylabel(f"Variance (Diagonal {i})")
|
|
293
293
|
plt.scatter(np.array(premature_adjusted_sandwiches[:, i, i]), color="blue+")
|
|
@@ -303,7 +303,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
303
303
|
|
|
304
304
|
plt.clear_figure()
|
|
305
305
|
plt.title(
|
|
306
|
-
f"Premature
|
|
306
|
+
f"Premature Adjusted Sandwich Diagonal Element {i} Ratio to Classical"
|
|
307
307
|
)
|
|
308
308
|
plt.xlabel("Premature Update Index")
|
|
309
309
|
plt.ylabel(f"Variance (Diagonal {i})")
|
|
@@ -338,7 +338,7 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
338
338
|
plt.show()
|
|
339
339
|
|
|
340
340
|
# Grab predictors related to premature Phi-dot-bars
|
|
341
|
-
RL_stack_beta_derivatives_block =
|
|
341
|
+
RL_stack_beta_derivatives_block = joint_adjusted_bread_matrix[
|
|
342
342
|
:-theta_dim, :-theta_dim
|
|
343
343
|
]
|
|
344
344
|
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
@@ -397,8 +397,8 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
397
397
|
)
|
|
398
398
|
return {
|
|
399
399
|
**{
|
|
400
|
-
"
|
|
401
|
-
"
|
|
400
|
+
"joint_bread_condition_number": joint_adjusted_bread_cond,
|
|
401
|
+
"joint_bread_min_singular_value": joint_bread_min_singular_value,
|
|
402
402
|
"max_reward": max_reward,
|
|
403
403
|
"norm_avg_estimating_function_stack": norm_avg_estimating_function_stack,
|
|
404
404
|
"max_estimating_function_stack_norm": max_estimating_function_stack_norm,
|
|
@@ -455,12 +455,10 @@ def get_datum_for_blowup_supervised_learning(
|
|
|
455
455
|
},
|
|
456
456
|
**{f"theta_est_{i}": theta_est[i].item() for i in range(len(theta_est))},
|
|
457
457
|
**{
|
|
458
|
-
f"
|
|
458
|
+
f"premature_joint_adjusted_bread_condition_number_{i}": premature_joint_adjusted_bread_condition_numbers[
|
|
459
459
|
i
|
|
460
460
|
]
|
|
461
|
-
for i in range(
|
|
462
|
-
len(premature_joint_adjusted_bread_inverse_condition_numbers)
|
|
463
|
-
)
|
|
461
|
+
for i in range(len(premature_joint_adjusted_bread_condition_numbers))
|
|
464
462
|
},
|
|
465
463
|
**{
|
|
466
464
|
f"premature_adjusted_sandwich_update_{i}_diag_position_{j}": premature_adjusted_sandwich[
|
|
@@ -526,7 +524,7 @@ def calculate_sequence_of_premature_adjusted_estimates(
|
|
|
526
524
|
action_by_decision_time_by_subject_id: dict[
|
|
527
525
|
collections.abc.Hashable, dict[int, int]
|
|
528
526
|
],
|
|
529
|
-
|
|
527
|
+
full_joint_adjusted_bread_matrix: jnp.ndarray,
|
|
530
528
|
per_subject_estimating_function_stacks: jnp.ndarray,
|
|
531
529
|
beta_dim: int,
|
|
532
530
|
) -> jnp.ndarray:
|
|
@@ -584,8 +582,8 @@ def calculate_sequence_of_premature_adjusted_estimates(
|
|
|
584
582
|
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
585
583
|
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
586
584
|
Only applies to in-study decision times!
|
|
587
|
-
|
|
588
|
-
The full joint adjusted bread
|
|
585
|
+
full_joint_adjusted_bread_matrix (jnp.ndarray):
|
|
586
|
+
The full joint adjusted bread matrix as a NumPy array.
|
|
589
587
|
per_subject_estimating_function_stacks (jnp.ndarray):
|
|
590
588
|
A NumPy array containing all per-subject (weighted) estimating function stacks.
|
|
591
589
|
beta_dim (int):
|
|
@@ -598,7 +596,7 @@ def calculate_sequence_of_premature_adjusted_estimates(
|
|
|
598
596
|
# variance estimates pretending that each was the final policy.
|
|
599
597
|
premature_adjusted_sandwiches = []
|
|
600
598
|
premature_thetas = []
|
|
601
|
-
|
|
599
|
+
premature_joint_adjusted_bread_condition_numbers = []
|
|
602
600
|
premature_avg_inference_estimating_functions = []
|
|
603
601
|
premature_classical_sandwiches = []
|
|
604
602
|
logger.info(
|
|
@@ -611,12 +609,10 @@ def calculate_sequence_of_premature_adjusted_estimates(
|
|
|
611
609
|
)
|
|
612
610
|
pretend_max_policy = policy_num
|
|
613
611
|
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
]
|
|
619
|
-
)
|
|
612
|
+
truncated_joint_adjusted_bread_matrix = full_joint_adjusted_bread_matrix[
|
|
613
|
+
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
614
|
+
: (beta_index_by_policy_num[pretend_max_policy] + 1) * beta_dim,
|
|
615
|
+
]
|
|
620
616
|
|
|
621
617
|
max_decision_time = analysis_df[
|
|
622
618
|
analysis_df["policy_num"] == pretend_max_policy
|
|
@@ -643,7 +639,7 @@ def calculate_sequence_of_premature_adjusted_estimates(
|
|
|
643
639
|
}
|
|
644
640
|
|
|
645
641
|
truncated_inference_func_args_by_subject_id, _, _ = (
|
|
646
|
-
|
|
642
|
+
post_deployment_analysis.process_inference_func_args(
|
|
647
643
|
inference_func,
|
|
648
644
|
inference_func_args_theta_index,
|
|
649
645
|
truncated_analysis_df,
|
|
@@ -690,7 +686,7 @@ def calculate_sequence_of_premature_adjusted_estimates(
|
|
|
690
686
|
premature_classical_sandwich,
|
|
691
687
|
premature_avg_inference_estimating_function,
|
|
692
688
|
) = construct_premature_classical_and_adjusted_sandwiches(
|
|
693
|
-
|
|
689
|
+
truncated_joint_adjusted_bread_matrix,
|
|
694
690
|
truncated_per_subject_estimating_function_stacks,
|
|
695
691
|
premature_theta,
|
|
696
692
|
truncated_all_post_update_betas,
|
|
@@ -720,13 +716,13 @@ def calculate_sequence_of_premature_adjusted_estimates(
|
|
|
720
716
|
jnp.array(premature_thetas),
|
|
721
717
|
jnp.array(premature_adjusted_sandwiches),
|
|
722
718
|
jnp.array(premature_classical_sandwiches),
|
|
723
|
-
jnp.array(
|
|
719
|
+
jnp.array(premature_joint_adjusted_bread_condition_numbers),
|
|
724
720
|
jnp.array(premature_avg_inference_estimating_functions),
|
|
725
721
|
)
|
|
726
722
|
|
|
727
723
|
|
|
728
724
|
def construct_premature_classical_and_adjusted_sandwiches(
|
|
729
|
-
|
|
725
|
+
truncated_joint_adjusted_bread_matrix: jnp.ndarray,
|
|
730
726
|
per_subject_truncated_estimating_function_stacks: jnp.ndarray,
|
|
731
727
|
theta: jnp.ndarray,
|
|
732
728
|
all_post_update_betas: jnp.ndarray,
|
|
@@ -769,15 +765,15 @@ def construct_premature_classical_and_adjusted_sandwiches(
|
|
|
769
765
|
|
|
770
766
|
This is done by computing and differentiating the new average inference estimating function
|
|
771
767
|
with respect to the betas and theta, and stitching this together with the existing
|
|
772
|
-
adjusted bread
|
|
773
|
-
to form the new premature joint adjusted bread
|
|
768
|
+
adjusted bread matrix portion (corresponding to the updates still under consideration)
|
|
769
|
+
to form the new premature joint adjusted bread matrix.
|
|
774
770
|
|
|
775
771
|
Args:
|
|
776
|
-
|
|
777
|
-
A 2-D JAX NumPy array holding the existing joint adjusted bread
|
|
772
|
+
truncated_joint_adjusted_bread_matrix (jnp.ndarray):
|
|
773
|
+
A 2-D JAX NumPy array holding the existing joint adjusted bread but
|
|
778
774
|
with rows corresponding to updates not under consideration and inference dropped.
|
|
779
775
|
We will stitch this together with the newly computed inference portion to form
|
|
780
|
-
our "premature" joint adjusted bread
|
|
776
|
+
our "premature" joint adjusted bread matrix.
|
|
781
777
|
per_subject_truncated_estimating_function_stacks (jnp.ndarray):
|
|
782
778
|
A 2-D JAX NumPy array holding the existing per-subject weighted estimating function
|
|
783
779
|
stacks but with rows corresponding to updates not under consideration dropped.
|
|
@@ -828,14 +824,14 @@ def construct_premature_classical_and_adjusted_sandwiches(
|
|
|
828
824
|
jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32],
|
|
829
825
|
jnp.ndarray[jnp.float32], jnp.ndarray[jnp.float32]]:
|
|
830
826
|
A tuple containing:
|
|
831
|
-
- The joint adjusted
|
|
827
|
+
- The joint adjusted bread matrix.
|
|
832
828
|
- The joint adjusted bread matrix.
|
|
833
829
|
- The joint adjusted meat matrix.
|
|
834
|
-
- The classical
|
|
830
|
+
- The classical bread matrix.
|
|
835
831
|
- The classical bread matrix.
|
|
836
832
|
- The classical meat matrix.
|
|
837
833
|
- The average (weighted) inference estimating function.
|
|
838
|
-
- The joint adjusted
|
|
834
|
+
- The joint adjusted bread matrix condition number.
|
|
839
835
|
"""
|
|
840
836
|
logger.info(
|
|
841
837
|
"Differentiating average weighted inference estimating function stack and collecting auxiliary values."
|
|
@@ -847,12 +843,12 @@ def construct_premature_classical_and_adjusted_sandwiches(
|
|
|
847
843
|
per_subject_inference_estimating_functions,
|
|
848
844
|
avg_inference_estimating_function,
|
|
849
845
|
per_subject_classical_meat_contributions,
|
|
850
|
-
|
|
846
|
+
per_subject_classical_bread_contributions,
|
|
851
847
|
) = jax.jacrev(get_weighted_inference_estimating_functions_only, has_aux=True)(
|
|
852
848
|
# While JAX can technically differentiate with respect to a list of JAX arrays,
|
|
853
849
|
# it is more efficient to flatten them into a single array. This is done
|
|
854
850
|
# here to improve performance. We can simply unflatten them inside the function.
|
|
855
|
-
|
|
851
|
+
post_deployment_analysis.flatten_params(all_post_update_betas, theta),
|
|
856
852
|
all_post_update_betas.shape[1],
|
|
857
853
|
theta.shape[0],
|
|
858
854
|
subject_ids,
|
|
@@ -871,13 +867,13 @@ def construct_premature_classical_and_adjusted_sandwiches(
|
|
|
871
867
|
action_by_decision_time_by_subject_id,
|
|
872
868
|
)
|
|
873
869
|
|
|
874
|
-
|
|
870
|
+
joint_adjusted_bread_matrix = jnp.block(
|
|
875
871
|
[
|
|
876
872
|
[
|
|
877
|
-
|
|
873
|
+
truncated_joint_adjusted_bread_matrix,
|
|
878
874
|
np.zeros(
|
|
879
875
|
(
|
|
880
|
-
|
|
876
|
+
truncated_joint_adjusted_bread_matrix.shape[0],
|
|
881
877
|
new_inference_block_row.shape[0],
|
|
882
878
|
)
|
|
883
879
|
),
|
|
@@ -902,34 +898,30 @@ def construct_premature_classical_and_adjusted_sandwiches(
|
|
|
902
898
|
per_subject_adjusted_meat_contributions, axis=0
|
|
903
899
|
)
|
|
904
900
|
|
|
905
|
-
|
|
906
|
-
per_subject_classical_bread_inverse_contributions, axis=0
|
|
907
|
-
)
|
|
901
|
+
classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
|
|
908
902
|
classical_meat_matrix = jnp.mean(per_subject_classical_meat_contributions, axis=0)
|
|
909
903
|
|
|
910
904
|
num_subjects = subject_ids.shape[0]
|
|
911
905
|
joint_adjusted_sandwich = (
|
|
912
|
-
|
|
913
|
-
|
|
906
|
+
post_deployment_analysis.form_sandwich_from_bread_and_meat(
|
|
907
|
+
joint_adjusted_bread_matrix,
|
|
914
908
|
joint_adjusted_meat_matrix,
|
|
915
909
|
num_subjects,
|
|
916
|
-
method="
|
|
910
|
+
method="bread_T_qr",
|
|
917
911
|
)
|
|
918
912
|
)
|
|
919
913
|
adjusted_sandwich = joint_adjusted_sandwich[-theta.shape[0] :, -theta.shape[0] :]
|
|
920
914
|
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
classical_sandwich = after_study_analysis.form_sandwich_from_bread_inverse_and_meat(
|
|
925
|
-
classical_bread_inverse_matrix,
|
|
915
|
+
classical_bread_matrix = jnp.mean(per_subject_classical_bread_contributions, axis=0)
|
|
916
|
+
classical_sandwich = post_deployment_analysis.form_sandwich_from_bread_and_meat(
|
|
917
|
+
classical_bread_matrix,
|
|
926
918
|
classical_meat_matrix,
|
|
927
919
|
num_subjects,
|
|
928
|
-
method="
|
|
920
|
+
method="bread_T_qr",
|
|
929
921
|
)
|
|
930
922
|
|
|
931
|
-
# Stack the joint adjusted
|
|
932
|
-
# values too. The joint adjusted bread
|
|
923
|
+
# Stack the joint adjusted bread pieces together horizontally and return the auxiliary
|
|
924
|
+
# values too. The joint adjusted bread should always be block lower triangular.
|
|
933
925
|
return (
|
|
934
926
|
adjusted_sandwich,
|
|
935
927
|
classical_sandwich,
|
|
@@ -1036,7 +1028,7 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1036
1028
|
else inference_func
|
|
1037
1029
|
)
|
|
1038
1030
|
|
|
1039
|
-
betas, theta =
|
|
1031
|
+
betas, theta = post_deployment_analysis.unflatten_params(
|
|
1040
1032
|
flattened_betas_and_theta,
|
|
1041
1033
|
beta_dim,
|
|
1042
1034
|
theta_dim,
|
|
@@ -1052,7 +1044,7 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1052
1044
|
(
|
|
1053
1045
|
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
1054
1046
|
action_prob_func_args_by_decision_time_by_subject_id,
|
|
1055
|
-
) =
|
|
1047
|
+
) = post_deployment_analysis.thread_action_prob_func_args(
|
|
1056
1048
|
action_prob_func_args_by_subject_id_by_decision_time,
|
|
1057
1049
|
policy_num_by_decision_time_by_subject_id,
|
|
1058
1050
|
initial_policy_num,
|
|
@@ -1069,7 +1061,7 @@ def get_weighted_inference_estimating_functions_only(
|
|
|
1069
1061
|
"function args for all subjects"
|
|
1070
1062
|
)
|
|
1071
1063
|
threaded_inference_func_args_by_subject_id = (
|
|
1072
|
-
|
|
1064
|
+
post_deployment_analysis.thread_inference_func_args(
|
|
1073
1065
|
inference_func_args_by_subject_id,
|
|
1074
1066
|
inference_func_args_theta_index,
|
|
1075
1067
|
theta,
|
|
@@ -1205,9 +1197,11 @@ def single_subject_weighted_inference_estimating_function(
|
|
|
1205
1197
|
|
|
1206
1198
|
# 1. Get the first time after the first update for convenience.
|
|
1207
1199
|
# This is used to form the Radon-Nikodym weights for the right times.
|
|
1208
|
-
_, first_time_after_first_update =
|
|
1209
|
-
|
|
1210
|
-
|
|
1200
|
+
_, first_time_after_first_update = (
|
|
1201
|
+
post_deployment_analysis.get_min_time_by_policy_num(
|
|
1202
|
+
policy_num_by_decision_time,
|
|
1203
|
+
beta_index_by_policy_num,
|
|
1204
|
+
)
|
|
1211
1205
|
)
|
|
1212
1206
|
|
|
1213
1207
|
# 2. Get the start and end times for this subject.
|
|
@@ -1268,7 +1262,7 @@ def single_subject_weighted_inference_estimating_function(
|
|
|
1268
1262
|
# value, but impervious to differentiation with respect to all_post_update_betas. The
|
|
1269
1263
|
# args, on the other hand, are a function of all_post_update_betas.
|
|
1270
1264
|
in_study_weights = jax.vmap(
|
|
1271
|
-
fun=
|
|
1265
|
+
fun=post_deployment_analysis.get_radon_nikodym_weight,
|
|
1272
1266
|
in_axes=[0, None, None, 0] + batch_axes,
|
|
1273
1267
|
out_axes=0,
|
|
1274
1268
|
)(
|
lifejacket/helper_functions.py
CHANGED
|
@@ -11,8 +11,6 @@ import numpy as np
|
|
|
11
11
|
import jax.numpy as jnp
|
|
12
12
|
import pandas as pd
|
|
13
13
|
|
|
14
|
-
from .constants import InverseStabilizationMethods
|
|
15
|
-
|
|
16
14
|
logger = logging.getLogger(__name__)
|
|
17
15
|
logging.basicConfig(
|
|
18
16
|
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
@@ -27,11 +25,7 @@ def conditional_x_or_one_minus_x(x, condition):
|
|
|
27
25
|
|
|
28
26
|
def invert_matrix_and_check_conditioning(
|
|
29
27
|
matrix: np.ndarray,
|
|
30
|
-
inverse_stabilization_method: str = InverseStabilizationMethods.NONE,
|
|
31
28
|
condition_num_threshold: float = 10**4,
|
|
32
|
-
ridge_median_singular_value_fraction: str = 0.01,
|
|
33
|
-
beta_dim: int = None,
|
|
34
|
-
theta_dim: int = None,
|
|
35
29
|
):
|
|
36
30
|
"""
|
|
37
31
|
Check a matrix's condition number and invert it. If the condition number is
|
|
@@ -39,139 +33,15 @@ def invert_matrix_and_check_conditioning(
|
|
|
39
33
|
Parameters
|
|
40
34
|
"""
|
|
41
35
|
inverse = None
|
|
42
|
-
|
|
43
|
-
if
|
|
36
|
+
condition_number = np.linalg.cond(matrix)
|
|
37
|
+
if condition_number > condition_num_threshold:
|
|
44
38
|
logger.warning(
|
|
45
|
-
"You are inverting a matrix with a large condition number: %s",
|
|
46
|
-
|
|
39
|
+
"You are inverting a matrix with a potentially large condition number: %s",
|
|
40
|
+
condition_number,
|
|
47
41
|
)
|
|
48
|
-
if (
|
|
49
|
-
inverse_stabilization_method
|
|
50
|
-
== InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES
|
|
51
|
-
):
|
|
52
|
-
logger.info("Trimming small singular values to improve conditioning.")
|
|
53
|
-
u, s, vT = np.linalg.svd(matrix, full_matrices=False)
|
|
54
|
-
logger.info(
|
|
55
|
-
" Sorted singular values: %s",
|
|
56
|
-
s,
|
|
57
|
-
)
|
|
58
|
-
sing_values_above_threshold_cond = s > s.max() / condition_num_threshold
|
|
59
|
-
if not np.any(sing_values_above_threshold_cond):
|
|
60
|
-
raise RuntimeError(
|
|
61
|
-
f"All singular values are below the threshold of {s.max() / condition_num_threshold}. Singular value trimming will not work.",
|
|
62
|
-
)
|
|
63
|
-
trimmed_pseudoinverse = (
|
|
64
|
-
vT.T[:, sing_values_above_threshold_cond]
|
|
65
|
-
/ s[sing_values_above_threshold_cond]
|
|
66
|
-
) @ u[:, sing_values_above_threshold_cond].T
|
|
67
|
-
inverse = trimmed_pseudoinverse
|
|
68
|
-
pre_inversion_condition_number = (
|
|
69
|
-
s[sing_values_above_threshold_cond].max()
|
|
70
|
-
/ s[sing_values_above_threshold_cond].min()
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
logger.info(
|
|
74
|
-
"Kept %s out of %s singular values. Condition number of resulting lower-rank-approximation before inversion: %s",
|
|
75
|
-
sum(sing_values_above_threshold_cond),
|
|
76
|
-
len(s),
|
|
77
|
-
pre_inversion_condition_number,
|
|
78
|
-
)
|
|
79
|
-
elif (
|
|
80
|
-
inverse_stabilization_method
|
|
81
|
-
== InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER
|
|
82
|
-
):
|
|
83
|
-
logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
|
|
84
|
-
_, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
|
|
85
|
-
logger.info(
|
|
86
|
-
"Using fixed condition number threshold of %s to determine lambda.",
|
|
87
|
-
condition_num_threshold,
|
|
88
|
-
)
|
|
89
|
-
lambda_ = (
|
|
90
|
-
singular_values.max() / condition_num_threshold - singular_values.min()
|
|
91
|
-
)
|
|
92
|
-
logger.info("Lambda for ridge regularization: %s", lambda_)
|
|
93
|
-
new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
|
|
94
|
-
pre_inversion_condition_number = np.linalg.cond(new_matrix)
|
|
95
|
-
logger.info(
|
|
96
|
-
"Condition number of matrix after ridge regularization: %s",
|
|
97
|
-
pre_inversion_condition_number,
|
|
98
|
-
)
|
|
99
|
-
inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
|
|
100
|
-
elif (
|
|
101
|
-
inverse_stabilization_method
|
|
102
|
-
== InverseStabilizationMethods.ADD_RIDGE_MEDIAN_SINGULAR_VALUE_FRACTION
|
|
103
|
-
):
|
|
104
|
-
logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
|
|
105
|
-
_, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
|
|
106
|
-
logger.info(
|
|
107
|
-
"Using median singular value times %s as lambda.",
|
|
108
|
-
ridge_median_singular_value_fraction,
|
|
109
|
-
)
|
|
110
|
-
lambda_ = ridge_median_singular_value_fraction * np.median(singular_values)
|
|
111
|
-
logger.info("Lambda for ridge regularization: %s", lambda_)
|
|
112
|
-
new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
|
|
113
|
-
pre_inversion_condition_number = np.linalg.cond(new_matrix)
|
|
114
|
-
logger.info(
|
|
115
|
-
"Condition number of matrix after ridge regularization: %s",
|
|
116
|
-
pre_inversion_condition_number,
|
|
117
|
-
)
|
|
118
|
-
inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
|
|
119
|
-
elif (
|
|
120
|
-
inverse_stabilization_method
|
|
121
|
-
== InverseStabilizationMethods.INVERSE_BREAD_STRUCTURE_AWARE_INVERSION
|
|
122
|
-
):
|
|
123
|
-
if not beta_dim or not theta_dim:
|
|
124
|
-
raise ValueError(
|
|
125
|
-
"When using structure-aware inversion, beta_dim and theta_dim must be provided."
|
|
126
|
-
)
|
|
127
|
-
logger.info(
|
|
128
|
-
"Using inverse bread's block lower triangular structure to invert only diagonal blocks."
|
|
129
|
-
)
|
|
130
|
-
pre_inversion_condition_number = np.linalg.cond(matrix)
|
|
131
|
-
inverse = invert_inverse_bread_matrix(
|
|
132
|
-
matrix,
|
|
133
|
-
beta_dim,
|
|
134
|
-
theta_dim,
|
|
135
|
-
InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER,
|
|
136
|
-
)
|
|
137
|
-
elif (
|
|
138
|
-
inverse_stabilization_method
|
|
139
|
-
== InverseStabilizationMethods.ZERO_OUT_SMALL_OFF_DIAGONALS
|
|
140
|
-
):
|
|
141
|
-
if not beta_dim or not theta_dim:
|
|
142
|
-
raise ValueError(
|
|
143
|
-
"When zeroing out small off diagonals, beta_dim and theta_dim must be provided."
|
|
144
|
-
)
|
|
145
|
-
logger.info(
|
|
146
|
-
"Zeroing out small off-diagonal blocks to improve conditioning."
|
|
147
|
-
)
|
|
148
|
-
zeroed_matrix = zero_small_off_diagonal_blocks(
|
|
149
|
-
matrix,
|
|
150
|
-
([beta_dim] * (matrix.shape[0] // beta_dim)) + [theta_dim],
|
|
151
|
-
)
|
|
152
|
-
pre_inversion_condition_number = np.linalg.cond(zeroed_matrix)
|
|
153
|
-
logger.info(
|
|
154
|
-
"Condition number of matrix after zeroing out small off-diagonal blocks: %s",
|
|
155
|
-
pre_inversion_condition_number,
|
|
156
|
-
)
|
|
157
|
-
inverse = np.linalg.solve(zeroed_matrix, np.eye(zeroed_matrix.shape[0]))
|
|
158
|
-
elif (
|
|
159
|
-
inverse_stabilization_method
|
|
160
|
-
== InverseStabilizationMethods.ALL_METHODS_COMPETITION
|
|
161
|
-
):
|
|
162
|
-
# TODO: Choose right metric for competition... identity diff might not be it.
|
|
163
|
-
raise NotImplementedError(
|
|
164
|
-
"All methods competition is not implemented yet. Please choose a specific method."
|
|
165
|
-
)
|
|
166
|
-
elif inverse_stabilization_method == InverseStabilizationMethods.NONE:
|
|
167
|
-
logger.info("No inverse stabilization method applied. Inverting directly.")
|
|
168
|
-
else:
|
|
169
|
-
raise ValueError(
|
|
170
|
-
f"Unknown inverse stabilization method: {inverse_stabilization_method}"
|
|
171
|
-
)
|
|
172
42
|
if inverse is None:
|
|
173
43
|
inverse = np.linalg.solve(matrix, np.eye(matrix.shape[0]))
|
|
174
|
-
return inverse,
|
|
44
|
+
return inverse, condition_number
|
|
175
45
|
|
|
176
46
|
|
|
177
47
|
def zero_small_off_diagonal_blocks(
|
|
@@ -183,7 +53,7 @@ def zero_small_off_diagonal_blocks(
|
|
|
183
53
|
Zero off-diagonal blocks whose Frobenius norm is < frobenius_norm_threshold_fraction x
|
|
184
54
|
Frobenius norm of the diagonal block in the same ROW. One could compare to
|
|
185
55
|
the same column or both the row and column, but we choose row here since
|
|
186
|
-
rows correspond to a single RL update or inference step in the
|
|
56
|
+
rows correspond to a single RL update or inference step in the bread
|
|
187
57
|
inverse matrices this method is designed for.
|
|
188
58
|
|
|
189
59
|
Args:
|
|
@@ -237,18 +107,17 @@ def zero_small_off_diagonal_blocks(
|
|
|
237
107
|
return J_trim
|
|
238
108
|
|
|
239
109
|
|
|
240
|
-
def
|
|
241
|
-
|
|
110
|
+
def invert_bread_matrix(
|
|
111
|
+
bread,
|
|
242
112
|
beta_dim,
|
|
243
113
|
theta_dim,
|
|
244
|
-
diag_inverse_stabilization_method=InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES,
|
|
245
114
|
):
|
|
246
115
|
"""
|
|
247
|
-
Invert the
|
|
116
|
+
Invert the bread matrix to get the inverse bread matrix. This is a special
|
|
248
117
|
function in order to take advantage of the block lower triangular structure.
|
|
249
118
|
|
|
250
119
|
The procedure is as follows:
|
|
251
|
-
1. Initialize the
|
|
120
|
+
1. Initialize the matrix B = A^{-1} as a block lower triangular matrix
|
|
252
121
|
with the same block structure as A.
|
|
253
122
|
|
|
254
123
|
2. Compute the diagonal blocks B_{ii}:
|
|
@@ -260,24 +129,23 @@ def invert_inverse_bread_matrix(
|
|
|
260
129
|
B_{ij} = -A_{ii}^{-1} * sum(A_{ik} * B_{kj} for k in range(j, i))
|
|
261
130
|
"""
|
|
262
131
|
blocks = []
|
|
263
|
-
num_beta_block_rows = (
|
|
132
|
+
num_beta_block_rows = (bread.shape[0] - theta_dim) // beta_dim
|
|
264
133
|
|
|
265
134
|
# Create upper rows of block of bread (just the beta portion)
|
|
266
135
|
for i in range(0, num_beta_block_rows):
|
|
267
136
|
beta_block_row = []
|
|
268
137
|
beta_diag_inverse = invert_matrix_and_check_conditioning(
|
|
269
|
-
|
|
138
|
+
bread[
|
|
270
139
|
beta_dim * i : beta_dim * (i + 1),
|
|
271
140
|
beta_dim * i : beta_dim * (i + 1),
|
|
272
141
|
],
|
|
273
|
-
diag_inverse_stabilization_method,
|
|
274
142
|
)[0]
|
|
275
143
|
for j in range(0, num_beta_block_rows):
|
|
276
144
|
if i > j:
|
|
277
145
|
beta_block_row.append(
|
|
278
146
|
-beta_diag_inverse
|
|
279
147
|
@ sum(
|
|
280
|
-
|
|
148
|
+
bread[
|
|
281
149
|
beta_dim * i : beta_dim * (i + 1),
|
|
282
150
|
beta_dim * k : beta_dim * (k + 1),
|
|
283
151
|
]
|
|
@@ -299,17 +167,16 @@ def invert_inverse_bread_matrix(
|
|
|
299
167
|
# Create the bottom block row of bread (the theta portion)
|
|
300
168
|
theta_block_row = []
|
|
301
169
|
theta_diag_inverse = invert_matrix_and_check_conditioning(
|
|
302
|
-
|
|
170
|
+
bread[
|
|
303
171
|
-theta_dim:,
|
|
304
172
|
-theta_dim:,
|
|
305
173
|
],
|
|
306
|
-
diag_inverse_stabilization_method,
|
|
307
174
|
)[0]
|
|
308
175
|
for k in range(0, num_beta_block_rows):
|
|
309
176
|
theta_block_row.append(
|
|
310
177
|
-theta_diag_inverse
|
|
311
178
|
@ sum(
|
|
312
|
-
|
|
179
|
+
bread[
|
|
313
180
|
-theta_dim:,
|
|
314
181
|
beta_dim * h : beta_dim * (h + 1),
|
|
315
182
|
]
|