scope 0.2.2__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scope
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: Metrics logging and analysis
5
5
  Home-page: http://github.com/danijar/scope
6
6
  Classifier: Intended Audience :: Science/Research
@@ -0,0 +1,9 @@
1
+ __version__ = '0.3.0'
2
+
3
+ from .reader import Reader
4
+ from .writer import Writer
5
+
6
+ from .formats import table_read
7
+ from .formats import table_append
8
+
9
+ from . import formats
@@ -0,0 +1,210 @@
1
+ import io
2
+ import struct
3
+ import time
4
+
5
+ import av
6
+ import numpy as np
7
+ import PIL.Image
8
+
9
+
10
+ class Float:
11
+
12
+ @property
13
+ def extension(self):
14
+ return 'float'
15
+
16
+ def valid(self, x):
17
+ return x.ndim == 0 and np.isreal(x)
18
+
19
+ def create(self, path):
20
+ pass
21
+
22
+ def write(self, path, steps, values):
23
+ table_append(path, '>qd', steps, values)
24
+
25
+ def read(self, path):
26
+ steps, values = table_read(path, '>qd')
27
+ steps = np.int64(steps)
28
+ values = np.float64(values)
29
+ return steps, values
30
+
31
+ def length(self, path):
32
+ return table_length(path, '>qd')
33
+
34
+
35
+ class Text:
36
+
37
+ @property
38
+ def extension(self):
39
+ return 'txt'
40
+
41
+ def valid(self, x):
42
+ return isinstance(x, str)
43
+
44
+ def create(self, path):
45
+ path.mkdir(exist_ok=True)
46
+
47
+ def write(self, path, steps, values):
48
+ files_write(path, steps, values, self.encode)
49
+
50
+ def read(self, path):
51
+ return files_read(path)
52
+
53
+ def encode(self, value):
54
+ return value.encode('utf-8')
55
+
56
+ def decode(self, buffer):
57
+ return buffer.deceode('utf-8')
58
+
59
+ def length(self, path):
60
+ return files_length(path)
61
+
62
+
63
+ class Image:
64
+
65
+ def __init__(self, ext='png', quality=90):
66
+ self.ext = ext
67
+ self.quality = quality
68
+
69
+ @property
70
+ def extension(self):
71
+ return self.ext
72
+
73
+ def valid(self, x):
74
+ return (
75
+ x.dtype == np.uint8 and
76
+ x.ndim == 3 and
77
+ x.shape[-1] in (1, 3))
78
+
79
+ def create(self, path):
80
+ path.mkdir(exist_ok=True)
81
+
82
+ def write(self, path, steps, values):
83
+ files_write(path, steps, values, self.encode)
84
+
85
+ def read(self, path):
86
+ return files_read(path)
87
+
88
+ def encode(self, value):
89
+ if value.shape[-1] == 1:
90
+ value = value.repeat(3, -1)
91
+ fmt = ('jpeg' if self.ext == 'jpg' else self.ext).upper()
92
+ fp = io.BytesIO()
93
+ PIL.Image.fromarray(value).save(fp, fmt, quality=self.quality)
94
+ return fp.getvalue()
95
+
96
+ def decode(self, buffer):
97
+ return np.asarray(PIL.Image.open(io.BytesIO(buffer)).convert('RGB'))
98
+
99
+ def length(self, path):
100
+ return files_length(path)
101
+
102
+
103
+ class Video:
104
+
105
+ def __init__(self, ext='mp4', codec='h264', fps=10):
106
+ self.ext = ext
107
+ self.codec = codec
108
+ self.fps = fps
109
+
110
+ @property
111
+ def extension(self):
112
+ return self.ext
113
+
114
+ def valid(self, x):
115
+ return (
116
+ x.dtype == np.uint8 and
117
+ x.ndim == 4 and
118
+ x.shape[-1] in (1, 3))
119
+
120
+ def create(self, path):
121
+ path.mkdir(exist_ok=True)
122
+
123
+ def write(self, path, steps, values):
124
+ files_write(path, steps, values, self.encode)
125
+
126
+ def read(self, path):
127
+ return files_read(path)
128
+
129
+ def encode(self, value):
130
+ if value.shape[-1] == 1:
131
+ value = value.repeat(3, -1)
132
+ T, H, W, _ = value.shape
133
+ fp = io.BytesIO()
134
+ output = av.open(fp, mode='w', format=self.ext)
135
+ stream = output.add_stream(self.codec, rate=float(self.fps))
136
+ stream.width = W
137
+ stream.height = H
138
+ stream.pix_fmt = 'yuv420p'
139
+ for t in range(T):
140
+ frame = av.VideoFrame.from_ndarray(value[t], format='rgb24')
141
+ frame.pts = t
142
+ output.mux(stream.encode(frame))
143
+ output.mux(stream.encode(None))
144
+ output.close()
145
+ return fp.getvalue()
146
+
147
+ def decode(self, buffer):
148
+ container = av.open(io.BytesIO(buffer))
149
+ value = []
150
+ for frame in container.decode(video=0):
151
+ value.append(frame.to_ndarray(format='rgb24'))
152
+ value = np.stack(value)
153
+ container.close()
154
+ return value
155
+
156
+ def length(self, path):
157
+ return files_length(path)
158
+
159
+
160
+ def table_append(filename, fmt, *cols):
161
+ rows = tuple(zip(*cols))
162
+ size = struct.calcsize(fmt)
163
+ buffer = bytearray(len(rows) * size)
164
+ for index, row in enumerate(rows):
165
+ struct.pack_into(fmt, buffer, index * size, *row)
166
+ with filename.open('ab') as f:
167
+ f.write(buffer)
168
+
169
+
170
+ def table_read(filename, fmt, start=0, stop=None):
171
+ assert stop is None or start < stop, (start, stop)
172
+ if start == 0 and stop is None:
173
+ buffer = filename.read_bytes()
174
+ else:
175
+ size = struct.calcsize(fmt)
176
+ with filename.open('rb') as f:
177
+ start and f.seek(start * size)
178
+ buffer = f.read((stop - start) * size if stop else None)
179
+ rows = struct.iter_unpack(fmt, buffer)
180
+ cols = tuple(zip(*rows))
181
+ return cols
182
+
183
+
184
+ def table_length(filename, fmt):
185
+ return filename.stat().st_size // struct.calcsize(fmt)
186
+
187
+
188
+ def files_write(path, steps, values, encode):
189
+ rng = np.random.default_rng(seed=None)
190
+ prefix = int(time.time()).to_bytes(4, 'big')
191
+ idents = [prefix + rng.bytes(4) for _ in range(len(steps))]
192
+ for ident, step, value in zip(idents, steps, values):
193
+ filename = f'{step:020}-{ident.hex()}{path.suffix}'
194
+ buffer = encode(value)
195
+ with (path / filename).open('wb') as f:
196
+ f.write(buffer)
197
+ table_append(path / 'index', 'q8s', steps, idents)
198
+
199
+
200
+ def files_read(path):
201
+ steps, idents = table_read(path / 'index', 'q8s')
202
+ filenames = [
203
+ path / f'{step:020}-{ident.hex()}{path.suffix}'
204
+ for step, ident in zip(steps, idents)]
205
+ steps = np.int64(steps)
206
+ return steps, filenames
207
+
208
+
209
+ def files_length(path):
210
+ return table_length(path / 'index', 'q8s')
@@ -0,0 +1,44 @@
1
+ import re
2
+ import pathlib
3
+
4
+ from . import formats
5
+
6
+
7
+ FORMATS = [
8
+ formats.Text(),
9
+ formats.Float(),
10
+ formats.Image(),
11
+ formats.Video(),
12
+ ]
13
+
14
+
15
+ class Reader:
16
+
17
+ def __init__(self, logdir, formats=FORMATS):
18
+ if isinstance(logdir, str):
19
+ logdir = pathlib.Path(logdir)
20
+ self.logdir = logdir
21
+ self.fmts = {x.extension: x for x in formats}
22
+ self.cols = {}
23
+ for child in sorted(logdir.glob('*')):
24
+ basename, ext = child.name.rsplit('.', 1)
25
+ key = basename.replace('-', '/')
26
+ assert re.match(r'[a-z0-9_]+(/[a-z0-9_]+)?', key), key
27
+ self.cols[key] = (child.name, self.fmts[ext])
28
+
29
+ def keys(self):
30
+ return tuple(self.cols.keys())
31
+
32
+ def __getitem__(self, key):
33
+ name, fmt = self.cols[key]
34
+ return fmt.read(self.logdir / name)
35
+
36
+ def length(self, key):
37
+ name, fmt = self.cols[key]
38
+ return fmt.length(self.logdir / name)
39
+
40
+ def load(self, key, filename):
41
+ _, fmt = self.cols[key]
42
+ buffer = (self.logdir / filename).read_bytes()
43
+ value = fmt.decode(buffer)
44
+ return value
@@ -0,0 +1,96 @@
1
+ import concurrent.futures
2
+ import dataclasses
3
+ import pathlib
4
+ import re
5
+
6
+ import numpy as np
7
+
8
+ from . import formats
9
+
10
+
11
+ FORMATS = [
12
+ formats.Text(),
13
+ formats.Float(),
14
+ formats.Image(),
15
+ formats.Video(),
16
+ ]
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Column:
21
+
22
+ fmt: str
23
+ name: str
24
+ created: bool
25
+ steps: list
26
+ values: list
27
+
28
+
29
+ class Writer:
30
+
31
+ def __init__(self, logdir, fps=20, workers=8, formats=FORMATS):
32
+ if isinstance(logdir, str):
33
+ logdir = pathlib.Path(logdir)
34
+ self.logdir = logdir
35
+ self.logdir.mkdir(parents=True, exist_ok=True)
36
+ self.fps = fps
37
+ self.workers = workers
38
+ self.rng = np.random.default_rng(seed=None)
39
+ self.fmts = FORMATS
40
+ self.cols = {}
41
+ if workers:
42
+ self.pool = concurrent.futures.ThreadPoolExecutor(workers, 'scope')
43
+ self.futures = []
44
+
45
+ def add(self, step, *args, **kwargs):
46
+ assert isinstance(step, (int, np.integer)), type(step)
47
+ step = int(step)
48
+ mapping = dict(*args, **kwargs)
49
+ for key, value in mapping.items():
50
+ if not isinstance(value, str):
51
+ value = np.asarray(value)
52
+ if key not in self.cols:
53
+ assert re.match(r'[a-z0-9_]+(/[a-z0-9_]+)?', key), key
54
+ for fmt in self.fmts:
55
+ if fmt.valid(value):
56
+ break
57
+ else:
58
+ raise NotImplementedError(
59
+ f"No format supports key '{key}' with {self._info(value)}")
60
+ name = key.replace('/', '-') + '.' + fmt.extension
61
+ self.cols[key] = Column(fmt, name, False, [], [])
62
+ col = self.cols[key]
63
+ if not col.fmt.valid(value):
64
+ raise ValueError(
65
+ f"Key '{key}' contains invalid value {self._info(value)}")
66
+ col.steps.append(step)
67
+ col.values.append(value)
68
+
69
+ def flush(self):
70
+ if self.workers:
71
+ list(self.futures)
72
+ jobs = [(c, c.steps, c.values) for c in self.cols.values() if c.steps]
73
+ self.futures = self.pool.map(self._write, *zip(*jobs))
74
+ else:
75
+ for col in self.cols.values():
76
+ if col.steps:
77
+ self._write(col, col.steps, col.values)
78
+ for col in self.cols.values():
79
+ col.steps = []
80
+ col.values = []
81
+
82
+ def _write(self, col, steps, values):
83
+ try:
84
+ path = self.logdir / col.name
85
+ if not col.created:
86
+ col.fmt.create(path)
87
+ col.created = True
88
+ col.fmt.write(path, steps, values)
89
+ except Exception:
90
+ print(f"Exception writing '{col.name}' column")
91
+ raise
92
+
93
+ def _info(self, value):
94
+ if hasattr(value, 'dtype') and hasattr(value, 'shape'):
95
+ return f"dtype '{value.dtype}' and shape '{value.shape}'"
96
+ return f"type '{type(value)}'"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scope
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: Metrics logging and analysis
5
5
  Home-page: http://github.com/danijar/scope
