senoquant 1.0.0b2__py3-none-any.whl → 1.0.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. senoquant/__init__.py +6 -2
  2. senoquant/_reader.py +1 -1
  3. senoquant/reader/core.py +201 -18
  4. senoquant/tabs/batch/backend.py +18 -3
  5. senoquant/tabs/batch/frontend.py +8 -4
  6. senoquant/tabs/quantification/features/marker/dialog.py +26 -6
  7. senoquant/tabs/quantification/features/marker/export.py +97 -24
  8. senoquant/tabs/quantification/features/marker/rows.py +2 -2
  9. senoquant/tabs/quantification/features/spots/dialog.py +41 -11
  10. senoquant/tabs/quantification/features/spots/export.py +163 -10
  11. senoquant/tabs/quantification/frontend.py +2 -2
  12. senoquant/tabs/segmentation/frontend.py +46 -9
  13. senoquant/tabs/segmentation/models/cpsam/model.py +1 -1
  14. senoquant/tabs/segmentation/models/default_2d/model.py +22 -77
  15. senoquant/tabs/segmentation/models/default_3d/model.py +8 -74
  16. senoquant/tabs/segmentation/stardist_onnx_utils/_csbdeep/tools/create_zip_contents.py +0 -0
  17. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/inspect/probe.py +13 -13
  18. senoquant/tabs/segmentation/stardist_onnx_utils/onnx_framework/stardist_libs.py +171 -0
  19. senoquant/tabs/spots/frontend.py +42 -5
  20. senoquant/tabs/spots/models/ufish/details.json +17 -0
  21. senoquant/tabs/spots/models/ufish/model.py +129 -0
  22. senoquant/tabs/spots/ufish_utils/__init__.py +13 -0
  23. senoquant/tabs/spots/ufish_utils/core.py +357 -0
  24. senoquant/utils.py +1 -1
  25. senoquant-1.0.0b3.dist-info/METADATA +161 -0
  26. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/RECORD +41 -28
  27. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/top_level.txt +1 -0
  28. ufish/__init__.py +1 -0
  29. ufish/api.py +778 -0
  30. ufish/model/__init__.py +0 -0
  31. ufish/model/loss.py +62 -0
  32. ufish/model/network/__init__.py +0 -0
  33. ufish/model/network/spot_learn.py +50 -0
  34. ufish/model/network/ufish_net.py +204 -0
  35. ufish/model/train.py +175 -0
  36. ufish/utils/__init__.py +0 -0
  37. ufish/utils/img.py +418 -0
  38. ufish/utils/log.py +8 -0
  39. ufish/utils/spot_calling.py +115 -0
  40. senoquant/tabs/spots/models/rmp/details.json +0 -61
  41. senoquant/tabs/spots/models/rmp/model.py +0 -499
  42. senoquant/tabs/spots/models/udwt/details.json +0 -103
  43. senoquant/tabs/spots/models/udwt/model.py +0 -482
  44. senoquant-1.0.0b2.dist-info/METADATA +0 -193
  45. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/WHEEL +0 -0
  46. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/entry_points.txt +0 -0
  47. {senoquant-1.0.0b2.dist-info → senoquant-1.0.0b3.dist-info}/licenses/LICENSE +0 -0
@@ -216,45 +216,75 @@ class SpotsChannelsDialog(QDialog):
216
216
  return
217
217
  for layer in viewer.layers:
218
218
  if layer.__class__.__name__ == "Labels":
219
- layer_name = layer.name
220
219
  # Filter based on label type
221
- if filter_type == "cellular" and self._is_cellular_label(layer_name):
222
- combo.addItem(layer_name)
223
- elif filter_type == "spots" and self._is_spot_label(layer_name):
224
- combo.addItem(layer_name)
220
+ if filter_type == "cellular" and self._is_cellular_label(layer):
221
+ combo.addItem(layer.name)
222
+ elif filter_type == "spots" and self._is_spot_label(layer):
223
+ combo.addItem(layer.name)
225
224
  if current:
