anemoi-utils 0.4.13__py3-none-any.whl → 0.4.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

anemoi/utils/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.4.13'
21
- __version_tuple__ = version_tuple = (0, 4, 13)
20
+ __version__ = version = '0.4.14'
21
+ __version_tuple__ = version_tuple = (0, 4, 14)
@@ -10,10 +10,10 @@
10
10
 
11
11
  """Collect information about the current environment, like:
12
12
 
13
- - The Python version
14
- - The versions of the modules which are currently loaded
15
- - The git information for the modules which are currently loaded from a git repository
16
- - ...
13
+ - The Python version
14
+ - The versions of the modules which are currently loaded
15
+ - The git information for the modules which are currently loaded from a git repository
16
+ - ...
17
17
  """
18
18
 
19
19
  import datetime
anemoi/utils/registry.py CHANGED
@@ -12,9 +12,12 @@ import importlib
12
12
  import logging
13
13
  import os
14
14
  import sys
15
+ import warnings
16
+ from functools import cached_property
15
17
  from typing import Any
16
18
  from typing import Callable
17
19
  from typing import Dict
20
+ from typing import List
18
21
  from typing import Optional
19
22
  from typing import Union
20
23
 
@@ -22,6 +25,8 @@ import entrypoints
22
25
 
23
26
  LOG = logging.getLogger(__name__)
24
27
 
28
+ DEBUG_ANEMOI_REGISTRY = int(os.environ.get("DEBUG_ANEMOI_REGISTRY", "0"))
29
+
25
30
 
26
31
  class Wrapper:
27
32
  """A wrapper for the registry.
@@ -55,6 +60,22 @@ class Wrapper:
55
60
  return factory
56
61
 
57
62
 
63
+ class Error:
64
+ """An error class. Used in place of a plugin that failed to load.
65
+
66
+ Parameters
67
+ ----------
68
+ error : Exception
69
+ The error.
70
+ """
71
+
72
+ def __init__(self, error: Exception):
73
+ self.error = error
74
+
75
+ def __call__(self, *args, **kwargs):
76
+ raise self.error
77
+
78
+
58
79
  _BY_KIND = {}
59
80
 
60
81
 
@@ -71,7 +92,8 @@ class Registry:
71
92
 
72
93
  def __init__(self, package: str, key: str = "_type"):
73
94
  self.package = package
74
- self.registered = {}
95
+ self.__registered = {}
96
+ self._sources = {}
75
97
  self.kind = package.split(".")[-1]
76
98
  self.key = key
77
99
  _BY_KIND[self.kind] = self
@@ -92,7 +114,9 @@ class Registry:
92
114
  """
93
115
  return _BY_KIND.get(kind)
94
116
 
95
- def register(self, name: str, factory: Optional[Callable] = None) -> Optional[Wrapper]:
117
+ def register(
118
+ self, name: str, factory: Optional[Callable] = None, source: Optional[Any] = None
119
+ ) -> Optional[Wrapper]:
96
120
  """Register a factory with the registry.
97
121
 
98
122
  Parameters
@@ -101,6 +125,8 @@ class Registry:
101
125
  The name of the factory.
102
126
  factory : Callable, optional
103
127
  The factory to register, by default None.
128
+ source : Any, optional
129
+ The source of the factory, by default None.
104
130
 
105
131
  Returns
106
132
  -------
@@ -108,12 +134,19 @@ class Registry:
108
134
  A wrapper if the factory is None, otherwise None.
109
135
  """
110
136
  if factory is None:
137
+ # This happens when the @register decorator is used
111
138
  return Wrapper(name, self)
112
139
 
113
- self.registered[name] = factory
140
+ if source is None:
141
+ source = getattr(factory, "_source") if hasattr(factory, "_source") else factory
142
+
143
+ if name in self.__registered:
144
+ warnings.warn(f"Factory '{name}' is already registered in {self.package}")
145
+ warnings.warn(f"Existing: {self._sources[name]}")
146
+ warnings.warn(f"New: {source}")
114
147
 
115
- # def registered(self, name: str):
116
- # return name in self.registered
148
+ self.__registered[name] = factory
149
+ self._sources[name] = source
117
150
 
118
151
  def _load(self, file: str) -> None:
