cellfinder 1.6.0__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cellfinder/core/main.py CHANGED
@@ -1,34 +1,33 @@
1
1
  import os
2
2
  from typing import Callable, List, Optional, Tuple
3
3
 
4
- import numpy as np
5
4
  from brainglobe_utils.cells.cells import Cell
6
5
 
7
- from cellfinder.core import logger
6
+ from cellfinder.core import logger, types
8
7
  from cellfinder.core.download.download import model_type
9
8
  from cellfinder.core.train.train_yaml import depth_type
10
9
 
11
10
 
12
11
  def main(
13
- signal_array: np.ndarray,
14
- background_array: np.ndarray,
15
- voxel_sizes: Tuple[int, int, int],
12
+ signal_array: types.array,
13
+ background_array: types.array,
14
+ voxel_sizes: Tuple[float, float, float],
16
15
  start_plane: int = 0,
17
16
  end_plane: int = -1,
18
17
  trained_model: Optional[os.PathLike] = None,
19
18
  model_weights: Optional[os.PathLike] = None,
20
19
  model: model_type = "resnet50_tv",
21
- batch_size: int = 64,
20
+ classification_batch_size: int = 64,
22
21
  n_free_cpus: int = 2,
23
- network_voxel_sizes: Tuple[int, int, int] = (5, 1, 1),
24
- soma_diameter: int = 16,
25
- ball_xy_size: int = 6,
26
- ball_z_size: int = 15,
22
+ network_voxel_sizes: Tuple[float, float, float] = (5, 1, 1),
23
+ soma_diameter: float = 16,
24
+ ball_xy_size: float = 6,
25
+ ball_z_size: float = 15,
27
26
  ball_overlap_fraction: float = 0.6,
28
27
  log_sigma_size: float = 0.2,
29
28
  n_sds_above_mean_thresh: float = 10,
30
29
  soma_spread_factor: float = 1.4,
31
- max_cluster_size: int = 100000,
30
+ max_cluster_size: float = 100000,
32
31
  cube_width: int = 50,
33
32
  cube_height: int = 50,
34
33
  cube_depth: int = 20,
@@ -36,8 +35,13 @@ def main(
36
35
  skip_detection: bool = False,
37
36
  skip_classification: bool = False,
38
37
  detected_cells: List[Cell] = None,
39
- classification_batch_size: Optional[int] = None,
40
- classification_torch_device: str = "cpu",
38
+ detection_batch_size: Optional[int] = None,
39
+ torch_device: Optional[str] = None,
40
+ pin_memory: bool = False,
41
+ split_ball_xy_size: float = 6,
42
+ split_ball_z_size: float = 15,
43
+ split_ball_overlap_fraction: float = 0.8,
44
+ n_splitting_iter: int = 10,
41
45
  *,
42
46
  detect_callback: Optional[Callable[[int], None]] = None,
43
47
  classify_callback: Optional[Callable[[int], None]] = None,
@@ -46,6 +50,111 @@ def main(
46
50
  """
47
51
  Parameters
48
52
  ----------
53
+ signal_array : numpy.ndarray or dask array
54
+ 3D array representing the signal data in z, y, x order.
55
+ background_array : numpy.ndarray or dask array
56
+ 3D array representing the signal data in z, y, x order.
57
+ voxel_sizes : 3-tuple of floats
58
+ Size of your voxels in the z, y, and x dimensions (microns).
59
+ start_plane : int
60
+ First plane index to process (inclusive, to process a subset of the
61
+ data).
62
+ end_plane : int
63
+ Last plane index to process (exclusive, to process a subset of the
64
+ data).
65
+ trained_model : Optional[Path]
66
+ Trained model file path (home directory (default) -> pretrained
67
+ weights).
68
+ model_weights : Optional[Path]
69
+ Model weights path (home directory (default) -> pretrained
70
+ weights).
71
+ model: str
72
+ Type of model to use. Defaults to `"resnet50_tv"`.
73
+ classification_batch_size : int
74
+ How many potential cells to classify at one time. The GPU/CPU
75
+ memory must be able to contain at once this many data cubes for
76
+ the models. For performance-critical applications, tune to maximize
77
+ memory usage without running out. Check your GPU/CPU memory to verify
78
+ it's not full.
79
+ n_free_cpus : int
80
+ How many CPU cores to leave free.
81
+ network_voxel_sizes : 3-tuple of floats
82
+ Size of the pre-trained network's voxels (microns) in the z, y, and x
83
+ dimensions.
84
+ soma_diameter : float
85
+ The expected in-plane (xy) soma diameter (microns).
86
+ ball_xy_size : float
87
+ 3d filter's in-plane (xy) filter ball size (microns).
88
+ ball_z_size : float
89
+ 3d filter's axial (z) filter ball size (microns).
90
+ ball_overlap_fraction : float
91
+ 3d filter's fraction of the ball filter needed to be filled by
92
+ foreground voxels, centered on a voxel, to retain the voxel.
93
+ log_sigma_size : float
94
+ Gaussian filter width (as a fraction of soma diameter) used during
95
+ 2d in-plane Laplacian of Gaussian filtering.
96
+ n_sds_above_mean_thresh : float
97
+ Intensity threshold (the number of standard deviations above
98
+ the mean) of the filtered 2d planes used to mark pixels as
99
+ foreground or background.
100
+ soma_spread_factor : float
101
+ Cell spread factor for determining the largest cell volume before
102
+ splitting up cell clusters. Structures with spherical volume of
103
+ diameter `soma_spread_factor * soma_diameter` or less will not be
104
+ split.
105
+ max_cluster_size : float
106
+ Largest detected cell cluster (in cubic um) where splitting
107
+ should be attempted. Clusters above this size will be labeled
108
+ as artifacts.
109
+ cube_width: int
110
+ The width of the data cube centered on the cell used for
111
+ classification. Defaults to `50`.
112
+ cube_height: int
113
+ The height of the data cube centered on the cell used for
114
+ classification. Defaults to `50`.
115
+ cube_depth: int
116
+ The depth of the data cube centered on the cell used for
117
+ classification. Defaults to `20`.
118
+ network_depth: str
119
+ The network depth to use during classification. Defaults to `"50"`.
120
+ skip_detection : bool
121
+ If selected, the detection step is skipped and instead we get the
122
+ detected cells from the cell layer below (from a previous
123
+ detection run or import).
124
+ skip_classification : bool
125
+ If selected, the classification step is skipped and all cells from
126
+ the detection stage are added.
127
+ detected_cells: Optional list of Cell objects.
128
+ If specified, the cells to use during classification.
129
+ detection_batch_size: int
130
+ The number of planes of the original data volume to process at
131
+ once. The GPU/CPU memory must be able to contain this many planes
132
+ for all the filters. For performance-critical applications, tune
133
+ to maximize memory usage without running out. Check your GPU/CPU
134
+ memory to verify it's not full.
135
+ torch_device : str, optional
136
+ The device on which to run the computation. If not specified (None),
137
+ "cuda" will be used if a GPU is available, otherwise "cpu".
138
+ You can also manually specify "cuda" or "cpu".
139
+ pin_memory: bool
140
+ Pins data to be sent to the GPU to the CPU memory. This allows faster
141
+ GPU data speeds, but can only be used if the data used by the GPU can
142
+ stay in the CPU RAM while the GPU uses it. I.e. there's enough RAM.
143
+ Otherwise, if there's a risk of the RAM being paged, it shouldn't be
144
+ used. Defaults to False.
145
+ split_ball_xy_size: float
146
+ Similar to `ball_xy_size`, except the value to use for the 3d
147
+ filter during cluster splitting.
148
+ split_ball_z_size: float
149
+ Similar to `ball_z_size`, except the value to use for the 3d filter
150
+ during cluster splitting.
151
+ split_ball_overlap_fraction: float
152
+ Similar to `ball_overlap_fraction`, except the value to use for the
153
+ 3d filter during cluster splitting.
154
+ n_splitting_iter: int
155
+ The number of iterations to run the 3d filtering on a cluster. Each
156
+ iteration reduces the cluster size by the voxels not retained in
157
+ the previous iteration.
49
158
  detect_callback : Callable[int], optional
50
159
  Called every time a plane has finished being processed during the
51
160
  detection stage. Called with the plane number that has finished.
@@ -76,9 +185,14 @@ def main(
76
185
  n_free_cpus,
77
186
  log_sigma_size,
78
187
  n_sds_above_mean_thresh,
79
- batch_size=classification_batch_size,
80
- torch_device=classification_torch_device,
188
+ batch_size=detection_batch_size,
189
+ torch_device=torch_device,
190
+ pin_memory=pin_memory,
81
191
  callback=detect_callback,
192
+ split_ball_z_size=split_ball_z_size,
193
+ split_ball_xy_size=split_ball_xy_size,
194
+ split_ball_overlap_fraction=split_ball_overlap_fraction,
195
+ n_splitting_iter=n_splitting_iter,
82
196
  )
83
197
 
84
198
  if detect_finished_callback is not None:
@@ -101,7 +215,7 @@ def main(
101
215
  n_free_cpus,
102
216
  voxel_sizes,
103
217
  network_voxel_sizes,
104
- batch_size,
218
+ classification_batch_size,
105
219
  cube_height,
106
220
  cube_width,
107
221
  cube_depth,
@@ -15,6 +15,7 @@ Typical example::
15
15
 
16
16
  from cellfinder.core.tools.threading import ThreadWithException, \\
17
17
  EOFSignal, ProcessWithException
18
+ from cellfinder.core import logger
18
19
  import torch
19
20
 
20
21
 
@@ -63,7 +64,7 @@ Typical example::
63
64
  # thread exited for whatever reason (not exception)
64
65
  break
65
66
 
66
- print(f"Thread processed tensor {i}")
67
+ logger.debug(f"Thread processed tensor {i}")
67
68
  finally:
68
69
  # whatever happens, make sure thread is told to finish so it
69
70
  # doesn't get stuck
@@ -248,8 +249,8 @@ class ExceptionWithQueueMixIn:
248
249
  ... # do something with the msg
249
250
  ... pass
250
251
  ... except ExecutionFailure as e:
251
- ... print(f"got exception {type(e.__cause__)}")
252
- ... print(f"with message {e.__cause__.args[0]}")
252
+ ... logger.error(f"got exception {type(e.__cause__)}")
253
+ ... logger.error(f"with message {e.__cause__.args[0]}")
253
254
  """
254
255
  msg, value = self.from_thread_queue.get(block=True, timeout=timeout)
255
256
  if msg == "eof":
@@ -3,9 +3,6 @@ main
3
3
  ===============
4
4
 
5
5
  Trains a network based on a yaml file specifying cubes of cells/non cells.
6
-
7
- N.B imports are within functions to prevent tensorflow being imported before
8
- it's warnings are silenced
9
6
  """
10
7
 
11
8
  import os
@@ -29,12 +26,16 @@ from brainglobe_utils.general.system import (
29
26
  from brainglobe_utils.IO.cells import find_relevant_tiffs
30
27
  from brainglobe_utils.IO.yaml import read_yaml_section
31
28
  from fancylog import fancylog
29
+ from keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard
32
30
  from sklearn.model_selection import train_test_split
33
31
 
34
32
  import cellfinder.core as program_for_log
35
33
  from cellfinder.core import logger
34
+ from cellfinder.core.classify.cube_generator import CubeGeneratorFromDisk
36
35
  from cellfinder.core.classify.resnet import layer_type
36
+ from cellfinder.core.classify.tools import get_model, make_lists
37
37
  from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY
38
+ from cellfinder.core.tools.prep import prep_model_weights
38
39
 
39
40
  depth_type = Literal["18", "34", "50", "101", "152"]
40
41
 
@@ -316,16 +317,6 @@ def run(
316
317
  save_progress=False,
317
318
  epochs=100,
318
319
  ):
319
- from keras.callbacks import (
320
- CSVLogger,
321
- ModelCheckpoint,
322
- TensorBoard,
323
- )
324
-
325
- from cellfinder.core.classify.cube_generator import CubeGeneratorFromDisk
326
- from cellfinder.core.classify.tools import get_model, make_lists
327
- from cellfinder.core.tools.prep import prep_model_weights
328
-
329
320
  start_time = datetime.now()
330
321
 
331
322
  ensure_directory_exists(output_dir)
@@ -95,6 +95,22 @@ class CurationWidget(QWidget):
95
95
  self.training_data_non_cell_choice, self.point_layer_names
96
96
  )
97
97
 
98
+ @self.viewer.layers.events.removed.connect
99
+ def _remove_selection_layers(event: QtCore.QEvent):
100
+ """
101
+ Set internal background, signal, training data cell,
102
+ and training data non-cell layers to None when they
103
+ are removed from the napari viewer GUI.
104
+ """
105
+ if event.value == self.signal_layer:
106
+ self.signal_layer = None
107
+ if event.value == self.background_layer:
108
+ self.background_layer = None
109
+ if event.value == self.training_data_cell_layer:
110
+ self.training_data_cell_layer = None
111
+ if event.value == self.training_data_non_cell_layer:
112
+ self.training_data_non_cell_layer = None
113
+
98
114
  @staticmethod
99
115
  def _update_combobox_options(combobox: QComboBox, options_list: List[str]):
100
116
  original_text = combobox.currentText()
@@ -212,8 +228,8 @@ class CurationWidget(QWidget):
212
228
  self.layout.addWidget(self.load_data_panel, row, column, 1, 1)
213
229
 
214
230
  def setup_keybindings(self):
215
- self.viewer.bind_key("c", self.mark_as_cell)
216
- self.viewer.bind_key("x", self.mark_as_non_cell)
231
+ self.viewer.bind_key("c", self.mark_as_cell, overwrite=True)
232
+ self.viewer.bind_key("x", self.mark_as_non_cell, overwrite=True)
217
233
 
218
234
  def set_signal_image(self):
219
235
  """
@@ -244,23 +244,26 @@ def detect_widget() -> FunctionGui:
244
244
  detection_options,
245
245
  skip_detection: bool,
246
246
  soma_diameter: float,
247
+ log_sigma_size: float,
248
+ n_sds_above_mean_thresh: float,
247
249
  ball_xy_size: float,
248
250
  ball_z_size: float,
249
251
  ball_overlap_fraction: float,
250
- log_sigma_size: float,
251
- n_sds_above_mean_thresh: int,
252
+ detection_batch_size: int,
252
253
  soma_spread_factor: float,
253
- max_cluster_size: int,
254
+ max_cluster_size: float,
254
255
  classification_options,
255
256
  skip_classification: bool,
256
257
  use_pre_trained_weights: bool,
257
258
  trained_model: Optional[Path],
258
- batch_size: int,
259
+ classification_batch_size: int,
259
260
  misc_options,
260
261
  start_plane: int,
261
262
  end_plane: int,
262
263
  n_free_cpus: int,
263
264
  analyse_local: bool,
265
+ use_gpu: bool,
266
+ pin_memory: bool,
264
267
  debug: bool,
265
268
  reset_button,
266
269
  ) -> None:
@@ -270,43 +273,60 @@ def detect_widget() -> FunctionGui:
270
273
  Parameters
271
274
  ----------
272
275
  voxel_size_z : float
273
- Size of your voxels in the axial dimension
276
+ Size of your voxels in the axial dimension (microns)
274
277
  voxel_size_y : float
275
- Size of your voxels in the y direction (top to bottom)
278
+ Size of your voxels in the y direction (top to bottom) (microns)
276
279
  voxel_size_x : float
277
- Size of your voxels in the x direction (left to right)
280
+ Size of your voxels in the x direction (left to right) (microns)
278
281
  skip_detection : bool
279
282
  If selected, the detection step is skipped and instead we get the
280
283
  detected cells from the cell layer below (from a previous
281
284
  detection run or import)
282
285
  soma_diameter : float
283
- The expected in-plane soma diameter (microns)
286
+ The expected in-plane (xy) soma diameter (microns)
287
+ log_sigma_size : float
288
+ Gaussian filter width (as a fraction of soma diameter) used during
289
+ 2d in-plane Laplacian of Gaussian filtering
290
+ n_sds_above_mean_thresh : float
291
+ Intensity threshold (the number of standard deviations above
292
+ the mean) of the filtered 2d planes used to mark pixels as
293
+ foreground or background
284
294
  ball_xy_size : float
285
- Elliptical morphological in-plane filter size (microns)
295
+ 3d filter's in-plane (xy) filter ball size (microns)
286
296
  ball_z_size : float
287
- Elliptical morphological axial filter size (microns)
297
+ 3d filter's axial (z) filter ball size (microns)
288
298
  ball_overlap_fraction : float
289
- Fraction of the morphological filter needed to be filled
290
- to retain a voxel
291
- log_sigma_size : float
292
- Laplacian of Gaussian filter width (as a fraction of soma diameter)
293
- n_sds_above_mean_thresh : int
294
- Cell intensity threshold (as a multiple of noise above the mean)
299
+ 3d filter's fraction of the ball filter needed to be filled by
300
+ foreground voxels, centered on a voxel, to retain the voxel
301
+ detection_batch_size: int
302
+ The number of planes of the original data volume to process at
303
+ once. The GPU/CPU memory must be able to contain this many planes
304
+ for all the filters. For performance-critical applications, tune
305
+ to maximize memory usage without
306
+ running out. Check your GPU/CPU memory to verify it's not full
295
307
  soma_spread_factor : float
296
- Cell spread factor (for splitting up cell clusters)
297
- max_cluster_size : int
298
- Largest putative cell cluster (in cubic um) where splitting
299
- should be attempted
300
- use_pre_trained_weights : bool
301
- Select to use pre-trained model weights
302
- batch_size : int
303
- How many points to classify at one time
308
+ Cell spread factor for determining the largest cell volume before
309
+ splitting up cell clusters. Structures with spherical volume of
310
+ diameter `soma_spread_factor * soma_diameter` or less will not be
311
+ split
312
+ max_cluster_size : float
313
+ Largest detected cell cluster (in cubic um) where splitting
314
+ should be attempted. Clusters above this size will be labeled
315
+ as artifacts
304
316
  skip_classification : bool
305
317
  If selected, the classification step is skipped and all cells from
306
318
  the detection stage are added
319
+ use_pre_trained_weights : bool
320
+ Select to use pre-trained model weights
307
321
  trained_model : Optional[Path]
308
322
  Trained model file path (home directory (default) -> pretrained
309
323
  weights)
324
+ classification_batch_size : int
325
+ How many potential cells to classify at one time. The GPU/CPU
326
+ memory must be able to contain at once this many data cubes for
327
+ the models. For performance-critical applications, tune to
328
+ maximize memory usage without running
329
+ out. Check your GPU/CPU memory to verify it's not full
310
330
  start_plane : int
311
331
  First plane to process (to process a subset of the data)
312
332
  end_plane : int
@@ -315,6 +335,14 @@ def detect_widget() -> FunctionGui:
315
335
  How many CPU cores to leave free
316
336
  analyse_local : bool
317
337
  Only analyse planes around the current position
338
+ use_gpu : bool
339
+ If True, use GPU for processing (if available); otherwise, use CPU.
340
+ pin_memory: bool
341
+ Pins data to be sent to the GPU to the CPU memory. This allows
342
+ faster GPU data speeds, but can only be used if the data used by
343
+ the GPU can stay in the CPU RAM while the GPU uses it. I.e. there's
344
+ enough RAM. Otherwise, if there's a risk of the RAM being paged, it
345
+ shouldn't be used. Defaults to False.
318
346
  debug : bool
319
347
  Increase logging
320
348
  reset_button :
@@ -370,6 +398,7 @@ def detect_widget() -> FunctionGui:
370
398
  n_sds_above_mean_thresh,
371
399
  soma_spread_factor,
372
400
  max_cluster_size,
401
+ detection_batch_size,
373
402
  )
374
403
 
375
404
  if use_pre_trained_weights:
@@ -378,7 +407,7 @@ def detect_widget() -> FunctionGui:
378
407
  skip_classification,
379
408
  use_pre_trained_weights,
380
409
  trained_model,
381
- batch_size,
410
+ classification_batch_size,
382
411
  )