6
6
  Classifier: Intended Audience :: Science/Research
@@ -4,7 +4,7 @@ pyproject.toml
4
4
  requirements.txt
5
5
  setup.py
6
6
  scope/__init__.py
7
- scope/columns.py
7
+ scope/formats.py
8
8
  scope/reader.py
9
9
  scope/writer.py
10
10
  scope.egg-info/PKG-INFO
@@ -22,21 +22,6 @@ class TestFloat:
22
22
  assert equal(reader['foo'], ([0, 5], [12, 42]), (np.int64, np.float64))
23
23
  assert equal(reader['bar'], ([5], [np.pi]), (np.int64, np.float64))
24
24
 
25
- def test_slicing(self, tmpdir):
26
- logdir = pathlib.Path(tmpdir)
27
- writer = scope.Writer(logdir, workers=0)
28
- writer.add(0, {'foo': 12})
29
- writer.add(5, {'foo': 42})
30
- writer.flush()
31
- reader = scope.Reader(logdir)
32
- assert equal(reader['foo', 0], ([0], [12]))
33
- assert equal(reader['foo', :2], ([0], [12]))
34
- assert equal(reader['foo', :5], ([0], [12]))
35
- assert equal(reader['foo', :6], ([0, 5], [12, 42]))
36
- assert equal(reader['foo', 1:6], ([5], [42]))
37
- assert equal(reader['foo', :-1], ([], []))
38
- assert equal(reader['foo', 7:], ([], []))
39
-
40
25
  def test_workers(self, tmpdir):