119
152
  """Load a module from a file.
@@ -126,8 +159,30 @@ class Registry:
126
159
  name, _ = os.path.splitext(file)
127
160
  try:
128
161
  importlib.import_module(f".{name}", package=self.package)
129
- except Exception:
130
- LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True)
162
+ except Exception as e:
163
+ if DEBUG_ANEMOI_REGISTRY:
164
+ raise
165
+ self._registered[name] = Error(e)
166
+
167
+ def is_registered(self, name: str) -> bool:
168
+ """Check if a factory is registered.
169
+
170
+ Parameters
171
+ ----------
172
+ name : str
173
+ The name of the factory.
174
+
175
+ Returns
176
+ -------
177
+ bool
178
+ Whether the factory is registered.
179
+ """
180
+ ok = name in self.factories
181
+ if not ok:
182
+ LOG.error(f"Cannot find '{name}' in {self.package}")
183
+ for e in self.factories:
184
+ LOG.info(f"Registered: {e} ({self._sources.get(e)})")
185
+ return ok
131
186
 
132
187
  def lookup(self, name: str, *, return_none: bool = False) -> Optional[Callable]:
133
188
  """Lookup a factory by name.
@@ -144,9 +199,22 @@ class Registry:
144
199
  Callable, optional
145
200
  The factory if found, otherwise None.
146
201
  """
147
- # print('✅✅✅✅✅✅✅✅✅✅✅✅✅', name, self.registered)
148
- if name in self.registered:
149
- return self.registered[name]
202
+ if return_none:
203
+ return self.factories.get(name)
204
+
205
+ factory = self.factories.get(name)
206
+ if factory is None:
207
+
208
+ LOG.error(f"Cannot find '{name}' in {self.package}")
209
+ for e in self.factories:
210
+ LOG.info(f"Registered: {e} ({self._sources.get(e)})")
211
+
212
+ raise ValueError(f"Cannot find '{name}' in {self.package}")
213
+
214
+ return factory
215
+
216
+ @cached_property
217
+ def factories(self) -> Dict[str, Callable]:
150
218
 
151
219
  directory = sys.modules[self.package].__path__[0]
152
220
 
@@ -167,25 +235,41 @@ class Registry:
167
235
  if file.endswith(".py"):
168
236
  self._load(file)
169
237
 
170
- entrypoint_group = f"anemoi.{self.kind}"
171
- for entry_point in entrypoints.get_group_all(entrypoint_group):
172
- if entry_point.name == name:
173
- if name in self.registered:
174
- LOG.warning(
175
- f"Overwriting builtin '{name}' from {self.package} with plugin '{entry_point.module_name}'"
176
- )
177
- self.registered[name] = entry_point.load()
238
+ bits = self.package.split(".")
239
+ # We assume a name like anemoi.datasets.create.sources, with kind = sources
240
+ assert bits[-1] == self.kind, (self.package, self.kind)
241
+ assert len(bits) > 1, self.package
242
+
243
+ groups = []
244
+ middle = bits[1:-1]
245
+ while True:
246
+ group = ".".join([bits[0], *middle, bits[-1]])
247
+ groups.append(group)
248
+ if len(middle) == 0:
249
+ break
250
+ middle.pop()
178
251
 
179
- if name not in self.registered:
180
- if return_none:
181
- return None
252
+ groups.reverse()
182
253
 
183
- for e in self.registered:
184
- LOG.info(f"Registered: {e}")
254
+ LOG.debug("Loading plugins from %s", groups)
185
255
 
186
- raise ValueError(f"Cannot load '{name}' from {self.package}")
256
+ for entrypoint_group in groups:
257
+ for entry_point in entrypoints.get_group_all(entrypoint_group):
258
+ source = entry_point.distro
259
+ try:
260
+ self.register(entry_point.name, entry_point.load(), source=source)
261
+ except Exception as e:
262
+ if DEBUG_ANEMOI_REGISTRY:
263
+ raise
264
+ self.register(entry_point.name, Error(e), source=source)
187
265
 
188
- return self.registered[name]
266
+ return self.__registered
267
+
268
+ @property
269
+ def registered(self) -> List[str]:
270
+ """Get the registered factories."""
271
+
272
+ return sorted(self.factories.keys())
189
273
 
190
274
  def create(self, name: str, *args: Any, **kwargs: Any) -> Any:
191
275
  """Create an instance using a factory.
