elements 3.18.0.dev1__tar.gz → 3.18.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {elements-3.18.0.dev1/elements.egg-info → elements-3.18.2}/PKG-INFO +1 -1
  2. {elements-3.18.0.dev1 → elements-3.18.2}/elements/__init__.py +2 -1
  3. {elements-3.18.0.dev1 → elements-3.18.2}/elements/checkpoint.py +26 -22
  4. {elements-3.18.0.dev1 → elements-3.18.2/elements.egg-info}/PKG-INFO +1 -1
  5. {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_checkpoint.py +49 -36
  6. {elements-3.18.0.dev1 → elements-3.18.2}/LICENSE +0 -0
  7. {elements-3.18.0.dev1 → elements-3.18.2}/MANIFEST.in +0 -0
  8. {elements-3.18.0.dev1 → elements-3.18.2}/README.md +0 -0
  9. {elements-3.18.0.dev1 → elements-3.18.2}/elements/agg.py +0 -0
  10. {elements-3.18.0.dev1 → elements-3.18.2}/elements/config.py +0 -0
  11. {elements-3.18.0.dev1 → elements-3.18.2}/elements/counter.py +0 -0
  12. {elements-3.18.0.dev1 → elements-3.18.2}/elements/flags.py +0 -0
  13. {elements-3.18.0.dev1 → elements-3.18.2}/elements/fps.py +0 -0
  14. {elements-3.18.0.dev1 → elements-3.18.2}/elements/logger.py +0 -0
  15. {elements-3.18.0.dev1 → elements-3.18.2}/elements/path.py +0 -0
  16. {elements-3.18.0.dev1 → elements-3.18.2}/elements/plotting.py +0 -0
  17. {elements-3.18.0.dev1 → elements-3.18.2}/elements/printing.py +0 -0
  18. {elements-3.18.0.dev1 → elements-3.18.2}/elements/rwlock.py +0 -0
  19. {elements-3.18.0.dev1 → elements-3.18.2}/elements/space.py +0 -0
  20. {elements-3.18.0.dev1 → elements-3.18.2}/elements/timer.py +0 -0
  21. {elements-3.18.0.dev1 → elements-3.18.2}/elements/tree.py +0 -0
  22. {elements-3.18.0.dev1 → elements-3.18.2}/elements/usage.py +0 -0
  23. {elements-3.18.0.dev1 → elements-3.18.2}/elements/utils.py +0 -0
  24. {elements-3.18.0.dev1 → elements-3.18.2}/elements/uuid.py +0 -0
  25. {elements-3.18.0.dev1 → elements-3.18.2}/elements/when.py +0 -0
  26. {elements-3.18.0.dev1 → elements-3.18.2}/elements.egg-info/SOURCES.txt +0 -0
  27. {elements-3.18.0.dev1 → elements-3.18.2}/elements.egg-info/dependency_links.txt +0 -0
  28. {elements-3.18.0.dev1 → elements-3.18.2}/elements.egg-info/requires.txt +0 -0
  29. {elements-3.18.0.dev1 → elements-3.18.2}/elements.egg-info/top_level.txt +0 -0
  30. {elements-3.18.0.dev1 → elements-3.18.2}/pyproject.toml +0 -0
  31. {elements-3.18.0.dev1 → elements-3.18.2}/requirements-optional.txt +0 -0
  32. {elements-3.18.0.dev1 → elements-3.18.2}/requirements.txt +0 -0
  33. {elements-3.18.0.dev1 → elements-3.18.2}/setup.cfg +0 -0
  34. {elements-3.18.0.dev1 → elements-3.18.2}/setup.py +0 -0
  35. {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_basics.py +0 -0
  36. {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_flags.py +0 -0
  37. {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_path.py +0 -0
  38. {elements-3.18.0.dev1 → elements-3.18.2}/tests/test_tree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: elements
3
- Version: 3.18.0.dev1
3
+ Version: 3.18.2
4
4
  Summary: Building blocks for productive research.
5
5
  Home-page: http://github.com/danijar/elements
6
6
  Classifier: Intended Audience :: Science/Research
@@ -1,4 +1,4 @@
1
- __version__ = '3.18.0.dev1'
1
+ __version__ = '3.18.2'
2
2
 
3
3
  from .agg import Agg
4
4
  from .checkpoint import Checkpoint, Saveable
@@ -17,6 +17,7 @@ from .usage import Usage
17
17
  from .utils import timestamp
18
18
  from .uuid import UUID
19
19
 
20
+ from . import checkpoint
20
21
  from . import logger
21
22
  from . import plotting
22
23
  from . import timer
@@ -92,10 +92,10 @@ class Checkpoint:
92
92
  def save(self, path=None, keys=None):
93
93
  assert self._directory or path
94
94
  if keys is None:
95
- saveables = self._saveables
95
+ savefns = {k: v.save for k, v in self._saveables.items()}
96
96
  else:
97
97
  assert all([not k.startswith('_') for k in keys]), keys
98
- saveables = {k: self._saveables[k] for k in keys}
98
+ savefns = {k: self._saveables[k].save for k in keys}
99
99
  if path:
100
100
  folder = None
101
101
  else:
@@ -104,8 +104,8 @@ class Checkpoint:
104
104
  folder += f'-{int(self._step):012d}'
105
105
  path = self._directory / folder
106
106
  printing.print_(f'Saving checkpoint: {path}')
107
- save(path, saveables, self._write)
108
- if folder:
107
+ save(path, savefns, self._write)
108
+ if folder and self._write:
109
109
  (self._directory / 'latest').write_text(folder)
110
110
  self._cleanup()
111
111
  print('Saved checkpoint.')
@@ -114,15 +114,15 @@ class Checkpoint:
114
114
  def load(self, path=None, keys=None):
115
115
  assert self._directory or path
116
116
  if keys is None:
117
- saveables = self._saveables
117
+ loadfns = {k: v.load for k, v in self._saveables.items()}
118
118
  else:
119
119
  assert all([not k.startswith('_') for k in keys]), keys
120
- saveables = {k: self._saveables[k] for k in keys}
120
+ loadfns = {k: self._saveables[k].load for k in keys}
121
121
  if not path:
122
122
  path = self.latest()
123
123
  assert path
124
124
  printing.print_(f'Loading checkpoint: {path}')
125
- load(path, saveables)
125
+ load(path, loadfns)
126
126
  print('Loaded checkpoint.')
127
127
 
128
128
  def load_or_save(self):
@@ -152,40 +152,44 @@ def exists(path):
152
152
  return (path / 'done').exists()
153
153
 
154
154
 
155
- def save(path, saveables, write=True):
155
+ def save(path, savefns, write=True):
156
156
  path = pathlib.Path(path)
157
157
  assert not exists(path), path
158
- path.mkdir(parents=True)
159
- for name, saveable in saveables.items():
158
+ write and path.mkdir(parents=True)
159
+ for name, savefn in savefns.items():
160
160
  try:
161
- data = saveable.save()
161
+ data = savefn()
162
162
  if inspect.isgenerator(data):
163
163
  for i, shard in enumerate(data):
164
- assert i < 1e7, i
164
+ assert i < 1e5, i
165
165
  if write: # Iterate even if we're not writing.
166
- buffer = pickle.dumps(shard)
167
- (path / f'{name}-{i:06d}.pkl').write_bytes(buffer)
166
+ with timer.section('checkpoint_pickle'):
167
+ buffer = pickle.dumps(shard)
168
+ with timer.section('checkpoint_write'):
169
+ (path / f'{name}-{i:04d}.pkl').write_bytes(buffer)
168
170
  else:
169
171
  if write:
170
- buffer = pickle.dumps(data)
171
- (path / f'{name}.pkl').write_bytes(buffer)
172
+ with timer.section('checkpoint_pickle'):
173
+ buffer = pickle.dumps(data)
174
+ with timer.section('checkpoint_write'):
175
+ (path / f'{name}.pkl').write_bytes(buffer)
172
176
  except Exception:
173
177
  print(f"Error save '{name}' to checkpoint.")
174
178
  raise
175
- (path / 'done').write_bytes(b'')
179
+ write and (path / 'done').write_bytes(b'')
176
180
 
177
181
 
178
- def load(path, saveables):
182
+ def load(path, loadfns):
179
183
  path = pathlib.Path(path)
180
184
  assert exists(path), path
181
185
  filenames = set(path.glob('*'))
182
- for name, saveable in saveables.items():
186
+ for name, loadfn in loadfns.items():
183
187
  try:
184
188
  if (path / f'{name}.pkl') in filenames:
185
189
  buffer = (path / f'{name}.pkl').read_bytes()
186
190
  data = pickle.loads(buffer)
187
- saveable.load(data)
188
- elif (path / f'{name}-000000.pkl') in filenames:
191
+ loadfn(data)
192
+ elif (path / f'{name}-0000.pkl') in filenames:
189
193
  shards = [x for x in filenames if x.name.startswith(f'{name}-')]
190
194
  shards = sorted(shards)
191
195
  def generator():
@@ -193,7 +197,7 @@ def load(path, saveables):
193
197
  buffer = filename.read_bytes()
194
198
  data = pickle.loads(buffer)
195
199
  yield data
196
- saveable.load(generator())
200
+ loadfn(generator())
197
201
  else:
198
202
  raise KeyError(name)
199
203
  except Exception:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: elements
3
- Version: 3.18.0.dev1
3
+ Version: 3.18.2
4
4
  Summary: Building blocks for productive research.
5
5
  Home-page: http://github.com/danijar/elements
6
6
  Classifier: Intended Audience :: Science/Research
@@ -6,7 +6,7 @@ class TestCheckpoint:
6
6
 
7
7
  def test_basic(self, tmpdir):
8
8
  path = elements.Path(tmpdir)
9
- foo = Foo(42)
9
+ foo = SaveableMock(42)
10
10
  cp = elements.Checkpoint(path)
11
11
  cp.foo = foo
12
12
  foo.value = 12
@@ -15,7 +15,7 @@ class TestCheckpoint:
15
15
  filenames = set(x.name for x in cp.latest().glob('*'))
16
16
  assert filenames == {'foo.pkl', 'done'}
17
17
  del cp
18
- foo = Foo(42)
18
+ foo = SaveableMock(42)
19
19
  cp = elements.Checkpoint(tmpdir)
20
20
  cp.foo = foo
21
21
  cp.load()
@@ -24,7 +24,7 @@ class TestCheckpoint:
24
24
  def test_load_or_save(self, tmpdir):
25
25
  path = elements.Path(tmpdir)
26
26
  for restart in range(3):
27
- foo = Foo(42)
27
+ foo = SaveableMock(42)
28
28
  cp = elements.Checkpoint(path, keep=3)
29
29
  cp.foo = foo
30
30
  cp.load_or_save()
@@ -35,7 +35,7 @@ class TestCheckpoint:
35
35
  def test_keep(self, tmpdir, keep=3):
36
36
  path = elements.Path(tmpdir)
37
37
  cp = elements.Checkpoint(path, keep=keep)
38
- cp.foo = Foo(0)
38
+ cp.foo = SaveableMock(0)
39
39
  for i in range(1, 6):
40
40
  cp.foo.value = i
41
41
  cp.save()
@@ -49,7 +49,7 @@ class TestCheckpoint:
49
49
  path = elements.Path(tmpdir)
50
50
  step = elements.Counter(0)
51
51
  cp = elements.Checkpoint(path, step=step, keep=3)
52
- cp.foo = Foo(0)
52
+ cp.foo = SaveableMock(0)
53
53
  for _ in range(5):
54
54
  cp.foo.value = int(step)
55
55
  cp.save()
@@ -60,33 +60,22 @@ class TestCheckpoint:
60
60
  assert steps == {2, 3, 4}
61
61
 
62
62
  def test_generator(self, tmpdir):
63
-
64
- class Bar:
65
- def __init__(self, values):
66
- self.values = values
67
- def save(self):
68
- for value in self.values:
69
- yield {'value': value}
70
- def load(self, data):
71
- for i, shard in enumerate(data):
72
- self.values[i] = shard['value']
73
-
74
63
  path = elements.Path(tmpdir)
75
64
  cp = elements.Checkpoint(path)
76
- cp.bar = Bar([42, 12, 26])
65
+ cp.foo = GeneratorMock([42, 12, 26])
77
66
  cp.save()
78
67
  filenames = set(x.name for x in cp.latest().glob('*'))
79
68
  assert filenames == {
80
- 'bar-000000.pkl',
81
- 'bar-000001.pkl',
82
- 'bar-000002.pkl',
69
+ 'foo-0000.pkl',
70
+ 'foo-0001.pkl',
71
+ 'foo-0002.pkl',
83
72
  'done',
84
73
  }
85
74
  del cp
86
75
  cp = elements.Checkpoint(path)
87
- cp.bar = Bar([0, 0, 0])
76
+ cp.foo = GeneratorMock([0, 0, 0])
88
77
  cp.load()
89
- assert cp.bar.values == [42, 12, 26]
78
+ assert cp.foo.values == [42, 12, 26]
90
79
 
91
80
  def test_saveable_inline(self, tmpdir):
92
81
  path = elements.Path(tmpdir)
@@ -101,25 +90,27 @@ class TestCheckpoint:
101
90
  assert foo == [42]
102
91
 
103
92
  def test_saveable_inherit(self, tmpdir):
104
-
105
- class Bar(elements.Saveable):
106
- def __init__(self, value):
107
- super().__init__(['value'])
108
- self.value = value
109
-
110
93
  path = elements.Path(tmpdir)
111
- bar = Bar(42)
94
+ foo = SubclassMock(42)
112
95
  cp = elements.Checkpoint(path)
113
- cp.bar = bar
96
+ cp.foo = foo
114
97
  cp.save()
115
- bar.value = 12
98
+ foo.value = 12
116
99
  cp.load()
117
- assert bar.value == 42
100
+ assert foo.value == 42
101
+
102
+ def test_write(self, tmpdir):
103
+ path = elements.Path(tmpdir)
104
+ cp = elements.Checkpoint(path, write=False)
105
+ cp.foo = SaveableMock(42)
106
+ cp.bar = GeneratorMock([1, 2, 3])
107
+ cp.save()
108
+ assert list(path.glob('*')) == []
118
109
 
119
110
  def test_path(self, tmpdir):
120
111
  path = elements.Path(tmpdir)
121
112
  cp = elements.Checkpoint()
122
- cp.foo = Foo(42)
113
+ cp.foo = SaveableMock(42)
123
114
  cp.save(path / 'inner')
124
115
  assert set(path.glob('*')) == {path / 'inner'}
125
116
  cp.foo.value = 12
@@ -129,8 +120,8 @@ class TestCheckpoint:
129
120
  def test_keys(self, tmpdir):
130
121
  path = elements.Path(tmpdir)
131
122
  cp = elements.Checkpoint(path)
132
- cp.foo = Foo(42)
133
- cp.bar = Foo(12)
123
+ cp.foo = SaveableMock(42)
124
+ cp.bar = SaveableMock(12)
134
125
  cp.save(keys=['bar'])
135
126
  filenames = set(x.name for x in cp.latest().glob('*'))
136
127
  assert filenames == {'bar.pkl', 'done'}
@@ -143,7 +134,7 @@ class TestCheckpoint:
143
134
  cp.load()
144
135
 
145
136
 
146
- class Foo:
137
+ class SaveableMock:
147
138
 
148
139
  def __init__(self, value):
149
140
  self.value = value
@@ -153,3 +144,25 @@ class Foo:
153
144
 
154
145
  def load(self, data):
155
146
  self.value = data['value']
147
+
148
+
149
+ class SubclassMock(elements.Saveable):
150
+
151
+ def __init__(self, value):
152
+ super().__init__(['value'])
153
+ self.value = value
154
+
155
+
156
+ class GeneratorMock:
157
+
158
+ def __init__(self, values):
159
+ self.values = values
160
+
161
+ def save(self):
162
+ for value in self.values:
163
+ shard = {'value': value}
164
+ yield shard
165
+
166
+ def load(self, data):
167
+ for i, shard in enumerate(data):
168
+ self.values[i] = shard['value']
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes