anemoi-utils 0.4.35__py3-none-any.whl → 0.4.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

@@ -17,6 +17,9 @@ class Environment:
17
17
  ANEMOI_CONFIG_OVERRIDE_PATH: str
18
18
  """Path to the configuration override file for Anemoi."""
19
19
 
20
+ ANEMOI_DEBUG_IMPORTS: bool
21
+ """Enable debug imports to trace module loading."""
22
+
20
23
  def __setattr__(self, name: str, value: Any) -> None:
21
24
  raise AttributeError("Cannot set attributes on Environment class. Use environment variables instead.")
22
25
 
anemoi/utils/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.4.35'
32
- __version_tuple__ = version_tuple = (0, 4, 35)
31
+ __version__ = version = '0.4.37'
32
+ __version_tuple__ = version_tuple = (0, 4, 37)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -214,7 +214,7 @@ def save_metadata(
214
214
  zipf.writestr(entry["path"], value.tobytes())
215
215
 
216
216
 
217
- def _edit_metadata(path: str, name: str, callback: Callable, supporting_arrays: dict = None) -> None:
217
+ def _edit_metadata(path: str, name: str, callback: Callable, supporting_arrays: dict | None = None) -> None:
218
218
  """Edit metadata in a checkpoint file.
219
219
 
220
220
  Parameters
@@ -230,41 +230,56 @@ def _edit_metadata(path: str, name: str, callback: Callable, supporting_arrays:
230
230
  """
231
231
  new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"
232
232
 
233
- found = False
234
-
235
- directory = None
236
- with TemporaryDirectory() as temp_dir:
237
- zipfile.ZipFile(path, "r").extractall(temp_dir)
238
- total = 0
239
- for root, dirs, files in os.walk(temp_dir):
240
- for f in files:
241
- total += 1
242
- full = os.path.join(root, f)
243
- if f == name:
244
- found = True
245
- callback(full)
246
- directory = os.path.dirname(full)
247
-
248
- if not found:
233
+ with zipfile.ZipFile(path, "r") as source_zip:
234
+ file_list = source_zip.namelist()
235
+
236
+ # Find the target file and its directory
237
+ target_file = None
238
+ directory = None
239
+ for file_path in file_list:
240
+ if os.path.basename(file_path) == name:
241
+ target_file = file_path
242
+ directory = os.path.dirname(file_path)
243
+ break
244
+
245
+ if target_file is None:
249
246
  raise ValueError(f"Could not find '{name}' in {path}")
250
247
 
248
+ # Calculate total files for progress bar
249
+ total_files = len(file_list)
251
250
  if supporting_arrays is not None:
251
+ total_files += len(supporting_arrays)
252
+
253
+ with zipfile.ZipFile(new_path, "w", zipfile.ZIP_STORED) as new_zip:
254
+ with tqdm.tqdm(total=total_files, desc="Rebuilding checkpoint") as pbar:
255
+
256
+ # Copy all files except the target file
257
+ for file_path in file_list:
258
+ if file_path != target_file:
259
+ with source_zip.open(file_path) as source_file:
260
+ data = source_file.read()
261
+ new_zip.writestr(file_path, data)
262
+ pbar.update(1)
252
263
 
253
- for key, entry in supporting_arrays.items():
254
- value = entry.tobytes()
255
- fname = os.path.join(directory, f"{key}.numpy")
256
- os.makedirs(os.path.dirname(fname), exist_ok=True)
257
- with open(fname, "wb") as f:
258
- f.write(value)
259
- total += 1
260
-
261
- with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf:
262
- with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar:
263
- for root, dirs, files in os.walk(temp_dir):
264
- for f in files:
265
- full = os.path.join(root, f)
266
- rel = os.path.relpath(full, temp_dir)
267
- zipf.write(full, rel)
264
+ # Handle the target file with callback
265
+ with TemporaryDirectory() as temp_dir:
266
+ # Extract only the target file
267
+ source_zip.extract(target_file, temp_dir)
268
+ target_full_path = os.path.join(temp_dir, target_file)
269
+
270
+ # Apply the callback
271
+ callback(target_full_path)
272
+
273
+ # Add the modified file to the new zip (if it still exists)
274
+ if os.path.exists(target_full_path):
275
+ new_zip.write(target_full_path, target_file)
276
+ pbar.update(1)
277
+
278
+ # Add supporting arrays if provided
279
+ if supporting_arrays is not None:
280
+ for key, entry in supporting_arrays.items():
281
+ array_path = os.path.join(directory, f"{key}.numpy") if directory else f"{key}.numpy"
282
+ new_zip.writestr(array_path, entry.tobytes())
268
283
  pbar.update(1)
269
284
 
270
285
  os.rename(new_path, path)
anemoi/utils/cli.py CHANGED
@@ -16,11 +16,26 @@ import sys
16
16
  import traceback
17
17
  from collections.abc import Callable
18
18
 
19
+ from anemoi.utils import ENV
20
+
19
21
  try:
20
22
  import argcomplete
21
23
  except ImportError:
22
24
  argcomplete = None
23
25
 
26
+
27
+ if ENV.ANEMOI_DEBUG_IMPORTS:
28
+ from datetime import datetime
29
+ from importlib.abc import MetaPathFinder
30
+
31
+ class ImportTracer(MetaPathFinder):
32
+ def find_spec(self, fullname, path, target=None):
33
+ now = datetime.now().isoformat(timespec="milliseconds")
34
+ print(f"[{now}] Importing {fullname} from {path}")
35
+ return None # allow normal import processing to continue
36
+
37
+ sys.meta_path.insert(0, ImportTracer())
38
+
24
39
  LOG = logging.getLogger(__name__)
25
40
 
26
41
 
@@ -29,6 +44,10 @@ class Command:
29
44
 
30
45
  accept_unknown_args = False
31
46
 
47
+ def check(self, parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
48
+ """Check the command arguments."""
49
+ pass
50
+
32
51
  def run(self, args: argparse.Namespace) -> None:
33
52
  """Run the command.
34
53
 
@@ -72,6 +91,11 @@ def make_parser(description: str, commands: dict[str, Command]) -> argparse.Argu
72
91
  action="store_true",
73
92
  help="Debug mode",
74
93
  )
94
+ parser.add_argument(
95
+ "--rich",
96
+ action="store_true",
97
+ help="Use rich for logging",
98
+ )
75
99
 
76
100
  subparsers = parser.add_subparsers(help="commands:", dest="command")
77
101
  for name, command in commands.items():
@@ -216,16 +240,25 @@ def cli_main(
216
240
 
217
241
  cmd = commands[args.command]
218
242
 
219
- logging.basicConfig(
220
- format="%(asctime)s %(levelname)s %(message)s",
221
- datefmt="%Y-%m-%d %H:%M:%S",
222
- level=logging.DEBUG if args.debug else logging.INFO,
223
- )
243
+ if args.rich:
244
+ from .logs import get_rich_handler
245
+
246
+ logging.basicConfig(
247
+ format="%(message)s", level=logging.DEBUG if args.debug else logging.INFO, handlers=[get_rich_handler()]
248
+ )
249
+ else:
250
+ logging.basicConfig(
251
+ format="%(asctime)s %(levelname)s %(message)s",
252
+ datefmt="%Y-%m-%d %H:%M:%S",
253
+ level=logging.DEBUG if args.debug else logging.INFO,
254
+ )
224
255
 
225
256
  if unknown and not cmd.accept_unknown_args:
226
257
  # This should trigger an error
227
258
  parser.parse_args(test_arguments)
228
259
 
260
+ cmd.check(parser, args)
261
+
229
262
  try:
230
263
  if unknown:
231
264
  cmd.run(args, unknown)
@@ -139,13 +139,13 @@ class Metadata(Command):
139
139
  command_parser.add_argument(
140
140
  "--json",
141
141
  action="store_true",
142
- help="Use the JSON format with ``--dump``, ``--view`` and ``--edit``.",
142
+ help="Use the JSON format with ``--dump``, ``--view``, ``--get`` and ``--edit``.",
143
143
  )
144
144
 
145
145
  command_parser.add_argument(
146
146
  "--yaml",
147
147
  action="store_true",
148
- help="Use the YAML format with ``--dump``, ``--view`` and ``--edit``.",
148
+ help="Use the YAML format with ``--dump``, ``--view``, ``--get`` and ``--edit``.",
149
149
  )
150
150
 
151
151
  def run(self, args: Namespace) -> None:
@@ -315,7 +315,6 @@ class Metadata(Command):
315
315
  args : Namespace
316
316
  The arguments passed to the command.
317
317
  """
318
- from pprint import pprint
319
318
 
320
319
  from anemoi.utils.checkpoints import load_metadata
321
320
 
@@ -335,7 +334,10 @@ class Metadata(Command):
335
334
 
336
335
  print(f"Metadata values for {args.get}: ", end="\n" if isinstance(metadata, (dict, list)) else "")
337
336
  if isinstance(metadata, dict):
338
- pprint(metadata, indent=2, compact=True)
337
+ if args.yaml:
338
+ print(yaml.dump(metadata, indent=2, sort_keys=True))
339
+ return
340
+ print(json.dumps(metadata, indent=2, sort_keys=True))
339
341
  else:
340
342
  print(metadata)
341
343
 
@@ -28,10 +28,14 @@ class Transfer(Command):
28
28
  The argument parser to which the arguments will be added.
29
29
  """
30
30
  command_parser.add_argument(
31
- "--source", help="A path to a local file or folder or a URL to a file or a folder on S3."
31
+ "--source",
32
+ help="A path to a local file or folder or a URL to a file or a folder on S3.",
33
+ required=True,
32
34
  )
33
35
  command_parser.add_argument(
34
- "--target", help="A path to a local file or folder or a URL to a file or a folder on S3 or a remote folder."
36
+ "--target",
37
+ help="A path to a local file or folder or a URL to a file or a folder on S3 or a remote folder.",
38
+ required=True,
35
39
  )
36
40
  command_parser.add_argument(
37
41
  "--overwrite",
anemoi/utils/config.py CHANGED
@@ -15,7 +15,6 @@ import json
15
15
  import logging
16
16
  import os
17
17
  import threading
18
- import warnings
19
18
  from typing import Any
20
19
 
21
20
  import yaml
@@ -187,7 +186,6 @@ class DotDict(dict):
187
186
  The attribute value.
188
187
  """
189
188
 
190
- self.warn_on_mutation(attr)
191
189
  value = self.convert_to_nested_dot_dict(value)
192
190
  super().__setitem__(attr, value)
193
191
 
@@ -201,13 +199,9 @@ class DotDict(dict):
201
199
  value : Any
202
200
  The value to set.
203
201
  """
204
- self.warn_on_mutation(key)
205
202
  value = self.convert_to_nested_dot_dict(value)
206
203
  super().__setitem__(key, value)
207
204
 
208
- def warn_on_mutation(self, key):
209
- warnings.warn("Modifying an instance of DotDict(). This class is intended to be immutable.")
210
-
211
205
  def __repr__(self) -> str:
212
206
  """Return a string representation of the DotDict.
213
207
 
anemoi/utils/logs.py CHANGED
@@ -10,10 +10,10 @@
10
10
 
11
11
  """Logging utilities."""
12
12
 
13
+ import contextvars
13
14
  import logging
14
- import threading
15
15
 
16
- thread_local = threading.local()
16
+ LOGGING_NAME = contextvars.ContextVar("logging_name", default="main")
17
17
 
18
18
 
19
19
  LOGGER = logging.getLogger(__name__)
@@ -27,7 +27,7 @@ def set_logging_name(name: str) -> None:
27
27
  name : str
28
28
  The name to set for logging.
29
29
  """
30
- thread_local.logging_name = name
30
+ LOGGING_NAME.set(name)
31
31
 
32
32
 
33
33
  class ThreadCustomFormatter(logging.Formatter):
@@ -46,7 +46,7 @@ class ThreadCustomFormatter(logging.Formatter):
46
46
  str
47
47
  The formatted log record.
48
48
  """
49
- record.logging_name = thread_local.logging_name
49
+ record.logging_name = LOGGING_NAME.get()
50
50
  return super().format(record)
51
51
 
52
52
 
@@ -58,11 +58,39 @@ def enable_logging_name(name: str = "main") -> None:
58
58
  name : str, optional
59
59
  The default logging name to set, by default "main".
60
60
  """
61
- thread_local.logging_name = name
62
61
 
63
- formatter = ThreadCustomFormatter("%(asctime)s - %(logging_name)s - %(levelname)s - %(message)s")
62
+ logger = logging.getLogger()
63
+ is_rich = any(handler.__class__.__name__ == "CustomRichHandler" for handler in logger.handlers)
64
+
65
+ set_logging_name(name)
66
+
67
+ if is_rich:
68
+ formatter = ThreadCustomFormatter("%(message)s")
69
+ else:
70
+ formatter = ThreadCustomFormatter("%(asctime)s - [%(logging_name)s] - %(levelname)s - %(message)s")
64
71
 
65
72
  logger = logging.getLogger()
66
73
 
67
74
  for handler in logger.handlers:
68
75
  handler.setFormatter(formatter)
76
+
77
+
78
+ def get_rich_handler() -> logging.Handler:
79
+ """Return a RichHandler with custom formatting for logging."""
80
+
81
+ from rich.logging import RichHandler
82
+ from rich.text import Text
83
+
84
+ class CustomRichHandler(RichHandler):
85
+ def render_message(self, record, message):
86
+ global width
87
+
88
+ text = super().render_message(record, message)
89
+
90
+ if hasattr(record, "logging_name"):
91
+ name = record.logging_name
92
+ text = Text.assemble(f"[{name}]", (" → ", "dim"), text)
93
+
94
+ return text
95
+
96
+ return CustomRichHandler(log_time_format="[%X]")
@@ -13,6 +13,9 @@ from __future__ import annotations
13
13
  import logging
14
14
  import os
15
15
  import time
16
+ import warnings
17
+ from abc import ABC
18
+ from abc import abstractmethod
16
19
  from datetime import datetime
17
20
  from datetime import timezone
18
21
  from functools import wraps
@@ -20,10 +23,15 @@ from getpass import getpass
20
23
  from typing import TYPE_CHECKING
21
24
 
22
25
  import requests
26
+ from pydantic import BaseModel
27
+ from pydantic import RootModel
28
+ from pydantic import field_validator
29
+ from pydantic import model_validator
23
30
  from requests.exceptions import HTTPError
24
31
 
32
+ from ..config import CONFIG_LOCK
25
33
  from ..config import config_path
26
- from ..config import load_config
34
+ from ..config import load_raw_config
27
35
  from ..config import save_config
28
36
  from ..remote import robust
29
37
  from ..timer import Timer
@@ -35,10 +43,96 @@ if TYPE_CHECKING:
35
43
  from collections.abc import Callable
36
44
 
37
45
 
38
- class TokenAuth:
46
+ class ServerConfig(BaseModel):
47
+ refresh_token: str | None = None
48
+ refresh_expires: int = 0
49
+
50
+ @field_validator("refresh_expires", mode="before")
51
+ def to_int(cls, value: float | int) -> int:
52
+ if not isinstance(value, int):
53
+ return int(value)
54
+ return value
55
+
56
+
57
+ class ServerStore(RootModel):
58
+ root: dict[str, ServerConfig] = {}
59
+
60
+ def get(self, url: str) -> ServerConfig | None:
61
+ return self.root.get(url)
62
+
63
+ def __getitem__(self, url: str) -> ServerConfig:
64
+ return self.root[url]
65
+
66
+ def items(self):
67
+ return self.root.items()
68
+
69
+ def update(self, url, config: ServerConfig) -> None:
70
+ """Update the server configuration for a given URL."""
71
+ self.root[url] = config
72
+
73
+ @property
74
+ def servers(self) -> list[tuple[str, int]]:
75
+ """List of servers in the store, as a tuple (url, refresh_expires). Ordered most recently used first."""
76
+ return [
77
+ (url, cfg.refresh_expires)
78
+ for url, cfg in sorted(
79
+ self.root.items(),
80
+ key=lambda item: item[1].refresh_expires,
81
+ reverse=True,
82
+ )
83
+ ]
84
+
85
+ @model_validator(mode="before")
86
+ @classmethod
87
+ def load_legacy_format(cls, data: dict) -> dict:
88
+ """Convert legacy single-server config format to multi-server."""
89
+ if isinstance(data, dict) and "url" in data:
90
+ _data = data.copy()
91
+ _url = _data.pop("url")
92
+ data = {_url: ServerConfig(**_data)}
93
+ return data
94
+
95
+
96
+ class AuthBase(ABC):
97
+ """Base class for authentication implementations."""
98
+
99
+ @abstractmethod
100
+ def __init__(self, *args, **kwargs):
101
+ pass
102
+
103
+ @abstractmethod
104
+ def save(self, **kwargs):
105
+ pass
106
+
107
+ @abstractmethod
108
+ def login(self, force_credentials: bool = False, **kwargs):
109
+ pass
110
+
111
+ @abstractmethod
112
+ def authenticate(self, **kwargs):
113
+ pass
114
+
115
+
116
+ class NoAuth(AuthBase):
117
+ """No-op authentication class."""
118
+
119
+ def __init__(self, *args, **kwargs):
120
+ self._enabled = False
121
+
122
+ def save(self, **kwargs):
123
+ pass
124
+
125
+ def login(self, force_credentials: bool = False, **kwargs):
126
+ pass
127
+
128
+ def authenticate(self, **kwargs):
129
+ pass
130
+
131
+
132
+ class TokenAuth(AuthBase):
39
133
  """Manage authentication with a keycloak token server."""
40
134
 
41
- config_file = "mlflow-token.json"
135
+ _config_file = "mlflow-token.json"
42
136
 
43
137
  def __init__(
44
138
  self,
@@ -63,10 +157,16 @@ class TokenAuth:
63
157
  self.target_env_var = target_env_var
64
158
  self._enabled = enabled
65
159
 
66
- config = self.load_config()
160
+ store = self._get_store()
161
+ config = store.get(self.url)
162
+
163
+ if config is not None:
164
+ self._refresh_token = config.refresh_token
165
+ self.refresh_expires = config.refresh_expires
166
+ else:
167
+ self._refresh_token = None
168
+ self.refresh_expires = 0
67
169
 
68
- self._refresh_token = config.get("refresh_token")
69
- self.refresh_expires = config.get("refresh_expires", 0)
70
170
  self.access_token = None
71
171
  self.access_expires = 0
72
172
 
@@ -86,17 +186,50 @@ class TokenAuth:
86
186
  self._refresh_token = value
87
187
  self.refresh_expires = time.time() + (REFRESH_EXPIRE_DAYS * 86400) # 86400 seconds in a day
88
188
 
189
+ @staticmethod
190
+ def _get_store() -> ServerStore:
191
+ """Read the server store from disk."""
192
+ with CONFIG_LOCK:
193
+ file = TokenAuth._config_file
194
+ path = config_path(file)
195
+
196
+ if not os.path.exists(path):
197
+ save_config(file, {})
198
+
199
+ if os.path.exists(path) and os.stat(path).st_mode & 0o777 != 0o600:
200
+ os.chmod(path, 0o600)
201
+
202
+ return ServerStore(**load_raw_config(file))
203
+
204
+ @staticmethod
205
+ def get_servers() -> list[tuple[str, int]]:
206
+ """List of all saved servers, as a tuple (url, refresh_expires). Ordered most recently used first."""
207
+ return TokenAuth._get_store().servers
208
+
89
209
  @staticmethod
90
210
  def load_config() -> dict:
91
- path = config_path(TokenAuth.config_file)
211
+ """Load the last used server configuration
92
212
 
93
- if not os.path.exists(path):
94
- save_config(TokenAuth.config_file, {})
213
+ Returns
214
+ -------
215
+ config : dict
216
+ Dictionary with the following keys: `url`, `refresh_token`, `refresh_expires`.
217
+ If no configuration is found, an empty dictionary is returned.
218
+ """
219
+ warnings.warn(
220
+ "TokenAuth.load_config() is deprecated and will be removed in a future release.",
221
+ DeprecationWarning,
222
+ stacklevel=2,
223
+ )
95
224
 
96
- if os.path.exists(path) and os.stat(path).st_mode & 0o777 != 0o600:
97
- os.chmod(path, 0o600)
225
+ store = TokenAuth._get_store()
98
226
 
99
- return load_config(TokenAuth.config_file)
227
+ last = {}
228
+ for url, cfg in store.items():
229
+ if cfg.refresh_expires > last.get("refresh_expires", 0):
230
+ last = dict(url=url, **cfg.model_dump())
231
+
232
+ return last
100
233
 
101
234
  def enabled(fn: Callable) -> Callable: # noqa: N805
102
235
  """Decorator to call or ignore a function based on the `enabled` flag."""
@@ -199,12 +332,15 @@ class TokenAuth:
199
332
  self.log.warning("No refresh token to save.")
200
333
  return
201
334
 
202
- config = {
203
- "url": self.url,
204
- "refresh_token": self.refresh_token,
205
- "refresh_expires": self.refresh_expires,
206
- }
207
- save_config(self.config_file, config)
335
+ server_config = ServerConfig(
336
+ refresh_token=self.refresh_token,
337
+ refresh_expires=self.refresh_expires,
338
+ )
339
+
340
+ with CONFIG_LOCK:
341
+ store = self._get_store()
342
+ store.update(self.url, server_config)
343
+ save_config(self._config_file, store.model_dump())
208
344
 
209
345
  expire_date = datetime.fromtimestamp(self.refresh_expires, tz=timezone.utc)
210
346
  self.log.info(