lifejacket 1.0.0__tar.gz → 1.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lifejacket-1.0.0 → lifejacket-1.0.2}/PKG-INFO +1 -1
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/calculate_derivatives.py +0 -2
- lifejacket-1.0.2/lifejacket/constants.py +16 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/deployment_conditioning_monitor.py +19 -12
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/form_adjusted_meat_adjustments_directly.py +25 -27
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/get_datum_for_blowup_supervised_learning.py +71 -77
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/helper_functions.py +15 -148
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/input_checks.py +49 -50
- lifejacket-1.0.0/lifejacket/after_study_analysis.py → lifejacket-1.0.2/lifejacket/post_deployment_analysis.py +127 -124
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/small_sample_corrections.py +11 -13
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket.egg-info/PKG-INFO +1 -1
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket.egg-info/SOURCES.txt +1 -1
- lifejacket-1.0.2/lifejacket.egg-info/entry_points.txt +2 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/pyproject.toml +2 -2
- lifejacket-1.0.0/lifejacket/constants.py +0 -28
- lifejacket-1.0.0/lifejacket.egg-info/entry_points.txt +0 -2
- {lifejacket-1.0.0 → lifejacket-1.0.2}/README.md +0 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/__init__.py +0 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/arg_threading_helpers.py +0 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket/vmap_helpers.py +0 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket.egg-info/dependency_links.txt +0 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket.egg-info/requires.txt +0 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/lifejacket.egg-info/top_level.txt +0 -0
- {lifejacket-1.0.0 → lifejacket-1.0.2}/setup.cfg +0 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
class SmallSampleCorrections:
|
|
2
|
+
NONE = "none"
|
|
3
|
+
Z1theta = "Z1theta"
|
|
4
|
+
Z2theta = "Z2theta"
|
|
5
|
+
Z3theta = "Z3theta"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FunctionTypes:
|
|
9
|
+
LOSS = "loss"
|
|
10
|
+
ESTIMATING = "estimating"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SandwichFormationMethods:
|
|
14
|
+
BREAD_T_QR = "bread_T_qr"
|
|
15
|
+
MEAT_SVD_SOLVE = "meat_svd_solve"
|
|
16
|
+
NAIVE = "naive"
|
|
@@ -35,6 +35,12 @@ logging.basicConfig(
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class DeploymentConditioningMonitor:
|
|
38
|
+
"""
|
|
39
|
+
Experimental feature. Monitors the conditioning of the RL portion of the bread matrix.
|
|
40
|
+
Repeats more logic from post_deployment_analysis.py than is ideal, but this is for experimental use only.
|
|
41
|
+
Unit tests should be unskipped and expanded if this is to be used more broadly.
|
|
42
|
+
"""
|
|
43
|
+
|
|
38
44
|
whole_RL_block_conditioning_threshold = None
|
|
39
45
|
diagonal_RL_block_conditioning_threshold = None
|
|
40
46
|
|
|
@@ -76,7 +82,7 @@ class DeploymentConditioningMonitor:
|
|
|
76
82
|
incremental: bool = True,
|
|
77
83
|
) -> None:
|
|
78
84
|
"""
|
|
79
|
-
Analyzes a dataset to estimate parameters and variance using
|
|
85
|
+
Analyzes a dataset to estimate parameters and variance using adjusted and classical sandwich estimators.
|
|
80
86
|
|
|
81
87
|
Parameters:
|
|
82
88
|
proposed_policy_num (int | float):
|
|
@@ -124,13 +130,13 @@ class DeploymentConditioningMonitor:
|
|
|
124
130
|
small_sample_correction (str):
|
|
125
131
|
Type of small sample correction to apply.
|
|
126
132
|
collect_data_for_blowup_supervised_learning (bool):
|
|
127
|
-
Whether to collect data for doing supervised learning about
|
|
133
|
+
Whether to collect data for doing supervised learning about adjusted sandwich blowup.
|
|
128
134
|
form_adjusted_meat_adjustments_explicitly (bool):
|
|
129
|
-
If True, explicitly forms the per-subject meat adjustments that differentiate the
|
|
135
|
+
If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
|
|
130
136
|
sandwich from the classical sandwich. This is for diagnostic purposes, as the
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
If True, stabilizes the joint
|
|
137
|
+
adjusted sandwich is formed without doing this.
|
|
138
|
+
stabilize_joint_bread (bool):
|
|
139
|
+
If True, stabilizes the joint bread matrix if it does not meet conditioning
|
|
134
140
|
thresholds.
|
|
135
141
|
|
|
136
142
|
Returns:
|
|
@@ -157,6 +163,7 @@ class DeploymentConditioningMonitor:
|
|
|
157
163
|
alg_update_func_args_beta_index,
|
|
158
164
|
alg_update_func_args_action_prob_index,
|
|
159
165
|
alg_update_func_args_action_prob_times_index,
|
|
166
|
+
alg_update_func_args_previous_betas_index,
|
|
160
167
|
suppress_interactive_data_checks,
|
|
161
168
|
)
|
|
162
169
|
|
|
@@ -230,7 +237,7 @@ class DeploymentConditioningMonitor:
|
|
|
230
237
|
|
|
231
238
|
if whole_RL_block_condition_number > self.whole_RL_block_conditioning_threshold:
|
|
232
239
|
logger.warning(
|
|
233
|
-
"The RL portion of the bread
|
|
240
|
+
"The RL portion of the bread up to this point exceeds the threshold set (condition number: %s, threshold: %s). Consider an alternative update strategy which produces less dependence on previous RL parameters (via the data they produced) and/or improves the conditioning of each update itself. Regularization may help with both of these.",
|
|
234
241
|
whole_RL_block_condition_number,
|
|
235
242
|
self.whole_RL_block_conditioning_threshold,
|
|
236
243
|
)
|
|
@@ -241,7 +248,7 @@ class DeploymentConditioningMonitor:
|
|
|
241
248
|
> self.diagonal_RL_block_conditioning_threshold
|
|
242
249
|
):
|
|
243
250
|
logger.warning(
|
|
244
|
-
"The diagonal RL block of the bread
|
|
251
|
+
"The diagonal RL block of the bread up to this point exceeds the threshold set (condition number: %s, threshold: %s). This may illustrate a fundamental problem with the conditioning of the RL update procedure.",
|
|
245
252
|
new_diagonal_RL_block_condition_number,
|
|
246
253
|
self.diagonal_RL_block_conditioning_threshold,
|
|
247
254
|
)
|
|
@@ -295,11 +302,11 @@ class DeploymentConditioningMonitor:
|
|
|
295
302
|
jnp.ndarray[jnp.float32],
|
|
296
303
|
]:
|
|
297
304
|
"""
|
|
298
|
-
Constructs the classical and
|
|
305
|
+
Constructs the classical and bread and meat matrices, as well as the average
|
|
299
306
|
estimating function stack and some other intermediate pieces.
|
|
300
307
|
|
|
301
308
|
This is done by computing and differentiating the average weighted estimating function stack
|
|
302
|
-
with respect to the betas and theta, using the resulting Jacobian to compute the
|
|
309
|
+
with respect to the betas and theta, using the resulting Jacobian to compute the bread
|
|
303
310
|
and meat matrices, and then stably computing sandwiches.
|
|
304
311
|
|
|
305
312
|
Args:
|
|
@@ -471,7 +478,7 @@ class DeploymentConditioningMonitor:
|
|
|
471
478
|
]:
|
|
472
479
|
"""
|
|
473
480
|
Computes the average weighted estimating function stack across all subjects, along with
|
|
474
|
-
auxiliary values used to construct the
|
|
481
|
+
auxiliary values used to construct the adjusted and classical sandwich variances.
|
|
475
482
|
|
|
476
483
|
If only_latest_block is True, only uses data from the most recent update.
|
|
477
484
|
|
|
@@ -614,7 +621,7 @@ class DeploymentConditioningMonitor:
|
|
|
614
621
|
)
|
|
615
622
|
|
|
616
623
|
# 5. Now we can compute the weighted estimating function stacks for all subjects
|
|
617
|
-
# as well as collect related values used to construct the
|
|
624
|
+
# as well as collect related values used to construct the adjusted and classical
|
|
618
625
|
# sandwich variances.
|
|
619
626
|
RL_stacks = jnp.array(
|
|
620
627
|
[
|
|
@@ -21,7 +21,7 @@ logging.basicConfig(
|
|
|
21
21
|
def form_adjusted_meat_adjustments_directly(
|
|
22
22
|
theta_dim: int,
|
|
23
23
|
beta_dim: int,
|
|
24
|
-
|
|
24
|
+
joint_bread_matrix: jnp.ndarray,
|
|
25
25
|
per_user_estimating_function_stacks: jnp.ndarray,
|
|
26
26
|
study_df: pd.DataFrame,
|
|
27
27
|
active_col_name: str,
|
|
@@ -38,18 +38,16 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
38
38
|
action_prob_col_name: str,
|
|
39
39
|
) -> jnp.ndarray:
|
|
40
40
|
logger.info(
|
|
41
|
-
"Explicitly forming the per-user meat adjustments that differentiate the
|
|
41
|
+
"Explicitly forming the per-user meat adjustments that differentiate the adjusted sandwich from the classical sandwich."
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
# 1. Form the M-matrices, which are shared across users.
|
|
45
45
|
# This is not quite the paper definition of the M-matrices, which
|
|
46
46
|
# includes multiplication by the classical bread. We don't care about
|
|
47
47
|
# that here, since in forming the adjustments there is a multiplication
|
|
48
|
-
# by the classical bread
|
|
49
|
-
V_blocks_together =
|
|
50
|
-
RL_stack_beta_derivatives_block =
|
|
51
|
-
:-theta_dim, :-theta_dim
|
|
52
|
-
]
|
|
48
|
+
# by the classical bread that cancels it out.
|
|
49
|
+
V_blocks_together = joint_bread_matrix[-theta_dim:, :-theta_dim]
|
|
50
|
+
RL_stack_beta_derivatives_block = joint_bread_matrix[:-theta_dim, :-theta_dim]
|
|
53
51
|
effective_M_blocks_together = np.linalg.solve(
|
|
54
52
|
RL_stack_beta_derivatives_block.T, V_blocks_together.T
|
|
55
53
|
).T
|
|
@@ -89,17 +87,17 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
89
87
|
# be added to each users inference estimating function before an outer product is taken with
|
|
90
88
|
# itself to get each users's contributioan theta-only meat matrix).
|
|
91
89
|
# Result is shape (num_users, theta_dim).
|
|
92
|
-
# Form the per-user
|
|
90
|
+
# Form the per-user adjusted meat adjustments explicitly for diagnostic purposes.
|
|
93
91
|
per_user_meat_adjustments_stacked = np.einsum(
|
|
94
92
|
"utb,nub->nt", M_blocks_stacked, per_user_RL_only_est_fns_stacked
|
|
95
93
|
)
|
|
96
94
|
|
|
97
|
-
# Log some diagnostics about the pieces going into the
|
|
95
|
+
# Log some diagnostics about the pieces going into the adjusted meat adjustments
|
|
98
96
|
# and the adjustments themselves.
|
|
99
97
|
V_blocks = np.split(
|
|
100
98
|
V_blocks_together, V_blocks_together.shape[1] // beta_dim, axis=1
|
|
101
99
|
)
|
|
102
|
-
logger.info("Examining
|
|
100
|
+
logger.info("Examining adjusted meat adjustments.")
|
|
103
101
|
# No scientific notation
|
|
104
102
|
np.set_printoptions(suppress=True)
|
|
105
103
|
|
|
@@ -118,11 +116,11 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
118
116
|
)
|
|
119
117
|
|
|
120
118
|
logger.debug(
|
|
121
|
-
"Per-user
|
|
119
|
+
"Per-user adjusted meat adjustments, to be added to inference estimating functions before forming the meat. Formed from the sum of the products of the M-blocks and the corresponding RL update estimating functions for each user: %s",
|
|
122
120
|
per_user_meat_adjustments_stacked,
|
|
123
121
|
)
|
|
124
122
|
logger.debug(
|
|
125
|
-
"Norms of per-user
|
|
123
|
+
"Norms of per-user adjusted meat adjustments: %s",
|
|
126
124
|
np.linalg.norm(per_user_meat_adjustments_stacked, axis=1),
|
|
127
125
|
)
|
|
128
126
|
|
|
@@ -148,10 +146,10 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
148
146
|
|
|
149
147
|
logger.debug(
|
|
150
148
|
"M_blocks, one per update, each shape theta_dim x beta_dim. The sum of the products "
|
|
151
|
-
"of each of these times a user's corresponding RL estimating function forms their
|
|
149
|
+
"of each of these times a user's corresponding RL estimating function forms their adjusted "
|
|
152
150
|
"adjustment. The M's are the blocks of the the product of the V's concatened and the inverse of "
|
|
153
|
-
"the RL-only upper-left block of the joint
|
|
154
|
-
"left block of the joint
|
|
151
|
+
"the RL-only upper-left block of the joint bread. In other words, the lower "
|
|
152
|
+
"left block of the joint bread. Also note that the inference estimating function "
|
|
155
153
|
"derivative inverse is omitted here despite the definition of the M's in the paper, because "
|
|
156
154
|
"that factor simply cancels later: %s",
|
|
157
155
|
M_blocks_stacked,
|
|
@@ -159,11 +157,11 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
159
157
|
logger.debug("Norms of M-blocks: %s", np.linalg.norm(M_blocks_stacked, axis=(1, 2)))
|
|
160
158
|
|
|
161
159
|
logger.debug(
|
|
162
|
-
"RL block of joint
|
|
160
|
+
"RL block of joint bread. The *inverse* of this goes into the M's: %s",
|
|
163
161
|
RL_stack_beta_derivatives_block,
|
|
164
162
|
)
|
|
165
163
|
logger.debug(
|
|
166
|
-
"Norm of RL block of joint
|
|
164
|
+
"Norm of RL block of joint bread: %s",
|
|
167
165
|
np.linalg.norm(RL_stack_beta_derivatives_block),
|
|
168
166
|
)
|
|
169
167
|
|
|
@@ -171,11 +169,11 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
171
169
|
RL_stack_beta_derivatives_block
|
|
172
170
|
)
|
|
173
171
|
logger.debug(
|
|
174
|
-
"Inverse of RL block of joint
|
|
172
|
+
"Inverse of RL block of joint bread. This goes into the M's: %s",
|
|
175
173
|
inverse_RL_stack_beta_derivatives_block,
|
|
176
174
|
)
|
|
177
175
|
logger.debug(
|
|
178
|
-
"Norm of Inverse of RL block of joint
|
|
176
|
+
"Norm of Inverse of RL block of joint bread: %s",
|
|
179
177
|
np.linalg.norm(inverse_RL_stack_beta_derivatives_block),
|
|
180
178
|
)
|
|
181
179
|
|
|
@@ -230,17 +228,17 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
230
228
|
per_user_meat_adjustments_stacked
|
|
231
229
|
+ per_user_inference_estimating_functions_stacked
|
|
232
230
|
)
|
|
233
|
-
|
|
231
|
+
per_user_theta_only_adjusted_meat_contributions = jnp.einsum(
|
|
234
232
|
"ni,nj->nij",
|
|
235
233
|
per_user_adjusted_inference_estimating_functions_stacked,
|
|
236
234
|
per_user_adjusted_inference_estimating_functions_stacked,
|
|
237
235
|
)
|
|
238
|
-
|
|
239
|
-
|
|
236
|
+
adjusted_theta_only_meat_matrix = jnp.mean(
|
|
237
|
+
per_user_theta_only_adjusted_meat_contributions, axis=0
|
|
240
238
|
)
|
|
241
239
|
logger.info(
|
|
242
|
-
"Theta-only
|
|
243
|
-
|
|
240
|
+
"Theta-only adjusted meat matrix (no small sample corrections): %s",
|
|
241
|
+
adjusted_theta_only_meat_matrix,
|
|
244
242
|
)
|
|
245
243
|
classical_theta_only_meat_matrix = jnp.mean(
|
|
246
244
|
jnp.einsum(
|
|
@@ -311,7 +309,7 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
311
309
|
# much or being too opinionated about what to log.
|
|
312
310
|
breakpoint()
|
|
313
311
|
|
|
314
|
-
# # Visualize the inverse RL block of joint
|
|
312
|
+
# # Visualize the inverse RL block of joint bread using seaborn heatmap
|
|
315
313
|
# pyplt.figure(figsize=(8, 6))
|
|
316
314
|
# sns.heatmap(inverse_RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
|
|
317
315
|
# pyplt.title("Inverse RL Block of Joint Adaptive Bread Inverse")
|
|
@@ -320,7 +318,7 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
320
318
|
# pyplt.tight_layout()
|
|
321
319
|
# pyplt.show()
|
|
322
320
|
|
|
323
|
-
# # # Visualize the RL block of joint
|
|
321
|
+
# # # Visualize the RL block of joint bread using seaborn heatmap
|
|
324
322
|
|
|
325
323
|
# pyplt.figure(figsize=(8, 6))
|
|
326
324
|
# sns.heatmap(RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
|
|
@@ -330,4 +328,4 @@ def form_adjusted_meat_adjustments_directly(
|
|
|
330
328
|
# pyplt.tight_layout()
|
|
331
329
|
# pyplt.show()
|
|
332
330
|
|
|
333
|
-
return
|
|
331
|
+
return per_user_theta_only_adjusted_meat_contributions
|