383
412
 
384
413
  if analyse_local:
@@ -389,7 +418,13 @@ def detect_widget() -> FunctionGui:
389
418
  end_plane = len(signal_image.data)
390
419
 
391
420
  misc_inputs = MiscInputs(
392
- start_plane, end_plane, n_free_cpus, analyse_local, debug
421
+ start_plane,
422
+ end_plane,
423
+ n_free_cpus,
424
+ analyse_local,
425
+ use_gpu,
426
+ pin_memory,
427
+ debug,
393
428
  )
394
429
 
395
430
  worker = Worker(
@@ -1,8 +1,9 @@
1
- from dataclasses import dataclass
1
+ from dataclasses import dataclass, field
2
2
  from pathlib import Path
3
3
  from typing import List, Optional
4
4
 
5
5
  import numpy
6
+ import torch
6
7
  from brainglobe_utils.cells.cells import Cell
7
8
 
8
9
  from cellfinder.napari.input_container import InputContainer
@@ -30,7 +31,7 @@ class DataInputs(InputContainer):
30
31
  self.voxel_size_x,
31
32
  )
32
33
  # del operator doesn't affect self, because asdict creates a copy of
33
- # fields.
34
+ # the dict.
34
35
  del data_input_dict["voxel_size_z"]
35
36
  del data_input_dict["voxel_size_y"]
