lifejacket 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,8 +18,6 @@ logging.basicConfig(
18
18
  level=logging.INFO,
19
19
  )
20
20
 
21
- # TODO: Consolidate function loading logic
22
-
23
21
 
24
22
  def get_batched_arg_lists_and_involved_user_ids(func, sorted_user_ids, args_by_user_id):
25
23
  """
lifejacket/constants.py CHANGED
@@ -1,20 +1,8 @@
1
1
  class SmallSampleCorrections:
2
2
  NONE = "none"
3
- HC1theta = "HC1theta"
4
- HC2theta = "HC2theta"
5
- HC3theta = "HC3theta"
6
-
7
-
8
- class InverseStabilizationMethods:
9
- NONE = "none"
10
- TRIM_SMALL_SINGULAR_VALUES = "trim_small_singular_values"
11
- ZERO_OUT_SMALL_OFF_DIAGONALS = "zero_out_small_off_diagonals"
12
- ADD_RIDGE_FIXED_CONDITION_NUMBER = "add_ridge_fixed_condition_number"
13
- ADD_RIDGE_MEDIAN_SINGULAR_VALUE_FRACTION = (
14
- "add_ridge_median_singular_value_fraction"
15
- )
16
- INVERSE_BREAD_STRUCTURE_AWARE_INVERSION = "inverse_bread_structure_aware_inversion"
17
- ALL_METHODS_COMPETITION = "all_methods_competition"
3
+ Z1theta = "Z1theta"
4
+ Z2theta = "Z2theta"
5
+ Z3theta = "Z3theta"
18
6
 
19
7
 
20
8
  class FunctionTypes:
@@ -23,6 +11,6 @@ class FunctionTypes:
23
11
 
24
12
 
25
13
  class SandwichFormationMethods:
26
- BREAD_INVERSE_T_QR = "bread_inverse_T_qr"
14
+ BREAD_T_QR = "bread_T_qr"
27
15
  MEAT_SVD_SOLVE = "meat_svd_solve"
28
16
  NAIVE = "naive"
@@ -35,6 +35,12 @@ logging.basicConfig(
35
35
 
36
36
 
37
37
  class DeploymentConditioningMonitor:
38
+ """
39
+ Experimental feature. Monitors the conditioning of the RL portion of the bread matrix.
40
+ Repeats more logic from post_deployment_analysis.py than is ideal, but this is for experimental use only.
41
+ Unit tests should be unskipped and expanded if this is to be used more broadly.
42
+ """
43
+
38
44
  whole_RL_block_conditioning_threshold = None
39
45
  diagonal_RL_block_conditioning_threshold = None
40
46
 
@@ -76,7 +82,7 @@ class DeploymentConditioningMonitor:
76
82
  incremental: bool = True,
77
83
  ) -> None:
78
84
  """
79
- Analyzes a dataset to estimate parameters and variance using adaptive and classical sandwich estimators.
85
+ Analyzes a dataset to estimate parameters and variance using adjusted and classical sandwich estimators.
80
86
 
81
87
  Parameters:
82
88
  proposed_policy_num (int | float):
@@ -124,13 +130,13 @@ class DeploymentConditioningMonitor:
124
130
  small_sample_correction (str):
125
131
  Type of small sample correction to apply.
126
132
  collect_data_for_blowup_supervised_learning (bool):
127
- Whether to collect data for doing supervised learning about adaptive sandwich blowup.
133
+ Whether to collect data for doing supervised learning about adjusted sandwich blowup.
128
134
  form_adjusted_meat_adjustments_explicitly (bool):
129
- If True, explicitly forms the per-subject meat adjustments that differentiate the adaptive
135
+ If True, explicitly forms the per-subject meat adjustments that differentiate the adjusted
130
136
  sandwich from the classical sandwich. This is for diagnostic purposes, as the
131
- adaptive sandwich is formed without doing this.
132
- stabilize_joint_adjusted_bread_inverse (bool):
133
- If True, stabilizes the joint adaptive bread inverse matrix if it does not meet conditioning
137
+ adjusted sandwich is formed without doing this.
138
+ stabilize_joint_bread (bool):
139
+ If True, stabilizes the joint bread matrix if it does not meet conditioning
134
140
  thresholds.
135
141
 
136
142
  Returns:
