gensbi-examples 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,217 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # task data
210
+ task_data/*
211
+ *.npz
212
+
213
+ #sub
214
+ condor_logs/
215
+
216
+ #some test traces
217
+ tests/tmp/*
@@ -0,0 +1,13 @@
1
+ Copyright 2025 Amerio Aurelio
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
@@ -0,0 +1,72 @@
1
+ Metadata-Version: 2.4
2
+ Name: gensbi-examples
3
+ Version: 0.0.2
4
+ Summary: Examples for the GenSBI library
5
+ Project-URL: Homepage, https://github.com/aurelio-amerio/GenSBI-examples
6
+ Project-URL: Issues, https://github.com/aurelio-amerio/GenSBI-examples/issues
7
+ Author-email: Aurelio Amerio <aure.amerio@gmail.com>
8
+ License: Copyright 2025 Amerio Aurelio
9
+
10
+ Licensed under the Apache License, Version 2.0 (the "License");
11
+ you may not use this file except in compliance with the License.
12
+ You may obtain a copy of the License at
13
+
14
+ http://www.apache.org/licenses/LICENSE-2.0
15
+
16
+ Unless required by applicable law or agreed to in writing, software
17
+ distributed under the License is distributed on an "AS IS" BASIS,
18
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ See the License for the specific language governing permissions and
20
+ limitations under the License.
21
+ License-File: LICENSE
22
+ Requires-Python: >=3.11
23
+ Requires-Dist: datasets
24
+ Requires-Dist: flax>=0.12.0
25
+ Requires-Dist: grain>=0.2.12
26
+ Requires-Dist: huggingface-hub
27
+ Requires-Dist: jax<=0.8.1,>=0.7.2
28
+ Requires-Dist: matplotlib>=3.10
29
+ Requires-Dist: numpy>=2.0
30
+ Requires-Dist: scikit-learn>=1.7.0
31
+ Description-Content-Type: text/markdown
32
+
33
+ # GenSBI Examples
34
+
35
+ This repository contains a collection of examples, tutorials, and recipes for **GenSBI**, a JAX-based library for Simulation-Based Inference using generative models.
36
+
37
+ These examples demonstrate how to use GenSBI for various tasks, including:
38
+
39
+ - Defining and running inference pipelines.
40
+ - Using different embedding networks (MLP, ResNet, etc.).
41
+ - Handling various data types (1D signals, 2D images).
42
+
43
+ ## Installation
44
+
45
+ ### Prerequisites
46
+
47
+ You need to have **GenSBI** installed.
48
+
49
+ **With CUDA 12 support (Recommended):**
50
+
51
+ ```bash
52
+ pip install gensbi[cuda12]
53
+ ```
54
+
55
+ **CPU-only:**
56
+
57
+ ```bash
58
+ pip install gensbi
59
+ ```
60
+
61
+ ### Install Examples Package
62
+
63
+ To run the examples and ensure all dependencies are met, install this package:
64
+
65
+ ```bash
66
+ pip install gensbi-examples
67
+ ```
68
+
69
+ ## Structure
70
+
71
+ - `examples/`: Contains standalone example scripts and notebooks.
72
+ - `src/gensbi_examples`: Helper utilities for the examples.
@@ -0,0 +1,40 @@
1
+ # GenSBI Examples
2
+
3
+ This repository contains a collection of examples, tutorials, and recipes for **GenSBI**, a JAX-based library for Simulation-Based Inference using generative models.
4
+
5
+ These examples demonstrate how to use GenSBI for various tasks, including:
6
+
7
+ - Defining and running inference pipelines.
8
+ - Using different embedding networks (MLP, ResNet, etc.).
9
+ - Handling various data types (1D signals, 2D images).
10
+
11
+ ## Installation
12
+
13
+ ### Prerequisites
14
+
15
+ You need to have **GenSBI** installed.
16
+
17
+ **With CUDA 12 support (Recommended):**
18
+
19
+ ```bash
20
+ pip install gensbi[cuda12]
21
+ ```
22
+
23
+ **CPU-only:**
24
+
25
+ ```bash
26
+ pip install gensbi
27
+ ```
28
+
29
+ ### Install Examples Package
30
+
31
+ To run the examples and ensure all dependencies are met, install this package:
32
+
33
+ ```bash
34
+ pip install gensbi-examples
35
+ ```
36
+
37
+ ## Structure
38
+
39
+ - `examples/`: Contains standalone example scripts and notebooks.
40
+ - `src/gensbi_examples`: Helper utilities for the examples.
File without changes
@@ -0,0 +1,111 @@
1
+ from typing import Optional
2
+
3
+ import jax
4
+ from jax import numpy as jnp
5
+ from jax import Array
6
+ import numpy as np
7
+ from sklearn.model_selection import KFold, cross_val_score
8
+ from sklearn.neural_network import MLPClassifier
9
+
10
+
11
+ def c2st(
12
+ X: Array,
13
+ Y: Array,
14
+ seed: int = 1,
15
+ n_folds: int = 5,
16
+ scoring: str = "accuracy",
17
+ z_score: bool = True,
18
+ noise_scale: Optional[float] = None,
19
+ ) -> Array:
20
+ """Classifier-based 2-sample test returning accuracy
21
+
22
+ Trains classifiers with N-fold cross-validation [1]. Scikit learn MLPClassifier are
23
+ used, with 2 hidden layers of 10x dim each, where dim is the dimensionality of the
24
+ samples X and Y.
25
+
26
+ Args:
27
+ X: Sample 1
28
+ Y: Sample 2
29
+ seed: Seed for sklearn
30
+ n_folds: Number of folds
31
+ z_score: Z-scoring using X
32
+ noise_scale: If passed, will add Gaussian noise with std noise_scale to samples
33
+
34
+ References:
35
+ [1]: https://scikit-learn.org/stable/modules/cross_validation.html
36
+ """
37
+ if z_score:
38
+ X_mean = jnp.mean(X, axis=0)
39
+ X_std = jnp.std(X, axis=0)
40
+ X = (X - X_mean) / X_std
41
+ Y = (Y - X_mean) / X_std
42
+
43
+
44
+ if noise_scale is not None:
45
+ key = jax.random.PRNGKey(seed)
46
+ X += noise_scale * jax.random.normal(key, X.shape) * noise_scale
47
+ Y += noise_scale * jax.random.normal(key, Y.shape) * noise_scale
48
+
49
+
50
+ # Convert to numpy if not already
51
+
52
+ X = np.asarray(X)
53
+ Y = np.asarray(Y)
54
+
55
+ ndim = X.shape[1]
56
+
57
+ clf = MLPClassifier(
58
+ activation="relu",
59
+ hidden_layer_sizes=(10 * ndim, 10 * ndim),
60
+ max_iter=10000,
61
+ solver="adam",
62
+ random_state=seed,
63
+ )
64
+
65
+ data = np.concatenate((X, Y))
66
+ target = np.concatenate(
67
+ (
68
+ np.zeros((X.shape[0],)),
69
+ np.ones((Y.shape[0],)),
70
+ )
71
+ )
72
+
73
+ shuffle = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
74
+ scores = cross_val_score(clf, data, target, cv=shuffle, scoring=scoring)
75
+
76
+ scores = np.asarray(np.mean(scores)).astype(np.float32)
77
+ return scores
78
+
79
+
80
+ def c2st_auc(
81
+ X: Array,
82
+ Y: Array,
83
+ seed: int = 1,
84
+ n_folds: int = 5,
85
+ z_score: bool = True,
86
+ noise_scale: Optional[float] = None,
87
+ ) -> Array:
88
+ """Classifier-based 2-sample test returning AUC (area under curve)
89
+
90
+ Same as c2st, except that it returns ROC AUC rather than accuracy
91
+
92
+ Args:
93
+ X: Sample 1
94
+ Y: Sample 2
95
+ seed: Seed for sklearn
96
+ n_folds: Number of folds
97
+ z_score: Z-scoring using X
98
+ noise_scale: If passed, will add Gaussian noise with std noise_scale to samples
99
+
100
+ Returns:
101
+ Metric
102
+ """
103
+ return c2st(
104
+ X,
105
+ Y,
106
+ seed=seed,
107
+ n_folds=n_folds,
108
+ scoring="roc_auc",
109
+ z_score=z_score,
110
+ noise_scale=noise_scale,
111
+ )
@@ -0,0 +1,147 @@
1
+ from typing import Optional
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jax import Array
6
+ import numpy as np
7
+ from sklearn.model_selection import KFold
8
+ from flax import nnx
9
+ import optax
10
+
11
+ # Define MLP using flax.nnx
12
+ class MLP(nnx.Module):
13
+ def __init__(self, in_dim, hidden_dim, *, rngs):
14
+ self.seq = nnx.Sequential([
15
+ nnx.Linear(in_dim, hidden_dim, rngs= rngs),
16
+ nnx.Relu(),
17
+ nnx.Linear(hidden_dim, hidden_dim, rngs=rngs),
18
+ nnx.Relu(),
19
+ nnx.Linear(hidden_dim, 2, rngs=rngs),
20
+ ])
21
+ def __call__(self, x):
22
+ return self.seq(x)
23
+
24
+
25
+ def loss_fn(state, x, y):
26
+ logits = state.value(x)
27
+ labels = jax.nn.one_hot(y, 2)
28
+ loss = optax.softmax_cross_entropy(logits, labels).mean()
29
+ return loss
30
+
31
+ def accuracy_fn(state, x, y):
32
+ logits = state.value(x)
33
+ preds = jnp.argmax(logits, axis=-1)
34
+ return (preds == y).mean()
35
+
36
+ def c2st(
37
+ X: Array,
38
+ Y: Array,
39
+ seed: int = 1,
40
+ n_folds: int = 5,
41
+ z_score: bool = True,
42
+ noise_scale: Optional[float] = None,
43
+ ) -> Array:
44
+ """Classifier-based 2-sample test returning accuracy (using nnx for GPU training)
45
+
46
+ Trains classifiers with N-fold cross-validation [1]. nnx MLP is used, with 2 hidden layers of 10x dim each.
47
+
48
+ Args:
49
+ X: Sample 1
50
+ Y: Sample 2
51
+ seed: Seed for random number generation
52
+ n_folds: Number of folds
53
+ z_score: Z-scoring using X
54
+ noise_scale: If passed, will add Gaussian noise with std noise_scale to samples
55
+
56
+ References:
57
+ [1]: https://scikit-learn.org/stable/modules/cross_validation.html
58
+ """
59
+
60
+ rngs = nnx.Rngs(seed)
61
+ if z_score:
62
+ X_mean = jnp.mean(X, axis=0)
63
+ X_std = jnp.std(X, axis=0)
64
+ X = (X - X_mean) / X_std
65
+ Y = (Y - X_mean) / X_std
66
+
67
+ if noise_scale is not None:
68
+ key = jax.random.PRNGKey(seed)
69
+ X = X + noise_scale * jax.random.normal(key, X.shape)
70
+ Y = Y + noise_scale * jax.random.normal(key, Y.shape)
71
+
72
+ X = jnp.asarray(X)
73
+ Y = jnp.asarray(Y)
74
+ ndim = X.shape[1]
75
+
76
+ # Prepare data and targets
77
+ data = jnp.concatenate([X, Y], axis=0)
78
+ target = jnp.concatenate([
79
+ jnp.zeros((X.shape[0],), dtype=jnp.int32),
80
+ jnp.ones((Y.shape[0],), dtype=jnp.int32)
81
+ ], axis=0)
82
+
83
+ kf = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
84
+ scores = []
85
+
86
+ for fold, (train_idx, test_idx) in enumerate(kf.split(data)):
87
+ x_train, y_train = data[train_idx], target[train_idx]
88
+ x_test, y_test = data[test_idx], target[test_idx]
89
+
90
+ for fold, (train_idx, test_idx) in enumerate(kf.split(data)):
91
+ # Model and optimizer
92
+ key = jax.random.PRNGKey(seed + fold)
93
+ model = MLP(ndim, 10 * ndim, rngs=rngs)
94
+ optimizer = nnx.Optimizer(model, optax.adam(1e-3))
95
+
96
+ @jax.jit
97
+ def train_step(optimizer, x, y):
98
+ def _loss_fn(model):
99
+ return loss_fn(model, x, y)
100
+ loss, grads = nnx.value_and_grad(_loss_fn)(optimizer.target)
101
+ optimizer.update(grads, value=loss)
102
+ return optimizer, loss
103
+
104
+ # Training loop
105
+ n_epochs = 100
106
+ batch_size = min(128, x_train.shape[0])
107
+ n_batches = int(jnp.ceil(x_train.shape[0] / batch_size))
108
+ for epoch in range(n_epochs):
109
+ perm = jax.random.permutation(key, x_train.shape[0])
110
+ x_train_shuffled = x_train[perm]
111
+ y_train_shuffled = y_train[perm]
112
+ for i in range(n_batches):
113
+ start = i * batch_size
114
+ end = min((i + 1) * batch_size, x_train.shape[0])
115
+ xb = x_train_shuffled[start:end]
116
+ yb = y_train_shuffled[start:end]
117
+ optimizer, _ = train_step(optimizer, xb, yb)
118
+ model = optimizer.target
119
+
120
+ score = float(accuracy_fn(model, x_test, y_test))
121
+
122
+ scores.append(score)
123
+
124
+ return np.asarray(np.mean(scores), dtype=np.float32)
125
+
126
+
127
+ # def c2st_auc(
128
+ # X: Array,
129
+ # Y: Array,
130
+ # seed: int = 1,
131
+ # n_folds: int = 5,
132
+ # z_score: bool = True,
133
+ # noise_scale: Optional[float] = None,
134
+ # ) -> Array:
135
+ # """Classifier-based 2-sample test returning AUC (area under curve)
136
+
137
+ # Same as c2st, except that it returns ROC AUC rather than accuracy
138
+ # """
139
+ # return c2st(
140
+ # X,
141
+ # Y,
142
+ # seed=seed,
143
+ # n_folds=n_folds,
144
+ # scoring="roc_auc",
145
+ # z_score=z_score,
146
+ # noise_scale=noise_scale,
147
+ # )