anemoi-utils 0.1.9__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

Files changed (51) hide show
  1. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/.pre-commit-config.yaml +1 -0
  2. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/PKG-INFO +7 -5
  3. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/conf.py +5 -9
  4. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/index.rst +3 -3
  5. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/requirements.txt +1 -0
  6. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/pyproject.toml +10 -6
  7. anemoi_utils-0.2.1/src/anemoi/utils/__main__.py +77 -0
  8. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/_version.py +2 -2
  9. anemoi_utils-0.2.1/src/anemoi/utils/checkpoints.py +179 -0
  10. anemoi_utils-0.2.1/src/anemoi/utils/cli.py +126 -0
  11. anemoi_utils-0.2.1/src/anemoi/utils/commands/__init__.py +78 -0
  12. anemoi_utils-0.2.1/src/anemoi/utils/commands/checkpoint.py +61 -0
  13. anemoi_utils-0.2.1/src/anemoi/utils/timer.py +32 -0
  14. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi_utils.egg-info/PKG-INFO +7 -5
  15. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi_utils.egg-info/SOURCES.txt +6 -0
  16. anemoi_utils-0.2.1/src/anemoi_utils.egg-info/entry_points.txt +2 -0
  17. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi_utils.egg-info/requires.txt +6 -4
  18. anemoi_utils-0.1.9/src/anemoi/utils/checkpoints.py +0 -75
  19. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/.github/workflows/python-publish.yml +0 -0
  20. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/.gitignore +0 -0
  21. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/.readthedocs.yaml +0 -0
  22. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/LICENSE +0 -0
  23. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/README.md +0 -0
  24. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/Makefile +0 -0
  25. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/_static/logo.png +0 -0
  26. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/_static/style.css +0 -0
  27. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/_templates/.gitkeep +0 -0
  28. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/installing.rst +0 -0
  29. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/modules/checkpoints.rst +0 -0
  30. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/modules/config.rst +0 -0
  31. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/modules/dates.rst +0 -0
  32. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/modules/grib.rst +0 -0
  33. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/modules/humanize.rst +0 -0
  34. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/modules/provenance.rst +0 -0
  35. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/docs/modules/text.rst +0 -0
  36. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/setup.cfg +0 -0
  37. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/__init__.py +0 -0
  38. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/caching.py +0 -0
  39. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/config.py +0 -0
  40. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/dates.py +0 -0
  41. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/grib.py +0 -0
  42. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/humanize.py +0 -0
  43. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/mars/__init__.py +0 -0
  44. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/mars/mars.yaml +0 -0
  45. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/provenance.py +0 -0
  46. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi/utils/text.py +0 -0
  47. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi_utils.egg-info/dependency_links.txt +0 -0
  48. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/src/anemoi_utils.egg-info/top_level.txt +0 -0
  49. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/tests/requirements.txt +0 -0
  50. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/tests/test_dates.py +0 -0
  51. {anemoi_utils-0.1.9 → anemoi_utils-0.2.1}/tests/test_utils.py +0 -0
@@ -61,6 +61,7 @@ repos:
61
61
  rev: v0.0.14
62
62
  hooks:
63
63
  - id: rstfmt
64
+ exclude: 'cli/.*' # Because we use argparse
64
65
 
65
66
  - repo: https://github.com/b8raoult/pre-commit-docconvert
66
67
  rev: "0.1.4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: anemoi-utils
3
- Version: 0.1.9
3
+ Version: 0.2.1
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
6
6
  License: Apache License
@@ -224,6 +224,7 @@ Requires-Python: >=3.9
224
224
  License-File: LICENSE
225
225
  Requires-Dist: tomli
226
226
  Requires-Dist: pyyaml
227
+ Requires-Dist: tqdm
227
228
  Provides-Extra: provenance
228
229
  Requires-Dist: GitPython; extra == "provenance"
229
230
  Requires-Dist: nvsmi; extra == "provenance"
