lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,28 @@
1
+ class SmallSampleCorrections:
2
+ NONE = "none"
3
+ HC1theta = "HC1theta"
4
+ HC2theta = "HC2theta"
5
+ HC3theta = "HC3theta"
6
+
7
+
8
+ class InverseStabilizationMethods:
9
+ NONE = "none"
10
+ TRIM_SMALL_SINGULAR_VALUES = "trim_small_singular_values"
11
+ ZERO_OUT_SMALL_OFF_DIAGONALS = "zero_out_small_off_diagonals"
12
+ ADD_RIDGE_FIXED_CONDITION_NUMBER = "add_ridge_fixed_condition_number"
13
+ ADD_RIDGE_MEDIAN_SINGULAR_VALUE_FRACTION = (
14
+ "add_ridge_median_singular_value_fraction"
15
+ )
16
+ INVERSE_BREAD_STRUCTURE_AWARE_INVERSION = "inverse_bread_structure_aware_inversion"
17
+ ALL_METHODS_COMPETITION = "all_methods_competition"
18
+
19
+
20
+ class FunctionTypes:
21
+ LOSS = "loss"
22
+ ESTIMATING = "estimating"
23
+
24
+
25
+ class SandwichFormationMethods:
26
+ BREAD_INVERSE_T_QR = "bread_inverse_T_qr"
27
+ MEAT_SVD_SOLVE = "meat_svd_solve"
28
+ NAIVE = "naive"
@@ -0,0 +1,333 @@
1
+ import collections
2
+ import logging
3
+
4
+ import pandas as pd
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+
8
+ from .calculate_derivatives import (
9
+ calculate_inference_loss_derivatives,
10
+ calculate_pi_and_weight_gradients,
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+ logging.basicConfig(
15
+ format="%(asctime)s,%(msecs)03d %(levelname)-2s [%(filename)s:%(lineno)d] %(message)s",
16
+ datefmt="%Y-%m-%d:%H:%M:%S",
17
+ level=logging.INFO,
18
+ )
19
+
20
+
21
+ def form_adaptive_meat_adjustments_directly(
22
+ theta_dim: int,
23
+ beta_dim: int,
24
+ joint_adaptive_bread_inverse_matrix: jnp.ndarray,
25
+ per_user_estimating_function_stacks: jnp.ndarray,
26
+ study_df: pd.DataFrame,
27
+ in_study_col_name: str,
28
+ action_col_name: str,
29
+ calendar_t_col_name: str,
30
+ user_id_col_name: str,
31
+ action_prob_func: callable,
32
+ action_prob_func_args: dict,
33
+ action_prob_func_args_beta_index: int,
34
+ theta_est: jnp.ndarray,
35
+ inference_func: callable,
36
+ inference_func_args_theta_index: int,
37
+ user_ids: list[collections.abc.Hashable],
38
+ action_prob_col_name: str,
39
+ ) -> jnp.ndarray:
40
+ logger.info(
41
+ "Explicitly forming the per-user meat adjustments that differentiate the adaptive sandwich from the classical sandwich."
42
+ )
43
+
44
+ # 1. Form the M-matrices, which are shared across users.
45
+ # This is not quite the paper definition of the M-matrices, which
46
+ # includes multiplication by the classical bread. We don't care about
47
+ # that here, since in forming the adjustments there is a multiplication
48
+ # by the classical bread inverse that cancels it out.
49
+ V_blocks_together = joint_adaptive_bread_inverse_matrix[-theta_dim:, :-theta_dim]
50
+ RL_stack_beta_derivatives_block = joint_adaptive_bread_inverse_matrix[
51
+ :-theta_dim, :-theta_dim
52
+ ]
53
+ effective_M_blocks_together = np.linalg.solve(
54
+ RL_stack_beta_derivatives_block.T, V_blocks_together.T
55
+ ).T
56
+
57
+ # 2. Extract the RL-only parts of the per-user estimating function stacks
58
+ per_user_RL_only_est_fn_stacks_together = per_user_estimating_function_stacks[
59
+ :, :-theta_dim
60
+ ]
61
+
62
+ # 3. Split the effective M blocks into (theta_dim, beta_dim) blocks and the
63
+ # estimating function stacks into (num_updates, beta_dim) stacks.
64
+
65
+ # effective_M_blocks is shape (theta_dim, num_updates * beta_dim)
66
+ # We want to split it into a list of (theta_dim, beta_dim) arrays
67
+ M_blocks = np.split(
68
+ effective_M_blocks_together,
69
+ effective_M_blocks_together.shape[1] // beta_dim,
70
+ axis=1,
71
+ )
72
+ # Now stack into a 3D array of shape (num_updates, theta_dim, beta_dim)
73
+ M_blocks_stacked = np.stack(M_blocks, axis=0)
74
+
75
+ # per_user_RL_only_est_fn_stacks is shape (num_users, num_updates * beta_dim)
76
+ # We want to split it into a list of (num_updates, beta_dim) arrays per user
77
+ per_user_RL_only_est_fns = np.split(
78
+ per_user_RL_only_est_fn_stacks_together,
79
+ per_user_RL_only_est_fn_stacks_together.shape[1] // beta_dim,
80
+ axis=1,
81
+ )
82
+ # Stack into a 3D array of shape (num_users, num_updates, beta_dim)
83
+ # Note the difference between this and the original format of these estimating functions,
84
+ # which was not broken down by update
85
+ per_user_RL_only_est_fns_stacked = np.stack(per_user_RL_only_est_fns, axis=1)
86
+
87
+ # Now multiply the M matrices and the per-user estimating functions
88
+ # and sum over the updates to get the per-user meat adjustments (to be more precise, what would
89
+ # be added to each users inference estimating function before an outer product is taken with
90
+ # itself to get each users's contributioan theta-only meat matrix).
91
+ # Result is shape (num_users, theta_dim).
92
+ # Form the per-user adaptive meat adjustments explicitly for diagnostic purposes.
93
+ per_user_meat_adjustments_stacked = np.einsum(
94
+ "utb,nub->nt", M_blocks_stacked, per_user_RL_only_est_fns_stacked
95
+ )
96
+
97
+ # Log some diagnostics about the pieces going into the adaptive meat adjustments
98
+ # and the adjustments themselves.
99
+ V_blocks = np.split(
100
+ V_blocks_together, V_blocks_together.shape[1] // beta_dim, axis=1
101
+ )
102
+ logger.info("Examining adaptive meat adjustments.")
103
+ # No scientific notation
104
+ np.set_printoptions(suppress=True)
105
+
106
+ per_user_inference_estimating_functions_stacked = (
107
+ per_user_estimating_function_stacks[:, -theta_dim:]
108
+ )
109
+ # This actually logs way too much, so making these all debug level to not exhaust VScode
110
+ # terminal buffer
111
+ logger.debug(
112
+ "Per-user inference estimating functions. Without adjustment, the average of the outer products of these is the classical meat: %s",
113
+ per_user_inference_estimating_functions_stacked,
114
+ )
115
+ logger.debug(
116
+ "Norms of per-user inference estimating functions: %s",
117
+ np.linalg.norm(per_user_inference_estimating_functions_stacked, axis=1),
118
+ )
119
+
120
+ logger.debug(
121
+ "Per-user adaptive meat adjustments, to be added to inference estimating functions before forming the meat. Formed from the sum of the products of the M-blocks and the corresponding RL update estimating functions for each user: %s",
122
+ per_user_meat_adjustments_stacked,
123
+ )
124
+ logger.debug(
125
+ "Norms of per-user adaptive meat adjustments: %s",
126
+ np.linalg.norm(per_user_meat_adjustments_stacked, axis=1),
127
+ )
128
+
129
+ per_user_fractional_adjustments = (
130
+ per_user_meat_adjustments_stacked
131
+ / per_user_inference_estimating_functions_stacked
132
+ )
133
+ logger.debug(
134
+ "Per-user fractional adjustments (elementwise ratio of adjustment to original inference estimating function): %s",
135
+ per_user_fractional_adjustments,
136
+ )
137
+ logger.debug(
138
+ "Norms of per-user fractional adjustments: %s",
139
+ np.linalg.norm(per_user_fractional_adjustments, axis=1),
140
+ )
141
+
142
+ V_blocks_stacked = np.stack(V_blocks, axis=0)
143
+ logger.debug(
144
+ "V_blocks, one per update, each shape theta_dim x beta_dim. These measure the sensitivity of the estimating function for theta to the limiting policy parameters per update: %s",
145
+ V_blocks_stacked,
146
+ )
147
+ logger.debug("Norms of V-blocks: %s", np.linalg.norm(V_blocks_stacked, axis=(1, 2)))
148
+
149
+ logger.debug(
150
+ "M_blocks, one per update, each shape theta_dim x beta_dim. The sum of the products "
151
+ "of each of these times a user's corresponding RL estimating function forms their adaptive "
152
+ "adjustment. The M's are the blocks of the the product of the V's concatened and the inverse of "
153
+ "the RL-only upper-left block of the joint adaptive bread inverse. In other words, the lower "
154
+ "left block of the joint adaptive bread. Also note that the inference estimating function "
155
+ "derivative inverse is omitted here despite the definition of the M's in the paper, because "
156
+ "that factor simply cancels later: %s",
157
+ M_blocks_stacked,
158
+ )
159
+ logger.debug("Norms of M-blocks: %s", np.linalg.norm(M_blocks_stacked, axis=(1, 2)))
160
+
161
+ logger.debug(
162
+ "RL block of joint adaptive bread inverse. The *inverse* of this goes into the M's: %s",
163
+ RL_stack_beta_derivatives_block,
164
+ )
165
+ logger.debug(
166
+ "Norm of RL block of joint adaptive bread inverse: %s",
167
+ np.linalg.norm(RL_stack_beta_derivatives_block),
168
+ )
169
+
170
+ inverse_RL_stack_beta_derivatives_block = np.linalg.inv(
171
+ RL_stack_beta_derivatives_block
172
+ )
173
+ logger.debug(
174
+ "Inverse of RL block of joint adaptive bread inverse. This goes into the M's: %s",
175
+ inverse_RL_stack_beta_derivatives_block,
176
+ )
177
+ logger.debug(
178
+ "Norm of Inverse of RL block of joint adaptive bread inverse: %s",
179
+ np.linalg.norm(inverse_RL_stack_beta_derivatives_block),
180
+ )
181
+
182
+ logger.debug(
183
+ "Per-update RL-only estimating function elementwise maxes across users: %s",
184
+ np.max(per_user_RL_only_est_fns_stacked, axis=0),
185
+ )
186
+ logger.debug(
187
+ "Per-update RL-only estimating function elementwise mins across users: %s",
188
+ np.min(per_user_RL_only_est_fns_stacked, axis=0),
189
+ )
190
+ logger.debug(
191
+ "Per-user average RL-only estimating functions across updates: %s",
192
+ np.mean(per_user_RL_only_est_fns_stacked, axis=1),
193
+ )
194
+ logger.debug(
195
+ "Per-update std of RL-only estimating functions across users: %s",
196
+ np.std(per_user_RL_only_est_fns_stacked, axis=0),
197
+ )
198
+ logger.debug(
199
+ "Norms of per-user RL-only estimating functions (num users x num updates): %s",
200
+ np.linalg.norm(per_user_RL_only_est_fns_stacked, axis=2),
201
+ )
202
+
203
+ # Now dig even deeper to get weight derivatives and inference estimating function mixed
204
+ # derivatives that go into the V's
205
+
206
+ pi_and_weight_gradients_by_calendar_t = calculate_pi_and_weight_gradients(
207
+ study_df,
208
+ in_study_col_name,
209
+ action_col_name,
210
+ calendar_t_col_name,
211
+ user_id_col_name,
212
+ action_prob_func,
213
+ action_prob_func_args,
214
+ action_prob_func_args_beta_index,
215
+ )
216
+
217
+ _, _, loss_gradient_pi_derivatives = calculate_inference_loss_derivatives(
218
+ study_df,
219
+ theta_est,
220
+ inference_func,
221
+ inference_func_args_theta_index,
222
+ user_ids,
223
+ user_id_col_name,
224
+ action_prob_col_name,
225
+ in_study_col_name,
226
+ calendar_t_col_name,
227
+ )
228
+ # Take the outer product of each row of (per_user_meat_adjustments_stacked + per_user_inference_estimating_functions_stacked)
229
+ per_user_adjusted_inference_estimating_functions_stacked = (
230
+ per_user_meat_adjustments_stacked
231
+ + per_user_inference_estimating_functions_stacked
232
+ )
233
+ per_user_theta_only_adaptive_meat_contributions = jnp.einsum(
234
+ "ni,nj->nij",
235
+ per_user_adjusted_inference_estimating_functions_stacked,
236
+ per_user_adjusted_inference_estimating_functions_stacked,
237
+ )
238
+ adaptive_theta_only_meat_matrix = jnp.mean(
239
+ per_user_theta_only_adaptive_meat_contributions, axis=0
240
+ )
241
+ logger.info(
242
+ "Theta-only adaptive meat matrix (no small sample corrections): %s",
243
+ adaptive_theta_only_meat_matrix,
244
+ )
245
+ classical_theta_only_meat_matrix = jnp.mean(
246
+ jnp.einsum(
247
+ "ni,nj->nij",
248
+ per_user_inference_estimating_functions_stacked,
249
+ per_user_inference_estimating_functions_stacked,
250
+ ),
251
+ axis=0,
252
+ )
253
+ logger.info(
254
+ "Classical meat matrix (no small sample corrections): %s",
255
+ classical_theta_only_meat_matrix,
256
+ )
257
+
258
+ # np.linalg.cond(RL_stack_beta_derivatives_block)
259
+
260
+ # Print the condition number of each upper left block of RL_stack_beta_derivatives_block
261
+ # as if we stopped after first update, then second update, etc, up to full beta_dim * num_updates
262
+ num_updates = RL_stack_beta_derivatives_block.shape[0] // beta_dim
263
+ whole_block_condition_numbers = []
264
+ diagonal_block_condition_numbers = []
265
+ for i in range(1, num_updates + 1):
266
+ whole_block_size = i * beta_dim
267
+ whole_block = RL_stack_beta_derivatives_block[
268
+ :whole_block_size, :whole_block_size
269
+ ]
270
+ whole_block_cond_number = np.linalg.cond(whole_block)
271
+ whole_block_condition_numbers.append(whole_block_cond_number)
272
+ logger.info(
273
+ "Condition number of whole RL_stack_beta_derivatives_block (after update %s): %s",
274
+ i,
275
+ whole_block_cond_number,
276
+ )
277
+ diagonal_block = RL_stack_beta_derivatives_block[
278
+ (i - 1) * beta_dim : i * beta_dim, (i - 1) * beta_dim : i * beta_dim
279
+ ]
280
+ diagonal_block_cond_number = np.linalg.cond(diagonal_block)
281
+ diagonal_block_condition_numbers.append(diagonal_block_cond_number)
282
+ logger.info(
283
+ "Condition number of just RL_stack_beta_derivatives_block *diagonal block* for update %s: %s",
284
+ i,
285
+ diagonal_block_cond_number,
286
+ )
287
+
288
+ off_diagonal_scaled_block_norm_sum = 0
289
+ for j in range(1, i):
290
+ off_diagonal_block = RL_stack_beta_derivatives_block[
291
+ (i - 1) * beta_dim : i * beta_dim, (j - 1) * beta_dim : j * beta_dim
292
+ ]
293
+ off_diagonal_scaled_block_norm = np.linalg.norm(
294
+ np.linalg.solve(diagonal_block, off_diagonal_block)
295
+ )
296
+ off_diagonal_scaled_block_norm_sum += off_diagonal_scaled_block_norm
297
+ logger.debug(
298
+ "Norm of off-diagonal block (%s, %s) scaled by inverse of diagonal block: %s",
299
+ i,
300
+ j,
301
+ off_diagonal_scaled_block_norm,
302
+ )
303
+
304
+ logger.info(
305
+ "Sum of norms of off-diagonal blocks in row %s scaled by inverse of diagonal block: %s",
306
+ i,
307
+ off_diagonal_scaled_block_norm_sum,
308
+ )
309
+
310
+ # Keeping a breakpoint here is the best way to dig in without logging too
311
+ # much or being too opinionated about what to log.
312
+ breakpoint()
313
+
314
+ # # Visualize the inverse RL block of joint adaptive bread inverse using seaborn heatmap
315
+ # pyplt.figure(figsize=(8, 6))
316
+ # sns.heatmap(inverse_RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
317
+ # pyplt.title("Inverse RL Block of Joint Adaptive Bread Inverse")
318
+ # pyplt.xlabel("Beta Index")
319
+ # pyplt.ylabel("Beta Index")
320
+ # pyplt.tight_layout()
321
+ # pyplt.show()
322
+
323
+ # # # Visualize the RL block of joint adaptive bread inverse using seaborn heatmap
324
+
325
+ # pyplt.figure(figsize=(8, 6))
326
+ # sns.heatmap(RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
327
+ # pyplt.title("RL Block of Joint Adaptive Bread Inverse")
328
+ # pyplt.xlabel("Beta Index")
329
+ # pyplt.ylabel("Beta Index")
330
+ # pyplt.tight_layout()
331
+ # pyplt.show()
332
+
333
+ return per_user_theta_only_adaptive_meat_contributions