deepboard 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepboard/__init__.py +1 -0
- deepboard/__version__.py +4 -0
- deepboard/gui/THEME.yml +28 -0
- deepboard/gui/__init__.py +0 -0
- deepboard/gui/assets/artefacts.css +108 -0
- deepboard/gui/assets/base.css +208 -0
- deepboard/gui/assets/base.js +77 -0
- deepboard/gui/assets/charts.css +188 -0
- deepboard/gui/assets/compare.css +90 -0
- deepboard/gui/assets/datagrid.css +120 -0
- deepboard/gui/assets/fileview.css +13 -0
- deepboard/gui/assets/right_panel.css +227 -0
- deepboard/gui/assets/theme.css +85 -0
- deepboard/gui/components/__init__.py +8 -0
- deepboard/gui/components/artefact_group.py +12 -0
- deepboard/gui/components/chart_type.py +22 -0
- deepboard/gui/components/legend.py +34 -0
- deepboard/gui/components/log_selector.py +22 -0
- deepboard/gui/components/modal.py +20 -0
- deepboard/gui/components/smoother.py +21 -0
- deepboard/gui/components/split_selector.py +21 -0
- deepboard/gui/components/stat_line.py +8 -0
- deepboard/gui/entry.py +21 -0
- deepboard/gui/main.py +93 -0
- deepboard/gui/pages/__init__.py +1 -0
- deepboard/gui/pages/compare_page/__init__.py +6 -0
- deepboard/gui/pages/compare_page/compare_page.py +22 -0
- deepboard/gui/pages/compare_page/components/__init__.py +4 -0
- deepboard/gui/pages/compare_page/components/card_list.py +19 -0
- deepboard/gui/pages/compare_page/components/chart.py +54 -0
- deepboard/gui/pages/compare_page/components/compare_setup.py +30 -0
- deepboard/gui/pages/compare_page/components/split_card.py +51 -0
- deepboard/gui/pages/compare_page/components/utils.py +20 -0
- deepboard/gui/pages/compare_page/routes.py +58 -0
- deepboard/gui/pages/main_page/__init__.py +4 -0
- deepboard/gui/pages/main_page/datagrid/__init__.py +5 -0
- deepboard/gui/pages/main_page/datagrid/compare_button.py +21 -0
- deepboard/gui/pages/main_page/datagrid/datagrid.py +67 -0
- deepboard/gui/pages/main_page/datagrid/handlers.py +54 -0
- deepboard/gui/pages/main_page/datagrid/header.py +43 -0
- deepboard/gui/pages/main_page/datagrid/routes.py +112 -0
- deepboard/gui/pages/main_page/datagrid/row.py +20 -0
- deepboard/gui/pages/main_page/datagrid/sortable_column_js.py +45 -0
- deepboard/gui/pages/main_page/datagrid/utils.py +9 -0
- deepboard/gui/pages/main_page/handlers.py +16 -0
- deepboard/gui/pages/main_page/main_page.py +21 -0
- deepboard/gui/pages/main_page/right_panel/__init__.py +12 -0
- deepboard/gui/pages/main_page/right_panel/config.py +57 -0
- deepboard/gui/pages/main_page/right_panel/fragments.py +133 -0
- deepboard/gui/pages/main_page/right_panel/hparams.py +25 -0
- deepboard/gui/pages/main_page/right_panel/images.py +358 -0
- deepboard/gui/pages/main_page/right_panel/run_info.py +86 -0
- deepboard/gui/pages/main_page/right_panel/scalars.py +251 -0
- deepboard/gui/pages/main_page/right_panel/template.py +151 -0
- deepboard/gui/pages/main_page/routes.py +25 -0
- deepboard/gui/pages/not_found.py +3 -0
- deepboard/gui/requirements.txt +5 -0
- deepboard/gui/utils.py +267 -0
- deepboard/resultTable/__init__.py +2 -0
- deepboard/resultTable/cursor.py +20 -0
- deepboard/resultTable/logwritter.py +667 -0
- deepboard/resultTable/resultTable.py +529 -0
- deepboard/resultTable/scalar.py +29 -0
- deepboard/resultTable/table_schema.py +135 -0
- deepboard/resultTable/utils.py +50 -0
- deepboard-0.2.0.dist-info/METADATA +164 -0
- deepboard-0.2.0.dist-info/RECORD +69 -0
- deepboard-0.2.0.dist-info/WHEEL +4 -0
- deepboard-0.2.0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,667 @@
|
|
1
|
+
from traceback import print_tb
|
2
|
+
from typing import *
|
3
|
+
from .cursor import Cursor
|
4
|
+
from datetime import datetime
|
5
|
+
from .scalar import Scalar
|
6
|
+
import sys
|
7
|
+
from PIL import Image
|
8
|
+
from io import BytesIO
|
9
|
+
import matplotlib.pyplot as plt
|
10
|
+
import hashlib
|
11
|
+
|
12
|
+
class LogWriter:
|
13
|
+
"""
|
14
|
+
This class makes an object that is bound to a run row in the result table. This means that everything that is
|
15
|
+
logged through this object is added into the result table and this object can be used to interact with a specific
|
16
|
+
run. This object is single use. This means that once the final results are written, the object becomes read-only.
|
17
|
+
|
18
|
+
You should not instantiate this class directly, but use the ResultTable class to create it instead.
|
19
|
+
|
20
|
+
"""
|
21
|
+
def __init__(self, db_path, run_id: int, start: datetime, flush_each: int = 10, keep_each: int = 1,
|
22
|
+
disable: bool = False, auto_log_plt=True):
|
23
|
+
"""
|
24
|
+
:param db_path: The path to the database file
|
25
|
+
:param run_id: The run id of this run
|
26
|
+
:param start: The start time of this run
|
27
|
+
:param flush_each: Every how many logs should we write them to the database. (increase it to reduce io)
|
28
|
+
:param keep_each: Every how many logs datapoint should we store. The others will be discarted. If 2, only one
|
29
|
+
datapoint every two times the add_scalar method is called will be stored.
|
30
|
+
:param disable: If True, the logger is disabled and will not log anything in the database.
|
31
|
+
:param auto_log_plt: If True, automatically detect if matplotlib figures were generated and log them. Note that
|
32
|
+
it checks only when a method on the socket is called.
|
33
|
+
"""
|
34
|
+
if keep_each <= 0:
|
35
|
+
raise ValueError("Parameter keep_each must be grater than 0: {1, 2, 3, ...}")
|
36
|
+
if flush_each <= 0:
|
37
|
+
raise ValueError("Parameter keep_each must be grater than 0: {1, 2, 3, ...}")
|
38
|
+
self.db_path = db_path
|
39
|
+
self.run_id = run_id
|
40
|
+
self.start = start
|
41
|
+
self.flush_each = flush_each
|
42
|
+
self.keep_each = keep_each
|
43
|
+
|
44
|
+
|
45
|
+
self.global_step = {}
|
46
|
+
self.buffer = {}
|
47
|
+
self.log_count = {}
|
48
|
+
|
49
|
+
self.image_buffer = {}
|
50
|
+
self.fig_ids = set()
|
51
|
+
|
52
|
+
self.fragments_buffer = {}
|
53
|
+
|
54
|
+
self.enabled = True
|
55
|
+
self.run_rep = 0
|
56
|
+
self.disable = disable
|
57
|
+
|
58
|
+
self.pre_hooks: List[Callable[[], None]] = []
|
59
|
+
|
60
|
+
if auto_log_plt:
|
61
|
+
self.pre_hooks.append(self.detect_and_log_figures)
|
62
|
+
|
63
|
+
# Set the exception handler to set the status to failed and disable the logger if the program crashes
|
64
|
+
self._exception_handler()
|
65
|
+
|
66
|
+
def new_repetition(self):
|
67
|
+
"""
|
68
|
+
Create a new repetition of the current run. This is useful if you want to log multiple repetitions of the same
|
69
|
+
run. This is a mutating method, meaning that you can call it at the end of the training loop before the next
|
70
|
+
full training loop is run again.
|
71
|
+
:return: None
|
72
|
+
"""
|
73
|
+
# Start by flushing the buffer
|
74
|
+
for tag in self.buffer.keys():
|
75
|
+
self._flush(tag)
|
76
|
+
|
77
|
+
self.run_rep += 1
|
78
|
+
|
79
|
+
# Reset the writer
|
80
|
+
self.log_count = {}
|
81
|
+
self.global_step = {}
|
82
|
+
self.start = datetime.now()
|
83
|
+
|
84
|
+
def add_scalar(self, tag: str, scalar_value: Union[float, int],
|
85
|
+
step: Optional[int] = None, epoch: Optional[int] = None,
|
86
|
+
walltime: Optional[float] = None, flush: bool = False):
|
87
|
+
"""
|
88
|
+
Add a scalar to the resultTable
|
89
|
+
:param tag: The tag, formatted as: 'split/name' or simply 'split'
|
90
|
+
:param scalar_value: The value
|
91
|
+
:param step: The global step. If none, the one calculated is used
|
92
|
+
:param epoch: The epoch. If None, none is saved
|
93
|
+
:param walltime: Override the wall time with this
|
94
|
+
:param flush: Force flush all the scalars in memory
|
95
|
+
:return: None
|
96
|
+
"""
|
97
|
+
self._run_pre_hooks()
|
98
|
+
if not self.enabled:
|
99
|
+
raise RuntimeError("The LogWriter is read only! This might be due to the fact that you loaded an already"
|
100
|
+
"existing one or you reported final metrics.")
|
101
|
+
# Early return if we are not supposed to keep this run.
|
102
|
+
if not self._keep(tag):
|
103
|
+
return
|
104
|
+
|
105
|
+
# We split the tag as a split and a name for readability
|
106
|
+
splitted_tag = tag.split("/")
|
107
|
+
if len(splitted_tag) == 2:
|
108
|
+
split, name = splitted_tag[0], splitted_tag[1]
|
109
|
+
else:
|
110
|
+
split, name = "", splitted_tag[0]
|
111
|
+
|
112
|
+
scalar_value = float(scalar_value) # Cast it as float
|
113
|
+
|
114
|
+
step = self._get_global_step(tag) if step is None else step
|
115
|
+
|
116
|
+
walltime = (datetime.now() - self.start).total_seconds() if walltime is None else walltime
|
117
|
+
|
118
|
+
epoch = 0 if epoch is None else epoch
|
119
|
+
|
120
|
+
# Added a row to table logs
|
121
|
+
self._log(tag, epoch, step, split, name, scalar_value, walltime, self.run_rep)
|
122
|
+
|
123
|
+
# Flush all if requested to force flush
|
124
|
+
if flush:
|
125
|
+
self._flush_all()
|
126
|
+
|
127
|
+
def read_scalar(self, tag) -> List[Scalar]:
|
128
|
+
"""
|
129
|
+
Read a scalar from the resultTable with the given tag
|
130
|
+
:param tag: The tag to read formatted as: 'split/name' or simply 'split'.
|
131
|
+
:return: A list of Scalars items
|
132
|
+
"""
|
133
|
+
splitted_tag = tag.split("/")
|
134
|
+
if len(splitted_tag) == 2:
|
135
|
+
split, name = splitted_tag[0], splitted_tag[1]
|
136
|
+
else:
|
137
|
+
split, name = "", splitted_tag[0]
|
138
|
+
|
139
|
+
with self._cursor as cursor:
|
140
|
+
cursor.execute("SELECT * FROM Logs WHERE run_id=? AND split=? AND label=?", (self.run_id, split, name))
|
141
|
+
# cursor.execute("SELECT * FROM Logs", (self.run_id, split, name))
|
142
|
+
rows = cursor.fetchall()
|
143
|
+
return [Scalar(*row[1:]) for row in rows]
|
144
|
+
|
145
|
+
def add_image(self, image: Union[bytes, Image.Image], step: Optional[int] = None, split: Optional[str] = None,
|
146
|
+
epoch: Optional[int] = None, flush: bool = False):
|
147
|
+
"""
|
148
|
+
Add an image to the resultTable
|
149
|
+
:param image: Must be png bytes or a PIL Image object.
|
150
|
+
:param step: The global step at which the image was generated. If None, the maximum step is taken from all global
|
151
|
+
steps.
|
152
|
+
:param split: The split in which the image was generated.
|
153
|
+
:param epoch: The epoch at which the image was generated. If None, no epoch is saved.
|
154
|
+
:param flush: If True, flush all data in memory to the database.
|
155
|
+
:return: None
|
156
|
+
"""
|
157
|
+
self._run_pre_hooks()
|
158
|
+
if isinstance(image, Image.Image):
|
159
|
+
buffer = BytesIO()
|
160
|
+
image.save(buffer, format='PNG')
|
161
|
+
img_bytes = buffer.getvalue()
|
162
|
+
else:
|
163
|
+
img_bytes = image
|
164
|
+
|
165
|
+
if step is None:
|
166
|
+
# Take the max step from scalars
|
167
|
+
step = max(self.global_step.values()) if self.global_step else 0
|
168
|
+
|
169
|
+
# Add to buffer
|
170
|
+
self._log_image(img_bytes, step, split, self.run_rep, epoch)
|
171
|
+
|
172
|
+
if flush:
|
173
|
+
self._flush_all()
|
174
|
+
|
175
|
+
def read_images(self, id: Optional[int] = None, step: Optional[int] = None, split: Optional[str] = None, epoch: Optional[int] = None,
|
176
|
+
repetition: Optional[int] = None) -> List[dict]:
|
177
|
+
"""
|
178
|
+
Return all images logged in the run with the given step, split and/or epoch.
|
179
|
+
:param id: The id of the image to read
|
180
|
+
:param step: The step at which the image was generated. If None, all images are returned.
|
181
|
+
:param split: The split in which the images were generated. If None, all splits are returned.
|
182
|
+
:param epoch: The epoch at which the images were generated. If None, all epochs are returned.
|
183
|
+
:param repetition: The repetition of the images. If None, all images are returned.
|
184
|
+
:return: A list of image bytes
|
185
|
+
"""
|
186
|
+
return self._get_images(id, step, split, epoch, repetition, img_type="IMAGE")
|
187
|
+
|
188
|
+
|
189
|
+
def detect_and_log_figures(self, step: Optional[int] = None, split: Optional[str] = None,
|
190
|
+
epoch: Optional[int] = None, flush: bool = False):
|
191
|
+
"""
|
192
|
+
Detect matplotlib figures that are currently open and log them to the result table. (Save them as png).
|
193
|
+
:param step: The global step at which the image was generated. If None, the maximum step is taken from all global
|
194
|
+
steps.
|
195
|
+
:param split: The split in which the images were generated.
|
196
|
+
:param epoch: The epoch at which the images were generated. If None, no epoch is saved.
|
197
|
+
:param flush: If True, flush all data in memory to the database.
|
198
|
+
:return: None
|
199
|
+
"""
|
200
|
+
if step is None:
|
201
|
+
# Take the max step from scalars
|
202
|
+
step = max(self.global_step.values()) if self.global_step else 0
|
203
|
+
|
204
|
+
for num in plt.get_fignums():
|
205
|
+
fig = plt.figure(num)
|
206
|
+
fig.tight_layout()
|
207
|
+
|
208
|
+
# Save it as bytes
|
209
|
+
buffer = BytesIO()
|
210
|
+
fig.savefig(buffer, format='png')
|
211
|
+
buffer.seek(0)
|
212
|
+
fig_hash = hashlib.sha256(buffer.read()).hexdigest()
|
213
|
+
if fig_hash in self.fig_ids:
|
214
|
+
# If we already logged this figure, skip it
|
215
|
+
continue
|
216
|
+
self.fig_ids.add(fig_hash)
|
217
|
+
img_bytes = buffer.getvalue()
|
218
|
+
|
219
|
+
self._log_image(img_bytes, step, split, self.run_rep, epoch, type="PLOT")
|
220
|
+
|
221
|
+
if flush:
|
222
|
+
self._flush_all()
|
223
|
+
|
224
|
+
def add_text(self, text: str, step: Optional[int] = None, split: Optional[str] = None,
|
225
|
+
epoch: Optional[int] = None, flush: bool = False):
|
226
|
+
"""
|
227
|
+
Add a text sample to the resultTable
|
228
|
+
:param text: Must be a string
|
229
|
+
:param step: The global step at which the image was generated. If None, the maximum step is taken from all global
|
230
|
+
scalar steps.
|
231
|
+
:param split: The split in which the image was generated.
|
232
|
+
:param epoch: The epoch at which the image was generated. If None, no epoch is saved.
|
233
|
+
:param flush: If True, flush all data in memory to the database.
|
234
|
+
:return: None
|
235
|
+
"""
|
236
|
+
self._run_pre_hooks()
|
237
|
+
if step is None:
|
238
|
+
# Take the max step from scalars
|
239
|
+
step = max(self.global_step.values()) if self.global_step else 0
|
240
|
+
|
241
|
+
self._log_fragment(text, step, split, self.run_rep, epoch, type="RAW")
|
242
|
+
if flush:
|
243
|
+
self._flush_all()
|
244
|
+
|
245
|
+
def read_text(self, id: Optional[int] = None, step: Optional[int] = None, split: Optional[str] = None,
|
246
|
+
epoch: Optional[int] = None, repetition: Optional[int] = None):
|
247
|
+
"""
|
248
|
+
Return all text samples logged in the run with the given id, step, split and/or epoch.
|
249
|
+
:param id: The id of the text sample to read
|
250
|
+
:param step: The step at which the text was generated. If None, all text samples are returned.
|
251
|
+
:param split: The split in which the texts were generated. If None, all splits are returned.
|
252
|
+
:param epoch: The epoch at which the texts were generated. If None, all epochs are returned.
|
253
|
+
:param repetition: The repetition of the run. If None, all text samples are returned.
|
254
|
+
:return: A list of text samples
|
255
|
+
"""
|
256
|
+
return self._get_fragments(id, step, split, epoch, repetition, fragment_type="RAW")
|
257
|
+
|
258
|
+
def add_fragment(self, content: str, step: Optional[int] = None, split: Optional[str] = None,
|
259
|
+
epoch: Optional[int] = None, flush: bool = False):
|
260
|
+
"""
|
261
|
+
Add a html fragment to the resultTable
|
262
|
+
:param content: Must be a string containing valid HTML content.
|
263
|
+
:param step: The global step at which the image was generated. If None, the maximum step is taken from all global
|
264
|
+
scalar steps.
|
265
|
+
:param split: The split in which the image was generated.
|
266
|
+
:param epoch: The epoch at which the image was generated. If None, no epoch is saved.
|
267
|
+
:param flush: If True, flush all data in memory to the database.
|
268
|
+
:return: None
|
269
|
+
"""
|
270
|
+
self._run_pre_hooks()
|
271
|
+
if step is None:
|
272
|
+
# Take the max step from scalars
|
273
|
+
step = max(self.global_step.values()) if self.global_step else 0
|
274
|
+
|
275
|
+
self._log_fragment(content, step, split, self.run_rep, epoch, type="HTML")
|
276
|
+
if flush:
|
277
|
+
self._flush_all()
|
278
|
+
|
279
|
+
def read_fragment(self, id: Optional[int] = None, step: Optional[int] = None, split: Optional[str] = None,
|
280
|
+
epoch: Optional[int] = None, repetition: Optional[int] = None):
|
281
|
+
"""
|
282
|
+
Return all html fragments logged in the run with the given id, step, split and/or epoch.
|
283
|
+
:param id: The id of the html fragment to read
|
284
|
+
:param step: The step at which the html fragment was generated. If None, all html fragment are returned.
|
285
|
+
:param split: The split in which the html fragments were generated. If None, all splits are returned.
|
286
|
+
:param epoch: The epoch at which the html fragments were generated. If None, all epochs are returned.
|
287
|
+
:param repetition: The repetition of the run. If None, all html fragment are returned.
|
288
|
+
:return: A list of html fragment
|
289
|
+
"""
|
290
|
+
return self._get_fragments(id, step, split, epoch, repetition, fragment_type="HTML")
|
291
|
+
|
292
|
+
def read_figures(self, id: Optional[int] = None, step: Optional[int] = None, split: Optional[str] = None, epoch: Optional[int] = None,
|
293
|
+
repetition: Optional[int] = None):
|
294
|
+
"""
|
295
|
+
Return all figures logged in the run with the given step, split and/or epoch.
|
296
|
+
:param id: The id of the figure to read. If None, all figures are returned.
|
297
|
+
:param step: The step at which the figure was generated. If None, all figures are returned.
|
298
|
+
:param split: The split in which the figures were generated. If None, all splits are returned.
|
299
|
+
:param epoch: The epoch at which the figures were generated. If None, all epochs are returned.
|
300
|
+
:param repetition: The repetition of the figures. If None, all figures are returned.
|
301
|
+
:return:
|
302
|
+
"""
|
303
|
+
return self._get_images(id, step, split, epoch, repetition, img_type="PLOT")
|
304
|
+
|
305
|
+
def add_hparams(self, **kwargs):
|
306
|
+
"""
|
307
|
+
Add hyperparameters to the result table
|
308
|
+
:param kwargs: The hyperparameters to save
|
309
|
+
:return: None
|
310
|
+
"""
|
311
|
+
self._run_pre_hooks()
|
312
|
+
# Prepare the data to save
|
313
|
+
if self.disable:
|
314
|
+
return
|
315
|
+
|
316
|
+
query = "INSERT INTO Results (run_id, metric, value, is_hparam) VALUES (?, ?, ?, ?)"
|
317
|
+
data = [(self.run_id, key, value, True) for key, value in kwargs.items()]
|
318
|
+
with self._cursor as cursor:
|
319
|
+
cursor.executemany(query, data)
|
320
|
+
|
321
|
+
def get_hparams(self) -> Dict[str, Any]:
|
322
|
+
"""
|
323
|
+
Get the hyperparameters of the current run
|
324
|
+
:return: A dict of hyperparameters
|
325
|
+
"""
|
326
|
+
with self._cursor as cursor:
|
327
|
+
cursor.execute("SELECT metric, value FROM Results WHERE run_id=? AND is_hparam=1", (self.run_id,))
|
328
|
+
rows = cursor.fetchall()
|
329
|
+
return {row[0]: row[1] for row in rows}
|
330
|
+
|
331
|
+
def get_repetitions(self) -> List[int]:
|
332
|
+
"""
|
333
|
+
Get the all the repetitions ids of the current run
|
334
|
+
:return: A list of repetitions ids
|
335
|
+
"""
|
336
|
+
with self._cursor as cursor:
|
337
|
+
cursor.execute("SELECT DISTINCT run_rep FROM Logs WHERE run_id=?", (self.run_id,))
|
338
|
+
rows = cursor.fetchall()
|
339
|
+
return [row[0] for row in rows]
|
340
|
+
|
341
|
+
def write_result(self, **kwargs):
|
342
|
+
"""
|
343
|
+
Log the results of the run to the table, then disable the logger. This means that the logger will be read-only
|
344
|
+
after this operation. If you run multiple iterations, consider writing the results only once all the runs are
|
345
|
+
finished. You can aggregate the different metrics before passing them.
|
346
|
+
:param kwargs: The metrics to save
|
347
|
+
:return: None
|
348
|
+
"""
|
349
|
+
self._run_pre_hooks()
|
350
|
+
if self.disable:
|
351
|
+
return
|
352
|
+
# Start by flushing the buffer
|
353
|
+
self._flush_all()
|
354
|
+
|
355
|
+
# Then, prepare the data to save
|
356
|
+
query = "INSERT INTO Results (run_id, metric, value, is_hparam) VALUES (?, ?, ?, ?)"
|
357
|
+
data = [(self.run_id, key, value, False) for key, value in kwargs.items()]
|
358
|
+
with self._cursor as cursor:
|
359
|
+
cursor.executemany(query, data)
|
360
|
+
|
361
|
+
# Set the status to finished
|
362
|
+
self.set_status("finished")
|
363
|
+
|
364
|
+
# Disable the logger
|
365
|
+
self.enabled = False
|
366
|
+
|
367
|
+
def set_status(self, status: Literal["running", "finished", "failed"]):
|
368
|
+
"""
|
369
|
+
Manually set the status of the run
|
370
|
+
:param status: The status to set
|
371
|
+
:return: None
|
372
|
+
"""
|
373
|
+
self._run_pre_hooks()
|
374
|
+
if self.disable:
|
375
|
+
return
|
376
|
+
if status not in ["running", "finished", "failed"]:
|
377
|
+
raise ValueError("Status must be one of: running, finished, failed")
|
378
|
+
with self._cursor as cursor:
|
379
|
+
cursor.execute("UPDATE Experiments SET status=? WHERE run_id=?", (status, self.run_id))
|
380
|
+
|
381
|
+
@property
|
382
|
+
def status(self) -> str:
|
383
|
+
"""
|
384
|
+
Get the status of the run
|
385
|
+
:return: The status of the run
|
386
|
+
"""
|
387
|
+
with self._cursor as cursor:
|
388
|
+
cursor.execute("SELECT status FROM Experiments WHERE run_id=?", (self.run_id,))
|
389
|
+
row = cursor.fetchone()
|
390
|
+
if row is None:
|
391
|
+
raise RuntimeError(f"Run {self.run_id} does not exist.")
|
392
|
+
return row[0]
|
393
|
+
|
394
|
+
@property
|
395
|
+
def scalars(self) -> List[str]:
|
396
|
+
"""
|
397
|
+
Return the tags of all scalars logged in the run
|
398
|
+
"""
|
399
|
+
# We need to format the tags as Split/Label
|
400
|
+
# If split is empty, we just return the label
|
401
|
+
rows = [(row[0] + "/" + row[1]) if row[0] != "" else row[1] for row in self.formatted_scalars]
|
402
|
+
return rows
|
403
|
+
|
404
|
+
@property
|
405
|
+
def formatted_scalars(self) -> List[Tuple[str, str]]:
|
406
|
+
"""
|
407
|
+
Return the scalars values as split and label
|
408
|
+
"""
|
409
|
+
with self._cursor as cursor:
|
410
|
+
cursor.execute("SELECT DISTINCT split, label FROM Logs WHERE run_id=?", (self.run_id,))
|
411
|
+
rows = cursor.fetchall()
|
412
|
+
# We need to format the tags as Split/Label
|
413
|
+
# If split is empty, we just return the label
|
414
|
+
return [(row[0], row[1]) for row in rows]
|
415
|
+
|
416
|
+
def __getitem__(self, tag):
|
417
|
+
"""
|
418
|
+
Get the scalar values for a given tag.
|
419
|
+
"""
|
420
|
+
return self.read_scalar(tag)
|
421
|
+
|
422
|
+
def _run_pre_hooks(self):
|
423
|
+
for hook in self.pre_hooks:
|
424
|
+
hook()
|
425
|
+
|
426
|
+
def _get_fragments(self, id: Optional[int], step: Optional[int], split: Optional[str], epoch: Optional[int],
|
427
|
+
repetition: Optional[int], fragment_type: Literal["RAW", "HTML"]) -> List[dict]:
|
428
|
+
command = f"SELECT id_, step, epoch, run_rep, split, fragment FROM Fragments WHERE run_id=? AND fragment_type='{fragment_type}'"
|
429
|
+
params = [self.run_id]
|
430
|
+
if id is not None:
|
431
|
+
command += " AND id_=?"
|
432
|
+
params.append(id)
|
433
|
+
|
434
|
+
if step is not None:
|
435
|
+
command += " AND step=?"
|
436
|
+
params.append(step)
|
437
|
+
|
438
|
+
if split is not None:
|
439
|
+
command += " AND split=?"
|
440
|
+
params.append(split)
|
441
|
+
|
442
|
+
if epoch is not None:
|
443
|
+
command += " AND epoch=?"
|
444
|
+
params.append(epoch)
|
445
|
+
|
446
|
+
if repetition is not None:
|
447
|
+
command += " AND run_rep=?"
|
448
|
+
params.append(repetition)
|
449
|
+
|
450
|
+
with self._cursor as cursor:
|
451
|
+
cursor.execute(f'{command};', tuple(params))
|
452
|
+
rows = cursor.fetchall()
|
453
|
+
# Convert the bytes to PIL Image objects
|
454
|
+
return [dict(
|
455
|
+
id=row[0],
|
456
|
+
step=row[1],
|
457
|
+
epoch=row[2],
|
458
|
+
run_rep=row[3],
|
459
|
+
split=row[4],
|
460
|
+
fragment=row[5]
|
461
|
+
) for row in rows]
|
462
|
+
|
463
|
+
def _get_images(self, id: Optional[int], step: Optional[int], split: Optional[str], epoch: Optional[int],
|
464
|
+
repetition: Optional[int], img_type: Literal["IMAGE", "PLOT"]) -> List[dict]:
|
465
|
+
command = f"SELECT id_, step, epoch, run_rep, split, image FROM Images WHERE run_id=? AND img_type='{img_type}'"
|
466
|
+
params = [self.run_id]
|
467
|
+
if id is not None:
|
468
|
+
command += " AND id_=?"
|
469
|
+
params.append(id)
|
470
|
+
|
471
|
+
if step is not None:
|
472
|
+
command += " AND step=?"
|
473
|
+
params.append(step)
|
474
|
+
|
475
|
+
if split is not None:
|
476
|
+
command += " AND split=?"
|
477
|
+
params.append(split)
|
478
|
+
|
479
|
+
if epoch is not None:
|
480
|
+
command += " AND epoch=?"
|
481
|
+
params.append(epoch)
|
482
|
+
|
483
|
+
if repetition is not None:
|
484
|
+
command += " AND run_rep=?"
|
485
|
+
params.append(repetition)
|
486
|
+
|
487
|
+
with self._cursor as cursor:
|
488
|
+
cursor.execute(f'{command};', tuple(params))
|
489
|
+
rows = cursor.fetchall()
|
490
|
+
# Convert the bytes to PIL Image objects
|
491
|
+
return [dict(
|
492
|
+
id=row[0],
|
493
|
+
step=row[1],
|
494
|
+
epoch=row[2],
|
495
|
+
run_rep=row[3],
|
496
|
+
split=row[4],
|
497
|
+
image=Image.open(BytesIO(row[5]))
|
498
|
+
) for row in rows]
|
499
|
+
|
500
|
+
def _get_global_step(self, tag):
|
501
|
+
"""
|
502
|
+
Keep track of the global step for each tag.
|
503
|
+
:param tag: The tag to get the step
|
504
|
+
:return: The current global step
|
505
|
+
"""
|
506
|
+
if tag not in self.global_step:
|
507
|
+
self.global_step[tag] = 0
|
508
|
+
|
509
|
+
out = self.global_step[tag]
|
510
|
+
self.global_step[tag] += 1
|
511
|
+
return out
|
512
|
+
|
513
|
+
def _log(self, tag: str, epoch: int, step: int, split: str, name: str, scalar_value: float, walltime: float,
|
514
|
+
run_rep: int):
|
515
|
+
"""
|
516
|
+
Store the scalar log into the buffer, and flush the buffer if it is full.
|
517
|
+
:param tag: The tag
|
518
|
+
:param epoch: The epoch
|
519
|
+
:param step: The step
|
520
|
+
:param split: The split
|
521
|
+
:param name: The name
|
522
|
+
:param scalar_value: The value
|
523
|
+
:param walltime: The wall time
|
524
|
+
:param run_rep: The run repetition
|
525
|
+
:return: None
|
526
|
+
"""
|
527
|
+
if tag not in self.buffer:
|
528
|
+
self.buffer[tag] = []
|
529
|
+
self.buffer[tag].append((self.run_id, epoch, step, split, name, scalar_value, walltime, run_rep))
|
530
|
+
|
531
|
+
if len(self.buffer[tag]) >= self.flush_each:
|
532
|
+
self._flush(tag)
|
533
|
+
|
534
|
+
def _log_fragment(self, fragment: str, step: int, split: Optional[int], repetition: int, epoch: Optional[int],
|
535
|
+
type: Literal["RAW", "HTML"] = "RAW"):
|
536
|
+
"""
|
537
|
+
Log a text or html fragment to the resultTable.
|
538
|
+
:param fragment: The content to log
|
539
|
+
:param step: The step
|
540
|
+
:param split: The split that made it
|
541
|
+
:param repetition: The run repetition
|
542
|
+
:param epoch: The epoch
|
543
|
+
:param type: Raw (for text only) or HTML (for html content)
|
544
|
+
:return: None
|
545
|
+
"""
|
546
|
+
if split not in self.fragments_buffer:
|
547
|
+
self.fragments_buffer[split] = []
|
548
|
+
|
549
|
+
self.fragments_buffer[split].append((self.run_id, step, epoch, repetition, type, split, fragment))
|
550
|
+
|
551
|
+
if len(self.fragments_buffer[split]) >= self.flush_each:
|
552
|
+
self._flush_fragment(split)
|
553
|
+
|
554
|
+
def _log_image(self, image: bytes, step: int, split: Optional[int], repetition: int, epoch: Optional[int],
|
555
|
+
type: Literal["IMAGE", "PLOT"] = "IMAGE"):
|
556
|
+
"""
|
557
|
+
Store the image log into the buffer, and flush the buffer if it is full.
|
558
|
+
:param image: The image bytes
|
559
|
+
:param step: The step
|
560
|
+
:param split: The split
|
561
|
+
:param repetition: The run repetition
|
562
|
+
:param epoch: The epoch
|
563
|
+
:param type: The type of the image, either "IMAGE" or "PLOT". Default is "IMAGE".
|
564
|
+
:return: None
|
565
|
+
"""
|
566
|
+
if split not in self.image_buffer:
|
567
|
+
self.image_buffer[split] = []
|
568
|
+
|
569
|
+
self.image_buffer[split].append((self.run_id, step, epoch, repetition, type, split, image))
|
570
|
+
|
571
|
+
if len(self.image_buffer[split]) >= self.flush_each:
|
572
|
+
self._flush_image(split)
|
573
|
+
|
574
|
+
def _flush_all(self):
|
575
|
+
"""
|
576
|
+
Flush all buffers.
|
577
|
+
:return: None
|
578
|
+
"""
|
579
|
+
# Flush all the scalars
|
580
|
+
for tag in self.buffer.keys():
|
581
|
+
self._flush(tag)
|
582
|
+
|
583
|
+
# Flush all the images
|
584
|
+
for split in self.image_buffer.keys():
|
585
|
+
self._flush_image(split)
|
586
|
+
|
587
|
+
# Flush all the fragments
|
588
|
+
for split in self.fragments_buffer.keys():
|
589
|
+
self._flush_fragment(split)
|
590
|
+
|
591
|
+
def _flush_image(self, split):
|
592
|
+
query = """
|
593
|
+
INSERT INTO Images (run_id, step, epoch, run_rep, img_type, split, image)
|
594
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
595
|
+
"""
|
596
|
+
if not self.disable:
|
597
|
+
with self._cursor as cursor:
|
598
|
+
cursor.executemany(query, self.image_buffer[split])
|
599
|
+
|
600
|
+
# Reset the buffer
|
601
|
+
self.image_buffer[split] = []
|
602
|
+
|
603
|
+
def _flush_fragment(self, split):
|
604
|
+
query = """
|
605
|
+
INSERT INTO Fragments (run_id, step, epoch, run_rep, fragment_type, split, fragment)
|
606
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
607
|
+
"""
|
608
|
+
if not self.disable:
|
609
|
+
with self._cursor as cursor:
|
610
|
+
cursor.executemany(query, self.fragments_buffer[split])
|
611
|
+
|
612
|
+
# Reset the buffer
|
613
|
+
self.fragments_buffer[split] = []
|
614
|
+
|
615
|
+
def _flush(self, tag: str):
|
616
|
+
"""
|
617
|
+
Flush the scalar values into the db and reset the buffer.
|
618
|
+
:param tag: The tag to flush
|
619
|
+
:return: None
|
620
|
+
"""
|
621
|
+
query = """
|
622
|
+
INSERT INTO Logs (run_id, epoch, step, split, label, value, wall_time, run_rep)
|
623
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
624
|
+
"""
|
625
|
+
|
626
|
+
if not self.disable:
|
627
|
+
with self._cursor as cursor:
|
628
|
+
cursor.executemany(query, self.buffer[tag])
|
629
|
+
|
630
|
+
# Reset the buffer
|
631
|
+
self.buffer[tag] = []
|
632
|
+
|
633
|
+
def _keep(self, tag: str) -> bool:
|
634
|
+
"""
|
635
|
+
Assert if we need to record this log or drop it. Depends on the kep_each attribute
|
636
|
+
:param tag: The tag
|
637
|
+
:return: True if we need to keep it and False if we drop it
|
638
|
+
"""
|
639
|
+
if tag not in self.log_count:
|
640
|
+
self.log_count[tag] = 0
|
641
|
+
self.log_count[tag] += 1
|
642
|
+
if self.log_count[tag] >= self.keep_each:
|
643
|
+
self.log_count[tag] = 0
|
644
|
+
return True
|
645
|
+
else:
|
646
|
+
return False
|
647
|
+
|
648
|
+
def _exception_handler(self):
|
649
|
+
"""
|
650
|
+
Set the exception handler to set the status to failed and disable the logger if the program crashes
|
651
|
+
"""
|
652
|
+
previous_hooks = sys.excepthook
|
653
|
+
def handler(exc_type, exc_value, traceback):
|
654
|
+
# Set the status to failed
|
655
|
+
self.set_status("failed")
|
656
|
+
# Disable the logger
|
657
|
+
self.enabled = False
|
658
|
+
|
659
|
+
# Call the previous exception handler
|
660
|
+
previous_hooks(exc_type, exc_value, traceback)
|
661
|
+
|
662
|
+
# Set the new exception handler
|
663
|
+
sys.excepthook = handler
|
664
|
+
|
665
|
+
@property
|
666
|
+
def _cursor(self):
|
667
|
+
return Cursor(self.db_path)
|