lifejacket 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifejacket/__init__.py +0 -0
- lifejacket/after_study_analysis.py +1845 -0
- lifejacket/arg_threading_helpers.py +354 -0
- lifejacket/calculate_derivatives.py +965 -0
- lifejacket/constants.py +28 -0
- lifejacket/form_adaptive_meat_adjustments_directly.py +333 -0
- lifejacket/get_datum_for_blowup_supervised_learning.py +1312 -0
- lifejacket/helper_functions.py +587 -0
- lifejacket/input_checks.py +1145 -0
- lifejacket/small_sample_corrections.py +125 -0
- lifejacket/trial_conditioning_monitor.py +870 -0
- lifejacket/vmap_helpers.py +71 -0
- lifejacket-0.1.0.dist-info/METADATA +100 -0
- lifejacket-0.1.0.dist-info/RECORD +17 -0
- lifejacket-0.1.0.dist-info/WHEEL +5 -0
- lifejacket-0.1.0.dist-info/entry_points.txt +2 -0
- lifejacket-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
|
|
3
|
+
from jax import numpy as jnp
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# TODO: Check for exactly the required types earlier
|
|
8
|
+
# TODO: Try except and nice error message
|
|
9
|
+
# TODO: This is complicated enough to deserve its own unit tests
|
|
10
|
+
def stack_batched_arg_lists_into_tensors(batched_arg_lists):
|
|
11
|
+
"""
|
|
12
|
+
Stack a simple Python list of lists of function arguments into a list of jnp arrays that can be
|
|
13
|
+
supplied to vmap as batch arguments. vmap requires all elements of such a batched array to be
|
|
14
|
+
the same shape, as do the stacking functions we use here. Thus we require this be called on
|
|
15
|
+
batches with the same data shape. We also supply the axes one must iterate over to get
|
|
16
|
+
each users's args in a batch.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
batched_arg_tensors = []
|
|
20
|
+
|
|
21
|
+
# This ends up being all zeros because of the way we are (now) doing the
|
|
22
|
+
# stacking, but better to not assume that externally and send out what
|
|
23
|
+
# we've done with this list.
|
|
24
|
+
batch_axes = []
|
|
25
|
+
|
|
26
|
+
for batched_arg_list in batched_arg_lists:
|
|
27
|
+
if (
|
|
28
|
+
isinstance(
|
|
29
|
+
batched_arg_list[0],
|
|
30
|
+
(jnp.ndarray, np.ndarray),
|
|
31
|
+
)
|
|
32
|
+
and batched_arg_list[0].ndim > 2
|
|
33
|
+
):
|
|
34
|
+
raise TypeError("Arrays with dimension greater that 2 are not supported.")
|
|
35
|
+
if (
|
|
36
|
+
isinstance(
|
|
37
|
+
batched_arg_list[0],
|
|
38
|
+
(jnp.ndarray, np.ndarray),
|
|
39
|
+
)
|
|
40
|
+
and batched_arg_list[0].ndim == 2
|
|
41
|
+
):
|
|
42
|
+
########## We have a matrix (2D array) arg
|
|
43
|
+
|
|
44
|
+
batched_arg_tensors.append(jnp.stack(batched_arg_list, 0))
|
|
45
|
+
batch_axes.append(0)
|
|
46
|
+
elif isinstance(
|
|
47
|
+
batched_arg_list[0],
|
|
48
|
+
(collections.abc.Sequence, jnp.ndarray, np.ndarray),
|
|
49
|
+
) and not isinstance(batched_arg_list[0], str):
|
|
50
|
+
########## We have a vector (1D array) arg
|
|
51
|
+
if not isinstance(batched_arg_list[0], (jnp.ndarray, np.ndarray)):
|
|
52
|
+
try:
|
|
53
|
+
batched_arg_list = [jnp.array(x) for x in batched_arg_list]
|
|
54
|
+
except Exception as e:
|
|
55
|
+
raise TypeError(
|
|
56
|
+
"Argument of sequence type that cannot be cast to JAX numpy array"
|
|
57
|
+
) from e
|
|
58
|
+
assert batched_arg_list[0].ndim == 1
|
|
59
|
+
|
|
60
|
+
batched_arg_tensors.append(jnp.vstack(batched_arg_list))
|
|
61
|
+
batch_axes.append(0)
|
|
62
|
+
else:
|
|
63
|
+
########## Otherwise we should have a list of scalars.
|
|
64
|
+
# Just turn into a jnp array.
|
|
65
|
+
batched_arg_tensors.append(jnp.array(batched_arg_list))
|
|
66
|
+
batch_axes.append(0)
|
|
67
|
+
|
|
68
|
+
return (
|
|
69
|
+
batched_arg_tensors,
|
|
70
|
+
batch_axes,
|
|
71
|
+
)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lifejacket
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A package for after-study analysis of adaptive experiments in which data is pooled across users.
|
|
5
|
+
Author-email: Nowell Closser <nowellclosser@gmail.com>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: click>=8.0
|
|
9
|
+
Requires-Dist: jax>=0.4.0
|
|
10
|
+
Requires-Dist: jaxlib>=0.4.0
|
|
11
|
+
Requires-Dist: numpy>=1.20.0
|
|
12
|
+
Requires-Dist: pandas>=1.3.0
|
|
13
|
+
Requires-Dist: scipy>=1.7.0
|
|
14
|
+
Requires-Dist: plotext>=5.0.0
|
|
15
|
+
Provides-Extra: dev
|
|
16
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
17
|
+
Requires-Dist: black>=22.0; extra == "dev"
|
|
18
|
+
Requires-Dist: flake8>=4.0; extra == "dev"
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
_ _ __ _ _ _
|
|
22
|
+
| (_)/ _| (_) | | | |
|
|
23
|
+
| |_| |_ ___ _ __ _ ___| | _____| |_
|
|
24
|
+
| | | _/ _ \ |/ _` |/ __| |/ / _ \ __|
|
|
25
|
+
| | | || __/ | (_| | (__| < __/ |_
|
|
26
|
+
|_|_|_| \___| |\__,_|\___|_|\_\___|\__|
|
|
27
|
+
_/ |
|
|
28
|
+
|__/
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
Save your standard errors from pooling in adaptive experiments.
|
|
32
|
+
|
|
33
|
+
## Setup (if not using conda)
|
|
34
|
+
### Create and activate a virtual environment
|
|
35
|
+
- `python3 -m venv .venv; source /.venv/bin/activate`
|
|
36
|
+
|
|
37
|
+
### Adding a package
|
|
38
|
+
- Add to `requirements.txt` with a specific version or no version if you want the latest stable
|
|
39
|
+
- Run `pip freeze > requirements.txt` to lock the versions of your package and all its subpackages
|
|
40
|
+
|
|
41
|
+
## Running the code
|
|
42
|
+
- `export PYTHONPATH to the absolute path of this repository on your computer
|
|
43
|
+
- `./run_local_synthetic.sh`, which outputs to `simulated_data/` by default. See all the possible flags to be toggled in the script code.
|
|
44
|
+
|
|
45
|
+
## Linting/Formatting
|
|
46
|
+
|
|
47
|
+
## Testing
|
|
48
|
+
python -m pytest
|
|
49
|
+
python -m pytest tests/unit_tests
|
|
50
|
+
python -m pytest tests/integration_tests
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Talk about gitignored cluster simulation scripts
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
### Important Large-Scale Simulations
|
|
61
|
+
|
|
62
|
+
#### No adaptivity
|
|
63
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=0.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
64
|
+
|
|
65
|
+
#### No adaptivity, 5 batches incremental recruitment
|
|
66
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=10000 --steepness=0.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
67
|
+
|
|
68
|
+
#### Some adaptivity, no action_centering
|
|
69
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
70
|
+
|
|
71
|
+
#### Some adaptivity, no action_centering, 5 batches incremental recruitment
|
|
72
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=10000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
73
|
+
|
|
74
|
+
#### More adaptivity, no action_centering
|
|
75
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=5.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
76
|
+
|
|
77
|
+
#### Even more adaptivity, no action_centering
|
|
78
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=10.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
79
|
+
|
|
80
|
+
#### Some adaptivity, RL action_centering, no inference action centering
|
|
81
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=1 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
82
|
+
|
|
83
|
+
#### Some adaptivity, inference action_centering, no RL action centering
|
|
84
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_action_centering.py"
|
|
85
|
+
|
|
86
|
+
#### Some adaptivity, inference and RL action_centering
|
|
87
|
+
sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=1 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_action_centering.py"
|
|
88
|
+
|
|
89
|
+
#### Some adaptivity, inference and RL action_centering, even more T
|
|
90
|
+
sbatch --array=[0-999] -t 1-00:00 --mem=50G run_and_analysis_parallel_synthetic --T=25 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=1 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
91
|
+
|
|
92
|
+
#### Some adaptivity, inference and RL action_centering, even more T, 5 batches incremental recruitment
|
|
93
|
+
sbatch --array=[0-999] -t 1-00:00 --mem=50G run_and_analysis_parallel_synthetic --T=25 --n=50000 --recruit_n=10000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=1 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
## TODO
|
|
98
|
+
1. Add precommit hooks (pip freeze, linting, formatting)
|
|
99
|
+
2. Run tests on PRs
|
|
100
|
+
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
lifejacket/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
lifejacket/after_study_analysis.py,sha256=HtAHRYCni8Lh7HCB9A0SsRkyAol2FbczrkxFn1ww1HA,81170
|
|
3
|
+
lifejacket/arg_threading_helpers.py,sha256=kxtGlN_B1C1WKcWtJMlgWJwzzI1THytflmWsL5ZML7k,16228
|
|
4
|
+
lifejacket/calculate_derivatives.py,sha256=3rYukD1wbjDof7d4_3QdQ-A4GSK9H8z8HJsbNQh0DzA,37472
|
|
5
|
+
lifejacket/constants.py,sha256=2L05p6NJ7l3qRZ1hD2KlrvzWF1ReSmWRUkULPIhdvJo,842
|
|
6
|
+
lifejacket/form_adaptive_meat_adjustments_directly.py,sha256=_BaziGfYjEySN78nU3lCrVtf2KWIuZ8PmzfMZypAaWI,13728
|
|
7
|
+
lifejacket/get_datum_for_blowup_supervised_learning.py,sha256=V8H4PE49dQwsKjj93QEu2BKbhwPr56QMtx2jhan39-c,58357
|
|
8
|
+
lifejacket/helper_functions.py,sha256=xOhRG-Cm4ZdRNm-O0faHna583d74pLWY5_jfnokegWc,23295
|
|
9
|
+
lifejacket/input_checks.py,sha256=A0f2owqRUjeBAh5jLULKu1nXW1SgZDR5eK7xBm1ahZw,44878
|
|
10
|
+
lifejacket/small_sample_corrections.py,sha256=f8jmg9U9ZN77WadJud70tt6NMxCTsSGtlsdF_-mfws4,5543
|
|
11
|
+
lifejacket/trial_conditioning_monitor.py,sha256=qNTHh0zY2P7zJxox_OwhEEK8Ed1l0TPOjGDqNxMNoIQ,42164
|
|
12
|
+
lifejacket/vmap_helpers.py,sha256=pZqYN3p9Ty9DPOeeY9TKbRJXR2AV__HBwwDFOvdOQ84,2688
|
|
13
|
+
lifejacket-0.1.0.dist-info/METADATA,sha256=VT6H9TNcYleRhp-wLda4HXrbI7EJYj6ZV0_7K5fraI4,7274
|
|
14
|
+
lifejacket-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
lifejacket-0.1.0.dist-info/entry_points.txt,sha256=4k8ibVIUT-OHxPaaDv-QwWpC64ErzhdemHpTAXCnb8w,67
|
|
16
|
+
lifejacket-0.1.0.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
|
|
17
|
+
lifejacket-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
lifejacket
|