226
225
  index = combo.findText(current)
227
226
  if index != -1:
228
227
  combo.setCurrentIndex(index)
229
228
 
230
- def _is_cellular_label(self, layer_name: str) -> bool:
229
+ def _layer_task(self, layer: object) -> str | None:
230
+ """Return normalized segmentation task from layer metadata."""
231
+ metadata = getattr(layer, "metadata", None)
232
+ if not isinstance(metadata, dict):
233
+ return None
234
+ task = metadata.get("task")
235
+ if not isinstance(task, str):
236
+ return None
237
+ normalized = task.strip().lower()
238
+ return normalized or None
239
+
240
+ def _is_cellular_label(self, layer: object | str) -> bool:
231
241
  """Check if a label layer is a cellular segmentation.
232
242
 
233
243
  Parameters
234
244
  ----------
235
- layer_name : str
236
- Name of the labels layer.
245
+ layer : object or str
246
+ Labels layer object or labels layer name.
237
247
 
238
248
  Returns
239
249
  -------
240
250
  bool
241
251
  True if the layer is a cellular label (nuclear or cytoplasmic).
242
252
  """
253
+ if isinstance(layer, str):
254
+ layer_name = layer
255
+ task = None
256
+ else:
257
+ layer_name = str(getattr(layer, "name", ""))
258
+ task = self._layer_task(layer)
259
+ if task in {"nuclear", "cytoplasmic"}:
260
+ return True
261
+ if task is not None:
262
+ return False
243
263
  return layer_name.endswith("_nuc_labels") or layer_name.endswith("_cyto_labels")
244
264
 
245
- def _is_spot_label(self, layer_name: str) -> bool:
265
+ def _is_spot_label(self, layer: object | str) -> bool:
246
266
  """Check if a label layer is a spot segmentation.
247
267
 
248
268
  Parameters
249
269
  ----------
250
- layer_name : str
251
- Name of the labels layer.
270
+ layer : object or str
271
+ Labels layer object or labels layer name.
252
272
 
253
273
  Returns
254
274
  -------
255
275
  bool
256
276
  True if the layer is a spot label.
257
277
  """
278
+ if isinstance(layer, str):
279
+ layer_name = layer
280
+ task = None
281
+ else:
282
+ layer_name = str(getattr(layer, "name", ""))
283
+ task = self._layer_task(layer)
284
+ if task == "spots":
285
+ return True
286
+ if task is not None:
287
+ return False
258
288
  return layer_name.endswith("_spot_labels")
259
289
 
260
290
  def _refresh_image_combo(self, combo: QComboBox) -> None:
