anemoi-utils 0.1.9__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

@@ -0,0 +1,72 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+
12
+ import argparse
13
+ import logging
14
+ import sys
15
+ import traceback
16
+
17
+ from . import __version__
18
+ from .commands import COMMANDS
19
+
20
+ LOG = logging.getLogger(__name__)
21
+
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
25
+
26
+ parser.add_argument(
27
+ "--version",
28
+ "-V",
29
+ action="store_true",
30
+ help="show the version and exit",
31
+ )
32
+ parser.add_argument(
33
+ "--debug",
34
+ "-d",
35
+ action="store_true",
36
+ help="Debug mode",
37
+ )
38
+
39
+ subparsers = parser.add_subparsers(help="commands:", dest="command")
40
+ for name, command in COMMANDS.items():
41
+ command_parser = subparsers.add_parser(name, help=command.__doc__)
42
+ command.add_arguments(command_parser)
43
+
44
+ args = parser.parse_args()
45
+
46
+ if args.version:
47
+ print(__version__)
48
+ return
49
+
50
+ if args.command is None:
51
+ parser.print_help()
52
+ return
53
+
54
+ cmd = COMMANDS[args.command]
55
+
56
+ logging.basicConfig(
57
+ format="%(asctime)s %(levelname)s %(message)s",
58
+ datefmt="%Y-%m-%d %H:%M:%S",
59
+ level=logging.DEBUG if args.debug else logging.INFO,
60
+ )
61
+
62
+ try:
63
+ cmd.run(args)
64
+ except ValueError as e:
65
+ traceback.print_exc()
66
+ LOG.error("\n💣 %s", str(e).lstrip())
67
+ LOG.error("💣 Exiting")
68
+ sys.exit(1)
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
anemoi/utils/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.9'
16
- __version_tuple__ = version_tuple = (0, 1, 9)
15
+ __version__ = version = '0.2.0'
16
+ __version_tuple__ = version_tuple = (0, 2, 0)
@@ -13,11 +13,38 @@ are zip archives containing the model weights.
13
13
  import json
14
14
  import logging
15
15
  import os
16
+ import time
16
17
  import zipfile
18
+ from tempfile import TemporaryDirectory
19
+
20
+ import tqdm
17
21
 
18
22
  LOG = logging.getLogger(__name__)
19
23
 
20
- DEFAULT_NAME = "anemoi-metadata.json"
24
+ DEFAULT_NAME = "ai-models.json"
25
+ DEFAULT_FOLDER = "anemoi-metadata"
26
+
27
+
28
+ def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
29
+ """Check if a checkpoint file has a metadata file
30
+
31
+ Parameters
32
+ ----------
33
+ path : str
34
+ The path to the checkpoint file
35
+ name : str, optional
36
+ The name of the metadata file in the zip archive
37
+
38
+ Returns
39
+ -------
40
+ bool
41
+ True if the metadata file is found
42
+ """
43
+ with zipfile.ZipFile(path, "r") as f:
44
+ for b in f.namelist():
45
+ if os.path.basename(b) == name:
46
+ return True
47
+ return False
21
48
 
22
49
 
23
50
  def load_metadata(path: str, name: str = DEFAULT_NAME):
@@ -45,17 +72,17 @@ def load_metadata(path: str, name: str = DEFAULT_NAME):
45
72
  for b in f.namelist():
46
73
  if os.path.basename(b) == name:
47
74
  if metadata is not None:
48
- LOG.warning(f"Found two '{name}' if {path}")
75
+ raise ValueError(f"Found two or more '{name}' in {path}.")
49
76
  metadata = b
50
77
 
51
78
  if metadata is not None:
52
79
  with zipfile.ZipFile(path, "r") as f:
53
80
  return json.load(f.open(metadata, "r"))
54
81
  else:
55
- raise ValueError(f"Could not find {name} in {path}")
82
+ raise ValueError(f"Could not find '{name}' in {path}.")
56
83
 
57
84
 
58
- def save_metadata(path, metadata, name=DEFAULT_NAME):
85
+ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER):
59
86
  """Save metadata to a checkpoint file
60
87
 
61
88
  Parameters
@@ -68,8 +95,85 @@ def save_metadata(path, metadata, name=DEFAULT_NAME):
68
95
  The name of the metadata file in the zip archive
69
96
  """
70
97
  with zipfile.ZipFile(path, "a") as zipf:
71
- base, _ = os.path.splitext(os.path.basename(path))
98
+
99
+ directories = set()
100
+
101
+ for b in zipf.namelist():
102
+ directory = os.path.dirname(b)
103
+ while os.path.dirname(directory) not in (".", ""):
104
+ directory = os.path.dirname(directory)
105
+ directories.add(directory)
106
+
107
+ if os.path.basename(b) == name:
108
+ raise ValueError(f"'{name}' already in {path}")
109
+
110
+ if len(directories) != 1:
111
+ # PyTorch checkpoints should have a single directory
112
+ # otherwise PyTorch will complain
113
+ raise ValueError(f"No or multiple directories in the checkpoint {path}, directories={directories}")
114
+
115
+ directory = list(directories)[0]
116
+
117
+ LOG.info("Saving metadata to %s/%s/%s", directory, folder, name)
118
+
72
119
  zipf.writestr(
73
- f"{base}/{name}",
120
+ f"{directory}/{folder}/{name}",
74
121
  json.dumps(metadata),
75
122
  )
123
+
124
+
125
+ def _edit_metadata(path, name, callback):
126
+ new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"
127
+
128
+ found = False
129
+
130
+ with TemporaryDirectory() as temp_dir:
131
+ zipfile.ZipFile(path, "r").extractall(temp_dir)
132
+ total = 0
133
+ for root, dirs, files in os.walk(temp_dir):
134
+ for f in files:
135
+ total += 1
136
+ full = os.path.join(root, f)
137
+ if f == name:
138
+ found = True
139
+ callback(full)
140
+
141
+ if not found:
142
+ raise ValueError(f"Could not find '{name}' in {path}")
143
+
144
+ with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf:
145
+ with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar:
146
+ for root, dirs, files in os.walk(temp_dir):
147
+ for f in files:
148
+ full = os.path.join(root, f)
149
+ rel = os.path.relpath(full, temp_dir)
150
+ zipf.write(full, rel)
151
+ pbar.update(1)
152
+
153
+ os.rename(new_path, path)
154
+ LOG.info("Updated metadata in %s", path)
155
+
156
+
157
+ def replace_metadata(path, metadata, name=DEFAULT_NAME):
158
+
159
+ if not isinstance(metadata, dict):
160
+ raise ValueError(f"metadata must be a dict, got {type(metadata)}")
161
+
162
+ if "version" not in metadata:
163
+ raise ValueError("metadata must have a 'version' key")
164
+
165
+ def callback(full):
166
+ with open(full, "w") as f:
167
+ json.dump(metadata, f)
168
+
169
+ _edit_metadata(path, name, callback)
170
+
171
+
172
+ def remove_metadata(path, name=DEFAULT_NAME):
173
+
174
+ LOG.info("Removing metadata '%s' from %s", name, path)
175
+
176
+ def callback(full):
177
+ os.remove(full)
178
+
179
+ _edit_metadata(path, name, callback)
@@ -0,0 +1,78 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+ import argparse
12
+ import importlib
13
+ import logging
14
+ import os
15
+ import sys
16
+
17
+ LOG = logging.getLogger(__name__)
18
+
19
+
20
+ def register(here, package, select, fail=None):
21
+ result = {}
22
+ not_available = {}
23
+
24
+ for p in os.listdir(here):
25
+ full = os.path.join(here, p)
26
+ if p.startswith("_"):
27
+ continue
28
+ if not (p.endswith(".py") or (os.path.isdir(full) and os.path.exists(os.path.join(full, "__init__.py")))):
29
+ continue
30
+
31
+ name, _ = os.path.splitext(p)
32
+
33
+ try:
34
+ imported = importlib.import_module(
35
+ f".{name}",
36
+ package=package,
37
+ )
38
+ except ImportError as e:
39
+ not_available[name] = e
40
+ continue
41
+
42
+ obj = select(imported)
43
+ if obj is not None:
44
+ result[name] = obj
45
+
46
+ for name, e in not_available.items():
47
+ if fail is None:
48
+ pass
49
+ if callable(fail):
50
+ result[name] = fail(name, e)
51
+
52
+ return result
53
+
54
+
55
+ class Command:
56
+ def run(self, args):
57
+ raise NotImplementedError(f"Command not implemented: {args.command}")
58
+
59
+
60
+ class Failed(Command):
61
+ def __init__(self, name, error):
62
+ self.name = name
63
+ self.error = error
64
+
65
+ def add_arguments(self, command_parser):
66
+ command_parser.add_argument("x", nargs=argparse.REMAINDER)
67
+
68
+ def run(self, args):
69
+ print(f"Command '{self.name}' not available: {self.error}")
70
+ sys.exit(1)
71
+
72
+
73
+ COMMANDS = register(
74
+ os.path.dirname(__file__),
75
+ __name__,
76
+ lambda x: x.command(),
77
+ lambda name, error: Failed(name, error),
78
+ )
@@ -0,0 +1,61 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+
12
+ import json
13
+
14
+ from . import Command
15
+
16
+
17
+ def visit(x, path, name, value):
18
+ if isinstance(x, dict):
19
+ for k, v in x.items():
20
+ if k == name:
21
+ print(".".join(path), k, v)
22
+
23
+ if v == value:
24
+ print(".".join(path), k, v)
25
+
26
+ path.append(k)
27
+ visit(v, path, name, value)
28
+ path.pop()
29
+
30
+ if isinstance(x, list):
31
+ for i, v in enumerate(x):
32
+ path.append(str(i))
33
+ visit(v, path, name, value)
34
+ path.pop()
35
+
36
+
37
+ class Checkpoint(Command):
38
+
39
+ def add_arguments(self, command_parser):
40
+ command_parser.add_argument("path", help="Path to the checkpoint.")
41
+ command_parser.add_argument("--name", help="Search for a specific name.")
42
+ command_parser.add_argument("--value", help="Search for a specific value.")
43
+
44
+ def run(self, args):
45
+ from anemoi.utils.checkpoints import load_metadata
46
+
47
+ checkpoint = load_metadata(args.path, "*.json")
48
+
49
+ if args.name or args.value:
50
+ visit(
51
+ checkpoint,
52
+ [],
53
+ args.name if args.name is not None else object(),
54
+ args.value if args.value is not None else object(),
55
+ )
56
+ return
57
+
58
+ print(json.dumps(checkpoint, sort_keys=True, indent=4))
59
+
60
+
61
+ command = Checkpoint
anemoi/utils/timer.py ADDED
@@ -0,0 +1,32 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ """Logging utilities."""
9
+
10
+ import logging
11
+ import time
12
+
13
+ from .humanize import seconds
14
+
15
+ LOGGER = logging.getLogger(__name__)
16
+
17
+
18
+ class Timer:
19
+ def __init__(self, title, logger=LOGGER):
20
+ self.title = title
21
+ self.start = time.time()
22
+ self.logger = logger
23
+
24
+ def __enter__(self):
25
+ return self
26
+
27
+ @property
28
+ def elapsed(self):
29
+ return time.time() - self.start
30
+
31
+ def __exit__(self, *args):
32
+ self.logger.info("%s: %s.", self.title, seconds(self.elapsed))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: anemoi-utils
3
- Version: 0.1.9
3
+ Version: 0.2.0
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
6
6
  License: Apache License
