lifejacket 0.2.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/after_study_analysis.py +397 -387
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -21
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +146 -128
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +7 -7
- lifejacket/get_datum_for_blowup_supervised_learning.py +315 -307
- lifejacket/helper_functions.py +45 -38
- lifejacket/input_checks.py +263 -261
- lifejacket/small_sample_corrections.py +42 -40
- lifejacket-1.0.0.dist-info/METADATA +56 -0
- lifejacket-1.0.0.dist-info/RECORD +17 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@ from .helper_functions import (
|
|
|
16
16
|
calculate_beta_dim,
|
|
17
17
|
collect_all_post_update_betas,
|
|
18
18
|
construct_beta_index_by_policy_num_map,
|
|
19
|
-
|
|
19
|
+
extract_action_and_policy_by_decision_time_by_subject_id,
|
|
20
20
|
flatten_params,
|
|
21
21
|
get_min_time_by_policy_num,
|
|
22
22
|
get_radon_nikodym_weight,
|
|
@@ -34,7 +34,7 @@ logging.basicConfig(
|
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class
|
|
37
|
+
class DeploymentConditioningMonitor:
|
|
38
38
|
whole_RL_block_conditioning_threshold = None
|
|
39
39
|
diagonal_RL_block_conditioning_threshold = None
|
|
40
40
|
|
|
@@ -54,7 +54,7 @@ class TrialConditioningMonitor:
|
|
|
54
54
|
def assess_update(
|
|
55
55
|
self,
|
|
56
56
|
proposed_policy_num: int | float,
|
|
57
|
-
|
|
57
|
+
analysis_df: pd.DataFrame,
|
|
58
58
|
action_prob_func: callable,
|
|
59
59
|
action_prob_func_args: dict,
|
|
60
60
|
action_prob_func_args_beta_index: int,
|
|
@@ -64,11 +64,12 @@ class TrialConditioningMonitor:
|
|
|
64
64
|
alg_update_func_args_beta_index: int,
|
|
65
65
|
alg_update_func_args_action_prob_index: int,
|
|
66
66
|
alg_update_func_args_action_prob_times_index: int,
|
|
67
|
-
|
|
67
|
+
alg_update_func_args_previous_betas_index: int,
|
|
68
|
+
active_col_name: str,
|
|
68
69
|
action_col_name: str,
|
|
69
70
|
policy_num_col_name: str,
|
|
70
71
|
calendar_t_col_name: str,
|
|
71
|
-
|
|
72
|
+
subject_id_col_name: str,
|
|
72
73
|
action_prob_col_name: str,
|
|
73
74
|
suppress_interactive_data_checks: bool,
|
|
74
75
|
suppress_all_data_checks: bool,
|
|
@@ -80,7 +81,7 @@ class TrialConditioningMonitor:
|
|
|
80
81
|
Parameters:
|
|
81
82
|
proposed_policy_num (int | float):
|
|
82
83
|
The policy number of the proposed update.
|
|
83
|
-
|
|
84
|
+
analysis_df (pd.DataFrame):
|
|
84
85
|
DataFrame containing the study data.
|
|
85
86
|
action_prob_func (str):
|
|
86
87
|
Action probability function.
|
|
@@ -100,16 +101,18 @@ class TrialConditioningMonitor:
|
|
|
100
101
|
Index for action probability in algorithm update function arguments.
|
|
101
102
|
alg_update_func_args_action_prob_times_index (int):
|
|
102
103
|
Index for action probability times in algorithm update function arguments.
|
|
103
|
-
|
|
104
|
-
|
|
104
|
+
alg_update_func_args_previous_betas_index (int):
|
|
105
|
+
Index for previous betas in algorithm update function arguments.
|
|
106
|
+
active_col_name (str):
|
|
107
|
+
Column name indicating if a subject is in the study in the study dataframe.
|
|
105
108
|
action_col_name (str):
|
|
106
109
|
Column name for actions in the study dataframe.
|
|
107
110
|
policy_num_col_name (str):
|
|
108
111
|
Column name for policy numbers in the study dataframe.
|
|
109
112
|
calendar_t_col_name (str):
|
|
110
113
|
Column name for calendar time in the study dataframe.
|
|
111
|
-
|
|
112
|
-
Column name for
|
|
114
|
+
subject_id_col_name (str):
|
|
115
|
+
Column name for subject IDs in the study dataframe.
|
|
113
116
|
action_prob_col_name (str):
|
|
114
117
|
Column name for action probabilities in the study dataframe.
|
|
115
118
|
reward_col_name (str):
|
|
@@ -122,11 +125,11 @@ class TrialConditioningMonitor:
|
|
|
122
125
|
Type of small sample correction to apply.
|
|
123
126
|
collect_data_for_blowup_supervised_learning (bool):
|
|
124
127
|
Whether to collect data for doing supervised learning about adaptive sandwich blowup.
|
|
125
|
-
|
|
126
|
-
If True, explicitly forms the per-
|
|
128
|
+
form_adjusted_meat_adjustments_explicitly (bool):
|
|
129
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
|
|
127
130
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
128
131
|
adaptive sandwich is formed without doing this.
|
|
129
|
-
|
|
132
|
+
stabilize_joint_adjusted_bread_inverse (bool):
|
|
130
133
|
If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
|
|
131
134
|
thresholds.
|
|
132
135
|
|
|
@@ -141,11 +144,11 @@ class TrialConditioningMonitor:
|
|
|
141
144
|
|
|
142
145
|
if not suppress_all_data_checks:
|
|
143
146
|
input_checks.perform_alg_only_input_checks(
|
|
144
|
-
|
|
145
|
-
|
|
147
|
+
analysis_df,
|
|
148
|
+
active_col_name,
|
|
146
149
|
policy_num_col_name,
|
|
147
150
|
calendar_t_col_name,
|
|
148
|
-
|
|
151
|
+
subject_id_col_name,
|
|
149
152
|
action_prob_col_name,
|
|
150
153
|
action_prob_func,
|
|
151
154
|
action_prob_func_args,
|
|
@@ -159,7 +162,7 @@ class TrialConditioningMonitor:
|
|
|
159
162
|
|
|
160
163
|
beta_index_by_policy_num, initial_policy_num = (
|
|
161
164
|
construct_beta_index_by_policy_num_map(
|
|
162
|
-
|
|
165
|
+
analysis_df, policy_num_col_name, active_col_name
|
|
163
166
|
)
|
|
164
167
|
)
|
|
165
168
|
# We augment the produced map to include the proposed policy num.
|
|
@@ -174,22 +177,23 @@ class TrialConditioningMonitor:
|
|
|
174
177
|
alg_update_func_args_beta_index,
|
|
175
178
|
)
|
|
176
179
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
180
|
+
(
|
|
181
|
+
action_by_decision_time_by_subject_id,
|
|
182
|
+
policy_num_by_decision_time_by_subject_id,
|
|
183
|
+
) = extract_action_and_policy_by_decision_time_by_subject_id(
|
|
184
|
+
analysis_df,
|
|
185
|
+
subject_id_col_name,
|
|
186
|
+
active_col_name,
|
|
187
|
+
calendar_t_col_name,
|
|
188
|
+
action_col_name,
|
|
189
|
+
policy_num_col_name,
|
|
186
190
|
)
|
|
187
191
|
|
|
188
|
-
|
|
192
|
+
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
189
193
|
|
|
190
194
|
phi_dot_bar, avg_estimating_function_stack = self.construct_phi_dot_bar_so_far(
|
|
191
195
|
all_post_update_betas,
|
|
192
|
-
|
|
196
|
+
subject_ids,
|
|
193
197
|
action_prob_func,
|
|
194
198
|
action_prob_func_args_beta_index,
|
|
195
199
|
alg_update_func,
|
|
@@ -197,12 +201,13 @@ class TrialConditioningMonitor:
|
|
|
197
201
|
alg_update_func_args_beta_index,
|
|
198
202
|
alg_update_func_args_action_prob_index,
|
|
199
203
|
alg_update_func_args_action_prob_times_index,
|
|
204
|
+
alg_update_func_args_previous_betas_index,
|
|
200
205
|
action_prob_func_args,
|
|
201
|
-
|
|
206
|
+
policy_num_by_decision_time_by_subject_id,
|
|
202
207
|
initial_policy_num,
|
|
203
208
|
beta_index_by_policy_num,
|
|
204
209
|
alg_update_func_args,
|
|
205
|
-
|
|
210
|
+
action_by_decision_time_by_subject_id,
|
|
206
211
|
suppress_all_data_checks,
|
|
207
212
|
suppress_interactive_data_checks,
|
|
208
213
|
incremental=incremental,
|
|
@@ -259,7 +264,7 @@ class TrialConditioningMonitor:
|
|
|
259
264
|
def construct_phi_dot_bar_so_far(
|
|
260
265
|
self,
|
|
261
266
|
all_post_update_betas: jnp.ndarray,
|
|
262
|
-
|
|
267
|
+
subject_ids: jnp.ndarray,
|
|
263
268
|
action_prob_func: callable,
|
|
264
269
|
action_prob_func_args_beta_index: int,
|
|
265
270
|
alg_update_func: callable,
|
|
@@ -267,18 +272,19 @@ class TrialConditioningMonitor:
|
|
|
267
272
|
alg_update_func_args_beta_index: int,
|
|
268
273
|
alg_update_func_args_action_prob_index: int,
|
|
269
274
|
alg_update_func_args_action_prob_times_index: int,
|
|
270
|
-
|
|
275
|
+
alg_update_func_args_previous_betas_index: int,
|
|
276
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
271
277
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
272
278
|
],
|
|
273
|
-
|
|
279
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
274
280
|
collections.abc.Hashable, dict[int, int | float]
|
|
275
281
|
],
|
|
276
282
|
initial_policy_num: int | float,
|
|
277
283
|
beta_index_by_policy_num: dict[int | float, int],
|
|
278
|
-
|
|
284
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
279
285
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
280
286
|
],
|
|
281
|
-
|
|
287
|
+
action_by_decision_time_by_subject_id: dict[
|
|
282
288
|
collections.abc.Hashable, dict[int, int]
|
|
283
289
|
],
|
|
284
290
|
suppress_all_data_checks: bool,
|
|
@@ -299,8 +305,8 @@ class TrialConditioningMonitor:
|
|
|
299
305
|
Args:
|
|
300
306
|
all_post_update_betas (jnp.ndarray):
|
|
301
307
|
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
302
|
-
|
|
303
|
-
A 1-D JAX NumPy array holding all
|
|
308
|
+
subject_ids (jnp.ndarray):
|
|
309
|
+
A 1-D JAX NumPy array holding all subject IDs in the study.
|
|
304
310
|
action_prob_func (callable):
|
|
305
311
|
The action probability function.
|
|
306
312
|
action_prob_func_args_beta_index (int):
|
|
@@ -317,22 +323,24 @@ class TrialConditioningMonitor:
|
|
|
317
323
|
alg_update_func_args_action_prob_times_index (int):
|
|
318
324
|
The index in the update function arguments tuple where an array of times for which the
|
|
319
325
|
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
326
|
+
alg_update_func_args_previous_betas_index (int):
|
|
327
|
+
The index in the update function arguments tuple where the previous betas are provided, if applicable. -1 otherwise.
|
|
328
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
329
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
330
|
+
required to compute action probabilities for this subject.
|
|
331
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
332
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
325
333
|
Only applies to in-study decision times!
|
|
326
334
|
initial_policy_num (int | float):
|
|
327
335
|
The policy number of the initial policy before any updates.
|
|
328
336
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
329
337
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
330
338
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
331
|
-
|
|
332
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
339
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
340
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
333
341
|
to their respective update function arguments.
|
|
334
|
-
|
|
335
|
-
A dictionary mapping
|
|
342
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
343
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
336
344
|
Only applies to in-study decision times!
|
|
337
345
|
suppress_all_data_checks (bool):
|
|
338
346
|
If True, suppresses carrying out any data checks at all.
|
|
@@ -362,7 +370,7 @@ class TrialConditioningMonitor:
|
|
|
362
370
|
# here to improve performance. We can simply unflatten them inside the function.
|
|
363
371
|
flatten_params(all_post_update_betas, jnp.array([])),
|
|
364
372
|
beta_dim,
|
|
365
|
-
|
|
373
|
+
subject_ids,
|
|
366
374
|
action_prob_func,
|
|
367
375
|
action_prob_func_args_beta_index,
|
|
368
376
|
alg_update_func,
|
|
@@ -370,12 +378,13 @@ class TrialConditioningMonitor:
|
|
|
370
378
|
alg_update_func_args_beta_index,
|
|
371
379
|
alg_update_func_args_action_prob_index,
|
|
372
380
|
alg_update_func_args_action_prob_times_index,
|
|
373
|
-
|
|
374
|
-
|
|
381
|
+
alg_update_func_args_previous_betas_index,
|
|
382
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
383
|
+
policy_num_by_decision_time_by_subject_id,
|
|
375
384
|
initial_policy_num,
|
|
376
385
|
beta_index_by_policy_num,
|
|
377
|
-
|
|
378
|
-
|
|
386
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
387
|
+
action_by_decision_time_by_subject_id,
|
|
379
388
|
suppress_all_data_checks,
|
|
380
389
|
suppress_interactive_data_checks,
|
|
381
390
|
only_latest_block=True,
|
|
@@ -404,7 +413,7 @@ class TrialConditioningMonitor:
|
|
|
404
413
|
# here to improve performance. We can simply unflatten them inside the function.
|
|
405
414
|
flatten_params(all_post_update_betas, jnp.array([])),
|
|
406
415
|
beta_dim,
|
|
407
|
-
|
|
416
|
+
subject_ids,
|
|
408
417
|
action_prob_func,
|
|
409
418
|
action_prob_func_args_beta_index,
|
|
410
419
|
alg_update_func,
|
|
@@ -412,12 +421,13 @@ class TrialConditioningMonitor:
|
|
|
412
421
|
alg_update_func_args_beta_index,
|
|
413
422
|
alg_update_func_args_action_prob_index,
|
|
414
423
|
alg_update_func_args_action_prob_times_index,
|
|
415
|
-
|
|
416
|
-
|
|
424
|
+
alg_update_func_args_previous_betas_index,
|
|
425
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
426
|
+
policy_num_by_decision_time_by_subject_id,
|
|
417
427
|
initial_policy_num,
|
|
418
428
|
beta_index_by_policy_num,
|
|
419
|
-
|
|
420
|
-
|
|
429
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
430
|
+
action_by_decision_time_by_subject_id,
|
|
421
431
|
suppress_all_data_checks,
|
|
422
432
|
suppress_interactive_data_checks,
|
|
423
433
|
)
|
|
@@ -429,7 +439,7 @@ class TrialConditioningMonitor:
|
|
|
429
439
|
self,
|
|
430
440
|
flattened_betas_and_theta: jnp.ndarray,
|
|
431
441
|
beta_dim: int,
|
|
432
|
-
|
|
442
|
+
subject_ids: jnp.ndarray,
|
|
433
443
|
action_prob_func: callable,
|
|
434
444
|
action_prob_func_args_beta_index: int,
|
|
435
445
|
alg_update_func: callable,
|
|
@@ -437,18 +447,19 @@ class TrialConditioningMonitor:
|
|
|
437
447
|
alg_update_func_args_beta_index: int,
|
|
438
448
|
alg_update_func_args_action_prob_index: int,
|
|
439
449
|
alg_update_func_args_action_prob_times_index: int,
|
|
440
|
-
|
|
450
|
+
alg_update_func_args_previous_betas_index: int,
|
|
451
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
441
452
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
442
453
|
],
|
|
443
|
-
|
|
454
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
444
455
|
collections.abc.Hashable, dict[int, int | float]
|
|
445
456
|
],
|
|
446
457
|
initial_policy_num: int | float,
|
|
447
458
|
beta_index_by_policy_num: dict[int | float, int],
|
|
448
|
-
|
|
459
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
449
460
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
450
461
|
],
|
|
451
|
-
|
|
462
|
+
action_by_decision_time_by_subject_id: dict[
|
|
452
463
|
collections.abc.Hashable, dict[int, int]
|
|
453
464
|
],
|
|
454
465
|
suppress_all_data_checks: bool,
|
|
@@ -459,7 +470,7 @@ class TrialConditioningMonitor:
|
|
|
459
470
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
|
|
460
471
|
]:
|
|
461
472
|
"""
|
|
462
|
-
Computes the average weighted estimating function stack across all
|
|
473
|
+
Computes the average weighted estimating function stack across all subjects, along with
|
|
463
474
|
auxiliary values used to construct the adaptive and classical sandwich variances.
|
|
464
475
|
|
|
465
476
|
If only_latest_block is True, only uses data from the most recent update.
|
|
@@ -471,8 +482,8 @@ class TrialConditioningMonitor:
|
|
|
471
482
|
We simply extract the betas and theta from this array below.
|
|
472
483
|
beta_dim (int):
|
|
473
484
|
The dimension of each of the beta parameters.
|
|
474
|
-
|
|
475
|
-
A 1D JAX NumPy array of
|
|
485
|
+
subject_ids (jnp.ndarray):
|
|
486
|
+
A 1D JAX NumPy array of subject IDs.
|
|
476
487
|
action_prob_func (callable):
|
|
477
488
|
The action probability function.
|
|
478
489
|
action_prob_func_args_beta_index (int):
|
|
@@ -489,22 +500,24 @@ class TrialConditioningMonitor:
|
|
|
489
500
|
alg_update_func_args_action_prob_times_index (int):
|
|
490
501
|
The index in the update function arguments tuple where an array of times for which the
|
|
491
502
|
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
503
|
+
alg_update_func_args_previous_betas_index (int):
|
|
504
|
+
The index in the update function arguments tuple where the previous betas are provided, if applicable. -1 otherwise.
|
|
505
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
506
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
507
|
+
required to compute action probabilities for this subject.
|
|
508
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
509
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
497
510
|
Only applies to in-study decision times!
|
|
498
511
|
initial_policy_num (int | float):
|
|
499
512
|
The policy number of the initial policy before any updates.
|
|
500
513
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
501
514
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
502
515
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
503
|
-
|
|
504
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
516
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
517
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
505
518
|
to their respective update function arguments.
|
|
506
|
-
|
|
507
|
-
A dictionary mapping
|
|
519
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
520
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
508
521
|
Only applies to in-study decision times!
|
|
509
522
|
suppress_all_data_checks (bool):
|
|
510
523
|
If True, suppresses carrying out any data checks at all.
|
|
@@ -536,15 +549,15 @@ class TrialConditioningMonitor:
|
|
|
536
549
|
# 1. If only_latest_block is True, we need to filter all the arguments to only
|
|
537
550
|
# include those relevant to the latest update. We still need action probabilities
|
|
538
551
|
# from the beginning for the weights, but the update function args can be trimmed
|
|
539
|
-
# to the max policy so that the loop
|
|
552
|
+
# to the max policy so that the loop single_subject_weighted_RL_estimating_function_stacker
|
|
540
553
|
# is only over one policy.
|
|
541
554
|
if only_latest_block:
|
|
542
555
|
logger.info(
|
|
543
556
|
"Filtering algorithm update function arguments to only include those relevant to the latest update."
|
|
544
557
|
)
|
|
545
558
|
max_policy_num = max(beta_index_by_policy_num)
|
|
546
|
-
|
|
547
|
-
max_policy_num:
|
|
559
|
+
update_func_args_by_by_subject_id_by_policy_num = {
|
|
560
|
+
max_policy_num: update_func_args_by_by_subject_id_by_policy_num[
|
|
548
561
|
max_policy_num
|
|
549
562
|
]
|
|
550
563
|
}
|
|
@@ -553,15 +566,17 @@ class TrialConditioningMonitor:
|
|
|
553
566
|
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
554
567
|
# values should be the same, but not connected to the parameter we are differentiating
|
|
555
568
|
# with respect to. Note we will also find it useful below to have the action probability args
|
|
556
|
-
# nested dict structure flipped to be
|
|
569
|
+
# nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
|
|
557
570
|
|
|
558
|
-
logger.info(
|
|
571
|
+
logger.info(
|
|
572
|
+
"Threading in betas to action probability arguments for all subjects."
|
|
573
|
+
)
|
|
559
574
|
(
|
|
560
|
-
|
|
561
|
-
|
|
575
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
576
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
562
577
|
) = thread_action_prob_func_args(
|
|
563
|
-
|
|
564
|
-
|
|
578
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
579
|
+
policy_num_by_decision_time_by_subject_id,
|
|
565
580
|
initial_policy_num,
|
|
566
581
|
betas,
|
|
567
582
|
beta_index_by_policy_num,
|
|
@@ -573,16 +588,17 @@ class TrialConditioningMonitor:
|
|
|
573
588
|
# arguments with the central betas introduced.
|
|
574
589
|
logger.info(
|
|
575
590
|
"Threading in betas and beta-dependent action probabilities to algorithm update "
|
|
576
|
-
"function args for all
|
|
591
|
+
"function args for all subjects"
|
|
577
592
|
)
|
|
578
|
-
|
|
579
|
-
|
|
593
|
+
threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
|
|
594
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
580
595
|
betas,
|
|
581
596
|
beta_index_by_policy_num,
|
|
582
597
|
alg_update_func_args_beta_index,
|
|
583
598
|
alg_update_func_args_action_prob_index,
|
|
584
599
|
alg_update_func_args_action_prob_times_index,
|
|
585
|
-
|
|
600
|
+
alg_update_func_args_previous_betas_index,
|
|
601
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
586
602
|
action_prob_func,
|
|
587
603
|
)
|
|
588
604
|
|
|
@@ -592,42 +608,44 @@ class TrialConditioningMonitor:
|
|
|
592
608
|
if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
|
|
593
609
|
input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
|
|
594
610
|
algorithm_estimating_func,
|
|
595
|
-
|
|
596
|
-
|
|
611
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
612
|
+
threaded_update_func_args_by_policy_num_by_subject_id,
|
|
597
613
|
suppress_interactive_data_checks,
|
|
598
614
|
)
|
|
599
615
|
|
|
600
|
-
# 5. Now we can compute the weighted estimating function stacks for all
|
|
616
|
+
# 5. Now we can compute the weighted estimating function stacks for all subjects
|
|
601
617
|
# as well as collect related values used to construct the adaptive and classical
|
|
602
618
|
# sandwich variances.
|
|
603
619
|
RL_stacks = jnp.array(
|
|
604
620
|
[
|
|
605
|
-
self.
|
|
621
|
+
self.single_subject_weighted_RL_estimating_function_stacker(
|
|
606
622
|
beta_dim,
|
|
607
|
-
|
|
623
|
+
subject_id,
|
|
608
624
|
action_prob_func,
|
|
609
625
|
algorithm_estimating_func,
|
|
610
626
|
action_prob_func_args_beta_index,
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
627
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
628
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
629
|
+
subject_id
|
|
630
|
+
],
|
|
631
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
|
|
632
|
+
policy_num_by_decision_time_by_subject_id[subject_id],
|
|
633
|
+
action_by_decision_time_by_subject_id[subject_id],
|
|
616
634
|
beta_index_by_policy_num,
|
|
617
635
|
)
|
|
618
|
-
for
|
|
636
|
+
for subject_id in subject_ids.tolist()
|
|
619
637
|
]
|
|
620
638
|
)
|
|
621
639
|
|
|
622
640
|
# 6. We will differentiate the first output, while the second will be used
|
|
623
641
|
# for an estimating function sum check.
|
|
624
|
-
|
|
625
|
-
return
|
|
642
|
+
mean_stack_across_subjects = jnp.mean(RL_stacks, axis=0)
|
|
643
|
+
return mean_stack_across_subjects, mean_stack_across_subjects
|
|
626
644
|
|
|
627
|
-
def
|
|
645
|
+
def single_subject_weighted_RL_estimating_function_stacker(
|
|
628
646
|
self,
|
|
629
647
|
beta_dim: int,
|
|
630
|
-
|
|
648
|
+
subject_id: collections.abc.Hashable,
|
|
631
649
|
action_prob_func: callable,
|
|
632
650
|
algorithm_estimating_func: callable,
|
|
633
651
|
action_prob_func_args_beta_index: int,
|
|
@@ -660,12 +678,12 @@ class TrialConditioningMonitor:
|
|
|
660
678
|
beta_dim (list[jnp.ndarray]):
|
|
661
679
|
A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
|
|
662
680
|
|
|
663
|
-
|
|
664
|
-
The
|
|
681
|
+
subject_id (collections.abc.Hashable):
|
|
682
|
+
The subject ID for which to compute the weighted estimating function stack.
|
|
665
683
|
|
|
666
684
|
action_prob_func (callable):
|
|
667
685
|
The function used to compute the probability of action 1 at a given decision time for
|
|
668
|
-
a particular
|
|
686
|
+
a particular subject given their state and the algorithm parameters.
|
|
669
687
|
|
|
670
688
|
algorithm_estimating_func (callable):
|
|
671
689
|
The estimating function that corresponds to algorithm updates.
|
|
@@ -674,7 +692,7 @@ class TrialConditioningMonitor:
|
|
|
674
692
|
The index of the beta argument in the action probability function's arguments.
|
|
675
693
|
|
|
676
694
|
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
677
|
-
A map from decision times to tuples of arguments for this
|
|
695
|
+
A map from decision times to tuples of arguments for this subject for the action
|
|
678
696
|
probability function. This is for all decision times (args are an empty
|
|
679
697
|
tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
|
|
680
698
|
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
@@ -687,13 +705,13 @@ class TrialConditioningMonitor:
|
|
|
687
705
|
|
|
688
706
|
threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
689
707
|
A map from policy numbers to tuples containing the arguments for
|
|
690
|
-
the corresponding estimating functions for this
|
|
708
|
+
the corresponding estimating functions for this subject, with the shared betas threaded in
|
|
691
709
|
for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
|
|
692
710
|
should be sorted.
|
|
693
711
|
|
|
694
712
|
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
695
713
|
A dictionary mapping decision times to the policy number in use. This may be
|
|
696
|
-
|
|
714
|
+
subject-specific. Should be sorted by decision time. Only applies to in-study decision
|
|
697
715
|
times!
|
|
698
716
|
|
|
699
717
|
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
@@ -705,18 +723,18 @@ class TrialConditioningMonitor:
|
|
|
705
723
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
706
724
|
|
|
707
725
|
Returns:
|
|
708
|
-
jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the
|
|
726
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the subject's weighted
|
|
709
727
|
estimating function stack.
|
|
710
728
|
"""
|
|
711
729
|
|
|
712
730
|
logger.info(
|
|
713
|
-
"Computing weighted estimating function stack for
|
|
731
|
+
"Computing weighted estimating function stack for subject %s.", subject_id
|
|
714
732
|
)
|
|
715
733
|
|
|
716
734
|
# First, reformat the supplied data into more convenient structures.
|
|
717
735
|
|
|
718
736
|
# 1. Form a dictionary mapping policy numbers to the first time they were
|
|
719
|
-
# applicable (for this
|
|
737
|
+
# applicable (for this subject). Note that this includes ALL policies, initial
|
|
720
738
|
# fallbacks included.
|
|
721
739
|
# Collect the first time after the first update separately for convenience.
|
|
722
740
|
# These are both used to form the Radon-Nikodym weights for the right times.
|
|
@@ -727,17 +745,17 @@ class TrialConditioningMonitor:
|
|
|
727
745
|
)
|
|
728
746
|
)
|
|
729
747
|
|
|
730
|
-
# 2. Get the start and end times for this
|
|
731
|
-
|
|
732
|
-
|
|
748
|
+
# 2. Get the start and end times for this subject.
|
|
749
|
+
subject_start_time = math.inf
|
|
750
|
+
subject_end_time = -math.inf
|
|
733
751
|
for decision_time in action_by_decision_time:
|
|
734
|
-
|
|
735
|
-
|
|
752
|
+
subject_start_time = min(subject_start_time, decision_time)
|
|
753
|
+
subject_end_time = max(subject_end_time, decision_time)
|
|
736
754
|
|
|
737
755
|
# 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
|
|
738
756
|
logger.info(
|
|
739
|
-
"Computing the algorithm component of the weighted estimating function stack for
|
|
740
|
-
|
|
757
|
+
"Computing the algorithm component of the weighted estimating function stack for subject %s.",
|
|
758
|
+
subject_id,
|
|
741
759
|
)
|
|
742
760
|
|
|
743
761
|
in_study_action_prob_func_args = [
|
|
@@ -754,13 +772,13 @@ class TrialConditioningMonitor:
|
|
|
754
772
|
)
|
|
755
773
|
|
|
756
774
|
# Sort the threaded args by decision time to be cautious. We check if the
|
|
757
|
-
#
|
|
758
|
-
# subset of the
|
|
775
|
+
# subject id is present in the subject args dict because we may call this on a
|
|
776
|
+
# subset of the subject arg dict when we are batching arguments by shape
|
|
759
777
|
sorted_threaded_action_prob_args_by_decision_time = {
|
|
760
778
|
decision_time: threaded_action_prob_func_args_by_decision_time[
|
|
761
779
|
decision_time
|
|
762
780
|
]
|
|
763
|
-
for decision_time in range(
|
|
781
|
+
for decision_time in range(subject_start_time, subject_end_time + 1)
|
|
764
782
|
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
765
783
|
}
|
|
766
784
|
|
|
@@ -820,27 +838,27 @@ class TrialConditioningMonitor:
|
|
|
820
838
|
[
|
|
821
839
|
# Here we compute a product of Radon-Nikodym weights
|
|
822
840
|
# for all decision times after the first update and before the update
|
|
823
|
-
# update under consideration took effect, for which the
|
|
841
|
+
# update under consideration took effect, for which the subject was in the study.
|
|
824
842
|
(
|
|
825
843
|
jnp.prod(
|
|
826
844
|
all_weights[
|
|
827
|
-
# The earliest time after the first update where the
|
|
845
|
+
# The earliest time after the first update where the subject was in
|
|
828
846
|
# the study
|
|
829
847
|
max(
|
|
830
848
|
first_time_after_first_update,
|
|
831
|
-
|
|
849
|
+
subject_start_time,
|
|
832
850
|
)
|
|
833
851
|
- decision_time_to_all_weights_index_offset :
|
|
834
|
-
# One more than the latest time the
|
|
852
|
+
# One more than the latest time the subject was in the study before the time
|
|
835
853
|
# the update under consideration first applied. Note the + 1 because range
|
|
836
854
|
# does not include the right endpoint.
|
|
837
855
|
min(
|
|
838
856
|
min_time_by_policy_num.get(policy_num, math.inf),
|
|
839
|
-
|
|
857
|
+
subject_end_time + 1,
|
|
840
858
|
)
|
|
841
859
|
- decision_time_to_all_weights_index_offset,
|
|
842
860
|
]
|
|
843
|
-
# If the
|
|
861
|
+
# If the subject exited the study before there were any updates,
|
|
844
862
|
# this variable will be None and the above code to grab a weight would
|
|
845
863
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
846
864
|
# if they have data to contribute to the update.
|
|
@@ -848,8 +866,8 @@ class TrialConditioningMonitor:
|
|
|
848
866
|
else 1
|
|
849
867
|
) # Now use the above to weight the alg estimating function for this update
|
|
850
868
|
* algorithm_estimating_func(*update_args)
|
|
851
|
-
# If there are no arguments for the update function, the
|
|
852
|
-
# study, so we just add a zero vector contribution to the sum across
|
|
869
|
+
# If there are no arguments for the update function, the subject is not yet in the
|
|
870
|
+
# study, so we just add a zero vector contribution to the sum across subjects.
|
|
853
871
|
# Note that after they exit, they still contribute all their data to later
|
|
854
872
|
# updates.
|
|
855
873
|
if update_args
|