dtlpy 1.115.44__py3-none-any.whl → 1.117.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. dtlpy/__init__.py +491 -491
  2. dtlpy/__version__.py +1 -1
  3. dtlpy/assets/__init__.py +26 -26
  4. dtlpy/assets/code_server/config.yaml +2 -2
  5. dtlpy/assets/code_server/installation.sh +24 -24
  6. dtlpy/assets/code_server/launch.json +13 -13
  7. dtlpy/assets/code_server/settings.json +2 -2
  8. dtlpy/assets/main.py +53 -53
  9. dtlpy/assets/main_partial.py +18 -18
  10. dtlpy/assets/mock.json +11 -11
  11. dtlpy/assets/model_adapter.py +83 -83
  12. dtlpy/assets/package.json +61 -61
  13. dtlpy/assets/package_catalog.json +29 -29
  14. dtlpy/assets/package_gitignore +307 -307
  15. dtlpy/assets/service_runners/__init__.py +33 -33
  16. dtlpy/assets/service_runners/converter.py +96 -96
  17. dtlpy/assets/service_runners/multi_method.py +49 -49
  18. dtlpy/assets/service_runners/multi_method_annotation.py +54 -54
  19. dtlpy/assets/service_runners/multi_method_dataset.py +55 -55
  20. dtlpy/assets/service_runners/multi_method_item.py +52 -52
  21. dtlpy/assets/service_runners/multi_method_json.py +52 -52
  22. dtlpy/assets/service_runners/single_method.py +37 -37
  23. dtlpy/assets/service_runners/single_method_annotation.py +43 -43
  24. dtlpy/assets/service_runners/single_method_dataset.py +43 -43
  25. dtlpy/assets/service_runners/single_method_item.py +41 -41
  26. dtlpy/assets/service_runners/single_method_json.py +42 -42
  27. dtlpy/assets/service_runners/single_method_multi_input.py +45 -45
  28. dtlpy/assets/voc_annotation_template.xml +23 -23
  29. dtlpy/caches/base_cache.py +32 -32
  30. dtlpy/caches/cache.py +473 -473
  31. dtlpy/caches/dl_cache.py +201 -201
  32. dtlpy/caches/filesystem_cache.py +89 -89
  33. dtlpy/caches/redis_cache.py +84 -84
  34. dtlpy/dlp/__init__.py +20 -20
  35. dtlpy/dlp/cli_utilities.py +367 -367
  36. dtlpy/dlp/command_executor.py +764 -764
  37. dtlpy/dlp/dlp +1 -1
  38. dtlpy/dlp/dlp.bat +1 -1
  39. dtlpy/dlp/dlp.py +128 -128
  40. dtlpy/dlp/parser.py +651 -651
  41. dtlpy/entities/__init__.py +83 -83
  42. dtlpy/entities/analytic.py +347 -347
  43. dtlpy/entities/annotation.py +1879 -1879
  44. dtlpy/entities/annotation_collection.py +699 -699
  45. dtlpy/entities/annotation_definitions/__init__.py +20 -20
  46. dtlpy/entities/annotation_definitions/base_annotation_definition.py +100 -100
  47. dtlpy/entities/annotation_definitions/box.py +195 -195
  48. dtlpy/entities/annotation_definitions/classification.py +67 -67
  49. dtlpy/entities/annotation_definitions/comparison.py +72 -72
  50. dtlpy/entities/annotation_definitions/cube.py +204 -204
  51. dtlpy/entities/annotation_definitions/cube_3d.py +149 -149
  52. dtlpy/entities/annotation_definitions/description.py +32 -32
  53. dtlpy/entities/annotation_definitions/ellipse.py +124 -124
  54. dtlpy/entities/annotation_definitions/free_text.py +62 -62
  55. dtlpy/entities/annotation_definitions/gis.py +69 -69
  56. dtlpy/entities/annotation_definitions/note.py +139 -139
  57. dtlpy/entities/annotation_definitions/point.py +117 -117
  58. dtlpy/entities/annotation_definitions/polygon.py +182 -182
  59. dtlpy/entities/annotation_definitions/polyline.py +111 -111
  60. dtlpy/entities/annotation_definitions/pose.py +92 -92
  61. dtlpy/entities/annotation_definitions/ref_image.py +86 -86
  62. dtlpy/entities/annotation_definitions/segmentation.py +240 -240
  63. dtlpy/entities/annotation_definitions/subtitle.py +34 -34
  64. dtlpy/entities/annotation_definitions/text.py +85 -85
  65. dtlpy/entities/annotation_definitions/undefined_annotation.py +74 -74
  66. dtlpy/entities/app.py +220 -220
  67. dtlpy/entities/app_module.py +107 -107
  68. dtlpy/entities/artifact.py +174 -174
  69. dtlpy/entities/assignment.py +399 -399
  70. dtlpy/entities/base_entity.py +214 -214
  71. dtlpy/entities/bot.py +113 -113
  72. dtlpy/entities/codebase.py +292 -292
  73. dtlpy/entities/collection.py +38 -38
  74. dtlpy/entities/command.py +169 -169
  75. dtlpy/entities/compute.py +449 -449
  76. dtlpy/entities/dataset.py +1299 -1299
  77. dtlpy/entities/directory_tree.py +44 -44
  78. dtlpy/entities/dpk.py +470 -470
  79. dtlpy/entities/driver.py +235 -235
  80. dtlpy/entities/execution.py +397 -397
  81. dtlpy/entities/feature.py +124 -124
  82. dtlpy/entities/feature_set.py +152 -145
  83. dtlpy/entities/filters.py +798 -798
  84. dtlpy/entities/gis_item.py +107 -107
  85. dtlpy/entities/integration.py +184 -184
  86. dtlpy/entities/item.py +975 -959
  87. dtlpy/entities/label.py +123 -123
  88. dtlpy/entities/links.py +85 -85
  89. dtlpy/entities/message.py +175 -175
  90. dtlpy/entities/model.py +684 -684
  91. dtlpy/entities/node.py +1005 -1005
  92. dtlpy/entities/ontology.py +810 -803
  93. dtlpy/entities/organization.py +287 -287
  94. dtlpy/entities/package.py +657 -657
  95. dtlpy/entities/package_defaults.py +5 -5
  96. dtlpy/entities/package_function.py +185 -185
  97. dtlpy/entities/package_module.py +113 -113
  98. dtlpy/entities/package_slot.py +118 -118
  99. dtlpy/entities/paged_entities.py +299 -299
  100. dtlpy/entities/pipeline.py +624 -624
  101. dtlpy/entities/pipeline_execution.py +279 -279
  102. dtlpy/entities/project.py +394 -394
  103. dtlpy/entities/prompt_item.py +505 -505
  104. dtlpy/entities/recipe.py +301 -301
  105. dtlpy/entities/reflect_dict.py +102 -102
  106. dtlpy/entities/resource_execution.py +138 -138
  107. dtlpy/entities/service.py +974 -963
  108. dtlpy/entities/service_driver.py +117 -117
  109. dtlpy/entities/setting.py +294 -294
  110. dtlpy/entities/task.py +495 -495
  111. dtlpy/entities/time_series.py +143 -143
  112. dtlpy/entities/trigger.py +426 -426
  113. dtlpy/entities/user.py +118 -118
  114. dtlpy/entities/webhook.py +124 -124
  115. dtlpy/examples/__init__.py +19 -19
  116. dtlpy/examples/add_labels.py +135 -135
  117. dtlpy/examples/add_metadata_to_item.py +21 -21
  118. dtlpy/examples/annotate_items_using_model.py +65 -65
  119. dtlpy/examples/annotate_video_using_model_and_tracker.py +75 -75
  120. dtlpy/examples/annotations_convert_to_voc.py +9 -9
  121. dtlpy/examples/annotations_convert_to_yolo.py +9 -9
  122. dtlpy/examples/convert_annotation_types.py +51 -51
  123. dtlpy/examples/converter.py +143 -143
  124. dtlpy/examples/copy_annotations.py +22 -22
  125. dtlpy/examples/copy_folder.py +31 -31
  126. dtlpy/examples/create_annotations.py +51 -51
  127. dtlpy/examples/create_video_annotations.py +83 -83
  128. dtlpy/examples/delete_annotations.py +26 -26
  129. dtlpy/examples/filters.py +113 -113
  130. dtlpy/examples/move_item.py +23 -23
  131. dtlpy/examples/play_video_annotation.py +13 -13
  132. dtlpy/examples/show_item_and_mask.py +53 -53
  133. dtlpy/examples/triggers.py +49 -49
  134. dtlpy/examples/upload_batch_of_items.py +20 -20
  135. dtlpy/examples/upload_items_and_custom_format_annotations.py +55 -55
  136. dtlpy/examples/upload_items_with_modalities.py +43 -43
  137. dtlpy/examples/upload_segmentation_annotations_from_mask_image.py +44 -44
  138. dtlpy/examples/upload_yolo_format_annotations.py +70 -70
  139. dtlpy/exceptions.py +125 -125
  140. dtlpy/miscellaneous/__init__.py +20 -20
  141. dtlpy/miscellaneous/dict_differ.py +95 -95
  142. dtlpy/miscellaneous/git_utils.py +217 -217
  143. dtlpy/miscellaneous/json_utils.py +14 -14
  144. dtlpy/miscellaneous/list_print.py +105 -105
  145. dtlpy/miscellaneous/zipping.py +130 -130
  146. dtlpy/ml/__init__.py +20 -20
  147. dtlpy/ml/base_feature_extractor_adapter.py +27 -27
  148. dtlpy/ml/base_model_adapter.py +1287 -1230
  149. dtlpy/ml/metrics.py +461 -461
  150. dtlpy/ml/predictions_utils.py +274 -274
  151. dtlpy/ml/summary_writer.py +57 -57
  152. dtlpy/ml/train_utils.py +60 -60
  153. dtlpy/new_instance.py +252 -252
  154. dtlpy/repositories/__init__.py +56 -56
  155. dtlpy/repositories/analytics.py +85 -85
  156. dtlpy/repositories/annotations.py +916 -916
  157. dtlpy/repositories/apps.py +383 -383
  158. dtlpy/repositories/artifacts.py +452 -452
  159. dtlpy/repositories/assignments.py +599 -599
  160. dtlpy/repositories/bots.py +213 -213
  161. dtlpy/repositories/codebases.py +559 -559
  162. dtlpy/repositories/collections.py +332 -332
  163. dtlpy/repositories/commands.py +152 -152
  164. dtlpy/repositories/compositions.py +61 -61
  165. dtlpy/repositories/computes.py +439 -439
  166. dtlpy/repositories/datasets.py +1585 -1504
  167. dtlpy/repositories/downloader.py +1157 -923
  168. dtlpy/repositories/dpks.py +433 -433
  169. dtlpy/repositories/drivers.py +482 -482
  170. dtlpy/repositories/executions.py +815 -815
  171. dtlpy/repositories/feature_sets.py +256 -226
  172. dtlpy/repositories/features.py +255 -255
  173. dtlpy/repositories/integrations.py +484 -484
  174. dtlpy/repositories/items.py +912 -912
  175. dtlpy/repositories/messages.py +94 -94
  176. dtlpy/repositories/models.py +1000 -1000
  177. dtlpy/repositories/nodes.py +80 -80
  178. dtlpy/repositories/ontologies.py +511 -511
  179. dtlpy/repositories/organizations.py +525 -525
  180. dtlpy/repositories/packages.py +1941 -1941
  181. dtlpy/repositories/pipeline_executions.py +451 -451
  182. dtlpy/repositories/pipelines.py +640 -640
  183. dtlpy/repositories/projects.py +539 -539
  184. dtlpy/repositories/recipes.py +429 -399
  185. dtlpy/repositories/resource_executions.py +137 -137
  186. dtlpy/repositories/schema.py +120 -120
  187. dtlpy/repositories/service_drivers.py +213 -213
  188. dtlpy/repositories/services.py +1704 -1704
  189. dtlpy/repositories/settings.py +339 -339
  190. dtlpy/repositories/tasks.py +1477 -1477
  191. dtlpy/repositories/times_series.py +278 -278
  192. dtlpy/repositories/triggers.py +536 -536
  193. dtlpy/repositories/upload_element.py +257 -257
  194. dtlpy/repositories/uploader.py +661 -661
  195. dtlpy/repositories/webhooks.py +249 -249
  196. dtlpy/services/__init__.py +22 -22
  197. dtlpy/services/aihttp_retry.py +131 -131
  198. dtlpy/services/api_client.py +1786 -1785
  199. dtlpy/services/api_reference.py +40 -40
  200. dtlpy/services/async_utils.py +133 -133
  201. dtlpy/services/calls_counter.py +44 -44
  202. dtlpy/services/check_sdk.py +68 -68
  203. dtlpy/services/cookie.py +115 -115
  204. dtlpy/services/create_logger.py +156 -156
  205. dtlpy/services/events.py +84 -84
  206. dtlpy/services/logins.py +235 -235
  207. dtlpy/services/reporter.py +256 -256
  208. dtlpy/services/service_defaults.py +91 -91
  209. dtlpy/utilities/__init__.py +20 -20
  210. dtlpy/utilities/annotations/__init__.py +16 -16
  211. dtlpy/utilities/annotations/annotation_converters.py +269 -269
  212. dtlpy/utilities/base_package_runner.py +285 -264
  213. dtlpy/utilities/converter.py +1650 -1650
  214. dtlpy/utilities/dataset_generators/__init__.py +1 -1
  215. dtlpy/utilities/dataset_generators/dataset_generator.py +670 -670
  216. dtlpy/utilities/dataset_generators/dataset_generator_tensorflow.py +23 -23
  217. dtlpy/utilities/dataset_generators/dataset_generator_torch.py +21 -21
  218. dtlpy/utilities/local_development/__init__.py +1 -1
  219. dtlpy/utilities/local_development/local_session.py +179 -179
  220. dtlpy/utilities/reports/__init__.py +2 -2
  221. dtlpy/utilities/reports/figures.py +343 -343
  222. dtlpy/utilities/reports/report.py +71 -71
  223. dtlpy/utilities/videos/__init__.py +17 -17
  224. dtlpy/utilities/videos/video_player.py +598 -598
  225. dtlpy/utilities/videos/videos.py +470 -470
  226. {dtlpy-1.115.44.data → dtlpy-1.117.6.data}/scripts/dlp +1 -1
  227. dtlpy-1.117.6.data/scripts/dlp.bat +2 -0
  228. {dtlpy-1.115.44.data → dtlpy-1.117.6.data}/scripts/dlp.py +128 -128
  229. {dtlpy-1.115.44.dist-info → dtlpy-1.117.6.dist-info}/METADATA +186 -186
  230. dtlpy-1.117.6.dist-info/RECORD +239 -0
  231. {dtlpy-1.115.44.dist-info → dtlpy-1.117.6.dist-info}/WHEEL +1 -1
  232. {dtlpy-1.115.44.dist-info → dtlpy-1.117.6.dist-info}/licenses/LICENSE +200 -200
  233. tests/features/environment.py +551 -551
  234. dtlpy/assets/__pycache__/__init__.cpython-310.pyc +0 -0
  235. dtlpy-1.115.44.data/scripts/dlp.bat +0 -2
  236. dtlpy-1.115.44.dist-info/RECORD +0 -240
  237. {dtlpy-1.115.44.dist-info → dtlpy-1.117.6.dist-info}/entry_points.txt +0 -0
  238. {dtlpy-1.115.44.dist-info → dtlpy-1.117.6.dist-info}/top_level.txt +0 -0
