lifejacket 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,71 @@
1
+ import collections
2
+
3
+ from jax import numpy as jnp
4
+ import numpy as np
5
+
6
+
7
+ # TODO: Check for exactly the required types earlier
8
+ # TODO: Try except and nice error message
9
+ # TODO: This is complicated enough to deserve its own unit tests
10
+ def stack_batched_arg_lists_into_tensors(batched_arg_lists):
11
+ """
12
+ Stack a simple Python list of lists of function arguments into a list of jnp arrays that can be
13
+ supplied to vmap as batch arguments. vmap requires all elements of such a batched array to be
14
+ the same shape, as do the stacking functions we use here. Thus we require this be called on
15
+ batches with the same data shape. We also supply the axes one must iterate over to get
16
+ each users's args in a batch.
17
+ """
18
+
19
+ batched_arg_tensors = []
20
+
21
+ # This ends up being all zeros because of the way we are (now) doing the
22
+ # stacking, but better to not assume that externally and send out what
23
+ # we've done with this list.
24
+ batch_axes = []
25
+
26
+ for batched_arg_list in batched_arg_lists:
27
+ if (
28
+ isinstance(
29
+ batched_arg_list[0],
30
+ (jnp.ndarray, np.ndarray),
31
+ )
32
+ and batched_arg_list[0].ndim > 2
33
+ ):
34
+ raise TypeError("Arrays with dimension greater that 2 are not supported.")
35
+ if (
36
+ isinstance(
37
+ batched_arg_list[0],
38
+ (jnp.ndarray, np.ndarray),
39
+ )
40
+ and batched_arg_list[0].ndim == 2
41
+ ):
42
+ ########## We have a matrix (2D array) arg
43
+
44
+ batched_arg_tensors.append(jnp.stack(batched_arg_list, 0))
45
+ batch_axes.append(0)
46
+ elif isinstance(
47
+ batched_arg_list[0],
48
+ (collections.abc.Sequence, jnp.ndarray, np.ndarray),
49
+ ) and not isinstance(batched_arg_list[0], str):
50
+ ########## We have a vector (1D array) arg
51
+ if not isinstance(batched_arg_list[0], (jnp.ndarray, np.ndarray)):
52
+ try:
53
+ batched_arg_list = [jnp.array(x) for x in batched_arg_list]
54
+ except Exception as e:
55
+ raise TypeError(
56
+ "Argument of sequence type that cannot be cast to JAX numpy array"
57
+ ) from e
58
+ assert batched_arg_list[0].ndim == 1
59
+
60
+ batched_arg_tensors.append(jnp.vstack(batched_arg_list))
61
+ batch_axes.append(0)
62
+ else:
63
+ ########## Otherwise we should have a list of scalars.
64
+ # Just turn into a jnp array.
65
+ batched_arg_tensors.append(jnp.array(batched_arg_list))
66
+ batch_axes.append(0)
67
+
68
+ return (
69
+ batched_arg_tensors,
70
+ batch_axes,
71
+ )
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: lifejacket
3
+ Version: 0.1.0
4
+ Summary: A package for after-study analysis of adaptive experiments in which data is pooled across users.
5
+ Author-email: Nowell Closser <nowellclosser@gmail.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: click>=8.0
9
+ Requires-Dist: jax>=0.4.0
10
+ Requires-Dist: jaxlib>=0.4.0
11
+ Requires-Dist: numpy>=1.20.0
12
+ Requires-Dist: pandas>=1.3.0
13
+ Requires-Dist: scipy>=1.7.0
14
+ Requires-Dist: plotext>=5.0.0
15
+ Provides-Extra: dev
16
+ Requires-Dist: pytest>=7.0; extra == "dev"
17
+ Requires-Dist: black>=22.0; extra == "dev"
18
+ Requires-Dist: flake8>=4.0; extra == "dev"
19
+
20
+ ```python
21
+ _ _ __ _ _ _
22
+ | (_)/ _| (_) | | | |
23
+ | |_| |_ ___ _ __ _ ___| | _____| |_
24
+ | | | _/ _ \ |/ _` |/ __| |/ / _ \ __|
25
+ | | | || __/ | (_| | (__| < __/ |_
26
+ |_|_|_| \___| |\__,_|\___|_|\_\___|\__|
27
+ _/ |
28
+ |__/
29
+ ```
30
+
31
+ Save your standard errors from pooling in adaptive experiments.
32
+
33
+ ## Setup (if not using conda)
34
+ ### Create and activate a virtual environment
35
+ - `python3 -m venv .venv; source /.venv/bin/activate`
36
+
37
+ ### Adding a package
38
+ - Add to `requirements.txt` with a specific version or no version if you want the latest stable
39
+ - Run `pip freeze > requirements.txt` to lock the versions of your package and all its subpackages
40
+
41
+ ## Running the code
42
+ - `export PYTHONPATH to the absolute path of this repository on your computer
43
+ - `./run_local_synthetic.sh`, which outputs to `simulated_data/` by default. See all the possible flags to be toggled in the script code.
44
+
45
+ ## Linting/Formatting
46
+
47
+ ## Testing
48
+ python -m pytest
49
+ python -m pytest tests/unit_tests
50
+ python -m pytest tests/integration_tests
51
+
52
+
53
+ # Talk about gitignored cluster simulation scripts
54
+
55
+
56
+
57
+
58
+
59
+
60
+ ### Important Large-Scale Simulations
61
+
62
+ #### No adaptivity
63
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=0.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
64
+
65
+ #### No adaptivity, 5 batches incremental recruitment
66
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=10000 --steepness=0.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
67
+
68
+ #### Some adaptivity, no action_centering
69
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
70
+
71
+ #### Some adaptivity, no action_centering, 5 batches incremental recruitment
72
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=10000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
73
+
74
+ #### More adaptivity, no action_centering
75
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=5.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
76
+
77
+ #### Even more adaptivity, no action_centering
78
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=10.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
79
+
80
+ #### Some adaptivity, RL action_centering, no inference action centering
81
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=1 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
82
+
83
+ #### Some adaptivity, inference action_centering, no RL action centering
84
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=0 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_action_centering.py"
85
+
86
+ #### Some adaptivity, inference and RL action_centering
87
+ sbatch --array=[0-999] -t 0-5:00 --mem=50G run_and_analysis_parallel_synthetic --T=10 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=1 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_action_centering.py"
88
+
89
+ #### Some adaptivity, inference and RL action_centering, even more T
90
+ sbatch --array=[0-999] -t 1-00:00 --mem=50G run_and_analysis_parallel_synthetic --T=25 --n=50000 --recruit_n=50000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=1 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
91
+
92
+ #### Some adaptivity, inference and RL action_centering, even more T, 5 batches incremental recruitment
93
+ sbatch --array=[0-999] -t 1-00:00 --mem=50G run_and_analysis_parallel_synthetic --T=25 --n=50000 --recruit_n=10000 --steepness=3.0 --alg_state_feats=intercept,past_reward --action_centering_RL=1 --inference_loss_func_filename="functions_to_pass_to_analysis/get_least_squares_loss_inference_no_action_centering.py" --theta_calculation_func_filename="functions_to_pass_to_analysis/estimate_theta_least_squares_no_action_centering.py"
94
+
95
+
96
+
97
+ ## TODO
98
+ 1. Add precommit hooks (pip freeze, linting, formatting)
99
+ 2. Run tests on PRs
100
+
@@ -0,0 +1,17 @@
1
+ lifejacket/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ lifejacket/after_study_analysis.py,sha256=HtAHRYCni8Lh7HCB9A0SsRkyAol2FbczrkxFn1ww1HA,81170
3
+ lifejacket/arg_threading_helpers.py,sha256=kxtGlN_B1C1WKcWtJMlgWJwzzI1THytflmWsL5ZML7k,16228
4
+ lifejacket/calculate_derivatives.py,sha256=3rYukD1wbjDof7d4_3QdQ-A4GSK9H8z8HJsbNQh0DzA,37472
5
+ lifejacket/constants.py,sha256=2L05p6NJ7l3qRZ1hD2KlrvzWF1ReSmWRUkULPIhdvJo,842
6
+ lifejacket/form_adaptive_meat_adjustments_directly.py,sha256=_BaziGfYjEySN78nU3lCrVtf2KWIuZ8PmzfMZypAaWI,13728
7
+ lifejacket/get_datum_for_blowup_supervised_learning.py,sha256=V8H4PE49dQwsKjj93QEu2BKbhwPr56QMtx2jhan39-c,58357
8
+ lifejacket/helper_functions.py,sha256=xOhRG-Cm4ZdRNm-O0faHna583d74pLWY5_jfnokegWc,23295
9
+ lifejacket/input_checks.py,sha256=A0f2owqRUjeBAh5jLULKu1nXW1SgZDR5eK7xBm1ahZw,44878
10
+ lifejacket/small_sample_corrections.py,sha256=f8jmg9U9ZN77WadJud70tt6NMxCTsSGtlsdF_-mfws4,5543
11
+ lifejacket/trial_conditioning_monitor.py,sha256=qNTHh0zY2P7zJxox_OwhEEK8Ed1l0TPOjGDqNxMNoIQ,42164
12
+ lifejacket/vmap_helpers.py,sha256=pZqYN3p9Ty9DPOeeY9TKbRJXR2AV__HBwwDFOvdOQ84,2688
13
+ lifejacket-0.1.0.dist-info/METADATA,sha256=VT6H9TNcYleRhp-wLda4HXrbI7EJYj6ZV0_7K5fraI4,7274
14
+ lifejacket-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
+ lifejacket-0.1.0.dist-info/entry_points.txt,sha256=4k8ibVIUT-OHxPaaDv-QwWpC64ErzhdemHpTAXCnb8w,67
16
+ lifejacket-0.1.0.dist-info/top_level.txt,sha256=vKl8m7jOQ4pkbzVuHCJsq-8LcXRrOAWnok3bBo9qpsE,11
17
+ lifejacket-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ lifejacket = lifejacket.after_study_analysis:cli
@@ -0,0 +1 @@
1
+ lifejacket