datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. datamint/__init__.py +1 -3
  2. datamint/api/__init__.py +0 -3
  3. datamint/api/base_api.py +286 -54
  4. datamint/api/client.py +76 -13
  5. datamint/api/endpoints/__init__.py +2 -2
  6. datamint/api/endpoints/annotations_api.py +186 -28
  7. datamint/api/endpoints/deploy_model_api.py +78 -0
  8. datamint/api/endpoints/models_api.py +1 -0
  9. datamint/api/endpoints/projects_api.py +38 -7
  10. datamint/api/endpoints/resources_api.py +227 -100
  11. datamint/api/entity_base_api.py +66 -7
  12. datamint/apihandler/base_api_handler.py +0 -1
  13. datamint/apihandler/dto/annotation_dto.py +2 -0
  14. datamint/client_cmd_tools/datamint_config.py +0 -1
  15. datamint/client_cmd_tools/datamint_upload.py +3 -1
  16. datamint/configs.py +11 -7
  17. datamint/dataset/base_dataset.py +24 -4
  18. datamint/dataset/dataset.py +1 -1
  19. datamint/entities/__init__.py +1 -1
  20. datamint/entities/annotations/__init__.py +13 -0
  21. datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
  22. datamint/entities/annotations/image_classification.py +12 -0
  23. datamint/entities/annotations/image_segmentation.py +252 -0
  24. datamint/entities/annotations/volume_segmentation.py +273 -0
  25. datamint/entities/base_entity.py +100 -6
  26. datamint/entities/cache_manager.py +129 -15
  27. datamint/entities/datasetinfo.py +60 -65
  28. datamint/entities/deployjob.py +18 -0
  29. datamint/entities/project.py +39 -0
  30. datamint/entities/resource.py +310 -46
  31. datamint/lightning/__init__.py +1 -0
  32. datamint/lightning/datamintdatamodule.py +103 -0
  33. datamint/mlflow/__init__.py +65 -0
  34. datamint/mlflow/artifact/__init__.py +1 -0
  35. datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
  36. datamint/mlflow/env_utils.py +131 -0
  37. datamint/mlflow/env_vars.py +5 -0
  38. datamint/mlflow/flavors/__init__.py +17 -0
  39. datamint/mlflow/flavors/datamint_flavor.py +150 -0
  40. datamint/mlflow/flavors/model.py +877 -0
  41. datamint/mlflow/lightning/callbacks/__init__.py +1 -0
  42. datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
  43. datamint/mlflow/models/__init__.py +93 -0
  44. datamint/mlflow/tracking/datamint_store.py +76 -0
  45. datamint/mlflow/tracking/default_experiment.py +27 -0
  46. datamint/mlflow/tracking/fluent.py +91 -0
  47. datamint/utils/env.py +27 -0
  48. datamint/utils/visualization.py +21 -13
  49. datamint-2.9.0.dist-info/METADATA +220 -0
  50. datamint-2.9.0.dist-info/RECORD +73 -0
  51. {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
  52. datamint-2.9.0.dist-info/entry_points.txt +18 -0
  53. datamint/apihandler/exp_api_handler.py +0 -204
  54. datamint/experiment/__init__.py +0 -1
  55. datamint/experiment/_patcher.py +0 -570
  56. datamint/experiment/experiment.py +0 -1049
  57. datamint-2.3.3.dist-info/METADATA +0 -125
  58. datamint-2.3.3.dist-info/RECORD +0 -54
  59. datamint-2.3.3.dist-info/entry_points.txt +0 -4
@@ -0,0 +1,877 @@
1
+ """
2
+ DataMint Model Adapter Module
3
+
4
+ This module provides a flexible framework for wrapping ML models to work with DataMint's
5
+ annotation system. It supports various prediction modes for different data types and use cases.
6
+ """
7
+
8
+ from typing import Any, TypeAlias
9
+ from collections.abc import Callable
10
+ from abc import ABC, abstractmethod
11
+ from enum import Enum
12
+ from dataclasses import dataclass
13
+ from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE
14
+ from mlflow.pyfunc import load_model as pyfunc_load_model
15
+ from mlflow.pytorch import load_model as pytorch_load_model
16
+ from mlflow.pyfunc import PyFuncModel, PythonModel, PythonModelContext
17
+ from datamint.entities.annotations import Annotation
18
+ from datamint.entities.resource import Resource
19
+ import logging
20
+ import os
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Type aliases
25
+ PredictionResult: TypeAlias = list[list[Annotation]]
26
+
27
+
28
+ @dataclass
29
+ class ModelSettings:
30
+ """
31
+ Deployment and inference configuration for DatamintModel.
32
+
33
+ These settings are serialized with the model and used by remote MLflow servers
34
+ to properly configure the runtime environment.
35
+ """
36
+ # Hardware requirements
37
+ need_gpu: bool = False
38
+ """Whether GPU is required for inference"""
39
+
40
+ @classmethod
41
+ def from_dict(cls, data: dict[str, Any]) -> 'ModelSettings':
42
+ """Create config from dictionary, raising error on unknown keys."""
43
+ valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
44
+ invalid_fields = set(data.keys()) - valid_fields
45
+ if invalid_fields:
46
+ raise ValueError(f"Invalid fields for ModelSettings: {', '.join(sorted(invalid_fields))}")
47
+ return cls(**data)
48
+
49
+
50
+ class PredictionMode(str, Enum):
51
+ """
52
+ Enumeration of supported prediction modes.
53
+
54
+ Each mode corresponds to a specific method signature in DatamintModel.
55
+ """
56
+ # Standard modes
57
+ DEFAULT = 'default' # Default: process entire resource as-is
58
+
59
+ # Simple modes
60
+ IMAGE = 'image' # Process single 2d image resource
61
+
62
+ # Video/temporal modes
63
+ FRAME = 'frame' # Extract and process specific frame
64
+ FRAME_RANGE = 'frame_range' # Process contiguous frame range
65
+ ALL_FRAMES = 'all_frames' # Process all frames independently
66
+ TEMPORAL_SEQUENCE = 'temporal_sequence' # Process with temporal context window
67
+
68
+ # 3D volume modes
69
+ SLICE = 'slice' # Extract and process specific slice
70
+ SLICE_RANGE = 'slice_range' # Process contiguous slice range
71
+ PRIMARY_SLICE = 'primary_slice' # Process center/primary slice
72
+ # MULTI_PLANE = 'multi_plane' # Process multiple anatomical planes
73
+ VOLUME = 'volume' # Process entire 3D volume
74
+
75
+ # Spatial modes
76
+ # ROI = 'roi' # Process single region of interest
77
+ # MULTI_ROI = 'multi_roi' # Process multiple regions
78
+ # TILE = 'tile' # Split into tiles (whole slide imaging)
79
+ # PATCH = 'patch' # Extract patches around points
80
+
81
+ # Advanced modes
82
+ INTERACTIVE = 'interactive' # With user prompts (SAM-like)
83
+ FEW_SHOT = 'few_shot' # With context examples
84
+ # MULTI_VIEW = 'multi_view' # Multiple views of same subject
85
+
86
+
87
+ class DatamintModel(ABC, PythonModel):
88
+ """
89
+ Abstract adapter class for wrapping models to produce Datamint annotations.
90
+
91
+ This class provides a flexible framework for integrating ML models with DataMint.
92
+ The main `predict()` method routes requests to specific handlers based on the
93
+ prediction mode, allowing users to implement only the modes they need.
94
+
95
+ Quick Start:
96
+ -----------
97
+ ```python
98
+ class MyModel(DatamintModel):
99
+ def __init__(self):
100
+ super().__init__(
101
+ mlflow_models_uri={'model': 'models:/MyModel/latest'},
102
+ config=ModelSettings(need_gpu=True)
103
+ )
104
+
105
+ def predict_default(self, model_input, **kwargs):
106
+ # Access the device for your computation
107
+ device = self.inference_device # Reads from MLFLOW_DEFAULT_PREDICTION_DEVICE or defaults to 'cpu'
108
+ model = self.mlflow_models['model'].get_raw_model().to(device)
109
+ # ... process and return annotations
110
+ return predictions
111
+ ```
112
+
113
+ Prediction Modes:
114
+ ----------------
115
+ Users can request different prediction modes via params['mode']:
116
+
117
+ **Default**: Default processing
118
+ ```python
119
+ model.predict(resources) # or params={'mode': 'default'}
120
+ ```
121
+
122
+ **Video Frame**: Extract specific frame
123
+ ```python
124
+ model.predict(videos, params={'mode': 'frame', 'frame_index': 42})
125
+ ```
126
+
127
+ **3D Slice**: Extract specific slice
128
+ ```python
129
+ model.predict(volumes, params={'mode': 'slice', 'slice_index': 50, 'axis': 'axial'})
130
+ ```
131
+
132
+ **Interactive**: With prompts
133
+ ```python
134
+ model.predict(images, params={'mode': 'interactive', 'prompt': {'points': [[x, y]], 'labels': [1]}})
135
+ ```
136
+
137
+ Common Parameters:
138
+ -----------------
139
+ - `confidence_threshold` (float): Filter predictions by confidence score
140
+ - `batch_size` (int): Batch size for processing
141
+ - `render_annotation` (bool): Return annotated images instead of annotations
142
+
143
+ Device Configuration:
144
+ --------------------
145
+ The device for computation is automatically configured from the
146
+ `MLFLOW_DEFAULT_PREDICTION_DEVICE` environment variable. Access it via `self.inference_device`.
147
+ Defaults to 'cpu' if not set.
148
+
149
+ Implementation Guide:
150
+ --------------------
151
+ 1. Implement `predict_default()` - this is required and serves as fallback
152
+ 2. Optionally implement specific modes your model supports
153
+ 3. Override `_render_annotations()` if you want to support visualization
154
+ 4. Use `self.mlflow_models` to access loaded MLflow models
155
+ 5. Configure deployment settings via `ModelSettings`
156
+
157
+ See individual method docstrings for detailed parameter specifications.
158
+ """
159
+
160
+ LINKED_MODELS_DIR = "linked_models"
161
+ _CACHED_ATTRS = ['_mlflow_models', '_mlflow_torch_models', '_inference_device']
162
+
163
+ def __init__(self,
164
+ settings: ModelSettings | dict[str, Any] | None = None,
165
+ mlflow_torch_models_uri: dict[str, str] | None = None,
166
+ mlflow_models_uri: dict[str, str] | None = None,
167
+ ) -> None:
168
+ """
169
+ Initialize the DatamintModel adapter.
170
+
171
+ Args:
172
+ config: ModelSettings instance or dict with deployment settings.
173
+ Example: {'need_gpu': True}
174
+ mlflow_torch_models_uri: Dictionary mapping model names to PyTorch model URIs.
175
+ Example: {'backbone': 'models:/MyClassifier/2'}
176
+ These models will be lazy-loaded and accessible via ``self.mlflow_torch_models_uri['backbone']``
177
+ mlflow_models_uri: Dictionary mapping model names to MLflow URIs.
178
+ Example: {'detector': 'models:/MyDetector/1',
179
+ 'classifier': 'models:/MyClassifier/latest'}
180
+ These models will be lazy-loaded and accessible via ``self.mlflow_models['detector']``
181
+
182
+ """
183
+ super().__init__()
184
+ self.mlflow_models_uri = (mlflow_models_uri or {}).copy()
185
+ self.mlflow_torch_models_uri = (mlflow_torch_models_uri or {}).copy()
186
+
187
+ # Handle settings - convert dict to ModelSettings if needed
188
+ if isinstance(settings, dict):
189
+ self.settings = ModelSettings.from_dict(settings)
190
+ elif isinstance(settings, ModelSettings):
191
+ self.settings = settings
192
+ else:
193
+ self.settings = ModelSettings()
194
+
195
+ self._supported_modes_cache = None
196
+
197
+ def load_context(self, context: PythonModelContext):
198
+ """
199
+ Called by MLflow when loading the model.
200
+
201
+ Override this if you need custom loading logic.
202
+ """
203
+ self._inference_device = self._load_inference_device(context=context)
204
+ self._mlflow_models = self._load_mlflow_models()
205
+ self._mlflow_torch_models = self._load_mlflow_torch_models()
206
+ # model_config = context.model_config
207
+
208
+ def _get_linked_models_uri(self) -> dict[str, Any]:
209
+ """Get all linked models (MLflow and PyTorch)"""
210
+ linked = {}
211
+ linked.update(self.mlflow_models_uri)
212
+ linked.update(self.mlflow_torch_models_uri)
213
+ return linked
214
+
215
+ def _clear_linked_models_cache(self):
216
+ """Clear loaded linked models to free memory"""
217
+
218
+ for attr in self._CACHED_ATTRS:
219
+ if hasattr(self, attr):
220
+ delattr(self, attr)
221
+
222
+
223
+ def __getstate__(self):
224
+ state = self.__dict__.copy()
225
+
226
+ for attr in self._CACHED_ATTRS:
227
+ if attr in state:
228
+ del state[attr]
229
+
230
+ return state
231
+
232
+ def __setstate__(self, state):
233
+ self.__dict__.update(state)
234
+ # avoid possible invalid states after unpickling
235
+ self._clear_linked_models_cache()
236
+
237
+ def _load_inference_device(self, context: PythonModelContext | None = None) -> str:
238
+ """
239
+ Load inference device from model config or environment variable.
240
+ """
241
+ import torch
242
+
243
+ device = None
244
+ if context and context.model_config:
245
+ device = context.model_config.get("device", None)
246
+ logger.info(f"Model config device: {device}")
247
+ if device is None:
248
+ env_device = MLFLOW_DEFAULT_PREDICTION_DEVICE.get()
249
+ if env_device:
250
+ device = env_device
251
+ elif torch.cuda.is_available():
252
+ device = 'cuda'
253
+ else:
254
+ device = 'cpu'
255
+
256
+ logger.info(f"Set inference device: {device}")
257
+ return device
258
+
259
+ @property
260
+ def inference_device(self) -> str:
261
+ if hasattr(self, '_inference_device') and self._inference_device is not None:
262
+ return self._inference_device
263
+ env_device = MLFLOW_DEFAULT_PREDICTION_DEVICE.get()
264
+ if env_device:
265
+ logger.info(f"Inference device not set; getting from environment variable ({env_device})")
266
+ return env_device
267
+ logger.warning("Inference device not set; defaulting to 'cpu'")
268
+ return 'cpu'
269
+
270
+ def _load_models_generic(self, uris: dict[str, str],
271
+ loader_func: Callable,
272
+ **loader_kwargs) -> dict[str, Any]:
273
+ """Generic helper to load models from URIs."""
274
+ loaded_models = {}
275
+ for name, uri in uris.items():
276
+ model_uri = uri
277
+ if os.path.exists(uri):
278
+ logger.info(f"Model '{name}' found locally at '{uri}'")
279
+ model_uri = os.path.abspath(uri)
280
+ elif uri.startswith("models:/"):
281
+ local_path = uri.replace("models:/", DatamintModel.LINKED_MODELS_DIR + "/", 1)
282
+ if os.path.exists(local_path):
283
+ logger.info(f"Model '{name}' found locally at '{local_path}'")
284
+ model_uri = os.path.abspath(local_path)
285
+
286
+ try:
287
+ loaded_models[name] = loader_func(model_uri, **loader_kwargs)
288
+ logger.info(f"Loaded model '{name}' from {model_uri}")
289
+ except Exception as e:
290
+ logger.error(f"Failed to load model '{name}' from {model_uri}: {e}")
291
+ raise
292
+ return loaded_models
293
+
294
+ def _load_mlflow_models(self) -> dict[str, PyFuncModel]:
295
+ """Load all MLflow models specified in mlflow_models_uri."""
296
+ return self._load_models_generic(
297
+ self.mlflow_models_uri,
298
+ pyfunc_load_model,
299
+ model_config={'device': self.inference_device}
300
+ )
301
+
302
+ def _load_mlflow_torch_models(self) -> dict[str, Any]:
303
+ """Load all MLflow PyTorch models specified in mlflow_torch_models_uri."""
304
+ models = self._load_models_generic(
305
+ self.mlflow_torch_models_uri,
306
+ pytorch_load_model,
307
+ device=self.inference_device,
308
+ map_location=self.inference_device,
309
+ )
310
+ for m in models.values():
311
+ if hasattr(m, 'eval'):
312
+ m.eval()
313
+ return models
314
+
315
+ def get_mlflow_models(self) -> dict[str, PyFuncModel]:
316
+ """
317
+ Access loaded MLflow models.
318
+
319
+ Returns:
320
+ Dictionary mapping model names to PyFuncModel instances.
321
+ Use .get_raw_model() to access the underlying model (e.g., torch.nn.Module)
322
+ """
323
+ if not hasattr(self, '_mlflow_models'):
324
+ logger.warning("Loading MLflow models on first access")
325
+ self._mlflow_models = self._load_mlflow_models()
326
+ return self._mlflow_models
327
+
328
+ def get_mlflow_torch_models(self) -> dict[str, Any]:
329
+ """
330
+ Access loaded MLflow PyTorch models.
331
+
332
+ Returns:
333
+ Dictionary mapping model names to PyTorch model instances.
334
+ """
335
+ if not hasattr(self, '_mlflow_torch_models'):
336
+ logger.warning("Loading MLflow PyTorch models on first access")
337
+ self._mlflow_torch_models = self._load_mlflow_torch_models()
338
+ return self._mlflow_torch_models
339
+
340
+ # def _preprocess_input(self,
341
+ # model_input: list[InferenceResource | Resource | dict[str, Any]],
342
+ # params: dict[str, Any]) -> list[Resource]:
343
+ # """
344
+ # Preprocess input to convert to list of Resource objects.
345
+
346
+ # Args:
347
+ # model_input: List of InferenceResource, Resource, or dict
348
+ # params: Additional parameters (unused here)
349
+ # Returns:
350
+ # List of Resource objects
351
+ # """
352
+ # resources = []
353
+ # for item in model_input:
354
+ # if isinstance(item, Resource):
355
+ # resources.append(item)
356
+ # elif isinstance(item, InferenceResource):
357
+ # resources.append(item.fabricate_resource())
358
+ # elif isinstance(item, dict):
359
+ # if 'local_filepath' in item or item.get('id', None) == '':
360
+ # logger.debug(f'Creating LocalResource from dict: {item}')
361
+ # resources.append(LocalResource(local_filepath=item['local_filepath']))
362
+ # elif 'upload_channel' in item or 'location' in item or 'storage' in item:
363
+ # resources.append(Resource(**item))
364
+ # else:
365
+ # resources.append(InferenceResource(**item).fabricate_resource())
366
+ # else:
367
+ # raise ValueError(f"Unsupported input type: {type(item)}")
368
+ # return resources
369
+
370
+ def predict(self,
371
+ model_input: list[Resource],
372
+ params: dict[str, Any] | None = None) -> PredictionResult:
373
+ """
374
+ Main prediction entry point.
375
+
376
+ Routes to appropriate prediction method based on params['mode'].
377
+ DO NOT override this method - implement specific predict_* methods instead.
378
+
379
+ Args:
380
+ model_input: List of Resource objects to process
381
+ params: Optional configuration dictionary with keys:
382
+ - mode (str): Prediction mode (default: 'standard')
383
+ - confidence_threshold (float): Filter by confidence
384
+ - batch_size (int): Batch size for processing
385
+ - render_annotation (bool): Return rendered images
386
+ - device (str): Computation device
387
+ + mode-specific parameters (see individual method docs)
388
+
389
+ Returns:
390
+ List of annotation lists (one per resource), or rendered outputs
391
+ if render_annotation=True
392
+
393
+ Raises:
394
+ ValueError: If mode is invalid or required parameters are missing
395
+ NotImplementedError: If requested mode is not implemented
396
+ """
397
+ params = params or {}
398
+ # model_input = self._preprocess_input(model_input, params)
399
+
400
+ # Parse and validate mode
401
+ mode = self._parse_mode(model_input=model_input, params=params)
402
+
403
+ # Route to appropriate prediction method
404
+ try:
405
+ if not self._is_mode_implemented(mode):
406
+ if self._is_mode_implemented(PredictionMode.DEFAULT):
407
+ logger.info(f"Mode '{mode.value}' not implemented, falling back to default")
408
+ mode = PredictionMode.DEFAULT
409
+ else:
410
+ raise NotImplementedError
411
+ logger.debug(f"Routing to '{mode.value}' mode for {len(model_input)} resources")
412
+ result = self._route_prediction(model_input, mode, params)
413
+
414
+ # Apply common post-processing
415
+ result = self._post_process(result, model_input, params)
416
+
417
+ return result
418
+
419
+ except NotImplementedError:
420
+ available = self.get_supported_modes()
421
+ raise NotImplementedError(
422
+ f"Prediction mode '{mode.value}' is not supported by this model.\n"
423
+ f"Supported modes: {', '.join(available)}\n"
424
+ f"Implement predict_{mode.value}() to add support for this mode."
425
+ )
426
+
427
+ def _parse_mode(self,
428
+ params: dict[str, Any],
429
+ model_input: list[Resource] | None = None) -> PredictionMode:
430
+ """Parse and validate prediction mode from params."""
431
+ mode_str = params.get('mode', PredictionMode.DEFAULT.value)
432
+ try:
433
+ is_all_image = all(res.mimetype.startswith('image/') for res in model_input) if model_input else False
434
+ except Exception:
435
+ is_all_image = False
436
+
437
+ logger.debug(f"Parsing prediction mode: '{mode_str}' | {is_all_image=}")
438
+
439
+ if mode_str == PredictionMode.DEFAULT.value and is_all_image:
440
+ mode_str = PredictionMode.IMAGE.value
441
+
442
+ try:
443
+ return PredictionMode(mode_str)
444
+ except ValueError:
445
+ valid_modes = [m.value for m in PredictionMode]
446
+ raise ValueError(
447
+ f"Invalid prediction mode: '{mode_str}'\n"
448
+ f"Valid modes: {', '.join(valid_modes)}"
449
+ )
450
+
451
+ def _route_prediction(self,
452
+ model_input: list[Resource],
453
+ mode: PredictionMode,
454
+ params: dict[str, Any]) -> PredictionResult:
455
+ """Route to the appropriate prediction method based on mode."""
456
+
457
+ # Extract mode-specific parameters and remove from kwargs
458
+ mode_params, common_params = self._extract_mode_params(mode, params)
459
+
460
+ # Dispatch to appropriate method
461
+ method = self._get_method_for_mode(mode)
462
+
463
+ # if method is None or not self._is_mode_implemented(mode, method):
464
+ # raise NotImplementedError
465
+
466
+ # Call with explicit parameters
467
+ return method(model_input, **mode_params, **common_params)
468
+
469
+ def _extract_mode_params(self, mode: PredictionMode, params: dict[str, Any]) -> tuple[dict, dict]:
470
+ """
471
+ Extract mode-specific and common parameters.
472
+
473
+ Returns:
474
+ Tuple of (mode_specific_params, common_params)
475
+ """
476
+ # Define mode-specific parameter mappings
477
+ mode_param_keys = {
478
+ PredictionMode.FRAME: ['frame_index'],
479
+ PredictionMode.FRAME_RANGE: ['start_frame', 'end_frame', 'step'],
480
+ PredictionMode.SLICE: ['slice_index', 'axis'],
481
+ PredictionMode.SLICE_RANGE: ['start_index', 'end_index', 'axis', 'step'],
482
+ PredictionMode.PRIMARY_SLICE: ['axis'],
483
+ PredictionMode.INTERACTIVE: ['prompt'],
484
+ PredictionMode.FEW_SHOT: ['context_resources', 'k'],
485
+ PredictionMode.TEMPORAL_SEQUENCE: ['center_frame', 'window_size'],
486
+ PredictionMode.IMAGE: [],
487
+ }
488
+
489
+ reserved_keys = {'mode', 'confidence_threshold'}
490
+
491
+ # Extract parameters
492
+ mode_specific = {}
493
+ common = {}
494
+
495
+ mode_keys = set(mode_param_keys.get(mode, ()))
496
+
497
+ for key, value in params.items():
498
+ if key in reserved_keys:
499
+ continue # Skip mode itself and post-processing-only params
500
+ if key in mode_keys:
501
+ mode_specific[key] = value
502
+ else:
503
+ common[key] = value
504
+
505
+ return mode_specific, common
506
+
507
+ def _get_method_for_mode(self, mode: PredictionMode):
508
+ """Get the method corresponding to the given prediction mode."""
509
+ method_name = f"predict_{mode.value}"
510
+ method = getattr(self, method_name, None)
511
+ return method
512
+
513
+ def get_supported_modes(self) -> list[str]:
514
+ """
515
+ Get list of prediction modes supported by this model.
516
+
517
+ Returns:
518
+ List of mode names (strings)
519
+ """
520
+ if self._supported_modes_cache is not None:
521
+ return self._supported_modes_cache
522
+
523
+ supported = []
524
+ for mode in PredictionMode:
525
+ if self._is_mode_implemented(mode):
526
+ supported.append(mode.value)
527
+
528
+ self._supported_modes_cache = supported
529
+ return supported
530
+
531
+ def _is_mode_implemented(self, mode: PredictionMode) -> bool:
532
+ """Determine whether the given mode has a concrete implementation."""
533
+ method = self._get_method_for_mode(mode)
534
+ if method is None:
535
+ return False
536
+
537
+ # Check if method is from DatamintModel base class (not overridden)
538
+ if hasattr(DatamintModel, method.__name__):
539
+ self._get_method_for_mode
540
+ base_method = getattr(DatamintModel, method.__name__)
541
+ # Method is implemented if it's not the same as base class method
542
+ return method.__func__ is not base_method
543
+
544
+ return True
545
+
546
+ def _post_process(self,
547
+ predictions: PredictionResult,
548
+ resources: list[Resource],
549
+ params: dict[str, Any]) -> PredictionResult:
550
+ """Apply common post-processing based on params."""
551
+
552
+ # Apply confidence threshold filtering
553
+ conf_threshold = params.get('confidence_threshold')
554
+ if conf_threshold is not None:
555
+ predictions = [
556
+ [ann for ann in pred_list
557
+ if getattr(ann, 'confiability', 1.0) >= conf_threshold]
558
+ for pred_list in predictions
559
+ ]
560
+ logger.debug(f"Applied confidence threshold: {conf_threshold}")
561
+
562
+ return predictions
563
+
564
+ def predict_default(self,
565
+ model_input: list[Resource],
566
+ **kwargs) -> PredictionResult:
567
+ """
568
+ **OPTIONAL**: Default prediction on entire resources.
569
+
570
+ This is the default mode and serves as fallback for unimplemented modes.
571
+ Override this method to implement default prediction behavior.
572
+ If called without being overridden, raises NotImplementedError.
573
+
574
+ Args:
575
+ model_input: Resources to process
576
+ **kwargs: Additional user-defined parameters
577
+
578
+ Returns:
579
+ List of annotation lists, one per resource
580
+
581
+ Example:
582
+ ```python
583
+ def predict_default(self, model_input, **kwargs):
584
+ dataset = MyDataset(model_input)
585
+ dataloader = DataLoader(dataset)
586
+ model = self.mlflow_models['model'].get_raw_model()
587
+
588
+ predictions = []
589
+ for batch in dataloader:
590
+ outputs = model(batch)
591
+ predictions.extend(self._outputs_to_annotations(outputs))
592
+
593
+ return predictions
594
+ ```
595
+ """
596
+ raise NotImplementedError(
597
+ "predict_default() must be implemented in your DatamintModel subclass. "
598
+ "This is the default fallback mode for prediction."
599
+ )
600
+
601
+ # ========================================================================
602
+ # VIDEO/TEMPORAL MODES
603
+ # ========================================================================
604
+
605
+ def predict_frame(self,
606
+ model_input: list[Resource],
607
+ frame_index: int,
608
+ **kwargs) -> PredictionResult:
609
+ """
610
+ Process specific frame from video resources.
611
+
612
+ Args:
613
+ model_input: Video resources
614
+ frame_index: Index of frame to extract and process (0-based)
615
+
616
+ Returns:
617
+ Annotations for the specified frame (one list per resource)
618
+
619
+ Example:
620
+ ```python
621
+ # Extract frame 42 from multiple videos
622
+ predictions = model.predict(
623
+ videos,
624
+ params={'mode': 'frame', 'frame_index': 42}
625
+ )
626
+ ```
627
+ """
628
+ logger.warning(f"predict_frame not implemented, falling back to predict_default")
629
+ return self.predict_default(model_input, **kwargs)
630
+
631
+ def predict_frame_range(self,
632
+ model_input: list[Resource],
633
+ start_frame: int,
634
+ end_frame: int,
635
+ step: int = 1,
636
+ **kwargs) -> PredictionResult:
637
+ """
638
+ Process range of frames from video resources.
639
+
640
+ Args:
641
+ model_input: Video resources
642
+ start_frame: Start frame index (inclusive)
643
+ end_frame: End frame index (inclusive)
644
+ step: Step size between frames (default: 1)
645
+
646
+ Returns:
647
+ Annotations for frames in range (may be frame-scoped annotations)
648
+
649
+ Example:
650
+ ```python
651
+ # Process frames 0-100, every 10th frame
652
+ predictions = model.predict(
653
+ videos,
654
+ params={'mode': 'frame_range',
655
+ 'start_frame': 0,
656
+ 'end_frame': 100,
657
+ 'step': 10}
658
+ )
659
+ ```
660
+ """
661
+ logger.warning(f"predict_frame_range not implemented, falling back to predict_default")
662
+ return self.predict_default(model_input, **kwargs)
663
+
664
+ def predict_frame_interval(self,
665
+ model_input: list[Resource],
666
+ interval: int,
667
+ start_frame: int = 0,
668
+ end_frame: int | None = None,
669
+ **kwargs) -> PredictionResult:
670
+ """
671
+ Process every nth frame from video resources.
672
+
673
+ Args:
674
+ model_input: Video resources
675
+ interval: Process every nth frame
676
+ start_frame: Starting frame index
677
+ end_frame: Ending frame index (None = last frame)
678
+
679
+ Returns:
680
+ Annotations for sampled frames
681
+
682
+ Example:
683
+ ```python
684
+ # Process every 30th frame (1 fps for 30fps video)
685
+ predictions = model.predict(
686
+ videos,
687
+ params={'mode': 'frame_interval', 'interval': 30}
688
+ )
689
+ ```
690
+ """
691
+ logger.warning(f"predict_frame_interval not implemented, falling back to predict_default")
692
+ return self.predict_default(model_input, **kwargs)
693
+
694
+ def predict_all_frames(self,
695
+ model_input: list[Resource],
696
+ **kwargs) -> PredictionResult:
697
+ """
698
+ Process all frames independently.
699
+
700
+ Args:
701
+ model_input: Video resources
702
+
703
+ Returns:
704
+ Annotations for all frames (likely frame-scoped)
705
+
706
+ Example:
707
+ ```python
708
+ # Analyze every frame
709
+ predictions = model.predict(
710
+ videos,
711
+ params={'mode': 'all_frames'}
712
+ )
713
+ ```
714
+ """
715
+ logger.warning(f"predict_all_frames not implemented, falling back to predict_default")
716
+ return self.predict_default(model_input, **kwargs)
717
+
718
+ # ========================================================================
719
+ # 3D VOLUME MODES
720
+ # ========================================================================
721
+
722
+ def predict_slice(self,
723
+ model_input: list[Resource],
724
+ slice_index: int,
725
+ axis: str = 'axial',
726
+ **kwargs) -> PredictionResult:
727
+ """
728
+ Process specific slice from 3D volume.
729
+
730
+ Args:
731
+ model_input: 3D volume resources (DICOM series, NIfTI, etc.)
732
+ slice_index: Index of slice to extract
733
+ axis: Anatomical axis ('axial', 'sagittal', 'coronal')
734
+
735
+ Returns:
736
+ Annotations for the specified slice
737
+
738
+ Example:
739
+ ```python
740
+ # Extract and analyze axial slice 50
741
+ predictions = model.predict(
742
+ ct_scans,
743
+ params={'mode': 'slice',
744
+ 'slice_index': 50,
745
+ 'axis': 'axial'}
746
+ )
747
+ ```
748
+ """
749
+ logger.warning(f"predict_slice not implemented, falling back to predict_default")
750
+ return self.predict_default(model_input, **kwargs)
751
+
752
+ def predict_slice_range(self,
753
+ model_input: list[Resource],
754
+ start_index: int,
755
+ end_index: int,
756
+ axis: str = 'axial',
757
+ step: int = 1,
758
+ **kwargs) -> PredictionResult:
759
+ """
760
+ Process range of slices from 3D volume.
761
+
762
+ Args:
763
+ model_input: 3D volume resources
764
+ start_index: Start slice index (inclusive)
765
+ end_index: End slice index (inclusive)
766
+ axis: Anatomical axis
767
+ step: Step size between slices
768
+
769
+ Returns:
770
+ Annotations for slices in range
771
+ """
772
+ logger.warning(f"predict_slice_range not implemented, falling back to predict_default")
773
+ return self.predict_default(model_input, **kwargs)
774
+
775
+ def predict_volume(self,
776
+ model_input: list[Resource],
777
+ **kwargs) -> PredictionResult:
778
+ """
779
+ Process entire 3D volume.
780
+
781
+ For true 3D models (not slice-by-slice).
782
+
783
+ Args:
784
+ model_input: 3D volume resources
785
+
786
+ Returns:
787
+ 3D annotations for entire volume
788
+ """
789
+ logger.warning(f"predict_volume not implemented, falling back to predict_default")
790
+ return self.predict_default(model_input, **kwargs)
791
+
792
+ # ========================================================================
793
+ # ADVANCED MODES
794
+ # ========================================================================
795
+
796
+ def predict_interactive(self,
797
+ model_input: list[Resource],
798
+ prompt: dict[str, Any],
799
+ **kwargs) -> PredictionResult:
800
+ """
801
+ Interactive prediction with user prompts.
802
+
803
+ For models like Segment Anything (SAM) that accept user guidance.
804
+
805
+ Args:
806
+ model_input: Resources to process
807
+ prompt: Prompt dictionary with keys:
808
+ - 'points': list of [x, y] coordinates
809
+ - 'labels': list of labels (1=foreground, 0=background)
810
+ - 'boxes': list of [x1, y1, x2, y2] bounding boxes
811
+ - 'masks': list of binary mask arrays
812
+
813
+ Returns:
814
+ Annotations based on prompts
815
+
816
+ Example:
817
+ ```python
818
+ # Segment based on positive and negative points
819
+ predictions = model.predict(
820
+ images,
821
+ params={'mode': 'interactive',
822
+ 'prompt': {
823
+ 'points': [[100, 150], [200, 250]],
824
+ 'labels': [1, 0] # foreground, background
825
+ }}
826
+ )
827
+ ```
828
+ """
829
+ logger.warning(f"predict_interactive not implemented, falling back to predict_default")
830
+ return self.predict_default(model_input, **kwargs)
831
+
832
+ def predict_few_shot(self,
833
+ model_input: list[Resource],
834
+ context_resources: list[Resource],
835
+ k: int = 5,
836
+ **kwargs) -> PredictionResult:
837
+ """
838
+ Few-shot prediction with context examples.
839
+
840
+ For models that can adapt based on a few labeled examples.
841
+
842
+ Args:
843
+ model_input: Resources to annotate
844
+ context_resources: Resources with existing annotations to use as examples
845
+ k: Number of examples to use (if more are provided)
846
+
847
+ Returns:
848
+ Annotations informed by context examples
849
+
850
+ Example:
851
+ ```python
852
+ # Predict using similar annotated examples
853
+ predictions = model.predict(
854
+ new_images,
855
+ params={'mode': 'few_shot',
856
+ 'context_resources': annotated_examples,
857
+ 'k': 3}
858
+ )
859
+ ```
860
+ """
861
+ logger.warning(f"predict_few_shot not implemented, falling back to predict_default")
862
+ return self.predict_default(model_input, **kwargs)
863
+
864
+ def predict_image(self,
865
+ model_input: list[Resource],
866
+ **kwargs) -> PredictionResult:
867
+ """
868
+ Process single 2D image resources.
869
+
870
+ Args:
871
+ model_input: 2D image resources
872
+
873
+ Returns:
874
+ Annotations for each image
875
+ """
876
+ logger.warning(f"predict_image not implemented, falling back to predict_default")
877
+ return self.predict_default(model_input, **kwargs)