@@ -50,7 +50,7 @@ def export_spots(
50
50
  temp_dir : Path
51
51
  Temporary directory where outputs should be written.
52
52
  viewer : object, optional
53
- Napari viewer instance used to resolve layers by name and read
53
+ napari viewer instance used to resolve layers by name and read
54
54
  layer data. When ``None``, export is skipped.
55
55
  export_format : str, optional
56
56
  File format for exports (``"csv"`` or ``"xlsx"``). Values are
@@ -98,12 +98,21 @@ def export_spots(
98
98
  if not data.segmentations or not channels:
99
99
  return []
100
100
 
101
+ cross_map = _build_cell_cross_segmentation_map(viewer, data.segmentations)
102
+
101
103
  # --- Resolve a reference channel for physical pixel sizes ---
102
104
  first_channel_layer = None
103
105
  for channel in channels:
104
106
  first_channel_layer = _find_layer(viewer, channel.channel, "Image")
105
107
  if first_channel_layer is not None:
106
108
  break
109
+ file_path = ""
110
+ if first_channel_layer is not None:
111
+ metadata = getattr(first_channel_layer, "metadata", None)
112
+ if isinstance(metadata, dict):
113
+ raw_path = metadata.get("path")
114
+ if raw_path:
115
+ file_path = str(raw_path)
107
116
 
108
117
  for index, segmentation in enumerate(data.segmentations, start=0):
109
118
  # --- Resolve the cell segmentation labels layer ---
@@ -133,6 +142,8 @@ def export_spots(
133
142
  cell_rows = _initialize_rows(
134
143
  cell_ids, cell_centroids, cell_pixel_sizes
135
144
  )
145
+ for row in cell_rows:
146
+ row["file_path"] = file_path
136
147
 
137
148
  # --- Add morphological descriptors to the cell table ---
138
149
  add_morphology_columns(cell_rows, cell_labels, cell_ids, cell_pixel_sizes)
@@ -145,6 +156,7 @@ def export_spots(
145
156
  data.rois,
146
157
  label_name,
147
158
  )
159
+ _add_cross_reference_column(cell_rows, label_name, cell_ids, cross_map)
148
160
  cell_header = list(cell_rows[0].keys()) if cell_rows else []
149
161
 
150
162
  # --- Prepare containers and ROI masks for the spots table ---
@@ -182,6 +194,7 @@ def export_spots(
182
194
  spot_lookup,
183
195
  spot_table_pixel_sizes,
184
196
  spot_roi_columns,
197
+ file_path,
185
198
  )
186
199
 
187
200
  # --- Apply colocalization columns (if requested) ---
@@ -205,7 +218,9 @@ def export_spots(
205
218
  outputs.append(cell_path)
206
219
  if not spot_header:
207
220
  spot_header = _spot_header(
208
- cell_labels.ndim, spot_table_pixel_sizes, spot_roi_columns
221
+ cell_labels.ndim,
222
+ spot_table_pixel_sizes,
223
+ spot_roi_columns,
209
224
  )
210
225
  if data.export_colocalization:
211
226
  if "colocalizes_with" not in spot_header:
@@ -219,6 +234,136 @@ def export_spots(
219
234
  return outputs
220
235
 
221
236
 
237
+ def _build_cell_cross_segmentation_map(
238
+ viewer: object, segmentations: Sequence[object]
239
+ ) -> dict[tuple[str, int], list[tuple[str, int]]]:
240
+ """Build overlap mapping for configured cell segmentations.
241
+
242
+ Parameters
243
+ ----------
244
+ viewer : object
245
+ napari viewer instance containing labels layers.
246
+ segmentations : sequence of object
247
+ Segmentation configs with ``label`` attributes.
248
+
249
+ Returns
250
+ -------
251
+ dict
252
+ Mapping from ``(segmentation_name, label_id)`` to overlapping
253
+ ``(other_segmentation_name, other_label_id)`` entries.
254
+ """
255
+ all_segmentations: dict[str, tuple[np.ndarray, np.ndarray]] = {}
256
+ for segmentation in segmentations:
257
+ label_name = str(getattr(segmentation, "label", "")).strip()
258
+ if not label_name:
259
+ continue
260
+ labels_layer = _find_layer(viewer, label_name, "Labels")
261
+ if labels_layer is None:
262
+ continue
263
+ labels = layer_data_asarray(labels_layer)
264
+ if labels.size == 0:
265
+ continue
266
+ label_ids, _centroids = _compute_centroids(labels)
267
+ if label_ids.size == 0:
268
+ continue
269
+ all_segmentations[label_name] = (labels, label_ids)
270
+ return _build_cross_segmentation_map(all_segmentations)
271
+
272
+
273
+ def _build_cross_segmentation_map(
274
+ all_segmentations: dict[str, tuple[np.ndarray, np.ndarray]]
275
+ ) -> dict[tuple[str, int], list[tuple[str, int]]]:
276
+ """Build bidirectional overlap mapping across segmentations.
277
+
278
+ Parameters
279
+ ----------
280
+ all_segmentations : dict
281
+ Mapping from segmentation name to ``(labels, label_ids)`` tuples.
282
+
283
+ Returns
284
+ -------
285
+ dict
286
+ Mapping from ``(seg_name, label_id)`` to list of overlapping
287
+ ``(other_seg_name, other_label_id)`` tuples.
288
+ """
289
+ cross_map: dict[tuple[str, int], list[tuple[str, int]]] = {}
290
+ valid_ids: dict[str, set[int]] = {}
291
+
292
+ for seg_name, (_labels, label_ids) in all_segmentations.items():
293
+ ids = {int(label_id) for label_id in np.asarray(label_ids, dtype=int)}
294
+ valid_ids[seg_name] = ids
295
+ for label_id in ids:
296
+ cross_map[(seg_name, label_id)] = []
297
+
298
+ seg_names = list(all_segmentations.keys())
299
+ for idx_a, seg_name_a in enumerate(seg_names):
300
+ labels_a, _label_ids_a = all_segmentations[seg_name_a]
301
+ for seg_name_b in seg_names[idx_a + 1 :]:
302
+ labels_b, _label_ids_b = all_segmentations[seg_name_b]
303
+ if labels_a.shape != labels_b.shape:
304
+ warnings.warn(
305
+ "Spots export: segmentation shape mismatch for "
306
+ f"'{seg_name_a}' vs '{seg_name_b}'. "
307
+ "Skipping cross-segmentation overlap mapping for this pair.",
308
+ RuntimeWarning,
309
+ )
310
+ continue
311
+
312
+ mask = (labels_a > 0) & (labels_b > 0)
313
+ if not np.any(mask):
314
+ continue
315
+
316
+ overlap_pairs = np.column_stack((labels_a[mask], labels_b[mask]))
317
+ unique_pairs = np.unique(overlap_pairs, axis=0)
318
+ for label_id_a, label_id_b in unique_pairs:
319
+ id_a = int(label_id_a)
320
+ id_b = int(label_id_b)
321
+ if (
322
+ id_a not in valid_ids[seg_name_a]
323
+ or id_b not in valid_ids[seg_name_b]
324
+ ):
325
+ continue
326
+ cross_map[(seg_name_a, id_a)].append((seg_name_b, id_b))
327
+ cross_map[(seg_name_b, id_b)].append((seg_name_a, id_a))
328
+
329
+ return cross_map
330
+
331
+
332
+ def _add_cross_reference_column(
333
+ rows: list[dict[str, object]],
334
+ segmentation_name: str,
335
+ label_ids: np.ndarray,
336
+ cross_map: dict[tuple[str, int], list[tuple[str, int]]],
337
+ ) -> str:
338
+ """Add cross-segmentation overlap references to cell rows.
339
+
340
+ Parameters
341
+ ----------
342
+ rows : list of dict
343
+ Cell table rows to update in-place.
344
+ segmentation_name : str
345
+ Name of the segmentation being exported.
346
+ label_ids : numpy.ndarray
347
+ Cell label ids corresponding to ``rows``.
348
+ cross_map : dict
349
+ Overlap mapping from :func:`_build_cross_segmentation_map`.
350
+
351
+ Returns
352
+ -------
353
+ str
354
+ Name of the added column (``"overlaps_with"``).
355
+ """
356
+ for row, label_id in zip(rows, label_ids):
357
+ overlaps = cross_map.get((segmentation_name, int(label_id)), [])
358
+ if overlaps:
359
+ row["overlaps_with"] = ";".join(
360
+ f"{seg_name}_{other_id}" for seg_name, other_id in overlaps
361
+ )
362
+ else:
363
+ row["overlaps_with"] = ""
364
+ return "overlaps_with"
365
+
366
+
222
367
  def _build_channel_entries(
223
368
  viewer: object,
224
369
  channels: list,
@@ -230,7 +375,7 @@ def _build_channel_entries(
230
375
  Parameters
231
376
  ----------
232
377
  viewer : object
233
- Napari viewer instance used to resolve layers.
378
+ napari viewer instance used to resolve layers.
234
379
  channels : list
235
380
  Spots channel configurations (image + labels names).
236
381
  cell_shape : tuple of int
@@ -298,6 +443,7 @@ def _append_channel_exports(
298
443
  spot_lookup: dict[tuple[int, int], dict[str, object]],
299
444
  spot_table_pixel_sizes: np.ndarray | None,
300
445
  spot_roi_columns: list[tuple[str, np.ndarray]],
446
+ file_path: str,
301
447
  ) -> None:
302
448
  """Compute and append per-channel cell/spot metrics.
303
449
 
@@ -325,6 +471,8 @@ def _append_channel_exports(
325
471
  Pixel sizes to use for spot physical units.
326
472
  spot_roi_columns : list of tuple
327
473
  ROI masks for spot ROI membership columns.
474
+ file_path : str
475
+ Source image path copied to exported spot rows.
328
476
  """
329
477
  channel_label = entry["channel_label"]
330
478
  channel_layer = entry["channel_layer"]
@@ -393,6 +541,7 @@ def _append_channel_exports(
393
541
  channel_label,
394
542
  spot_table_pixel_sizes,
395
543
  spot_roi_columns,
544
+ file_path,
396
545
  )
397
546
  if spot_rows_for_channel:
398
547
  if not spot_header:
@@ -528,7 +677,7 @@ def _find_layer(viewer, name: str, layer_type: str):
528
677
  Parameters
529
678
  ----------
530
679
  viewer : object
531
- Napari viewer instance containing layers.
680
+ napari viewer instance containing layers.
532
681
  name : str
533
682
  Layer name to locate.
534
683
  layer_type : str
@@ -645,7 +794,7 @@ def _pixel_sizes(layer, ndim: int) -> np.ndarray | None:
645
794
  Parameters
646
795
  ----------
647
796
  layer : object
648
- Napari layer providing ``metadata``.
797
+ napari layer providing ``metadata``.
649
798
  ndim : int
650
799
  Dimensionality of the labels or image array.
651
800
 
@@ -781,7 +930,7 @@ def _add_roi_columns(
781
930
  label_ids : numpy.ndarray
782
931
  Label ids corresponding to the output rows.
783
932
  viewer : object or None
784
- Napari viewer used to resolve shapes layers.
933
+ napari viewer used to resolve shapes layers.
785
934
  rois : sequence of ROIConfig
786
935
  ROI configuration entries to evaluate.
787
936
  label_name : str
@@ -832,7 +981,7 @@ def _shapes_layer_mask(
832
981
  Parameters
833
982
  ----------
834
983
  layer : object
835
- Napari shapes layer instance.
984
+ napari shapes layer instance.
836
985
  shape : tuple of int
837
986
  Target mask shape matching the labels array.
838
987
 
@@ -863,7 +1012,7 @@ def _shape_masks_array(
863
1012
  Parameters
864
1013
  ----------
865
1014
  layer : object
866
- Napari shapes layer instance.
1015
+ napari shapes layer instance.
867
1016
  shape : tuple of int
868
1017
  Target mask shape.
869
1018
 
@@ -974,6 +1123,7 @@ def _spot_rows(
974
1123
  channel_label: str,
975
1124
  pixel_sizes: np.ndarray | None,
976
1125
  roi_columns: list[tuple[str, np.ndarray]],
1126
+ file_path: str,
977
1127
  ) -> list[dict[str, object]]:
978
1128
  """Build per-spot rows for export.
979
1129
 
@@ -996,6 +1146,8 @@ def _spot_rows(
996
1146
  centroid coordinates and area/volume are included.
997
1147
  roi_columns : list of tuple
998
1148
  Precomputed ROI column names and boolean masks.
1149
+ file_path : str
1150
+ Source image path to include on each row.
999
1151
 
1000
1152
  Returns
1001
1153
  -------
@@ -1015,6 +1167,7 @@ def _spot_rows(
1015
1167
  "spot_id": int(spot_id),
1016
1168
  "cell_id": int(cell_id),
1017
1169
  "channel": channel_label,
1170
+ "file_path": file_path,
1018
1171
  }
1019
1172
  for axis, value in zip(axes, centroid):
1020
1173
  row[f"centroid_{axis}_pixels"] = float(value)
@@ -1075,7 +1228,7 @@ def _spot_roi_columns(
1075
1228
  Parameters
1076
1229
  ----------
1077
1230
  viewer : object or None
1078
- Napari viewer instance used to resolve shapes layers.
1231
+ napari viewer instance used to resolve shapes layers.
1079
1232
  rois : sequence of ROIConfig
1080
1233
  ROI configuration entries to evaluate.
1081
1234
  label_name : str
@@ -1175,7 +1328,7 @@ def _spot_header(
1175
1328
  """
1176
1329
  axes = _axis_names(ndim)
1177
1330
  size_key_px, size_key_um, _scale = _spot_size_keys(ndim, pixel_sizes)
1178
- header = ["spot_id", "cell_id", "channel"]
1331
+ header = ["spot_id", "cell_id", "channel", "file_path"]
1179
1332
  header.extend([f"centroid_{axis}_pixels" for axis in axes])
1180
1333
  if pixel_sizes is not None and pixel_sizes.size == len(axes):
1181
1334
  header.extend([f"centroid_{axis}_um" for axis in axes])
@@ -56,7 +56,7 @@ class QuantificationTab(QWidget):
56
56
  backend : QuantificationBackend or None
57
57
  Backend instance for quantification workflows.
58
58
  napari_viewer : object or None
59
- Napari viewer used to populate layer dropdowns.
59
+ napari viewer used to populate layer dropdowns.
60
60
  """
61
61
  def __init__(
62
62
  self,
@@ -76,7 +76,7 @@ class QuantificationTab(QWidget):
76
76
  backend : QuantificationBackend or None
77
77
  Backend instance for quantification workflows.
78
78
  napari_viewer : object or None
79
- Napari viewer used to populate layer dropdowns.
79
+ napari viewer used to populate layer dropdowns.
80
80
  show_output_section : bool, optional
81
81
  Whether to show the output configuration controls.
82
82
  show_process_button : bool, optional
@@ -69,7 +69,7 @@ class SegmentationTab(QWidget):
69
69
  backend : SegmentationBackend or None
70
70
  Backend instance used to discover and load models.
71
71
  napari_viewer : object or None
72
- Napari viewer used to populate layer choices.
72
+ napari viewer used to populate layer choices.
73
73
  settings_backend : SettingsBackend or None
74
74
  Settings store used for preload configuration.
75
75
  """
@@ -87,7 +87,7 @@ class SegmentationTab(QWidget):
87
87
  backend : SegmentationBackend or None
88
88
  Backend instance used to discover and load models.
89
89
  napari_viewer : object or None
90
- Napari viewer used to populate layer choices.
90
+ napari viewer used to populate layer choices.
91
91
  settings_backend : SettingsBackend or None
92
92
  Settings store used for preload configuration.
93
93
  """
@@ -889,7 +889,7 @@ class SegmentationTab(QWidget):
889
889
  Parameters
890
890
  ----------
891
891
  layer : object or None
892
- Napari layer to validate.
892
+ napari layer to validate.
893
893
  label : str
894
894
  User-facing label for notifications.
895
895
 
@@ -972,13 +972,50 @@ class SegmentationTab(QWidget):
972
972
  if self._viewer is None or source_layer is None or masks is None:
973
973
  return
974
974
  label_name = f"{source_layer.name}_{model_name}_{label_type}_labels"
975
- self._viewer.add_labels(
976
- masks,
977
- name=label_name,
978
- )
975
+ task_value = {
976
+ "nuc": "nuclear",
977
+ "cyto": "cytoplasmic",
978
+ }.get(label_type)
979
+ source_metadata = getattr(source_layer, "metadata", {})
980
+ merged_metadata: dict[str, object] = {}
981
+ if isinstance(source_metadata, dict):
982
+ merged_metadata.update(source_metadata)
983
+ if task_value:
984
+ merged_metadata["task"] = task_value
985
+
986
+ labels_layer = None
987
+ if Labels is not None and hasattr(self._viewer, "add_layer"):
988
+ # Add a fully configured Labels layer object to avoid name-based lookup.
989
+ labels_layer = Labels(
990
+ masks,
991
+ name=label_name,
992
+ metadata=merged_metadata,
993
+ )
994
+ added_layer = self._viewer.add_layer(labels_layer)
995
+ if added_layer is not None:
996
+ labels_layer = added_layer
997
+ elif hasattr(self._viewer, "add_labels"):
998
+ try:
999
+ labels_layer = self._viewer.add_labels(
1000
+ masks,
1001
+ name=label_name,
1002
+ metadata=merged_metadata,
1003
+ )
1004
+ except TypeError:
1005
+ labels_layer = self._viewer.add_labels(
1006
+ masks,
1007
+ name=label_name,
1008
+ )
1009
+
1010
+ if labels_layer is None:
1011
+ return
979
1012
 
980
- # Get the labels layer and set contour = 2
981
- labels_layer = self._viewer.layers[label_name]
1013
+ layer_metadata = getattr(labels_layer, "metadata", {})
1014
+ if isinstance(layer_metadata, dict):
1015
+ merged_metadata.update(layer_metadata)
1016
+ if task_value:
1017
+ merged_metadata["task"] = task_value
1018
+ labels_layer.metadata = merged_metadata
982
1019
  labels_layer.contour = 2
983
1020
 
984
1021
 
@@ -103,7 +103,7 @@ class CPSAMModel(SenoQuantSegmentationModel):
103
103
  Parameters
104
104
  ----------
105
105
  layer : object or None
106
- Napari layer to convert.
106
+ napari layer to convert.
107
107
  required : bool
108
108
  Whether a missing layer should raise an error.
109
109
 
@@ -19,6 +19,9 @@ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework import (
19
19
  normalize,
20
20
  predict_tiled,
21
21
  )
22
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.stardist_libs import (
23
+ ensure_stardist_libs,
24
+ )
22
25
  from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect import (
23
26
  make_probe_image,
24
27
  )
@@ -234,7 +237,7 @@ class StarDistOnnxModel(SenoQuantSegmentationModel):
234
237
  Parameters
235
238
  ----------
236
239
  layer : object or None
237
- Napari layer to convert.
240
+ napari layer to convert.
238
241
  required : bool
239
242
  Whether a missing layer should raise an error.
240
243
 
@@ -257,6 +260,9 @@ class StarDistOnnxModel(SenoQuantSegmentationModel):
257
260
  session = self._sessions.get(model_path)
258
261
  if session is None or providers_override is not None:
259
262
  providers = providers_override or self._preferred_providers()
263
+ preload = getattr(ort, "preload_dlls", None)
264
+ if callable(preload):
265
+ preload()
260
266
  session = ort.InferenceSession(
261
267
  str(model_path),
262
268
  providers=providers,
@@ -321,7 +327,17 @@ class StarDistOnnxModel(SenoQuantSegmentationModel):
321
327
  ndim = image.ndim
322
328
  div_by = self._div_by_cache.get(model_path)
323
329
  if div_by is None:
324
- div_by = (16,) * ndim
330
+ try:
331
+ from senoquant.tabs.segmentation.stardist_onnx_utils.onnx_framework.inspect import (
332
+ infer_div_by,
333
+ )
334
+ except Exception:
335
+ div_by = (16,) * ndim
336
+ else:
337
+ try:
338
+ div_by = infer_div_by(model_path, ndim=ndim)
339
+ except Exception:
340
+ div_by = (16,) * ndim
325
341
  self._div_by_cache[model_path] = div_by
326
342
 
327
343
  overlap = self._overlap_cache.get(model_path)
@@ -373,8 +389,9 @@ class StarDistOnnxModel(SenoQuantSegmentationModel):
373
389
 
374
390
  tile_shape = snap_shape(tile_shape, patterns)
375
391
  tile_shape = tuple(
376
- max(16, (ts // 16) * 16)
377
- for ts in tile_shape
392
+ max(int(div), (int(ts) // int(div)) * int(div))
393
+ if int(div) > 0 else int(ts)
394
+ for ts, div in zip(tile_shape, div_by)
378
395
  )
379
396
  overlap = tuple(
380
397
  max(0, min(int(ov), max(0, ts - 1)))
@@ -487,85 +504,13 @@ class StarDistOnnxModel(SenoQuantSegmentationModel):
487
504
  libraries are absent, allowing Python utilities to import.
488
505
  """
489
506
  utils_root = self._get_utils_root()
490
- csbdeep_root = utils_root / "_csbdeep"
491
- if csbdeep_root.exists():
492
- csbdeep_path = str(csbdeep_root)
493
- if csbdeep_path not in sys.path:
494
- sys.path.insert(0, csbdeep_path)
495
-
496
507
  stardist_pkg = (
497
508
  "senoquant.tabs.segmentation.stardist_onnx_utils._stardist"
498
509
  )
499
- if stardist_pkg not in sys.modules:
500
- pkg = types.ModuleType(stardist_pkg)
501
- pkg.__path__ = [str(utils_root / "_stardist")]
502
- sys.modules[stardist_pkg] = pkg
503
-
504
- base_pkg = f"{stardist_pkg}.lib"
505
- lib_dirs = [utils_root / "_stardist" / "lib"]
506
- for entry in list(sys.path):
507
- if not entry:
508
- continue
509
- try:
510
- candidate = (
511
- Path(entry)
512
- / "senoquant"
513
- / "tabs"
514
- / "segmentation"
515
- / "stardist_onnx_utils"
516
- / "_stardist"
517
- / "lib"
518
- )
519
- except Exception:
520
- continue
521
- if candidate.exists():
522
- lib_dirs.append(candidate)
523
-
524
- if base_pkg in sys.modules:
525
- pkg = sys.modules[base_pkg]
526
- pkg.__path__ = [str(p) for p in lib_dirs]
527
- else:
528
- pkg = types.ModuleType(base_pkg)
529
- pkg.__path__ = [str(p) for p in lib_dirs]
530
- sys.modules[base_pkg] = pkg
531
-
532
- def _stub(*_args, **_kwargs):
533
- raise RuntimeError("StarDist compiled ops are unavailable.")
534
-
535
- has_2d = False
536
- has_3d = False
537
- for lib_dir in lib_dirs:
538
- has_2d = has_2d or any(lib_dir.glob("stardist2d*.so")) or any(
539
- lib_dir.glob("stardist2d*.pyd")
540
- )
541
- has_3d = has_3d or any(lib_dir.glob("stardist3d*.so")) or any(
542
- lib_dir.glob("stardist3d*.pyd")
543
- )
510
+ has_2d, has_3d = ensure_stardist_libs(utils_root, stardist_pkg)
544
511
  self._has_stardist_2d_lib = has_2d
545
512
  self._has_stardist_3d_lib = has_3d
546
513
 
547
- mod2d = f"{base_pkg}.stardist2d"
548
- if has_2d and mod2d in sys.modules:
549
- if getattr(sys.modules[mod2d], "__file__", None) is None:
550
- del sys.modules[mod2d]
551
- if not has_2d and mod2d not in sys.modules:
552
- module = types.ModuleType(mod2d)
553
- module.c_star_dist = _stub
554
- module.c_non_max_suppression_inds_old = _stub
555
- module.c_non_max_suppression_inds = _stub
556
- sys.modules[mod2d] = module
557
-
558
- mod3d = f"{base_pkg}.stardist3d"
559
- if has_3d and mod3d in sys.modules:
560
- if getattr(sys.modules[mod3d], "__file__", None) is None:
561
- del sys.modules[mod3d]
562
- if not has_3d and mod3d not in sys.modules:
563
- module = types.ModuleType(mod3d)
564
- module.c_star_dist3d = _stub
565
- module.c_polyhedron_to_label = _stub
566
- module.c_non_max_suppression_inds = _stub
567
- sys.modules[mod3d] = module
568
-
569
514
  def _get_rays_class(self):
570
515
  """Load and cache the StarDist Rays_GoldenSpiral class."""
571
516
  if self._rays_class is not None: