lifejacket 0.2.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/after_study_analysis.py +401 -387
- lifejacket/arg_threading_helpers.py +75 -69
- lifejacket/calculate_derivatives.py +19 -21
- lifejacket/{trial_conditioning_monitor.py → deployment_conditioning_monitor.py} +146 -128
- lifejacket/{form_adaptive_meat_adjustments_directly.py → form_adjusted_meat_adjustments_directly.py} +7 -7
- lifejacket/get_datum_for_blowup_supervised_learning.py +315 -307
- lifejacket/helper_functions.py +45 -38
- lifejacket/input_checks.py +263 -261
- lifejacket/small_sample_corrections.py +42 -40
- lifejacket-1.0.0.dist-info/METADATA +56 -0
- lifejacket-1.0.0.dist-info/RECORD +17 -0
- lifejacket-0.2.0.dist-info/METADATA +0 -100
- lifejacket-0.2.0.dist-info/RECORD +0 -17
- {lifejacket-0.2.0.dist-info → lifejacket-1.0.0.dist-info}/WHEEL +0 -0
- {lifejacket-0.2.0.dist-info → lifejacket-1.0.0.dist-info}/entry_points.txt +0 -0
- {lifejacket-0.2.0.dist-info → lifejacket-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -20,10 +20,10 @@ def replace_tuple_index(tupl, index, value):
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def thread_action_prob_func_args(
|
|
23
|
-
|
|
23
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
24
24
|
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
25
25
|
],
|
|
26
|
-
|
|
26
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
27
27
|
collections.abc.Hashable, dict[int, int | float]
|
|
28
28
|
],
|
|
29
29
|
initial_policy_num: int | float,
|
|
@@ -39,12 +39,12 @@ def thread_action_prob_func_args(
|
|
|
39
39
|
decision time to enable correct differentiation.
|
|
40
40
|
|
|
41
41
|
Args:
|
|
42
|
-
|
|
42
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
43
43
|
A map from decision times to maps of user ids to tuples of arguments for action
|
|
44
44
|
probability function. This is for all decision times for all users (args are an empty
|
|
45
45
|
tuple if they are not in the study). Should be sorted by decision time.
|
|
46
46
|
|
|
47
|
-
|
|
47
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
48
48
|
A dictionary mapping decision times to the policy number in use. This may be user-specific.
|
|
49
49
|
Should be sorted by decision time.
|
|
50
50
|
|
|
@@ -69,56 +69,58 @@ def thread_action_prob_func_args(
|
|
|
69
69
|
A map from user ids to maps of decision times to action probability function
|
|
70
70
|
arguments tuples with the shared betas threaded in. Note the key order switch.
|
|
71
71
|
"""
|
|
72
|
-
|
|
72
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id = (
|
|
73
73
|
collections.defaultdict(dict)
|
|
74
74
|
)
|
|
75
|
-
|
|
75
|
+
action_prob_func_args_by_decision_time_by_subject_id = collections.defaultdict(dict)
|
|
76
76
|
for (
|
|
77
77
|
decision_time,
|
|
78
|
-
|
|
79
|
-
) in
|
|
80
|
-
for
|
|
78
|
+
action_prob_func_args_by_subject_id,
|
|
79
|
+
) in action_prob_func_args_by_subject_id_by_decision_time.items():
|
|
80
|
+
for subject_id, args in action_prob_func_args_by_subject_id.items():
|
|
81
81
|
# Always add a contribution to the reversed key order dictionary.
|
|
82
|
-
|
|
82
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id][
|
|
83
83
|
decision_time
|
|
84
84
|
] = args
|
|
85
85
|
|
|
86
86
|
# Now proceed with the threading, if necessary.
|
|
87
87
|
if not args:
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
] = ()
|
|
88
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
89
|
+
subject_id
|
|
90
|
+
][decision_time] = ()
|
|
91
91
|
continue
|
|
92
92
|
|
|
93
|
-
policy_num =
|
|
93
|
+
policy_num = policy_num_by_decision_time_by_subject_id[subject_id][
|
|
94
|
+
decision_time
|
|
95
|
+
]
|
|
94
96
|
|
|
95
97
|
# The expectation is that fallback policies have empty args, and the only other
|
|
96
98
|
# policy not represented in beta_index_by_policy_num is the initial policy.
|
|
97
99
|
if policy_num == initial_policy_num:
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
] =
|
|
100
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
101
|
+
subject_id
|
|
102
|
+
][decision_time] = action_prob_func_args_by_subject_id[subject_id]
|
|
101
103
|
continue
|
|
102
104
|
|
|
103
105
|
beta_to_introduce = all_post_update_betas[
|
|
104
106
|
beta_index_by_policy_num[policy_num]
|
|
105
107
|
]
|
|
106
|
-
|
|
108
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id][
|
|
107
109
|
decision_time
|
|
108
110
|
] = replace_tuple_index(
|
|
109
|
-
|
|
111
|
+
action_prob_func_args_by_subject_id[subject_id],
|
|
110
112
|
action_prob_func_args_beta_index,
|
|
111
113
|
beta_to_introduce,
|
|
112
114
|
)
|
|
113
115
|
|
|
114
116
|
return (
|
|
115
|
-
|
|
116
|
-
|
|
117
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
118
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
117
119
|
)
|
|
118
120
|
|
|
119
121
|
|
|
120
122
|
def thread_update_func_args(
|
|
121
|
-
|
|
123
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
122
124
|
int | float, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
123
125
|
],
|
|
124
126
|
all_post_update_betas: jnp.ndarray,
|
|
@@ -127,7 +129,7 @@ def thread_update_func_args(
|
|
|
127
129
|
alg_update_func_args_action_prob_index: int,
|
|
128
130
|
alg_update_func_args_action_prob_times_index: int,
|
|
129
131
|
alg_update_func_args_previous_betas_index: int,
|
|
130
|
-
|
|
132
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id: dict[
|
|
131
133
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
132
134
|
],
|
|
133
135
|
action_prob_func: callable,
|
|
@@ -139,7 +141,7 @@ def thread_update_func_args(
|
|
|
139
141
|
with reconstructed action probabilities computed using the shared betas.
|
|
140
142
|
|
|
141
143
|
Args:
|
|
142
|
-
|
|
144
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
143
145
|
A dictionary where keys are policy
|
|
144
146
|
numbers and values are dictionaries mapping user IDs to their respective update function
|
|
145
147
|
arguments.
|
|
@@ -170,7 +172,7 @@ def thread_update_func_args(
|
|
|
170
172
|
alg_update_func_args_previous_betas_index (int):
|
|
171
173
|
The index in the update function with previous beta parameters
|
|
172
174
|
|
|
173
|
-
|
|
175
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
174
176
|
A dictionary mapping decision times to the function arguments required to compute action
|
|
175
177
|
probabilities for this user, and with the shared betas thread in.
|
|
176
178
|
|
|
@@ -183,49 +185,51 @@ def thread_update_func_args(
|
|
|
183
185
|
arguments tuples for the specified user with the shared betas threaded in. Note the key
|
|
184
186
|
order switch relative to the supplied args!
|
|
185
187
|
"""
|
|
186
|
-
|
|
188
|
+
threaded_update_func_args_by_policy_num_by_subject_id = collections.defaultdict(
|
|
189
|
+
dict
|
|
190
|
+
)
|
|
187
191
|
for (
|
|
188
192
|
policy_num,
|
|
189
|
-
|
|
190
|
-
) in
|
|
191
|
-
for
|
|
193
|
+
update_func_args_by_subject_id,
|
|
194
|
+
) in update_func_args_by_by_subject_id_by_policy_num.items():
|
|
195
|
+
for subject_id, args in update_func_args_by_subject_id.items():
|
|
192
196
|
if not args:
|
|
193
|
-
|
|
197
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
194
198
|
policy_num
|
|
195
199
|
] = ()
|
|
196
200
|
continue
|
|
197
201
|
|
|
198
202
|
logger.debug(
|
|
199
203
|
"Threading in shared betas to update function arguments for user %s and policy number %s.",
|
|
200
|
-
|
|
204
|
+
subject_id,
|
|
201
205
|
policy_num,
|
|
202
206
|
)
|
|
203
207
|
|
|
204
208
|
beta_to_introduce = all_post_update_betas[
|
|
205
209
|
beta_index_by_policy_num[policy_num]
|
|
206
210
|
]
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
211
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
212
|
+
policy_num
|
|
213
|
+
] = replace_tuple_index(
|
|
214
|
+
update_func_args_by_subject_id[subject_id],
|
|
215
|
+
alg_update_func_args_beta_index,
|
|
216
|
+
beta_to_introduce,
|
|
213
217
|
)
|
|
214
218
|
if alg_update_func_args_previous_betas_index >= 0:
|
|
215
219
|
previous_betas_to_introduce = all_post_update_betas[
|
|
216
220
|
: len(
|
|
217
|
-
|
|
221
|
+
update_func_args_by_subject_id[subject_id][
|
|
218
222
|
alg_update_func_args_previous_betas_index
|
|
219
223
|
]
|
|
220
224
|
)
|
|
221
225
|
]
|
|
222
226
|
if previous_betas_to_introduce.size > 0:
|
|
223
|
-
|
|
227
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
224
228
|
policy_num
|
|
225
229
|
] = replace_tuple_index(
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
],
|
|
230
|
+
threaded_update_func_args_by_policy_num_by_subject_id[
|
|
231
|
+
subject_id
|
|
232
|
+
][policy_num],
|
|
229
233
|
alg_update_func_args_previous_betas_index,
|
|
230
234
|
previous_betas_to_introduce,
|
|
231
235
|
)
|
|
@@ -234,20 +238,20 @@ def thread_update_func_args(
|
|
|
234
238
|
logger.debug(
|
|
235
239
|
"Action probabilities are used in the algorithm update function. Reconstructing them using the shared betas."
|
|
236
240
|
)
|
|
237
|
-
action_prob_times =
|
|
241
|
+
action_prob_times = update_func_args_by_subject_id[subject_id][
|
|
238
242
|
alg_update_func_args_action_prob_times_index
|
|
239
243
|
]
|
|
240
244
|
# Vectorized computation of action_probs_to_introduce using jax.vmap
|
|
241
245
|
flattened_times = action_prob_times.flatten()
|
|
242
246
|
args_list = [
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
]
|
|
247
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
248
|
+
subject_id
|
|
249
|
+
][int(t)]
|
|
246
250
|
for t in flattened_times.tolist()
|
|
247
251
|
]
|
|
248
252
|
if len(args_list) == 0:
|
|
249
253
|
action_probs_to_introduce = jnp.array([]).reshape(
|
|
250
|
-
|
|
254
|
+
update_func_args_by_subject_id[subject_id][
|
|
251
255
|
alg_update_func_args_action_prob_index
|
|
252
256
|
].shape
|
|
253
257
|
)
|
|
@@ -264,31 +268,31 @@ def thread_update_func_args(
|
|
|
264
268
|
action_prob_func, in_axes=tuple(0 for _ in batched_tensors)
|
|
265
269
|
)
|
|
266
270
|
action_probs_to_introduce = vmapped_func(*batched_tensors).reshape(
|
|
267
|
-
|
|
271
|
+
update_func_args_by_subject_id[subject_id][
|
|
268
272
|
alg_update_func_args_action_prob_index
|
|
269
273
|
].shape
|
|
270
274
|
)
|
|
271
|
-
|
|
275
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
272
276
|
policy_num
|
|
273
277
|
] = replace_tuple_index(
|
|
274
|
-
|
|
278
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
275
279
|
policy_num
|
|
276
280
|
],
|
|
277
281
|
alg_update_func_args_action_prob_index,
|
|
278
282
|
action_probs_to_introduce,
|
|
279
283
|
)
|
|
280
|
-
return
|
|
284
|
+
return threaded_update_func_args_by_policy_num_by_subject_id
|
|
281
285
|
|
|
282
286
|
|
|
283
287
|
def thread_inference_func_args(
|
|
284
|
-
|
|
288
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
285
289
|
inference_func_args_theta_index: int,
|
|
286
290
|
theta: jnp.ndarray,
|
|
287
291
|
inference_func_args_action_prob_index: int,
|
|
288
|
-
|
|
292
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id: dict[
|
|
289
293
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
290
294
|
],
|
|
291
|
-
|
|
295
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
292
296
|
collections.abc.Hashable, list[int]
|
|
293
297
|
],
|
|
294
298
|
action_prob_func: callable,
|
|
@@ -300,7 +304,7 @@ def thread_inference_func_args(
|
|
|
300
304
|
probabilities computed using the shared betas.
|
|
301
305
|
|
|
302
306
|
Args:
|
|
303
|
-
|
|
307
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
304
308
|
A dictionary mapping user IDs to their respective inference function arguments.
|
|
305
309
|
|
|
306
310
|
inference_func_args_theta_index (int):
|
|
@@ -315,11 +319,11 @@ def thread_inference_func_args(
|
|
|
315
319
|
tuple where new beta-threaded action probabilities should be inserted, if applicable.
|
|
316
320
|
-1 otherwise.
|
|
317
321
|
|
|
318
|
-
|
|
322
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
319
323
|
A dictionary mapping decision times to the function arguments required to compute action
|
|
320
324
|
probabilities for this user, and with the shared betas thread in.
|
|
321
325
|
|
|
322
|
-
|
|
326
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
323
327
|
For each user, a list of decision times to which action probabilities correspond if
|
|
324
328
|
provided. Typically just in-study times if action probabilites are used in the inference
|
|
325
329
|
loss or estimating function.
|
|
@@ -332,9 +336,9 @@ def thread_inference_func_args(
|
|
|
332
336
|
threaded in.
|
|
333
337
|
"""
|
|
334
338
|
|
|
335
|
-
|
|
336
|
-
for
|
|
337
|
-
|
|
339
|
+
threaded_inference_func_args_by_subject_id = {}
|
|
340
|
+
for subject_id, args in inference_func_args_by_subject_id.items():
|
|
341
|
+
threaded_inference_func_args_by_subject_id[subject_id] = replace_tuple_index(
|
|
338
342
|
args,
|
|
339
343
|
inference_func_args_theta_index,
|
|
340
344
|
theta,
|
|
@@ -343,12 +347,12 @@ def thread_inference_func_args(
|
|
|
343
347
|
if inference_func_args_action_prob_index >= 0:
|
|
344
348
|
# Use a vmap-like pattern to compute action probabilities in batch.
|
|
345
349
|
action_prob_times_flattened = (
|
|
346
|
-
|
|
350
|
+
inference_action_prob_decision_times_by_subject_id[subject_id].flatten()
|
|
347
351
|
)
|
|
348
352
|
args_list = [
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
]
|
|
353
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
354
|
+
subject_id
|
|
355
|
+
][int(t)]
|
|
352
356
|
for t in action_prob_times_flattened.tolist()
|
|
353
357
|
]
|
|
354
358
|
if len(args_list) == 0:
|
|
@@ -369,9 +373,11 @@ def thread_inference_func_args(
|
|
|
369
373
|
action_probs_to_introduce = vmapped_func(*batched_tensors).reshape(
|
|
370
374
|
args[inference_func_args_action_prob_index].shape
|
|
371
375
|
)
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
+
threaded_inference_func_args_by_subject_id[subject_id] = (
|
|
377
|
+
replace_tuple_index(
|
|
378
|
+
threaded_inference_func_args_by_subject_id[subject_id],
|
|
379
|
+
inference_func_args_action_prob_index,
|
|
380
|
+
action_probs_to_introduce,
|
|
381
|
+
)
|
|
376
382
|
)
|
|
377
|
-
return
|
|
383
|
+
return threaded_inference_func_args_by_subject_id
|
|
@@ -198,10 +198,10 @@ def pad_in_study_derivatives_with_zeros(
|
|
|
198
198
|
|
|
199
199
|
def calculate_pi_and_weight_gradients(
|
|
200
200
|
study_df,
|
|
201
|
-
|
|
201
|
+
active_col_name,
|
|
202
202
|
action_col_name,
|
|
203
203
|
calendar_t_col_name,
|
|
204
|
-
|
|
204
|
+
subject_id_col_name,
|
|
205
205
|
action_prob_func,
|
|
206
206
|
action_prob_func_args,
|
|
207
207
|
action_prob_func_args_beta_index,
|
|
@@ -226,10 +226,10 @@ def calculate_pi_and_weight_gradients(
|
|
|
226
226
|
|
|
227
227
|
pi_gradients, weight_gradients = calculate_pi_and_weight_gradients_specific_t(
|
|
228
228
|
study_df,
|
|
229
|
-
|
|
229
|
+
active_col_name,
|
|
230
230
|
action_col_name,
|
|
231
231
|
calendar_t_col_name,
|
|
232
|
-
|
|
232
|
+
subject_id_col_name,
|
|
233
233
|
action_prob_func,
|
|
234
234
|
action_prob_func_args_beta_index,
|
|
235
235
|
calendar_t,
|
|
@@ -252,10 +252,10 @@ def calculate_pi_and_weight_gradients(
|
|
|
252
252
|
|
|
253
253
|
def calculate_pi_and_weight_gradients_specific_t(
|
|
254
254
|
study_df,
|
|
255
|
-
|
|
255
|
+
active_col_name,
|
|
256
256
|
action_col_name,
|
|
257
257
|
calendar_t_col_name,
|
|
258
|
-
|
|
258
|
+
subject_id_col_name,
|
|
259
259
|
action_prob_func,
|
|
260
260
|
action_prob_func_args_beta_index,
|
|
261
261
|
calendar_t,
|
|
@@ -320,10 +320,10 @@ def calculate_pi_and_weight_gradients_specific_t(
|
|
|
320
320
|
study_df,
|
|
321
321
|
calendar_t,
|
|
322
322
|
sorted_user_ids,
|
|
323
|
-
|
|
323
|
+
active_col_name,
|
|
324
324
|
action_col_name,
|
|
325
325
|
calendar_t_col_name,
|
|
326
|
-
|
|
326
|
+
subject_id_col_name,
|
|
327
327
|
)
|
|
328
328
|
# Note the first argument here: we extract the betas to pass in
|
|
329
329
|
# again as the "target" denominator betas, whereas we differentiate with
|
|
@@ -382,10 +382,10 @@ def collect_batched_in_study_actions(
|
|
|
382
382
|
study_df,
|
|
383
383
|
calendar_t,
|
|
384
384
|
sorted_user_ids,
|
|
385
|
-
|
|
385
|
+
active_col_name,
|
|
386
386
|
action_col_name,
|
|
387
387
|
calendar_t_col_name,
|
|
388
|
-
|
|
388
|
+
subject_id_col_name,
|
|
389
389
|
):
|
|
390
390
|
|
|
391
391
|
# TODO: This for loop can be removed, just grabbing the actions col after
|
|
@@ -394,9 +394,9 @@ def collect_batched_in_study_actions(
|
|
|
394
394
|
batched_actions_list = []
|
|
395
395
|
for user_id in sorted_user_ids:
|
|
396
396
|
filtered_user_data = study_df.loc[
|
|
397
|
-
(study_df[
|
|
397
|
+
(study_df[subject_id_col_name] == user_id)
|
|
398
398
|
& (study_df[calendar_t_col_name] == calendar_t)
|
|
399
|
-
& (study_df[
|
|
399
|
+
& (study_df[active_col_name] == 1)
|
|
400
400
|
]
|
|
401
401
|
if not filtered_user_data.empty:
|
|
402
402
|
batched_actions_list.append(filtered_user_data[action_col_name].values[0])
|
|
@@ -785,9 +785,9 @@ def calculate_inference_loss_derivatives(
|
|
|
785
785
|
inference_func,
|
|
786
786
|
inference_func_args_theta_index,
|
|
787
787
|
user_ids,
|
|
788
|
-
|
|
788
|
+
subject_id_col_name,
|
|
789
789
|
action_prob_col_name,
|
|
790
|
-
|
|
790
|
+
active_col_name,
|
|
791
791
|
calendar_t_col_name,
|
|
792
792
|
inference_func_type=FunctionTypes.LOSS,
|
|
793
793
|
):
|
|
@@ -819,18 +819,18 @@ def calculate_inference_loss_derivatives(
|
|
|
819
819
|
max_calendar_time = study_df[calendar_t_col_name].max()
|
|
820
820
|
for user_id in user_ids:
|
|
821
821
|
user_args_list = []
|
|
822
|
-
filtered_user_data = study_df.loc[study_df[
|
|
822
|
+
filtered_user_data = study_df.loc[study_df[subject_id_col_name] == user_id]
|
|
823
823
|
for idx, col_name in enumerate(inference_func_arg_names):
|
|
824
824
|
if idx == inference_func_args_theta_index:
|
|
825
825
|
user_args_list.append(theta_est)
|
|
826
826
|
else:
|
|
827
827
|
user_args_list.append(
|
|
828
|
-
get_study_df_column(filtered_user_data, col_name,
|
|
828
|
+
get_study_df_column(filtered_user_data, col_name, active_col_name)
|
|
829
829
|
)
|
|
830
830
|
args_by_user_id[user_id] = tuple(user_args_list)
|
|
831
831
|
if using_action_probs:
|
|
832
832
|
action_prob_decision_times_by_user_id[user_id] = get_study_df_column(
|
|
833
|
-
filtered_user_data, calendar_t_col_name,
|
|
833
|
+
filtered_user_data, calendar_t_col_name, active_col_name
|
|
834
834
|
)
|
|
835
835
|
|
|
836
836
|
# Get a list of subdicts of the user args dict, with each united by having
|
|
@@ -957,9 +957,7 @@ def calculate_inference_loss_derivatives(
|
|
|
957
957
|
return loss_gradients, loss_hessians, loss_gradient_pi_derivatives
|
|
958
958
|
|
|
959
959
|
|
|
960
|
-
def get_study_df_column(study_df, col_name,
|
|
960
|
+
def get_study_df_column(study_df, col_name, active_col_name):
|
|
961
961
|
return jnp.array(
|
|
962
|
-
study_df.loc[study_df[
|
|
963
|
-
.to_numpy()
|
|
964
|
-
.reshape(-1, 1)
|
|
962
|
+
study_df.loc[study_df[active_col_name] == 1, col_name].to_numpy().reshape(-1, 1)
|
|
965
963
|
)
|