36
37
  del data_input_dict["voxel_size_x"]
@@ -67,9 +68,10 @@ class DetectionInputs(InputContainer):
67
68
  ball_z_size: float = 15
68
69
  ball_overlap_fraction: float = 0.6
69
70
  log_sigma_size: float = 0.2
70
- n_sds_above_mean_thresh: int = 10
71
+ n_sds_above_mean_thresh: float = 10
71
72
  soma_spread_factor: float = 1.4
72
- max_cluster_size: int = 100000
73
+ max_cluster_size: float = 100000
74
+ detection_batch_size: int = 1
73
75
 
74
76
  def as_core_arguments(self) -> dict:
75
77
  return super().as_core_arguments()
@@ -96,14 +98,17 @@ class DetectionInputs(InputContainer):
96
98
  "n_sds_above_mean_thresh", custom_label="Threshold"
97
99
  ),
98
100
  soma_spread_factor=cls._custom_widget(
99
- "soma_spread_factor", custom_label="Cell spread"
101
+ "soma_spread_factor", custom_label="Split cell spread"
100
102
  ),
101
103
  max_cluster_size=cls._custom_widget(
102
104
  "max_cluster_size",
103
- custom_label="Max cluster",
105
+ custom_label="Split max cluster",
104
106
  min=0,
105
107
  max=10000000,
106
108
  ),