@@ -1,1230 +1,1287 @@
1
- import dataclasses
2
- import threading
3
- import tempfile
4
- import datetime
5
- import logging
6
- import string
7
- import shutil
8
- import random
9
- import base64
10
- import copy
11
- import time
12
- import tqdm
13
- import traceback
14
- import sys
15
- import io
16
- import os
17
- from itertools import chain
18
- from PIL import Image
19
- from functools import partial
20
- import numpy as np
21
- from concurrent.futures import ThreadPoolExecutor
22
- import attr
23
- from collections.abc import MutableMapping
24
- from typing import Optional
25
- from .. import entities, utilities, repositories, exceptions
26
- from ..services import service_defaults
27
- from ..services.api_client import ApiClient
28
-
29
- logger = logging.getLogger('ModelAdapter')
30
-
31
- # Constants
32
- PREDICT_EMBED_DEFAULT_SUBSET_LIMIT = 1000
33
- PREDICT_EMBED_DEFAULT_TIMEOUT = 3600 * 24
34
-
35
-
36
- class ModelConfigurations(MutableMapping):
37
- """
38
- Manages model configuration using composition with a backing dict.
39
-
40
- Uses MutableMapping to implement dict-like behavior without inheritance.
41
- This avoids duplication: if we inherited from dict, we'd have two dicts
42
- (one from inheritance, one from model_entity.configuration), leading to
43
- data inconsistency and maintenance issues.
44
- """
45
-
46
- def __init__(self, base_model_adapter):
47
- # Store reference to base_model_adapter dictionary
48
- self._backing_dict = {}
49
-
50
- if base_model_adapter is not None and base_model_adapter.model_entity is not None and base_model_adapter.model_entity.configuration is not None:
51
- self._backing_dict = base_model_adapter.model_entity.configuration
52
- if 'include_background' not in self._backing_dict:
53
- self._backing_dict['include_background'] = False
54
- self._base_model_adapter = base_model_adapter
55
- # Don't call _update_model_entity during initialization to avoid premature updates
56
-
57
- def _update_model_entity(self):
58
- if self._base_model_adapter is not None and self._base_model_adapter.model_entity is not None:
59
- self._base_model_adapter.model_entity.update(reload_services=False)
60
-
61
- def __ior__(self, other):
62
- self.update(other)
63
- return self
64
-
65
- # Required MutableMapping abstract methods
66
- def __getitem__(self, key):
67
- return self._backing_dict[key]
68
-
69
- def __setitem__(self, key, value):
70
- # Note: This method only updates the backing dict, not object attributes.
71
- # If you need to also update object attributes, be careful to avoid
72
- # infinite recursion by not calling __setattr__ from here.
73
- update = False
74
- if key not in self._backing_dict or self._backing_dict.get(key) != value:
75
- update = True
76
- self._backing_dict[key] = value
77
- if update:
78
- self._update_model_entity()
79
-
80
- def __delitem__(self, key):
81
- del self._backing_dict[key]
82
-
83
- def __iter__(self):
84
- return iter(self._backing_dict)
85
-
86
- def __len__(self):
87
- return len(self._backing_dict)
88
-
89
- def get(self, key, default=None):
90
- if key not in self._backing_dict:
91
- self.__setitem__(key, default)
92
- return self._backing_dict.get(key)
93
-
94
- def update(self, *args, **kwargs):
95
- # Check if there will be any modifications
96
- update_dict = dict(*args, **kwargs)
97
- has_changes = False
98
- for key, value in update_dict.items():
99
- if key not in self._backing_dict or self._backing_dict[key] != value:
100
- has_changes = True
101
- break
102
- self._backing_dict.update(*args, **kwargs)
103
-
104
- if has_changes:
105
- self._update_model_entity()
106
-
107
- def setdefault(self, key, default=None):
108
- if key not in self._backing_dict:
109
- self._backing_dict[key] = default
110
- return self._backing_dict[key]
111
-
112
-
113
- @dataclasses.dataclass
114
- class AdapterDefaults(ModelConfigurations):
115
- # for predict items, dataset, evaluate
116
- upload_annotations: bool = dataclasses.field(default=True)
117
- clean_annotations: bool = dataclasses.field(default=True)
118
- # for embeddings
119
- upload_features: bool = dataclasses.field(default=True)
120
- # for training
121
- root_path: str = dataclasses.field(default=None)
122
- data_path: str = dataclasses.field(default=None)
123
- output_path: str = dataclasses.field(default=None)
124
-
125
- def __init__(self, base_model_adapter=None):
126
- super().__init__(base_model_adapter)
127
- for f in dataclasses.fields(AdapterDefaults):
128
- # if the field exists in model_entity.configuration, use it
129
- # else set it from the attribute default value
130
- if super().get(f.name) is not None:
131
- super().__setattr__(f.name, super().get(f.name))
132
- else:
133
- super().__setitem__(f.name, f.default)
134
-
135
- def __setattr__(self, key, value):
136
- # Dataclass-like fields behave as attributes, so map to dict
137
- super().__setattr__(key, value)
138
- if not key.startswith("_"):
139
- super().__setitem__(key, value)
140
-
141
- def update(self, *args, **kwargs):
142
- for f in dataclasses.fields(AdapterDefaults):
143
- if f.name in kwargs:
144
- setattr(self, f.name, kwargs[f.name])
145
- super().update(*args, **kwargs)
146
-
147
- def resolve(self, key, *args):
148
- for arg in args:
149
- if arg is not None:
150
- super().__setitem__(key, arg)
151
- return arg
152
- return self.get(key, None)
153
-
154
-
155
- class BaseModelAdapter(utilities.BaseServiceRunner):
156
- _client_api = attr.ib(type=ApiClient, repr=False)
157
- _feature_set_lock = threading.Lock()
158
-
159
- def __init__(self, model_entity: entities.Model = None):
160
- self.logger = logger
161
- # entities
162
- self._model_entity = None
163
- self._package = None
164
- self._base_configuration = dict()
165
- self._configuration = None
166
- self.adapter_defaults = None
167
- self.package_name = None
168
- self.model = None
169
- self.bucket_path = None
170
- # funcs
171
- self.item_to_batch_mapping = {'text': self._item_to_text, 'image': self._item_to_image}
172
- if model_entity is not None:
173
- self.load_from_model(model_entity=model_entity)
174
- logger.warning(
175
- "in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'."
176
- )
177
-
178
- ##################
179
- # Configurations #
180
- ##################
181
-
182
- @property
183
- def configuration(self) -> dict:
184
- # load from model
185
- if self._model_entity is not None:
186
- configuration = self._configuration
187
- # else - load the default from the package
188
- elif self._package is not None:
189
- configuration = self.package.metadata.get('system', {}).get('ml', {}).get('defaultConfiguration', {})
190
- else:
191
- configuration = self._base_configuration
192
- return configuration
193
-
194
- @configuration.setter
195
- def configuration(self, configuration: dict):
196
- assert isinstance(configuration, dict)
197
- if self._model_entity is not None:
198
- # Update configuration with received dict
199
- self._model_entity.configuration = configuration
200
- self.adapter_defaults = AdapterDefaults(self)
201
- self._configuration = self.adapter_defaults
202
-
203
- ############
204
- # Entities #
205
- ############
206
- @property
207
- def model_entity(self):
208
- if self._model_entity is None:
209
- raise ValueError("No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
210
- assert isinstance(self._model_entity, entities.Model)
211
- return self._model_entity
212
-
213
- @model_entity.setter
214
- def model_entity(self, model_entity):
215
- assert isinstance(model_entity, entities.Model)
216
- if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
217
- if self._model_entity.id != model_entity.id:
218
- self.logger.warning('Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
219
- self._model_entity = model_entity
220
- self.package = model_entity.package
221
- self.adapter_defaults = AdapterDefaults(self)
222
- self._configuration = self.adapter_defaults
223
-
224
- @property
225
- def package(self):
226
- if self._model_entity is not None:
227
- self.package = self.model_entity.package
228
- if self._package is None:
229
- raise ValueError('Missing Package entity on adapter. Please set: "adapter.package=package"')
230
- assert isinstance(self._package, (entities.Package, entities.Dpk))
231
- return self._package
232
-
233
- @package.setter
234
- def package(self, package):
235
- assert isinstance(package, (entities.Package, entities.Dpk))
236
- self.package_name = package.name
237
- self._package = package
238
-
239
- ###################################
240
- # NEED TO IMPLEMENT THESE METHODS #
241
- ###################################
242
-
243
- def load(self, local_path, **kwargs):
244
- """
245
- Loads model and populates self.model with a `runnable` model
246
-
247
- Virtual method - need to implement
248
-
249
- This function is called by load_from_model (download to local and then loads)
250
-
251
- :param local_path: `str` directory path in local FileSystem
252
- """
253
- raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
254
-
255
- def save(self, local_path, **kwargs):
256
- """
257
- Saves configuration and weights locally
258
-
259
- Virtual method - need to implement
260
-
261
- the function is called in save_to_model which first save locally and then uploads to model entity
262
-
263
- :param local_path: `str` directory path in local FileSystem
264
- """
265
- raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
266
-
267
- def train(self, data_path, output_path, **kwargs):
268
- """
269
- Virtual method - need to implement
270
-
271
- Train the model according to data in data_paths and save the train outputs to output_path,
272
- this include the weights and any other artifacts created during train
273
-
274
- :param data_path: `str` local File System path to where the data was downloaded and converted at
275
- :param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
276
- """
277
- raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
278
-
279
- def predict(self, batch, **kwargs):
280
- """
281
- Model inference (predictions) on batch of items
282
-
283
- Virtual method - need to implement
284
-
285
- :param batch: output of the `prepare_item_func` func
286
- :return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
287
- """
288
- raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
289
-
290
- def embed(self, batch, **kwargs):
291
- """
292
- Extract model embeddings on batch of items
293
-
294
- Virtual method - need to implement
295
-
296
- :param batch: output of the `prepare_item_func` func
297
- :return: `list[list]` a feature vector per each item in the batch
298
- """
299
- raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
300
-
301
- def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
302
- """
303
- This function evaluates the model prediction on a dataset (with GT annotations).
304
- The evaluation process will upload the scores and metrics to the platform.
305
-
306
- :param model: The model to evaluate (annotation.metadata.system.model.name
307
- :param dataset: Dataset where the model predicted and uploaded its annotations
308
- :param filters: Filters to query items on the dataset
309
- :return:
310
- """
311
- import dtlpymetrics
312
-
313
- compare_types = model.output_type
314
- if not filters:
315
- filters = entities.Filters()
316
- if filters is not None and isinstance(filters, dict):
317
- filters = entities.Filters(custom_filter=filters)
318
- model = dtlpymetrics.scoring.create_model_score(
319
- model=model,
320
- dataset=dataset,
321
- filters=filters,
322
- compare_types=compare_types,
323
- )
324
- return model
325
-
326
- def convert_from_dtlpy(self, data_path, **kwargs):
327
- """Convert Dataloop structure data to model structured
328
-
329
- Virtual method - need to implement
330
-
331
- e.g. take dlp dir structure and construct annotation file
332
-
333
- :param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
334
- :return:
335
- """
336
- raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
337
-
338
- #################
339
- # DTLPY METHODS #
340
- ################
341
- def prepare_item_func(self, item: entities.Item):
342
- """
343
- Prepare the Dataloop item before calling the `predict` function with a batch.
344
- A user can override this function to load item differently
345
- Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
346
-
347
- :param item:
348
- :return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
349
- """
350
- # Item to batch func
351
- if isinstance(self.model_entity.input_type, list):
352
- if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
353
- processed = self._item_to_text(item)
354
- elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
355
- processed = self._item_to_image(item)
356
- else:
357
- processed = self._item_to_item(item)
358
-
359
- elif self.model_entity.input_type in self.item_to_batch_mapping:
360
- processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
361
-
362
- else:
363
- processed = self._item_to_item(item)
364
-
365
- return processed
366
-
367
- def __include_model_annotations(self, annotation_filters):
368
- include_model_annotations = self.model_entity.configuration.get("include_model_annotations", False)
369
- if include_model_annotations is False:
370
- if annotation_filters.custom_filter is None:
371
- annotation_filters.add(field="metadata.system.model.name", values=False, operator=entities.FiltersOperations.EXISTS)
372
- else:
373
- annotation_filters.custom_filter['filter']['$and'].append({'metadata.system.model.name': {'$exists': False}})
374
- return annotation_filters
375
-
376
- def __download_background_images(self, filters, data_subset_base_path, annotation_options):
377
- background_list = list()
378
- if self.configuration.get('include_background', False) is True:
379
- filters.custom_filter["filter"]["$and"].append({"annotated": False})
380
- background_list = self.model_entity.dataset.items.download(
381
- filters=filters,
382
- local_path=data_subset_base_path,
383
- annotation_options=annotation_options,
384
- )
385
- return background_list
386
-
387
- def prepare_data(
388
- self,
389
- dataset: entities.Dataset,
390
- # paths
391
- root_path=None,
392
- data_path=None,
393
- output_path=None,
394
- #
395
- overwrite=False,
396
- **kwargs,
397
- ):
398
- """
399
- Prepares dataset locally before training or evaluation.
400
- download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
401
-
402
- :param dataset: dl.Dataset
403
- :param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
404
- :param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
405
- :param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
406
-
407
- :param bool overwrite: overwrite the data path (download again). default is False
408
- """
409
- # define paths
410
- dataloop_path = service_defaults.DATALOOP_PATH
411
- root_path = self.adapter_defaults.resolve("root_path", root_path)
412
- data_path = self.adapter_defaults.resolve("data_path", data_path)
413
- output_path = self.adapter_defaults.resolve("output_path", output_path)
414
- if root_path is None:
415
- now = datetime.datetime.now()
416
- root_path = os.path.join(
417
- dataloop_path,
418
- 'model_data',
419
- "{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
420
- now.strftime('%Y-%m-%d-%H%M%S'),
421
- )
422
- if data_path is None:
423
- data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
424
- os.makedirs(data_path, exist_ok=True)
425
- if output_path is None:
426
- output_path = os.path.join(root_path, 'output')
427
- os.makedirs(output_path, exist_ok=True)
428
-
429
- if len(os.listdir(data_path)) > 0:
430
- self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
431
-
432
- annotation_options = entities.ViewAnnotationOptions.JSON
433
- if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
434
- annotation_options = entities.ViewAnnotationOptions.INSTANCE
435
-
436
- # Download the subset items
437
- subsets = self.model_entity.metadata.get("system", {}).get("subsets", None)
438
- annotations_subsets = self.model_entity.metadata.get("system", {}).get("annotationsSubsets", {})
439
- if subsets is None:
440
- raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
441
- for subset, filters_dict in subsets.items():
442
- data_subset_base_path = os.path.join(data_path, subset)
443
- if os.path.isdir(data_subset_base_path) and not overwrite:
444
- # existing and dont overwrite
445
- self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
446
- continue
447
-
448
- filters = entities.Filters(custom_filter=filters_dict)
449
- self.logger.debug("Downloading subset {!r} of {}".format(subset, self.model_entity.dataset.name))
450
-
451
- annotation_filters = None
452
- if subset in annotations_subsets:
453
- annotation_filters = entities.Filters(
454
- use_defaults=False,
455
- resource=entities.FiltersResource.ANNOTATION,
456
- custom_filter=annotations_subsets[subset],
457
- )
458
- # if user provided annotation_filters, skip the default filters
459
- elif self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
460
- annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
461
- if self.model_entity.output_type in [
462
- entities.AnnotationType.SEGMENTATION,
463
- entities.AnnotationType.POLYGON,
464
- ]:
465
- model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
466
- else:
467
- model_output_types = [self.model_entity.output_type]
468
-
469
- annotation_filters.add(
470
- field=entities.FiltersKnownFields.TYPE,
471
- values=model_output_types,
472
- operator=entities.FiltersOperations.IN,
473
- )
474
-
475
- annotation_filters = self.__include_model_annotations(annotation_filters)
476
- annotations_subsets[subset] = annotation_filters.prepare()
477
-
478
- ret_list = dataset.items.download(
479
- filters=filters,
480
- local_path=data_subset_base_path,
481
- annotation_options=annotation_options,
482
- annotation_filters=annotation_filters,
483
- )
484
- filters = entities.Filters(custom_filter=subsets[subset])
485
- background_ret_list = self.__download_background_images(
486
- filters=filters,
487
- data_subset_base_path=data_subset_base_path,
488
- annotation_options=annotation_options,
489
- )
490
- ret_list = list(ret_list)
491
- background_ret_list = list(background_ret_list)
492
- self.logger.debug(f"Subset '{subset}': ret_list length: {len(ret_list)}, background_ret_list length: {len(background_ret_list)}")
493
- # Combine ret_list and background_ret_list generators into a single generator
494
- ret_list = ret_list + background_ret_list
495
- if isinstance(ret_list, list) and len(ret_list) == 0:
496
- if annotation_filters is not None:
497
- annotation_filters_str = annotation_filters.prepare()
498
- else:
499
- annotation_filters_str = None
500
- raise ValueError(
501
- f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
502
- f"Subset {subset} filters: {filters.prepare()}\n"
503
- f"Annotation filters: {annotation_filters_str}"
504
- )
505
-
506
- self.convert_from_dtlpy(data_path=data_path, **kwargs)
507
- return root_path, data_path, output_path
508
-
509
- def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
510
- """Loads a model from given `dl.Model`.
511
- Reads configurations and instantiate self.model_entity
512
- Downloads the model_entity bucket (if available)
513
-
514
- :param model_entity: `str` dl.Model entity
515
- :param local_path: `str` directory path in local FileSystem to download the model_entity to
516
- :param overwrite: `bool` (default False) if False does not download files with same name else (True) download all
517
- """
518
- if model_entity is not None:
519
- self.model_entity = model_entity
520
- if local_path is None:
521
- local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
522
- # Load configuration and adapter defaults
523
- self.adapter_defaults = AdapterDefaults(self)
524
- # Point _configuration to the same object since AdapterDefaults inherits from ModelConfigurations
525
- self._configuration = self.adapter_defaults
526
- # Download
527
- self.model_entity.artifacts.download(local_path=local_path, overwrite=overwrite)
528
- self.load(local_path, **kwargs)
529
-
530
- def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
531
- """
532
- Saves the model state to a new bucket and configuration
533
-
534
- Saves configuration and weights to artifacts
535
- Mark the model as `trained`
536
- loads only applies for remote buckets
537
-
538
- :param local_path: `str` directory path in local FileSystem to save the current model bucket (weights) (default will create a temp dir)
539
- :param replace: `bool` will clean the bucket's content before uploading new files
540
- :param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
541
- :return:
542
- """
543
-
544
- if local_path is None:
545
- local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
546
- self.logger.debug("Using temporary dir at {}".format(local_path))
547
-
548
- self.save(local_path=local_path, **kwargs)
549
-
550
- if self.model_entity is None:
551
- raise ValueError('Missing model entity on the adapter. ' 'Please set before saving: "adapter.model_entity=model"')
552
-
553
- self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'), overwrite=True)
554
- if cleanup:
555
- shutil.rmtree(path=local_path, ignore_errors=True)
556
- self.logger.info("Clean-up. deleting {}".format(local_path))
557
-
558
- # ===============
559
- # SERVICE METHODS
560
- # ===============
561
-
562
- @entities.Package.decorators.function(
563
- display_name='Predict Items',
564
- inputs={'items': 'Item[]'},
565
- outputs={'items': 'Item[]', 'annotations': 'Annotation[]'},
566
- )
567
- def predict_items(self, items: list, upload_annotations=None, clean_annotations=None, batch_size=None, **kwargs):
568
- """
569
- Run the predict function on the input list of items (or single) and return the items and the predictions.
570
- Each prediction is by the model output type (package.output_type) and model_info in the metadata
571
-
572
- :param items: `List[dl.Item]` list of items to predict
573
- :param upload_annotations: `bool` uploads the predictions on the given items
574
- :param clean_annotations: `bool` deletes previous model annotations (predictions) before uploading new ones
575
- :param batch_size: `int` size of batch to run a single inference
576
-
577
- :return: `List[dl.Item]`, `List[List[dl.Annotation]]`
578
- """
579
- if batch_size is None:
580
- batch_size = self.configuration.get('batch_size', 4)
581
- upload_annotations = self.adapter_defaults.resolve("upload_annotations", upload_annotations)
582
- clean_annotations = self.adapter_defaults.resolve("clean_annotations", clean_annotations)
583
- input_type = self.model_entity.input_type
584
- self.logger.debug("Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
585
- pool = ThreadPoolExecutor(max_workers=16)
586
- error_counter = 0
587
- fail_ids = list()
588
- annotations = list()
589
- for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None, file=sys.stdout):
590
- batch_items = items[i_batch : i_batch + batch_size]
591
- batch = list(pool.map(self.prepare_item_func, batch_items))
592
- try:
593
- batch_collections = self.predict(batch, **kwargs)
594
- except Exception as e:
595
- item_ids = [item.id for item in batch_items]
596
- self.logger.error(f"Failed to predict batch {i_batch} for items {item_ids}. Error: {e}\n{traceback.format_exc()}")
597
- error_counter += 1
598
- fail_ids.extend(item_ids)
599
- continue
600
- _futures = list(pool.map(partial(self._update_predictions_metadata), batch_items, batch_collections))
601
- # Loop over the futures to make sure they are all done to avoid race conditions
602
- _ = [_f for _f in _futures]
603
- if upload_annotations is True:
604
- self.logger.debug("Uploading items' annotation for model {!r}.".format(self.model_entity.name))
605
- try:
606
- batch_collections = list(
607
- pool.map(partial(self._upload_model_annotations, clean_annotations=clean_annotations), batch_items, batch_collections)
608
- )
609
- except Exception as err:
610
- item_ids = [item.id for item in batch_items]
611
- self.logger.error(f"Failed to upload annotations for items {item_ids}. Error: {err}\n{traceback.format_exc()}")
612
- error_counter += 1
613
- fail_ids.extend(item_ids)
614
-
615
- for collection in batch_collections:
616
- # function needs to return `List[List[dl.Annotation]]`
617
- # convert annotation collection to a list of dl.Annotation for each batch
618
- if isinstance(collection, entities.AnnotationCollection):
619
- annotations.extend([annotation for annotation in collection.annotations])
620
- else:
621
- logger.warning(f'RETURN TYPE MAY BE INVALID: {type(collection)}')
622
- annotations.extend(collection)
623
- # TODO call the callback
624
-
625
- pool.shutdown()
626
- if error_counter > 0:
627
- raise Exception(f"Failed to predict all items. Failed IDs: {fail_ids}, See logs for more details")
628
- return items, annotations
629
-
630
- @entities.Package.decorators.function(
631
- display_name='Embed Items',
632
- inputs={'items': 'Item[]'},
633
- outputs={'items': 'Item[]', 'features': 'Json[]'},
634
- )
635
- def embed_items(self, items: list, upload_features=None, batch_size=None, progress: utilities.Progress = None, **kwargs):
636
- """
637
- Extract feature from an input list of items (or single) and return the items and the feature vector.
638
-
639
- :param items: `List[dl.Item]` list of items to embed
640
- :param upload_features: `bool` uploads the features on the given items
641
- :param batch_size: `int` size of batch to run a single embed
642
-
643
- :return: `List[dl.Item]`, `List[List[vector]]`
644
- """
645
- if batch_size is None:
646
- batch_size = self.configuration.get('batch_size', 4)
647
- upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
648
- input_type = self.model_entity.input_type
649
- self.logger.debug("Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
650
- error_counter = 0
651
- fail_ids = list()
652
-
653
- feature_set = self._get_feature_set()
654
-
655
- # upload the feature vectors
656
- pool = ThreadPoolExecutor(max_workers=16)
657
- vectors = list()
658
- _items = list()
659
- for i_batch in tqdm.tqdm(
660
- range(0, len(items), batch_size),
661
- desc='embedding',
662
- unit='bt',
663
- leave=None,
664
- file=sys.stdout,
665
- ):
666
- batch_items = items[i_batch : i_batch + batch_size]
667
- batch = list(pool.map(self.prepare_item_func, batch_items))
668
- try:
669
- batch_vectors = self.embed(batch, **kwargs)
670
- except Exception as err:
671
- item_ids = [item.id for item in batch_items]
672
- self.logger.error(f"Failed to embed batch {i_batch} for items {item_ids}. Error: {err}\n{traceback.format_exc()}")
673
- error_counter += 1
674
- fail_ids.extend(item_ids)
675
- continue
676
- vectors.extend(batch_vectors)
677
- # Save the items in the order of the vectors
678
- _items.extend(batch_items)
679
- pool.shutdown()
680
-
681
- if upload_features is True:
682
- _indicies_to_remove = list()
683
- embeddings_size = self.configuration.get('embeddings_size', 256)
684
- for idx, vector in enumerate(vectors):
685
- if vector is None or len(vector) != embeddings_size:
686
- self.logger.warning(f"Vector generated for item {_items[idx].id} is None or has wrong size. Skipping...")
687
- _indicies_to_remove.append(idx)
688
-
689
- # Remove indices in descending order to avoid IndexError
690
- # When removing items, indices shift down, so we must remove from highest to lowest
691
- for index in sorted(_indicies_to_remove, reverse=True):
692
- _items.pop(index)
693
- vectors.pop(index)
694
-
695
- if len(_items) != len(vectors):
696
- raise ValueError(f"The number of items ({len(_items)}) is not equal to the number of vectors ({len(vectors)}).")
697
- self.logger.debug(f"Uploading {len(_items)} items' feature vectors for model {self.model_entity.name}.")
698
- try:
699
- start_time = time.time()
700
- feature_set.features.create(entity=_items, value=vectors, feature_set_id=feature_set.id, project_id=self.model_entity.project_id)
701
- self.logger.debug(f"Uploaded {len(_items)} items' feature vectors for model {self.model_entity.name} in {time.time() - start_time} seconds.")
702
- except Exception as err:
703
- self.logger.error(f"Failed to upload feature vectors. Error: {err}\n{traceback.format_exc()}")
704
- error_counter += 1
705
- if error_counter > 0:
706
- raise Exception(f"Failed to embed all items. Failed IDs: {fail_ids}, See logs for more details")
707
- return _items, vectors
708
-
709
- @entities.Package.decorators.function(
710
- display_name='Embed Dataset with DQL',
711
- inputs={'dataset': 'Dataset', 'filters': 'Json'},
712
- )
713
- def embed_dataset(
714
- self,
715
- dataset: entities.Dataset,
716
- filters: Optional[entities.Filters] = None,
717
- upload_features: Optional[bool] = None,
718
- batch_size: Optional[int] = None,
719
- progress: Optional[utilities.Progress] = None,
720
- **kwargs,
721
- ):
722
- """
723
- Run model embedding on all items in a dataset
724
-
725
- :param dataset: Dataset entity to embed
726
- :param filters: Filters entity for filtering before embedding
727
- :param upload_features: bool whether to upload features back to platform
728
- :param batch_size: int size of batch to run a single embedding
729
- :param progress: dl.Progress object to track progress
730
- :return: bool indicating if embedding completed successfully
731
- """
732
-
733
- status = self._execute_dataset_operation(
734
- dataset=dataset,
735
- operation_type='embed',
736
- filters=filters,
737
- progress=progress,
738
- batch_size=batch_size,
739
- )
740
- if status is False:
741
- raise ValueError(f"Failed to embed entire dataset, please check the logs for more details")
742
-
743
- @entities.Package.decorators.function(
744
- display_name='Predict Dataset with DQL',
745
- inputs={'dataset': 'Dataset', 'filters': 'Json'},
746
- )
747
- def predict_dataset(
748
- self,
749
- dataset: entities.Dataset,
750
- filters: Optional[entities.Filters] = None,
751
- upload_annotations: Optional[bool] = None,
752
- clean_annotations: Optional[bool] = None,
753
- batch_size: Optional[int] = None,
754
- progress: Optional[utilities.Progress] = None,
755
- **kwargs,
756
- ):
757
- """
758
- Run model prediction on all items in a dataset
759
-
760
- :param dataset: Dataset entity to predict
761
- :param filters: Filters entity for filtering before prediction
762
- :param upload_annotations: bool whether to upload annotations back to platform
763
- :param clean_annotations: bool whether to clean existing annotations
764
- :param batch_size: int size of batch to run a single prediction
765
- :param progress: dl.Progress object to track progress
766
- :return: bool indicating if prediction completed successfully
767
- """
768
- status = self._execute_dataset_operation(
769
- dataset=dataset,
770
- operation_type='predict',
771
- filters=filters,
772
- progress=progress,
773
- batch_size=batch_size,
774
- )
775
- if status is False:
776
- raise ValueError(f"Failed to predict entire dataset, please check the logs for more details")
777
-
778
- @entities.Package.decorators.function(
779
- display_name='Train a Model',
780
- inputs={'model': entities.Model},
781
- outputs={'model': entities.Model},
782
- )
783
- def train_model(self, model: entities.Model, cleanup=False, progress: utilities.Progress = None, context: utilities.Context = None):
784
- """
785
- Train on existing model.
786
- data will be taken from dl.Model.datasetId
787
- configuration is as defined in dl.Model.configuration
788
- upload the output the model's bucket (model.bucket)
789
- """
790
- if isinstance(model, dict):
791
- model = repositories.Models(client_api=self._client_api).get(model_id=model['id'])
792
- output_path = None
793
- try:
794
- logger.info("Received {s} for training".format(s=model.id))
795
- model = model.wait_for_model_ready()
796
- if model.status == 'failed':
797
- raise ValueError("Model is in failed state, cannot train.")
798
-
799
- ##############
800
- # Set status #
801
- ##############
802
- model.status = 'training'
803
- if context is not None:
804
- if 'system' not in model.metadata:
805
- model.metadata['system'] = dict()
806
- model.update(reload_services=False)
807
-
808
- ##########################
809
- # load model and weights #
810
- ##########################
811
- logger.info("Loading Adapter with: {n} ({i!r})".format(n=model.name, i=model.id))
812
- self.load_from_model(model_entity=model)
813
-
814
- ################
815
- # prepare data #
816
- ################
817
- root_path, data_path, output_path = self.prepare_data(dataset=self.model_entity.dataset, root_path=os.path.join('tmp', model.id))
818
- # Start the Train
819
- logger.info("Training {p_name!r} with model {m_name!r} on data {d_path!r}".format(p_name=self.package_name, m_name=model.id, d_path=data_path))
820
- if progress is not None:
821
- progress.update(message='starting training')
822
-
823
- def on_epoch_end_callback(i_epoch, n_epoch):
824
- if progress is not None:
825
- progress.update(progress=int(100 * (i_epoch + 1) / n_epoch), message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
826
-
827
- self.train(data_path=data_path, output_path=output_path, on_epoch_end_callback=on_epoch_end_callback)
828
- if progress is not None:
829
- progress.update(message='saving model', progress=99)
830
-
831
- self.save_to_model(local_path=output_path, replace=True)
832
- model.status = 'trained'
833
- model.update(reload_services=False)
834
- ###########
835
- # cleanup #
836
- ###########
837
- if cleanup:
838
- shutil.rmtree(output_path, ignore_errors=True)
839
- except Exception:
840
- # save also on fail
841
- if output_path is not None:
842
- self.save_to_model(local_path=output_path, replace=True)
843
- logger.info('Execution failed. Setting model.status to failed')
844
- raise
845
- return model
846
-
847
- @entities.Package.decorators.function(
848
- display_name='Evaluate a Model',
849
- inputs={'model': entities.Model, 'dataset': entities.Dataset, 'filters': 'Json'},
850
- outputs={'model': entities.Model, 'dataset': entities.Dataset},
851
- )
852
- def evaluate_model(
853
- self,
854
- model: entities.Model,
855
- dataset: entities.Dataset,
856
- filters: entities.Filters = None,
857
- #
858
- progress: utilities.Progress = None,
859
- context: utilities.Context = None,
860
- ):
861
- """
862
- Evaluate a model.
863
- data will be downloaded from the dataset and query
864
- configuration is as defined in dl.Model.configuration
865
- upload annotations and calculate metrics vs GT
866
-
867
- :param model: Model entity to run prediction
868
- :param dataset: Dataset to evaluate
869
- :param filters: Filter for specific items from dataset
870
- :param progress: dl.Progress for report FaaS progress
871
- :param context:
872
- :return:
873
- """
874
- logger.info(f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
875
- ##########################
876
- # load model and weights #
877
- ##########################
878
- logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
879
- self.load_from_model(dataset=dataset, model_entity=model)
880
-
881
- ##############
882
- # Predicting #
883
- ##############
884
- logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
885
- if not filters:
886
- filters = entities.Filters()
887
- self.predict_dataset(dataset=dataset, filters=filters, with_upload=True)
888
-
889
- ##############
890
- # Evaluating #
891
- ##############
892
- logger.info(f"Starting adapter.evaluate()")
893
- if progress is not None:
894
- progress.update(message='calculating metrics', progress=98)
895
- model = self.evaluate(model=model, dataset=dataset, filters=filters)
896
- #########
897
- # Done! #
898
- #########
899
- if progress is not None:
900
- progress.update(message='finishing evaluation', progress=99)
901
- return model, dataset
902
-
903
- # =============
904
- # INNER METHODS
905
- # =============
906
- def _get_feature_set(self):
907
- # Ensure feature set creation/retrieval is thread-safe across the class
908
- with self.__class__._feature_set_lock:
909
- # Search for existing feature set for this model id
910
- feature_set = self.model_entity.feature_set
911
- if feature_set is None:
912
- logger.info('Feature Set not found. creating... ')
913
- try:
914
- self.model_entity.project.feature_sets.get(feature_set_name=self.model_entity.name)
915
- feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
916
- logger.warning(
917
- f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}"
918
- )
919
-
920
- except exceptions.NotFound:
921
- feature_set_name = self.model_entity.name
922
- feature_set = self.model_entity.project.feature_sets.create(
923
- name=feature_set_name,
924
- entity_type=entities.FeatureEntityType.ITEM,
925
- model_id=self.model_entity.id,
926
- project_id=self.model_entity.project_id,
927
- set_type=self.model_entity.name,
928
- size=self.configuration.get('embeddings_size', 256),
929
- )
930
- logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
931
- else:
932
- logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
933
- return feature_set
934
-
935
- def _execute_dataset_operation(
936
- self,
937
- dataset: entities.Dataset,
938
- operation_type: str,
939
- filters: Optional[entities.Filters] = None,
940
- progress: Optional[utilities.Progress] = None,
941
- batch_size: Optional[int] = None,
942
- ) -> bool:
943
- """
944
- Execute dataset operation (predict/embed) with batching and filtering support.
945
-
946
- :param dataset: Dataset entity to run operation on
947
- :param operation_type: Type of operation to execute ('predict' or 'embed')
948
- :param filters: Filters entity to filter items, default None
949
- :param progress: Progress object for tracking progress, default None
950
- :param batch_size: Size of batches to process items, default None (uses model config)
951
- :return: True if operation completes successfully
952
- :raises ValueError: If operation_type is not 'predict' or 'embed'
953
- """
954
- self.logger.debug(f"Running {operation_type} for dataset (name:{dataset.name}, id:{dataset.id})")
955
-
956
- if not filters:
957
- self.logger.debug("No filters provided, using default filters")
958
- filters = entities.Filters()
959
- if filters is not None and isinstance(filters, dict):
960
- self.logger.debug(f"Received custom filters {filters}")
961
- filters = entities.Filters(custom_filter=filters)
962
-
963
- if operation_type == 'embed':
964
- feature_set = self._get_feature_set()
965
- logger.info(f"Feature set found! name: {feature_set.name}, id: {feature_set.id}")
966
-
967
- predict_embed_subset_limit = self.configuration.get('predict_embed_subset_limit', PREDICT_EMBED_DEFAULT_SUBSET_LIMIT)
968
- predict_embed_timeout = self.configuration.get('predict_embed_timeout', PREDICT_EMBED_DEFAULT_TIMEOUT)
969
- self.logger.debug(f"Inputs: predict_embed_subset_limit: {predict_embed_subset_limit}, predict_embed_timeout: {predict_embed_timeout}")
970
- tmp_filters = copy.deepcopy(filters.prepare())
971
- tmp_filters['pageSize'] = 0
972
- num_items = dataset.items.list(filters=entities.Filters(custom_filter=tmp_filters)).items_count
973
- self.logger.debug(f"Number of items for current filters: {num_items}")
974
-
975
- # One-item lookahead on generator: if only one subset, run locally; else create executions for all
976
- gen = entities.Filters._get_split_filters(dataset, filters, predict_embed_subset_limit)
977
- try:
978
- first_filter = next(gen)
979
- except StopIteration:
980
- self.logger.info("Filters is empty, nothing to run")
981
- return True
982
-
983
- try:
984
- second_filter = next(gen)
985
- multiple = True
986
- except StopIteration:
987
- multiple = False
988
-
989
- if not multiple:
990
- self.logger.info("Split filters has only one subset, running locally")
991
- if batch_size is None:
992
- batch_size = self.configuration.get('batch_size', 4)
993
- first_filter["pageSize"] = 1000
994
- single_filters = entities.Filters(custom_filter=first_filter)
995
- pages = dataset.items.list(filters=single_filters)
996
- self.logger.info(f"Single run pages on: {pages.items_count} items")
997
- items = [item for page in pages for item in page if item.type == 'file']
998
- self.logger.debug(f"Single run items length: {len(items)}")
999
- if operation_type == 'embed':
1000
- self.embed_items(items=items, batch_size=batch_size, progress=progress)
1001
- elif operation_type == 'predict':
1002
- self.predict_items(items=items, batch_size=batch_size, progress=progress)
1003
- else:
1004
- raise ValueError(f"Unsupported operation type: {operation_type}")
1005
- return True
1006
-
1007
- executions = []
1008
- for filter_dict in chain([first_filter, second_filter], gen):
1009
- self.logger.debug(f"Creating execution for models {operation_type} with dataset id {dataset.id} and filter_dict {filter_dict}")
1010
- if operation_type == 'embed':
1011
- execution = self.model_entity.models.embed(
1012
- model=self.model_entity,
1013
- dataset_id=dataset.id,
1014
- filters=entities.Filters(custom_filter=filter_dict),
1015
- )
1016
- elif operation_type == 'predict':
1017
- execution = self.model_entity.models.predict(
1018
- model=self.model_entity, dataset_id=dataset.id, filters=entities.Filters(custom_filter=filter_dict)
1019
- )
1020
- else:
1021
- raise ValueError(f"Unsupported operation type: {operation_type}")
1022
- executions.append(execution)
1023
-
1024
- if executions:
1025
- self.logger.info(f'Created {len(executions)} executions for {operation_type}, ' f'execution ids: {[ex.id for ex in executions]}')
1026
-
1027
- wait_time = 5
1028
- start_time = time.time()
1029
- last_perc = 0
1030
- self.logger.debug(f"Waiting for executions with timeout {predict_embed_timeout}")
1031
- while time.time() - start_time < predict_embed_timeout:
1032
- continue_loop = False
1033
- total_perc = 0
1034
-
1035
- for ex in executions:
1036
- execution = dataset.project.executions.get(execution_id=ex.id)
1037
- perc = execution.latest_status.get('percentComplete', 0)
1038
- total_perc += perc
1039
- if execution.in_progress():
1040
- continue_loop = True
1041
-
1042
- avg_perc = round(total_perc / len(executions), 0)
1043
- if progress is not None and last_perc != avg_perc:
1044
- last_perc = avg_perc
1045
- progress.update(progress=last_perc, message=f'running {operation_type}')
1046
-
1047
- if not continue_loop:
1048
- break
1049
-
1050
- time.sleep(wait_time)
1051
- self.logger.debug("End waiting for executions")
1052
- # Check if any execution failed
1053
- executions_filter = entities.Filters(resource=entities.FiltersResource.EXECUTION)
1054
- executions_filter.add(field="id", values=[ex.id for ex in executions], operator=entities.FiltersOperations.IN)
1055
- executions_filter.add(field='latestStatus.status', values='failed')
1056
- executions_filter.page_size = 0
1057
- failed_executions_count = dataset.project.executions.list(filters=executions_filter).items_count
1058
- if failed_executions_count > 0:
1059
- self.logger.error(f"Failed to {operation_type} for {failed_executions_count} executions")
1060
- return False
1061
- return True
1062
-
1063
- def _upload_model_annotations(self, item: entities.Item, predictions, clean_annotations):
1064
- """
1065
- Utility function that upload prediction to dlp platform based on the package.output_type
1066
- :param predictions: `dl.AnnotationCollection`
1067
- :param cleanup: `bool` if set removes existing predictions with the same package-model name
1068
- """
1069
- if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
1070
- raise TypeError(f'predictions was expected to be of type {entities.AnnotationCollection}, but instead it is {type(predictions)}')
1071
- if clean_annotations:
1072
- clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
1073
- clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name)
1074
- # clean_filter.add(field='type', values=self.model_entity.output_type,)
1075
- item.annotations.delete(filters=clean_filter)
1076
- annotations = item.annotations.upload(annotations=predictions)
1077
- return annotations
1078
-
1079
- @staticmethod
1080
- def _item_to_image(item):
1081
- """
1082
- Preprocess items before calling the `predict` functions.
1083
- Convert item to numpy array
1084
-
1085
- :param item:
1086
- :return:
1087
- """
1088
- try:
1089
- buffer = item.download(save_locally=False)
1090
- image = np.asarray(Image.open(buffer))
1091
- except Exception as e:
1092
- logger.error(f"Failed to convert image to np.array, Error: {e}\n{traceback.format_exc()}")
1093
- image = None
1094
- return image
1095
-
1096
- @staticmethod
1097
- def _item_to_item(item):
1098
- """
1099
- Default item to batch function.
1100
- This function should prepare a single item for the predict function, e.g. for images, it loads the image as numpy array
1101
- :param item:
1102
- :return:
1103
- """
1104
- return item
1105
-
1106
- @staticmethod
1107
- def _item_to_text(item):
1108
- filename = item.download(overwrite=True)
1109
- text = None
1110
- if item.mimetype == 'text/plain' or item.mimetype == 'text/markdown':
1111
- with open(filename, 'r') as f:
1112
- text = f.read()
1113
- text = text.replace('\n', ' ')
1114
- else:
1115
- logger.warning('Item is not text file. mimetype: {}'.format(item.mimetype))
1116
- text = item
1117
- if os.path.exists(filename):
1118
- os.remove(filename)
1119
- return text
1120
-
1121
- @staticmethod
1122
- def _uri_to_image(data_uri):
1123
- # data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAS4AAAEuCAYAAAAwQP9DAAAU80lEQVR4Xu2da+hnRRnHv0qZKV42LDOt1eyGULoSJBGpRBFprBJBQrBJBBWGSm8jld5WroHUCyEXKutNu2IJ1QtXetULL0uQFCu24WoRsV5KpYvGYzM4nv6X8zu/mTnznPkcWP6XPTPzzOf7/L7/OXPmzDlOHBCAAAScETjOWbyECwEIQEAYF0kAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxDwTeDTkr4s6UxJ/5F0QNK3JD3lu1tbR49xLVld+jYXgcskvSTpIkmnS/qgpJMk/Tv8bHHZ7+PXPw6M5kRJx0t6Ijkv9uUsSW+U9Iykczfp4K8lfXiuztdoF+OqQZk2vBEwUzFTsK9mQNFkotGkhvFeSc+G86NRtdDfd0h6tIVASsSAcZWgSp0eCJjJ7JR0SRgZ2SjHDMp+38Jho7PXTAzkBUmvn1jWRTGMy4VMBJmBgBnSpZLsMs7+paOodao3k/hLqCBe8j0cfj4Yvtp8k/1fPLaaf4pxxXPSS8r4/Vsl3SXp5EHgNjo8JukDkg6v06nWy2JcrSvUX3xmKjYSipdqF0h6V/jgp6Mh+2DHf0YpnSd6p6TTkjml7UZRL4bLPasnmo7VHb+PKsQ20rZTQ6ql1lclfXODxr4u6Ru1gpizHYxrTvq0beZkE9cfkXRxxcu0pyXZaMiMKX71dBfua5sY1Psk/baHtMK4elC5rT5eFS7Z7Otmd8VyRDwcRZkxmUlFo8rRxlx13Clpz6Dxn0r61FwB1W4X46pNvM/27PLPPmhmVhvNLUWTiaZil1/xEswMx/7fbv9bWfs5nfcxommdceQU55eWSNxGihcmHbMRZK45Oxe8MK75ZYofaku8MyQ9J+mQpKNJMqbzLfeHkIeTuPP35JUIbCSVToRvNrKyftqCSfs3nE9qqT+txWKT8OmxT9LnWguyZDwYV0m6m9dtH+SbJNlamw+tGIIl7Va6/VPS8xusP4rN2JojG8E8NrhUS+d4ht/bbfkTJP0umGk6ER7PtfkVmwR/wzaXgEck7Q1mNcfE9oq4mzx9aFxXB55NBlsiKIyrBNXt67xB0q3bn7aYM+xSxkZVNjez5Eu4GoLZ5fb+pCFb/mB/LLo6MK555LaRyUPzND251VUWRJpRxTt2cUJ8csMUfBUBG61en/ymu8tE6zvGNd+nwuao7N8PJO0Kz7JZNDbH9aSkv4fQ0su2RyS9VtKD4dJtOClt5+4Il4Fpz+KkdqzLnpuzdrY74vnppWG6ujx9xMXOsUWPjw8WW27XBv+/GgH7Q2Dzh/G4NoxkV6vF+dkYV1sCRoNpKyqiaYmA/TGxxbXxsD963d3YwLhaSkligcDWBIZTDHajo+RauGb1wLialYbAIPB/BO6Q9Pnkt7dJshs93R0YV3eS02HHBGz+8Owk/vN6nU/EuBxnMaF3RWC4DOJ7kr7UFYGksxhXr8rTb28Eho/5dDvaMuEwLm/pS7w9EhiOtu4Oz332yOLlPmNc3UpPx50QsCUytlg5vXvY5RKIVC+My0n2Ema3BG4Oz7VGAN2PthhxdftZoOOOCKQLTu1RKlvL1f3D6Yy4HGUwoXZHwLaq+X7S6xvDzhrdgRh2GOPqPgUA0DCB9LlE27tsu73zG+5K3tAwrrw8qQ0CuQjYZLztmRaP7vbc2gokxpUrzagHAnkJpNvXMNoasMW48iYbtUEgF4F0Up7RFsaVK6+oBwLFCKST8t3uAMGlYrH8omIIFCFg21zvDjV3uwMExlUkt6gUAkUIDCflu34mcTPCzHEVyT0qhcBkAumLVJiU3wQjxjU5vygIgSIE0l0gutxPfgxVjGsMJc6BQB0C9kC1vW4sHvbik/RlKXWicNAKxuVAJELshkC6fY29sdzecs6xAQGMi7SAQDsE7IW5e0I4PJe4hS4YVztJSyQQsF0fdgYM3E3EuPhEQKB5Aumrx7ibuI1cjLiaz2cC7IRAugyCy0SMq5O0p5veCaSr5blMxLi85zPxd0LgGUmnSOIycYTgXCqOgMQpEChMwJY93MfdxPGUMa7xrDgTAqUIxGUQ7Ck/kjDGNRIUp0GgIIG49xaXiSMhY1wjQXEaBAoRSFfLczdxJGSMayQoToNAIQLpannuJo6EjHGNBMVpEChEgMvECWAxrgnQKAKBTAS4TJwIEuOaCI5iEMhAgMvEiRAxrongKAaBDAS4TJwIEeOaCI5iEFiTQPpQNXcTV4SJca0IjNMhkIlA+sJX7iauCBXjWhEYp0MgE4G49xaLTicAxbgmQKMIBNYkkL6CjPcmToCJcU2ARhEIrEkgfVP1Lkn2Zh+OFQhgXCvA4lQIZCIQl0EckWSjL44VCWBcKwLjdAhkIHBY0vmS9kmy0RfHigQwrhWBcToE1iSQLoO4QtK9a9bXZXGMq0vZ6fSMBOLe8rb3ll0m8sLXCWJgXBOgUQQCaxA4KOlStmheg6AkjGs9fpSGwKoEXgoFbpF086qFOf9/BDAuMgEC9Qike8tfLslGXxwTCGBcE6BRBAITCdgI66ZQls/eRIiMuNYAR1EITCAQ57ful2SjL46JBHD9ieAoBoEJBJjfmgBtoyIYVyaQVAOBbQik67eulmRvruaYSADjmgiOYhBYkUBcv2XFdrB+a0V6g9MxrvX4URoCYwnwfOJYUiPOw7hGQOIUCGQgEPff4vnEDDAxrgwQqQIC2xBI99+6VpKNvjjWIIBxrQGPohAYSSDdf4ttmkdC2+o0jCsDRKqAwDYEmN/KnCIYV2agVAeBDQgclfQW9t/KlxsYVz6W1ASBjQiw/1aBvMC4CkClSggkBOLziey/lTEtMK6MMKkKAhsQsBdhXMj+W3lzA+PKy5PaIJASOF3SsfAL3ladMTcwrowwqQoCAwK8hqxQSmBchcBSLQTCg9S7Jdn8lo2+ODIRwLgygaQaCGxAwF6EcRrLIPLnBsaVnyk1QsAIXCVpf0DBNjaZcwLjygyU6iAQCOyVdH34nm1sMqcFxpUZKNVBIBCIu0HcHUZfgMlIAOPKCJOqIBAIpKvl2Q2iQFpgXAWgUmX3BLhMLJwCGFdhwFTfJQEuEwvLjnEVBkz13RHgpRgVJMe4KkCmia4IpA9Vs+i0kPQYVyGwVNstgQcl7WLRaVn9Ma6yfKm9LwLsvVVJb4yrEmia6YJAvJvIs4mF5ca4CgOm+q4I8GxiJbkxrkqgaWbxBNJnE22OyzYQ5ChEAOMqBJZquyMQ124dkWTvUeQoSADjKgiXqrshcJmk+0Jv2em0guwYVwXINLF4Agck2YaBdvDC1wpyY1wVINPEognYZeHvJZ0g6RFJFyy6t410DuNqRAjCcEvgBkm3huhvl3Sd2544ChzjciQWoTZJIL5+zILjbmIliTCuSqBpZpEE0tePsei0osQYV0XYNLU4Aunrx/ZJsp85KhDAuCpAponFErhT0p7QO5ZBVJQZ46oIm6YWR4D5rZkkxbhmAk+z7gkwvzWjhBjXjPBp2jWBz0i6K/TgN5Iucd0bZ8FjXM4EI9xmCMSdTi2gn0gyI+OoRADjqgSaZhZHIH3Mh1eQVZYX46oMnOYWQyDuBmEdulzSwcX0zEFHMC4HIhFikwReSqLiwerKEmFclYHT3CIIpNvYWIf4HFWWFeCVgdPcIgh8R9JXQk/+KulNi+iVo05gXI7EItRmCPxS0kdDNLalzXuaiayTQDCuToSmm9kI2MJT25751FDjLZJsaQRHRQIYV0XYNLUIAvdIujLpCXcUZ5AV45oBOk26JvCMpFNCD+zO4vGue+M0eIzLqXCEPQuBdBsbC+BeSVfMEknnjWJcnScA3V+JwJOS3pyUuFqSraDnqEwA46oMnOZcE0gXnVpH+PzMJCfgZwJPsy4JYFyNyIZxNSIEYbggMDSuHZKechH5woLEuBYmKN0pSoARV1G84yvHuMaz4sy+CQzvKB6VdE7fSObrPcY1H3ta9kVgeEeRt/rMqB/GNSN8mnZFYHiZyIr5GeXDuGaET9NuCFwlaX8SLTtCzCwdxjWzADTvgkC6v7wFfJukG1xEvtAgMa6FCku3shL4s6QzkxpZMZ8V7+qVYVyrM6NEfwSel3Ri0m3Wb82cAxjXzALQfPMEhvNbf5D07uajXniAGNfCBaZ7axN4VNLbk1pulLR37VqpYC0CGNda+Ci8cAK22+mxQR95o08DomNcDYhACM0SGK6Wt3cpmnFxzEwA45pZAJpvmsBwtTyXiY3IhXE1IgRhNElguFqey8RGZMK4GhGCMJojMLybyGViQxJhXA2JQShNEbhT0p4kIlbLNyQPxtWQGITSFAH2l29KjlcHg3E1LA6hzUrgxcGe8nxWZpUD42oIP6E0SuAiSQ8NYtsl6eFG4+0uLP6KdCc5HR5BYKOFp+y/NQJcrVMwrlqkaccTgQckXTwI+DJJ93vqxJJjxbiWrC59m0LgfEmHBwX/JemEKZVRpgwBjKsMV2r1S8BGVvcNwv+spB/67dLyIse4lqcpPVqPwEbGxcaB6zHNXhrjyo6UCp0TuFLSPYM+XCPpx877tajwMa5FyUlnMhCwveRvHdTDjqcZwOasAuPKSZO6lkDggKTdSUeOSDp3CR1bUh8wriWpSV9yEPiHpJOSinhGMQfVzHVgXJmBUp17AsOtbFgx36CkGFeDohDSbASGj/r8TdIZs0VDw5sSwLhIDgi8QmC4VfPdkmxfLo7GCGBcjQlCOLMSGO7BxVbNs8qxeeMYV6PCENYsBGyX051JyzxYPYsM2zeKcW3PiDP6ITCcmGf9VqPaY1yNCkNY1QkMJ+YPSbLfcTRIAONqUBRCmoXA8BlF1m/NIsO4RjGucZw4a/kEhncUebC6Yc0xrobFIbSqBIbPKDK/VRX/ao1hXKvx4uzlEtgr6frQvUckXbDcrvrvGcblX0N6kIdAaly/kPTxPNVSSwkCGFcJqtTpkUC6+JSFp40riHE1LhDhVSNwUNKloTUm5qthn9YQxjWNG6WWRyA1LlbMN64vxtW4QIRXjcBTkk4LrWFc1bBPawjjmsaNUssjkD7ug3E1ri/G1bhAhFeNQGpcbB5YDfu0hjCuadwotTwCqXGdJ8l2iuBolADG1agwhFWdQGpcfC6q41+tQQRajRdnL5dANK6nJZ2+3G4uo2cY1zJ0pBfrEbDXjz0WquB1ZOuxrFIa46qCmUYaJ/AJST8PMf5K0scaj7f78DCu7lMAAJLSnSFul3QdVNomgHG1rQ/R1SGQPmDNGq46zNdqBeNaCx+FF0LgYUkXhr6wFMKBqBiXA5EIsTgB7igWR5y3AYwrL09q80cg3WueF8A60Q/jciIUYRYjcLOkm0Lt7MNVDHPeijGuvDypzR+BdH6LZxSd6IdxORGKMIsQsBXyx0LNLDwtgrhMpRhXGa7U6oNA+kqyfZLsZw4HBDAuByIRYjEC6T7zbNdcDHP+ijGu/Eyp0Q+BuOspD1b70ezlSDEuZ4IRbjYCF0l6KNTGZWI2rHUqwrjqcKaV9gikj/lwmdiePltGhHE5E4xwsxGIyyC4TMyGtF5FGFc91rTUFoEXJL1OEqvl29JlVDQY1yhMnLQwAuljPl+QdMfC+rf47mBci5eYDm5AIJ3fYjcIhymCcTkUjZDXJhDnt1gtvzbKeSrAuObhTqvzEUj3l78t7H46XzS0PIkAxjUJG4UcE0i3aWYZhFMhMS6nwhH2ZAIHJO0Opcn/yRjnLYhw8/Kn9foE4m6nhyTZ6nkOhwQwLoeiEfJkAryGbDK6tgpiXG3pQTRlCaS7nfJ8YlnWRWvHuIripfLGCLCNTWOCTA0H45pKjnIeCaTbNPP+RI8KclfFsWqEPpVAnJi38jsk2X5cHA4JMOJyKBohTyaQGhe5Pxnj/AURb34NiKAOgXTjQLayqcO8WCsYVzG0VNwYgXRHCNZwNSbOquFgXKsS43yvBOxlr98OwT8g6f1eO0Lc7DlPDvRD4LuSvhi6+zNJn+yn68vrKSOu5WlKjzYmkD6jaKMv25OLwykBjMupcIS9MoH4KjIryK4QK+NrqwDG1ZYeRFOGQDoxby2whqsM52q1YlzVUNPQjAR+JOma0P5zkk6eMRaazkAA48oAkSqaJ/CEpLNClM9KOrX5iAlwSwIYFwmydAJnS3p80MlzJB1deseX3D+Ma8nq0rdIwF6K8bbww58k7QSNbwIYl2/9iH4cAdtA0O4k2rFf0r3jinFWqwQwrlaVIS4IQGBTAhgXyQEBCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwR+C/doIhTZIi/uMAAAAASUVORK5CYII="
1124
- image_b64 = data_uri.split(",")[1]
1125
- binary = base64.b64decode(image_b64)
1126
- image = np.asarray(Image.open(io.BytesIO(binary)))
1127
- return image
1128
-
1129
- def _update_predictions_metadata(self, item: entities.Item, predictions: entities.AnnotationCollection):
1130
- """
1131
- add model_name and model_id to the metadata of the annotations.
1132
- add model_info to the metadata of the system metadata of the annotation.
1133
- Add item id to all the annotations in the AnnotationCollection
1134
-
1135
- :param item: Entity.Item
1136
- :param predictions: item's AnnotationCollection
1137
- :return:
1138
- """
1139
- for prediction in predictions:
1140
- if prediction.type == entities.AnnotationType.SEGMENTATION:
1141
- color = None
1142
- try:
1143
- color = item.dataset._get_ontology().color_map.get(prediction.label, None)
1144
- except (exceptions.BadRequest, exceptions.NotFound):
1145
- ...
1146
- if color is None:
1147
- if self.model_entity._dataset is not None:
1148
- try:
1149
- color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label, (255, 255, 255))
1150
- except (exceptions.BadRequest, exceptions.NotFound):
1151
- ...
1152
- if color is None:
1153
- logger.warning("Can't get annotation color from model's dataset, using default.")
1154
- color = prediction.color
1155
- prediction.color = color
1156
-
1157
- prediction.item_id = item.id
1158
- if 'user' in prediction.metadata and 'model' in prediction.metadata['user']:
1159
- prediction.metadata['user']['model']['model_id'] = self.model_entity.id
1160
- prediction.metadata['user']['model']['name'] = self.model_entity.name
1161
- if 'system' not in prediction.metadata:
1162
- prediction.metadata['system'] = dict()
1163
- if 'model' not in prediction.metadata['system']:
1164
- prediction.metadata['system']['model'] = dict()
1165
- confidence = prediction.metadata.get('user', dict()).get('model', dict()).get('confidence', None)
1166
- prediction.metadata['system']['model'] = {
1167
- 'model_id': self.model_entity.id,
1168
- 'name': self.model_entity.name,
1169
- 'confidence': confidence,
1170
- }
1171
-
1172
- ##############################
1173
- # Callback Factory functions #
1174
- ##############################
1175
- @property
1176
- def dataloop_keras_callback(self):
1177
- """
1178
- Returns the constructor for a keras api dump callback
1179
- The callback is used for dlp platform to show train losses
1180
-
1181
- :return: DumpHistoryCallback constructor
1182
- """
1183
- try:
1184
- import keras
1185
- except (ImportError, ModuleNotFoundError) as err:
1186
- raise RuntimeError(f'{self.__class__.__name__} depends on extenral package. Please install ') from err
1187
-
1188
- import os
1189
- import time
1190
- import json
1191
-
1192
- class DumpHistoryCallback(keras.callbacks.Callback):
1193
- def __init__(self, dump_path):
1194
- super().__init__()
1195
- if os.path.isdir(dump_path):
1196
- dump_path = os.path.join(dump_path, f'__view__training-history__{time.strftime("%F-%X")}.json')
1197
- self.dump_file = dump_path
1198
- self.data = dict()
1199
-
1200
- def on_epoch_end(self, epoch, logs=None):
1201
- logs = logs or {}
1202
- for name, val in logs.items():
1203
- if name not in self.data:
1204
- self.data[name] = {'x': list(), 'y': list()}
1205
- self.data[name]['x'].append(float(epoch))
1206
- self.data[name]['y'].append(float(val))
1207
- self.dump_history()
1208
-
1209
- def dump_history(self):
1210
- _json = {
1211
- "query": {},
1212
- "datasetId": "",
1213
- "xlabel": "epoch",
1214
- "title": "training loss",
1215
- "ylabel": "val",
1216
- "type": "metric",
1217
- "data": [
1218
- {
1219
- "name": name,
1220
- "x": values['x'],
1221
- "y": values['y'],
1222
- }
1223
- for name, values in self.data.items()
1224
- ],
1225
- }
1226
-
1227
- with open(self.dump_file, 'w') as f:
1228
- json.dump(_json, f, indent=2)
1229
-
1230
- return DumpHistoryCallback
1
+ import dataclasses
2
+ import threading
3
+ import tempfile
4
+ import datetime
5
+ import logging
6
+ import string
7
+ import shutil
8
+ import random
9
+ import base64
10
+ import copy
11
+ import time
12
+ import tqdm
13
+ import traceback
14
+ import sys
15
+ import io
16
+ import os
17
+ from itertools import chain
18
+ from PIL import Image
19
+ from functools import partial
20
+ import numpy as np
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ import attr
23
+ from collections.abc import MutableMapping
24
+ from typing import Optional
25
+ from .. import entities, utilities, repositories, exceptions
26
+ from ..services import service_defaults
27
+ from ..services.api_client import ApiClient
28
+
29
+ logger = logging.getLogger('ModelAdapter')
30
+
31
+ # Constants
32
+ PREDICT_EMBED_DEFAULT_SUBSET_LIMIT = 1000
33
+ PREDICT_EMBED_DEFAULT_TIMEOUT = 3600 * 24
34
+
35
+
36
+ class ModelConfigurations(MutableMapping):
37
+ """
38
+ Manages model configuration using composition with a backing dict.
39
+
40
+ Uses MutableMapping to implement dict-like behavior without inheritance.
41
+ This avoids duplication: if we inherited from dict, we'd have two dicts
42
+ (one from inheritance, one from model_entity.configuration), leading to
43
+ data inconsistency and maintenance issues.
44
+ """
45
+
46
+ def __init__(self, base_model_adapter):
47
+ # Store reference to base_model_adapter dictionary
48
+ self._backing_dict = {}
49
+
50
+ if base_model_adapter is not None and base_model_adapter.model_entity is not None and base_model_adapter.model_entity.configuration is not None:
51
+ self._backing_dict = base_model_adapter.model_entity.configuration
52
+ if 'include_background' not in self._backing_dict:
53
+ self._backing_dict['include_background'] = False
54
+ self._base_model_adapter = base_model_adapter
55
+ # Don't call _update_model_entity during initialization to avoid premature updates
56
+
57
+ def _update_model_entity(self):
58
+ if self._base_model_adapter is not None and self._base_model_adapter.model_entity is not None:
59
+ self._base_model_adapter.model_entity.update(reload_services=False)
60
+
61
+ def __ior__(self, other):
62
+ self.update(other)
63
+ return self
64
+
65
+ # Required MutableMapping abstract methods
66
+ def __getitem__(self, key):
67
+ return self._backing_dict[key]
68
+
69
+ def __setitem__(self, key, value):
70
+ # Note: This method only updates the backing dict, not object attributes.
71
+ # If you need to also update object attributes, be careful to avoid
72
+ # infinite recursion by not calling __setattr__ from here.
73
+ update = False
74
+ if key not in self._backing_dict or self._backing_dict.get(key) != value:
75
+ update = True
76
+ self._backing_dict[key] = value
77
+ if update:
78
+ self._update_model_entity()
79
+
80
+ def __delitem__(self, key):
81
+ del self._backing_dict[key]
82
+
83
+ def __iter__(self):
84
+ return iter(self._backing_dict)
85
+
86
+ def __len__(self):
87
+ return len(self._backing_dict)
88
+
89
+ def get(self, key, default=None):
90
+ if key not in self._backing_dict:
91
+ self.__setitem__(key, default)
92
+ return self._backing_dict.get(key)
93
+
94
+ def update(self, *args, **kwargs):
95
+ # Check if there will be any modifications
96
+ update_dict = dict(*args, **kwargs)
97
+ has_changes = False
98
+ for key, value in update_dict.items():
99
+ if key not in self._backing_dict or self._backing_dict[key] != value:
100
+ has_changes = True
101
+ break
102
+ self._backing_dict.update(*args, **kwargs)
103
+
104
+ if has_changes:
105
+ self._update_model_entity()
106
+
107
+ def setdefault(self, key, default=None):
108
+ if key not in self._backing_dict:
109
+ self._backing_dict[key] = default
110
+ return self._backing_dict[key]
111
+
112
+
113
+ @dataclasses.dataclass
114
+ class AdapterDefaults(ModelConfigurations):
115
+ # for predict items, dataset, evaluate
116
+ upload_annotations: bool = dataclasses.field(default=True)
117
+ clean_annotations: bool = dataclasses.field(default=True)
118
+ overwrite_annotations: bool = dataclasses.field(default=True)
119
+ # for embeddings
120
+ upload_features: bool = dataclasses.field(default=None)
121
+ # for training
122
+ root_path: str = dataclasses.field(default=None)
123
+ data_path: str = dataclasses.field(default=None)
124
+ output_path: str = dataclasses.field(default=None)
125
+
126
+ def __init__(self, base_model_adapter=None):
127
+ super().__init__(base_model_adapter)
128
+ for f in dataclasses.fields(AdapterDefaults):
129
+ # if the field exists in model_entity.configuration, use it
130
+ # else set it from the attribute default value
131
+ if super().get(f.name) is not None:
132
+ super().__setattr__(f.name, super().get(f.name))
133
+ else:
134
+ super().__setitem__(f.name, f.default)
135
+
136
+ def __setattr__(self, key, value):
137
+ # Dataclass-like fields behave as attributes, so map to dict
138
+ super().__setattr__(key, value)
139
+ if not key.startswith("_"):
140
+ super().__setitem__(key, value)
141
+
142
+ def update(self, *args, **kwargs):
143
+ for f in dataclasses.fields(AdapterDefaults):
144
+ if f.name in kwargs:
145
+ setattr(self, f.name, kwargs[f.name])
146
+ super().update(*args, **kwargs)
147
+
148
+ def resolve(self, key, *args):
149
+ for arg in args:
150
+ if arg is not None:
151
+ super().__setitem__(key, arg)
152
+ return arg
153
+ return self.get(key, None)
154
+
155
+
156
+ class BaseModelAdapter(utilities.BaseServiceRunner):
157
+ _client_api = attr.ib(type=ApiClient, repr=False)
158
+ _feature_set_lock = threading.Lock()
159
+
160
+ def __init__(self, model_entity: entities.Model = None):
161
+ self.logger = logger
162
+ # entities
163
+ self._model_entity = None
164
+ self._configuration = None
165
+ self.adapter_defaults = None
166
+ self.model = None
167
+ self.bucket_path = None
168
+ self._project = None
169
+ self._feature_set = None
170
+ # funcs
171
+ self.item_to_batch_mapping = {'text': self._item_to_text, 'image': self._item_to_image}
172
+ if model_entity is not None:
173
+ self.load_from_model(model_entity=model_entity)
174
+ logger.warning(
175
+ "in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'."
176
+ )
177
+
178
+ ##################
179
+ # Configurations #
180
+ ##################
181
+
182
+ @property
183
+ def configuration(self) -> dict:
184
+ # load from model
185
+ if self._model_entity is not None:
186
+ configuration = self._configuration
187
+ else:
188
+ configuration = dict()
189
+ return configuration
190
+
191
+ @configuration.setter
192
+ def configuration(self, configuration: dict):
193
+ assert isinstance(configuration, dict)
194
+ if self._model_entity is not None:
195
+ # Update configuration with received dict
196
+ self._model_entity.configuration = configuration
197
+ self.adapter_defaults = AdapterDefaults(self)
198
+ self._configuration = self.adapter_defaults
199
+
200
+ ############
201
+ # Entities #
202
+ ############
203
+ @property
204
+ def project(self):
205
+ if self._project is None:
206
+ self._project = self.model_entity.project
207
+ assert isinstance(self._project, entities.Project)
208
+ return self._project
209
+
210
+ @property
211
+ def feature_set(self):
212
+ if self._feature_set is None:
213
+ self._feature_set = self._get_feature_set()
214
+ assert isinstance(self._feature_set, entities.FeatureSet)
215
+ return self._feature_set
216
+
217
+ @property
218
+ def model_entity(self):
219
+ if self._model_entity is None:
220
+ raise ValueError("No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
221
+ assert isinstance(self._model_entity, entities.Model)
222
+ return self._model_entity
223
+
224
+ @model_entity.setter
225
+ def model_entity(self, model_entity):
226
+ assert isinstance(model_entity, entities.Model)
227
+ if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
228
+ if self._model_entity.id != model_entity.id:
229
+ self.logger.warning('Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
230
+ self._model_entity = model_entity
231
+ self.adapter_defaults = AdapterDefaults(self)
232
+ self._configuration = self.adapter_defaults
233
+
234
+ ###################################
235
+ # NEED TO IMPLEMENT THESE METHODS #
236
+ ###################################
237
+
238
+ def load(self, local_path, **kwargs):
239
+ """
240
+ Loads model and populates self.model with a `runnable` model
241
+
242
+ Virtual method - need to implement
243
+
244
+ This function is called by load_from_model (download to local and then loads)
245
+
246
+ :param local_path: `str` directory path in local FileSystem
247
+ """
248
+ raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
249
+
250
+ def save(self, local_path, **kwargs):
251
+ """
252
+ Saves configuration and weights locally
253
+
254
+ Virtual method - need to implement
255
+
256
+ the function is called in save_to_model which first save locally and then uploads to model entity
257
+
258
+ :param local_path: `str` directory path in local FileSystem
259
+ """
260
+ raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
261
+
262
+ def train(self, data_path, output_path, **kwargs):
263
+ """
264
+ Virtual method - need to implement
265
+
266
+ Train the model according to data in data_paths and save the train outputs to output_path,
267
+ this include the weights and any other artifacts created during train
268
+
269
+ :param data_path: `str` local File System path to where the data was downloaded and converted at
270
+ :param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
271
+ """
272
+ raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
273
+
274
+ def predict(self, batch, **kwargs):
275
+ """
276
+ Model inference (predictions) on batch of items
277
+
278
+ Virtual method - need to implement
279
+
280
+ :param batch: output of the `prepare_item_func` func
281
+ :return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
282
+ """
283
+ raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
284
+
285
+ def embed(self, batch, **kwargs):
286
+ """
287
+ Extract model embeddings on batch of items
288
+
289
+ Virtual method - need to implement
290
+
291
+ :param batch: output of the `prepare_item_func` func
292
+ :return: `list[list]` a feature vector per each item in the batch
293
+ """
294
+ raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
295
+
296
+ def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
297
+ """
298
+ This function evaluates the model prediction on a dataset (with GT annotations).
299
+ The evaluation process will upload the scores and metrics to the platform.
300
+
301
+ :param model: The model to evaluate (annotation.metadata.system.model.name
302
+ :param dataset: Dataset where the model predicted and uploaded its annotations
303
+ :param filters: Filters to query items on the dataset
304
+ :return:
305
+ """
306
+ import dtlpymetrics
307
+
308
+ compare_types = model.output_type
309
+ if not filters:
310
+ filters = entities.Filters()
311
+ if filters is not None and isinstance(filters, dict):
312
+ filters = entities.Filters(custom_filter=filters)
313
+ model = dtlpymetrics.scoring.create_model_score(
314
+ model=model,
315
+ dataset=dataset,
316
+ filters=filters,
317
+ compare_types=compare_types,
318
+ )
319
+ return model
320
+
321
+ def convert_from_dtlpy(self, data_path, **kwargs):
322
+ """Convert Dataloop structure data to model structured
323
+
324
+ Virtual method - need to implement
325
+
326
+ e.g. take dlp dir structure and construct annotation file
327
+
328
+ :param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
329
+ :return:
330
+ """
331
+ raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
332
+
333
+ #################
334
+ # DTLPY METHODS #
335
+ ################
336
+ def prepare_item_func(self, item: entities.Item):
337
+ """
338
+ Prepare the Dataloop item before calling the `predict` function with a batch.
339
+ A user can override this function to load item differently
340
+ Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
341
+
342
+ :param item:
343
+ :return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
344
+ """
345
+ # Item to batch func
346
+ if isinstance(self.model_entity.input_type, list):
347
+ if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
348
+ processed = self._item_to_text(item)
349
+ elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
350
+ processed = self._item_to_image(item)
351
+ else:
352
+ processed = self._item_to_item(item)
353
+
354
+ elif self.model_entity.input_type in self.item_to_batch_mapping:
355
+ processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
356
+
357
+ else:
358
+ processed = self._item_to_item(item)
359
+
360
+ return processed
361
+
362
+ def __include_model_annotations(self, annotation_filters):
363
+ include_model_annotations = self.model_entity.configuration.get("include_model_annotations", False)
364
+ if include_model_annotations is False:
365
+ if annotation_filters.custom_filter is None:
366
+ annotation_filters.add(field="metadata.system.model.name", values=False, operator=entities.FiltersOperations.EXISTS)
367
+ else:
368
+ annotation_filters.custom_filter['filter']['$and'].append({'metadata.system.model.name': {'$exists': False}})
369
+ return annotation_filters
370
+
371
+ def __download_items(self, dataset, filters, local_path, annotation_options, annotation_filters=None):
372
+ """
373
+ Download items from dataset with optional annotation filters.
374
+
375
+ :param dataset: Dataset to download from
376
+ :param filters: Filters to apply
377
+ :param local_path: Local path to save files
378
+ :param annotation_options: Annotation download options
379
+ :param annotation_filters: Optional filters for annotations
380
+ :return: List of downloaded items
381
+ """
382
+ if annotation_options == entities.ViewAnnotationOptions.JSON:
383
+ downloader = repositories.Downloader(dataset.items)
384
+ return downloader._download_recursive(
385
+ local_path=local_path,
386
+ filters=filters,
387
+ annotation_filters=annotation_filters
388
+ )
389
+ else:
390
+ return dataset.items.download(
391
+ filters=filters,
392
+ local_path=local_path,
393
+ annotation_options=annotation_options,
394
+ annotation_filters=annotation_filters
395
+ )
396
+
397
+ def __download_background_images(self, filters, data_subset_base_path, annotation_options):
398
+ background_list = list()
399
+ if self.configuration.get('include_background', False) is True:
400
+ filters.custom_filter["filter"]["$and"].append({"annotated": False})
401
+ background_list = self.__download_items(
402
+ dataset=self.model_entity.dataset,
403
+ filters=filters,
404
+ local_path=data_subset_base_path,
405
+ annotation_options=annotation_options
406
+ )
407
+ return background_list
408
+
409
+ def prepare_data(
410
+ self,
411
+ dataset: entities.Dataset,
412
+ # paths
413
+ root_path=None,
414
+ data_path=None,
415
+ output_path=None,
416
+ #
417
+ overwrite=False,
418
+ **kwargs,
419
+ ):
420
+ """
421
+ Prepares dataset locally before training or evaluation.
422
+ download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
423
+
424
+ :param dataset: dl.Dataset
425
+ :param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
426
+ :param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
427
+ :param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
428
+
429
+ :param bool overwrite: overwrite the data path (download again). default is False
430
+ """
431
+ # define paths
432
+ dataloop_path = service_defaults.DATALOOP_PATH
433
+ root_path = self.adapter_defaults.resolve("root_path", root_path)
434
+ data_path = self.adapter_defaults.resolve("data_path", data_path)
435
+ output_path = self.adapter_defaults.resolve("output_path", output_path)
436
+ if root_path is None:
437
+ now = datetime.datetime.now()
438
+ root_path = os.path.join(
439
+ dataloop_path,
440
+ 'model_data',
441
+ "{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
442
+ now.strftime('%Y-%m-%d-%H%M%S'),
443
+ )
444
+ if data_path is None:
445
+ data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
446
+ os.makedirs(data_path, exist_ok=True)
447
+ if output_path is None:
448
+ output_path = os.path.join(root_path, 'output')
449
+ os.makedirs(output_path, exist_ok=True)
450
+
451
+ if len(os.listdir(data_path)) > 0:
452
+ self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
453
+
454
+ annotation_options = entities.ViewAnnotationOptions.JSON
455
+ if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
456
+ annotation_options = entities.ViewAnnotationOptions.INSTANCE
457
+
458
+ # Download the subset items
459
+ subsets = self.model_entity.metadata.get("system", {}).get("subsets", None)
460
+ annotations_subsets = self.model_entity.metadata.get("system", {}).get("annotationsSubsets", {})
461
+ if subsets is None:
462
+ raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
463
+ for subset, filters_dict in subsets.items():
464
+ _filters_dict = filters_dict.copy()
465
+ data_subset_base_path = os.path.join(data_path, subset)
466
+ if os.path.isdir(data_subset_base_path) and not overwrite:
467
+ # existing and dont overwrite
468
+ self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
469
+ continue
470
+
471
+ filters = entities.Filters(custom_filter=_filters_dict)
472
+ self.logger.debug("Downloading subset {!r} of {}".format(subset, self.model_entity.dataset.name))
473
+
474
+ annotation_filters = None
475
+ if subset in annotations_subsets:
476
+ annotation_filters = entities.Filters(
477
+ use_defaults=False,
478
+ resource=entities.FiltersResource.ANNOTATION,
479
+ custom_filter=annotations_subsets[subset],
480
+ )
481
+ # if user provided annotation_filters, skip the default filters
482
+ elif self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
483
+ annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
484
+ if self.model_entity.output_type in [
485
+ entities.AnnotationType.SEGMENTATION,
486
+ entities.AnnotationType.POLYGON,
487
+ ]:
488
+ model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
489
+ else:
490
+ model_output_types = [self.model_entity.output_type]
491
+
492
+ annotation_filters.add(
493
+ field=entities.FiltersKnownFields.TYPE,
494
+ values=model_output_types,
495
+ operator=entities.FiltersOperations.IN,
496
+ )
497
+
498
+ annotation_filters = self.__include_model_annotations(annotation_filters)
499
+ annotations_subsets[subset] = annotation_filters.prepare()
500
+
501
+ ret_list = self.__download_items(
502
+ dataset=dataset,
503
+ filters=filters,
504
+ local_path=data_subset_base_path,
505
+ annotation_options=annotation_options,
506
+ annotation_filters=annotation_filters
507
+ )
508
+ _filters_dict = subsets[subset].copy()
509
+ filters = entities.Filters(custom_filter=_filters_dict)
510
+ background_ret_list = self.__download_background_images(
511
+ filters=filters,
512
+ data_subset_base_path=data_subset_base_path,
513
+ annotation_options=annotation_options,
514
+ )
515
+ ret_list = list(ret_list)
516
+ background_ret_list = list(background_ret_list)
517
+ self.logger.debug(f"Subset '{subset}': ret_list length: {len(ret_list)}, background_ret_list length: {len(background_ret_list)}")
518
+ # Combine ret_list and background_ret_list generators into a single generator
519
+ ret_list = ret_list + background_ret_list
520
+ if isinstance(ret_list, list) and len(ret_list) == 0:
521
+ if annotation_filters is not None:
522
+ annotation_filters_str = annotation_filters.prepare()
523
+ else:
524
+ annotation_filters_str = None
525
+ raise ValueError(
526
+ f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
527
+ f"Subset {subset} filters: {filters.prepare()}\n"
528
+ f"Annotation filters: {annotation_filters_str}"
529
+ )
530
+
531
+ self.convert_from_dtlpy(data_path=data_path, **kwargs)
532
+ return root_path, data_path, output_path
533
+
534
+ def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
535
+ """Loads a model from given `dl.Model`.
536
+ Reads configurations and instantiate self.model_entity
537
+ Downloads the model_entity bucket (if available)
538
+
539
+ :param model_entity: `str` dl.Model entity
540
+ :param local_path: `str` directory path in local FileSystem to download the model_entity to
541
+ :param overwrite: `bool` (default False) if False does not download files with same name else (True) download all
542
+ """
543
+ if model_entity is not None:
544
+ self.model_entity = model_entity
545
+ if local_path is None:
546
+ local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
547
+ # Load configuration and adapter defaults
548
+ self.adapter_defaults = AdapterDefaults(self)
549
+ # Point _configuration to the same object since AdapterDefaults inherits from ModelConfigurations
550
+ self._configuration = self.adapter_defaults
551
+ # Download
552
+ self.model_entity.artifacts.download(local_path=local_path, overwrite=overwrite)
553
+ self.load(local_path, **kwargs)
554
+
555
+ def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
556
+ """
557
+ Saves the model state to a new bucket and configuration
558
+
559
+ Saves configuration and weights to artifacts
560
+ Mark the model as `trained`
561
+ loads only applies for remote buckets
562
+
563
+ :param local_path: `str` directory path in local FileSystem to save the current model bucket (weights) (default will create a temp dir)
564
+ :param replace: `bool` will clean the bucket's content before uploading new files
565
+ :param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
566
+ :return:
567
+ """
568
+
569
+ if local_path is None:
570
+ local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
571
+ self.logger.debug("Using temporary dir at {}".format(local_path))
572
+
573
+ self.save(local_path=local_path, **kwargs)
574
+
575
+ if self.model_entity is None:
576
+ raise ValueError('Missing model entity on the adapter. ' 'Please set before saving: "adapter.model_entity=model"')
577
+
578
+ self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'), overwrite=True)
579
+ if cleanup:
580
+ shutil.rmtree(path=local_path, ignore_errors=True)
581
+ self.logger.info("Clean-up. deleting {}".format(local_path))
582
+
583
+ # ===============
584
+ # SERVICE METHODS
585
+ # ===============
586
+
587
+ @entities.Package.decorators.function(
588
+ display_name='Predict Items',
589
+ inputs={'items': 'Item[]'},
590
+ outputs={'items': 'Item[]', 'annotations': 'Annotation[]'},
591
+ )
592
+ def predict_items(self, items: list, batch_size=None, **kwargs):
593
+ """
594
+ Run the predict function on the input list of items (or single) and return the items and the predictions.
595
+ Each prediction is by the model output type (package.output_type) and model_info in the metadata
596
+
597
+ :param items: `List[dl.Item]` list of items to predict
598
+ :param batch_size: `int` size of batch to run a single inference
599
+
600
+ :return: `List[dl.Item]`, `List[List[dl.Annotation]]`
601
+ """
602
+ if batch_size is None:
603
+ batch_size = self.configuration.get('batch_size', 4)
604
+ input_type = self.model_entity.input_type
605
+ self.logger.debug("Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
606
+ pool = ThreadPoolExecutor(max_workers=16)
607
+ error_counter = 0
608
+ fail_ids = list()
609
+ annotations = list()
610
+ for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None, file=sys.stdout):
611
+ batch_items = items[i_batch : i_batch + batch_size]
612
+ batch = list(pool.map(self.prepare_item_func, batch_items))
613
+ try:
614
+ batch_collections = self.predict(batch, **kwargs)
615
+ except Exception as e:
616
+ item_ids = [item.id for item in batch_items]
617
+ self.logger.error(f"Failed to predict batch {i_batch} for items {item_ids}. Error: {e}\n{traceback.format_exc()}")
618
+ error_counter += 1
619
+ fail_ids.extend(item_ids)
620
+ continue
621
+ _futures = list(pool.map(partial(self._update_predictions_metadata), batch_items, batch_collections))
622
+ # Loop over the futures to make sure they are all done to avoid race conditions
623
+ _ = [_f for _f in _futures]
624
+ self.logger.debug("Uploading items' annotation for model {!r}.".format(self.model_entity.name))
625
+ try:
626
+ batch_collections = list(
627
+ pool.map(partial(self._upload_model_annotations), batch_items, batch_collections)
628
+ )
629
+ except Exception as err:
630
+ item_ids = [item.id for item in batch_items]
631
+ self.logger.error(
632
+ f"Failed to upload annotations for items {item_ids}. Error: {err}\n{traceback.format_exc()}"
633
+ )
634
+ error_counter += 1
635
+ fail_ids.extend(item_ids)
636
+
637
+ for collection in batch_collections:
638
+ # function needs to return `List[List[dl.Annotation]]`
639
+ # convert annotation collection to a list of dl.Annotation for each batch
640
+ if isinstance(collection, entities.AnnotationCollection):
641
+ annotations.extend([annotation for annotation in collection.annotations])
642
+ else:
643
+ logger.warning(f'RETURN TYPE MAY BE INVALID: {type(collection)}')
644
+ annotations.extend(collection)
645
+ # TODO call the callback
646
+
647
+ pool.shutdown()
648
+ if error_counter > 0:
649
+ raise Exception(f"Failed to predict all items. Failed IDs: {fail_ids}, See logs for more details")
650
+ return items, annotations
651
+
652
+ @entities.Package.decorators.function(
653
+ display_name='Embed Items',
654
+ inputs={'items': 'Item[]'},
655
+ outputs={'items': 'Item[]', 'features': 'Json[]'},
656
+ )
657
+ def embed_items(self, items: list, upload_features=None, batch_size=None, progress: utilities.Progress = None, **kwargs):
658
+ """
659
+ Extract feature from an input list of items (or single) and return the items and the feature vector.
660
+
661
+ :param items: `List[dl.Item]` list of items to embed
662
+ :param upload_features: `bool` uploads the features on the given items
663
+ :param batch_size: `int` size of batch to run a single embed
664
+
665
+ :return: `List[dl.Item]`, `List[List[vector]]`
666
+ """
667
+ if batch_size is None:
668
+ batch_size = self.configuration.get('batch_size', 4)
669
+ upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
670
+ skip_default_items = upload_features is None
671
+ if upload_features is None:
672
+ upload_features = True
673
+ input_type = self.model_entity.input_type
674
+ self.logger.debug("Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
675
+ error_counter = 0
676
+ fail_ids = list()
677
+
678
+ feature_set = self.feature_set
679
+
680
+ # upload the feature vectors
681
+ pool = ThreadPoolExecutor(max_workers=16)
682
+ vectors = list()
683
+ _items = list()
684
+ for i_batch in tqdm.tqdm(
685
+ range(0, len(items), batch_size),
686
+ desc='embedding',
687
+ unit='bt',
688
+ leave=None,
689
+ file=sys.stdout,
690
+ ):
691
+ batch_items = items[i_batch : i_batch + batch_size]
692
+ batch = list(pool.map(self.prepare_item_func, batch_items))
693
+ try:
694
+ batch_vectors = self.embed(batch, **kwargs)
695
+ except Exception as err:
696
+ item_ids = [item.id for item in batch_items]
697
+ self.logger.error(f"Failed to embed batch {i_batch} for items {item_ids}. Error: {err}\n{traceback.format_exc()}")
698
+ error_counter += 1
699
+ fail_ids.extend(item_ids)
700
+ continue
701
+ vectors.extend(batch_vectors)
702
+ # Save the items in the order of the vectors
703
+ _items.extend(batch_items)
704
+ pool.shutdown()
705
+
706
+ if upload_features is True:
707
+ embeddings_size = self.configuration.get('embeddings_size', 256)
708
+ valid_items = []
709
+ valid_vectors = []
710
+ items_to_upload = []
711
+ vectors_to_upload = []
712
+
713
+ for item, vector in zip(_items, vectors):
714
+ # Check if vector is valid
715
+ if vector is None or len(vector) != embeddings_size:
716
+ self.logger.warning(f"Vector generated for item {item.id} is None or has wrong size. Skipping...")
717
+ continue
718
+
719
+ # Item and vector are valid
720
+ valid_items.append(item)
721
+ valid_vectors.append(vector)
722
+
723
+ # Check if item should be skipped (prompt items)
724
+ _system_metadata = getattr(item, 'system', dict())
725
+ is_prompt = _system_metadata.get('shebang', dict()).get('dltype', '') == 'prompt'
726
+ if skip_default_items and is_prompt:
727
+ self.logger.debug(f"Skipping feature upload for prompt item {item.id}")
728
+ continue
729
+
730
+ # Items were not skipped - should be uploaded
731
+ items_to_upload.append(item)
732
+ vectors_to_upload.append(vector)
733
+
734
+ # Update the original lists with valid items only
735
+ _items[:] = valid_items
736
+ vectors[:] = valid_vectors
737
+
738
+ if len(_items) != len(vectors):
739
+ raise ValueError(f"The number of items ({len(_items)}) is not equal to the number of vectors ({len(vectors)}).")
740
+
741
+ self.logger.debug(f"Uploading {len(items_to_upload)} items' feature vectors for model {self.model_entity.name}.")
742
+ try:
743
+ start_time = time.time()
744
+ feature_set.features.create(entity=items_to_upload, value=vectors_to_upload, feature_set_id=feature_set.id, project_id=self.model_entity.project_id)
745
+ self.logger.debug(f"Uploaded {len(items_to_upload)} items' feature vectors for model {self.model_entity.name} in {time.time() - start_time} seconds.")
746
+ except Exception as err:
747
+ self.logger.error(f"Failed to upload feature vectors. Error: {err}\n{traceback.format_exc()}")
748
+ error_counter += 1
749
+ if error_counter > 0:
750
+ raise Exception(f"Failed to embed all items. Failed IDs: {fail_ids}, See logs for more details")
751
+ return _items, vectors
752
+
753
+ @entities.Package.decorators.function(
754
+ display_name='Embed Dataset with DQL',
755
+ inputs={'dataset': 'Dataset', 'filters': 'Json'},
756
+ )
757
+ def embed_dataset(
758
+ self,
759
+ dataset: entities.Dataset,
760
+ filters: Optional[entities.Filters] = None,
761
+ upload_features: Optional[bool] = None,
762
+ batch_size: Optional[int] = None,
763
+ progress: Optional[utilities.Progress] = None,
764
+ **kwargs,
765
+ ):
766
+ """
767
+ Run model embedding on all items in a dataset
768
+
769
+ :param dataset: Dataset entity to embed
770
+ :param filters: Filters entity for filtering before embedding
771
+ :param upload_features: bool whether to upload features back to platform
772
+ :param batch_size: int size of batch to run a single embedding
773
+ :param progress: dl.Progress object to track progress
774
+ :return: bool indicating if embedding completed successfully
775
+ """
776
+
777
+ self._execute_dataset_operation(
778
+ dataset=dataset,
779
+ operation_type='embed',
780
+ filters=filters,
781
+ progress=progress,
782
+ batch_size=batch_size,
783
+ )
784
+
785
+ @entities.Package.decorators.function(
786
+ display_name='Predict Dataset with DQL',
787
+ inputs={'dataset': 'Dataset', 'filters': 'Json'},
788
+ )
789
+ def predict_dataset(
790
+ self,
791
+ dataset: entities.Dataset,
792
+ filters: Optional[entities.Filters] = None,
793
+ batch_size: Optional[int] = None,
794
+ progress: Optional[utilities.Progress] = None,
795
+ **kwargs,
796
+ ):
797
+ """
798
+ Run model prediction on all items in a dataset
799
+
800
+ :param dataset: Dataset entity to predict
801
+ :param filters: Filters entity for filtering before prediction
802
+ :param batch_size: int size of batch to run a single prediction
803
+ :param progress: dl.Progress object to track progress
804
+ :return: bool indicating if prediction completed successfully
805
+ """
806
+ self._execute_dataset_operation(
807
+ dataset=dataset,
808
+ operation_type='predict',
809
+ filters=filters,
810
+ progress=progress,
811
+ batch_size=batch_size,
812
+ )
813
+
814
+ @entities.Package.decorators.function(
815
+ display_name='Train a Model',
816
+ inputs={'model': entities.Model},
817
+ outputs={'model': entities.Model},
818
+ )
819
+ def train_model(self, model: entities.Model, cleanup=False, progress: utilities.Progress = None, context: utilities.Context = None):
820
+ """
821
+ Train on existing model.
822
+ data will be taken from dl.Model.datasetId
823
+ configuration is as defined in dl.Model.configuration
824
+ upload the output the model's bucket (model.bucket)
825
+ """
826
+ if isinstance(model, dict):
827
+ model = repositories.Models(client_api=self._client_api).get(model_id=model['id'])
828
+ output_path = None
829
+ try:
830
+ logger.info("Received {s} for training".format(s=model.id))
831
+ model = model.wait_for_model_ready()
832
+ if model.status == 'failed':
833
+ raise ValueError("Model is in failed state, cannot train.")
834
+
835
+ ##############
836
+ # Set status #
837
+ ##############
838
+ model.status = 'training'
839
+ if context is not None:
840
+ if 'system' not in model.metadata:
841
+ model.metadata['system'] = dict()
842
+ model.update(reload_services=False)
843
+
844
+ ##########################
845
+ # load model and weights #
846
+ ##########################
847
+ logger.info("Loading Adapter with: {n} ({i!r})".format(n=model.name, i=model.id))
848
+ self.load_from_model(model_entity=model)
849
+
850
+ ################
851
+ # prepare data #
852
+ ################
853
+ root_path, data_path, output_path = self.prepare_data(dataset=self.model_entity.dataset, root_path=os.path.join('tmp', model.id))
854
+ # Start the Train
855
+ logger.info(f"Training model {model.name!r} ({model.id!r}) on data {data_path!r}")
856
+ if progress is not None:
857
+ progress.update(message='starting training')
858
+
859
+ def on_epoch_end_callback(i_epoch, n_epoch):
860
+ if progress is not None:
861
+ progress.update(progress=int(100 * (i_epoch + 1) / n_epoch), message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
862
+
863
+ self.train(data_path=data_path, output_path=output_path, on_epoch_end_callback=on_epoch_end_callback)
864
+ if progress is not None:
865
+ progress.update(message='saving model', progress=99)
866
+
867
+ self.save_to_model(local_path=output_path, replace=True)
868
+ model.status = 'trained'
869
+ model.update(reload_services=False)
870
+ ###########
871
+ # cleanup #
872
+ ###########
873
+ if cleanup:
874
+ shutil.rmtree(output_path, ignore_errors=True)
875
+ except Exception:
876
+ # save also on fail
877
+ if output_path is not None:
878
+ self.save_to_model(local_path=output_path, replace=True)
879
+ logger.info('Execution failed. Setting model.status to failed')
880
+ raise
881
+ return model
882
+
883
+ @entities.Package.decorators.function(
884
+ display_name='Evaluate a Model',
885
+ inputs={'model': entities.Model, 'dataset': entities.Dataset, 'filters': 'Json'},
886
+ outputs={'model': entities.Model, 'dataset': entities.Dataset},
887
+ )
888
+ def evaluate_model(
889
+ self,
890
+ model: entities.Model,
891
+ dataset: entities.Dataset,
892
+ filters: entities.Filters = None,
893
+ #
894
+ progress: utilities.Progress = None,
895
+ context: utilities.Context = None,
896
+ ):
897
+ """
898
+ Evaluate a model.
899
+ data will be downloaded from the dataset and query
900
+ configuration is as defined in dl.Model.configuration
901
+ upload annotations and calculate metrics vs GT
902
+
903
+ :param model: Model entity to run prediction
904
+ :param dataset: Dataset to evaluate
905
+ :param filters: Filter for specific items from dataset
906
+ :param progress: dl.Progress for report FaaS progress
907
+ :param context:
908
+ :return:
909
+ """
910
+ logger.info(f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
911
+ ##########################
912
+ # load model and weights #
913
+ ##########################
914
+ logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
915
+ self.load_from_model(dataset=dataset, model_entity=model)
916
+
917
+ ##############
918
+ # Predicting #
919
+ ##############
920
+ logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
921
+ if not filters:
922
+ filters = entities.Filters()
923
+ if self.adapter_defaults.get("overwrite_annotations", True) is True:
924
+ self._execute_dataset_operation(
925
+ dataset=dataset,
926
+ operation_type='predict',
927
+ filters=filters,
928
+ multiple_executions=False,
929
+ )
930
+
931
+ ##############
932
+ # Evaluating #
933
+ ##############
934
+ logger.info(f"Starting adapter.evaluate()")
935
+ if progress is not None:
936
+ progress.update(message='calculating metrics', progress=98)
937
+ model = self.evaluate(model=model, dataset=dataset, filters=filters)
938
+ #########
939
+ # Done! #
940
+ #########
941
+ if progress is not None:
942
+ progress.update(message='finishing evaluation', progress=99)
943
+ return model, dataset
944
+
945
+ # =============
946
+ # INNER METHODS
947
+ # =============
948
+ def _get_feature_set(self):
949
+ # Ensure feature set creation/retrieval is thread-safe across the class
950
+ with self.__class__._feature_set_lock:
951
+ # Search for existing feature set for this model id
952
+ feature_set = self.model_entity.feature_set
953
+ if feature_set is None:
954
+ logger.info('Feature Set not found. creating... ')
955
+ try:
956
+ self.project.feature_sets.get(feature_set_name=self.model_entity.name)
957
+ feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
958
+ logger.warning(
959
+ f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}"
960
+ )
961
+
962
+ except exceptions.NotFound:
963
+ feature_set_name = self.model_entity.name
964
+ feature_set = self.project.feature_sets.create(
965
+ name=feature_set_name,
966
+ entity_type=entities.FeatureEntityType.ITEM,
967
+ model_id=self.model_entity.id,
968
+ project_id=self.project.id,
969
+ set_type=self.model_entity.name,
970
+ size=self.configuration.get('embeddings_size', 256),
971
+ )
972
+ logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
973
+ else:
974
+ logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
975
+ return feature_set
976
+
977
+ def _execute_dataset_operation(
978
+ self,
979
+ dataset: entities.Dataset,
980
+ operation_type: str,
981
+ filters: Optional[entities.Filters] = None,
982
+ progress: Optional[utilities.Progress] = None,
983
+ batch_size: Optional[int] = None,
984
+ multiple_executions: bool = True,
985
+ ) -> bool:
986
+ """
987
+ Execute dataset operation (predict/embed) with batching and filtering support.
988
+
989
+ :param dataset: Dataset entity to run operation on
990
+ :param operation_type: Type of operation to execute ('predict' or 'embed')
991
+ :param filters: Filters entity to filter items, default None
992
+ :param progress: Progress object for tracking progress, default None
993
+ :param batch_size: Size of batches to process items, default None (uses model config)
994
+ :param multiple_executions: Whether to use multiple executions when filters exceed subset limit, default True
995
+ :return: True if operation completes successfully
996
+ :raises ValueError: If operation_type is not 'predict' or 'embed'
997
+ """
998
+ self.logger.debug(f"Running {operation_type} for dataset (name:{dataset.name}, id:{dataset.id})")
999
+
1000
+ if not filters:
1001
+ self.logger.debug("No filters provided, using default filters")
1002
+ filters = entities.Filters()
1003
+ if filters is not None and isinstance(filters, dict):
1004
+ self.logger.debug(f"Received custom filters {filters}")
1005
+ filters = entities.Filters(custom_filter=filters)
1006
+
1007
+ if operation_type == 'embed':
1008
+ feature_set = self.feature_set
1009
+ logger.info(f"Feature set found! name: {feature_set.name}, id: {feature_set.id}")
1010
+
1011
+ predict_embed_subset_limit = self.configuration.get('predict_embed_subset_limit', PREDICT_EMBED_DEFAULT_SUBSET_LIMIT)
1012
+ predict_embed_timeout = self.configuration.get('predict_embed_timeout', PREDICT_EMBED_DEFAULT_TIMEOUT)
1013
+ self.logger.debug(f"Inputs: predict_embed_subset_limit: {predict_embed_subset_limit}, predict_embed_timeout: {predict_embed_timeout}")
1014
+ tmp_filters = copy.deepcopy(filters.prepare())
1015
+ tmp_filters['pageSize'] = 0
1016
+ num_items = dataset.items.list(filters=entities.Filters(custom_filter=tmp_filters)).items_count
1017
+ self.logger.debug(f"Number of items for current filters: {num_items}")
1018
+
1019
+ # One-item lookahead on generator: if only one subset, run locally; else create executions for all
1020
+ gen = entities.Filters._get_split_filters(dataset, filters, predict_embed_subset_limit)
1021
+ try:
1022
+ first_filter = next(gen)
1023
+ except StopIteration:
1024
+ self.logger.info("Filters is empty, nothing to run")
1025
+ return True
1026
+
1027
+ try:
1028
+ second_filter = next(gen)
1029
+ multiple = True
1030
+ except StopIteration:
1031
+ multiple = False
1032
+
1033
+ # Create consistent iterable of all filters for reuse
1034
+ # Both paths use chain to ensure consistent type and iteration behavior
1035
+ if multiple:
1036
+ # Chain together the pre-consumed filters with the remaining generator
1037
+ all_filters = chain([first_filter, second_filter], gen)
1038
+ else:
1039
+ # Single filter - use chain with empty generator for consistency
1040
+ all_filters = chain([first_filter], [])
1041
+
1042
+ if not multiple or not multiple_executions:
1043
+ self.logger.info("Running locally")
1044
+ if batch_size is None:
1045
+ batch_size = self.configuration.get('batch_size', 4)
1046
+
1047
+ # Process each filter locally
1048
+ for filter_dict in all_filters:
1049
+ filter_dict["pageSize"] = 1000
1050
+ single_filters = entities.Filters(custom_filter=filter_dict)
1051
+ pages = dataset.items.list(filters=single_filters)
1052
+ self.logger.info(f"Processing filter on: {pages.items_count} items")
1053
+ items = [item for page in pages for item in page if item.type == 'file']
1054
+ self.logger.debug(f"Items length: {len(items)}")
1055
+
1056
+ if operation_type == 'embed':
1057
+ self.embed_items(items=items, batch_size=batch_size, progress=progress)
1058
+ elif operation_type == 'predict':
1059
+ self.predict_items(items=items, batch_size=batch_size, progress=progress)
1060
+ else:
1061
+ raise ValueError(f"Unsupported operation type: {operation_type}")
1062
+ return True
1063
+
1064
+ executions = []
1065
+ for filter_dict in all_filters:
1066
+ self.logger.debug(f"Creating execution for models {operation_type} with dataset id {dataset.id} and filter_dict {filter_dict}")
1067
+ if operation_type == 'embed':
1068
+ execution = self.model_entity.models.embed(
1069
+ model=self.model_entity,
1070
+ dataset_id=dataset.id,
1071
+ filters=entities.Filters(custom_filter=filter_dict),
1072
+ )
1073
+ elif operation_type == 'predict':
1074
+ execution = self.model_entity.models.predict(
1075
+ model=self.model_entity, dataset_id=dataset.id, filters=entities.Filters(custom_filter=filter_dict)
1076
+ )
1077
+ else:
1078
+ raise ValueError(f"Unsupported operation type: {operation_type}")
1079
+ executions.append(execution)
1080
+
1081
+ if executions:
1082
+ self.logger.info(f'Created {len(executions)} executions for {operation_type}, ' f'execution ids: {[ex.id for ex in executions]}')
1083
+
1084
+ wait_time = 5
1085
+ start_time = time.time()
1086
+ last_perc = 0
1087
+ self.logger.debug(f"Waiting for executions with timeout {predict_embed_timeout}")
1088
+ while time.time() - start_time < predict_embed_timeout:
1089
+ continue_loop = False
1090
+ total_perc = 0
1091
+
1092
+ for ex in executions:
1093
+ execution = self.project.executions.get(execution_id=ex.id)
1094
+ perc = execution.latest_status.get('percentComplete', 0)
1095
+ total_perc += perc
1096
+ if execution.in_progress():
1097
+ continue_loop = True
1098
+
1099
+ avg_perc = round(total_perc / len(executions), 0)
1100
+ if progress is not None and last_perc != avg_perc:
1101
+ last_perc = avg_perc
1102
+ progress.update(progress=last_perc, message=f'running {operation_type}')
1103
+
1104
+ if not continue_loop:
1105
+ break
1106
+
1107
+ time.sleep(wait_time)
1108
+ self.logger.debug("End waiting for executions")
1109
+ # Check if any execution failed
1110
+ executions_filter = entities.Filters(resource=entities.FiltersResource.EXECUTION)
1111
+ executions_filter.add(field="id", values=[ex.id for ex in executions], operator=entities.FiltersOperations.IN)
1112
+ executions_filter.add(field='latestStatus.status', values='failed')
1113
+ executions_filter.page_size = 0
1114
+ failed_executions_count = self.project.executions.list(filters=executions_filter).items_count
1115
+ if failed_executions_count > 0:
1116
+ self.logger.error(f"Failed to {operation_type} for {failed_executions_count} executions")
1117
+ raise ValueError(f"Failed to {operation_type} entire dataset, please check the logs for more details")
1118
+ return True
1119
+
1120
+ def _upload_model_annotations(self, item: entities.Item, predictions):
1121
+ """
1122
+ Utility function that upload prediction to dlp platform based on the package.output_type
1123
+ :param predictions: `dl.AnnotationCollection`
1124
+ :param cleanup: `bool` if set removes existing predictions with the same package-model name
1125
+ """
1126
+ if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
1127
+ raise TypeError(f'predictions was expected to be of type {entities.AnnotationCollection}, but instead it is {type(predictions)}')
1128
+ clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
1129
+ clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name, method=entities.FiltersMethod.OR)
1130
+ clean_filter.add(field='metadata.system.model.name', values=self.model_entity.name, method=entities.FiltersMethod.OR)
1131
+ # clean_filter.add(field='type', values=self.model_entity.output_type,)
1132
+ item.annotations.delete(filters=clean_filter)
1133
+ annotations = item.annotations.upload(annotations=predictions)
1134
+ return annotations
1135
+
1136
+ @staticmethod
1137
+ def _item_to_image(item):
1138
+ """
1139
+ Preprocess items before calling the `predict` functions.
1140
+ Convert item to numpy array
1141
+
1142
+ :param item:
1143
+ :return:
1144
+ """
1145
+ try:
1146
+ buffer = item.download(save_locally=False)
1147
+ image = np.asarray(Image.open(buffer))
1148
+ except Exception as e:
1149
+ logger.error(f"Failed to convert image to np.array, Error: {e}\n{traceback.format_exc()}")
1150
+ image = None
1151
+ return image
1152
+
1153
+ @staticmethod
1154
+ def _item_to_item(item):
1155
+ """
1156
+ Default item to batch function.
1157
+ This function should prepare a single item for the predict function, e.g. for images, it loads the image as numpy array
1158
+ :param item:
1159
+ :return:
1160
+ """
1161
+ return item
1162
+
1163
+ @staticmethod
1164
+ def _item_to_text(item):
1165
+ filename = item.download(overwrite=True)
1166
+ text = None
1167
+ if item.mimetype == 'text/plain' or item.mimetype == 'text/markdown':
1168
+ with open(filename, 'r') as f:
1169
+ text = f.read()
1170
+ text = text.replace('\n', ' ')
1171
+ else:
1172
+ logger.warning('Item is not text file. mimetype: {}'.format(item.mimetype))
1173
+ text = item
1174
+ if os.path.exists(filename):
1175
+ os.remove(filename)
1176
+ return text
1177
+
1178
+ @staticmethod
1179
+ def _uri_to_image(data_uri):
1180
+ # data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAS4AAAEuCAYAAAAwQP9DAAAU80lEQVR4Xu2da+hnRRnHv0qZKV42LDOt1eyGULoSJBGpRBFprBJBQrBJBBWGSm8jld5WroHUCyEXKutNu2IJ1QtXetULL0uQFCu24WoRsV5KpYvGYzM4nv6X8zu/mTnznPkcWP6XPTPzzOf7/L7/OXPmzDlOHBCAAAScETjOWbyECwEIQEAYF0kAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxDwTeDTkr4s6UxJ/5F0QNK3JD3lu1tbR49xLVld+jYXgcskvSTpIkmnS/qgpJMk/Tv8bHHZ7+PXPw6M5kRJx0t6Ijkv9uUsSW+U9Iykczfp4K8lfXiuztdoF+OqQZk2vBEwUzFTsK9mQNFkotGkhvFeSc+G86NRtdDfd0h6tIVASsSAcZWgSp0eCJjJ7JR0SRgZ2SjHDMp+38Jho7PXTAzkBUmvn1jWRTGMy4VMBJmBgBnSpZLsMs7+paOodao3k/hLqCBe8j0cfj4Yvtp8k/1fPLaaf4pxxXPSS8r4/Vsl3SXp5EHgNjo8JukDkg6v06nWy2JcrSvUX3xmKjYSipdqF0h6V/jgp6Mh+2DHf0YpnSd6p6TTkjml7UZRL4bLPasnmo7VHb+PKsQ20rZTQ6ql1lclfXODxr4u6Ru1gpizHYxrTvq0beZkE9cfkXRxxcu0pyXZaMiMKX71dBfua5sY1Psk/baHtMK4elC5rT5eFS7Z7Otmd8VyRDwcRZkxmUlFo8rRxlx13Clpz6Dxn0r61FwB1W4X46pNvM/27PLPPmhmVhvNLUWTiaZil1/xEswMx/7fbv9bWfs5nfcxommdceQU55eWSNxGihcmHbMRZK45Oxe8MK75ZYofaku8MyQ9J+mQpKNJMqbzLfeHkIeTuPP35JUIbCSVToRvNrKyftqCSfs3nE9qqT+txWKT8OmxT9LnWguyZDwYV0m6m9dtH+SbJNlamw+tGIIl7Va6/VPS8xusP4rN2JojG8E8NrhUS+d4ht/bbfkTJP0umGk6ER7PtfkVmwR/wzaXgEck7Q1mNcfE9oq4mzx9aFxXB55NBlsiKIyrBNXt67xB0q3bn7aYM+xSxkZVNjez5Eu4GoLZ5fb+pCFb/mB/LLo6MK555LaRyUPzND251VUWRJpRxTt2cUJ8csMUfBUBG61en/ymu8tE6zvGNd+nwuao7N8PJO0Kz7JZNDbH9aSkv4fQ0su2RyS9VtKD4dJtOClt5+4Il4Fpz+KkdqzLnpuzdrY74vnppWG6ujx9xMXOsUWPjw8WW27XBv+/GgH7Q2Dzh/G4NoxkV6vF+dkYV1sCRoNpKyqiaYmA/TGxxbXxsD963d3YwLhaSkligcDWBIZTDHajo+RauGb1wLialYbAIPB/BO6Q9Pnkt7dJshs93R0YV3eS02HHBGz+8Owk/vN6nU/EuBxnMaF3RWC4DOJ7kr7UFYGksxhXr8rTb28Eho/5dDvaMuEwLm/pS7w9EhiOtu4Oz332yOLlPmNc3UpPx50QsCUytlg5vXvY5RKIVC+My0n2Ema3BG4Oz7VGAN2PthhxdftZoOOOCKQLTu1RKlvL1f3D6Yy4HGUwoXZHwLaq+X7S6xvDzhrdgRh2GOPqPgUA0DCB9LlE27tsu73zG+5K3tAwrrw8qQ0CuQjYZLztmRaP7vbc2gokxpUrzagHAnkJpNvXMNoasMW48iYbtUEgF4F0Up7RFsaVK6+oBwLFCKST8t3uAMGlYrH8omIIFCFg21zvDjV3uwMExlUkt6gUAkUIDCflu34mcTPCzHEVyT0qhcBkAumLVJiU3wQjxjU5vygIgSIE0l0gutxPfgxVjGsMJc6BQB0C9kC1vW4sHvbik/RlKXWicNAKxuVAJELshkC6fY29sdzecs6xAQGMi7SAQDsE7IW5e0I4PJe4hS4YVztJSyQQsF0fdgYM3E3EuPhEQKB5Aumrx7ibuI1cjLiaz2cC7IRAugyCy0SMq5O0p5veCaSr5blMxLi85zPxd0LgGUmnSOIycYTgXCqOgMQpEChMwJY93MfdxPGUMa7xrDgTAqUIxGUQ7Ck/kjDGNRIUp0GgIIG49xaXiSMhY1wjQXEaBAoRSFfLczdxJGSMayQoToNAIQLpannuJo6EjHGNBMVpEChEgMvECWAxrgnQKAKBTAS4TJwIEuOaCI5iEMhAgMvEiRAxrongKAaBDAS4TJwIEeOaCI5iEFiTQPpQNXcTV4SJca0IjNMhkIlA+sJX7iauCBXjWhEYp0MgE4G49xaLTicAxbgmQKMIBNYkkL6CjPcmToCJcU2ARhEIrEkgfVP1Lkn2Zh+OFQhgXCvA4lQIZCIQl0EckWSjL44VCWBcKwLjdAhkIHBY0vmS9kmy0RfHigQwrhWBcToE1iSQLoO4QtK9a9bXZXGMq0vZ6fSMBOLe8rb3ll0m8sLXCWJgXBOgUQQCaxA4KOlStmheg6AkjGs9fpSGwKoEXgoFbpF086qFOf9/BDAuMgEC9Qike8tfLslGXxwTCGBcE6BRBAITCdgI66ZQls/eRIiMuNYAR1EITCAQ57ful2SjL46JBHD9ieAoBoEJBJjfmgBtoyIYVyaQVAOBbQik67eulmRvruaYSADjmgiOYhBYkUBcv2XFdrB+a0V6g9MxrvX4URoCYwnwfOJYUiPOw7hGQOIUCGQgEPff4vnEDDAxrgwQqQIC2xBI99+6VpKNvjjWIIBxrQGPohAYSSDdf4ttmkdC2+o0jCsDRKqAwDYEmN/KnCIYV2agVAeBDQgclfQW9t/KlxsYVz6W1ASBjQiw/1aBvMC4CkClSggkBOLziey/lTEtMK6MMKkKAhsQsBdhXMj+W3lzA+PKy5PaIJASOF3SsfAL3ladMTcwrowwqQoCAwK8hqxQSmBchcBSLQTCg9S7Jdn8lo2+ODIRwLgygaQaCGxAwF6EcRrLIPLnBsaVnyk1QsAIXCVpf0DBNjaZcwLjygyU6iAQCOyVdH34nm1sMqcFxpUZKNVBIBCIu0HcHUZfgMlIAOPKCJOqIBAIpKvl2Q2iQFpgXAWgUmX3BLhMLJwCGFdhwFTfJQEuEwvLjnEVBkz13RHgpRgVJMe4KkCmia4IpA9Vs+i0kPQYVyGwVNstgQcl7WLRaVn9Ma6yfKm9LwLsvVVJb4yrEmia6YJAvJvIs4mF5ca4CgOm+q4I8GxiJbkxrkqgaWbxBNJnE22OyzYQ5ChEAOMqBJZquyMQ124dkWTvUeQoSADjKgiXqrshcJmk+0Jv2em0guwYVwXINLF4Agck2YaBdvDC1wpyY1wVINPEognYZeHvJZ0g6RFJFyy6t410DuNqRAjCcEvgBkm3huhvl3Sd2544ChzjciQWoTZJIL5+zILjbmIliTCuSqBpZpEE0tePsei0osQYV0XYNLU4Aunrx/ZJsp85KhDAuCpAponFErhT0p7QO5ZBVJQZ46oIm6YWR4D5rZkkxbhmAk+z7gkwvzWjhBjXjPBp2jWBz0i6K/TgN5Iucd0bZ8FjXM4EI9xmCMSdTi2gn0gyI+OoRADjqgSaZhZHIH3Mh1eQVZYX46oMnOYWQyDuBmEdulzSwcX0zEFHMC4HIhFikwReSqLiwerKEmFclYHT3CIIpNvYWIf4HFWWFeCVgdPcIgh8R9JXQk/+KulNi+iVo05gXI7EItRmCPxS0kdDNLalzXuaiayTQDCuToSmm9kI2MJT25751FDjLZJsaQRHRQIYV0XYNLUIAvdIujLpCXcUZ5AV45oBOk26JvCMpFNCD+zO4vGue+M0eIzLqXCEPQuBdBsbC+BeSVfMEknnjWJcnScA3V+JwJOS3pyUuFqSraDnqEwA46oMnOZcE0gXnVpH+PzMJCfgZwJPsy4JYFyNyIZxNSIEYbggMDSuHZKechH5woLEuBYmKN0pSoARV1G84yvHuMaz4sy+CQzvKB6VdE7fSObrPcY1H3ta9kVgeEeRt/rMqB/GNSN8mnZFYHiZyIr5GeXDuGaET9NuCFwlaX8SLTtCzCwdxjWzADTvgkC6v7wFfJukG1xEvtAgMa6FCku3shL4s6QzkxpZMZ8V7+qVYVyrM6NEfwSel3Ri0m3Wb82cAxjXzALQfPMEhvNbf5D07uajXniAGNfCBaZ7axN4VNLbk1pulLR37VqpYC0CGNda+Ci8cAK22+mxQR95o08DomNcDYhACM0SGK6Wt3cpmnFxzEwA45pZAJpvmsBwtTyXiY3IhXE1IgRhNElguFqey8RGZMK4GhGCMJojMLybyGViQxJhXA2JQShNEbhT0p4kIlbLNyQPxtWQGITSFAH2l29KjlcHg3E1LA6hzUrgxcGe8nxWZpUD42oIP6E0SuAiSQ8NYtsl6eFG4+0uLP6KdCc5HR5BYKOFp+y/NQJcrVMwrlqkaccTgQckXTwI+DJJ93vqxJJjxbiWrC59m0LgfEmHBwX/JemEKZVRpgwBjKsMV2r1S8BGVvcNwv+spB/67dLyIse4lqcpPVqPwEbGxcaB6zHNXhrjyo6UCp0TuFLSPYM+XCPpx877tajwMa5FyUlnMhCwveRvHdTDjqcZwOasAuPKSZO6lkDggKTdSUeOSDp3CR1bUh8wriWpSV9yEPiHpJOSinhGMQfVzHVgXJmBUp17AsOtbFgx36CkGFeDohDSbASGj/r8TdIZs0VDw5sSwLhIDgi8QmC4VfPdkmxfLo7GCGBcjQlCOLMSGO7BxVbNs8qxeeMYV6PCENYsBGyX051JyzxYPYsM2zeKcW3PiDP6ITCcmGf9VqPaY1yNCkNY1QkMJ+YPSbLfcTRIAONqUBRCmoXA8BlF1m/NIsO4RjGucZw4a/kEhncUebC6Yc0xrobFIbSqBIbPKDK/VRX/ao1hXKvx4uzlEtgr6frQvUckXbDcrvrvGcblX0N6kIdAaly/kPTxPNVSSwkCGFcJqtTpkUC6+JSFp40riHE1LhDhVSNwUNKloTUm5qthn9YQxjWNG6WWRyA1LlbMN64vxtW4QIRXjcBTkk4LrWFc1bBPawjjmsaNUssjkD7ug3E1ri/G1bhAhFeNQGpcbB5YDfu0hjCuadwotTwCqXGdJ8l2iuBolADG1agwhFWdQGpcfC6q41+tQQRajRdnL5dANK6nJZ2+3G4uo2cY1zJ0pBfrEbDXjz0WquB1ZOuxrFIa46qCmUYaJ/AJST8PMf5K0scaj7f78DCu7lMAAJLSnSFul3QdVNomgHG1rQ/R1SGQPmDNGq46zNdqBeNaCx+FF0LgYUkXhr6wFMKBqBiXA5EIsTgB7igWR5y3AYwrL09q80cg3WueF8A60Q/jciIUYRYjcLOkm0Lt7MNVDHPeijGuvDypzR+BdH6LZxSd6IdxORGKMIsQsBXyx0LNLDwtgrhMpRhXGa7U6oNA+kqyfZLsZw4HBDAuByIRYjEC6T7zbNdcDHP+ijGu/Eyp0Q+BuOspD1b70ezlSDEuZ4IRbjYCF0l6KNTGZWI2rHUqwrjqcKaV9gikj/lwmdiePltGhHE5E4xwsxGIyyC4TMyGtF5FGFc91rTUFoEXJL1OEqvl29JlVDQY1yhMnLQwAuljPl+QdMfC+rf47mBci5eYDm5AIJ3fYjcIhymCcTkUjZDXJhDnt1gtvzbKeSrAuObhTqvzEUj3l78t7H46XzS0PIkAxjUJG4UcE0i3aWYZhFMhMS6nwhH2ZAIHJO0Opcn/yRjnLYhw8/Kn9foE4m6nhyTZ6nkOhwQwLoeiEfJkAryGbDK6tgpiXG3pQTRlCaS7nfJ8YlnWRWvHuIripfLGCLCNTWOCTA0H45pKjnIeCaTbNPP+RI8KclfFsWqEPpVAnJi38jsk2X5cHA4JMOJyKBohTyaQGhe5Pxnj/AURb34NiKAOgXTjQLayqcO8WCsYVzG0VNwYgXRHCNZwNSbOquFgXKsS43yvBOxlr98OwT8g6f1eO0Lc7DlPDvRD4LuSvhi6+zNJn+yn68vrKSOu5WlKjzYmkD6jaKMv25OLwykBjMupcIS9MoH4KjIryK4QK+NrqwDG1ZYeRFOGQDoxby2whqsM52q1YlzVUNPQjAR+JOma0P5zkk6eMRaazkAA48oAkSqaJ/CEpLNClM9KOrX5iAlwSwIYFwmydAJnS3p80MlzJB1deseX3D+Ma8nq0rdIwF6K8bbww58k7QSNbwIYl2/9iH4cAdtA0O4k2rFf0r3jinFWqwQwrlaVIS4IQGBTAhgXyQEBCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwR+C/doIhTZIi/uMAAAAASUVORK5CYII="
1181
+ image_b64 = data_uri.split(",")[1]
1182
+ binary = base64.b64decode(image_b64)
1183
+ image = np.asarray(Image.open(io.BytesIO(binary)))
1184
+ return image
1185
+
1186
+ def _update_predictions_metadata(self, item: entities.Item, predictions: entities.AnnotationCollection):
1187
+ """
1188
+ add model_name and model_id to the metadata of the annotations.
1189
+ add model_info to the metadata of the system metadata of the annotation.
1190
+ Add item id to all the annotations in the AnnotationCollection
1191
+
1192
+ :param item: Entity.Item
1193
+ :param predictions: item's AnnotationCollection
1194
+ :return:
1195
+ """
1196
+ for prediction in predictions:
1197
+ if prediction.type == entities.AnnotationType.SEGMENTATION:
1198
+ color = None
1199
+ try:
1200
+ color = item.dataset._get_ontology().color_map.get(prediction.label, None)
1201
+ except (exceptions.BadRequest, exceptions.NotFound):
1202
+ ...
1203
+ if color is None:
1204
+ if self.model_entity._dataset is not None:
1205
+ try:
1206
+ color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label, (255, 255, 255))
1207
+ except (exceptions.BadRequest, exceptions.NotFound):
1208
+ ...
1209
+ if color is None:
1210
+ logger.warning("Can't get annotation color from model's dataset, using default.")
1211
+ color = prediction.color
1212
+ prediction.color = color
1213
+
1214
+ prediction.item_id = item.id
1215
+ if 'user' in prediction.metadata and 'model' in prediction.metadata['user']:
1216
+ prediction.metadata['user']['model']['model_id'] = self.model_entity.id
1217
+ prediction.metadata['user']['model']['name'] = self.model_entity.name
1218
+ if 'system' not in prediction.metadata:
1219
+ prediction.metadata['system'] = dict()
1220
+ if 'model' not in prediction.metadata['system']:
1221
+ prediction.metadata['system']['model'] = dict()
1222
+ confidence = prediction.metadata.get('user', dict()).get('model', dict()).get('confidence', None)
1223
+ prediction.metadata['system']['model'] = {
1224
+ 'model_id': self.model_entity.id,
1225
+ 'name': self.model_entity.name,
1226
+ 'confidence': confidence,
1227
+ }
1228
+
1229
+ ##############################
1230
+ # Callback Factory functions #
1231
+ ##############################
1232
+ @property
1233
+ def dataloop_keras_callback(self):
1234
+ """
1235
+ Returns the constructor for a keras api dump callback
1236
+ The callback is used for dlp platform to show train losses
1237
+
1238
+ :return: DumpHistoryCallback constructor
1239
+ """
1240
+ try:
1241
+ import keras
1242
+ except (ImportError, ModuleNotFoundError) as err:
1243
+ raise RuntimeError(f'{self.__class__.__name__} depends on extenral package. Please install ') from err
1244
+
1245
+ import os
1246
+ import time
1247
+ import json
1248
+
1249
+ class DumpHistoryCallback(keras.callbacks.Callback):
1250
+ def __init__(self, dump_path):
1251
+ super().__init__()
1252
+ if os.path.isdir(dump_path):
1253
+ dump_path = os.path.join(dump_path, f'__view__training-history__{time.strftime("%F-%X")}.json')
1254
+ self.dump_file = dump_path
1255
+ self.data = dict()
1256
+
1257
+ def on_epoch_end(self, epoch, logs=None):
1258
+ logs = logs or {}
1259
+ for name, val in logs.items():
1260
+ if name not in self.data:
1261
+ self.data[name] = {'x': list(), 'y': list()}
1262
+ self.data[name]['x'].append(float(epoch))
1263
+ self.data[name]['y'].append(float(val))
1264
+ self.dump_history()
1265
+
1266
+ def dump_history(self):
1267
+ _json = {
1268
+ "query": {},
1269
+ "datasetId": "",
1270
+ "xlabel": "epoch",
1271
+ "title": "training loss",
1272
+ "ylabel": "val",
1273
+ "type": "metric",
1274
+ "data": [
1275
+ {
1276
+ "name": name,
1277
+ "x": values['x'],
1278
+ "y": values['y'],
1279
+ }
1280
+ for name, values in self.data.items()
1281
+ ],
1282
+ }
1283
+
1284
+ with open(self.dump_file, 'w') as f:
1285
+ json.dump(_json, f, indent=2)
1286
+
1287
+ return DumpHistoryCallback