lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
lifejacket/constants.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
class SmallSampleCorrections:
|
|
2
|
+
NONE = "none"
|
|
3
|
+
HC1theta = "HC1theta"
|
|
4
|
+
HC2theta = "HC2theta"
|
|
5
|
+
HC3theta = "HC3theta"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class InverseStabilizationMethods:
|
|
9
|
+
NONE = "none"
|
|
10
|
+
TRIM_SMALL_SINGULAR_VALUES = "trim_small_singular_values"
|
|
11
|
+
ZERO_OUT_SMALL_OFF_DIAGONALS = "zero_out_small_off_diagonals"
|
|
12
|
+
ADD_RIDGE_FIXED_CONDITION_NUMBER = "add_ridge_fixed_condition_number"
|
|
13
|
+
ADD_RIDGE_MEDIAN_SINGULAR_VALUE_FRACTION = (
|
|
14
|
+
"add_ridge_median_singular_value_fraction"
|
|
15
|
+
)
|
|
16
|
+
INVERSE_BREAD_STRUCTURE_AWARE_INVERSION = "inverse_bread_structure_aware_inversion"
|
|
17
|
+
ALL_METHODS_COMPETITION = "all_methods_competition"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FunctionTypes:
|
|
21
|
+
LOSS = "loss"
|
|
22
|
+
ESTIMATING = "estimating"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SandwichFormationMethods:
|
|
26
|
+
BREAD_INVERSE_T_QR = "bread_inverse_T_qr"
|
|
27
|
+
MEAT_SVD_SOLVE = "meat_svd_solve"
|
|
28
|
+
NAIVE = "naive"
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .calculate_derivatives import (
|
|
9
|
+
calculate_inference_loss_derivatives,
|
|
10
|
+
calculate_pi_and_weight_gradients,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
logging.basicConfig(
|
|
15
|
+
format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
|
|
16
|
+
datefmt="%Y-%m-%d:%H:%M:%S",
|
|
17
|
+
level=logging.INFO,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def form_adaptive_meat_adjustments_directly(
|
|
22
|
+
theta_dim: int,
|
|
23
|
+
beta_dim: int,
|
|
24
|
+
joint_adaptive_bread_inverse_matrix: jnp.ndarray,
|
|
25
|
+
per_user_estimating_function_stacks: jnp.ndarray,
|
|
26
|
+
study_df: pd.DataFrame,
|
|
27
|
+
in_study_col_name: str,
|
|
28
|
+
action_col_name: str,
|
|
29
|
+
calendar_t_col_name: str,
|
|
30
|
+
user_id_col_name: str,
|
|
31
|
+
action_prob_func: callable,
|
|
32
|
+
action_prob_func_args: dict,
|
|
33
|
+
action_prob_func_args_beta_index: int,
|
|
34
|
+
theta_est: jnp.ndarray,
|
|
35
|
+
inference_func: callable,
|
|
36
|
+
inference_func_args_theta_index: int,
|
|
37
|
+
user_ids: list[collections.abc.Hashable],
|
|
38
|
+
action_prob_col_name: str,
|
|
39
|
+
) -> jnp.ndarray:
|
|
40
|
+
logger.info(
|
|
41
|
+
"Explicitly forming the per-user meat adjustments that differentiate the adaptive sandwich from the classical sandwich."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# 1. Form the M-matrices, which are shared across users.
|
|
45
|
+
# This is not quite the paper definition of the M-matrices, which
|
|
46
|
+
# includes multiplication by the classical bread. We don't care about
|
|
47
|
+
# that here, since in forming the adjustments there is a multiplication
|
|
48
|
+
# by the classical bread inverse that cancels it out.
|
|
49
|
+
V_blocks_together = joint_adaptive_bread_inverse_matrix[-theta_dim:, :-theta_dim]
|
|
50
|
+
RL_stack_beta_derivatives_block = joint_adaptive_bread_inverse_matrix[
|
|
51
|
+
:-theta_dim, :-theta_dim
|
|
52
|
+
]
|
|
53
|
+
effective_M_blocks_together = np.linalg.solve(
|
|
54
|
+
RL_stack_beta_derivatives_block.T, V_blocks_together.T
|
|
55
|
+
).T
|
|
56
|
+
|
|
57
|
+
# 2. Extract the RL-only parts of the per-user estimating function stacks
|
|
58
|
+
per_user_RL_only_est_fn_stacks_together = per_user_estimating_function_stacks[
|
|
59
|
+
:, :-theta_dim
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
# 3. Split the effective M blocks into (theta_dim, beta_dim) blocks and the
|
|
63
|
+
# estimating function stacks into (num_updates, beta_dim) stacks.
|
|
64
|
+
|
|
65
|
+
# effective_M_blocks is shape (theta_dim, num_updates * beta_dim)
|
|
66
|
+
# We want to split it into a list of (theta_dim, beta_dim) arrays
|
|
67
|
+
M_blocks = np.split(
|
|
68
|
+
effective_M_blocks_together,
|
|
69
|
+
effective_M_blocks_together.shape[1] // beta_dim,
|
|
70
|
+
axis=1,
|
|
71
|
+
)
|
|
72
|
+
# Now stack into a 3D array of shape (num_updates, theta_dim, beta_dim)
|
|
73
|
+
M_blocks_stacked = np.stack(M_blocks, axis=0)
|
|
74
|
+
|
|
75
|
+
# per_user_RL_only_est_fn_stacks is shape (num_users, num_updates * beta_dim)
|
|
76
|
+
# We want to split it into a list of (num_updates, beta_dim) arrays per user
|
|
77
|
+
per_user_RL_only_est_fns = np.split(
|
|
78
|
+
per_user_RL_only_est_fn_stacks_together,
|
|
79
|
+
per_user_RL_only_est_fn_stacks_together.shape[1] // beta_dim,
|
|
80
|
+
axis=1,
|
|
81
|
+
)
|
|
82
|
+
# Stack into a 3D array of shape (num_users, num_updates, beta_dim)
|
|
83
|
+
# Note the difference between this and the original format of these estimating functions,
|
|
84
|
+
# which was not broken down by update
|
|
85
|
+
per_user_RL_only_est_fns_stacked = np.stack(per_user_RL_only_est_fns, axis=1)
|
|
86
|
+
|
|
87
|
+
# Now multiply the M matrices and the per-user estimating functions
|
|
88
|
+
# and sum over the updates to get the per-user meat adjustments (to be more precise, what would
|
|
89
|
+
# be added to each users inference estimating function before an outer product is taken with
|
|
90
|
+
# itself to get each users's contributioan theta-only meat matrix).
|
|
91
|
+
# Result is shape (num_users, theta_dim).
|
|
92
|
+
# Form the per-user adaptive meat adjustments explicitly for diagnostic purposes.
|
|
93
|
+
per_user_meat_adjustments_stacked = np.einsum(
|
|
94
|
+
"utb,nub->nt", M_blocks_stacked, per_user_RL_only_est_fns_stacked
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Log some diagnostics about the pieces going into the adaptive meat adjustments
|
|
98
|
+
# and the adjustments themselves.
|
|
99
|
+
V_blocks = np.split(
|
|
100
|
+
V_blocks_together, V_blocks_together.shape[1] // beta_dim, axis=1
|
|
101
|
+
)
|
|
102
|
+
logger.info("Examining adaptive meat adjustments.")
|
|
103
|
+
# No scientific notation
|
|
104
|
+
np.set_printoptions(suppress=True)
|
|
105
|
+
|
|
106
|
+
per_user_inference_estimating_functions_stacked = (
|
|
107
|
+
per_user_estimating_function_stacks[:, -theta_dim:]
|
|
108
|
+
)
|
|
109
|
+
# This actually logs way too much, so making these all debug level to not exhaust VScode
|
|
110
|
+
# terminal buffer
|
|
111
|
+
logger.debug(
|
|
112
|
+
"Per-user inference estimating functions. Without adjustment, the average of the outer products of these is the classical meat: %s",
|
|
113
|
+
per_user_inference_estimating_functions_stacked,
|
|
114
|
+
)
|
|
115
|
+
logger.debug(
|
|
116
|
+
"Norms of per-user inference estimating functions: %s",
|
|
117
|
+
np.linalg.norm(per_user_inference_estimating_functions_stacked, axis=1),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
logger.debug(
|
|
121
|
+
"Per-user adaptive meat adjustments, to be added to inference estimating functions before forming the meat. Formed from the sum of the products of the M-blocks and the corresponding RL update estimating functions for each user: %s",
|
|
122
|
+
per_user_meat_adjustments_stacked,
|
|
123
|
+
)
|
|
124
|
+
logger.debug(
|
|
125
|
+
"Norms of per-user adaptive meat adjustments: %s",
|
|
126
|
+
np.linalg.norm(per_user_meat_adjustments_stacked, axis=1),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
per_user_fractional_adjustments = (
|
|
130
|
+
per_user_meat_adjustments_stacked
|
|
131
|
+
/ per_user_inference_estimating_functions_stacked
|
|
132
|
+
)
|
|
133
|
+
logger.debug(
|
|
134
|
+
"Per-user fractional adjustments (elementwise ratio of adjustment to original inference estimating function): %s",
|
|
135
|
+
per_user_fractional_adjustments,
|
|
136
|
+
)
|
|
137
|
+
logger.debug(
|
|
138
|
+
"Norms of per-user fractional adjustments: %s",
|
|
139
|
+
np.linalg.norm(per_user_fractional_adjustments, axis=1),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
V_blocks_stacked = np.stack(V_blocks, axis=0)
|
|
143
|
+
logger.debug(
|
|
144
|
+
"V_blocks, one per update, each shape theta_dim x beta_dim. These measure the sensitivity of the estimating function for theta to the limiting policy parameters per update: %s",
|
|
145
|
+
V_blocks_stacked,
|
|
146
|
+
)
|
|
147
|
+
logger.debug("Norms of V-blocks: %s", np.linalg.norm(V_blocks_stacked, axis=(1, 2)))
|
|
148
|
+
|
|
149
|
+
logger.debug(
|
|
150
|
+
"M_blocks, one per update, each shape theta_dim x beta_dim. The sum of the products "
|
|
151
|
+
"of each of these times a user's corresponding RL estimating function forms their adaptive "
|
|
152
|
+
"adjustment. The M's are the blocks of the the product of the V's concatened and the inverse of "
|
|
153
|
+
"the RL-only upper-left block of the joint adaptive bread inverse. In other words, the lower "
|
|
154
|
+
"left block of the joint adaptive bread. Also note that the inference estimating function "
|
|
155
|
+
"derivative inverse is omitted here despite the definition of the M's in the paper, because "
|
|
156
|
+
"that factor simply cancels later: %s",
|
|
157
|
+
M_blocks_stacked,
|
|
158
|
+
)
|
|
159
|
+
logger.debug("Norms of M-blocks: %s", np.linalg.norm(M_blocks_stacked, axis=(1, 2)))
|
|
160
|
+
|
|
161
|
+
logger.debug(
|
|
162
|
+
"RL block of joint adaptive bread inverse. The *inverse* of this goes into the M's: %s",
|
|
163
|
+
RL_stack_beta_derivatives_block,
|
|
164
|
+
)
|
|
165
|
+
logger.debug(
|
|
166
|
+
"Norm of RL block of joint adaptive bread inverse: %s",
|
|
167
|
+
np.linalg.norm(RL_stack_beta_derivatives_block),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
inverse_RL_stack_beta_derivatives_block = np.linalg.inv(
|
|
171
|
+
RL_stack_beta_derivatives_block
|
|
172
|
+
)
|
|
173
|
+
logger.debug(
|
|
174
|
+
"Inverse of RL block of joint adaptive bread inverse. This goes into the M's: %s",
|
|
175
|
+
inverse_RL_stack_beta_derivatives_block,
|
|
176
|
+
)
|
|
177
|
+
logger.debug(
|
|
178
|
+
"Norm of Inverse of RL block of joint adaptive bread inverse: %s",
|
|
179
|
+
np.linalg.norm(inverse_RL_stack_beta_derivatives_block),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
logger.debug(
|
|
183
|
+
"Per-update RL-only estimating function elementwise maxes across users: %s",
|
|
184
|
+
np.max(per_user_RL_only_est_fns_stacked, axis=0),
|
|
185
|
+
)
|
|
186
|
+
logger.debug(
|
|
187
|
+
"Per-update RL-only estimating function elementwise mins across users: %s",
|
|
188
|
+
np.min(per_user_RL_only_est_fns_stacked, axis=0),
|
|
189
|
+
)
|
|
190
|
+
logger.debug(
|
|
191
|
+
"Per-user average RL-only estimating functions across updates: %s",
|
|
192
|
+
np.mean(per_user_RL_only_est_fns_stacked, axis=1),
|
|
193
|
+
)
|
|
194
|
+
logger.debug(
|
|
195
|
+
"Per-update std of RL-only estimating functions across users: %s",
|
|
196
|
+
np.std(per_user_RL_only_est_fns_stacked, axis=0),
|
|
197
|
+
)
|
|
198
|
+
logger.debug(
|
|
199
|
+
"Norms of per-user RL-only estimating functions (num users x num updates): %s",
|
|
200
|
+
np.linalg.norm(per_user_RL_only_est_fns_stacked, axis=2),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Now dig even deeper to get weight derivatives and inference estimating function mixed
|
|
204
|
+
# derivatives that go into the V's
|
|
205
|
+
|
|
206
|
+
pi_and_weight_gradients_by_calendar_t = calculate_pi_and_weight_gradients(
|
|
207
|
+
study_df,
|
|
208
|
+
in_study_col_name,
|
|
209
|
+
action_col_name,
|
|
210
|
+
calendar_t_col_name,
|
|
211
|
+
user_id_col_name,
|
|
212
|
+
action_prob_func,
|
|
213
|
+
action_prob_func_args,
|
|
214
|
+
action_prob_func_args_beta_index,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
_, _, loss_gradient_pi_derivatives = calculate_inference_loss_derivatives(
|
|
218
|
+
study_df,
|
|
219
|
+
theta_est,
|
|
220
|
+
inference_func,
|
|
221
|
+
inference_func_args_theta_index,
|
|
222
|
+
user_ids,
|
|
223
|
+
user_id_col_name,
|
|
224
|
+
action_prob_col_name,
|
|
225
|
+
in_study_col_name,
|
|
226
|
+
calendar_t_col_name,
|
|
227
|
+
)
|
|
228
|
+
# Take the outer product of each row of (per_user_meat_adjustments_stacked + per_user_inference_estimating_functions_stacked)
|
|
229
|
+
per_user_adjusted_inference_estimating_functions_stacked = (
|
|
230
|
+
per_user_meat_adjustments_stacked
|
|
231
|
+
+ per_user_inference_estimating_functions_stacked
|
|
232
|
+
)
|
|
233
|
+
per_user_theta_only_adaptive_meat_contributions = jnp.einsum(
|
|
234
|
+
"ni,nj->nij",
|
|
235
|
+
per_user_adjusted_inference_estimating_functions_stacked,
|
|
236
|
+
per_user_adjusted_inference_estimating_functions_stacked,
|
|
237
|
+
)
|
|
238
|
+
adaptive_theta_only_meat_matrix = jnp.mean(
|
|
239
|
+
per_user_theta_only_adaptive_meat_contributions, axis=0
|
|
240
|
+
)
|
|
241
|
+
logger.info(
|
|
242
|
+
"Theta-only adaptive meat matrix (no small sample corrections): %s",
|
|
243
|
+
adaptive_theta_only_meat_matrix,
|
|
244
|
+
)
|
|
245
|
+
classical_theta_only_meat_matrix = jnp.mean(
|
|
246
|
+
jnp.einsum(
|
|
247
|
+
"ni,nj->nij",
|
|
248
|
+
per_user_inference_estimating_functions_stacked,
|
|
249
|
+
per_user_inference_estimating_functions_stacked,
|
|
250
|
+
),
|
|
251
|
+
axis=0,
|
|
252
|
+
)
|
|
253
|
+
logger.info(
|
|
254
|
+
"Classical meat matrix (no small sample corrections): %s",
|
|
255
|
+
classical_theta_only_meat_matrix,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# np.linalg.cond(RL_stack_beta_derivatives_block)
|
|
259
|
+
|
|
260
|
+
# Print the condition number of each upper left block of RL_stack_beta_derivatives_block
|
|
261
|
+
# as if we stopped after first update, then second update, etc, up to full beta_dim * num_updates
|
|
262
|
+
num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
|
|
263
|
+
whole_block_condition_numbers = []
|
|
264
|
+
diagonal_block_condition_numbers = []
|
|
265
|
+
for i in range(1, num_updates + 1):
|
|
266
|
+
whole_block_size = i * beta_dim
|
|
267
|
+
whole_block = RL_stack_beta_derivatives_block[
|
|
268
|
+
:whole_block_size, :whole_block_size
|
|
269
|
+
]
|
|
270
|
+
whole_block_cond_number = np.linalg.cond(whole_block)
|
|
271
|
+
whole_block_condition_numbers.append(whole_block_cond_number)
|
|
272
|
+
logger.info(
|
|
273
|
+
"Condition number of whole RL_stack_beta_derivatives_block (after update %s): %s",
|
|
274
|
+
i,
|
|
275
|
+
whole_block_cond_number,
|
|
276
|
+
)
|
|
277
|
+
diagonal_block = RL_stack_beta_derivatives_block[
|
|
278
|
+
(i - 1) * beta_dim : i * beta_dim, (i - 1) * beta_dim : i * beta_dim
|
|
279
|
+
]
|
|
280
|
+
diagonal_block_cond_number = np.linalg.cond(diagonal_block)
|
|
281
|
+
diagonal_block_condition_numbers.append(diagonal_block_cond_number)
|
|
282
|
+
logger.info(
|
|
283
|
+
"Condition number of just RL_stack_beta_derivatives_block *diagonal block* for update %s: %s",
|
|
284
|
+
i,
|
|
285
|
+
diagonal_block_cond_number,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
off_diagonal_scaled_block_norm_sum = 0
|
|
289
|
+
for j in range(1, i):
|
|
290
|
+
off_diagonal_block = RL_stack_beta_derivatives_block[
|
|
291
|
+
(i - 1) * beta_dim : i * beta_dim, (j - 1) * beta_dim : j * beta_dim
|
|
292
|
+
]
|
|
293
|
+
off_diagonal_scaled_block_norm = np.linalg.norm(
|
|
294
|
+
np.linalg.solve(diagonal_block, off_diagonal_block)
|
|
295
|
+
)
|
|
296
|
+
off_diagonal_scaled_block_norm_sum += off_diagonal_scaled_block_norm
|
|
297
|
+
logger.debug(
|
|
298
|
+
"Norm of off-diagonal block (%s, %s) scaled by inverse of diagonal block: %s",
|
|
299
|
+
i,
|
|
300
|
+
j,
|
|
301
|
+
off_diagonal_scaled_block_norm,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
logger.info(
|
|
305
|
+
"Sum of norms of off-diagonal blocks in row %s scaled by inverse of diagonal block: %s",
|
|
306
|
+
i,
|
|
307
|
+
off_diagonal_scaled_block_norm_sum,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Keeping a breakpoint here is the best way to dig in without logging too
|
|
311
|
+
# much or being too opinionated about what to log.
|
|
312
|
+
breakpoint()
|
|
313
|
+
|
|
314
|
+
# # Visualize the inverse RL block of joint adaptive bread inverse using seaborn heatmap
|
|
315
|
+
# pyplt.figure(figsize=(8, 6))
|
|
316
|
+
# sns.heatmap(inverse_RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
|
|
317
|
+
# pyplt.title("Inverse RL Block of Joint Adaptive Bread Inverse")
|
|
318
|
+
# pyplt.xlabel("Beta Index")
|
|
319
|
+
# pyplt.ylabel("Beta Index")
|
|
320
|
+
# pyplt.tight_layout()
|
|
321
|
+
# pyplt.show()
|
|
322
|
+
|
|
323
|
+
# # # Visualize the RL block of joint adaptive bread inverse using seaborn heatmap
|
|
324
|
+
|
|
325
|
+
# pyplt.figure(figsize=(8, 6))
|
|
326
|
+
# sns.heatmap(RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
|
|
327
|
+
# pyplt.title("RL Block of Joint Adaptive Bread Inverse")
|
|
328
|
+
# pyplt.xlabel("Beta Index")
|
|
329
|
+
# pyplt.ylabel("Beta Index")
|
|
330
|
+
# pyplt.tight_layout()
|
|
331
|
+
# pyplt.show()
|
|
332
|
+
|
|
333
|
+
return per_user_theta_only_adaptive_meat_contributions
|