lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
import collections
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
logging.basicConfig(
|
|
12
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
13
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
14
|
+
level=logging.INFO,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def replace_tuple_index(tupl, index, value):
|
|
19
|
+
return tupl[:index] + (value,) + tupl[index + 1 :]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def thread_action_prob_func_args(
|
|
23
|
+
action_prob_func_args_by_user_id_by_decision_time: dict[
|
|
24
|
+
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
25
|
+
],
|
|
26
|
+
policy_num_by_decision_time_by_user_id: dict[
|
|
27
|
+
collections.abc.Hashable, dict[int, int | float]
|
|
28
|
+
],
|
|
29
|
+
initial_policy_num: int | float,
|
|
30
|
+
all_post_update_betas: jnp.ndarray,
|
|
31
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
32
|
+
action_prob_func_args_beta_index: int,
|
|
33
|
+
) -> tuple[
|
|
34
|
+
dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]],
|
|
35
|
+
dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]],
|
|
36
|
+
]:
|
|
37
|
+
"""
|
|
38
|
+
Threads the shared betas into the action probability function arguments for each user and
|
|
39
|
+
decision time to enable correct differentiation.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
action_prob_func_args_by_user_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
43
|
+
A map from decision times to maps of user ids to tuples of arguments for action
|
|
44
|
+
probability function. This is for all decision times for all users (args are an empty
|
|
45
|
+
tuple if they are not in the study). Should be sorted by decision time.
|
|
46
|
+
|
|
47
|
+
policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
48
|
+
A dictionary mapping decision times to the policy number in use. This may be user-specific.
|
|
49
|
+
Should be sorted by decision time.
|
|
50
|
+
|
|
51
|
+
initial_policy_num (int | float): The policy number of the initial policy before any
|
|
52
|
+
updates.
|
|
53
|
+
|
|
54
|
+
all_post_update_betas (jnp.ndarray):
|
|
55
|
+
A 2D array of beta values to be introduced into arguments to
|
|
56
|
+
facilitate differentiation. They will be the same value as what they replace, but this
|
|
57
|
+
introduces direct dependence on the parameter we will differentiate with respect to.
|
|
58
|
+
|
|
59
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
60
|
+
A dictionary mapping policy numbers to their respective
|
|
61
|
+
beta indices in all_post_update_betas.
|
|
62
|
+
|
|
63
|
+
action_prob_func_args_beta_index (int):
|
|
64
|
+
The index in the action probability function arguments tuple
|
|
65
|
+
where the beta value should be inserted.
|
|
66
|
+
Returns:
|
|
67
|
+
dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]:
|
|
68
|
+
A map from user ids to maps of decision times to action probability function
|
|
69
|
+
arguments tuples with the shared betas threaded in. Note the key order switch.
|
|
70
|
+
"""
|
|
71
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id = (
|
|
72
|
+
collections.defaultdict(dict)
|
|
73
|
+
)
|
|
74
|
+
action_prob_func_args_by_decision_time_by_user_id = collections.defaultdict(dict)
|
|
75
|
+
for (
|
|
76
|
+
decision_time,
|
|
77
|
+
action_prob_func_args_by_user_id,
|
|
78
|
+
) in action_prob_func_args_by_user_id_by_decision_time.items():
|
|
79
|
+
for user_id, args in action_prob_func_args_by_user_id.items():
|
|
80
|
+
# Always add a contribution to the reversed key order dictionary.
|
|
81
|
+
action_prob_func_args_by_decision_time_by_user_id[user_id][
|
|
82
|
+
decision_time
|
|
83
|
+
] = args
|
|
84
|
+
|
|
85
|
+
# Now proceed with the threading, if necessary.
|
|
86
|
+
if not args:
|
|
87
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
|
|
88
|
+
decision_time
|
|
89
|
+
] = ()
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
policy_num = policy_num_by_decision_time_by_user_id[user_id][decision_time]
|
|
93
|
+
|
|
94
|
+
# The expectation is that fallback policies have empty args, and the only other
|
|
95
|
+
# policy not represented in beta_index_by_policy_num is the initial policy.
|
|
96
|
+
if policy_num == initial_policy_num:
|
|
97
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
|
|
98
|
+
decision_time
|
|
99
|
+
] = action_prob_func_args_by_user_id[user_id]
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
beta_to_introduce = all_post_update_betas[
|
|
103
|
+
beta_index_by_policy_num[policy_num]
|
|
104
|
+
]
|
|
105
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
|
|
106
|
+
decision_time
|
|
107
|
+
] = replace_tuple_index(
|
|
108
|
+
action_prob_func_args_by_user_id[user_id],
|
|
109
|
+
action_prob_func_args_beta_index,
|
|
110
|
+
beta_to_introduce,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return (
|
|
114
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id,
|
|
115
|
+
action_prob_func_args_by_decision_time_by_user_id,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def thread_update_func_args(
|
|
120
|
+
update_func_args_by_by_user_id_by_policy_num: dict[
|
|
121
|
+
int | float, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
122
|
+
],
|
|
123
|
+
all_post_update_betas: jnp.ndarray,
|
|
124
|
+
beta_index_by_policy_num: dict[int | float, int],
|
|
125
|
+
alg_update_func_args_beta_index: int,
|
|
126
|
+
alg_update_func_args_action_prob_index: int,
|
|
127
|
+
alg_update_func_args_action_prob_times_index: int,
|
|
128
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id: dict[
|
|
129
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
130
|
+
],
|
|
131
|
+
action_prob_func: callable,
|
|
132
|
+
) -> dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]:
|
|
133
|
+
"""
|
|
134
|
+
Threads the shared betas into the algorithm update function arguments for each user and
|
|
135
|
+
policy update to enable correct differentiation. This is done by replacing the betas in the
|
|
136
|
+
update function arguments with the shared betas, and if necessary replacing action probabilities
|
|
137
|
+
with reconstructed action probabilities computed using the shared betas.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
update_func_args_by_by_user_id_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
141
|
+
A dictionary where keys are policy
|
|
142
|
+
numbers and values are dictionaries mapping user IDs to their respective update function
|
|
143
|
+
arguments.
|
|
144
|
+
|
|
145
|
+
all_post_update_betas (jnp.ndarray):
|
|
146
|
+
A 2D array of beta values to be introduced into arguments to
|
|
147
|
+
facilitate differentiation. They will be the same value as what they replace, but this
|
|
148
|
+
introduces direct dependence on the parameter we will differentiate with respect to.
|
|
149
|
+
|
|
150
|
+
beta_index_by_policy_num (dict[int | float, int]):
|
|
151
|
+
A dictionary mapping policy numbers to their respective
|
|
152
|
+
beta indices in all_post_update_betas.
|
|
153
|
+
|
|
154
|
+
alg_update_func_args_beta_index (int):
|
|
155
|
+
The index in the update function arguments tuple
|
|
156
|
+
where the beta value should be inserted.
|
|
157
|
+
|
|
158
|
+
alg_update_func_args_action_prob_index (int):
|
|
159
|
+
The index in the update function arguments
|
|
160
|
+
tuple where new beta-threaded action probabilities should be inserted, if applicable.
|
|
161
|
+
-1 otherwise.
|
|
162
|
+
|
|
163
|
+
alg_update_func_args_action_prob_times_index (int):
|
|
164
|
+
If action probabilities are supplied
|
|
165
|
+
to the update function, this is the index in the arguments where an array of times for
|
|
166
|
+
which the given action probabilities apply is provided.
|
|
167
|
+
|
|
168
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
169
|
+
A dictionary mapping decision times to the function arguments required to compute action
|
|
170
|
+
probabilities for this user, and with the shared betas thread in.
|
|
171
|
+
|
|
172
|
+
action_prob_func (callable):
|
|
173
|
+
A function that computes an action 1 probability given the appropriate arguments.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]:
|
|
177
|
+
A map from user ids to maps of policy numbers to update function
|
|
178
|
+
arguments tuples for the specified user with the shared betas threaded in. Note the key
|
|
179
|
+
order switch relative to the supplied args!
|
|
180
|
+
"""
|
|
181
|
+
threaded_update_func_args_by_policy_num_by_user_id = collections.defaultdict(dict)
|
|
182
|
+
for (
|
|
183
|
+
policy_num,
|
|
184
|
+
update_func_args_by_user_id,
|
|
185
|
+
) in update_func_args_by_by_user_id_by_policy_num.items():
|
|
186
|
+
for user_id, args in update_func_args_by_user_id.items():
|
|
187
|
+
if not args:
|
|
188
|
+
threaded_update_func_args_by_policy_num_by_user_id[user_id][
|
|
189
|
+
policy_num
|
|
190
|
+
] = ()
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
logger.debug(
|
|
194
|
+
"Threading in shared betas to update function arguments for user %s and policy number %s.",
|
|
195
|
+
user_id,
|
|
196
|
+
policy_num,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
beta_to_introduce = all_post_update_betas[
|
|
200
|
+
beta_index_by_policy_num[policy_num]
|
|
201
|
+
]
|
|
202
|
+
threaded_update_func_args_by_policy_num_by_user_id[user_id][policy_num] = (
|
|
203
|
+
replace_tuple_index(
|
|
204
|
+
update_func_args_by_user_id[user_id],
|
|
205
|
+
alg_update_func_args_beta_index,
|
|
206
|
+
beta_to_introduce,
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if alg_update_func_args_action_prob_index >= 0:
|
|
211
|
+
logger.debug(
|
|
212
|
+
"Action probabilities are used in the algorithm update function. Reconstructing them using the shared betas."
|
|
213
|
+
)
|
|
214
|
+
action_prob_times = update_func_args_by_user_id[user_id][
|
|
215
|
+
alg_update_func_args_action_prob_times_index
|
|
216
|
+
]
|
|
217
|
+
# Vectorized computation of action_probs_to_introduce using jax.vmap
|
|
218
|
+
flattened_times = action_prob_times.flatten()
|
|
219
|
+
args_list = [
|
|
220
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
|
|
221
|
+
int(t)
|
|
222
|
+
]
|
|
223
|
+
for t in flattened_times.tolist()
|
|
224
|
+
]
|
|
225
|
+
if len(args_list) == 0:
|
|
226
|
+
action_probs_to_introduce = jnp.array([]).reshape(
|
|
227
|
+
update_func_args_by_user_id[user_id][
|
|
228
|
+
alg_update_func_args_action_prob_index
|
|
229
|
+
].shape
|
|
230
|
+
)
|
|
231
|
+
else:
|
|
232
|
+
batched_args = list(zip(*args_list))
|
|
233
|
+
# Ensure each argument is at least 2D for batching, to avoid shape issues with scalars
|
|
234
|
+
batched_tensors = []
|
|
235
|
+
for arg_group in batched_args:
|
|
236
|
+
arr = jnp.array(arg_group)
|
|
237
|
+
if arr.ndim == 1:
|
|
238
|
+
arr = arr[:, None]
|
|
239
|
+
batched_tensors.append(arr)
|
|
240
|
+
vmapped_func = jax.vmap(
|
|
241
|
+
action_prob_func, in_axes=tuple(0 for _ in batched_tensors)
|
|
242
|
+
)
|
|
243
|
+
action_probs_to_introduce = vmapped_func(*batched_tensors).reshape(
|
|
244
|
+
update_func_args_by_user_id[user_id][
|
|
245
|
+
alg_update_func_args_action_prob_index
|
|
246
|
+
].shape
|
|
247
|
+
)
|
|
248
|
+
threaded_update_func_args_by_policy_num_by_user_id[user_id][
|
|
249
|
+
policy_num
|
|
250
|
+
] = replace_tuple_index(
|
|
251
|
+
threaded_update_func_args_by_policy_num_by_user_id[user_id][
|
|
252
|
+
policy_num
|
|
253
|
+
],
|
|
254
|
+
alg_update_func_args_action_prob_index,
|
|
255
|
+
action_probs_to_introduce,
|
|
256
|
+
)
|
|
257
|
+
return threaded_update_func_args_by_policy_num_by_user_id
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def thread_inference_func_args(
|
|
261
|
+
inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
262
|
+
inference_func_args_theta_index: int,
|
|
263
|
+
theta: jnp.ndarray,
|
|
264
|
+
inference_func_args_action_prob_index: int,
|
|
265
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id: dict[
|
|
266
|
+
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
267
|
+
],
|
|
268
|
+
inference_action_prob_decision_times_by_user_id: dict[
|
|
269
|
+
collections.abc.Hashable, list[int]
|
|
270
|
+
],
|
|
271
|
+
action_prob_func: callable,
|
|
272
|
+
) -> dict[collections.abc.Hashable, tuple[Any, ...]]:
|
|
273
|
+
"""
|
|
274
|
+
Threads the shared theta into the inference function arguments for each user to enable correct
|
|
275
|
+
differentiation. This is done by replacing the theta in the inference function arguments with
|
|
276
|
+
theta. If applicable, action probabilities are also replaced with reconstructed action
|
|
277
|
+
probabilities computed using the shared betas.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
281
|
+
A dictionary mapping user IDs to their respective inference function arguments.
|
|
282
|
+
|
|
283
|
+
inference_func_args_theta_index (int):
|
|
284
|
+
The index in the inference function arguments tuple
|
|
285
|
+
where the theta value should be inserted.
|
|
286
|
+
|
|
287
|
+
theta (jnp.ndarray):
|
|
288
|
+
The theta value to be threaded into the inference function arguments.
|
|
289
|
+
|
|
290
|
+
inference_func_args_action_prob_index (int):
|
|
291
|
+
The index in the inference function arguments
|
|
292
|
+
tuple where new beta-threaded action probabilities should be inserted, if applicable.
|
|
293
|
+
-1 otherwise.
|
|
294
|
+
|
|
295
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
296
|
+
A dictionary mapping decision times to the function arguments required to compute action
|
|
297
|
+
probabilities for this user, and with the shared betas thread in.
|
|
298
|
+
|
|
299
|
+
inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
|
|
300
|
+
For each user, a list of decision times to which action probabilities correspond if
|
|
301
|
+
provided. Typically just in-study times if action probabilites are used in the inference
|
|
302
|
+
loss or estimating function.
|
|
303
|
+
|
|
304
|
+
action_prob_func (callable):
|
|
305
|
+
A function that computes an action 1 probability given the appropriate arguments.
|
|
306
|
+
Returns:
|
|
307
|
+
dict[collections.abc.Hashable, tuple[Any, ...]]:
|
|
308
|
+
A map from user ids to tuples of inference function arguments with the shared theta
|
|
309
|
+
threaded in.
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
threaded_inference_func_args_by_user_id = {}
|
|
313
|
+
for user_id, args in inference_func_args_by_user_id.items():
|
|
314
|
+
threaded_inference_func_args_by_user_id[user_id] = replace_tuple_index(
|
|
315
|
+
args,
|
|
316
|
+
inference_func_args_theta_index,
|
|
317
|
+
theta,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
if inference_func_args_action_prob_index >= 0:
|
|
321
|
+
# Use a vmap-like pattern to compute action probabilities in batch.
|
|
322
|
+
action_prob_times_flattened = (
|
|
323
|
+
inference_action_prob_decision_times_by_user_id[user_id].flatten()
|
|
324
|
+
)
|
|
325
|
+
args_list = [
|
|
326
|
+
threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
|
|
327
|
+
int(t)
|
|
328
|
+
]
|
|
329
|
+
for t in action_prob_times_flattened.tolist()
|
|
330
|
+
]
|
|
331
|
+
if len(args_list) == 0:
|
|
332
|
+
action_probs_to_introduce = jnp.array([]).reshape(
|
|
333
|
+
args[inference_func_args_action_prob_index].shape
|
|
334
|
+
)
|
|
335
|
+
else:
|
|
336
|
+
batched_args = list(zip(*args_list))
|
|
337
|
+
batched_tensors = []
|
|
338
|
+
for arg_group in batched_args:
|
|
339
|
+
arr = jnp.array(arg_group)
|
|
340
|
+
if arr.ndim == 1:
|
|
341
|
+
arr = arr[:, None]
|
|
342
|
+
batched_tensors.append(arr)
|
|
343
|
+
vmapped_func = jax.vmap(
|
|
344
|
+
action_prob_func, in_axes=tuple(0 for _ in batched_tensors)
|
|
345
|
+
)
|
|
346
|
+
action_probs_to_introduce = vmapped_func(*batched_tensors).reshape(
|
|
347
|
+
args[inference_func_args_action_prob_index].shape
|
|
348
|
+
)
|
|
349
|
+
threaded_inference_func_args_by_user_id[user_id] = replace_tuple_index(
|
|
350
|
+
threaded_inference_func_args_by_user_id[user_id],
|
|
351
|
+
inference_func_args_action_prob_index,
|
|
352
|
+
action_probs_to_introduce,
|
|
353
|
+
)
|
|
354
|
+
return threaded_inference_func_args_by_user_id
|