elements 3.18.0.dev1__tar.gz → 3.18.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {elements-3.18.0.dev1/elements.egg-info → elements-3.18.2}/PKG-INFO +1 -1
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/__init__.py +2 -1
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/checkpoint.py +26 -22
- {elements-3.18.0.dev1 → elements-3.18.2/elements.egg-info}/PKG-INFO +1 -1
- {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_checkpoint.py +49 -36
- {elements-3.18.0.dev1 → elements-3.18.2}/LICENSE +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/MANIFEST.in +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/README.md +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/agg.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/config.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/counter.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/flags.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/fps.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/logger.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/path.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/plotting.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/printing.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/rwlock.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/space.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/timer.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/tree.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/usage.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/utils.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/uuid.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements/when.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements.egg-info/SOURCES.txt +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements.egg-info/dependency_links.txt +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements.egg-info/requires.txt +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/elements.egg-info/top_level.txt +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/pyproject.toml +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/requirements-optional.txt +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/requirements.txt +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/setup.cfg +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/setup.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_basics.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_flags.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_path.py +0 -0
- {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_tree.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '3.18.
|
|
1
|
+
__version__ = '3.18.2'
|
|
2
2
|
|
|
3
3
|
from .agg import Agg
|
|
4
4
|
from .checkpoint import Checkpoint, Saveable
|
|
@@ -17,6 +17,7 @@ from .usage import Usage
|
|
|
17
17
|
from .utils import timestamp
|
|
18
18
|
from .uuid import UUID
|
|
19
19
|
|
|
20
|
+
from . import checkpoint
|
|
20
21
|
from . import logger
|
|
21
22
|
from . import plotting
|
|
22
23
|
from . import timer
|
|
@@ -92,10 +92,10 @@ class Checkpoint:
|
|
|
92
92
|
def save(self, path=None, keys=None):
|
|
93
93
|
assert self._directory or path
|
|
94
94
|
if keys is None:
|
|
95
|
-
|
|
95
|
+
savefns = {k: v.save for k, v in self._saveables.items()}
|
|
96
96
|
else:
|
|
97
97
|
assert all([not k.startswith('_') for k in keys]), keys
|
|
98
|
-
|
|
98
|
+
savefns = {k: self._saveables[k].save for k in keys}
|
|
99
99
|
if path:
|
|
100
100
|
folder = None
|
|
101
101
|
else:
|
|
@@ -104,8 +104,8 @@ class Checkpoint:
|
|
|
104
104
|
folder += f'-{int(self._step):012d}'
|
|
105
105
|
path = self._directory / folder
|
|
106
106
|
printing.print_(f'Saving checkpoint: {path}')
|
|
107
|
-
save(path,
|
|
108
|
-
if folder:
|
|
107
|
+
save(path, savefns, self._write)
|
|
108
|
+
if folder and self._write:
|
|
109
109
|
(self._directory / 'latest').write_text(folder)
|
|
110
110
|
self._cleanup()
|
|
111
111
|
print('Saved checkpoint.')
|
|
@@ -114,15 +114,15 @@ class Checkpoint:
|
|
|
114
114
|
def load(self, path=None, keys=None):
|
|
115
115
|
assert self._directory or path
|
|
116
116
|
if keys is None:
|
|
117
|
-
|
|
117
|
+
loadfns = {k: v.load for k, v in self._saveables.items()}
|
|
118
118
|
else:
|
|
119
119
|
assert all([not k.startswith('_') for k in keys]), keys
|
|
120
|
-
|
|
120
|
+
loadfns = {k: self._saveables[k].load for k in keys}
|
|
121
121
|
if not path:
|
|
122
122
|
path = self.latest()
|
|
123
123
|
assert path
|
|
124
124
|
printing.print_(f'Loading checkpoint: {path}')
|
|
125
|
-
load(path,
|
|
125
|
+
load(path, loadfns)
|
|
126
126
|
print('Loaded checkpoint.')
|
|
127
127
|
|
|
128
128
|
def load_or_save(self):
|
|
@@ -152,40 +152,44 @@ def exists(path):
|
|
|
152
152
|
return (path / 'done').exists()
|
|
153
153
|
|
|
154
154
|
|
|
155
|
-
def save(path,
|
|
155
|
+
def save(path, savefns, write=True):
|
|
156
156
|
path = pathlib.Path(path)
|
|
157
157
|
assert not exists(path), path
|
|
158
|
-
path.mkdir(parents=True)
|
|
159
|
-
for name,
|
|
158
|
+
write and path.mkdir(parents=True)
|
|
159
|
+
for name, savefn in savefns.items():
|
|
160
160
|
try:
|
|
161
|
-
data =
|
|
161
|
+
data = savefn()
|
|
162
162
|
if inspect.isgenerator(data):
|
|
163
163
|
for i, shard in enumerate(data):
|
|
164
|
-
assert i <
|
|
164
|
+
assert i < 1e5, i
|
|
165
165
|
if write: # Iterate even if we're not writing.
|
|
166
|
-
|
|
167
|
-
|
|
166
|
+
with timer.section('checkpoint_pickle'):
|
|
167
|
+
buffer = pickle.dumps(shard)
|
|
168
|
+
with timer.section('checkpoint_write'):
|
|
169
|
+
(path / f'{name}-{i:04d}.pkl').write_bytes(buffer)
|
|
168
170
|
else:
|
|
169
171
|
if write:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
+
with timer.section('checkpoint_pickle'):
|
|
173
|
+
buffer = pickle.dumps(data)
|
|
174
|
+
with timer.section('checkpoint_write'):
|
|
175
|
+
(path / f'{name}.pkl').write_bytes(buffer)
|
|
172
176
|
except Exception:
|
|
173
177
|
print(f"Error save '{name}' to checkpoint.")
|
|
174
178
|
raise
|
|
175
|
-
(path / 'done').write_bytes(b'')
|
|
179
|
+
write and (path / 'done').write_bytes(b'')
|
|
176
180
|
|
|
177
181
|
|
|
178
|
-
def load(path,
|
|
182
|
+
def load(path, loadfns):
|
|
179
183
|
path = pathlib.Path(path)
|
|
180
184
|
assert exists(path), path
|
|
181
185
|
filenames = set(path.glob('*'))
|
|
182
|
-
for name,
|
|
186
|
+
for name, loadfn in loadfns.items():
|
|
183
187
|
try:
|
|
184
188
|
if (path / f'{name}.pkl') in filenames:
|
|
185
189
|
buffer = (path / f'{name}.pkl').read_bytes()
|
|
186
190
|
data = pickle.loads(buffer)
|
|
187
|
-
|
|
188
|
-
elif (path / f'{name}-
|
|
191
|
+
loadfn(data)
|
|
192
|
+
elif (path / f'{name}-0000.pkl') in filenames:
|
|
189
193
|
shards = [x for x in filenames if x.name.startswith(f'{name}-')]
|
|
190
194
|
shards = sorted(shards)
|
|
191
195
|
def generator():
|
|
@@ -193,7 +197,7 @@ def load(path, saveables):
|
|
|
193
197
|
buffer = filename.read_bytes()
|
|
194
198
|
data = pickle.loads(buffer)
|
|
195
199
|
yield data
|
|
196
|
-
|
|
200
|
+
loadfn(generator())
|
|
197
201
|
else:
|
|
198
202
|
raise KeyError(name)
|
|
199
203
|
except Exception:
|
|
@@ -6,7 +6,7 @@ class TestCheckpoint:
|
|
|
6
6
|
|
|
7
7
|
def test_basic(self, tmpdir):
|
|
8
8
|
path = elements.Path(tmpdir)
|
|
9
|
-
foo =
|
|
9
|
+
foo = SaveableMock(42)
|
|
10
10
|
cp = elements.Checkpoint(path)
|
|
11
11
|
cp.foo = foo
|
|
12
12
|
foo.value = 12
|
|
@@ -15,7 +15,7 @@ class TestCheckpoint:
|
|
|
15
15
|
filenames = set(x.name for x in cp.latest().glob('*'))
|
|
16
16
|
assert filenames == {'foo.pkl', 'done'}
|
|
17
17
|
del cp
|
|
18
|
-
foo =
|
|
18
|
+
foo = SaveableMock(42)
|
|
19
19
|
cp = elements.Checkpoint(tmpdir)
|
|
20
20
|
cp.foo = foo
|
|
21
21
|
cp.load()
|
|
@@ -24,7 +24,7 @@ class TestCheckpoint:
|
|
|
24
24
|
def test_load_or_save(self, tmpdir):
|
|
25
25
|
path = elements.Path(tmpdir)
|
|
26
26
|
for restart in range(3):
|
|
27
|
-
foo =
|
|
27
|
+
foo = SaveableMock(42)
|
|
28
28
|
cp = elements.Checkpoint(path, keep=3)
|
|
29
29
|
cp.foo = foo
|
|
30
30
|
cp.load_or_save()
|
|
@@ -35,7 +35,7 @@ class TestCheckpoint:
|
|
|
35
35
|
def test_keep(self, tmpdir, keep=3):
|
|
36
36
|
path = elements.Path(tmpdir)
|
|
37
37
|
cp = elements.Checkpoint(path, keep=keep)
|
|
38
|
-
cp.foo =
|
|
38
|
+
cp.foo = SaveableMock(0)
|
|
39
39
|
for i in range(1, 6):
|
|
40
40
|
cp.foo.value = i
|
|
41
41
|
cp.save()
|
|
@@ -49,7 +49,7 @@ class TestCheckpoint:
|
|
|
49
49
|
path = elements.Path(tmpdir)
|
|
50
50
|
step = elements.Counter(0)
|
|
51
51
|
cp = elements.Checkpoint(path, step=step, keep=3)
|
|
52
|
-
cp.foo =
|
|
52
|
+
cp.foo = SaveableMock(0)
|
|
53
53
|
for _ in range(5):
|
|
54
54
|
cp.foo.value = int(step)
|
|
55
55
|
cp.save()
|
|
@@ -60,33 +60,22 @@ class TestCheckpoint:
|
|
|
60
60
|
assert steps == {2, 3, 4}
|
|
61
61
|
|
|
62
62
|
def test_generator(self, tmpdir):
|
|
63
|
-
|
|
64
|
-
class Bar:
|
|
65
|
-
def __init__(self, values):
|
|
66
|
-
self.values = values
|
|
67
|
-
def save(self):
|
|
68
|
-
for value in self.values:
|
|
69
|
-
yield {'value': value}
|
|
70
|
-
def load(self, data):
|
|
71
|
-
for i, shard in enumerate(data):
|
|
72
|
-
self.values[i] = shard['value']
|
|
73
|
-
|
|
74
63
|
path = elements.Path(tmpdir)
|
|
75
64
|
cp = elements.Checkpoint(path)
|
|
76
|
-
cp.
|
|
65
|
+
cp.foo = GeneratorMock([42, 12, 26])
|
|
77
66
|
cp.save()
|
|
78
67
|
filenames = set(x.name for x in cp.latest().glob('*'))
|
|
79
68
|
assert filenames == {
|
|
80
|
-
'
|
|
81
|
-
'
|
|
82
|
-
'
|
|
69
|
+
'foo-0000.pkl',
|
|
70
|
+
'foo-0001.pkl',
|
|
71
|
+
'foo-0002.pkl',
|
|
83
72
|
'done',
|
|
84
73
|
}
|
|
85
74
|
del cp
|
|
86
75
|
cp = elements.Checkpoint(path)
|
|
87
|
-
cp.
|
|
76
|
+
cp.foo = GeneratorMock([0, 0, 0])
|
|
88
77
|
cp.load()
|
|
89
|
-
assert cp.
|
|
78
|
+
assert cp.foo.values == [42, 12, 26]
|
|
90
79
|
|
|
91
80
|
def test_saveable_inline(self, tmpdir):
|
|
92
81
|
path = elements.Path(tmpdir)
|
|
@@ -101,25 +90,27 @@ class TestCheckpoint:
|
|
|
101
90
|
assert foo == [42]
|
|
102
91
|
|
|
103
92
|
def test_saveable_inherit(self, tmpdir):
|
|
104
|
-
|
|
105
|
-
class Bar(elements.Saveable):
|
|
106
|
-
def __init__(self, value):
|
|
107
|
-
super().__init__(['value'])
|
|
108
|
-
self.value = value
|
|
109
|
-
|
|
110
93
|
path = elements.Path(tmpdir)
|
|
111
|
-
|
|
94
|
+
foo = SubclassMock(42)
|
|
112
95
|
cp = elements.Checkpoint(path)
|
|
113
|
-
cp.
|
|
96
|
+
cp.foo = foo
|
|
114
97
|
cp.save()
|
|
115
|
-
|
|
98
|
+
foo.value = 12
|
|
116
99
|
cp.load()
|
|
117
|
-
assert
|
|
100
|
+
assert foo.value == 42
|
|
101
|
+
|
|
102
|
+
def test_write(self, tmpdir):
|
|
103
|
+
path = elements.Path(tmpdir)
|
|
104
|
+
cp = elements.Checkpoint(path, write=False)
|
|
105
|
+
cp.foo = SaveableMock(42)
|
|
106
|
+
cp.bar = GeneratorMock([1, 2, 3])
|
|
107
|
+
cp.save()
|
|
108
|
+
assert list(path.glob('*')) == []
|
|
118
109
|
|
|
119
110
|
def test_path(self, tmpdir):
|
|
120
111
|
path = elements.Path(tmpdir)
|
|
121
112
|
cp = elements.Checkpoint()
|
|
122
|
-
cp.foo =
|
|
113
|
+
cp.foo = SaveableMock(42)
|
|
123
114
|
cp.save(path / 'inner')
|
|
124
115
|
assert set(path.glob('*')) == {path / 'inner'}
|
|
125
116
|
cp.foo.value = 12
|
|
@@ -129,8 +120,8 @@ class TestCheckpoint:
|
|
|
129
120
|
def test_keys(self, tmpdir):
|
|
130
121
|
path = elements.Path(tmpdir)
|
|
131
122
|
cp = elements.Checkpoint(path)
|
|
132
|
-
cp.foo =
|
|
133
|
-
cp.bar =
|
|
123
|
+
cp.foo = SaveableMock(42)
|
|
124
|
+
cp.bar = SaveableMock(12)
|
|
134
125
|
cp.save(keys=['bar'])
|
|
135
126
|
filenames = set(x.name for x in cp.latest().glob('*'))
|
|
136
127
|
assert filenames == {'bar.pkl', 'done'}
|
|
@@ -143,7 +134,7 @@ class TestCheckpoint:
|
|
|
143
134
|
cp.load()
|
|
144
135
|
|
|
145
136
|
|
|
146
|
-
class
|
|
137
|
+
class SaveableMock:
|
|
147
138
|
|
|
148
139
|
def __init__(self, value):
|
|
149
140
|
self.value = value
|
|
@@ -153,3 +144,25 @@ class Foo:
|
|
|
153
144
|
|
|
154
145
|
def load(self, data):
|
|
155
146
|
self.value = data['value']
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class SubclassMock(elements.Saveable):
|
|
150
|
+
|
|
151
|
+
def __init__(self, value):
|
|
152
|
+
super().__init__(['value'])
|
|
153
|
+
self.value = value
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class GeneratorMock:
|
|
157
|
+
|
|
158
|
+
def __init__(self, values):
|
|
159
|
+
self.values = values
|
|
160
|
+
|
|
161
|
+
def save(self):
|
|
162
|
+
for value in self.values:
|
|
163
|
+
shard = {'value': value}
|
|
164
|
+
yield shard
|
|
165
|
+
|
|
166
|
+
def load(self, data):
|
|
167
|
+
for i, shard in enumerate(data):
|
|
168
|
+
self.values[i] = shard['value']
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|