lifejacket 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/after_study_analysis.py +17 -0
- lifejacket/arg_threading_helpers.py +23 -0
- lifejacket/calculate_derivatives.py +2 -2
- lifejacket/form_adaptive_meat_adjustments_directly.py +1 -1
- lifejacket/input_checks.py +48 -0
- {lifejacket-0.1.0.dist-info → lifejacket-0.2.0.dist-info}/METADATA +1 -1
- lifejacket-0.2.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/RECORD +0 -17
- {lifejacket-0.1.0.dist-info → lifejacket-0.2.0.dist-info}/WHEEL +0 -0
- {lifejacket-0.1.0.dist-info → lifejacket-0.2.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-0.1.0.dist-info → lifejacket-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -128,6 +128,12 @@ def cli():
|
|
|
128
128
|
default=-1000,
|
|
129
129
|
help="Index of the argument holding the decision times the action probabilities correspond to in the tuple of algorithm update func args, if applicable.",
|
|
130
130
|
)
|
|
131
|
+
@click.option(
|
|
132
|
+
"--alg_update_func_args_previous_betas_index",
|
|
133
|
+
type=int,
|
|
134
|
+
default=-1000,
|
|
135
|
+
help="Index of the previous betas array in the tuple of algorithm update func args, if applicable. Note that these are only post-update betas. Sometimes a beta_0 may be defined pre-update; this should not be in here.",
|
|
136
|
+
)
|
|
131
137
|
@click.option(
|
|
132
138
|
"--inference_func_filename",
|
|
133
139
|
type=click.Path(exists=True),
|
|
@@ -299,6 +305,7 @@ def analyze_dataset(
|
|
|
299
305
|
alg_update_func_args_beta_index: int,
|
|
300
306
|
alg_update_func_args_action_prob_index: int,
|
|
301
307
|
alg_update_func_args_action_prob_times_index: int,
|
|
308
|
+
alg_update_func_args_previous_betas_index: int,
|
|
302
309
|
inference_func: Callable,
|
|
303
310
|
inference_func_type: str,
|
|
304
311
|
inference_func_args_theta_index: int,
|
|
@@ -421,6 +428,7 @@ def analyze_dataset(
|
|
|
421
428
|
alg_update_func_args_beta_index,
|
|
422
429
|
alg_update_func_args_action_prob_index,
|
|
423
430
|
alg_update_func_args_action_prob_times_index,
|
|
431
|
+
alg_update_func_args_previous_betas_index,
|
|
424
432
|
theta_est,
|
|
425
433
|
beta_dim,
|
|
426
434
|
suppress_interactive_data_checks,
|
|
@@ -497,6 +505,7 @@ def analyze_dataset(
|
|
|
497
505
|
alg_update_func_args_beta_index,
|
|
498
506
|
alg_update_func_args_action_prob_index,
|
|
499
507
|
alg_update_func_args_action_prob_times_index,
|
|
508
|
+
alg_update_func_args_previous_betas_index,
|
|
500
509
|
inference_func,
|
|
501
510
|
inference_func_type,
|
|
502
511
|
inference_func_args_theta_index,
|
|
@@ -1019,6 +1028,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1019
1028
|
alg_update_func_args_beta_index: int,
|
|
1020
1029
|
alg_update_func_args_action_prob_index: int,
|
|
1021
1030
|
alg_update_func_args_action_prob_times_index: int,
|
|
1031
|
+
alg_update_func_args_previous_betas_index: int,
|
|
1022
1032
|
inference_func: callable,
|
|
1023
1033
|
inference_func_type: str,
|
|
1024
1034
|
inference_func_args_theta_index: int,
|
|
@@ -1075,6 +1085,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1075
1085
|
alg_update_func_args_action_prob_times_index (int):
|
|
1076
1086
|
The index in the update function arguments tuple where an array of times for which the
|
|
1077
1087
|
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
1088
|
+
alg_update_func_args_previous_betas_index (int):
|
|
1089
|
+
The index in the update function arguments tuple where previous betas are provided.
|
|
1078
1090
|
inference_func (callable):
|
|
1079
1091
|
The inference loss or estimating function.
|
|
1080
1092
|
inference_func_type (str):
|
|
@@ -1179,6 +1191,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
|
|
|
1179
1191
|
alg_update_func_args_beta_index,
|
|
1180
1192
|
alg_update_func_args_action_prob_index,
|
|
1181
1193
|
alg_update_func_args_action_prob_times_index,
|
|
1194
|
+
alg_update_func_args_previous_betas_index,
|
|
1182
1195
|
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
1183
1196
|
action_prob_func,
|
|
1184
1197
|
)
|
|
@@ -1278,6 +1291,7 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1278
1291
|
alg_update_func_args_beta_index: int,
|
|
1279
1292
|
alg_update_func_args_action_prob_index: int,
|
|
1280
1293
|
alg_update_func_args_action_prob_times_index: int,
|
|
1294
|
+
alg_update_func_args_previous_betas_index: int,
|
|
1281
1295
|
inference_func: callable,
|
|
1282
1296
|
inference_func_type: str,
|
|
1283
1297
|
inference_func_args_theta_index: int,
|
|
@@ -1354,6 +1368,8 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1354
1368
|
alg_update_func_args_action_prob_times_index (int):
|
|
1355
1369
|
The index in the update function arguments tuple where an array of times for which the
|
|
1356
1370
|
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
1371
|
+
alg_update_func_args_previous_betas_index (int):
|
|
1372
|
+
The index in the update function arguments tuple where the previous betas are provided.
|
|
1357
1373
|
inference_func (callable):
|
|
1358
1374
|
The inference loss or estimating function.
|
|
1359
1375
|
inference_func_type (str):
|
|
@@ -1463,6 +1479,7 @@ def construct_classical_and_adaptive_sandwiches(
|
|
|
1463
1479
|
alg_update_func_args_beta_index,
|
|
1464
1480
|
alg_update_func_args_action_prob_index,
|
|
1465
1481
|
alg_update_func_args_action_prob_times_index,
|
|
1482
|
+
alg_update_func_args_previous_betas_index,
|
|
1466
1483
|
inference_func,
|
|
1467
1484
|
inference_func_type,
|
|
1468
1485
|
inference_func_args_theta_index,
|
|
@@ -63,6 +63,7 @@ def thread_action_prob_func_args(
|
|
|
63
63
|
action_prob_func_args_beta_index (int):
|
|
64
64
|
The index in the action probability function arguments tuple
|
|
65
65
|
where the beta value should be inserted.
|
|
66
|
+
|
|
66
67
|
Returns:
|
|
67
68
|
dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]:
|
|
68
69
|
A map from user ids to maps of decision times to action probability function
|
|
@@ -125,6 +126,7 @@ def thread_update_func_args(
|
|
|
125
126
|
alg_update_func_args_beta_index: int,
|
|
126
127
|
alg_update_func_args_action_prob_index: int,
|
|
127
128
|
alg_update_func_args_action_prob_times_index: int,
|
|
129
|
+
alg_update_func_args_previous_betas_index: int,
|
|
128
130
|
threaded_action_prob_func_args_by_decision_time_by_user_id: dict[
|
|
129
131
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
130
132
|
],
|
|
@@ -165,6 +167,9 @@ def thread_update_func_args(
|
|
|
165
167
|
to the update function, this is the index in the arguments where an array of times for
|
|
166
168
|
which the given action probabilities apply is provided.
|
|
167
169
|
|
|
170
|
+
alg_update_func_args_previous_betas_index (int):
|
|
171
|
+
The index in the update function with previous beta parameters
|
|
172
|
+
|
|
168
173
|
threaded_action_prob_func_args_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
169
174
|
A dictionary mapping decision times to the function arguments required to compute action
|
|
170
175
|
probabilities for this user, and with the shared betas thread in.
|
|
@@ -206,6 +211,24 @@ def thread_update_func_args(
|
|
|
206
211
|
beta_to_introduce,
|
|
207
212
|
)
|
|
208
213
|
)
|
|
214
|
+
if alg_update_func_args_previous_betas_index >= 0:
|
|
215
|
+
previous_betas_to_introduce = all_post_update_betas[
|
|
216
|
+
: len(
|
|
217
|
+
update_func_args_by_user_id[user_id][
|
|
218
|
+
alg_update_func_args_previous_betas_index
|
|
219
|
+
]
|
|
220
|
+
)
|
|
221
|
+
]
|
|
222
|
+
if previous_betas_to_introduce.size > 0:
|
|
223
|
+
threaded_update_func_args_by_policy_num_by_user_id[user_id][
|
|
224
|
+
policy_num
|
|
225
|
+
] = replace_tuple_index(
|
|
226
|
+
threaded_update_func_args_by_policy_num_by_user_id[user_id][
|
|
227
|
+
policy_num
|
|
228
|
+
],
|
|
229
|
+
alg_update_func_args_previous_betas_index,
|
|
230
|
+
previous_betas_to_introduce,
|
|
231
|
+
)
|
|
209
232
|
|
|
210
233
|
if alg_update_func_args_action_prob_index >= 0:
|
|
211
234
|
logger.debug(
|
|
@@ -743,7 +743,7 @@ def get_loss_gradient_derivatives_wrt_pi_batched(
|
|
|
743
743
|
*batched_arg_tensors,
|
|
744
744
|
):
|
|
745
745
|
if update_func_type == FunctionTypes.LOSS:
|
|
746
|
-
return jax.jit(
|
|
746
|
+
return jax.jit( # pylint: disable=not-callable
|
|
747
747
|
jax.vmap(
|
|
748
748
|
fun=jax.jacrev(
|
|
749
749
|
jax.grad(update_func, update_func_args_beta_index),
|
|
@@ -754,7 +754,7 @@ def get_loss_gradient_derivatives_wrt_pi_batched(
|
|
|
754
754
|
)
|
|
755
755
|
)(*batched_arg_tensors)
|
|
756
756
|
if update_func_type == FunctionTypes.ESTIMATING:
|
|
757
|
-
return jax.jit(
|
|
757
|
+
return jax.jit( # pylint: disable=not-callable
|
|
758
758
|
jax.vmap(
|
|
759
759
|
fun=jax.jacrev(
|
|
760
760
|
update_func,
|
|
@@ -62,7 +62,7 @@ def form_adaptive_meat_adjustments_directly(
|
|
|
62
62
|
# 3. Split the effective M blocks into (theta_dim, beta_dim) blocks and the
|
|
63
63
|
# estimating function stacks into (num_updates, beta_dim) stacks.
|
|
64
64
|
|
|
65
|
-
#
|
|
65
|
+
# effective_M_blocks_together is shape (theta_dim, num_updates * beta_dim)
|
|
66
66
|
# We want to split it into a list of (theta_dim, beta_dim) arrays
|
|
67
67
|
M_blocks = np.split(
|
|
68
68
|
effective_M_blocks_together,
|
lifejacket/input_checks.py
CHANGED
|
@@ -41,6 +41,7 @@ def perform_first_wave_input_checks(
|
|
|
41
41
|
alg_update_func_args_beta_index,
|
|
42
42
|
alg_update_func_args_action_prob_index,
|
|
43
43
|
alg_update_func_args_action_prob_times_index,
|
|
44
|
+
alg_update_func_args_previous_betas_index,
|
|
44
45
|
theta_est,
|
|
45
46
|
beta_dim,
|
|
46
47
|
suppress_interactive_data_checks,
|
|
@@ -56,6 +57,9 @@ def perform_first_wave_input_checks(
|
|
|
56
57
|
require_beta_is_1D_array_in_alg_update_args(
|
|
57
58
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
58
59
|
)
|
|
60
|
+
require_previous_betas_is_2D_array_in_alg_update_args(
|
|
61
|
+
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
62
|
+
)
|
|
59
63
|
require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
|
|
60
64
|
study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
|
|
61
65
|
)
|
|
@@ -73,6 +77,9 @@ def perform_first_wave_input_checks(
|
|
|
73
77
|
require_betas_match_in_alg_update_args_each_update(
|
|
74
78
|
alg_update_func_args, alg_update_func_args_beta_index
|
|
75
79
|
)
|
|
80
|
+
require_previous_betas_match_in_alg_update_args_each_update(
|
|
81
|
+
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
82
|
+
)
|
|
76
83
|
require_action_prob_args_in_alg_update_func_correspond_to_study_df(
|
|
77
84
|
study_df,
|
|
78
85
|
action_prob_col_name,
|
|
@@ -675,6 +682,23 @@ def require_beta_is_1D_array_in_alg_update_args(
|
|
|
675
682
|
== 1
|
|
676
683
|
), "Beta is not a 1D array in the algorithm update function args."
|
|
677
684
|
|
|
685
|
+
def require_previous_betas_is_2D_array_in_alg_update_args(
|
|
686
|
+
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
687
|
+
):
|
|
688
|
+
if alg_update_func_args_previous_betas_index < 0:
|
|
689
|
+
return
|
|
690
|
+
|
|
691
|
+
for policy_num in alg_update_func_args:
|
|
692
|
+
for user_id in alg_update_func_args[policy_num]:
|
|
693
|
+
if not alg_update_func_args[policy_num][user_id]:
|
|
694
|
+
continue
|
|
695
|
+
assert (
|
|
696
|
+
alg_update_func_args[policy_num][user_id][
|
|
697
|
+
alg_update_func_args_previous_betas_index
|
|
698
|
+
].ndim
|
|
699
|
+
== 2
|
|
700
|
+
), "Previous betas is not a 2D array in the algorithm update function args."
|
|
701
|
+
|
|
678
702
|
|
|
679
703
|
def require_beta_is_1D_array_in_action_prob_args(
|
|
680
704
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
@@ -827,6 +851,30 @@ def require_betas_match_in_alg_update_args_each_update(
|
|
|
827
851
|
beta, first_beta
|
|
828
852
|
), f"Betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
829
853
|
|
|
854
|
+
def require_previous_betas_match_in_alg_update_args_each_update(
|
|
855
|
+
alg_update_func_args, alg_update_func_args_previous_betas_index
|
|
856
|
+
):
|
|
857
|
+
logger.info(
|
|
858
|
+
"Checking that previous betas match across users for each update in the algorithm update function args."
|
|
859
|
+
)
|
|
860
|
+
if alg_update_func_args_previous_betas_index < 0:
|
|
861
|
+
return
|
|
862
|
+
|
|
863
|
+
for policy_num in alg_update_func_args:
|
|
864
|
+
first_previous_betas = None
|
|
865
|
+
for user_id in alg_update_func_args[policy_num]:
|
|
866
|
+
if not alg_update_func_args[policy_num][user_id]:
|
|
867
|
+
continue
|
|
868
|
+
previous_betas = alg_update_func_args[policy_num][user_id][
|
|
869
|
+
alg_update_func_args_previous_betas_index
|
|
870
|
+
]
|
|
871
|
+
if first_previous_betas is None:
|
|
872
|
+
first_previous_betas = previous_betas
|
|
873
|
+
else:
|
|
874
|
+
assert np.array_equal(
|
|
875
|
+
previous_betas, first_previous_betas
|
|
876
|
+
), f"Previous betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
|
|
877
|
+
|
|
830
878
|
|
|
831
879
|
def require_betas_match_in_action_prob_func_args_each_decision(
|
|
832
880
|
action_prob_func_args, action_prob_func_args_beta_index
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
lifejacket/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
lifejacket/after_study_analysis.py,sha256=_Weeca51EXWlQkx3IWd2t1jImBxODOTe9-8gtKBSlus,82168
|
|
3
|
+
lifejacket/arg_threading_helpers.py,sha256=7HV7qkiJtm-E-cpAhtv4n_BpCVjz9tC0nGENbc090h8,17282
|
|
4
|
+
lifejacket/calculate_derivatives.py,sha256=SceXFWtK56uCCdXGD7v8JijgYz0UCBKzcnrPH_nAqNE,37536
|
|
5
|
+
lifejacket/constants.py,sha256=2L05p6NJ7l3qRZ1hD2KlrvzWF1ReSmWRUkULPIhdvJo,842
|
|
6
|
+
lifejacket/form_adaptive_meat_adjustments_directly.py,sha256=bSLrVYLZR1-Qlm5yIdktzv8ZQTVhHTlhVVL2wEjLTmw,13737
|
|
7
|
+
lifejacket/get_datum_for_blowup_supervised_learning.py,sha256=V8H4PE49dQwsKjj93QEu2BKbhwPr56QMtx2jhan39-c,58357
|
|
8
|
+
lifejacket/helper_functions.py,sha256=xOhRG-Cm4ZdRNm-O0faHna583d74pLWY5_jfnokegWc,23295
|
|
9
|
+
lifejacket/input_checks.py,sha256=KcDdfsdCVCKKcx07FfOKJb3KVX6xFuWwAufGJ3msAuc,46972
|
|
10
|
+
lifejacket/small_sample_corrections.py,sha256=f8jmg9U9ZN77WadJud70tt6NMxCTsSGtlsdF_-mfws4,5543
|
|
11
|
+
lifejacket/trial_conditioning_monitor.py,sha256=qNTHh0zY2P7zJxox_OwhEEK8Ed1l0TPOjGDqNxMNoIQ,42164
|
|
12
|
+
lifejacket/vmap_helpers.py,sha256=pZqYN3p9Ty9DPOeeY9TKbRJXR2AV__HBwwDFOvdOQ84,2688
|
|
13
|
+
lifejacket-0.2.0.dist-info/METADATA,sha256=XzZQXRKzohTg0FOv0sIh7Ii6Snj2m5HVIdeXVYUphe4,7274
|
|
14
|
+
lifejacket-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
lifejacket-0.2.0.dist-info/entry_points.txt,sha256=4k8ibVIUT-OHxPaaDv-QwWpC64ErzhdemHpTAXCnb8w,67
|
|
16
|
+
lifejacket-0.2.0.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
|
|
17
|
+
lifejacket-0.2.0.dist-info/RECORD,,
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
lifejacket/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
lifejacket/after_study_analysis.py,sha256=HtAHRYCni8Lh7HCB9A0SsRkyAol2FbczrkxFn1ww1HA,81170
|
|
3
|
-
lifejacket/arg_threading_helpers.py,sha256=kxtGlN_B1C1WKcWtJMlgWJwzzI1THytflmWsL5ZML7k,16228
|
|
4
|
-
lifejacket/calculate_derivatives.py,sha256=3rYukD1wbjDof7d4_3QdQ-A4GSK9H8z8HJsbNQh0DzA,37472
|
|
5
|
-
lifejacket/constants.py,sha256=2L05p6NJ7l3qRZ1hD2KlrvzWF1ReSmWRUkULPIhdvJo,842
|
|
6
|
-
lifejacket/form_adaptive_meat_adjustments_directly.py,sha256=_BaziGfYjEySN78nU3lCrVtf2KWIuZ8PmzfMZypAaWI,13728
|
|
7
|
-
lifejacket/get_datum_for_blowup_supervised_learning.py,sha256=V8H4PE49dQwsKjj93QEu2BKbhwPr56QMtx2jhan39-c,58357
|
|
8
|
-
lifejacket/helper_functions.py,sha256=xOhRG-Cm4ZdRNm-O0faHna583d74pLWY5_jfnokegWc,23295
|
|
9
|
-
lifejacket/input_checks.py,sha256=A0f2owqRUjeBAh5jLULKu1nXW1SgZDR5eK7xBm1ahZw,44878
|
|
10
|
-
lifejacket/small_sample_corrections.py,sha256=f8jmg9U9ZN77WadJud70tt6NMxCTsSGtlsdF_-mfws4,5543
|
|
11
|
-
lifejacket/trial_conditioning_monitor.py,sha256=qNTHh0zY2P7zJxox_OwhEEK8Ed1l0TPOjGDqNxMNoIQ,42164
|
|
12
|
-
lifejacket/vmap_helpers.py,sha256=pZqYN3p9Ty9DPOeeY9TKbRJXR2AV__HBwwDFOvdOQ84,2688
|
|
13
|
-
lifejacket-0.1.0.dist-info/METADATA,sha256=VT6H9TNcYleRhp-wLda4HXrbI7EJYj6ZV0_7K5fraI4,7274
|
|
14
|
-
lifejacket-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
-
lifejacket-0.1.0.dist-info/entry_points.txt,sha256=4k8ibVIUT-OHxPaaDv-QwWpC64ErzhdemHpTAXCnb8w,67
|
|
16
|
-
lifejacket-0.1.0.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
|
|
17
|
-
lifejacket-0.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|