@@ -235,6 +236,11 @@ Provides-Extra: docs
235
236
  Requires-Dist: tomli; extra == "docs"
236
237
  Requires-Dist: termcolor; extra == "docs"
237
238
  Requires-Dist: requests; extra == "docs"
239
+ Requires-Dist: sphinx; extra == "docs"
240
+ Requires-Dist: sphinx_rtd_theme; extra == "docs"
241
+ Requires-Dist: nbsphinx; extra == "docs"
242
+ Requires-Dist: pandoc; extra == "docs"
243
+ Requires-Dist: sphinx_argparse; extra == "docs"
238
244
  Provides-Extra: all
239
245
  Requires-Dist: tomli; extra == "all"
240
246
  Requires-Dist: GitPython; extra == "all"
@@ -247,7 +253,3 @@ Requires-Dist: GitPython; extra == "dev"
247
253
  Requires-Dist: nvsmi; extra == "dev"
248
254
  Requires-Dist: termcolor; extra == "dev"
249
255
  Requires-Dist: requests; extra == "dev"
250
- Requires-Dist: sphinx; extra == "dev"
251
- Requires-Dist: sphinx_rtd_theme; extra == "dev"
252
- Requires-Dist: nbsphinx; extra == "dev"
253
- Requires-Dist: pandoc; extra == "dev"
@@ -14,14 +14,9 @@ import datetime
14
14
  import os
15
15
  import sys
16
16
 
17
- sys.path.insert(0, os.path.join(os.path.abspath(".."), "src"))
18
-
19
-
20
17
  read_the_docs_build = os.environ.get("READTHEDOCS", None) == "True"
21
18
 
22
- # top = os.path.realpath(os.path.dirname(os.path.dirname(__file__)))
23
- # sys.path.insert(0, top)
24
-
19
+ sys.path.insert(0, os.path.join(os.path.abspath(".."), "src"))
25
20
 
26
21
  source_suffix = ".rst"
27
22
  master_doc = "index"
@@ -32,7 +27,7 @@ html_logo = "_static/logo.png"
32
27
 
33
28
  # -- Project information -----------------------------------------------------
34
29
 
35
- project = "Anemoi"
30
+ project = "Anemoi Utils"
36
31
 
37
32
  author = "ECMWF"
38
33
 
@@ -45,9 +40,9 @@ else:
45
40
  copyright = "%s, ECMWF" % (years,)
46
41
 
47
42
  try:
48
- import anemoi.utils
43
+ from anemoi.utils._version import __version__
49
44
 
50
- release = anemoi.utils.__version__
45
+ release = __version__
51
46
  except ImportError:
52
47
  release = "0.0.0"
53
48
 
@@ -65,6 +60,7 @@ extensions = [
65
60
  "sphinx.ext.intersphinx",
66
61
  "sphinx.ext.autodoc",
67
62
  "sphinx.ext.napoleon",
63
+ "sphinxarg.ext",
68
64
  ]
69
65
 
70
66
  # Add any paths that contain templates here, relative to this directory.
@@ -2,9 +2,9 @@
2
2
 
3
3
  .. _index-page:
4
4
 
5
- ####################################
6
- Welcome to Anemoi's documentation!
7
- ####################################
5
+ ##########################################
6
+ Welcome to `anemoi-utils` documentation!
7
+ ##########################################
8
8
 
9
9
  .. warning::
10
10
 
@@ -2,6 +2,7 @@
2
2
  sphinx
3
3
  sphinx_rtd_theme
4
4
  nbsphinx
5
+ sphinx_argparse
5
6
 
6
7
  # Also requires `brew install pandoc` on Mac
7
8
  pandoc
@@ -40,8 +40,9 @@ classifiers = [
40
40
  ]
41
41
 
42
42
  dependencies = [
43
- "tomli", # Only needed before 3.11
43
+ "tomli", # Only needed before 3.11
44
44
  "pyyaml",
45
+ "tqdm",
45
46
  ]
