deepboard 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. deepboard/__init__.py +1 -0
  2. deepboard/__version__.py +4 -0
  3. deepboard/gui/THEME.yml +28 -0
  4. deepboard/gui/__init__.py +0 -0
  5. deepboard/gui/assets/artefacts.css +108 -0
  6. deepboard/gui/assets/base.css +208 -0
  7. deepboard/gui/assets/base.js +77 -0
  8. deepboard/gui/assets/charts.css +188 -0
  9. deepboard/gui/assets/compare.css +90 -0
  10. deepboard/gui/assets/datagrid.css +120 -0
  11. deepboard/gui/assets/fileview.css +13 -0
  12. deepboard/gui/assets/right_panel.css +227 -0
  13. deepboard/gui/assets/theme.css +85 -0
  14. deepboard/gui/components/__init__.py +8 -0
  15. deepboard/gui/components/artefact_group.py +12 -0
  16. deepboard/gui/components/chart_type.py +22 -0
  17. deepboard/gui/components/legend.py +34 -0
  18. deepboard/gui/components/log_selector.py +22 -0
  19. deepboard/gui/components/modal.py +20 -0
  20. deepboard/gui/components/smoother.py +21 -0
  21. deepboard/gui/components/split_selector.py +21 -0
  22. deepboard/gui/components/stat_line.py +8 -0
  23. deepboard/gui/entry.py +21 -0
  24. deepboard/gui/main.py +93 -0
  25. deepboard/gui/pages/__init__.py +1 -0
  26. deepboard/gui/pages/compare_page/__init__.py +6 -0
  27. deepboard/gui/pages/compare_page/compare_page.py +22 -0
  28. deepboard/gui/pages/compare_page/components/__init__.py +4 -0
  29. deepboard/gui/pages/compare_page/components/card_list.py +19 -0
  30. deepboard/gui/pages/compare_page/components/chart.py +54 -0
  31. deepboard/gui/pages/compare_page/components/compare_setup.py +30 -0
  32. deepboard/gui/pages/compare_page/components/split_card.py +51 -0
  33. deepboard/gui/pages/compare_page/components/utils.py +20 -0
  34. deepboard/gui/pages/compare_page/routes.py +58 -0
  35. deepboard/gui/pages/main_page/__init__.py +4 -0
  36. deepboard/gui/pages/main_page/datagrid/__init__.py +5 -0
  37. deepboard/gui/pages/main_page/datagrid/compare_button.py +21 -0
  38. deepboard/gui/pages/main_page/datagrid/datagrid.py +67 -0
  39. deepboard/gui/pages/main_page/datagrid/handlers.py +54 -0
  40. deepboard/gui/pages/main_page/datagrid/header.py +43 -0
  41. deepboard/gui/pages/main_page/datagrid/routes.py +112 -0
  42. deepboard/gui/pages/main_page/datagrid/row.py +20 -0
  43. deepboard/gui/pages/main_page/datagrid/sortable_column_js.py +45 -0
  44. deepboard/gui/pages/main_page/datagrid/utils.py +9 -0
  45. deepboard/gui/pages/main_page/handlers.py +16 -0
  46. deepboard/gui/pages/main_page/main_page.py +21 -0
  47. deepboard/gui/pages/main_page/right_panel/__init__.py +12 -0
  48. deepboard/gui/pages/main_page/right_panel/config.py +57 -0
  49. deepboard/gui/pages/main_page/right_panel/fragments.py +133 -0
  50. deepboard/gui/pages/main_page/right_panel/hparams.py +25 -0
  51. deepboard/gui/pages/main_page/right_panel/images.py +358 -0
  52. deepboard/gui/pages/main_page/right_panel/run_info.py +86 -0
  53. deepboard/gui/pages/main_page/right_panel/scalars.py +251 -0
  54. deepboard/gui/pages/main_page/right_panel/template.py +151 -0
  55. deepboard/gui/pages/main_page/routes.py +25 -0
  56. deepboard/gui/pages/not_found.py +3 -0
  57. deepboard/gui/requirements.txt +5 -0
  58. deepboard/gui/utils.py +267 -0
  59. deepboard/resultTable/__init__.py +2 -0
  60. deepboard/resultTable/cursor.py +20 -0
  61. deepboard/resultTable/logwritter.py +667 -0
  62. deepboard/resultTable/resultTable.py +529 -0
  63. deepboard/resultTable/scalar.py +29 -0
  64. deepboard/resultTable/table_schema.py +135 -0
  65. deepboard/resultTable/utils.py +50 -0
  66. deepboard-0.2.0.dist-info/METADATA +164 -0
  67. deepboard-0.2.0.dist-info/RECORD +69 -0
  68. deepboard-0.2.0.dist-info/WHEEL +4 -0
  69. deepboard-0.2.0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,667 @@
