climemu 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. climemu-0.1.6/LICENSE +21 -0
  2. climemu-0.1.6/PKG-INFO +41 -0
  3. climemu-0.1.6/README.md +0 -0
  4. climemu-0.1.6/pyproject.toml +80 -0
  5. climemu-0.1.6/setup.cfg +4 -0
  6. climemu-0.1.6/src/climemu/__init__.py +17 -0
  7. climemu-0.1.6/src/climemu/emulators/__init__.py +1 -0
  8. climemu-0.1.6/src/climemu/emulators/abstractemulator.py +44 -0
  9. climemu-0.1.6/src/climemu/emulators/bouabid2025.py +200 -0
  10. climemu-0.1.6/src/climemu/utils/__init__.py +3 -0
  11. climemu-0.1.6/src/climemu/utils/registry.py +71 -0
  12. climemu-0.1.6/src/climemu.egg-info/PKG-INFO +41 -0
  13. climemu-0.1.6/src/climemu.egg-info/SOURCES.txt +46 -0
  14. climemu-0.1.6/src/climemu.egg-info/dependency_links.txt +1 -0
  15. climemu-0.1.6/src/climemu.egg-info/requires.txt +35 -0
  16. climemu-0.1.6/src/climemu.egg-info/top_level.txt +4 -0
  17. climemu-0.1.6/src/datasets/__init__.py +12 -0
  18. climemu-0.1.6/src/datasets/cmip6.py +246 -0
  19. climemu-0.1.6/src/datasets/constants.py +10 -0
  20. climemu-0.1.6/src/datasets/pattern_to_cmip6.py +192 -0
  21. climemu-0.1.6/src/diffusion/__init__.py +25 -0
  22. climemu-0.1.6/src/diffusion/losses/__init__.py +9 -0
  23. climemu-0.1.6/src/diffusion/losses/denoising_score_matching.py +123 -0
  24. climemu-0.1.6/src/diffusion/nn/__init__.py +3 -0
  25. climemu-0.1.6/src/diffusion/nn/backbones/__init__.py +3 -0
  26. climemu-0.1.6/src/diffusion/nn/backbones/convnet.py +86 -0
  27. climemu-0.1.6/src/diffusion/nn/healpixunet.py +559 -0
  28. climemu-0.1.6/src/diffusion/nn/modules/__init__.py +32 -0
  29. climemu-0.1.6/src/diffusion/nn/modules/healpix/__init__.py +21 -0
  30. climemu-0.1.6/src/diffusion/nn/modules/healpix/conv.py +493 -0
  31. climemu-0.1.6/src/diffusion/nn/modules/healpix/padding.py +274 -0
  32. climemu-0.1.6/src/diffusion/nn/modules/remap.py +74 -0
  33. climemu-0.1.6/src/diffusion/nn/timeencoder/__init__.py +7 -0
  34. climemu-0.1.6/src/diffusion/nn/timeencoder/gaussianfourier.py +34 -0
  35. climemu-0.1.6/src/diffusion/samplers/__init__.py +7 -0
  36. climemu-0.1.6/src/diffusion/samplers/continuous_ode_sampler.py +103 -0
  37. climemu-0.1.6/src/diffusion/schedules/__init__.py +7 -0
  38. climemu-0.1.6/src/diffusion/schedules/variance_exploding.py +111 -0
  39. climemu-0.1.6/src/utils/__init__.py +0 -0
  40. climemu-0.1.6/src/utils/arrays/__init__.py +27 -0
  41. climemu-0.1.6/src/utils/arrays/aggregate.py +52 -0
  42. climemu-0.1.6/src/utils/arrays/convert.py +23 -0
  43. climemu-0.1.6/src/utils/arrays/filter.py +7 -0
  44. climemu-0.1.6/src/utils/arrays/reshape.py +11 -0
  45. climemu-0.1.6/src/utils/arrays/smooth.py +5 -0
  46. climemu-0.1.6/src/utils/collate.py +4 -0
  47. climemu-0.1.6/src/utils/graphs/__init__.py +5 -0
  48. climemu-0.1.6/src/utils/graphs/haversine.py +84 -0
