cellfinder 1.1.3__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cellfinder might be problematic. Click here for more details.

Files changed (34) hide show
  1. cellfinder/__init__.py +21 -12
  2. cellfinder/core/classify/classify.py +13 -6
  3. cellfinder/core/classify/cube_generator.py +27 -11
  4. cellfinder/core/classify/resnet.py +9 -6
  5. cellfinder/core/classify/tools.py +13 -11
  6. cellfinder/core/detect/detect.py +12 -1
  7. cellfinder/core/detect/filters/volume/ball_filter.py +198 -113
  8. cellfinder/core/detect/filters/volume/structure_detection.py +105 -41
  9. cellfinder/core/detect/filters/volume/structure_splitting.py +1 -1
  10. cellfinder/core/detect/filters/volume/volume_filter.py +48 -49
  11. cellfinder/core/download/cli.py +39 -32
  12. cellfinder/core/download/download.py +44 -56
  13. cellfinder/core/main.py +53 -68
  14. cellfinder/core/tools/prep.py +12 -20
  15. cellfinder/core/tools/source_files.py +5 -3
  16. cellfinder/core/tools/system.py +10 -0
  17. cellfinder/core/train/train_yml.py +29 -27
  18. cellfinder/napari/curation.py +1 -1
  19. cellfinder/napari/detect/detect.py +259 -58
  20. cellfinder/napari/detect/detect_containers.py +11 -1
  21. cellfinder/napari/detect/thread_worker.py +16 -2
  22. cellfinder/napari/train/train.py +2 -9
  23. cellfinder/napari/train/train_containers.py +3 -3
  24. cellfinder/napari/utils.py +88 -47
  25. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/METADATA +12 -11
  26. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/RECORD +30 -34
  27. cellfinder/core/download/models.py +0 -49
  28. cellfinder/core/tools/IO.py +0 -48
  29. cellfinder/core/tools/tf.py +0 -46
  30. cellfinder/napari/images/brainglobe.png +0 -0
  31. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/LICENSE +0 -0
  32. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/WHEEL +0 -0
  33. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/entry_points.txt +0 -0
  34. {cellfinder-1.1.3.dist-info → cellfinder-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,11 @@
1
+ from functools import partial
1
2
  from math import ceil
2
3
  from pathlib import Path
3
- from typing import Optional
4
+ from typing import Any, Callable, Dict, Optional, Tuple
4
5
 
5
6
  import napari
7
+ import napari.layers
8
+ from brainglobe_utils.cells.cells import Cell
6
9
  from magicgui import magicgui
7
10
  from magicgui.widgets import FunctionGui, ProgressBar
8
11
  from napari.utils.notifications import show_info
@@ -10,10 +13,11 @@ from qtpy.QtWidgets import QScrollArea
10
13
 
11
14
  from cellfinder.core.classify.cube_generator import get_cube_depth_min_max
12
15
  from cellfinder.napari.utils import (
13
- add_layers,
14
- header_label_widget,
16
+ add_classified_layers,
17
+ add_single_layer,
18
+ cellfinder_header,
15
19
  html_label_widget,
16
- widget_header,
20
+ napari_array_to_cells,
17
21
  )
18
22
 
19
23
  from .detect_containers import (
@@ -33,14 +37,194 @@ CUBE_DEPTH = 20
33
37
  MIN_PLANES_ANALYSE = 0
34
38
 
35
39
 
40
+ def get_heavy_widgets(
41
+ options: Dict[str, Any]
42
+ ) -> Tuple[Callable, Callable, Callable]:
43
+ # signal and other input are separated out from the main magicgui
44
+ # parameter selections and are inserted as widget children in their own
45
+ # sub-containers of the root. Because if these image parameters are
46
+ # included in the root widget, every time *any* parameter updates, the gui
47
+ # freezes for a bit likely because magicgui is processing something for
48
+ # all the parameters when any parameter changes. And this processing takes
49
+ # particularly long for image parameters. Placing them as sub-containers
50
+ # alleviates this
51
+ @magicgui(
52
+ call_button=False,
53
+ persist=False,
54
+ scrollable=False,
55
+ labels=False,
56
+ auto_call=True,
57
+ )
58
+ def signal_image_opt(
59
+ viewer: napari.Viewer,
60
+ signal_image: napari.layers.Image,
61
+ ):
62
+ """
63
+ magicgui widget for setting the signal_image parameter.
64
+
65
+ Parameters
66
+ ----------
67
+ signal_image : napari.layers.Image
68
+ Image layer containing the labelled cells
69
+ """
70
+ options["signal_image"] = signal_image
71
+ options["viewer"] = viewer
72
+
73
+ @magicgui(
74
+ call_button=False,
75
+ persist=False,
76
+ scrollable=False,
77
+ labels=False,
78
+ auto_call=True,
79
+ )
80
+ def background_image_opt(
81
+ background_image: napari.layers.Image,
82
+ ):
83
+ """
84
+ magicgui widget for setting the background image parameter.
85
+
86
+ Parameters
87
+ ----------
88
+ background_image : napari.layers.Image
89
+ Image layer without labelled cells
90
+ """
91
+ options["background_image"] = background_image
92
+
93
+ @magicgui(
94
+ call_button=False,
95
+ persist=False,
96
+ scrollable=False,
97
+ labels=False,
98
+ auto_call=True,
99
+ )
100
+ def cell_layer_opt(
101
+ cell_layer: napari.layers.Points,
102
+ ):
103
+ """
104
+ magicgui widget for setting the cell layer input when detection is
105
+ skipped.
106
+
107
+ Parameters
108
+ ----------
109
+ cell_layer : napari.layers.Points
110
+ If detection is skipped, select the cell layer containing the
111
+ detected cells to use for classification
112
+ """
113
+ options["cell_layer"] = cell_layer
114
+
115
+ return signal_image_opt, background_image_opt, cell_layer_opt
116
+
117
+
118
+ def add_heavy_widgets(
119
+ root: FunctionGui,
120
+ widgets: Tuple[FunctionGui, ...],
121
+ new_names: Tuple[str, ...],
122
+ insertions: Tuple[str, ...],
123
+ ) -> None:
124
+ for widget, new_name, insertion in zip(widgets, new_names, insertions):
125
+ # make it look as if it's directly in the root container
126
+ widget.margins = 0, 0, 0, 0
127
+ # the parameters of these widgets are updated using `auto_call` only.
128
+ # If False, magicgui passes these as args to root() when the root's
129
+ # function runs. But that doesn't list them as args of its function
130
+ widget.gui_only = True
131
+ root.insert(root.index(insertion) + 1, widget)
132
+ getattr(root, widget.name).label = new_name
133
+
134
+
135
+ def restore_options_defaults(widget: FunctionGui) -> None:
136
+ """
137
+ Restore default widget values.
138
+ """
139
+ defaults = {
140
+ **DataInputs.defaults(),
141
+ **DetectionInputs.defaults(),
142
+ **ClassificationInputs.defaults(),
143
+ **MiscInputs.defaults(),
144
+ }
145
+ for name, value in defaults.items():
146
+ if value is not None: # ignore fields with no default
147
+ getattr(widget, name).value = value
148
+
149
+
150
+ def get_results_callback(
151
+ skip_classification: bool, viewer: napari.Viewer
152
+ ) -> Callable:
153
+ """
154
+ Returns the callback that is connected to output of the pipeline.
155
+ It returns the detected points that we have to visualize.
156
+ """
157
+ if skip_classification:
158
+ # after detection w/o classification, everything is unknown
159
+ def done_func(points):
160
+ add_single_layer(
161
+ points,
162
+ viewer=viewer,
163
+ name="Cell candidates",
164
+ cell_type=Cell.UNKNOWN,
165
+ )
166
+
167
+ else:
168
+ # after classification we have either cell or unknown
169
+ def done_func(points):
170
+ add_classified_layers(
171
+ points,
172
+ viewer=viewer,
173
+ unknown_name="Rejected",
174
+ cell_name="Detected",
175
+ )
176
+
177
+ return done_func
178
+
179
+
180
+ def find_local_planes(
181
+ viewer: napari.Viewer,
182
+ voxel_size_z: float,
183
+ signal_image: napari.layers.Image,
184
+ ) -> Tuple[int, int]:
185
+ """
186
+ When detecting only locally, it returns the start and end planes to use.
187
+ """
188
+ current_plane = viewer.dims.current_step[0]
189
+
190
+ # so a reasonable number of cells in the plane are detected
191
+ planes_needed = MIN_PLANES_ANALYSE + int(
192
+ ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z)
193
+ )
194
+
195
+ start_plane, end_plane = get_cube_depth_min_max(
196
+ current_plane, planes_needed
197
+ )
198
+ start_plane = max(0, start_plane)
199
+ end_plane = min(len(signal_image.data), end_plane)
200
+
201
+ return start_plane, end_plane
202
+
203
+
204
+ def reraise(e: Exception) -> None:
205
+ """Re-raises the exception."""
206
+ raise Exception from e
207
+
208
+
36
209
  def detect_widget() -> FunctionGui:
37
210
  """
38
211
  Create a detection plugin GUI.
39
212
  """
40
213
  progress_bar = ProgressBar()
41
214
 
215
+ # options that is filled in from the gui
216
+ options = {
217
+ "signal_image": None,
218
+ "background_image": None,
219
+ "viewer": None,
220
+ "cell_layer": None,
221
+ }
222
+
223
+ signal_image_opt, background_image_opt, cell_layer_opt = get_heavy_widgets(
224
+ options
225
+ )
226
+
42
227
  @magicgui(
43
- header=header_label_widget,
44
228
  detection_label=html_label_widget("Cell detection", tag="h3"),
45
229
  **DataInputs.widget_representation(),
46
230
  **DetectionInputs.widget_representation(),
@@ -52,16 +236,13 @@ def detect_widget() -> FunctionGui:
52
236
  scrollable=True,
53
237
  )
54
238
  def widget(
55
- header,
56
239
  detection_label,
57
240
  data_options,
58
- viewer: napari.Viewer,
59
- signal_image: napari.layers.Image,
60
- background_image: napari.layers.Image,
61
241
  voxel_size_z: float,
62
242
  voxel_size_y: float,
63
243
  voxel_size_x: float,
64
244
  detection_options,
245
+ skip_detection: bool,
65
246
  soma_diameter: float,
66
247
  ball_xy_size: float,
67
248
  ball_z_size: float,
@@ -71,8 +252,10 @@ def detect_widget() -> FunctionGui:
71
252
  soma_spread_factor: float,
72
253
  max_cluster_size: int,
73
254
  classification_options,
74
- trained_model: Optional[Path],
255
+ skip_classification: bool,
75
256
  use_pre_trained_weights: bool,
257
+ trained_model: Optional[Path],
258
+ batch_size: int,
76
259
  misc_options,
77
260
  start_plane: int,
78
261
  end_plane: int,
@@ -86,16 +269,16 @@ def detect_widget() -> FunctionGui:
86
269
 
87
270
  Parameters
88
271
  ----------
89
- signal_image : napari.layers.Image
90
- Image layer containing the labelled cells
91
- background_image : napari.layers.Image
92
- Image layer without labelled cells
93
272
  voxel_size_z : float
94
273
  Size of your voxels in the axial dimension
95
274
  voxel_size_y : float
96
275
  Size of your voxels in the y direction (top to bottom)
97
276
  voxel_size_x : float
98
277
  Size of your voxels in the x direction (left to right)
278
+ skip_detection : bool
279
+ If selected, the detection step is skipped and instead we get the
280
+ detected cells from the cell layer below (from a previous
281
+ detection run or import)
99
282
  soma_diameter : float
100
283
  The expected in-plane soma diameter (microns)
101
284
  ball_xy_size : float
@@ -116,6 +299,11 @@ def detect_widget() -> FunctionGui:
116
299
  should be attempted
117
300
  use_pre_trained_weights : bool
118
301
  Select to use pre-trained model weights
302
+ batch_size : int
303
+ How many points to classify at one time
304
+ skip_classification : bool
305
+ If selected, the classification step is skipped and all cells from
306
+ the detection stage are added
119
307
  trained_model : Optional[Path]
120
308
  Trained model file path (home directory (default) -> pretrained
121
309
  weights)
@@ -132,18 +320,48 @@ def detect_widget() -> FunctionGui:
132
320
  reset_button :
133
321
  Reset parameters to default
134
322
  """
135
- if signal_image is None or background_image is None:
323
+ # we must manually call so that the parameters of these functions are
324
+ # initialized and updated. Because, if the images are open in napari
325
+ # before we open cellfinder, then these functions may never be called,
326
+ # even though the image filenames are shown properly in the parameters
327
+ # in the gui. Likely auto_call doesn't make magicgui call the functions
328
+ # in this circumstance, only if the parameters are updated once
329
+ # cellfinder plugin is fully open and initialized
330
+ signal_image_opt()
331
+ background_image_opt()
332
+ cell_layer_opt()
333
+
334
+ signal_image = options["signal_image"]
335
+
336
+ if signal_image is None or options["background_image"] is None:
136
337
  show_info("Both signal and background images must be specified.")
137
338
  return
339
+
340
+ detected_cells = []
341
+ if skip_detection:
342
+ if options["cell_layer"] is None:
343
+ show_info(
344
+ "Skip detection selected, but no existing cell layer "
345
+ "is selected."
346
+ )
347
+ return
348
+
349
+ # set cells as unknown so that classification will process them
350
+ detected_cells = napari_array_to_cells(
351
+ options["cell_layer"], Cell.UNKNOWN
352
+ )
353
+
138
354
  data_inputs = DataInputs(
139
355
  signal_image.data,
140
- background_image.data,
356
+ options["background_image"].data,
141
357
  voxel_size_z,
142
358
  voxel_size_y,
143
359
  voxel_size_x,
144
360
  )
145
361
 
146
362
  detection_inputs = DetectionInputs(
363
+ skip_detection,
364
+ detected_cells,
147
365
  soma_diameter,
148
366
  ball_xy_size,
149
367
  ball_z_size,
@@ -157,24 +375,18 @@ def detect_widget() -> FunctionGui:
157
375
  if use_pre_trained_weights:
158
376
  trained_model = None
159
377
  classification_inputs = ClassificationInputs(
160
- use_pre_trained_weights, trained_model
378
+ skip_classification,
379
+ use_pre_trained_weights,
380
+ trained_model,
381
+ batch_size,
161
382
  )
162
383
 
163
- end_plane = len(signal_image.data) if end_plane == 0 else end_plane
164
-
165
384
  if analyse_local:
166
- current_plane = viewer.dims.current_step[0]
167
-
168
- # so a reasonable number of cells in the plane are detected
169
- planes_needed = MIN_PLANES_ANALYSE + int(
170
- ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z)
385
+ start_plane, end_plane = find_local_planes(
386
+ options["viewer"], voxel_size_z, signal_image
171
387
  )
172
-
173
- start_plane, end_plane = get_cube_depth_min_max(
174
- current_plane, planes_needed
175
- )
176
- start_plane = max(0, start_plane)
177
- end_plane = min(len(signal_image.data), end_plane)
388
+ elif not end_plane:
389
+ end_plane = len(signal_image.data)
178
390
 
179
391
  misc_inputs = MiscInputs(
180
392
  start_plane, end_plane, n_free_cpus, analyse_local, debug
@@ -186,45 +398,34 @@ def detect_widget() -> FunctionGui:
186
398
  classification_inputs,
187
399
  misc_inputs,
188
400
  )
401
+
189
402
  worker.returned.connect(
190
- lambda points: add_layers(points, viewer=viewer)
403
+ get_results_callback(skip_classification, options["viewer"])
191
404
  )
192
-
193
405
  # Make sure if the worker emits an error, it is propagated to this
194
406
  # thread
195
- def reraise(e):
196
- raise Exception from e
197
-
198
407
  worker.errored.connect(reraise)
408
+ worker.connect_progress_bar_callback(progress_bar)
199
409
 
200
- def update_progress_bar(label: str, max: int, value: int):
201
- progress_bar.label = label
202
- progress_bar.max = max
203
- progress_bar.value = value
204
-
205
- worker.update_progress_bar.connect(update_progress_bar)
206
410
  worker.start()
207
411
 
208
- widget.header.value = widget_header
209
- widget.header.native.setOpenExternalLinks(True)
412
+ widget.native.layout().insertWidget(0, cellfinder_header())
210
413
 
211
- @widget.reset_button.changed.connect
212
- def restore_defaults():
213
- """
214
- Restore default widget values.
215
- """
216
- defaults = {
217
- **DataInputs.defaults(),
218
- **DetectionInputs.defaults(),
219
- **ClassificationInputs.defaults(),
220
- **MiscInputs.defaults(),
221
- }
222
- for name, value in defaults.items():
223
- if value is not None: # ignore fields with no default
224
- getattr(widget, name).value = value
414
+ # reset restores defaults
415
+ widget.reset_button.changed.connect(
416
+ partial(restore_options_defaults, widget)
417
+ )
225
418
 
226
419
  # Insert progress bar before the run and reset buttons
227
- widget.insert(-3, progress_bar)
420
+ widget.insert(widget.index("debug") + 1, progress_bar)
421
+
422
+ # add the signal and background image etc.
423
+ add_heavy_widgets(
424
+ widget,
425
+ (background_image_opt, signal_image_opt, cell_layer_opt),
426
+ ("Background image", "Signal image", "Candidate cell layer"),
427
+ ("voxel_size_z", "voxel_size_z", "soma_diameter"),
428
+ )
228
429
 
229
430
  scroll = QScrollArea()
230
431
  scroll.setWidget(widget._widget._qwidget)
@@ -1,8 +1,9 @@
1
1
  from dataclasses import dataclass
2
2
  from pathlib import Path
3
- from typing import Optional
3
+ from typing import List, Optional
4
4
 
5
5
  import numpy
6
+ from brainglobe_utils.cells.cells import Cell
6
7
 
7
8
  from cellfinder.napari.input_container import InputContainer
8
9
  from cellfinder.napari.utils import html_label_widget
@@ -59,6 +60,8 @@ class DataInputs(InputContainer):
59
60
  class DetectionInputs(InputContainer):
60
61
  """Container for cell candidate detection inputs."""
61
62
 
63
+ skip_detection: bool = False
64
+ detected_cells: Optional[List[Cell]] = None
62
65
  soma_diameter: float = 16.0
63
66
  ball_xy_size: float = 6
64
67
  ball_z_size: float = 15
@@ -75,6 +78,7 @@ class DetectionInputs(InputContainer):
75
78
  def widget_representation(cls) -> dict:
76
79
  return dict(
77
80
  detection_options=html_label_widget("Detection:"),
81
+ skip_detection=dict(value=cls.defaults()["skip_detection"]),
78
82
  soma_diameter=cls._custom_widget("soma_diameter"),
79
83
  ball_xy_size=cls._custom_widget(
80
84
  "ball_xy_size", custom_label="Ball filter (xy)"
@@ -107,8 +111,10 @@ class DetectionInputs(InputContainer):
107
111
  class ClassificationInputs(InputContainer):
108
112
  """Container for classification inputs."""
109
113
 
114
+ skip_classification: bool = False
110
115
  use_pre_trained_weights: bool = True
111
116
  trained_model: Optional[Path] = Path.home()
117
+ batch_size: int = 64
112
118
 
113
119
  def as_core_arguments(self) -> dict:
114
120
  args = super().as_core_arguments()
@@ -123,6 +129,10 @@ class ClassificationInputs(InputContainer):
123
129
  value=cls.defaults()["use_pre_trained_weights"]
124
130
  ),
125
131
  trained_model=dict(value=cls.defaults()["trained_model"]),
132
+ skip_classification=dict(
133
+ value=cls.defaults()["skip_classification"]
134
+ ),
135
+ batch_size=dict(value=cls.defaults()["batch_size"]),
126
136
  )
127
137
 
128
138
 
@@ -1,3 +1,4 @@
1
+ from magicgui.widgets import ProgressBar
1
2
  from napari.qt.threading import WorkerBase, WorkerBaseSignals
2
3
  from qtpy.QtCore import Signal
3
4
 
@@ -41,6 +42,19 @@ class Worker(WorkerBase):
41
42
  self.classification_inputs = classification_inputs
42
43
  self.misc_inputs = misc_inputs
43
44
 
45
+ def connect_progress_bar_callback(self, progress_bar: ProgressBar):
46
+ """
47
+ Connects the progress bar to the work so that updates are shown on
48
+ the bar.
49
+ """
50
+
51
+ def update_progress_bar(label: str, max: int, value: int):
52
+ progress_bar.label = label
53
+ progress_bar.max = max
54
+ progress_bar.value = value
55
+
56
+ self.update_progress_bar.connect(update_progress_bar)
57
+
44
58
  def work(self) -> list:
45
59
  self.update_progress_bar.emit("Setting up detection...", 1, 0)
46
60
 
@@ -58,10 +72,10 @@ class Worker(WorkerBase):
58
72
  def classify_callback(batch: int) -> None:
59
73
  self.update_progress_bar.emit(
60
74
  "Classifying cells",
61
- # Default cellfinder-core batch size is 32. This seems to give
75
+ # Default cellfinder-core batch size is 64. This seems to give
62
76
  # a slight underestimate of the number of batches though, so
63
77
  # allow for batch number to go over this
64
- max(self.npoints_detected // 32 + 1, batch + 1),
78
+ max(self.npoints_detected // 64 + 1, batch + 1),
65
79
  batch + 1,
66
80
  )
67
81
 
@@ -8,11 +8,7 @@ from napari.utils.notifications import show_info
8
8
  from qtpy.QtWidgets import QScrollArea
9
9
 
10
10
  from cellfinder.core.train.train_yml import run as train_yml
11
- from cellfinder.napari.utils import (
12
- header_label_widget,
13
- html_label_widget,
14
- widget_header,
15
- )
11
+ from cellfinder.napari.utils import cellfinder_header, html_label_widget
16
12
 
17
13
  from .train_containers import (
18
14
  MiscTrainingInputs,
@@ -41,7 +37,6 @@ def run_training(
41
37
 
42
38
  def training_widget() -> FunctionGui:
43
39
  @magicgui(
44
- header=header_label_widget,
45
40
  training_label=html_label_widget("Network training", tag="h3"),
46
41
  **TrainingDataInputs.widget_representation(),
47
42
  **OptionalNetworkInputs.widget_representation(),
@@ -52,7 +47,6 @@ def training_widget() -> FunctionGui:
52
47
  scrollable=True,
53
48
  )
54
49
  def widget(
55
- header: dict,
56
50
  training_label: dict,
57
51
  data_options: dict,
58
52
  yaml_files: Path,
@@ -161,8 +155,7 @@ def training_widget() -> FunctionGui:
161
155
  )
162
156
  worker.start()
163
157
 
164
- widget.header.value = widget_header
165
- widget.header.native.setOpenExternalLinks(True)
158
+ widget.native.layout().insertWidget(0, cellfinder_header())
166
159
 
167
160
  @widget.reset_button.changed.connect
168
161
  def restore_defaults():
@@ -4,7 +4,7 @@ from typing import Optional
4
4
 
5
5
  from magicgui.types import FileDialogMode
6
6
 
7
- from cellfinder.core.download.models import model_weight_urls
7
+ from cellfinder.core.download.download import model_filenames
8
8
  from cellfinder.core.train.train_yml import models
9
9
  from cellfinder.napari.input_container import InputContainer
10
10
  from cellfinder.napari.utils import html_label_widget
@@ -46,7 +46,7 @@ class OptionalNetworkInputs(InputContainer):
46
46
  trained_model: Optional[Path] = Path.home()
47
47
  model_weights: Optional[Path] = Path.home()
48
48
  model_depth: str = list(models.keys())[2]
49
- pretrained_model: str = str(list(model_weight_urls.keys())[0])
49
+ pretrained_model: str = str(list(model_filenames.keys())[0])
50
50
 
51
51
  def as_core_arguments(self) -> dict:
52
52
  arguments = super().as_core_arguments()
@@ -65,7 +65,7 @@ class OptionalNetworkInputs(InputContainer):
65
65
  ),
66
66
  pretrained_model=cls._custom_widget(
67
67
  "pretrained_model",
68
- choices=list(model_weight_urls.keys()),
68
+ choices=list(model_filenames.keys()),
69
69
  ),
70
70
  )
71
71