anemoi-utils 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

@@ -0,0 +1,353 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ """
9
+ Collect information about the current environment, like:
10
+
11
+ - The Python version
12
+ - The versions of the modules which are currently loaded
13
+ - The git information for the modules which are currently loaded from a git repository
14
+ - ...
15
+
16
+ """
17
+
18
+ import datetime
19
+ import json
20
+ import logging
21
+ import os
22
+ import subprocess
23
+ import sys
24
+ import sysconfig
25
+
26
+ LOG = logging.getLogger(__name__)
27
+
28
+
29
+ def lookup_git_repo(path):
30
+ from git import InvalidGitRepositoryError
31
+ from git import Repo
32
+
33
+ while path != "/":
34
+ try:
35
+ return Repo(path)
36
+ except InvalidGitRepositoryError:
37
+ path = os.path.dirname(path)
38
+
39
+ return None
40
+
41
+
42
+ def _check_for_git(paths, full):
43
+ versions = {}
44
+ for name, path in paths:
45
+ repo = lookup_git_repo(path)
46
+ if repo is None:
47
+ continue
48
+
49
+ try:
50
+
51
+ if not full:
52
+ versions[name] = dict(
53
+ git=dict(
54
+ sha1=repo.head.commit.hexsha,
55
+ modified_files=len([item.a_path for item in repo.index.diff(None)]),
56
+ untracked_files=len(repo.untracked_files),
57
+ ),
58
+ )
59
+ continue
60
+
61
+ versions[name] = dict(
62
+ path=path,
63
+ git=dict(
64
+ sha1=repo.head.commit.hexsha,
65
+ remotes=[r.url for r in repo.remotes],
66
+ modified_files=sorted([item.a_path for item in repo.index.diff(None)]),
67
+ untracked_files=sorted(repo.untracked_files),
68
+ ),
69
+ )
70
+
71
+ except ValueError as e:
72
+ LOG.error(f"Error checking git repo {path}: {e}")
73
+
74
+ return versions
75
+
76
+
77
+ def version(versions, name, module, roots, namespaces, paths, full):
78
+ path = None
79
+
80
+ if hasattr(module, "__file__"):
81
+ path = module.__file__
82
+ if path is not None:
83
+ for k, v in roots.items():
84
+ path = path.replace(k, f"<{v}>")
85
+
86
+ if path.startswith("/"):
87
+ paths.add((name, path))
88
+
89
+ try:
90
+ versions[name] = module.__version__
91
+ return
92
+ except AttributeError:
93
+ pass
94
+
95
+ try:
96
+ if path is None:
97
+ namespaces.add(name)
98
+ return
99
+
100
+ # For now, don't report on stdlib modules
101
+ if path.startswith("<stdlib>"):
102
+ return
103
+
104
+ if full:
105
+ versions[name] = path
106
+ else:
107
+ if not path.startswith("<"):
108
+ versions[name] = os.path.join("...", os.path.basename(path))
109
+ return
110
+ except AttributeError:
111
+ pass
112
+
113
+ if name in sys.builtin_module_names:
114
+ return
115
+
116
+ versions[name] = str(module)
117
+
118
+
119
+ def _module_versions(full):
120
+ # https://docs.python.org/3/library/sysconfig.html
121
+
122
+ roots = {}
123
+ for name, path in sysconfig.get_paths().items():
124
+ if path not in roots:
125
+ roots[path] = name
126
+
127
+ # Sort by length of path, so that we get the most specific first
128
+ roots = {path: name for path, name in sorted(roots.items(), key=lambda x: len(x[0]), reverse=True)}
129
+
130
+ paths = set()
131
+
132
+ versions = {}
133
+ namespaces = set()
134
+ for k, v in sorted(sys.modules.items()):
135
+ if "." not in k:
136
+ version(versions, k, v, roots, namespaces, paths, full)
137
+
138
+ # Catter for modules like "earthkit.meteo"
139
+ for k, v in sorted(sys.modules.items()):
140
+ bits = k.split(".")
141
+ if len(bits) == 2 and bits[0] in namespaces:
142
+ version(versions, k, v, roots, namespaces, paths, full)
143
+
144
+ return versions, paths
145
+
146
+
147
+ def module_versions(full):
148
+ versions, paths = _module_versions(full)
149
+ git_versions = _check_for_git(paths, full)
150
+ return versions, git_versions
151
+
152
+
153
+ def _name(obj):
154
+ if hasattr(obj, "__name__"):
155
+ if hasattr(obj, "__module__"):
156
+ return f"{obj.__module__}.{obj.__name__}"
157
+ return obj.__name__
158
+ if hasattr(obj, "__class__"):
159
+ return _name(obj.__class__)
160
+ return str(obj)
161
+
162
+
163
+ def _paths(path_or_object):
164
+
165
+ if path_or_object is None:
166
+ _, paths = _module_versions(full=False)
167
+ return paths
168
+
169
+ if isinstance(path_or_object, (list, tuple, set)):
170
+ paths = []
171
+ for p in path_or_object:
172
+ paths.extend(_paths(p))
173
+ return paths
174
+
175
+ if isinstance(path_or_object, str):
176
+ module = sys.modules.get(path_or_object)
177
+ if module is not None:
178
+ return _paths(module)
179
+ return [(path_or_object, path_or_object)]
180
+
181
+ if hasattr(path_or_object, "__module__"):
182
+ module = sys.modules.get(path_or_object.__module__)
183
+ return [(path_or_object.__module__, module.__file__)]
184
+
185
+ name = _name(path_or_object)
186
+ paths = []
187
+ if hasattr(path_or_object, "__file__"):
188
+ paths.append((name, path_or_object.__file__))
189
+
190
+ if hasattr(path_or_object, "__code__"):
191
+ paths.append((name, path_or_object.__code__.co_filename))
192
+
193
+ if hasattr(path_or_object, "__module__"):
194
+ module = sys.modules.get(path_or_object.__module__)
195
+ paths.append((name, module.__file__))
196
+
197
+ if not paths:
198
+ raise ValueError(f"Could not find path for {name} {path_or_object} {type(path_or_object)}")
199
+
200
+ return paths
201
+
202
+
203
+ def git_check(*args):
204
+ """Return the git information for the given arguments.
205
+
206
+ Arguments can be:
207
+ - an empty list, in that case all loaded modules are checked
208
+ - a module name
209
+ - a module object
210
+ - an object or a class
211
+ - a path to a directory
212
+
213
+ Returns
214
+ -------
215
+ dict
216
+ An object with the git information for the given arguments.
217
+
218
+
219
+ >>> {
220
+ "anemoi.utils": {
221
+ "sha1": "c999d83ae283bcbb99f68d92c42d24315922129f",
222
+ "remotes": [
223
+ "git@github.com:ecmwf/anemoi-utils.git"
224
+ ],
225
+ "modified_files": [
226
+ "anemoi/utils/checkpoints.py"
227
+ ],
228
+ "untracked_files": []
229
+ }
230
+ }
231
+ """
232
+ paths = _paths(args if len(args) > 0 else None)
233
+
234
+ git = _check_for_git(paths, full=True)
235
+ result = {}
236
+ for k, v in git.items():
237
+ result[k] = v["git"]
238
+
239
+ return result
240
+
241
+
242
+ def platform_info():
243
+ import platform
244
+
245
+ r = {}
246
+ for p in dir(platform):
247
+ if p.startswith("_"):
248
+ continue
249
+ try:
250
+ r[p] = getattr(platform, p)()
251
+ except Exception:
252
+ pass
253
+
254
+ def all_empty(x):
255
+ return all(all_empty(v) if isinstance(v, (list, tuple)) else v == "" for v in x)
256
+
257
+ for k, v in list(r.items()):
258
+ if isinstance(v, (list, tuple)) and all_empty(v):
259
+ del r[k]
260
+
261
+ return r
262
+
263
+
264
+ def gpu_info():
265
+ import nvsmi
266
+
267
+ if not nvsmi.is_nvidia_smi_on_path():
268
+ return "nvdia-smi not found"
269
+
270
+ try:
271
+ return [json.loads(gpu.to_json()) for gpu in nvsmi.get_gpus()]
272
+ except subprocess.CalledProcessError as e:
273
+ return e.output.decode("utf-8").strip()
274
+
275
+
276
+ def path_md5(path):
277
+ import hashlib
278
+
279
+ hash = hashlib.md5()
280
+ with open(path, "rb") as f:
281
+ for chunk in iter(lambda: f.read(1024 * 1024), b""):
282
+ hash.update(chunk)
283
+ return hash.hexdigest()
284
+
285
+
286
+ def assets_info(paths):
287
+ result = {}
288
+
289
+ for path in paths:
290
+ try:
291
+ (mode, ino, dev, nlink, uid, gid, size, atime, mtime, ctime) = os.stat(path) # noqa: F841
292
+ md5 = path_md5(path)
293
+ except Exception as e:
294
+ result[path] = str(e)
295
+ continue
296
+
297
+ result[path] = dict(
298
+ size=size,
299
+ atime=datetime.datetime.fromtimestamp(atime).isoformat(),
300
+ mtime=datetime.datetime.fromtimestamp(mtime).isoformat(),
301
+ ctime=datetime.datetime.fromtimestamp(ctime).isoformat(),
302
+ md5=md5,
303
+ )
304
+
305
+ try:
306
+ from .checkpoint import peek
307
+
308
+ result[path]["peek"] = peek(path)
309
+ except Exception:
310
+ pass
311
+
312
+ return result
313
+
314
+
315
+ def gather_provenance_info(assets=[], full=False) -> dict:
316
+ """Gather information about the current environment
317
+
318
+ Parameters
319
+ ----------
320
+ assets : list, optional
321
+ A list of file paths for which to collect the MD5 sum, the size and time attributes, by default []
322
+ full : bool, optional
323
+ If true, will also collect various paths, by default False
324
+
325
+ Returns
326
+ -------
327
+ dict
328
+ A dictionary with the collected information
329
+ """
330
+ executable = sys.executable
331
+
332
+ versions, git_versions = module_versions(full)
333
+
334
+ if not full:
335
+ return dict(
336
+ time=datetime.datetime.utcnow().isoformat(),
337
+ python=f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
338
+ module_versions=versions,
339
+ git_versions=git_versions,
340
+ )
341
+ else:
342
+ return dict(
343
+ time=datetime.datetime.utcnow().isoformat(),
344
+ executable=executable,
345
+ args=sys.argv,
346
+ python_path=sys.path,
347
+ config_paths=sysconfig.get_paths(),
348
+ module_versions=versions,
349
+ git_versions=git_versions,
350
+ platform=platform_info(),
351
+ gpus=gpu_info(),
352
+ assets=assets_info(assets),
353
+ )
anemoi/utils/text.py ADDED
@@ -0,0 +1,345 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ """
9
+ Text utilities
10
+ """
11
+
12
+ import sys
13
+ from collections import defaultdict
14
+
15
+ # https://en.wikipedia.org/wiki/Box-drawing_character
16
+
17
+
18
+ def dotted_line(width=84) -> str:
19
+ """Return a dotted line using '┈'
20
+
21
+ >>> dotted_line(40)
22
+ ┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈
23
+
24
+ Parameters
25
+ ----------
26
+ width : int, optional
27
+ Number of characters, by default 84
28
+
29
+ Returns
30
+ -------
31
+ str
32
+ The dotted line
33
+ """
34
+
35
+ return "┈" * width
36
+
37
+
38
+ def boxed(text, min_width=80, max_width=80) -> str:
39
+ """Put a box around a text
40
+
41
+ >>> boxed("Hello,\\nWorld!", max_width=40)
42
+ ┌──────────────────────────────────────────┐
43
+ │ Hello, │
44
+ │ World! │
45
+ └──────────────────────────────────────────┘
46
+
47
+ Parameters
48
+ ----------
49
+ text : str
50
+ The text to box
51
+ min_width : int, optional
52
+ The minimum width of the box, by default 80
53
+ max_width : int, optional
54
+ The maximum width of the box, by default 80
55
+
56
+ Returns
57
+ -------
58
+ str
59
+ A boxed version of the input text
60
+
61
+
62
+
63
+
64
+ """
65
+
66
+ lines = text.split("\n")
67
+ width = max(len(_) for _ in lines)
68
+
69
+ if min_width is not None:
70
+ width = max(width, min_width)
71
+
72
+ if max_width is not None:
73
+ width = min(width, max_width)
74
+ lines = []
75
+ for line in text.split("\n"):
76
+ if len(line) > max_width:
77
+ line = line[: max_width - 1] + "…"
78
+ lines.append(line)
79
+ text = "\n".join(lines)
80
+
81
+ box = []
82
+ box.append("┌" + "─" * (width + 2) + "┐")
83
+ for line in lines:
84
+ box.append(f"│ {line:{width}} │")
85
+
86
+ box.append("└" + "─" * (width + 2) + "┘")
87
+ return "\n".join(box)
88
+
89
+
90
+ def bold(text):
91
+ from termcolor import colored
92
+
93
+ return colored(text, attrs=["bold"])
94
+
95
+
96
+ def red(text):
97
+ from termcolor import colored
98
+
99
+ return colored(text, "red")
100
+
101
+
102
+ def green(text):
103
+ from termcolor import colored
104
+
105
+ return colored(text, "green")
106
+
107
+
108
+ class Tree:
109
+ def __init__(self, actor, parent=None):
110
+ self._actor = actor
111
+ self._kids = []
112
+ self._parent = parent
113
+
114
+ def adopt(self, kid):
115
+ kid._parent._kids.remove(kid)
116
+ self._kids.append(kid)
117
+ kid._parent = self
118
+ # assert False
119
+
120
+ def forget(self):
121
+ self._parent._kids.remove(self)
122
+ self._parent = None
123
+
124
+ @property
125
+ def is_leaf(self):
126
+ return len(self._kids) == 0
127
+
128
+ @property
129
+ def key(self):
130
+ return tuple(sorted(self._actor.as_dict().items()))
131
+
132
+ @property
133
+ def _text(self):
134
+ return self._actor.summary
135
+
136
+ @property
137
+ def summary(self):
138
+ return self._actor.summary
139
+
140
+ def as_dict(self):
141
+ return self._actor.as_dict()
142
+
143
+ def node(self, actor, insert=False):
144
+ node = Tree(actor, self)
145
+ if insert:
146
+ self._kids.insert(0, node)
147
+ else:
148
+ self._kids.append(node)
149
+ return node
150
+
151
+ def print(self, file=sys.stdout):
152
+ padding = []
153
+
154
+ while self._factorise():
155
+ pass
156
+
157
+ self._print(padding, file=file)
158
+
159
+ def _leaves(self, result):
160
+ if self.is_leaf:
161
+ result.append(self)
162
+ else:
163
+ for kid in self._kids:
164
+ kid._leaves(result)
165
+
166
+ def _factorise(self):
167
+ if len(self._kids) == 0:
168
+ return False
169
+
170
+ result = False
171
+ for kid in self._kids:
172
+ result = kid._factorise() or result
173
+
174
+ if result:
175
+ return True
176
+
177
+ same = defaultdict(list)
178
+ for kid in self._kids:
179
+ for grand_kid in kid._kids:
180
+ same[grand_kid.key].append((kid, grand_kid))
181
+
182
+ result = False
183
+ n = len(self._kids)
184
+ texts = []
185
+ for text, v in same.items():
186
+ if len(v) == n and n > 1:
187
+ for kid, grand_kid in v:
188
+ kid._kids.remove(grand_kid)
189
+ texts.append((text, v[1][1]))
190
+ result = True
191
+
192
+ for text, actor in reversed(texts):
193
+ self.node(actor, True)
194
+
195
+ if result:
196
+ return True
197
+
198
+ if len(self._kids) != 1:
199
+ return False
200
+
201
+ kid = self._kids[0]
202
+ texts = []
203
+ for grand_kid in list(kid._kids):
204
+ if len(grand_kid._kids) == 0:
205
+ kid._kids.remove(grand_kid)
206
+ texts.append((grand_kid.key, grand_kid))
207
+ result = True
208
+
209
+ for text, actor in reversed(texts):
210
+ self.node(actor, True)
211
+
212
+ return result
213
+
214
+ def _print(self, padding, file=sys.stdout):
215
+ for i, p in enumerate(padding[:-1]):
216
+ if p == " └":
217
+ padding[i] = " "
218
+ if p == " ├":
219
+ padding[i] = " │"
220
+ if padding:
221
+ print(f"{''.join(padding)}─{self._text}", file=file)
222
+ else:
223
+ print(self._text, file=file)
224
+ padding.append(" ")
225
+ for i, k in enumerate(self._kids):
226
+ sep = " ├" if i < len(self._kids) - 1 else " └"
227
+ padding[-1] = sep
228
+ k._print(padding, file=file)
229
+
230
+ padding.pop()
231
+
232
+ def to_json(self, depth=0):
233
+ while self._factorise():
234
+ pass
235
+
236
+ return {
237
+ "actor": self._actor.as_dict(),
238
+ "kids": [k.to_json(depth + 1) for k in self._kids],
239
+ "depth": depth,
240
+ }
241
+
242
+
243
+ def table(rows, header, align, margin=0):
244
+ """Format a table
245
+
246
+ >>> table([['Aa', 12, 5],
247
+ ['B', 120, 1],
248
+ ['C', 9, 123]],
249
+ ['C1', 'C2', 'C3'],
250
+ ['<', '>', '>']))
251
+ C1 │ C2 │ C3
252
+ ───┼─────┼────
253
+ Aa │ 12 │ 5
254
+ B │ 120 │ 1
255
+ C │ 9 │ 123
256
+ ───┴─────┴────
257
+
258
+ Parameters
259
+ ----------
260
+ rows : list of lists (or tuples)
261
+ The rows of the table
262
+ header : A list or tuple of strings
263
+ The header of the table
264
+ align : A list of '<', '>', or '^'
265
+ To align the columns to the left, right, or center
266
+ margin : int, optional
267
+ Extra spaces on the left side of the table, by default 0
268
+
269
+
270
+ Returns
271
+ -------
272
+ str
273
+ A table as a string
274
+ """
275
+
276
+ def _(x):
277
+ try:
278
+ x = float(x)
279
+ except Exception:
280
+ pass
281
+
282
+ if isinstance(x, float):
283
+ return f"{x:g}"
284
+
285
+ if isinstance(x, str):
286
+ return x
287
+ if isinstance(x, int):
288
+ return str(x)
289
+
290
+ return str(x)
291
+
292
+ tmp = []
293
+ for row in rows:
294
+ tmp.append([_(x) for x in row])
295
+
296
+ all_rows = [header] + tmp
297
+
298
+ lens = [max(len(x) for x in col) for col in zip(*all_rows)]
299
+
300
+ result = []
301
+ for i, row in enumerate(all_rows):
302
+
303
+ def _(x, i, j):
304
+ if align[j] == "<":
305
+ return x.ljust(i)
306
+ if align[j] == ">":
307
+ return x.rjust(i)
308
+ return x.center(i)
309
+
310
+ result.append(" │ ".join([_(x, i, j) for j, (x, i) in enumerate(zip(row, lens))]))
311
+ if i == 0:
312
+ result.append("─┼─".join(["─" * i for i in lens]))
313
+
314
+ result.append("─┴─".join(["─" * i for i in lens]))
315
+
316
+ if margin:
317
+ result = [margin * " " + x for x in result]
318
+
319
+ return "\n".join(result)
320
+
321
+
322
+ def progress(done, todo, width=80) -> str:
323
+ """_summary_
324
+
325
+ >>> print(progress(10, 100,width=50))
326
+ █████▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒
327
+
328
+ Parameters
329
+ ----------
330
+ done : function
331
+ _description_
332
+ todo : _type_
333
+ _description_
334
+ width : int, optional
335
+ _description_, by default 80
336
+
337
+ Returns
338
+ -------
339
+ str
340
+ _description_
341
+
342
+
343
+ """
344
+ done = min(int(done / todo * width + 0.5), width)
345
+ return green("█" * done) + red("█" * (width - done))