datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamint/__init__.py +1 -3
- datamint/api/__init__.py +0 -3
- datamint/api/base_api.py +286 -54
- datamint/api/client.py +76 -13
- datamint/api/endpoints/__init__.py +2 -2
- datamint/api/endpoints/annotations_api.py +186 -28
- datamint/api/endpoints/deploy_model_api.py +78 -0
- datamint/api/endpoints/models_api.py +1 -0
- datamint/api/endpoints/projects_api.py +38 -7
- datamint/api/endpoints/resources_api.py +227 -100
- datamint/api/entity_base_api.py +66 -7
- datamint/apihandler/base_api_handler.py +0 -1
- datamint/apihandler/dto/annotation_dto.py +2 -0
- datamint/client_cmd_tools/datamint_config.py +0 -1
- datamint/client_cmd_tools/datamint_upload.py +3 -1
- datamint/configs.py +11 -7
- datamint/dataset/base_dataset.py +24 -4
- datamint/dataset/dataset.py +1 -1
- datamint/entities/__init__.py +1 -1
- datamint/entities/annotations/__init__.py +13 -0
- datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
- datamint/entities/annotations/image_classification.py +12 -0
- datamint/entities/annotations/image_segmentation.py +252 -0
- datamint/entities/annotations/volume_segmentation.py +273 -0
- datamint/entities/base_entity.py +100 -6
- datamint/entities/cache_manager.py +129 -15
- datamint/entities/datasetinfo.py +60 -65
- datamint/entities/deployjob.py +18 -0
- datamint/entities/project.py +39 -0
- datamint/entities/resource.py +310 -46
- datamint/lightning/__init__.py +1 -0
- datamint/lightning/datamintdatamodule.py +103 -0
- datamint/mlflow/__init__.py +65 -0
- datamint/mlflow/artifact/__init__.py +1 -0
- datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
- datamint/mlflow/env_utils.py +131 -0
- datamint/mlflow/env_vars.py +5 -0
- datamint/mlflow/flavors/__init__.py +17 -0
- datamint/mlflow/flavors/datamint_flavor.py +150 -0
- datamint/mlflow/flavors/model.py +877 -0
- datamint/mlflow/lightning/callbacks/__init__.py +1 -0
- datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
- datamint/mlflow/models/__init__.py +93 -0
- datamint/mlflow/tracking/datamint_store.py +76 -0
- datamint/mlflow/tracking/default_experiment.py +27 -0
- datamint/mlflow/tracking/fluent.py +91 -0
- datamint/utils/env.py +27 -0
- datamint/utils/visualization.py +21 -13
- datamint-2.9.0.dist-info/METADATA +220 -0
- datamint-2.9.0.dist-info/RECORD +73 -0
- {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
- datamint-2.9.0.dist-info/entry_points.txt +18 -0
- datamint/apihandler/exp_api_handler.py +0 -204
- datamint/experiment/__init__.py +0 -1
- datamint/experiment/_patcher.py +0 -570
- datamint/experiment/experiment.py +0 -1049
- datamint-2.3.3.dist-info/METADATA +0 -125
- datamint-2.3.3.dist-info/RECORD +0 -54
- datamint-2.3.3.dist-info/entry_points.txt +0 -4
|
@@ -0,0 +1,877 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DataMint Model Adapter Module
|
|
3
|
+
|
|
4
|
+
This module provides a flexible framework for wrapping ML models to work with DataMint's
|
|
5
|
+
annotation system. It supports various prediction modes for different data types and use cases.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, TypeAlias
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from mlflow.environment_variables import MLFLOW_DEFAULT_PREDICTION_DEVICE
|
|
14
|
+
from mlflow.pyfunc import load_model as pyfunc_load_model
|
|
15
|
+
from mlflow.pytorch import load_model as pytorch_load_model
|
|
16
|
+
from mlflow.pyfunc import PyFuncModel, PythonModel, PythonModelContext
|
|
17
|
+
from datamint.entities.annotations import Annotation
|
|
18
|
+
from datamint.entities.resource import Resource
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Type aliases
|
|
25
|
+
PredictionResult: TypeAlias = list[list[Annotation]]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class ModelSettings:
|
|
30
|
+
"""
|
|
31
|
+
Deployment and inference configuration for DatamintModel.
|
|
32
|
+
|
|
33
|
+
These settings are serialized with the model and used by remote MLflow servers
|
|
34
|
+
to properly configure the runtime environment.
|
|
35
|
+
"""
|
|
36
|
+
# Hardware requirements
|
|
37
|
+
need_gpu: bool = False
|
|
38
|
+
"""Whether GPU is required for inference"""
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_dict(cls, data: dict[str, Any]) -> 'ModelSettings':
|
|
42
|
+
"""Create config from dictionary, raising error on unknown keys."""
|
|
43
|
+
valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
|
|
44
|
+
invalid_fields = set(data.keys()) - valid_fields
|
|
45
|
+
if invalid_fields:
|
|
46
|
+
raise ValueError(f"Invalid fields for ModelSettings: {', '.join(sorted(invalid_fields))}")
|
|
47
|
+
return cls(**data)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class PredictionMode(str, Enum):
|
|
51
|
+
"""
|
|
52
|
+
Enumeration of supported prediction modes.
|
|
53
|
+
|
|
54
|
+
Each mode corresponds to a specific method signature in DatamintModel.
|
|
55
|
+
"""
|
|
56
|
+
# Standard modes
|
|
57
|
+
DEFAULT = 'default' # Default: process entire resource as-is
|
|
58
|
+
|
|
59
|
+
# Simple modes
|
|
60
|
+
IMAGE = 'image' # Process single 2d image resource
|
|
61
|
+
|
|
62
|
+
# Video/temporal modes
|
|
63
|
+
FRAME = 'frame' # Extract and process specific frame
|
|
64
|
+
FRAME_RANGE = 'frame_range' # Process contiguous frame range
|
|
65
|
+
ALL_FRAMES = 'all_frames' # Process all frames independently
|
|
66
|
+
TEMPORAL_SEQUENCE = 'temporal_sequence' # Process with temporal context window
|
|
67
|
+
|
|
68
|
+
# 3D volume modes
|
|
69
|
+
SLICE = 'slice' # Extract and process specific slice
|
|
70
|
+
SLICE_RANGE = 'slice_range' # Process contiguous slice range
|
|
71
|
+
PRIMARY_SLICE = 'primary_slice' # Process center/primary slice
|
|
72
|
+
# MULTI_PLANE = 'multi_plane' # Process multiple anatomical planes
|
|
73
|
+
VOLUME = 'volume' # Process entire 3D volume
|
|
74
|
+
|
|
75
|
+
# Spatial modes
|
|
76
|
+
# ROI = 'roi' # Process single region of interest
|
|
77
|
+
# MULTI_ROI = 'multi_roi' # Process multiple regions
|
|
78
|
+
# TILE = 'tile' # Split into tiles (whole slide imaging)
|
|
79
|
+
# PATCH = 'patch' # Extract patches around points
|
|
80
|
+
|
|
81
|
+
# Advanced modes
|
|
82
|
+
INTERACTIVE = 'interactive' # With user prompts (SAM-like)
|
|
83
|
+
FEW_SHOT = 'few_shot' # With context examples
|
|
84
|
+
# MULTI_VIEW = 'multi_view' # Multiple views of same subject
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class DatamintModel(ABC, PythonModel):
|
|
88
|
+
"""
|
|
89
|
+
Abstract adapter class for wrapping models to produce Datamint annotations.
|
|
90
|
+
|
|
91
|
+
This class provides a flexible framework for integrating ML models with DataMint.
|
|
92
|
+
The main `predict()` method routes requests to specific handlers based on the
|
|
93
|
+
prediction mode, allowing users to implement only the modes they need.
|
|
94
|
+
|
|
95
|
+
Quick Start:
|
|
96
|
+
-----------
|
|
97
|
+
```python
|
|
98
|
+
class MyModel(DatamintModel):
|
|
99
|
+
def __init__(self):
|
|
100
|
+
super().__init__(
|
|
101
|
+
mlflow_models_uri={'model': 'models:/MyModel/latest'},
|
|
102
|
+
config=ModelSettings(need_gpu=True)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def predict_default(self, model_input, **kwargs):
|
|
106
|
+
# Access the device for your computation
|
|
107
|
+
device = self.inference_device # Reads from MLFLOW_DEFAULT_PREDICTION_DEVICE or defaults to 'cpu'
|
|
108
|
+
model = self.mlflow_models['model'].get_raw_model().to(device)
|
|
109
|
+
# ... process and return annotations
|
|
110
|
+
return predictions
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
Prediction Modes:
|
|
114
|
+
----------------
|
|
115
|
+
Users can request different prediction modes via params['mode']:
|
|
116
|
+
|
|
117
|
+
**Default**: Default processing
|
|
118
|
+
```python
|
|
119
|
+
model.predict(resources) # or params={'mode': 'default'}
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
**Video Frame**: Extract specific frame
|
|
123
|
+
```python
|
|
124
|
+
model.predict(videos, params={'mode': 'frame', 'frame_index': 42})
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
**3D Slice**: Extract specific slice
|
|
128
|
+
```python
|
|
129
|
+
model.predict(volumes, params={'mode': 'slice', 'slice_index': 50, 'axis': 'axial'})
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
**Interactive**: With prompts
|
|
133
|
+
```python
|
|
134
|
+
model.predict(images, params={'mode': 'interactive', 'prompt': {'points': [[x, y]], 'labels': [1]}})
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
Common Parameters:
|
|
138
|
+
-----------------
|
|
139
|
+
- `confidence_threshold` (float): Filter predictions by confidence score
|
|
140
|
+
- `batch_size` (int): Batch size for processing
|
|
141
|
+
- `render_annotation` (bool): Return annotated images instead of annotations
|
|
142
|
+
|
|
143
|
+
Device Configuration:
|
|
144
|
+
--------------------
|
|
145
|
+
The device for computation is automatically configured from the
|
|
146
|
+
`MLFLOW_DEFAULT_PREDICTION_DEVICE` environment variable. Access it via `self.inference_device`.
|
|
147
|
+
Defaults to 'cpu' if not set.
|
|
148
|
+
|
|
149
|
+
Implementation Guide:
|
|
150
|
+
--------------------
|
|
151
|
+
1. Implement `predict_default()` - this is required and serves as fallback
|
|
152
|
+
2. Optionally implement specific modes your model supports
|
|
153
|
+
3. Override `_render_annotations()` if you want to support visualization
|
|
154
|
+
4. Use `self.mlflow_models` to access loaded MLflow models
|
|
155
|
+
5. Configure deployment settings via `ModelSettings`
|
|
156
|
+
|
|
157
|
+
See individual method docstrings for detailed parameter specifications.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
LINKED_MODELS_DIR = "linked_models"
|
|
161
|
+
_CACHED_ATTRS = ['_mlflow_models', '_mlflow_torch_models', '_inference_device']
|
|
162
|
+
|
|
163
|
+
def __init__(self,
|
|
164
|
+
settings: ModelSettings | dict[str, Any] | None = None,
|
|
165
|
+
mlflow_torch_models_uri: dict[str, str] | None = None,
|
|
166
|
+
mlflow_models_uri: dict[str, str] | None = None,
|
|
167
|
+
) -> None:
|
|
168
|
+
"""
|
|
169
|
+
Initialize the DatamintModel adapter.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
config: ModelSettings instance or dict with deployment settings.
|
|
173
|
+
Example: {'need_gpu': True}
|
|
174
|
+
mlflow_torch_models_uri: Dictionary mapping model names to PyTorch model URIs.
|
|
175
|
+
Example: {'backbone': 'models:/MyClassifier/2'}
|
|
176
|
+
These models will be lazy-loaded and accessible via ``self.mlflow_torch_models_uri['backbone']``
|
|
177
|
+
mlflow_models_uri: Dictionary mapping model names to MLflow URIs.
|
|
178
|
+
Example: {'detector': 'models:/MyDetector/1',
|
|
179
|
+
'classifier': 'models:/MyClassifier/latest'}
|
|
180
|
+
These models will be lazy-loaded and accessible via ``self.mlflow_models['detector']``
|
|
181
|
+
|
|
182
|
+
"""
|
|
183
|
+
super().__init__()
|
|
184
|
+
self.mlflow_models_uri = (mlflow_models_uri or {}).copy()
|
|
185
|
+
self.mlflow_torch_models_uri = (mlflow_torch_models_uri or {}).copy()
|
|
186
|
+
|
|
187
|
+
# Handle settings - convert dict to ModelSettings if needed
|
|
188
|
+
if isinstance(settings, dict):
|
|
189
|
+
self.settings = ModelSettings.from_dict(settings)
|
|
190
|
+
elif isinstance(settings, ModelSettings):
|
|
191
|
+
self.settings = settings
|
|
192
|
+
else:
|
|
193
|
+
self.settings = ModelSettings()
|
|
194
|
+
|
|
195
|
+
self._supported_modes_cache = None
|
|
196
|
+
|
|
197
|
+
def load_context(self, context: PythonModelContext):
|
|
198
|
+
"""
|
|
199
|
+
Called by MLflow when loading the model.
|
|
200
|
+
|
|
201
|
+
Override this if you need custom loading logic.
|
|
202
|
+
"""
|
|
203
|
+
self._inference_device = self._load_inference_device(context=context)
|
|
204
|
+
self._mlflow_models = self._load_mlflow_models()
|
|
205
|
+
self._mlflow_torch_models = self._load_mlflow_torch_models()
|
|
206
|
+
# model_config = context.model_config
|
|
207
|
+
|
|
208
|
+
def _get_linked_models_uri(self) -> dict[str, Any]:
|
|
209
|
+
"""Get all linked models (MLflow and PyTorch)"""
|
|
210
|
+
linked = {}
|
|
211
|
+
linked.update(self.mlflow_models_uri)
|
|
212
|
+
linked.update(self.mlflow_torch_models_uri)
|
|
213
|
+
return linked
|
|
214
|
+
|
|
215
|
+
def _clear_linked_models_cache(self):
|
|
216
|
+
"""Clear loaded linked models to free memory"""
|
|
217
|
+
|
|
218
|
+
for attr in self._CACHED_ATTRS:
|
|
219
|
+
if hasattr(self, attr):
|
|
220
|
+
delattr(self, attr)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def __getstate__(self):
|
|
224
|
+
state = self.__dict__.copy()
|
|
225
|
+
|
|
226
|
+
for attr in self._CACHED_ATTRS:
|
|
227
|
+
if attr in state:
|
|
228
|
+
del state[attr]
|
|
229
|
+
|
|
230
|
+
return state
|
|
231
|
+
|
|
232
|
+
def __setstate__(self, state):
|
|
233
|
+
self.__dict__.update(state)
|
|
234
|
+
# avoid possible invalid states after unpickling
|
|
235
|
+
self._clear_linked_models_cache()
|
|
236
|
+
|
|
237
|
+
def _load_inference_device(self, context: PythonModelContext | None = None) -> str:
|
|
238
|
+
"""
|
|
239
|
+
Load inference device from model config or environment variable.
|
|
240
|
+
"""
|
|
241
|
+
import torch
|
|
242
|
+
|
|
243
|
+
device = None
|
|
244
|
+
if context and context.model_config:
|
|
245
|
+
device = context.model_config.get("device", None)
|
|
246
|
+
logger.info(f"Model config device: {device}")
|
|
247
|
+
if device is None:
|
|
248
|
+
env_device = MLFLOW_DEFAULT_PREDICTION_DEVICE.get()
|
|
249
|
+
if env_device:
|
|
250
|
+
device = env_device
|
|
251
|
+
elif torch.cuda.is_available():
|
|
252
|
+
device = 'cuda'
|
|
253
|
+
else:
|
|
254
|
+
device = 'cpu'
|
|
255
|
+
|
|
256
|
+
logger.info(f"Set inference device: {device}")
|
|
257
|
+
return device
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def inference_device(self) -> str:
|
|
261
|
+
if hasattr(self, '_inference_device') and self._inference_device is not None:
|
|
262
|
+
return self._inference_device
|
|
263
|
+
env_device = MLFLOW_DEFAULT_PREDICTION_DEVICE.get()
|
|
264
|
+
if env_device:
|
|
265
|
+
logger.info(f"Inference device not set; getting from environment variable ({env_device})")
|
|
266
|
+
return env_device
|
|
267
|
+
logger.warning("Inference device not set; defaulting to 'cpu'")
|
|
268
|
+
return 'cpu'
|
|
269
|
+
|
|
270
|
+
def _load_models_generic(self, uris: dict[str, str],
|
|
271
|
+
loader_func: Callable,
|
|
272
|
+
**loader_kwargs) -> dict[str, Any]:
|
|
273
|
+
"""Generic helper to load models from URIs."""
|
|
274
|
+
loaded_models = {}
|
|
275
|
+
for name, uri in uris.items():
|
|
276
|
+
model_uri = uri
|
|
277
|
+
if os.path.exists(uri):
|
|
278
|
+
logger.info(f"Model '{name}' found locally at '{uri}'")
|
|
279
|
+
model_uri = os.path.abspath(uri)
|
|
280
|
+
elif uri.startswith("models:/"):
|
|
281
|
+
local_path = uri.replace("models:/", DatamintModel.LINKED_MODELS_DIR + "/", 1)
|
|
282
|
+
if os.path.exists(local_path):
|
|
283
|
+
logger.info(f"Model '{name}' found locally at '{local_path}'")
|
|
284
|
+
model_uri = os.path.abspath(local_path)
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
loaded_models[name] = loader_func(model_uri, **loader_kwargs)
|
|
288
|
+
logger.info(f"Loaded model '{name}' from {model_uri}")
|
|
289
|
+
except Exception as e:
|
|
290
|
+
logger.error(f"Failed to load model '{name}' from {model_uri}: {e}")
|
|
291
|
+
raise
|
|
292
|
+
return loaded_models
|
|
293
|
+
|
|
294
|
+
def _load_mlflow_models(self) -> dict[str, PyFuncModel]:
|
|
295
|
+
"""Load all MLflow models specified in mlflow_models_uri."""
|
|
296
|
+
return self._load_models_generic(
|
|
297
|
+
self.mlflow_models_uri,
|
|
298
|
+
pyfunc_load_model,
|
|
299
|
+
model_config={'device': self.inference_device}
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def _load_mlflow_torch_models(self) -> dict[str, Any]:
|
|
303
|
+
"""Load all MLflow PyTorch models specified in mlflow_torch_models_uri."""
|
|
304
|
+
models = self._load_models_generic(
|
|
305
|
+
self.mlflow_torch_models_uri,
|
|
306
|
+
pytorch_load_model,
|
|
307
|
+
device=self.inference_device,
|
|
308
|
+
map_location=self.inference_device,
|
|
309
|
+
)
|
|
310
|
+
for m in models.values():
|
|
311
|
+
if hasattr(m, 'eval'):
|
|
312
|
+
m.eval()
|
|
313
|
+
return models
|
|
314
|
+
|
|
315
|
+
def get_mlflow_models(self) -> dict[str, PyFuncModel]:
|
|
316
|
+
"""
|
|
317
|
+
Access loaded MLflow models.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
Dictionary mapping model names to PyFuncModel instances.
|
|
321
|
+
Use .get_raw_model() to access the underlying model (e.g., torch.nn.Module)
|
|
322
|
+
"""
|
|
323
|
+
if not hasattr(self, '_mlflow_models'):
|
|
324
|
+
logger.warning("Loading MLflow models on first access")
|
|
325
|
+
self._mlflow_models = self._load_mlflow_models()
|
|
326
|
+
return self._mlflow_models
|
|
327
|
+
|
|
328
|
+
def get_mlflow_torch_models(self) -> dict[str, Any]:
|
|
329
|
+
"""
|
|
330
|
+
Access loaded MLflow PyTorch models.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
Dictionary mapping model names to PyTorch model instances.
|
|
334
|
+
"""
|
|
335
|
+
if not hasattr(self, '_mlflow_torch_models'):
|
|
336
|
+
logger.warning("Loading MLflow PyTorch models on first access")
|
|
337
|
+
self._mlflow_torch_models = self._load_mlflow_torch_models()
|
|
338
|
+
return self._mlflow_torch_models
|
|
339
|
+
|
|
340
|
+
# def _preprocess_input(self,
|
|
341
|
+
# model_input: list[InferenceResource | Resource | dict[str, Any]],
|
|
342
|
+
# params: dict[str, Any]) -> list[Resource]:
|
|
343
|
+
# """
|
|
344
|
+
# Preprocess input to convert to list of Resource objects.
|
|
345
|
+
|
|
346
|
+
# Args:
|
|
347
|
+
# model_input: List of InferenceResource, Resource, or dict
|
|
348
|
+
# params: Additional parameters (unused here)
|
|
349
|
+
# Returns:
|
|
350
|
+
# List of Resource objects
|
|
351
|
+
# """
|
|
352
|
+
# resources = []
|
|
353
|
+
# for item in model_input:
|
|
354
|
+
# if isinstance(item, Resource):
|
|
355
|
+
# resources.append(item)
|
|
356
|
+
# elif isinstance(item, InferenceResource):
|
|
357
|
+
# resources.append(item.fabricate_resource())
|
|
358
|
+
# elif isinstance(item, dict):
|
|
359
|
+
# if 'local_filepath' in item or item.get('id', None) == '':
|
|
360
|
+
# logger.debug(f'Creating LocalResource from dict: {item}')
|
|
361
|
+
# resources.append(LocalResource(local_filepath=item['local_filepath']))
|
|
362
|
+
# elif 'upload_channel' in item or 'location' in item or 'storage' in item:
|
|
363
|
+
# resources.append(Resource(**item))
|
|
364
|
+
# else:
|
|
365
|
+
# resources.append(InferenceResource(**item).fabricate_resource())
|
|
366
|
+
# else:
|
|
367
|
+
# raise ValueError(f"Unsupported input type: {type(item)}")
|
|
368
|
+
# return resources
|
|
369
|
+
|
|
370
|
+
def predict(self,
|
|
371
|
+
model_input: list[Resource],
|
|
372
|
+
params: dict[str, Any] | None = None) -> PredictionResult:
|
|
373
|
+
"""
|
|
374
|
+
Main prediction entry point.
|
|
375
|
+
|
|
376
|
+
Routes to appropriate prediction method based on params['mode'].
|
|
377
|
+
DO NOT override this method - implement specific predict_* methods instead.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
model_input: List of Resource objects to process
|
|
381
|
+
params: Optional configuration dictionary with keys:
|
|
382
|
+
- mode (str): Prediction mode (default: 'standard')
|
|
383
|
+
- confidence_threshold (float): Filter by confidence
|
|
384
|
+
- batch_size (int): Batch size for processing
|
|
385
|
+
- render_annotation (bool): Return rendered images
|
|
386
|
+
- device (str): Computation device
|
|
387
|
+
+ mode-specific parameters (see individual method docs)
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
List of annotation lists (one per resource), or rendered outputs
|
|
391
|
+
if render_annotation=True
|
|
392
|
+
|
|
393
|
+
Raises:
|
|
394
|
+
ValueError: If mode is invalid or required parameters are missing
|
|
395
|
+
NotImplementedError: If requested mode is not implemented
|
|
396
|
+
"""
|
|
397
|
+
params = params or {}
|
|
398
|
+
# model_input = self._preprocess_input(model_input, params)
|
|
399
|
+
|
|
400
|
+
# Parse and validate mode
|
|
401
|
+
mode = self._parse_mode(model_input=model_input, params=params)
|
|
402
|
+
|
|
403
|
+
# Route to appropriate prediction method
|
|
404
|
+
try:
|
|
405
|
+
if not self._is_mode_implemented(mode):
|
|
406
|
+
if self._is_mode_implemented(PredictionMode.DEFAULT):
|
|
407
|
+
logger.info(f"Mode '{mode.value}' not implemented, falling back to default")
|
|
408
|
+
mode = PredictionMode.DEFAULT
|
|
409
|
+
else:
|
|
410
|
+
raise NotImplementedError
|
|
411
|
+
logger.debug(f"Routing to '{mode.value}' mode for {len(model_input)} resources")
|
|
412
|
+
result = self._route_prediction(model_input, mode, params)
|
|
413
|
+
|
|
414
|
+
# Apply common post-processing
|
|
415
|
+
result = self._post_process(result, model_input, params)
|
|
416
|
+
|
|
417
|
+
return result
|
|
418
|
+
|
|
419
|
+
except NotImplementedError:
|
|
420
|
+
available = self.get_supported_modes()
|
|
421
|
+
raise NotImplementedError(
|
|
422
|
+
f"Prediction mode '{mode.value}' is not supported by this model.\n"
|
|
423
|
+
f"Supported modes: {', '.join(available)}\n"
|
|
424
|
+
f"Implement predict_{mode.value}() to add support for this mode."
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
def _parse_mode(self,
|
|
428
|
+
params: dict[str, Any],
|
|
429
|
+
model_input: list[Resource] | None = None) -> PredictionMode:
|
|
430
|
+
"""Parse and validate prediction mode from params."""
|
|
431
|
+
mode_str = params.get('mode', PredictionMode.DEFAULT.value)
|
|
432
|
+
try:
|
|
433
|
+
is_all_image = all(res.mimetype.startswith('image/') for res in model_input) if model_input else False
|
|
434
|
+
except Exception:
|
|
435
|
+
is_all_image = False
|
|
436
|
+
|
|
437
|
+
logger.debug(f"Parsing prediction mode: '{mode_str}' | {is_all_image=}")
|
|
438
|
+
|
|
439
|
+
if mode_str == PredictionMode.DEFAULT.value and is_all_image:
|
|
440
|
+
mode_str = PredictionMode.IMAGE.value
|
|
441
|
+
|
|
442
|
+
try:
|
|
443
|
+
return PredictionMode(mode_str)
|
|
444
|
+
except ValueError:
|
|
445
|
+
valid_modes = [m.value for m in PredictionMode]
|
|
446
|
+
raise ValueError(
|
|
447
|
+
f"Invalid prediction mode: '{mode_str}'\n"
|
|
448
|
+
f"Valid modes: {', '.join(valid_modes)}"
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
def _route_prediction(self,
|
|
452
|
+
model_input: list[Resource],
|
|
453
|
+
mode: PredictionMode,
|
|
454
|
+
params: dict[str, Any]) -> PredictionResult:
|
|
455
|
+
"""Route to the appropriate prediction method based on mode."""
|
|
456
|
+
|
|
457
|
+
# Extract mode-specific parameters and remove from kwargs
|
|
458
|
+
mode_params, common_params = self._extract_mode_params(mode, params)
|
|
459
|
+
|
|
460
|
+
# Dispatch to appropriate method
|
|
461
|
+
method = self._get_method_for_mode(mode)
|
|
462
|
+
|
|
463
|
+
# if method is None or not self._is_mode_implemented(mode, method):
|
|
464
|
+
# raise NotImplementedError
|
|
465
|
+
|
|
466
|
+
# Call with explicit parameters
|
|
467
|
+
return method(model_input, **mode_params, **common_params)
|
|
468
|
+
|
|
469
|
+
def _extract_mode_params(self, mode: PredictionMode, params: dict[str, Any]) -> tuple[dict, dict]:
|
|
470
|
+
"""
|
|
471
|
+
Extract mode-specific and common parameters.
|
|
472
|
+
|
|
473
|
+
Returns:
|
|
474
|
+
Tuple of (mode_specific_params, common_params)
|
|
475
|
+
"""
|
|
476
|
+
# Define mode-specific parameter mappings
|
|
477
|
+
mode_param_keys = {
|
|
478
|
+
PredictionMode.FRAME: ['frame_index'],
|
|
479
|
+
PredictionMode.FRAME_RANGE: ['start_frame', 'end_frame', 'step'],
|
|
480
|
+
PredictionMode.SLICE: ['slice_index', 'axis'],
|
|
481
|
+
PredictionMode.SLICE_RANGE: ['start_index', 'end_index', 'axis', 'step'],
|
|
482
|
+
PredictionMode.PRIMARY_SLICE: ['axis'],
|
|
483
|
+
PredictionMode.INTERACTIVE: ['prompt'],
|
|
484
|
+
PredictionMode.FEW_SHOT: ['context_resources', 'k'],
|
|
485
|
+
PredictionMode.TEMPORAL_SEQUENCE: ['center_frame', 'window_size'],
|
|
486
|
+
PredictionMode.IMAGE: [],
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
reserved_keys = {'mode', 'confidence_threshold'}
|
|
490
|
+
|
|
491
|
+
# Extract parameters
|
|
492
|
+
mode_specific = {}
|
|
493
|
+
common = {}
|
|
494
|
+
|
|
495
|
+
mode_keys = set(mode_param_keys.get(mode, ()))
|
|
496
|
+
|
|
497
|
+
for key, value in params.items():
|
|
498
|
+
if key in reserved_keys:
|
|
499
|
+
continue # Skip mode itself and post-processing-only params
|
|
500
|
+
if key in mode_keys:
|
|
501
|
+
mode_specific[key] = value
|
|
502
|
+
else:
|
|
503
|
+
common[key] = value
|
|
504
|
+
|
|
505
|
+
return mode_specific, common
|
|
506
|
+
|
|
507
|
+
def _get_method_for_mode(self, mode: PredictionMode):
|
|
508
|
+
"""Get the method corresponding to the given prediction mode."""
|
|
509
|
+
method_name = f"predict_{mode.value}"
|
|
510
|
+
method = getattr(self, method_name, None)
|
|
511
|
+
return method
|
|
512
|
+
|
|
513
|
+
def get_supported_modes(self) -> list[str]:
|
|
514
|
+
"""
|
|
515
|
+
Get list of prediction modes supported by this model.
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
List of mode names (strings)
|
|
519
|
+
"""
|
|
520
|
+
if self._supported_modes_cache is not None:
|
|
521
|
+
return self._supported_modes_cache
|
|
522
|
+
|
|
523
|
+
supported = []
|
|
524
|
+
for mode in PredictionMode:
|
|
525
|
+
if self._is_mode_implemented(mode):
|
|
526
|
+
supported.append(mode.value)
|
|
527
|
+
|
|
528
|
+
self._supported_modes_cache = supported
|
|
529
|
+
return supported
|
|
530
|
+
|
|
531
|
+
def _is_mode_implemented(self, mode: PredictionMode) -> bool:
|
|
532
|
+
"""Determine whether the given mode has a concrete implementation."""
|
|
533
|
+
method = self._get_method_for_mode(mode)
|
|
534
|
+
if method is None:
|
|
535
|
+
return False
|
|
536
|
+
|
|
537
|
+
# Check if method is from DatamintModel base class (not overridden)
|
|
538
|
+
if hasattr(DatamintModel, method.__name__):
|
|
539
|
+
self._get_method_for_mode
|
|
540
|
+
base_method = getattr(DatamintModel, method.__name__)
|
|
541
|
+
# Method is implemented if it's not the same as base class method
|
|
542
|
+
return method.__func__ is not base_method
|
|
543
|
+
|
|
544
|
+
return True
|
|
545
|
+
|
|
546
|
+
def _post_process(self,
|
|
547
|
+
predictions: PredictionResult,
|
|
548
|
+
resources: list[Resource],
|
|
549
|
+
params: dict[str, Any]) -> PredictionResult:
|
|
550
|
+
"""Apply common post-processing based on params."""
|
|
551
|
+
|
|
552
|
+
# Apply confidence threshold filtering
|
|
553
|
+
conf_threshold = params.get('confidence_threshold')
|
|
554
|
+
if conf_threshold is not None:
|
|
555
|
+
predictions = [
|
|
556
|
+
[ann for ann in pred_list
|
|
557
|
+
if getattr(ann, 'confiability', 1.0) >= conf_threshold]
|
|
558
|
+
for pred_list in predictions
|
|
559
|
+
]
|
|
560
|
+
logger.debug(f"Applied confidence threshold: {conf_threshold}")
|
|
561
|
+
|
|
562
|
+
return predictions
|
|
563
|
+
|
|
564
|
+
def predict_default(self,
|
|
565
|
+
model_input: list[Resource],
|
|
566
|
+
**kwargs) -> PredictionResult:
|
|
567
|
+
"""
|
|
568
|
+
**OPTIONAL**: Default prediction on entire resources.
|
|
569
|
+
|
|
570
|
+
This is the default mode and serves as fallback for unimplemented modes.
|
|
571
|
+
Override this method to implement default prediction behavior.
|
|
572
|
+
If called without being overridden, raises NotImplementedError.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
model_input: Resources to process
|
|
576
|
+
**kwargs: Additional user-defined parameters
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
List of annotation lists, one per resource
|
|
580
|
+
|
|
581
|
+
Example:
|
|
582
|
+
```python
|
|
583
|
+
def predict_default(self, model_input, **kwargs):
|
|
584
|
+
dataset = MyDataset(model_input)
|
|
585
|
+
dataloader = DataLoader(dataset)
|
|
586
|
+
model = self.mlflow_models['model'].get_raw_model()
|
|
587
|
+
|
|
588
|
+
predictions = []
|
|
589
|
+
for batch in dataloader:
|
|
590
|
+
outputs = model(batch)
|
|
591
|
+
predictions.extend(self._outputs_to_annotations(outputs))
|
|
592
|
+
|
|
593
|
+
return predictions
|
|
594
|
+
```
|
|
595
|
+
"""
|
|
596
|
+
raise NotImplementedError(
|
|
597
|
+
"predict_default() must be implemented in your DatamintModel subclass. "
|
|
598
|
+
"This is the default fallback mode for prediction."
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# ========================================================================
|
|
602
|
+
# VIDEO/TEMPORAL MODES
|
|
603
|
+
# ========================================================================
|
|
604
|
+
|
|
605
|
+
def predict_frame(self,
|
|
606
|
+
model_input: list[Resource],
|
|
607
|
+
frame_index: int,
|
|
608
|
+
**kwargs) -> PredictionResult:
|
|
609
|
+
"""
|
|
610
|
+
Process specific frame from video resources.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
model_input: Video resources
|
|
614
|
+
frame_index: Index of frame to extract and process (0-based)
|
|
615
|
+
|
|
616
|
+
Returns:
|
|
617
|
+
Annotations for the specified frame (one list per resource)
|
|
618
|
+
|
|
619
|
+
Example:
|
|
620
|
+
```python
|
|
621
|
+
# Extract frame 42 from multiple videos
|
|
622
|
+
predictions = model.predict(
|
|
623
|
+
videos,
|
|
624
|
+
params={'mode': 'frame', 'frame_index': 42}
|
|
625
|
+
)
|
|
626
|
+
```
|
|
627
|
+
"""
|
|
628
|
+
logger.warning(f"predict_frame not implemented, falling back to predict_default")
|
|
629
|
+
return self.predict_default(model_input, **kwargs)
|
|
630
|
+
|
|
631
|
+
def predict_frame_range(self,
|
|
632
|
+
model_input: list[Resource],
|
|
633
|
+
start_frame: int,
|
|
634
|
+
end_frame: int,
|
|
635
|
+
step: int = 1,
|
|
636
|
+
**kwargs) -> PredictionResult:
|
|
637
|
+
"""
|
|
638
|
+
Process range of frames from video resources.
|
|
639
|
+
|
|
640
|
+
Args:
|
|
641
|
+
model_input: Video resources
|
|
642
|
+
start_frame: Start frame index (inclusive)
|
|
643
|
+
end_frame: End frame index (inclusive)
|
|
644
|
+
step: Step size between frames (default: 1)
|
|
645
|
+
|
|
646
|
+
Returns:
|
|
647
|
+
Annotations for frames in range (may be frame-scoped annotations)
|
|
648
|
+
|
|
649
|
+
Example:
|
|
650
|
+
```python
|
|
651
|
+
# Process frames 0-100, every 10th frame
|
|
652
|
+
predictions = model.predict(
|
|
653
|
+
videos,
|
|
654
|
+
params={'mode': 'frame_range',
|
|
655
|
+
'start_frame': 0,
|
|
656
|
+
'end_frame': 100,
|
|
657
|
+
'step': 10}
|
|
658
|
+
)
|
|
659
|
+
```
|
|
660
|
+
"""
|
|
661
|
+
logger.warning(f"predict_frame_range not implemented, falling back to predict_default")
|
|
662
|
+
return self.predict_default(model_input, **kwargs)
|
|
663
|
+
|
|
664
|
+
def predict_frame_interval(self,
|
|
665
|
+
model_input: list[Resource],
|
|
666
|
+
interval: int,
|
|
667
|
+
start_frame: int = 0,
|
|
668
|
+
end_frame: int | None = None,
|
|
669
|
+
**kwargs) -> PredictionResult:
|
|
670
|
+
"""
|
|
671
|
+
Process every nth frame from video resources.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
model_input: Video resources
|
|
675
|
+
interval: Process every nth frame
|
|
676
|
+
start_frame: Starting frame index
|
|
677
|
+
end_frame: Ending frame index (None = last frame)
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
Annotations for sampled frames
|
|
681
|
+
|
|
682
|
+
Example:
|
|
683
|
+
```python
|
|
684
|
+
# Process every 30th frame (1 fps for 30fps video)
|
|
685
|
+
predictions = model.predict(
|
|
686
|
+
videos,
|
|
687
|
+
params={'mode': 'frame_interval', 'interval': 30}
|
|
688
|
+
)
|
|
689
|
+
```
|
|
690
|
+
"""
|
|
691
|
+
logger.warning(f"predict_frame_interval not implemented, falling back to predict_default")
|
|
692
|
+
return self.predict_default(model_input, **kwargs)
|
|
693
|
+
|
|
694
|
+
def predict_all_frames(self,
|
|
695
|
+
model_input: list[Resource],
|
|
696
|
+
**kwargs) -> PredictionResult:
|
|
697
|
+
"""
|
|
698
|
+
Process all frames independently.
|
|
699
|
+
|
|
700
|
+
Args:
|
|
701
|
+
model_input: Video resources
|
|
702
|
+
|
|
703
|
+
Returns:
|
|
704
|
+
Annotations for all frames (likely frame-scoped)
|
|
705
|
+
|
|
706
|
+
Example:
|
|
707
|
+
```python
|
|
708
|
+
# Analyze every frame
|
|
709
|
+
predictions = model.predict(
|
|
710
|
+
videos,
|
|
711
|
+
params={'mode': 'all_frames'}
|
|
712
|
+
)
|
|
713
|
+
```
|
|
714
|
+
"""
|
|
715
|
+
logger.warning(f"predict_all_frames not implemented, falling back to predict_default")
|
|
716
|
+
return self.predict_default(model_input, **kwargs)
|
|
717
|
+
|
|
718
|
+
# ========================================================================
|
|
719
|
+
# 3D VOLUME MODES
|
|
720
|
+
# ========================================================================
|
|
721
|
+
|
|
722
|
+
def predict_slice(self,
|
|
723
|
+
model_input: list[Resource],
|
|
724
|
+
slice_index: int,
|
|
725
|
+
axis: str = 'axial',
|
|
726
|
+
**kwargs) -> PredictionResult:
|
|
727
|
+
"""
|
|
728
|
+
Process specific slice from 3D volume.
|
|
729
|
+
|
|
730
|
+
Args:
|
|
731
|
+
model_input: 3D volume resources (DICOM series, NIfTI, etc.)
|
|
732
|
+
slice_index: Index of slice to extract
|
|
733
|
+
axis: Anatomical axis ('axial', 'sagittal', 'coronal')
|
|
734
|
+
|
|
735
|
+
Returns:
|
|
736
|
+
Annotations for the specified slice
|
|
737
|
+
|
|
738
|
+
Example:
|
|
739
|
+
```python
|
|
740
|
+
# Extract and analyze axial slice 50
|
|
741
|
+
predictions = model.predict(
|
|
742
|
+
ct_scans,
|
|
743
|
+
params={'mode': 'slice',
|
|
744
|
+
'slice_index': 50,
|
|
745
|
+
'axis': 'axial'}
|
|
746
|
+
)
|
|
747
|
+
```
|
|
748
|
+
"""
|
|
749
|
+
logger.warning(f"predict_slice not implemented, falling back to predict_default")
|
|
750
|
+
return self.predict_default(model_input, **kwargs)
|
|
751
|
+
|
|
752
|
+
def predict_slice_range(self,
|
|
753
|
+
model_input: list[Resource],
|
|
754
|
+
start_index: int,
|
|
755
|
+
end_index: int,
|
|
756
|
+
axis: str = 'axial',
|
|
757
|
+
step: int = 1,
|
|
758
|
+
**kwargs) -> PredictionResult:
|
|
759
|
+
"""
|
|
760
|
+
Process range of slices from 3D volume.
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
model_input: 3D volume resources
|
|
764
|
+
start_index: Start slice index (inclusive)
|
|
765
|
+
end_index: End slice index (inclusive)
|
|
766
|
+
axis: Anatomical axis
|
|
767
|
+
step: Step size between slices
|
|
768
|
+
|
|
769
|
+
Returns:
|
|
770
|
+
Annotations for slices in range
|
|
771
|
+
"""
|
|
772
|
+
logger.warning(f"predict_slice_range not implemented, falling back to predict_default")
|
|
773
|
+
return self.predict_default(model_input, **kwargs)
|
|
774
|
+
|
|
775
|
+
def predict_volume(self,
|
|
776
|
+
model_input: list[Resource],
|
|
777
|
+
**kwargs) -> PredictionResult:
|
|
778
|
+
"""
|
|
779
|
+
Process entire 3D volume.
|
|
780
|
+
|
|
781
|
+
For true 3D models (not slice-by-slice).
|
|
782
|
+
|
|
783
|
+
Args:
|
|
784
|
+
model_input: 3D volume resources
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
3D annotations for entire volume
|
|
788
|
+
"""
|
|
789
|
+
logger.warning(f"predict_volume not implemented, falling back to predict_default")
|
|
790
|
+
return self.predict_default(model_input, **kwargs)
|
|
791
|
+
|
|
792
|
+
# ========================================================================
|
|
793
|
+
# ADVANCED MODES
|
|
794
|
+
# ========================================================================
|
|
795
|
+
|
|
796
|
+
def predict_interactive(self,
|
|
797
|
+
model_input: list[Resource],
|
|
798
|
+
prompt: dict[str, Any],
|
|
799
|
+
**kwargs) -> PredictionResult:
|
|
800
|
+
"""
|
|
801
|
+
Interactive prediction with user prompts.
|
|
802
|
+
|
|
803
|
+
For models like Segment Anything (SAM) that accept user guidance.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
model_input: Resources to process
|
|
807
|
+
prompt: Prompt dictionary with keys:
|
|
808
|
+
- 'points': list of [x, y] coordinates
|
|
809
|
+
- 'labels': list of labels (1=foreground, 0=background)
|
|
810
|
+
- 'boxes': list of [x1, y1, x2, y2] bounding boxes
|
|
811
|
+
- 'masks': list of binary mask arrays
|
|
812
|
+
|
|
813
|
+
Returns:
|
|
814
|
+
Annotations based on prompts
|
|
815
|
+
|
|
816
|
+
Example:
|
|
817
|
+
```python
|
|
818
|
+
# Segment based on positive and negative points
|
|
819
|
+
predictions = model.predict(
|
|
820
|
+
images,
|
|
821
|
+
params={'mode': 'interactive',
|
|
822
|
+
'prompt': {
|
|
823
|
+
'points': [[100, 150], [200, 250]],
|
|
824
|
+
'labels': [1, 0] # foreground, background
|
|
825
|
+
}}
|
|
826
|
+
)
|
|
827
|
+
```
|
|
828
|
+
"""
|
|
829
|
+
logger.warning(f"predict_interactive not implemented, falling back to predict_default")
|
|
830
|
+
return self.predict_default(model_input, **kwargs)
|
|
831
|
+
|
|
832
|
+
def predict_few_shot(self,
|
|
833
|
+
model_input: list[Resource],
|
|
834
|
+
context_resources: list[Resource],
|
|
835
|
+
k: int = 5,
|
|
836
|
+
**kwargs) -> PredictionResult:
|
|
837
|
+
"""
|
|
838
|
+
Few-shot prediction with context examples.
|
|
839
|
+
|
|
840
|
+
For models that can adapt based on a few labeled examples.
|
|
841
|
+
|
|
842
|
+
Args:
|
|
843
|
+
model_input: Resources to annotate
|
|
844
|
+
context_resources: Resources with existing annotations to use as examples
|
|
845
|
+
k: Number of examples to use (if more are provided)
|
|
846
|
+
|
|
847
|
+
Returns:
|
|
848
|
+
Annotations informed by context examples
|
|
849
|
+
|
|
850
|
+
Example:
|
|
851
|
+
```python
|
|
852
|
+
# Predict using similar annotated examples
|
|
853
|
+
predictions = model.predict(
|
|
854
|
+
new_images,
|
|
855
|
+
params={'mode': 'few_shot',
|
|
856
|
+
'context_resources': annotated_examples,
|
|
857
|
+
'k': 3}
|
|
858
|
+
)
|
|
859
|
+
```
|
|
860
|
+
"""
|
|
861
|
+
logger.warning(f"predict_few_shot not implemented, falling back to predict_default")
|
|
862
|
+
return self.predict_default(model_input, **kwargs)
|
|
863
|
+
|
|
864
|
+
def predict_image(self,
|
|
865
|
+
model_input: list[Resource],
|
|
866
|
+
**kwargs) -> PredictionResult:
|
|
867
|
+
"""
|
|
868
|
+
Process single 2D image resources.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
model_input: 2D image resources
|
|
872
|
+
|
|
873
|
+
Returns:
|
|
874
|
+
Annotations for each image
|
|
875
|
+
"""
|
|
876
|
+
logger.warning(f"predict_image not implemented, falling back to predict_default")
|
|
877
|
+
return self.predict_default(model_input, **kwargs)
|