41
26
  logdir = pathlib.Path(tmpdir)
42
27
  writer = scope.Writer(logdir, workers=8)
@@ -62,6 +47,21 @@ class TestFloat:
62
47
  assert reader.length('foo/bar') == 1
63
48
  assert equal(reader['foo/bar'], ([0], [12]), (np.int64, np.float64))
64
49
 
50
+ # def test_slicing(self, tmpdir):
51
+ # logdir = pathlib.Path(tmpdir)
52
+ # writer = scope.Writer(logdir, workers=0)
53
+ # writer.add(0, {'foo': 12})
54
+ # writer.add(5, {'foo': 42})
55
+ # writer.flush()
56
+ # reader = scope.Reader(logdir)
57
+ # assert equal(reader['foo', 0], ([0], [12]))
58
+ # assert equal(reader['foo', :2], ([0], [12]))
59
+ # assert equal(reader['foo', :5], ([0], [12]))
60
+ # assert equal(reader['foo', :6], ([0, 5], [12, 42]))
61
+ # assert equal(reader['foo', 1:6], ([5], [42]))
62
+ # assert equal(reader['foo', :-1], ([], []))
63
+ # assert equal(reader['foo', 7:], ([], []))
64
+
65
65
 
66
66
  def equal(actuals, references, dtypes=None):
