lifejacket 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -128,6 +128,12 @@ def cli():
128
128
  default=-1000,
129
129
  help="Index of the argument holding the decision times the action probabilities correspond to in the tuple of algorithm update func args, if applicable.",
130
130
  )
131
+ @click.option(
132
+ "--alg_update_func_args_previous_betas_index",
133
+ type=int,
134
+ default=-1000,
135
+ help="Index of the previous betas array in the tuple of algorithm update func args, if applicable. Note that these are only post-update betas. Sometimes a beta_0 may be defined pre-update; this should not be in here.",
136
+ )
131
137
  @click.option(
132
138
  "--inference_func_filename",
133
139
  type=click.Path(exists=True),
@@ -299,6 +305,7 @@ def analyze_dataset(
299
305
  alg_update_func_args_beta_index: int,
300
306
  alg_update_func_args_action_prob_index: int,
301
307
  alg_update_func_args_action_prob_times_index: int,
308
+ alg_update_func_args_previous_betas_index: int,
302
309
  inference_func: Callable,
303
310
  inference_func_type: str,
304
311
  inference_func_args_theta_index: int,
@@ -421,6 +428,7 @@ def analyze_dataset(
421
428
  alg_update_func_args_beta_index,
422
429
  alg_update_func_args_action_prob_index,
423
430
  alg_update_func_args_action_prob_times_index,
431
+ alg_update_func_args_previous_betas_index,
424
432
  theta_est,
425
433
  beta_dim,
426
434
  suppress_interactive_data_checks,
@@ -497,6 +505,7 @@ def analyze_dataset(
497
505
  alg_update_func_args_beta_index,
498
506
  alg_update_func_args_action_prob_index,
499
507
  alg_update_func_args_action_prob_times_index,
508
+ alg_update_func_args_previous_betas_index,
500
509
  inference_func,
501
510
  inference_func_type,
502
511
  inference_func_args_theta_index,
@@ -523,10 +532,6 @@ def analyze_dataset(
523
532
  action_prob_col_name,
524
533
  )
525
534
 
526
- joint_adaptive_bread_inverse_cond = jnp.linalg.cond(
527
- stabilized_joint_adaptive_bread_inverse_matrix
528
- )
529
-
530
535
  theta_dim = len(theta_est)
531
536
  if not suppress_all_data_checks:
532
537
  input_checks.require_estimating_functions_sum_to_zero(
@@ -567,6 +572,14 @@ def analyze_dataset(
567
572
  f,
568
573
  )
569
574
 
575
+ joint_adaptive_bread_inverse_cond = jnp.linalg.cond(
576
+ raw_joint_adaptive_bread_inverse_matrix
577
+ )
578
+ logger.info(
579
+ "Joint adaptive bread inverse condition number: %f",
580
+ joint_adaptive_bread_inverse_cond,
581
+ )
582
+
570
583
  debug_pieces_dict = {
571
584
  "theta_est": theta_est,
572
585
  "adaptive_sandwich_var_estimate": adaptive_sandwich_var_estimate,
@@ -1019,6 +1032,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1019
1032
  alg_update_func_args_beta_index: int,
1020
1033
  alg_update_func_args_action_prob_index: int,
1021
1034
  alg_update_func_args_action_prob_times_index: int,
1035
+ alg_update_func_args_previous_betas_index: int,
1022
1036
  inference_func: callable,
1023
1037
  inference_func_type: str,
1024
1038
  inference_func_args_theta_index: int,
@@ -1075,6 +1089,8 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1075
1089
  alg_update_func_args_action_prob_times_index (int):
1076
1090
  The index in the update function arguments tuple where an array of times for which the
1077
1091
  given action probabilities apply is provided, if applicable. -1 otherwise.
1092
+ alg_update_func_args_previous_betas_index (int):
1093
+ The index in the update function arguments tuple where previous betas are provided.
1078
1094
  inference_func (callable):
1079
1095
  The inference loss or estimating function.
1080
1096
  inference_func_type (str):
@@ -1179,6 +1195,7 @@ def get_avg_weighted_estimating_function_stacks_and_aux_values(
1179
1195
  alg_update_func_args_beta_index,
1180
1196
  alg_update_func_args_action_prob_index,
1181
1197
  alg_update_func_args_action_prob_times_index,
1198
+ alg_update_func_args_previous_betas_index,
1182
1199
  threaded_action_prob_func_args_by_decision_time_by_user_id,
1183
1200
  action_prob_func,
1184
1201
  )
@@ -1278,6 +1295,7 @@ def construct_classical_and_adaptive_sandwiches(
1278
1295
  alg_update_func_args_beta_index: int,
1279
1296
  alg_update_func_args_action_prob_index: int,
1280
1297
  alg_update_func_args_action_prob_times_index: int,
1298
+ alg_update_func_args_previous_betas_index: int,
1281
1299
  inference_func: callable,
1282
1300
  inference_func_type: str,
1283
1301
  inference_func_args_theta_index: int,
@@ -1354,6 +1372,8 @@ def construct_classical_and_adaptive_sandwiches(
1354
1372
  alg_update_func_args_action_prob_times_index (int):
1355
1373
  The index in the update function arguments tuple where an array of times for which the
1356
1374
  given action probabilities apply is provided, if applicable. -1 otherwise.
1375
+ alg_update_func_args_previous_betas_index (int):
1376
+ The index in the update function arguments tuple where the previous betas are provided.
1357
1377
  inference_func (callable):
1358
1378
  The inference loss or estimating function.
1359
1379
  inference_func_type (str):
@@ -1463,6 +1483,7 @@ def construct_classical_and_adaptive_sandwiches(
1463
1483
  alg_update_func_args_beta_index,
1464
1484
  alg_update_func_args_action_prob_index,
1465
1485
  alg_update_func_args_action_prob_times_index,
1486
+ alg_update_func_args_previous_betas_index,
1466
1487
  inference_func,
1467
1488
  inference_func_type,
1468
1489
  inference_func_args_theta_index,
@@ -63,6 +63,7 @@ def thread_action_prob_func_args(
63
63
  action_prob_func_args_beta_index (int):
64
64
  The index in the action probability function arguments tuple
65
65
  where the beta value should be inserted.
66
+
66
67
  Returns:
67
68
  dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]:
68
69
  A map from user ids to maps of decision times to action probability function
@@ -125,6 +126,7 @@ def thread_update_func_args(
125
126
  alg_update_func_args_beta_index: int,
126
127
  alg_update_func_args_action_prob_index: int,
127
128
  alg_update_func_args_action_prob_times_index: int,
129
+ alg_update_func_args_previous_betas_index: int,
128
130
  threaded_action_prob_func_args_by_decision_time_by_user_id: dict[
129
131
  collections.abc.Hashable, dict[int, tuple[Any, ...]]
130
132
  ],
@@ -165,6 +167,9 @@ def thread_update_func_args(
165
167
  to the update function, this is the index in the arguments where an array of times for
166
168
  which the given action probabilities apply is provided.
167
169
 
170
+ alg_update_func_args_previous_betas_index (int):
171
+ The index in the update function with previous beta parameters
172
+
168
173
  threaded_action_prob_func_args_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
169
174
  A dictionary mapping decision times to the function arguments required to compute action
170
175
  probabilities for this user, and with the shared betas thread in.
@@ -206,6 +211,24 @@ def thread_update_func_args(
206
211
  beta_to_introduce,
207
212
  )
208
213
  )
214
+ if alg_update_func_args_previous_betas_index >= 0:
215
+ previous_betas_to_introduce = all_post_update_betas[
216
+ : len(
217
+ update_func_args_by_user_id[user_id][
218
+ alg_update_func_args_previous_betas_index
219
+ ]
220
+ )
221
+ ]
222
+ if previous_betas_to_introduce.size > 0:
223
+ threaded_update_func_args_by_policy_num_by_user_id[user_id][
224
+ policy_num
225
+ ] = replace_tuple_index(
226
+ threaded_update_func_args_by_policy_num_by_user_id[user_id][
227
+ policy_num
228
+ ],
229
+ alg_update_func_args_previous_betas_index,
230
+ previous_betas_to_introduce,
231
+ )
209
232
 
210
233
  if alg_update_func_args_action_prob_index >= 0:
211
234
  logger.debug(
@@ -743,7 +743,7 @@ def get_loss_gradient_derivatives_wrt_pi_batched(
743
743
  *batched_arg_tensors,
744
744
  ):
745
745
  if update_func_type == FunctionTypes.LOSS:
746
- return jax.jit(
746
+ return jax.jit( # pylint: disable=not-callable
747
747
  jax.vmap(
748
748
  fun=jax.jacrev(
749
749
  jax.grad(update_func, update_func_args_beta_index),
@@ -754,7 +754,7 @@ def get_loss_gradient_derivatives_wrt_pi_batched(
754
754
  )
755
755
  )(*batched_arg_tensors)
756
756
  if update_func_type == FunctionTypes.ESTIMATING:
757
- return jax.jit(
757
+ return jax.jit( # pylint: disable=not-callable
758
758
  jax.vmap(
759
759
  fun=jax.jacrev(
760
760
  update_func,
@@ -62,7 +62,7 @@ def form_adaptive_meat_adjustments_directly(
62
62
  # 3. Split the effective M blocks into (theta_dim, beta_dim) blocks and the
63
63
  # estimating function stacks into (num_updates, beta_dim) stacks.
64
64
 
65
- # effective_M_blocks is shape (theta_dim, num_updates * beta_dim)
65
+ # effective_M_blocks_together is shape (theta_dim, num_updates * beta_dim)
66
66
  # We want to split it into a list of (theta_dim, beta_dim) arrays
67
67
  M_blocks = np.split(
68
68
  effective_M_blocks_together,
@@ -41,6 +41,7 @@ def perform_first_wave_input_checks(
41
41
  alg_update_func_args_beta_index,
42
42
  alg_update_func_args_action_prob_index,
43
43
  alg_update_func_args_action_prob_times_index,
44
+ alg_update_func_args_previous_betas_index,
44
45
  theta_est,
45
46
  beta_dim,
46
47
  suppress_interactive_data_checks,
@@ -56,6 +57,9 @@ def perform_first_wave_input_checks(
56
57
  require_beta_is_1D_array_in_alg_update_args(
57
58
  alg_update_func_args, alg_update_func_args_beta_index
58
59
  )
60
+ require_previous_betas_is_2D_array_in_alg_update_args(
61
+ alg_update_func_args, alg_update_func_args_previous_betas_index
62
+ )
59
63
  require_all_policy_numbers_in_study_df_except_possibly_initial_and_fallback_present_in_alg_update_args(
60
64
  study_df, in_study_col_name, policy_num_col_name, alg_update_func_args
61
65
  )
@@ -73,6 +77,9 @@ def perform_first_wave_input_checks(
73
77
  require_betas_match_in_alg_update_args_each_update(
74
78
  alg_update_func_args, alg_update_func_args_beta_index
75
79
  )
80
+ require_previous_betas_match_in_alg_update_args_each_update(
81
+ alg_update_func_args, alg_update_func_args_previous_betas_index
82
+ )
76
83
  require_action_prob_args_in_alg_update_func_correspond_to_study_df(
77
84
  study_df,
78
85
  action_prob_col_name,
@@ -675,6 +682,23 @@ def require_beta_is_1D_array_in_alg_update_args(
675
682
  == 1
676
683
  ), "Beta is not a 1D array in the algorithm update function args."
677
684
 
685
+ def require_previous_betas_is_2D_array_in_alg_update_args(
686
+ alg_update_func_args, alg_update_func_args_previous_betas_index
687
+ ):
688
+ if alg_update_func_args_previous_betas_index < 0:
689
+ return
690
+
691
+ for policy_num in alg_update_func_args:
692
+ for user_id in alg_update_func_args[policy_num]:
693
+ if not alg_update_func_args[policy_num][user_id]:
694
+ continue
695
+ assert (
696
+ alg_update_func_args[policy_num][user_id][
697
+ alg_update_func_args_previous_betas_index
698
+ ].ndim
699
+ == 2
700
+ ), "Previous betas is not a 2D array in the algorithm update function args."
701
+
678
702
 
679
703
  def require_beta_is_1D_array_in_action_prob_args(
680
704
  action_prob_func_args, action_prob_func_args_beta_index
@@ -827,6 +851,30 @@ def require_betas_match_in_alg_update_args_each_update(
827
851
  beta, first_beta
828
852
  ), f"Betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
829
853
 
854
+ def require_previous_betas_match_in_alg_update_args_each_update(
855
+ alg_update_func_args, alg_update_func_args_previous_betas_index
856
+ ):
857
+ logger.info(
858
+ "Checking that previous betas match across users for each update in the algorithm update function args."
859
+ )
860
+ if alg_update_func_args_previous_betas_index < 0:
861
+ return
862
+
863
+ for policy_num in alg_update_func_args:
864
+ first_previous_betas = None
865
+ for user_id in alg_update_func_args[policy_num]:
866
+ if not alg_update_func_args[policy_num][user_id]:
867
+ continue
868
+ previous_betas = alg_update_func_args[policy_num][user_id][
869
+ alg_update_func_args_previous_betas_index
870
+ ]
871
+ if first_previous_betas is None:
872
+ first_previous_betas = previous_betas
873
+ else:
874
+ assert np.array_equal(
875
+ previous_betas, first_previous_betas
876
+ ), f"Previous betas do not match across users in the algorithm update function args for policy number {policy_num}. Please see the contract for details."
877
+
830
878
 
831
879
  def require_betas_match_in_action_prob_func_args_each_decision(
832
880
  action_prob_func_args, action_prob_func_args_beta_index
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lifejacket
3
- Version: 0.1.0
3
+ Version: 0.2.1
4
4
  Summary: A package for after-study analysis of adaptive experiments in which data is pooled across users.
5
5
  Author-email: Nowell Closser <nowellclosser@gmail.com>
6
6
  Requires-Python: >=3.10
@@ -28,7 +28,7 @@ Requires-Dist: flake8>=4.0; extra == "dev"
28
28
  |__/
29
29
  ```
30
30
 
31
- Save your standard errors from pooling in adaptive experiments.
31
+ Save your standard errors from pooling in online decision-making algorithms.
32
32
 
33
33
  ## Setup (if not using conda)
34
34
  ### Create and activate a virtual environment
@@ -0,0 +1,17 @@
1
+ lifejacket/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ lifejacket/after_study_analysis.py,sha256=jTOPDhDWThlE0pyVtprkYBJ_W7SHzeloKFrnI_WLBks,82288
3
+ lifejacket/arg_threading_helpers.py,sha256=7HV7qkiJtm-E-cpAhtv4n_BpCVjz9tC0nGENbc090h8,17282
4
+ lifejacket/calculate_derivatives.py,sha256=SceXFWtK56uCCdXGD7v8JijgYz0UCBKzcnrPH_nAqNE,37536
5
+ lifejacket/constants.py,sha256=2L05p6NJ7l3qRZ1hD2KlrvzWF1ReSmWRUkULPIhdvJo,842
6
+ lifejacket/form_adaptive_meat_adjustments_directly.py,sha256=bSLrVYLZR1-Qlm5yIdktzv8ZQTVhHTlhVVL2wEjLTmw,13737
7
+ lifejacket/get_datum_for_blowup_supervised_learning.py,sha256=V8H4PE49dQwsKjj93QEu2BKbhwPr56QMtx2jhan39-c,58357
8
+ lifejacket/helper_functions.py,sha256=xOhRG-Cm4ZdRNm-O0faHna583d74pLWY5_jfnokegWc,23295
9
+ lifejacket/input_checks.py,sha256=KcDdfsdCVCKKcx07FfOKJb3KVX6xFuWwAufGJ3msAuc,46972
10
+ lifejacket/small_sample_corrections.py,sha256=f8jmg9U9ZN77WadJud70tt6NMxCTsSGtlsdF_-mfws4,5543
11
+ lifejacket/trial_conditioning_monitor.py,sha256=qNTHh0zY2P7zJxox_OwhEEK8Ed1l0TPOjGDqNxMNoIQ,42164
12
+ lifejacket/vmap_helpers.py,sha256=pZqYN3p9Ty9DPOeeY9TKbRJXR2AV__HBwwDFOvdOQ84,2688
13
+ lifejacket-0.2.1.dist-info/METADATA,sha256=vFb90EnjvF_CxGN2XKlk6b8s1iK-aqF0p4Wr0dEIKxA,7287
14
+ lifejacket-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
+ lifejacket-0.2.1.dist-info/entry_points.txt,sha256=4k8ibVIUT-OHxPaaDv-QwWpC64ErzhdemHpTAXCnb8w,67
16
+ lifejacket-0.2.1.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
17
+ lifejacket-0.2.1.dist-info/RECORD,,
@@ -1,17 +0,0 @@
1
- lifejacket/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- lifejacket/after_study_analysis.py,sha256=HtAHRYCni8Lh7HCB9A0SsRkyAol2FbczrkxFn1ww1HA,81170
3
- lifejacket/arg_threading_helpers.py,sha256=kxtGlN_B1C1WKcWtJMlgWJwzzI1THytflmWsL5ZML7k,16228
4
- lifejacket/calculate_derivatives.py,sha256=3rYukD1wbjDof7d4_3QdQ-A4GSK9H8z8HJsbNQh0DzA,37472
5
- lifejacket/constants.py,sha256=2L05p6NJ7l3qRZ1hD2KlrvzWF1ReSmWRUkULPIhdvJo,842
6
- lifejacket/form_adaptive_meat_adjustments_directly.py,sha256=_BaziGfYjEySN78nU3lCrVtf2KWIuZ8PmzfMZypAaWI,13728
7
- lifejacket/get_datum_for_blowup_supervised_learning.py,sha256=V8H4PE49dQwsKjj93QEu2BKbhwPr56QMtx2jhan39-c,58357
8
- lifejacket/helper_functions.py,sha256=xOhRG-Cm4ZdRNm-O0faHna583d74pLWY5_jfnokegWc,23295
9
- lifejacket/input_checks.py,sha256=A0f2owqRUjeBAh5jLULKu1nXW1SgZDR5eK7xBm1ahZw,44878
10
- lifejacket/small_sample_corrections.py,sha256=f8jmg9U9ZN77WadJud70tt6NMxCTsSGtlsdF_-mfws4,5543
11
- lifejacket/trial_conditioning_monitor.py,sha256=qNTHh0zY2P7zJxox_OwhEEK8Ed1l0TPOjGDqNxMNoIQ,42164
12
- lifejacket/vmap_helpers.py,sha256=pZqYN3p9Ty9DPOeeY9TKbRJXR2AV__HBwwDFOvdOQ84,2688
13
- lifejacket-0.1.0.dist-info/METADATA,sha256=VT6H9TNcYleRhp-wLda4HXrbI7EJYj6ZV0_7K5fraI4,7274
14
- lifejacket-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
- lifejacket-0.1.0.dist-info/entry_points.txt,sha256=4k8ibVIUT-OHxPaaDv-QwWpC64ErzhdemHpTAXCnb8w,67
16
- lifejacket-0.1.0.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
17
- lifejacket-0.1.0.dist-info/RECORD,,