109
+ detection_batch_size=cls._custom_widget(
110
+ "detection_batch_size", custom_label="Batch size (detection)"
111
+ ),
107
112
  )
108
113
 
109
114
 
@@ -114,7 +119,7 @@ class ClassificationInputs(InputContainer):
114
119
  skip_classification: bool = False
115
120
  use_pre_trained_weights: bool = True
116
121
  trained_model: Optional[Path] = Path.home()
117
- batch_size: int = 64
122
+ classification_batch_size: int = 64
118
123
 
119
124
  def as_core_arguments(self) -> dict:
120
125
  args = super().as_core_arguments()
@@ -132,7 +137,10 @@ class ClassificationInputs(InputContainer):
132
137
  skip_classification=dict(
133
138
  value=cls.defaults()["skip_classification"]
134
139
  ),
135
- batch_size=dict(value=cls.defaults()["batch_size"]),
140
+ classification_batch_size=dict(
141
+ value=cls.defaults()["classification_batch_size"],
142
+ label="Batch size (classification)",
143
+ ),
136
144
  )
137
145
 
138
146
 
@@ -144,10 +152,14 @@ class MiscInputs(InputContainer):
144
152
  end_plane: int = 0
145
153
  n_free_cpus: int = 2
146
154
  analyse_local: bool = False