@@ -157,6 +163,7 @@ class DeploymentConditioningMonitor:
157
163
  alg_update_func_args_beta_index,
158
164
  alg_update_func_args_action_prob_index,
159
165
  alg_update_func_args_action_prob_times_index,
166
+ alg_update_func_args_previous_betas_index,
160
167
  suppress_interactive_data_checks,
161
168
  )
162
169
 
@@ -230,7 +237,7 @@ class DeploymentConditioningMonitor:
230
237
 
231
238
  if whole_RL_block_condition_number > self.whole_RL_block_conditioning_threshold:
232
239
  logger.warning(
233
- "The RL portion of the bread inverse up to this point exceeds the threshold set (condition number: %s, threshold: %s). Consider an alternative update strategy which produces less dependence on previous RL parameters (via the data they produced) and/or improves the conditioning of each update itself. Regularization may help with both of these.",
240
+ "The RL portion of the bread up to this point exceeds the threshold set (condition number: %s, threshold: %s). Consider an alternative update strategy which produces less dependence on previous RL parameters (via the data they produced) and/or improves the conditioning of each update itself. Regularization may help with both of these.",
234
241
  whole_RL_block_condition_number,
235
242
  self.whole_RL_block_conditioning_threshold,
236
243
  )
@@ -241,7 +248,7 @@ class DeploymentConditioningMonitor:
241
248
  > self.diagonal_RL_block_conditioning_threshold
242
249
  ):
243
250
  logger.warning(
244
- "The diagonal RL block of the bread inverse up to this point exceeds the threshold set (condition number: %s, threshold: %s). This may illustrate a fundamental problem with the conditioning of the RL update procedure.",
251
+ "The diagonal RL block of the bread up to this point exceeds the threshold set (condition number: %s, threshold: %s). This may illustrate a fundamental problem with the conditioning of the RL update procedure.",
245
252
  new_diagonal_RL_block_condition_number,
246
253
  self.diagonal_RL_block_conditioning_threshold,
247
254
  )
@@ -295,11 +302,11 @@ class DeploymentConditioningMonitor:
295
302
  jnp.ndarray[jnp.float32],
296
303
  ]:
297
304
  """
298
- Constructs the classical and adaptive inverse bread and meat matrices, as well as the average
305
+ Constructs the classical and bread and meat matrices, as well as the average
299
306
  estimating function stack and some other intermediate pieces.
300
307
 
301
308
  This is done by computing and differentiating the average weighted estimating function stack
302
- with respect to the betas and theta, using the resulting Jacobian to compute the inverse bread
309
+ with respect to the betas and theta, using the resulting Jacobian to compute the bread
303
310
  and meat matrices, and then stably computing sandwiches.
304
311
 
305
312
  Args:
@@ -471,7 +478,7 @@ class DeploymentConditioningMonitor:
471
478
  ]:
472
479
  """
473
480
  Computes the average weighted estimating function stack across all subjects, along with
474
- auxiliary values used to construct the adaptive and classical sandwich variances.
481
+ auxiliary values used to construct the adjusted and classical sandwich variances.
475
482
 
476
483
  If only_latest_block is True, only uses data from the most recent update.
477
484
 
@@ -614,7 +621,7 @@ class DeploymentConditioningMonitor:
614
621
  )
615
622
 
616
623
  # 5. Now we can compute the weighted estimating function stacks for all subjects
617
- # as well as collect related values used to construct the adaptive and classical
624
+ # as well as collect related values used to construct the adjusted and classical
618
625
  # sandwich variances.
