anemoi-utils 0.1.8__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

Files changed (50) hide show
  1. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/.gitignore +1 -0
  2. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/.pre-commit-config.yaml +1 -0
  3. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/PKG-INFO +8 -5
  4. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/conf.py +5 -9
  5. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/index.rst +3 -3
  6. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/requirements.txt +1 -0
  7. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/pyproject.toml +14 -6
  8. anemoi_utils-0.2.0/src/anemoi/utils/__main__.py +72 -0
  9. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/_version.py +2 -2
  10. anemoi_utils-0.2.0/src/anemoi/utils/checkpoints.py +179 -0
  11. anemoi_utils-0.2.0/src/anemoi/utils/commands/__init__.py +78 -0
  12. anemoi_utils-0.2.0/src/anemoi/utils/commands/checkpoint.py +61 -0
  13. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/dates.py +75 -2
  14. anemoi_utils-0.2.0/src/anemoi/utils/mars/__init__.py +76 -0
  15. anemoi_utils-0.2.0/src/anemoi/utils/mars/mars.yaml +5 -0
  16. anemoi_utils-0.2.0/src/anemoi/utils/timer.py +32 -0
  17. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/PKG-INFO +8 -5
  18. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/SOURCES.txt +8 -0
  19. anemoi_utils-0.2.0/src/anemoi_utils.egg-info/entry_points.txt +2 -0
  20. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/requires.txt +7 -4
  21. anemoi_utils-0.2.0/tests/test_dates.py +113 -0
  22. anemoi_utils-0.1.8/src/anemoi/utils/checkpoints.py +0 -75
  23. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/.github/workflows/python-publish.yml +0 -0
  24. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/.readthedocs.yaml +0 -0
  25. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/LICENSE +0 -0
  26. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/README.md +0 -0
  27. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/Makefile +0 -0
  28. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/_static/logo.png +0 -0
  29. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/_static/style.css +0 -0
  30. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/_templates/.gitkeep +0 -0
  31. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/installing.rst +0 -0
  32. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/checkpoints.rst +0 -0
  33. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/config.rst +0 -0
  34. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/dates.rst +0 -0
  35. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/grib.rst +0 -0
  36. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/humanize.rst +0 -0
  37. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/provenance.rst +0 -0
  38. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/text.rst +0 -0
  39. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/setup.cfg +0 -0
  40. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/__init__.py +0 -0
  41. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/caching.py +0 -0
  42. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/config.py +0 -0
  43. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/grib.py +0 -0
  44. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/humanize.py +0 -0
  45. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/provenance.py +0 -0
  46. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/text.py +0 -0
  47. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/dependency_links.txt +0 -0
  48. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/top_level.txt +0 -0
  49. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/tests/requirements.txt +0 -0
  50. {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/tests/test_utils.py +0 -0
@@ -185,3 +185,4 @@ _build/
185
185
  ~*
186
186
  *.sync
187
187
  _version.py
188
+ *.code-workspace
@@ -61,6 +61,7 @@ repos:
61
61
  rev: v0.0.14
62
62
  hooks:
63
63
  - id: rstfmt
64
+ exclude: 'cli/.*' # Because we use argparse
64
65
 
65
66
  - repo: https://github.com/b8raoult/pre-commit-docconvert
66
67
  rev: "0.1.4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: anemoi-utils
3
- Version: 0.1.8
3
+ Version: 0.2.0
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
6
6
  License: Apache License
@@ -223,6 +223,8 @@ Classifier: Operating System :: OS Independent
223
223
  Requires-Python: >=3.9
224
224
  License-File: LICENSE
225
225
  Requires-Dist: tomli
226
+ Requires-Dist: pyyaml
227
+ Requires-Dist: tqdm
226
228
  Provides-Extra: provenance
227
229
  Requires-Dist: GitPython; extra == "provenance"
228
230
  Requires-Dist: nvsmi; extra == "provenance"
@@ -234,6 +236,11 @@ Provides-Extra: docs
234
236
  Requires-Dist: tomli; extra == "docs"
235
237
  Requires-Dist: termcolor; extra == "docs"
236
238
  Requires-Dist: requests; extra == "docs"
239
+ Requires-Dist: sphinx; extra == "docs"
240
+ Requires-Dist: sphinx_rtd_theme; extra == "docs"
241
+ Requires-Dist: nbsphinx; extra == "docs"
242
+ Requires-Dist: pandoc; extra == "docs"
243
+ Requires-Dist: sphinx_argparse; extra == "docs"
237
244
  Provides-Extra: all
238
245
  Requires-Dist: tomli; extra == "all"
239
246
  Requires-Dist: GitPython; extra == "all"
@@ -246,7 +253,3 @@ Requires-Dist: GitPython; extra == "dev"
246
253
  Requires-Dist: nvsmi; extra == "dev"
247
254
  Requires-Dist: termcolor; extra == "dev"
248
255
  Requires-Dist: requests; extra == "dev"
249
- Requires-Dist: sphinx; extra == "dev"
250
- Requires-Dist: sphinx_rtd_theme; extra == "dev"
251
- Requires-Dist: nbsphinx; extra == "dev"
252
- Requires-Dist: pandoc; extra == "dev"
@@ -14,14 +14,9 @@ import datetime
14
14
  import os
15
15
  import sys
16
16
 
17
- sys.path.insert(0, os.path.join(os.path.abspath(".."), "src"))
18
-
19
-
20
17
  read_the_docs_build = os.environ.get("READTHEDOCS", None) == "True"
21
18
 
22
- # top = os.path.realpath(os.path.dirname(os.path.dirname(__file__)))
23
- # sys.path.insert(0, top)
24
-
19
+ sys.path.insert(0, os.path.join(os.path.abspath(".."), "src"))
25
20
 
26
21
  source_suffix = ".rst"
27
22
  master_doc = "index"
@@ -32,7 +27,7 @@ html_logo = "_static/logo.png"
32
27
 
33
28
  # -- Project information -----------------------------------------------------
34
29
 
35
- project = "Anemoi"
30
+ project = "Anemoi Utils"
36
31
 
37
32
  author = "ECMWF"
38
33
 
@@ -45,9 +40,9 @@ else:
45
40
  copyright = "%s, ECMWF" % (years,)
46
41
 
47
42
  try:
48
- import anemoi.utils
43
+ from anemoi.utils._version import __version__
49
44
 
50
- release = anemoi.utils.__version__
45
+ release = __version__
51
46
  except ImportError:
52
47
  release = "0.0.0"
53
48
 
@@ -65,6 +60,7 @@ extensions = [
65
60
  "sphinx.ext.intersphinx",
66
61
  "sphinx.ext.autodoc",
67
62
  "sphinx.ext.napoleon",
63
+ "sphinxarg.ext",
68
64
  ]
69
65
 
70
66
  # Add any paths that contain templates here, relative to this directory.
@@ -2,9 +2,9 @@
2
2
 
3
3
  .. _index-page:
4
4
 
5
- ####################################
6
- Welcome to Anemoi's documentation!
7
- ####################################
5
+ ##########################################
6
+ Welcome to `anemoi-utils` documentation!
7
+ ##########################################
8
8
 
9
9
  .. warning::
10
10
 
@@ -2,6 +2,7 @@
2
2
  sphinx
3
3
  sphinx_rtd_theme
4
4
  nbsphinx
5
+ sphinx_argparse
5
6
 
6
7
  # Also requires `brew install pandoc` on Mac
7
8
  pandoc
@@ -40,7 +40,9 @@ classifiers = [
40
40
  ]
41
41
 
42
42
  dependencies = [
43
- "tomli", # Only needed before 3.11
43
+ "tomli", # Only needed before 3.11
44
+ "pyyaml",
45
+ "tqdm",
44
46
  ]
45
47
 
46
48
  [project.optional-dependencies]
@@ -53,9 +55,14 @@ grib = ["requests"]
53
55
  # Loaded by read-the-docs
54
56
  # `pip install .[docs]`
55
57
  docs = [
56
- "tomli", # Only needed before 3.11
58
+ "tomli", # Only needed before 3.11
57
59
  "termcolor",
58
60
  "requests",
61
+ "sphinx",
62
+ "sphinx_rtd_theme",
63
+ "nbsphinx",
64
+ "pandoc",
65
+ "sphinx_argparse",
59
66
  ]
60
67
 
61
68
  all = [
@@ -72,10 +79,6 @@ dev = [
72
79
  "nvsmi",
73
80
  "termcolor",
74
81
  "requests",
75
- "sphinx",
76
- "sphinx_rtd_theme",
77
- "nbsphinx",
78
- "pandoc",
79
82
  ]
80
83
 
81
84
  [project.urls]
@@ -85,6 +88,11 @@ Repository = "https://github.com/ecmwf/anemoi-utils/"
85
88
  Issues = "https://github.com/ecmwf/anemoi-utils/issues"
86
89
  # Changelog = "https://github.com/ecmwf/anemoi-utils/CHANGELOG.md"
87
90
 
91
+ [project.scripts]
92
+ anemoi-utils = "anemoi.utils.__main__:main"
88
93
 
89
94
  [tool.setuptools_scm]
90
95
  version_file = "src/anemoi/utils/_version.py"
96
+
97
+ [tool.setuptools.package-data]
98
+ "anemoi.utils.mars" = ["*.yaml"]
@@ -0,0 +1,72 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+
12
+ import argparse
13
+ import logging
14
+ import sys
15
+ import traceback
16
+
17
+ from . import __version__
18
+ from .commands import COMMANDS
19
+
20
+ LOG = logging.getLogger(__name__)
21
+
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
25
+
26
+ parser.add_argument(
27
+ "--version",
28
+ "-V",
29
+ action="store_true",
30
+ help="show the version and exit",
31
+ )
32
+ parser.add_argument(
33
+ "--debug",
34
+ "-d",
35
+ action="store_true",
36
+ help="Debug mode",
37
+ )
38
+
39
+ subparsers = parser.add_subparsers(help="commands:", dest="command")
40
+ for name, command in COMMANDS.items():
41
+ command_parser = subparsers.add_parser(name, help=command.__doc__)
42
+ command.add_arguments(command_parser)
43
+
44
+ args = parser.parse_args()
45
+
46
+ if args.version:
47
+ print(__version__)
48
+ return
49
+
50
+ if args.command is None:
51
+ parser.print_help()
52
+ return
53
+
54
+ cmd = COMMANDS[args.command]
55
+
56
+ logging.basicConfig(
57
+ format="%(asctime)s %(levelname)s %(message)s",
58
+ datefmt="%Y-%m-%d %H:%M:%S",
59
+ level=logging.DEBUG if args.debug else logging.INFO,
60
+ )
61
+
62
+ try:
63
+ cmd.run(args)
64
+ except ValueError as e:
65
+ traceback.print_exc()
66
+ LOG.error("\n💣 %s", str(e).lstrip())
67
+ LOG.error("💣 Exiting")
68
+ sys.exit(1)
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.8'
16
- __version_tuple__ = version_tuple = (0, 1, 8)
15
+ __version__ = version = '0.2.0'
16
+ __version_tuple__ = version_tuple = (0, 2, 0)
@@ -0,0 +1,179 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+
9
+ """Read and write extra metadata in PyTorch checkpoints files. These files
10
+ are zip archives containing the model weights.
11
+ """
12
+
13
+ import json
14
+ import logging
15
+ import os
16
+ import time
17
+ import zipfile
18
+ from tempfile import TemporaryDirectory
19
+
20
+ import tqdm
21
+
22
+ LOG = logging.getLogger(__name__)
23
+
24
+ DEFAULT_NAME = "ai-models.json"
25
+ DEFAULT_FOLDER = "anemoi-metadata"
26
+
27
+
28
+ def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
29
+ """Check if a checkpoint file has a metadata file
30
+
31
+ Parameters
32
+ ----------
33
+ path : str
34
+ The path to the checkpoint file
35
+ name : str, optional
36
+ The name of the metadata file in the zip archive
37
+
38
+ Returns
39
+ -------
40
+ bool
41
+ True if the metadata file is found
42
+ """
43
+ with zipfile.ZipFile(path, "r") as f:
44
+ for b in f.namelist():
45
+ if os.path.basename(b) == name:
46
+ return True
47
+ return False
48
+
49
+
50
+ def load_metadata(path: str, name: str = DEFAULT_NAME):
51
+ """Load metadata from a checkpoint file
52
+
53
+ Parameters
54
+ ----------
55
+ path : str
56
+ The path to the checkpoint file
57
+ name : str, optional
58
+ The name of the metadata file in the zip archive
59
+
60
+ Returns
61
+ -------
62
+ JSON
63
+ The content of the metadata file
64
+
65
+ Raises
66
+ ------
67
+ ValueError
68
+ If the metadata file is not found
69
+ """
70
+ with zipfile.ZipFile(path, "r") as f:
71
+ metadata = None
72
+ for b in f.namelist():
73
+ if os.path.basename(b) == name:
74
+ if metadata is not None:
75
+ raise ValueError(f"Found two or more '{name}' in {path}.")
76
+ metadata = b
77
+
78
+ if metadata is not None:
79
+ with zipfile.ZipFile(path, "r") as f:
80
+ return json.load(f.open(metadata, "r"))
81
+ else:
82
+ raise ValueError(f"Could not find '{name}' in {path}.")
83
+
84
+
85
+ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER):
86
+ """Save metadata to a checkpoint file
87
+
88
+ Parameters
89
+ ----------
90
+ path : str
91
+ The path to the checkpoint file
92
+ metadata : JSON
93
+ A JSON serializable object
94
+ name : str, optional
95
+ The name of the metadata file in the zip archive
96
+ """
97
+ with zipfile.ZipFile(path, "a") as zipf:
98
+
99
+ directories = set()
100
+
101
+ for b in zipf.namelist():
102
+ directory = os.path.dirname(b)
103
+ while os.path.dirname(directory) not in (".", ""):
104
+ directory = os.path.dirname(directory)
105
+ directories.add(directory)
106
+
107
+ if os.path.basename(b) == name:
108
+ raise ValueError(f"'{name}' already in {path}")
109
+
110
+ if len(directories) != 1:
111
+ # PyTorch checkpoints should have a single directory
112
+ # otherwise PyTorch will complain
113
+ raise ValueError(f"No or multiple directories in the checkpoint {path}, directories={directories}")
114
+
115
+ directory = list(directories)[0]
116
+
117
+ LOG.info("Saving metadata to %s/%s/%s", directory, folder, name)
118
+
119
+ zipf.writestr(
120
+ f"{directory}/{folder}/{name}",
121
+ json.dumps(metadata),
122
+ )
123
+
124
+
125
+ def _edit_metadata(path, name, callback):
126
+ new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"
127
+
128
+ found = False
129
+
130
+ with TemporaryDirectory() as temp_dir:
131
+ zipfile.ZipFile(path, "r").extractall(temp_dir)
132
+ total = 0
133
+ for root, dirs, files in os.walk(temp_dir):
134
+ for f in files:
135
+ total += 1
136
+ full = os.path.join(root, f)
137
+ if f == name:
138
+ found = True
139
+ callback(full)
140
+
141
+ if not found:
142
+ raise ValueError(f"Could not find '{name}' in {path}")
143
+
144
+ with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf:
145
+ with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar:
146
+ for root, dirs, files in os.walk(temp_dir):
147
+ for f in files:
148
+ full = os.path.join(root, f)
149
+ rel = os.path.relpath(full, temp_dir)
150
+ zipf.write(full, rel)
151
+ pbar.update(1)
152
+
153
+ os.rename(new_path, path)
154
+ LOG.info("Updated metadata in %s", path)
155
+
156
+
157
+ def replace_metadata(path, metadata, name=DEFAULT_NAME):
158
+
159
+ if not isinstance(metadata, dict):
160
+ raise ValueError(f"metadata must be a dict, got {type(metadata)}")
161
+
162
+ if "version" not in metadata:
163
+ raise ValueError("metadata must have a 'version' key")
164
+
165
+ def callback(full):
166
+ with open(full, "w") as f:
167
+ json.dump(metadata, f)
168
+
169
+ _edit_metadata(path, name, callback)
170
+
171
+
172
+ def remove_metadata(path, name=DEFAULT_NAME):
173
+
174
+ LOG.info("Removing metadata '%s' from %s", name, path)
175
+
176
+ def callback(full):
177
+ os.remove(full)
178
+
179
+ _edit_metadata(path, name, callback)
@@ -0,0 +1,78 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+ import argparse
12
+ import importlib
13
+ import logging
14
+ import os
15
+ import sys
16
+
17
+ LOG = logging.getLogger(__name__)
18
+
19
+
20
+ def register(here, package, select, fail=None):
21
+ result = {}
22
+ not_available = {}
23
+
24
+ for p in os.listdir(here):
25
+ full = os.path.join(here, p)
26
+ if p.startswith("_"):
27
+ continue
28
+ if not (p.endswith(".py") or (os.path.isdir(full) and os.path.exists(os.path.join(full, "__init__.py")))):
29
+ continue
30
+
31
+ name, _ = os.path.splitext(p)
32
+
33
+ try:
34
+ imported = importlib.import_module(
35
+ f".{name}",
36
+ package=package,
37
+ )
38
+ except ImportError as e:
39
+ not_available[name] = e
40
+ continue
41
+
42
+ obj = select(imported)
43
+ if obj is not None:
44
+ result[name] = obj
45
+
46
+ for name, e in not_available.items():
47
+ if fail is None:
48
+ pass
49
+ if callable(fail):
50
+ result[name] = fail(name, e)
51
+
52
+ return result
53
+
54
+
55
+ class Command:
56
+ def run(self, args):
57
+ raise NotImplementedError(f"Command not implemented: {args.command}")
58
+
59
+
60
+ class Failed(Command):
61
+ def __init__(self, name, error):
62
+ self.name = name
63
+ self.error = error
64
+
65
+ def add_arguments(self, command_parser):
66
+ command_parser.add_argument("x", nargs=argparse.REMAINDER)
67
+
68
+ def run(self, args):
69
+ print(f"Command '{self.name}' not available: {self.error}")
70
+ sys.exit(1)
71
+
72
+
73
+ COMMANDS = register(
74
+ os.path.dirname(__file__),
75
+ __name__,
76
+ lambda x: x.command(),
77
+ lambda name, error: Failed(name, error),
78
+ )
@@ -0,0 +1,61 @@
1
+ #!/usr/bin/env python
2
+ # (C) Copyright 2024 ECMWF.
3
+ #
4
+ # This software is licensed under the terms of the Apache Licence Version 2.0
5
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+ #
10
+
11
+
12
+ import json
13
+
14
+ from . import Command
15
+
16
+
17
+ def visit(x, path, name, value):
18
+ if isinstance(x, dict):
19
+ for k, v in x.items():
20
+ if k == name:
21
+ print(".".join(path), k, v)
22
+
23
+ if v == value:
24
+ print(".".join(path), k, v)
25
+
26
+ path.append(k)
27
+ visit(v, path, name, value)
28
+ path.pop()
29
+
30
+ if isinstance(x, list):
31
+ for i, v in enumerate(x):
32
+ path.append(str(i))
33
+ visit(v, path, name, value)
34
+ path.pop()
35
+
36
+
37
+ class Checkpoint(Command):
38
+
39
+ def add_arguments(self, command_parser):
40
+ command_parser.add_argument("path", help="Path to the checkpoint.")
41
+ command_parser.add_argument("--name", help="Search for a specific name.")
42
+ command_parser.add_argument("--value", help="Search for a specific value.")
43
+
44
+ def run(self, args):
45
+ from anemoi.utils.checkpoints import load_metadata
46
+
47
+ checkpoint = load_metadata(args.path, "*.json")
48
+
49
+ if args.name or args.value:
50
+ visit(
51
+ checkpoint,
52
+ [],
53
+ args.name if args.name is not None else object(),
54
+ args.value if args.value is not None else object(),
55
+ )
56
+ return
57
+
58
+ print(json.dumps(checkpoint, sort_keys=True, indent=4))
59
+
60
+
61
+ command = Checkpoint
@@ -10,6 +10,16 @@ import calendar
10
10
  import datetime
11
11
 
12
12
 
13
+ def normalise_frequency(frequency):
14
+ if isinstance(frequency, int):
15
+ return frequency
16
+ assert isinstance(frequency, str), (type(frequency), frequency)
17
+
18
+ unit = frequency[-1].lower()
19
+ v = int(frequency[:-1])
20
+ return {"h": v, "d": v * 24}[unit]
21
+
22
+
13
23
  def no_time_zone(date):
14
24
  """Remove time zone information from a date.
15
25
 
@@ -27,6 +37,7 @@ def no_time_zone(date):
27
37
  return date.replace(tzinfo=None)
28
38
 
29
39
 
40
+ # this function is use in anemoi-datasets
30
41
  def as_datetime(date):
31
42
  """Convert a date to a datetime object, removing any time zone information.
32
43
 
@@ -162,11 +173,15 @@ class HindcastDatesTimes:
162
173
  """
163
174
 
164
175
  self.reference_dates = reference_dates
165
- self.years = (1, years + 1)
176
+
177
+ if isinstance(years, list):
178
+ self.years = years
179
+ else:
180
+ self.years = range(1, years + 1)
166
181
 
167
182
  def __iter__(self):
168
183
  for reference_date in self.reference_dates:
169
- for year in range(*self.years):
184
+ for year in self.years:
170
185
  if reference_date.month == 2 and reference_date.day == 29:
171
186
  date = datetime.datetime(reference_date.year - year, 2, 28)
172
187
  else:
@@ -246,3 +261,61 @@ class Autumn(DateTimes):
246
261
  _description_
247
262
  """
248
263
  super().__init__(datetime.datetime(year, 9, 1), datetime.datetime(year, 11, 30), **kwargs)
264
+
265
+
266
+ class ConcatDateTimes:
267
+ def __init__(self, *dates):
268
+ if len(dates) == 1 and isinstance(dates[0], list):
269
+ dates = dates[0]
270
+
271
+ self.dates = dates
272
+
273
+ def __iter__(self):
274
+ for date in self.dates:
275
+ yield from date
276
+
277
+
278
+ class EnumDateTimes:
279
+ def __init__(self, dates):
280
+ self.dates = dates
281
+
282
+ def __iter__(self):
283
+ for date in self.dates:
284
+ yield as_datetime(date)
285
+
286
+
287
+ def datetimes_factory(*args, **kwargs):
288
+ if args and kwargs:
289
+ raise ValueError("Cannot provide both args and kwargs for a list of dates")
290
+
291
+ if not args and not kwargs:
292
+ raise ValueError("No dates provided")
293
+
294
+ if kwargs:
295
+ name = kwargs.get("name")
296
+
297
+ if name == "hindcast":
298
+ reference_dates = kwargs["reference_dates"]
299
+ reference_dates = datetimes_factory(reference_dates)
300
+ years = kwargs["years"]
301
+ return HindcastDatesTimes(reference_dates=reference_dates, years=years)
302
+
303
+ kwargs = kwargs.copy()
304
+ if "frequency" in kwargs:
305
+ freq = kwargs.pop("frequency")
306
+ kwargs["increment"] = normalise_frequency(freq)
307
+ return DateTimes(**kwargs)
308
+
309
+ if not any((isinstance(x, dict) or isinstance(x, list)) for x in args):
310
+ return EnumDateTimes(args)
311
+
312
+ if len(args) == 1:
313
+ a = args[0]
314
+
315
+ if isinstance(a, dict):
316
+ return datetimes_factory(**a)
317
+
318
+ if isinstance(a, list):
319
+ return datetimes_factory(*a)
320
+
321
+ return ConcatDateTimes(*[datetimes_factory(a) for a in args])
@@ -0,0 +1,76 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+
9
+ """Utilities for working with Mars requests.
10
+
11
+ Has some konwledge of how certain streams are organised in Mars.
12
+
13
+ """
14
+
15
+ import datetime
16
+ import logging
17
+ import os
18
+
19
+ import yaml
20
+
21
+ LOG = logging.getLogger(__name__)
22
+
23
+ DEFAULT_MARS_LABELLING = {
24
+ "class": "od",
25
+ "type": "an",
26
+ "stream": "oper",
27
+ "expver": "0001",
28
+ }
29
+
30
+
31
+ def _expand_mars_labelling(request):
32
+ """Expand the request with the default Mars labelling.
33
+
34
+ The default Mars labelling is:
35
+
36
+ {'class': 'od',
37
+ 'type': 'an',
38
+ 'stream': 'oper',
39
+ 'expver': '0001'}
40
+
41
+ """
42
+ result = DEFAULT_MARS_LABELLING.copy()
43
+ result.update(request)
44
+ return result
45
+
46
+
47
+ STREAMS = None
48
+
49
+
50
+ def _lookup_mars_stream(request):
51
+ global STREAMS
52
+
53
+ if STREAMS is None:
54
+
55
+ with open(os.path.join(os.path.dirname(__file__), "mars.yaml")) as f:
56
+ STREAMS = yaml.safe_load(f)
57
+
58
+ request = _expand_mars_labelling(request)
59
+ for s in STREAMS:
60
+ match = s["match"]
61
+ if all(request.get(k) == v for k, v in match.items()):
62
+ return s["info"]
63
+
64
+
65
+ def recenter(date, center, members):
66
+
67
+ center = _lookup_mars_stream(center)
68
+ members = _lookup_mars_stream(members)
69
+
70
+ return (center, members)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ date = datetime.datetime(2024, 5, 9, 0)
75
+
76
+ print(recenter(date, {"type": "an"}, {"stream": "elda"}))
@@ -0,0 +1,5 @@
1
+ - match:
2
+ class: od
3
+ stream: elda
4
+ info:
5
+ runs: [6, 18]
@@ -0,0 +1,32 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ """Logging utilities."""
9
+
10
+ import logging
11
+ import time
12
+
13
+ from .humanize import seconds
14
+
15
+ LOGGER = logging.getLogger(__name__)
16
+
17
+
18
+ class Timer:
19
+ def __init__(self, title, logger=LOGGER):
20
+ self.title = title
21
+ self.start = time.time()
22
+ self.logger = logger
23
+
24
+ def __enter__(self):
25
+ return self
26
+
27
+ @property
28
+ def elapsed(self):
29
+ return time.time() - self.start
30
+
31
+ def __exit__(self, *args):
32
+ self.logger.info("%s: %s.", self.title, seconds(self.elapsed))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: anemoi-utils
3
- Version: 0.1.8
3
+ Version: 0.2.0
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
6
6
  License: Apache License
@@ -223,6 +223,8 @@ Classifier: Operating System :: OS Independent
223
223
  Requires-Python: >=3.9
224
224
  License-File: LICENSE
225
225
  Requires-Dist: tomli
226
+ Requires-Dist: pyyaml
227
+ Requires-Dist: tqdm
226
228
  Provides-Extra: provenance
227
229
  Requires-Dist: GitPython; extra == "provenance"
228
230
  Requires-Dist: nvsmi; extra == "provenance"
@@ -234,6 +236,11 @@ Provides-Extra: docs
234
236
  Requires-Dist: tomli; extra == "docs"
235
237
  Requires-Dist: termcolor; extra == "docs"
236
238
  Requires-Dist: requests; extra == "docs"
239
+ Requires-Dist: sphinx; extra == "docs"
240
+ Requires-Dist: sphinx_rtd_theme; extra == "docs"
241
+ Requires-Dist: nbsphinx; extra == "docs"
242
+ Requires-Dist: pandoc; extra == "docs"
243
+ Requires-Dist: sphinx_argparse; extra == "docs"
237
244
  Provides-Extra: all
238
245
  Requires-Dist: tomli; extra == "all"
239
246
  Requires-Dist: GitPython; extra == "all"
@@ -246,7 +253,3 @@ Requires-Dist: GitPython; extra == "dev"
246
253
  Requires-Dist: nvsmi; extra == "dev"
247
254
  Requires-Dist: termcolor; extra == "dev"
248
255
  Requires-Dist: requests; extra == "dev"
249
- Requires-Dist: sphinx; extra == "dev"
250
- Requires-Dist: sphinx_rtd_theme; extra == "dev"
251
- Requires-Dist: nbsphinx; extra == "dev"
252
- Requires-Dist: pandoc; extra == "dev"
@@ -21,6 +21,7 @@ docs/modules/humanize.rst
21
21
  docs/modules/provenance.rst
22
22
  docs/modules/text.rst
23
23
  src/anemoi/utils/__init__.py
24
+ src/anemoi/utils/__main__.py
24
25
  src/anemoi/utils/_version.py
25
26
  src/anemoi/utils/caching.py
26
27
  src/anemoi/utils/checkpoints.py
@@ -30,10 +31,17 @@ src/anemoi/utils/grib.py
30
31
  src/anemoi/utils/humanize.py
31
32
  src/anemoi/utils/provenance.py
32
33
  src/anemoi/utils/text.py
34
+ src/anemoi/utils/timer.py
35
+ src/anemoi/utils/commands/__init__.py
36
+ src/anemoi/utils/commands/checkpoint.py
37
+ src/anemoi/utils/mars/__init__.py
38
+ src/anemoi/utils/mars/mars.yaml
33
39
  src/anemoi_utils.egg-info/PKG-INFO
34
40
  src/anemoi_utils.egg-info/SOURCES.txt
35
41
  src/anemoi_utils.egg-info/dependency_links.txt
42
+ src/anemoi_utils.egg-info/entry_points.txt
36
43
  src/anemoi_utils.egg-info/requires.txt
37
44
  src/anemoi_utils.egg-info/top_level.txt
38
45
  tests/requirements.txt
46
+ tests/test_dates.py
39
47
  tests/test_utils.py
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ anemoi-utils = anemoi.utils.__main__:main
@@ -1,4 +1,6 @@
1
1
  tomli
2
+ pyyaml
3
+ tqdm
2
4
 
3
5
  [all]
4
6
  tomli
@@ -13,15 +15,16 @@ GitPython
13
15
  nvsmi
14
16
  termcolor
15
17
  requests
16
- sphinx
17
- sphinx_rtd_theme
18
- nbsphinx
19
- pandoc
20
18
 
21
19
  [docs]
22
20
  tomli
23
21
  termcolor
24
22
  requests
23
+ sphinx
24
+ sphinx_rtd_theme
25
+ nbsphinx
26
+ pandoc
27
+ sphinx_argparse
25
28
 
26
29
  [grib]
27
30
  requests
@@ -0,0 +1,113 @@
1
+ import datetime
2
+ from textwrap import dedent
3
+
4
+ import yaml
5
+
6
+ from anemoi.utils.dates import datetimes_factory
7
+
8
+
9
+ def _(txt):
10
+ txt = dedent(txt)
11
+ config = yaml.safe_load(txt)
12
+ return datetimes_factory(config)
13
+
14
+
15
+ def test_date_1():
16
+ d = _(
17
+ """
18
+ - 2023-01-01
19
+ - 2023-01-02
20
+ - 2023-01-03
21
+ """
22
+ )
23
+ assert len(list(d)) == 3
24
+
25
+
26
+ def test_date_2():
27
+ d = _(
28
+ """
29
+ start: 2023-01-01
30
+ end: 2023-01-07
31
+ frequency: 12
32
+ day_of_week: [monday, friday]
33
+ """
34
+ )
35
+ assert len(list(d)) == 4
36
+
37
+
38
+ def test_date_3():
39
+ d = _(
40
+ """
41
+ - start: 2023-01-01
42
+ end: 2023-01-03
43
+ frequency: 24
44
+ - start: 2024-01-01T06:00:00
45
+ end: 2024-01-02T18:00:00
46
+ frequency: 6h
47
+ """
48
+ )
49
+ assert datetime.datetime(2023, 1, 1, 0) in d
50
+ assert datetime.datetime(2023, 1, 2, 0) in d
51
+ assert datetime.datetime(2023, 1, 3, 0) in d
52
+ assert datetime.datetime(2024, 1, 1, 6) in d
53
+ assert datetime.datetime(2024, 1, 1, 12) in d
54
+ assert datetime.datetime(2024, 1, 1, 18) in d
55
+ assert datetime.datetime(2024, 1, 2, 0) in d
56
+ assert datetime.datetime(2024, 1, 2, 6) in d
57
+ assert datetime.datetime(2024, 1, 2, 12) in d
58
+ assert datetime.datetime(2024, 1, 2, 18) in d
59
+ assert len(list(d)) == 10
60
+
61
+
62
+ def test_date_hindcast_1():
63
+ d = _(
64
+ """
65
+ - name: hindcast
66
+ reference_dates:
67
+ start: 2023-01-01
68
+ end: 2023-01-03
69
+ frequency: 24
70
+ years: 20
71
+ """
72
+ )
73
+ assert len(list(d)) == 60
74
+
75
+
76
+ def test_date_hindcast_2():
77
+ d = _(
78
+ """
79
+ - name: hindcast
80
+ reference_dates:
81
+ start: 2023-01-01
82
+ end: 2023-01-03
83
+ frequency: 24
84
+ years: [2018, 2019, 2020, 2021]
85
+ """
86
+ )
87
+ assert len(list(d)) == 12
88
+
89
+
90
+ def test_date_hindcast_3():
91
+ d = _(
92
+ """
93
+ - name: hindcast
94
+ reference_dates:
95
+ start: 2022-12-25 00:00:00
96
+ end: 2022-12-31 12:00:00
97
+ frequency: 12h
98
+ day_of_week: tuesday
99
+ years: [2018, 2019, 2020, 2021]
100
+ """
101
+ )
102
+ print(list(d))
103
+ assert len(list(d)) == 8
104
+
105
+
106
+ if __name__ == "__main__":
107
+ test_functions = [
108
+ obj for name, obj in globals().items() if name.startswith("test_") and isinstance(obj, type(lambda: 0))
109
+ ]
110
+ for test_func in test_functions:
111
+ print(f"Running test: {test_func.__name__}")
112
+ test_func()
113
+ print("All tests passed!")
@@ -1,75 +0,0 @@
1
- # (C) Copyright 2024 ECMWF.
2
- #
3
- # This software is licensed under the terms of the Apache Licence Version 2.0
4
- # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
- # In applying this licence, ECMWF does not waive the privileges and immunities
6
- # granted to it by virtue of its status as an intergovernmental organisation
7
- # nor does it submit to any jurisdiction.
8
-
9
- """Read and write extra metadata in PyTorch checkpoints files. These files
10
- are zip archives containing the model weights.
11
- """
12
-
13
- import json
14
- import logging
15
- import os
16
- import zipfile
17
-
18
- LOG = logging.getLogger(__name__)
19
-
20
- DEFAULT_NAME = "anemoi-metadata.json"
21
-
22
-
23
- def load_metadata(path: str, name: str = DEFAULT_NAME):
24
- """Load metadata from a checkpoint file
25
-
26
- Parameters
27
- ----------
28
- path : str
29
- The path to the checkpoint file
30
- name : str, optional
31
- The name of the metadata file in the zip archive
32
-
33
- Returns
34
- -------
35
- JSON
36
- The content of the metadata file
37
-
38
- Raises
39
- ------
40
- ValueError
41
- If the metadata file is not found
42
- """
43
- with zipfile.ZipFile(path, "r") as f:
44
- metadata = None
45
- for b in f.namelist():
46
- if os.path.basename(b) == name:
47
- if metadata is not None:
48
- LOG.warning(f"Found two '{name}' if {path}")
49
- metadata = b
50
-
51
- if metadata is not None:
52
- with zipfile.ZipFile(path, "r") as f:
53
- return json.load(f.open(metadata, "r"))
54
- else:
55
- raise ValueError(f"Could not find {name} in {path}")
56
-
57
-
58
- def save_metadata(path, metadata, name=DEFAULT_NAME):
59
- """Save metadata to a checkpoint file
60
-
61
- Parameters
62
- ----------
63
- path : str
64
- The path to the checkpoint file
65
- metadata : JSON
66
- A JSON serializable object
67
- name : str, optional
68
- The name of the metadata file in the zip archive
69
- """
70
- with zipfile.ZipFile(path, "a") as zipf:
71
- base, _ = os.path.splitext(os.path.basename(path))
72
- zipf.writestr(
73
- f"{base}/{name}",
74
- json.dumps(metadata),
75
- )
File without changes
File without changes
File without changes
File without changes