anemoi-utils 0.1.8__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of anemoi-utils might be problematic. Click here for more details.
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/.gitignore +1 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/.pre-commit-config.yaml +1 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/PKG-INFO +8 -5
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/conf.py +5 -9
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/index.rst +3 -3
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/requirements.txt +1 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/pyproject.toml +14 -6
- anemoi_utils-0.2.0/src/anemoi/utils/__main__.py +72 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/_version.py +2 -2
- anemoi_utils-0.2.0/src/anemoi/utils/checkpoints.py +179 -0
- anemoi_utils-0.2.0/src/anemoi/utils/commands/__init__.py +78 -0
- anemoi_utils-0.2.0/src/anemoi/utils/commands/checkpoint.py +61 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/dates.py +75 -2
- anemoi_utils-0.2.0/src/anemoi/utils/mars/__init__.py +76 -0
- anemoi_utils-0.2.0/src/anemoi/utils/mars/mars.yaml +5 -0
- anemoi_utils-0.2.0/src/anemoi/utils/timer.py +32 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/PKG-INFO +8 -5
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/SOURCES.txt +8 -0
- anemoi_utils-0.2.0/src/anemoi_utils.egg-info/entry_points.txt +2 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/requires.txt +7 -4
- anemoi_utils-0.2.0/tests/test_dates.py +113 -0
- anemoi_utils-0.1.8/src/anemoi/utils/checkpoints.py +0 -75
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/.github/workflows/python-publish.yml +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/.readthedocs.yaml +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/LICENSE +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/README.md +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/Makefile +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/_static/logo.png +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/_static/style.css +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/_templates/.gitkeep +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/installing.rst +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/checkpoints.rst +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/config.rst +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/dates.rst +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/grib.rst +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/humanize.rst +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/provenance.rst +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/docs/modules/text.rst +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/setup.cfg +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/__init__.py +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/caching.py +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/config.py +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/grib.py +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/humanize.py +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/provenance.py +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi/utils/text.py +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/dependency_links.txt +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/src/anemoi_utils.egg-info/top_level.txt +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/tests/requirements.txt +0 -0
- {anemoi_utils-0.1.8 → anemoi_utils-0.2.0}/tests/test_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: anemoi-utils
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A package to hold various functions to support training of ML models on ECMWF data.
|
|
5
5
|
Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
|
|
6
6
|
License: Apache License
|
|
@@ -223,6 +223,8 @@ Classifier: Operating System :: OS Independent
|
|
|
223
223
|
Requires-Python: >=3.9
|
|
224
224
|
License-File: LICENSE
|
|
225
225
|
Requires-Dist: tomli
|
|
226
|
+
Requires-Dist: pyyaml
|
|
227
|
+
Requires-Dist: tqdm
|
|
226
228
|
Provides-Extra: provenance
|
|
227
229
|
Requires-Dist: GitPython; extra == "provenance"
|
|
228
230
|
Requires-Dist: nvsmi; extra == "provenance"
|
|
@@ -234,6 +236,11 @@ Provides-Extra: docs
|
|
|
234
236
|
Requires-Dist: tomli; extra == "docs"
|
|
235
237
|
Requires-Dist: termcolor; extra == "docs"
|
|
236
238
|
Requires-Dist: requests; extra == "docs"
|
|
239
|
+
Requires-Dist: sphinx; extra == "docs"
|
|
240
|
+
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
|
241
|
+
Requires-Dist: nbsphinx; extra == "docs"
|
|
242
|
+
Requires-Dist: pandoc; extra == "docs"
|
|
243
|
+
Requires-Dist: sphinx_argparse; extra == "docs"
|
|
237
244
|
Provides-Extra: all
|
|
238
245
|
Requires-Dist: tomli; extra == "all"
|
|
239
246
|
Requires-Dist: GitPython; extra == "all"
|
|
@@ -246,7 +253,3 @@ Requires-Dist: GitPython; extra == "dev"
|
|
|
246
253
|
Requires-Dist: nvsmi; extra == "dev"
|
|
247
254
|
Requires-Dist: termcolor; extra == "dev"
|
|
248
255
|
Requires-Dist: requests; extra == "dev"
|
|
249
|
-
Requires-Dist: sphinx; extra == "dev"
|
|
250
|
-
Requires-Dist: sphinx_rtd_theme; extra == "dev"
|
|
251
|
-
Requires-Dist: nbsphinx; extra == "dev"
|
|
252
|
-
Requires-Dist: pandoc; extra == "dev"
|
|
@@ -14,14 +14,9 @@ import datetime
|
|
|
14
14
|
import os
|
|
15
15
|
import sys
|
|
16
16
|
|
|
17
|
-
sys.path.insert(0, os.path.join(os.path.abspath(".."), "src"))
|
|
18
|
-
|
|
19
|
-
|
|
20
17
|
read_the_docs_build = os.environ.get("READTHEDOCS", None) == "True"
|
|
21
18
|
|
|
22
|
-
|
|
23
|
-
# sys.path.insert(0, top)
|
|
24
|
-
|
|
19
|
+
sys.path.insert(0, os.path.join(os.path.abspath(".."), "src"))
|
|
25
20
|
|
|
26
21
|
source_suffix = ".rst"
|
|
27
22
|
master_doc = "index"
|
|
@@ -32,7 +27,7 @@ html_logo = "_static/logo.png"
|
|
|
32
27
|
|
|
33
28
|
# -- Project information -----------------------------------------------------
|
|
34
29
|
|
|
35
|
-
project = "Anemoi"
|
|
30
|
+
project = "Anemoi Utils"
|
|
36
31
|
|
|
37
32
|
author = "ECMWF"
|
|
38
33
|
|
|
@@ -45,9 +40,9 @@ else:
|
|
|
45
40
|
copyright = "%s, ECMWF" % (years,)
|
|
46
41
|
|
|
47
42
|
try:
|
|
48
|
-
|
|
43
|
+
from anemoi.utils._version import __version__
|
|
49
44
|
|
|
50
|
-
release =
|
|
45
|
+
release = __version__
|
|
51
46
|
except ImportError:
|
|
52
47
|
release = "0.0.0"
|
|
53
48
|
|
|
@@ -65,6 +60,7 @@ extensions = [
|
|
|
65
60
|
"sphinx.ext.intersphinx",
|
|
66
61
|
"sphinx.ext.autodoc",
|
|
67
62
|
"sphinx.ext.napoleon",
|
|
63
|
+
"sphinxarg.ext",
|
|
68
64
|
]
|
|
69
65
|
|
|
70
66
|
# Add any paths that contain templates here, relative to this directory.
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
.. _index-page:
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
Welcome to
|
|
7
|
-
|
|
5
|
+
##########################################
|
|
6
|
+
Welcome to `anemoi-utils` documentation!
|
|
7
|
+
##########################################
|
|
8
8
|
|
|
9
9
|
.. warning::
|
|
10
10
|
|
|
@@ -40,7 +40,9 @@ classifiers = [
|
|
|
40
40
|
]
|
|
41
41
|
|
|
42
42
|
dependencies = [
|
|
43
|
-
"tomli",
|
|
43
|
+
"tomli", # Only needed before 3.11
|
|
44
|
+
"pyyaml",
|
|
45
|
+
"tqdm",
|
|
44
46
|
]
|
|
45
47
|
|
|
46
48
|
[project.optional-dependencies]
|
|
@@ -53,9 +55,14 @@ grib = ["requests"]
|
|
|
53
55
|
# Loaded by read-the-docs
|
|
54
56
|
# `pip install .[docs]`
|
|
55
57
|
docs = [
|
|
56
|
-
"tomli",
|
|
58
|
+
"tomli", # Only needed before 3.11
|
|
57
59
|
"termcolor",
|
|
58
60
|
"requests",
|
|
61
|
+
"sphinx",
|
|
62
|
+
"sphinx_rtd_theme",
|
|
63
|
+
"nbsphinx",
|
|
64
|
+
"pandoc",
|
|
65
|
+
"sphinx_argparse",
|
|
59
66
|
]
|
|
60
67
|
|
|
61
68
|
all = [
|
|
@@ -72,10 +79,6 @@ dev = [
|
|
|
72
79
|
"nvsmi",
|
|
73
80
|
"termcolor",
|
|
74
81
|
"requests",
|
|
75
|
-
"sphinx",
|
|
76
|
-
"sphinx_rtd_theme",
|
|
77
|
-
"nbsphinx",
|
|
78
|
-
"pandoc",
|
|
79
82
|
]
|
|
80
83
|
|
|
81
84
|
[project.urls]
|
|
@@ -85,6 +88,11 @@ Repository = "https://github.com/ecmwf/anemoi-utils/"
|
|
|
85
88
|
Issues = "https://github.com/ecmwf/anemoi-utils/issues"
|
|
86
89
|
# Changelog = "https://github.com/ecmwf/anemoi-utils/CHANGELOG.md"
|
|
87
90
|
|
|
91
|
+
[project.scripts]
|
|
92
|
+
anemoi-utils = "anemoi.utils.__main__:main"
|
|
88
93
|
|
|
89
94
|
[tool.setuptools_scm]
|
|
90
95
|
version_file = "src/anemoi/utils/_version.py"
|
|
96
|
+
|
|
97
|
+
[tool.setuptools.package-data]
|
|
98
|
+
"anemoi.utils.mars" = ["*.yaml"]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# (C) Copyright 2024 ECMWF.
|
|
3
|
+
#
|
|
4
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
5
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
#
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import logging
|
|
14
|
+
import sys
|
|
15
|
+
import traceback
|
|
16
|
+
|
|
17
|
+
from . import __version__
|
|
18
|
+
from .commands import COMMANDS
|
|
19
|
+
|
|
20
|
+
LOG = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def main():
|
|
24
|
+
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
25
|
+
|
|
26
|
+
parser.add_argument(
|
|
27
|
+
"--version",
|
|
28
|
+
"-V",
|
|
29
|
+
action="store_true",
|
|
30
|
+
help="show the version and exit",
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--debug",
|
|
34
|
+
"-d",
|
|
35
|
+
action="store_true",
|
|
36
|
+
help="Debug mode",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
subparsers = parser.add_subparsers(help="commands:", dest="command")
|
|
40
|
+
for name, command in COMMANDS.items():
|
|
41
|
+
command_parser = subparsers.add_parser(name, help=command.__doc__)
|
|
42
|
+
command.add_arguments(command_parser)
|
|
43
|
+
|
|
44
|
+
args = parser.parse_args()
|
|
45
|
+
|
|
46
|
+
if args.version:
|
|
47
|
+
print(__version__)
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
if args.command is None:
|
|
51
|
+
parser.print_help()
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
cmd = COMMANDS[args.command]
|
|
55
|
+
|
|
56
|
+
logging.basicConfig(
|
|
57
|
+
format="%(asctime)s %(levelname)s %(message)s",
|
|
58
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
59
|
+
level=logging.DEBUG if args.debug else logging.INFO,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
cmd.run(args)
|
|
64
|
+
except ValueError as e:
|
|
65
|
+
traceback.print_exc()
|
|
66
|
+
LOG.error("\n💣 %s", str(e).lstrip())
|
|
67
|
+
LOG.error("💣 Exiting")
|
|
68
|
+
sys.exit(1)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
if __name__ == "__main__":
|
|
72
|
+
main()
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
|
|
9
|
+
"""Read and write extra metadata in PyTorch checkpoints files. These files
|
|
10
|
+
are zip archives containing the model weights.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
import zipfile
|
|
18
|
+
from tempfile import TemporaryDirectory
|
|
19
|
+
|
|
20
|
+
import tqdm
|
|
21
|
+
|
|
22
|
+
LOG = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
DEFAULT_NAME = "ai-models.json"
|
|
25
|
+
DEFAULT_FOLDER = "anemoi-metadata"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
|
|
29
|
+
"""Check if a checkpoint file has a metadata file
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
path : str
|
|
34
|
+
The path to the checkpoint file
|
|
35
|
+
name : str, optional
|
|
36
|
+
The name of the metadata file in the zip archive
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
bool
|
|
41
|
+
True if the metadata file is found
|
|
42
|
+
"""
|
|
43
|
+
with zipfile.ZipFile(path, "r") as f:
|
|
44
|
+
for b in f.namelist():
|
|
45
|
+
if os.path.basename(b) == name:
|
|
46
|
+
return True
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def load_metadata(path: str, name: str = DEFAULT_NAME):
|
|
51
|
+
"""Load metadata from a checkpoint file
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
path : str
|
|
56
|
+
The path to the checkpoint file
|
|
57
|
+
name : str, optional
|
|
58
|
+
The name of the metadata file in the zip archive
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
JSON
|
|
63
|
+
The content of the metadata file
|
|
64
|
+
|
|
65
|
+
Raises
|
|
66
|
+
------
|
|
67
|
+
ValueError
|
|
68
|
+
If the metadata file is not found
|
|
69
|
+
"""
|
|
70
|
+
with zipfile.ZipFile(path, "r") as f:
|
|
71
|
+
metadata = None
|
|
72
|
+
for b in f.namelist():
|
|
73
|
+
if os.path.basename(b) == name:
|
|
74
|
+
if metadata is not None:
|
|
75
|
+
raise ValueError(f"Found two or more '{name}' in {path}.")
|
|
76
|
+
metadata = b
|
|
77
|
+
|
|
78
|
+
if metadata is not None:
|
|
79
|
+
with zipfile.ZipFile(path, "r") as f:
|
|
80
|
+
return json.load(f.open(metadata, "r"))
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f"Could not find '{name}' in {path}.")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER):
|
|
86
|
+
"""Save metadata to a checkpoint file
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
path : str
|
|
91
|
+
The path to the checkpoint file
|
|
92
|
+
metadata : JSON
|
|
93
|
+
A JSON serializable object
|
|
94
|
+
name : str, optional
|
|
95
|
+
The name of the metadata file in the zip archive
|
|
96
|
+
"""
|
|
97
|
+
with zipfile.ZipFile(path, "a") as zipf:
|
|
98
|
+
|
|
99
|
+
directories = set()
|
|
100
|
+
|
|
101
|
+
for b in zipf.namelist():
|
|
102
|
+
directory = os.path.dirname(b)
|
|
103
|
+
while os.path.dirname(directory) not in (".", ""):
|
|
104
|
+
directory = os.path.dirname(directory)
|
|
105
|
+
directories.add(directory)
|
|
106
|
+
|
|
107
|
+
if os.path.basename(b) == name:
|
|
108
|
+
raise ValueError(f"'{name}' already in {path}")
|
|
109
|
+
|
|
110
|
+
if len(directories) != 1:
|
|
111
|
+
# PyTorch checkpoints should have a single directory
|
|
112
|
+
# otherwise PyTorch will complain
|
|
113
|
+
raise ValueError(f"No or multiple directories in the checkpoint {path}, directories={directories}")
|
|
114
|
+
|
|
115
|
+
directory = list(directories)[0]
|
|
116
|
+
|
|
117
|
+
LOG.info("Saving metadata to %s/%s/%s", directory, folder, name)
|
|
118
|
+
|
|
119
|
+
zipf.writestr(
|
|
120
|
+
f"{directory}/{folder}/{name}",
|
|
121
|
+
json.dumps(metadata),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _edit_metadata(path, name, callback):
|
|
126
|
+
new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"
|
|
127
|
+
|
|
128
|
+
found = False
|
|
129
|
+
|
|
130
|
+
with TemporaryDirectory() as temp_dir:
|
|
131
|
+
zipfile.ZipFile(path, "r").extractall(temp_dir)
|
|
132
|
+
total = 0
|
|
133
|
+
for root, dirs, files in os.walk(temp_dir):
|
|
134
|
+
for f in files:
|
|
135
|
+
total += 1
|
|
136
|
+
full = os.path.join(root, f)
|
|
137
|
+
if f == name:
|
|
138
|
+
found = True
|
|
139
|
+
callback(full)
|
|
140
|
+
|
|
141
|
+
if not found:
|
|
142
|
+
raise ValueError(f"Could not find '{name}' in {path}")
|
|
143
|
+
|
|
144
|
+
with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
|
145
|
+
with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar:
|
|
146
|
+
for root, dirs, files in os.walk(temp_dir):
|
|
147
|
+
for f in files:
|
|
148
|
+
full = os.path.join(root, f)
|
|
149
|
+
rel = os.path.relpath(full, temp_dir)
|
|
150
|
+
zipf.write(full, rel)
|
|
151
|
+
pbar.update(1)
|
|
152
|
+
|
|
153
|
+
os.rename(new_path, path)
|
|
154
|
+
LOG.info("Updated metadata in %s", path)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def replace_metadata(path, metadata, name=DEFAULT_NAME):
|
|
158
|
+
|
|
159
|
+
if not isinstance(metadata, dict):
|
|
160
|
+
raise ValueError(f"metadata must be a dict, got {type(metadata)}")
|
|
161
|
+
|
|
162
|
+
if "version" not in metadata:
|
|
163
|
+
raise ValueError("metadata must have a 'version' key")
|
|
164
|
+
|
|
165
|
+
def callback(full):
|
|
166
|
+
with open(full, "w") as f:
|
|
167
|
+
json.dump(metadata, f)
|
|
168
|
+
|
|
169
|
+
_edit_metadata(path, name, callback)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def remove_metadata(path, name=DEFAULT_NAME):
|
|
173
|
+
|
|
174
|
+
LOG.info("Removing metadata '%s' from %s", name, path)
|
|
175
|
+
|
|
176
|
+
def callback(full):
|
|
177
|
+
os.remove(full)
|
|
178
|
+
|
|
179
|
+
_edit_metadata(path, name, callback)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# (C) Copyright 2024 ECMWF.
|
|
3
|
+
#
|
|
4
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
5
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
#
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import importlib
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
import sys
|
|
16
|
+
|
|
17
|
+
LOG = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def register(here, package, select, fail=None):
|
|
21
|
+
result = {}
|
|
22
|
+
not_available = {}
|
|
23
|
+
|
|
24
|
+
for p in os.listdir(here):
|
|
25
|
+
full = os.path.join(here, p)
|
|
26
|
+
if p.startswith("_"):
|
|
27
|
+
continue
|
|
28
|
+
if not (p.endswith(".py") or (os.path.isdir(full) and os.path.exists(os.path.join(full, "__init__.py")))):
|
|
29
|
+
continue
|
|
30
|
+
|
|
31
|
+
name, _ = os.path.splitext(p)
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
imported = importlib.import_module(
|
|
35
|
+
f".{name}",
|
|
36
|
+
package=package,
|
|
37
|
+
)
|
|
38
|
+
except ImportError as e:
|
|
39
|
+
not_available[name] = e
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
obj = select(imported)
|
|
43
|
+
if obj is not None:
|
|
44
|
+
result[name] = obj
|
|
45
|
+
|
|
46
|
+
for name, e in not_available.items():
|
|
47
|
+
if fail is None:
|
|
48
|
+
pass
|
|
49
|
+
if callable(fail):
|
|
50
|
+
result[name] = fail(name, e)
|
|
51
|
+
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Command:
|
|
56
|
+
def run(self, args):
|
|
57
|
+
raise NotImplementedError(f"Command not implemented: {args.command}")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Failed(Command):
|
|
61
|
+
def __init__(self, name, error):
|
|
62
|
+
self.name = name
|
|
63
|
+
self.error = error
|
|
64
|
+
|
|
65
|
+
def add_arguments(self, command_parser):
|
|
66
|
+
command_parser.add_argument("x", nargs=argparse.REMAINDER)
|
|
67
|
+
|
|
68
|
+
def run(self, args):
|
|
69
|
+
print(f"Command '{self.name}' not available: {self.error}")
|
|
70
|
+
sys.exit(1)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
COMMANDS = register(
|
|
74
|
+
os.path.dirname(__file__),
|
|
75
|
+
__name__,
|
|
76
|
+
lambda x: x.command(),
|
|
77
|
+
lambda name, error: Failed(name, error),
|
|
78
|
+
)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# (C) Copyright 2024 ECMWF.
|
|
3
|
+
#
|
|
4
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
5
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
#
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
|
|
14
|
+
from . import Command
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def visit(x, path, name, value):
|
|
18
|
+
if isinstance(x, dict):
|
|
19
|
+
for k, v in x.items():
|
|
20
|
+
if k == name:
|
|
21
|
+
print(".".join(path), k, v)
|
|
22
|
+
|
|
23
|
+
if v == value:
|
|
24
|
+
print(".".join(path), k, v)
|
|
25
|
+
|
|
26
|
+
path.append(k)
|
|
27
|
+
visit(v, path, name, value)
|
|
28
|
+
path.pop()
|
|
29
|
+
|
|
30
|
+
if isinstance(x, list):
|
|
31
|
+
for i, v in enumerate(x):
|
|
32
|
+
path.append(str(i))
|
|
33
|
+
visit(v, path, name, value)
|
|
34
|
+
path.pop()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Checkpoint(Command):
|
|
38
|
+
|
|
39
|
+
def add_arguments(self, command_parser):
|
|
40
|
+
command_parser.add_argument("path", help="Path to the checkpoint.")
|
|
41
|
+
command_parser.add_argument("--name", help="Search for a specific name.")
|
|
42
|
+
command_parser.add_argument("--value", help="Search for a specific value.")
|
|
43
|
+
|
|
44
|
+
def run(self, args):
|
|
45
|
+
from anemoi.utils.checkpoints import load_metadata
|
|
46
|
+
|
|
47
|
+
checkpoint = load_metadata(args.path, "*.json")
|
|
48
|
+
|
|
49
|
+
if args.name or args.value:
|
|
50
|
+
visit(
|
|
51
|
+
checkpoint,
|
|
52
|
+
[],
|
|
53
|
+
args.name if args.name is not None else object(),
|
|
54
|
+
args.value if args.value is not None else object(),
|
|
55
|
+
)
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
print(json.dumps(checkpoint, sort_keys=True, indent=4))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
command = Checkpoint
|
|
@@ -10,6 +10,16 @@ import calendar
|
|
|
10
10
|
import datetime
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
def normalise_frequency(frequency):
|
|
14
|
+
if isinstance(frequency, int):
|
|
15
|
+
return frequency
|
|
16
|
+
assert isinstance(frequency, str), (type(frequency), frequency)
|
|
17
|
+
|
|
18
|
+
unit = frequency[-1].lower()
|
|
19
|
+
v = int(frequency[:-1])
|
|
20
|
+
return {"h": v, "d": v * 24}[unit]
|
|
21
|
+
|
|
22
|
+
|
|
13
23
|
def no_time_zone(date):
|
|
14
24
|
"""Remove time zone information from a date.
|
|
15
25
|
|
|
@@ -27,6 +37,7 @@ def no_time_zone(date):
|
|
|
27
37
|
return date.replace(tzinfo=None)
|
|
28
38
|
|
|
29
39
|
|
|
40
|
+
# this function is use in anemoi-datasets
|
|
30
41
|
def as_datetime(date):
|
|
31
42
|
"""Convert a date to a datetime object, removing any time zone information.
|
|
32
43
|
|
|
@@ -162,11 +173,15 @@ class HindcastDatesTimes:
|
|
|
162
173
|
"""
|
|
163
174
|
|
|
164
175
|
self.reference_dates = reference_dates
|
|
165
|
-
|
|
176
|
+
|
|
177
|
+
if isinstance(years, list):
|
|
178
|
+
self.years = years
|
|
179
|
+
else:
|
|
180
|
+
self.years = range(1, years + 1)
|
|
166
181
|
|
|
167
182
|
def __iter__(self):
|
|
168
183
|
for reference_date in self.reference_dates:
|
|
169
|
-
for year in
|
|
184
|
+
for year in self.years:
|
|
170
185
|
if reference_date.month == 2 and reference_date.day == 29:
|
|
171
186
|
date = datetime.datetime(reference_date.year - year, 2, 28)
|
|
172
187
|
else:
|
|
@@ -246,3 +261,61 @@ class Autumn(DateTimes):
|
|
|
246
261
|
_description_
|
|
247
262
|
"""
|
|
248
263
|
super().__init__(datetime.datetime(year, 9, 1), datetime.datetime(year, 11, 30), **kwargs)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class ConcatDateTimes:
|
|
267
|
+
def __init__(self, *dates):
|
|
268
|
+
if len(dates) == 1 and isinstance(dates[0], list):
|
|
269
|
+
dates = dates[0]
|
|
270
|
+
|
|
271
|
+
self.dates = dates
|
|
272
|
+
|
|
273
|
+
def __iter__(self):
|
|
274
|
+
for date in self.dates:
|
|
275
|
+
yield from date
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class EnumDateTimes:
|
|
279
|
+
def __init__(self, dates):
|
|
280
|
+
self.dates = dates
|
|
281
|
+
|
|
282
|
+
def __iter__(self):
|
|
283
|
+
for date in self.dates:
|
|
284
|
+
yield as_datetime(date)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def datetimes_factory(*args, **kwargs):
|
|
288
|
+
if args and kwargs:
|
|
289
|
+
raise ValueError("Cannot provide both args and kwargs for a list of dates")
|
|
290
|
+
|
|
291
|
+
if not args and not kwargs:
|
|
292
|
+
raise ValueError("No dates provided")
|
|
293
|
+
|
|
294
|
+
if kwargs:
|
|
295
|
+
name = kwargs.get("name")
|
|
296
|
+
|
|
297
|
+
if name == "hindcast":
|
|
298
|
+
reference_dates = kwargs["reference_dates"]
|
|
299
|
+
reference_dates = datetimes_factory(reference_dates)
|
|
300
|
+
years = kwargs["years"]
|
|
301
|
+
return HindcastDatesTimes(reference_dates=reference_dates, years=years)
|
|
302
|
+
|
|
303
|
+
kwargs = kwargs.copy()
|
|
304
|
+
if "frequency" in kwargs:
|
|
305
|
+
freq = kwargs.pop("frequency")
|
|
306
|
+
kwargs["increment"] = normalise_frequency(freq)
|
|
307
|
+
return DateTimes(**kwargs)
|
|
308
|
+
|
|
309
|
+
if not any((isinstance(x, dict) or isinstance(x, list)) for x in args):
|
|
310
|
+
return EnumDateTimes(args)
|
|
311
|
+
|
|
312
|
+
if len(args) == 1:
|
|
313
|
+
a = args[0]
|
|
314
|
+
|
|
315
|
+
if isinstance(a, dict):
|
|
316
|
+
return datetimes_factory(**a)
|
|
317
|
+
|
|
318
|
+
if isinstance(a, list):
|
|
319
|
+
return datetimes_factory(*a)
|
|
320
|
+
|
|
321
|
+
return ConcatDateTimes(*[datetimes_factory(a) for a in args])
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
|
|
2
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
4
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
|
+
# nor does it submit to any jurisdiction.
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
"""Utilities for working with Mars requests.
|
|
10
|
+
|
|
11
|
+
Has some konwledge of how certain streams are organised in Mars.
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import datetime
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
|
|
19
|
+
import yaml
|
|
20
|
+
|
|
21
|
+
LOG = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
DEFAULT_MARS_LABELLING = {
|
|
24
|
+
"class": "od",
|
|
25
|
+
"type": "an",
|
|
26
|
+
"stream": "oper",
|
|
27
|
+
"expver": "0001",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _expand_mars_labelling(request):
|
|
32
|
+
"""Expand the request with the default Mars labelling.
|
|
33
|
+
|
|
34
|
+
The default Mars labelling is:
|
|
35
|
+
|
|
36
|
+
{'class': 'od',
|
|
37
|
+
'type': 'an',
|
|
38
|
+
'stream': 'oper',
|
|
39
|
+
'expver': '0001'}
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
result = DEFAULT_MARS_LABELLING.copy()
|
|
43
|
+
result.update(request)
|
|
44
|
+
return result
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
STREAMS = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _lookup_mars_stream(request):
|
|
51
|
+
global STREAMS
|
|
52
|
+
|
|
53
|
+
if STREAMS is None:
|
|
54
|
+
|
|
55
|
+
with open(os.path.join(os.path.dirname(__file__), "mars.yaml")) as f:
|
|
56
|
+
STREAMS = yaml.safe_load(f)
|
|
57
|
+
|
|
58
|
+
request = _expand_mars_labelling(request)
|
|
59
|
+
for s in STREAMS:
|
|
60
|
+
match = s["match"]
|
|
61
|
+
if all(request.get(k) == v for k, v in match.items()):
|
|
62
|
+
return s["info"]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def recenter(date, center, members):
|
|
66
|
+
|
|
67
|
+
center = _lookup_mars_stream(center)
|
|
68
|
+
members = _lookup_mars_stream(members)
|
|
69
|
+
|
|
70
|
+
return (center, members)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
if __name__ == "__main__":
|
|
74
|
+
date = datetime.datetime(2024, 5, 9, 0)
|
|
75
|
+
|
|
76
|
+
print(recenter(date, {"type": "an"}, {"stream": "elda"}))
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
|
|
2
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
4
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
|
+
# nor does it submit to any jurisdiction.
|
|
7
|
+
|
|
8
|
+
"""Logging utilities."""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import time
|
|
12
|
+
|
|
13
|
+
from .humanize import seconds
|
|
14
|
+
|
|
15
|
+
LOGGER = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Timer:
|
|
19
|
+
def __init__(self, title, logger=LOGGER):
|
|
20
|
+
self.title = title
|
|
21
|
+
self.start = time.time()
|
|
22
|
+
self.logger = logger
|
|
23
|
+
|
|
24
|
+
def __enter__(self):
|
|
25
|
+
return self
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def elapsed(self):
|
|
29
|
+
return time.time() - self.start
|
|
30
|
+
|
|
31
|
+
def __exit__(self, *args):
|
|
32
|
+
self.logger.info("%s: %s.", self.title, seconds(self.elapsed))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: anemoi-utils
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A package to hold various functions to support training of ML models on ECMWF data.
|
|
5
5
|
Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
|
|
6
6
|
License: Apache License
|
|
@@ -223,6 +223,8 @@ Classifier: Operating System :: OS Independent
|
|
|
223
223
|
Requires-Python: >=3.9
|
|
224
224
|
License-File: LICENSE
|
|
225
225
|
Requires-Dist: tomli
|
|
226
|
+
Requires-Dist: pyyaml
|
|
227
|
+
Requires-Dist: tqdm
|
|
226
228
|
Provides-Extra: provenance
|
|
227
229
|
Requires-Dist: GitPython; extra == "provenance"
|
|
228
230
|
Requires-Dist: nvsmi; extra == "provenance"
|
|
@@ -234,6 +236,11 @@ Provides-Extra: docs
|
|
|
234
236
|
Requires-Dist: tomli; extra == "docs"
|
|
235
237
|
Requires-Dist: termcolor; extra == "docs"
|
|
236
238
|
Requires-Dist: requests; extra == "docs"
|
|
239
|
+
Requires-Dist: sphinx; extra == "docs"
|
|
240
|
+
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
|
241
|
+
Requires-Dist: nbsphinx; extra == "docs"
|
|
242
|
+
Requires-Dist: pandoc; extra == "docs"
|
|
243
|
+
Requires-Dist: sphinx_argparse; extra == "docs"
|
|
237
244
|
Provides-Extra: all
|
|
238
245
|
Requires-Dist: tomli; extra == "all"
|
|
239
246
|
Requires-Dist: GitPython; extra == "all"
|
|
@@ -246,7 +253,3 @@ Requires-Dist: GitPython; extra == "dev"
|
|
|
246
253
|
Requires-Dist: nvsmi; extra == "dev"
|
|
247
254
|
Requires-Dist: termcolor; extra == "dev"
|
|
248
255
|
Requires-Dist: requests; extra == "dev"
|
|
249
|
-
Requires-Dist: sphinx; extra == "dev"
|
|
250
|
-
Requires-Dist: sphinx_rtd_theme; extra == "dev"
|
|
251
|
-
Requires-Dist: nbsphinx; extra == "dev"
|
|
252
|
-
Requires-Dist: pandoc; extra == "dev"
|
|
@@ -21,6 +21,7 @@ docs/modules/humanize.rst
|
|
|
21
21
|
docs/modules/provenance.rst
|
|
22
22
|
docs/modules/text.rst
|
|
23
23
|
src/anemoi/utils/__init__.py
|
|
24
|
+
src/anemoi/utils/__main__.py
|
|
24
25
|
src/anemoi/utils/_version.py
|
|
25
26
|
src/anemoi/utils/caching.py
|
|
26
27
|
src/anemoi/utils/checkpoints.py
|
|
@@ -30,10 +31,17 @@ src/anemoi/utils/grib.py
|
|
|
30
31
|
src/anemoi/utils/humanize.py
|
|
31
32
|
src/anemoi/utils/provenance.py
|
|
32
33
|
src/anemoi/utils/text.py
|
|
34
|
+
src/anemoi/utils/timer.py
|
|
35
|
+
src/anemoi/utils/commands/__init__.py
|
|
36
|
+
src/anemoi/utils/commands/checkpoint.py
|
|
37
|
+
src/anemoi/utils/mars/__init__.py
|
|
38
|
+
src/anemoi/utils/mars/mars.yaml
|
|
33
39
|
src/anemoi_utils.egg-info/PKG-INFO
|
|
34
40
|
src/anemoi_utils.egg-info/SOURCES.txt
|
|
35
41
|
src/anemoi_utils.egg-info/dependency_links.txt
|
|
42
|
+
src/anemoi_utils.egg-info/entry_points.txt
|
|
36
43
|
src/anemoi_utils.egg-info/requires.txt
|
|
37
44
|
src/anemoi_utils.egg-info/top_level.txt
|
|
38
45
|
tests/requirements.txt
|
|
46
|
+
tests/test_dates.py
|
|
39
47
|
tests/test_utils.py
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
tomli
|
|
2
|
+
pyyaml
|
|
3
|
+
tqdm
|
|
2
4
|
|
|
3
5
|
[all]
|
|
4
6
|
tomli
|
|
@@ -13,15 +15,16 @@ GitPython
|
|
|
13
15
|
nvsmi
|
|
14
16
|
termcolor
|
|
15
17
|
requests
|
|
16
|
-
sphinx
|
|
17
|
-
sphinx_rtd_theme
|
|
18
|
-
nbsphinx
|
|
19
|
-
pandoc
|
|
20
18
|
|
|
21
19
|
[docs]
|
|
22
20
|
tomli
|
|
23
21
|
termcolor
|
|
24
22
|
requests
|
|
23
|
+
sphinx
|
|
24
|
+
sphinx_rtd_theme
|
|
25
|
+
nbsphinx
|
|
26
|
+
pandoc
|
|
27
|
+
sphinx_argparse
|
|
25
28
|
|
|
26
29
|
[grib]
|
|
27
30
|
requests
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from textwrap import dedent
|
|
3
|
+
|
|
4
|
+
import yaml
|
|
5
|
+
|
|
6
|
+
from anemoi.utils.dates import datetimes_factory
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _(txt):
|
|
10
|
+
txt = dedent(txt)
|
|
11
|
+
config = yaml.safe_load(txt)
|
|
12
|
+
return datetimes_factory(config)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_date_1():
|
|
16
|
+
d = _(
|
|
17
|
+
"""
|
|
18
|
+
- 2023-01-01
|
|
19
|
+
- 2023-01-02
|
|
20
|
+
- 2023-01-03
|
|
21
|
+
"""
|
|
22
|
+
)
|
|
23
|
+
assert len(list(d)) == 3
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_date_2():
|
|
27
|
+
d = _(
|
|
28
|
+
"""
|
|
29
|
+
start: 2023-01-01
|
|
30
|
+
end: 2023-01-07
|
|
31
|
+
frequency: 12
|
|
32
|
+
day_of_week: [monday, friday]
|
|
33
|
+
"""
|
|
34
|
+
)
|
|
35
|
+
assert len(list(d)) == 4
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test_date_3():
|
|
39
|
+
d = _(
|
|
40
|
+
"""
|
|
41
|
+
- start: 2023-01-01
|
|
42
|
+
end: 2023-01-03
|
|
43
|
+
frequency: 24
|
|
44
|
+
- start: 2024-01-01T06:00:00
|
|
45
|
+
end: 2024-01-02T18:00:00
|
|
46
|
+
frequency: 6h
|
|
47
|
+
"""
|
|
48
|
+
)
|
|
49
|
+
assert datetime.datetime(2023, 1, 1, 0) in d
|
|
50
|
+
assert datetime.datetime(2023, 1, 2, 0) in d
|
|
51
|
+
assert datetime.datetime(2023, 1, 3, 0) in d
|
|
52
|
+
assert datetime.datetime(2024, 1, 1, 6) in d
|
|
53
|
+
assert datetime.datetime(2024, 1, 1, 12) in d
|
|
54
|
+
assert datetime.datetime(2024, 1, 1, 18) in d
|
|
55
|
+
assert datetime.datetime(2024, 1, 2, 0) in d
|
|
56
|
+
assert datetime.datetime(2024, 1, 2, 6) in d
|
|
57
|
+
assert datetime.datetime(2024, 1, 2, 12) in d
|
|
58
|
+
assert datetime.datetime(2024, 1, 2, 18) in d
|
|
59
|
+
assert len(list(d)) == 10
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_date_hindcast_1():
|
|
63
|
+
d = _(
|
|
64
|
+
"""
|
|
65
|
+
- name: hindcast
|
|
66
|
+
reference_dates:
|
|
67
|
+
start: 2023-01-01
|
|
68
|
+
end: 2023-01-03
|
|
69
|
+
frequency: 24
|
|
70
|
+
years: 20
|
|
71
|
+
"""
|
|
72
|
+
)
|
|
73
|
+
assert len(list(d)) == 60
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def test_date_hindcast_2():
|
|
77
|
+
d = _(
|
|
78
|
+
"""
|
|
79
|
+
- name: hindcast
|
|
80
|
+
reference_dates:
|
|
81
|
+
start: 2023-01-01
|
|
82
|
+
end: 2023-01-03
|
|
83
|
+
frequency: 24
|
|
84
|
+
years: [2018, 2019, 2020, 2021]
|
|
85
|
+
"""
|
|
86
|
+
)
|
|
87
|
+
assert len(list(d)) == 12
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def test_date_hindcast_3():
|
|
91
|
+
d = _(
|
|
92
|
+
"""
|
|
93
|
+
- name: hindcast
|
|
94
|
+
reference_dates:
|
|
95
|
+
start: 2022-12-25 00:00:00
|
|
96
|
+
end: 2022-12-31 12:00:00
|
|
97
|
+
frequency: 12h
|
|
98
|
+
day_of_week: tuesday
|
|
99
|
+
years: [2018, 2019, 2020, 2021]
|
|
100
|
+
"""
|
|
101
|
+
)
|
|
102
|
+
print(list(d))
|
|
103
|
+
assert len(list(d)) == 8
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
if __name__ == "__main__":
|
|
107
|
+
test_functions = [
|
|
108
|
+
obj for name, obj in globals().items() if name.startswith("test_") and isinstance(obj, type(lambda: 0))
|
|
109
|
+
]
|
|
110
|
+
for test_func in test_functions:
|
|
111
|
+
print(f"Running test: {test_func.__name__}")
|
|
112
|
+
test_func()
|
|
113
|
+
print("All tests passed!")
|
|
@@ -1,75 +0,0 @@
|
|
|
1
|
-
# (C) Copyright 2024 ECMWF.
|
|
2
|
-
#
|
|
3
|
-
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
-
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
-
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
-
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
-
# nor does it submit to any jurisdiction.
|
|
8
|
-
|
|
9
|
-
"""Read and write extra metadata in PyTorch checkpoints files. These files
|
|
10
|
-
are zip archives containing the model weights.
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
import json
|
|
14
|
-
import logging
|
|
15
|
-
import os
|
|
16
|
-
import zipfile
|
|
17
|
-
|
|
18
|
-
LOG = logging.getLogger(__name__)
|
|
19
|
-
|
|
20
|
-
DEFAULT_NAME = "anemoi-metadata.json"
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def load_metadata(path: str, name: str = DEFAULT_NAME):
|
|
24
|
-
"""Load metadata from a checkpoint file
|
|
25
|
-
|
|
26
|
-
Parameters
|
|
27
|
-
----------
|
|
28
|
-
path : str
|
|
29
|
-
The path to the checkpoint file
|
|
30
|
-
name : str, optional
|
|
31
|
-
The name of the metadata file in the zip archive
|
|
32
|
-
|
|
33
|
-
Returns
|
|
34
|
-
-------
|
|
35
|
-
JSON
|
|
36
|
-
The content of the metadata file
|
|
37
|
-
|
|
38
|
-
Raises
|
|
39
|
-
------
|
|
40
|
-
ValueError
|
|
41
|
-
If the metadata file is not found
|
|
42
|
-
"""
|
|
43
|
-
with zipfile.ZipFile(path, "r") as f:
|
|
44
|
-
metadata = None
|
|
45
|
-
for b in f.namelist():
|
|
46
|
-
if os.path.basename(b) == name:
|
|
47
|
-
if metadata is not None:
|
|
48
|
-
LOG.warning(f"Found two '{name}' if {path}")
|
|
49
|
-
metadata = b
|
|
50
|
-
|
|
51
|
-
if metadata is not None:
|
|
52
|
-
with zipfile.ZipFile(path, "r") as f:
|
|
53
|
-
return json.load(f.open(metadata, "r"))
|
|
54
|
-
else:
|
|
55
|
-
raise ValueError(f"Could not find {name} in {path}")
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def save_metadata(path, metadata, name=DEFAULT_NAME):
|
|
59
|
-
"""Save metadata to a checkpoint file
|
|
60
|
-
|
|
61
|
-
Parameters
|
|
62
|
-
----------
|
|
63
|
-
path : str
|
|
64
|
-
The path to the checkpoint file
|
|
65
|
-
metadata : JSON
|
|
66
|
-
A JSON serializable object
|
|
67
|
-
name : str, optional
|
|
68
|
-
The name of the metadata file in the zip archive
|
|
69
|
-
"""
|
|
70
|
-
with zipfile.ZipFile(path, "a") as zipf:
|
|
71
|
-
base, _ = os.path.splitext(os.path.basename(path))
|
|
72
|
-
zipf.writestr(
|
|
73
|
-
f"{base}/{name}",
|
|
74
|
-
json.dumps(metadata),
|
|
75
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|