lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,587 @@
1
+ from __future__ import annotations
2
+
3
+ import collections
4
+ import os
5
+ import importlib.util
6
+ import importlib.machinery
7
+ import logging
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import jax.numpy as jnp
12
+ import pandas as pd
13
+
14
+ from .constants import InverseStabilizationMethods
15
+
16
+ logger = logging.getLogger(__name__)
17
+ logging.basicConfig(
18
+ format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
19
+ datefmt="%Y-%m-%d:%H:%M:%S",
20
+ level=logging.INFO,
21
+ )
22
+
23
+
24
+ def conditional_x_or_one_minus_x(x, condition):
25
+ return (1 - condition) + (2 * condition - 1) * x
26
+
27
+
28
+ def invert_matrix_and_check_conditioning(
29
+ matrix: np.ndarray,
30
+ inverse_stabilization_method: str = InverseStabilizationMethods.NONE,
31
+ condition_num_threshold: float = 10**4,
32
+ ridge_median_singular_value_fraction: str = 0.01,
33
+ beta_dim: int = None,
34
+ theta_dim: int = None,
35
+ ):
36
+ """
37
+ Check a matrix's condition number and invert it. If the condition number is
38
+ above a threshold, apply stabilization methods to improve conditioning.
39
+ Parameters
40
+ """
41
+ inverse = None
42
+ pre_inversion_condition_number = np.linalg.cond(matrix)
43
+ if pre_inversion_condition_number > condition_num_threshold:
44
+ logger.warning(
45
+ "You are inverting a matrix with a large condition number: %s",
46
+ pre_inversion_condition_number,
47
+ )
48
+ if (
49
+ inverse_stabilization_method
50
+ == InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES
51
+ ):
52
+ logger.info("Trimming small singular values to improve conditioning.")
53
+ u, s, vT = np.linalg.svd(matrix, full_matrices=False)
54
+ logger.info(
55
+ " Sorted singular values: %s",
56
+ s,
57
+ )
58
+ sing_values_above_threshold_cond = s > s.max() / condition_num_threshold
59
+ if not np.any(sing_values_above_threshold_cond):
60
+ raise RuntimeError(
61
+ f"All singular values are below the threshold of {s.max() / condition_num_threshold}. Singular value trimming will not work.",
62
+ )
63
+ trimmed_pseudoinverse = (
64
+ vT.T[:, sing_values_above_threshold_cond]
65
+ / s[sing_values_above_threshold_cond]
66
+ ) @ u[:, sing_values_above_threshold_cond].T
67
+ inverse = trimmed_pseudoinverse
68
+ pre_inversion_condition_number = (
69
+ s[sing_values_above_threshold_cond].max()
70
+ / s[sing_values_above_threshold_cond].min()
71
+ )
72
+
73
+ logger.info(
74
+ "Kept %s out of %s singular values. Condition number of resulting lower-rank-approximation before inversion: %s",
75
+ sum(sing_values_above_threshold_cond),
76
+ len(s),
77
+ pre_inversion_condition_number,
78
+ )
79
+ elif (
80
+ inverse_stabilization_method
81
+ == InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER
82
+ ):
83
+ logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
84
+ _, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
85
+ logger.info(
86
+ "Using fixed condition number threshold of %s to determine lambda.",
87
+ condition_num_threshold,
88
+ )
89
+ lambda_ = (
90
+ singular_values.max() / condition_num_threshold - singular_values.min()
91
+ )
92
+ logger.info("Lambda for ridge regularization: %s", lambda_)
93
+ new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
94
+ pre_inversion_condition_number = np.linalg.cond(new_matrix)
95
+ logger.info(
96
+ "Condition number of matrix after ridge regularization: %s",
97
+ pre_inversion_condition_number,
98
+ )
99
+ inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
100
+ elif (
101
+ inverse_stabilization_method
102
+ == InverseStabilizationMethods.ADD_RIDGE_MEDIAN_SINGULAR_VALUE_FRACTION
103
+ ):
104
+ logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
105
+ _, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
106
+ logger.info(
107
+ "Using median singular value times %s as lambda.",
108
+ ridge_median_singular_value_fraction,
109
+ )
110
+ lambda_ = ridge_median_singular_value_fraction * np.median(singular_values)
111
+ logger.info("Lambda for ridge regularization: %s", lambda_)
112
+ new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
113
+ pre_inversion_condition_number = np.linalg.cond(new_matrix)
114
+ logger.info(
115
+ "Condition number of matrix after ridge regularization: %s",
116
+ pre_inversion_condition_number,
117
+ )
118
+ inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
119
+ elif (
120
+ inverse_stabilization_method
121
+ == InverseStabilizationMethods.INVERSE_BREAD_STRUCTURE_AWARE_INVERSION
122
+ ):
123
+ if not beta_dim or not theta_dim:
124
+ raise ValueError(
125
+ "When using structure-aware inversion, beta_dim and theta_dim must be provided."
126
+ )
127
+ logger.info(
128
+ "Using inverse bread's block lower triangular structure to invert only diagonal blocks."
129
+ )
130
+ pre_inversion_condition_number = np.linalg.cond(matrix)
131
+ inverse = invert_inverse_bread_matrix(
132
+ matrix,
133
+ beta_dim,
134
+ theta_dim,
135
+ InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER,
136
+ )
137
+ elif (
138
+ inverse_stabilization_method
139
+ == InverseStabilizationMethods.ZERO_OUT_SMALL_OFF_DIAGONALS
140
+ ):
141
+ if not beta_dim or not theta_dim:
142
+ raise ValueError(
143
+ "When zeroing out small off diagonals, beta_dim and theta_dim must be provided."
144
+ )
145
+ logger.info(
146
+ "Zeroing out small off-diagonal blocks to improve conditioning."
147
+ )
148
+ zeroed_matrix = zero_small_off_diagonal_blocks(
149
+ matrix,
150
+ ([beta_dim] * (matrix.shape[0] // beta_dim)) + [theta_dim],
151
+ )
152
+ pre_inversion_condition_number = np.linalg.cond(zeroed_matrix)
153
+ logger.info(
154
+ "Condition number of matrix after zeroing out small off-diagonal blocks: %s",
155
+ pre_inversion_condition_number,
156
+ )
157
+ inverse = np.linalg.solve(zeroed_matrix, np.eye(zeroed_matrix.shape[0]))
158
+ elif (
159
+ inverse_stabilization_method
160
+ == InverseStabilizationMethods.ALL_METHODS_COMPETITION
161
+ ):
162
+ # TODO: Choose right metric for competition... identity diff might not be it.
163
+ raise NotImplementedError(
164
+ "All methods competition is not implemented yet. Please choose a specific method."
165
+ )
166
+ elif inverse_stabilization_method == InverseStabilizationMethods.NONE:
167
+ logger.info("No inverse stabilization method applied. Inverting directly.")
168
+ else:
169
+ raise ValueError(
170
+ f"Unknown inverse stabilization method: {inverse_stabilization_method}"
171
+ )
172
+ if inverse is None:
173
+ inverse = np.linalg.solve(matrix, np.eye(matrix.shape[0]))
174
+ return inverse, pre_inversion_condition_number
175
+
176
+
177
+ def zero_small_off_diagonal_blocks(
178
+ matrix: jnp.ndarray,
179
+ block_sizes: list[int],
180
+ frobenius_norm_threshold_fraction: float = 1e-3,
181
+ ):
182
+ """
183
+ Zero off-diagonal blocks whose Frobenius norm is < frobenius_norm_threshold_fraction x
184
+ Frobenius norm of the diagonal block in the same ROW. One could compare to
185
+ the same column or both the row and column, but we choose row here since
186
+ rows correspond to a single RL update or inference step in the adaptive bread
187
+ inverse matrices this method is designed for.
188
+
189
+ Args:
190
+ matrix (jnp.ndarray):
191
+ 2-D ndarray, square (q_total x q_total)
192
+ block_sizes (list[int]):
193
+ list like [p1, p2, ..., pT]
194
+ frobenius_norm_threshold_fraction (float):
195
+ frobenius norm fraction relative to same-row diagonal block under which we zero a block
196
+
197
+ Returns
198
+ ndarray with selected off-blocks zeroed
199
+ """
200
+
201
+ bounds = np.cumsum([0] + list(block_sizes))
202
+ num_block_rows_cols = len(block_sizes)
203
+ J_trim = matrix.copy()
204
+
205
+ # 1. collect Frobenius norms of every diagonal block in one pass
206
+ diag_norm = np.empty(num_block_rows_cols)
207
+ for t in range(num_block_rows_cols):
208
+ sl = slice(bounds[t], bounds[t + 1])
209
+ diag_norm[t] = np.linalg.norm(matrix[sl, sl], ord="fro")
210
+
211
+ # 2. Zero all sufficiently small off-diagonal blocks
212
+ for t in range(num_block_rows_cols):
213
+ source_norm = diag_norm[t]
214
+ r0, r1 = bounds[t], bounds[t + 1] # rows belonging to block t
215
+
216
+ # rows BELOW the diagonal (lower-triangular part)
217
+ for tau in range(t + 1, num_block_rows_cols):
218
+ c0, c1 = bounds[tau], bounds[tau + 1]
219
+ block = J_trim[r0:r1, c0:c1]
220
+ block_norm = np.linalg.norm(block, ord="fro")
221
+ if (
222
+ block_norm
223
+ and block_norm < frobenius_norm_threshold_fraction * source_norm
224
+ ):
225
+ logger.info(
226
+ "Zeroing out block [%s:%s, %s:%s] with Frobenius norm %s < %s * %s",
227
+ r0,
228
+ r1,
229
+ c0,
230
+ c1,
231
+ block_norm,
232
+ frobenius_norm_threshold_fraction,
233
+ source_norm,
234
+ )
235
+ J_trim = J_trim.at[r0:r1, c0:c1].set(0.0)
236
+
237
+ return J_trim
238
+
239
+
240
+ def invert_inverse_bread_matrix(
241
+ inverse_bread,
242
+ beta_dim,
243
+ theta_dim,
244
+ diag_inverse_stabilization_method=InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES,
245
+ ):
246
+ """
247
+ Invert the inverse bread matrix to get the bread matrix. This is a special
248
+ function in order to take advantage of the block lower triangular structure.
249
+
250
+ The procedure is as follows:
251
+ 1. Initialize the inverse matrix B = A^{-1} as a block lower triangular matrix
252
+ with the same block structure as A.
253
+
254
+ 2. Compute the diagonal blocks B_{ii}:
255
+ For each diagonal block A_{ii}, calculate:
256
+ B_{ii} = A_{ii}^{-1}
257
+
258
+ 3. Compute the off-diagonal blocks B_{ij} for i > j:
259
+ For each off-diagonal block B_{ij} (where i > j), compute:
260
+ B_{ij} = -A_{ii}^{-1} * sum(A_{ik} * B_{kj} for k in range(j, i))
261
+ """
262
+ blocks = []
263
+ num_beta_block_rows = (inverse_bread.shape[0] - theta_dim) // beta_dim
264
+
265
+ # Create upper rows of block of bread (just the beta portion)
266
+ for i in range(0, num_beta_block_rows):
267
+ beta_block_row = []
268
+ beta_diag_inverse = invert_matrix_and_check_conditioning(
269
+ inverse_bread[
270
+ beta_dim * i : beta_dim * (i + 1),
271
+ beta_dim * i : beta_dim * (i + 1),
272
+ ],
273
+ diag_inverse_stabilization_method,
274
+ )[0]
275
+ for j in range(0, num_beta_block_rows):
276
+ if i > j:
277
+ beta_block_row.append(
278
+ -beta_diag_inverse
279
+ @ sum(
280
+ inverse_bread[
281
+ beta_dim * i : beta_dim * (i + 1),
282
+ beta_dim * k : beta_dim * (k + 1),
283
+ ]
284
+ @ blocks[k][j]
285
+ for k in range(j, i)
286
+ )
287
+ )
288
+ elif i == j:
289
+ beta_block_row.append(beta_diag_inverse)
290
+ else:
291
+ beta_block_row.append(np.zeros((beta_dim, beta_dim)).astype(np.float32))
292
+
293
+ # Extra beta * theta zero block. This is the last block of the row.
294
+ # Any other zeros in the row have already been handled above.
295
+ beta_block_row.append(np.zeros((beta_dim, theta_dim)))
296
+
297
+ blocks.append(beta_block_row)
298
+
299
+ # Create the bottom block row of bread (the theta portion)
300
+ theta_block_row = []
301
+ theta_diag_inverse = invert_matrix_and_check_conditioning(
302
+ inverse_bread[
303
+ -theta_dim:,
304
+ -theta_dim:,
305
+ ],
306
+ diag_inverse_stabilization_method,
307
+ )[0]
308
+ for k in range(0, num_beta_block_rows):
309
+ theta_block_row.append(
310
+ -theta_diag_inverse
311
+ @ sum(
312
+ inverse_bread[
313
+ -theta_dim:,
314
+ beta_dim * h : beta_dim * (h + 1),
315
+ ]
316
+ @ blocks[h][k]
317
+ for h in range(k, num_beta_block_rows)
318
+ )
319
+ )
320
+
321
+ theta_block_row.append(theta_diag_inverse)
322
+ blocks.append(theta_block_row)
323
+
324
+ return np.block(blocks)
325
+
326
+
327
+ def matrix_inv_sqrt(mat: np.ndarray, eps: float = 1e-12) -> np.ndarray:
328
+ """Return (mat)^{-1/2} with eigenvalues clipped at `eps`."""
329
+ eigval, eigvec = np.linalg.eigh(mat)
330
+ eigval = np.clip(eigval, eps, None) # ensure strictly positive
331
+ return eigvec @ np.diag(eigval**-0.5) @ eigvec.T
332
+
333
+
334
+ def load_module_from_source_file(modname, filename):
335
+ loader = importlib.machinery.SourceFileLoader(modname, filename)
336
+ spec = importlib.util.spec_from_file_location(modname, filename, loader=loader)
337
+ module = importlib.util.module_from_spec(spec)
338
+ # The module is always executed and not cached in sys.modules.
339
+ # Uncomment the following line to cache the module.
340
+ # sys.modules[module.__name__] = module
341
+ loader.exec_module(module)
342
+ return module
343
+
344
+
345
+ def load_function_from_same_named_file(filename):
346
+ module = load_module_from_source_file(filename, filename)
347
+ try:
348
+ return module.__dict__[os.path.basename(filename).split(".")[0]]
349
+ except AttributeError as e:
350
+ raise ValueError(
351
+ f"Unable to import function from {filename}. Please verify the file has the same name as the function of interest (ignoring the extension)."
352
+ ) from e
353
+ except KeyError as e:
354
+ raise ValueError(
355
+ f"Unable to import function from {filename}. Please verify the file has the same name as the function of interest (ignoring the extension)."
356
+ ) from e
357
+
358
+
359
+ def confirm_input_check_result(message, suppress_interaction, error=None):
360
+
361
+ if suppress_interaction:
362
+ logger.info(
363
+ "Skipping the following interactive data check, as requested:\n%s", message
364
+ )
365
+ return
366
+ answer = None
367
+ while answer != "y":
368
+ # pylint: disable=bad-builtin
369
+ answer = input(message).lower()
370
+ # pylint: enable=bad-builtin
371
+ if answer == "y":
372
+ print("\nOk, proceeding.\n")
373
+ elif answer == "n":
374
+ if error:
375
+ raise SystemExit from error
376
+ raise SystemExit
377
+ else:
378
+ print("\nPlease enter 'y' or 'n'.\n")
379
+
380
+
381
+ def get_in_study_df_column(study_df, col_name, in_study_col_name):
382
+ return jnp.array(
383
+ study_df.loc[study_df[in_study_col_name] == 1, col_name]
384
+ .to_numpy()
385
+ .reshape(-1, 1)
386
+ )
387
+
388
+
389
+ def flatten_params(betas: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray:
390
+ return jnp.concatenate(list(betas) + [theta])
391
+
392
+
393
+ def unflatten_params(
394
+ flat: jnp.ndarray, beta_dim: int, theta_dim: int
395
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
396
+ theta = flat[-theta_dim:]
397
+ betas = jnp.array(
398
+ [
399
+ flat[i * beta_dim : (i + 1) * beta_dim]
400
+ for i in range((len(flat) - theta_dim) // beta_dim)
401
+ ]
402
+ )
403
+ return betas, theta
404
+
405
+
406
+ def get_radon_nikodym_weight(
407
+ beta_target: jnp.ndarray[jnp.float32],
408
+ action_prob_func: callable,
409
+ action_prob_func_args_beta_index: int,
410
+ action: int,
411
+ *action_prob_func_args_single_user: tuple[Any, ...],
412
+ ):
413
+ """
414
+ Computes a ratio of action probabilities under two sets of algorithm parameters:
415
+ in the denominator, beta_target is substituted in with the the rest of the supplied action
416
+ probability function arguments, and in the numerator the original value is used. The action
417
+ actually taken at the relevant decision time is also supplied, which is used to determine
418
+ whether to use action 1 probabilities or action 0 probabilities in the ratio.
419
+
420
+ Even though in practice we call this in such a way that the beta value is the same in numerator
421
+ and denominator, it is important to define the function this way so that differentiation, which
422
+ is with respect to the numerator beta, is done correctly.
423
+
424
+ Args:
425
+ beta_target (jnp.ndarray[jnp.float32]):
426
+ The beta value to use in the denominator. NOT involved in differentation!
427
+ action_prob_func (callable):
428
+ The function used to compute the probability of action 1 at a given decision time for
429
+ a particular user given their state and the algorithm parameters.
430
+ action_prob_func_args_beta_index (int):
431
+ The index of the beta argument in the action probability function's arguments.
432
+ action (int):
433
+ The actual taken action at the relevant decision time.
434
+ *action_prob_func_args_single_user (tuple[Any, ...]):
435
+ The arguments to the action probability function for the relevant user at this time.
436
+
437
+ Returns:
438
+ jnp.float32: The Radon-Nikodym weight.
439
+
440
+ """
441
+
442
+ # numerator
443
+ pi_beta = action_prob_func(*action_prob_func_args_single_user)
444
+
445
+ # denominator, where we thread in beta_target so that differentiation with respect to the
446
+ # original beta in the arguments leaves this alone.
447
+ beta_target_action_prob_func_args_single_user = [*action_prob_func_args_single_user]
448
+ beta_target_action_prob_func_args_single_user[action_prob_func_args_beta_index] = (
449
+ beta_target
450
+ )
451
+ pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_user)
452
+
453
+ return conditional_x_or_one_minus_x(pi_beta, action) / conditional_x_or_one_minus_x(
454
+ pi_beta_target, action
455
+ )
456
+
457
+
458
+ def get_min_time_by_policy_num(
459
+ single_user_policy_num_by_decision_time, beta_index_by_policy_num
460
+ ):
461
+ """
462
+ Returns a dictionary mapping each policy number to the first time it was applicable,
463
+ and the first time after the first update.
464
+ """
465
+ min_time_by_policy_num = {}
466
+ first_time_after_first_update = None
467
+ for decision_time, policy_num in single_user_policy_num_by_decision_time.items():
468
+ if policy_num not in min_time_by_policy_num:
469
+ min_time_by_policy_num[policy_num] = decision_time
470
+
471
+ # Grab the first time where a non-initial, non-fallback policy is used.
472
+ # Assumes single_user_policy_num_by_decision_time is sorted.
473
+ if (
474
+ policy_num in beta_index_by_policy_num
475
+ and first_time_after_first_update is None
476
+ ):
477
+ first_time_after_first_update = decision_time
478
+
479
+ return min_time_by_policy_num, first_time_after_first_update
480
+
481
+
482
+ def calculate_beta_dim(
483
+ action_prob_func_args: dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]],
484
+ action_prob_func_args_beta_index: int,
485
+ ) -> int:
486
+ """
487
+ Calculates the dimension of the beta vector based on the action probability function arguments.
488
+
489
+ Args:
490
+ action_prob_func_args (dict): Dictionary containing the action probability function arguments.
491
+ action_prob_func_args_beta_index (int): Index of the beta parameter in the action probability function arguments.
492
+
493
+ Returns:
494
+ int: The dimension of the beta vector.
495
+ """
496
+ for decision_time in action_prob_func_args:
497
+ for user_id in action_prob_func_args[decision_time]:
498
+ if action_prob_func_args[decision_time][user_id]:
499
+ return len(
500
+ action_prob_func_args[decision_time][user_id][
501
+ action_prob_func_args_beta_index
502
+ ]
503
+ )
504
+ raise ValueError(
505
+ "No valid beta vector found in action probability function arguments. Please check the input data."
506
+ )
507
+
508
+
509
+ def construct_beta_index_by_policy_num_map(
510
+ study_df: pd.DataFrame, policy_num_col_name: str, in_study_col_name: str
511
+ ) -> tuple[dict[int | float, int], int | float]:
512
+ """
513
+ Constructs a mapping from non-initial, non-fallback policy numbers to the index of the
514
+ corresponding beta in all_post_update_betas.
515
+
516
+ This is useful because differentiating the stacked estimating functions with respect to all the
517
+ betas is simplest if they are passed in a single list. This auxiliary data then allows us to
518
+ route the right beta to the right policy number at each time.
519
+
520
+ If we really keep the enforcement of consecutive policy numbers, we don't actually need all
521
+ this logic and can just pass around the initial policy number, but I'd like to have this
522
+ handle the merely increasing (non-fallback) case even though upstream we currently do require no
523
+ gaps.
524
+ """
525
+
526
+ unique_sorted_non_fallback_policy_nums = sorted(
527
+ study_df[
528
+ (study_df[policy_num_col_name] >= 0) & (study_df[in_study_col_name] == 1)
529
+ ][policy_num_col_name]
530
+ .unique()
531
+ .tolist()
532
+ )
533
+ # This assumes only the first policy is an initial policy not produced by an update.
534
+ # Hence the [1:] slice.
535
+ return {
536
+ policy_num: i
537
+ for i, policy_num in enumerate(unique_sorted_non_fallback_policy_nums[1:])
538
+ }, unique_sorted_non_fallback_policy_nums[0]
539
+
540
+
541
+ def collect_all_post_update_betas(
542
+ beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
543
+ ):
544
+ """
545
+ Collects all betas produced by the algorithm updates in an ordered list.
546
+
547
+ This data structure is chosen because it makes for the most convenient
548
+ differentiation of the stacked estimating functions with respect to all the
549
+ betas. Otherwise a dictionary keyed on policy number would be more natural.
550
+ """
551
+ all_post_update_betas = []
552
+ for policy_num in sorted(beta_index_by_policy_num.keys()):
553
+ for user_id in alg_update_func_args[policy_num]:
554
+ if alg_update_func_args[policy_num][user_id]:
555
+ all_post_update_betas.append(
556
+ alg_update_func_args[policy_num][user_id][
557
+ alg_update_func_args_beta_index
558
+ ]
559
+ )
560
+ break
561
+ return jnp.array(all_post_update_betas)
562
+
563
+
564
+ def extract_action_and_policy_by_decision_time_by_user_id(
565
+ study_df,
566
+ user_id_col_name,
567
+ in_study_col_name,
568
+ calendar_t_col_name,
569
+ action_col_name,
570
+ policy_num_col_name,
571
+ ):
572
+ action_by_decision_time_by_user_id = {}
573
+ policy_num_by_decision_time_by_user_id = {}
574
+ for user_id, user_df in study_df.groupby(user_id_col_name):
575
+ in_study_user_df = user_df[user_df[in_study_col_name] == 1]
576
+ action_by_decision_time_by_user_id[user_id] = dict(
577
+ zip(
578
+ in_study_user_df[calendar_t_col_name], in_study_user_df[action_col_name]
579
+ )
580
+ )
581
+ policy_num_by_decision_time_by_user_id[user_id] = dict(
582
+ zip(
583
+ in_study_user_df[calendar_t_col_name],
584
+ in_study_user_df[policy_num_col_name],
585
+ )
586
+ )
587
+ return action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id