lifejacket 0.2.1__tar.gz → 1.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket-1.0.2/PKG-INFO +56 -0
- lifejacket-1.0.2/README.md +37 -0
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket/arg_threading_helpers.py +75 -69
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket/calculate_derivatives.py +19 -23
- lifejacket-1.0.2/lifejacket/constants.py +16 -0
- lifejacket-0.2.1/lifejacket/trial_conditioning_monitor.py → lifejacket-1.0.2/lifejacket/deployment_conditioning_monitor.py +163 -138
- lifejacket-0.2.1/lifejacket/form_adaptive_meat_adjustments_directly.py → lifejacket-1.0.2/lifejacket/form_adjusted_meat_adjustments_directly.py +32 -34
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket/get_datum_for_blowup_supervised_learning.py +341 -339
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket/helper_functions.py +60 -186
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket/input_checks.py +303 -302
- lifejacket-0.2.1/lifejacket/after_study_analysis.py → lifejacket-1.0.2/lifejacket/post_deployment_analysis.py +470 -457
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket/small_sample_corrections.py +49 -49
- lifejacket-1.0.2/lifejacket.egg-info/PKG-INFO +56 -0
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket.egg-info/SOURCES.txt +3 -3
- lifejacket-1.0.2/lifejacket.egg-info/entry_points.txt +2 -0
- {lifejacket-0.2.1 → lifejacket-1.0.2}/pyproject.toml +3 -3
- lifejacket-0.2.1/PKG-INFO +0 -100
- lifejacket-0.2.1/README.md +0 -81
- lifejacket-0.2.1/lifejacket/constants.py +0 -28
- lifejacket-0.2.1/lifejacket.egg-info/PKG-INFO +0 -100
- lifejacket-0.2.1/lifejacket.egg-info/entry_points.txt +0 -2
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket/__init__.py +0 -0
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket/vmap_helpers.py +0 -0
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket.egg-info/dependency_links.txt +0 -0
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket.egg-info/requires.txt +0 -0
- {lifejacket-0.2.1 → lifejacket-1.0.2}/lifejacket.egg-info/top_level.txt +0 -0
- {lifejacket-0.2.1 → lifejacket-1.0.2}/setup.cfg +0 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lifejacket
|
|
3
|
+
Version: 1.0.2
|
|
4
|
+
Summary: Consistent standard errors for longitudinal data collected under pooling online decision policies.
|
|
5
|
+
Author-email: Nowell Closser <nowellclosser@gmail.com>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: click>=8.0
|
|
9
|
+
Requires-Dist: jax>=0.4.0
|
|
10
|
+
Requires-Dist: jaxlib>=0.4.0
|
|
11
|
+
Requires-Dist: numpy>=1.20.0
|
|
12
|
+
Requires-Dist: pandas>=1.3.0
|
|
13
|
+
Requires-Dist: scipy>=1.7.0
|
|
14
|
+
Requires-Dist: plotext>=5.0.0
|
|
15
|
+
Provides-Extra: dev
|
|
16
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
17
|
+
Requires-Dist: black>=22.0; extra == "dev"
|
|
18
|
+
Requires-Dist: flake8>=4.0; extra == "dev"
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
_ _ __ _ _ _
|
|
22
|
+
| (_)/ _| (_) | | | |
|
|
23
|
+
| |_| |_ ___ _ __ _ ___| | _____| |_
|
|
24
|
+
| | | _/ _ \ |/ _` |/ __| |/ / _ \ __|
|
|
25
|
+
| | | || __/ | (_| | (__| < __/ |_
|
|
26
|
+
|_|_|_| \___| |\__,_|\___|_|\_\___|\__|
|
|
27
|
+
_/ |
|
|
28
|
+
|__/
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
Save your standard errors from pooling in online decision-making algorithms.
|
|
32
|
+
|
|
33
|
+
## Setup (if not using conda)
|
|
34
|
+
### Create and activate a virtual environment
|
|
35
|
+
- `python3 -m venv .venv; source /.venv/bin/activate`
|
|
36
|
+
|
|
37
|
+
### Adding a package
|
|
38
|
+
- Add to `requirements.txt` with a specific version or no version if you want the latest stable
|
|
39
|
+
- Run `pip freeze > requirements.txt` to lock the versions of your package and all its subpackages
|
|
40
|
+
|
|
41
|
+
## Running the code
|
|
42
|
+
- `export PYTHONPATH to the absolute path of this repository on your computer
|
|
43
|
+
- `./run_local_synthetic.sh`, which outputs to `simulated_data/` by default. See all the possible flags to be toggled in the script code.
|
|
44
|
+
|
|
45
|
+
## Linting/Formatting
|
|
46
|
+
|
|
47
|
+
## Testing
|
|
48
|
+
python -m pytest
|
|
49
|
+
python -m pytest tests/unit_tests
|
|
50
|
+
python -m pytest tests/integration_tests
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
## TODO
|
|
55
|
+
1. Add precommit hooks (pip freeze, linting, formatting)
|
|
56
|
+
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
```python
|
|
2
|
+
_ _ __ _ _ _
|
|
3
|
+
| (_)/ _| (_) | | | |
|
|
4
|
+
| |_| |_ ___ _ __ _ ___| | _____| |_
|
|
5
|
+
| | | _/ _ \ |/ _` |/ __| |/ / _ \ __|
|
|
6
|
+
| | | || __/ | (_| | (__| < __/ |_
|
|
7
|
+
|_|_|_| \___| |\__,_|\___|_|\_\___|\__|
|
|
8
|
+
_/ |
|
|
9
|
+
|__/
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
Save your standard errors from pooling in online decision-making algorithms.
|
|
13
|
+
|
|
14
|
+
## Setup (if not using conda)
|
|
15
|
+
### Create and activate a virtual environment
|
|
16
|
+
- `python3 -m venv .venv; source /.venv/bin/activate`
|
|
17
|
+
|
|
18
|
+
### Adding a package
|
|
19
|
+
- Add to `requirements.txt` with a specific version or no version if you want the latest stable
|
|
20
|
+
- Run `pip freeze > requirements.txt` to lock the versions of your package and all its subpackages
|
|
21
|
+
|
|
22
|
+
## Running the code
|
|
23
|
+
- `export PYTHONPATH to the absolute path of this repository on your computer
|
|
24
|
+
- `./run_local_synthetic.sh`, which outputs to `simulated_data/` by default. See all the possible flags to be toggled in the script code.
|
|
25
|
+
|
|
26
|
+
## Linting/Formatting
|
|
27
|
+
|
|
28
|
+
## Testing
|
|
29
|
+
python -m pytest
|
|
30
|
+
python -m pytest tests/unit_tests
|
|
31
|
+
python -m pytest tests/integration_tests
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
## TODO
|
|
36
|
+
1. Add precommit hooks (pip freeze, linting, formatting)
|
|
37
|
+
|
|
@@ -20,10 +20,10 @@ def replace_tuple_index(tupl, index, value):
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def thread_action_prob_func_args(
|
|
23
|
-
|
|
23
|
+
action_prob_func_args_by_subject_id_by_decision_time: dict[
|
|
24
24
|
int, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
25
25
|
],
|
|
26
|
-
|
|
26
|
+
policy_num_by_decision_time_by_subject_id: dict[
|
|
27
27
|
collections.abc.Hashable, dict[int, int | float]
|
|
28
28
|
],
|
|
29
29
|
initial_policy_num: int | float,
|
|
@@ -39,12 +39,12 @@ def thread_action_prob_func_args(
|
|
|
39
39
|
decision time to enable correct differentiation.
|
|
40
40
|
|
|
41
41
|
Args:
|
|
42
|
-
|
|
42
|
+
action_prob_func_args_by_subject_id_by_decision_time (dict[int, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
43
43
|
A map from decision times to maps of user ids to tuples of arguments for action
|
|
44
44
|
probability function. This is for all decision times for all users (args are an empty
|
|
45
45
|
tuple if they are not in the study). Should be sorted by decision time.
|
|
46
46
|
|
|
47
|
-
|
|
47
|
+
policy_num_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, int | float]]):
|
|
48
48
|
A dictionary mapping decision times to the policy number in use. This may be user-specific.
|
|
49
49
|
Should be sorted by decision time.
|
|
50
50
|
|
|
@@ -69,56 +69,58 @@ def thread_action_prob_func_args(
|
|
|
69
69
|
A map from user ids to maps of decision times to action probability function
|
|
70
70
|
arguments tuples with the shared betas threaded in. Note the key order switch.
|
|
71
71
|
"""
|
|
72
|
-
|
|
72
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id = (
|
|
73
73
|
collections.defaultdict(dict)
|
|
74
74
|
)
|
|
75
|
-
|
|
75
|
+
action_prob_func_args_by_decision_time_by_subject_id = collections.defaultdict(dict)
|
|
76
76
|
for (
|
|
77
77
|
decision_time,
|
|
78
|
-
|
|
79
|
-
) in
|
|
80
|
-
for
|
|
78
|
+
action_prob_func_args_by_subject_id,
|
|
79
|
+
) in action_prob_func_args_by_subject_id_by_decision_time.items():
|
|
80
|
+
for subject_id, args in action_prob_func_args_by_subject_id.items():
|
|
81
81
|
# Always add a contribution to the reversed key order dictionary.
|
|
82
|
-
|
|
82
|
+
action_prob_func_args_by_decision_time_by_subject_id[subject_id][
|
|
83
83
|
decision_time
|
|
84
84
|
] = args
|
|
85
85
|
|
|
86
86
|
# Now proceed with the threading, if necessary.
|
|
87
87
|
if not args:
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
] = ()
|
|
88
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
89
|
+
subject_id
|
|
90
|
+
][decision_time] = ()
|
|
91
91
|
continue
|
|
92
92
|
|
|
93
|
-
policy_num =
|
|
93
|
+
policy_num = policy_num_by_decision_time_by_subject_id[subject_id][
|
|
94
|
+
decision_time
|
|
95
|
+
]
|
|
94
96
|
|
|
95
97
|
# The expectation is that fallback policies have empty args, and the only other
|
|
96
98
|
# policy not represented in beta_index_by_policy_num is the initial policy.
|
|
97
99
|
if policy_num == initial_policy_num:
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
] =
|
|
100
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
101
|
+
subject_id
|
|
102
|
+
][decision_time] = action_prob_func_args_by_subject_id[subject_id]
|
|
101
103
|
continue
|
|
102
104
|
|
|
103
105
|
beta_to_introduce = all_post_update_betas[
|
|
104
106
|
beta_index_by_policy_num[policy_num]
|
|
105
107
|
]
|
|
106
|
-
|
|
108
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[subject_id][
|
|
107
109
|
decision_time
|
|
108
110
|
] = replace_tuple_index(
|
|
109
|
-
|
|
111
|
+
action_prob_func_args_by_subject_id[subject_id],
|
|
110
112
|
action_prob_func_args_beta_index,
|
|
111
113
|
beta_to_introduce,
|
|
112
114
|
)
|
|
113
115
|
|
|
114
116
|
return (
|
|
115
|
-
|
|
116
|
-
|
|
117
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id,
|
|
118
|
+
action_prob_func_args_by_decision_time_by_subject_id,
|
|
117
119
|
)
|
|
118
120
|
|
|
119
121
|
|
|
120
122
|
def thread_update_func_args(
|
|
121
|
-
|
|
123
|
+
update_func_args_by_by_subject_id_by_policy_num: dict[
|
|
122
124
|
int | float, dict[collections.abc.Hashable, tuple[Any, ...]]
|
|
123
125
|
],
|
|
124
126
|
all_post_update_betas: jnp.ndarray,
|
|
@@ -127,7 +129,7 @@ def thread_update_func_args(
|
|
|
127
129
|
alg_update_func_args_action_prob_index: int,
|
|
128
130
|
alg_update_func_args_action_prob_times_index: int,
|
|
129
131
|
alg_update_func_args_previous_betas_index: int,
|
|
130
|
-
|
|
132
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id: dict[
|
|
131
133
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
132
134
|
],
|
|
133
135
|
action_prob_func: callable,
|
|
@@ -139,7 +141,7 @@ def thread_update_func_args(
|
|
|
139
141
|
with reconstructed action probabilities computed using the shared betas.
|
|
140
142
|
|
|
141
143
|
Args:
|
|
142
|
-
|
|
144
|
+
update_func_args_by_by_subject_id_by_policy_num (dict[int | float, dict[collections.abc.Hashable, tuple[Any, ...]]]):
|
|
143
145
|
A dictionary where keys are policy
|
|
144
146
|
numbers and values are dictionaries mapping user IDs to their respective update function
|
|
145
147
|
arguments.
|
|
@@ -170,7 +172,7 @@ def thread_update_func_args(
|
|
|
170
172
|
alg_update_func_args_previous_betas_index (int):
|
|
171
173
|
The index in the update function with previous beta parameters
|
|
172
174
|
|
|
173
|
-
|
|
175
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
174
176
|
A dictionary mapping decision times to the function arguments required to compute action
|
|
175
177
|
probabilities for this user, and with the shared betas thread in.
|
|
176
178
|
|
|
@@ -183,49 +185,51 @@ def thread_update_func_args(
|
|
|
183
185
|
arguments tuples for the specified user with the shared betas threaded in. Note the key
|
|
184
186
|
order switch relative to the supplied args!
|
|
185
187
|
"""
|
|
186
|
-
|
|
188
|
+
threaded_update_func_args_by_policy_num_by_subject_id = collections.defaultdict(
|
|
189
|
+
dict
|
|
190
|
+
)
|
|
187
191
|
for (
|
|
188
192
|
policy_num,
|
|
189
|
-
|
|
190
|
-
) in
|
|
191
|
-
for
|
|
193
|
+
update_func_args_by_subject_id,
|
|
194
|
+
) in update_func_args_by_by_subject_id_by_policy_num.items():
|
|
195
|
+
for subject_id, args in update_func_args_by_subject_id.items():
|
|
192
196
|
if not args:
|
|
193
|
-
|
|
197
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
194
198
|
policy_num
|
|
195
199
|
] = ()
|
|
196
200
|
continue
|
|
197
201
|
|
|
198
202
|
logger.debug(
|
|
199
203
|
"Threading in shared betas to update function arguments for user %s and policy number %s.",
|
|
200
|
-
|
|
204
|
+
subject_id,
|
|
201
205
|
policy_num,
|
|
202
206
|
)
|
|
203
207
|
|
|
204
208
|
beta_to_introduce = all_post_update_betas[
|
|
205
209
|
beta_index_by_policy_num[policy_num]
|
|
206
210
|
]
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
211
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
212
|
+
policy_num
|
|
213
|
+
] = replace_tuple_index(
|
|
214
|
+
update_func_args_by_subject_id[subject_id],
|
|
215
|
+
alg_update_func_args_beta_index,
|
|
216
|
+
beta_to_introduce,
|
|
213
217
|
)
|
|
214
218
|
if alg_update_func_args_previous_betas_index >= 0:
|
|
215
219
|
previous_betas_to_introduce = all_post_update_betas[
|
|
216
220
|
: len(
|
|
217
|
-
|
|
221
|
+
update_func_args_by_subject_id[subject_id][
|
|
218
222
|
alg_update_func_args_previous_betas_index
|
|
219
223
|
]
|
|
220
224
|
)
|
|
221
225
|
]
|
|
222
226
|
if previous_betas_to_introduce.size > 0:
|
|
223
|
-
|
|
227
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
224
228
|
policy_num
|
|
225
229
|
] = replace_tuple_index(
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
],
|
|
230
|
+
threaded_update_func_args_by_policy_num_by_subject_id[
|
|
231
|
+
subject_id
|
|
232
|
+
][policy_num],
|
|
229
233
|
alg_update_func_args_previous_betas_index,
|
|
230
234
|
previous_betas_to_introduce,
|
|
231
235
|
)
|
|
@@ -234,20 +238,20 @@ def thread_update_func_args(
|
|
|
234
238
|
logger.debug(
|
|
235
239
|
"Action probabilities are used in the algorithm update function. Reconstructing them using the shared betas."
|
|
236
240
|
)
|
|
237
|
-
action_prob_times =
|
|
241
|
+
action_prob_times = update_func_args_by_subject_id[subject_id][
|
|
238
242
|
alg_update_func_args_action_prob_times_index
|
|
239
243
|
]
|
|
240
244
|
# Vectorized computation of action_probs_to_introduce using jax.vmap
|
|
241
245
|
flattened_times = action_prob_times.flatten()
|
|
242
246
|
args_list = [
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
]
|
|
247
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
248
|
+
subject_id
|
|
249
|
+
][int(t)]
|
|
246
250
|
for t in flattened_times.tolist()
|
|
247
251
|
]
|
|
248
252
|
if len(args_list) == 0:
|
|
249
253
|
action_probs_to_introduce = jnp.array([]).reshape(
|
|
250
|
-
|
|
254
|
+
update_func_args_by_subject_id[subject_id][
|
|
251
255
|
alg_update_func_args_action_prob_index
|
|
252
256
|
].shape
|
|
253
257
|
)
|
|
@@ -264,31 +268,31 @@ def thread_update_func_args(
|
|
|
264
268
|
action_prob_func, in_axes=tuple(0 for _ in batched_tensors)
|
|
265
269
|
)
|
|
266
270
|
action_probs_to_introduce = vmapped_func(*batched_tensors).reshape(
|
|
267
|
-
|
|
271
|
+
update_func_args_by_subject_id[subject_id][
|
|
268
272
|
alg_update_func_args_action_prob_index
|
|
269
273
|
].shape
|
|
270
274
|
)
|
|
271
|
-
|
|
275
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
272
276
|
policy_num
|
|
273
277
|
] = replace_tuple_index(
|
|
274
|
-
|
|
278
|
+
threaded_update_func_args_by_policy_num_by_subject_id[subject_id][
|
|
275
279
|
policy_num
|
|
276
280
|
],
|
|
277
281
|
alg_update_func_args_action_prob_index,
|
|
278
282
|
action_probs_to_introduce,
|
|
279
283
|
)
|
|
280
|
-
return
|
|
284
|
+
return threaded_update_func_args_by_policy_num_by_subject_id
|
|
281
285
|
|
|
282
286
|
|
|
283
287
|
def thread_inference_func_args(
|
|
284
|
-
|
|
288
|
+
inference_func_args_by_subject_id: dict[collections.abc.Hashable, tuple[Any, ...]],
|
|
285
289
|
inference_func_args_theta_index: int,
|
|
286
290
|
theta: jnp.ndarray,
|
|
287
291
|
inference_func_args_action_prob_index: int,
|
|
288
|
-
|
|
292
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id: dict[
|
|
289
293
|
collections.abc.Hashable, dict[int, tuple[Any, ...]]
|
|
290
294
|
],
|
|
291
|
-
|
|
295
|
+
inference_action_prob_decision_times_by_subject_id: dict[
|
|
292
296
|
collections.abc.Hashable, list[int]
|
|
293
297
|
],
|
|
294
298
|
action_prob_func: callable,
|
|
@@ -300,7 +304,7 @@ def thread_inference_func_args(
|
|
|
300
304
|
probabilities computed using the shared betas.
|
|
301
305
|
|
|
302
306
|
Args:
|
|
303
|
-
|
|
307
|
+
inference_func_args_by_subject_id (dict[collections.abc.Hashable, tuple[Any, ...]]):
|
|
304
308
|
A dictionary mapping user IDs to their respective inference function arguments.
|
|
305
309
|
|
|
306
310
|
inference_func_args_theta_index (int):
|
|
@@ -315,11 +319,11 @@ def thread_inference_func_args(
|
|
|
315
319
|
tuple where new beta-threaded action probabilities should be inserted, if applicable.
|
|
316
320
|
-1 otherwise.
|
|
317
321
|
|
|
318
|
-
|
|
322
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id (dict[collections.abc.Hashable, dict[int, tuple[Any, ...]]]):
|
|
319
323
|
A dictionary mapping decision times to the function arguments required to compute action
|
|
320
324
|
probabilities for this user, and with the shared betas thread in.
|
|
321
325
|
|
|
322
|
-
|
|
326
|
+
inference_action_prob_decision_times_by_subject_id (dict[collections.abc.Hashable, list[int]]):
|
|
323
327
|
For each user, a list of decision times to which action probabilities correspond if
|
|
324
328
|
provided. Typically just in-study times if action probabilites are used in the inference
|
|
325
329
|
loss or estimating function.
|
|
@@ -332,9 +336,9 @@ def thread_inference_func_args(
|
|
|
332
336
|
threaded in.
|
|
333
337
|
"""
|
|
334
338
|
|
|
335
|
-
|
|
336
|
-
for
|
|
337
|
-
|
|
339
|
+
threaded_inference_func_args_by_subject_id = {}
|
|
340
|
+
for subject_id, args in inference_func_args_by_subject_id.items():
|
|
341
|
+
threaded_inference_func_args_by_subject_id[subject_id] = replace_tuple_index(
|
|
338
342
|
args,
|
|
339
343
|
inference_func_args_theta_index,
|
|
340
344
|
theta,
|
|
@@ -343,12 +347,12 @@ def thread_inference_func_args(
|
|
|
343
347
|
if inference_func_args_action_prob_index >= 0:
|
|
344
348
|
# Use a vmap-like pattern to compute action probabilities in batch.
|
|
345
349
|
action_prob_times_flattened = (
|
|
346
|
-
|
|
350
|
+
inference_action_prob_decision_times_by_subject_id[subject_id].flatten()
|
|
347
351
|
)
|
|
348
352
|
args_list = [
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
]
|
|
353
|
+
threaded_action_prob_func_args_by_decision_time_by_subject_id[
|
|
354
|
+
subject_id
|
|
355
|
+
][int(t)]
|
|
352
356
|
for t in action_prob_times_flattened.tolist()
|
|
353
357
|
]
|
|
354
358
|
if len(args_list) == 0:
|
|
@@ -369,9 +373,11 @@ def thread_inference_func_args(
|
|
|
369
373
|
action_probs_to_introduce = vmapped_func(*batched_tensors).reshape(
|
|
370
374
|
args[inference_func_args_action_prob_index].shape
|
|
371
375
|
)
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
+
threaded_inference_func_args_by_subject_id[subject_id] = (
|
|
377
|
+
replace_tuple_index(
|
|
378
|
+
threaded_inference_func_args_by_subject_id[subject_id],
|
|
379
|
+
inference_func_args_action_prob_index,
|
|
380
|
+
action_probs_to_introduce,
|
|
381
|
+
)
|
|
376
382
|
)
|
|
377
|
-
return
|
|
383
|
+
return threaded_inference_func_args_by_subject_id
|
|
@@ -18,8 +18,6 @@ logging.basicConfig(
|
|
|
18
18
|
level=logging.INFO,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
|
-
# TODO: Consolidate function loading logic
|
|
22
|
-
|
|
23
21
|
|
|
24
22
|
def get_batched_arg_lists_and_involved_user_ids(func, sorted_user_ids, args_by_user_id):
|
|
25
23
|
"""
|
|
@@ -198,10 +196,10 @@ def pad_in_study_derivatives_with_zeros(
|
|
|
198
196
|
|
|
199
197
|
def calculate_pi_and_weight_gradients(
|
|
200
198
|
study_df,
|
|
201
|
-
|
|
199
|
+
active_col_name,
|
|
202
200
|
action_col_name,
|
|
203
201
|
calendar_t_col_name,
|
|
204
|
-
|
|
202
|
+
subject_id_col_name,
|
|
205
203
|
action_prob_func,
|
|
206
204
|
action_prob_func_args,
|
|
207
205
|
action_prob_func_args_beta_index,
|
|
@@ -226,10 +224,10 @@ def calculate_pi_and_weight_gradients(
|
|
|
226
224
|
|
|
227
225
|
pi_gradients, weight_gradients = calculate_pi_and_weight_gradients_specific_t(
|
|
228
226
|
study_df,
|
|
229
|
-
|
|
227
|
+
active_col_name,
|
|
230
228
|
action_col_name,
|
|
231
229
|
calendar_t_col_name,
|
|
232
|
-
|
|
230
|
+
subject_id_col_name,
|
|
233
231
|
action_prob_func,
|
|
234
232
|
action_prob_func_args_beta_index,
|
|
235
233
|
calendar_t,
|
|
@@ -252,10 +250,10 @@ def calculate_pi_and_weight_gradients(
|
|
|
252
250
|
|
|
253
251
|
def calculate_pi_and_weight_gradients_specific_t(
|
|
254
252
|
study_df,
|
|
255
|
-
|
|
253
|
+
active_col_name,
|
|
256
254
|
action_col_name,
|
|
257
255
|
calendar_t_col_name,
|
|
258
|
-
|
|
256
|
+
subject_id_col_name,
|
|
259
257
|
action_prob_func,
|
|
260
258
|
action_prob_func_args_beta_index,
|
|
261
259
|
calendar_t,
|
|
@@ -320,10 +318,10 @@ def calculate_pi_and_weight_gradients_specific_t(
|
|
|
320
318
|
study_df,
|
|
321
319
|
calendar_t,
|
|
322
320
|
sorted_user_ids,
|
|
323
|
-
|
|
321
|
+
active_col_name,
|
|
324
322
|
action_col_name,
|
|
325
323
|
calendar_t_col_name,
|
|
326
|
-
|
|
324
|
+
subject_id_col_name,
|
|
327
325
|
)
|
|
328
326
|
# Note the first argument here: we extract the betas to pass in
|
|
329
327
|
# again as the "target" denominator betas, whereas we differentiate with
|
|
@@ -382,10 +380,10 @@ def collect_batched_in_study_actions(
|
|
|
382
380
|
study_df,
|
|
383
381
|
calendar_t,
|
|
384
382
|
sorted_user_ids,
|
|
385
|
-
|
|
383
|
+
active_col_name,
|
|
386
384
|
action_col_name,
|
|
387
385
|
calendar_t_col_name,
|
|
388
|
-
|
|
386
|
+
subject_id_col_name,
|
|
389
387
|
):
|
|
390
388
|
|
|
391
389
|
# TODO: This for loop can be removed, just grabbing the actions col after
|
|
@@ -394,9 +392,9 @@ def collect_batched_in_study_actions(
|
|
|
394
392
|
batched_actions_list = []
|
|
395
393
|
for user_id in sorted_user_ids:
|
|
396
394
|
filtered_user_data = study_df.loc[
|
|
397
|
-
(study_df[
|
|
395
|
+
(study_df[subject_id_col_name] == user_id)
|
|
398
396
|
& (study_df[calendar_t_col_name] == calendar_t)
|
|
399
|
-
& (study_df[
|
|
397
|
+
& (study_df[active_col_name] == 1)
|
|
400
398
|
]
|
|
401
399
|
if not filtered_user_data.empty:
|
|
402
400
|
batched_actions_list.append(filtered_user_data[action_col_name].values[0])
|
|
@@ -785,9 +783,9 @@ def calculate_inference_loss_derivatives(
|
|
|
785
783
|
inference_func,
|
|
786
784
|
inference_func_args_theta_index,
|
|
787
785
|
user_ids,
|
|
788
|
-
|
|
786
|
+
subject_id_col_name,
|
|
789
787
|
action_prob_col_name,
|
|
790
|
-
|
|
788
|
+
active_col_name,
|
|
791
789
|
calendar_t_col_name,
|
|
792
790
|
inference_func_type=FunctionTypes.LOSS,
|
|
793
791
|
):
|
|
@@ -819,18 +817,18 @@ def calculate_inference_loss_derivatives(
|
|
|
819
817
|
max_calendar_time = study_df[calendar_t_col_name].max()
|
|
820
818
|
for user_id in user_ids:
|
|
821
819
|
user_args_list = []
|
|
822
|
-
filtered_user_data = study_df.loc[study_df[
|
|
820
|
+
filtered_user_data = study_df.loc[study_df[subject_id_col_name] == user_id]
|
|
823
821
|
for idx, col_name in enumerate(inference_func_arg_names):
|
|
824
822
|
if idx == inference_func_args_theta_index:
|
|
825
823
|
user_args_list.append(theta_est)
|
|
826
824
|
else:
|
|
827
825
|
user_args_list.append(
|
|
828
|
-
get_study_df_column(filtered_user_data, col_name,
|
|
826
|
+
get_study_df_column(filtered_user_data, col_name, active_col_name)
|
|
829
827
|
)
|
|
830
828
|
args_by_user_id[user_id] = tuple(user_args_list)
|
|
831
829
|
if using_action_probs:
|
|
832
830
|
action_prob_decision_times_by_user_id[user_id] = get_study_df_column(
|
|
833
|
-
filtered_user_data, calendar_t_col_name,
|
|
831
|
+
filtered_user_data, calendar_t_col_name, active_col_name
|
|
834
832
|
)
|
|
835
833
|
|
|
836
834
|
# Get a list of subdicts of the user args dict, with each united by having
|
|
@@ -957,9 +955,7 @@ def calculate_inference_loss_derivatives(
|
|
|
957
955
|
return loss_gradients, loss_hessians, loss_gradient_pi_derivatives
|
|
958
956
|
|
|
959
957
|
|
|
960
|
-
def get_study_df_column(study_df, col_name,
|
|
958
|
+
def get_study_df_column(study_df, col_name, active_col_name):
|
|
961
959
|
return jnp.array(
|
|
962
|
-
study_df.loc[study_df[
|
|
963
|
-
.to_numpy()
|
|
964
|
-
.reshape(-1, 1)
|
|
960
|
+
study_df.loc[study_df[active_col_name] == 1, col_name].to_numpy().reshape(-1, 1)
|
|
965
961
|
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
class SmallSampleCorrections:
|
|
2
|
+
NONE = "none"
|
|
3
|
+
Z1theta = "Z1theta"
|
|
4
|
+
Z2theta = "Z2theta"
|
|
5
|
+
Z3theta = "Z3theta"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FunctionTypes:
|
|
9
|
+
LOSS = "loss"
|
|
10
|
+
ESTIMATING = "estimating"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SandwichFormationMethods:
|
|
14
|
+
BREAD_T_QR = "bread_T_qr"
|
|
15
|
+
MEAT_SVD_SOLVE = "meat_svd_solve"
|
|
16
|
+
NAIVE = "naive"
|