67
67
  dtypes = dtypes or [x.dtype for x in actuals]
@@ -20,32 +20,10 @@ class TestImage:
20
20
  reader = scope.Reader(logdir)
21
21
  assert reader.keys() == ('foo',)
22
22
  assert reader.length('foo') == 2
23
- steps, values = reader['foo']
23
+ steps, filenames = reader['foo']
24
+ values = [reader.load('foo', x) for x in filenames]
24
25
  assert (steps == np.array([0, 5])).all()
25
- assert (values == np.array([img1, img2])).all()
26
-
27
- def test_slicing(self, tmpdir):
28
- logdir = pathlib.Path(tmpdir)
29
- writer = scope.Writer(logdir, workers=0)
30
- img1 = np.ones((64, 128, 3), np.uint8) + 12
31
- img2 = np.ones((64, 128, 3), np.uint8) + 255
32
- writer.add(0, {'foo': img1})
33
- writer.add(5, {'foo': img2})
34
- writer.flush()
35
- assert {x.name for x in logdir.glob('*')} == {'foo.png'}
36
- assert (logdir / 'foo.png' / 'index').stat().st_size == (8 + 8) * 2
37
- reader = scope.Reader(logdir)
38
- assert reader.keys() == ('foo',)
39
- assert reader.length('foo') == 2
40
- steps, values = reader['foo']
41
- assert (steps == np.array([0, 5])).all()
42
- assert (values == np.array([img1, img2])).all()
43
- assert (reader['foo', 0][1] == img1[None]).all()
44
- assert (reader['foo', :5][1] == img1[None]).all()
45
- assert (reader['foo', :6][1] == np.array([img1, img2])).all()
46
- assert (reader['foo', 1:6][1] == img2[None]).all()
47
- assert reader['foo', :-1][1] == ()
48
- assert reader['foo', 6:][1] == ()
26
+ assert (np.array(values) == np.array([img1, img2])).all()
49
27
 