46
47
 
47
48
  [project.optional-dependencies]
@@ -54,9 +55,14 @@ grib = ["requests"]
54
55
  # Loaded by read-the-docs
55
56
  # `pip install .[docs]`
56
57
  docs = [
57
- "tomli", # Only needed before 3.11
58
+ "tomli", # Only needed before 3.11
58
59
  "termcolor",
59
60
  "requests",
61
+ "sphinx",
62
+ "sphinx_rtd_theme",
63
+ "nbsphinx",
64
+ "pandoc",
65
+ "sphinx_argparse",
60
66
  ]
61
67
 
62
68
  all = [
@@ -73,10 +79,6 @@ dev = [
73
79
  "nvsmi",
74
80
  "termcolor",
75
81
  "requests",
76
- "sphinx",
77
- "sphinx_rtd_theme",
78
- "nbsphinx",
79
- "pandoc",
80
82
  ]
81
83
 
82
84
  [project.urls]
@@ -86,6 +88,8 @@ Repository = "https://github.com/ecmwf/anemoi-utils/"
86
88
  Issues = "https://github.com/ecmwf/anemoi-utils/issues"
87
89
  # Changelog = "https://github.com/ecmwf/anemoi-utils/CHANGELOG.md"
88
90
 
91
+ [project.scripts]
92
+ anemoi-utils = "anemoi.utils.__main__:main"
89
93
 
90
94
  [tool.setuptools_scm]
91
95
  version_file = "src/anemoi/utils/_version.py"
@@ -0,0 +1,77 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+
12
+ import argparse
13
+ import logging
14
+ import sys
15
+ import traceback
16
+
17
+ from . import __version__
18
+ from .commands import COMMANDS
19
+
20
+ LOG = logging.getLogger(__name__)
21
+
22
+
23
+ def create_parser():
24
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
25
+
26
+ parser.add_argument(
27
+ "--version",
28
+ "-V",
29
+ action="store_true",
30
+ help="show the version and exit",
31
+ )
32
+ parser.add_argument(
33
+ "--debug",
34
+ "-d",
35
+ action="store_true",
36
+ help="Debug mode",
37
+ )
38
+
39
+ subparsers = parser.add_subparsers(help="commands:", dest="command")
40
+ for name, command in COMMANDS.items():
41
+ command_parser = subparsers.add_parser(name, help=command.__doc__)
42
+ command.add_arguments(command_parser)
43
+
44
+ return parser
45
+
46
+
47
+ def main():
48
+ parser = create_parser()
49
+ args = parser.parse_args()
50
+
51
+ if args.version:
52
+ print(__version__)
53
+ return
54
+
55
+ if args.command is None:
56
+ parser.print_help()
57
+ return
58
+
59
+ cmd = COMMANDS[args.command]
60
+
61
+ logging.basicConfig(
62
+ format="%(asctime)s %(levelname)s %(message)s",
63
+ datefmt="%Y-%m-%d %H:%M:%S",
64
+ level=logging.DEBUG if args.debug else logging.INFO,
65
+ )
66
+
67
+ try:
68
+ cmd.run(args)
69
+ except ValueError as e:
70
+ traceback.print_exc()
71
+ LOG.error("\n💣 %s", str(e).lstrip())
72
+ LOG.error("💣 Exiting")
73
+ sys.exit(1)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.9'
16
- __version_tuple__ = version_tuple = (0, 1, 9)
15
+ __version__ = version = '0.2.1'
16
+ __version_tuple__ = version_tuple = (0, 2, 1)
@@ -0,0 +1,179 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+
9
+ """Read and write extra metadata in PyTorch checkpoints files. These files
10
+ are zip archives containing the model weights.
11
+ """
12
+
13
+ import json
14
+ import logging
15
+ import os
16
+ import time
17
+ import zipfile
18
+ from tempfile import TemporaryDirectory
19
+
20
+ import tqdm
21
+
22
+ LOG = logging.getLogger(__name__)
23
+
24
+ DEFAULT_NAME = "ai-models.json"
25
+ DEFAULT_FOLDER = "anemoi-metadata"
26
+
27
+
28
+ def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
29
+ """Check if a checkpoint file has a metadata file
30
+
31
+ Parameters
32
+ ----------
33
+ path : str
34
+ The path to the checkpoint file
35
+ name : str, optional
36
+ The name of the metadata file in the zip archive
37
+
38
+ Returns
39
+ -------
40
+ bool
41
+ True if the metadata file is found
42
+ """
43
+ with zipfile.ZipFile(path, "r") as f:
44
+ for b in f.namelist():
45
+ if os.path.basename(b) == name:
46
+ return True
47
+ return False
48
+
49
+
50
+ def load_metadata(path: str, name: str = DEFAULT_NAME):
51
+ """Load metadata from a checkpoint file
52
+
53
+ Parameters
54
+ ----------
55
+ path : str
56
+ The path to the checkpoint file
57
+ name : str, optional
58
+ The name of the metadata file in the zip archive
59
+
60
+ Returns
61
+ -------
62
+ JSON
63
+ The content of the metadata file
64
+
65
+ Raises
66
+ ------
67
+ ValueError
68
+ If the metadata file is not found
69
+ """
70
+ with zipfile.ZipFile(path, "r") as f:
71
+ metadata = None
72
+ for b in f.namelist():
73
+ if os.path.basename(b) == name:
74
+ if metadata is not None:
75
+ raise ValueError(f"Found two or more '{name}' in {path}.")
76
+ metadata = b
77
+
78
+ if metadata is not None:
79
+ with zipfile.ZipFile(path, "r") as f:
80
+ return json.load(f.open(metadata, "r"))
81
+ else:
82
+ raise ValueError(f"Could not find '{name}' in {path}.")
83
+
84
+
85
+ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER):
86
+ """Save metadata to a checkpoint file
87
+
88
+ Parameters
89
+ ----------
90
+ path : str
91
+ The path to the checkpoint file
92
+ metadata : JSON
93
+ A JSON serializable object
94
+ name : str, optional
95
+ The name of the metadata file in the zip archive
96
+ """
97
+ with zipfile.ZipFile(path, "a") as zipf:
98
+
99
+ directories = set()
100
+
101
+ for b in zipf.namelist():
102
+ directory = os.path.dirname(b)
103
+ while os.path.dirname(directory) not in (".", ""):
104
+ directory = os.path.dirname(directory)
105
+ directories.add(directory)
106
+
107
+ if os.path.basename(b) == name:
108
+ raise ValueError(f"'{name}' already in {path}")
109
+
110
+ if len(directories) != 1:
111
+ # PyTorch checkpoints should have a single directory
112
+ # otherwise PyTorch will complain
113
+ raise ValueError(f"No or multiple directories in the checkpoint {path}, directories={directories}")
114
+
115
+ directory = list(directories)[0]
116
+
117
+ LOG.info("Saving metadata to %s/%s/%s", directory, folder, name)
118
+
119
+ zipf.writestr(
120
+ f"{directory}/{folder}/{name}",
121
+ json.dumps(metadata),
122
+ )
123
+
124
+
125
+ def _edit_metadata(path, name, callback):
126
+ new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"
127
+
128
+ found = False
129
+
130
+ with TemporaryDirectory() as temp_dir:
131
+ zipfile.ZipFile(path, "r").extractall(temp_dir)
132
+ total = 0
133
+ for root, dirs, files in os.walk(temp_dir):
134
+ for f in files:
135
+ total += 1
136
+ full = os.path.join(root, f)
137
+ if f == name:
138
+ found = True
139
+ callback(full)
140
+
141
+ if not found:
142
+ raise ValueError(f"Could not find '{name}' in {path}")
143
+
144
+ with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf:
145
+ with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar:
146
+ for root, dirs, files in os.walk(temp_dir):
147
+ for f in files:
148
+ full = os.path.join(root, f)
149
+ rel = os.path.relpath(full, temp_dir)
150
+ zipf.write(full, rel)
151
+ pbar.update(1)
152
+
153
+ os.rename(new_path, path)
154
+ LOG.info("Updated metadata in %s", path)
155
+
156
+
157
+ def replace_metadata(path, metadata, name=DEFAULT_NAME):
158
+
159
+ if not isinstance(metadata, dict):
160
+ raise ValueError(f"metadata must be a dict, got {type(metadata)}")
161
+
162
+ if "version" not in metadata:
163
+ raise ValueError("metadata must have a 'version' key")
164
+
165
+ def callback(full):
166
+ with open(full, "w") as f:
167
+ json.dump(metadata, f)
168
+
169
+ _edit_metadata(path, name, callback)
170
+
171
+
172
+ def remove_metadata(path, name=DEFAULT_NAME):
173
+
174
+ LOG.info("Removing metadata '%s' from %s", name, path)
175
+
176
+ def callback(full):
177
+ os.remove(full)
178
+
179
+ _edit_metadata(path, name, callback)
@@ -0,0 +1,126 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ import argparse
9
+ import importlib
10
+ import logging
11
+ import os
12
+ import sys
13
+ import traceback
14
+
15
+ LOG = logging.getLogger(__name__)
16
+
17
+
18
+ class Command:
19
+ def run(self, args):
20
+ raise NotImplementedError(f"Command not implemented: {args.command}")
21
+
22
+
23
+ def make_parser(description, commands):
24
+ parser = argparse.ArgumentParser(
25
+ description=description,
26
+ formatter_class=argparse.RawDescriptionHelpFormatter,
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--version",
31
+ "-V",
32
+ action="store_true",
33
+ help="show the version and exit",
34
+ )
35
+ parser.add_argument(
36
+ "--debug",
37
+ "-d",
38
+ action="store_true",
39
+ help="Debug mode",
40
+ )
41
+
42
+ subparsers = parser.add_subparsers(help="commands:", dest="command")
43
+ for name, command in commands.items():
44
+ command_parser = subparsers.add_parser(name, help=command.__doc__)
45
+ command.add_arguments(command_parser)
46
+
47
+ return parser
48
+
49
+
50
+ class Failed(Command):
51
+ def __init__(self, name, error):
52
+ self.name = name
53
+ self.error = error
54
+
55
+ def add_arguments(self, command_parser):
56
+ command_parser.add_argument("x", nargs=argparse.REMAINDER)
57
+
58
+ def run(self, args):
59
+ print(f"Command '{self.name}' not available: {self.error}")
60
+ sys.exit(1)
61
+
62
+
63
+ def register_commands(here, package, select, fail=None):
64
+ result = {}
65
+ not_available = {}
66
+
67
+ for p in os.listdir(here):
68
+ full = os.path.join(here, p)
69
+ if p.startswith("_"):
70
+ continue
71
+ if not (p.endswith(".py") or (os.path.isdir(full) and os.path.exists(os.path.join(full, "__init__.py")))):
72
+ continue
73
+
74
+ name, _ = os.path.splitext(p)
75
+
76
+ try:
77
+ imported = importlib.import_module(
78
+ f".{name}",
79
+ package=package,
80
+ )
81
+ except ImportError as e:
82
+ not_available[name] = e
83
+ continue
84
+
85
+ obj = select(imported)
86
+ if obj is not None:
87
+ result[name] = obj
88
+
89
+ for name, e in not_available.items():
90
+ if fail is None:
91
+ pass
92
+ if callable(fail):
93
+ result[name] = fail(name, e)
94
+
95
+ return result
96
+
97
+
98
+ def cli_main(version, description, commands):
99
+ parser = make_parser(description, commands)
100
+ args = parser.parse_args()
101
+
102
+ if args.version:
103
+ print(version)
104
+ return
105
+
106
+ if args.command is None:
107
+ parser.print_help()
108
+ return
109
+
110
+ cmd = commands[args.command]
111
+
112
+ logging.basicConfig(
113
+ format="%(asctime)s %(levelname)s %(message)s",
114
+ datefmt="%Y-%m-%d %H:%M:%S",
115
+ level=logging.DEBUG if args.debug else logging.INFO,
116
+ )
117
+
118
+ try:
119
+ cmd.run(args)
120
+ except ValueError as e:
121
+ traceback.print_exc()
122
+ LOG.error("\n💣 %s", str(e).lstrip())
123
+ LOG.error("💣 Exiting")
124
+ sys.exit(1)
125
+
126
+ sys.exit(0)
@@ -0,0 +1,78 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+ import argparse
12
+ import importlib
13
+ import logging
14
+ import os
15
+ import sys
16
+
17
+ LOG = logging.getLogger(__name__)
18
+
19
+
20
+ def register(here, package, select, fail=None):
21
+ result = {}
22
+ not_available = {}
23
+
24
+ for p in os.listdir(here):
25
+ full = os.path.join(here, p)
26
+ if p.startswith("_"):
27
+ continue
28
+ if not (p.endswith(".py") or (os.path.isdir(full) and os.path.exists(os.path.join(full, "__init__.py")))):
29
+ continue
30
+
31
+ name, _ = os.path.splitext(p)
32
+
33
+ try:
34
+ imported = importlib.import_module(
35
+ f".{name}",
36
+ package=package,
37
+ )
38
+ except ImportError as e:
39
+ not_available[name] = e
40
+ continue
41
+
42
+ obj = select(imported)
43
+ if obj is not None:
44
+ result[name] = obj
45
+
46
+ for name, e in not_available.items():
47
+ if fail is None:
48
+ pass
49
+ if callable(fail):
50
+ result[name] = fail(name, e)
51
+
52
+ return result
53
+
54
+
55
+ class Command:
56
+ def run(self, args):
57
+ raise NotImplementedError(f"Command not implemented: {args.command}")
58
+
59
+
60
+ class Failed(Command):
61
+ def __init__(self, name, error):
62
+ self.name = name
63
+ self.error = error
64
+
65
+ def add_arguments(self, command_parser):
66
+ command_parser.add_argument("x", nargs=argparse.REMAINDER)
67
+
68
+ def run(self, args):
69
+ print(f"Command '{self.name}' not available: {self.error}")
70
+ sys.exit(1)
71
+
72
+
73
+ COMMANDS = register(
74
+ os.path.dirname(__file__),
75
+ __name__,
76
+ lambda x: x.command(),
77
+ lambda name, error: Failed(name, error),
78
+ )
@@ -0,0 +1,61 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+
12
+ import json
13
+
14
+ from . import Command
15
+
16
+
17
+ def visit(x, path, name, value):
18
+ if isinstance(x, dict):
19
+ for k, v in x.items():
20
+ if k == name:
21
+ print(".".join(path), k, v)
22
+
23
+ if v == value:
24
+ print(".".join(path), k, v)
25
+
26
+ path.append(k)
27
+ visit(v, path, name, value)
28
+ path.pop()
29
+
30
+ if isinstance(x, list):
31
+ for i, v in enumerate(x):
32
+ path.append(str(i))
33
+ visit(v, path, name, value)
34
+ path.pop()
35
+
36
+
37
+ class Checkpoint(Command):
38
+
39
+ def add_arguments(self, command_parser):
40
+ command_parser.add_argument("path", help="Path to the checkpoint.")
41
+ command_parser.add_argument("--name", help="Search for a specific name.")
42
+ command_parser.add_argument("--value", help="Search for a specific value.")
43
+
44
+ def run(self, args):
45
+ from anemoi.utils.checkpoints import load_metadata
46
+
47
+ checkpoint = load_metadata(args.path, "*.json")
48
+
49
+ if args.name or args.value:
50
+ visit(
51
+ checkpoint,
52
+ [],
53
+ args.name if args.name is not None else object(),
54
+ args.value if args.value is not None else object(),
55
+ )
56
+ return
57
+
58
+ print(json.dumps(checkpoint, sort_keys=True, indent=4))
59
+
60
+
61
+ command = Checkpoint
@@ -0,0 +1,32 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ """Logging utilities."""
9
+
10
+ import logging
11
+ import time
12
+
13
+ from .humanize import seconds
14
+
15
+ LOGGER = logging.getLogger(__name__)
16
+
17
+
18
+ class Timer:
19
+ def __init__(self, title, logger=LOGGER):
20
+ self.title = title
21
+ self.start = time.time()
22
+ self.logger = logger
23
+
24
+ def __enter__(self):
25
+ return self
26
+
27
+ @property
28
+ def elapsed(self):
29
+ return time.time() - self.start
30
+
31
+ def __exit__(self, *args):
32
+ self.logger.info("%s: %s.", self.title, seconds(self.elapsed))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: anemoi-utils
3
- Version: 0.1.9
3
+ Version: 0.2.1
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
6
6
  License: Apache License