climemu-0.1.6/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Shahine Bouabid
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
climemu-0.1.6/PKG-INFO ADDED
@@ -0,0 +1,41 @@
1
+ Metadata-Version: 2.4
2
+ Name: climemu
3
+ Version: 0.1.6
4
+ Summary: Score-based generative emulation of impact-relevant earth system model outputs
5
+ Author-email: Shahine Bouabid <shahineb@mit.edu>
6
+ License: MIT
7
+ Requires-Python: >=3.12
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: dask>=2025.7.0
11
+ Requires-Dist: diffrax>=0.7.0
12
+ Requires-Dist: einops>=0.8.1
13
+ Requires-Dist: equinox>=0.13.0
14
+ Requires-Dist: huggingface-hub>=0.35.0
15
+ Requires-Dist: jax>=0.7.1
16
+ Requires-Dist: netcdf4>=1.7.2
17
+ Requires-Dist: numpy>=2.3.2
18
+ Requires-Dist: optax>=0.2.5
19
+ Requires-Dist: scikit-learn>=1.7.1
20
+ Requires-Dist: scipy>=1.16.1
21
+ Requires-Dist: tqdm>=4.67.1
22
+ Requires-Dist: xarray>=2025.9.0
23
+ Provides-Extra: gpu
24
+ Requires-Dist: jax[cuda12]>=0.7.1; extra == "gpu"
25
+ Provides-Extra: train
26
+ Requires-Dist: torch>=2.8.0; extra == "train"
27
+ Requires-Dist: healpy>=1.18.1; extra == "train"
28
+ Requires-Dist: wandb>=0.21.3; extra == "train"
29
+ Provides-Extra: plots
30
+ Requires-Dist: matplotlib>=3.10.6; extra == "plots"
31
+ Requires-Dist: seaborn>=0.13.2; extra == "plots"
32
+ Requires-Dist: cartopy>=0.25.0; extra == "plots"
33
+ Requires-Dist: regionmask>=0.13.0; extra == "plots"
34
+ Requires-Dist: pandas>=2.3.2; extra == "plots"
35
+ Requires-Dist: pyshtools>=4.13.1; extra == "plots"
36
+ Provides-Extra: test
37
+ Requires-Dist: pytest>=8.0.0; extra == "test"
38
+ Requires-Dist: pytest-cov>=4.0.0; extra == "test"
39
+ Requires-Dist: pytest-mock>=3.10.0; extra == "test"
40
+ Requires-Dist: jaxlib>=0.4.0; extra == "test"
41
+ Dynamic: license-file
File without changes
@@ -0,0 +1,80 @@
1
+ [project]
2
+ name = "climemu"
3
+ version = "0.1.6"
4
+ description = "Score-based generative emulation of impact-relevant earth system model outputs"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Shahine Bouabid", email = "shahineb@mit.edu" },
8
+ ]
9
+ requires-python = ">=3.12"
10
+ license = { text = "MIT" }
11
+ dependencies = [
12
+ "dask>=2025.7.0",
13
+ "diffrax>=0.7.0",
14
+ "einops>=0.8.1",
15
+ "equinox>=0.13.0",
16
+ "huggingface-hub>=0.35.0",
17
+ "jax>=0.7.1",
18
+ "netcdf4>=1.7.2",
19
+ "numpy>=2.3.2",
20
+ "optax>=0.2.5",
21
+ "scikit-learn>=1.7.1",
22
+ "scipy>=1.16.1",
23
+ "tqdm>=4.67.1",
24
+ "xarray>=2025.9.0",
25
+ ]
26
+
27
+ [project.optional-dependencies]
28
+ gpu = [
29
+ "jax[cuda12]>=0.7.1"
30
+ ]
31
+ train = [
32
+ "torch>=2.8.0",
33
+ "healpy>=1.18.1",
34
+ "wandb>=0.21.3",
35
+ ]
36
+ plots = [
37
+ "matplotlib>=3.10.6",
38
+ "seaborn>=0.13.2",
39
+ "cartopy>=0.25.0",
40
+ "regionmask>=0.13.0",
41
+ "pandas>=2.3.2",
42
+ "pyshtools>=4.13.1",
43
+ ]
44
+ test = [
45
+ "pytest>=8.0.0",
46
+ "pytest-cov>=4.0.0",
47
+ "pytest-mock>=3.10.0",
48
+ "jaxlib>=0.4.0",
49
+ ]
50
+
51
+
52
+
53
+ [build-system]
54
+ requires = ["setuptools>=61.0"]
55
+ build-backend = "setuptools.build_meta"
56
+
57
+ [tool.setuptools]
58
+ package-dir = {"" = "src"}
59
+
60
+ [tool.setuptools.packages.find]
61
+ where = ["src"]
62
+
63
+ [tool.pytest.ini_options]
64
+ filterwarnings = [
65
+ "ignore::DeprecationWarning:equinox.internal._noinline",
66
+ "ignore::DeprecationWarning:jax.interpreters.batching",
67
+ ]
68
+
69
+ [dependency-groups]
70
+ dev = [
71
+ "pytest>=8.4.2",
72
+ "pytest-cov>=7.0.0",
73
+ "pytest-mock>=3.15.1",
74
+ ]
75
+
76
+ [[tool.uv.index]]
77
+ name = "testpypi"
78
+ url = "https://test.pypi.org/simple/"
79
+ publish-url = "https://test.pypi.org/legacy/"
80
+ explicit = true
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,17 @@
1
+ from climemu.utils import Registry
2
+
3
+ __version__ = "0.1.6"
4
+
5
+ """
6
+ Registery of pretrained emulators for usage
7
+ """
8
+ EMULATORS = Registry()
9
+
10
+
11
+ def build_emulator(name):
12
+ model = EMULATORS[name]()
13
+ return model
14
+
15
+
16
+ from .emulators import Bouabid2025Emulator
17
+ __all__ = ['build_emulator', 'Bouabid2025Emulator']
@@ -0,0 +1 @@
1
+ from .bouabid2025 import Bouabid2025Emulator
@@ -0,0 +1,44 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class AbstractEmulator(ABC):
5
+ @abstractmethod
6
+ def __call__(self, *args, **kwargs):
7
+ """Generate climate data samples"""
8
+ pass
9
+
10
+ @classmethod
11
+ def build(cls, **kwargs):
12
+ return cls(**kwargs)
13
+
14
+
15
+ class GriddedEmulator(AbstractEmulator):
16
+ @property
17
+ @abstractmethod
18
+ def lat(self):
19
+ """Latitude coordinates"""
20
+ pass
21
+
22
+ @property
23
+ @abstractmethod
24
+ def lon(self):
25
+ """Longitude coordinates"""
26
+ pass
27
+
28
+ @property
29
+ @abstractmethod
30
+ def vars(self):
31
+ """Variable list"""
32
+ pass
33
+
34
+ @property
35
+ def nlat(self):
36
+ return len(self.lat)
37
+
38
+ @property
39
+ def nlon(self):
40
+ return len(self.lon)
41
+
42
+ @property
43
+ def nvar(self):
44
+ return len(self.vars)
@@ -0,0 +1,200 @@
1
+ import os
2
+ import yaml
3
+ from functools import partial
4
+ import xarray as xr
5
+ import numpy as np
6
+ import jax.numpy as jnp
7
+ import jax.random as jr
8
+ import equinox as eqx
9
+ from huggingface_hub import hf_hub_download
10
+ from diffusion import HealPIXUNet, ContinuousVESchedule, ContinuousHeunSampler
11
+ from .abstractemulator import GriddedEmulator
12
+ from .. import EMULATORS
13
+
14
+
15
+ class Bouabid2025Emulator(GriddedEmulator):
16
+ def __init__(self, esm_name: str):
17
+ self.esm = esm_name
18
+ self.repo_id = "shahineb/climemu"
19
+
20
+ def load(self, which: str = "default"):
21
+ # Set files directory in hugging face repo
22
+ self.files_dir = os.path.join(self.esm, which)
23
+
24
+ # Load pattern scaling coefficients
25
+ self.β = self._load_pattern_scaling()
26
+
27
+ # Load climatology data
28
+ self.climatology = self._load_climatology()
29
+
30
+ # Load the generative model precursor
31
+ self.precursor = self._load_precursor()
32
+
33
+ def compile(self, n_samples, n_steps=30):
34
+ # Fix number of samples and steps for generation
35
+ self.generative_model = partial(self.precursor,
36
+ n_samples=n_samples,
37
+ n_steps=n_steps)
38
+
39
+ # Perform a dry run to compile the JAX functions (important for performance)
40
+ dummy_pattern = jnp.zeros((self.nlat, self.nlon))
41
+ _ = self.generative_model(pattern=dummy_pattern, key=jr.PRNGKey(0))
42
+
43
+ def __call__(self, gmst, month, seed=None, xarray=False):
44
+ # Apply pattern scaling: pattern = β₀ + β₁ * ΔT
45
+ pattern = self.β[month - 1, :, 1] * gmst + self.β[month - 1, :, 0]
46
+ pattern = pattern.reshape((self.nlat, self.nlon))
47
+
48
+ # Generate samples using the diffusion model
49
+ key = jr.PRNGKey(seed) if seed else jr.PRNGKey(np.random.randint(0, 1000000))
50
+ samples = self.generative_model(pattern=pattern, key=key)
51
+
52
+ # Convert to xarray Dataset
53
+ if xarray:
54
+ samples = xr.Dataset(
55
+ {
56
+ var: (("member", "lat", "lon"), samples[:, i, :, :])
57
+ for i, var in enumerate(self.vars)
58
+ },
59
+ coords={
60
+ "member": jnp.arange(len(samples)) + 1,
61
+ "lat": self.lat,
62
+ "lon": self.lon,
63
+ },
64
+ )
65
+ return samples
66
+
67
+ def _load_precursor(self):
68
+ # Build the neural network and noise schedule
69
+ config = self._load_config()
70
+ nn = self._load_nn(config)
71
+ schedule = self._load_schedule(config)
72
+
73
+ # Load normalization statistics used during training (needed to denormalize the generated samples)
74
+ μ, σ = self._load_normalization()
75
+
76
+ # Define the output size for the generated samples
77
+ output_size = (config['out_channels'], config['input_size'][1], config['input_size'][2])
78
+
79
+ # Create an precursor for the generative model
80
+ precursor = partial(draw_samples_single,
81
+ nn=nn,
82
+ schedule=schedule,
83
+ output_size=output_size,
84
+ μ=μ, σ=σ)
85
+ return precursor
86
+
87
+ def _load_config(self):
88
+ # Load configuration from YAML file
89
+ config_path = hf_hub_download(self.repo_id, f"{self.files_dir}/config.yaml")
90
+ with open(config_path, 'r') as file:
91
+ config = yaml.safe_load(file)
92
+ return config
93
+
94
+ def _load_normalization(self):
95
+ # Load normalization statistics used during training
96
+ norm_stats_path = hf_hub_download(self.repo_id, f"{self.files_dir}/μ_σ.npz")
97
+ stats = jnp.load(norm_stats_path)
98
+ μ, σ = stats['μ'], stats['σ']
99
+ return μ, σ
100
+
101
+ def _load_pattern_scaling(self):
102
+ # Load pattern scaling coefficients
103
+ pattern_scaling_path = hf_hub_download(self.repo_id, f"{self.files_dir}/β.npy")
104
+ β = jnp.load(pattern_scaling_path)
105
+ return β
106
+
107
+ def _load_climatology(self):
108
+ # Load climatology data
109
+ climatology_path = hf_hub_download(self.repo_id, f"{self.esm}/piControl_climatology.nc")
110
+ climatology = xr.open_dataset(climatology_path)
111
+ return climatology
112
+
113
+ def _load_nn(self, config):
114
+ # Load graph edges for HEALPix to lat-lon connectivity
115
+ edges_path = hf_hub_download(self.repo_id, f"{self.files_dir}/edges.npz")
116
+ edges_data = jnp.load(edges_path)
117
+ to_healpix = jnp.array(edges_data['to_healpix']).astype(jnp.int32)
118
+ to_latlon = jnp.array(edges_data['to_latlon']).astype(jnp.int32)
119
+
120
+ # Initialize the neural network
121
+ nn = HealPIXUNet(input_size=config['input_size'],
122
+ nside=config['nside'],
123
+ enc_filters=config['enc_filters'],
124
+ dec_filters=config['dec_filters'],
125
+ out_channels=config['out_channels'],
126
+ temb_dim=config['temb_dim'],
127
+ healpix_emb_dim=config['healpix_emb_dim'],
128
+ edges_to_healpix=to_healpix,
129
+ edges_to_latlon=to_latlon)
130
+
131
+ # Load the pre-trained weights from the saved model file
132
+ weights_path = hf_hub_download(self.repo_id, f"{self.files_dir}/weights.eqx")
133
+ nn = eqx.tree_deserialise_leaves(weights_path, nn)
134
+ return nn
135
+
136
+ def _load_schedule(self, config):
137
+ # Load the maximum noise level as used in training
138
+ sigma_max_path = hf_hub_download(self.repo_id, f"{self.files_dir}/σmax.npy")
139
+ σmax = jnp.load(sigma_max_path)
140
+
141
+ # Create the variance exploding schedule
142
+ schedule = ContinuousVESchedule(config['sigma_min'], σmax)
143
+ return schedule
144
+
145
+ @property
146
+ def lat(self):
147
+ return self.climatology['lat'].values
148
+
149
+ @property
150
+ def lon(self):
151
+ return self.climatology['lon'].values
152
+
153
+ @property
154
+ def vars(self):
155
+ return list(self.climatology.data_vars)
156
+
157
+
158
+ @eqx.filter_jit
159
+ def normalize(x, μ, σ):
160
+ """Normalize data using mean and standard deviation."""
161
+ return (x - μ) / σ
162
+
163
+ @eqx.filter_jit
164
+ def denormalize(x, μ, σ):
165
+ """Denormalize data using mean and standard deviation."""
166
+ return σ * x + μ
167
+
168
+
169
+ def create_sampler(nn, schedule, pattern, μ, σ, output_size):
170
+ """Create a sampler for a given pattern."""
171
+ context = normalize(pattern, μ[-1], σ[-1])[None, ...]
172
+ def nn_with_context(x, t):
173
+ x = jnp.concatenate((x, context), axis=0)
174
+ return nn(x, t)
175
+ return ContinuousHeunSampler(schedule, nn_with_context, output_size)
176
+
177
+ @eqx.filter_jit
178
+ def draw_samples_single(nn, schedule, pattern, n_samples, n_steps, μ, σ, output_size, key=jr.PRNGKey(0)):
179
+ """Draw samples for a given pattern."""
180
+ sampler = create_sampler(nn, schedule, pattern, μ, σ, output_size)
181
+ samples = sampler.sample(n_samples, steps=n_steps, key=key)
182
+ return denormalize(samples, μ[:-1], σ[:-1])
183
+
184
+
185
+ @EMULATORS.register("MPI-ESM1-2-LR")
186
+ class MPIEmulator(Bouabid2025Emulator):
187
+ def __init__(self, which="default"):
188
+ super().__init__(esm_name="MPI-ESM1-2-LR")
189
+
190
+
191
+ @EMULATORS.register("MIROC6")
192
+ class MIROCEmulator(Bouabid2025Emulator):
193
+ def __init__(self, which="default"):
194
+ super().__init__(esm_name="MIROC6")
195
+
196
+
197
+ @EMULATORS.register("ACCESS-ESM1-5")
198
+ class ACCESSEmulator(Bouabid2025Emulator):
199
+ def __init__(self, which="default"):
200
+ super().__init__(esm_name="ACCESS-ESM1-5")
@@ -0,0 +1,3 @@
1
+ from .registry import Registry
2
+
3
+ __all__ = ['Registry']
@@ -0,0 +1,71 @@
1
+ import types
2
+
3
+
4
+ class Registry(dict):
5
+ """
6
+ A helper class for managing access to builders, it extends a dictionary
7
+ and provides a registering functions than can be used as a decorator
8
+
9
+ Creating a registry:
10
+ MODULES = Registry()
11
+
12
+ There two types of builder callable you can register:
13
+
14
+ (1) : Functions which can be registered by either a simple call
15
+ ```
16
+ def build_bar():
17
+ return True
18
+ MODULES.register('bar', build_bar)
19
+ ```
20
+ or using a decorator at function definition
21
+ ```
22
+ @MODULES.register('bar')
23
+ def build_bar():
24
+ return True
25
+ ```
26
+
27
+ (2) : Class constructor method cls.build which again can be registered
28
+ with a call if class has a cls.build method
29
+ ```
30
+ foo = Foo()
31
+ MODULES.register('Foo', foo)
32
+ ```
33
+ of using a decorator at class definition
34
+ ```
35
+ @MODULES.register('foo')
36
+ class Foo:
37
+ @classmethod
38
+ def build(cls, *args, **kwargs):
39
+ # build a class instance foo
40
+ return foo
41
+ ```
42
+
43
+ Access of module is just like using a dictionary, eg:
44
+ build_bar = MODULES['bar']
45
+ build_foo = MODULES['foo']
46
+ """
47
+
48
+ def __init__(self, *args, **kwargs):
49
+ super(Registry, self).__init__(*args, **kwargs)
50
+
51
+ def register(self, name, module_object=None):
52
+ # Used as a decorator
53
+ if module_object is None:
54
+ def register_func(module_object):
55
+ self.register(name=name, module_object=module_object)
56
+ return module_object
57
+ return register_func
58
+
59
+ # Used as a function call
60
+ else:
61
+ if isinstance(module_object, type):
62
+ self._register_generic(module_dict=self, name=name, builder=module_object.build)
63
+ elif isinstance(module_object, types.FunctionType):
64
+ self._register_generic(module_dict=self, name=name, builder=module_object)
65
+ else:
66
+ raise TypeError("Trying to register unknown data type")
67
+
68
+ @staticmethod
69
+ def _register_generic(module_dict, name, builder):
70
+ assert name not in module_dict
71
+ module_dict[name] = builder
@@ -0,0 +1,41 @@
1
+ Metadata-Version: 2.4
2
+ Name: climemu
3
+ Version: 0.1.6
4
+ Summary: Score-based generative emulation of impact-relevant earth system model outputs
5
+ Author-email: Shahine Bouabid <shahineb@mit.edu>
6
+ License: MIT
7
+ Requires-Python: >=3.12
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: dask>=2025.7.0
11
+ Requires-Dist: diffrax>=0.7.0
12
+ Requires-Dist: einops>=0.8.1
13
+ Requires-Dist: equinox>=0.13.0
14
+ Requires-Dist: huggingface-hub>=0.35.0
15
+ Requires-Dist: jax>=0.7.1
16
+ Requires-Dist: netcdf4>=1.7.2
17
+ Requires-Dist: numpy>=2.3.2
18
+ Requires-Dist: optax>=0.2.5
19
+ Requires-Dist: scikit-learn>=1.7.1
20
+ Requires-Dist: scipy>=1.16.1
21
+ Requires-Dist: tqdm>=4.67.1
22
+ Requires-Dist: xarray>=2025.9.0
23
+ Provides-Extra: gpu
24
+ Requires-Dist: jax[cuda12]>=0.7.1; extra == "gpu"
25
+ Provides-Extra: train
26
+ Requires-Dist: torch>=2.8.0; extra == "train"
27
+ Requires-Dist: healpy>=1.18.1; extra == "train"
28
+ Requires-Dist: wandb>=0.21.3; extra == "train"
29
+ Provides-Extra: plots
30
+ Requires-Dist: matplotlib>=3.10.6; extra == "plots"
31
+ Requires-Dist: seaborn>=0.13.2; extra == "plots"
32
+ Requires-Dist: cartopy>=0.25.0; extra == "plots"
33
+ Requires-Dist: regionmask>=0.13.0; extra == "plots"
34
+ Requires-Dist: pandas>=2.3.2; extra == "plots"
35
+ Requires-Dist: pyshtools>=4.13.1; extra == "plots"
36
+ Provides-Extra: test
37
+ Requires-Dist: pytest>=8.0.0; extra == "test"
38
+ Requires-Dist: pytest-cov>=4.0.0; extra == "test"
39
+ Requires-Dist: pytest-mock>=3.10.0; extra == "test"
40
+ Requires-Dist: jaxlib>=0.4.0; extra == "test"
41
+ Dynamic: license-file
@@ -0,0 +1,46 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/climemu/__init__.py
5
+ src/climemu.egg-info/PKG-INFO
6
+ src/climemu.egg-info/SOURCES.txt
7
+ src/climemu.egg-info/dependency_links.txt
8
+ src/climemu.egg-info/requires.txt
9
+ src/climemu.egg-info/top_level.txt
10
+ src/climemu/emulators/__init__.py
11
+ src/climemu/emulators/abstractemulator.py
12
+ src/climemu/emulators/bouabid2025.py
13
+ src/climemu/utils/__init__.py
14
+ src/climemu/utils/registry.py
15
+ src/datasets/__init__.py
16
+ src/datasets/cmip6.py
17
+ src/datasets/constants.py
18
+ src/datasets/pattern_to_cmip6.py
19
+ src/diffusion/__init__.py
20
+ src/diffusion/losses/__init__.py
21
+ src/diffusion/losses/denoising_score_matching.py
22
+ src/diffusion/nn/__init__.py
23
+ src/diffusion/nn/healpixunet.py
24
+ src/diffusion/nn/backbones/__init__.py
25
+ src/diffusion/nn/backbones/convnet.py
26
+ src/diffusion/nn/modules/__init__.py
27
+ src/diffusion/nn/modules/remap.py
28
+ src/diffusion/nn/modules/healpix/__init__.py
29
+ src/diffusion/nn/modules/healpix/conv.py
30
+ src/diffusion/nn/modules/healpix/padding.py
31
+ src/diffusion/nn/timeencoder/__init__.py
32
+ src/diffusion/nn/timeencoder/gaussianfourier.py
33
+ src/diffusion/samplers/__init__.py
34
+ src/diffusion/samplers/continuous_ode_sampler.py
35
+ src/diffusion/schedules/__init__.py
36
+ src/diffusion/schedules/variance_exploding.py
37
+ src/utils/__init__.py
38
+ src/utils/collate.py
39
+ src/utils/arrays/__init__.py
40
+ src/utils/arrays/aggregate.py
41
+ src/utils/arrays/convert.py
42
+ src/utils/arrays/filter.py
43
+ src/utils/arrays/reshape.py
44
+ src/utils/arrays/smooth.py
45
+ src/utils/graphs/__init__.py
46
+ src/utils/graphs/haversine.py
@@ -0,0 +1,35 @@
1
+ dask>=2025.7.0
2
+ diffrax>=0.7.0
3
+ einops>=0.8.1
4
+ equinox>=0.13.0
5
+ huggingface-hub>=0.35.0
6
+ jax>=0.7.1
7
+ netcdf4>=1.7.2
8
+ numpy>=2.3.2
9
+ optax>=0.2.5
10
+ scikit-learn>=1.7.1
11
+ scipy>=1.16.1
12
+ tqdm>=4.67.1
13
+ xarray>=2025.9.0
14
+
15
+ [gpu]
16
+ jax[cuda12]>=0.7.1
17
+
18
+ [plots]
19
+ matplotlib>=3.10.6
20
+ seaborn>=0.13.2
21
+ cartopy>=0.25.0
22
+ regionmask>=0.13.0
23
+ pandas>=2.3.2
24
+ pyshtools>=4.13.1
25
+
26
+ [test]
27
+ pytest>=8.0.0
28
+ pytest-cov>=4.0.0
29
+ pytest-mock>=3.10.0
30
+ jaxlib>=0.4.0
31
+
32
+ [train]
33
+ torch>=2.8.0
34
+ healpy>=1.18.1
35
+ wandb>=0.21.3
@@ -0,0 +1,4 @@
1
+ climemu
2
+ datasets
3
+ diffusion
4
+ utils
@@ -0,0 +1,12 @@
1
+ from .cmip6 import (
2
+ CMIP6Data
3
+ )
4
+
5
+ from .pattern_to_cmip6 import (
6
+ PatternToCMIP6Dataset
7
+ )
8
+
9
+ __all__ = [
10
+ "CMIP6Data",
11
+ "PatternToCMIP6Dataset"
12
+ ]