lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
import os
|
|
5
|
+
import importlib.util
|
|
6
|
+
import importlib.machinery
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from .constants import InverseStabilizationMethods
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
logging.basicConfig(
|
|
18
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
19
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
20
|
+
level=logging.INFO,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def conditional_x_or_one_minus_x(x, condition):
|
|
25
|
+
return (1 - condition) + (2 * condition - 1) * x
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def invert_matrix_and_check_conditioning(
|
|
29
|
+
matrix: np.ndarray,
|
|
30
|
+
inverse_stabilization_method: str = InverseStabilizationMethods.NONE,
|
|
31
|
+
condition_num_threshold: float = 10**4,
|
|
32
|
+
ridge_median_singular_value_fraction: str = 0.01,
|
|
33
|
+
beta_dim: int = None,
|
|
34
|
+
theta_dim: int = None,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Check a matrix's condition number and invert it. If the condition number is
|
|
38
|
+
above a threshold, apply stabilization methods to improve conditioning.
|
|
39
|
+
Parameters
|
|
40
|
+
"""
|
|
41
|
+
inverse = None
|
|
42
|
+
pre_inversion_condition_number = np.linalg.cond(matrix)
|
|
43
|
+
if pre_inversion_condition_number > condition_num_threshold:
|
|
44
|
+
logger.warning(
|
|
45
|
+
"You are inverting a matrix with a large condition number: %s",
|
|
46
|
+
pre_inversion_condition_number,
|
|
47
|
+
)
|
|
48
|
+
if (
|
|
49
|
+
inverse_stabilization_method
|
|
50
|
+
== InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES
|
|
51
|
+
):
|
|
52
|
+
logger.info("Trimming small singular values to improve conditioning.")
|
|
53
|
+
u, s, vT = np.linalg.svd(matrix, full_matrices=False)
|
|
54
|
+
logger.info(
|
|
55
|
+
" Sorted singular values: %s",
|
|
56
|
+
s,
|
|
57
|
+
)
|
|
58
|
+
sing_values_above_threshold_cond = s > s.max() / condition_num_threshold
|
|
59
|
+
if not np.any(sing_values_above_threshold_cond):
|
|
60
|
+
raise RuntimeError(
|
|
61
|
+
f"All singular values are below the threshold of {s.max() / condition_num_threshold}. Singular value trimming will not work.",
|
|
62
|
+
)
|
|
63
|
+
trimmed_pseudoinverse = (
|
|
64
|
+
vT.T[:, sing_values_above_threshold_cond]
|
|
65
|
+
/ s[sing_values_above_threshold_cond]
|
|
66
|
+
) @ u[:, sing_values_above_threshold_cond].T
|
|
67
|
+
inverse = trimmed_pseudoinverse
|
|
68
|
+
pre_inversion_condition_number = (
|
|
69
|
+
s[sing_values_above_threshold_cond].max()
|
|
70
|
+
/ s[sing_values_above_threshold_cond].min()
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
logger.info(
|
|
74
|
+
"Kept %s out of %s singular values. Condition number of resulting lower-rank-approximation before inversion: %s",
|
|
75
|
+
sum(sing_values_above_threshold_cond),
|
|
76
|
+
len(s),
|
|
77
|
+
pre_inversion_condition_number,
|
|
78
|
+
)
|
|
79
|
+
elif (
|
|
80
|
+
inverse_stabilization_method
|
|
81
|
+
== InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER
|
|
82
|
+
):
|
|
83
|
+
logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
|
|
84
|
+
_, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
|
|
85
|
+
logger.info(
|
|
86
|
+
"Using fixed condition number threshold of %s to determine lambda.",
|
|
87
|
+
condition_num_threshold,
|
|
88
|
+
)
|
|
89
|
+
lambda_ = (
|
|
90
|
+
singular_values.max() / condition_num_threshold - singular_values.min()
|
|
91
|
+
)
|
|
92
|
+
logger.info("Lambda for ridge regularization: %s", lambda_)
|
|
93
|
+
new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
|
|
94
|
+
pre_inversion_condition_number = np.linalg.cond(new_matrix)
|
|
95
|
+
logger.info(
|
|
96
|
+
"Condition number of matrix after ridge regularization: %s",
|
|
97
|
+
pre_inversion_condition_number,
|
|
98
|
+
)
|
|
99
|
+
inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
|
|
100
|
+
elif (
|
|
101
|
+
inverse_stabilization_method
|
|
102
|
+
== InverseStabilizationMethods.ADD_RIDGE_MEDIAN_SINGULAR_VALUE_FRACTION
|
|
103
|
+
):
|
|
104
|
+
logger.info("Adding ridge/Tikhonov regularization to improve conditioning.")
|
|
105
|
+
_, singular_values, _ = np.linalg.svd(matrix, full_matrices=False)
|
|
106
|
+
logger.info(
|
|
107
|
+
"Using median singular value times %s as lambda.",
|
|
108
|
+
ridge_median_singular_value_fraction,
|
|
109
|
+
)
|
|
110
|
+
lambda_ = ridge_median_singular_value_fraction * np.median(singular_values)
|
|
111
|
+
logger.info("Lambda for ridge regularization: %s", lambda_)
|
|
112
|
+
new_matrix = matrix + lambda_ * np.eye(matrix.shape[0])
|
|
113
|
+
pre_inversion_condition_number = np.linalg.cond(new_matrix)
|
|
114
|
+
logger.info(
|
|
115
|
+
"Condition number of matrix after ridge regularization: %s",
|
|
116
|
+
pre_inversion_condition_number,
|
|
117
|
+
)
|
|
118
|
+
inverse = np.linalg.solve(new_matrix, np.eye(matrix.shape[0]))
|
|
119
|
+
elif (
|
|
120
|
+
inverse_stabilization_method
|
|
121
|
+
== InverseStabilizationMethods.INVERSE_BREAD_STRUCTURE_AWARE_INVERSION
|
|
122
|
+
):
|
|
123
|
+
if not beta_dim or not theta_dim:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
"When using structure-aware inversion, beta_dim and theta_dim must be provided."
|
|
126
|
+
)
|
|
127
|
+
logger.info(
|
|
128
|
+
"Using inverse bread's block lower triangular structure to invert only diagonal blocks."
|
|
129
|
+
)
|
|
130
|
+
pre_inversion_condition_number = np.linalg.cond(matrix)
|
|
131
|
+
inverse = invert_inverse_bread_matrix(
|
|
132
|
+
matrix,
|
|
133
|
+
beta_dim,
|
|
134
|
+
theta_dim,
|
|
135
|
+
InverseStabilizationMethods.ADD_RIDGE_FIXED_CONDITION_NUMBER,
|
|
136
|
+
)
|
|
137
|
+
elif (
|
|
138
|
+
inverse_stabilization_method
|
|
139
|
+
== InverseStabilizationMethods.ZERO_OUT_SMALL_OFF_DIAGONALS
|
|
140
|
+
):
|
|
141
|
+
if not beta_dim or not theta_dim:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
"When zeroing out small off diagonals, beta_dim and theta_dim must be provided."
|
|
144
|
+
)
|
|
145
|
+
logger.info(
|
|
146
|
+
"Zeroing out small off-diagonal blocks to improve conditioning."
|
|
147
|
+
)
|
|
148
|
+
zeroed_matrix = zero_small_off_diagonal_blocks(
|
|
149
|
+
matrix,
|
|
150
|
+
([beta_dim] * (matrix.shape[0] // beta_dim)) + [theta_dim],
|
|
151
|
+
)
|
|
152
|
+
pre_inversion_condition_number = np.linalg.cond(zeroed_matrix)
|
|
153
|
+
logger.info(
|
|
154
|
+
"Condition number of matrix after zeroing out small off-diagonal blocks: %s",
|
|
155
|
+
pre_inversion_condition_number,
|
|
156
|
+
)
|
|
157
|
+
inverse = np.linalg.solve(zeroed_matrix, np.eye(zeroed_matrix.shape[0]))
|
|
158
|
+
elif (
|
|
159
|
+
inverse_stabilization_method
|
|
160
|
+
== InverseStabilizationMethods.ALL_METHODS_COMPETITION
|
|
161
|
+
):
|
|
162
|
+
# TODO: Choose right metric for competition... identity diff might not be it.
|
|
163
|
+
raise NotImplementedError(
|
|
164
|
+
"All methods competition is not implemented yet. Please choose a specific method."
|
|
165
|
+
)
|
|
166
|
+
elif inverse_stabilization_method == InverseStabilizationMethods.NONE:
|
|
167
|
+
logger.info("No inverse stabilization method applied. Inverting directly.")
|
|
168
|
+
else:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Unknown inverse stabilization method: {inverse_stabilization_method}"
|
|
171
|
+
)
|
|
172
|
+
if inverse is None:
|
|
173
|
+
inverse = np.linalg.solve(matrix, np.eye(matrix.shape[0]))
|
|
174
|
+
return inverse, pre_inversion_condition_number
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def zero_small_off_diagonal_blocks(
|
|
178
|
+
matrix: jnp.ndarray,
|
|
179
|
+
block_sizes: list[int],
|
|
180
|
+
frobenius_norm_threshold_fraction: float = 1e-3,
|
|
181
|
+
):
|
|
182
|
+
"""
|
|
183
|
+
Zero off-diagonal blocks whose Frobenius norm is < frobenius_norm_threshold_fraction x
|
|
184
|
+
Frobenius norm of the diagonal block in the same ROW. One could compare to
|
|
185
|
+
the same column or both the row and column, but we choose row here since
|
|
186
|
+
rows correspond to a single RL update or inference step in the adaptive bread
|
|
187
|
+
inverse matrices this method is designed for.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
matrix (jnp.ndarray):
|
|
191
|
+
2-D ndarray, square (q_total x q_total)
|
|
192
|
+
block_sizes (list[int]):
|
|
193
|
+
list like [p1, p2, ..., pT]
|
|
194
|
+
frobenius_norm_threshold_fraction (float):
|
|
195
|
+
frobenius norm fraction relative to same-row diagonal block under which we zero a block
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
ndarray with selected off-blocks zeroed
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
bounds = np.cumsum([0] + list(block_sizes))
|
|
202
|
+
num_block_rows_cols = len(block_sizes)
|
|
203
|
+
J_trim = matrix.copy()
|
|
204
|
+
|
|
205
|
+
# 1. collect Frobenius norms of every diagonal block in one pass
|
|
206
|
+
diag_norm = np.empty(num_block_rows_cols)
|
|
207
|
+
for t in range(num_block_rows_cols):
|
|
208
|
+
sl = slice(bounds[t], bounds[t + 1])
|
|
209
|
+
diag_norm[t] = np.linalg.norm(matrix[sl, sl], ord="fro")
|
|
210
|
+
|
|
211
|
+
# 2. Zero all sufficiently small off-diagonal blocks
|
|
212
|
+
for t in range(num_block_rows_cols):
|
|
213
|
+
source_norm = diag_norm[t]
|
|
214
|
+
r0, r1 = bounds[t], bounds[t + 1] # rows belonging to block t
|
|
215
|
+
|
|
216
|
+
# rows BELOW the diagonal (lower-triangular part)
|
|
217
|
+
for tau in range(t + 1, num_block_rows_cols):
|
|
218
|
+
c0, c1 = bounds[tau], bounds[tau + 1]
|
|
219
|
+
block = J_trim[r0:r1, c0:c1]
|
|
220
|
+
block_norm = np.linalg.norm(block, ord="fro")
|
|
221
|
+
if (
|
|
222
|
+
block_norm
|
|
223
|
+
and block_norm < frobenius_norm_threshold_fraction * source_norm
|
|
224
|
+
):
|
|
225
|
+
logger.info(
|
|
226
|
+
"Zeroing out block [%s:%s, %s:%s] with Frobenius norm %s < %s * %s",
|
|
227
|
+
r0,
|
|
228
|
+
r1,
|
|
229
|
+
c0,
|
|
230
|
+
c1,
|
|
231
|
+
block_norm,
|
|
232
|
+
frobenius_norm_threshold_fraction,
|
|
233
|
+
source_norm,
|
|
234
|
+
)
|
|
235
|
+
J_trim = J_trim.at[r0:r1, c0:c1].set(0.0)
|
|
236
|
+
|
|
237
|
+
return J_trim
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def invert_inverse_bread_matrix(
|
|
241
|
+
inverse_bread,
|
|
242
|
+
beta_dim,
|
|
243
|
+
theta_dim,
|
|
244
|
+
diag_inverse_stabilization_method=InverseStabilizationMethods.TRIM_SMALL_SINGULAR_VALUES,
|
|
245
|
+
):
|
|
246
|
+
"""
|
|
247
|
+
Invert the inverse bread matrix to get the bread matrix. This is a special
|
|
248
|
+
function in order to take advantage of the block lower triangular structure.
|
|
249
|
+
|
|
250
|
+
The procedure is as follows:
|
|
251
|
+
1. Initialize the inverse matrix B = A^{-1} as a block lower triangular matrix
|
|
252
|
+
with the same block structure as A.
|
|
253
|
+
|
|
254
|
+
2. Compute the diagonal blocks B_{ii}:
|
|
255
|
+
For each diagonal block A_{ii}, calculate:
|
|
256
|
+
B_{ii} = A_{ii}^{-1}
|
|
257
|
+
|
|
258
|
+
3. Compute the off-diagonal blocks B_{ij} for i > j:
|
|
259
|
+
For each off-diagonal block B_{ij} (where i > j), compute:
|
|
260
|
+
B_{ij} = -A_{ii}^{-1} * sum(A_{ik} * B_{kj} for k in range(j, i))
|
|
261
|
+
"""
|
|
262
|
+
blocks = []
|
|
263
|
+
num_beta_block_rows = (inverse_bread.shape[0] - theta_dim) // beta_dim
|
|
264
|
+
|
|
265
|
+
# Create upper rows of block of bread (just the beta portion)
|
|
266
|
+
for i in range(0, num_beta_block_rows):
|
|
267
|
+
beta_block_row = []
|
|
268
|
+
beta_diag_inverse = invert_matrix_and_check_conditioning(
|
|
269
|
+
inverse_bread[
|
|
270
|
+
beta_dim * i : beta_dim * (i + 1),
|
|
271
|
+
beta_dim * i : beta_dim * (i + 1),
|
|
272
|
+
],
|
|
273
|
+
diag_inverse_stabilization_method,
|
|
274
|
+
)[0]
|
|
275
|
+
for j in range(0, num_beta_block_rows):
|
|
276
|
+
if i > j:
|
|
277
|
+
beta_block_row.append(
|
|
278
|
+
-beta_diag_inverse
|
|
279
|
+
@ sum(
|
|
280
|
+
inverse_bread[
|
|
281
|
+
beta_dim * i : beta_dim * (i + 1),
|
|
282
|
+
beta_dim * k : beta_dim * (k + 1),
|
|
283
|
+
]
|
|
284
|
+
@ blocks[k][j]
|
|
285
|
+
for k in range(j, i)
|
|
286
|
+
)
|
|
287
|
+
)
|
|
288
|
+
elif i == j:
|
|
289
|
+
beta_block_row.append(beta_diag_inverse)
|
|
290
|
+
else:
|
|
291
|
+
beta_block_row.append(np.zeros((beta_dim, beta_dim)).astype(np.float32))
|
|
292
|
+
|
|
293
|
+
# Extra beta * theta zero block. This is the last block of the row.
|
|
294
|
+
# Any other zeros in the row have already been handled above.
|
|
295
|
+
beta_block_row.append(np.zeros((beta_dim, theta_dim)))
|
|
296
|
+
|
|
297
|
+
blocks.append(beta_block_row)
|
|
298
|
+
|
|
299
|
+
# Create the bottom block row of bread (the theta portion)
|
|
300
|
+
theta_block_row = []
|
|
301
|
+
theta_diag_inverse = invert_matrix_and_check_conditioning(
|
|
302
|
+
inverse_bread[
|
|
303
|
+
-theta_dim:,
|
|
304
|
+
-theta_dim:,
|
|
305
|
+
],
|
|
306
|
+
diag_inverse_stabilization_method,
|
|
307
|
+
)[0]
|
|
308
|
+
for k in range(0, num_beta_block_rows):
|
|
309
|
+
theta_block_row.append(
|
|
310
|
+
-theta_diag_inverse
|
|
311
|
+
@ sum(
|
|
312
|
+
inverse_bread[
|
|
313
|
+
-theta_dim:,
|
|
314
|
+
beta_dim * h : beta_dim * (h + 1),
|
|
315
|
+
]
|
|
316
|
+
@ blocks[h][k]
|
|
317
|
+
for h in range(k, num_beta_block_rows)
|
|
318
|
+
)
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
theta_block_row.append(theta_diag_inverse)
|
|
322
|
+
blocks.append(theta_block_row)
|
|
323
|
+
|
|
324
|
+
return np.block(blocks)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def matrix_inv_sqrt(mat: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
|
328
|
+
"""Return (mat)^{-1/2} with eigenvalues clipped at `eps`."""
|
|
329
|
+
eigval, eigvec = np.linalg.eigh(mat)
|
|
330
|
+
eigval = np.clip(eigval, eps, None) # ensure strictly positive
|
|
331
|
+
return eigvec @ np.diag(eigval**-0.5) @ eigvec.T
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def load_module_from_source_file(modname, filename):
|
|
335
|
+
loader = importlib.machinery.SourceFileLoader(modname, filename)
|
|
336
|
+
spec = importlib.util.spec_from_file_location(modname, filename, loader=loader)
|
|
337
|
+
module = importlib.util.module_from_spec(spec)
|
|
338
|
+
# The module is always executed and not cached in sys.modules.
|
|
339
|
+
# Uncomment the following line to cache the module.
|
|
340
|
+
# sys.modules[module.__name__] = module
|
|
341
|
+
loader.exec_module(module)
|
|
342
|
+
return module
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def load_function_from_same_named_file(filename):
|
|
346
|
+
module = load_module_from_source_file(filename, filename)
|
|
347
|
+
try:
|
|
348
|
+
return module.__dict__[os.path.basename(filename).split(".")[0]]
|
|
349
|
+
except AttributeError as e:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Unable to import function from {filename}. Please verify the file has the same name as the function of interest (ignoring the extension)."
|
|
352
|
+
) from e
|
|
353
|
+
except KeyError as e:
|
|
354
|
+
raise ValueError(
|
|
355
|
+
f"Unable to import function from {filename}. Please verify the file has the same name as the function of interest (ignoring the extension)."
|
|
356
|
+
) from e
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def confirm_input_check_result(message, suppress_interaction, error=None):
|
|
360
|
+
|
|
361
|
+
if suppress_interaction:
|
|
362
|
+
logger.info(
|
|
363
|
+
"Skipping the following interactive data check, as requested:\n%s", message
|
|
364
|
+
)
|
|
365
|
+
return
|
|
366
|
+
answer = None
|
|
367
|
+
while answer != "y":
|
|
368
|
+
# pylint: disable=bad-builtin
|
|
369
|
+
answer = input(message).lower()
|
|
370
|
+
# pylint: enable=bad-builtin
|
|
371
|
+
if answer == "y":
|
|
372
|
+
print("\nOk, proceeding.\n")
|
|
373
|
+
elif answer == "n":
|
|
374
|
+
if error:
|
|
375
|
+
raise SystemExit from error
|
|
376
|
+
raise SystemExit
|
|
377
|
+
else:
|
|
378
|
+
print("\nPlease enter 'y' or 'n'.\n")
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def get_in_study_df_column(study_df, col_name, in_study_col_name):
|
|
382
|
+
return jnp.array(
|
|
383
|
+
study_df.loc[study_df[in_study_col_name] == 1, col_name]
|
|
384
|
+
.to_numpy()
|
|
385
|
+
.reshape(-1, 1)
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def flatten_params(betas: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray:
|
|
390
|
+
return jnp.concatenate(list(betas) + [theta])
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def unflatten_params(
|
|
394
|
+
flat: jnp.ndarray, beta_dim: int, theta_dim: int
|
|
395
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
396
|
+
theta = flat[-theta_dim:]
|
|
397
|
+
betas = jnp.array(
|
|
398
|
+
[
|
|
399
|
+
flat[i * beta_dim : (i + 1) * beta_dim]
|
|
400
|
+
for i in range((len(flat) - theta_dim) // beta_dim)
|
|
401
|
+
]
|
|
402
|
+
)
|
|
403
|
+
return betas, theta
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def get_radon_nikodym_weight(
|
|
407
|
+
beta_target: jnp.ndarray[jnp.float32],
|
|
408
|
+
action_prob_func: callable,
|
|
409
|
+
action_prob_func_args_beta_index: int,
|
|
410
|
+
action: int,
|
|
411
|
+
*action_prob_func_args_single_user: tuple[Any, ...],
|
|
412
|
+
):
|
|
413
|
+
"""
|
|
414
|
+
Computes a ratio of action probabilities under two sets of algorithm parameters:
|
|
415
|
+
in the denominator, beta_target is substituted in with the the rest of the supplied action
|
|
416
|
+
probability function arguments, and in the numerator the original value is used. The action
|
|
417
|
+
actually taken at the relevant decision time is also supplied, which is used to determine
|
|
418
|
+
whether to use action 1 probabilities or action 0 probabilities in the ratio.
|
|
419
|
+
|
|
420
|
+
Even though in practice we call this in such a way that the beta value is the same in numerator
|
|
421
|
+
and denominator, it is important to define the function this way so that differentiation, which
|
|
422
|
+
is with respect to the numerator beta, is done correctly.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
beta_target (jnp.ndarray[jnp.float32]):
|
|
426
|
+
The beta value to use in the denominator. NOT involved in differentation!
|
|
427
|
+
action_prob_func (callable):
|
|
428
|
+
The function used to compute the probability of action 1 at a given decision time for
|
|
429
|
+
a particular user given their state and the algorithm parameters.
|
|
430
|
+
action_prob_func_args_beta_index (int):
|
|
431
|
+
The index of the beta argument in the action probability function's arguments.
|
|
432
|
+
action (int):
|
|
433
|
+
The actual taken action at the relevant decision time.
|
|
434
|
+
*action_prob_func_args_single_user (tuple[Any, ...]):
|
|
435
|
+
The arguments to the action probability function for the relevant user at this time.
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
jnp.float32: The Radon-Nikodym weight.
|
|
439
|
+
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
# numerator
|
|
443
|
+
pi_beta = action_prob_func(*action_prob_func_args_single_user)
|
|
444
|
+
|
|
445
|
+
# denominator, where we thread in beta_target so that differentiation with respect to the
|
|
446
|
+
# original beta in the arguments leaves this alone.
|
|
447
|
+
beta_target_action_prob_func_args_single_user = [*action_prob_func_args_single_user]
|
|
448
|
+
beta_target_action_prob_func_args_single_user[action_prob_func_args_beta_index] = (
|
|
449
|
+
beta_target
|
|
450
|
+
)
|
|
451
|
+
pi_beta_target = action_prob_func(*beta_target_action_prob_func_args_single_user)
|
|
452
|
+
|
|
453
|
+
return conditional_x_or_one_minus_x(pi_beta, action) / conditional_x_or_one_minus_x(
|
|
454
|
+
pi_beta_target, action
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def get_min_time_by_policy_num(
|
|
459
|
+
single_user_policy_num_by_decision_time, beta_index_by_policy_num
|
|
460
|
+
):
|
|
461
|
+
"""
|
|
462
|
+
Returns a dictionary mapping each policy number to the first time it was applicable,
|
|
463
|
+
and the first time after the first update.
|
|
464
|
+
"""
|
|
465
|
+
min_time_by_policy_num = {}
|
|
466
|
+
first_time_after_first_update = None
|
|
467
|
+
for decision_time, policy_num in single_user_policy_num_by_decision_time.items():
|
|
468
|
+
if policy_num not in min_time_by_policy_num:
|
|
469
|
+
min_time_by_policy_num[policy_num] = decision_time
|
|
470
|
+
|
|
471
|
+
# Grab the first time where a non-initial, non-fallback policy is used.
|
|
472
|
+
# Assumes single_user_policy_num_by_decision_time is sorted.
|
|
473
|
+
if (
|
|
474
|
+
policy_num in beta_index_by_policy_num
|
|
475
|
+
and first_time_after_first_update is None
|
|
476
|
+
):
|
|
477
|
+
first_time_after_first_update = decision_time
|
|
478
|
+
|
|
479
|
+
return min_time_by_policy_num, first_time_after_first_update
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def calculate_beta_dim(
|
|
483
|
+
action_prob_func_args: dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]],
|
|
484
|
+
action_prob_func_args_beta_index: int,
|
|
485
|
+
) -> int:
|
|
486
|
+
"""
|
|
487
|
+
Calculates the dimension of the beta vector based on the action probability function arguments.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
action_prob_func_args (dict): Dictionary containing the action probability function arguments.
|
|
491
|
+
action_prob_func_args_beta_index (int): Index of the beta parameter in the action probability function arguments.
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
int: The dimension of the beta vector.
|
|
495
|
+
"""
|
|
496
|
+
for decision_time in action_prob_func_args:
|
|
497
|
+
for user_id in action_prob_func_args[decision_time]:
|
|
498
|
+
if action_prob_func_args[decision_time][user_id]:
|
|
499
|
+
return len(
|
|
500
|
+
action_prob_func_args[decision_time][user_id][
|
|
501
|
+
action_prob_func_args_beta_index
|
|
502
|
+
]
|
|
503
|
+
)
|
|
504
|
+
raise ValueError(
|
|
505
|
+
"No valid beta vector found in action probability function arguments. Please check the input data."
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def construct_beta_index_by_policy_num_map(
|
|
510
|
+
study_df: pd.DataFrame, policy_num_col_name: str, in_study_col_name: str
|
|
511
|
+
) -> tuple[dict[int | float, int], int | float]:
|
|
512
|
+
"""
|
|
513
|
+
Constructs a mapping from non-initial, non-fallback policy numbers to the index of the
|
|
514
|
+
corresponding beta in all_post_update_betas.
|
|
515
|
+
|
|
516
|
+
This is useful because differentiating the stacked estimating functions with respect to all the
|
|
517
|
+
betas is simplest if they are passed in a single list. This auxiliary data then allows us to
|
|
518
|
+
route the right beta to the right policy number at each time.
|
|
519
|
+
|
|
520
|
+
If we really keep the enforcement of consecutive policy numbers, we don't actually need all
|
|
521
|
+
this logic and can just pass around the initial policy number, but I'd like to have this
|
|
522
|
+
handle the merely increasing (non-fallback) case even though upstream we currently do require no
|
|
523
|
+
gaps.
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
unique_sorted_non_fallback_policy_nums = sorted(
|
|
527
|
+
study_df[
|
|
528
|
+
(study_df[policy_num_col_name] >= 0) & (study_df[in_study_col_name] == 1)
|
|
529
|
+
][policy_num_col_name]
|
|
530
|
+
.unique()
|
|
531
|
+
.tolist()
|
|
532
|
+
)
|
|
533
|
+
# This assumes only the first policy is an initial policy not produced by an update.
|
|
534
|
+
# Hence the [1:] slice.
|
|
535
|
+
return {
|
|
536
|
+
policy_num: i
|
|
537
|
+
for i, policy_num in enumerate(unique_sorted_non_fallback_policy_nums[1:])
|
|
538
|
+
}, unique_sorted_non_fallback_policy_nums[0]
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def collect_all_post_update_betas(
|
|
542
|
+
beta_index_by_policy_num, alg_update_func_args, alg_update_func_args_beta_index
|
|
543
|
+
):
|
|
544
|
+
"""
|
|
545
|
+
Collects all betas produced by the algorithm updates in an ordered list.
|
|
546
|
+
|
|
547
|
+
This data structure is chosen because it makes for the most convenient
|
|
548
|
+
differentiation of the stacked estimating functions with respect to all the
|
|
549
|
+
betas. Otherwise a dictionary keyed on policy number would be more natural.
|
|
550
|
+
"""
|
|
551
|
+
all_post_update_betas = []
|
|
552
|
+
for policy_num in sorted(beta_index_by_policy_num.keys()):
|
|
553
|
+
for user_id in alg_update_func_args[policy_num]:
|
|
554
|
+
if alg_update_func_args[policy_num][user_id]:
|
|
555
|
+
all_post_update_betas.append(
|
|
556
|
+
alg_update_func_args[policy_num][user_id][
|
|
557
|
+
alg_update_func_args_beta_index
|
|
558
|
+
]
|
|
559
|
+
)
|
|
560
|
+
break
|
|
561
|
+
return jnp.array(all_post_update_betas)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def extract_action_and_policy_by_decision_time_by_user_id(
|
|
565
|
+
study_df,
|
|
566
|
+
user_id_col_name,
|
|
567
|
+
in_study_col_name,
|
|
568
|
+
calendar_t_col_name,
|
|
569
|
+
action_col_name,
|
|
570
|
+
policy_num_col_name,
|
|
571
|
+
):
|
|
572
|
+
action_by_decision_time_by_user_id = {}
|
|
573
|
+
policy_num_by_decision_time_by_user_id = {}
|
|
574
|
+
for user_id, user_df in study_df.groupby(user_id_col_name):
|
|
575
|
+
in_study_user_df = user_df[user_df[in_study_col_name] == 1]
|
|
576
|
+
action_by_decision_time_by_user_id[user_id] = dict(
|
|
577
|
+
zip(
|
|
578
|
+
in_study_user_df[calendar_t_col_name], in_study_user_df[action_col_name]
|
|
579
|
+
)
|
|
580
|
+
)
|
|
581
|
+
policy_num_by_decision_time_by_user_id[user_id] = dict(
|
|
582
|
+
zip(
|
|
583
|
+
in_study_user_df[calendar_t_col_name],
|
|
584
|
+
in_study_user_df[policy_num_col_name],
|
|
585
|
+
)
|
|
586
|
+
)
|
|
587
|
+
return action_by_decision_time_by_user_id, policy_num_by_decision_time_by_user_id
|