155
+ use_gpu: bool = field(default_factory=lambda: torch.cuda.is_available())
156
+ pin_memory: bool = False
147
157
  debug: bool = False
148
158
 
149
159
  def as_core_arguments(self) -> dict:
150
160
  misc_input_dict = super().as_core_arguments()
161
+ misc_input_dict["torch_device"] = "cuda" if self.use_gpu else "cpu"
162
+ del misc_input_dict["use_gpu"]
151
163
  del misc_input_dict["debug"]
152
164
  del misc_input_dict["analyse_local"]
153
165
  return misc_input_dict
@@ -162,5 +174,16 @@ class MiscInputs(InputContainer):
162
174
  "n_free_cpus", custom_label="Number of free CPUs"
163
175
  ),
164
176
  analyse_local=dict(value=cls.defaults()["analyse_local"]),
177
+ use_gpu=dict(
178
+ widget_type="CheckBox",
179
+ label="Use GPU",
180
+ value=cls.defaults()["use_gpu"],
181
+ enabled=torch.cuda.is_available(),
182
+ ),
183
+ pin_memory=dict(
184
+ widget_type="CheckBox",
185
+ label="Pin data to memory",
186
+ value=cls.defaults()["pin_memory"],
187
+ ),
165
188
  debug=dict(value=cls.defaults()["debug"]),