50
28
  def test_workers(self, tmpdir):
51
29
  logdir = pathlib.Path(tmpdir)
@@ -64,7 +42,8 @@ class TestImage:
64
42
  assert reader.keys() == tuple(sorted(['foo', 'bar', 'baz']))
65
43
  for key in ('foo', 'bar', 'baz'):
66
44
  assert reader.length(key) == 5
67
- steps, values = reader[key]
45
+ steps, filenames = reader[key]
46
+ values = [reader.load(key, x) for x in filenames]
68
47
  assert (steps == np.arange(5)).all()
69
48
  assert all(x.dtype == np.uint8 for x in values)
70
49
  reference = np.arange(5, dtype=np.uint8)[:, None, None, None]
@@ -81,4 +60,29 @@ class TestImage:
81
60
  reader = scope.Reader(logdir)
82
61
  assert reader.keys() == ('foo/bar',)
83
62
  assert reader.length('foo/bar') == 1
84
- assert (reader['foo/bar'][1] == img).all()
63
+ _, filenames = reader['foo/bar']
64
+ assert len(filenames) == 1
65
+ assert (reader.load('foo/bar', filenames[0]) == img).all()
66
+
67
+ # def test_slicing(self, tmpdir):
68
+ # logdir = pathlib.Path(tmpdir)
69
+ # writer = scope.Writer(logdir, workers=0)
70
+ # img1 = np.ones((64, 128, 3), np.uint8) + 12
71
+ # img2 = np.ones((64, 128, 3), np.uint8) + 255
72
+ # writer.add(0, {'foo': img1})
73
+ # writer.add(5, {'foo': img2})
74
+ # writer.flush()
75
+ # assert {x.name for x in logdir.glob('*')} == {'foo.png'}
76
+ # assert (logdir / 'foo.png' / 'index').stat().st_size == (8 + 8) * 2
77
+ # reader = scope.Reader(logdir)
78
+ # assert reader.keys() == ('foo',)
79
+ # assert reader.length('foo') == 2
80
+ # steps, values = reader['foo']
81
+ # assert (steps == np.array([0, 5])).all()
82
+ # assert (values == np.array([img1, img2])).all()
83
+ # assert (reader['foo', 0][1] == img1[None]).all()
84
+ # assert (reader['foo', :5][1] == img1[None]).all()
85
+ # assert (reader['foo', :6][1] == np.array([img1, img2])).all()
86
+ # assert (reader['foo', 1:6][1] == img2[None]).all()
87
+ # assert reader['foo', :-1][1] == ()
88
+ # assert reader['foo', 6:][1] == ()
@@ -20,6 +20,7 @@ class TestVideo:
20
20
  reader = scope.Reader(logdir)