@@ -224,6 +224,7 @@ Requires-Python: >=3.9
224
224
  License-File: LICENSE
225
225
  Requires-Dist: tomli
226
226
  Requires-Dist: pyyaml
227
+ Requires-Dist: tqdm
227
228
  Provides-Extra: all
228
229
  Requires-Dist: tomli ; extra == 'all'
229
230
  Requires-Dist: GitPython ; extra == 'all'
@@ -236,14 +237,15 @@ Requires-Dist: GitPython ; extra == 'dev'
236
237
  Requires-Dist: nvsmi ; extra == 'dev'
237
238
  Requires-Dist: termcolor ; extra == 'dev'
238
239
  Requires-Dist: requests ; extra == 'dev'
239
- Requires-Dist: sphinx ; extra == 'dev'
240
- Requires-Dist: sphinx-rtd-theme ; extra == 'dev'
241
- Requires-Dist: nbsphinx ; extra == 'dev'
242
- Requires-Dist: pandoc ; extra == 'dev'
243
240
  Provides-Extra: docs
244
241
  Requires-Dist: tomli ; extra == 'docs'
245
242
  Requires-Dist: termcolor ; extra == 'docs'
246
243
  Requires-Dist: requests ; extra == 'docs'
244
+ Requires-Dist: sphinx ; extra == 'docs'
245
+ Requires-Dist: sphinx-rtd-theme ; extra == 'docs'
246
+ Requires-Dist: nbsphinx ; extra == 'docs'
247
+ Requires-Dist: pandoc ; extra == 'docs'
248
+ Requires-Dist: sphinx-argparse ; extra == 'docs'
247
249
  Provides-Extra: grib
248
250
  Requires-Dist: requests ; extra == 'grib'
249
251
  Provides-Extra: provenance
