climemu 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- climemu-0.1.6/LICENSE +21 -0
- climemu-0.1.6/PKG-INFO +41 -0
- climemu-0.1.6/README.md +0 -0
- climemu-0.1.6/pyproject.toml +80 -0
- climemu-0.1.6/setup.cfg +4 -0
- climemu-0.1.6/src/climemu/__init__.py +17 -0
- climemu-0.1.6/src/climemu/emulators/__init__.py +1 -0
- climemu-0.1.6/src/climemu/emulators/abstractemulator.py +44 -0
- climemu-0.1.6/src/climemu/emulators/bouabid2025.py +200 -0
- climemu-0.1.6/src/climemu/utils/__init__.py +3 -0
- climemu-0.1.6/src/climemu/utils/registry.py +71 -0
- climemu-0.1.6/src/climemu.egg-info/PKG-INFO +41 -0
- climemu-0.1.6/src/climemu.egg-info/SOURCES.txt +46 -0
- climemu-0.1.6/src/climemu.egg-info/dependency_links.txt +1 -0
- climemu-0.1.6/src/climemu.egg-info/requires.txt +35 -0
- climemu-0.1.6/src/climemu.egg-info/top_level.txt +4 -0
- climemu-0.1.6/src/datasets/__init__.py +12 -0
- climemu-0.1.6/src/datasets/cmip6.py +246 -0
- climemu-0.1.6/src/datasets/constants.py +10 -0
- climemu-0.1.6/src/datasets/pattern_to_cmip6.py +192 -0
- climemu-0.1.6/src/diffusion/__init__.py +25 -0
- climemu-0.1.6/src/diffusion/losses/__init__.py +9 -0
- climemu-0.1.6/src/diffusion/losses/denoising_score_matching.py +123 -0
- climemu-0.1.6/src/diffusion/nn/__init__.py +3 -0
- climemu-0.1.6/src/diffusion/nn/backbones/__init__.py +3 -0
- climemu-0.1.6/src/diffusion/nn/backbones/convnet.py +86 -0
- climemu-0.1.6/src/diffusion/nn/healpixunet.py +559 -0
- climemu-0.1.6/src/diffusion/nn/modules/__init__.py +32 -0
- climemu-0.1.6/src/diffusion/nn/modules/healpix/__init__.py +21 -0
- climemu-0.1.6/src/diffusion/nn/modules/healpix/conv.py +493 -0
- climemu-0.1.6/src/diffusion/nn/modules/healpix/padding.py +274 -0
- climemu-0.1.6/src/diffusion/nn/modules/remap.py +74 -0
- climemu-0.1.6/src/diffusion/nn/timeencoder/__init__.py +7 -0
- climemu-0.1.6/src/diffusion/nn/timeencoder/gaussianfourier.py +34 -0
- climemu-0.1.6/src/diffusion/samplers/__init__.py +7 -0
- climemu-0.1.6/src/diffusion/samplers/continuous_ode_sampler.py +103 -0
- climemu-0.1.6/src/diffusion/schedules/__init__.py +7 -0
- climemu-0.1.6/src/diffusion/schedules/variance_exploding.py +111 -0
- climemu-0.1.6/src/utils/__init__.py +0 -0
- climemu-0.1.6/src/utils/arrays/__init__.py +27 -0
- climemu-0.1.6/src/utils/arrays/aggregate.py +52 -0
- climemu-0.1.6/src/utils/arrays/convert.py +23 -0
- climemu-0.1.6/src/utils/arrays/filter.py +7 -0
- climemu-0.1.6/src/utils/arrays/reshape.py +11 -0
- climemu-0.1.6/src/utils/arrays/smooth.py +5 -0
- climemu-0.1.6/src/utils/collate.py +4 -0
- climemu-0.1.6/src/utils/graphs/__init__.py +5 -0
- climemu-0.1.6/src/utils/graphs/haversine.py +84 -0
climemu-0.1.6/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Shahine Bouabid
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
climemu-0.1.6/PKG-INFO
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: climemu
|
|
3
|
+
Version: 0.1.6
|
|
4
|
+
Summary: Score-based generative emulation of impact-relevant earth system model outputs
|
|
5
|
+
Author-email: Shahine Bouabid <shahineb@mit.edu>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: dask>=2025.7.0
|
|
11
|
+
Requires-Dist: diffrax>=0.7.0
|
|
12
|
+
Requires-Dist: einops>=0.8.1
|
|
13
|
+
Requires-Dist: equinox>=0.13.0
|
|
14
|
+
Requires-Dist: huggingface-hub>=0.35.0
|
|
15
|
+
Requires-Dist: jax>=0.7.1
|
|
16
|
+
Requires-Dist: netcdf4>=1.7.2
|
|
17
|
+
Requires-Dist: numpy>=2.3.2
|
|
18
|
+
Requires-Dist: optax>=0.2.5
|
|
19
|
+
Requires-Dist: scikit-learn>=1.7.1
|
|
20
|
+
Requires-Dist: scipy>=1.16.1
|
|
21
|
+
Requires-Dist: tqdm>=4.67.1
|
|
22
|
+
Requires-Dist: xarray>=2025.9.0
|
|
23
|
+
Provides-Extra: gpu
|
|
24
|
+
Requires-Dist: jax[cuda12]>=0.7.1; extra == "gpu"
|
|
25
|
+
Provides-Extra: train
|
|
26
|
+
Requires-Dist: torch>=2.8.0; extra == "train"
|
|
27
|
+
Requires-Dist: healpy>=1.18.1; extra == "train"
|
|
28
|
+
Requires-Dist: wandb>=0.21.3; extra == "train"
|
|
29
|
+
Provides-Extra: plots
|
|
30
|
+
Requires-Dist: matplotlib>=3.10.6; extra == "plots"
|
|
31
|
+
Requires-Dist: seaborn>=0.13.2; extra == "plots"
|
|
32
|
+
Requires-Dist: cartopy>=0.25.0; extra == "plots"
|
|
33
|
+
Requires-Dist: regionmask>=0.13.0; extra == "plots"
|
|
34
|
+
Requires-Dist: pandas>=2.3.2; extra == "plots"
|
|
35
|
+
Requires-Dist: pyshtools>=4.13.1; extra == "plots"
|
|
36
|
+
Provides-Extra: test
|
|
37
|
+
Requires-Dist: pytest>=8.0.0; extra == "test"
|
|
38
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "test"
|
|
39
|
+
Requires-Dist: pytest-mock>=3.10.0; extra == "test"
|
|
40
|
+
Requires-Dist: jaxlib>=0.4.0; extra == "test"
|
|
41
|
+
Dynamic: license-file
|
climemu-0.1.6/README.md
ADDED
|
File without changes
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "climemu"
|
|
3
|
+
version = "0.1.6"
|
|
4
|
+
description = "Score-based generative emulation of impact-relevant earth system model outputs"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "Shahine Bouabid", email = "shahineb@mit.edu" },
|
|
8
|
+
]
|
|
9
|
+
requires-python = ">=3.12"
|
|
10
|
+
license = { text = "MIT" }
|
|
11
|
+
dependencies = [
|
|
12
|
+
"dask>=2025.7.0",
|
|
13
|
+
"diffrax>=0.7.0",
|
|
14
|
+
"einops>=0.8.1",
|
|
15
|
+
"equinox>=0.13.0",
|
|
16
|
+
"huggingface-hub>=0.35.0",
|
|
17
|
+
"jax>=0.7.1",
|
|
18
|
+
"netcdf4>=1.7.2",
|
|
19
|
+
"numpy>=2.3.2",
|
|
20
|
+
"optax>=0.2.5",
|
|
21
|
+
"scikit-learn>=1.7.1",
|
|
22
|
+
"scipy>=1.16.1",
|
|
23
|
+
"tqdm>=4.67.1",
|
|
24
|
+
"xarray>=2025.9.0",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[project.optional-dependencies]
|
|
28
|
+
gpu = [
|
|
29
|
+
"jax[cuda12]>=0.7.1"
|
|
30
|
+
]
|
|
31
|
+
train = [
|
|
32
|
+
"torch>=2.8.0",
|
|
33
|
+
"healpy>=1.18.1",
|
|
34
|
+
"wandb>=0.21.3",
|
|
35
|
+
]
|
|
36
|
+
plots = [
|
|
37
|
+
"matplotlib>=3.10.6",
|
|
38
|
+
"seaborn>=0.13.2",
|
|
39
|
+
"cartopy>=0.25.0",
|
|
40
|
+
"regionmask>=0.13.0",
|
|
41
|
+
"pandas>=2.3.2",
|
|
42
|
+
"pyshtools>=4.13.1",
|
|
43
|
+
]
|
|
44
|
+
test = [
|
|
45
|
+
"pytest>=8.0.0",
|
|
46
|
+
"pytest-cov>=4.0.0",
|
|
47
|
+
"pytest-mock>=3.10.0",
|
|
48
|
+
"jaxlib>=0.4.0",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
[build-system]
|
|
54
|
+
requires = ["setuptools>=61.0"]
|
|
55
|
+
build-backend = "setuptools.build_meta"
|
|
56
|
+
|
|
57
|
+
[tool.setuptools]
|
|
58
|
+
package-dir = {"" = "src"}
|
|
59
|
+
|
|
60
|
+
[tool.setuptools.packages.find]
|
|
61
|
+
where = ["src"]
|
|
62
|
+
|
|
63
|
+
[tool.pytest.ini_options]
|
|
64
|
+
filterwarnings = [
|
|
65
|
+
"ignore::DeprecationWarning:equinox.internal._noinline",
|
|
66
|
+
"ignore::DeprecationWarning:jax.interpreters.batching",
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
[dependency-groups]
|
|
70
|
+
dev = [
|
|
71
|
+
"pytest>=8.4.2",
|
|
72
|
+
"pytest-cov>=7.0.0",
|
|
73
|
+
"pytest-mock>=3.15.1",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
[[tool.uv.index]]
|
|
77
|
+
name = "testpypi"
|
|
78
|
+
url = "https://test.pypi.org/simple/"
|
|
79
|
+
publish-url = "https://test.pypi.org/legacy/"
|
|
80
|
+
explicit = true
|
climemu-0.1.6/setup.cfg
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from climemu.utils import Registry
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.6"
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Registery of pretrained emulators for usage
|
|
7
|
+
"""
|
|
8
|
+
EMULATORS = Registry()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def build_emulator(name):
|
|
12
|
+
model = EMULATORS[name]()
|
|
13
|
+
return model
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from .emulators import Bouabid2025Emulator
|
|
17
|
+
__all__ = ['build_emulator', 'Bouabid2025Emulator']
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .bouabid2025 import Bouabid2025Emulator
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AbstractEmulator(ABC):
|
|
5
|
+
@abstractmethod
|
|
6
|
+
def __call__(self, *args, **kwargs):
|
|
7
|
+
"""Generate climate data samples"""
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
@classmethod
|
|
11
|
+
def build(cls, **kwargs):
|
|
12
|
+
return cls(**kwargs)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class GriddedEmulator(AbstractEmulator):
|
|
16
|
+
@property
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def lat(self):
|
|
19
|
+
"""Latitude coordinates"""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def lon(self):
|
|
25
|
+
"""Longitude coordinates"""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def vars(self):
|
|
31
|
+
"""Variable list"""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def nlat(self):
|
|
36
|
+
return len(self.lat)
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def nlon(self):
|
|
40
|
+
return len(self.lon)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def nvar(self):
|
|
44
|
+
return len(self.vars)
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import yaml
|
|
3
|
+
from functools import partial
|
|
4
|
+
import xarray as xr
|
|
5
|
+
import numpy as np
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import jax.random as jr
|
|
8
|
+
import equinox as eqx
|
|
9
|
+
from huggingface_hub import hf_hub_download
|
|
10
|
+
from diffusion import HealPIXUNet, ContinuousVESchedule, ContinuousHeunSampler
|
|
11
|
+
from .abstractemulator import GriddedEmulator
|
|
12
|
+
from .. import EMULATORS
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Bouabid2025Emulator(GriddedEmulator):
|
|
16
|
+
def __init__(self, esm_name: str):
|
|
17
|
+
self.esm = esm_name
|
|
18
|
+
self.repo_id = "shahineb/climemu"
|
|
19
|
+
|
|
20
|
+
def load(self, which: str = "default"):
|
|
21
|
+
# Set files directory in hugging face repo
|
|
22
|
+
self.files_dir = os.path.join(self.esm, which)
|
|
23
|
+
|
|
24
|
+
# Load pattern scaling coefficients
|
|
25
|
+
self.β = self._load_pattern_scaling()
|
|
26
|
+
|
|
27
|
+
# Load climatology data
|
|
28
|
+
self.climatology = self._load_climatology()
|
|
29
|
+
|
|
30
|
+
# Load the generative model precursor
|
|
31
|
+
self.precursor = self._load_precursor()
|
|
32
|
+
|
|
33
|
+
def compile(self, n_samples, n_steps=30):
|
|
34
|
+
# Fix number of samples and steps for generation
|
|
35
|
+
self.generative_model = partial(self.precursor,
|
|
36
|
+
n_samples=n_samples,
|
|
37
|
+
n_steps=n_steps)
|
|
38
|
+
|
|
39
|
+
# Perform a dry run to compile the JAX functions (important for performance)
|
|
40
|
+
dummy_pattern = jnp.zeros((self.nlat, self.nlon))
|
|
41
|
+
_ = self.generative_model(pattern=dummy_pattern, key=jr.PRNGKey(0))
|
|
42
|
+
|
|
43
|
+
def __call__(self, gmst, month, seed=None, xarray=False):
|
|
44
|
+
# Apply pattern scaling: pattern = β₀ + β₁ * ΔT
|
|
45
|
+
pattern = self.β[month - 1, :, 1] * gmst + self.β[month - 1, :, 0]
|
|
46
|
+
pattern = pattern.reshape((self.nlat, self.nlon))
|
|
47
|
+
|
|
48
|
+
# Generate samples using the diffusion model
|
|
49
|
+
key = jr.PRNGKey(seed) if seed else jr.PRNGKey(np.random.randint(0, 1000000))
|
|
50
|
+
samples = self.generative_model(pattern=pattern, key=key)
|
|
51
|
+
|
|
52
|
+
# Convert to xarray Dataset
|
|
53
|
+
if xarray:
|
|
54
|
+
samples = xr.Dataset(
|
|
55
|
+
{
|
|
56
|
+
var: (("member", "lat", "lon"), samples[:, i, :, :])
|
|
57
|
+
for i, var in enumerate(self.vars)
|
|
58
|
+
},
|
|
59
|
+
coords={
|
|
60
|
+
"member": jnp.arange(len(samples)) + 1,
|
|
61
|
+
"lat": self.lat,
|
|
62
|
+
"lon": self.lon,
|
|
63
|
+
},
|
|
64
|
+
)
|
|
65
|
+
return samples
|
|
66
|
+
|
|
67
|
+
def _load_precursor(self):
|
|
68
|
+
# Build the neural network and noise schedule
|
|
69
|
+
config = self._load_config()
|
|
70
|
+
nn = self._load_nn(config)
|
|
71
|
+
schedule = self._load_schedule(config)
|
|
72
|
+
|
|
73
|
+
# Load normalization statistics used during training (needed to denormalize the generated samples)
|
|
74
|
+
μ, σ = self._load_normalization()
|
|
75
|
+
|
|
76
|
+
# Define the output size for the generated samples
|
|
77
|
+
output_size = (config['out_channels'], config['input_size'][1], config['input_size'][2])
|
|
78
|
+
|
|
79
|
+
# Create an precursor for the generative model
|
|
80
|
+
precursor = partial(draw_samples_single,
|
|
81
|
+
nn=nn,
|
|
82
|
+
schedule=schedule,
|
|
83
|
+
output_size=output_size,
|
|
84
|
+
μ=μ, σ=σ)
|
|
85
|
+
return precursor
|
|
86
|
+
|
|
87
|
+
def _load_config(self):
|
|
88
|
+
# Load configuration from YAML file
|
|
89
|
+
config_path = hf_hub_download(self.repo_id, f"{self.files_dir}/config.yaml")
|
|
90
|
+
with open(config_path, 'r') as file:
|
|
91
|
+
config = yaml.safe_load(file)
|
|
92
|
+
return config
|
|
93
|
+
|
|
94
|
+
def _load_normalization(self):
|
|
95
|
+
# Load normalization statistics used during training
|
|
96
|
+
norm_stats_path = hf_hub_download(self.repo_id, f"{self.files_dir}/μ_σ.npz")
|
|
97
|
+
stats = jnp.load(norm_stats_path)
|
|
98
|
+
μ, σ = stats['μ'], stats['σ']
|
|
99
|
+
return μ, σ
|
|
100
|
+
|
|
101
|
+
def _load_pattern_scaling(self):
|
|
102
|
+
# Load pattern scaling coefficients
|
|
103
|
+
pattern_scaling_path = hf_hub_download(self.repo_id, f"{self.files_dir}/β.npy")
|
|
104
|
+
β = jnp.load(pattern_scaling_path)
|
|
105
|
+
return β
|
|
106
|
+
|
|
107
|
+
def _load_climatology(self):
|
|
108
|
+
# Load climatology data
|
|
109
|
+
climatology_path = hf_hub_download(self.repo_id, f"{self.esm}/piControl_climatology.nc")
|
|
110
|
+
climatology = xr.open_dataset(climatology_path)
|
|
111
|
+
return climatology
|
|
112
|
+
|
|
113
|
+
def _load_nn(self, config):
|
|
114
|
+
# Load graph edges for HEALPix to lat-lon connectivity
|
|
115
|
+
edges_path = hf_hub_download(self.repo_id, f"{self.files_dir}/edges.npz")
|
|
116
|
+
edges_data = jnp.load(edges_path)
|
|
117
|
+
to_healpix = jnp.array(edges_data['to_healpix']).astype(jnp.int32)
|
|
118
|
+
to_latlon = jnp.array(edges_data['to_latlon']).astype(jnp.int32)
|
|
119
|
+
|
|
120
|
+
# Initialize the neural network
|
|
121
|
+
nn = HealPIXUNet(input_size=config['input_size'],
|
|
122
|
+
nside=config['nside'],
|
|
123
|
+
enc_filters=config['enc_filters'],
|
|
124
|
+
dec_filters=config['dec_filters'],
|
|
125
|
+
out_channels=config['out_channels'],
|
|
126
|
+
temb_dim=config['temb_dim'],
|
|
127
|
+
healpix_emb_dim=config['healpix_emb_dim'],
|
|
128
|
+
edges_to_healpix=to_healpix,
|
|
129
|
+
edges_to_latlon=to_latlon)
|
|
130
|
+
|
|
131
|
+
# Load the pre-trained weights from the saved model file
|
|
132
|
+
weights_path = hf_hub_download(self.repo_id, f"{self.files_dir}/weights.eqx")
|
|
133
|
+
nn = eqx.tree_deserialise_leaves(weights_path, nn)
|
|
134
|
+
return nn
|
|
135
|
+
|
|
136
|
+
def _load_schedule(self, config):
|
|
137
|
+
# Load the maximum noise level as used in training
|
|
138
|
+
sigma_max_path = hf_hub_download(self.repo_id, f"{self.files_dir}/σmax.npy")
|
|
139
|
+
σmax = jnp.load(sigma_max_path)
|
|
140
|
+
|
|
141
|
+
# Create the variance exploding schedule
|
|
142
|
+
schedule = ContinuousVESchedule(config['sigma_min'], σmax)
|
|
143
|
+
return schedule
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def lat(self):
|
|
147
|
+
return self.climatology['lat'].values
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def lon(self):
|
|
151
|
+
return self.climatology['lon'].values
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def vars(self):
|
|
155
|
+
return list(self.climatology.data_vars)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@eqx.filter_jit
|
|
159
|
+
def normalize(x, μ, σ):
|
|
160
|
+
"""Normalize data using mean and standard deviation."""
|
|
161
|
+
return (x - μ) / σ
|
|
162
|
+
|
|
163
|
+
@eqx.filter_jit
|
|
164
|
+
def denormalize(x, μ, σ):
|
|
165
|
+
"""Denormalize data using mean and standard deviation."""
|
|
166
|
+
return σ * x + μ
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def create_sampler(nn, schedule, pattern, μ, σ, output_size):
|
|
170
|
+
"""Create a sampler for a given pattern."""
|
|
171
|
+
context = normalize(pattern, μ[-1], σ[-1])[None, ...]
|
|
172
|
+
def nn_with_context(x, t):
|
|
173
|
+
x = jnp.concatenate((x, context), axis=0)
|
|
174
|
+
return nn(x, t)
|
|
175
|
+
return ContinuousHeunSampler(schedule, nn_with_context, output_size)
|
|
176
|
+
|
|
177
|
+
@eqx.filter_jit
|
|
178
|
+
def draw_samples_single(nn, schedule, pattern, n_samples, n_steps, μ, σ, output_size, key=jr.PRNGKey(0)):
|
|
179
|
+
"""Draw samples for a given pattern."""
|
|
180
|
+
sampler = create_sampler(nn, schedule, pattern, μ, σ, output_size)
|
|
181
|
+
samples = sampler.sample(n_samples, steps=n_steps, key=key)
|
|
182
|
+
return denormalize(samples, μ[:-1], σ[:-1])
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@EMULATORS.register("MPI-ESM1-2-LR")
|
|
186
|
+
class MPIEmulator(Bouabid2025Emulator):
|
|
187
|
+
def __init__(self, which="default"):
|
|
188
|
+
super().__init__(esm_name="MPI-ESM1-2-LR")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@EMULATORS.register("MIROC6")
|
|
192
|
+
class MIROCEmulator(Bouabid2025Emulator):
|
|
193
|
+
def __init__(self, which="default"):
|
|
194
|
+
super().__init__(esm_name="MIROC6")
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@EMULATORS.register("ACCESS-ESM1-5")
|
|
198
|
+
class ACCESSEmulator(Bouabid2025Emulator):
|
|
199
|
+
def __init__(self, which="default"):
|
|
200
|
+
super().__init__(esm_name="ACCESS-ESM1-5")
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import types
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Registry(dict):
|
|
5
|
+
"""
|
|
6
|
+
A helper class for managing access to builders, it extends a dictionary
|
|
7
|
+
and provides a registering functions than can be used as a decorator
|
|
8
|
+
|
|
9
|
+
Creating a registry:
|
|
10
|
+
MODULES = Registry()
|
|
11
|
+
|
|
12
|
+
There two types of builder callable you can register:
|
|
13
|
+
|
|
14
|
+
(1) : Functions which can be registered by either a simple call
|
|
15
|
+
```
|
|
16
|
+
def build_bar():
|
|
17
|
+
return True
|
|
18
|
+
MODULES.register('bar', build_bar)
|
|
19
|
+
```
|
|
20
|
+
or using a decorator at function definition
|
|
21
|
+
```
|
|
22
|
+
@MODULES.register('bar')
|
|
23
|
+
def build_bar():
|
|
24
|
+
return True
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
(2) : Class constructor method cls.build which again can be registered
|
|
28
|
+
with a call if class has a cls.build method
|
|
29
|
+
```
|
|
30
|
+
foo = Foo()
|
|
31
|
+
MODULES.register('Foo', foo)
|
|
32
|
+
```
|
|
33
|
+
of using a decorator at class definition
|
|
34
|
+
```
|
|
35
|
+
@MODULES.register('foo')
|
|
36
|
+
class Foo:
|
|
37
|
+
@classmethod
|
|
38
|
+
def build(cls, *args, **kwargs):
|
|
39
|
+
# build a class instance foo
|
|
40
|
+
return foo
|
|
41
|
+
```
|
|
42
|
+
|
|
43
|
+
Access of module is just like using a dictionary, eg:
|
|
44
|
+
build_bar = MODULES['bar']
|
|
45
|
+
build_foo = MODULES['foo']
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, *args, **kwargs):
|
|
49
|
+
super(Registry, self).__init__(*args, **kwargs)
|
|
50
|
+
|
|
51
|
+
def register(self, name, module_object=None):
|
|
52
|
+
# Used as a decorator
|
|
53
|
+
if module_object is None:
|
|
54
|
+
def register_func(module_object):
|
|
55
|
+
self.register(name=name, module_object=module_object)
|
|
56
|
+
return module_object
|
|
57
|
+
return register_func
|
|
58
|
+
|
|
59
|
+
# Used as a function call
|
|
60
|
+
else:
|
|
61
|
+
if isinstance(module_object, type):
|
|
62
|
+
self._register_generic(module_dict=self, name=name, builder=module_object.build)
|
|
63
|
+
elif isinstance(module_object, types.FunctionType):
|
|
64
|
+
self._register_generic(module_dict=self, name=name, builder=module_object)
|
|
65
|
+
else:
|
|
66
|
+
raise TypeError("Trying to register unknown data type")
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def _register_generic(module_dict, name, builder):
|
|
70
|
+
assert name not in module_dict
|
|
71
|
+
module_dict[name] = builder
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: climemu
|
|
3
|
+
Version: 0.1.6
|
|
4
|
+
Summary: Score-based generative emulation of impact-relevant earth system model outputs
|
|
5
|
+
Author-email: Shahine Bouabid <shahineb@mit.edu>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: dask>=2025.7.0
|
|
11
|
+
Requires-Dist: diffrax>=0.7.0
|
|
12
|
+
Requires-Dist: einops>=0.8.1
|
|
13
|
+
Requires-Dist: equinox>=0.13.0
|
|
14
|
+
Requires-Dist: huggingface-hub>=0.35.0
|
|
15
|
+
Requires-Dist: jax>=0.7.1
|
|
16
|
+
Requires-Dist: netcdf4>=1.7.2
|
|
17
|
+
Requires-Dist: numpy>=2.3.2
|
|
18
|
+
Requires-Dist: optax>=0.2.5
|
|
19
|
+
Requires-Dist: scikit-learn>=1.7.1
|
|
20
|
+
Requires-Dist: scipy>=1.16.1
|
|
21
|
+
Requires-Dist: tqdm>=4.67.1
|
|
22
|
+
Requires-Dist: xarray>=2025.9.0
|
|
23
|
+
Provides-Extra: gpu
|
|
24
|
+
Requires-Dist: jax[cuda12]>=0.7.1; extra == "gpu"
|
|
25
|
+
Provides-Extra: train
|
|
26
|
+
Requires-Dist: torch>=2.8.0; extra == "train"
|
|
27
|
+
Requires-Dist: healpy>=1.18.1; extra == "train"
|
|
28
|
+
Requires-Dist: wandb>=0.21.3; extra == "train"
|
|
29
|
+
Provides-Extra: plots
|
|
30
|
+
Requires-Dist: matplotlib>=3.10.6; extra == "plots"
|
|
31
|
+
Requires-Dist: seaborn>=0.13.2; extra == "plots"
|
|
32
|
+
Requires-Dist: cartopy>=0.25.0; extra == "plots"
|
|
33
|
+
Requires-Dist: regionmask>=0.13.0; extra == "plots"
|
|
34
|
+
Requires-Dist: pandas>=2.3.2; extra == "plots"
|
|
35
|
+
Requires-Dist: pyshtools>=4.13.1; extra == "plots"
|
|
36
|
+
Provides-Extra: test
|
|
37
|
+
Requires-Dist: pytest>=8.0.0; extra == "test"
|
|
38
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "test"
|
|
39
|
+
Requires-Dist: pytest-mock>=3.10.0; extra == "test"
|
|
40
|
+
Requires-Dist: jaxlib>=0.4.0; extra == "test"
|
|
41
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/climemu/__init__.py
|
|
5
|
+
src/climemu.egg-info/PKG-INFO
|
|
6
|
+
src/climemu.egg-info/SOURCES.txt
|
|
7
|
+
src/climemu.egg-info/dependency_links.txt
|
|
8
|
+
src/climemu.egg-info/requires.txt
|
|
9
|
+
src/climemu.egg-info/top_level.txt
|
|
10
|
+
src/climemu/emulators/__init__.py
|
|
11
|
+
src/climemu/emulators/abstractemulator.py
|
|
12
|
+
src/climemu/emulators/bouabid2025.py
|
|
13
|
+
src/climemu/utils/__init__.py
|
|
14
|
+
src/climemu/utils/registry.py
|
|
15
|
+
src/datasets/__init__.py
|
|
16
|
+
src/datasets/cmip6.py
|
|
17
|
+
src/datasets/constants.py
|
|
18
|
+
src/datasets/pattern_to_cmip6.py
|
|
19
|
+
src/diffusion/__init__.py
|
|
20
|
+
src/diffusion/losses/__init__.py
|
|
21
|
+
src/diffusion/losses/denoising_score_matching.py
|
|
22
|
+
src/diffusion/nn/__init__.py
|
|
23
|
+
src/diffusion/nn/healpixunet.py
|
|
24
|
+
src/diffusion/nn/backbones/__init__.py
|
|
25
|
+
src/diffusion/nn/backbones/convnet.py
|
|
26
|
+
src/diffusion/nn/modules/__init__.py
|
|
27
|
+
src/diffusion/nn/modules/remap.py
|
|
28
|
+
src/diffusion/nn/modules/healpix/__init__.py
|
|
29
|
+
src/diffusion/nn/modules/healpix/conv.py
|
|
30
|
+
src/diffusion/nn/modules/healpix/padding.py
|
|
31
|
+
src/diffusion/nn/timeencoder/__init__.py
|
|
32
|
+
src/diffusion/nn/timeencoder/gaussianfourier.py
|
|
33
|
+
src/diffusion/samplers/__init__.py
|
|
34
|
+
src/diffusion/samplers/continuous_ode_sampler.py
|
|
35
|
+
src/diffusion/schedules/__init__.py
|
|
36
|
+
src/diffusion/schedules/variance_exploding.py
|
|
37
|
+
src/utils/__init__.py
|
|
38
|
+
src/utils/collate.py
|
|
39
|
+
src/utils/arrays/__init__.py
|
|
40
|
+
src/utils/arrays/aggregate.py
|
|
41
|
+
src/utils/arrays/convert.py
|
|
42
|
+
src/utils/arrays/filter.py
|
|
43
|
+
src/utils/arrays/reshape.py
|
|
44
|
+
src/utils/arrays/smooth.py
|
|
45
|
+
src/utils/graphs/__init__.py
|
|
46
|
+
src/utils/graphs/haversine.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
dask>=2025.7.0
|
|
2
|
+
diffrax>=0.7.0
|
|
3
|
+
einops>=0.8.1
|
|
4
|
+
equinox>=0.13.0
|
|
5
|
+
huggingface-hub>=0.35.0
|
|
6
|
+
jax>=0.7.1
|
|
7
|
+
netcdf4>=1.7.2
|
|
8
|
+
numpy>=2.3.2
|
|
9
|
+
optax>=0.2.5
|
|
10
|
+
scikit-learn>=1.7.1
|
|
11
|
+
scipy>=1.16.1
|
|
12
|
+
tqdm>=4.67.1
|
|
13
|
+
xarray>=2025.9.0
|
|
14
|
+
|
|
15
|
+
[gpu]
|
|
16
|
+
jax[cuda12]>=0.7.1
|
|
17
|
+
|
|
18
|
+
[plots]
|
|
19
|
+
matplotlib>=3.10.6
|
|
20
|
+
seaborn>=0.13.2
|
|
21
|
+
cartopy>=0.25.0
|
|
22
|
+
regionmask>=0.13.0
|
|
23
|
+
pandas>=2.3.2
|
|
24
|
+
pyshtools>=4.13.1
|
|
25
|
+
|
|
26
|
+
[test]
|
|
27
|
+
pytest>=8.0.0
|
|
28
|
+
pytest-cov>=4.0.0
|
|
29
|
+
pytest-mock>=3.10.0
|
|
30
|
+
jaxlib>=0.4.0
|
|
31
|
+
|
|
32
|
+
[train]
|
|
33
|
+
torch>=2.8.0
|
|
34
|
+
healpy>=1.18.1
|
|
35
|
+
wandb>=0.21.3
|