@@ -224,6 +224,7 @@ Requires-Python: >=3.9
224
224
  License-File: LICENSE
225
225
  Requires-Dist: tomli
226
226
  Requires-Dist: pyyaml
227
+ Requires-Dist: tqdm
227
228
  Provides-Extra: provenance
228
229
  Requires-Dist: GitPython; extra == "provenance"
229
230
  Requires-Dist: nvsmi; extra == "provenance"
@@ -235,6 +236,11 @@ Provides-Extra: docs
235
236
  Requires-Dist: tomli; extra == "docs"
236
237
  Requires-Dist: termcolor; extra == "docs"
237
238
  Requires-Dist: requests; extra == "docs"
239
+ Requires-Dist: sphinx; extra == "docs"
240
+ Requires-Dist: sphinx_rtd_theme; extra == "docs"
241
+ Requires-Dist: nbsphinx; extra == "docs"
242
+ Requires-Dist: pandoc; extra == "docs"
243
+ Requires-Dist: sphinx_argparse; extra == "docs"
238
244
  Provides-Extra: all
239
245
  Requires-Dist: tomli; extra == "all"
240
246
  Requires-Dist: GitPython; extra == "all"
@@ -247,7 +253,3 @@ Requires-Dist: GitPython; extra == "dev"
247
253
  Requires-Dist: nvsmi; extra == "dev"