166
189
  )
@@ -1,8 +1,18 @@
1
1
  from abc import abstractmethod
2
- from dataclasses import asdict, dataclass
2
+ from dataclasses import dataclass, fields
3
3
  from typing import Optional
4
4
 
5
5
 
6
+ def asdict_no_copy(obj: dataclass) -> dict:
7
+ """
8
+ Similar to `asdict`, except it makes no copies of the field values.
9
+ asdict will do a deep copy of field values that are non-basic objects.
10
+
11
+ It still creates a new dict to return, though.
12
+ """
13
+ return {field.name: getattr(obj, field.name) for field in fields(obj)}
14
+
15
+
6
16
  @dataclass
7
17
  class InputContainer:
8
18
  """Base for classes that contain inputs
@@ -23,7 +33,7 @@ class InputContainer:
23
33
  # Derived classes are not expected to be particularly
24
34
  # slow to instantiate, so use the default constructor
25
35
  # to avoid code repetition.
26
- return asdict(cls())
36
+ return asdict_no_copy(cls())
27
37
 
28
38
  @abstractmethod
29
39
  def as_core_arguments(self) -> dict:
@@ -32,10 +42,10 @@ class InputContainer:
32
42
  The implementation provided here can be re-used in derived classes, if
33
43
  convenient.
34
44
  """
35
- # note that asdict returns a new instance of a dict,
45
+ # note that asdict_no_copy returns a new instance of a dict,
36
46
  # so any subsequent modifications of this dict won't affect the class
37
47
  # instance
38
- return asdict(self)
48
+ return asdict_no_copy(self)
39
49
 
40
50
  @classmethod
41
51
  def _custom_widget(
@@ -25,14 +25,14 @@ def run_training(
25
25
  optional_training_inputs: OptionalTrainingInputs,
26
26
  misc_training_inputs: MiscTrainingInputs,
27
27
  ):
28
- print("Running training")
28
+ show_info("Running training...")
29
29
  train_yaml(
30
30
  **training_data_inputs.as_core_arguments(),
31
31
  **optional_network_inputs.as_core_arguments(),
32
32
  **optional_training_inputs.as_core_arguments(),
33
33
  **misc_training_inputs.as_core_arguments(),
34
34
  )
35
- print("Finished!")
35
+ show_info("Training finished!")
36
36
 
37
37
 
38
38
  def training_widget() -> FunctionGui:
@@ -60,7 +60,6 @@ def training_widget() -> FunctionGui:
60
60
  continue_training: bool,
61
61
  augment: bool,
62
62
  tensorboard: bool,
63
- save_weights: bool,
64
63
  save_checkpoints: bool,
65
64
  save_progress: bool,
66
65
  epochs: int,
@@ -96,9 +95,6 @@ def training_widget() -> FunctionGui:
96
95
  Augment the training data to improve generalisation
97
96
  tensorboard : bool
98
97
  Log to output_directory/tensorboard
99
- save_weights : bool
100
- Only store the model weights, and not the full model
101
- Useful to save storage space
102
98
  save_checkpoints : bool
103
99
  Store the model at intermediate points during training
104
100
  save_progress : bool
@@ -133,7 +129,6 @@ def training_widget() -> FunctionGui:
133
129
  continue_training,
134
130
  augment,
135
131
  tensorboard,
136
- save_weights,
137
132
  save_checkpoints,
138
133
  save_progress,
139
134
  epochs,
@@ -147,6 +142,7 @@ def training_widget() -> FunctionGui:
147
142
  if yaml_files[0] == Path.home(): # type: ignore
148
143
  show_info("Please select a YAML file for training")
149
144
  else:
145
+ show_info("Starting training process...")
150
146
  worker = run_training(
151
147
  training_data_inputs,
152
148
  optional_network_inputs,
@@ -75,7 +75,6 @@ class OptionalTrainingInputs(InputContainer):
75
75
  continue_training: bool = False
76
76
  augment: bool = True
77
77
  tensorboard: bool = False
78
- save_weights: bool = False
79
78
  save_checkpoints: bool = True
80
79
  save_progress: bool = True
81
80
  epochs: int = 100
@@ -98,7 +97,6 @@ class OptionalTrainingInputs(InputContainer):
98
97
  continue_training=cls._custom_widget("continue_training"),
99
98
  augment=cls._custom_widget("augment"),
100
99
  tensorboard=cls._custom_widget("tensorboard"),
101
- save_weights=cls._custom_widget("save_weights"),
102
100
  save_checkpoints=cls._custom_widget("save_checkpoints"),
103
101
  save_progress=cls._custom_widget("save_progress"),
104
102
  epochs=cls._custom_widget("epochs"),