lifejacket 0.2.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -23
- lifejacket/constants.py +4 -16
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +163 -138
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +32 -34
- lifejacket/get_datum_for_blowup_supervised_learning.py +341 -339
- lifejacket/helper_functions.py +60 -186
- lifejacket/input_checks.py +303 -302
- lifejacket/{after_study_analysis.py → post_deployment_analysis.py} +470 -457
- lifejacket/small_sample_corrections.py +49 -49
- lifejacket-1.0.2.dist-info/METADATA +56 -0
- lifejacket-1.0.2.dist-info/RECORD +17 -0
- lifejacket-1.0.2.dist-info/entry_points.txt +2 -0
- lifejacket-0.2.1.dist-info/METADATA +0 -100
- lifejacket-0.2.1.dist-info/RECORD +0 -17
- lifejacket-0.2.1.dist-info/entry_points.txt +0 -2
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.1.dist-info → lifejacket-1.0.2.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@ from .helper_functions import (
|
|
|
16
16
|
calculate_beta_dim,
|
|
17
17
|
collect_all_post_update_betas,
|
|
18
18
|
construct_beta_index_by_policy_num_map,
|
|
19
|
-
|
|
19
|
+
extract_action_and_policy_by_decision_time_by_subject_id,
|
|
20
20
|
flatten_params,
|
|
21
21
|
get_min_time_by_policy_num,
|
|
22
22
|
get_radon_nikodym_weight,
|
|
@@ -34,7 +34,13 @@ logging.basicConfig(
|
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class
|
|
37
|
+
class DeploymentConditioningMonitor:
|
|
38
|
+
"""
|
|
39
|
+
Experimental feature. Monitors the conditioning of the RL portion of the bread matrix.
|
|
40
|
+
Repeats more logic from post_deployment_analysis.py than is ideal, but this is for experimental use only.
|
|
41
|
+
Unit tests should be unskipped and expanded if this is to be used more broadly.
|
|
42
|
+
"""
|
|
43
|
+
|
|
38
44
|
whole_RL_block_conditioning_threshold = None
|
|
39
45
|
diagonal_RL_block_conditioning_threshold = None
|
|
40
46
|
|
|
@@ -54,7 +60,7 @@ class TrialConditioningMonitor:
|
|
|
54
60
|
def assess_update(
|
|
55
61
|
self,
|
|
56
62
|
proposed_policy_num: int | float,
|
|
57
|
-
|
|
63
|
+
analysis_df: pd.DataFrame,
|
|
58
64
|
action_prob_func: callable,
|
|
59
65
|
action_prob_func_args: dict,
|
|
60
66
|
action_prob_func_args_beta_index: int,
|
|
@@ -64,23 +70,24 @@ class TrialConditioningMonitor:
|
|
|
64
70
|
alg_update_func_args_beta_index: int,
|
|
65
71
|
alg_update_func_args_action_prob_index: int,
|
|
66
72
|
alg_update_func_args_action_prob_times_index: int,
|
|
67
|
-
|
|
73
|
+
alg_update_func_args_previous_betas_index: int,
|
|
74
|
+
active_col_name: str,
|
|
68
75
|
action_col_name: str,
|
|
69
76
|
policy_num_col_name: str,
|
|
70
77
|
calendar_t_col_name: str,
|
|
71
|
-
|
|
78
|
+
subject_id_col_name: str,
|
|
72
79
|
action_prob_col_name: str,
|
|
73
80
|
suppress_interactive_data_checks: bool,
|
|
74
81
|
suppress_all_data_checks: bool,
|
|
75
82
|
incremental: bool = True,
|
|
76
83
|
) -> None:
|
|
77
84
|
"""
|
|
78
|
-
Analyzes a dataset to estimate parameters and variance using
|
|
85
|
+
Analyzes a dataset to estimate parameters and variance using adjusted and classical sandwich estimators.
|
|
79
86
|
|
|
80
87
|
Parameters:
|
|
81
88
|
proposed_policy_num (int | float):
|
|
82
89
|
The policy number of the proposed update.
|
|
83
|
-
|
|
90
|
+
analysis_df (pd.DataFrame):
|
|
84
91
|
DataFrame containing the study data.
|
|
85
92
|
action_prob_func (str):
|
|
86
93
|
Action probability function.
|
|
@@ -100,16 +107,18 @@ class TrialConditioningMonitor:
|
|
|
100
107
|
Index for action probability in algorithm update function arguments.
|
|
101
108
|
alg_update_func_args_action_prob_times_index (int):
|
|
102
109
|
Index for action probability times in algorithm update function arguments.
|
|
103
|
-
|
|
104
|
-
|
|
110
|
+
alg_update_func_args_previous_betas_index (int):
|
|
111
|
+
Index for previous betas in algorithm update function arguments.
|
|
112
|
+
active_col_name (str):
|
|
113
|
+
Column name indicating if a subject is in the study in the study dataframe.
|
|
105
114
|
action_col_name (str):
|
|
106
115
|
Column name for actions in the study dataframe.
|
|
107
116
|
policy_num_col_name (str):
|
|
108
117
|
Column name for policy numbers in the study dataframe.
|
|
109
118
|
calendar_t_col_name (str):
|
|
110
119
|
Column name for calendar time in the study dataframe.
|
|
111
|
-
|
|
112
|
-
Column name for
|
|
120
|
+
subject_id_col_name (str):
|
|
121
|
+
Column name for subject IDs in the study dataframe.
|
|
113
122
|
action_prob_col_name (str):
|
|
114
123
|
Column name for action probabilities in the study dataframe.
|
|
115
124
|
reward_col_name (str):
|
|
@@ -121,13 +130,13 @@ class TrialConditioningMonitor:
|
|
|
121
130
|
small_sample_correction (str):
|
|
122
131
|
Type of small sample correction to apply.
|
|
123
132
|
collect_data_for_blowup_supervised_learning (bool):
|
|
124
|
-
Whether to collect data for doing supervised learning about
|
|
125
|
-
|
|
126
|
-
If True, explicitly forms the per-
|
|
133
|
+
Whether to collect data for doing supervised learning about adjusted sandwich blowup.
|
|
134
|
+
form_adjusted_meat_adjustments_explicitly (bool):
|
|
135
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
|
|
127
136
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
If True, stabilizes the joint
|
|
137
|
+
adjusted sandwich is formed without doing this.
|
|
138
|
+
stabilize_joint_bread (bool):
|
|
139
|
+
If True, stabilizes the joint bread matrix if it does not meet conditioning
|
|
131
140
|
thresholds.
|
|
132
141
|
|
|
133
142
|
Returns:
|
|
@@ -141,11 +150,11 @@ class TrialConditioningMonitor:
|
|
|
141
150
|
|
|
142
151
|
if not suppress_all_data_checks:
|
|
143
152
|
input_checks.perform_alg_only_input_checks(
|
|
144
|
-
|
|
145
|
-
|
|
153
|
+
analysis_df,
|
|
154
|
+
active_col_name,
|
|
146
155
|
policy_num_col_name,
|
|
147
156
|
calendar_t_col_name,
|
|
148
|
-
|
|
157
|
+
subject_id_col_name,
|
|
149
158
|
action_prob_col_name,
|
|
150
159
|
action_prob_func,
|
|
151
160
|
action_prob_func_args,
|
|
@@ -154,12 +163,13 @@ class TrialConditioningMonitor:
|
|
|
154
163
|
alg_update_func_args_beta_index,
|
|
155
164
|
alg_update_func_args_action_prob_index,
|
|
156
165
|
alg_update_func_args_action_prob_times_index,
|
|
166
|
+
alg_update_func_args_previous_betas_index,
|
|
157
167
|
suppress_interactive_data_checks,
|
|
158
168
|
)
|
|
159
169
|
|
|
160
170
|
beta_index_by_policy_num, initial_policy_num = (
|
|
161
171
|
construct_beta_index_by_policy_num_map(
|
|
162
|
-
|
|
172
|
+
analysis_df, policy_num_col_name, active_col_name
|
|
163
173
|
)
|
|
164
174
|
)
|
|
165
175
|
# We augment the produced map to include the proposed policy num.
|
|
@@ -174,22 +184,23 @@ class TrialConditioningMonitor:
|
|
|
174
184
|
alg_update_func_args_beta_index,
|
|
175
185
|
)
|
|
176
186
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
187
|
+
(
|
|
188
|
+
action_by_decision_time_by_subject_id,
|
|
189
|
+
policy_num_by_decision_time_by_subject_id,
|
|
190
|
+
) = extract_action_and_policy_by_decision_time_by_subject_id(
|
|
191
|
+
analysis_df,
|
|
192
|
+
subject_id_col_name,
|
|
193
|
+
active_col_name,
|
|
194
|
+
calendar_t_col_name,
|
|
195
|
+
action_col_name,
|
|
196
|
+
policy_num_col_name,
|
|
186
197
|
)
|
|
187
198
|
|
|
188
|
-
|
|
199
|
+
subject_ids = jnp.array(analysis_df[subject_id_col_name].unique())
|
|
189
200
|
|
|
190
201
|
phi_dot_bar, avg_estimating_function_stack = self.construct_phi_dot_bar_so_far(
|
|
191
202
|
all_post_update_betas,
|
|
192
|
-
|
|
203
|
+
subject_ids,
|
|
193
204
|
action_prob_func,
|
|
194
205
|
action_prob_func_args_beta_index,
|
|
195
206
|
alg_update_func,
|
|
@@ -197,12 +208,13 @@ class TrialConditioningMonitor:
|
|
|
197
208
|
alg_update_func_args_beta_index,
|
|
198
209
|
alg_update_func_args_action_prob_index,
|
|
199
210
|
alg_update_func_args_action_prob_times_index,
|
|
211
|
+
alg_update_func_args_previous_betas_index,
|
|
200
212
|
action_prob_func_args,
|
|
201
|
-
|
|
213
|
+
policy_num_by_decision_time_by_subject_id,
|
|
202
214
|
initial_policy_num,
|
|
203
215
|
beta_index_by_policy_num,
|
|
204
216
|
alg_update_func_args,
|
|
205
|
-
|
|
217
|
+
action_by_decision_time_by_subject_id,
|
|
206
218
|
suppress_all_data_checks,
|
|
207
219
|
suppress_interactive_data_checks,
|
|
208
220
|
incremental=incremental,
|
|
@@ -225,7 +237,7 @@ class TrialConditioningMonitor:
|
|
|
225
237
|
|
|
226
238
|
if whole_RL_block_condition_number > self.whole_RL_block_conditioning_threshold:
|
|
227
239
|
logger.warning(
|
|
228
|
-
"The RL portion of the bread
|
|
240
|
+
"The RL portion of the bread up to this point exceeds the threshold set (condition number: %s, threshold: %s). Consider an alternative update strategy which produces less dependence on previous RL parameters (via the data they produced) and/or improves the conditioning of each update itself. Regularization may help with both of these.",
|
|
229
241
|
whole_RL_block_condition_number,
|
|
230
242
|
self.whole_RL_block_conditioning_threshold,
|
|
231
243
|
)
|
|
@@ -236,7 +248,7 @@ class TrialConditioningMonitor:
|
|
|
236
248
|
> self.diagonal_RL_block_conditioning_threshold
|
|
237
249
|
):
|
|
238
250
|
logger.warning(
|
|
239
|
-
"The diagonal RL block of the bread
|
|
251
|
+
"The diagonal RL block of the bread up to this point exceeds the threshold set (condition number: %s, threshold: %s). This may illustrate a fundamental problem with the conditioning of the RL update procedure.",
|
|
240
252
|
new_diagonal_RL_block_condition_number,
|
|
241
253
|
self.diagonal_RL_block_conditioning_threshold,
|
|
242
254
|
)
|
|
@@ -259,7 +271,7 @@ class TrialConditioningMonitor:
|
|
|
259
271
|
def construct_phi_dot_bar_so_far(
|
|
260
272
|
self,
|
|
261
273
|
all_post_update_betas: jnp.ndarray,
|
|
262
|
-
|
|
274
|
+
subject_ids: jnp.ndarray,
|
|
263
275
|
action_prob_func: callable,
|
|
264
276
|
action_prob_func_args_beta_index: int,
|
|
265
277
|
alg_update_func: callable,
|
|
@@ -267,18 +279,19 @@ class TrialConditioningMonitor:
|
|
|
267
279
|
alg_update_func_args_beta_index: int,
|
|
268
280
|
alg_update_func_args_action_prob_index: int,
|
|
269
281
|
alg_update_func_args_action_prob_times_index: int,
|
|
270
|
-
|
|
282
|
+
alg_update_func_args_previous_betas_index: int,
|
|
283
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
271
284
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
272
285
|
],
|
|
273
|
-
|
|
286
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
274
287
|
collections.abc.Hashable, dict[int, int | float]
|
|
275
288
|
],
|
|
276
289
|
initial_policy_num: int | float,
|
|
277
290
|
beta_index_by_policy_num: dict[int | float, int],
|
|
278
|
-
|
|
291
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
279
292
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
280
293
|
],
|
|
281
|
-
|
|
294
|
+
action_by_decision_time_by_subject_id: dict[
|
|
282
295
|
collections.abc.Hashable, dict[int, int]
|
|
283
296
|
],
|
|
284
297
|
suppress_all_data_checks: bool,
|
|
@@ -289,18 +302,18 @@ class TrialConditioningMonitor:
|
|
|
289
302
|
jnp.ndarray[jnp.float32],
|
|
290
303
|
]:
|
|
291
304
|
"""
|
|
292
|
-
Constructs the classical and
|
|
305
|
+
Constructs the classical and bread and meat matrices, as well as the average
|
|
293
306
|
estimating function stack and some other intermediate pieces.
|
|
294
307
|
|
|
295
308
|
This is done by computing and differentiating the average weighted estimating function stack
|
|
296
|
-
with respect to the betas and theta, using the resulting Jacobian to compute the
|
|
309
|
+
with respect to the betas and theta, using the resulting Jacobian to compute the bread
|
|
297
310
|
and meat matrices, and then stably computing sandwiches.
|
|
298
311
|
|
|
299
312
|
Args:
|
|
300
313
|
all_post_update_betas (jnp.ndarray):
|
|
301
314
|
A 2-D JAX NumPy array representing all parameter estimates for the algorithm updates.
|
|
302
|
-
|
|
303
|
-
A 1-D JAX NumPy array holding all
|
|
315
|
+
subject_ids (jnp.ndarray):
|
|
316
|
+
A 1-D JAX NumPy array holding all subject IDs in the study.
|
|
304
317
|
action_prob_func (callable):
|
|
305
318
|
The action probability function.
|
|
306
319
|
action_prob_func_args_beta_index (int):
|
|
@@ -317,22 +330,24 @@ class TrialConditioningMonitor:
|
|
|
317
330
|
alg_update_func_args_action_prob_times_index (int):
|
|
318
331
|
The index in the update function arguments tuple where an array of times for which the
|
|
319
332
|
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
333
|
+
alg_update_func_args_previous_betas_index (int):
|
|
334
|
+
The index in the update function arguments tuple where the previous betas are provided, if applicable. -1 otherwise.
|
|
335
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
336
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
337
|
+
required to compute action probabilities for this subject.
|
|
338
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
339
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
325
340
|
Only applies to in-study decision times!
|
|
326
341
|
initial_policy_num (int | float):
|
|
327
342
|
The policy number of the initial policy before any updates.
|
|
328
343
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
329
344
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
330
345
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
331
|
-
|
|
332
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
346
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
347
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
333
348
|
to their respective update function arguments.
|
|
334
|
-
|
|
335
|
-
A dictionary mapping
|
|
349
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
350
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
336
351
|
Only applies to in-study decision times!
|
|
337
352
|
suppress_all_data_checks (bool):
|
|
338
353
|
If True, suppresses carrying out any data checks at all.
|
|
@@ -362,7 +377,7 @@ class TrialConditioningMonitor:
|
|
|
362
377
|
# here to improve performance. We can simply unflatten them inside the function.
|
|
363
378
|
flatten_params(all_post_update_betas, jnp.array([])),
|
|
364
379
|
beta_dim,
|
|
365
|
-
|
|
380
|
+
subject_ids,
|
|
366
381
|
action_prob_func,
|
|
367
382
|
action_prob_func_args_beta_index,
|
|
368
383
|
alg_update_func,
|
|
@@ -370,12 +385,13 @@ class TrialConditioningMonitor:
|
|
|
370
385
|
alg_update_func_args_beta_index,
|
|
371
386
|
alg_update_func_args_action_prob_index,
|
|
372
387
|
alg_update_func_args_action_prob_times_index,
|
|
373
|
-
|
|
374
|
-
|
|
388
|
+
alg_update_func_args_previous_betas_index,
|
|
389
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
390
|
+
policy_num_by_decision_time_by_subject_id,
|
|
375
391
|
initial_policy_num,
|
|
376
392
|
beta_index_by_policy_num,
|
|
377
|
-
|
|
378
|
-
|
|
393
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
394
|
+
action_by_decision_time_by_subject_id,
|
|
379
395
|
suppress_all_data_checks,
|
|
380
396
|
suppress_interactive_data_checks,
|
|
381
397
|
only_latest_block=True,
|
|
@@ -404,7 +420,7 @@ class TrialConditioningMonitor:
|
|
|
404
420
|
# here to improve performance. We can simply unflatten them inside the function.
|
|
405
421
|
flatten_params(all_post_update_betas, jnp.array([])),
|
|
406
422
|
beta_dim,
|
|
407
|
-
|
|
423
|
+
subject_ids,
|
|
408
424
|
action_prob_func,
|
|
409
425
|
action_prob_func_args_beta_index,
|
|
410
426
|
alg_update_func,
|
|
@@ -412,12 +428,13 @@ class TrialConditioningMonitor:
|
|
|
412
428
|
alg_update_func_args_beta_index,
|
|
413
429
|
alg_update_func_args_action_prob_index,
|
|
414
430
|
alg_update_func_args_action_prob_times_index,
|
|
415
|
-
|
|
416
|
-
|
|
431
|
+
alg_update_func_args_previous_betas_index,
|
|
432
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
433
|
+
policy_num_by_decision_time_by_subject_id,
|
|
417
434
|
initial_policy_num,
|
|
418
435
|
beta_index_by_policy_num,
|
|
419
|
-
|
|
420
|
-
|
|
436
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
437
|
+
action_by_decision_time_by_subject_id,
|
|
421
438
|
suppress_all_data_checks,
|
|
422
439
|
suppress_interactive_data_checks,
|
|
423
440
|
)
|
|
@@ -429,7 +446,7 @@ class TrialConditioningMonitor:
|
|
|
429
446
|
self,
|
|
430
447
|
flattened_betas_and_theta: jnp.ndarray,
|
|
431
448
|
beta_dim: int,
|
|
432
|
-
|
|
449
|
+
subject_ids: jnp.ndarray,
|
|
433
450
|
action_prob_func: callable,
|
|
434
451
|
action_prob_func_args_beta_index: int,
|
|
435
452
|
alg_update_func: callable,
|
|
@@ -437,18 +454,19 @@ class TrialConditioningMonitor:
|
|
|
437
454
|
alg_update_func_args_beta_index: int,
|
|
438
455
|
alg_update_func_args_action_prob_index: int,
|
|
439
456
|
alg_update_func_args_action_prob_times_index: int,
|
|
440
|
-
|
|
457
|
+
alg_update_func_args_previous_betas_index: int,
|
|
458
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
441
459
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
442
460
|
],
|
|
443
|
-
|
|
461
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
444
462
|
collections.abc.Hashable, dict[int, int | float]
|
|
445
463
|
],
|
|
446
464
|
initial_policy_num: int | float,
|
|
447
465
|
beta_index_by_policy_num: dict[int | float, int],
|
|
448
|
-
|
|
466
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
449
467
|
collections.abc.Hashable, dict[int | float, tuple[Any, ...]]
|
|
450
468
|
],
|
|
451
|
-
|
|
469
|
+
action_by_decision_time_by_subject_id: dict[
|
|
452
470
|
collections.abc.Hashable, dict[int, int]
|
|
453
471
|
],
|
|
454
472
|
suppress_all_data_checks: bool,
|
|
@@ -459,8 +477,8 @@ class TrialConditioningMonitor:
|
|
|
459
477
|
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
|
|
460
478
|
]:
|
|
461
479
|
"""
|
|
462
|
-
Computes the average weighted estimating function stack across all
|
|
463
|
-
auxiliary values used to construct the
|
|
480
|
+
Computes the average weighted estimating function stack across all subjects, along with
|
|
481
|
+
auxiliary values used to construct the adjusted and classical sandwich variances.
|
|
464
482
|
|
|
465
483
|
If only_latest_block is True, only uses data from the most recent update.
|
|
466
484
|
|
|
@@ -471,8 +489,8 @@ class TrialConditioningMonitor:
|
|
|
471
489
|
We simply extract the betas and theta from this array below.
|
|
472
490
|
beta_dim (int):
|
|
473
491
|
The dimension of each of the beta parameters.
|
|
474
|
-
|
|
475
|
-
A 1D JAX NumPy array of
|
|
492
|
+
subject_ids (jnp.ndarray):
|
|
493
|
+
A 1D JAX NumPy array of subject IDs.
|
|
476
494
|
action_prob_func (callable):
|
|
477
495
|
The action probability function.
|
|
478
496
|
action_prob_func_args_beta_index (int):
|
|
@@ -489,22 +507,24 @@ class TrialConditioningMonitor:
|
|
|
489
507
|
alg_update_func_args_action_prob_times_index (int):
|
|
490
508
|
The index in the update function arguments tuple where an array of times for which the
|
|
491
509
|
given action probabilities apply is provided, if applicable. -1 otherwise.
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
510
|
+
alg_update_func_args_previous_betas_index (int):
|
|
511
|
+
The index in the update function arguments tuple where the previous betas are provided, if applicable. -1 otherwise.
|
|
512
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
513
|
+
A dictionary mapping decision times to maps of subject ids to the function arguments
|
|
514
|
+
required to compute action probabilities for this subject.
|
|
515
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
516
|
+
A map of subject ids to dictionaries mapping decision times to the policy number in use.
|
|
497
517
|
Only applies to in-study decision times!
|
|
498
518
|
initial_policy_num (int | float):
|
|
499
519
|
The policy number of the initial policy before any updates.
|
|
500
520
|
beta_index_by_policy_num (dict[int | float, int]):
|
|
501
521
|
A dictionary mapping policy numbers to the index of the corresponding beta in
|
|
502
522
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
503
|
-
|
|
504
|
-
A dictionary where keys are policy numbers and values are dictionaries mapping
|
|
523
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]):
|
|
524
|
+
A dictionary where keys are policy numbers and values are dictionaries mapping subject IDs
|
|
505
525
|
to their respective update function arguments.
|
|
506
|
-
|
|
507
|
-
A dictionary mapping
|
|
526
|
+
action_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int]]):
|
|
527
|
+
A dictionary mapping subject IDs to their respective actions taken at each decision time.
|
|
508
528
|
Only applies to in-study decision times!
|
|
509
529
|
suppress_all_data_checks (bool):
|
|
510
530
|
If True, suppresses carrying out any data checks at all.
|
|
@@ -536,15 +556,15 @@ class TrialConditioningMonitor:
|
|
|
536
556
|
# 1. If only_latest_block is True, we need to filter all the arguments to only
|
|
537
557
|
# include those relevant to the latest update. We still need action probabilities
|
|
538
558
|
# from the beginning for the weights, but the update function args can be trimmed
|
|
539
|
-
# to the max policy so that the loop
|
|
559
|
+
# to the max policy so that the loop single_subject_weighted_RL_estimating_function_stacker
|
|
540
560
|
# is only over one policy.
|
|
541
561
|
if only_latest_block:
|
|
542
562
|
logger.info(
|
|
543
563
|
"Filtering algorithm update function arguments to only include those relevant to the latest update."
|
|
544
564
|
)
|
|
545
565
|
max_policy_num = max(beta_index_by_policy_num)
|
|
546
|
-
|
|
547
|
-
max_policy_num:
|
|
566
|
+
update_func_args_by_by_subject_id_by_policy_num = {
|
|
567
|
+
max_policy_num: update_func_args_by_by_subject_id_by_policy_num[
|
|
548
568
|
max_policy_num
|
|
549
569
|
]
|
|
550
570
|
}
|
|
@@ -553,15 +573,17 @@ class TrialConditioningMonitor:
|
|
|
553
573
|
# supplied for the above functions, so that differentiation works correctly. The existing
|
|
554
574
|
# values should be the same, but not connected to the parameter we are differentiating
|
|
555
575
|
# with respect to. Note we will also find it useful below to have the action probability args
|
|
556
|
-
# nested dict structure flipped to be
|
|
576
|
+
# nested dict structure flipped to be subject_id -> decision_time -> args, so we do that here too.
|
|
557
577
|
|
|
558
|
-
logger.info(
|
|
578
|
+
logger.info(
|
|
579
|
+
"Threading in betas to action probability arguments for all subjects."
|
|
580
|
+
)
|
|
559
581
|
(
|
|
560
|
-
|
|
561
|
-
|
|
582
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
583
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
562
584
|
) = thread_action_prob_func_args(
|
|
563
|
-
|
|
564
|
-
|
|
585
|
+
action_prob_func_args_by_subject_id_by_decision_time,
|
|
586
|
+
policy_num_by_decision_time_by_subject_id,
|
|
565
587
|
initial_policy_num,
|
|
566
588
|
betas,
|
|
567
589
|
beta_index_by_policy_num,
|
|
@@ -573,16 +595,17 @@ class TrialConditioningMonitor:
|
|
|
573
595
|
# arguments with the central betas introduced.
|
|
574
596
|
logger.info(
|
|
575
597
|
"Threading in betas and beta-dependent action probabilities to algorithm update "
|
|
576
|
-
"function args for all
|
|
598
|
+
"function args for all subjects"
|
|
577
599
|
)
|
|
578
|
-
|
|
579
|
-
|
|
600
|
+
threaded_update_func_args_by_policy_num_by_subject_id = thread_update_func_args(
|
|
601
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
580
602
|
betas,
|
|
581
603
|
beta_index_by_policy_num,
|
|
582
604
|
alg_update_func_args_beta_index,
|
|
583
605
|
alg_update_func_args_action_prob_index,
|
|
584
606
|
alg_update_func_args_action_prob_times_index,
|
|
585
|
-
|
|
607
|
+
alg_update_func_args_previous_betas_index,
|
|
608
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
586
609
|
action_prob_func,
|
|
587
610
|
)
|
|
588
611
|
|
|
@@ -592,42 +615,44 @@ class TrialConditioningMonitor:
|
|
|
592
615
|
if not suppress_all_data_checks and alg_update_func_args_action_prob_index >= 0:
|
|
593
616
|
input_checks.require_threaded_algorithm_estimating_function_args_equivalent(
|
|
594
617
|
algorithm_estimating_func,
|
|
595
|
-
|
|
596
|
-
|
|
618
|
+
update_func_args_by_by_subject_id_by_policy_num,
|
|
619
|
+
threaded_update_func_args_by_policy_num_by_subject_id,
|
|
597
620
|
suppress_interactive_data_checks,
|
|
598
621
|
)
|
|
599
622
|
|
|
600
|
-
# 5. Now we can compute the weighted estimating function stacks for all
|
|
601
|
-
# as well as collect related values used to construct the
|
|
623
|
+
# 5. Now we can compute the weighted estimating function stacks for all subjects
|
|
624
|
+
# as well as collect related values used to construct the adjusted and classical
|
|
602
625
|
# sandwich variances.
|
|
603
626
|
RL_stacks = jnp.array(
|
|
604
627
|
[
|
|
605
|
-
self.
|
|
628
|
+
self.single_subject_weighted_RL_estimating_function_stacker(
|
|
606
629
|
beta_dim,
|
|
607
|
-
|
|
630
|
+
subject_id,
|
|
608
631
|
action_prob_func,
|
|
609
632
|
algorithm_estimating_func,
|
|
610
633
|
action_prob_func_args_beta_index,
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
634
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id],
|
|
635
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
636
|
+
subject_id
|
|
637
|
+
],
|
|
638
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id],
|
|
639
|
+
policy_num_by_decision_time_by_subject_id[subject_id],
|
|
640
|
+
action_by_decision_time_by_subject_id[subject_id],
|
|
616
641
|
beta_index_by_policy_num,
|
|
617
642
|
)
|
|
618
|
-
for
|
|
643
|
+
for subject_id in subject_ids.tolist()
|
|
619
644
|
]
|
|
620
645
|
)
|
|
621
646
|
|
|
622
647
|
# 6. We will differentiate the first output, while the second will be used
|
|
623
648
|
# for an estimating function sum check.
|
|
624
|
-
|
|
625
|
-
return
|
|
649
|
+
mean_stack_across_subjects = jnp.mean(RL_stacks, axis=0)
|
|
650
|
+
return mean_stack_across_subjects, mean_stack_across_subjects
|
|
626
651
|
|
|
627
|
-
def
|
|
652
|
+
def single_subject_weighted_RL_estimating_function_stacker(
|
|
628
653
|
self,
|
|
629
654
|
beta_dim: int,
|
|
630
|
-
|
|
655
|
+
subject_id: collections.abc.Hashable,
|
|
631
656
|
action_prob_func: callable,
|
|
632
657
|
algorithm_estimating_func: callable,
|
|
633
658
|
action_prob_func_args_beta_index: int,
|
|
@@ -660,12 +685,12 @@ class TrialConditioningMonitor:
|
|
|
660
685
|
beta_dim (list[jnp.ndarray]):
|
|
661
686
|
A list of 1D JAX NumPy arrays corresponding to the betas produced by all updates.
|
|
662
687
|
|
|
663
|
-
|
|
664
|
-
The
|
|
688
|
+
subject_id (collections.abc.Hashable):
|
|
689
|
+
The subject ID for which to compute the weighted estimating function stack.
|
|
665
690
|
|
|
666
691
|
action_prob_func (callable):
|
|
667
692
|
The function used to compute the probability of action 1 at a given decision time for
|
|
668
|
-
a particular
|
|
693
|
+
a particular subject given their state and the algorithm parameters.
|
|
669
694
|
|
|
670
695
|
algorithm_estimating_func (callable):
|
|
671
696
|
The estimating function that corresponds to algorithm updates.
|
|
@@ -674,7 +699,7 @@ class TrialConditioningMonitor:
|
|
|
674
699
|
The index of the beta argument in the action probability function's arguments.
|
|
675
700
|
|
|
676
701
|
action_prob_func_args_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
677
|
-
A map from decision times to tuples of arguments for this
|
|
702
|
+
A map from decision times to tuples of arguments for this subject for the action
|
|
678
703
|
probability function. This is for all decision times (args are an empty
|
|
679
704
|
tuple if they are not in the study). Should be sorted by decision time. NOTE THAT THESE
|
|
680
705
|
ARGS DO NOT CONTAIN THE SHARED BETAS, making them impervious to the differentiation that
|
|
@@ -687,13 +712,13 @@ class TrialConditioningMonitor:
|
|
|
687
712
|
|
|
688
713
|
threaded_update_func_args_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
689
714
|
A map from policy numbers to tuples containing the arguments for
|
|
690
|
-
the corresponding estimating functions for this
|
|
715
|
+
the corresponding estimating functions for this subject, with the shared betas threaded in
|
|
691
716
|
for differentiation. This is for all non-initial, non-fallback policies. Policy numbers
|
|
692
717
|
should be sorted.
|
|
693
718
|
|
|
694
719
|
policy_num_by_decision_time (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
695
720
|
A dictionary mapping decision times to the policy number in use. This may be
|
|
696
|
-
|
|
721
|
+
subject-specific. Should be sorted by decision time. Only applies to in-study decision
|
|
697
722
|
times!
|
|
698
723
|
|
|
699
724
|
action_by_decision_time (dict[collections.abc.Hashable, dict[int, int]]):
|
|
@@ -705,18 +730,18 @@ class TrialConditioningMonitor:
|
|
|
705
730
|
all_post_update_betas. Note that this is only for non-initial, non-fallback policies.
|
|
706
731
|
|
|
707
732
|
Returns:
|
|
708
|
-
jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the
|
|
733
|
+
jnp.ndarray: A 1-D JAX NumPy array representing the RL portion of the subject's weighted
|
|
709
734
|
estimating function stack.
|
|
710
735
|
"""
|
|
711
736
|
|
|
712
737
|
logger.info(
|
|
713
|
-
"Computing weighted estimating function stack for
|
|
738
|
+
"Computing weighted estimating function stack for subject %s.", subject_id
|
|
714
739
|
)
|
|
715
740
|
|
|
716
741
|
# First, reformat the supplied data into more convenient structures.
|
|
717
742
|
|
|
718
743
|
# 1. Form a dictionary mapping policy numbers to the first time they were
|
|
719
|
-
# applicable (for this
|
|
744
|
+
# applicable (for this subject). Note that this includes ALL policies, initial
|
|
720
745
|
# fallbacks included.
|
|
721
746
|
# Collect the first time after the first update separately for convenience.
|
|
722
747
|
# These are both used to form the Radon-Nikodym weights for the right times.
|
|
@@ -727,17 +752,17 @@ class TrialConditioningMonitor:
|
|
|
727
752
|
)
|
|
728
753
|
)
|
|
729
754
|
|
|
730
|
-
# 2. Get the start and end times for this
|
|
731
|
-
|
|
732
|
-
|
|
755
|
+
# 2. Get the start and end times for this subject.
|
|
756
|
+
subject_start_time = math.inf
|
|
757
|
+
subject_end_time = -math.inf
|
|
733
758
|
for decision_time in action_by_decision_time:
|
|
734
|
-
|
|
735
|
-
|
|
759
|
+
subject_start_time = min(subject_start_time, decision_time)
|
|
760
|
+
subject_end_time = max(subject_end_time, decision_time)
|
|
736
761
|
|
|
737
762
|
# 3. Form a stack of weighted estimating equations, one for each update of the algorithm.
|
|
738
763
|
logger.info(
|
|
739
|
-
"Computing the algorithm component of the weighted estimating function stack for
|
|
740
|
-
|
|
764
|
+
"Computing the algorithm component of the weighted estimating function stack for subject %s.",
|
|
765
|
+
subject_id,
|
|
741
766
|
)
|
|
742
767
|
|
|
743
768
|
in_study_action_prob_func_args = [
|
|
@@ -754,13 +779,13 @@ class TrialConditioningMonitor:
|
|
|
754
779
|
)
|
|
755
780
|
|
|
756
781
|
# Sort the threaded args by decision time to be cautious. We check if the
|
|
757
|
-
#
|
|
758
|
-
# subset of the
|
|
782
|
+
# subject id is present in the subject args dict because we may call this on a
|
|
783
|
+
# subset of the subject arg dict when we are batching arguments by shape
|
|
759
784
|
sorted_threaded_action_prob_args_by_decision_time = {
|
|
760
785
|
decision_time: threaded_action_prob_func_args_by_decision_time[
|
|
761
786
|
decision_time
|
|
762
787
|
]
|
|
763
|
-
for decision_time in range(
|
|
788
|
+
for decision_time in range(subject_start_time, subject_end_time + 1)
|
|
764
789
|
if decision_time in threaded_action_prob_func_args_by_decision_time
|
|
765
790
|
}
|
|
766
791
|
|
|
@@ -820,27 +845,27 @@ class TrialConditioningMonitor:
|
|
|
820
845
|
[
|
|
821
846
|
# Here we compute a product of Radon-Nikodym weights
|
|
822
847
|
# for all decision times after the first update and before the update
|
|
823
|
-
# update under consideration took effect, for which the
|
|
848
|
+
# update under consideration took effect, for which the subject was in the study.
|
|
824
849
|
(
|
|
825
850
|
jnp.prod(
|
|
826
851
|
all_weights[
|
|
827
|
-
# The earliest time after the first update where the
|
|
852
|
+
# The earliest time after the first update where the subject was in
|
|
828
853
|
# the study
|
|
829
854
|
max(
|
|
830
855
|
first_time_after_first_update,
|
|
831
|
-
|
|
856
|
+
subject_start_time,
|
|
832
857
|
)
|
|
833
858
|
- decision_time_to_all_weights_index_offset :
|
|
834
|
-
# One more than the latest time the
|
|
859
|
+
# One more than the latest time the subject was in the study before the time
|
|
835
860
|
# the update under consideration first applied. Note the + 1 because range
|
|
836
861
|
# does not include the right endpoint.
|
|
837
862
|
min(
|
|
838
863
|
min_time_by_policy_num.get(policy_num, math.inf),
|
|
839
|
-
|
|
864
|
+
subject_end_time + 1,
|
|
840
865
|
)
|
|
841
866
|
- decision_time_to_all_weights_index_offset,
|
|
842
867
|
]
|
|
843
|
-
# If the
|
|
868
|
+
# If the subject exited the study before there were any updates,
|
|
844
869
|
# this variable will be None and the above code to grab a weight would
|
|
845
870
|
# throw an error. Just use 1 to include the unweighted estimating function
|
|
846
871
|
# if they have data to contribute to the update.
|
|
@@ -848,8 +873,8 @@ class TrialConditioningMonitor:
|
|
|
848
873
|
else 1
|
|
849
874
|
) # Now use the above to weight the alg estimating function for this update
|
|
850
875
|
* algorithm_estimating_func(*update_args)
|
|
851
|
-
# If there are no arguments for the update function, the
|
|
852
|
-
# study, so we just add a zero vector contribution to the sum across
|
|
876
|
+
# If there are no arguments for the update function, the subject is not yet in the
|
|
877
|
+
# study, so we just add a zero vector contribution to the sum across subjects.
|
|
853
878
|
# Note that after they exit, they still contribute all their data to later
|
|
854
879
|
# updates.
|
|
855
880
|
if update_args
|