248
254
  Requires-Dist: termcolor; extra == "dev"
249
255
  Requires-Dist: requests; extra == "dev"
250
- Requires-Dist: sphinx; extra == "dev"
251
- Requires-Dist: sphinx_rtd_theme; extra == "dev"
252
- Requires-Dist: nbsphinx; extra == "dev"
253
- Requires-Dist: pandoc; extra == "dev"
@@ -21,20 +21,26 @@ docs/modules/humanize.rst
21
21
  docs/modules/provenance.rst
22
22
  docs/modules/text.rst
23
23
  src/anemoi/utils/__init__.py
24
+ src/anemoi/utils/__main__.py
24
25
  src/anemoi/utils/_version.py
25
26
  src/anemoi/utils/caching.py
26
27
  src/anemoi/utils/checkpoints.py
28
+ src/anemoi/utils/cli.py
27
29
  src/anemoi/utils/config.py
28
30
  src/anemoi/utils/dates.py
29
31
  src/anemoi/utils/grib.py
30
32
  src/anemoi/utils/humanize.py
31
33
  src/anemoi/utils/provenance.py
32
34
  src/anemoi/utils/text.py
35
+ src/anemoi/utils/timer.py
36
+ src/anemoi/utils/commands/__init__.py
37
+ src/anemoi/utils/commands/checkpoint.py
33
38
  src/anemoi/utils/mars/__init__.py