1
+ from traceback import print_tb
2
+ from typing import *
3
+ from .cursor import Cursor
4
+ from datetime import datetime
5
+ from .scalar import Scalar
6
+ import sys
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import matplotlib.pyplot as plt
10
+ import hashlib
11
+
12
+ class LogWriter:
13
+ """
14
+ This class makes an object that is bound to a run row in the result table. This means that everything that is
15
+ logged through this object is added into the result table and this object can be used to interact with a specific
16
+ run. This object is single use. This means that once the final results are written, the object becomes read-only.
17
+
18
+ You should not instantiate this class directly, but use the ResultTable class to create it instead.
19
+
20
+ """
21
+ def __init__(self, db_path, run_id: int, start: datetime, flush_each: int = 10, keep_each: int = 1,
22
+ disable: bool = False, auto_log_plt=True):
23
+ """
24
+ :param db_path: The path to the database file
25
+ :param run_id: The run id of this run
26
+ :param start: The start time of this run
27
+ :param flush_each: Every how many logs should we write them to the database. (increase it to reduce io)
28
+ :param keep_each: Every how many logs datapoint should we store. The others will be discarted. If 2, only one
29
+ datapoint every two times the add_scalar method is called will be stored.
30
+ :param disable: If True, the logger is disabled and will not log anything in the database.
31
+ :param auto_log_plt: If True, automatically detect if matplotlib figures were generated and log them. Note that
32
+ it checks only when a method on the socket is called.
33
+ """
34
+ if keep_each <= 0:
35
+ raise ValueError("Parameter keep_each must be grater than 0: {1, 2, 3, ...}")
36
+ if flush_each <= 0:
37
+ raise ValueError("Parameter keep_each must be grater than 0: {1, 2, 3, ...}")
38
+ self.db_path = db_path
39
+ self.run_id = run_id
40
+ self.start = start
41
+ self.flush_each = flush_each
42
+ self.keep_each = keep_each
43
+
44
+
45
+ self.global_step = {}
46
+ self.buffer = {}
47
+ self.log_count = {}
48
+
49
+ self.image_buffer = {}
50
+ self.fig_ids = set()
51
+
52
+ self.fragments_buffer = {}
53
+
54
+ self.enabled = True
55
+ self.run_rep = 0
56
+ self.disable = disable
57
+
58
+ self.pre_hooks: List[Callable[[], None]] = []
59
+
60
+ if auto_log_plt:
61
+ self.pre_hooks.append(self.detect_and_log_figures)
62
+
63
+ # Set the exception handler to set the status to failed and disable the logger if the program crashes
64
+ self._exception_handler()
65
+
66
+ def new_repetition(self):
67
+ """
68
+ Create a new repetition of the current run. This is useful if you want to log multiple repetitions of the same
69
+ run. This is a mutating method, meaning that you can call it at the end of the training loop before the next
70
+ full training loop is run again.
71
+ :return: None
72
+ """
73
+ # Start by flushing the buffer
74
+ for tag in self.buffer.keys():
75
+ self._flush(tag)
76
+
77
+ self.run_rep += 1
78
+
79
+ # Reset the writer
80
+ self.log_count = {}
81
+ self.global_step = {}
82
+ self.start = datetime.now()
83
+
84
+ def add_scalar(self, tag: str, scalar_value: Union[float, int],
85
+ step: Optional[int] = None, epoch: Optional[int] = None,
86
+ walltime: Optional[float] = None, flush: bool = False):
87
+ """
88
+ Add a scalar to the resultTable
89
+ :param tag: The tag, formatted as: 'split/name' or simply 'split'
90
+ :param scalar_value: The value
91
+ :param step: The global step. If none, the one calculated is used
92
+ :param epoch: The epoch. If None, none is saved
93
+ :param walltime: Override the wall time with this
94
+ :param flush: Force flush all the scalars in memory
95
+ :return: None
96
+ """
97
+ self._run_pre_hooks()
98
+ if not self.enabled:
99
+ raise RuntimeError("The LogWriter is read only! This might be due to the fact that you loaded an already"
100
+ "existing one or you reported final metrics.")
101
+ # Early return if we are not supposed to keep this run.
102
+ if not self._keep(tag):
103
+ return
104
+
105
+ # We split the tag as a split and a name for readability
106
+ splitted_tag = tag.split("/")
107
+ if len(splitted_tag) == 2:
108
+ split, name = splitted_tag[0], splitted_tag[1]
109
+ else:
110
+ split, name = "", splitted_tag[0]
111
+
112
+ scalar_value = float(scalar_value) # Cast it as float
113
+
114
+ step = self._get_global_step(tag) if step is None else step
115
+
116
+ walltime = (datetime.now() - self.start).total_seconds() if walltime is None else walltime
117
+
118
+ epoch = 0 if epoch is None else epoch
119
+
120
+ # Added a row to table logs
121
+ self._log(tag, epoch, step, split, name, scalar_value, walltime, self.run_rep)
122
+
123
+ # Flush all if requested to force flush
124
+ if flush:
125
+ self._flush_all()
126
+
127
+ def read_scalar(self, tag) -> List[Scalar]:
128
+ """
129
+ Read a scalar from the resultTable with the given tag
130
+ :param tag: The tag to read formatted as: 'split/name' or simply 'split'.
131
+ :return: A list of Scalars items
132
+ """
133
+ splitted_tag = tag.split("/")
134
+ if len(splitted_tag) == 2:
135
+ split, name = splitted_tag[0], splitted_tag[1]
136
+ else:
137
+ split, name = "", splitted_tag[0]
138
+
139
+ with self._cursor as cursor:
140
+ cursor.execute("SELECT * FROM Logs WHERE run_id=? AND split=? AND label=?", (self.run_id, split, name))
141
+ # cursor.execute("SELECT * FROM Logs", (self.run_id, split, name))
142
+ rows = cursor.fetchall()
143
+ return [Scalar(*row[1:]) for row in rows]
144
+
145
+ def add_image(self, image: Union[bytes, Image.Image], step: Optional[int] = None, split: Optional[str] = None,
146
+ epoch: Optional[int] = None, flush: bool = False):
147
+ """
148
+ Add an image to the resultTable
149
+ :param image: Must be png bytes or a PIL Image object.
150
+ :param step: The global step at which the image was generated. If None, the maximum step is taken from all global
151
+ steps.
152
+ :param split: The split in which the image was generated.
153
+ :param epoch: The epoch at which the image was generated. If None, no epoch is saved.
154
+ :param flush: If True, flush all data in memory to the database.
155
+ :return: None
156
+ """
157
+ self._run_pre_hooks()
158
+ if isinstance(image, Image.Image):
159
+ buffer = BytesIO()
160
+ image.save(buffer, format='PNG')
161
+ img_bytes = buffer.getvalue()
162
+ else:
163
+ img_bytes = image
164
+
165
+ if step is None:
166
+ # Take the max step from scalars
167
+ step = max(self.global_step.values()) if self.global_step else 0
168
+
169
+ # Add to buffer
170
+ self._log_image(img_bytes, step, split, self.run_rep, epoch)
171
+
172
+ if flush:
173
+ self._flush_all()
174
+
175
+ def read_images(self, id: Optional[int] = None, step: Optional[int] = None, split: Optional[str] = None, epoch: Optional[int] = None,
176
+ repetition: Optional[int] = None) -> List[dict]:
177
+ """
178
+ Return all images logged in the run with the given step, split and/or epoch.
179
+ :param id: The id of the image to read
180
+ :param step: The step at which the image was generated. If None, all images are returned.
181
+ :param split: The split in which the images were generated. If None, all splits are returned.
182
+ :param epoch: The epoch at which the images were generated. If None, all epochs are returned.
183
+ :param repetition: The repetition of the images. If None, all images are returned.
184
+ :return: A list of image bytes
185
+ """
186
+ return self._get_images(id, step, split, epoch, repetition, img_type="IMAGE")
187
+
188
+
189
+ def detect_and_log_figures(self, step: Optional[int] = None, split: Optional[str] = None,
190
+ epoch: Optional[int] = None, flush: bool = False):
191
+ """
192
+ Detect matplotlib figures that are currently open and log them to the result table. (Save them as png).
193
+ :param step: The global step at which the image was generated. If None, the maximum step is taken from all global
194
+ steps.
195
+ :param split: The split in which the images were generated.
196
+ :param epoch: The epoch at which the images were generated. If None, no epoch is saved.
197
+ :param flush: If True, flush all data in memory to the database.
198
+ :return: None
199
+ """
200
+ if step is None:
201
+ # Take the max step from scalars
202
+ step = max(self.global_step.values()) if self.global_step else 0
203
+
204
+ for num in plt.get_fignums():
205
+ fig = plt.figure(num)
206
+ fig.tight_layout()
207
+
208
+ # Save it as bytes
209
+ buffer = BytesIO()
210
+ fig.savefig(buffer, format='png')
211
+ buffer.seek(0)
212
+ fig_hash = hashlib.sha256(buffer.read()).hexdigest()
213
+ if fig_hash in self.fig_ids:
214
+ # If we already logged this figure, skip it
215
+ continue
216
+ self.fig_ids.add(fig_hash)
217
+ img_bytes = buffer.getvalue()
218
+
219
+ self._log_image(img_bytes, step, split, self.run_rep, epoch, type="PLOT")
220
+
221
+ if flush:
222
+ self._flush_all()
223
+
224
+ def add_text(self, text: str, step: Optional[int] = None, split: Optional[str] = None,
225
+ epoch: Optional[int] = None, flush: bool = False):
226
+ """
227
+ Add a text sample to the resultTable
228
+ :param text: Must be a string
229
+ :param step: The global step at which the image was generated. If None, the maximum step is taken from all global
230
+ scalar steps.
231
+ :param split: The split in which the image was generated.
232
+ :param epoch: The epoch at which the image was generated. If None, no epoch is saved.
233
+ :param flush: If True, flush all data in memory to the database.
234
+ :return: None
235
+ """
236
+ self._run_pre_hooks()
237
+ if step is None:
238
+ # Take the max step from scalars
239
+ step = max(self.global_step.values()) if self.global_step else 0
240
+
241
+ self._log_fragment(text, step, split, self.run_rep, epoch, type="RAW")
242
+ if flush:
243
+ self._flush_all()
244
+
245
+ def read_text(self, id: Optional[int] = None, step: Optional[int] = None, split: Optional[str] = None,
246
+ epoch: Optional[int] = None, repetition: Optional[int] = None):
247
+ """
248
+ Return all text samples logged in the run with the given id, step, split and/or epoch.
249
+ :param id: The id of the text sample to read
250
+ :param step: The step at which the text was generated. If None, all text samples are returned.
251
+ :param split: The split in which the texts were generated. If None, all splits are returned.
252
+ :param epoch: The epoch at which the texts were generated. If None, all epochs are returned.
253
+ :param repetition: The repetition of the run. If None, all text samples are returned.
254
+ :return: A list of text samples
255
+ """
256
+ return self._get_fragments(id, step, split, epoch, repetition, fragment_type="RAW")
257
+
258
+ def add_fragment(self, content: str, step: Optional[int] = None, split: Optional[str] = None,
259
+ epoch: Optional[int] = None, flush: bool = False):
260
+ """
261
+ Add a html fragment to the resultTable
262
+ :param content: Must be a string containing valid HTML content.
263
+ :param step: The global step at which the image was generated. If None, the maximum step is taken from all global
264
+ scalar steps.
265
+ :param split: The split in which the image was generated.
266
+ :param epoch: The epoch at which the image was generated. If None, no epoch is saved.
267
+ :param flush: If True, flush all data in memory to the database.
268
+ :return: None
269
+ """
270
+ self._run_pre_hooks()
271
+ if step is None:
272
+ # Take the max step from scalars
273
+ step = max(self.global_step.values()) if self.global_step else 0
274
+
275
+ self._log_fragment(content, step, split, self.run_rep, epoch, type="HTML")
276
+ if flush:
277
+ self._flush_all()
278
+
279
+ def read_fragment(self, id: Optional[int] = None, step: Optional[int] = None, split: Optional[str] = None,
280
+ epoch: Optional[int] = None, repetition: Optional[int] = None):
281
+ """
282
+ Return all html fragments logged in the run with the given id, step, split and/or epoch.
283
+ :param id: The id of the html fragment to read
284
+ :param step: The step at which the html fragment was generated. If None, all html fragment are returned.
285
+ :param split: The split in which the html fragments were generated. If None, all splits are returned.
286
+ :param epoch: The epoch at which the html fragments were generated. If None, all epochs are returned.
287
+ :param repetition: The repetition of the run. If None, all html fragment are returned.
288
+ :return: A list of html fragment
289
+ """
290
+ return self._get_fragments(id, step, split, epoch, repetition, fragment_type="HTML")
291
+
292
+ def read_figures(self, id: Optional[int] = None, step: Optional[int] = None, split: Optional[str] = None, epoch: Optional[int] = None,
293
+ repetition: Optional[int] = None):
294
+ """
295
+ Return all figures logged in the run with the given step, split and/or epoch.
296
+ :param id: The id of the figure to read. If None, all figures are returned.
297
+ :param step: The step at which the figure was generated. If None, all figures are returned.
298
+ :param split: The split in which the figures were generated. If None, all splits are returned.
299
+ :param epoch: The epoch at which the figures were generated. If None, all epochs are returned.
300
+ :param repetition: The repetition of the figures. If None, all figures are returned.
301
+ :return:
302
+ """
303
+ return self._get_images(id, step, split, epoch, repetition, img_type="PLOT")
304
+
305
+ def add_hparams(self, **kwargs):
306
+ """
307
+ Add hyperparameters to the result table
308
+ :param kwargs: The hyperparameters to save
309
+ :return: None
310
+ """
311
+ self._run_pre_hooks()
312
+ # Prepare the data to save
313
+ if self.disable:
314
+ return
315
+
316
+ query = "INSERT INTO Results (run_id, metric, value, is_hparam) VALUES (?, ?, ?, ?)"
317
+ data = [(self.run_id, key, value, True) for key, value in kwargs.items()]
318
+ with self._cursor as cursor:
319
+ cursor.executemany(query, data)
320
+
321
+ def get_hparams(self) -> Dict[str, Any]:
322
+ """
323
+ Get the hyperparameters of the current run
324
+ :return: A dict of hyperparameters
325
+ """
326
+ with self._cursor as cursor:
327
+ cursor.execute("SELECT metric, value FROM Results WHERE run_id=? AND is_hparam=1", (self.run_id,))
328
+ rows = cursor.fetchall()
329
+ return {row[0]: row[1] for row in rows}
330
+
331
+ def get_repetitions(self) -> List[int]:
332
+ """
333
+ Get the all the repetitions ids of the current run
334
+ :return: A list of repetitions ids
335
+ """
336
+ with self._cursor as cursor:
337
+ cursor.execute("SELECT DISTINCT run_rep FROM Logs WHERE run_id=?", (self.run_id,))
338
+ rows = cursor.fetchall()
339
+ return [row[0] for row in rows]
340
+
341
+ def write_result(self, **kwargs):
342
+ """
343
+ Log the results of the run to the table, then disable the logger. This means that the logger will be read-only
344
+ after this operation. If you run multiple iterations, consider writing the results only once all the runs are
345
+ finished. You can aggregate the different metrics before passing them.
346
+ :param kwargs: The metrics to save
347
+ :return: None
348
+ """
349
+ self._run_pre_hooks()
350
+ if self.disable:
351
+ return
352
+ # Start by flushing the buffer
353
+ self._flush_all()
354
+
355
+ # Then, prepare the data to save
356
+ query = "INSERT INTO Results (run_id, metric, value, is_hparam) VALUES (?, ?, ?, ?)"
357
+ data = [(self.run_id, key, value, False) for key, value in kwargs.items()]
358
+ with self._cursor as cursor:
359
+ cursor.executemany(query, data)
360
+
361
+ # Set the status to finished
362
+ self.set_status("finished")
363
+
364
+ # Disable the logger
365
+ self.enabled = False
366
+
367
+ def set_status(self, status: Literal["running", "finished", "failed"]):
368
+ """
369
+ Manually set the status of the run
370
+ :param status: The status to set
371
+ :return: None
372
+ """
373
+ self._run_pre_hooks()
374
+ if self.disable:
375
+ return
376
+ if status not in ["running", "finished", "failed"]:
377
+ raise ValueError("Status must be one of: running, finished, failed")
378
+ with self._cursor as cursor:
379
+ cursor.execute("UPDATE Experiments SET status=? WHERE run_id=?", (status, self.run_id))
380
+
381
+ @property
382
+ def status(self) -> str:
383
+ """
384
+ Get the status of the run
385
+ :return: The status of the run
386
+ """
387
+ with self._cursor as cursor:
388
+ cursor.execute("SELECT status FROM Experiments WHERE run_id=?", (self.run_id,))
389
+ row = cursor.fetchone()
390
+ if row is None:
391
+ raise RuntimeError(f"Run {self.run_id} does not exist.")
392
+ return row[0]
393
+
394
+ @property
395
+ def scalars(self) -> List[str]:
396
+ """
397
+ Return the tags of all scalars logged in the run
398
+ """
399
+ # We need to format the tags as Split/Label
400
+ # If split is empty, we just return the label
401
+ rows = [(row[0] + "/" + row[1]) if row[0] != "" else row[1] for row in self.formatted_scalars]
402
+ return rows
403
+
404
+ @property
405
+ def formatted_scalars(self) -> List[Tuple[str, str]]:
406
+ """
407
+ Return the scalars values as split and label
408
+ """
409
+ with self._cursor as cursor:
410
+ cursor.execute("SELECT DISTINCT split, label FROM Logs WHERE run_id=?", (self.run_id,))
411
+ rows = cursor.fetchall()
412
+ # We need to format the tags as Split/Label
413
+ # If split is empty, we just return the label
414
+ return [(row[0], row[1]) for row in rows]
415
+
416
+ def __getitem__(self, tag):
417
+ """
418
+ Get the scalar values for a given tag.
419
+ """
420
+ return self.read_scalar(tag)
421
+
422
+ def _run_pre_hooks(self):
423
+ for hook in self.pre_hooks:
424
+ hook()
425
+
426
+ def _get_fragments(self, id: Optional[int], step: Optional[int], split: Optional[str], epoch: Optional[int],
427
+ repetition: Optional[int], fragment_type: Literal["RAW", "HTML"]) -> List[dict]:
428
+ command = f"SELECT id_, step, epoch, run_rep, split, fragment FROM Fragments WHERE run_id=? AND fragment_type='{fragment_type}'"
429
+ params = [self.run_id]
430
+ if id is not None:
431
+ command += " AND id_=?"
432
+ params.append(id)
433
+
434
+ if step is not None:
435
+ command += " AND step=?"
436
+ params.append(step)
437
+
438
+ if split is not None:
439
+ command += " AND split=?"
440
+ params.append(split)
441
+
442
+ if epoch is not None:
443
+ command += " AND epoch=?"
444
+ params.append(epoch)
445
+
446
+ if repetition is not None:
447
+ command += " AND run_rep=?"
448
+ params.append(repetition)
449
+
450
+ with self._cursor as cursor:
451
+ cursor.execute(f'{command};', tuple(params))
452
+ rows = cursor.fetchall()
453
+ # Convert the bytes to PIL Image objects
454
+ return [dict(
455
+ id=row[0],
456
+ step=row[1],
457
+ epoch=row[2],
458
+ run_rep=row[3],
459
+ split=row[4],
460
+ fragment=row[5]
461
+ ) for row in rows]
462
+
463
+ def _get_images(self, id: Optional[int], step: Optional[int], split: Optional[str], epoch: Optional[int],
464
+ repetition: Optional[int], img_type: Literal["IMAGE", "PLOT"]) -> List[dict]:
465
+ command = f"SELECT id_, step, epoch, run_rep, split, image FROM Images WHERE run_id=? AND img_type='{img_type}'"
466
+ params = [self.run_id]
467
+ if id is not None:
468
+ command += " AND id_=?"
469
+ params.append(id)
470
+
471
+ if step is not None:
472
+ command += " AND step=?"
473
+ params.append(step)
474
+
475
+ if split is not None:
476
+ command += " AND split=?"
477
+ params.append(split)
478
+
479
+ if epoch is not None:
480
+ command += " AND epoch=?"
481
+ params.append(epoch)
482
+
483
+ if repetition is not None:
484
+ command += " AND run_rep=?"
485
+ params.append(repetition)
486
+
487
+ with self._cursor as cursor:
488
+ cursor.execute(f'{command};', tuple(params))
489
+ rows = cursor.fetchall()
490
+ # Convert the bytes to PIL Image objects
491
+ return [dict(
492
+ id=row[0],
493
+ step=row[1],
494
+ epoch=row[2],
495
+ run_rep=row[3],
496
+ split=row[4],
497
+ image=Image.open(BytesIO(row[5]))
498
+ ) for row in rows]
499
+
500
+ def _get_global_step(self, tag):
501
+ """
502
+ Keep track of the global step for each tag.
503
+ :param tag: The tag to get the step
504
+ :return: The current global step
505
+ """
506
+ if tag not in self.global_step:
507
+ self.global_step[tag] = 0
508
+
509
+ out = self.global_step[tag]
510
+ self.global_step[tag] += 1
511
+ return out
512
+
513
+ def _log(self, tag: str, epoch: int, step: int, split: str, name: str, scalar_value: float, walltime: float,
514
+ run_rep: int):
515
+ """
516
+ Store the scalar log into the buffer, and flush the buffer if it is full.
517
+ :param tag: The tag
518
+ :param epoch: The epoch
519
+ :param step: The step
520
+ :param split: The split
521
+ :param name: The name
522
+ :param scalar_value: The value
523
+ :param walltime: The wall time
524
+ :param run_rep: The run repetition
525
+ :return: None
526
+ """
527
+ if tag not in self.buffer:
528
+ self.buffer[tag] = []
529
+ self.buffer[tag].append((self.run_id, epoch, step, split, name, scalar_value, walltime, run_rep))
530
+
531
+ if len(self.buffer[tag]) >= self.flush_each:
532
+ self._flush(tag)
533
+
534
+ def _log_fragment(self, fragment: str, step: int, split: Optional[int], repetition: int, epoch: Optional[int],
535
+ type: Literal["RAW", "HTML"] = "RAW"):
536
+ """
537
+ Log a text or html fragment to the resultTable.
538
+ :param fragment: The content to log
539
+ :param step: The step
540
+ :param split: The split that made it
541
+ :param repetition: The run repetition
542
+ :param epoch: The epoch
543
+ :param type: Raw (for text only) or HTML (for html content)
544
+ :return: None
545
+ """
546
+ if split not in self.fragments_buffer:
547
+ self.fragments_buffer[split] = []
548
+
549
+ self.fragments_buffer[split].append((self.run_id, step, epoch, repetition, type, split, fragment))
550
+
551
+ if len(self.fragments_buffer[split]) >= self.flush_each:
552
+ self._flush_fragment(split)
553
+
554
+ def _log_image(self, image: bytes, step: int, split: Optional[int], repetition: int, epoch: Optional[int],
555
+ type: Literal["IMAGE", "PLOT"] = "IMAGE"):
556
+ """
557
+ Store the image log into the buffer, and flush the buffer if it is full.
558
+ :param image: The image bytes
559
+ :param step: The step
560
+ :param split: The split
561
+ :param repetition: The run repetition
562
+ :param epoch: The epoch
563
+ :param type: The type of the image, either "IMAGE" or "PLOT". Default is "IMAGE".
564
+ :return: None
565
+ """
566
+ if split not in self.image_buffer:
567
+ self.image_buffer[split] = []
568
+
569
+ self.image_buffer[split].append((self.run_id, step, epoch, repetition, type, split, image))
570
+
571
+ if len(self.image_buffer[split]) >= self.flush_each:
572
+ self._flush_image(split)
573
+
574
+ def _flush_all(self):
575
+ """
576
+ Flush all buffers.
577
+ :return: None
578
+ """
579
+ # Flush all the scalars
580
+ for tag in self.buffer.keys():
581
+ self._flush(tag)
582
+
583
+ # Flush all the images
584
+ for split in self.image_buffer.keys():
585
+ self._flush_image(split)
586
+
587
+ # Flush all the fragments
588
+ for split in self.fragments_buffer.keys():
589
+ self._flush_fragment(split)
590
+
591
+ def _flush_image(self, split):
592
+ query = """
593
+ INSERT INTO Images (run_id, step, epoch, run_rep, img_type, split, image)
594
+ VALUES (?, ?, ?, ?, ?, ?, ?)
595
+ """
596
+ if not self.disable:
597
+ with self._cursor as cursor:
598
+ cursor.executemany(query, self.image_buffer[split])
599
+
600
+ # Reset the buffer
601
+ self.image_buffer[split] = []
602
+
603
+ def _flush_fragment(self, split):
604
+ query = """
605
+ INSERT INTO Fragments (run_id, step, epoch, run_rep, fragment_type, split, fragment)
606
+ VALUES (?, ?, ?, ?, ?, ?, ?)
607
+ """
608
+ if not self.disable:
609
+ with self._cursor as cursor:
610
+ cursor.executemany(query, self.fragments_buffer[split])
611
+
612
+ # Reset the buffer
613
+ self.fragments_buffer[split] = []
614
+
615
+ def _flush(self, tag: str):
616
+ """
617
+ Flush the scalar values into the db and reset the buffer.
618
+ :param tag: The tag to flush
619
+ :return: None
620
+ """
621
+ query = """
622
+ INSERT INTO Logs (run_id, epoch, step, split, label, value, wall_time, run_rep)
623
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
624
+ """
625
+
626
+ if not self.disable:
627
+ with self._cursor as cursor:
628
+ cursor.executemany(query, self.buffer[tag])
629
+
630
+ # Reset the buffer
631
+ self.buffer[tag] = []
632
+
633
+ def _keep(self, tag: str) -> bool:
634
+ """
635
+ Assert if we need to record this log or drop it. Depends on the kep_each attribute
636
+ :param tag: The tag
637
+ :return: True if we need to keep it and False if we drop it
638
+ """
639
+ if tag not in self.log_count:
640
+ self.log_count[tag] = 0
641
+ self.log_count[tag] += 1
642
+ if self.log_count[tag] >= self.keep_each:
643
+ self.log_count[tag] = 0
644
+ return True
645
+ else:
646
+ return False
647
+
648
+ def _exception_handler(self):
649
+ """
650
+ Set the exception handler to set the status to failed and disable the logger if the program crashes
651
+ """
652
+ previous_hooks = sys.excepthook
653
+ def handler(exc_type, exc_value, traceback):
654
+ # Set the status to failed
655
+ self.set_status("failed")
656
+ # Disable the logger
657
+ self.enabled = False
658
+
659
+ # Call the previous exception handler
660
+ previous_hooks(exc_type, exc_value, traceback)
661
+
662
+ # Set the new exception handler
663
+ sys.excepthook = handler
664
+
665
+ @property
666
+ def _cursor(self):
667
+ return Cursor(self.db_path)