pyglove 0.4.5.dev202410290809__py3-none-any.whl → 0.4.5.dev202411010809__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyglove/core/__init__.py +1 -0
- pyglove/core/io/__init__.py +1 -0
- pyglove/core/io/file_system.py +14 -4
- pyglove/core/io/file_system_test.py +2 -0
- pyglove/core/io/sequence.py +299 -0
- pyglove/core/io/sequence_test.py +124 -0
- pyglove/core/object_utils/__init__.py +4 -3
- pyglove/core/object_utils/error_utils.py +65 -1
- pyglove/core/object_utils/error_utils_test.py +56 -2
- pyglove/core/object_utils/json_conversion.py +1 -2
- pyglove/core/object_utils/{profiling.py → timing.py} +70 -21
- pyglove/core/object_utils/{profiling_test.py → timing_test.py} +34 -17
- pyglove/core/symbolic/__init__.py +1 -0
- pyglove/core/symbolic/base.py +75 -24
- pyglove/core/symbolic/dict.py +13 -4
- pyglove/core/symbolic/list.py +13 -4
- pyglove/core/symbolic/object.py +5 -2
- pyglove/core/views/html/base.py +18 -3
- pyglove/core/views/html/base_test.py +2 -0
- pyglove/core/views/html/tree_view.py +7 -1
- pyglove/core/views/html/tree_view_test.py +9 -0
- {pyglove-0.4.5.dev202410290809.dist-info → pyglove-0.4.5.dev202411010809.dist-info}/METADATA +1 -1
- {pyglove-0.4.5.dev202410290809.dist-info → pyglove-0.4.5.dev202411010809.dist-info}/RECORD +26 -24
- {pyglove-0.4.5.dev202410290809.dist-info → pyglove-0.4.5.dev202411010809.dist-info}/WHEEL +1 -1
- {pyglove-0.4.5.dev202410290809.dist-info → pyglove-0.4.5.dev202411010809.dist-info}/LICENSE +0 -0
- {pyglove-0.4.5.dev202410290809.dist-info → pyglove-0.4.5.dev202411010809.dist-info}/top_level.txt +0 -0
pyglove/core/__init__.py
CHANGED
@@ -163,6 +163,7 @@ to_json = symbolic.to_json
|
|
163
163
|
to_json_str = symbolic.to_json_str
|
164
164
|
save = symbolic.save
|
165
165
|
load = symbolic.load
|
166
|
+
open_jsonl = symbolic.open_jsonl
|
166
167
|
get_load_handler = symbolic.get_load_handler
|
167
168
|
set_load_handler = symbolic.set_load_handler
|
168
169
|
get_save_handler = symbolic.get_save_handler
|
pyglove/core/io/__init__.py
CHANGED
pyglove/core/io/file_system.py
CHANGED
@@ -51,6 +51,10 @@ class File(metaclass=abc.ABCMeta):
|
|
51
51
|
def tell(self) -> int:
|
52
52
|
"""Returns the current position of the file."""
|
53
53
|
|
54
|
+
@abc.abstractmethod
|
55
|
+
def flush(self) -> None:
|
56
|
+
"""Flushes the written content to the storage."""
|
57
|
+
|
54
58
|
@abc.abstractmethod
|
55
59
|
def close(self) -> None:
|
56
60
|
"""Closes the file."""
|
@@ -112,7 +116,7 @@ class FileSystem(metaclass=abc.ABCMeta):
|
|
112
116
|
"""Removes a directory chain based on a path."""
|
113
117
|
|
114
118
|
|
115
|
-
def
|
119
|
+
def resolve_path(path: Union[str, os.PathLike[str]]) -> str:
|
116
120
|
if isinstance(path, str):
|
117
121
|
return path
|
118
122
|
elif hasattr(path, '__fspath__'):
|
@@ -148,6 +152,9 @@ class StdFile(File):
|
|
148
152
|
def tell(self) -> int:
|
149
153
|
return self._file_object.tell()
|
150
154
|
|
155
|
+
def flush(self) -> None:
|
156
|
+
self._file_object.flush()
|
157
|
+
|
151
158
|
def close(self) -> None:
|
152
159
|
self._file_object.close()
|
153
160
|
|
@@ -220,6 +227,9 @@ class MemoryFile(File):
|
|
220
227
|
def tell(self) -> int:
|
221
228
|
return self._buffer.tell()
|
222
229
|
|
230
|
+
def flush(self) -> None:
|
231
|
+
pass
|
232
|
+
|
223
233
|
def close(self) -> None:
|
224
234
|
self.seek(0)
|
225
235
|
|
@@ -233,7 +243,7 @@ class MemoryFileSystem(FileSystem):
|
|
233
243
|
self._prefix = prefix
|
234
244
|
|
235
245
|
def _internal_path(self, path: Union[str, os.PathLike[str]]) -> str:
|
236
|
-
return '/' +
|
246
|
+
return '/' + resolve_path(path).lstrip(self._prefix)
|
237
247
|
|
238
248
|
def _locate(self, path: Union[str, os.PathLike[str]]) -> Any:
|
239
249
|
current = self._root
|
@@ -277,7 +287,7 @@ class MemoryFileSystem(FileSystem):
|
|
277
287
|
def _parent_and_name(
|
278
288
|
self, path: Union[str, os.PathLike[str]]
|
279
289
|
) -> tuple[dict[str, Any], str]:
|
280
|
-
path =
|
290
|
+
path = resolve_path(path)
|
281
291
|
rpos = path.rfind('/')
|
282
292
|
assert rpos >= 0, path
|
283
293
|
name = path[rpos + 1:]
|
@@ -372,7 +382,7 @@ class _FileSystemRegistry:
|
|
372
382
|
|
373
383
|
def get(self, path: Union[str, os.PathLike[str]]) -> FileSystem:
|
374
384
|
"""Gets the file system for a path."""
|
375
|
-
path =
|
385
|
+
path = resolve_path(path)
|
376
386
|
for prefix, fs in self._filesystems:
|
377
387
|
if path.startswith(prefix):
|
378
388
|
return fs
|
@@ -133,6 +133,7 @@ class MemoryFileSystemTest(unittest.TestCase):
|
|
133
133
|
self.assertFalse(fs.exists(file1))
|
134
134
|
with fs.open(file1, 'w') as f:
|
135
135
|
f.write('hello')
|
136
|
+
f.flush()
|
136
137
|
self.assertTrue(fs.exists(file1))
|
137
138
|
self.assertFalse(fs.isdir(file1))
|
138
139
|
|
@@ -224,6 +225,7 @@ class FileIoApiTest(unittest.TestCase):
|
|
224
225
|
with file_system.open(file1, 'w') as f:
|
225
226
|
self.assertIsInstance(f, file_system.MemoryFile)
|
226
227
|
f.write('foo')
|
228
|
+
f.flush()
|
227
229
|
self.assertTrue(file_system.path_exists(file1))
|
228
230
|
self.assertEqual(file_system.readfile(file1), 'foo')
|
229
231
|
file_system.writefile(file1, 'bar')
|
@@ -0,0 +1,299 @@
|
|
1
|
+
# Copyright 2024 The PyGlove Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Pluggable record IO."""
|
15
|
+
|
16
|
+
import abc
|
17
|
+
import collections
|
18
|
+
import os
|
19
|
+
from typing import Any, Callable, Iterator, Optional, Union
|
20
|
+
|
21
|
+
from pyglove.core.io import file_system
|
22
|
+
|
23
|
+
|
24
|
+
class Sequence(metaclass=abc.ABCMeta):
|
25
|
+
"""Interface for a sequence of records."""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
serializer: Optional[Callable[[Any], Union[bytes, str]]] = None,
|
30
|
+
deserializer: Optional[Callable[[Union[bytes, str]], Any]] = None
|
31
|
+
):
|
32
|
+
self._serializer = serializer
|
33
|
+
self._deserializer = deserializer
|
34
|
+
|
35
|
+
def add(self, record: Any) -> None:
|
36
|
+
"""Adds a record to the reader."""
|
37
|
+
if self._serializer:
|
38
|
+
record = self._serializer(record)
|
39
|
+
if not isinstance(record, (str, bytes)):
|
40
|
+
raise ValueError(
|
41
|
+
f'Cannot write record with type {type(record)}. '
|
42
|
+
'Did you forget to pass a serializer?'
|
43
|
+
)
|
44
|
+
self._add(record)
|
45
|
+
|
46
|
+
@abc.abstractmethod
|
47
|
+
def _add(self, record: Union[bytes, str]) -> None:
|
48
|
+
"""Adds a raw record to the reader."""
|
49
|
+
|
50
|
+
@abc.abstractmethod
|
51
|
+
def __len__(self):
|
52
|
+
"""Gets the number of records in the reader."""
|
53
|
+
|
54
|
+
def __iter__(self) -> Iterator[Any]:
|
55
|
+
"""Iterates over the records in the reader."""
|
56
|
+
for record in self._iter():
|
57
|
+
if self._deserializer:
|
58
|
+
yield self._deserializer(record)
|
59
|
+
else:
|
60
|
+
yield record
|
61
|
+
|
62
|
+
@abc.abstractmethod
|
63
|
+
def _iter(self) -> Iterator[Union[bytes, str]]:
|
64
|
+
"""Iterates over the raw records in the reader."""
|
65
|
+
|
66
|
+
@abc.abstractmethod
|
67
|
+
def close(self) -> None:
|
68
|
+
"""Closes the reader."""
|
69
|
+
|
70
|
+
@abc.abstractmethod
|
71
|
+
def flush(self) -> None:
|
72
|
+
"""Flushes the read records to the storage."""
|
73
|
+
|
74
|
+
def __enter__(self):
|
75
|
+
return self
|
76
|
+
|
77
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
78
|
+
del exc_type, exc_value, traceback
|
79
|
+
self.close()
|
80
|
+
|
81
|
+
|
82
|
+
class SequenceIO(metaclass=abc.ABCMeta):
|
83
|
+
"""Interface for a record IO system."""
|
84
|
+
|
85
|
+
@abc.abstractmethod
|
86
|
+
def open(
|
87
|
+
self,
|
88
|
+
path: Union[str, os.PathLike[str]],
|
89
|
+
mode: str,
|
90
|
+
*,
|
91
|
+
serializer: Optional[Callable[[Any], Union[bytes, str]]],
|
92
|
+
deserializer: Optional[Callable[[Union[bytes, str]], Any]],
|
93
|
+
**kwargs
|
94
|
+
) -> Sequence:
|
95
|
+
"""Opens a sequence for reading or writing."""
|
96
|
+
|
97
|
+
|
98
|
+
class _SequenceIORegistry(object):
|
99
|
+
"""Registry for record IO systems."""
|
100
|
+
|
101
|
+
def __init__(self):
|
102
|
+
self._registry = {}
|
103
|
+
|
104
|
+
def add(self, extension: str, sequence_io: SequenceIO) -> None:
|
105
|
+
"""Adds a record IO system with a prefix."""
|
106
|
+
self._registry[extension] = sequence_io
|
107
|
+
|
108
|
+
def get(self, path: Union[str, os.PathLike[str]]) -> SequenceIO:
|
109
|
+
"""Gets the record IO system for a path."""
|
110
|
+
path = file_system.resolve_path(path)
|
111
|
+
parts = path.split('.')
|
112
|
+
if parts:
|
113
|
+
extension = parts[-1].lower()
|
114
|
+
if '@' in extension:
|
115
|
+
extension = extension.split('@')[0]
|
116
|
+
if extension in self._registry:
|
117
|
+
return self._registry[extension]
|
118
|
+
return LineSequenceIO()
|
119
|
+
|
120
|
+
|
121
|
+
_registry = _SequenceIORegistry()
|
122
|
+
|
123
|
+
|
124
|
+
def add_sequence_io(extension: str, sequence_io: SequenceIO) -> None:
|
125
|
+
"""Adds a record IO system with a prefix."""
|
126
|
+
_registry.add(extension, sequence_io)
|
127
|
+
|
128
|
+
|
129
|
+
def open_sequence(
|
130
|
+
path: Union[str, os.PathLike[str]],
|
131
|
+
mode: str = 'r',
|
132
|
+
*,
|
133
|
+
serializer: Optional[
|
134
|
+
Callable[[Any], Union[bytes, str]]
|
135
|
+
] = None,
|
136
|
+
deserializer: Optional[
|
137
|
+
Callable[[Union[bytes, str]], Any]
|
138
|
+
] = None,
|
139
|
+
make_dirs_if_not_exist: bool = True,
|
140
|
+
) -> Sequence:
|
141
|
+
"""Open sequence for reading or writing.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
path: The path to the sequence.
|
145
|
+
mode: The mode of the sequence.
|
146
|
+
serializer: (Optional) A serializer function for converting a structured
|
147
|
+
object to a string or bytes.
|
148
|
+
deserializer: (Optional) A deserializer function for converting a string or
|
149
|
+
bytes to a structured object.
|
150
|
+
make_dirs_if_not_exist: (Optional) Whether to create the directories
|
151
|
+
if they do not exist. Applicable when opening in write or append mode.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
A sequence for reading or writing.
|
155
|
+
"""
|
156
|
+
if 'w' in mode or 'a' in mode:
|
157
|
+
parent_dir = os.path.dirname(path)
|
158
|
+
if make_dirs_if_not_exist:
|
159
|
+
file_system.mkdirs(parent_dir, exist_ok=True)
|
160
|
+
return _registry.get(path).open(
|
161
|
+
path, mode, serializer=serializer, deserializer=deserializer
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
class MemorySequence(Sequence):
|
166
|
+
"""An in-memory sequence."""
|
167
|
+
|
168
|
+
def __init__(
|
169
|
+
self,
|
170
|
+
path: str,
|
171
|
+
mode: str,
|
172
|
+
records: list[Union[str, bytes]],
|
173
|
+
serializer: Optional[Callable[[Any], Union[bytes, str]]],
|
174
|
+
deserializer: Optional[Callable[[Union[bytes, str]], Any]]
|
175
|
+
):
|
176
|
+
super().__init__(serializer, deserializer)
|
177
|
+
self._path = path
|
178
|
+
self._mode = mode
|
179
|
+
self._records = records
|
180
|
+
self._closed = False
|
181
|
+
|
182
|
+
def _add(self, record: Union[str, bytes]) -> None:
|
183
|
+
if 'w' not in self._mode and 'a' not in self._mode:
|
184
|
+
raise ValueError(
|
185
|
+
f'Cannot write record {record!r} to memory sequence {self._path!r} '
|
186
|
+
f'with mode {self._mode!r}.'
|
187
|
+
)
|
188
|
+
if self._closed:
|
189
|
+
raise ValueError(
|
190
|
+
f'Cannot write record {record!r} to a closed writer for '
|
191
|
+
f'{self._path!r}.'
|
192
|
+
)
|
193
|
+
self._records.append(record)
|
194
|
+
|
195
|
+
def __len__(self):
|
196
|
+
return len(self._records)
|
197
|
+
|
198
|
+
def _iter(self):
|
199
|
+
if 'r' not in self._mode:
|
200
|
+
raise ValueError(
|
201
|
+
f'Cannot read memory sequence {self._path!r} with '
|
202
|
+
f'mode {self._mode!r}.'
|
203
|
+
)
|
204
|
+
if self._closed:
|
205
|
+
raise ValueError(
|
206
|
+
f'Cannot iterate over a closed sequence reader {self._path!r}.'
|
207
|
+
)
|
208
|
+
return iter(self._records)
|
209
|
+
|
210
|
+
def flush(self):
|
211
|
+
pass
|
212
|
+
|
213
|
+
def close(self) -> None:
|
214
|
+
self._closed = True
|
215
|
+
|
216
|
+
|
217
|
+
class MemorySequenceIO(SequenceIO):
|
218
|
+
"""Memory-based record IO."""
|
219
|
+
|
220
|
+
def __init__(self):
|
221
|
+
super().__init__()
|
222
|
+
self._root = collections.defaultdict(list)
|
223
|
+
|
224
|
+
def open(
|
225
|
+
self,
|
226
|
+
path: Union[str, os.PathLike[str]],
|
227
|
+
mode: str,
|
228
|
+
*,
|
229
|
+
serializer: Optional[Callable[[Any], Union[bytes, str]]],
|
230
|
+
deserializer: Optional[Callable[[Union[bytes, str]], Any]],
|
231
|
+
**kwargs
|
232
|
+
) -> Sequence:
|
233
|
+
"""Opens a reader for a sequence."""
|
234
|
+
del kwargs
|
235
|
+
if 'w' in mode:
|
236
|
+
self._root[path] = []
|
237
|
+
return MemorySequence(
|
238
|
+
path, mode, self._root[path],
|
239
|
+
serializer=serializer, deserializer=deserializer
|
240
|
+
)
|
241
|
+
|
242
|
+
|
243
|
+
add_sequence_io('mem', MemorySequenceIO())
|
244
|
+
|
245
|
+
|
246
|
+
class LineSequence(Sequence):
|
247
|
+
"""A new-line broken sequence."""
|
248
|
+
|
249
|
+
def __init__(
|
250
|
+
self,
|
251
|
+
file: file_system.File,
|
252
|
+
serializer: Optional[Callable[[Any], Union[bytes, str]]],
|
253
|
+
deserializer: Optional[Callable[[Union[bytes, str]], Any]],
|
254
|
+
) -> None:
|
255
|
+
super().__init__(serializer, deserializer)
|
256
|
+
self._file = file
|
257
|
+
|
258
|
+
def __len__(self):
|
259
|
+
raise NotImplementedError(
|
260
|
+
'__len__ is not supported for LineSequence. '
|
261
|
+
'Use `len(list(iter(sequence)))` instead.'
|
262
|
+
)
|
263
|
+
|
264
|
+
def _iter(self):
|
265
|
+
while True:
|
266
|
+
line = self._file.readline()
|
267
|
+
if not line:
|
268
|
+
break
|
269
|
+
yield line.rstrip('\n')
|
270
|
+
|
271
|
+
def _add(self, record: Union[str, bytes]) -> None:
|
272
|
+
self._file.write(record.rstrip('\n'))
|
273
|
+
self._file.write('\n')
|
274
|
+
|
275
|
+
def flush(self) -> None:
|
276
|
+
self._file.flush()
|
277
|
+
|
278
|
+
def close(self) -> None:
|
279
|
+
self._file.close()
|
280
|
+
|
281
|
+
|
282
|
+
class LineSequenceIO(SequenceIO):
|
283
|
+
"""Line-based record IO."""
|
284
|
+
|
285
|
+
def open(
|
286
|
+
self,
|
287
|
+
path: Union[str, os.PathLike[str]],
|
288
|
+
mode: str,
|
289
|
+
*,
|
290
|
+
serializer: Optional[Callable[[Any], Union[bytes, str]]],
|
291
|
+
deserializer: Optional[Callable[[Union[bytes, str]], Any]],
|
292
|
+
**kwargs
|
293
|
+
) -> Sequence:
|
294
|
+
"""Opens a reader for a sequence."""
|
295
|
+
del kwargs
|
296
|
+
return LineSequence(
|
297
|
+
file_system.open(path, mode), serializer, deserializer
|
298
|
+
)
|
299
|
+
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright 2024 The PyGlove Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import os
|
16
|
+
import tempfile
|
17
|
+
import unittest
|
18
|
+
from pyglove.core.io import sequence as sequence_io
|
19
|
+
# We need to import this module to register the default parser/serializer.
|
20
|
+
import pyglove.core.symbolic as pg_symbolic
|
21
|
+
|
22
|
+
|
23
|
+
class LineSequenceIOTest(unittest.TestCase):
|
24
|
+
|
25
|
+
def test_read_write(self):
|
26
|
+
tmp_dir = tempfile.gettempdir()
|
27
|
+
file1 = os.path.join(tmp_dir, 'abc', 'file1')
|
28
|
+
with pg_symbolic.open_jsonl(file1, 'w') as f:
|
29
|
+
self.assertIsInstance(f, sequence_io.LineSequence)
|
30
|
+
f.add(1)
|
31
|
+
f.add(' foo')
|
32
|
+
f.add(' bar ')
|
33
|
+
f.flush()
|
34
|
+
self.assertTrue(os.path.exists(file1))
|
35
|
+
|
36
|
+
f.add('baz\n')
|
37
|
+
f.add(dict(x=1))
|
38
|
+
|
39
|
+
self.assertTrue(os.path.exists(file1))
|
40
|
+
with pg_symbolic.open_jsonl(file1, 'r') as f:
|
41
|
+
self.assertIsInstance(f, sequence_io.LineSequence)
|
42
|
+
self.assertEqual(list(iter(f)), [1, ' foo', ' bar ', 'baz\n', dict(x=1)])
|
43
|
+
|
44
|
+
with pg_symbolic.open_jsonl(file1, 'a') as f:
|
45
|
+
f.add('qux')
|
46
|
+
|
47
|
+
with pg_symbolic.open_jsonl(file1, 'r') as f:
|
48
|
+
with self.assertRaisesRegex(
|
49
|
+
NotImplementedError, '__len__ is not supported'
|
50
|
+
):
|
51
|
+
_ = len(f)
|
52
|
+
self.assertEqual(
|
53
|
+
list(iter(f)), [1, ' foo', ' bar ', 'baz\n', dict(x=1), 'qux']
|
54
|
+
)
|
55
|
+
|
56
|
+
def test_read_write_with_raw_texts(self):
|
57
|
+
tmp_dir = tempfile.gettempdir()
|
58
|
+
file2 = os.path.join(tmp_dir, 'file2')
|
59
|
+
with sequence_io.open_sequence(file2, 'w') as f:
|
60
|
+
self.assertIsInstance(f, sequence_io.LineSequence)
|
61
|
+
with self.assertRaisesRegex(
|
62
|
+
ValueError, 'Cannot write record with type'
|
63
|
+
):
|
64
|
+
f.add(1)
|
65
|
+
f.add('foo\nbar\n')
|
66
|
+
|
67
|
+
with sequence_io.open_sequence(file2, 'r') as f:
|
68
|
+
with self.assertRaisesRegex(
|
69
|
+
ValueError, 'not writable'
|
70
|
+
):
|
71
|
+
f.add('baz')
|
72
|
+
self.assertEqual(list(iter(f)), ['foo', 'bar'])
|
73
|
+
|
74
|
+
|
75
|
+
class MemorySequenceIOTest(unittest.TestCase):
|
76
|
+
|
77
|
+
def test_read_write(self):
|
78
|
+
with sequence_io.open_sequence('/file1.mem@123', 'w') as f:
|
79
|
+
self.assertIsInstance(f, sequence_io.MemorySequence)
|
80
|
+
f.add(' foo')
|
81
|
+
f.add(' bar ')
|
82
|
+
f.flush()
|
83
|
+
f.add('baz')
|
84
|
+
with self.assertRaisesRegex(
|
85
|
+
ValueError, 'Cannot write record with type'
|
86
|
+
):
|
87
|
+
f.add(1)
|
88
|
+
with self.assertRaisesRegex(
|
89
|
+
ValueError, 'Cannot read memory sequence'
|
90
|
+
):
|
91
|
+
next(iter(f))
|
92
|
+
|
93
|
+
with self.assertRaisesRegex(
|
94
|
+
ValueError, 'Cannot write record .* to a closed writer'
|
95
|
+
):
|
96
|
+
f.add('qux')
|
97
|
+
|
98
|
+
with sequence_io.open_sequence('/file1.mem@123', 'a') as f:
|
99
|
+
self.assertIsInstance(f, sequence_io.MemorySequence)
|
100
|
+
f.add('qux')
|
101
|
+
|
102
|
+
with sequence_io.open_sequence('/file1.mem@123') as f:
|
103
|
+
self.assertIsInstance(f, sequence_io.MemorySequence)
|
104
|
+
self.assertEqual(len(f), 4)
|
105
|
+
self.assertEqual(list(f), [' foo', ' bar ', 'baz', 'qux'])
|
106
|
+
with self.assertRaisesRegex(
|
107
|
+
ValueError, 'Cannot write record .* to memory sequence'
|
108
|
+
):
|
109
|
+
f.add('abc')
|
110
|
+
|
111
|
+
with self.assertRaisesRegex(
|
112
|
+
ValueError, 'Cannot iterate over a closed sequence reader'
|
113
|
+
):
|
114
|
+
next(iter(f))
|
115
|
+
|
116
|
+
with sequence_io.open_sequence('/file1.mem@123', 'w') as f:
|
117
|
+
f.add('abc')
|
118
|
+
|
119
|
+
with sequence_io.open_sequence('/file1.mem@123', 'r') as f:
|
120
|
+
self.assertEqual(list(iter(f)), ['abc'])
|
121
|
+
|
122
|
+
|
123
|
+
if __name__ == '__main__':
|
124
|
+
unittest.main()
|
@@ -154,10 +154,11 @@ from pyglove.core.object_utils.docstr_utils import docstr
|
|
154
154
|
# Handling exceptions.
|
155
155
|
from pyglove.core.object_utils.error_utils import catch_errors
|
156
156
|
from pyglove.core.object_utils.error_utils import CatchErrorsContext
|
157
|
+
from pyglove.core.object_utils.error_utils import ErrorInfo
|
157
158
|
|
158
|
-
#
|
159
|
-
from pyglove.core.object_utils.
|
160
|
-
from pyglove.core.object_utils.
|
159
|
+
# Timing.
|
160
|
+
from pyglove.core.object_utils.timing import timeit
|
161
|
+
from pyglove.core.object_utils.timing import TimeIt
|
161
162
|
|
162
163
|
# pylint: enable=g-importing-member
|
163
164
|
# pylint: enable=g-bad-import-order
|
@@ -17,7 +17,71 @@ import contextlib
|
|
17
17
|
import dataclasses
|
18
18
|
import inspect
|
19
19
|
import re
|
20
|
-
|
20
|
+
import sys
|
21
|
+
import traceback
|
22
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
|
23
|
+
|
24
|
+
from pyglove.core.object_utils import formatting
|
25
|
+
from pyglove.core.object_utils import json_conversion
|
26
|
+
|
27
|
+
|
28
|
+
@dataclasses.dataclass(frozen=True)
|
29
|
+
class ErrorInfo(json_conversion.JSONConvertible, formatting.Formattable):
|
30
|
+
"""Serializable error information.
|
31
|
+
|
32
|
+
Attributes:
|
33
|
+
tag: A path of the error types in the exception chain. For example,
|
34
|
+
`ValueError.ZeroDivisionError` means the error is a `ZeroDivisionError`
|
35
|
+
raised at the first place and then reraised as a `ValueError`.
|
36
|
+
description: The description of the error.
|
37
|
+
stacktrace: The stacktrace of the error.
|
38
|
+
"""
|
39
|
+
|
40
|
+
tag: str
|
41
|
+
description: str
|
42
|
+
stacktrace: str
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def _compute_tag(cls, error: BaseException):
|
46
|
+
error_types = []
|
47
|
+
while error is not None:
|
48
|
+
error_types.append(error.__class__.__name__)
|
49
|
+
error = getattr(error, 'cause', error.__cause__)
|
50
|
+
return '.'.join(error_types)
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def from_exception(cls, error: BaseException) -> 'ErrorInfo':
|
54
|
+
"""Creates an error info from an exception."""
|
55
|
+
return cls(
|
56
|
+
tag=cls._compute_tag(error),
|
57
|
+
description=str(error),
|
58
|
+
stacktrace=''.join(
|
59
|
+
traceback.format_exception(*sys.exc_info())
|
60
|
+
)
|
61
|
+
)
|
62
|
+
|
63
|
+
def to_json(self, **kwargs) -> Dict[str, Any]:
|
64
|
+
return self.to_json_dict(
|
65
|
+
fields=dict(
|
66
|
+
tag=(self.tag, None),
|
67
|
+
description=(self.description, None),
|
68
|
+
stacktrace=(self.stacktrace, None),
|
69
|
+
),
|
70
|
+
exclude_default=True,
|
71
|
+
**kwargs,
|
72
|
+
)
|
73
|
+
|
74
|
+
def format(self, *args, **kwargs) -> str:
|
75
|
+
return formatting.kvlist_str(
|
76
|
+
[
|
77
|
+
('tag', self.tag, None),
|
78
|
+
('description', self.description, None),
|
79
|
+
('stacktrace', self.stacktrace, None),
|
80
|
+
],
|
81
|
+
*args,
|
82
|
+
label=self.__class__.__name__,
|
83
|
+
**kwargs,
|
84
|
+
)
|
21
85
|
|
22
86
|
|
23
87
|
@dataclasses.dataclass()
|
@@ -11,12 +11,66 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
|
14
|
+
import inspect
|
16
15
|
import unittest
|
17
16
|
from pyglove.core.object_utils import error_utils
|
18
17
|
|
19
18
|
|
19
|
+
class ErrorInfoTest(unittest.TestCase):
|
20
|
+
"""Tests for ErrorInfo."""
|
21
|
+
|
22
|
+
def test_from_exception(self):
|
23
|
+
|
24
|
+
def foo():
|
25
|
+
return 1 / 0
|
26
|
+
|
27
|
+
def bar():
|
28
|
+
try:
|
29
|
+
foo()
|
30
|
+
except ZeroDivisionError as e:
|
31
|
+
raise ValueError('Bad call to `foo`') from e
|
32
|
+
|
33
|
+
error_info = None
|
34
|
+
try:
|
35
|
+
bar()
|
36
|
+
except ValueError as e:
|
37
|
+
error_info = error_utils.ErrorInfo.from_exception(e)
|
38
|
+
self.assertIsNotNone(error_info)
|
39
|
+
self.assertEqual(error_info.tag, 'ValueError.ZeroDivisionError')
|
40
|
+
self.assertEqual(error_info.description, 'Bad call to `foo`')
|
41
|
+
self.assertIn('Traceback (most recent call last)', error_info.stacktrace)
|
42
|
+
|
43
|
+
def test_to_json(self):
|
44
|
+
error_info = error_utils.ErrorInfo(
|
45
|
+
tag='ValueError.ZeroDivisionError',
|
46
|
+
description='Bad call to `foo`',
|
47
|
+
stacktrace='Traceback (most recent call last)',
|
48
|
+
)
|
49
|
+
json_dict = error_info.to_json()
|
50
|
+
error_info2 = error_utils.ErrorInfo.from_json(json_dict)
|
51
|
+
self.assertIsNot(error_info2, error_info)
|
52
|
+
self.assertEqual(error_info2, error_info)
|
53
|
+
|
54
|
+
def test_format(self):
|
55
|
+
error_info = error_utils.ErrorInfo(
|
56
|
+
tag='ValueError.ZeroDivisionError',
|
57
|
+
description='Bad call to `foo`',
|
58
|
+
stacktrace='Traceback (most recent call last)',
|
59
|
+
)
|
60
|
+
self.assertEqual(
|
61
|
+
error_info.format(compact=False),
|
62
|
+
inspect.cleandoc(
|
63
|
+
"""
|
64
|
+
ErrorInfo(
|
65
|
+
tag='ValueError.ZeroDivisionError',
|
66
|
+
description='Bad call to `foo`',
|
67
|
+
stacktrace='Traceback (most recent call last)'
|
68
|
+
)
|
69
|
+
"""
|
70
|
+
)
|
71
|
+
)
|
72
|
+
|
73
|
+
|
20
74
|
class CatchErrorsTest(unittest.TestCase):
|
21
75
|
|
22
76
|
def assert_caught_error(self, func, errors_to_catch):
|
@@ -202,9 +202,8 @@ class JSONConvertible(metaclass=abc.ABCMeta):
|
|
202
202
|
Returns:
|
203
203
|
An instance of cls.
|
204
204
|
"""
|
205
|
-
del kwargs
|
206
205
|
assert isinstance(json_value, dict)
|
207
|
-
init_args = {k: from_json(v) for k, v in json_value.items()
|
206
|
+
init_args = {k: from_json(v, **kwargs) for k, v in json_value.items()
|
208
207
|
if k != JSONConvertible.TYPE_NAME_KEY}
|
209
208
|
return cls(**init_args)
|
210
209
|
|