34
39
  src/anemoi/utils/mars/mars.yaml
35
40
  src/anemoi_utils.egg-info/PKG-INFO
36
41
  src/anemoi_utils.egg-info/SOURCES.txt
37
42
  src/anemoi_utils.egg-info/dependency_links.txt
43
+ src/anemoi_utils.egg-info/entry_points.txt
38
44
  src/anemoi_utils.egg-info/requires.txt
39
45
  src/anemoi_utils.egg-info/top_level.txt
40
46
  tests/requirements.txt
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ anemoi-utils = anemoi.utils.__main__:main
@@ -1,5 +1,6 @@
1
1
  tomli
2
2
  pyyaml
3
+ tqdm
3
4
 
4
5
  [all]
5
6
  tomli
@@ -14,15 +15,16 @@ GitPython
14
15
  nvsmi
15
16
  termcolor
16
17
  requests
17
- sphinx
18
- sphinx_rtd_theme
19
- nbsphinx
20
- pandoc
21
18
 
22
19
  [docs]
23
20
  tomli
24
21
  termcolor
25
22
  requests
23
+ sphinx
24
+ sphinx_rtd_theme
25
+ nbsphinx
26
+ pandoc
27
+ sphinx_argparse
26
28
 
27
29
  [grib]
28
30
  requests
