anemoi-utils 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: anemoi-utils
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Home-page: https://github.com/ecmwf/anemoi-utils
6
6
  Author: European Centre for Medium-Range Weather Forecasts (ECMWF)
@@ -6,4 +6,4 @@
6
6
  # nor does it submit to any jurisdiction.
7
7
 
8
8
 
9
- __version__ = "0.1.2"
9
+ __version__ = "0.1.3"
@@ -6,6 +6,11 @@
6
6
  # granted to it by virtue of its status as an intergovernmental organisation
7
7
  # nor does it submit to any jurisdiction.
8
8
 
9
+ """
10
+ Read and write extra metadata in PyTorch checkpoints files. These files
11
+ are zip archives containing the model weights.
12
+ """
13
+
9
14
  import json
10
15
  import logging
11
16
  import os
@@ -16,7 +21,26 @@ LOG = logging.getLogger(__name__)
16
21
  DEFAULT_NAME = "anemoi-metadata.json"
17
22
 
18
23
 
19
- def load_metadata(path, name=DEFAULT_NAME):
24
+ def load_metadata(path: str, name: str = DEFAULT_NAME):
25
+ """Load metadata from a checkpoint file
26
+
27
+ Parameters
28
+ ----------
29
+ path : str
30
+ The path to the checkpoint file
31
+ name : str, optional
32
+ The name of the metadata file in the zip archive
33
+
34
+ Returns
35
+ -------
36
+ JSON
37
+ The content of the metadata file
38
+
39
+ Raises
40
+ ------
41
+ ValueError
42
+ If the metadata file is not found
43
+ """
20
44
  with zipfile.ZipFile(path, "r") as f:
21
45
  metadata = None
22
46
  for b in f.namelist():
@@ -33,6 +57,17 @@ def load_metadata(path, name=DEFAULT_NAME):
33
57
 
34
58
 
35
59
  def save_metadata(path, metadata, name=DEFAULT_NAME):
60
+ """Save metadata to a checkpoint file
61
+
62
+ Parameters
63
+ ----------
64
+ path : str
65
+ The path to the checkpoint file
66
+ metadata : JSON
67
+ A JSON serializable object
68
+ name : str, optional
69
+ The name of the metadata file in the zip archive
70
+ """
36
71
  with zipfile.ZipFile(path, "a") as zipf:
37
72
  base, _ = os.path.splitext(os.path.basename(path))
38
73
  zipf.writestr(
@@ -7,18 +7,41 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
+ """
11
+ Generate human readable strings
12
+ """
13
+
10
14
  import datetime
11
15
  import re
12
16
  from collections import defaultdict
13
17
 
14
18
 
15
- def bytes(n):
16
- """
19
+ def bytes(n: float) -> str:
20
+ """Convert a number of bytes to a human readable string
21
+
17
22
  >>> bytes(4096)
18
23
  '4 KiB'
24
+
19
25
  >>> bytes(4000)
20
26
  '3.9 KiB'
27
+
28
+ Parameters
29
+ ----------
30
+ n : float
31
+ the number of bytes
32
+
33
+ Returns
34
+ -------
35
+ str
36
+ a human readable string
37
+ """
38
+
21
39
  """
40
+
41
+
42
+
43
+ """
44
+
22
45
  if n < 0:
23
46
  sign = "-"
24
47
  n -= 0
@@ -33,13 +56,7 @@ def bytes(n):
33
56
  return "%s%g%s" % (sign, int(n * 10 + 0.5) / 10.0, u[i])
34
57
 
35
58
 
36
- def base2(n):
37
- """
38
- >>> base2(4096)
39
- '4K'
40
- >>> base2(4000)
41
- '3.9K'
42
- """
59
+ def base2(n) -> str:
43
60
 
44
61
  u = ["", "K", "M", "G", "T", " P", "E", "Z", "Y"]
45
62
  i = 0
@@ -65,7 +82,24 @@ def _plural(count):
65
82
  return ""
66
83
 
67
84
 
68
- def seconds(seconds):
85
+ def seconds(seconds: float) -> str:
86
+ """Convert a number of seconds to a human readable string
87
+
88
+ >>> seconds(4000)
89
+ '1 hour 6 minutes 40 seconds'
90
+
91
+
92
+ Parameters
93
+ ----------
94
+ seconds : float
95
+ The number of seconds
96
+
97
+ Returns
98
+ -------
99
+ str
100
+ A human readable string
101
+
102
+ """
69
103
  if isinstance(seconds, datetime.timedelta):
70
104
  seconds = seconds.total_seconds()
71
105
 
@@ -113,6 +147,20 @@ def number(value):
113
147
 
114
148
 
115
149
  def plural(value, what):
150
+ """_summary_
151
+
152
+ Parameters
153
+ ----------
154
+ value : _type_
155
+ _description_
156
+ what : _type_
157
+ _description_
158
+
159
+ Returns
160
+ -------
161
+ _type_
162
+ _description_
163
+ """
116
164
  return f"{number(value)} {what}{_plural(value)}"
117
165
 
118
166
 
@@ -159,6 +207,39 @@ def __(n):
159
207
 
160
208
 
161
209
  def when(then, now=None, short=True):
210
+ """Generate a human readable string for a date, relative to now
211
+
212
+
213
+ >>> when(datetime.datetime.now() - datetime.timedelta(hours=2))
214
+ '2 hours ago'
215
+
216
+ >>> when(datetime.datetime.now() - datetime.timedelta(days=1))
217
+ 'yesterday at 08:46'
218
+
219
+ >>> when(datetime.datetime.now() - datetime.timedelta(days=5))
220
+ 'last Sunday'
221
+
222
+ >>> when(datetime.datetime.now() - datetime.timedelta(days=365))
223
+ 'last year'
224
+
225
+ >>> when(datetime.datetime.now() + datetime.timedelta(days=365))
226
+ 'next year'
227
+
228
+ Parameters
229
+ ----------
230
+ then : datetime.datetime
231
+ A datetime
232
+ now : datetime.datetime, optional
233
+ The reference date, by default NOW
234
+ short : bool, optional
235
+ Genererate shorter strings, by default True
236
+
237
+ Returns
238
+ -------
239
+ str
240
+ A human readable string
241
+
242
+ """
162
243
  last = "last"
163
244
 
164
245
  if now is None:
@@ -5,6 +5,16 @@
5
5
  # granted to it by virtue of its status as an intergovernmental organisation
6
6
  # nor does it submit to any jurisdiction.
7
7
 
8
+ """
9
+ Collect information about the current environment, like:
10
+
11
+ - The Python version
12
+ - The versions of the modules which are currently loaded
13
+ - The git information for the modules which are currently loaded from a git repository
14
+ - ...
15
+
16
+ """
17
+
8
18
  import datetime
9
19
  import json
10
20
  import logging
@@ -29,7 +39,7 @@ def lookup_git_repo(path):
29
39
  return None
30
40
 
31
41
 
32
- def check_for_git(paths, full):
42
+ def _check_for_git(paths, full):
33
43
  versions = {}
34
44
  for name, path in paths:
35
45
  repo = lookup_git_repo(path)
@@ -106,7 +116,7 @@ def version(versions, name, module, roots, namespaces, paths, full):
106
116
  versions[name] = str(module)
107
117
 
108
118
 
109
- def module_versions(full):
119
+ def _module_versions(full):
110
120
  # https://docs.python.org/3/library/sysconfig.html
111
121
 
112
122
  roots = {}
@@ -131,11 +141,104 @@ def module_versions(full):
131
141
  if len(bits) == 2 and bits[0] in namespaces:
132
142
  version(versions, k, v, roots, namespaces, paths, full)
133
143
 
134
- git_versions = check_for_git(paths, full)
144
+ return versions, paths
135
145
 
146
+
147
+ def module_versions(full):
148
+ versions, paths = _module_versions(full)
149
+ git_versions = _check_for_git(paths, full)
136
150
  return versions, git_versions
137
151
 
138
152
 
153
+ def _name(obj):
154
+ if hasattr(obj, "__name__"):
155
+ if hasattr(obj, "__module__"):
156
+ return f"{obj.__module__}.{obj.__name__}"
157
+ return obj.__name__
158
+ if hasattr(obj, "__class__"):
159
+ return _name(obj.__class__)
160
+ return str(obj)
161
+
162
+
163
+ def _paths(path_or_object):
164
+
165
+ if path_or_object is None:
166
+ _, paths = _module_versions(full=False)
167
+ return paths
168
+
169
+ if isinstance(path_or_object, (list, tuple, set)):
170
+ paths = []
171
+ for p in path_or_object:
172
+ paths.extend(_paths(p))
173
+ return paths
174
+
175
+ if isinstance(path_or_object, str):
176
+ module = sys.modules.get(path_or_object)
177
+ if module is not None:
178
+ return _paths(module)
179
+ return [(path_or_object, path_or_object)]
180
+
181
+ if hasattr(path_or_object, "__module__"):
182
+ module = sys.modules.get(path_or_object.__module__)
183
+ return [(path_or_object.__module__, module.__file__)]
184
+
185
+ name = _name(path_or_object)
186
+ paths = []
187
+ if hasattr(path_or_object, "__file__"):
188
+ paths.append((name, path_or_object.__file__))
189
+
190
+ if hasattr(path_or_object, "__code__"):
191
+ paths.append((name, path_or_object.__code__.co_filename))
192
+
193
+ if hasattr(path_or_object, "__module__"):
194
+ module = sys.modules.get(path_or_object.__module__)
195
+ paths.append((name, module.__file__))
196
+
197
+ if not paths:
198
+ raise ValueError(f"Could not find path for {name} {path_or_object} {type(path_or_object)}")
199
+
200
+ return paths
201
+
202
+
203
+ def git_check(*args):
204
+ """Return the git information for the given arguments.
205
+
206
+ Arguments can be:
207
+ - an empty list, in that case all loaded modules are checked
208
+ - a module name
209
+ - a module object
210
+ - an object or a class
211
+ - a path to a directory
212
+
213
+ Returns
214
+ -------
215
+ dict
216
+ An object with the git information for the given arguments.
217
+
218
+
219
+ >>> {
220
+ "anemoi.utils": {
221
+ "sha1": "c999d83ae283bcbb99f68d92c42d24315922129f",
222
+ "remotes": [
223
+ "git@github.com:ecmwf/anemoi-utils.git"
224
+ ],
225
+ "modified_files": [
226
+ "anemoi/utils/checkpoints.py"
227
+ ],
228
+ "untracked_files": []
229
+ }
230
+ }
231
+ """
232
+ paths = _paths(args if len(args) > 0 else None)
233
+
234
+ git = _check_for_git(paths, full=True)
235
+ result = {}
236
+ for k, v in git.items():
237
+ result[k] = v["git"]
238
+
239
+ return result
240
+
241
+
139
242
  def platform_info():
140
243
  import platform
141
244
 
@@ -209,7 +312,21 @@ def assets_info(paths):
209
312
  return result
210
313
 
211
314
 
212
- def gather_provenance_info(assets=[], full=False):
315
+ def gather_provenance_info(assets=[], full=False) -> dict:
316
+ """Gather information about the current environment
317
+
318
+ Parameters
319
+ ----------
320
+ assets : list, optional
321
+ A list of file paths for which to collect the MD5 sum, the size and time attributes, by default []
322
+ full : bool, optional
323
+ If true, will also collect various paths, by default False
324
+
325
+ Returns
326
+ -------
327
+ dict
328
+ A dictionary with the collected information
329
+ """
213
330
  executable = sys.executable
214
331
 
215
332
  versions, git_versions = module_versions(full)
@@ -5,20 +5,60 @@
5
5
  # granted to it by virtue of its status as an intergovernmental organisation
6
6
  # nor does it submit to any jurisdiction.
7
7
 
8
+ """
9
+ Text utilities
10
+ """
8
11
 
9
12
  import sys
10
13
  from collections import defaultdict
11
14
 
12
- from termcolor import colored
13
-
14
15
  # https://en.wikipedia.org/wiki/Box-drawing_character
15
16
 
16
17
 
17
- def dotted_line(n=84, file=sys.stdout):
18
- print("" * n, file=file)
18
+ def dotted_line(n=84) -> str:
19
+ """_summary_
20
+
21
+ Parameters
22
+ ----------
23
+ n : int, optional
24
+ _description_, by default 84
25
+
26
+ Returns
27
+ -------
28
+ str
29
+ _description_
30
+ """
31
+ return "┈" * n
32
+
33
+
34
+ def boxed(text, min_width=80, max_width=80) -> str:
35
+ """Put a box around a text
36
+
37
+ >>> print(boxed("Hello,\\nWorld!", max_width=40))
38
+ ┌──────────────────────────────────────────┐
39
+ │ Hello, │
40
+ │ World! │
41
+ └──────────────────────────────────────────┘
42
+
43
+ Parameters
44
+ ----------
45
+ text : str
46
+ The text to box
47
+ min_width : int, optional
48
+ The minimum width of the box, by default 80
49
+ max_width : int, optional
50
+ The maximum width of the box, by default 80
51
+
52
+ Returns
53
+ -------
54
+ str
55
+ A boxed version of the input text
56
+
57
+
19
58
 
20
59
 
21
- def boxed(text, min_width=80, max_width=80):
60
+ """
61
+
22
62
  lines = text.split("\n")
23
63
  width = max(len(_) for _ in lines)
24
64
 
@@ -44,14 +84,20 @@ def boxed(text, min_width=80, max_width=80):
44
84
 
45
85
 
46
86
  def bold(text):
87
+ from termcolor import colored
88
+
47
89
  return colored(text, attrs=["bold"])
48
90
 
49
91
 
50
92
  def red(text):
93
+ from termcolor import colored
94
+
51
95
  return colored(text, "red")
52
96
 
53
97
 
54
98
  def green(text):
99
+ from termcolor import colored
100
+
55
101
  return colored(text, "green")
56
102
 
57
103
 
@@ -191,6 +237,20 @@ class Tree:
191
237
 
192
238
 
193
239
  def table(rows, header, align, margin=0):
240
+ """_summary_
241
+
242
+ Parameters
243
+ ----------
244
+ rows : _type_
245
+ _description_
246
+ header : _type_
247
+ _description_
248
+ align : _type_
249
+ _description_
250
+ margin : int, optional
251
+ _description_, by default 0
252
+ """
253
+
194
254
  def _(x):
195
255
  try:
196
256
  x = float(x)
@@ -231,6 +291,27 @@ def table(rows, header, align, margin=0):
231
291
  return "\n".join(result)
232
292
 
233
293
 
234
- def progress(done, todo, width=80):
294
+ def progress(done, todo, width=80) -> str:
295
+ """_summary_
296
+
297
+ >>> print(progress(10, 100,width=50))
298
+ █████▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒
299
+
300
+ Parameters
301
+ ----------
302
+ done : function
303
+ _description_
304
+ todo : _type_
305
+ _description_
306
+ width : int, optional
307
+ _description_, by default 80
308
+
309
+ Returns
310
+ -------
311
+ str
312
+ _description_
313
+
314
+
315
+ """
235
316
  done = min(int(done / todo * width + 0.5), width)
236
317
  return green("█" * done) + red("█" * (width - done))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: anemoi-utils
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Home-page: https://github.com/ecmwf/anemoi-utils
6
6
  Author: European Centre for Medium-Range Weather Forecasts (ECMWF)
@@ -41,6 +41,7 @@ provenance_requires = [
41
41
  text_requires = [
42
42
  "termcolor",
43
43
  ]
44
+
44
45
  doc_requires = ["sphinx", "sphinx_rtd_theme", "nbsphinx", "pandoc"]
45
46
 
46
47
  all_requires = install_requires + provenance_requires + text_requires
File without changes
File without changes
File without changes