21
21
  assert reader.keys() == ('foo',)
22
22
  assert reader.length('foo') == 2
23
- steps, values = reader['foo']
23
+ steps, filenames = reader['foo']
24
+ values = [reader.load('foo', x) for x in filenames]
24
25
  assert (steps == np.array([0, 5])).all()
25
26
  assert np.allclose(values, [vid1, vid2], rtol=0.1)
@@ -1,4 +0,0 @@
1
- __version__ = '0.2.2'
2
-
3
- from .writer import Writer
4
- from .reader import Reader
@@ -1,188 +0,0 @@
1
- import io
2
- import struct
3
- import time
4
-
5
- import av
6
- import numpy as np
7
- from PIL import Image
8
-
9
-
10
- def table_length(filename, fmt):
11
- return filename.stat().st_size // struct.calcsize(fmt)
12
-
13
-
14
- def table_write(filename, fmt, *cols):
15
- rows = tuple(zip(*cols))
16
- size = struct.calcsize(fmt)
17
- buffer = bytearray(len(rows) * size)
18
- for index, row in enumerate(rows):
19
- struct.pack_into(fmt, buffer, index * size, *row)
20
- with filename.open('ab') as f:
21
- f.write(buffer)
22
-
23
-
24
- def table_read(filename, fmt, start=0, stop=None):
25
- assert stop is None or start < stop, (start, stop)
26
- size = struct.calcsize(fmt)
27
- with filename.open('rb') as f:
28
- start and f.seek(start * size)
29
- buffer = f.read((stop - start) * size if stop else None)
30
- rows = struct.iter_unpack(fmt, buffer)
31
- cols = tuple(zip(*rows))
32
- return cols
33
-
34
-
35
- class FloatColumn:
36
-
37
- def __init__(self, logdir, key):
38
- name = key.replace('/', '-') + '.float'
39
- self.filename = logdir / name
40
-
41
- def validate(self, value):
42
- assert value.dtype in (float, int) and value.ndim == 0, (
43
- value.dtype, value.shape)
44
- return value
45
-
46
- def write(self, values):
47
- steps, values = zip(*values)
48
- table_write(self.filename, '>qd', steps, values)
49
-
50
- def length(self):
51
- return table_length(self.filename, '>qd')
52
-
53
- def read(self, start, stop):
54
- steps, values = table_read(self.filename, '>qd')
55
- filtered = [(s, v) for s, v in zip(steps, values) if start <= s < stop]
56
- steps, values = zip(*filtered) if filtered else ([], [])
57
- steps = np.array(steps, np.int64)
58
- values = np.array(values, np.float64)
59
- return steps, values
60
-
61
-
62
- class FileColumn:
63
-
64
- def __init__(self, logdir, key, ext, encfn, decfn):
65
- name = key.replace('/', '-') + '.' + ext
66
- self.folder = logdir / name
67
- self.folder.mkdir(exist_ok=True)
68
- self.index = self.folder / 'index'
69
- self.rng = np.random.default_rng(seed=None)
70
- self.ext = ext
71
- self.encfn = encfn
72
- self.decfn = decfn
73
-
74
- def validate(self, value):
75
- raise NotImplementedError
76
-
77
- def write(self, values):
78
- prefix = int(time.time()).to_bytes(4, 'big')
79
- steps, values = zip(*values)
80
- idents = [prefix + self.rng.bytes(4) for _ in range(len(steps))]
81
- for ident, step, value in zip(idents, steps, values):
82
- buffer = self.encfn(value)
83
- with self._filename(step, ident).open('wb') as f:
84
- f.write(buffer)
85
- table_write(self.index, 'q8s', steps, idents)
86
-
87
- def length(self):
88
- return table_length(self.index, 'q8s')
89
-
90
- def read(self, start, stop):
91
- steps, idents = table_read(self.index, 'q8s')
92
- filtered = [(s, v) for s, v in zip(steps, idents) if start <= s < stop]
93
- steps, idents = zip(*filtered) if filtered else ([], [])
94
- values = []
95
- for step, ident in zip(steps, idents):
96
- with self._filename(step, ident).open('rb') as f:
97
- buffer = f.read()
98
- values.append(self.decfn(buffer))
99
- steps = np.array(steps, np.int64)
100
- values = tuple(values)
101
- return steps, values
102
-
103
- def _filename(self, step, ident):
104
- return self.folder / f'{step:020}-{ident.hex()}.{self.ext}'
105
-
106
-
107
- class TextColumn(FileColumn):
108
-
109
- def __init__(self, logdir, key, fmt='txt'):
110
- super().__init__(logdir, key, fmt, self.encode, self.decode)
111
- self.fmt = fmt
112
-
113
- def validate(self, value):
114
- assert isinstance(value, str), type(value)
115
- return value
116
-
117
- def encode(self, value):
118
- return value.encode('utf-8')
119
-
120
- def decode(self, buffer):
121
- return buffer.decode('utf-8')
122
-
123
-
124
- class ImageColumn(FileColumn):
125
-
126
- def __init__(self, logdir, key, fmt='png', quality=None):
127
- super().__init__(logdir, key, fmt, self.encode, self.decode)
128
- self.fmt = fmt
129
- self.quality = quality
130
-
131
- def validate(self, value):
132
- assert (
133
- value.dtype == np.uint8 and value.ndim == 3 and
134
- value.shape[-1] in (1, 3)), (value.dtype, value.shape)
135
- return value
136
-
137
- def encode(self, value):
138
- if value.shape[-1] == 1:
139
- value = value.repeat(3, -1)
140
- fmt = ('jpeg' if self.fmt == 'jpg' else self.fmt).upper()
141
- fp = io.BytesIO()
142
- Image.fromarray(value).save(fp, format=fmt, quality=self.quality)
143
- return fp.getvalue()
144
-
145
- def decode(self, buffer):
146
- return np.asarray(Image.open(io.BytesIO(buffer)).convert('RGB'))
147
-
148
-
149
- class VideoColumn(FileColumn):
150
-
151
- def __init__(self, logdir, key, fmt='mp4', fps=20, codec='h264'):
152
- super().__init__(logdir, key, fmt, self.encode, self.decode)
153
- self.fmt = fmt
154
- self.fps = fps
155
- self.codec = codec
156
-
157
- def validate(self, value):
158
- assert (
159
- value.dtype == np.uint8 and value.ndim == 4 and
160
- value.shape[-1] in (1, 3)), (value.dtype, value.shape)
161
- return value
162
-
163
- def encode(self, value):
164
- if value.shape[-1] == 1:
165
- value = value.repeat(3, -1)
166
- T, H, W, C = value.shape
167
- fp = io.BytesIO()
168
- output = av.open(fp, mode='w', format=self.fmt)
169
- stream = output.add_stream(self.codec, rate=float(self.fps))
170
- stream.width = W
171
- stream.height = H
172
- stream.pix_fmt = 'yuv420p'
173
- for t in range(T):
174
- frame = av.VideoFrame.from_ndarray(value[t], format='rgb24')
175
- frame.pts = t
176
- output.mux(stream.encode(frame))
177
- output.mux(stream.encode(None))
178
- output.close()
179
- return fp.getvalue()
180
-
181
- def decode(self, buffer):
182
- container = av.open(io.BytesIO(buffer))
183
- value = []
184
- for frame in container.decode(video=0):
185
- value.append(frame.to_ndarray(format='rgb24'))
186
- value = np.stack(value)
187
- container.close()
188
- return value
@@ -1,43 +0,0 @@
1
- import re
2
- import pathlib
3
-
4
- import numpy as np
5
-
6
- from . import columns
7
-
8
-
9
- class Reader:
10
-
11
- def __init__(self, logdir):
12
- if isinstance(logdir, str):
13
- logdir = pathlib.Path(logdir)
14
- self.coltypes = {
15
- 'float': columns.FloatColumn,
16
- 'png': columns.ImageColumn,
17
- 'mp4': columns.VideoColumn,
18
- }
19
- self.columns = {}
20
- for child in sorted(logdir.glob('*')):
21
- name, ext = child.name.rsplit('.', 1)
22
- key = name.replace('-', '/')
23
- assert re.match(r'[a-z0-9_]+(/[a-z0-9_]+)?', key), key
24
- self.columns[key] = self.coltypes[ext](logdir, key)
25
-
26
- def keys(self):
27
- return tuple(self.columns.keys())
28
-
29
- def length(self, key):
30
- return self.columns[key].length()
31
-
32
- def __getitem__(self, index):
33
- if isinstance(index, str):
34
- key, start, stop = index, -np.inf, +np.inf
35
- else:
36
- key, pos = index
37
- if isinstance(pos, int):
38
- start, stop = pos, pos + 1
39
- else:
40
- assert pos.step is None
41
- start = -np.inf if pos.start is None else pos.start
42
- stop = +np.inf if pos.stop is None else pos.stop
43
- return self.columns[key].read(start, stop)
@@ -1,71 +0,0 @@
1
- import concurrent.futures
2
- import pathlib
3
- import re
4
- from functools import partial as bind
5
-
6
- import numpy as np
7
-
8
- from . import columns
9
-
10
-
11
- class Writer:
12
-
13
- def __init__(self, logdir, fps=20, workers=32):
14
- if isinstance(logdir, str):
15
- logdir = pathlib.Path(logdir)
16
- self.logdir = logdir
17
- self.logdir.mkdir(parents=True, exist_ok=True)
18
- self.fps = fps
19
- self.workers = workers
20
- self.coltypes = [
21
- (lambda x: isinstance(x, str), columns.TextColumn),
22
- (lambda x: x.ndim == 0, columns.FloatColumn),
23
- (lambda x: x.ndim == 3, bind(columns.ImageColumn, fmt='png')),
24
- (lambda x: x.ndim == 4, bind(columns.VideoColumn, fmt='mp4', fps=fps)),
25
- ]
26
- self.columns = {}
27
- self.values = {}
28
- if workers:
29
- self.pool = concurrent.futures.ThreadPoolExecutor(workers, 'writer')
30
- self.futures = []
31
-
32
- def add(self, step, *args, **kwargs):
33
- step = int(step)
34
- mapping = dict(*args, **kwargs)
35
- for key, value in mapping.items():
36
- value = value if isinstance(value, str) else np.asarray(value)
37
- if key not in self.columns:
38
- assert re.match(r'[a-z0-9_]+(/[a-z0-9_]+)?', key), key
39
- for applies, coltype in self.coltypes:
40
- if applies(value):
41
- break
42
- else:
43
- raise NotImplementedError((
44
- key, value,
45
- getattr(value, 'shape', None),
46
- getattr(value, 'dtype', None)))
47
- self.columns[key] = coltype(self.logdir, key)
48
- self.values[key] = []
49
- column = self.columns[key]
50
- try:
51
- value = column.validate(value)
52
- except Exception:
53
- print(f"Error validating key '{key}' with value '{value}'.")
54
- raise
55
- self.values[key].append((step, value))
56
-
57
- def flush(self):
58
- keys = [key for key, values in self.values.items() if values]
59
- if self.workers:
60
- list(self.futures)
61
- columns = [self.columns[x] for x in keys]
62
- values = [self.values[x] for x in keys]
63
- self.futures = self.pool.map(lambda x, y: x.write(y), columns, values)
64
- else:
65
- for key in keys:
66
- try:
67
- self.columns[key].write(self.values[key])
68
- except Exception:
69
- print(f"Exception writing '{key}' column.")
70
- raise
71
- self.values = {key: [] for key in self.values.keys()}
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes