elements 2.2.0__tar.gz → 3.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {elements-2.2.0/elements.egg-info → elements-3.1.0}/PKG-INFO +1 -1
- {elements-2.2.0 → elements-3.1.0}/elements/__init__.py +1 -1
- elements-3.1.0/elements/agg.py +143 -0
- {elements-2.2.0 → elements-3.1.0}/elements/checkpoint.py +23 -15
- {elements-2.2.0 → elements-3.1.0}/elements/config.py +7 -5
- {elements-2.2.0 → elements-3.1.0}/elements/flags.py +43 -19
- {elements-2.2.0 → elements-3.1.0}/elements/printing.py +27 -18
- {elements-2.2.0 → elements-3.1.0}/elements/tree.py +5 -0
- {elements-2.2.0 → elements-3.1.0}/elements/usage.py +1 -1
- {elements-2.2.0 → elements-3.1.0}/elements/uuid.py +12 -2
- {elements-2.2.0 → elements-3.1.0/elements.egg-info}/PKG-INFO +1 -1
- {elements-2.2.0 → elements-3.1.0}/elements.egg-info/SOURCES.txt +4 -1
- elements-3.1.0/elements.egg-info/requires.txt +1 -0
- {elements-2.2.0 → elements-3.1.0}/requirements-optional.txt +3 -1
- elements-3.1.0/requirements.txt +1 -0
- elements-3.1.0/tests/test_basics.py +70 -0
- elements-3.1.0/tests/test_flags.py +126 -0
- elements-3.1.0/tests/test_path.py +45 -0
- elements-2.2.0/elements/agg.py +0 -78
- elements-2.2.0/elements.egg-info/requires.txt +0 -4
- elements-2.2.0/requirements.txt +0 -4
- {elements-2.2.0 → elements-3.1.0}/LICENSE +0 -0
- {elements-2.2.0 → elements-3.1.0}/MANIFEST.in +0 -0
- {elements-2.2.0 → elements-3.1.0}/README.md +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements/counter.py +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements/fps.py +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements/logger.py +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements/path.py +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements/plotting.py +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements/rwlock.py +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements/timer.py +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements/when.py +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements.egg-info/dependency_links.txt +0 -0
- {elements-2.2.0 → elements-3.1.0}/elements.egg-info/top_level.txt +0 -0
- {elements-2.2.0 → elements-3.1.0}/setup.cfg +0 -0
- {elements-2.2.0 → elements-3.1.0}/setup.py +0 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import operator
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from functools import partial as bind
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Agg:
|
|
10
|
+
|
|
11
|
+
def __init__(self, maxlen=1e6):
|
|
12
|
+
self.reducers = defaultdict(list)
|
|
13
|
+
self.names = {}
|
|
14
|
+
self.maxlen = int(maxlen)
|
|
15
|
+
|
|
16
|
+
def add(self, key_or_dict, value=None, agg='default', prefix=None):
|
|
17
|
+
if value is not None:
|
|
18
|
+
self._add_single(key_or_dict, value, agg, prefix)
|
|
19
|
+
return
|
|
20
|
+
for key, value in key_or_dict.items():
|
|
21
|
+
self._add_single(key, value, agg, prefix)
|
|
22
|
+
|
|
23
|
+
def result(self, reset=True, prefix=None):
|
|
24
|
+
metrics = {}
|
|
25
|
+
for key, reducers in self.reducers.items():
|
|
26
|
+
if len(reducers) == 1:
|
|
27
|
+
metrics[key] = reducers[0].current()
|
|
28
|
+
else:
|
|
29
|
+
for name, reducer in zip(self.names[key], reducers):
|
|
30
|
+
metrics[f'{key}/{name}'] = reducer.current()
|
|
31
|
+
if prefix:
|
|
32
|
+
metrics = {f'{prefix}/{k}': v for k, v in metrics.items()}
|
|
33
|
+
reset and self.reset()
|
|
34
|
+
return metrics
|
|
35
|
+
|
|
36
|
+
def reset(self):
|
|
37
|
+
self.reducers.clear()
|
|
38
|
+
|
|
39
|
+
def _add_single(self, key, value, agg, prefix):
|
|
40
|
+
key = f'{prefix}/{key}' if prefix else key
|
|
41
|
+
reducers = self.reducers[key]
|
|
42
|
+
if reducers:
|
|
43
|
+
for reducer in reducers:
|
|
44
|
+
reducer.update(value)
|
|
45
|
+
return
|
|
46
|
+
if agg == 'default':
|
|
47
|
+
agg = 'avg' if np.asarray(value).ndim <= 1 else 'last'
|
|
48
|
+
if isinstance(agg, str):
|
|
49
|
+
aggs = (agg,)
|
|
50
|
+
self.names[key] = None
|
|
51
|
+
else:
|
|
52
|
+
aggs = agg
|
|
53
|
+
self.names[key] = aggs
|
|
54
|
+
for agg in aggs:
|
|
55
|
+
if agg == 'avg':
|
|
56
|
+
reducer = Mean(value)
|
|
57
|
+
elif agg == 'sum':
|
|
58
|
+
reducer = Sum(value)
|
|
59
|
+
elif agg == 'min':
|
|
60
|
+
reducer = Min(value)
|
|
61
|
+
elif agg == 'max':
|
|
62
|
+
reducer = Max(value)
|
|
63
|
+
elif agg == 'stack':
|
|
64
|
+
reducer = Stack(value, self.maxlen)
|
|
65
|
+
elif agg == 'last':
|
|
66
|
+
reducer = Last(value)
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(agg)
|
|
69
|
+
reducers.append(reducer)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class Reducer:
|
|
73
|
+
|
|
74
|
+
def __init__(self, scalar_fn, array_fn, initial):
|
|
75
|
+
self.scalar_fn = scalar_fn
|
|
76
|
+
self.array_fn = array_fn
|
|
77
|
+
self.is_scalar = isinstance(initial, (int, float))
|
|
78
|
+
if self.is_scalar:
|
|
79
|
+
self.interm = initial
|
|
80
|
+
else:
|
|
81
|
+
self.interm = np.array(initial, np.float64)
|
|
82
|
+
self.count = 1
|
|
83
|
+
|
|
84
|
+
def update(self, value):
|
|
85
|
+
if self.is_scalar:
|
|
86
|
+
if math.isnan(value):
|
|
87
|
+
return
|
|
88
|
+
self.interm = self.scalar_fn(self.interm, value)
|
|
89
|
+
else:
|
|
90
|
+
value = np.asarray(value)
|
|
91
|
+
if np.isnan(value).any():
|
|
92
|
+
return
|
|
93
|
+
self.interm = self.array_fn(self.interm, value)
|
|
94
|
+
self.count += 1
|
|
95
|
+
|
|
96
|
+
def current(self):
|
|
97
|
+
return np.array(self.interm)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Mean:
|
|
101
|
+
|
|
102
|
+
def __init__(self, initial):
|
|
103
|
+
self.reducer = Sum(initial)
|
|
104
|
+
|
|
105
|
+
def update(self, value):
|
|
106
|
+
self.reducer.update(value)
|
|
107
|
+
|
|
108
|
+
def current(self):
|
|
109
|
+
return self.reducer.current() / self.reducer.count
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class Stack:
|
|
113
|
+
|
|
114
|
+
def __init__(self, initial, maxlen=1e5):
|
|
115
|
+
self.stack = [initial]
|
|
116
|
+
self.maxlen = int(maxlen)
|
|
117
|
+
|
|
118
|
+
def update(self, value):
|
|
119
|
+
if len(self.stack) < self.maxlen:
|
|
120
|
+
self.stack.append(value)
|
|
121
|
+
|
|
122
|
+
def current(self):
|
|
123
|
+
return np.stack(self.stack)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class Last:
|
|
127
|
+
|
|
128
|
+
def __init__(self, initial):
|
|
129
|
+
self.value = initial
|
|
130
|
+
|
|
131
|
+
def update(self, value):
|
|
132
|
+
self.value = value
|
|
133
|
+
|
|
134
|
+
def current(self):
|
|
135
|
+
return self.value
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
Sum = bind(
|
|
139
|
+
Reducer, operator.add, lambda x, y: np.add(x, y, out=x, dtype=np.float64))
|
|
140
|
+
Min = bind(
|
|
141
|
+
Reducer, min, lambda x, y: np.minimum(x, y, out=x, dtype=np.float64))
|
|
142
|
+
Max = bind(
|
|
143
|
+
Reducer, max, lambda x, y: np.maximum(x, y, out=x, dtype=np.float64))
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import concurrent.futures
|
|
2
|
-
import
|
|
2
|
+
import pickle
|
|
3
3
|
import time
|
|
4
4
|
|
|
5
|
+
from . import printing
|
|
5
6
|
from . import path
|
|
7
|
+
from . import timer
|
|
6
8
|
|
|
7
9
|
|
|
8
10
|
class Checkpoint:
|
|
@@ -12,7 +14,7 @@ class Checkpoint:
|
|
|
12
14
|
self._values = {}
|
|
13
15
|
self._parallel = parallel
|
|
14
16
|
if self._parallel:
|
|
15
|
-
self._worker = concurrent.futures.ThreadPoolExecutor(1)
|
|
17
|
+
self._worker = concurrent.futures.ThreadPoolExecutor(1, 'checkpoint')
|
|
16
18
|
self._promise = None
|
|
17
19
|
|
|
18
20
|
def __setattr__(self, name, value):
|
|
@@ -31,7 +33,7 @@ class Checkpoint:
|
|
|
31
33
|
if name.startswith('_'):
|
|
32
34
|
raise AttributeError(name)
|
|
33
35
|
try:
|
|
34
|
-
return
|
|
36
|
+
return self._values[name]
|
|
35
37
|
except AttributeError:
|
|
36
38
|
raise ValueError(name)
|
|
37
39
|
|
|
@@ -48,33 +50,39 @@ class Checkpoint:
|
|
|
48
50
|
def save(self, filename=None, keys=None):
|
|
49
51
|
assert self._filename or filename
|
|
50
52
|
filename = path.Path(filename or self._filename)
|
|
51
|
-
|
|
53
|
+
printing.print_(f'Writing checkpoint: {filename}')
|
|
52
54
|
if self._parallel:
|
|
53
55
|
self._promise and self._promise.result()
|
|
54
56
|
self._promise = self._worker.submit(self._save, filename, keys)
|
|
55
57
|
else:
|
|
56
58
|
self._save(filename, keys)
|
|
57
59
|
|
|
60
|
+
@timer.section('checkpoint_save')
|
|
58
61
|
def _save(self, filename, keys):
|
|
59
62
|
keys = tuple(self._values.keys() if keys is None else keys)
|
|
60
63
|
assert all([not k.startswith('_') for k in keys]), keys
|
|
61
64
|
data = {k: self._values[k].save() for k in keys}
|
|
62
65
|
data['_timestamp'] = time.time()
|
|
63
66
|
filename.parent.mkdirs()
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
67
|
+
content = pickle.dumps(data)
|
|
68
|
+
if str(filename).startswith('gs://'):
|
|
69
|
+
filename.write(content, mode='wb')
|
|
70
|
+
else:
|
|
71
|
+
# Write to a temporary file and then atomically rename, so that the file
|
|
72
|
+
# either contains a complete write or not update at all if writing was
|
|
73
|
+
# interrupted.
|
|
74
|
+
tmp = filename.parent / (filename.name + '.tmp')
|
|
75
|
+
tmp.write(content, mode='wb')
|
|
76
|
+
tmp.move(filename)
|
|
77
|
+
print('Wrote checkpoint.')
|
|
71
78
|
|
|
79
|
+
@timer.section('checkpoint_load')
|
|
72
80
|
def load(self, filename=None, keys=None):
|
|
73
81
|
assert self._filename or filename
|
|
74
82
|
self._promise and self._promise.result() # Wait for last save.
|
|
75
83
|
filename = path.Path(filename or self._filename)
|
|
76
|
-
|
|
77
|
-
data =
|
|
84
|
+
printing.print_(f'Loading checkpoint: {filename}')
|
|
85
|
+
data = pickle.loads(filename.read('rb'))
|
|
78
86
|
keys = tuple(data.keys() if keys is None else keys)
|
|
79
87
|
for key in keys:
|
|
80
88
|
if key.startswith('_'):
|
|
@@ -82,10 +90,10 @@ class Checkpoint:
|
|
|
82
90
|
try:
|
|
83
91
|
self._values[key].load(data[key])
|
|
84
92
|
except Exception:
|
|
85
|
-
print(f
|
|
93
|
+
print(f"Error loading '{key}' from checkpoint.")
|
|
86
94
|
raise
|
|
87
95
|
age = time.time() - data['_timestamp']
|
|
88
|
-
|
|
96
|
+
printing.print_(f'Loaded checkpoint from {age:.0f} seconds ago.')
|
|
89
97
|
|
|
90
98
|
def load_or_save(self):
|
|
91
99
|
if self.exists():
|
|
@@ -30,9 +30,10 @@ class Config(dict):
|
|
|
30
30
|
if filename.suffix == '.json':
|
|
31
31
|
filename.write(json.dumps(dict(self)))
|
|
32
32
|
elif filename.suffix in ('.yml', '.yaml'):
|
|
33
|
-
|
|
33
|
+
from ruamel.yaml import YAML
|
|
34
|
+
yaml = YAML(typ='safe')
|
|
34
35
|
with io.StringIO() as stream:
|
|
35
|
-
yaml.
|
|
36
|
+
yaml.dump(dict(self), stream)
|
|
36
37
|
filename.write(stream.getvalue())
|
|
37
38
|
else:
|
|
38
39
|
raise NotImplementedError(filename.suffix)
|
|
@@ -41,10 +42,11 @@ class Config(dict):
|
|
|
41
42
|
def load(cls, filename):
|
|
42
43
|
filename = path.Path(filename)
|
|
43
44
|
if filename.suffix == '.json':
|
|
44
|
-
return cls(json.loads(filename.
|
|
45
|
+
return cls(json.loads(filename.read()))
|
|
45
46
|
elif filename.suffix in ('.yml', '.yaml'):
|
|
46
|
-
|
|
47
|
-
|
|
47
|
+
from ruamel.yaml import YAML
|
|
48
|
+
yaml = YAML(typ='safe')
|
|
49
|
+
return cls(yaml.load(filename.read()))
|
|
48
50
|
else:
|
|
49
51
|
raise NotImplementedError(filename.suffix)
|
|
50
52
|
|
|
@@ -12,9 +12,11 @@ class Flags:
|
|
|
12
12
|
def parse(self, argv=None, help_exits=True):
|
|
13
13
|
parsed, remaining = self.parse_known(argv)
|
|
14
14
|
for flag in remaining:
|
|
15
|
-
if flag.startswith('--'):
|
|
16
|
-
raise
|
|
17
|
-
|
|
15
|
+
if flag.startswith('--') and flag[2:] not in self._config.flat:
|
|
16
|
+
raise KeyError(f"Flag '{flag}' did not match any config keys.")
|
|
17
|
+
if remaining:
|
|
18
|
+
raise ValueError(
|
|
19
|
+
f'Could not parse all arguments. Remaining: {remaining}')
|
|
18
20
|
return parsed
|
|
19
21
|
|
|
20
22
|
def parse_known(self, argv=None, help_exits=False):
|
|
@@ -52,25 +54,40 @@ class Flags:
|
|
|
52
54
|
return
|
|
53
55
|
if not key:
|
|
54
56
|
vals = ', '.join(f"'{x}'" for x in vals)
|
|
55
|
-
|
|
57
|
+
remaining.extend(vals)
|
|
58
|
+
return
|
|
59
|
+
# raise ValueError(f"Values {vals} were not preceded by any flag.")
|
|
56
60
|
name = key[len('--'):]
|
|
57
61
|
if '=' in name:
|
|
58
62
|
remaining.extend([key] + vals)
|
|
59
63
|
return
|
|
60
|
-
if
|
|
64
|
+
if not vals:
|
|
65
|
+
remaining.extend([key])
|
|
66
|
+
return
|
|
67
|
+
# raise ValueError(f"Flag '{key}' was not followed by any values.")
|
|
68
|
+
if name.endswith('+') and name[:-1] in self._config:
|
|
69
|
+
key = name[:-1]
|
|
70
|
+
default = self._config[key]
|
|
71
|
+
if not isinstance(default, tuple):
|
|
72
|
+
raise TypeError(
|
|
73
|
+
f"Cannot append to key '{key}' which is of type "
|
|
74
|
+
f"'{type(default).__name__}' instead of tuple.")
|
|
75
|
+
if key not in parsed:
|
|
76
|
+
parsed[key] = default
|
|
77
|
+
parsed[key] += self._parse_flag_value(default, vals, key)
|
|
78
|
+
elif self._config.IS_PATTERN.fullmatch(name):
|
|
61
79
|
pattern = re.compile(name)
|
|
62
|
-
keys =
|
|
80
|
+
keys = [k for k in self._config.flat if pattern.fullmatch(k)]
|
|
81
|
+
if keys:
|
|
82
|
+
for key in keys:
|
|
83
|
+
parsed[key] = self._parse_flag_value(self._config[key], vals, key)
|
|
84
|
+
else:
|
|
85
|
+
remaining.extend([key] + vals)
|
|
63
86
|
elif name in self._config:
|
|
64
|
-
|
|
87
|
+
key = name
|
|
88
|
+
parsed[key] = self._parse_flag_value(self._config[key], vals, key)
|
|
65
89
|
else:
|
|
66
|
-
keys = []
|
|
67
|
-
if not keys:
|
|
68
90
|
remaining.extend([key] + vals)
|
|
69
|
-
return
|
|
70
|
-
if not vals:
|
|
71
|
-
raise ValueError(f"Flag '{key}' was not followed by any values.")
|
|
72
|
-
for key in keys:
|
|
73
|
-
parsed[key] = self._parse_flag_value(self._config[key], vals, key)
|
|
74
91
|
|
|
75
92
|
def _parse_flag_value(self, default, value, key):
|
|
76
93
|
value = value if isinstance(value, (tuple, list)) else (value,)
|
|
@@ -78,7 +95,9 @@ class Flags:
|
|
|
78
95
|
if len(value) == 1 and ',' in value[0]:
|
|
79
96
|
value = value[0].split(',')
|
|
80
97
|
return tuple(self._parse_flag_value(default[0], [x], key) for x in value)
|
|
81
|
-
|
|
98
|
+
if len(value) != 1:
|
|
99
|
+
raise TypeError(
|
|
100
|
+
f"Expected a single value for key '{key}' but got: {value}")
|
|
82
101
|
value = str(value[0])
|
|
83
102
|
if default is None:
|
|
84
103
|
return value
|
|
@@ -92,11 +111,16 @@ class Flags:
|
|
|
92
111
|
try:
|
|
93
112
|
value = float(value) # Allow scientific notation for integers.
|
|
94
113
|
assert float(int(value)) == value
|
|
95
|
-
except (TypeError, AssertionError):
|
|
96
|
-
message = f"Expected int but got
|
|
114
|
+
except (ValueError, TypeError, AssertionError):
|
|
115
|
+
message = f"Expected int but got '{value}' for key '{key}'."
|
|
97
116
|
raise TypeError(message)
|
|
98
117
|
return int(value)
|
|
99
118
|
if isinstance(default, dict):
|
|
100
|
-
raise
|
|
119
|
+
raise KeyError(
|
|
101
120
|
f"Key '{key}' refers to a whole dict. Please speicfy a subkey.")
|
|
102
|
-
|
|
121
|
+
try:
|
|
122
|
+
return type(default)(value)
|
|
123
|
+
except ValueError:
|
|
124
|
+
raise TypeError(
|
|
125
|
+
f"Cannot convert '{value}' to type '{type(default).__name__}' for "
|
|
126
|
+
f"key '{key}'.")
|
|
@@ -2,15 +2,22 @@ import re
|
|
|
2
2
|
|
|
3
3
|
try:
|
|
4
4
|
import colored
|
|
5
|
-
REGEX_TOKEN = re.compile(r"([^a-zA-Z0-9-_./'\"\[\]])", re.MULTILINE)
|
|
6
|
-
REGEX_NUMBER = re.compile(r'[-+]?[0-9]+[0-9.]*(e[-+]?[0-9])?')
|
|
7
|
-
KEYWORDS = () # ('True', 'False', 'None', 'bool', 'int', 'str', 'float')
|
|
8
5
|
except ImportError:
|
|
9
|
-
print('For colored outputs: pip install colored')
|
|
10
6
|
colored = None
|
|
7
|
+
print('For colored outputs: pip install colored')
|
|
8
|
+
|
|
11
9
|
|
|
10
|
+
REGEX_TOKEN = re.compile(
|
|
11
|
+
r"([^a-zA-Z0-9-_./'\"\[\]]|['\"][^\s]*['\"])", re.MULTILINE)
|
|
12
|
+
REGEX_NUMBER = re.compile(
|
|
13
|
+
r'([-+]?[0-9]+[0-9.,]*(e[-+]?[0-9])?|nan|-?inf)')
|
|
14
|
+
KEYWORDS = (
|
|
15
|
+
'True', 'False', 'None', 'bool', 'int', 'str', 'float',
|
|
16
|
+
'uint8', 'float16', 'float32', 'int32', 'int64')
|
|
12
17
|
|
|
13
|
-
|
|
18
|
+
|
|
19
|
+
def print_(*values, color=True, **kwargs):
|
|
20
|
+
value = kwargs.get('sep', ' ').join(str(x) for x in values)
|
|
14
21
|
assert not color or isinstance(color, (bool, str)), color
|
|
15
22
|
if isinstance(color, str) and colored:
|
|
16
23
|
value = colored.stylize(value, colored.fg(color))
|
|
@@ -20,34 +27,36 @@ def print_(value, color=True, **kwargs):
|
|
|
20
27
|
tokens = REGEX_TOKEN.split(value) + [None]
|
|
21
28
|
for i, token in enumerate(tokens[:-1]):
|
|
22
29
|
new = prev.copy()
|
|
23
|
-
|
|
30
|
+
word = token.strip()
|
|
24
31
|
new[2] = None
|
|
25
|
-
if not
|
|
32
|
+
if not word:
|
|
26
33
|
new[0] = None
|
|
27
|
-
elif
|
|
34
|
+
elif word in '/-+':
|
|
28
35
|
new[0] = 'green'
|
|
29
36
|
new[2] = True
|
|
30
|
-
elif
|
|
37
|
+
elif word in '{}()<>,:':
|
|
31
38
|
new[0] = 'white'
|
|
32
39
|
elif token == '=':
|
|
33
40
|
new[0] = 'white'
|
|
34
|
-
elif
|
|
41
|
+
elif word[0].isalpha() and tokens[i + 1] == '=':
|
|
35
42
|
new[0] = 'magenta'
|
|
36
|
-
elif
|
|
43
|
+
elif word in KEYWORDS:
|
|
37
44
|
new[0] = 'blue'
|
|
38
|
-
elif
|
|
45
|
+
elif word.startswith('---'):
|
|
39
46
|
new[1] = True
|
|
40
|
-
elif REGEX_NUMBER.match(
|
|
47
|
+
elif REGEX_NUMBER.match(word):
|
|
41
48
|
new[0] = 'blue'
|
|
42
|
-
elif
|
|
49
|
+
elif word[0] == word[-1] == "'":
|
|
43
50
|
new[0] = 'red'
|
|
44
|
-
elif
|
|
51
|
+
elif word[0] == word[-1] == '"':
|
|
45
52
|
new[0] = 'red'
|
|
46
|
-
elif
|
|
53
|
+
elif word[0] == '[' and word[-1] == ']':
|
|
47
54
|
new[0] = 'cyan'
|
|
48
|
-
elif
|
|
55
|
+
elif any(word.startswith(x) for x in ('/', '~', './')):
|
|
49
56
|
new[0] = 'yellow'
|
|
50
|
-
elif
|
|
57
|
+
elif len(word) >= 3 and word[0] == word[-1] and word[0] in ("'", '"'):
|
|
58
|
+
new[0] = 'green'
|
|
59
|
+
elif word[0] == word[0].upper():
|
|
51
60
|
new[0] = None
|
|
52
61
|
else:
|
|
53
62
|
new[0] = None
|
|
@@ -20,6 +20,11 @@ def map_(fn, *trees, isleaf=None):
|
|
|
20
20
|
assert all(set(x.keys()) == set(first.keys()) for x in trees), (
|
|
21
21
|
printing.format_(trees))
|
|
22
22
|
return {k: map_(fn, *[t[k] for t in trees], **kw) for k in first}
|
|
23
|
+
if hasattr(first, 'keys') and hasattr(first, 'get'):
|
|
24
|
+
assert all(set(x.keys()) == set(first.keys()) for x in trees), (
|
|
25
|
+
printing.format_(trees))
|
|
26
|
+
return type(first)(
|
|
27
|
+
{k: map_(fn, *[t[k] for t in trees], **kw) for k in first})
|
|
23
28
|
return fn(*trees)
|
|
24
29
|
|
|
25
30
|
|
|
@@ -50,7 +50,7 @@ class NvsmiStats:
|
|
|
50
50
|
|
|
51
51
|
@timer.section('nvsmi_stats')
|
|
52
52
|
def __call__(self):
|
|
53
|
-
output = os.popen('nvidia-smi --query -d UTILIZATION').read()
|
|
53
|
+
output = os.popen('nvidia-smi --query -d UTILIZATION 2>&1').read()
|
|
54
54
|
if not output:
|
|
55
55
|
print('To log GPU stats, make sure nvidia-smi is working.')
|
|
56
56
|
return {}
|
|
@@ -5,11 +5,18 @@ import numpy as np
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class UUID:
|
|
8
|
+
"""UUID that is stored as 16 byte string and can be converted to and from
|
|
9
|
+
int, string, and array types."""
|
|
10
|
+
|
|
11
|
+
__slots__ = ('value', '_hash')
|
|
8
12
|
|
|
9
13
|
DEBUG_ID = None
|
|
10
14
|
BASE62 = string.digits + string.ascii_letters
|
|
11
15
|
BASE62REV = {x: i for i, x in enumerate(BASE62)}
|
|
12
16
|
|
|
17
|
+
# def __new__(cls, val=None):
|
|
18
|
+
# return val or np.random.randint(1, 2 ** 63)
|
|
19
|
+
|
|
13
20
|
@classmethod
|
|
14
21
|
def reset(cls, *, debug):
|
|
15
22
|
cls.DEBUG_ID = 0 if debug else None
|
|
@@ -21,10 +28,13 @@ class UUID:
|
|
|
21
28
|
else:
|
|
22
29
|
type(self).DEBUG_ID += 1
|
|
23
30
|
self.value = self.DEBUG_ID.to_bytes(16, 'big')
|
|
24
|
-
elif isinstance(value,
|
|
31
|
+
elif isinstance(value, UUID):
|
|
25
32
|
self.value = value.value
|
|
26
33
|
elif isinstance(value, int):
|
|
27
34
|
self.value = value.to_bytes(16, 'big')
|
|
35
|
+
elif isinstance(value, bytes):
|
|
36
|
+
assert len(value) == 16, value
|
|
37
|
+
self.value = value
|
|
28
38
|
elif isinstance(value, str):
|
|
29
39
|
if self.DEBUG_ID is None:
|
|
30
40
|
integer = 0
|
|
@@ -37,7 +47,7 @@ class UUID:
|
|
|
37
47
|
self.value = value.tobytes()
|
|
38
48
|
else:
|
|
39
49
|
raise ValueError(value)
|
|
40
|
-
assert type(self.value) == bytes, type(self.value)
|
|
50
|
+
assert type(self.value) == bytes, type(self.value) # noqa
|
|
41
51
|
assert len(self.value) == 16, len(self.value)
|
|
42
52
|
self._hash = hash(self.value)
|
|
43
53
|
|
|
@@ -25,4 +25,7 @@ elements.egg-info/PKG-INFO
|
|
|
25
25
|
elements.egg-info/SOURCES.txt
|
|
26
26
|
elements.egg-info/dependency_links.txt
|
|
27
27
|
elements.egg-info/requires.txt
|
|
28
|
-
elements.egg-info/top_level.txt
|
|
28
|
+
elements.egg-info/top_level.txt
|
|
29
|
+
tests/test_basics.py
|
|
30
|
+
tests/test_flags.py
|
|
31
|
+
tests/test_path.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
numpy
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
numpy
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import elements
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TestBasics:
|
|
6
|
+
|
|
7
|
+
def test_config(self):
|
|
8
|
+
config = elements.Config({'one.two': 12, 'foo': {'bar': True}})
|
|
9
|
+
assert config.one.two == 12
|
|
10
|
+
assert config.foo.bar is True
|
|
11
|
+
assert config['foo'].bar is True
|
|
12
|
+
assert 'foo' in config
|
|
13
|
+
assert 'foo.bar' in config
|
|
14
|
+
assert str(config)
|
|
15
|
+
_ = config.update({'one.two': 42})
|
|
16
|
+
with pytest.raises(KeyError):
|
|
17
|
+
_ = config.update({'new': 1})
|
|
18
|
+
with pytest.raises(TypeError):
|
|
19
|
+
_ = config.update({'one.two': 'string'})
|
|
20
|
+
with pytest.raises(AttributeError):
|
|
21
|
+
config.one.two = 1
|
|
22
|
+
with pytest.raises(TypeError):
|
|
23
|
+
elements.Config({'foo': lambda: None})
|
|
24
|
+
|
|
25
|
+
def test_flags(self):
|
|
26
|
+
flags = elements.Flags(foo=12, bar={'baz': True})
|
|
27
|
+
assert flags.parse(['--bar.baz', 'False']).foo == 12
|
|
28
|
+
assert flags.parse(['--bar.baz', 'False']).bar.baz is False
|
|
29
|
+
assert flags.parse(['--bar.*', 'False']).bar.baz is False
|
|
30
|
+
with pytest.raises(TypeError):
|
|
31
|
+
flags.parse(['--bar.baz', '12'])
|
|
32
|
+
with pytest.raises(KeyError):
|
|
33
|
+
flags.parse(['--.*unknown.*', '12'])
|
|
34
|
+
_, remaining = flags.parse_known(['one=two', '--foo', '42', '--three'])
|
|
35
|
+
assert remaining == ['one=two', '--three']
|
|
36
|
+
flags = elements.Flags({'foo': 12})
|
|
37
|
+
_, remaining = flags.parse_known(['--help'], help_exits=False)
|
|
38
|
+
assert remaining == ['--help']
|
|
39
|
+
|
|
40
|
+
def test_logger(self, capsys):
|
|
41
|
+
step = elements.Counter()
|
|
42
|
+
logger = elements.Logger(step, [elements.logger.TerminalOutput()])
|
|
43
|
+
logger.scalar('name', 15)
|
|
44
|
+
logger.write()
|
|
45
|
+
output = capsys.readouterr()
|
|
46
|
+
print(output)
|
|
47
|
+
|
|
48
|
+
def test_every(self):
|
|
49
|
+
should = elements.when.Every(5)
|
|
50
|
+
result = []
|
|
51
|
+
for i in range(16):
|
|
52
|
+
if should(i):
|
|
53
|
+
result.append(i)
|
|
54
|
+
assert result == [0, 5, 10, 15]
|
|
55
|
+
|
|
56
|
+
def test_once(self):
|
|
57
|
+
should = elements.when.Once()
|
|
58
|
+
result = []
|
|
59
|
+
for i in range(16):
|
|
60
|
+
if should():
|
|
61
|
+
result.append(i)
|
|
62
|
+
assert result == [0]
|
|
63
|
+
|
|
64
|
+
def test_until(self):
|
|
65
|
+
should = elements.when.Until(6)
|
|
66
|
+
result = []
|
|
67
|
+
for i in range(16):
|
|
68
|
+
if should(i):
|
|
69
|
+
result.append(i)
|
|
70
|
+
assert result == [0, 1, 2, 3, 4, 5]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import elements
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TestFlags:
|
|
6
|
+
|
|
7
|
+
def test_int(self):
|
|
8
|
+
flags = elements.Flags({'foo': 42})
|
|
9
|
+
assert flags.parse(['--foo=1']).foo == 1
|
|
10
|
+
assert flags.parse(['--foo=1.0']).foo == 1
|
|
11
|
+
assert flags.parse(['--foo=1e2']).foo == 100
|
|
12
|
+
with pytest.raises(TypeError):
|
|
13
|
+
flags.parse(['--foo=0.5'])
|
|
14
|
+
with pytest.raises(TypeError):
|
|
15
|
+
flags.parse(['--foo=foo'])
|
|
16
|
+
with pytest.raises(TypeError):
|
|
17
|
+
assert flags.parse(['--foo=1,2,3'])
|
|
18
|
+
|
|
19
|
+
def test_float(self):
|
|
20
|
+
flags = elements.Flags({'foo': 1.0})
|
|
21
|
+
assert flags.parse(['--foo=0.5']).foo == 0.5
|
|
22
|
+
assert flags.parse(['--foo=1']).foo == 1.0
|
|
23
|
+
assert flags.parse(['--foo=1e2']).foo == 1e2
|
|
24
|
+
with pytest.raises(TypeError):
|
|
25
|
+
flags.parse(['--foo=True'])
|
|
26
|
+
with pytest.raises(TypeError):
|
|
27
|
+
flags.parse(['--foo=foo'])
|
|
28
|
+
with pytest.raises(TypeError):
|
|
29
|
+
assert flags.parse(['--foo=0.5,1.0'])
|
|
30
|
+
|
|
31
|
+
def test_bool(self):
|
|
32
|
+
flags = elements.Flags({'foo': True})
|
|
33
|
+
assert flags.parse(['--foo=True']).foo is True
|
|
34
|
+
assert flags.parse(['--foo=False']).foo is False
|
|
35
|
+
with pytest.raises(TypeError):
|
|
36
|
+
flags.parse(['--foo=true'])
|
|
37
|
+
with pytest.raises(TypeError):
|
|
38
|
+
flags.parse(['--foo=1'])
|
|
39
|
+
with pytest.raises(TypeError):
|
|
40
|
+
flags.parse(['--foo=foo'])
|
|
41
|
+
with pytest.raises(TypeError):
|
|
42
|
+
assert flags.parse(['--foo=True,False'])
|
|
43
|
+
|
|
44
|
+
def test_str(self):
|
|
45
|
+
flags = elements.Flags({'foo': 'hello'})
|
|
46
|
+
assert flags.parse(['--foo=hi']).foo == 'hi'
|
|
47
|
+
assert flags.parse(['--foo=1,2,3']).foo == '1,2,3'
|
|
48
|
+
assert flags.parse(['--foo=']).foo == ''
|
|
49
|
+
with pytest.raises(TypeError):
|
|
50
|
+
assert flags.parse(['--foo', '1', '2', '3'])
|
|
51
|
+
|
|
52
|
+
def test_sequence(self):
|
|
53
|
+
flags = elements.Flags({'foo': [1, 2]})
|
|
54
|
+
assert flags.parse(['--foo=1']).foo == (1,)
|
|
55
|
+
assert flags.parse(['--foo=1,2,3']).foo == (1, 2, 3)
|
|
56
|
+
assert flags.parse(['--foo', '1,2,3']).foo == (1, 2, 3)
|
|
57
|
+
assert flags.parse(['--foo', '1', '2', '3']).foo == (1, 2, 3)
|
|
58
|
+
with pytest.raises(TypeError):
|
|
59
|
+
assert flags.parse(['--foo', 'False'])
|
|
60
|
+
with pytest.raises(TypeError):
|
|
61
|
+
assert flags.parse(['--foo=1,2,0.5'])
|
|
62
|
+
|
|
63
|
+
def test_append(self):
|
|
64
|
+
flags = elements.Flags({'foo': [1, 2]})
|
|
65
|
+
assert flags.parse(['--foo+=3']).foo == (1, 2, 3)
|
|
66
|
+
assert flags.parse(['--foo+', '3']).foo == (1, 2, 3)
|
|
67
|
+
with pytest.raises(TypeError):
|
|
68
|
+
assert flags.parse(['--foo+', '0.5'])
|
|
69
|
+
assert flags.parse(['--foo=1', '--foo+=2', '--foo+=3']).foo == (1, 2, 3)
|
|
70
|
+
assert flags.parse(['--foo+=3', '--foo+=4']).foo == (1, 2, 3, 4)
|
|
71
|
+
|
|
72
|
+
def test_nested(self):
|
|
73
|
+
flags = elements.Flags({'foo.bar': 12})
|
|
74
|
+
assert flags.parse(['--foo.bar=42']).foo.bar == 42
|
|
75
|
+
with pytest.raises(KeyError):
|
|
76
|
+
assert flags.parse(['--foo=42'])
|
|
77
|
+
with pytest.raises(KeyError):
|
|
78
|
+
assert flags.parse(['--foo.baz=42'])
|
|
79
|
+
flags = elements.Flags({'foo': {'bar': 12}})
|
|
80
|
+
assert flags.parse(['--foo.bar=42']).foo.bar == 42
|
|
81
|
+
|
|
82
|
+
def test_regex(self):
|
|
83
|
+
flags = elements.Flags({'foo.bar': 12, 'baz': 'text'})
|
|
84
|
+
assert flags.parse([r'--.*\.bar=42']).foo.bar == 42
|
|
85
|
+
assert flags.parse([r'--.*z$=hello']).baz == 'hello'
|
|
86
|
+
parsed = flags.parse([r'--.*=42'])
|
|
87
|
+
assert parsed.foo.bar == 42
|
|
88
|
+
assert parsed.baz == '42'
|
|
89
|
+
with pytest.raises(TypeError):
|
|
90
|
+
assert flags.parse([r'--.*\.bar=0.5'])
|
|
91
|
+
|
|
92
|
+
def test_kwargs(self):
|
|
93
|
+
assert elements.Flags(foo=42).parse(['--foo=12']).foo == 12
|
|
94
|
+
assert elements.Flags(foo=42, bar='text').parse(['--bar=i']).bar == 'i'
|
|
95
|
+
|
|
96
|
+
def test_multiple(self):
|
|
97
|
+
defaults = elements.Config(foo=12, bar=0.5, baz='text')
|
|
98
|
+
flags = elements.Flags(defaults)
|
|
99
|
+
assert flags.parse([]) == defaults
|
|
100
|
+
assert flags.parse(['--bar', '2.5', '--foo=1']) == (
|
|
101
|
+
defaults.update(foo=1, bar=2.5))
|
|
102
|
+
with pytest.raises(ValueError):
|
|
103
|
+
flags.parse(['--bar', '--foo=1'])
|
|
104
|
+
|
|
105
|
+
def test_invalid(self):
|
|
106
|
+
flags = elements.Flags(foo=12)
|
|
107
|
+
with pytest.raises(KeyError):
|
|
108
|
+
flags.parse(['--bar=1'])
|
|
109
|
+
with pytest.raises(ValueError):
|
|
110
|
+
flags.parse(['foo=1'])
|
|
111
|
+
with pytest.raises(ValueError):
|
|
112
|
+
flags.parse(['1', '--foo=5'])
|
|
113
|
+
|
|
114
|
+
def test_from_yaml(self, tmpdir):
|
|
115
|
+
filename = elements.Path(tmpdir) / 'defaults.yaml'
|
|
116
|
+
filename.write("""
|
|
117
|
+
foo: 42
|
|
118
|
+
parent.child: 12
|
|
119
|
+
seq: [1, 2, 3]
|
|
120
|
+
scope:
|
|
121
|
+
inside: foo
|
|
122
|
+
""")
|
|
123
|
+
defaults = elements.Config.load(filename)
|
|
124
|
+
flags = elements.Flags(defaults)
|
|
125
|
+
assert flags.parse(['--parent.child=42']).parent.child == 42
|
|
126
|
+
assert flags.parse(['--seq=2,4,6']).seq == (2, 4, 6)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import elements
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class TestPath:
|
|
5
|
+
|
|
6
|
+
def test_str_canonical(self):
|
|
7
|
+
examples = ['/', 'foo/bar', 'file.txt', '/bar.tar.gz']
|
|
8
|
+
for example in examples:
|
|
9
|
+
assert str(elements.Path(example)) == example
|
|
10
|
+
|
|
11
|
+
def test_parent_and_name(self):
|
|
12
|
+
examples = ['foo/bar', '/bar.tar.gz', 'file.txt', 'foo/bar/baz']
|
|
13
|
+
for example in examples:
|
|
14
|
+
path = elements.Path(example)
|
|
15
|
+
assert path == path.parent / path.name
|
|
16
|
+
|
|
17
|
+
def test_stem_and_suffix(self):
|
|
18
|
+
examples = ['foo/bar', '/bar.tar.gz', 'file.txt', 'foo/bar/baz']
|
|
19
|
+
for example in examples:
|
|
20
|
+
path = elements.Path(example)
|
|
21
|
+
assert path.name == path.stem + path.suffix
|
|
22
|
+
|
|
23
|
+
def test_leading_dot(self):
|
|
24
|
+
assert str(elements.Path('')) == '.'
|
|
25
|
+
assert str(elements.Path('.')) == '.'
|
|
26
|
+
assert str(elements.Path('./')) == '.'
|
|
27
|
+
assert str(elements.Path('./foo')) == 'foo'
|
|
28
|
+
|
|
29
|
+
def test_trailing_slash(self):
|
|
30
|
+
assert str(elements.Path('./')) == '.'
|
|
31
|
+
assert str(elements.Path('a/')) == 'a'
|
|
32
|
+
assert str(elements.Path('foo/bar/')) == 'foo/bar'
|
|
33
|
+
|
|
34
|
+
# @pytest.mark.filterwarnings('ignore::DeprecationWarning')
|
|
35
|
+
# def test_protocols(self):
|
|
36
|
+
# assert str(elements.Path('gs://')) == ('gs://')
|
|
37
|
+
# assert str(elements.Path('gs://foo/bar')) == 'gs://foo/bar'
|
|
38
|
+
|
|
39
|
+
def test_parent(self):
|
|
40
|
+
empty = elements.Path('.')
|
|
41
|
+
root = elements.Path('/')
|
|
42
|
+
assert (root / 'foo' / 'bar.txt').parent.parent == root
|
|
43
|
+
assert (empty / 'foo' / 'bar.txt').parent.parent == empty
|
|
44
|
+
assert root.parent == root
|
|
45
|
+
assert empty.parent == empty
|
elements-2.2.0/elements/agg.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
from collections import defaultdict
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class Agg:
|
|
7
|
-
|
|
8
|
-
def __init__(self, maxlen=int(1e6)):
|
|
9
|
-
self.maxlen = maxlen
|
|
10
|
-
self.avgs = defaultdict(lambda: [0.0, 0.0])
|
|
11
|
-
self.sums = defaultdict(float)
|
|
12
|
-
self.mins = defaultdict(float)
|
|
13
|
-
self.maxs = defaultdict(float)
|
|
14
|
-
self.lasts = defaultdict(lambda: None)
|
|
15
|
-
self.stacks = defaultdict(list)
|
|
16
|
-
|
|
17
|
-
def add(
|
|
18
|
-
self, key_or_dict, value=None, agg='default', prefix=None, nan='keep'):
|
|
19
|
-
aggs = (agg,) if isinstance(agg, str) else agg
|
|
20
|
-
assert nan in ('keep', 'ignore')
|
|
21
|
-
if value is None:
|
|
22
|
-
for key, value in dict(key_or_dict).items():
|
|
23
|
-
key = f'{prefix}/{key}' if prefix else key
|
|
24
|
-
self._add_single(key, value, aggs, nan)
|
|
25
|
-
else:
|
|
26
|
-
assert not prefix, prefix
|
|
27
|
-
self._add_single(key_or_dict, value, aggs, nan)
|
|
28
|
-
|
|
29
|
-
def result(self, reset=True, prefix=None):
|
|
30
|
-
metrics = {}
|
|
31
|
-
metrics.update({k: v[0] / v[1] for k, v in self.avgs.items()})
|
|
32
|
-
metrics.update(self.sums)
|
|
33
|
-
metrics.update(self.mins.items())
|
|
34
|
-
metrics.update(self.maxs.items())
|
|
35
|
-
metrics.update(self.lasts.items())
|
|
36
|
-
metrics.update({k: np.stack(v) for k, v in self.stacks.items()})
|
|
37
|
-
if prefix:
|
|
38
|
-
metrics = {f'{prefix}/{k}': v for k, v in metrics.items()}
|
|
39
|
-
reset and self.reset()
|
|
40
|
-
return metrics
|
|
41
|
-
|
|
42
|
-
def reset(self):
|
|
43
|
-
self.avgs.clear()
|
|
44
|
-
self.sums.clear()
|
|
45
|
-
self.mins.clear()
|
|
46
|
-
self.maxs.clear()
|
|
47
|
-
self.lasts.clear()
|
|
48
|
-
self.stacks.clear()
|
|
49
|
-
|
|
50
|
-
def _add_single(self, key, value, aggs, nan):
|
|
51
|
-
value = np.asarray(value)
|
|
52
|
-
if nan == 'ignore' and np.isnan(value):
|
|
53
|
-
return
|
|
54
|
-
for agg in aggs:
|
|
55
|
-
name = key if len(aggs) == 1 else f'{key}/{agg}'
|
|
56
|
-
if agg == 'default':
|
|
57
|
-
agg = 'avg' if len(value.shape) <= 1 else 'last'
|
|
58
|
-
if agg == 'avg':
|
|
59
|
-
avg = self.avgs[name]
|
|
60
|
-
avg[0] += value
|
|
61
|
-
avg[1] += 1
|
|
62
|
-
# assert not np.shares_memory(self.avgs[name][0], value)
|
|
63
|
-
elif agg == 'sum':
|
|
64
|
-
self.sums[name] += value
|
|
65
|
-
elif agg == 'min':
|
|
66
|
-
self.mins[name] = min(self.mins[name], value)
|
|
67
|
-
# assert not np.shares_memory(self.mins[name], value)
|
|
68
|
-
elif agg == 'max':
|
|
69
|
-
self.maxs[name] = max(self.maxs[name], value)
|
|
70
|
-
# assert not np.shares_memory(self.mins[name], value)
|
|
71
|
-
elif agg == 'last':
|
|
72
|
-
self.lasts[name] = value
|
|
73
|
-
elif agg == 'stack':
|
|
74
|
-
stack = self.stacks[name]
|
|
75
|
-
if len(stack) < self.maxlen:
|
|
76
|
-
stack.append(value)
|
|
77
|
-
else:
|
|
78
|
-
raise KeyError(agg)
|
elements-2.2.0/requirements.txt
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|