lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,965 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .constants import FunctionTypes
|
|
9
|
+
from .helper_functions import (
|
|
10
|
+
conditional_x_or_one_minus_x,
|
|
11
|
+
load_function_from_same_named_file,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
logging.basicConfig(
|
|
16
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
17
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
18
|
+
level=logging.INFO,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# TODO: Consolidate function loading logic
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_batched_arg_lists_and_involved_user_ids(func, sorted_user_ids, args_by_user_id):
|
|
25
|
+
"""
|
|
26
|
+
Collect a dict of arg tuples by user id into a list of lists, each containing
|
|
27
|
+
all the args at a particular index across users. We make sure the list is
|
|
28
|
+
ordered according to sorted_user_ids.
|
|
29
|
+
"""
|
|
30
|
+
# Sort users to be cautious. We check if the user id is present in the user args dict
|
|
31
|
+
# because we may call this on a subset of the user arg dict when we are batching
|
|
32
|
+
# arguments by shape
|
|
33
|
+
sorted_args_by_user_id = {
|
|
34
|
+
user_id: args_by_user_id[user_id]
|
|
35
|
+
for user_id in sorted_user_ids
|
|
36
|
+
if user_id in args_by_user_id
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Just a quick way to get the arg count instead of iterating through args
|
|
40
|
+
# for the first Truthy tuple
|
|
41
|
+
# TODO: If there are arguments with defaults and not supplied, this will break.
|
|
42
|
+
# Should probably in fact iterate through to first Truthy tuple.
|
|
43
|
+
num_args = func.__code__.co_argcount
|
|
44
|
+
|
|
45
|
+
# NOTE: Cannot do [[]] * num_args here! Then all lists point
|
|
46
|
+
# same object...
|
|
47
|
+
batched_arg_lists = [[] for _ in range(num_args)]
|
|
48
|
+
involved_user_ids = set()
|
|
49
|
+
for user_id, user_args in sorted_args_by_user_id.items():
|
|
50
|
+
if not user_args:
|
|
51
|
+
continue
|
|
52
|
+
involved_user_ids.add(user_id)
|
|
53
|
+
for idx, arg in enumerate(user_args):
|
|
54
|
+
batched_arg_lists[idx].append(arg)
|
|
55
|
+
|
|
56
|
+
return batched_arg_lists, involved_user_ids
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_shape(obj):
|
|
60
|
+
if hasattr(obj, "shape"):
|
|
61
|
+
return obj.shape
|
|
62
|
+
if isinstance(obj, str):
|
|
63
|
+
return None
|
|
64
|
+
try:
|
|
65
|
+
return len(obj)
|
|
66
|
+
except Exception:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def group_user_args_by_shape(user_arg_dict, empty_allowed=True):
|
|
71
|
+
user_arg_dicts_by_shape = collections.defaultdict(dict)
|
|
72
|
+
for user_id, args in user_arg_dict.items():
|
|
73
|
+
if not args:
|
|
74
|
+
if not empty_allowed:
|
|
75
|
+
raise ValueError("There shouldn't be a user with no data at this time")
|
|
76
|
+
continue
|
|
77
|
+
shape_id = tuple(get_shape(arg) for arg in args)
|
|
78
|
+
user_arg_dicts_by_shape[shape_id][user_id] = args
|
|
79
|
+
return user_arg_dicts_by_shape.values()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# TODO: Check for exactly the required types earlier
|
|
83
|
+
# TODO: Try except and nice error message
|
|
84
|
+
# TODO: This is complicated enough to deserve its own unit tests
|
|
85
|
+
def stack_batched_arg_lists_into_tensor(batched_arg_lists):
|
|
86
|
+
"""
|
|
87
|
+
Stack a simple Python list of lists of function arguments (across all users for a specific arg position)
|
|
88
|
+
into a list of jnp arrays that can be supplied to vmap as batch arguments. vmap requires all elements of
|
|
89
|
+
such a batched array to be the same shape, as do the stacking functions we use here. Thus we require
|
|
90
|
+
this be called on batches of users with the same data shape. We also supply the axes one must
|
|
91
|
+
iterate over to get each users's args in a batch.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
batched_arg_tensors = []
|
|
95
|
+
|
|
96
|
+
# This ends up being all zeros because of the way we are (now) doing the
|
|
97
|
+
# stacking, but better to not assume that externally and send out what
|
|
98
|
+
# we've done with this list.
|
|
99
|
+
batch_axes = []
|
|
100
|
+
|
|
101
|
+
for batched_arg_list in batched_arg_lists:
|
|
102
|
+
if (
|
|
103
|
+
isinstance(
|
|
104
|
+
batched_arg_list[0],
|
|
105
|
+
(jnp.ndarray, np.ndarray),
|
|
106
|
+
)
|
|
107
|
+
and batched_arg_list[0].ndim > 2
|
|
108
|
+
):
|
|
109
|
+
raise TypeError("Arrays with dimension greater that 2 are not supported.")
|
|
110
|
+
if (
|
|
111
|
+
isinstance(
|
|
112
|
+
batched_arg_list[0],
|
|
113
|
+
(jnp.ndarray, np.ndarray),
|
|
114
|
+
)
|
|
115
|
+
and batched_arg_list[0].ndim == 2
|
|
116
|
+
):
|
|
117
|
+
########## We have a matrix (2D array) arg
|
|
118
|
+
|
|
119
|
+
batched_arg_tensors.append(jnp.stack(batched_arg_list, 0))
|
|
120
|
+
batch_axes.append(0)
|
|
121
|
+
elif isinstance(
|
|
122
|
+
batched_arg_list[0],
|
|
123
|
+
(collections.abc.Sequence, jnp.ndarray, np.ndarray),
|
|
124
|
+
) and not isinstance(batched_arg_list[0], str):
|
|
125
|
+
########## We have a vector (1D array) arg
|
|
126
|
+
if not isinstance(batched_arg_list[0], (jnp.ndarray, np.ndarray)):
|
|
127
|
+
try:
|
|
128
|
+
batched_arg_list = [jnp.array(x) for x in batched_arg_list]
|
|
129
|
+
except Exception as e:
|
|
130
|
+
raise TypeError(
|
|
131
|
+
"Argument of sequence type that cannot be cast to JAX numpy array"
|
|
132
|
+
) from e
|
|
133
|
+
assert batched_arg_list[0].ndim == 1
|
|
134
|
+
|
|
135
|
+
batched_arg_tensors.append(jnp.vstack(batched_arg_list))
|
|
136
|
+
batch_axes.append(0)
|
|
137
|
+
else:
|
|
138
|
+
########## Otherwise we should have a list of scalars.
|
|
139
|
+
# Just turn into a jnp array.
|
|
140
|
+
batched_arg_tensors.append(jnp.array(batched_arg_list))
|
|
141
|
+
batch_axes.append(0)
|
|
142
|
+
|
|
143
|
+
return (
|
|
144
|
+
batched_arg_tensors,
|
|
145
|
+
batch_axes,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# TODO: Add clarity on why we need gradients at all times.
|
|
150
|
+
def pad_loss_gradient_pi_derivative_outside_supplied_action_probabilites(
|
|
151
|
+
loss_gradient_pi_derivative,
|
|
152
|
+
action_prob_times,
|
|
153
|
+
first_time_after,
|
|
154
|
+
):
|
|
155
|
+
"""
|
|
156
|
+
This fills in zero gradients for the times for which action probabilites
|
|
157
|
+
were not supplied, for users currently in the study. Compare to the below
|
|
158
|
+
padding which is about filling in full sets of zero gradients for all users
|
|
159
|
+
not currently in the study. This is about filling in zero gradients for
|
|
160
|
+
times 1,2,3,4,9,10,11,12 if action probabilities are given for times 5,6,7,
|
|
161
|
+
8.
|
|
162
|
+
"""
|
|
163
|
+
zero_gradient = np.zeros((loss_gradient_pi_derivative.shape[0], 1, 1))
|
|
164
|
+
gradients_to_stack = []
|
|
165
|
+
next_column_index_to_grab = 0
|
|
166
|
+
for t in range(1, first_time_after):
|
|
167
|
+
if t in action_prob_times:
|
|
168
|
+
gradients_to_stack.append(
|
|
169
|
+
np.expand_dims(
|
|
170
|
+
loss_gradient_pi_derivative[:, next_column_index_to_grab, :], 2
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
next_column_index_to_grab += 1
|
|
174
|
+
else:
|
|
175
|
+
gradients_to_stack.append(zero_gradient)
|
|
176
|
+
|
|
177
|
+
return np.hstack(gradients_to_stack)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def pad_in_study_derivatives_with_zeros(
|
|
181
|
+
in_study_derivatives, sorted_user_ids, in_study_user_ids
|
|
182
|
+
):
|
|
183
|
+
"""
|
|
184
|
+
This fills in zero gradients for users not currently in the study given the
|
|
185
|
+
derivatives computed for those in it.
|
|
186
|
+
"""
|
|
187
|
+
all_derivatives = []
|
|
188
|
+
in_study_next_idx = 0
|
|
189
|
+
for user_id in sorted_user_ids:
|
|
190
|
+
if user_id in in_study_user_ids:
|
|
191
|
+
all_derivatives.append(in_study_derivatives[in_study_next_idx])
|
|
192
|
+
in_study_next_idx += 1
|
|
193
|
+
else:
|
|
194
|
+
all_derivatives.append(np.zeros_like(in_study_derivatives[0]))
|
|
195
|
+
|
|
196
|
+
return all_derivatives
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def calculate_pi_and_weight_gradients(
|
|
200
|
+
study_df,
|
|
201
|
+
in_study_col_name,
|
|
202
|
+
action_col_name,
|
|
203
|
+
calendar_t_col_name,
|
|
204
|
+
user_id_col_name,
|
|
205
|
+
action_prob_func,
|
|
206
|
+
action_prob_func_args,
|
|
207
|
+
action_prob_func_args_beta_index,
|
|
208
|
+
):
|
|
209
|
+
"""
|
|
210
|
+
For all decision times, for all users, compute the gradient with respect to
|
|
211
|
+
beta of both the pi function (which takes a state and gives the probability
|
|
212
|
+
of selecting action 1) and the Radon-Nikodym weight (derived from pi
|
|
213
|
+
functions as described in the paper)
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
logger.debug("Calculating pi and weight gradients with respect to beta.")
|
|
217
|
+
|
|
218
|
+
pi_and_weight_gradients_by_calendar_t = {}
|
|
219
|
+
|
|
220
|
+
# This is a reliable way to get all user ids since we require all user ids
|
|
221
|
+
# at all decision times
|
|
222
|
+
user_ids = list(next(iter(action_prob_func_args.values())).keys())
|
|
223
|
+
sorted_user_ids = sorted(user_ids)
|
|
224
|
+
|
|
225
|
+
for calendar_t, args_by_user_id in action_prob_func_args.items():
|
|
226
|
+
|
|
227
|
+
pi_gradients, weight_gradients = calculate_pi_and_weight_gradients_specific_t(
|
|
228
|
+
study_df,
|
|
229
|
+
in_study_col_name,
|
|
230
|
+
action_col_name,
|
|
231
|
+
calendar_t_col_name,
|
|
232
|
+
user_id_col_name,
|
|
233
|
+
action_prob_func,
|
|
234
|
+
action_prob_func_args_beta_index,
|
|
235
|
+
calendar_t,
|
|
236
|
+
args_by_user_id,
|
|
237
|
+
sorted_user_ids,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
logger.debug("Collecting pi gradients into algorithm stats dictionary.")
|
|
241
|
+
pi_and_weight_gradients_by_calendar_t.setdefault(calendar_t, {})[
|
|
242
|
+
"pi_gradients_by_user_id"
|
|
243
|
+
] = {user_id: pi_gradients[i] for i, user_id in enumerate(sorted_user_ids)}
|
|
244
|
+
|
|
245
|
+
logger.debug("Collecting weight gradients into algorithm stats dictionary.")
|
|
246
|
+
pi_and_weight_gradients_by_calendar_t.setdefault(calendar_t, {})[
|
|
247
|
+
"weight_gradients_by_user_id"
|
|
248
|
+
] = {user_id: weight_gradients[i] for i, user_id in enumerate(sorted_user_ids)}
|
|
249
|
+
|
|
250
|
+
return pi_and_weight_gradients_by_calendar_t
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def calculate_pi_and_weight_gradients_specific_t(
|
|
254
|
+
study_df,
|
|
255
|
+
in_study_col_name,
|
|
256
|
+
action_col_name,
|
|
257
|
+
calendar_t_col_name,
|
|
258
|
+
user_id_col_name,
|
|
259
|
+
action_prob_func,
|
|
260
|
+
action_prob_func_args_beta_index,
|
|
261
|
+
calendar_t,
|
|
262
|
+
args_by_user_id,
|
|
263
|
+
sorted_user_ids,
|
|
264
|
+
):
|
|
265
|
+
logger.debug(
|
|
266
|
+
"Calculating pi and weight gradients for decision time %d.",
|
|
267
|
+
calendar_t,
|
|
268
|
+
)
|
|
269
|
+
# Get a list of subdicts of the user args dict, with each united by having
|
|
270
|
+
# the same shapes across all arguments. We will then vmap the gradients needed
|
|
271
|
+
# for each subdict separately, and later combine the results. In the worst
|
|
272
|
+
# case we may have a batch per user, if, say, everyone starts on a different
|
|
273
|
+
# date, and this will be slow. If this is problematic we can pad the data
|
|
274
|
+
# with some values that don't affect computations, producing one batch here.
|
|
275
|
+
# This also supports very large simulations by making things fast as long
|
|
276
|
+
# as there is 1 or a small number of shape batches.
|
|
277
|
+
nontrivial_user_args_grouped_by_shape = group_user_args_by_shape(args_by_user_id)
|
|
278
|
+
logger.debug(
|
|
279
|
+
"Found %d set(s) of users with different arg shapes.",
|
|
280
|
+
len(nontrivial_user_args_grouped_by_shape),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Loop over each set of user args and vmap to get their pi and weight gradients
|
|
284
|
+
in_study_pi_gradients_by_user_id = {}
|
|
285
|
+
in_study_weight_gradients_by_user_id = {}
|
|
286
|
+
all_involved_user_ids = set()
|
|
287
|
+
for args_by_user_id_subset in nontrivial_user_args_grouped_by_shape:
|
|
288
|
+
# Now that we are grouping by arg shape and excluding the out of study
|
|
289
|
+
# group, all the users should be involved in the study in this loop,
|
|
290
|
+
# but... just keep this logic that works for heterogeneous-shaped
|
|
291
|
+
# batches too.
|
|
292
|
+
batched_arg_lists, involved_user_ids = (
|
|
293
|
+
get_batched_arg_lists_and_involved_user_ids(
|
|
294
|
+
action_prob_func, sorted_user_ids, args_by_user_id_subset
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
all_involved_user_ids |= involved_user_ids
|
|
298
|
+
|
|
299
|
+
if not batched_arg_lists[0]:
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
logger.debug("Reforming batched data lists into tensors.")
|
|
303
|
+
batched_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensor(
|
|
304
|
+
batched_arg_lists
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
logger.debug("Forming pi gradients with respect to beta.")
|
|
308
|
+
# Note that we care about the probability of action 1 specifically,
|
|
309
|
+
# not the taken action.
|
|
310
|
+
in_study_pi_gradients_subset = get_pi_gradients_batched(
|
|
311
|
+
action_prob_func,
|
|
312
|
+
action_prob_func_args_beta_index,
|
|
313
|
+
batch_axes,
|
|
314
|
+
batched_arg_tensors,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# TODO: betas should be verified to be the same across users now or earlier
|
|
318
|
+
logger.debug("Forming weight gradients with respect to beta.")
|
|
319
|
+
in_study_batched_actions_tensor = collect_batched_in_study_actions(
|
|
320
|
+
study_df,
|
|
321
|
+
calendar_t,
|
|
322
|
+
sorted_user_ids,
|
|
323
|
+
in_study_col_name,
|
|
324
|
+
action_col_name,
|
|
325
|
+
calendar_t_col_name,
|
|
326
|
+
user_id_col_name,
|
|
327
|
+
)
|
|
328
|
+
# Note the first argument here: we extract the betas to pass in
|
|
329
|
+
# again as the "target" denominator betas, whereas we differentiate with
|
|
330
|
+
# respect to the betas in the numerator. Also note that these betas are
|
|
331
|
+
# redundant across users: it's just the same thing repeated num users
|
|
332
|
+
# times.
|
|
333
|
+
in_study_weight_gradients_subset = get_weight_gradients_batched(
|
|
334
|
+
batched_arg_tensors[action_prob_func_args_beta_index],
|
|
335
|
+
action_prob_func,
|
|
336
|
+
action_prob_func_args_beta_index,
|
|
337
|
+
in_study_batched_actions_tensor,
|
|
338
|
+
batch_axes,
|
|
339
|
+
batched_arg_tensors,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Collect the gradients for the in-study users in this group into the
|
|
343
|
+
# overall dict.
|
|
344
|
+
in_batch_index = 0
|
|
345
|
+
for user_id in sorted_user_ids:
|
|
346
|
+
if user_id not in involved_user_ids:
|
|
347
|
+
continue
|
|
348
|
+
in_study_pi_gradients_by_user_id[user_id] = in_study_pi_gradients_subset[
|
|
349
|
+
in_batch_index
|
|
350
|
+
]
|
|
351
|
+
in_study_weight_gradients_by_user_id[user_id] = (
|
|
352
|
+
in_study_weight_gradients_subset[in_batch_index]
|
|
353
|
+
)
|
|
354
|
+
in_batch_index += 1
|
|
355
|
+
|
|
356
|
+
in_study_pi_gradients = [
|
|
357
|
+
in_study_pi_gradients_by_user_id[user_id]
|
|
358
|
+
for user_id in sorted_user_ids
|
|
359
|
+
if user_id in all_involved_user_ids
|
|
360
|
+
]
|
|
361
|
+
in_study_weight_gradients = [
|
|
362
|
+
in_study_weight_gradients_by_user_id[user_id]
|
|
363
|
+
for user_id in sorted_user_ids
|
|
364
|
+
if user_id in all_involved_user_ids
|
|
365
|
+
]
|
|
366
|
+
# TODO: These padding methods assume someone was in the study at this time.
|
|
367
|
+
pi_gradients = pad_in_study_derivatives_with_zeros(
|
|
368
|
+
in_study_pi_gradients, sorted_user_ids, all_involved_user_ids
|
|
369
|
+
)
|
|
370
|
+
weight_gradients = pad_in_study_derivatives_with_zeros(
|
|
371
|
+
in_study_weight_gradients, sorted_user_ids, all_involved_user_ids
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return pi_gradients, weight_gradients
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
# TODO: is it ok to get the action from the study df? No issues with actions taken
|
|
378
|
+
# but not known about?
|
|
379
|
+
# TODO: Test this at least with an incremental recruitment collect pi gradients
|
|
380
|
+
# case, possibly directly.
|
|
381
|
+
def collect_batched_in_study_actions(
|
|
382
|
+
study_df,
|
|
383
|
+
calendar_t,
|
|
384
|
+
sorted_user_ids,
|
|
385
|
+
in_study_col_name,
|
|
386
|
+
action_col_name,
|
|
387
|
+
calendar_t_col_name,
|
|
388
|
+
user_id_col_name,
|
|
389
|
+
):
|
|
390
|
+
|
|
391
|
+
# TODO: This for loop can be removed, just grabbing the actions col after
|
|
392
|
+
# filtering and sorting, and converting to jnp array. It's just an artifact
|
|
393
|
+
# from when the loop used to be more complicated.
|
|
394
|
+
batched_actions_list = []
|
|
395
|
+
for user_id in sorted_user_ids:
|
|
396
|
+
filtered_user_data = study_df.loc[
|
|
397
|
+
(study_df[user_id_col_name] == user_id)
|
|
398
|
+
& (study_df[calendar_t_col_name] == calendar_t)
|
|
399
|
+
& (study_df[in_study_col_name] == 1)
|
|
400
|
+
]
|
|
401
|
+
if not filtered_user_data.empty:
|
|
402
|
+
batched_actions_list.append(filtered_user_data[action_col_name].values[0])
|
|
403
|
+
|
|
404
|
+
return jnp.array(batched_actions_list)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
# TODO: Docstring
|
|
408
|
+
def get_radon_nikodym_weight(
|
|
409
|
+
beta_target,
|
|
410
|
+
action_prob_func,
|
|
411
|
+
action_prob_func_args_beta_index,
|
|
412
|
+
action,
|
|
413
|
+
*action_prob_func_args_single_user,
|
|
414
|
+
):
|
|
415
|
+
|
|
416
|
+
beta_target_action_prob_func_args_single_user = [*action_prob_func_args_single_user]
|
|
417
|
+
beta_target_action_prob_func_args_single_user[action_prob_func_args_beta_index] = (
|
|
418
|
+
beta_target
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
pi_beta = action_prob_func(*action_prob_func_args_single_user)
|
|
422
|
+
pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_user)
|
|
423
|
+
return conditional_x_or_one_minus_x(pi_beta, action) / conditional_x_or_one_minus_x(
|
|
424
|
+
pi_beta_target, action
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
# TODO: Docstring
|
|
429
|
+
def get_pi_gradients_batched(
|
|
430
|
+
action_prob_func,
|
|
431
|
+
action_prob_func_args_beta_index,
|
|
432
|
+
batch_axes,
|
|
433
|
+
batched_arg_tensors,
|
|
434
|
+
):
|
|
435
|
+
return jax.vmap(
|
|
436
|
+
fun=jax.grad(action_prob_func, action_prob_func_args_beta_index),
|
|
437
|
+
in_axes=batch_axes,
|
|
438
|
+
out_axes=0,
|
|
439
|
+
)(*batched_arg_tensors)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
# TODO: Docstring
|
|
443
|
+
def get_weight_gradients_batched(
|
|
444
|
+
batched_beta_target_tensor,
|
|
445
|
+
action_prob_func,
|
|
446
|
+
action_prob_func_args_beta_index,
|
|
447
|
+
batched_actions_tensor,
|
|
448
|
+
batch_axes,
|
|
449
|
+
batched_arg_tensors,
|
|
450
|
+
):
|
|
451
|
+
# NOTE the (4 + index) is due to the fact that we have four fixed args in
|
|
452
|
+
# the above definition of the weight function before passing in the action
|
|
453
|
+
# prob args
|
|
454
|
+
return jax.vmap(
|
|
455
|
+
fun=jax.grad(get_radon_nikodym_weight, 4 + action_prob_func_args_beta_index),
|
|
456
|
+
in_axes=[0, None, None, 0] + batch_axes,
|
|
457
|
+
out_axes=0,
|
|
458
|
+
)(
|
|
459
|
+
batched_beta_target_tensor,
|
|
460
|
+
action_prob_func,
|
|
461
|
+
action_prob_func_args_beta_index,
|
|
462
|
+
batched_actions_tensor,
|
|
463
|
+
*batched_arg_tensors,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
# TODO: Docstring
|
|
468
|
+
# TODO: JIT whole function? or just gradient and hessian batch functions
|
|
469
|
+
# TODO: This is a hotspot for moving away from update times
|
|
470
|
+
def calculate_rl_update_derivatives(
|
|
471
|
+
study_df,
|
|
472
|
+
rl_update_func_filename,
|
|
473
|
+
rl_update_func_args,
|
|
474
|
+
rl_update_func_type,
|
|
475
|
+
rl_update_func_args_beta_index,
|
|
476
|
+
rl_update_func_args_action_prob_index,
|
|
477
|
+
rl_update_func_args_action_prob_times_index,
|
|
478
|
+
policy_num_col_name,
|
|
479
|
+
calendar_t_col_name,
|
|
480
|
+
):
|
|
481
|
+
logger.debug(
|
|
482
|
+
"Calculating RL loss gradients and hessians with respect to beta and mixed beta/action probability derivatives for each user at all update times."
|
|
483
|
+
)
|
|
484
|
+
rl_update_func = load_function_from_same_named_file(rl_update_func_filename)
|
|
485
|
+
|
|
486
|
+
rl_update_derivatives_by_calendar_t = {}
|
|
487
|
+
user_ids = list(next(iter(rl_update_func_args.values())).keys())
|
|
488
|
+
sorted_user_ids = sorted(user_ids)
|
|
489
|
+
for policy_num, args_by_user_id in rl_update_func_args.items():
|
|
490
|
+
# We store these loss gradients by the first time the resulting parameters
|
|
491
|
+
# apply to, so determine this time.
|
|
492
|
+
# Because we perform algorithm updates at the *end* of a timestep, the
|
|
493
|
+
# first timestep they apply to is one more than the time at which the
|
|
494
|
+
# update data is gathered.
|
|
495
|
+
first_applicable_time = get_first_applicable_time(
|
|
496
|
+
study_df, policy_num, policy_num_col_name, calendar_t_col_name
|
|
497
|
+
)
|
|
498
|
+
loss_gradients, loss_hessians, loss_gradient_pi_derivatives = (
|
|
499
|
+
calculate_rl_update_derivatives_specific_update(
|
|
500
|
+
rl_update_func,
|
|
501
|
+
rl_update_func_type,
|
|
502
|
+
rl_update_func_args_beta_index,
|
|
503
|
+
rl_update_func_args_action_prob_index,
|
|
504
|
+
rl_update_func_args_action_prob_times_index,
|
|
505
|
+
args_by_user_id,
|
|
506
|
+
sorted_user_ids,
|
|
507
|
+
first_applicable_time,
|
|
508
|
+
)
|
|
509
|
+
)
|
|
510
|
+
rl_update_derivatives_by_calendar_t.setdefault(first_applicable_time, {})[
|
|
511
|
+
"loss_gradients_by_user_id"
|
|
512
|
+
] = {user_id: loss_gradients[i] for i, user_id in enumerate(sorted_user_ids)}
|
|
513
|
+
rl_update_derivatives_by_calendar_t[first_applicable_time][
|
|
514
|
+
"avg_loss_hessian"
|
|
515
|
+
] = np.mean(loss_hessians, axis=0)
|
|
516
|
+
|
|
517
|
+
rl_update_derivatives_by_calendar_t[first_applicable_time][
|
|
518
|
+
"loss_gradient_pi_derivatives_by_user_id"
|
|
519
|
+
] = {
|
|
520
|
+
# NOTE the [..., 0] here... it is very important. Without it we have
|
|
521
|
+
# a shape (x,y,z,1) array of gradients, and the use of these
|
|
522
|
+
# probabilities assumes (x,y,z). This should arguably
|
|
523
|
+
# happen above, but the vmap call spits out a 4D array so in that
|
|
524
|
+
# sense that's the most natural return value. Note that we don't
|
|
525
|
+
# simply squeeze because that would remove the beta dimension
|
|
526
|
+
# if it were one.
|
|
527
|
+
# TODO: This probably has to do with the dimension of the action
|
|
528
|
+
# probabilities... we may need to specify that they are scalars in the
|
|
529
|
+
# loss function args, rather than 1-element vectors. Or one will
|
|
530
|
+
# have to say so. Test both of these cases. Can probably check
|
|
531
|
+
# dimensions and squeeze if necessary.
|
|
532
|
+
user_id: loss_gradient_pi_derivatives[i][..., 0]
|
|
533
|
+
for i, user_id in enumerate(sorted_user_ids)
|
|
534
|
+
}
|
|
535
|
+
return rl_update_derivatives_by_calendar_t
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def calculate_rl_update_derivatives_specific_update(
|
|
539
|
+
rl_update_func,
|
|
540
|
+
rl_update_func_type,
|
|
541
|
+
rl_update_func_args_beta_index,
|
|
542
|
+
rl_update_func_args_action_prob_index,
|
|
543
|
+
rl_update_func_args_action_prob_times_index,
|
|
544
|
+
args_by_user_id,
|
|
545
|
+
sorted_user_ids,
|
|
546
|
+
first_applicable_time,
|
|
547
|
+
):
|
|
548
|
+
logger.debug(
|
|
549
|
+
"Calculating RL update derivatives for the update that first applies at time %d.",
|
|
550
|
+
first_applicable_time,
|
|
551
|
+
)
|
|
552
|
+
# Get a list of subdicts of the user args dict, with each united by having
|
|
553
|
+
# the same shapes across all arguments. We will then vmap the gradients needed
|
|
554
|
+
# for each subdict separately, and later combine the results. In the worst
|
|
555
|
+
# case we may have a batch per user, if, say, everyone starts on a different
|
|
556
|
+
# date, and this will be slow. If this is problematic we can pad the data
|
|
557
|
+
# with some values that don't affect computations, producing one batch here.
|
|
558
|
+
# This also supports very large simulations by making things fast as long
|
|
559
|
+
# as there is 1 or a small number of shape batches.
|
|
560
|
+
# NOTE: Susan and Kelly think we might actually have uniqueish shapes pretty often
|
|
561
|
+
nontrivial_user_args_grouped_by_shape = group_user_args_by_shape(args_by_user_id)
|
|
562
|
+
logger.debug(
|
|
563
|
+
"Found %d set(s) of users with different arg shapes.",
|
|
564
|
+
len(nontrivial_user_args_grouped_by_shape),
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Loop over each set of user args and vmap to get their pi and weight gradients
|
|
568
|
+
in_study_loss_gradients_by_user_id = {}
|
|
569
|
+
in_study_loss_hessians_by_user_id = {}
|
|
570
|
+
in_study_loss_gradient_pi_derivatives_by_user_id = {}
|
|
571
|
+
all_involved_user_ids = set()
|
|
572
|
+
for args_by_user_id_subset in nontrivial_user_args_grouped_by_shape:
|
|
573
|
+
# Pivot the loss args for the involved users into a list of lists, each
|
|
574
|
+
# representing all the args at a particular index across users. Note
|
|
575
|
+
# that users not in the study at this time are filtered out by this
|
|
576
|
+
# function when it checks for truthiness of the supplied args.
|
|
577
|
+
batched_arg_lists, involved_user_ids = (
|
|
578
|
+
get_batched_arg_lists_and_involved_user_ids(
|
|
579
|
+
rl_update_func, sorted_user_ids, args_by_user_id_subset
|
|
580
|
+
)
|
|
581
|
+
)
|
|
582
|
+
all_involved_user_ids |= involved_user_ids
|
|
583
|
+
|
|
584
|
+
if not batched_arg_lists[0]:
|
|
585
|
+
continue
|
|
586
|
+
|
|
587
|
+
logger.debug("Reforming batched data lists into tensors.")
|
|
588
|
+
# Now just transform the previous list of lists into a jnp array for each
|
|
589
|
+
# index (a tensor for each argument). This is for passing to vmap.
|
|
590
|
+
batched_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensor(
|
|
591
|
+
batched_arg_lists
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
logger.debug("Forming loss gradients with respect to beta.")
|
|
595
|
+
in_study_loss_gradients_subset = get_loss_gradients_batched(
|
|
596
|
+
rl_update_func,
|
|
597
|
+
rl_update_func_type,
|
|
598
|
+
rl_update_func_args_beta_index,
|
|
599
|
+
batch_axes,
|
|
600
|
+
*batched_arg_tensors,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
logger.debug("Forming loss hessians with respect to beta.")
|
|
604
|
+
in_study_loss_hessians_subset = get_loss_hessians_batched(
|
|
605
|
+
rl_update_func,
|
|
606
|
+
rl_update_func_type,
|
|
607
|
+
rl_update_func_args_beta_index,
|
|
608
|
+
batch_axes,
|
|
609
|
+
*batched_arg_tensors,
|
|
610
|
+
)
|
|
611
|
+
logger.debug(
|
|
612
|
+
"Forming derivatives of loss with respect to beta and then the action probabilites vector at each time."
|
|
613
|
+
)
|
|
614
|
+
if rl_update_func_args_action_prob_index >= 0:
|
|
615
|
+
in_study_loss_gradient_pi_derivatives_subset = (
|
|
616
|
+
get_loss_gradient_derivatives_wrt_pi_batched(
|
|
617
|
+
rl_update_func,
|
|
618
|
+
rl_update_func_type,
|
|
619
|
+
rl_update_func_args_beta_index,
|
|
620
|
+
rl_update_func_args_action_prob_index,
|
|
621
|
+
batch_axes,
|
|
622
|
+
*batched_arg_tensors,
|
|
623
|
+
)
|
|
624
|
+
)
|
|
625
|
+
# Collect the gradients for the in-study users in this group into the
|
|
626
|
+
# overall dict.
|
|
627
|
+
in_batch_index = 0
|
|
628
|
+
for user_id in sorted_user_ids:
|
|
629
|
+
if user_id not in involved_user_ids:
|
|
630
|
+
continue
|
|
631
|
+
in_study_loss_gradients_by_user_id[user_id] = (
|
|
632
|
+
in_study_loss_gradients_subset[in_batch_index]
|
|
633
|
+
)
|
|
634
|
+
in_study_loss_hessians_by_user_id[user_id] = in_study_loss_hessians_subset[
|
|
635
|
+
in_batch_index
|
|
636
|
+
]
|
|
637
|
+
if rl_update_func_args_action_prob_index >= 0:
|
|
638
|
+
in_study_loss_gradient_pi_derivatives_by_user_id[user_id] = (
|
|
639
|
+
pad_loss_gradient_pi_derivative_outside_supplied_action_probabilites(
|
|
640
|
+
in_study_loss_gradient_pi_derivatives_subset[in_batch_index],
|
|
641
|
+
args_by_user_id[user_id][
|
|
642
|
+
rl_update_func_args_action_prob_times_index
|
|
643
|
+
],
|
|
644
|
+
first_applicable_time,
|
|
645
|
+
)
|
|
646
|
+
)
|
|
647
|
+
in_batch_index += 1
|
|
648
|
+
in_study_loss_gradients = [
|
|
649
|
+
in_study_loss_gradients_by_user_id[user_id]
|
|
650
|
+
for user_id in sorted_user_ids
|
|
651
|
+
if user_id in all_involved_user_ids
|
|
652
|
+
]
|
|
653
|
+
in_study_loss_hessians = [
|
|
654
|
+
in_study_loss_hessians_by_user_id[user_id]
|
|
655
|
+
for user_id in sorted_user_ids
|
|
656
|
+
if user_id in all_involved_user_ids
|
|
657
|
+
]
|
|
658
|
+
if rl_update_func_args_action_prob_index >= 0:
|
|
659
|
+
in_study_loss_gradient_pi_derivatives = [
|
|
660
|
+
in_study_loss_gradient_pi_derivatives_by_user_id[user_id]
|
|
661
|
+
for user_id in sorted_user_ids
|
|
662
|
+
if user_id in all_involved_user_ids
|
|
663
|
+
]
|
|
664
|
+
# TODO: These padding methods assume *someone* had study data at this time.
|
|
665
|
+
loss_gradients = pad_in_study_derivatives_with_zeros(
|
|
666
|
+
in_study_loss_gradients, sorted_user_ids, all_involved_user_ids
|
|
667
|
+
)
|
|
668
|
+
loss_hessians = pad_in_study_derivatives_with_zeros(
|
|
669
|
+
in_study_loss_hessians, sorted_user_ids, all_involved_user_ids
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
# If there is an action probability argument in the RL update function, we need to
|
|
673
|
+
# pad the derivatives calculated already with zeros for those users not currently
|
|
674
|
+
# in the study. Otherwise simply return all zero gradients of the correct shape.
|
|
675
|
+
if rl_update_func_args_action_prob_index >= 0:
|
|
676
|
+
loss_gradient_pi_derivatives = pad_in_study_derivatives_with_zeros(
|
|
677
|
+
in_study_loss_gradient_pi_derivatives,
|
|
678
|
+
sorted_user_ids,
|
|
679
|
+
all_involved_user_ids,
|
|
680
|
+
)
|
|
681
|
+
else:
|
|
682
|
+
num_users = len(sorted_user_ids)
|
|
683
|
+
beta_dim = batched_arg_lists[rl_update_func_args_beta_index][0].size
|
|
684
|
+
timesteps_included = first_applicable_time - 1
|
|
685
|
+
|
|
686
|
+
loss_gradient_pi_derivatives = np.zeros(
|
|
687
|
+
(num_users, beta_dim, timesteps_included, 1)
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
return loss_gradients, loss_hessians, loss_gradient_pi_derivatives
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def get_loss_gradients_batched(
|
|
694
|
+
update_func,
|
|
695
|
+
update_func_type,
|
|
696
|
+
update_func_args_beta_index,
|
|
697
|
+
batch_axes,
|
|
698
|
+
*batched_arg_tensors,
|
|
699
|
+
):
|
|
700
|
+
if update_func_type == FunctionTypes.LOSS:
|
|
701
|
+
return jax.vmap(
|
|
702
|
+
fun=jax.grad(update_func, update_func_args_beta_index),
|
|
703
|
+
in_axes=batch_axes,
|
|
704
|
+
out_axes=0,
|
|
705
|
+
)(*batched_arg_tensors)
|
|
706
|
+
if update_func_type == FunctionTypes.ESTIMATING:
|
|
707
|
+
return jax.vmap(
|
|
708
|
+
fun=update_func,
|
|
709
|
+
in_axes=batch_axes,
|
|
710
|
+
out_axes=0,
|
|
711
|
+
)(*batched_arg_tensors)
|
|
712
|
+
raise ValueError("Unknown update function type.")
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def get_loss_hessians_batched(
|
|
716
|
+
update_func,
|
|
717
|
+
update_func_type,
|
|
718
|
+
update_func_args_beta_index,
|
|
719
|
+
batch_axes,
|
|
720
|
+
*batched_arg_tensors,
|
|
721
|
+
):
|
|
722
|
+
if update_func_type == FunctionTypes.LOSS:
|
|
723
|
+
return jax.vmap(
|
|
724
|
+
fun=jax.hessian(update_func, update_func_args_beta_index),
|
|
725
|
+
in_axes=batch_axes,
|
|
726
|
+
out_axes=0,
|
|
727
|
+
)(*batched_arg_tensors)
|
|
728
|
+
if update_func_type == FunctionTypes.ESTIMATING:
|
|
729
|
+
return jax.vmap(
|
|
730
|
+
fun=jax.jacrev(update_func, update_func_args_beta_index),
|
|
731
|
+
in_axes=batch_axes,
|
|
732
|
+
out_axes=0,
|
|
733
|
+
)(*batched_arg_tensors)
|
|
734
|
+
raise ValueError("Unknown update function type.")
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def get_loss_gradient_derivatives_wrt_pi_batched(
|
|
738
|
+
update_func,
|
|
739
|
+
update_func_type,
|
|
740
|
+
update_func_args_beta_index,
|
|
741
|
+
update_func_args_action_prob_index,
|
|
742
|
+
batch_axes,
|
|
743
|
+
*batched_arg_tensors,
|
|
744
|
+
):
|
|
745
|
+
if update_func_type == FunctionTypes.LOSS:
|
|
746
|
+
return jax.jit(
|
|
747
|
+
jax.vmap(
|
|
748
|
+
fun=jax.jacrev(
|
|
749
|
+
jax.grad(update_func, update_func_args_beta_index),
|
|
750
|
+
update_func_args_action_prob_index,
|
|
751
|
+
),
|
|
752
|
+
in_axes=batch_axes,
|
|
753
|
+
out_axes=0,
|
|
754
|
+
)
|
|
755
|
+
)(*batched_arg_tensors)
|
|
756
|
+
if update_func_type == FunctionTypes.ESTIMATING:
|
|
757
|
+
return jax.jit(
|
|
758
|
+
jax.vmap(
|
|
759
|
+
fun=jax.jacrev(
|
|
760
|
+
update_func,
|
|
761
|
+
update_func_args_action_prob_index,
|
|
762
|
+
),
|
|
763
|
+
in_axes=batch_axes,
|
|
764
|
+
out_axes=0,
|
|
765
|
+
)
|
|
766
|
+
)(*batched_arg_tensors)
|
|
767
|
+
raise ValueError("Unknown update function type.")
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
# TODO: Is there a better way to calculate this? This seems like it should
|
|
771
|
+
# be reliable, not messing up when a policy was actually available. If study
|
|
772
|
+
# df says policy was used, that should be correct. May not play nicely with
|
|
773
|
+
# pure exploration phase though.
|
|
774
|
+
def get_first_applicable_time(
|
|
775
|
+
study_df, policy_num, policy_num_col_name, calendar_t_col_name
|
|
776
|
+
):
|
|
777
|
+
return study_df[study_df[policy_num_col_name] == policy_num][
|
|
778
|
+
calendar_t_col_name
|
|
779
|
+
].min()
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def calculate_inference_loss_derivatives(
|
|
783
|
+
study_df,
|
|
784
|
+
theta_est,
|
|
785
|
+
inference_func,
|
|
786
|
+
inference_func_args_theta_index,
|
|
787
|
+
user_ids,
|
|
788
|
+
user_id_col_name,
|
|
789
|
+
action_prob_col_name,
|
|
790
|
+
in_study_col_name,
|
|
791
|
+
calendar_t_col_name,
|
|
792
|
+
inference_func_type=FunctionTypes.LOSS,
|
|
793
|
+
):
|
|
794
|
+
logger.debug("Calculating inference loss derivatives.")
|
|
795
|
+
|
|
796
|
+
# Convert to list if needed (from jnp array, etc)
|
|
797
|
+
try:
|
|
798
|
+
user_ids = user_ids.tolist()
|
|
799
|
+
except Exception:
|
|
800
|
+
pass
|
|
801
|
+
|
|
802
|
+
num_args = inference_func.__code__.co_argcount
|
|
803
|
+
inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
|
|
804
|
+
# NOTE: Cannot do [[]] * num_args here! Then all lists point
|
|
805
|
+
# same object...
|
|
806
|
+
batched_arg_lists = [[] for _ in range(num_args)]
|
|
807
|
+
|
|
808
|
+
# We begin by constructing a dict of loss function arg tuples of the type we get from file
|
|
809
|
+
# for the RL data; because we have to group user args by shape anyway, we
|
|
810
|
+
# might as well collect them in this format and then use previous machinery
|
|
811
|
+
# to process them. There are a few extra loops but more shared code this way.
|
|
812
|
+
args_by_user_id = {}
|
|
813
|
+
using_action_probs = action_prob_col_name in inference_func_arg_names
|
|
814
|
+
if using_action_probs:
|
|
815
|
+
inference_func_args_action_prob_index = inference_func_arg_names.index(
|
|
816
|
+
action_prob_col_name
|
|
817
|
+
)
|
|
818
|
+
action_prob_decision_times_by_user_id = {}
|
|
819
|
+
max_calendar_time = study_df[calendar_t_col_name].max()
|
|
820
|
+
for user_id in user_ids:
|
|
821
|
+
user_args_list = []
|
|
822
|
+
filtered_user_data = study_df.loc[study_df[user_id_col_name] == user_id]
|
|
823
|
+
for idx, col_name in enumerate(inference_func_arg_names):
|
|
824
|
+
if idx == inference_func_args_theta_index:
|
|
825
|
+
user_args_list.append(theta_est)
|
|
826
|
+
else:
|
|
827
|
+
user_args_list.append(
|
|
828
|
+
get_study_df_column(filtered_user_data, col_name, in_study_col_name)
|
|
829
|
+
)
|
|
830
|
+
args_by_user_id[user_id] = tuple(user_args_list)
|
|
831
|
+
if using_action_probs:
|
|
832
|
+
action_prob_decision_times_by_user_id[user_id] = get_study_df_column(
|
|
833
|
+
filtered_user_data, calendar_t_col_name, in_study_col_name
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
# Get a list of subdicts of the user args dict, with each united by having
|
|
837
|
+
# the same shapes across all arguments. We will then vmap the gradients needed
|
|
838
|
+
# for each subdict separately, and later combine the results. In the worst
|
|
839
|
+
# case we may have a batch per user, if, say, everyone starts on a different
|
|
840
|
+
# date, and this will be slow. If this is problematic we can pad the data
|
|
841
|
+
# with some values that don't affect computations, producing one batch here.
|
|
842
|
+
# This also supports very large simulations by making things fast as long
|
|
843
|
+
# as there is 1 or a small number of shape batches.
|
|
844
|
+
# NOTE: As opposed to the RL updates, we should expect a small number of
|
|
845
|
+
# batches here. It is only users having different numbers of decision times
|
|
846
|
+
# that contributes additional batches.
|
|
847
|
+
nontrivial_user_args_grouped_by_shape = group_user_args_by_shape(
|
|
848
|
+
args_by_user_id, empty_allowed=False
|
|
849
|
+
)
|
|
850
|
+
logger.debug(
|
|
851
|
+
"Found %d set(s) of users with different arg shapes.",
|
|
852
|
+
len(nontrivial_user_args_grouped_by_shape),
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
loss_gradients_by_user_id = {}
|
|
856
|
+
loss_hessians_by_user_id = {}
|
|
857
|
+
loss_gradient_pi_derivatives_by_user_id = {}
|
|
858
|
+
all_involved_user_ids = set()
|
|
859
|
+
sorted_user_ids = sorted(user_ids)
|
|
860
|
+
for args_by_user_id_subset in nontrivial_user_args_grouped_by_shape:
|
|
861
|
+
batched_arg_lists, involved_user_ids = (
|
|
862
|
+
get_batched_arg_lists_and_involved_user_ids(
|
|
863
|
+
inference_func, sorted_user_ids, args_by_user_id_subset
|
|
864
|
+
)
|
|
865
|
+
)
|
|
866
|
+
all_involved_user_ids |= involved_user_ids
|
|
867
|
+
|
|
868
|
+
logger.debug("Reforming batched data lists into tensors.")
|
|
869
|
+
batched_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensor(
|
|
870
|
+
batched_arg_lists
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
logger.debug("Forming loss gradients with respect to theta.")
|
|
874
|
+
loss_gradients_subset = get_loss_gradients_batched(
|
|
875
|
+
inference_func,
|
|
876
|
+
inference_func_type,
|
|
877
|
+
inference_func_args_theta_index,
|
|
878
|
+
batch_axes,
|
|
879
|
+
*batched_arg_tensors,
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
logger.debug("Forming loss hessians with respect to theta.")
|
|
883
|
+
loss_hessians_subset = get_loss_hessians_batched(
|
|
884
|
+
inference_func,
|
|
885
|
+
inference_func_type,
|
|
886
|
+
inference_func_args_theta_index,
|
|
887
|
+
batch_axes,
|
|
888
|
+
*batched_arg_tensors,
|
|
889
|
+
)
|
|
890
|
+
logger.debug(
|
|
891
|
+
"Forming derivatives of loss with respect to theta and then the action probabilities vector at each time."
|
|
892
|
+
)
|
|
893
|
+
# If there is an action probability argument in the loss,
|
|
894
|
+
# actually differentiate with respect to action probabilities
|
|
895
|
+
if using_action_probs:
|
|
896
|
+
loss_gradient_pi_derivatives_subset = (
|
|
897
|
+
get_loss_gradient_derivatives_wrt_pi_batched(
|
|
898
|
+
inference_func,
|
|
899
|
+
inference_func_type,
|
|
900
|
+
inference_func_args_theta_index,
|
|
901
|
+
inference_func_args_action_prob_index,
|
|
902
|
+
batch_axes,
|
|
903
|
+
*batched_arg_tensors,
|
|
904
|
+
)
|
|
905
|
+
)
|
|
906
|
+
# Collect the gradients for the in-study users in this group into the
|
|
907
|
+
# overall dict.
|
|
908
|
+
in_batch_index = 0
|
|
909
|
+
for user_id in sorted_user_ids:
|
|
910
|
+
if user_id not in involved_user_ids:
|
|
911
|
+
continue
|
|
912
|
+
loss_gradients_by_user_id[user_id] = loss_gradients_subset[in_batch_index]
|
|
913
|
+
loss_hessians_by_user_id[user_id] = loss_hessians_subset[in_batch_index]
|
|
914
|
+
if using_action_probs:
|
|
915
|
+
loss_gradient_pi_derivatives_by_user_id[user_id] = (
|
|
916
|
+
pad_loss_gradient_pi_derivative_outside_supplied_action_probabilites(
|
|
917
|
+
loss_gradient_pi_derivatives_subset[in_batch_index],
|
|
918
|
+
action_prob_decision_times_by_user_id[user_id],
|
|
919
|
+
max_calendar_time + 1,
|
|
920
|
+
)
|
|
921
|
+
)
|
|
922
|
+
in_batch_index += 1
|
|
923
|
+
loss_gradients = np.array(
|
|
924
|
+
[
|
|
925
|
+
loss_gradients_by_user_id[user_id]
|
|
926
|
+
for user_id in sorted_user_ids
|
|
927
|
+
if user_id in all_involved_user_ids
|
|
928
|
+
]
|
|
929
|
+
)
|
|
930
|
+
loss_hessians = np.array(
|
|
931
|
+
[
|
|
932
|
+
loss_hessians_by_user_id[user_id]
|
|
933
|
+
for user_id in sorted_user_ids
|
|
934
|
+
if user_id in all_involved_user_ids
|
|
935
|
+
]
|
|
936
|
+
)
|
|
937
|
+
# If using action probs, collect the mixed theta pi derivatives computed
|
|
938
|
+
# so far.
|
|
939
|
+
if using_action_probs:
|
|
940
|
+
loss_gradient_pi_derivatives = np.array(
|
|
941
|
+
[
|
|
942
|
+
loss_gradient_pi_derivatives_by_user_id[user_id]
|
|
943
|
+
for user_id in sorted_user_ids
|
|
944
|
+
if user_id in all_involved_user_ids
|
|
945
|
+
]
|
|
946
|
+
)
|
|
947
|
+
# Otherwise, we need to simply return zero gradients of the correct shape.
|
|
948
|
+
else:
|
|
949
|
+
num_users = len(user_ids)
|
|
950
|
+
theta_dim = theta_est.size
|
|
951
|
+
timesteps_included = study_df[calendar_t_col_name].nunique()
|
|
952
|
+
|
|
953
|
+
loss_gradient_pi_derivatives = np.zeros(
|
|
954
|
+
(num_users, theta_dim, timesteps_included, 1)
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
return loss_gradients, loss_hessians, loss_gradient_pi_derivatives
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def get_study_df_column(study_df, col_name, in_study_col_name):
|
|
961
|
+
return jnp.array(
|
|
962
|
+
study_df.loc[study_df[in_study_col_name] == 1, col_name]
|
|
963
|
+
.to_numpy()
|
|
964
|
+
.reshape(-1, 1)
|
|
965
|
+
)
|