@@ -207,9 +291,6 @@ class Registry:
207
291
  factory = self.lookup(name)
208
292
  return factory(*args, **kwargs)
209
293
 
210
- # def __call__(self, name: str, *args, **kwargs):
211
- # return self.create(name, *args, **kwargs)
212
-
213
294
  def from_config(self, config: Union[str, Dict[str, Any]], *args: Any, **kwargs: Any) -> Any:
214
295
  """Create an instance from a configuration.
215
296
 
@@ -0,0 +1,182 @@
1
+ # (C) Copyright 2025- Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import atexit
11
+ import logging
12
+ import os
13
+ import shutil
14
+ import tempfile
15
+ import threading
16
+
17
+ from multiurl import download
18
+
19
+ LOG = logging.getLogger(__name__)
20
+
21
+ TEST_DATA_URL = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/samples/"
22
+
23
+ lock = threading.RLock()
24
+ TEMPORARY_DIRECTORY = None
25
+
26
+
27
+ def _temporary_directory() -> str:
28
+ """Return a temporary directory in which to download test data.
29
+
30
+ Returns
31
+ -------
32
+ str
33
+ The path to the temporary directory.
34
+ """
35
+ global TEMPORARY_DIRECTORY
36
+ with lock:
37
+ if TEMPORARY_DIRECTORY is not None:
38
+ return TEMPORARY_DIRECTORY
39
+
40
+ TEMPORARY_DIRECTORY = tempfile.mkdtemp()
41
+
42
+ # Register a cleanup function to remove the directory at exit
43
+ atexit.register(shutil.rmtree, TEMPORARY_DIRECTORY)
44
+
45
+ return TEMPORARY_DIRECTORY
46
+
47
+
48
+ def _check_path(path: str) -> None:
49
+ """Check if the given path is normalized, not absolute, and does not start with a dot.
50
+
51
+ Parameters
52
+ ----------
53
+ path : str
54
+ The path to check.
55
+
56
+ Raises
57
+ ------
58
+ AssertionError
59
+ If the path is not normalized, is absolute, or starts with a dot.
60
+ """
61
+ assert os.path.normpath(path) == path, f"Path '{path}' should be normalized"
62
+ assert not os.path.isabs(path), f"Path '{path}' should not be absolute"
63
+ assert not path.startswith("."), f"Path '{path}' should not start with '.'"
64
+
65
+
66
+ def url_for_test_data(path: str) -> str:
67
+ """Generate the URL for the test data based on the given path.
68
+
69
+ Parameters
70
+ ----------
71
+ path : str
72
+ The relative path to the test data.
73
+
74
+ Returns
75
+ -------
76
+ str
77
+ The full URL to the test data.
78
+ """
79
+ _check_path(path)
80
+
81
+ return f"{TEST_DATA_URL}{path}"
82
+
83
+
84
+ def get_test_data(path: str, gzipped=False) -> str:
85
+ """Download the test data to a temporary directory and return the local path.
86
+
87
+ Parameters
88
+ ----------
89
+ path : str
90
+ The relative path to the test data.
91
+ gzipped : bool, optional
92
+ Flag indicating if the remote file is gzipped, by default False. The local file will be gunzipped.
93
+
94
+ Returns
95
+ -------
96
+ str
97
+ The local path to the downloaded test data.
98
+ """
99
+ _check_path(path)
100
+
101
+ target = os.path.normpath(os.path.join(_temporary_directory(), path))
102
+ with lock:
103
+ if os.path.exists(target):
104
+ return target
105
+
106
+ os.makedirs(os.path.dirname(target), exist_ok=True)
107
+ url = url_for_test_data(path)
108
+
109
+ if gzipped:
110
+ url += ".gz"
111
+ target += ".gz"
112
+
113
+ LOG.info(f"Downloading test data from {url} to {target}")
114
+
115
+ download(url, target)
116
+
117
+ if gzipped:
118
+ import gzip
119
+
120
+ with gzip.open(target, "rb") as f_in:
121
+ with open(target[:-3], "wb") as f_out:
122
+ shutil.copyfileobj(f_in, f_out)
123
+ os.remove(target)
124
+ target = target[:-3]
125
+
126
+ return target
127
+
128
+
129
+ def get_test_archive(path: str, extension=".extracted") -> str:
130
+ """Download an archive file (.zip, .tar, .tar.gz, .tar.bz2, .tar.xz) to a temporary directory
131
+ unpack it, and return the local path to the directory containing the extracted files.
132
+
133
+ Parameters
134
+ ----------
135
+ path : str
136
+ The relative path to the test data.
137
+ extension : str, optional
138
+ The extension to add to the extracted directory, by default '.extracted'
139
+
140
+ Returns
141
+ -------
142
+ str
143
+ The local path to the downloaded test data.
144
+ """
145
+
146
+ with lock:
147
+
148
+ archive = get_test_data(path)
149
+ target = archive + extension
150
+
151
+ shutil.unpack_archive(archive, os.path.dirname(target) + ".tmp")
152
+ os.rename(os.path.dirname(target) + ".tmp", target)
153
+
154
+ return target
155
+
156
+
157
+ def packages_installed(*names) -> bool:
158
+ """Check if all the given packages are installed.
159
+
160
+ Use this function to check if the required packages are installed before running tests.
161
+
162
+ >>> @pytest.mark.skipif(not packages_installed("foo", "bar"), reason="Packages 'foo' and 'bar' are not installed")
163
+ >>> def test_foo_bar() -> None:
164
+ >>> ...
165
+
166
+ Parameters
167
+ ----------
168
+ names : str
169
+ The names of the packages to check.
170
+
171
+ Returns
172
+ -------
173
+ bool:
174
+ Flag indicating if all the packages are installed."
175
+ """
176
+
177
+ for name in names:
178
+ try:
179
+ __import__(name)
180
+ except ImportError:
181
+ return False
182
+ return True
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: anemoi-utils
3
- Version: 0.4.13
3
+ Version: 0.4.14
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
6
6
  License: Apache License
@@ -256,3 +256,4 @@ Requires-Dist: pytest; extra == "tests"
256
256
  Provides-Extra: text
257
257
  Requires-Dist: termcolor; extra == "text"
258
258
  Requires-Dist: wcwidth; extra == "text"
259
+ Dynamic: license-file
@@ -1,6 +1,6 @@
1
1
  anemoi/utils/__init__.py,sha256=uVhpF-VjIl_4mMywOVtgTutgsdIsqz-xdkwxeMhzuag,730
2
2
  anemoi/utils/__main__.py,sha256=6LlE4MYrPvqqrykxXh7XMi50UZteUY59NeM8P9Zs2dU,910
3
- anemoi/utils/_version.py,sha256=NEWZMsf9hAPHbHDxsYo2Rgs87lmnioXpPKJ62lRxhqs,513
3
+ anemoi/utils/_version.py,sha256=ggKLQlA5gegh3l4LAMWuwHPmL8zgZ4b0pYtFx7quk78,513
4
4
  anemoi/utils/caching.py,sha256=rXbeAmpBcMbbfN4EVblaHWKicsrtx1otER84FEBtz98,6183
5
5
  anemoi/utils/checkpoints.py,sha256=N4WpAZXa4etrpSEKhHqUUtG2-x9w3FJMHcLO-dDAXPY,9600
6
6
  anemoi/utils/cli.py,sha256=IyZfnSw0u0yYnrjOrzvm2RuuKvDk4cVb8pf8BkaChgA,6209
@@ -13,11 +13,12 @@ anemoi/utils/grids.py,sha256=edTrMK8hpE9ZBzSfwcRftgk0jljNAK3i8CraadILQoM,4427
13
13
  anemoi/utils/hindcasts.py,sha256=iYVIxSNFL2HJcc_k1abCFLkpJFGHT8WKRIR4wcAwA3s,2144
14
14
  anemoi/utils/humanize.py,sha256=hCrHr5ppREuJR-tBqRqynqe58BHR6Ga_gCQqgEmmrfU,25301
15
15
  anemoi/utils/logs.py,sha256=naTgrmPwWHD4eekFttXftS4gtcAGYHpCqG4iwYprNDA,1804
16
- anemoi/utils/provenance.py,sha256=l6TRVadM5l3SKUvYM20EQn0TbolSY6vbbZ_WqekjxwM,14619
17
- anemoi/utils/registry.py,sha256=J2onjVqR4LLYRAj5sioobsdxo_e4kxuBOSiNeAQbGe8,7061
16
+ anemoi/utils/provenance.py,sha256=tIIgweS0EJyaYzgKwuv3iWny-Gz7N5e5CNWH5MeSYWU,14615
17
+ anemoi/utils/registry.py,sha256=zTnVSCecmrI6SkZhF4ipV7WKiZYHEBJ_5ZwqhRGM4T0,9287
18
18
  anemoi/utils/s3.py,sha256=xMT48kbcelcjjqsaU567WI3oZ5eqo88Rlgyx5ECszAU,4074
19
19
  anemoi/utils/sanitise.py,sha256=ZYGdSX6qihQANr3pHZjbKnoapnzP1KcrWdW1Ul1mOGk,3668
20
20
  anemoi/utils/sanitize.py,sha256=43ZKDcfVpeXSsJ9TFEc9aZnD6oe2cUh151XnDspM98M,462
21
+ anemoi/utils/testing.py,sha256=N1y4dfZLE9zqOhIR3o-933fdAdd9BxDvjcJx7SwFC9A,4803
21
22
  anemoi/utils/text.py,sha256=HkzIvi24obDceFLpJEwBJ9PmPrJUkQN2TrElJ-A87gU,14441
22
23
  anemoi/utils/timer.py,sha256=_leKMYza2faM7JKlGE7LCNy13rbdPnwaCF7PSrI_NmI,3895
23
24
  anemoi/utils/commands/__init__.py,sha256=5u_6EwdqYczIAgJfCwRSyQAYFEqh2ZuHHT57g9g7sdI,808
@@ -29,9 +30,9 @@ anemoi/utils/mars/requests.py,sha256=VFMHBVAAl0_2lOcMBa1lvaKHctN0lDJsI6_U4BucGew
29
30
  anemoi/utils/remote/__init__.py,sha256=-uaYFi4yRYFRf46ubQbJo86GCn6HE5VQrcaoyrmyW28,20704
30
31
  anemoi/utils/remote/s3.py,sha256=spQ8l0rwQjLZh9dZu5cOsYIvNwKihQfCJ6YsFYegeqI,17339
31
32
  anemoi/utils/remote/ssh.py,sha256=xNtsawh8okytCKRehkRCVExbHZj-CRUQNormEHglfuw,8088
32
- anemoi_utils-0.4.13.dist-info/LICENSE,sha256=8HznKF1Vi2IvfLsKNE5A2iVyiri3pRjRPvPC9kxs6qk,11354
33
- anemoi_utils-0.4.13.dist-info/METADATA,sha256=WSExhiI_8UePoAsCE4h65G2moeuMEgWZcgECBsisx30,15338
34
- anemoi_utils-0.4.13.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
35
- anemoi_utils-0.4.13.dist-info/entry_points.txt,sha256=LENOkn88xzFQo-V59AKoA_F_cfYQTJYtrNTtf37YgHY,60
36
- anemoi_utils-0.4.13.dist-info/top_level.txt,sha256=DYn8VPs-fNwr7fNH9XIBqeXIwiYYd2E2k5-dUFFqUz0,7
37
- anemoi_utils-0.4.13.dist-info/RECORD,,
33
+ anemoi_utils-0.4.14.dist-info/licenses/LICENSE,sha256=8HznKF1Vi2IvfLsKNE5A2iVyiri3pRjRPvPC9kxs6qk,11354
34
+ anemoi_utils-0.4.14.dist-info/METADATA,sha256=VLuh3lxCAuACOrJRnE_Dm0zX_pVfIHoppmRrP505pkY,15360
35
+ anemoi_utils-0.4.14.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
36
+ anemoi_utils-0.4.14.dist-info/entry_points.txt,sha256=LENOkn88xzFQo-V59AKoA_F_cfYQTJYtrNTtf37YgHY,60
37
+ anemoi_utils-0.4.14.dist-info/top_level.txt,sha256=DYn8VPs-fNwr7fNH9XIBqeXIwiYYd2E2k5-dUFFqUz0,7
38
+ anemoi_utils-0.4.14.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (77.0.3)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5