lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,354 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+ import collections
5
+ import logging
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(
12
+ format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
13
+ datefmt="%Y-%m-%d:%H:%M:%S",
14
+ level=logging.INFO,
15
+ )
16
+
17
+
18
+ def replace_tuple_index(tupl, index, value):
19
+ return tupl[:index] + (value,) + tupl[index + 1 :]
20
+
21
+
22
+ def thread_action_prob_func_args(
23
+ action_prob_func_args_by_user_id_by_decision_time: dict[
24
+ int, dict[collections.abc.Hashable, tuple[Any, ...]]
25
+ ],
26
+ policy_num_by_decision_time_by_user_id: dict[
27
+ collections.abc.Hashable, dict[int, int | float]
28
+ ],
29
+ initial_policy_num: int | float,
30
+ all_post_update_betas: jnp.ndarray,
31
+ beta_index_by_policy_num: dict[int | float, int],
32
+ action_prob_func_args_beta_index: int,
33
+ ) -> tuple[
34
+ dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]],
35
+ dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]],
36
+ ]:
37
+ """
38
+ Threads the shared betas into the action probability function arguments for each user and
39
+ decision time to enable correct differentiation.
40
+
41
+ Args:
42
+ action_prob_func_args_by_user_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
43
+ A map from decision times to maps of user ids to tuples of arguments for action
44
+ probability function. This is for all decision times for all users (args are an empty
45
+ tuple if they are not in the study). Should be sorted by decision time.
46
+
47
+ policy_num_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, int | float]]):
48
+ A dictionary mapping decision times to the policy number in use. This may be user-specific.
49
+ Should be sorted by decision time.
50
+
51
+ initial_policy_num (int | float): The policy number of the initial policy before any
52
+ updates.
53
+
54
+ all_post_update_betas (jnp.ndarray):
55
+ A 2D array of beta values to be introduced into arguments to
56
+ facilitate differentiation. They will be the same value as what they replace, but this
57
+ introduces direct dependence on the parameter we will differentiate with respect to.
58
+
59
+ beta_index_by_policy_num (dict[int | float, int]):
60
+ A dictionary mapping policy numbers to their respective
61
+ beta indices in all_post_update_betas.
62
+
63
+ action_prob_func_args_beta_index (int):
64
+ The index in the action probability function arguments tuple
65
+ where the beta value should be inserted.
66
+ Returns:
67
+ dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]:
68
+ A map from user ids to maps of decision times to action probability function
69
+ arguments tuples with the shared betas threaded in. Note the key order switch.
70
+ """
71
+ threaded_action_prob_func_args_by_decision_time_by_user_id = (
72
+ collections.defaultdict(dict)
73
+ )
74
+ action_prob_func_args_by_decision_time_by_user_id = collections.defaultdict(dict)
75
+ for (
76
+ decision_time,
77
+ action_prob_func_args_by_user_id,
78
+ ) in action_prob_func_args_by_user_id_by_decision_time.items():
79
+ for user_id, args in action_prob_func_args_by_user_id.items():
80
+ # Always add a contribution to the reversed key order dictionary.
81
+ action_prob_func_args_by_decision_time_by_user_id[user_id][
82
+ decision_time
83
+ ] = args
84
+
85
+ # Now proceed with the threading, if necessary.
86
+ if not args:
87
+ threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
88
+ decision_time
89
+ ] = ()
90
+ continue
91
+
92
+ policy_num = policy_num_by_decision_time_by_user_id[user_id][decision_time]
93
+
94
+ # The expectation is that fallback policies have empty args, and the only other
95
+ # policy not represented in beta_index_by_policy_num is the initial policy.
96
+ if policy_num == initial_policy_num:
97
+ threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
98
+ decision_time
99
+ ] = action_prob_func_args_by_user_id[user_id]
100
+ continue
101
+
102
+ beta_to_introduce = all_post_update_betas[
103
+ beta_index_by_policy_num[policy_num]
104
+ ]
105
+ threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
106
+ decision_time
107
+ ] = replace_tuple_index(
108
+ action_prob_func_args_by_user_id[user_id],
109
+ action_prob_func_args_beta_index,
110
+ beta_to_introduce,
111
+ )
112
+
113
+ return (
114
+ threaded_action_prob_func_args_by_decision_time_by_user_id,
115
+ action_prob_func_args_by_decision_time_by_user_id,
116
+ )
117
+
118
+
119
+ def thread_update_func_args(
120
+ update_func_args_by_by_user_id_by_policy_num: dict[
121
+ int | float, dict[collections.abc.Hashable, tuple[Any, ...]]
122
+ ],
123
+ all_post_update_betas: jnp.ndarray,
124
+ beta_index_by_policy_num: dict[int | float, int],
125
+ alg_update_func_args_beta_index: int,
126
+ alg_update_func_args_action_prob_index: int,
127
+ alg_update_func_args_action_prob_times_index: int,
128
+ threaded_action_prob_func_args_by_decision_time_by_user_id: dict[
129
+ collections.abc.Hashable, dict[int, tuple[Any, ...]]
130
+ ],
131
+ action_prob_func: callable,
132
+ ) -> dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]:
133
+ """
134
+ Threads the shared betas into the algorithm update function arguments for each user and
135
+ policy update to enable correct differentiation. This is done by replacing the betas in the
136
+ update function arguments with the shared betas, and if necessary replacing action probabilities
137
+ with reconstructed action probabilities computed using the shared betas.
138
+
139
+ Args:
140
+ update_func_args_by_by_user_id_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
141
+ A dictionary where keys are policy
142
+ numbers and values are dictionaries mapping user IDs to their respective update function
143
+ arguments.
144
+
145
+ all_post_update_betas (jnp.ndarray):
146
+ A 2D array of beta values to be introduced into arguments to
147
+ facilitate differentiation. They will be the same value as what they replace, but this
148
+ introduces direct dependence on the parameter we will differentiate with respect to.
149
+
150
+ beta_index_by_policy_num (dict[int | float, int]):
151
+ A dictionary mapping policy numbers to their respective
152
+ beta indices in all_post_update_betas.
153
+
154
+ alg_update_func_args_beta_index (int):
155
+ The index in the update function arguments tuple
156
+ where the beta value should be inserted.
157
+
158
+ alg_update_func_args_action_prob_index (int):
159
+ The index in the update function arguments
160
+ tuple where new beta-threaded action probabilities should be inserted, if applicable.
161
+ -1 otherwise.
162
+
163
+ alg_update_func_args_action_prob_times_index (int):
164
+ If action probabilities are supplied
165
+ to the update function, this is the index in the arguments where an array of times for
166
+ which the given action probabilities apply is provided.
167
+
168
+ threaded_action_prob_func_args_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
169
+ A dictionary mapping decision times to the function arguments required to compute action
170
+ probabilities for this user, and with the shared betas thread in.
171
+
172
+ action_prob_func (callable):
173
+ A function that computes an action 1 probability given the appropriate arguments.
174
+
175
+ Returns:
176
+ dict[collections.abc.Hashable, dict[int | float, tuple[Any, ...]]]:
177
+ A map from user ids to maps of policy numbers to update function
178
+ arguments tuples for the specified user with the shared betas threaded in. Note the key
179
+ order switch relative to the supplied args!
180
+ """
181
+ threaded_update_func_args_by_policy_num_by_user_id = collections.defaultdict(dict)
182
+ for (
183
+ policy_num,
184
+ update_func_args_by_user_id,
185
+ ) in update_func_args_by_by_user_id_by_policy_num.items():
186
+ for user_id, args in update_func_args_by_user_id.items():
187
+ if not args:
188
+ threaded_update_func_args_by_policy_num_by_user_id[user_id][
189
+ policy_num
190
+ ] = ()
191
+ continue
192
+
193
+ logger.debug(
194
+ "Threading in shared betas to update function arguments for user %s and policy number %s.",
195
+ user_id,
196
+ policy_num,
197
+ )
198
+
199
+ beta_to_introduce = all_post_update_betas[
200
+ beta_index_by_policy_num[policy_num]
201
+ ]
202
+ threaded_update_func_args_by_policy_num_by_user_id[user_id][policy_num] = (
203
+ replace_tuple_index(
204
+ update_func_args_by_user_id[user_id],
205
+ alg_update_func_args_beta_index,
206
+ beta_to_introduce,
207
+ )
208
+ )
209
+
210
+ if alg_update_func_args_action_prob_index >= 0:
211
+ logger.debug(
212
+ "Action probabilities are used in the algorithm update function. Reconstructing them using the shared betas."
213
+ )
214
+ action_prob_times = update_func_args_by_user_id[user_id][
215
+ alg_update_func_args_action_prob_times_index
216
+ ]
217
+ # Vectorized computation of action_probs_to_introduce using jax.vmap
218
+ flattened_times = action_prob_times.flatten()
219
+ args_list = [
220
+ threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
221
+ int(t)
222
+ ]
223
+ for t in flattened_times.tolist()
224
+ ]
225
+ if len(args_list) == 0:
226
+ action_probs_to_introduce = jnp.array([]).reshape(
227
+ update_func_args_by_user_id[user_id][
228
+ alg_update_func_args_action_prob_index
229
+ ].shape
230
+ )
231
+ else:
232
+ batched_args = list(zip(*args_list))
233
+ # Ensure each argument is at least 2D for batching, to avoid shape issues with scalars
234
+ batched_tensors = []
235
+ for arg_group in batched_args:
236
+ arr = jnp.array(arg_group)
237
+ if arr.ndim == 1:
238
+ arr = arr[:, None]
239
+ batched_tensors.append(arr)
240
+ vmapped_func = jax.vmap(
241
+ action_prob_func, in_axes=tuple(0 for _ in batched_tensors)
242
+ )
243
+ action_probs_to_introduce = vmapped_func(*batched_tensors).reshape(
244
+ update_func_args_by_user_id[user_id][
245
+ alg_update_func_args_action_prob_index
246
+ ].shape
247
+ )
248
+ threaded_update_func_args_by_policy_num_by_user_id[user_id][
249
+ policy_num
250
+ ] = replace_tuple_index(
251
+ threaded_update_func_args_by_policy_num_by_user_id[user_id][
252
+ policy_num
253
+ ],
254
+ alg_update_func_args_action_prob_index,
255
+ action_probs_to_introduce,
256
+ )
257
+ return threaded_update_func_args_by_policy_num_by_user_id
258
+
259
+
260
+ def thread_inference_func_args(
261
+ inference_func_args_by_user_id: dict[collections.abc.Hashable, tuple[Any, ...]],
262
+ inference_func_args_theta_index: int,
263
+ theta: jnp.ndarray,
264
+ inference_func_args_action_prob_index: int,
265
+ threaded_action_prob_func_args_by_decision_time_by_user_id: dict[
266
+ collections.abc.Hashable, dict[int, tuple[Any, ...]]
267
+ ],
268
+ inference_action_prob_decision_times_by_user_id: dict[
269
+ collections.abc.Hashable, list[int]
270
+ ],
271
+ action_prob_func: callable,
272
+ ) -> dict[collections.abc.Hashable, tuple[Any, ...]]:
273
+ """
274
+ Threads the shared theta into the inference function arguments for each user to enable correct
275
+ differentiation. This is done by replacing the theta in the inference function arguments with
276
+ theta. If applicable, action probabilities are also replaced with reconstructed action
277
+ probabilities computed using the shared betas.
278
+
279
+ Args:
280
+ inference_func_args_by_user_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
281
+ A dictionary mapping user IDs to their respective inference function arguments.
282
+
283
+ inference_func_args_theta_index (int):
284
+ The index in the inference function arguments tuple
285
+ where the theta value should be inserted.
286
+
287
+ theta (jnp.ndarray):
288
+ The theta value to be threaded into the inference function arguments.
289
+
290
+ inference_func_args_action_prob_index (int):
291
+ The index in the inference function arguments
292
+ tuple where new beta-threaded action probabilities should be inserted, if applicable.
293
+ -1 otherwise.
294
+
295
+ threaded_action_prob_func_args_by_decision_time_by_user_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
296
+ A dictionary mapping decision times to the function arguments required to compute action
297
+ probabilities for this user, and with the shared betas thread in.
298
+
299
+ inference_action_prob_decision_times_by_user_id (dict[collections.abc.Hashable, list[int]]):
300
+ For each user, a list of decision times to which action probabilities correspond if
301
+ provided. Typically just in-study times if action probabilites are used in the inference
302
+ loss or estimating function.
303
+
304
+ action_prob_func (callable):
305
+ A function that computes an action 1 probability given the appropriate arguments.
306
+ Returns:
307
+ dict[collections.abc.Hashable, tuple[Any, ...]]:
308
+ A map from user ids to tuples of inference function arguments with the shared theta
309
+ threaded in.
310
+ """
311
+
312
+ threaded_inference_func_args_by_user_id = {}
313
+ for user_id, args in inference_func_args_by_user_id.items():
314
+ threaded_inference_func_args_by_user_id[user_id] = replace_tuple_index(
315
+ args,
316
+ inference_func_args_theta_index,
317
+ theta,
318
+ )
319
+
320
+ if inference_func_args_action_prob_index >= 0:
321
+ # Use a vmap-like pattern to compute action probabilities in batch.
322
+ action_prob_times_flattened = (
323
+ inference_action_prob_decision_times_by_user_id[user_id].flatten()
324
+ )
325
+ args_list = [
326
+ threaded_action_prob_func_args_by_decision_time_by_user_id[user_id][
327
+ int(t)
328
+ ]
329
+ for t in action_prob_times_flattened.tolist()
330
+ ]
331
+ if len(args_list) == 0:
332
+ action_probs_to_introduce = jnp.array([]).reshape(
333
+ args[inference_func_args_action_prob_index].shape
334
+ )
335
+ else:
336
+ batched_args = list(zip(*args_list))
337
+ batched_tensors = []
338
+ for arg_group in batched_args:
339
+ arr = jnp.array(arg_group)
340
+ if arr.ndim == 1:
341
+ arr = arr[:, None]
342
+ batched_tensors.append(arr)
343
+ vmapped_func = jax.vmap(
344
+ action_prob_func, in_axes=tuple(0 for _ in batched_tensors)
345
+ )
346
+ action_probs_to_introduce = vmapped_func(*batched_tensors).reshape(
347
+ args[inference_func_args_action_prob_index].shape
348
+ )
349
+ threaded_inference_func_args_by_user_id[user_id] = replace_tuple_index(
350
+ threaded_inference_func_args_by_user_id[user_id],
351
+ inference_func_args_action_prob_index,
352
+ action_probs_to_introduce,
353
+ )
354
+ return threaded_inference_func_args_by_user_id