datamint 2.3.1__py3-none-any.whl → 2.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamint/api/base_api.py +66 -8
- datamint/api/client.py +16 -5
- datamint/api/dto/__init__.py +10 -2
- datamint/api/endpoints/__init__.py +2 -0
- datamint/api/endpoints/annotations_api.py +47 -7
- datamint/api/endpoints/annotationsets_api.py +11 -0
- datamint/api/endpoints/projects_api.py +36 -34
- datamint/api/endpoints/resources_api.py +75 -28
- datamint/api/entity_base_api.py +11 -43
- datamint/apihandler/dto/annotation_dto.py +6 -2
- datamint/configs.py +6 -0
- datamint/dataset/base_dataset.py +18 -12
- datamint/dataset/dataset.py +2 -2
- datamint/entities/__init__.py +4 -2
- datamint/entities/annotation.py +74 -4
- datamint/entities/base_entity.py +47 -6
- datamint/entities/cache_manager.py +302 -0
- datamint/entities/datasetinfo.py +108 -1
- datamint/entities/project.py +47 -6
- datamint/entities/resource.py +146 -19
- datamint/types.py +17 -0
- {datamint-2.3.1.dist-info → datamint-2.3.3.dist-info}/METADATA +2 -1
- {datamint-2.3.1.dist-info → datamint-2.3.3.dist-info}/RECORD +25 -22
- {datamint-2.3.1.dist-info → datamint-2.3.3.dist-info}/WHEEL +0 -0
- {datamint-2.3.1.dist-info → datamint-2.3.3.dist-info}/entry_points.txt +0 -0
datamint/api/entity_base_api.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Any, TypeVar, Generic, Type, Sequence
|
|
2
2
|
import logging
|
|
3
3
|
import httpx
|
|
4
|
-
from dataclasses import dataclass
|
|
5
4
|
from datamint.entities.base_entity import BaseEntity
|
|
6
5
|
from datamint.exceptions import DatamintException, ResourceNotFoundError
|
|
7
6
|
import aiohttp
|
|
@@ -37,9 +36,14 @@ class EntityBaseApi(BaseApi, Generic[T]):
|
|
|
37
36
|
client: Optional HTTP client instance. If None, a new one will be created.
|
|
38
37
|
"""
|
|
39
38
|
super().__init__(config, client)
|
|
40
|
-
self.
|
|
39
|
+
self.__entity_class = entity_class
|
|
41
40
|
self.endpoint_base = endpoint_base.strip('/')
|
|
42
41
|
|
|
42
|
+
def _init_entity_obj(self, **kwargs) -> T:
|
|
43
|
+
obj = self.__entity_class(**kwargs)
|
|
44
|
+
obj._api = self
|
|
45
|
+
return obj
|
|
46
|
+
|
|
43
47
|
@staticmethod
|
|
44
48
|
def _entid(entity: BaseEntity | str) -> str:
|
|
45
49
|
return entity if isinstance(entity, str) else entity.id
|
|
@@ -117,7 +121,7 @@ class EntityBaseApi(BaseApi, Generic[T]):
|
|
|
117
121
|
for resp, items in items_gen:
|
|
118
122
|
all_items.extend(items)
|
|
119
123
|
|
|
120
|
-
return [self.
|
|
124
|
+
return [self._init_entity_obj(**item) for item in all_items]
|
|
121
125
|
|
|
122
126
|
def get_all(self, limit: int | None = None) -> Sequence[T]:
|
|
123
127
|
"""Get all entities with optional pagination and filtering.
|
|
@@ -143,7 +147,7 @@ class EntityBaseApi(BaseApi, Generic[T]):
|
|
|
143
147
|
httpx.HTTPStatusError: If the entity is not found or request fails.
|
|
144
148
|
"""
|
|
145
149
|
response = self._make_entity_request('GET', entity_id)
|
|
146
|
-
return self.
|
|
150
|
+
return self._init_entity_obj(**response.json())
|
|
147
151
|
|
|
148
152
|
async def _create_async(self, entity_data: dict[str, Any]) -> str | Sequence[str | dict]:
|
|
149
153
|
"""Create a new entity.
|
|
@@ -177,42 +181,6 @@ class EntityBaseApi(BaseApi, Generic[T]):
|
|
|
177
181
|
add_path=child_entity_name)
|
|
178
182
|
return response
|
|
179
183
|
|
|
180
|
-
# def bulk_create(self, entities_data: list[dict[str, Any]]) -> list[T]:
|
|
181
|
-
# """Create multiple entities in a single request.
|
|
182
|
-
|
|
183
|
-
# Args:
|
|
184
|
-
# entities_data: List of dictionaries containing entity data
|
|
185
|
-
|
|
186
|
-
# Returns:
|
|
187
|
-
# List of created entity instances
|
|
188
|
-
|
|
189
|
-
# Raises:
|
|
190
|
-
# httpx.HTTPStatusError: If bulk creation fails
|
|
191
|
-
# """
|
|
192
|
-
# payload = {'items': entities_data} # Common bulk API format
|
|
193
|
-
# response = self._make_request('POST', f'/{self.endpoint_base}/bulk', json=payload)
|
|
194
|
-
# data = response.json()
|
|
195
|
-
|
|
196
|
-
# # Handle response format - may be direct list or wrapped
|
|
197
|
-
# items = data if isinstance(data, list) else data.get('items', [])
|
|
198
|
-
# return [self.entity_class(**item) for item in items]
|
|
199
|
-
|
|
200
|
-
# def count(self, **params: Any) -> int:
|
|
201
|
-
# """Get the total count of entities matching the given filters.
|
|
202
|
-
|
|
203
|
-
# Args:
|
|
204
|
-
# **params: Query parameters for filtering
|
|
205
|
-
|
|
206
|
-
# Returns:
|
|
207
|
-
# Total count of matching entities
|
|
208
|
-
|
|
209
|
-
# Raises:
|
|
210
|
-
# httpx.HTTPStatusError: If the request fails
|
|
211
|
-
# """
|
|
212
|
-
# response = self._make_request('GET', f'/{self.endpoint_base}/count', params=params)
|
|
213
|
-
# data = response.json()
|
|
214
|
-
# return data.get('count', 0) if isinstance(data, dict) else data
|
|
215
|
-
|
|
216
184
|
|
|
217
185
|
class DeletableEntityApi(EntityBaseApi[T]):
|
|
218
186
|
"""Extension of EntityBaseApi for entities that support soft deletion.
|
|
@@ -221,7 +189,7 @@ class DeletableEntityApi(EntityBaseApi[T]):
|
|
|
221
189
|
retrieval and restoration of such entities.
|
|
222
190
|
"""
|
|
223
191
|
|
|
224
|
-
def delete(self, entity: str |
|
|
192
|
+
def delete(self, entity: str | T) -> None:
|
|
225
193
|
"""Delete an entity.
|
|
226
194
|
|
|
227
195
|
Args:
|
|
@@ -232,7 +200,7 @@ class DeletableEntityApi(EntityBaseApi[T]):
|
|
|
232
200
|
"""
|
|
233
201
|
self._make_entity_request('DELETE', entity)
|
|
234
202
|
|
|
235
|
-
def bulk_delete(self, entities: Sequence[str |
|
|
203
|
+
def bulk_delete(self, entities: Sequence[str | T]) -> None:
|
|
236
204
|
"""Delete multiple entities.
|
|
237
205
|
|
|
238
206
|
Args:
|
|
@@ -264,7 +232,7 @@ class DeletableEntityApi(EntityBaseApi[T]):
|
|
|
264
232
|
httpx.HTTPStatusError: If deletion fails or entity not found
|
|
265
233
|
"""
|
|
266
234
|
async with self._make_entity_request_async('DELETE', entity,
|
|
267
|
-
|
|
235
|
+
session=session) as resp:
|
|
268
236
|
await resp.text() # Consume response to complete request
|
|
269
237
|
|
|
270
238
|
# def get_deleted(self, **kwargs) -> Sequence[T]:
|
|
@@ -17,7 +17,11 @@ Classes:
|
|
|
17
17
|
import json
|
|
18
18
|
from typing import Any, TypeAlias, Literal
|
|
19
19
|
import logging
|
|
20
|
-
|
|
20
|
+
import sys
|
|
21
|
+
if sys.version_info >= (3, 11):
|
|
22
|
+
from enum import StrEnum
|
|
23
|
+
else:
|
|
24
|
+
from backports.strenum import StrEnum
|
|
21
25
|
from medimgkit.dicom_utils import pixel_to_patient
|
|
22
26
|
import pydicom
|
|
23
27
|
import numpy as np
|
|
@@ -31,7 +35,7 @@ CoordinateSystem: TypeAlias = Literal['pixel', 'patient']
|
|
|
31
35
|
"""
|
|
32
36
|
|
|
33
37
|
|
|
34
|
-
class AnnotationType(
|
|
38
|
+
class AnnotationType(StrEnum):
|
|
35
39
|
SEGMENTATION = 'segmentation'
|
|
36
40
|
AREA = 'area'
|
|
37
41
|
DISTANCE = 'distance'
|
datamint/configs.py
CHANGED
|
@@ -18,6 +18,12 @@ _LOGGER = logging.getLogger(__name__)
|
|
|
18
18
|
|
|
19
19
|
DIRS = PlatformDirs(appname='datamintapi')
|
|
20
20
|
CONFIG_FILE = os.path.join(DIRS.user_config_dir, 'datamintapi.yaml')
|
|
21
|
+
try:
|
|
22
|
+
DATAMINT_DATA_DIR = os.path.join(os.path.expanduser("~"), '.datamint')
|
|
23
|
+
except Exception as e:
|
|
24
|
+
_LOGGER.error(f"Could not determine home directory: {e}")
|
|
25
|
+
DATAMINT_DATA_DIR = None
|
|
26
|
+
|
|
21
27
|
|
|
22
28
|
|
|
23
29
|
def get_env_var_name(key: str) -> str:
|
datamint/dataset/base_dataset.py
CHANGED
|
@@ -11,7 +11,6 @@ from torch.utils.data import DataLoader
|
|
|
11
11
|
import torch
|
|
12
12
|
from torch import Tensor
|
|
13
13
|
from datamint.exceptions import DatamintException
|
|
14
|
-
from medimgkit.dicom_utils import is_dicom
|
|
15
14
|
from medimgkit.readers import read_array_normalized
|
|
16
15
|
from medimgkit.format_detection import guess_extension, guess_typez
|
|
17
16
|
from medimgkit.nifti_utils import NIFTI_MIMES, get_nifti_shape
|
|
@@ -20,6 +19,7 @@ from pathlib import Path
|
|
|
20
19
|
from datamint.entities import Annotation, DatasetInfo
|
|
21
20
|
import cv2
|
|
22
21
|
from datamint.entities import Resource
|
|
22
|
+
import datamint.configs
|
|
23
23
|
|
|
24
24
|
_LOGGER = logging.getLogger(__name__)
|
|
25
25
|
|
|
@@ -55,7 +55,7 @@ class DatamintBaseDataset:
|
|
|
55
55
|
exclude_frame_label_names: List of frame label names to exclude. If None, no frame labels will be excluded.
|
|
56
56
|
"""
|
|
57
57
|
|
|
58
|
-
|
|
58
|
+
|
|
59
59
|
DATAMINT_DATASETS_DIR = "datasets"
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
@@ -184,8 +184,7 @@ class DatamintBaseDataset:
|
|
|
184
184
|
"""Setup root and dataset directories."""
|
|
185
185
|
if root is None:
|
|
186
186
|
root = os.path.join(
|
|
187
|
-
|
|
188
|
-
self.DATAMINT_DEFAULT_DIR,
|
|
187
|
+
datamint.configs.DATAMINT_DATA_DIR,
|
|
189
188
|
self.DATAMINT_DATASETS_DIR
|
|
190
189
|
)
|
|
191
190
|
os.makedirs(root, exist_ok=True)
|
|
@@ -306,6 +305,15 @@ class DatamintBaseDataset:
|
|
|
306
305
|
"""Setup label sets and mappings."""
|
|
307
306
|
self.frame_lsets, self.frame_lcodes = self._get_labels_set(framed=True)
|
|
308
307
|
self.image_lsets, self.image_lcodes = self._get_labels_set(framed=False)
|
|
308
|
+
worklist_id = self.get_info()['worklist_id']
|
|
309
|
+
groups: dict[str, dict] = self.api.annotationsets.get_segmentation_group(worklist_id)['groups']
|
|
310
|
+
# order by 'index' key
|
|
311
|
+
max_index = max([g['index'] for g in groups.values()])
|
|
312
|
+
self.seglabel_list : list[str] = ['UNKNOWN'] * max_index # 1-based
|
|
313
|
+
for segname, g in groups.items():
|
|
314
|
+
self.seglabel_list[g['index'] - 1] = segname
|
|
315
|
+
|
|
316
|
+
self.seglabel2code = {label: idx + 1 for idx, label in enumerate(self.seglabel_list)}
|
|
309
317
|
|
|
310
318
|
def _filter_items(self, images_metainfo: list[dict]) -> list[dict]:
|
|
311
319
|
"""Filter items that have annotations."""
|
|
@@ -357,9 +365,7 @@ class DatamintBaseDataset:
|
|
|
357
365
|
@property
|
|
358
366
|
def segmentation_labels_set(self) -> list[str]:
|
|
359
367
|
"""Returns the set of segmentation labels in the dataset."""
|
|
360
|
-
|
|
361
|
-
b = set(self.image_lsets['segmentation'])
|
|
362
|
-
return list(a.union(b))
|
|
368
|
+
return self.seglabel_list
|
|
363
369
|
|
|
364
370
|
def _get_annotations_internal(
|
|
365
371
|
self,
|
|
@@ -463,8 +469,8 @@ class DatamintBaseDataset:
|
|
|
463
469
|
label_anns = self.get_annotations(i, type='label', scope=scope)
|
|
464
470
|
multilabel_set.update(ann.name for ann in label_anns)
|
|
465
471
|
|
|
466
|
-
seg_anns = self.get_annotations(i, type='segmentation', scope=scope)
|
|
467
|
-
segmentation_labels.update(ann.name for ann in seg_anns)
|
|
472
|
+
# seg_anns = self.get_annotations(i, type='segmentation', scope=scope)
|
|
473
|
+
# segmentation_labels.update(ann.name for ann in seg_anns)
|
|
468
474
|
|
|
469
475
|
cat_anns = self.get_annotations(i, type='category', scope=scope)
|
|
470
476
|
multiclass_set.update((ann.name, ann.value) for ann in cat_anns)
|
|
@@ -472,17 +478,17 @@ class DatamintBaseDataset:
|
|
|
472
478
|
# Sort and create mappings
|
|
473
479
|
multilabel_list = sorted(multilabel_set)
|
|
474
480
|
multiclass_list = sorted(multiclass_set)
|
|
475
|
-
segmentation_list = sorted(segmentation_labels)
|
|
481
|
+
# segmentation_list = sorted(segmentation_labels)
|
|
476
482
|
|
|
477
483
|
sets = {
|
|
478
484
|
'multilabel': multilabel_list,
|
|
479
|
-
'segmentation': segmentation_list,
|
|
485
|
+
# 'segmentation': segmentation_list,
|
|
480
486
|
'multiclass': multiclass_list
|
|
481
487
|
}
|
|
482
488
|
|
|
483
489
|
codes_map = {
|
|
484
490
|
'multilabel': {label: idx for idx, label in enumerate(multilabel_list)},
|
|
485
|
-
'segmentation': {label: idx + 1 for idx, label in enumerate(segmentation_list)},
|
|
491
|
+
# 'segmentation': {label: idx + 1 for idx, label in enumerate(segmentation_list)},
|
|
486
492
|
'multiclass': {label: idx for idx, label in enumerate(multiclass_list)}
|
|
487
493
|
}
|
|
488
494
|
|
datamint/dataset/dataset.py
CHANGED
|
@@ -191,7 +191,7 @@ class DatamintDataset(DatamintBaseDataset):
|
|
|
191
191
|
author_labels = seg_labels[author]
|
|
192
192
|
|
|
193
193
|
if frame_index is not None and ann.scope == 'frame':
|
|
194
|
-
seg_code = self.
|
|
194
|
+
seg_code = self.seglabel2code[ann.name]
|
|
195
195
|
if author_segs[frame_index] is None:
|
|
196
196
|
author_segs[frame_index] = []
|
|
197
197
|
author_labels[frame_index] = []
|
|
@@ -199,7 +199,7 @@ class DatamintDataset(DatamintBaseDataset):
|
|
|
199
199
|
author_segs[frame_index].append(s)
|
|
200
200
|
author_labels[frame_index].append(seg_code)
|
|
201
201
|
elif frame_index is None and ann.scope == 'image':
|
|
202
|
-
seg_code = self.
|
|
202
|
+
seg_code = self.seglabel2code[ann.name]
|
|
203
203
|
# apply to all frames
|
|
204
204
|
for i in range(nframes):
|
|
205
205
|
if author_segs[i] is None:
|
datamint/entities/__init__.py
CHANGED
|
@@ -7,14 +7,16 @@ from .project import Project
|
|
|
7
7
|
from .resource import Resource
|
|
8
8
|
from .user import User # new export
|
|
9
9
|
from .datasetinfo import DatasetInfo
|
|
10
|
+
from .cache_manager import CacheManager
|
|
10
11
|
|
|
11
12
|
__all__ = [
|
|
12
13
|
'Annotation',
|
|
13
14
|
'BaseEntity',
|
|
15
|
+
'CacheManager',
|
|
14
16
|
'Channel',
|
|
15
17
|
'ChannelResourceData',
|
|
18
|
+
'DatasetInfo',
|
|
16
19
|
'Project',
|
|
17
20
|
'Resource',
|
|
18
|
-
|
|
19
|
-
'DatasetInfo',
|
|
21
|
+
'User',
|
|
20
22
|
]
|
datamint/entities/annotation.py
CHANGED
|
@@ -5,11 +5,20 @@ This module defines the Annotation model used to represent annotation
|
|
|
5
5
|
records returned by the DataMint API.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
9
|
import logging
|
|
10
|
+
import os
|
|
11
|
+
|
|
10
12
|
from .base_entity import BaseEntity, MISSING_FIELD
|
|
11
|
-
from
|
|
13
|
+
from .cache_manager import CacheManager
|
|
14
|
+
from pydantic import PrivateAttr
|
|
12
15
|
from datetime import datetime
|
|
16
|
+
from datamint.api.dto import AnnotationType
|
|
17
|
+
from datamint.types import ImagingData
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from datamint.api.endpoints.annotations_api import AnnotationsApi
|
|
21
|
+
from .resource import Resource
|
|
13
22
|
|
|
14
23
|
logger = logging.getLogger(__name__)
|
|
15
24
|
|
|
@@ -21,6 +30,8 @@ _FIELD_MAPPING = {
|
|
|
21
30
|
'index': 'frame_index',
|
|
22
31
|
}
|
|
23
32
|
|
|
33
|
+
_ANNOTATION_CACHE_KEY = "annotation_data"
|
|
34
|
+
|
|
24
35
|
|
|
25
36
|
class Annotation(BaseEntity):
|
|
26
37
|
"""Pydantic Model representing a DataMint annotation.
|
|
@@ -60,7 +71,7 @@ class Annotation(BaseEntity):
|
|
|
60
71
|
identifier: str
|
|
61
72
|
scope: str
|
|
62
73
|
frame_index: int | None
|
|
63
|
-
annotation_type:
|
|
74
|
+
annotation_type: AnnotationType
|
|
64
75
|
text_value: str | None
|
|
65
76
|
numeric_value: float | int | None
|
|
66
77
|
units: str | None
|
|
@@ -83,7 +94,66 @@ class Annotation(BaseEntity):
|
|
|
83
94
|
annotation_worklist_name: str | None
|
|
84
95
|
user_info: dict | None
|
|
85
96
|
values: list | None = MISSING_FIELD
|
|
86
|
-
file: str | None = None
|
|
97
|
+
file: str | None = None
|
|
98
|
+
|
|
99
|
+
_api: 'AnnotationsApi' = PrivateAttr()
|
|
100
|
+
|
|
101
|
+
def __init__(self, **data):
|
|
102
|
+
"""Initialize the annotation entity."""
|
|
103
|
+
super().__init__(**data)
|
|
104
|
+
self._cache: CacheManager = CacheManager('annotations')
|
|
105
|
+
self._resource: 'Resource | None' = None
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def resource(self) -> 'Resource':
|
|
109
|
+
"""Lazily load and cache the associated Resource entity."""
|
|
110
|
+
if self._resource is None:
|
|
111
|
+
self._resource = self._api._get_resource(self)
|
|
112
|
+
return self._resource
|
|
113
|
+
|
|
114
|
+
def fetch_file_data(
|
|
115
|
+
self,
|
|
116
|
+
save_path: os.PathLike | str | None = None,
|
|
117
|
+
auto_convert: bool = True,
|
|
118
|
+
use_cache: bool = False,
|
|
119
|
+
) -> bytes | ImagingData:
|
|
120
|
+
# Version info for cache validation
|
|
121
|
+
version_info = self._generate_version_info()
|
|
122
|
+
|
|
123
|
+
# Try to get from cache
|
|
124
|
+
img_data = None
|
|
125
|
+
if use_cache:
|
|
126
|
+
img_data = self._cache.get(self.id, _ANNOTATION_CACHE_KEY, version_info)
|
|
127
|
+
|
|
128
|
+
if img_data is None:
|
|
129
|
+
# Fetch from server using download_resource_file
|
|
130
|
+
logger.debug(f"Fetching image data from server for resource {self.id}")
|
|
131
|
+
img_data = self._api.download_file(
|
|
132
|
+
self,
|
|
133
|
+
fpath_out=save_path
|
|
134
|
+
)
|
|
135
|
+
# Cache the data
|
|
136
|
+
if use_cache:
|
|
137
|
+
self._cache.set(self.id, _ANNOTATION_CACHE_KEY, img_data, version_info)
|
|
138
|
+
|
|
139
|
+
if auto_convert:
|
|
140
|
+
return self._api.convert_format(img_data)
|
|
141
|
+
|
|
142
|
+
return img_data
|
|
143
|
+
|
|
144
|
+
def _generate_version_info(self) -> dict:
|
|
145
|
+
"""Helper to generate version info for caching."""
|
|
146
|
+
return {
|
|
147
|
+
'created_at': self.created_at,
|
|
148
|
+
'deleted_at': self.deleted_at,
|
|
149
|
+
'associated_file': self.associated_file,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
def invalidate_cache(self) -> None:
|
|
153
|
+
"""Invalidate all cached data for this annotation."""
|
|
154
|
+
self._cache.invalidate(self.id)
|
|
155
|
+
self._resource = None
|
|
156
|
+
logger.debug(f"Invalidated cache for annotation {self.id}")
|
|
87
157
|
|
|
88
158
|
@classmethod
|
|
89
159
|
def from_dict(cls, data: dict[str, Any]) -> 'Annotation':
|
datamint/entities/base_entity.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import sys
|
|
3
|
-
from typing import Any
|
|
4
|
-
from pydantic import ConfigDict, BaseModel
|
|
3
|
+
from typing import Any, TYPE_CHECKING
|
|
4
|
+
from pydantic import ConfigDict, BaseModel, PrivateAttr
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from datamint.api.client import Api
|
|
8
|
+
from datamint.api.entity_base_api import EntityBaseApi
|
|
5
9
|
|
|
6
10
|
if sys.version_info >= (3, 11):
|
|
7
11
|
from typing import Self
|
|
@@ -22,9 +26,14 @@ class BaseEntity(BaseModel):
|
|
|
22
26
|
This class provides common functionality for all entities, such as
|
|
23
27
|
serialization and deserialization from dictionaries, as well as
|
|
24
28
|
handling unknown fields gracefully.
|
|
29
|
+
|
|
30
|
+
The API client is automatically injected by the Api class when entities
|
|
31
|
+
are created through API endpoints.
|
|
25
32
|
"""
|
|
26
33
|
|
|
27
|
-
model_config = ConfigDict(extra='allow') # Allow extra fields
|
|
34
|
+
model_config = ConfigDict(extra='allow', arbitrary_types_allowed=True) # Allow extra fields and arbitrary types
|
|
35
|
+
|
|
36
|
+
_api: 'EntityBaseApi[Self] | EntityBaseApi' = PrivateAttr()
|
|
28
37
|
|
|
29
38
|
def asdict(self) -> dict[str, Any]:
|
|
30
39
|
"""Convert the entity to a dictionary, including unknown fields."""
|
|
@@ -38,14 +47,46 @@ class BaseEntity(BaseModel):
|
|
|
38
47
|
"""Handle unknown fields by logging a warning once per class/field combination in debug mode."""
|
|
39
48
|
if self.__pydantic_extra__ and _LOGGER.isEnabledFor(logging.DEBUG):
|
|
40
49
|
class_name = self.__class__.__name__
|
|
41
|
-
|
|
50
|
+
|
|
42
51
|
have_to_log = False
|
|
43
52
|
for key in self.__pydantic_extra__.keys():
|
|
44
53
|
warning_key = (class_name, key)
|
|
45
|
-
|
|
54
|
+
|
|
46
55
|
if warning_key not in _LOGGED_WARNINGS:
|
|
47
56
|
_LOGGED_WARNINGS.add(warning_key)
|
|
48
57
|
have_to_log = True
|
|
49
|
-
|
|
58
|
+
|
|
50
59
|
if have_to_log:
|
|
51
60
|
_LOGGER.warning(f"Unknown fields {list(self.__pydantic_extra__.keys())} found in {class_name}")
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def is_attr_missing(value: Any) -> bool:
|
|
64
|
+
"""Check if a value is the MISSING_FIELD sentinel."""
|
|
65
|
+
return value == MISSING_FIELD
|
|
66
|
+
|
|
67
|
+
def _refresh(self) -> Self:
|
|
68
|
+
"""Refresh the entity data from the server.
|
|
69
|
+
|
|
70
|
+
This method fetches the latest data from the server and updates
|
|
71
|
+
the current instance with any missing or updated fields.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The updated Entity instance (self)
|
|
75
|
+
"""
|
|
76
|
+
updated_ent = self._api.get_by_id(self._api._entid(self))
|
|
77
|
+
|
|
78
|
+
# Update all fields from the fresh data
|
|
79
|
+
for field_name, field_value in updated_ent.model_dump().items():
|
|
80
|
+
if field_value != MISSING_FIELD:
|
|
81
|
+
setattr(self, field_name, field_value)
|
|
82
|
+
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
def _ensure_attr(self, attr_name: str) -> None:
|
|
86
|
+
"""Ensure that a given attribute is not MISSING_FIELD, refreshing if necessary.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
attr_name: Name of the attribute to check and ensure
|
|
90
|
+
"""
|
|
91
|
+
if self.is_attr_missing(getattr(self, attr_name)):
|
|
92
|
+
self._refresh()
|