lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,965 @@
1
+ import collections
2
+ import logging
3
+
4
+ import jax
5
+ from jax import numpy as jnp
6
+ import numpy as np
7
+
8
+ from .constants import FunctionTypes
9
+ from .helper_functions import (
10
+ conditional_x_or_one_minus_x,
11
+ load_function_from_same_named_file,
12
+ )
13
+
14
+ logger = logging.getLogger(__name__)
15
+ logging.basicConfig(
16
+ format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
17
+ datefmt="%Y-%m-%d:%H:%M:%S",
18
+ level=logging.INFO,
19
+ )
20
+
21
+ # TODO: Consolidate function loading logic
22
+
23
+
24
+ def get_batched_arg_lists_and_involved_user_ids(func, sorted_user_ids, args_by_user_id):
25
+ """
26
+ Collect a dict of arg tuples by user id into a list of lists, each containing
27
+ all the args at a particular index across users. We make sure the list is
28
+ ordered according to sorted_user_ids.
29
+ """
30
+ # Sort users to be cautious. We check if the user id is present in the user args dict
31
+ # because we may call this on a subset of the user arg dict when we are batching
32
+ # arguments by shape
33
+ sorted_args_by_user_id = {
34
+ user_id: args_by_user_id[user_id]
35
+ for user_id in sorted_user_ids
36
+ if user_id in args_by_user_id
37
+ }
38
+
39
+ # Just a quick way to get the arg count instead of iterating through args
40
+ # for the first Truthy tuple
41
+ # TODO: If there are arguments with defaults and not supplied, this will break.
42
+ # Should probably in fact iterate through to first Truthy tuple.
43
+ num_args = func.__code__.co_argcount
44
+
45
+ # NOTE: Cannot do [[]] * num_args here! Then all lists point
46
+ # same object...
47
+ batched_arg_lists = [[] for _ in range(num_args)]
48
+ involved_user_ids = set()
49
+ for user_id, user_args in sorted_args_by_user_id.items():
50
+ if not user_args:
51
+ continue
52
+ involved_user_ids.add(user_id)
53
+ for idx, arg in enumerate(user_args):
54
+ batched_arg_lists[idx].append(arg)
55
+
56
+ return batched_arg_lists, involved_user_ids
57
+
58
+
59
+ def get_shape(obj):
60
+ if hasattr(obj, "shape"):
61
+ return obj.shape
62
+ if isinstance(obj, str):
63
+ return None
64
+ try:
65
+ return len(obj)
66
+ except Exception:
67
+ return None
68
+
69
+
70
+ def group_user_args_by_shape(user_arg_dict, empty_allowed=True):
71
+ user_arg_dicts_by_shape = collections.defaultdict(dict)
72
+ for user_id, args in user_arg_dict.items():
73
+ if not args:
74
+ if not empty_allowed:
75
+ raise ValueError("There shouldn't be a user with no data at this time")
76
+ continue
77
+ shape_id = tuple(get_shape(arg) for arg in args)
78
+ user_arg_dicts_by_shape[shape_id][user_id] = args
79
+ return user_arg_dicts_by_shape.values()
80
+
81
+
82
+ # TODO: Check for exactly the required types earlier
83
+ # TODO: Try except and nice error message
84
+ # TODO: This is complicated enough to deserve its own unit tests
85
+ def stack_batched_arg_lists_into_tensor(batched_arg_lists):
86
+ """
87
+ Stack a simple Python list of lists of function arguments (across all users for a specific arg position)
88
+ into a list of jnp arrays that can be supplied to vmap as batch arguments. vmap requires all elements of
89
+ such a batched array to be the same shape, as do the stacking functions we use here. Thus we require
90
+ this be called on batches of users with the same data shape. We also supply the axes one must
91
+ iterate over to get each users's args in a batch.
92
+ """
93
+
94
+ batched_arg_tensors = []
95
+
96
+ # This ends up being all zeros because of the way we are (now) doing the
97
+ # stacking, but better to not assume that externally and send out what
98
+ # we've done with this list.
99
+ batch_axes = []
100
+
101
+ for batched_arg_list in batched_arg_lists:
102
+ if (
103
+ isinstance(
104
+ batched_arg_list[0],
105
+ (jnp.ndarray, np.ndarray),
106
+ )
107
+ and batched_arg_list[0].ndim > 2
108
+ ):
109
+ raise TypeError("Arrays with dimension greater that 2 are not supported.")
110
+ if (
111
+ isinstance(
112
+ batched_arg_list[0],
113
+ (jnp.ndarray, np.ndarray),
114
+ )
115
+ and batched_arg_list[0].ndim == 2
116
+ ):
117
+ ########## We have a matrix (2D array) arg
118
+
119
+ batched_arg_tensors.append(jnp.stack(batched_arg_list, 0))
120
+ batch_axes.append(0)
121
+ elif isinstance(
122
+ batched_arg_list[0],
123
+ (collections.abc.Sequence, jnp.ndarray, np.ndarray),
124
+ ) and not isinstance(batched_arg_list[0], str):
125
+ ########## We have a vector (1D array) arg
126
+ if not isinstance(batched_arg_list[0], (jnp.ndarray, np.ndarray)):
127
+ try:
128
+ batched_arg_list = [jnp.array(x) for x in batched_arg_list]
129
+ except Exception as e:
130
+ raise TypeError(
131
+ "Argument of sequence type that cannot be cast to JAX numpy array"
132
+ ) from e
133
+ assert batched_arg_list[0].ndim == 1
134
+
135
+ batched_arg_tensors.append(jnp.vstack(batched_arg_list))
136
+ batch_axes.append(0)
137
+ else:
138
+ ########## Otherwise we should have a list of scalars.
139
+ # Just turn into a jnp array.
140
+ batched_arg_tensors.append(jnp.array(batched_arg_list))
141
+ batch_axes.append(0)
142
+
143
+ return (
144
+ batched_arg_tensors,
145
+ batch_axes,
146
+ )
147
+
148
+
149
+ # TODO: Add clarity on why we need gradients at all times.
150
+ def pad_loss_gradient_pi_derivative_outside_supplied_action_probabilites(
151
+ loss_gradient_pi_derivative,
152
+ action_prob_times,
153
+ first_time_after,
154
+ ):
155
+ """
156
+ This fills in zero gradients for the times for which action probabilites
157
+ were not supplied, for users currently in the study. Compare to the below
158
+ padding which is about filling in full sets of zero gradients for all users
159
+ not currently in the study. This is about filling in zero gradients for
160
+ times 1,2,3,4,9,10,11,12 if action probabilities are given for times 5,6,7,
161
+ 8.
162
+ """
163
+ zero_gradient = np.zeros((loss_gradient_pi_derivative.shape[0], 1, 1))
164
+ gradients_to_stack = []
165
+ next_column_index_to_grab = 0
166
+ for t in range(1, first_time_after):
167
+ if t in action_prob_times:
168
+ gradients_to_stack.append(
169
+ np.expand_dims(
170
+ loss_gradient_pi_derivative[:, next_column_index_to_grab, :], 2
171
+ )
172
+ )
173
+ next_column_index_to_grab += 1
174
+ else:
175
+ gradients_to_stack.append(zero_gradient)
176
+
177
+ return np.hstack(gradients_to_stack)
178
+
179
+
180
+ def pad_in_study_derivatives_with_zeros(
181
+ in_study_derivatives, sorted_user_ids, in_study_user_ids
182
+ ):
183
+ """
184
+ This fills in zero gradients for users not currently in the study given the
185
+ derivatives computed for those in it.
186
+ """
187
+ all_derivatives = []
188
+ in_study_next_idx = 0
189
+ for user_id in sorted_user_ids:
190
+ if user_id in in_study_user_ids:
191
+ all_derivatives.append(in_study_derivatives[in_study_next_idx])
192
+ in_study_next_idx += 1
193
+ else:
194
+ all_derivatives.append(np.zeros_like(in_study_derivatives[0]))
195
+
196
+ return all_derivatives
197
+
198
+
199
+ def calculate_pi_and_weight_gradients(
200
+ study_df,
201
+ in_study_col_name,
202
+ action_col_name,
203
+ calendar_t_col_name,
204
+ user_id_col_name,
205
+ action_prob_func,
206
+ action_prob_func_args,
207
+ action_prob_func_args_beta_index,
208
+ ):
209
+ """
210
+ For all decision times, for all users, compute the gradient with respect to
211
+ beta of both the pi function (which takes a state and gives the probability
212
+ of selecting action 1) and the Radon-Nikodym weight (derived from pi
213
+ functions as described in the paper)
214
+ """
215
+
216
+ logger.debug("Calculating pi and weight gradients with respect to beta.")
217
+
218
+ pi_and_weight_gradients_by_calendar_t = {}
219
+
220
+ # This is a reliable way to get all user ids since we require all user ids
221
+ # at all decision times
222
+ user_ids = list(next(iter(action_prob_func_args.values())).keys())
223
+ sorted_user_ids = sorted(user_ids)
224
+
225
+ for calendar_t, args_by_user_id in action_prob_func_args.items():
226
+
227
+ pi_gradients, weight_gradients = calculate_pi_and_weight_gradients_specific_t(
228
+ study_df,
229
+ in_study_col_name,
230
+ action_col_name,
231
+ calendar_t_col_name,
232
+ user_id_col_name,
233
+ action_prob_func,
234
+ action_prob_func_args_beta_index,
235
+ calendar_t,
236
+ args_by_user_id,
237
+ sorted_user_ids,
238
+ )
239
+
240
+ logger.debug("Collecting pi gradients into algorithm stats dictionary.")
241
+ pi_and_weight_gradients_by_calendar_t.setdefault(calendar_t, {})[
242
+ "pi_gradients_by_user_id"
243
+ ] = {user_id: pi_gradients[i] for i, user_id in enumerate(sorted_user_ids)}
244
+
245
+ logger.debug("Collecting weight gradients into algorithm stats dictionary.")
246
+ pi_and_weight_gradients_by_calendar_t.setdefault(calendar_t, {})[
247
+ "weight_gradients_by_user_id"
248
+ ] = {user_id: weight_gradients[i] for i, user_id in enumerate(sorted_user_ids)}
249
+
250
+ return pi_and_weight_gradients_by_calendar_t
251
+
252
+
253
+ def calculate_pi_and_weight_gradients_specific_t(
254
+ study_df,
255
+ in_study_col_name,
256
+ action_col_name,
257
+ calendar_t_col_name,
258
+ user_id_col_name,
259
+ action_prob_func,
260
+ action_prob_func_args_beta_index,
261
+ calendar_t,
262
+ args_by_user_id,
263
+ sorted_user_ids,
264
+ ):
265
+ logger.debug(
266
+ "Calculating pi and weight gradients for decision time %d.",
267
+ calendar_t,
268
+ )
269
+ # Get a list of subdicts of the user args dict, with each united by having
270
+ # the same shapes across all arguments. We will then vmap the gradients needed
271
+ # for each subdict separately, and later combine the results. In the worst
272
+ # case we may have a batch per user, if, say, everyone starts on a different
273
+ # date, and this will be slow. If this is problematic we can pad the data
274
+ # with some values that don't affect computations, producing one batch here.
275
+ # This also supports very large simulations by making things fast as long
276
+ # as there is 1 or a small number of shape batches.
277
+ nontrivial_user_args_grouped_by_shape = group_user_args_by_shape(args_by_user_id)
278
+ logger.debug(
279
+ "Found %d set(s) of users with different arg shapes.",
280
+ len(nontrivial_user_args_grouped_by_shape),
281
+ )
282
+
283
+ # Loop over each set of user args and vmap to get their pi and weight gradients
284
+ in_study_pi_gradients_by_user_id = {}
285
+ in_study_weight_gradients_by_user_id = {}
286
+ all_involved_user_ids = set()
287
+ for args_by_user_id_subset in nontrivial_user_args_grouped_by_shape:
288
+ # Now that we are grouping by arg shape and excluding the out of study
289
+ # group, all the users should be involved in the study in this loop,
290
+ # but... just keep this logic that works for heterogeneous-shaped
291
+ # batches too.
292
+ batched_arg_lists, involved_user_ids = (
293
+ get_batched_arg_lists_and_involved_user_ids(
294
+ action_prob_func, sorted_user_ids, args_by_user_id_subset
295
+ )
296
+ )
297
+ all_involved_user_ids |= involved_user_ids
298
+
299
+ if not batched_arg_lists[0]:
300
+ continue
301
+
302
+ logger.debug("Reforming batched data lists into tensors.")
303
+ batched_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensor(
304
+ batched_arg_lists
305
+ )
306
+
307
+ logger.debug("Forming pi gradients with respect to beta.")
308
+ # Note that we care about the probability of action 1 specifically,
309
+ # not the taken action.
310
+ in_study_pi_gradients_subset = get_pi_gradients_batched(
311
+ action_prob_func,
312
+ action_prob_func_args_beta_index,
313
+ batch_axes,
314
+ batched_arg_tensors,
315
+ )
316
+
317
+ # TODO: betas should be verified to be the same across users now or earlier
318
+ logger.debug("Forming weight gradients with respect to beta.")
319
+ in_study_batched_actions_tensor = collect_batched_in_study_actions(
320
+ study_df,
321
+ calendar_t,
322
+ sorted_user_ids,
323
+ in_study_col_name,
324
+ action_col_name,
325
+ calendar_t_col_name,
326
+ user_id_col_name,
327
+ )
328
+ # Note the first argument here: we extract the betas to pass in
329
+ # again as the "target" denominator betas, whereas we differentiate with
330
+ # respect to the betas in the numerator. Also note that these betas are
331
+ # redundant across users: it's just the same thing repeated num users
332
+ # times.
333
+ in_study_weight_gradients_subset = get_weight_gradients_batched(
334
+ batched_arg_tensors[action_prob_func_args_beta_index],
335
+ action_prob_func,
336
+ action_prob_func_args_beta_index,
337
+ in_study_batched_actions_tensor,
338
+ batch_axes,
339
+ batched_arg_tensors,
340
+ )
341
+
342
+ # Collect the gradients for the in-study users in this group into the
343
+ # overall dict.
344
+ in_batch_index = 0
345
+ for user_id in sorted_user_ids:
346
+ if user_id not in involved_user_ids:
347
+ continue
348
+ in_study_pi_gradients_by_user_id[user_id] = in_study_pi_gradients_subset[
349
+ in_batch_index
350
+ ]
351
+ in_study_weight_gradients_by_user_id[user_id] = (
352
+ in_study_weight_gradients_subset[in_batch_index]
353
+ )
354
+ in_batch_index += 1
355
+
356
+ in_study_pi_gradients = [
357
+ in_study_pi_gradients_by_user_id[user_id]
358
+ for user_id in sorted_user_ids
359
+ if user_id in all_involved_user_ids
360
+ ]
361
+ in_study_weight_gradients = [
362
+ in_study_weight_gradients_by_user_id[user_id]
363
+ for user_id in sorted_user_ids
364
+ if user_id in all_involved_user_ids
365
+ ]
366
+ # TODO: These padding methods assume someone was in the study at this time.
367
+ pi_gradients = pad_in_study_derivatives_with_zeros(
368
+ in_study_pi_gradients, sorted_user_ids, all_involved_user_ids
369
+ )
370
+ weight_gradients = pad_in_study_derivatives_with_zeros(
371
+ in_study_weight_gradients, sorted_user_ids, all_involved_user_ids
372
+ )
373
+
374
+ return pi_gradients, weight_gradients
375
+
376
+
377
+ # TODO: is it ok to get the action from the study df? No issues with actions taken
378
+ # but not known about?
379
+ # TODO: Test this at least with an incremental recruitment collect pi gradients
380
+ # case, possibly directly.
381
+ def collect_batched_in_study_actions(
382
+ study_df,
383
+ calendar_t,
384
+ sorted_user_ids,
385
+ in_study_col_name,
386
+ action_col_name,
387
+ calendar_t_col_name,
388
+ user_id_col_name,
389
+ ):
390
+
391
+ # TODO: This for loop can be removed, just grabbing the actions col after
392
+ # filtering and sorting, and converting to jnp array. It's just an artifact
393
+ # from when the loop used to be more complicated.
394
+ batched_actions_list = []
395
+ for user_id in sorted_user_ids:
396
+ filtered_user_data = study_df.loc[
397
+ (study_df[user_id_col_name] == user_id)
398
+ & (study_df[calendar_t_col_name] == calendar_t)
399
+ & (study_df[in_study_col_name] == 1)
400
+ ]
401
+ if not filtered_user_data.empty:
402
+ batched_actions_list.append(filtered_user_data[action_col_name].values[0])
403
+
404
+ return jnp.array(batched_actions_list)
405
+
406
+
407
+ # TODO: Docstring
408
+ def get_radon_nikodym_weight(
409
+ beta_target,
410
+ action_prob_func,
411
+ action_prob_func_args_beta_index,
412
+ action,
413
+ *action_prob_func_args_single_user,
414
+ ):
415
+
416
+ beta_target_action_prob_func_args_single_user = [*action_prob_func_args_single_user]
417
+ beta_target_action_prob_func_args_single_user[action_prob_func_args_beta_index] = (
418
+ beta_target
419
+ )
420
+
421
+ pi_beta = action_prob_func(*action_prob_func_args_single_user)
422
+ pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_user)
423
+ return conditional_x_or_one_minus_x(pi_beta, action) / conditional_x_or_one_minus_x(
424
+ pi_beta_target, action
425
+ )
426
+
427
+
428
+ # TODO: Docstring
429
+ def get_pi_gradients_batched(
430
+ action_prob_func,
431
+ action_prob_func_args_beta_index,
432
+ batch_axes,
433
+ batched_arg_tensors,
434
+ ):
435
+ return jax.vmap(
436
+ fun=jax.grad(action_prob_func, action_prob_func_args_beta_index),
437
+ in_axes=batch_axes,
438
+ out_axes=0,
439
+ )(*batched_arg_tensors)
440
+
441
+
442
+ # TODO: Docstring
443
+ def get_weight_gradients_batched(
444
+ batched_beta_target_tensor,
445
+ action_prob_func,
446
+ action_prob_func_args_beta_index,
447
+ batched_actions_tensor,
448
+ batch_axes,
449
+ batched_arg_tensors,
450
+ ):
451
+ # NOTE the (4 + index) is due to the fact that we have four fixed args in
452
+ # the above definition of the weight function before passing in the action
453
+ # prob args
454
+ return jax.vmap(
455
+ fun=jax.grad(get_radon_nikodym_weight, 4 + action_prob_func_args_beta_index),
456
+ in_axes=[0, None, None, 0] + batch_axes,
457
+ out_axes=0,
458
+ )(
459
+ batched_beta_target_tensor,
460
+ action_prob_func,
461
+ action_prob_func_args_beta_index,
462
+ batched_actions_tensor,
463
+ *batched_arg_tensors,
464
+ )
465
+
466
+
467
+ # TODO: Docstring
468
+ # TODO: JIT whole function? or just gradient and hessian batch functions
469
+ # TODO: This is a hotspot for moving away from update times
470
+ def calculate_rl_update_derivatives(
471
+ study_df,
472
+ rl_update_func_filename,
473
+ rl_update_func_args,
474
+ rl_update_func_type,
475
+ rl_update_func_args_beta_index,
476
+ rl_update_func_args_action_prob_index,
477
+ rl_update_func_args_action_prob_times_index,
478
+ policy_num_col_name,
479
+ calendar_t_col_name,
480
+ ):
481
+ logger.debug(
482
+ "Calculating RL loss gradients and hessians with respect to beta and mixed beta/action probability derivatives for each user at all update times."
483
+ )
484
+ rl_update_func = load_function_from_same_named_file(rl_update_func_filename)
485
+
486
+ rl_update_derivatives_by_calendar_t = {}
487
+ user_ids = list(next(iter(rl_update_func_args.values())).keys())
488
+ sorted_user_ids = sorted(user_ids)
489
+ for policy_num, args_by_user_id in rl_update_func_args.items():
490
+ # We store these loss gradients by the first time the resulting parameters
491
+ # apply to, so determine this time.
492
+ # Because we perform algorithm updates at the *end* of a timestep, the
493
+ # first timestep they apply to is one more than the time at which the
494
+ # update data is gathered.
495
+ first_applicable_time = get_first_applicable_time(
496
+ study_df, policy_num, policy_num_col_name, calendar_t_col_name
497
+ )
498
+ loss_gradients, loss_hessians, loss_gradient_pi_derivatives = (
499
+ calculate_rl_update_derivatives_specific_update(
500
+ rl_update_func,
501
+ rl_update_func_type,
502
+ rl_update_func_args_beta_index,
503
+ rl_update_func_args_action_prob_index,
504
+ rl_update_func_args_action_prob_times_index,
505
+ args_by_user_id,
506
+ sorted_user_ids,
507
+ first_applicable_time,
508
+ )
509
+ )
510
+ rl_update_derivatives_by_calendar_t.setdefault(first_applicable_time, {})[
511
+ "loss_gradients_by_user_id"
512
+ ] = {user_id: loss_gradients[i] for i, user_id in enumerate(sorted_user_ids)}
513
+ rl_update_derivatives_by_calendar_t[first_applicable_time][
514
+ "avg_loss_hessian"
515
+ ] = np.mean(loss_hessians, axis=0)
516
+
517
+ rl_update_derivatives_by_calendar_t[first_applicable_time][
518
+ "loss_gradient_pi_derivatives_by_user_id"
519
+ ] = {
520
+ # NOTE the [..., 0] here... it is very important. Without it we have
521
+ # a shape (x,y,z,1) array of gradients, and the use of these
522
+ # probabilities assumes (x,y,z). This should arguably
523
+ # happen above, but the vmap call spits out a 4D array so in that
524
+ # sense that's the most natural return value. Note that we don't
525
+ # simply squeeze because that would remove the beta dimension
526
+ # if it were one.
527
+ # TODO: This probably has to do with the dimension of the action
528
+ # probabilities... we may need to specify that they are scalars in the
529
+ # loss function args, rather than 1-element vectors. Or one will
530
+ # have to say so. Test both of these cases. Can probably check
531
+ # dimensions and squeeze if necessary.
532
+ user_id: loss_gradient_pi_derivatives[i][..., 0]
533
+ for i, user_id in enumerate(sorted_user_ids)
534
+ }
535
+ return rl_update_derivatives_by_calendar_t
536
+
537
+
538
+ def calculate_rl_update_derivatives_specific_update(
539
+ rl_update_func,
540
+ rl_update_func_type,
541
+ rl_update_func_args_beta_index,
542
+ rl_update_func_args_action_prob_index,
543
+ rl_update_func_args_action_prob_times_index,
544
+ args_by_user_id,
545
+ sorted_user_ids,
546
+ first_applicable_time,
547
+ ):
548
+ logger.debug(
549
+ "Calculating RL update derivatives for the update that first applies at time %d.",
550
+ first_applicable_time,
551
+ )
552
+ # Get a list of subdicts of the user args dict, with each united by having
553
+ # the same shapes across all arguments. We will then vmap the gradients needed
554
+ # for each subdict separately, and later combine the results. In the worst
555
+ # case we may have a batch per user, if, say, everyone starts on a different
556
+ # date, and this will be slow. If this is problematic we can pad the data
557
+ # with some values that don't affect computations, producing one batch here.
558
+ # This also supports very large simulations by making things fast as long
559
+ # as there is 1 or a small number of shape batches.
560
+ # NOTE: Susan and Kelly think we might actually have uniqueish shapes pretty often
561
+ nontrivial_user_args_grouped_by_shape = group_user_args_by_shape(args_by_user_id)
562
+ logger.debug(
563
+ "Found %d set(s) of users with different arg shapes.",
564
+ len(nontrivial_user_args_grouped_by_shape),
565
+ )
566
+
567
+ # Loop over each set of user args and vmap to get their pi and weight gradients
568
+ in_study_loss_gradients_by_user_id = {}
569
+ in_study_loss_hessians_by_user_id = {}
570
+ in_study_loss_gradient_pi_derivatives_by_user_id = {}
571
+ all_involved_user_ids = set()
572
+ for args_by_user_id_subset in nontrivial_user_args_grouped_by_shape:
573
+ # Pivot the loss args for the involved users into a list of lists, each
574
+ # representing all the args at a particular index across users. Note
575
+ # that users not in the study at this time are filtered out by this
576
+ # function when it checks for truthiness of the supplied args.
577
+ batched_arg_lists, involved_user_ids = (
578
+ get_batched_arg_lists_and_involved_user_ids(
579
+ rl_update_func, sorted_user_ids, args_by_user_id_subset
580
+ )
581
+ )
582
+ all_involved_user_ids |= involved_user_ids
583
+
584
+ if not batched_arg_lists[0]:
585
+ continue
586
+
587
+ logger.debug("Reforming batched data lists into tensors.")
588
+ # Now just transform the previous list of lists into a jnp array for each
589
+ # index (a tensor for each argument). This is for passing to vmap.
590
+ batched_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensor(
591
+ batched_arg_lists
592
+ )
593
+
594
+ logger.debug("Forming loss gradients with respect to beta.")
595
+ in_study_loss_gradients_subset = get_loss_gradients_batched(
596
+ rl_update_func,
597
+ rl_update_func_type,
598
+ rl_update_func_args_beta_index,
599
+ batch_axes,
600
+ *batched_arg_tensors,
601
+ )
602
+
603
+ logger.debug("Forming loss hessians with respect to beta.")
604
+ in_study_loss_hessians_subset = get_loss_hessians_batched(
605
+ rl_update_func,
606
+ rl_update_func_type,
607
+ rl_update_func_args_beta_index,
608
+ batch_axes,
609
+ *batched_arg_tensors,
610
+ )
611
+ logger.debug(
612
+ "Forming derivatives of loss with respect to beta and then the action probabilites vector at each time."
613
+ )
614
+ if rl_update_func_args_action_prob_index >= 0:
615
+ in_study_loss_gradient_pi_derivatives_subset = (
616
+ get_loss_gradient_derivatives_wrt_pi_batched(
617
+ rl_update_func,
618
+ rl_update_func_type,
619
+ rl_update_func_args_beta_index,
620
+ rl_update_func_args_action_prob_index,
621
+ batch_axes,
622
+ *batched_arg_tensors,
623
+ )
624
+ )
625
+ # Collect the gradients for the in-study users in this group into the
626
+ # overall dict.
627
+ in_batch_index = 0
628
+ for user_id in sorted_user_ids:
629
+ if user_id not in involved_user_ids:
630
+ continue
631
+ in_study_loss_gradients_by_user_id[user_id] = (
632
+ in_study_loss_gradients_subset[in_batch_index]
633
+ )
634
+ in_study_loss_hessians_by_user_id[user_id] = in_study_loss_hessians_subset[
635
+ in_batch_index
636
+ ]
637
+ if rl_update_func_args_action_prob_index >= 0:
638
+ in_study_loss_gradient_pi_derivatives_by_user_id[user_id] = (
639
+ pad_loss_gradient_pi_derivative_outside_supplied_action_probabilites(
640
+ in_study_loss_gradient_pi_derivatives_subset[in_batch_index],
641
+ args_by_user_id[user_id][
642
+ rl_update_func_args_action_prob_times_index
643
+ ],
644
+ first_applicable_time,
645
+ )
646
+ )
647
+ in_batch_index += 1
648
+ in_study_loss_gradients = [
649
+ in_study_loss_gradients_by_user_id[user_id]
650
+ for user_id in sorted_user_ids
651
+ if user_id in all_involved_user_ids
652
+ ]
653
+ in_study_loss_hessians = [
654
+ in_study_loss_hessians_by_user_id[user_id]
655
+ for user_id in sorted_user_ids
656
+ if user_id in all_involved_user_ids
657
+ ]
658
+ if rl_update_func_args_action_prob_index >= 0:
659
+ in_study_loss_gradient_pi_derivatives = [
660
+ in_study_loss_gradient_pi_derivatives_by_user_id[user_id]
661
+ for user_id in sorted_user_ids
662
+ if user_id in all_involved_user_ids
663
+ ]
664
+ # TODO: These padding methods assume *someone* had study data at this time.
665
+ loss_gradients = pad_in_study_derivatives_with_zeros(
666
+ in_study_loss_gradients, sorted_user_ids, all_involved_user_ids
667
+ )
668
+ loss_hessians = pad_in_study_derivatives_with_zeros(
669
+ in_study_loss_hessians, sorted_user_ids, all_involved_user_ids
670
+ )
671
+
672
+ # If there is an action probability argument in the RL update function, we need to
673
+ # pad the derivatives calculated already with zeros for those users not currently
674
+ # in the study. Otherwise simply return all zero gradients of the correct shape.
675
+ if rl_update_func_args_action_prob_index >= 0:
676
+ loss_gradient_pi_derivatives = pad_in_study_derivatives_with_zeros(
677
+ in_study_loss_gradient_pi_derivatives,
678
+ sorted_user_ids,
679
+ all_involved_user_ids,
680
+ )
681
+ else:
682
+ num_users = len(sorted_user_ids)
683
+ beta_dim = batched_arg_lists[rl_update_func_args_beta_index][0].size
684
+ timesteps_included = first_applicable_time - 1
685
+
686
+ loss_gradient_pi_derivatives = np.zeros(
687
+ (num_users, beta_dim, timesteps_included, 1)
688
+ )
689
+
690
+ return loss_gradients, loss_hessians, loss_gradient_pi_derivatives
691
+
692
+
693
+ def get_loss_gradients_batched(
694
+ update_func,
695
+ update_func_type,
696
+ update_func_args_beta_index,
697
+ batch_axes,
698
+ *batched_arg_tensors,
699
+ ):
700
+ if update_func_type == FunctionTypes.LOSS:
701
+ return jax.vmap(
702
+ fun=jax.grad(update_func, update_func_args_beta_index),
703
+ in_axes=batch_axes,
704
+ out_axes=0,
705
+ )(*batched_arg_tensors)
706
+ if update_func_type == FunctionTypes.ESTIMATING:
707
+ return jax.vmap(
708
+ fun=update_func,
709
+ in_axes=batch_axes,
710
+ out_axes=0,
711
+ )(*batched_arg_tensors)
712
+ raise ValueError("Unknown update function type.")
713
+
714
+
715
+ def get_loss_hessians_batched(
716
+ update_func,
717
+ update_func_type,
718
+ update_func_args_beta_index,
719
+ batch_axes,
720
+ *batched_arg_tensors,
721
+ ):
722
+ if update_func_type == FunctionTypes.LOSS:
723
+ return jax.vmap(
724
+ fun=jax.hessian(update_func, update_func_args_beta_index),
725
+ in_axes=batch_axes,
726
+ out_axes=0,
727
+ )(*batched_arg_tensors)
728
+ if update_func_type == FunctionTypes.ESTIMATING:
729
+ return jax.vmap(
730
+ fun=jax.jacrev(update_func, update_func_args_beta_index),
731
+ in_axes=batch_axes,
732
+ out_axes=0,
733
+ )(*batched_arg_tensors)
734
+ raise ValueError("Unknown update function type.")
735
+
736
+
737
+ def get_loss_gradient_derivatives_wrt_pi_batched(
738
+ update_func,
739
+ update_func_type,
740
+ update_func_args_beta_index,
741
+ update_func_args_action_prob_index,
742
+ batch_axes,
743
+ *batched_arg_tensors,
744
+ ):
745
+ if update_func_type == FunctionTypes.LOSS:
746
+ return jax.jit(
747
+ jax.vmap(
748
+ fun=jax.jacrev(
749
+ jax.grad(update_func, update_func_args_beta_index),
750
+ update_func_args_action_prob_index,
751
+ ),
752
+ in_axes=batch_axes,
753
+ out_axes=0,
754
+ )
755
+ )(*batched_arg_tensors)
756
+ if update_func_type == FunctionTypes.ESTIMATING:
757
+ return jax.jit(
758
+ jax.vmap(
759
+ fun=jax.jacrev(
760
+ update_func,
761
+ update_func_args_action_prob_index,
762
+ ),
763
+ in_axes=batch_axes,
764
+ out_axes=0,
765
+ )
766
+ )(*batched_arg_tensors)
767
+ raise ValueError("Unknown update function type.")
768
+
769
+
770
+ # TODO: Is there a better way to calculate this? This seems like it should
771
+ # be reliable, not messing up when a policy was actually available. If study
772
+ # df says policy was used, that should be correct. May not play nicely with
773
+ # pure exploration phase though.
774
+ def get_first_applicable_time(
775
+ study_df, policy_num, policy_num_col_name, calendar_t_col_name
776
+ ):
777
+ return study_df[study_df[policy_num_col_name] == policy_num][
778
+ calendar_t_col_name
779
+ ].min()
780
+
781
+
782
+ def calculate_inference_loss_derivatives(
783
+ study_df,
784
+ theta_est,
785
+ inference_func,
786
+ inference_func_args_theta_index,
787
+ user_ids,
788
+ user_id_col_name,
789
+ action_prob_col_name,
790
+ in_study_col_name,
791
+ calendar_t_col_name,
792
+ inference_func_type=FunctionTypes.LOSS,
793
+ ):
794
+ logger.debug("Calculating inference loss derivatives.")
795
+
796
+ # Convert to list if needed (from jnp array, etc)
797
+ try:
798
+ user_ids = user_ids.tolist()
799
+ except Exception:
800
+ pass
801
+
802
+ num_args = inference_func.__code__.co_argcount
803
+ inference_func_arg_names = inference_func.__code__.co_varnames[:num_args]
804
+ # NOTE: Cannot do [[]] * num_args here! Then all lists point
805
+ # same object...
806
+ batched_arg_lists = [[] for _ in range(num_args)]
807
+
808
+ # We begin by constructing a dict of loss function arg tuples of the type we get from file
809
+ # for the RL data; because we have to group user args by shape anyway, we
810
+ # might as well collect them in this format and then use previous machinery
811
+ # to process them. There are a few extra loops but more shared code this way.
812
+ args_by_user_id = {}
813
+ using_action_probs = action_prob_col_name in inference_func_arg_names
814
+ if using_action_probs:
815
+ inference_func_args_action_prob_index = inference_func_arg_names.index(
816
+ action_prob_col_name
817
+ )
818
+ action_prob_decision_times_by_user_id = {}
819
+ max_calendar_time = study_df[calendar_t_col_name].max()
820
+ for user_id in user_ids:
821
+ user_args_list = []
822
+ filtered_user_data = study_df.loc[study_df[user_id_col_name] == user_id]
823
+ for idx, col_name in enumerate(inference_func_arg_names):
824
+ if idx == inference_func_args_theta_index:
825
+ user_args_list.append(theta_est)
826
+ else:
827
+ user_args_list.append(
828
+ get_study_df_column(filtered_user_data, col_name, in_study_col_name)
829
+ )
830
+ args_by_user_id[user_id] = tuple(user_args_list)
831
+ if using_action_probs:
832
+ action_prob_decision_times_by_user_id[user_id] = get_study_df_column(
833
+ filtered_user_data, calendar_t_col_name, in_study_col_name
834
+ )
835
+
836
+ # Get a list of subdicts of the user args dict, with each united by having
837
+ # the same shapes across all arguments. We will then vmap the gradients needed
838
+ # for each subdict separately, and later combine the results. In the worst
839
+ # case we may have a batch per user, if, say, everyone starts on a different
840
+ # date, and this will be slow. If this is problematic we can pad the data
841
+ # with some values that don't affect computations, producing one batch here.
842
+ # This also supports very large simulations by making things fast as long
843
+ # as there is 1 or a small number of shape batches.
844
+ # NOTE: As opposed to the RL updates, we should expect a small number of
845
+ # batches here. It is only users having different numbers of decision times
846
+ # that contributes additional batches.
847
+ nontrivial_user_args_grouped_by_shape = group_user_args_by_shape(
848
+ args_by_user_id, empty_allowed=False
849
+ )
850
+ logger.debug(
851
+ "Found %d set(s) of users with different arg shapes.",
852
+ len(nontrivial_user_args_grouped_by_shape),
853
+ )
854
+
855
+ loss_gradients_by_user_id = {}
856
+ loss_hessians_by_user_id = {}
857
+ loss_gradient_pi_derivatives_by_user_id = {}
858
+ all_involved_user_ids = set()
859
+ sorted_user_ids = sorted(user_ids)
860
+ for args_by_user_id_subset in nontrivial_user_args_grouped_by_shape:
861
+ batched_arg_lists, involved_user_ids = (
862
+ get_batched_arg_lists_and_involved_user_ids(
863
+ inference_func, sorted_user_ids, args_by_user_id_subset
864
+ )
865
+ )
866
+ all_involved_user_ids |= involved_user_ids
867
+
868
+ logger.debug("Reforming batched data lists into tensors.")
869
+ batched_arg_tensors, batch_axes = stack_batched_arg_lists_into_tensor(
870
+ batched_arg_lists
871
+ )
872
+
873
+ logger.debug("Forming loss gradients with respect to theta.")
874
+ loss_gradients_subset = get_loss_gradients_batched(
875
+ inference_func,
876
+ inference_func_type,
877
+ inference_func_args_theta_index,
878
+ batch_axes,
879
+ *batched_arg_tensors,
880
+ )
881
+
882
+ logger.debug("Forming loss hessians with respect to theta.")
883
+ loss_hessians_subset = get_loss_hessians_batched(
884
+ inference_func,
885
+ inference_func_type,
886
+ inference_func_args_theta_index,
887
+ batch_axes,
888
+ *batched_arg_tensors,
889
+ )
890
+ logger.debug(
891
+ "Forming derivatives of loss with respect to theta and then the action probabilities vector at each time."
892
+ )
893
+ # If there is an action probability argument in the loss,
894
+ # actually differentiate with respect to action probabilities
895
+ if using_action_probs:
896
+ loss_gradient_pi_derivatives_subset = (
897
+ get_loss_gradient_derivatives_wrt_pi_batched(
898
+ inference_func,
899
+ inference_func_type,
900
+ inference_func_args_theta_index,
901
+ inference_func_args_action_prob_index,
902
+ batch_axes,
903
+ *batched_arg_tensors,
904
+ )
905
+ )
906
+ # Collect the gradients for the in-study users in this group into the
907
+ # overall dict.
908
+ in_batch_index = 0
909
+ for user_id in sorted_user_ids:
910
+ if user_id not in involved_user_ids:
911
+ continue
912
+ loss_gradients_by_user_id[user_id] = loss_gradients_subset[in_batch_index]
913
+ loss_hessians_by_user_id[user_id] = loss_hessians_subset[in_batch_index]
914
+ if using_action_probs:
915
+ loss_gradient_pi_derivatives_by_user_id[user_id] = (
916
+ pad_loss_gradient_pi_derivative_outside_supplied_action_probabilites(
917
+ loss_gradient_pi_derivatives_subset[in_batch_index],
918
+ action_prob_decision_times_by_user_id[user_id],
919
+ max_calendar_time + 1,
920
+ )
921
+ )
922
+ in_batch_index += 1
923
+ loss_gradients = np.array(
924
+ [
925
+ loss_gradients_by_user_id[user_id]
926
+ for user_id in sorted_user_ids
927
+ if user_id in all_involved_user_ids
928
+ ]
929
+ )
930
+ loss_hessians = np.array(
931
+ [
932
+ loss_hessians_by_user_id[user_id]
933
+ for user_id in sorted_user_ids
934
+ if user_id in all_involved_user_ids
935
+ ]
936
+ )
937
+ # If using action probs, collect the mixed theta pi derivatives computed
938
+ # so far.
939
+ if using_action_probs:
940
+ loss_gradient_pi_derivatives = np.array(
941
+ [
942
+ loss_gradient_pi_derivatives_by_user_id[user_id]
943
+ for user_id in sorted_user_ids
944
+ if user_id in all_involved_user_ids
945
+ ]
946
+ )
947
+ # Otherwise, we need to simply return zero gradients of the correct shape.
948
+ else:
949
+ num_users = len(user_ids)
950
+ theta_dim = theta_est.size
951
+ timesteps_included = study_df[calendar_t_col_name].nunique()
952
+
953
+ loss_gradient_pi_derivatives = np.zeros(
954
+ (num_users, theta_dim, timesteps_included, 1)
955
+ )
956
+
957
+ return loss_gradients, loss_hessians, loss_gradient_pi_derivatives
958
+
959
+
960
+ def get_study_df_column(study_df, col_name, in_study_col_name):
961
+ return jnp.array(
962
+ study_df.loc[study_df[in_study_col_name] == 1, col_name]
963
+ .to_numpy()
964
+ .reshape(-1, 1)
965
+ )