@@ -1,75 +0,0 @@
1
- # (C) Copyright 2024 ECMWF.
2
- #
3
- # This software is licensed under the terms of the Apache Licence Version 2.0
4
- # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
- # In applying this licence, ECMWF does not waive the privileges and immunities
6
- # granted to it by virtue of its status as an intergovernmental organisation
7
- # nor does it submit to any jurisdiction.
8
-
9
- """Read and write extra metadata in PyTorch checkpoints files. These files
10
- are zip archives containing the model weights.
11
- """
12
-
13
- import json
14
- import logging
15
- import os
16
- import zipfile
17
-
18
- LOG = logging.getLogger(__name__)
19
-
20
- DEFAULT_NAME = "anemoi-metadata.json"
21
-
22
-
23
- def load_metadata(path: str, name: str = DEFAULT_NAME):
24
- """Load metadata from a checkpoint file
25
-
26
- Parameters
27
- ----------
28
- path : str
29
- The path to the checkpoint file
30
- name : str, optional
31
- The name of the metadata file in the zip archive
32
-
33
- Returns
34
- -------
35
- JSON
36
- The content of the metadata file
37
-
38
- Raises
39
- ------
40
- ValueError
41
- If the metadata file is not found
42
- """
43
- with zipfile.ZipFile(path, "r") as f:
44
- metadata = None
45
- for b in f.namelist():
46
- if os.path.basename(b) == name:
47
- if metadata is not None:
48
- LOG.warning(f"Found two '{name}' if {path}")
49
- metadata = b
50
-
51
- if metadata is not None:
52
- with zipfile.ZipFile(path, "r") as f:
53
- return json.load(f.open(metadata, "r"))
54
- else:
55
- raise ValueError(f"Could not find {name} in {path}")
56
-
57
-
58
- def save_metadata(path, metadata, name=DEFAULT_NAME):
59
- """Save metadata to a checkpoint file
60
-
61
- Parameters
62
- ----------
63
- path : str
64
- The path to the checkpoint file
65
- metadata : JSON
66
- A JSON serializable object
67
- name : str, optional
68
- The name of the metadata file in the zip archive
69
- """
70
- with zipfile.ZipFile(path, "a") as zipf:
71
- base, _ = os.path.splitext(os.path.basename(path))
72
- zipf.writestr(
73
- f"{base}/{name}",
74
- json.dumps(metadata),
75
- )
File without changes
File without changes
File without changes
File without changes
File without changes