@@ -0,0 +1,22 @@
1
+ anemoi/utils/__init__.py,sha256=zZZpbKIoGWwdCOuo6YSruLR7C0GzvzI1Wzhyqaa0K7M,456
2
+ anemoi/utils/__main__.py,sha256=CGl8WF7rWMx9EoArysla0-ThjUFtEZUEGM58LbdU488,1798
3
+ anemoi/utils/_version.py,sha256=H-qsvrxCpdhaQzyddR-yajEqI71hPxLa4KxzpP3uS1g,411
4
+ anemoi/utils/caching.py,sha256=HrC9aFHlcCTaM2Z5u0ivGIXz7eFu35UQQhUuwwuG2pk,1743
5
+ anemoi/utils/checkpoints.py,sha256=1_3mg4B-ykTVfIvIUEv7IxGyREx_ZcilVbB3U-V6O6I,5165
6
+ anemoi/utils/config.py,sha256=XEesqODvkuE3ZA7dnEnZ-ooBRtU6ecPmkfP65FtialA,2147
7
+ anemoi/utils/dates.py,sha256=Ot9OTY1uFvHxW1EU4DPv3oUqmzvkXTwKuwhlfVlY788,8426
8
+ anemoi/utils/grib.py,sha256=gVfo4KYQv31iRyoqRDwk5tiqZDUgOIvhag_kO0qjYD0,3067
9
+ anemoi/utils/humanize.py,sha256=LD6dGnqChxA5j3tMhSybsAGRQzi33d_qS9pUoUHubkc,10330
10
+ anemoi/utils/provenance.py,sha256=v54L9jF1JgYcclOhg3iojRl1v3ajbiWz_oc289xTgO4,9574
11
+ anemoi/utils/text.py,sha256=pGWtDvRFoDxAnSuZJiA-GOGJOJLHsw2dAm0tfVvPKno,8599
12
+ anemoi/utils/timer.py,sha256=5aNdcxVmiCijRmqp0URmsqsDypLUJgME0GSn9bk8zxo,920
13
+ anemoi/utils/commands/__init__.py,sha256=Pc5bhVgW92ox1lMR5WUOLuhiY2HT6PsadSHclyw99Vc,1983
14
+ anemoi/utils/commands/checkpoint.py,sha256=SEnAizU3WklqMXUjmIh4eNrgBVwmheKG9gEBS90zwYU,1741
15
+ anemoi/utils/mars/__init__.py,sha256=RAeY8gJ7ZvsPlcIvrQ4fy9xVHs3SphTAPw_XJDtNIKo,1750
16
+ anemoi/utils/mars/mars.yaml,sha256=R0dujp75lLA4wCWhPeOQnzJ45WZAYLT8gpx509cBFlc,66
17
+ anemoi_utils-0.2.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
+ anemoi_utils-0.2.0.dist-info/METADATA,sha256=7IO_KqlBHcAKCmdszjmE5Zxe5G_ox2QOwKn16mgBIa4,15174
19
+ anemoi_utils-0.2.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
20
+ anemoi_utils-0.2.0.dist-info/entry_points.txt,sha256=LENOkn88xzFQo-V59AKoA_F_cfYQTJYtrNTtf37YgHY,60
21
+ anemoi_utils-0.2.0.dist-info/top_level.txt,sha256=DYn8VPs-fNwr7fNH9XIBqeXIwiYYd2E2k5-dUFFqUz0,7
22
+ anemoi_utils-0.2.0.dist-info/RECORD,,
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ anemoi-utils = anemoi.utils.__main__:main
@@ -1,17 +0,0 @@
1
- anemoi/utils/__init__.py,sha256=zZZpbKIoGWwdCOuo6YSruLR7C0GzvzI1Wzhyqaa0K7M,456
2
- anemoi/utils/_version.py,sha256=NWmu2cvzOcqY9v-ee-qFLmtXRczssdN-cFGZ9qMNSmY,411
3
- anemoi/utils/caching.py,sha256=HrC9aFHlcCTaM2Z5u0ivGIXz7eFu35UQQhUuwwuG2pk,1743
4
- anemoi/utils/checkpoints.py,sha256=IR86FFNh5JR_uQVlgybnZG74PyU0CNLhyocqARwZIrs,2069
5
- anemoi/utils/config.py,sha256=XEesqODvkuE3ZA7dnEnZ-ooBRtU6ecPmkfP65FtialA,2147
6
- anemoi/utils/dates.py,sha256=Ot9OTY1uFvHxW1EU4DPv3oUqmzvkXTwKuwhlfVlY788,8426
7
- anemoi/utils/grib.py,sha256=gVfo4KYQv31iRyoqRDwk5tiqZDUgOIvhag_kO0qjYD0,3067
8
- anemoi/utils/humanize.py,sha256=LD6dGnqChxA5j3tMhSybsAGRQzi33d_qS9pUoUHubkc,10330
9
- anemoi/utils/provenance.py,sha256=v54L9jF1JgYcclOhg3iojRl1v3ajbiWz_oc289xTgO4,9574
10
- anemoi/utils/text.py,sha256=pGWtDvRFoDxAnSuZJiA-GOGJOJLHsw2dAm0tfVvPKno,8599
11
- anemoi/utils/mars/__init__.py,sha256=RAeY8gJ7ZvsPlcIvrQ4fy9xVHs3SphTAPw_XJDtNIKo,1750
12
- anemoi/utils/mars/mars.yaml,sha256=R0dujp75lLA4wCWhPeOQnzJ45WZAYLT8gpx509cBFlc,66
13
- anemoi_utils-0.1.9.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
14
- anemoi_utils-0.1.9.dist-info/METADATA,sha256=wVXW4E6hpkTjCEHUCWDiQXm-pZfJl81Mrzdy6kwKbfA,15101
15
- anemoi_utils-0.1.9.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
16
- anemoi_utils-0.1.9.dist-info/top_level.txt,sha256=DYn8VPs-fNwr7fNH9XIBqeXIwiYYd2E2k5-dUFFqUz0,7
17
- anemoi_utils-0.1.9.dist-info/RECORD,,