dtlpy 1.114.17__py3-none-any.whl → 1.116.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. dtlpy/__init__.py +491 -491
  2. dtlpy/__version__.py +1 -1
  3. dtlpy/assets/__init__.py +26 -26
  4. dtlpy/assets/code_server/config.yaml +2 -2
  5. dtlpy/assets/code_server/installation.sh +24 -24
  6. dtlpy/assets/code_server/launch.json +13 -13
  7. dtlpy/assets/code_server/settings.json +2 -2
  8. dtlpy/assets/main.py +53 -53
  9. dtlpy/assets/main_partial.py +18 -18
  10. dtlpy/assets/mock.json +11 -11
  11. dtlpy/assets/model_adapter.py +83 -83
  12. dtlpy/assets/package.json +61 -61
  13. dtlpy/assets/package_catalog.json +29 -29
  14. dtlpy/assets/package_gitignore +307 -307
  15. dtlpy/assets/service_runners/__init__.py +33 -33
  16. dtlpy/assets/service_runners/converter.py +96 -96
  17. dtlpy/assets/service_runners/multi_method.py +49 -49
  18. dtlpy/assets/service_runners/multi_method_annotation.py +54 -54
  19. dtlpy/assets/service_runners/multi_method_dataset.py +55 -55
  20. dtlpy/assets/service_runners/multi_method_item.py +52 -52
  21. dtlpy/assets/service_runners/multi_method_json.py +52 -52
  22. dtlpy/assets/service_runners/single_method.py +37 -37
  23. dtlpy/assets/service_runners/single_method_annotation.py +43 -43
  24. dtlpy/assets/service_runners/single_method_dataset.py +43 -43
  25. dtlpy/assets/service_runners/single_method_item.py +41 -41
  26. dtlpy/assets/service_runners/single_method_json.py +42 -42
  27. dtlpy/assets/service_runners/single_method_multi_input.py +45 -45
  28. dtlpy/assets/voc_annotation_template.xml +23 -23
  29. dtlpy/caches/base_cache.py +32 -32
  30. dtlpy/caches/cache.py +473 -473
  31. dtlpy/caches/dl_cache.py +201 -201
  32. dtlpy/caches/filesystem_cache.py +89 -89
  33. dtlpy/caches/redis_cache.py +84 -84
  34. dtlpy/dlp/__init__.py +20 -20
  35. dtlpy/dlp/cli_utilities.py +367 -367
  36. dtlpy/dlp/command_executor.py +764 -764
  37. dtlpy/dlp/dlp +1 -1
  38. dtlpy/dlp/dlp.bat +1 -1
  39. dtlpy/dlp/dlp.py +128 -128
  40. dtlpy/dlp/parser.py +651 -651
  41. dtlpy/entities/__init__.py +83 -83
  42. dtlpy/entities/analytic.py +347 -311
  43. dtlpy/entities/annotation.py +1879 -1879
  44. dtlpy/entities/annotation_collection.py +699 -699
  45. dtlpy/entities/annotation_definitions/__init__.py +20 -20
  46. dtlpy/entities/annotation_definitions/base_annotation_definition.py +100 -100
  47. dtlpy/entities/annotation_definitions/box.py +195 -195
  48. dtlpy/entities/annotation_definitions/classification.py +67 -67
  49. dtlpy/entities/annotation_definitions/comparison.py +72 -72
  50. dtlpy/entities/annotation_definitions/cube.py +204 -204
  51. dtlpy/entities/annotation_definitions/cube_3d.py +149 -149
  52. dtlpy/entities/annotation_definitions/description.py +32 -32
  53. dtlpy/entities/annotation_definitions/ellipse.py +124 -124
  54. dtlpy/entities/annotation_definitions/free_text.py +62 -62
  55. dtlpy/entities/annotation_definitions/gis.py +69 -69
  56. dtlpy/entities/annotation_definitions/note.py +139 -139
  57. dtlpy/entities/annotation_definitions/point.py +117 -117
  58. dtlpy/entities/annotation_definitions/polygon.py +182 -182
  59. dtlpy/entities/annotation_definitions/polyline.py +111 -111
  60. dtlpy/entities/annotation_definitions/pose.py +92 -92
  61. dtlpy/entities/annotation_definitions/ref_image.py +86 -86
  62. dtlpy/entities/annotation_definitions/segmentation.py +240 -240
  63. dtlpy/entities/annotation_definitions/subtitle.py +34 -34
  64. dtlpy/entities/annotation_definitions/text.py +85 -85
  65. dtlpy/entities/annotation_definitions/undefined_annotation.py +74 -74
  66. dtlpy/entities/app.py +220 -220
  67. dtlpy/entities/app_module.py +107 -107
  68. dtlpy/entities/artifact.py +174 -174
  69. dtlpy/entities/assignment.py +399 -399
  70. dtlpy/entities/base_entity.py +214 -214
  71. dtlpy/entities/bot.py +113 -113
  72. dtlpy/entities/codebase.py +292 -296
  73. dtlpy/entities/collection.py +38 -38
  74. dtlpy/entities/command.py +169 -169
  75. dtlpy/entities/compute.py +449 -442
  76. dtlpy/entities/dataset.py +1299 -1285
  77. dtlpy/entities/directory_tree.py +44 -44
  78. dtlpy/entities/dpk.py +470 -470
  79. dtlpy/entities/driver.py +235 -223
  80. dtlpy/entities/execution.py +397 -397
  81. dtlpy/entities/feature.py +124 -124
  82. dtlpy/entities/feature_set.py +145 -145
  83. dtlpy/entities/filters.py +798 -645
  84. dtlpy/entities/gis_item.py +107 -107
  85. dtlpy/entities/integration.py +184 -184
  86. dtlpy/entities/item.py +959 -953
  87. dtlpy/entities/label.py +123 -123
  88. dtlpy/entities/links.py +85 -85
  89. dtlpy/entities/message.py +175 -175
  90. dtlpy/entities/model.py +684 -684
  91. dtlpy/entities/node.py +1005 -1005
  92. dtlpy/entities/ontology.py +810 -803
  93. dtlpy/entities/organization.py +287 -287
  94. dtlpy/entities/package.py +657 -657
  95. dtlpy/entities/package_defaults.py +5 -5
  96. dtlpy/entities/package_function.py +185 -185
  97. dtlpy/entities/package_module.py +113 -113
  98. dtlpy/entities/package_slot.py +118 -118
  99. dtlpy/entities/paged_entities.py +299 -299
  100. dtlpy/entities/pipeline.py +624 -624
  101. dtlpy/entities/pipeline_execution.py +279 -279
  102. dtlpy/entities/project.py +394 -394
  103. dtlpy/entities/prompt_item.py +505 -499
  104. dtlpy/entities/recipe.py +301 -301
  105. dtlpy/entities/reflect_dict.py +102 -102
  106. dtlpy/entities/resource_execution.py +138 -138
  107. dtlpy/entities/service.py +963 -958
  108. dtlpy/entities/service_driver.py +117 -117
  109. dtlpy/entities/setting.py +294 -294
  110. dtlpy/entities/task.py +495 -495
  111. dtlpy/entities/time_series.py +143 -143
  112. dtlpy/entities/trigger.py +426 -426
  113. dtlpy/entities/user.py +118 -118
  114. dtlpy/entities/webhook.py +124 -124
  115. dtlpy/examples/__init__.py +19 -19
  116. dtlpy/examples/add_labels.py +135 -135
  117. dtlpy/examples/add_metadata_to_item.py +21 -21
  118. dtlpy/examples/annotate_items_using_model.py +65 -65
  119. dtlpy/examples/annotate_video_using_model_and_tracker.py +75 -75
  120. dtlpy/examples/annotations_convert_to_voc.py +9 -9
  121. dtlpy/examples/annotations_convert_to_yolo.py +9 -9
  122. dtlpy/examples/convert_annotation_types.py +51 -51
  123. dtlpy/examples/converter.py +143 -143
  124. dtlpy/examples/copy_annotations.py +22 -22
  125. dtlpy/examples/copy_folder.py +31 -31
  126. dtlpy/examples/create_annotations.py +51 -51
  127. dtlpy/examples/create_video_annotations.py +83 -83
  128. dtlpy/examples/delete_annotations.py +26 -26
  129. dtlpy/examples/filters.py +113 -113
  130. dtlpy/examples/move_item.py +23 -23
  131. dtlpy/examples/play_video_annotation.py +13 -13
  132. dtlpy/examples/show_item_and_mask.py +53 -53
  133. dtlpy/examples/triggers.py +49 -49
  134. dtlpy/examples/upload_batch_of_items.py +20 -20
  135. dtlpy/examples/upload_items_and_custom_format_annotations.py +55 -55
  136. dtlpy/examples/upload_items_with_modalities.py +43 -43
  137. dtlpy/examples/upload_segmentation_annotations_from_mask_image.py +44 -44
  138. dtlpy/examples/upload_yolo_format_annotations.py +70 -70
  139. dtlpy/exceptions.py +125 -125
  140. dtlpy/miscellaneous/__init__.py +20 -20
  141. dtlpy/miscellaneous/dict_differ.py +95 -95
  142. dtlpy/miscellaneous/git_utils.py +217 -217
  143. dtlpy/miscellaneous/json_utils.py +14 -14
  144. dtlpy/miscellaneous/list_print.py +105 -105
  145. dtlpy/miscellaneous/zipping.py +130 -130
  146. dtlpy/ml/__init__.py +20 -20
  147. dtlpy/ml/base_feature_extractor_adapter.py +27 -27
  148. dtlpy/ml/base_model_adapter.py +1257 -1086
  149. dtlpy/ml/metrics.py +461 -461
  150. dtlpy/ml/predictions_utils.py +274 -274
  151. dtlpy/ml/summary_writer.py +57 -57
  152. dtlpy/ml/train_utils.py +60 -60
  153. dtlpy/new_instance.py +252 -252
  154. dtlpy/repositories/__init__.py +56 -56
  155. dtlpy/repositories/analytics.py +85 -85
  156. dtlpy/repositories/annotations.py +916 -916
  157. dtlpy/repositories/apps.py +383 -383
  158. dtlpy/repositories/artifacts.py +452 -452
  159. dtlpy/repositories/assignments.py +599 -599
  160. dtlpy/repositories/bots.py +213 -213
  161. dtlpy/repositories/codebases.py +559 -559
  162. dtlpy/repositories/collections.py +332 -332
  163. dtlpy/repositories/commands.py +152 -158
  164. dtlpy/repositories/compositions.py +61 -61
  165. dtlpy/repositories/computes.py +439 -435
  166. dtlpy/repositories/datasets.py +1504 -1291
  167. dtlpy/repositories/downloader.py +976 -903
  168. dtlpy/repositories/dpks.py +433 -433
  169. dtlpy/repositories/drivers.py +482 -470
  170. dtlpy/repositories/executions.py +815 -817
  171. dtlpy/repositories/feature_sets.py +226 -226
  172. dtlpy/repositories/features.py +255 -238
  173. dtlpy/repositories/integrations.py +484 -484
  174. dtlpy/repositories/items.py +912 -909
  175. dtlpy/repositories/messages.py +94 -94
  176. dtlpy/repositories/models.py +1000 -988
  177. dtlpy/repositories/nodes.py +80 -80
  178. dtlpy/repositories/ontologies.py +511 -511
  179. dtlpy/repositories/organizations.py +525 -525
  180. dtlpy/repositories/packages.py +1941 -1941
  181. dtlpy/repositories/pipeline_executions.py +451 -451
  182. dtlpy/repositories/pipelines.py +640 -640
  183. dtlpy/repositories/projects.py +539 -539
  184. dtlpy/repositories/recipes.py +419 -399
  185. dtlpy/repositories/resource_executions.py +137 -137
  186. dtlpy/repositories/schema.py +120 -120
  187. dtlpy/repositories/service_drivers.py +213 -213
  188. dtlpy/repositories/services.py +1704 -1704
  189. dtlpy/repositories/settings.py +339 -339
  190. dtlpy/repositories/tasks.py +1477 -1477
  191. dtlpy/repositories/times_series.py +278 -278
  192. dtlpy/repositories/triggers.py +536 -536
  193. dtlpy/repositories/upload_element.py +257 -257
  194. dtlpy/repositories/uploader.py +661 -651
  195. dtlpy/repositories/webhooks.py +249 -249
  196. dtlpy/services/__init__.py +22 -22
  197. dtlpy/services/aihttp_retry.py +131 -131
  198. dtlpy/services/api_client.py +1785 -1782
  199. dtlpy/services/api_reference.py +40 -40
  200. dtlpy/services/async_utils.py +133 -133
  201. dtlpy/services/calls_counter.py +44 -44
  202. dtlpy/services/check_sdk.py +68 -68
  203. dtlpy/services/cookie.py +115 -115
  204. dtlpy/services/create_logger.py +156 -156
  205. dtlpy/services/events.py +84 -84
  206. dtlpy/services/logins.py +235 -235
  207. dtlpy/services/reporter.py +256 -256
  208. dtlpy/services/service_defaults.py +91 -91
  209. dtlpy/utilities/__init__.py +20 -20
  210. dtlpy/utilities/annotations/__init__.py +16 -16
  211. dtlpy/utilities/annotations/annotation_converters.py +269 -269
  212. dtlpy/utilities/base_package_runner.py +285 -264
  213. dtlpy/utilities/converter.py +1650 -1650
  214. dtlpy/utilities/dataset_generators/__init__.py +1 -1
  215. dtlpy/utilities/dataset_generators/dataset_generator.py +670 -670
  216. dtlpy/utilities/dataset_generators/dataset_generator_tensorflow.py +23 -23
  217. dtlpy/utilities/dataset_generators/dataset_generator_torch.py +21 -21
  218. dtlpy/utilities/local_development/__init__.py +1 -1
  219. dtlpy/utilities/local_development/local_session.py +179 -179
  220. dtlpy/utilities/reports/__init__.py +2 -2
  221. dtlpy/utilities/reports/figures.py +343 -343
  222. dtlpy/utilities/reports/report.py +71 -71
  223. dtlpy/utilities/videos/__init__.py +17 -17
  224. dtlpy/utilities/videos/video_player.py +598 -598
  225. dtlpy/utilities/videos/videos.py +470 -470
  226. {dtlpy-1.114.17.data → dtlpy-1.116.6.data}/scripts/dlp +1 -1
  227. dtlpy-1.116.6.data/scripts/dlp.bat +2 -0
  228. {dtlpy-1.114.17.data → dtlpy-1.116.6.data}/scripts/dlp.py +128 -128
  229. {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/METADATA +186 -183
  230. dtlpy-1.116.6.dist-info/RECORD +239 -0
  231. {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/WHEEL +1 -1
  232. {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/licenses/LICENSE +200 -200
  233. tests/features/environment.py +551 -551
  234. dtlpy/assets/__pycache__/__init__.cpython-310.pyc +0 -0
  235. dtlpy-1.114.17.data/scripts/dlp.bat +0 -2
  236. dtlpy-1.114.17.dist-info/RECORD +0 -240
  237. {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/entry_points.txt +0 -0
  238. {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/top_level.txt +0 -0
@@ -1,1086 +1,1257 @@
1
- import dataclasses
2
- import tempfile
3
- import datetime
4
- import logging
5
- import string
6
- import shutil
7
- import random
8
- import base64
9
- import tqdm
10
- import sys
11
- import io
12
- import os
13
- from PIL import Image
14
- from functools import partial
15
- import numpy as np
16
- from concurrent.futures import ThreadPoolExecutor
17
- import attr
18
- from collections.abc import MutableMapping
19
- from .. import entities, utilities, repositories, exceptions
20
- from ..services import service_defaults
21
- from ..services.api_client import ApiClient
22
-
23
- logger = logging.getLogger('ModelAdapter')
24
-
25
-
26
- class ModelConfigurations(MutableMapping):
27
- """
28
- Manages model configuration using composition with a backing dict.
29
-
30
- Uses MutableMapping to implement dict-like behavior without inheritance.
31
- This avoids duplication: if we inherited from dict, we'd have two dicts
32
- (one from inheritance, one from model_entity.configuration), leading to
33
- data inconsistency and maintenance issues.
34
- """
35
-
36
- def __init__(self, base_model_adapter):
37
- # Store reference to base_model_adapter dictionary
38
- self._backing_dict = {}
39
-
40
- if (
41
- base_model_adapter is not None
42
- and base_model_adapter.model_entity is not None
43
- and base_model_adapter.model_entity.configuration is not None
44
- ):
45
- self._backing_dict = base_model_adapter.model_entity.configuration
46
- if 'include_background' not in self._backing_dict:
47
- self._backing_dict['include_background'] = False
48
- self._base_model_adapter = base_model_adapter
49
- # Don't call _update_model_entity during initialization to avoid premature updates
50
-
51
- def _update_model_entity(self):
52
- if self._base_model_adapter is not None and self._base_model_adapter.model_entity is not None:
53
- self._base_model_adapter.model_entity.update(reload_services=False)
54
-
55
- def __ior__(self, other):
56
- self.update(other)
57
- return self
58
-
59
- # Required MutableMapping abstract methods
60
- def __getitem__(self, key):
61
- return self._backing_dict[key]
62
-
63
- def __setitem__(self, key, value):
64
- # Note: This method only updates the backing dict, not object attributes.
65
- # If you need to also update object attributes, be careful to avoid
66
- # infinite recursion by not calling __setattr__ from here.
67
- update = False
68
- if key not in self._backing_dict or self._backing_dict.get(key) != value:
69
- update = True
70
- self._backing_dict[key] = value
71
- if update:
72
- self._update_model_entity()
73
-
74
- def __delitem__(self, key):
75
- del self._backing_dict[key]
76
-
77
- def __iter__(self):
78
- return iter(self._backing_dict)
79
-
80
- def __len__(self):
81
- return len(self._backing_dict)
82
-
83
- def get(self, key, default=None):
84
- if key not in self._backing_dict:
85
- self.__setitem__(key, default)
86
- return self._backing_dict.get(key)
87
-
88
- def update(self, *args, **kwargs):
89
- # Check if there will be any modifications
90
- update_dict = dict(*args, **kwargs)
91
- has_changes = False
92
- for key, value in update_dict.items():
93
- if key not in self._backing_dict or self._backing_dict[key] != value:
94
- has_changes = True
95
- break
96
- self._backing_dict.update(*args, **kwargs)
97
-
98
- if has_changes:
99
- self._update_model_entity()
100
-
101
- def setdefault(self, key, default=None):
102
- if key not in self._backing_dict:
103
- self._backing_dict[key] = default
104
- return self._backing_dict[key]
105
-
106
-
107
- @dataclasses.dataclass
108
- class AdapterDefaults(ModelConfigurations):
109
- # for predict items, dataset, evaluate
110
- upload_annotations: bool = dataclasses.field(default=True)
111
- clean_annotations: bool = dataclasses.field(default=True)
112
- # for embeddings
113
- upload_features: bool = dataclasses.field(default=True)
114
- # for training
115
- root_path: str = dataclasses.field(default=None)
116
- data_path: str = dataclasses.field(default=None)
117
- output_path: str = dataclasses.field(default=None)
118
-
119
- def __init__(self, base_model_adapter=None):
120
- super().__init__(base_model_adapter)
121
- for f in dataclasses.fields(AdapterDefaults):
122
- # if the field exists in model_entity.configuration, use it
123
- # else set it from the attribute default value
124
- if super().get(f.name) is not None:
125
- super().__setattr__(f.name, super().get(f.name))
126
- else:
127
- super().__setitem__(f.name, f.default)
128
-
129
- def __setattr__(self, key, value):
130
- # Dataclass-like fields behave as attributes, so map to dict
131
- super().__setattr__(key, value)
132
- if not key.startswith("_"):
133
- super().__setitem__(key, value)
134
-
135
- def update(self, *args, **kwargs):
136
- for f in dataclasses.fields(AdapterDefaults):
137
- if f.name in kwargs:
138
- setattr(self, f.name, kwargs[f.name])
139
- super().update(*args, **kwargs)
140
-
141
- def resolve(self, key, *args):
142
- for arg in args:
143
- if arg is not None:
144
- super().__setitem__(key, arg)
145
- return arg
146
- return self.get(key, None)
147
-
148
-
149
- class BaseModelAdapter(utilities.BaseServiceRunner):
150
- _client_api = attr.ib(type=ApiClient, repr=False)
151
-
152
- def __init__(self, model_entity: entities.Model = None):
153
- self.logger = logger
154
- # entities
155
- self._model_entity = None
156
- self._package = None
157
- self._base_configuration = dict()
158
- self._configuration = None
159
- self.adapter_defaults = None
160
- self.package_name = None
161
- self.model = None
162
- self.bucket_path = None
163
- # funcs
164
- self.item_to_batch_mapping = {'text': self._item_to_text,
165
- 'image': self._item_to_image}
166
- if model_entity is not None:
167
- self.load_from_model(model_entity=model_entity)
168
- logger.warning(
169
- "in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'.")
170
-
171
- ##################
172
- # Configurations #
173
- ##################
174
-
175
- @property
176
- def configuration(self) -> dict:
177
- # load from model
178
- if self._model_entity is not None:
179
- configuration = self._configuration
180
- # else - load the default from the package
181
- elif self._package is not None:
182
- configuration = self.package.metadata.get('system', {}).get('ml', {}).get('defaultConfiguration', {})
183
- else:
184
- configuration = self._base_configuration
185
- return configuration
186
-
187
- @configuration.setter
188
- def configuration(self, configuration: dict):
189
- assert isinstance(configuration, dict)
190
- if self._model_entity is not None:
191
- # Update configuration with received dict
192
- self._model_entity.configuration = configuration
193
- self.adapter_defaults = AdapterDefaults(self)
194
- self._configuration = self.adapter_defaults
195
-
196
- ############
197
- # Entities #
198
- ############
199
- @property
200
- def model_entity(self):
201
- if self._model_entity is None:
202
- raise ValueError(
203
- "No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
204
- assert isinstance(self._model_entity, entities.Model)
205
- return self._model_entity
206
-
207
- @model_entity.setter
208
- def model_entity(self, model_entity):
209
- assert isinstance(model_entity, entities.Model)
210
- if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
211
- if self._model_entity.id != model_entity.id:
212
- self.logger.warning(
213
- 'Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
214
- self._model_entity = model_entity
215
- self.package = model_entity.package
216
- self.adapter_defaults = AdapterDefaults(self)
217
- self._configuration = self.adapter_defaults
218
-
219
- @property
220
- def package(self):
221
- if self._model_entity is not None:
222
- self.package = self.model_entity.package
223
- if self._package is None:
224
- raise ValueError('Missing Package entity on adapter. Please set: "adapter.package=package"')
225
- assert isinstance(self._package, (entities.Package, entities.Dpk))
226
- return self._package
227
-
228
- @package.setter
229
- def package(self, package):
230
- assert isinstance(package, (entities.Package, entities.Dpk))
231
- self.package_name = package.name
232
- self._package = package
233
-
234
- ###################################
235
- # NEED TO IMPLEMENT THESE METHODS #
236
- ###################################
237
-
238
- def load(self, local_path, **kwargs):
239
- """ Loads model and populates self.model with a `runnable` model
240
-
241
- Virtual method - need to implement
242
-
243
- This function is called by load_from_model (download to local and then loads)
244
-
245
- :param local_path: `str` directory path in local FileSystem
246
- """
247
- raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
248
-
249
- def save(self, local_path, **kwargs):
250
- """ saves configuration and weights locally
251
-
252
- Virtual method - need to implement
253
-
254
- the function is called in save_to_model which first save locally and then uploads to model entity
255
-
256
- :param local_path: `str` directory path in local FileSystem
257
- """
258
- raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
259
-
260
- def train(self, data_path, output_path, **kwargs):
261
- """
262
- Virtual method - need to implement
263
-
264
- Train the model according to data in data_paths and save the train outputs to output_path,
265
- this include the weights and any other artifacts created during train
266
-
267
- :param data_path: `str` local File System path to where the data was downloaded and converted at
268
- :param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
269
- """
270
- raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
271
-
272
- def predict(self, batch, **kwargs):
273
- """ Model inference (predictions) on batch of items
274
-
275
- Virtual method - need to implement
276
-
277
- :param batch: output of the `prepare_item_func` func
278
- :return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
279
- """
280
- raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
281
-
282
- def embed(self, batch, **kwargs):
283
- """ Extract model embeddings on batch of items
284
-
285
- Virtual method - need to implement
286
-
287
- :param batch: output of the `prepare_item_func` func
288
- :return: `list[list]` a feature vector per each item in the batch
289
- """
290
- raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
291
-
292
- def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
293
- """
294
- This function evaluates the model prediction on a dataset (with GT annotations).
295
- The evaluation process will upload the scores and metrics to the platform.
296
-
297
- :param model: The model to evaluate (annotation.metadata.system.model.name
298
- :param dataset: Dataset where the model predicted and uploaded its annotations
299
- :param filters: Filters to query items on the dataset
300
- :return:
301
- """
302
- import dtlpymetrics
303
- compare_types = model.output_type
304
- if not filters:
305
- filters = entities.Filters()
306
- if filters is not None and isinstance(filters, dict):
307
- filters = entities.Filters(custom_filter=filters)
308
- model = dtlpymetrics.scoring.create_model_score(model=model,
309
- dataset=dataset,
310
- filters=filters,
311
- compare_types=compare_types)
312
- return model
313
-
314
- def convert_from_dtlpy(self, data_path, **kwargs):
315
- """ Convert Dataloop structure data to model structured
316
-
317
- Virtual method - need to implement
318
-
319
- e.g. take dlp dir structure and construct annotation file
320
-
321
- :param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
322
- :return:
323
- """
324
- raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
325
-
326
- #################
327
- # DTLPY METHODS #
328
- ################
329
- def prepare_item_func(self, item: entities.Item):
330
- """
331
- Prepare the Dataloop item before calling the `predict` function with a batch.
332
- A user can override this function to load item differently
333
- Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
334
-
335
- :param item:
336
- :return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
337
- """
338
- # Item to batch func
339
- if isinstance(self.model_entity.input_type, list):
340
- if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
341
- processed = self._item_to_text(item)
342
- elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
343
- processed = self._item_to_image(item)
344
- else:
345
- processed = self._item_to_item(item)
346
-
347
- elif self.model_entity.input_type in self.item_to_batch_mapping:
348
- processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
349
-
350
- else:
351
- processed = self._item_to_item(item)
352
-
353
- return processed
354
-
355
- def __include_model_annotations(self, annotation_filters):
356
- include_model_annotations = self.model_entity.configuration.get("include_model_annotations", False)
357
- if include_model_annotations is False:
358
- if annotation_filters.custom_filter is None:
359
- annotation_filters.add(
360
- field="metadata.system.model.name", values=False, operator=entities.FiltersOperations.EXISTS
361
- )
362
- else:
363
- annotation_filters.custom_filter['filter']['$and'].append({'metadata.system.model.name': {'$exists': False}})
364
- return annotation_filters
365
-
366
- def __download_background_images(self, filters, data_subset_base_path, annotation_options):
367
- background_list = list()
368
- if self.configuration.get('include_background', False) is True:
369
- filters.custom_filter["filter"]["$and"].append({"annotated": False})
370
- background_list = self.model_entity.dataset.items.download(
371
- filters=filters,
372
- local_path=data_subset_base_path,
373
- annotation_options=annotation_options,
374
- )
375
- return background_list
376
-
377
- def prepare_data(
378
- self,
379
- dataset: entities.Dataset,
380
- # paths
381
- root_path=None,
382
- data_path=None,
383
- output_path=None,
384
- #
385
- overwrite=False,
386
- **kwargs,
387
- ):
388
- """
389
- Prepares dataset locally before training or evaluation.
390
- download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
391
-
392
- :param dataset: dl.Dataset
393
- :param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
394
- :param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
395
- :param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
396
-
397
- :param bool overwrite: overwrite the data path (download again). default is False
398
- """
399
- # define paths
400
- dataloop_path = service_defaults.DATALOOP_PATH
401
- root_path = self.adapter_defaults.resolve("root_path", root_path)
402
- data_path = self.adapter_defaults.resolve("data_path", data_path)
403
- output_path = self.adapter_defaults.resolve("output_path", output_path)
404
- if root_path is None:
405
- now = datetime.datetime.now()
406
- root_path = os.path.join(dataloop_path,
407
- 'model_data',
408
- "{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
409
- now.strftime('%Y-%m-%d-%H%M%S'),
410
- )
411
- if data_path is None:
412
- data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
413
- os.makedirs(data_path, exist_ok=True)
414
- if output_path is None:
415
- output_path = os.path.join(root_path, 'output')
416
- os.makedirs(output_path, exist_ok=True)
417
-
418
- if len(os.listdir(data_path)) > 0:
419
- self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
420
-
421
- annotation_options = entities.ViewAnnotationOptions.JSON
422
- if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
423
- annotation_options = entities.ViewAnnotationOptions.INSTANCE
424
-
425
- # Download the subset items
426
- subsets = self.model_entity.metadata.get("system", {}).get("subsets", None)
427
- annotations_subsets = self.model_entity.metadata.get("system", {}).get("annotationsSubsets", {})
428
- if subsets is None:
429
- raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
430
- for subset, filters_dict in subsets.items():
431
- data_subset_base_path = os.path.join(data_path, subset)
432
- if os.path.isdir(data_subset_base_path) and not overwrite:
433
- # existing and dont overwrite
434
- self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
435
- continue
436
-
437
- filters = entities.Filters(custom_filter=filters_dict)
438
- self.logger.debug("Downloading subset {!r} of {}".format(subset, self.model_entity.dataset.name))
439
-
440
- annotation_filters = None
441
- if subset in annotations_subsets:
442
- annotation_filters = entities.Filters(
443
- use_defaults=False,
444
- resource=entities.FiltersResource.ANNOTATION,
445
- custom_filter=annotations_subsets[subset]
446
- )
447
- # if user provided annotation_filters, skip the default filters
448
- elif self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
449
- annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
450
- if self.model_entity.output_type in [
451
- entities.AnnotationType.SEGMENTATION,
452
- entities.AnnotationType.POLYGON,
453
- ]:
454
- model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
455
- else:
456
- model_output_types = [self.model_entity.output_type]
457
-
458
- annotation_filters.add(
459
- field=entities.FiltersKnownFields.TYPE,
460
- values=model_output_types,
461
- operator=entities.FiltersOperations.IN,
462
- )
463
-
464
- annotation_filters = self.__include_model_annotations(annotation_filters)
465
- annotations_subsets[subset] = annotation_filters.prepare()
466
-
467
- ret_list = dataset.items.download(
468
- filters=filters,
469
- local_path=data_subset_base_path,
470
- annotation_options=annotation_options,
471
- annotation_filters=annotation_filters,
472
- )
473
- filters = entities.Filters(custom_filter=subsets[subset])
474
- background_ret_list = self.__download_background_images(
475
- filters=filters, data_subset_base_path=data_subset_base_path, annotation_options=annotation_options
476
- )
477
- ret_list = list(ret_list)
478
- background_ret_list = list(background_ret_list)
479
- self.logger.debug(f"Subset '{subset}': ret_list length: {len(ret_list)}, background_ret_list length: {len(background_ret_list)}")
480
- # Combine ret_list and background_ret_list generators into a single generator
481
- ret_list = ret_list + background_ret_list
482
- if isinstance(ret_list, list) and len(ret_list) == 0:
483
- if annotation_filters is not None:
484
- annotation_filters_str = annotation_filters.prepare()
485
- else:
486
- annotation_filters_str = None
487
- raise ValueError(
488
- f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
489
- f"Subset {subset} filters: {filters.prepare()}\n"
490
- f"Annotation filters: {annotation_filters_str}"
491
- )
492
-
493
- self.convert_from_dtlpy(data_path=data_path, **kwargs)
494
- return root_path, data_path, output_path
495
-
496
- def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
497
- """ Loads a model from given `dl.Model`.
498
- Reads configurations and instantiate self.model_entity
499
- Downloads the model_entity bucket (if available)
500
-
501
- :param model_entity: `str` dl.Model entity
502
- :param local_path: `str` directory path in local FileSystem to download the model_entity to
503
- :param overwrite: `bool` (default False) if False does not download files with same name else (True) download all
504
- """
505
- if model_entity is not None:
506
- self.model_entity = model_entity
507
- if local_path is None:
508
- local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
509
- # Load configuration and adapter defaults
510
- self.adapter_defaults = AdapterDefaults(self)
511
- # Point _configuration to the same object since AdapterDefaults inherits from ModelConfigurations
512
- self._configuration = self.adapter_defaults
513
- # Download
514
- self.model_entity.artifacts.download(
515
- local_path=local_path,
516
- overwrite=overwrite
517
- )
518
- self.load(local_path, **kwargs)
519
-
520
- def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
521
- """
522
- Saves the model state to a new bucket and configuration
523
-
524
- Saves configuration and weights to artifacts
525
- Mark the model as `trained`
526
- loads only applies for remote buckets
527
-
528
- :param local_path: `str` directory path in local FileSystem to save the current model bucket (weights) (default will create a temp dir)
529
- :param replace: `bool` will clean the bucket's content before uploading new files
530
- :param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
531
- :return:
532
- """
533
-
534
- if local_path is None:
535
- local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
536
- self.logger.debug("Using temporary dir at {}".format(local_path))
537
-
538
- self.save(local_path=local_path, **kwargs)
539
-
540
- if self.model_entity is None:
541
- raise ValueError('Missing model entity on the adapter. '
542
- 'Please set before saving: "adapter.model_entity=model"')
543
-
544
- self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'),
545
- overwrite=True)
546
- if cleanup:
547
- shutil.rmtree(path=local_path, ignore_errors=True)
548
- self.logger.info("Clean-up. deleting {}".format(local_path))
549
-
550
- # ===============
551
- # SERVICE METHODS
552
- # ===============
553
-
554
- @entities.Package.decorators.function(display_name='Predict Items',
555
- inputs={'items': 'Item[]'},
556
- outputs={'items': 'Item[]', 'annotations': 'Annotation[]'})
557
- def predict_items(self, items: list, upload_annotations=None, clean_annotations=None, batch_size=None, **kwargs):
558
- """
559
- Run the predict function on the input list of items (or single) and return the items and the predictions.
560
- Each prediction is by the model output type (package.output_type) and model_info in the metadata
561
-
562
- :param items: `List[dl.Item]` list of items to predict
563
- :param upload_annotations: `bool` uploads the predictions on the given items
564
- :param clean_annotations: `bool` deletes previous model annotations (predictions) before uploading new ones
565
- :param batch_size: `int` size of batch to run a single inference
566
-
567
- :return: `List[dl.Item]`, `List[List[dl.Annotation]]`
568
- """
569
- if batch_size is None:
570
- batch_size = self.configuration.get('batch_size', 4)
571
- upload_annotations = self.adapter_defaults.resolve("upload_annotations", upload_annotations)
572
- clean_annotations = self.adapter_defaults.resolve("clean_annotations", clean_annotations)
573
- input_type = self.model_entity.input_type
574
- self.logger.debug(
575
- "Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
576
- pool = ThreadPoolExecutor(max_workers=16)
577
-
578
- annotations = list()
579
- for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None,
580
- file=sys.stdout):
581
- batch_items = items[i_batch: i_batch + batch_size]
582
- batch = list(pool.map(self.prepare_item_func, batch_items))
583
- batch_collections = self.predict(batch, **kwargs)
584
- _futures = list(pool.map(partial(self._update_predictions_metadata),
585
- batch_items,
586
- batch_collections))
587
- # Loop over the futures to make sure they are all done to avoid race conditions
588
- _ = [_f for _f in _futures]
589
- if upload_annotations is True:
590
- self.logger.debug(
591
- "Uploading items' annotation for model {!r}.".format(self.model_entity.name))
592
- try:
593
- batch_collections = list(pool.map(partial(self._upload_model_annotations,
594
- clean_annotations=clean_annotations),
595
- batch_items,
596
- batch_collections))
597
- except Exception as err:
598
- self.logger.exception("Failed to upload annotations items.")
599
-
600
- for collection in batch_collections:
601
- # function needs to return `List[List[dl.Annotation]]`
602
- # convert annotation collection to a list of dl.Annotation for each batch
603
- if isinstance(collection, entities.AnnotationCollection):
604
- annotations.extend([annotation for annotation in collection.annotations])
605
- else:
606
- logger.warning(f'RETURN TYPE MAY BE INVALID: {type(collection)}')
607
- annotations.extend(collection)
608
- # TODO call the callback
609
-
610
- pool.shutdown()
611
- return items, annotations
612
-
613
- @entities.Package.decorators.function(display_name='Embed Items',
614
- inputs={'items': 'Item[]'},
615
- outputs={'items': 'Item[]', 'features': 'Json[]'})
616
- def embed_items(self, items: list, upload_features=None, batch_size=None, progress:utilities.Progress=None, **kwargs):
617
- """
618
- Extract feature from an input list of items (or single) and return the items and the feature vector.
619
-
620
- :param items: `List[dl.Item]` list of items to embed
621
- :param upload_features: `bool` uploads the features on the given items
622
- :param batch_size: `int` size of batch to run a single embed
623
-
624
- :return: `List[dl.Item]`, `List[List[vector]]`
625
- """
626
- if batch_size is None:
627
- batch_size = self.configuration.get('batch_size', 4)
628
- upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
629
- input_type = self.model_entity.input_type
630
- self.logger.debug(
631
- "Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
632
-
633
- # Search for existing feature set for this model id
634
- feature_set = self.model_entity.feature_set
635
- if feature_set is None:
636
- logger.info('Feature Set not found. creating... ')
637
- try:
638
- self.model_entity.project.feature_sets.get(feature_set_name=self.model_entity.name)
639
- feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
640
- logger.warning(
641
- f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}")
642
- except exceptions.NotFound:
643
- feature_set_name = self.model_entity.name
644
- feature_set = self.model_entity.project.feature_sets.create(name=feature_set_name,
645
- entity_type=entities.FeatureEntityType.ITEM,
646
- model_id=self.model_entity.id,
647
- project_id=self.model_entity.project_id,
648
- set_type=self.model_entity.name,
649
- size=self.configuration.get('embeddings_size',
650
- 256))
651
- logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
652
- else:
653
- logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
654
-
655
- # upload the feature vectors
656
- pool = ThreadPoolExecutor(max_workers=16)
657
- vectors = list()
658
- for i_batch in tqdm.tqdm(range(0, len(items), batch_size),
659
- desc='embedding',
660
- unit='bt',
661
- leave=None,
662
- file=sys.stdout):
663
- batch_items = items[i_batch: i_batch + batch_size]
664
- batch = list(pool.map(self.prepare_item_func, batch_items))
665
- batch_vectors = self.embed(batch, **kwargs)
666
- vectors.extend(batch_vectors)
667
- if upload_features is True:
668
- self.logger.debug(
669
- "Uploading items' feature vectors for model {!r}.".format(self.model_entity.name))
670
- try:
671
- list(pool.map(partial(self._upload_model_features,
672
- progress.logger if progress is not None else self.logger,
673
- feature_set.id,
674
- self.model_entity.project_id),
675
- batch_items,
676
- batch_vectors))
677
- except Exception as err:
678
- self.logger.exception("Failed to upload feature vectors to items.")
679
-
680
- pool.shutdown()
681
- return items, vectors
682
-
683
- @entities.Package.decorators.function(display_name='Embed Dataset with DQL',
684
- inputs={'dataset': 'Dataset',
685
- 'filters': 'Json'})
686
- def embed_dataset(self,
687
- dataset: entities.Dataset,
688
- filters: entities.Filters = None,
689
- upload_features=None,
690
- batch_size=None,
691
- progress:utilities.Progress=None,
692
- **kwargs):
693
- """
694
- Extract feature from all items given
695
-
696
- :param dataset: Dataset entity to predict
697
- :param filters: Filters entity for a filtering before embedding
698
- :param upload_features: `bool` uploads the features back to the given items
699
- :param batch_size: `int` size of batch to run a single embed
700
-
701
- :return: `bool` indicating if the embedding process completed successfully
702
- """
703
- if batch_size is None:
704
- batch_size = self.configuration.get('batch_size', 4)
705
- upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
706
-
707
- self.logger.debug("Creating embeddings for dataset (name:{}, id:{}), using batch size {}".format(dataset.name,
708
- dataset.id,
709
- batch_size))
710
- if not filters:
711
- filters = entities.Filters()
712
- if filters is not None and isinstance(filters, dict):
713
- filters = entities.Filters(custom_filter=filters)
714
- pages = dataset.items.list(filters=filters, page_size=batch_size)
715
- # Item type is 'file' only, can be deleted if default filters are added to custom filters
716
- items = [item for page in pages for item in page if item.type == 'file']
717
- self.embed_items(items=items,
718
- upload_features=upload_features,
719
- batch_size=batch_size,
720
- progress=progress,
721
- **kwargs)
722
- return True
723
-
724
- @entities.Package.decorators.function(display_name='Predict Dataset with DQL',
725
- inputs={'dataset': 'Dataset',
726
- 'filters': 'Json'})
727
- def predict_dataset(self,
728
- dataset: entities.Dataset,
729
- filters: entities.Filters = None,
730
- upload_annotations=None,
731
- clean_annotations=None,
732
- batch_size=None,
733
- **kwargs):
734
- """
735
- Predicts all items given
736
-
737
- :param dataset: Dataset entity to predict
738
- :param filters: Filters entity for a filtering before predicting
739
- :param upload_annotations: `bool` uploads the predictions back to the given items
740
- :param clean_annotations: `bool` if set removes existing predictions with the same package-model name (default: False)
741
- :param batch_size: `int` size of batch to run a single inference
742
-
743
- :return: `bool` indicating if the prediction process completed successfully
744
- """
745
-
746
- if batch_size is None:
747
- batch_size = self.configuration.get('batch_size', 4)
748
-
749
- self.logger.debug("Predicting dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
750
- dataset.id,
751
- batch_size))
752
- if not filters:
753
- filters = entities.Filters()
754
- if filters is not None and isinstance(filters, dict):
755
- filters = entities.Filters(custom_filter=filters)
756
- pages = dataset.items.list(filters=filters, page_size=batch_size)
757
- # Item type is 'file' only, can be deleted if default filters are added to custom filters
758
- items = [item for page in pages for item in page if item.type == 'file']
759
- self.predict_items(items=items,
760
- upload_annotations=upload_annotations,
761
- clean_annotations=clean_annotations,
762
- batch_size=batch_size,
763
- **kwargs)
764
- return True
765
-
766
- @entities.Package.decorators.function(display_name='Train a Model',
767
- inputs={'model': entities.Model},
768
- outputs={'model': entities.Model})
769
- def train_model(self,
770
- model: entities.Model,
771
- cleanup=False,
772
- progress: utilities.Progress = None,
773
- context: utilities.Context = None):
774
- """
775
- Train on existing model.
776
- data will be taken from dl.Model.datasetId
777
- configuration is as defined in dl.Model.configuration
778
- upload the output the model's bucket (model.bucket)
779
- """
780
- if isinstance(model, dict):
781
- model = repositories.Models(client_api=self._client_api).get(model_id=model['id'])
782
- output_path = None
783
- try:
784
- logger.info("Received {s} for training".format(s=model.id))
785
- model = model.wait_for_model_ready()
786
- if model.status == 'failed':
787
- raise ValueError("Model is in failed state, cannot train.")
788
-
789
- ##############
790
- # Set status #
791
- ##############
792
- model.status = 'training'
793
- if context is not None:
794
- if 'system' not in model.metadata:
795
- model.metadata['system'] = dict()
796
- model.update(reload_services=False)
797
-
798
- ##########################
799
- # load model and weights #
800
- ##########################
801
- logger.info("Loading Adapter with: {n} ({i!r})".format(n=model.name, i=model.id))
802
- self.load_from_model(model_entity=model)
803
-
804
- ################
805
- # prepare data #
806
- ################
807
- root_path, data_path, output_path = self.prepare_data(
808
- dataset=self.model_entity.dataset,
809
- root_path=os.path.join('tmp', model.id)
810
- )
811
- # Start the Train
812
- logger.info("Training {p_name!r} with model {m_name!r} on data {d_path!r}".
813
- format(p_name=self.package_name, m_name=model.id, d_path=data_path))
814
- if progress is not None:
815
- progress.update(message='starting training')
816
-
817
- def on_epoch_end_callback(i_epoch, n_epoch):
818
- if progress is not None:
819
- progress.update(progress=int(100 * (i_epoch + 1) / n_epoch),
820
- message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
821
-
822
- self.train(data_path=data_path,
823
- output_path=output_path,
824
- on_epoch_end_callback=on_epoch_end_callback)
825
- if progress is not None:
826
- progress.update(message='saving model',
827
- progress=99)
828
-
829
- self.save_to_model(local_path=output_path, replace=True)
830
- model.status = 'trained'
831
- model.update(reload_services=False)
832
- ###########
833
- # cleanup #
834
- ###########
835
- if cleanup:
836
- shutil.rmtree(output_path, ignore_errors=True)
837
- except Exception:
838
- # save also on fail
839
- if output_path is not None:
840
- self.save_to_model(local_path=output_path, replace=True)
841
- logger.info('Execution failed. Setting model.status to failed')
842
- raise
843
- return model
844
-
845
- @entities.Package.decorators.function(display_name='Evaluate a Model',
846
- inputs={'model': entities.Model,
847
- 'dataset': entities.Dataset,
848
- 'filters': 'Json'},
849
- outputs={'model': entities.Model,
850
- 'dataset': entities.Dataset
851
- })
852
- def evaluate_model(self,
853
- model: entities.Model,
854
- dataset: entities.Dataset,
855
- filters: entities.Filters = None,
856
- #
857
- progress: utilities.Progress = None,
858
- context: utilities.Context = None):
859
- """
860
- Evaluate a model.
861
- data will be downloaded from the dataset and query
862
- configuration is as defined in dl.Model.configuration
863
- upload annotations and calculate metrics vs GT
864
-
865
- :param model: Model entity to run prediction
866
- :param dataset: Dataset to evaluate
867
- :param filters: Filter for specific items from dataset
868
- :param progress: dl.Progress for report FaaS progress
869
- :param context:
870
- :return:
871
- """
872
- logger.info(
873
- f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
874
- ##########################
875
- # load model and weights #
876
- ##########################
877
- logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
878
- self.load_from_model(dataset=dataset,
879
- model_entity=model)
880
-
881
- ##############
882
- # Predicting #
883
- ##############
884
- logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
885
- if not filters:
886
- filters = entities.Filters()
887
- self.predict_dataset(dataset=dataset,
888
- filters=filters,
889
- with_upload=True)
890
-
891
- ##############
892
- # Evaluating #
893
- ##############
894
- logger.info(f"Starting adapter.evaluate()")
895
- if progress is not None:
896
- progress.update(message='calculating metrics',
897
- progress=98)
898
- model = self.evaluate(model=model,
899
- dataset=dataset,
900
- filters=filters)
901
- #########
902
- # Done! #
903
- #########
904
- if progress is not None:
905
- progress.update(message='finishing evaluation',
906
- progress=99)
907
- return model, dataset
908
-
909
- # =============
910
- # INNER METHODS
911
- # =============
912
-
913
- @staticmethod
914
- def _upload_model_features(logger, feature_set_id, project_id, item: entities.Item, vector):
915
- try:
916
- if vector is not None:
917
- item.features.create(value=vector,
918
- project_id=project_id,
919
- feature_set_id=feature_set_id,
920
- entity=item)
921
- except Exception as e:
922
- logger.error(f'Failed to upload feature vector of length {len(vector)} to item {item.id}, Error: {e}')
923
-
924
- def _upload_model_annotations(self, item: entities.Item, predictions, clean_annotations):
925
- """
926
- Utility function that upload prediction to dlp platform based on the package.output_type
927
- :param predictions: `dl.AnnotationCollection`
928
- :param cleanup: `bool` if set removes existing predictions with the same package-model name
929
- """
930
- if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
931
- raise TypeError('predictions was expected to be of type {}, but instead it is {}'.
932
- format(entities.AnnotationCollection, type(predictions)))
933
- if clean_annotations:
934
- clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
935
- clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name)
936
- # clean_filter.add(field='type', values=self.model_entity.output_type,)
937
- item.annotations.delete(filters=clean_filter)
938
- annotations = item.annotations.upload(annotations=predictions)
939
- return annotations
940
-
941
- @staticmethod
942
- def _item_to_image(item):
943
- """
944
- Preprocess items before calling the `predict` functions.
945
- Convert item to numpy array
946
-
947
- :param item:
948
- :return:
949
- """
950
- buffer = item.download(save_locally=False)
951
- image = np.asarray(Image.open(buffer))
952
- return image
953
-
954
- @staticmethod
955
- def _item_to_item(item):
956
- """
957
- Default item to batch function.
958
- This function should prepare a single item for the predict function, e.g. for images, it loads the image as numpy array
959
- :param item:
960
- :return:
961
- """
962
- return item
963
-
964
- @staticmethod
965
- def _item_to_text(item):
966
- filename = item.download(overwrite=True)
967
- text = None
968
- if item.mimetype == 'text/plain' or item.mimetype == 'text/markdown':
969
- with open(filename, 'r') as f:
970
- text = f.read()
971
- text = text.replace('\n', ' ')
972
- else:
973
- logger.warning('Item is not text file. mimetype: {}'.format(item.mimetype))
974
- text = item
975
- if os.path.exists(filename):
976
- os.remove(filename)
977
- return text
978
-
979
- @staticmethod
980
- def _uri_to_image(data_uri):
981
- # data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAS4AAAEuCAYAAAAwQP9DAAAU80lEQVR4Xu2da+hnRRnHv0qZKV42LDOt1eyGULoSJBGpRBFprBJBQrBJBBWGSm8jld5WroHUCyEXKutNu2IJ1QtXetULL0uQFCu24WoRsV5KpYvGYzM4nv6X8zu/mTnznPkcWP6XPTPzzOf7/L7/OXPmzDlOHBCAAAScETjOWbyECwEIQEAYF0kAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxDwTeDTkr4s6UxJ/5F0QNK3JD3lu1tbR49xLVld+jYXgcskvSTpIkmnS/qgpJMk/Tv8bHHZ7+PXPw6M5kRJx0t6Ijkv9uUsSW+U9Iykczfp4K8lfXiuztdoF+OqQZk2vBEwUzFTsK9mQNFkotGkhvFeSc+G86NRtdDfd0h6tIVASsSAcZWgSp0eCJjJ7JR0SRgZ2SjHDMp+38Jho7PXTAzkBUmvn1jWRTGMy4VMBJmBgBnSpZLsMs7+paOodao3k/hLqCBe8j0cfj4Yvtp8k/1fPLaaf4pxxXPSS8r4/Vsl3SXp5EHgNjo8JukDkg6v06nWy2JcrSvUX3xmKjYSipdqF0h6V/jgp6Mh+2DHf0YpnSd6p6TTkjml7UZRL4bLPasnmo7VHb+PKsQ20rZTQ6ql1lclfXODxr4u6Ru1gpizHYxrTvq0beZkE9cfkXRxxcu0pyXZaMiMKX71dBfua5sY1Psk/baHtMK4elC5rT5eFS7Z7Otmd8VyRDwcRZkxmUlFo8rRxlx13Clpz6Dxn0r61FwB1W4X46pNvM/27PLPPmhmVhvNLUWTiaZil1/xEswMx/7fbv9bWfs5nfcxommdceQU55eWSNxGihcmHbMRZK45Oxe8MK75ZYofaku8MyQ9J+mQpKNJMqbzLfeHkIeTuPP35JUIbCSVToRvNrKyftqCSfs3nE9qqT+txWKT8OmxT9LnWguyZDwYV0m6m9dtH+SbJNlamw+tGIIl7Va6/VPS8xusP4rN2JojG8E8NrhUS+d4ht/bbfkTJP0umGk6ER7PtfkVmwR/wzaXgEck7Q1mNcfE9oq4mzx9aFxXB55NBlsiKIyrBNXt67xB0q3bn7aYM+xSxkZVNjez5Eu4GoLZ5fb+pCFb/mB/LLo6MK555LaRyUPzND251VUWRJpRxTt2cUJ8csMUfBUBG61en/ymu8tE6zvGNd+nwuao7N8PJO0Kz7JZNDbH9aSkv4fQ0su2RyS9VtKD4dJtOClt5+4Il4Fpz+KkdqzLnpuzdrY74vnppWG6ujx9xMXOsUWPjw8WW27XBv+/GgH7Q2Dzh/G4NoxkV6vF+dkYV1sCRoNpKyqiaYmA/TGxxbXxsD963d3YwLhaSkligcDWBIZTDHajo+RauGb1wLialYbAIPB/BO6Q9Pnkt7dJshs93R0YV3eS02HHBGz+8Owk/vN6nU/EuBxnMaF3RWC4DOJ7kr7UFYGksxhXr8rTb28Eho/5dDvaMuEwLm/pS7w9EhiOtu4Oz332yOLlPmNc3UpPx50QsCUytlg5vXvY5RKIVC+My0n2Ema3BG4Oz7VGAN2PthhxdftZoOOOCKQLTu1RKlvL1f3D6Yy4HGUwoXZHwLaq+X7S6xvDzhrdgRh2GOPqPgUA0DCB9LlE27tsu73zG+5K3tAwrrw8qQ0CuQjYZLztmRaP7vbc2gokxpUrzagHAnkJpNvXMNoasMW48iYbtUEgF4F0Up7RFsaVK6+oBwLFCKST8t3uAMGlYrH8omIIFCFg21zvDjV3uwMExlUkt6gUAkUIDCflu34mcTPCzHEVyT0qhcBkAumLVJiU3wQjxjU5vygIgSIE0l0gutxPfgxVjGsMJc6BQB0C9kC1vW4sHvbik/RlKXWicNAKxuVAJELshkC6fY29sdzecs6xAQGMi7SAQDsE7IW5e0I4PJe4hS4YVztJSyQQsF0fdgYM3E3EuPhEQKB5Aumrx7ibuI1cjLiaz2cC7IRAugyCy0SMq5O0p5veCaSr5blMxLi85zPxd0LgGUmnSOIycYTgXCqOgMQpEChMwJY93MfdxPGUMa7xrDgTAqUIxGUQ7Ck/kjDGNRIUp0GgIIG49xaXiSMhY1wjQXEaBAoRSFfLczdxJGSMayQoToNAIQLpannuJo6EjHGNBMVpEChEgMvECWAxrgnQKAKBTAS4TJwIEuOaCI5iEMhAgMvEiRAxrongKAaBDAS4TJwIEeOaCI5iEFiTQPpQNXcTV4SJca0IjNMhkIlA+sJX7iauCBXjWhEYp0MgE4G49xaLTicAxbgmQKMIBNYkkL6CjPcmToCJcU2ARhEIrEkgfVP1Lkn2Zh+OFQhgXCvA4lQIZCIQl0EckWSjL44VCWBcKwLjdAhkIHBY0vmS9kmy0RfHigQwrhWBcToE1iSQLoO4QtK9a9bXZXGMq0vZ6fSMBOLe8rb3ll0m8sLXCWJgXBOgUQQCaxA4KOlStmheg6AkjGs9fpSGwKoEXgoFbpF086qFOf9/BDAuMgEC9Qike8tfLslGXxwTCGBcE6BRBAITCdgI66ZQls/eRIiMuNYAR1EITCAQ57ful2SjL46JBHD9ieAoBoEJBJjfmgBtoyIYVyaQVAOBbQik67eulmRvruaYSADjmgiOYhBYkUBcv2XFdrB+a0V6g9MxrvX4URoCYwnwfOJYUiPOw7hGQOIUCGQgEPff4vnEDDAxrgwQqQIC2xBI99+6VpKNvjjWIIBxrQGPohAYSSDdf4ttmkdC2+o0jCsDRKqAwDYEmN/KnCIYV2agVAeBDQgclfQW9t/KlxsYVz6W1ASBjQiw/1aBvMC4CkClSggkBOLziey/lTEtMK6MMKkKAhsQsBdhXMj+W3lzA+PKy5PaIJASOF3SsfAL3ladMTcwrowwqQoCAwK8hqxQSmBchcBSLQTCg9S7Jdn8lo2+ODIRwLgygaQaCGxAwF6EcRrLIPLnBsaVnyk1QsAIXCVpf0DBNjaZcwLjygyU6iAQCOyVdH34nm1sMqcFxpUZKNVBIBCIu0HcHUZfgMlIAOPKCJOqIBAIpKvl2Q2iQFpgXAWgUmX3BLhMLJwCGFdhwFTfJQEuEwvLjnEVBkz13RHgpRgVJMe4KkCmia4IpA9Vs+i0kPQYVyGwVNstgQcl7WLRaVn9Ma6yfKm9LwLsvVVJb4yrEmia6YJAvJvIs4mF5ca4CgOm+q4I8GxiJbkxrkqgaWbxBNJnE22OyzYQ5ChEAOMqBJZquyMQ124dkWTvUeQoSADjKgiXqrshcJmk+0Jv2em0guwYVwXINLF4Agck2YaBdvDC1wpyY1wVINPEognYZeHvJZ0g6RFJFyy6t410DuNqRAjCcEvgBkm3huhvl3Sd2544ChzjciQWoTZJIL5+zILjbmIliTCuSqBpZpEE0tePsei0osQYV0XYNLU4Aunrx/ZJsp85KhDAuCpAponFErhT0p7QO5ZBVJQZ46oIm6YWR4D5rZkkxbhmAk+z7gkwvzWjhBjXjPBp2jWBz0i6K/TgN5Iucd0bZ8FjXM4EI9xmCMSdTi2gn0gyI+OoRADjqgSaZhZHIH3Mh1eQVZYX46oMnOYWQyDuBmEdulzSwcX0zEFHMC4HIhFikwReSqLiwerKEmFclYHT3CIIpNvYWIf4HFWWFeCVgdPcIgh8R9JXQk/+KulNi+iVo05gXI7EItRmCPxS0kdDNLalzXuaiayTQDCuToSmm9kI2MJT25751FDjLZJsaQRHRQIYV0XYNLUIAvdIujLpCXcUZ5AV45oBOk26JvCMpFNCD+zO4vGue+M0eIzLqXCEPQuBdBsbC+BeSVfMEknnjWJcnScA3V+JwJOS3pyUuFqSraDnqEwA46oMnOZcE0gXnVpH+PzMJCfgZwJPsy4JYFyNyIZxNSIEYbggMDSuHZKechH5woLEuBYmKN0pSoARV1G84yvHuMaz4sy+CQzvKB6VdE7fSObrPcY1H3ta9kVgeEeRt/rMqB/GNSN8mnZFYHiZyIr5GeXDuGaET9NuCFwlaX8SLTtCzCwdxjWzADTvgkC6v7wFfJukG1xEvtAgMa6FCku3shL4s6QzkxpZMZ8V7+qVYVyrM6NEfwSel3Ri0m3Wb82cAxjXzALQfPMEhvNbf5D07uajXniAGNfCBaZ7axN4VNLbk1pulLR37VqpYC0CGNda+Ci8cAK22+mxQR95o08DomNcDYhACM0SGK6Wt3cpmnFxzEwA45pZAJpvmsBwtTyXiY3IhXE1IgRhNElguFqey8RGZMK4GhGCMJojMLybyGViQxJhXA2JQShNEbhT0p4kIlbLNyQPxtWQGITSFAH2l29KjlcHg3E1LA6hzUrgxcGe8nxWZpUD42oIP6E0SuAiSQ8NYtsl6eFG4+0uLP6KdCc5HR5BYKOFp+y/NQJcrVMwrlqkaccTgQckXTwI+DJJ93vqxJJjxbiWrC59m0LgfEmHBwX/JemEKZVRpgwBjKsMV2r1S8BGVvcNwv+spB/67dLyIse4lqcpPVqPwEbGxcaB6zHNXhrjyo6UCp0TuFLSPYM+XCPpx877tajwMa5FyUlnMhCwveRvHdTDjqcZwOasAuPKSZO6lkDggKTdSUeOSDp3CR1bUh8wriWpSV9yEPiHpJOSinhGMQfVzHVgXJmBUp17AsOtbFgx36CkGFeDohDSbASGj/r8TdIZs0VDw5sSwLhIDgi8QmC4VfPdkmxfLo7GCGBcjQlCOLMSGO7BxVbNs8qxeeMYV6PCENYsBGyX051JyzxYPYsM2zeKcW3PiDP6ITCcmGf9VqPaY1yNCkNY1QkMJ+YPSbLfcTRIAONqUBRCmoXA8BlF1m/NIsO4RjGucZw4a/kEhncUebC6Yc0xrobFIbSqBIbPKDK/VRX/ao1hXKvx4uzlEtgr6frQvUckXbDcrvrvGcblX0N6kIdAaly/kPTxPNVSSwkCGFcJqtTpkUC6+JSFp40riHE1LhDhVSNwUNKloTUm5qthn9YQxjWNG6WWRyA1LlbMN64vxtW4QIRXjcBTkk4LrWFc1bBPawjjmsaNUssjkD7ug3E1ri/G1bhAhFeNQGpcbB5YDfu0hjCuadwotTwCqXGdJ8l2iuBolADG1agwhFWdQGpcfC6q41+tQQRajRdnL5dANK6nJZ2+3G4uo2cY1zJ0pBfrEbDXjz0WquB1ZOuxrFIa46qCmUYaJ/AJST8PMf5K0scaj7f78DCu7lMAAJLSnSFul3QdVNomgHG1rQ/R1SGQPmDNGq46zNdqBeNaCx+FF0LgYUkXhr6wFMKBqBiXA5EIsTgB7igWR5y3AYwrL09q80cg3WueF8A60Q/jciIUYRYjcLOkm0Lt7MNVDHPeijGuvDypzR+BdH6LZxSd6IdxORGKMIsQsBXyx0LNLDwtgrhMpRhXGa7U6oNA+kqyfZLsZw4HBDAuByIRYjEC6T7zbNdcDHP+ijGu/Eyp0Q+BuOspD1b70ezlSDEuZ4IRbjYCF0l6KNTGZWI2rHUqwrjqcKaV9gikj/lwmdiePltGhHE5E4xwsxGIyyC4TMyGtF5FGFc91rTUFoEXJL1OEqvl29JlVDQY1yhMnLQwAuljPl+QdMfC+rf47mBci5eYDm5AIJ3fYjcIhymCcTkUjZDXJhDnt1gtvzbKeSrAuObhTqvzEUj3l78t7H46XzS0PIkAxjUJG4UcE0i3aWYZhFMhMS6nwhH2ZAIHJO0Opcn/yRjnLYhw8/Kn9foE4m6nhyTZ6nkOhwQwLoeiEfJkAryGbDK6tgpiXG3pQTRlCaS7nfJ8YlnWRWvHuIripfLGCLCNTWOCTA0H45pKjnIeCaTbNPP+RI8KclfFsWqEPpVAnJi38jsk2X5cHA4JMOJyKBohTyaQGhe5Pxnj/AURb34NiKAOgXTjQLayqcO8WCsYVzG0VNwYgXRHCNZwNSbOquFgXKsS43yvBOxlr98OwT8g6f1eO0Lc7DlPDvRD4LuSvhi6+zNJn+yn68vrKSOu5WlKjzYmkD6jaKMv25OLwykBjMupcIS9MoH4KjIryK4QK+NrqwDG1ZYeRFOGQDoxby2whqsM52q1YlzVUNPQjAR+JOma0P5zkk6eMRaazkAA48oAkSqaJ/CEpLNClM9KOrX5iAlwSwIYFwmydAJnS3p80MlzJB1deseX3D+Ma8nq0rdIwF6K8bbww58k7QSNbwIYl2/9iH4cAdtA0O4k2rFf0r3jinFWqwQwrlaVIS4IQGBTAhgXyQEBCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwR+C/doIhTZIi/uMAAAAASUVORK5CYII="
982
- image_b64 = data_uri.split(",")[1]
983
- binary = base64.b64decode(image_b64)
984
- image = np.asarray(Image.open(io.BytesIO(binary)))
985
- return image
986
-
987
- def _update_predictions_metadata(self, item: entities.Item, predictions: entities.AnnotationCollection):
988
- """
989
- add model_name and model_id to the metadata of the annotations.
990
- add model_info to the metadata of the system metadata of the annotation.
991
- Add item id to all the annotations in the AnnotationCollection
992
-
993
- :param item: Entity.Item
994
- :param predictions: item's AnnotationCollection
995
- :return:
996
- """
997
- for prediction in predictions:
998
- if prediction.type == entities.AnnotationType.SEGMENTATION:
999
- color = None
1000
- try:
1001
- color = item.dataset._get_ontology().color_map.get(prediction.label, None)
1002
- except (exceptions.BadRequest, exceptions.NotFound):
1003
- ...
1004
- if color is None:
1005
- if self.model_entity._dataset is not None:
1006
- try:
1007
- color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label,
1008
- (255, 255, 255))
1009
- except (exceptions.BadRequest, exceptions.NotFound):
1010
- ...
1011
- if color is None:
1012
- logger.warning("Can't get annotation color from model's dataset, using default.")
1013
- color = prediction.color
1014
- prediction.color = color
1015
-
1016
- prediction.item_id = item.id
1017
- if 'user' in prediction.metadata and 'model' in prediction.metadata['user']:
1018
- prediction.metadata['user']['model']['model_id'] = self.model_entity.id
1019
- prediction.metadata['user']['model']['name'] = self.model_entity.name
1020
- if 'system' not in prediction.metadata:
1021
- prediction.metadata['system'] = dict()
1022
- if 'model' not in prediction.metadata['system']:
1023
- prediction.metadata['system']['model'] = dict()
1024
- confidence = prediction.metadata.get('user', dict()).get('model', dict()).get('confidence', None)
1025
- prediction.metadata['system']['model'] = {
1026
- 'model_id': self.model_entity.id,
1027
- 'name': self.model_entity.name,
1028
- 'confidence': confidence
1029
- }
1030
-
1031
- ##############################
1032
- # Callback Factory functions #
1033
- ##############################
1034
- @property
1035
- def dataloop_keras_callback(self):
1036
- """
1037
- Returns the constructor for a keras api dump callback
1038
- The callback is used for dlp platform to show train losses
1039
-
1040
- :return: DumpHistoryCallback constructor
1041
- """
1042
- try:
1043
- import keras
1044
- except (ImportError, ModuleNotFoundError) as err:
1045
- raise RuntimeError(
1046
- '{} depends on extenral package. Please install '.format(self.__class__.__name__)) from err
1047
-
1048
- import os
1049
- import time
1050
- import json
1051
-
1052
- class DumpHistoryCallback(keras.callbacks.Callback):
1053
- def __init__(self, dump_path):
1054
- super().__init__()
1055
- if os.path.isdir(dump_path):
1056
- dump_path = os.path.join(dump_path,
1057
- '__view__training-history__{}.json'.format(time.strftime("%F-%X")))
1058
- self.dump_file = dump_path
1059
- self.data = dict()
1060
-
1061
- def on_epoch_end(self, epoch, logs=None):
1062
- logs = logs or {}
1063
- for name, val in logs.items():
1064
- if name not in self.data:
1065
- self.data[name] = {'x': list(), 'y': list()}
1066
- self.data[name]['x'].append(float(epoch))
1067
- self.data[name]['y'].append(float(val))
1068
- self.dump_history()
1069
-
1070
- def dump_history(self):
1071
- _json = {
1072
- "query": {},
1073
- "datasetId": "",
1074
- "xlabel": "epoch",
1075
- "title": "training loss",
1076
- "ylabel": "val",
1077
- "type": "metric",
1078
- "data": [{"name": name,
1079
- "x": values['x'],
1080
- "y": values['y']} for name, values in self.data.items()]
1081
- }
1082
-
1083
- with open(self.dump_file, 'w') as f:
1084
- json.dump(_json, f, indent=2)
1085
-
1086
- return DumpHistoryCallback
1
+ import dataclasses
2
+ import threading
3
+ import tempfile
4
+ import datetime
5
+ import logging
6
+ import string
7
+ import shutil
8
+ import random
9
+ import base64
10
+ import copy
11
+ import time
12
+ import tqdm
13
+ import traceback
14
+ import sys
15
+ import io
16
+ import os
17
+ from itertools import chain
18
+ from PIL import Image
19
+ from functools import partial
20
+ import numpy as np
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ import attr
23
+ from collections.abc import MutableMapping
24
+ from typing import Optional
25
+ from .. import entities, utilities, repositories, exceptions
26
+ from ..services import service_defaults
27
+ from ..services.api_client import ApiClient
28
+
29
+ logger = logging.getLogger('ModelAdapter')
30
+
31
+ # Constants
32
+ PREDICT_EMBED_DEFAULT_SUBSET_LIMIT = 1000
33
+ PREDICT_EMBED_DEFAULT_TIMEOUT = 3600 * 24
34
+
35
+
36
+ class ModelConfigurations(MutableMapping):
37
+ """
38
+ Manages model configuration using composition with a backing dict.
39
+
40
+ Uses MutableMapping to implement dict-like behavior without inheritance.
41
+ This avoids duplication: if we inherited from dict, we'd have two dicts
42
+ (one from inheritance, one from model_entity.configuration), leading to
43
+ data inconsistency and maintenance issues.
44
+ """
45
+
46
+ def __init__(self, base_model_adapter):
47
+ # Store reference to base_model_adapter dictionary
48
+ self._backing_dict = {}
49
+
50
+ if base_model_adapter is not None and base_model_adapter.model_entity is not None and base_model_adapter.model_entity.configuration is not None:
51
+ self._backing_dict = base_model_adapter.model_entity.configuration
52
+ if 'include_background' not in self._backing_dict:
53
+ self._backing_dict['include_background'] = False
54
+ self._base_model_adapter = base_model_adapter
55
+ # Don't call _update_model_entity during initialization to avoid premature updates
56
+
57
+ def _update_model_entity(self):
58
+ if self._base_model_adapter is not None and self._base_model_adapter.model_entity is not None:
59
+ self._base_model_adapter.model_entity.update(reload_services=False)
60
+
61
+ def __ior__(self, other):
62
+ self.update(other)
63
+ return self
64
+
65
+ # Required MutableMapping abstract methods
66
+ def __getitem__(self, key):
67
+ return self._backing_dict[key]
68
+
69
+ def __setitem__(self, key, value):
70
+ # Note: This method only updates the backing dict, not object attributes.
71
+ # If you need to also update object attributes, be careful to avoid
72
+ # infinite recursion by not calling __setattr__ from here.
73
+ update = False
74
+ if key not in self._backing_dict or self._backing_dict.get(key) != value:
75
+ update = True
76
+ self._backing_dict[key] = value
77
+ if update:
78
+ self._update_model_entity()
79
+
80
+ def __delitem__(self, key):
81
+ del self._backing_dict[key]
82
+
83
+ def __iter__(self):
84
+ return iter(self._backing_dict)
85
+
86
+ def __len__(self):
87
+ return len(self._backing_dict)
88
+
89
+ def get(self, key, default=None):
90
+ if key not in self._backing_dict:
91
+ self.__setitem__(key, default)
92
+ return self._backing_dict.get(key)
93
+
94
+ def update(self, *args, **kwargs):
95
+ # Check if there will be any modifications
96
+ update_dict = dict(*args, **kwargs)
97
+ has_changes = False
98
+ for key, value in update_dict.items():
99
+ if key not in self._backing_dict or self._backing_dict[key] != value:
100
+ has_changes = True
101
+ break
102
+ self._backing_dict.update(*args, **kwargs)
103
+
104
+ if has_changes:
105
+ self._update_model_entity()
106
+
107
+ def setdefault(self, key, default=None):
108
+ if key not in self._backing_dict:
109
+ self._backing_dict[key] = default
110
+ return self._backing_dict[key]
111
+
112
+
113
+ @dataclasses.dataclass
114
+ class AdapterDefaults(ModelConfigurations):
115
+ # for predict items, dataset, evaluate
116
+ upload_annotations: bool = dataclasses.field(default=True)
117
+ clean_annotations: bool = dataclasses.field(default=True)
118
+ overwrite_annotations: bool = dataclasses.field(default=True)
119
+ # for embeddings
120
+ upload_features: bool = dataclasses.field(default=None)
121
+ # for training
122
+ root_path: str = dataclasses.field(default=None)
123
+ data_path: str = dataclasses.field(default=None)
124
+ output_path: str = dataclasses.field(default=None)
125
+
126
+ def __init__(self, base_model_adapter=None):
127
+ super().__init__(base_model_adapter)
128
+ for f in dataclasses.fields(AdapterDefaults):
129
+ # if the field exists in model_entity.configuration, use it
130
+ # else set it from the attribute default value
131
+ if super().get(f.name) is not None:
132
+ super().__setattr__(f.name, super().get(f.name))
133
+ else:
134
+ super().__setitem__(f.name, f.default)
135
+
136
+ def __setattr__(self, key, value):
137
+ # Dataclass-like fields behave as attributes, so map to dict
138
+ super().__setattr__(key, value)
139
+ if not key.startswith("_"):
140
+ super().__setitem__(key, value)
141
+
142
+ def update(self, *args, **kwargs):
143
+ for f in dataclasses.fields(AdapterDefaults):
144
+ if f.name in kwargs:
145
+ setattr(self, f.name, kwargs[f.name])
146
+ super().update(*args, **kwargs)
147
+
148
+ def resolve(self, key, *args):
149
+ for arg in args:
150
+ if arg is not None:
151
+ super().__setitem__(key, arg)
152
+ return arg
153
+ return self.get(key, None)
154
+
155
+
156
+ class BaseModelAdapter(utilities.BaseServiceRunner):
157
+ _client_api = attr.ib(type=ApiClient, repr=False)
158
+ _feature_set_lock = threading.Lock()
159
+
160
+ def __init__(self, model_entity: entities.Model = None):
161
+ self.logger = logger
162
+ # entities
163
+ self._model_entity = None
164
+ self._configuration = None
165
+ self.adapter_defaults = None
166
+ self.model = None
167
+ self.bucket_path = None
168
+ self._project = None
169
+ self._feature_set = None
170
+ # funcs
171
+ self.item_to_batch_mapping = {'text': self._item_to_text, 'image': self._item_to_image}
172
+ if model_entity is not None:
173
+ self.load_from_model(model_entity=model_entity)
174
+ logger.warning(
175
+ "in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'."
176
+ )
177
+
178
+ ##################
179
+ # Configurations #
180
+ ##################
181
+
182
+ @property
183
+ def configuration(self) -> dict:
184
+ # load from model
185
+ if self._model_entity is not None:
186
+ configuration = self._configuration
187
+ else:
188
+ configuration = dict()
189
+ return configuration
190
+
191
+ @configuration.setter
192
+ def configuration(self, configuration: dict):
193
+ assert isinstance(configuration, dict)
194
+ if self._model_entity is not None:
195
+ # Update configuration with received dict
196
+ self._model_entity.configuration = configuration
197
+ self.adapter_defaults = AdapterDefaults(self)
198
+ self._configuration = self.adapter_defaults
199
+
200
+ ############
201
+ # Entities #
202
+ ############
203
+ @property
204
+ def project(self):
205
+ if self._project is None:
206
+ self._project = self.model_entity.project
207
+ assert isinstance(self._project, entities.Project)
208
+ return self._project
209
+
210
+ @property
211
+ def feature_set(self):
212
+ if self._feature_set is None:
213
+ self._feature_set = self._get_feature_set()
214
+ assert isinstance(self._feature_set, entities.FeatureSet)
215
+ return self._feature_set
216
+
217
+ @property
218
+ def model_entity(self):
219
+ if self._model_entity is None:
220
+ raise ValueError("No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
221
+ assert isinstance(self._model_entity, entities.Model)
222
+ return self._model_entity
223
+
224
+ @model_entity.setter
225
+ def model_entity(self, model_entity):
226
+ assert isinstance(model_entity, entities.Model)
227
+ if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
228
+ if self._model_entity.id != model_entity.id:
229
+ self.logger.warning('Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
230
+ self._model_entity = model_entity
231
+ self.adapter_defaults = AdapterDefaults(self)
232
+ self._configuration = self.adapter_defaults
233
+
234
+ ###################################
235
+ # NEED TO IMPLEMENT THESE METHODS #
236
+ ###################################
237
+
238
+ def load(self, local_path, **kwargs):
239
+ """
240
+ Loads model and populates self.model with a `runnable` model
241
+
242
+ Virtual method - need to implement
243
+
244
+ This function is called by load_from_model (download to local and then loads)
245
+
246
+ :param local_path: `str` directory path in local FileSystem
247
+ """
248
+ raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
249
+
250
+ def save(self, local_path, **kwargs):
251
+ """
252
+ Saves configuration and weights locally
253
+
254
+ Virtual method - need to implement
255
+
256
+ the function is called in save_to_model which first save locally and then uploads to model entity
257
+
258
+ :param local_path: `str` directory path in local FileSystem
259
+ """
260
+ raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
261
+
262
+ def train(self, data_path, output_path, **kwargs):
263
+ """
264
+ Virtual method - need to implement
265
+
266
+ Train the model according to data in data_paths and save the train outputs to output_path,
267
+ this include the weights and any other artifacts created during train
268
+
269
+ :param data_path: `str` local File System path to where the data was downloaded and converted at
270
+ :param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
271
+ """
272
+ raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
273
+
274
+ def predict(self, batch, **kwargs):
275
+ """
276
+ Model inference (predictions) on batch of items
277
+
278
+ Virtual method - need to implement
279
+
280
+ :param batch: output of the `prepare_item_func` func
281
+ :return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
282
+ """
283
+ raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
284
+
285
+ def embed(self, batch, **kwargs):
286
+ """
287
+ Extract model embeddings on batch of items
288
+
289
+ Virtual method - need to implement
290
+
291
+ :param batch: output of the `prepare_item_func` func
292
+ :return: `list[list]` a feature vector per each item in the batch
293
+ """
294
+ raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
295
+
296
+ def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
297
+ """
298
+ This function evaluates the model prediction on a dataset (with GT annotations).
299
+ The evaluation process will upload the scores and metrics to the platform.
300
+
301
+ :param model: The model to evaluate (annotation.metadata.system.model.name
302
+ :param dataset: Dataset where the model predicted and uploaded its annotations
303
+ :param filters: Filters to query items on the dataset
304
+ :return:
305
+ """
306
+ import dtlpymetrics
307
+
308
+ compare_types = model.output_type
309
+ if not filters:
310
+ filters = entities.Filters()
311
+ if filters is not None and isinstance(filters, dict):
312
+ filters = entities.Filters(custom_filter=filters)
313
+ model = dtlpymetrics.scoring.create_model_score(
314
+ model=model,
315
+ dataset=dataset,
316
+ filters=filters,
317
+ compare_types=compare_types,
318
+ )
319
+ return model
320
+
321
+ def convert_from_dtlpy(self, data_path, **kwargs):
322
+ """Convert Dataloop structure data to model structured
323
+
324
+ Virtual method - need to implement
325
+
326
+ e.g. take dlp dir structure and construct annotation file
327
+
328
+ :param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
329
+ :return:
330
+ """
331
+ raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
332
+
333
+ #################
334
+ # DTLPY METHODS #
335
+ ################
336
+ def prepare_item_func(self, item: entities.Item):
337
+ """
338
+ Prepare the Dataloop item before calling the `predict` function with a batch.
339
+ A user can override this function to load item differently
340
+ Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
341
+
342
+ :param item:
343
+ :return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
344
+ """
345
+ # Item to batch func
346
+ if isinstance(self.model_entity.input_type, list):
347
+ if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
348
+ processed = self._item_to_text(item)
349
+ elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
350
+ processed = self._item_to_image(item)
351
+ else:
352
+ processed = self._item_to_item(item)
353
+
354
+ elif self.model_entity.input_type in self.item_to_batch_mapping:
355
+ processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
356
+
357
+ else:
358
+ processed = self._item_to_item(item)
359
+
360
+ return processed
361
+
362
+ def __include_model_annotations(self, annotation_filters):
363
+ include_model_annotations = self.model_entity.configuration.get("include_model_annotations", False)
364
+ if include_model_annotations is False:
365
+ if annotation_filters.custom_filter is None:
366
+ annotation_filters.add(field="metadata.system.model.name", values=False, operator=entities.FiltersOperations.EXISTS)
367
+ else:
368
+ annotation_filters.custom_filter['filter']['$and'].append({'metadata.system.model.name': {'$exists': False}})
369
+ return annotation_filters
370
+
371
+ def __download_background_images(self, filters, data_subset_base_path, annotation_options):
372
+ background_list = list()
373
+ if self.configuration.get('include_background', False) is True:
374
+ filters.custom_filter["filter"]["$and"].append({"annotated": False})
375
+ background_list = self.model_entity.dataset.items.download(
376
+ filters=filters,
377
+ local_path=data_subset_base_path,
378
+ annotation_options=annotation_options,
379
+ )
380
+ return background_list
381
+
382
+ def prepare_data(
383
+ self,
384
+ dataset: entities.Dataset,
385
+ # paths
386
+ root_path=None,
387
+ data_path=None,
388
+ output_path=None,
389
+ #
390
+ overwrite=False,
391
+ **kwargs,
392
+ ):
393
+ """
394
+ Prepares dataset locally before training or evaluation.
395
+ download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
396
+
397
+ :param dataset: dl.Dataset
398
+ :param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
399
+ :param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
400
+ :param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
401
+
402
+ :param bool overwrite: overwrite the data path (download again). default is False
403
+ """
404
+ # define paths
405
+ dataloop_path = service_defaults.DATALOOP_PATH
406
+ root_path = self.adapter_defaults.resolve("root_path", root_path)
407
+ data_path = self.adapter_defaults.resolve("data_path", data_path)
408
+ output_path = self.adapter_defaults.resolve("output_path", output_path)
409
+ if root_path is None:
410
+ now = datetime.datetime.now()
411
+ root_path = os.path.join(
412
+ dataloop_path,
413
+ 'model_data',
414
+ "{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
415
+ now.strftime('%Y-%m-%d-%H%M%S'),
416
+ )
417
+ if data_path is None:
418
+ data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
419
+ os.makedirs(data_path, exist_ok=True)
420
+ if output_path is None:
421
+ output_path = os.path.join(root_path, 'output')
422
+ os.makedirs(output_path, exist_ok=True)
423
+
424
+ if len(os.listdir(data_path)) > 0:
425
+ self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
426
+
427
+ annotation_options = entities.ViewAnnotationOptions.JSON
428
+ if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
429
+ annotation_options = entities.ViewAnnotationOptions.INSTANCE
430
+
431
+ # Download the subset items
432
+ subsets = self.model_entity.metadata.get("system", {}).get("subsets", None)
433
+ annotations_subsets = self.model_entity.metadata.get("system", {}).get("annotationsSubsets", {})
434
+ if subsets is None:
435
+ raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
436
+ for subset, filters_dict in subsets.items():
437
+ data_subset_base_path = os.path.join(data_path, subset)
438
+ if os.path.isdir(data_subset_base_path) and not overwrite:
439
+ # existing and dont overwrite
440
+ self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
441
+ continue
442
+
443
+ filters = entities.Filters(custom_filter=filters_dict)
444
+ self.logger.debug("Downloading subset {!r} of {}".format(subset, self.model_entity.dataset.name))
445
+
446
+ annotation_filters = None
447
+ if subset in annotations_subsets:
448
+ annotation_filters = entities.Filters(
449
+ use_defaults=False,
450
+ resource=entities.FiltersResource.ANNOTATION,
451
+ custom_filter=annotations_subsets[subset],
452
+ )
453
+ # if user provided annotation_filters, skip the default filters
454
+ elif self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
455
+ annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
456
+ if self.model_entity.output_type in [
457
+ entities.AnnotationType.SEGMENTATION,
458
+ entities.AnnotationType.POLYGON,
459
+ ]:
460
+ model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
461
+ else:
462
+ model_output_types = [self.model_entity.output_type]
463
+
464
+ annotation_filters.add(
465
+ field=entities.FiltersKnownFields.TYPE,
466
+ values=model_output_types,
467
+ operator=entities.FiltersOperations.IN,
468
+ )
469
+
470
+ annotation_filters = self.__include_model_annotations(annotation_filters)
471
+ annotations_subsets[subset] = annotation_filters.prepare()
472
+
473
+ ret_list = dataset.items.download(
474
+ filters=filters,
475
+ local_path=data_subset_base_path,
476
+ annotation_options=annotation_options,
477
+ annotation_filters=annotation_filters,
478
+ )
479
+ filters = entities.Filters(custom_filter=subsets[subset])
480
+ background_ret_list = self.__download_background_images(
481
+ filters=filters,
482
+ data_subset_base_path=data_subset_base_path,
483
+ annotation_options=annotation_options,
484
+ )
485
+ ret_list = list(ret_list)
486
+ background_ret_list = list(background_ret_list)
487
+ self.logger.debug(f"Subset '{subset}': ret_list length: {len(ret_list)}, background_ret_list length: {len(background_ret_list)}")
488
+ # Combine ret_list and background_ret_list generators into a single generator
489
+ ret_list = ret_list + background_ret_list
490
+ if isinstance(ret_list, list) and len(ret_list) == 0:
491
+ if annotation_filters is not None:
492
+ annotation_filters_str = annotation_filters.prepare()
493
+ else:
494
+ annotation_filters_str = None
495
+ raise ValueError(
496
+ f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
497
+ f"Subset {subset} filters: {filters.prepare()}\n"
498
+ f"Annotation filters: {annotation_filters_str}"
499
+ )
500
+
501
+ self.convert_from_dtlpy(data_path=data_path, **kwargs)
502
+ return root_path, data_path, output_path
503
+
504
+ def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
505
+ """Loads a model from given `dl.Model`.
506
+ Reads configurations and instantiate self.model_entity
507
+ Downloads the model_entity bucket (if available)
508
+
509
+ :param model_entity: `str` dl.Model entity
510
+ :param local_path: `str` directory path in local FileSystem to download the model_entity to
511
+ :param overwrite: `bool` (default False) if False does not download files with same name else (True) download all
512
+ """
513
+ if model_entity is not None:
514
+ self.model_entity = model_entity
515
+ if local_path is None:
516
+ local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
517
+ # Load configuration and adapter defaults
518
+ self.adapter_defaults = AdapterDefaults(self)
519
+ # Point _configuration to the same object since AdapterDefaults inherits from ModelConfigurations
520
+ self._configuration = self.adapter_defaults
521
+ # Download
522
+ self.model_entity.artifacts.download(local_path=local_path, overwrite=overwrite)
523
+ self.load(local_path, **kwargs)
524
+
525
+ def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
526
+ """
527
+ Saves the model state to a new bucket and configuration
528
+
529
+ Saves configuration and weights to artifacts
530
+ Mark the model as `trained`
531
+ loads only applies for remote buckets
532
+
533
+ :param local_path: `str` directory path in local FileSystem to save the current model bucket (weights) (default will create a temp dir)
534
+ :param replace: `bool` will clean the bucket's content before uploading new files
535
+ :param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
536
+ :return:
537
+ """
538
+
539
+ if local_path is None:
540
+ local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
541
+ self.logger.debug("Using temporary dir at {}".format(local_path))
542
+
543
+ self.save(local_path=local_path, **kwargs)
544
+
545
+ if self.model_entity is None:
546
+ raise ValueError('Missing model entity on the adapter. ' 'Please set before saving: "adapter.model_entity=model"')
547
+
548
+ self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'), overwrite=True)
549
+ if cleanup:
550
+ shutil.rmtree(path=local_path, ignore_errors=True)
551
+ self.logger.info("Clean-up. deleting {}".format(local_path))
552
+
553
+ # ===============
554
+ # SERVICE METHODS
555
+ # ===============
556
+
557
+ @entities.Package.decorators.function(
558
+ display_name='Predict Items',
559
+ inputs={'items': 'Item[]'},
560
+ outputs={'items': 'Item[]', 'annotations': 'Annotation[]'},
561
+ )
562
+ def predict_items(self, items: list, batch_size=None, **kwargs):
563
+ """
564
+ Run the predict function on the input list of items (or single) and return the items and the predictions.
565
+ Each prediction is by the model output type (package.output_type) and model_info in the metadata
566
+
567
+ :param items: `List[dl.Item]` list of items to predict
568
+ :param batch_size: `int` size of batch to run a single inference
569
+
570
+ :return: `List[dl.Item]`, `List[List[dl.Annotation]]`
571
+ """
572
+ if batch_size is None:
573
+ batch_size = self.configuration.get('batch_size', 4)
574
+ input_type = self.model_entity.input_type
575
+ self.logger.debug("Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
576
+ pool = ThreadPoolExecutor(max_workers=16)
577
+ error_counter = 0
578
+ fail_ids = list()
579
+ annotations = list()
580
+ for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None, file=sys.stdout):
581
+ batch_items = items[i_batch : i_batch + batch_size]
582
+ batch = list(pool.map(self.prepare_item_func, batch_items))
583
+ try:
584
+ batch_collections = self.predict(batch, **kwargs)
585
+ except Exception as e:
586
+ item_ids = [item.id for item in batch_items]
587
+ self.logger.error(f"Failed to predict batch {i_batch} for items {item_ids}. Error: {e}\n{traceback.format_exc()}")
588
+ error_counter += 1
589
+ fail_ids.extend(item_ids)
590
+ continue
591
+ _futures = list(pool.map(partial(self._update_predictions_metadata), batch_items, batch_collections))
592
+ # Loop over the futures to make sure they are all done to avoid race conditions
593
+ _ = [_f for _f in _futures]
594
+ self.logger.debug("Uploading items' annotation for model {!r}.".format(self.model_entity.name))
595
+ try:
596
+ batch_collections = list(
597
+ pool.map(partial(self._upload_model_annotations), batch_items, batch_collections)
598
+ )
599
+ except Exception as err:
600
+ item_ids = [item.id for item in batch_items]
601
+ self.logger.error(
602
+ f"Failed to upload annotations for items {item_ids}. Error: {err}\n{traceback.format_exc()}"
603
+ )
604
+ error_counter += 1
605
+ fail_ids.extend(item_ids)
606
+
607
+ for collection in batch_collections:
608
+ # function needs to return `List[List[dl.Annotation]]`
609
+ # convert annotation collection to a list of dl.Annotation for each batch
610
+ if isinstance(collection, entities.AnnotationCollection):
611
+ annotations.extend([annotation for annotation in collection.annotations])
612
+ else:
613
+ logger.warning(f'RETURN TYPE MAY BE INVALID: {type(collection)}')
614
+ annotations.extend(collection)
615
+ # TODO call the callback
616
+
617
+ pool.shutdown()
618
+ if error_counter > 0:
619
+ raise Exception(f"Failed to predict all items. Failed IDs: {fail_ids}, See logs for more details")
620
+ return items, annotations
621
+
622
+ @entities.Package.decorators.function(
623
+ display_name='Embed Items',
624
+ inputs={'items': 'Item[]'},
625
+ outputs={'items': 'Item[]', 'features': 'Json[]'},
626
+ )
627
+ def embed_items(self, items: list, upload_features=None, batch_size=None, progress: utilities.Progress = None, **kwargs):
628
+ """
629
+ Extract feature from an input list of items (or single) and return the items and the feature vector.
630
+
631
+ :param items: `List[dl.Item]` list of items to embed
632
+ :param upload_features: `bool` uploads the features on the given items
633
+ :param batch_size: `int` size of batch to run a single embed
634
+
635
+ :return: `List[dl.Item]`, `List[List[vector]]`
636
+ """
637
+ if batch_size is None:
638
+ batch_size = self.configuration.get('batch_size', 4)
639
+ upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
640
+ skip_default_items = upload_features is None
641
+ if upload_features is None:
642
+ upload_features = True
643
+ input_type = self.model_entity.input_type
644
+ self.logger.debug("Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
645
+ error_counter = 0
646
+ fail_ids = list()
647
+
648
+ feature_set = self.feature_set
649
+
650
+ # upload the feature vectors
651
+ pool = ThreadPoolExecutor(max_workers=16)
652
+ vectors = list()
653
+ _items = list()
654
+ for i_batch in tqdm.tqdm(
655
+ range(0, len(items), batch_size),
656
+ desc='embedding',
657
+ unit='bt',
658
+ leave=None,
659
+ file=sys.stdout,
660
+ ):
661
+ batch_items = items[i_batch : i_batch + batch_size]
662
+ batch = list(pool.map(self.prepare_item_func, batch_items))
663
+ try:
664
+ batch_vectors = self.embed(batch, **kwargs)
665
+ except Exception as err:
666
+ item_ids = [item.id for item in batch_items]
667
+ self.logger.error(f"Failed to embed batch {i_batch} for items {item_ids}. Error: {err}\n{traceback.format_exc()}")
668
+ error_counter += 1
669
+ fail_ids.extend(item_ids)
670
+ continue
671
+ vectors.extend(batch_vectors)
672
+ # Save the items in the order of the vectors
673
+ _items.extend(batch_items)
674
+ pool.shutdown()
675
+
676
+ if upload_features is True:
677
+ embeddings_size = self.configuration.get('embeddings_size', 256)
678
+ valid_items = []
679
+ valid_vectors = []
680
+ items_to_upload = []
681
+ vectors_to_upload = []
682
+
683
+ for item, vector in zip(_items, vectors):
684
+ # Check if vector is valid
685
+ if vector is None or len(vector) != embeddings_size:
686
+ self.logger.warning(f"Vector generated for item {item.id} is None or has wrong size. Skipping...")
687
+ continue
688
+
689
+ # Item and vector are valid
690
+ valid_items.append(item)
691
+ valid_vectors.append(vector)
692
+
693
+ # Check if item should be skipped (prompt items)
694
+ _system_metadata = getattr(item, 'system', dict())
695
+ is_prompt = _system_metadata.get('shebang', dict()).get('dltype', '') == 'prompt'
696
+ if skip_default_items and is_prompt:
697
+ self.logger.debug(f"Skipping feature upload for prompt item {item.id}")
698
+ continue
699
+
700
+ # Items were not skipped - should be uploaded
701
+ items_to_upload.append(item)
702
+ vectors_to_upload.append(vector)
703
+
704
+ # Update the original lists with valid items only
705
+ _items[:] = valid_items
706
+ vectors[:] = valid_vectors
707
+
708
+ if len(_items) != len(vectors):
709
+ raise ValueError(f"The number of items ({len(_items)}) is not equal to the number of vectors ({len(vectors)}).")
710
+
711
+ self.logger.debug(f"Uploading {len(items_to_upload)} items' feature vectors for model {self.model_entity.name}.")
712
+ try:
713
+ start_time = time.time()
714
+ feature_set.features.create(entity=items_to_upload, value=vectors_to_upload, feature_set_id=feature_set.id, project_id=self.model_entity.project_id)
715
+ self.logger.debug(f"Uploaded {len(items_to_upload)} items' feature vectors for model {self.model_entity.name} in {time.time() - start_time} seconds.")
716
+ except Exception as err:
717
+ self.logger.error(f"Failed to upload feature vectors. Error: {err}\n{traceback.format_exc()}")
718
+ error_counter += 1
719
+ if error_counter > 0:
720
+ raise Exception(f"Failed to embed all items. Failed IDs: {fail_ids}, See logs for more details")
721
+ return _items, vectors
722
+
723
+ @entities.Package.decorators.function(
724
+ display_name='Embed Dataset with DQL',
725
+ inputs={'dataset': 'Dataset', 'filters': 'Json'},
726
+ )
727
+ def embed_dataset(
728
+ self,
729
+ dataset: entities.Dataset,
730
+ filters: Optional[entities.Filters] = None,
731
+ upload_features: Optional[bool] = None,
732
+ batch_size: Optional[int] = None,
733
+ progress: Optional[utilities.Progress] = None,
734
+ **kwargs,
735
+ ):
736
+ """
737
+ Run model embedding on all items in a dataset
738
+
739
+ :param dataset: Dataset entity to embed
740
+ :param filters: Filters entity for filtering before embedding
741
+ :param upload_features: bool whether to upload features back to platform
742
+ :param batch_size: int size of batch to run a single embedding
743
+ :param progress: dl.Progress object to track progress
744
+ :return: bool indicating if embedding completed successfully
745
+ """
746
+
747
+ self._execute_dataset_operation(
748
+ dataset=dataset,
749
+ operation_type='embed',
750
+ filters=filters,
751
+ progress=progress,
752
+ batch_size=batch_size,
753
+ )
754
+
755
+ @entities.Package.decorators.function(
756
+ display_name='Predict Dataset with DQL',
757
+ inputs={'dataset': 'Dataset', 'filters': 'Json'},
758
+ )
759
+ def predict_dataset(
760
+ self,
761
+ dataset: entities.Dataset,
762
+ filters: Optional[entities.Filters] = None,
763
+ batch_size: Optional[int] = None,
764
+ progress: Optional[utilities.Progress] = None,
765
+ **kwargs,
766
+ ):
767
+ """
768
+ Run model prediction on all items in a dataset
769
+
770
+ :param dataset: Dataset entity to predict
771
+ :param filters: Filters entity for filtering before prediction
772
+ :param batch_size: int size of batch to run a single prediction
773
+ :param progress: dl.Progress object to track progress
774
+ :return: bool indicating if prediction completed successfully
775
+ """
776
+ self._execute_dataset_operation(
777
+ dataset=dataset,
778
+ operation_type='predict',
779
+ filters=filters,
780
+ progress=progress,
781
+ batch_size=batch_size,
782
+ )
783
+
784
+ @entities.Package.decorators.function(
785
+ display_name='Train a Model',
786
+ inputs={'model': entities.Model},
787
+ outputs={'model': entities.Model},
788
+ )
789
+ def train_model(self, model: entities.Model, cleanup=False, progress: utilities.Progress = None, context: utilities.Context = None):
790
+ """
791
+ Train on existing model.
792
+ data will be taken from dl.Model.datasetId
793
+ configuration is as defined in dl.Model.configuration
794
+ upload the output the model's bucket (model.bucket)
795
+ """
796
+ if isinstance(model, dict):
797
+ model = repositories.Models(client_api=self._client_api).get(model_id=model['id'])
798
+ output_path = None
799
+ try:
800
+ logger.info("Received {s} for training".format(s=model.id))
801
+ model = model.wait_for_model_ready()
802
+ if model.status == 'failed':
803
+ raise ValueError("Model is in failed state, cannot train.")
804
+
805
+ ##############
806
+ # Set status #
807
+ ##############
808
+ model.status = 'training'
809
+ if context is not None:
810
+ if 'system' not in model.metadata:
811
+ model.metadata['system'] = dict()
812
+ model.update(reload_services=False)
813
+
814
+ ##########################
815
+ # load model and weights #
816
+ ##########################
817
+ logger.info("Loading Adapter with: {n} ({i!r})".format(n=model.name, i=model.id))
818
+ self.load_from_model(model_entity=model)
819
+
820
+ ################
821
+ # prepare data #
822
+ ################
823
+ root_path, data_path, output_path = self.prepare_data(dataset=self.model_entity.dataset, root_path=os.path.join('tmp', model.id))
824
+ # Start the Train
825
+ logger.info(f"Training model {model.name!r} ({model.id!r}) on data {data_path!r}")
826
+ if progress is not None:
827
+ progress.update(message='starting training')
828
+
829
+ def on_epoch_end_callback(i_epoch, n_epoch):
830
+ if progress is not None:
831
+ progress.update(progress=int(100 * (i_epoch + 1) / n_epoch), message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
832
+
833
+ self.train(data_path=data_path, output_path=output_path, on_epoch_end_callback=on_epoch_end_callback)
834
+ if progress is not None:
835
+ progress.update(message='saving model', progress=99)
836
+
837
+ self.save_to_model(local_path=output_path, replace=True)
838
+ model.status = 'trained'
839
+ model.update(reload_services=False)
840
+ ###########
841
+ # cleanup #
842
+ ###########
843
+ if cleanup:
844
+ shutil.rmtree(output_path, ignore_errors=True)
845
+ except Exception:
846
+ # save also on fail
847
+ if output_path is not None:
848
+ self.save_to_model(local_path=output_path, replace=True)
849
+ logger.info('Execution failed. Setting model.status to failed')
850
+ raise
851
+ return model
852
+
853
+ @entities.Package.decorators.function(
854
+ display_name='Evaluate a Model',
855
+ inputs={'model': entities.Model, 'dataset': entities.Dataset, 'filters': 'Json'},
856
+ outputs={'model': entities.Model, 'dataset': entities.Dataset},
857
+ )
858
+ def evaluate_model(
859
+ self,
860
+ model: entities.Model,
861
+ dataset: entities.Dataset,
862
+ filters: entities.Filters = None,
863
+ #
864
+ progress: utilities.Progress = None,
865
+ context: utilities.Context = None,
866
+ ):
867
+ """
868
+ Evaluate a model.
869
+ data will be downloaded from the dataset and query
870
+ configuration is as defined in dl.Model.configuration
871
+ upload annotations and calculate metrics vs GT
872
+
873
+ :param model: Model entity to run prediction
874
+ :param dataset: Dataset to evaluate
875
+ :param filters: Filter for specific items from dataset
876
+ :param progress: dl.Progress for report FaaS progress
877
+ :param context:
878
+ :return:
879
+ """
880
+ logger.info(f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
881
+ ##########################
882
+ # load model and weights #
883
+ ##########################
884
+ logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
885
+ self.load_from_model(dataset=dataset, model_entity=model)
886
+
887
+ ##############
888
+ # Predicting #
889
+ ##############
890
+ logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
891
+ if not filters:
892
+ filters = entities.Filters()
893
+ if self.adapter_defaults.get("overwrite_annotations", True) is True:
894
+ self._execute_dataset_operation(
895
+ dataset=dataset,
896
+ operation_type='predict',
897
+ filters=filters,
898
+ multiple_executions=False,
899
+ )
900
+
901
+ ##############
902
+ # Evaluating #
903
+ ##############
904
+ logger.info(f"Starting adapter.evaluate()")
905
+ if progress is not None:
906
+ progress.update(message='calculating metrics', progress=98)
907
+ model = self.evaluate(model=model, dataset=dataset, filters=filters)
908
+ #########
909
+ # Done! #
910
+ #########
911
+ if progress is not None:
912
+ progress.update(message='finishing evaluation', progress=99)
913
+ return model, dataset
914
+
915
+ # =============
916
+ # INNER METHODS
917
+ # =============
918
+ def _get_feature_set(self):
919
+ # Ensure feature set creation/retrieval is thread-safe across the class
920
+ with self.__class__._feature_set_lock:
921
+ # Search for existing feature set for this model id
922
+ feature_set = self.model_entity.feature_set
923
+ if feature_set is None:
924
+ logger.info('Feature Set not found. creating... ')
925
+ try:
926
+ self.project.feature_sets.get(feature_set_name=self.model_entity.name)
927
+ feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
928
+ logger.warning(
929
+ f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}"
930
+ )
931
+
932
+ except exceptions.NotFound:
933
+ feature_set_name = self.model_entity.name
934
+ feature_set = self.project.feature_sets.create(
935
+ name=feature_set_name,
936
+ entity_type=entities.FeatureEntityType.ITEM,
937
+ model_id=self.model_entity.id,
938
+ project_id=self.project.id,
939
+ set_type=self.model_entity.name,
940
+ size=self.configuration.get('embeddings_size', 256),
941
+ )
942
+ logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
943
+ else:
944
+ logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
945
+ return feature_set
946
+
947
+ def _execute_dataset_operation(
948
+ self,
949
+ dataset: entities.Dataset,
950
+ operation_type: str,
951
+ filters: Optional[entities.Filters] = None,
952
+ progress: Optional[utilities.Progress] = None,
953
+ batch_size: Optional[int] = None,
954
+ multiple_executions: bool = True,
955
+ ) -> bool:
956
+ """
957
+ Execute dataset operation (predict/embed) with batching and filtering support.
958
+
959
+ :param dataset: Dataset entity to run operation on
960
+ :param operation_type: Type of operation to execute ('predict' or 'embed')
961
+ :param filters: Filters entity to filter items, default None
962
+ :param progress: Progress object for tracking progress, default None
963
+ :param batch_size: Size of batches to process items, default None (uses model config)
964
+ :param multiple_executions: Whether to use multiple executions when filters exceed subset limit, default True
965
+ :return: True if operation completes successfully
966
+ :raises ValueError: If operation_type is not 'predict' or 'embed'
967
+ """
968
+ self.logger.debug(f"Running {operation_type} for dataset (name:{dataset.name}, id:{dataset.id})")
969
+
970
+ if not filters:
971
+ self.logger.debug("No filters provided, using default filters")
972
+ filters = entities.Filters()
973
+ if filters is not None and isinstance(filters, dict):
974
+ self.logger.debug(f"Received custom filters {filters}")
975
+ filters = entities.Filters(custom_filter=filters)
976
+
977
+ if operation_type == 'embed':
978
+ feature_set = self.feature_set
979
+ logger.info(f"Feature set found! name: {feature_set.name}, id: {feature_set.id}")
980
+
981
+ predict_embed_subset_limit = self.configuration.get('predict_embed_subset_limit', PREDICT_EMBED_DEFAULT_SUBSET_LIMIT)
982
+ predict_embed_timeout = self.configuration.get('predict_embed_timeout', PREDICT_EMBED_DEFAULT_TIMEOUT)
983
+ self.logger.debug(f"Inputs: predict_embed_subset_limit: {predict_embed_subset_limit}, predict_embed_timeout: {predict_embed_timeout}")
984
+ tmp_filters = copy.deepcopy(filters.prepare())
985
+ tmp_filters['pageSize'] = 0
986
+ num_items = dataset.items.list(filters=entities.Filters(custom_filter=tmp_filters)).items_count
987
+ self.logger.debug(f"Number of items for current filters: {num_items}")
988
+
989
+ # One-item lookahead on generator: if only one subset, run locally; else create executions for all
990
+ gen = entities.Filters._get_split_filters(dataset, filters, predict_embed_subset_limit)
991
+ try:
992
+ first_filter = next(gen)
993
+ except StopIteration:
994
+ self.logger.info("Filters is empty, nothing to run")
995
+ return True
996
+
997
+ try:
998
+ second_filter = next(gen)
999
+ multiple = True
1000
+ except StopIteration:
1001
+ multiple = False
1002
+
1003
+ # Create consistent iterable of all filters for reuse
1004
+ # Both paths use chain to ensure consistent type and iteration behavior
1005
+ if multiple:
1006
+ # Chain together the pre-consumed filters with the remaining generator
1007
+ all_filters = chain([first_filter, second_filter], gen)
1008
+ else:
1009
+ # Single filter - use chain with empty generator for consistency
1010
+ all_filters = chain([first_filter], [])
1011
+
1012
+ if not multiple or not multiple_executions:
1013
+ self.logger.info("Running locally")
1014
+ if batch_size is None:
1015
+ batch_size = self.configuration.get('batch_size', 4)
1016
+
1017
+ # Process each filter locally
1018
+ for filter_dict in all_filters:
1019
+ filter_dict["pageSize"] = 1000
1020
+ single_filters = entities.Filters(custom_filter=filter_dict)
1021
+ pages = dataset.items.list(filters=single_filters)
1022
+ self.logger.info(f"Processing filter on: {pages.items_count} items")
1023
+ items = [item for page in pages for item in page if item.type == 'file']
1024
+ self.logger.debug(f"Items length: {len(items)}")
1025
+
1026
+ if operation_type == 'embed':
1027
+ self.embed_items(items=items, batch_size=batch_size, progress=progress)
1028
+ elif operation_type == 'predict':
1029
+ self.predict_items(items=items, batch_size=batch_size, progress=progress)
1030
+ else:
1031
+ raise ValueError(f"Unsupported operation type: {operation_type}")
1032
+ return True
1033
+
1034
+ executions = []
1035
+ for filter_dict in all_filters:
1036
+ self.logger.debug(f"Creating execution for models {operation_type} with dataset id {dataset.id} and filter_dict {filter_dict}")
1037
+ if operation_type == 'embed':
1038
+ execution = self.model_entity.models.embed(
1039
+ model=self.model_entity,
1040
+ dataset_id=dataset.id,
1041
+ filters=entities.Filters(custom_filter=filter_dict),
1042
+ )
1043
+ elif operation_type == 'predict':
1044
+ execution = self.model_entity.models.predict(
1045
+ model=self.model_entity, dataset_id=dataset.id, filters=entities.Filters(custom_filter=filter_dict)
1046
+ )
1047
+ else:
1048
+ raise ValueError(f"Unsupported operation type: {operation_type}")
1049
+ executions.append(execution)
1050
+
1051
+ if executions:
1052
+ self.logger.info(f'Created {len(executions)} executions for {operation_type}, ' f'execution ids: {[ex.id for ex in executions]}')
1053
+
1054
+ wait_time = 5
1055
+ start_time = time.time()
1056
+ last_perc = 0
1057
+ self.logger.debug(f"Waiting for executions with timeout {predict_embed_timeout}")
1058
+ while time.time() - start_time < predict_embed_timeout:
1059
+ continue_loop = False
1060
+ total_perc = 0
1061
+
1062
+ for ex in executions:
1063
+ execution = self.project.executions.get(execution_id=ex.id)
1064
+ perc = execution.latest_status.get('percentComplete', 0)
1065
+ total_perc += perc
1066
+ if execution.in_progress():
1067
+ continue_loop = True
1068
+
1069
+ avg_perc = round(total_perc / len(executions), 0)
1070
+ if progress is not None and last_perc != avg_perc:
1071
+ last_perc = avg_perc
1072
+ progress.update(progress=last_perc, message=f'running {operation_type}')
1073
+
1074
+ if not continue_loop:
1075
+ break
1076
+
1077
+ time.sleep(wait_time)
1078
+ self.logger.debug("End waiting for executions")
1079
+ # Check if any execution failed
1080
+ executions_filter = entities.Filters(resource=entities.FiltersResource.EXECUTION)
1081
+ executions_filter.add(field="id", values=[ex.id for ex in executions], operator=entities.FiltersOperations.IN)
1082
+ executions_filter.add(field='latestStatus.status', values='failed')
1083
+ executions_filter.page_size = 0
1084
+ failed_executions_count = self.project.executions.list(filters=executions_filter).items_count
1085
+ if failed_executions_count > 0:
1086
+ self.logger.error(f"Failed to {operation_type} for {failed_executions_count} executions")
1087
+ raise ValueError(f"Failed to {operation_type} entire dataset, please check the logs for more details")
1088
+ return True
1089
+
1090
+ def _upload_model_annotations(self, item: entities.Item, predictions):
1091
+ """
1092
+ Utility function that upload prediction to dlp platform based on the package.output_type
1093
+ :param predictions: `dl.AnnotationCollection`
1094
+ :param cleanup: `bool` if set removes existing predictions with the same package-model name
1095
+ """
1096
+ if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
1097
+ raise TypeError(f'predictions was expected to be of type {entities.AnnotationCollection}, but instead it is {type(predictions)}')
1098
+ clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
1099
+ clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name, method=entities.FiltersMethod.OR)
1100
+ clean_filter.add(field='metadata.system.model.name', values=self.model_entity.name, method=entities.FiltersMethod.OR)
1101
+ # clean_filter.add(field='type', values=self.model_entity.output_type,)
1102
+ item.annotations.delete(filters=clean_filter)
1103
+ annotations = item.annotations.upload(annotations=predictions)
1104
+ return annotations
1105
+
1106
+ @staticmethod
1107
+ def _item_to_image(item):
1108
+ """
1109
+ Preprocess items before calling the `predict` functions.
1110
+ Convert item to numpy array
1111
+
1112
+ :param item:
1113
+ :return:
1114
+ """
1115
+ try:
1116
+ buffer = item.download(save_locally=False)
1117
+ image = np.asarray(Image.open(buffer))
1118
+ except Exception as e:
1119
+ logger.error(f"Failed to convert image to np.array, Error: {e}\n{traceback.format_exc()}")
1120
+ image = None
1121
+ return image
1122
+
1123
+ @staticmethod
1124
+ def _item_to_item(item):
1125
+ """
1126
+ Default item to batch function.
1127
+ This function should prepare a single item for the predict function, e.g. for images, it loads the image as numpy array
1128
+ :param item:
1129
+ :return:
1130
+ """
1131
+ return item
1132
+
1133
+ @staticmethod
1134
+ def _item_to_text(item):
1135
+ filename = item.download(overwrite=True)
1136
+ text = None
1137
+ if item.mimetype == 'text/plain' or item.mimetype == 'text/markdown':
1138
+ with open(filename, 'r') as f:
1139
+ text = f.read()
1140
+ text = text.replace('\n', ' ')
1141
+ else:
1142
+ logger.warning('Item is not text file. mimetype: {}'.format(item.mimetype))
1143
+ text = item
1144
+ if os.path.exists(filename):
1145
+ os.remove(filename)
1146
+ return text
1147
+
1148
+ @staticmethod
1149
+ def _uri_to_image(data_uri):
1150
+ # data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAS4AAAEuCAYAAAAwQP9DAAAU80lEQVR4Xu2da+hnRRnHv0qZKV42LDOt1eyGULoSJBGpRBFprBJBQrBJBBWGSm8jld5WroHUCyEXKutNu2IJ1QtXetULL0uQFCu24WoRsV5KpYvGYzM4nv6X8zu/mTnznPkcWP6XPTPzzOf7/L7/OXPmzDlOHBCAAAScETjOWbyECwEIQEAYF0kAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxDwTeDTkr4s6UxJ/5F0QNK3JD3lu1tbR49xLVld+jYXgcskvSTpIkmnS/qgpJMk/Tv8bHHZ7+PXPw6M5kRJx0t6Ijkv9uUsSW+U9Iykczfp4K8lfXiuztdoF+OqQZk2vBEwUzFTsK9mQNFkotGkhvFeSc+G86NRtdDfd0h6tIVASsSAcZWgSp0eCJjJ7JR0SRgZ2SjHDMp+38Jho7PXTAzkBUmvn1jWRTGMy4VMBJmBgBnSpZLsMs7+paOodao3k/hLqCBe8j0cfj4Yvtp8k/1fPLaaf4pxxXPSS8r4/Vsl3SXp5EHgNjo8JukDkg6v06nWy2JcrSvUX3xmKjYSipdqF0h6V/jgp6Mh+2DHf0YpnSd6p6TTkjml7UZRL4bLPasnmo7VHb+PKsQ20rZTQ6ql1lclfXODxr4u6Ru1gpizHYxrTvq0beZkE9cfkXRxxcu0pyXZaMiMKX71dBfua5sY1Psk/baHtMK4elC5rT5eFS7Z7Otmd8VyRDwcRZkxmUlFo8rRxlx13Clpz6Dxn0r61FwB1W4X46pNvM/27PLPPmhmVhvNLUWTiaZil1/xEswMx/7fbv9bWfs5nfcxommdceQU55eWSNxGihcmHbMRZK45Oxe8MK75ZYofaku8MyQ9J+mQpKNJMqbzLfeHkIeTuPP35JUIbCSVToRvNrKyftqCSfs3nE9qqT+txWKT8OmxT9LnWguyZDwYV0m6m9dtH+SbJNlamw+tGIIl7Va6/VPS8xusP4rN2JojG8E8NrhUS+d4ht/bbfkTJP0umGk6ER7PtfkVmwR/wzaXgEck7Q1mNcfE9oq4mzx9aFxXB55NBlsiKIyrBNXt67xB0q3bn7aYM+xSxkZVNjez5Eu4GoLZ5fb+pCFb/mB/LLo6MK555LaRyUPzND251VUWRJpRxTt2cUJ8csMUfBUBG61en/ymu8tE6zvGNd+nwuao7N8PJO0Kz7JZNDbH9aSkv4fQ0su2RyS9VtKD4dJtOClt5+4Il4Fpz+KkdqzLnpuzdrY74vnppWG6ujx9xMXOsUWPjw8WW27XBv+/GgH7Q2Dzh/G4NoxkV6vF+dkYV1sCRoNpKyqiaYmA/TGxxbXxsD963d3YwLhaSkligcDWBIZTDHajo+RauGb1wLialYbAIPB/BO6Q9Pnkt7dJshs93R0YV3eS02HHBGz+8Owk/vN6nU/EuBxnMaF3RWC4DOJ7kr7UFYGksxhXr8rTb28Eho/5dDvaMuEwLm/pS7w9EhiOtu4Oz332yOLlPmNc3UpPx50QsCUytlg5vXvY5RKIVC+My0n2Ema3BG4Oz7VGAN2PthhxdftZoOOOCKQLTu1RKlvL1f3D6Yy4HGUwoXZHwLaq+X7S6xvDzhrdgRh2GOPqPgUA0DCB9LlE27tsu73zG+5K3tAwrrw8qQ0CuQjYZLztmRaP7vbc2gokxpUrzagHAnkJpNvXMNoasMW48iYbtUEgF4F0Up7RFsaVK6+oBwLFCKST8t3uAMGlYrH8omIIFCFg21zvDjV3uwMExlUkt6gUAkUIDCflu34mcTPCzHEVyT0qhcBkAumLVJiU3wQjxjU5vygIgSIE0l0gutxPfgxVjGsMJc6BQB0C9kC1vW4sHvbik/RlKXWicNAKxuVAJELshkC6fY29sdzecs6xAQGMi7SAQDsE7IW5e0I4PJe4hS4YVztJSyQQsF0fdgYM3E3EuPhEQKB5Aumrx7ibuI1cjLiaz2cC7IRAugyCy0SMq5O0p5veCaSr5blMxLi85zPxd0LgGUmnSOIycYTgXCqOgMQpEChMwJY93MfdxPGUMa7xrDgTAqUIxGUQ7Ck/kjDGNRIUp0GgIIG49xaXiSMhY1wjQXEaBAoRSFfLczdxJGSMayQoToNAIQLpannuJo6EjHGNBMVpEChEgMvECWAxrgnQKAKBTAS4TJwIEuOaCI5iEMhAgMvEiRAxrongKAaBDAS4TJwIEeOaCI5iEFiTQPpQNXcTV4SJca0IjNMhkIlA+sJX7iauCBXjWhEYp0MgE4G49xaLTicAxbgmQKMIBNYkkL6CjPcmToCJcU2ARhEIrEkgfVP1Lkn2Zh+OFQhgXCvA4lQIZCIQl0EckWSjL44VCWBcKwLjdAhkIHBY0vmS9kmy0RfHigQwrhWBcToE1iSQLoO4QtK9a9bXZXGMq0vZ6fSMBOLe8rb3ll0m8sLXCWJgXBOgUQQCaxA4KOlStmheg6AkjGs9fpSGwKoEXgoFbpF086qFOf9/BDAuMgEC9Qike8tfLslGXxwTCGBcE6BRBAITCdgI66ZQls/eRIiMuNYAR1EITCAQ57ful2SjL46JBHD9ieAoBoEJBJjfmgBtoyIYVyaQVAOBbQik67eulmRvruaYSADjmgiOYhBYkUBcv2XFdrB+a0V6g9MxrvX4URoCYwnwfOJYUiPOw7hGQOIUCGQgEPff4vnEDDAxrgwQqQIC2xBI99+6VpKNvjjWIIBxrQGPohAYSSDdf4ttmkdC2+o0jCsDRKqAwDYEmN/KnCIYV2agVAeBDQgclfQW9t/KlxsYVz6W1ASBjQiw/1aBvMC4CkClSggkBOLziey/lTEtMK6MMKkKAhsQsBdhXMj+W3lzA+PKy5PaIJASOF3SsfAL3ladMTcwrowwqQoCAwK8hqxQSmBchcBSLQTCg9S7Jdn8lo2+ODIRwLgygaQaCGxAwF6EcRrLIPLnBsaVnyk1QsAIXCVpf0DBNjaZcwLjygyU6iAQCOyVdH34nm1sMqcFxpUZKNVBIBCIu0HcHUZfgMlIAOPKCJOqIBAIpKvl2Q2iQFpgXAWgUmX3BLhMLJwCGFdhwFTfJQEuEwvLjnEVBkz13RHgpRgVJMe4KkCmia4IpA9Vs+i0kPQYVyGwVNstgQcl7WLRaVn9Ma6yfKm9LwLsvVVJb4yrEmia6YJAvJvIs4mF5ca4CgOm+q4I8GxiJbkxrkqgaWbxBNJnE22OyzYQ5ChEAOMqBJZquyMQ124dkWTvUeQoSADjKgiXqrshcJmk+0Jv2em0guwYVwXINLF4Agck2YaBdvDC1wpyY1wVINPEognYZeHvJZ0g6RFJFyy6t410DuNqRAjCcEvgBkm3huhvl3Sd2544ChzjciQWoTZJIL5+zILjbmIliTCuSqBpZpEE0tePsei0osQYV0XYNLU4Aunrx/ZJsp85KhDAuCpAponFErhT0p7QO5ZBVJQZ46oIm6YWR4D5rZkkxbhmAk+z7gkwvzWjhBjXjPBp2jWBz0i6K/TgN5Iucd0bZ8FjXM4EI9xmCMSdTi2gn0gyI+OoRADjqgSaZhZHIH3Mh1eQVZYX46oMnOYWQyDuBmEdulzSwcX0zEFHMC4HIhFikwReSqLiwerKEmFclYHT3CIIpNvYWIf4HFWWFeCVgdPcIgh8R9JXQk/+KulNi+iVo05gXI7EItRmCPxS0kdDNLalzXuaiayTQDCuToSmm9kI2MJT25751FDjLZJsaQRHRQIYV0XYNLUIAvdIujLpCXcUZ5AV45oBOk26JvCMpFNCD+zO4vGue+M0eIzLqXCEPQuBdBsbC+BeSVfMEknnjWJcnScA3V+JwJOS3pyUuFqSraDnqEwA46oMnOZcE0gXnVpH+PzMJCfgZwJPsy4JYFyNyIZxNSIEYbggMDSuHZKechH5woLEuBYmKN0pSoARV1G84yvHuMaz4sy+CQzvKB6VdE7fSObrPcY1H3ta9kVgeEeRt/rMqB/GNSN8mnZFYHiZyIr5GeXDuGaET9NuCFwlaX8SLTtCzCwdxjWzADTvgkC6v7wFfJukG1xEvtAgMa6FCku3shL4s6QzkxpZMZ8V7+qVYVyrM6NEfwSel3Ri0m3Wb82cAxjXzALQfPMEhvNbf5D07uajXniAGNfCBaZ7axN4VNLbk1pulLR37VqpYC0CGNda+Ci8cAK22+mxQR95o08DomNcDYhACM0SGK6Wt3cpmnFxzEwA45pZAJpvmsBwtTyXiY3IhXE1IgRhNElguFqey8RGZMK4GhGCMJojMLybyGViQxJhXA2JQShNEbhT0p4kIlbLNyQPxtWQGITSFAH2l29KjlcHg3E1LA6hzUrgxcGe8nxWZpUD42oIP6E0SuAiSQ8NYtsl6eFG4+0uLP6KdCc5HR5BYKOFp+y/NQJcrVMwrlqkaccTgQckXTwI+DJJ93vqxJJjxbiWrC59m0LgfEmHBwX/JemEKZVRpgwBjKsMV2r1S8BGVvcNwv+spB/67dLyIse4lqcpPVqPwEbGxcaB6zHNXhrjyo6UCp0TuFLSPYM+XCPpx877tajwMa5FyUlnMhCwveRvHdTDjqcZwOasAuPKSZO6lkDggKTdSUeOSDp3CR1bUh8wriWpSV9yEPiHpJOSinhGMQfVzHVgXJmBUp17AsOtbFgx36CkGFeDohDSbASGj/r8TdIZs0VDw5sSwLhIDgi8QmC4VfPdkmxfLo7GCGBcjQlCOLMSGO7BxVbNs8qxeeMYV6PCENYsBGyX051JyzxYPYsM2zeKcW3PiDP6ITCcmGf9VqPaY1yNCkNY1QkMJ+YPSbLfcTRIAONqUBRCmoXA8BlF1m/NIsO4RjGucZw4a/kEhncUebC6Yc0xrobFIbSqBIbPKDK/VRX/ao1hXKvx4uzlEtgr6frQvUckXbDcrvrvGcblX0N6kIdAaly/kPTxPNVSSwkCGFcJqtTpkUC6+JSFp40riHE1LhDhVSNwUNKloTUm5qthn9YQxjWNG6WWRyA1LlbMN64vxtW4QIRXjcBTkk4LrWFc1bBPawjjmsaNUssjkD7ug3E1ri/G1bhAhFeNQGpcbB5YDfu0hjCuadwotTwCqXGdJ8l2iuBolADG1agwhFWdQGpcfC6q41+tQQRajRdnL5dANK6nJZ2+3G4uo2cY1zJ0pBfrEbDXjz0WquB1ZOuxrFIa46qCmUYaJ/AJST8PMf5K0scaj7f78DCu7lMAAJLSnSFul3QdVNomgHG1rQ/R1SGQPmDNGq46zNdqBeNaCx+FF0LgYUkXhr6wFMKBqBiXA5EIsTgB7igWR5y3AYwrL09q80cg3WueF8A60Q/jciIUYRYjcLOkm0Lt7MNVDHPeijGuvDypzR+BdH6LZxSd6IdxORGKMIsQsBXyx0LNLDwtgrhMpRhXGa7U6oNA+kqyfZLsZw4HBDAuByIRYjEC6T7zbNdcDHP+ijGu/Eyp0Q+BuOspD1b70ezlSDEuZ4IRbjYCF0l6KNTGZWI2rHUqwrjqcKaV9gikj/lwmdiePltGhHE5E4xwsxGIyyC4TMyGtF5FGFc91rTUFoEXJL1OEqvl29JlVDQY1yhMnLQwAuljPl+QdMfC+rf47mBci5eYDm5AIJ3fYjcIhymCcTkUjZDXJhDnt1gtvzbKeSrAuObhTqvzEUj3l78t7H46XzS0PIkAxjUJG4UcE0i3aWYZhFMhMS6nwhH2ZAIHJO0Opcn/yRjnLYhw8/Kn9foE4m6nhyTZ6nkOhwQwLoeiEfJkAryGbDK6tgpiXG3pQTRlCaS7nfJ8YlnWRWvHuIripfLGCLCNTWOCTA0H45pKjnIeCaTbNPP+RI8KclfFsWqEPpVAnJi38jsk2X5cHA4JMOJyKBohTyaQGhe5Pxnj/AURb34NiKAOgXTjQLayqcO8WCsYVzG0VNwYgXRHCNZwNSbOquFgXKsS43yvBOxlr98OwT8g6f1eO0Lc7DlPDvRD4LuSvhi6+zNJn+yn68vrKSOu5WlKjzYmkD6jaKMv25OLwykBjMupcIS9MoH4KjIryK4QK+NrqwDG1ZYeRFOGQDoxby2whqsM52q1YlzVUNPQjAR+JOma0P5zkk6eMRaazkAA48oAkSqaJ/CEpLNClM9KOrX5iAlwSwIYFwmydAJnS3p80MlzJB1deseX3D+Ma8nq0rdIwF6K8bbww58k7QSNbwIYl2/9iH4cAdtA0O4k2rFf0r3jinFWqwQwrlaVIS4IQGBTAhgXyQEBCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwR+C/doIhTZIi/uMAAAAASUVORK5CYII="
1151
+ image_b64 = data_uri.split(",")[1]
1152
+ binary = base64.b64decode(image_b64)
1153
+ image = np.asarray(Image.open(io.BytesIO(binary)))
1154
+ return image
1155
+
1156
+ def _update_predictions_metadata(self, item: entities.Item, predictions: entities.AnnotationCollection):
1157
+ """
1158
+ add model_name and model_id to the metadata of the annotations.
1159
+ add model_info to the metadata of the system metadata of the annotation.
1160
+ Add item id to all the annotations in the AnnotationCollection
1161
+
1162
+ :param item: Entity.Item
1163
+ :param predictions: item's AnnotationCollection
1164
+ :return:
1165
+ """
1166
+ for prediction in predictions:
1167
+ if prediction.type == entities.AnnotationType.SEGMENTATION:
1168
+ color = None
1169
+ try:
1170
+ color = item.dataset._get_ontology().color_map.get(prediction.label, None)
1171
+ except (exceptions.BadRequest, exceptions.NotFound):
1172
+ ...
1173
+ if color is None:
1174
+ if self.model_entity._dataset is not None:
1175
+ try:
1176
+ color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label, (255, 255, 255))
1177
+ except (exceptions.BadRequest, exceptions.NotFound):
1178
+ ...
1179
+ if color is None:
1180
+ logger.warning("Can't get annotation color from model's dataset, using default.")
1181
+ color = prediction.color
1182
+ prediction.color = color
1183
+
1184
+ prediction.item_id = item.id
1185
+ if 'user' in prediction.metadata and 'model' in prediction.metadata['user']:
1186
+ prediction.metadata['user']['model']['model_id'] = self.model_entity.id
1187
+ prediction.metadata['user']['model']['name'] = self.model_entity.name
1188
+ if 'system' not in prediction.metadata:
1189
+ prediction.metadata['system'] = dict()
1190
+ if 'model' not in prediction.metadata['system']:
1191
+ prediction.metadata['system']['model'] = dict()
1192
+ confidence = prediction.metadata.get('user', dict()).get('model', dict()).get('confidence', None)
1193
+ prediction.metadata['system']['model'] = {
1194
+ 'model_id': self.model_entity.id,
1195
+ 'name': self.model_entity.name,
1196
+ 'confidence': confidence,
1197
+ }
1198
+
1199
+ ##############################
1200
+ # Callback Factory functions #
1201
+ ##############################
1202
+ @property
1203
+ def dataloop_keras_callback(self):
1204
+ """
1205
+ Returns the constructor for a keras api dump callback
1206
+ The callback is used for dlp platform to show train losses
1207
+
1208
+ :return: DumpHistoryCallback constructor
1209
+ """
1210
+ try:
1211
+ import keras
1212
+ except (ImportError, ModuleNotFoundError) as err:
1213
+ raise RuntimeError(f'{self.__class__.__name__} depends on extenral package. Please install ') from err
1214
+
1215
+ import os
1216
+ import time
1217
+ import json
1218
+
1219
+ class DumpHistoryCallback(keras.callbacks.Callback):
1220
+ def __init__(self, dump_path):
1221
+ super().__init__()
1222
+ if os.path.isdir(dump_path):
1223
+ dump_path = os.path.join(dump_path, f'__view__training-history__{time.strftime("%F-%X")}.json')
1224
+ self.dump_file = dump_path
1225
+ self.data = dict()
1226
+
1227
+ def on_epoch_end(self, epoch, logs=None):
1228
+ logs = logs or {}
1229
+ for name, val in logs.items():
1230
+ if name not in self.data:
1231
+ self.data[name] = {'x': list(), 'y': list()}
1232
+ self.data[name]['x'].append(float(epoch))
1233
+ self.data[name]['y'].append(float(val))
1234
+ self.dump_history()
1235
+
1236
+ def dump_history(self):
1237
+ _json = {
1238
+ "query": {},
1239
+ "datasetId": "",
1240
+ "xlabel": "epoch",
1241
+ "title": "training loss",
1242
+ "ylabel": "val",
1243
+ "type": "metric",
1244
+ "data": [
1245
+ {
1246
+ "name": name,
1247
+ "x": values['x'],
1248
+ "y": values['y'],
1249
+ }
1250
+ for name, values in self.data.items()
1251
+ ],
1252
+ }
1253
+
1254
+ with open(self.dump_file, 'w') as f:
1255
+ json.dump(_json, f, indent=2)
1256
+
1257
+ return DumpHistoryCallback