619
626
  RL_stacks = jnp.array(
620
627
  [
@@ -21,7 +21,7 @@ logging.basicConfig(
21
21
  def form_adjusted_meat_adjustments_directly(
22
22
  theta_dim: int,
23
23
  beta_dim: int,
24
- joint_adaptive_bread_inverse_matrix: jnp.ndarray,
24
+ joint_bread_matrix: jnp.ndarray,
25
25
  per_user_estimating_function_stacks: jnp.ndarray,
26
26
  study_df: pd.DataFrame,
27
27
  active_col_name: str,
@@ -38,18 +38,16 @@ def form_adjusted_meat_adjustments_directly(
38
38
  action_prob_col_name: str,
39
39
  ) -> jnp.ndarray:
40
40
  logger.info(
41
- "Explicitly forming the per-user meat adjustments that differentiate the adaptive sandwich from the classical sandwich."
41
+ "Explicitly forming the per-user meat adjustments that differentiate the adjusted sandwich from the classical sandwich."
42
42
  )
43
43
 
44
44
  # 1. Form the M-matrices, which are shared across users.
45
45
  # This is not quite the paper definition of the M-matrices, which
46
46
  # includes multiplication by the classical bread. We don't care about
47
47
  # that here, since in forming the adjustments there is a multiplication
48
- # by the classical bread inverse that cancels it out.
49
- V_blocks_together = joint_adaptive_bread_inverse_matrix[-theta_dim:, :-theta_dim]
50
- RL_stack_beta_derivatives_block = joint_adaptive_bread_inverse_matrix[
51
- :-theta_dim, :-theta_dim
52
- ]
48
+ # by the classical bread that cancels it out.
49
+ V_blocks_together = joint_bread_matrix[-theta_dim:, :-theta_dim]
50
+ RL_stack_beta_derivatives_block = joint_bread_matrix[:-theta_dim, :-theta_dim]
53
51
  effective_M_blocks_together = np.linalg.solve(
54
52
  RL_stack_beta_derivatives_block.T, V_blocks_together.T
55
53
  ).T
@@ -89,17 +87,17 @@ def form_adjusted_meat_adjustments_directly(
89
87
  # be added to each users inference estimating function before an outer product is taken with
90
88
  # itself to get each users's contributioan theta-only meat matrix).
91
89
  # Result is shape (num_users, theta_dim).
92
- # Form the per-user adaptive meat adjustments explicitly for diagnostic purposes.
90
+ # Form the per-user adjusted meat adjustments explicitly for diagnostic purposes.
93
91
  per_user_meat_adjustments_stacked = np.einsum(
94
92
  "utb,nub->nt", M_blocks_stacked, per_user_RL_only_est_fns_stacked
95
93
  )
96
94
 
97
- # Log some diagnostics about the pieces going into the adaptive meat adjustments
95
+ # Log some diagnostics about the pieces going into the adjusted meat adjustments
98
96
  # and the adjustments themselves.
99
97
  V_blocks = np.split(
100
98
  V_blocks_together, V_blocks_together.shape[1] // beta_dim, axis=1
101
99
  )
102
- logger.info("Examining adaptive meat adjustments.")
100
+ logger.info("Examining adjusted meat adjustments.")
103
101
  # No scientific notation
104
102
  np.set_printoptions(suppress=True)
105
103
 
@@ -118,11 +116,11 @@ def form_adjusted_meat_adjustments_directly(
118
116
  )
119
117
 
120
118
  logger.debug(
121
- "Per-user adaptive meat adjustments, to be added to inference estimating functions before forming the meat. Formed from the sum of the products of the M-blocks and the corresponding RL update estimating functions for each user: %s",
119
+ "Per-user adjusted meat adjustments, to be added to inference estimating functions before forming the meat. Formed from the sum of the products of the M-blocks and the corresponding RL update estimating functions for each user: %s",
122
120
  per_user_meat_adjustments_stacked,
123
121
  )
124
122
  logger.debug(
125
- "Norms of per-user adaptive meat adjustments: %s",
123
+ "Norms of per-user adjusted meat adjustments: %s",
126
124
  np.linalg.norm(per_user_meat_adjustments_stacked, axis=1),
127
125
  )
128
126
 
@@ -148,10 +146,10 @@ def form_adjusted_meat_adjustments_directly(
148
146
 
149
147
  logger.debug(
150
148
  "M_blocks, one per update, each shape theta_dim x beta_dim. The sum of the products "
151
- "of each of these times a user's corresponding RL estimating function forms their adaptive "
149
+ "of each of these times a user's corresponding RL estimating function forms their adjusted "
152
150
  "adjustment. The M's are the blocks of the the product of the V's concatened and the inverse of "
153
- "the RL-only upper-left block of the joint adaptive bread inverse. In other words, the lower "
154
- "left block of the joint adaptive bread. Also note that the inference estimating function "
151
+ "the RL-only upper-left block of the joint bread. In other words, the lower "
152
+ "left block of the joint bread. Also note that the inference estimating function "
155
153
  "derivative inverse is omitted here despite the definition of the M's in the paper, because "
156
154
  "that factor simply cancels later: %s",
157
155
  M_blocks_stacked,
@@ -159,11 +157,11 @@ def form_adjusted_meat_adjustments_directly(
159
157
  logger.debug("Norms of M-blocks: %s", np.linalg.norm(M_blocks_stacked, axis=(1, 2)))
160
158
 
161
159
  logger.debug(
162
- "RL block of joint adaptive bread inverse. The *inverse* of this goes into the M's: %s",
160
+ "RL block of joint bread. The *inverse* of this goes into the M's: %s",
163
161
  RL_stack_beta_derivatives_block,
164
162
  )
165
163
  logger.debug(
166
- "Norm of RL block of joint adaptive bread inverse: %s",
164
+ "Norm of RL block of joint bread: %s",
167
165
  np.linalg.norm(RL_stack_beta_derivatives_block),
168
166
  )
169
167
 
@@ -171,11 +169,11 @@ def form_adjusted_meat_adjustments_directly(
171
169
  RL_stack_beta_derivatives_block
172
170
  )
173
171
  logger.debug(
174
- "Inverse of RL block of joint adaptive bread inverse. This goes into the M's: %s",
172
+ "Inverse of RL block of joint bread. This goes into the M's: %s",
175
173
  inverse_RL_stack_beta_derivatives_block,
176
174
  )
177
175
  logger.debug(
178
- "Norm of Inverse of RL block of joint adaptive bread inverse: %s",
176
+ "Norm of Inverse of RL block of joint bread: %s",
179
177
  np.linalg.norm(inverse_RL_stack_beta_derivatives_block),
180
178
  )
181
179
 
@@ -230,17 +228,17 @@ def form_adjusted_meat_adjustments_directly(
230
228
  per_user_meat_adjustments_stacked
231
229
  + per_user_inference_estimating_functions_stacked
232
230
  )
233
- per_user_theta_only_adaptive_meat_contributions = jnp.einsum(
231
+ per_user_theta_only_adjusted_meat_contributions = jnp.einsum(
234
232
  "ni,nj->nij",
235
233
  per_user_adjusted_inference_estimating_functions_stacked,
236
234
  per_user_adjusted_inference_estimating_functions_stacked,
237
235
  )
238
- adaptive_theta_only_meat_matrix = jnp.mean(
239
- per_user_theta_only_adaptive_meat_contributions, axis=0
236
+ adjusted_theta_only_meat_matrix = jnp.mean(
237
+ per_user_theta_only_adjusted_meat_contributions, axis=0
240
238
  )
241
239
  logger.info(
242
- "Theta-only adaptive meat matrix (no small sample corrections): %s",
243
- adaptive_theta_only_meat_matrix,
240
+ "Theta-only adjusted meat matrix (no small sample corrections): %s",
241
+ adjusted_theta_only_meat_matrix,
244
242
  )
245
243
  classical_theta_only_meat_matrix = jnp.mean(
246
244
  jnp.einsum(
@@ -311,7 +309,7 @@ def form_adjusted_meat_adjustments_directly(
311
309
  # much or being too opinionated about what to log.
312
310
  breakpoint()
313
311
 
314
- # # Visualize the inverse RL block of joint adaptive bread inverse using seaborn heatmap
312
+ # # Visualize the inverse RL block of joint bread using seaborn heatmap
315
313
  # pyplt.figure(figsize=(8, 6))
316
314
  # sns.heatmap(inverse_RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
317
315
  # pyplt.title("Inverse RL Block of Joint Adaptive Bread Inverse")
@@ -320,7 +318,7 @@ def form_adjusted_meat_adjustments_directly(
320
318
  # pyplt.tight_layout()
321
319
  # pyplt.show()
322
320
 
323
- # # # Visualize the RL block of joint adaptive bread inverse using seaborn heatmap
321
+ # # # Visualize the RL block of joint bread using seaborn heatmap
324
322
 
325
323
  # pyplt.figure(figsize=(8, 6))
326
324
  # sns.heatmap(RL_stack_beta_derivatives_block, annot=False, cmap="viridis")
@@ -330,4 +328,4 @@ def form_adjusted_meat_adjustments_directly(
330
328
  # pyplt.tight_layout()
331
329
  # pyplt.show()
332
330
 
333
- return per_user_theta_only_adaptive_meat_contributions
331
+ return per_user_theta_only_adjusted_meat_contributions