dtlpy 1.113.10__py3-none-any.whl → 1.114.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. dtlpy/__init__.py +488 -488
  2. dtlpy/__version__.py +1 -1
  3. dtlpy/assets/__init__.py +26 -26
  4. dtlpy/assets/__pycache__/__init__.cpython-38.pyc +0 -0
  5. dtlpy/assets/code_server/config.yaml +2 -2
  6. dtlpy/assets/code_server/installation.sh +24 -24
  7. dtlpy/assets/code_server/launch.json +13 -13
  8. dtlpy/assets/code_server/settings.json +2 -2
  9. dtlpy/assets/main.py +53 -53
  10. dtlpy/assets/main_partial.py +18 -18
  11. dtlpy/assets/mock.json +11 -11
  12. dtlpy/assets/model_adapter.py +83 -83
  13. dtlpy/assets/package.json +61 -61
  14. dtlpy/assets/package_catalog.json +29 -29
  15. dtlpy/assets/package_gitignore +307 -307
  16. dtlpy/assets/service_runners/__init__.py +33 -33
  17. dtlpy/assets/service_runners/converter.py +96 -96
  18. dtlpy/assets/service_runners/multi_method.py +49 -49
  19. dtlpy/assets/service_runners/multi_method_annotation.py +54 -54
  20. dtlpy/assets/service_runners/multi_method_dataset.py +55 -55
  21. dtlpy/assets/service_runners/multi_method_item.py +52 -52
  22. dtlpy/assets/service_runners/multi_method_json.py +52 -52
  23. dtlpy/assets/service_runners/single_method.py +37 -37
  24. dtlpy/assets/service_runners/single_method_annotation.py +43 -43
  25. dtlpy/assets/service_runners/single_method_dataset.py +43 -43
  26. dtlpy/assets/service_runners/single_method_item.py +41 -41
  27. dtlpy/assets/service_runners/single_method_json.py +42 -42
  28. dtlpy/assets/service_runners/single_method_multi_input.py +45 -45
  29. dtlpy/assets/voc_annotation_template.xml +23 -23
  30. dtlpy/caches/base_cache.py +32 -32
  31. dtlpy/caches/cache.py +473 -473
  32. dtlpy/caches/dl_cache.py +201 -201
  33. dtlpy/caches/filesystem_cache.py +89 -89
  34. dtlpy/caches/redis_cache.py +84 -84
  35. dtlpy/dlp/__init__.py +20 -20
  36. dtlpy/dlp/cli_utilities.py +367 -367
  37. dtlpy/dlp/command_executor.py +764 -764
  38. dtlpy/dlp/dlp +1 -1
  39. dtlpy/dlp/dlp.bat +1 -1
  40. dtlpy/dlp/dlp.py +128 -128
  41. dtlpy/dlp/parser.py +651 -651
  42. dtlpy/entities/__init__.py +83 -83
  43. dtlpy/entities/analytic.py +311 -311
  44. dtlpy/entities/annotation.py +1879 -1879
  45. dtlpy/entities/annotation_collection.py +699 -699
  46. dtlpy/entities/annotation_definitions/__init__.py +20 -20
  47. dtlpy/entities/annotation_definitions/base_annotation_definition.py +100 -100
  48. dtlpy/entities/annotation_definitions/box.py +195 -195
  49. dtlpy/entities/annotation_definitions/classification.py +67 -67
  50. dtlpy/entities/annotation_definitions/comparison.py +72 -72
  51. dtlpy/entities/annotation_definitions/cube.py +204 -204
  52. dtlpy/entities/annotation_definitions/cube_3d.py +149 -149
  53. dtlpy/entities/annotation_definitions/description.py +32 -32
  54. dtlpy/entities/annotation_definitions/ellipse.py +124 -124
  55. dtlpy/entities/annotation_definitions/free_text.py +62 -62
  56. dtlpy/entities/annotation_definitions/gis.py +69 -69
  57. dtlpy/entities/annotation_definitions/note.py +139 -139
  58. dtlpy/entities/annotation_definitions/point.py +117 -117
  59. dtlpy/entities/annotation_definitions/polygon.py +182 -182
  60. dtlpy/entities/annotation_definitions/polyline.py +111 -111
  61. dtlpy/entities/annotation_definitions/pose.py +92 -92
  62. dtlpy/entities/annotation_definitions/ref_image.py +86 -86
  63. dtlpy/entities/annotation_definitions/segmentation.py +240 -240
  64. dtlpy/entities/annotation_definitions/subtitle.py +34 -34
  65. dtlpy/entities/annotation_definitions/text.py +85 -85
  66. dtlpy/entities/annotation_definitions/undefined_annotation.py +74 -74
  67. dtlpy/entities/app.py +220 -220
  68. dtlpy/entities/app_module.py +107 -107
  69. dtlpy/entities/artifact.py +174 -174
  70. dtlpy/entities/assignment.py +399 -399
  71. dtlpy/entities/base_entity.py +214 -214
  72. dtlpy/entities/bot.py +113 -113
  73. dtlpy/entities/codebase.py +296 -296
  74. dtlpy/entities/collection.py +38 -38
  75. dtlpy/entities/command.py +169 -169
  76. dtlpy/entities/compute.py +442 -442
  77. dtlpy/entities/dataset.py +1285 -1285
  78. dtlpy/entities/directory_tree.py +44 -44
  79. dtlpy/entities/dpk.py +470 -470
  80. dtlpy/entities/driver.py +222 -222
  81. dtlpy/entities/execution.py +397 -397
  82. dtlpy/entities/feature.py +124 -124
  83. dtlpy/entities/feature_set.py +145 -145
  84. dtlpy/entities/filters.py +641 -641
  85. dtlpy/entities/gis_item.py +107 -107
  86. dtlpy/entities/integration.py +184 -184
  87. dtlpy/entities/item.py +953 -953
  88. dtlpy/entities/label.py +123 -123
  89. dtlpy/entities/links.py +85 -85
  90. dtlpy/entities/message.py +175 -175
  91. dtlpy/entities/model.py +694 -691
  92. dtlpy/entities/node.py +1005 -1005
  93. dtlpy/entities/ontology.py +803 -803
  94. dtlpy/entities/organization.py +287 -287
  95. dtlpy/entities/package.py +657 -657
  96. dtlpy/entities/package_defaults.py +5 -5
  97. dtlpy/entities/package_function.py +185 -185
  98. dtlpy/entities/package_module.py +113 -113
  99. dtlpy/entities/package_slot.py +118 -118
  100. dtlpy/entities/paged_entities.py +290 -267
  101. dtlpy/entities/pipeline.py +593 -593
  102. dtlpy/entities/pipeline_execution.py +279 -279
  103. dtlpy/entities/project.py +394 -394
  104. dtlpy/entities/prompt_item.py +499 -499
  105. dtlpy/entities/recipe.py +301 -301
  106. dtlpy/entities/reflect_dict.py +102 -102
  107. dtlpy/entities/resource_execution.py +138 -138
  108. dtlpy/entities/service.py +958 -958
  109. dtlpy/entities/service_driver.py +117 -117
  110. dtlpy/entities/setting.py +294 -294
  111. dtlpy/entities/task.py +491 -491
  112. dtlpy/entities/time_series.py +143 -143
  113. dtlpy/entities/trigger.py +426 -426
  114. dtlpy/entities/user.py +118 -118
  115. dtlpy/entities/webhook.py +124 -124
  116. dtlpy/examples/__init__.py +19 -19
  117. dtlpy/examples/add_labels.py +135 -135
  118. dtlpy/examples/add_metadata_to_item.py +21 -21
  119. dtlpy/examples/annotate_items_using_model.py +65 -65
  120. dtlpy/examples/annotate_video_using_model_and_tracker.py +75 -75
  121. dtlpy/examples/annotations_convert_to_voc.py +9 -9
  122. dtlpy/examples/annotations_convert_to_yolo.py +9 -9
  123. dtlpy/examples/convert_annotation_types.py +51 -51
  124. dtlpy/examples/converter.py +143 -143
  125. dtlpy/examples/copy_annotations.py +22 -22
  126. dtlpy/examples/copy_folder.py +31 -31
  127. dtlpy/examples/create_annotations.py +51 -51
  128. dtlpy/examples/create_video_annotations.py +83 -83
  129. dtlpy/examples/delete_annotations.py +26 -26
  130. dtlpy/examples/filters.py +113 -113
  131. dtlpy/examples/move_item.py +23 -23
  132. dtlpy/examples/play_video_annotation.py +13 -13
  133. dtlpy/examples/show_item_and_mask.py +53 -53
  134. dtlpy/examples/triggers.py +49 -49
  135. dtlpy/examples/upload_batch_of_items.py +20 -20
  136. dtlpy/examples/upload_items_and_custom_format_annotations.py +55 -55
  137. dtlpy/examples/upload_items_with_modalities.py +43 -43
  138. dtlpy/examples/upload_segmentation_annotations_from_mask_image.py +44 -44
  139. dtlpy/examples/upload_yolo_format_annotations.py +70 -70
  140. dtlpy/exceptions.py +125 -125
  141. dtlpy/miscellaneous/__init__.py +20 -20
  142. dtlpy/miscellaneous/dict_differ.py +95 -95
  143. dtlpy/miscellaneous/git_utils.py +217 -217
  144. dtlpy/miscellaneous/json_utils.py +14 -14
  145. dtlpy/miscellaneous/list_print.py +105 -105
  146. dtlpy/miscellaneous/zipping.py +130 -130
  147. dtlpy/ml/__init__.py +20 -20
  148. dtlpy/ml/base_feature_extractor_adapter.py +27 -27
  149. dtlpy/ml/base_model_adapter.py +945 -940
  150. dtlpy/ml/metrics.py +461 -461
  151. dtlpy/ml/predictions_utils.py +274 -274
  152. dtlpy/ml/summary_writer.py +57 -57
  153. dtlpy/ml/train_utils.py +60 -60
  154. dtlpy/new_instance.py +252 -252
  155. dtlpy/repositories/__init__.py +56 -56
  156. dtlpy/repositories/analytics.py +85 -85
  157. dtlpy/repositories/annotations.py +916 -916
  158. dtlpy/repositories/apps.py +383 -383
  159. dtlpy/repositories/artifacts.py +452 -452
  160. dtlpy/repositories/assignments.py +599 -599
  161. dtlpy/repositories/bots.py +213 -213
  162. dtlpy/repositories/codebases.py +559 -559
  163. dtlpy/repositories/collections.py +332 -348
  164. dtlpy/repositories/commands.py +158 -158
  165. dtlpy/repositories/compositions.py +61 -61
  166. dtlpy/repositories/computes.py +434 -406
  167. dtlpy/repositories/datasets.py +1291 -1291
  168. dtlpy/repositories/downloader.py +895 -895
  169. dtlpy/repositories/dpks.py +433 -433
  170. dtlpy/repositories/drivers.py +266 -266
  171. dtlpy/repositories/executions.py +817 -817
  172. dtlpy/repositories/feature_sets.py +226 -226
  173. dtlpy/repositories/features.py +238 -238
  174. dtlpy/repositories/integrations.py +484 -484
  175. dtlpy/repositories/items.py +909 -915
  176. dtlpy/repositories/messages.py +94 -94
  177. dtlpy/repositories/models.py +877 -867
  178. dtlpy/repositories/nodes.py +80 -80
  179. dtlpy/repositories/ontologies.py +511 -511
  180. dtlpy/repositories/organizations.py +525 -525
  181. dtlpy/repositories/packages.py +1941 -1941
  182. dtlpy/repositories/pipeline_executions.py +448 -448
  183. dtlpy/repositories/pipelines.py +642 -642
  184. dtlpy/repositories/projects.py +539 -539
  185. dtlpy/repositories/recipes.py +399 -399
  186. dtlpy/repositories/resource_executions.py +137 -137
  187. dtlpy/repositories/schema.py +120 -120
  188. dtlpy/repositories/service_drivers.py +213 -213
  189. dtlpy/repositories/services.py +1704 -1704
  190. dtlpy/repositories/settings.py +339 -339
  191. dtlpy/repositories/tasks.py +1124 -1124
  192. dtlpy/repositories/times_series.py +278 -278
  193. dtlpy/repositories/triggers.py +536 -536
  194. dtlpy/repositories/upload_element.py +257 -257
  195. dtlpy/repositories/uploader.py +651 -651
  196. dtlpy/repositories/webhooks.py +249 -249
  197. dtlpy/services/__init__.py +22 -22
  198. dtlpy/services/aihttp_retry.py +131 -131
  199. dtlpy/services/api_client.py +1782 -1782
  200. dtlpy/services/api_reference.py +40 -40
  201. dtlpy/services/async_utils.py +133 -133
  202. dtlpy/services/calls_counter.py +44 -44
  203. dtlpy/services/check_sdk.py +68 -68
  204. dtlpy/services/cookie.py +115 -115
  205. dtlpy/services/create_logger.py +156 -156
  206. dtlpy/services/events.py +84 -84
  207. dtlpy/services/logins.py +235 -235
  208. dtlpy/services/reporter.py +256 -256
  209. dtlpy/services/service_defaults.py +91 -91
  210. dtlpy/utilities/__init__.py +20 -20
  211. dtlpy/utilities/annotations/__init__.py +16 -16
  212. dtlpy/utilities/annotations/annotation_converters.py +269 -269
  213. dtlpy/utilities/base_package_runner.py +264 -264
  214. dtlpy/utilities/converter.py +1650 -1650
  215. dtlpy/utilities/dataset_generators/__init__.py +1 -1
  216. dtlpy/utilities/dataset_generators/dataset_generator.py +670 -670
  217. dtlpy/utilities/dataset_generators/dataset_generator_tensorflow.py +23 -23
  218. dtlpy/utilities/dataset_generators/dataset_generator_torch.py +21 -21
  219. dtlpy/utilities/local_development/__init__.py +1 -1
  220. dtlpy/utilities/local_development/local_session.py +179 -179
  221. dtlpy/utilities/reports/__init__.py +2 -2
  222. dtlpy/utilities/reports/figures.py +343 -343
  223. dtlpy/utilities/reports/report.py +71 -71
  224. dtlpy/utilities/videos/__init__.py +17 -17
  225. dtlpy/utilities/videos/video_player.py +598 -598
  226. dtlpy/utilities/videos/videos.py +470 -470
  227. {dtlpy-1.113.10.data → dtlpy-1.114.13.data}/scripts/dlp +1 -1
  228. dtlpy-1.114.13.data/scripts/dlp.bat +2 -0
  229. {dtlpy-1.113.10.data → dtlpy-1.114.13.data}/scripts/dlp.py +128 -128
  230. {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/LICENSE +200 -200
  231. {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/METADATA +172 -172
  232. dtlpy-1.114.13.dist-info/RECORD +240 -0
  233. {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/WHEEL +1 -1
  234. tests/features/environment.py +551 -550
  235. dtlpy-1.113.10.data/scripts/dlp.bat +0 -2
  236. dtlpy-1.113.10.dist-info/RECORD +0 -244
  237. tests/assets/__init__.py +0 -0
  238. tests/assets/models_flow/__init__.py +0 -0
  239. tests/assets/models_flow/failedmain.py +0 -52
  240. tests/assets/models_flow/main.py +0 -62
  241. tests/assets/models_flow/main_model.py +0 -54
  242. {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/entry_points.txt +0 -0
  243. {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/top_level.txt +0 -0
@@ -1,940 +1,945 @@
1
- import dataclasses
2
- import tempfile
3
- import datetime
4
- import logging
5
- import string
6
- import shutil
7
- import random
8
- import base64
9
- import tqdm
10
- import sys
11
- import io
12
- import os
13
- from PIL import Image
14
- from functools import partial
15
- import numpy as np
16
- from concurrent.futures import ThreadPoolExecutor
17
- import attr
18
- from .. import entities, utilities, repositories, exceptions
19
- from ..services import service_defaults
20
- from ..services.api_client import ApiClient
21
-
22
- logger = logging.getLogger('ModelAdapter')
23
-
24
-
25
- @dataclasses.dataclass
26
- class AdapterDefaults(dict):
27
- # for predict items, dataset, evaluate
28
- upload_annotations: bool = dataclasses.field(default=True)
29
- clean_annotations: bool = dataclasses.field(default=True)
30
- # for embeddings
31
- upload_features: bool = dataclasses.field(default=True)
32
- # for training
33
- root_path: str = dataclasses.field(default=None)
34
- data_path: str = dataclasses.field(default=None)
35
- output_path: str = dataclasses.field(default=None)
36
-
37
- def __post_init__(self):
38
- # Initialize the internal dictionary with the dataclass fields
39
- self.update(**dataclasses.asdict(self))
40
-
41
- def update(self, **kwargs):
42
- for f in dataclasses.fields(AdapterDefaults):
43
- if f.name in kwargs:
44
- setattr(self, f.name, kwargs[f.name])
45
- super().update(**kwargs)
46
-
47
- def resolve(self, key, *args):
48
-
49
- for arg in args:
50
- if arg is not None:
51
- return arg
52
- return self.get(key, None)
53
-
54
-
55
- class BaseModelAdapter(utilities.BaseServiceRunner):
56
- _client_api = attr.ib(type=ApiClient, repr=False)
57
-
58
- def __init__(self, model_entity: entities.Model = None):
59
- self.adapter_defaults = AdapterDefaults()
60
- self.logger = logger
61
- # entities
62
- self._model_entity = None
63
- self._package = None
64
- self._base_configuration = dict()
65
- self.package_name = None
66
- self.model = None
67
- self.bucket_path = None
68
- # funcs
69
- self.item_to_batch_mapping = {'text': self._item_to_text,
70
- 'image': self._item_to_image}
71
- if model_entity is not None:
72
- self.load_from_model(model_entity=model_entity)
73
- logger.warning(
74
- "in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'.")
75
-
76
- ##################
77
- # Configurations #
78
- ##################
79
-
80
- @property
81
- def configuration(self) -> dict:
82
- # load from model
83
- if self._model_entity is not None:
84
- configuration = self.model_entity.configuration
85
- # else - load the default from the package
86
- elif self._package is not None:
87
- configuration = self.package.metadata.get('system', {}).get('ml', {}).get('defaultConfiguration', {})
88
- else:
89
- configuration = self._base_configuration
90
- return configuration
91
-
92
- @configuration.setter
93
- def configuration(self, d):
94
- assert isinstance(d, dict)
95
- if self._model_entity is not None:
96
- self._model_entity.configuration = d
97
-
98
- ############
99
- # Entities #
100
- ############
101
- @property
102
- def model_entity(self):
103
- if self._model_entity is None:
104
- raise ValueError(
105
- "No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
106
- assert isinstance(self._model_entity, entities.Model)
107
- return self._model_entity
108
-
109
- @model_entity.setter
110
- def model_entity(self, model_entity):
111
- assert isinstance(model_entity, entities.Model)
112
- if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
113
- if self._model_entity.id != model_entity.id:
114
- self.logger.warning(
115
- 'Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
116
- self._model_entity = model_entity
117
- self.package = model_entity.package
118
-
119
- @property
120
- def package(self):
121
- if self._model_entity is not None:
122
- self.package = self.model_entity.package
123
- if self._package is None:
124
- raise ValueError('Missing Package entity on adapter. Please set: "adapter.package=package"')
125
- assert isinstance(self._package, (entities.Package, entities.Dpk))
126
- return self._package
127
-
128
- @package.setter
129
- def package(self, package):
130
- assert isinstance(package, (entities.Package, entities.Dpk))
131
- self.package_name = package.name
132
- self._package = package
133
-
134
- ###################################
135
- # NEED TO IMPLEMENT THESE METHODS #
136
- ###################################
137
-
138
- def load(self, local_path, **kwargs):
139
- """ Loads model and populates self.model with a `runnable` model
140
-
141
- Virtual method - need to implement
142
-
143
- This function is called by load_from_model (download to local and then loads)
144
-
145
- :param local_path: `str` directory path in local FileSystem
146
- """
147
- raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
148
-
149
- def save(self, local_path, **kwargs):
150
- """ saves configuration and weights locally
151
-
152
- Virtual method - need to implement
153
-
154
- the function is called in save_to_model which first save locally and then uploads to model entity
155
-
156
- :param local_path: `str` directory path in local FileSystem
157
- """
158
- raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
159
-
160
- def train(self, data_path, output_path, **kwargs):
161
- """
162
- Virtual method - need to implement
163
-
164
- Train the model according to data in data_paths and save the train outputs to output_path,
165
- this include the weights and any other artifacts created during train
166
-
167
- :param data_path: `str` local File System path to where the data was downloaded and converted at
168
- :param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
169
- """
170
- raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
171
-
172
- def predict(self, batch, **kwargs):
173
- """ Model inference (predictions) on batch of items
174
-
175
- Virtual method - need to implement
176
-
177
- :param batch: output of the `prepare_item_func` func
178
- :return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
179
- """
180
- raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
181
-
182
- def embed(self, batch, **kwargs):
183
- """ Extract model embeddings on batch of items
184
-
185
- Virtual method - need to implement
186
-
187
- :param batch: output of the `prepare_item_func` func
188
- :return: `list[list]` a feature vector per each item in the batch
189
- """
190
- raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
191
-
192
- def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
193
- """
194
- This function evaluates the model prediction on a dataset (with GT annotations).
195
- The evaluation process will upload the scores and metrics to the platform.
196
-
197
- :param model: The model to evaluate (annotation.metadata.system.model.name
198
- :param dataset: Dataset where the model predicted and uploaded its annotations
199
- :param filters: Filters to query items on the dataset
200
- :return:
201
- """
202
- import dtlpymetrics
203
- compare_types = model.output_type
204
- if not filters:
205
- filters = entities.Filters()
206
- if filters is not None and isinstance(filters, dict):
207
- filters = entities.Filters(custom_filter=filters)
208
- model = dtlpymetrics.scoring.create_model_score(model=model,
209
- dataset=dataset,
210
- filters=filters,
211
- compare_types=compare_types)
212
- return model
213
-
214
- def convert_from_dtlpy(self, data_path, **kwargs):
215
- """ Convert Dataloop structure data to model structured
216
-
217
- Virtual method - need to implement
218
-
219
- e.g. take dlp dir structure and construct annotation file
220
-
221
- :param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
222
- :return:
223
- """
224
- raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
225
-
226
- #################
227
- # DTLPY METHODS #
228
- ################
229
- def prepare_item_func(self, item: entities.Item):
230
- """
231
- Prepare the Dataloop item before calling the `predict` function with a batch.
232
- A user can override this function to load item differently
233
- Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
234
-
235
- :param item:
236
- :return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
237
- """
238
- # Item to batch func
239
- if isinstance(self.model_entity.input_type, list):
240
- if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
241
- processed = self._item_to_text(item)
242
- elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
243
- processed = self._item_to_image(item)
244
- else:
245
- processed = self._item_to_item(item)
246
-
247
- elif self.model_entity.input_type in self.item_to_batch_mapping:
248
- processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
249
-
250
- else:
251
- processed = self._item_to_item(item)
252
-
253
- return processed
254
-
255
- def prepare_data(self,
256
- dataset: entities.Dataset,
257
- # paths
258
- root_path=None,
259
- data_path=None,
260
- output_path=None,
261
- #
262
- overwrite=False,
263
- **kwargs):
264
- """
265
- Prepares dataset locally before training or evaluation.
266
- download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
267
-
268
- :param dataset: dl.Dataset
269
- :param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
270
- :param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
271
- :param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
272
-
273
- :param bool overwrite: overwrite the data path (download again). default is False
274
- """
275
- # define paths
276
- dataloop_path = service_defaults.DATALOOP_PATH
277
- root_path = self.adapter_defaults.resolve("root_path", root_path)
278
- data_path = self.adapter_defaults.resolve("data_path", data_path)
279
- output_path = self.adapter_defaults.resolve("output_path", output_path)
280
-
281
- if root_path is None:
282
- now = datetime.datetime.now()
283
- root_path = os.path.join(dataloop_path,
284
- 'model_data',
285
- "{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
286
- now.strftime('%Y-%m-%d-%H%M%S'),
287
- )
288
- if data_path is None:
289
- data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
290
- os.makedirs(data_path, exist_ok=True)
291
- if output_path is None:
292
- output_path = os.path.join(root_path, 'output')
293
- os.makedirs(output_path, exist_ok=True)
294
-
295
- if len(os.listdir(data_path)) > 0:
296
- self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
297
-
298
- annotation_options = entities.ViewAnnotationOptions.JSON
299
- if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
300
- annotation_options = entities.ViewAnnotationOptions.INSTANCE
301
-
302
- # Download the subset items
303
- subsets = self.model_entity.metadata.get("system", dict()).get("subsets", None)
304
- if subsets is None:
305
- raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
306
- for subset, filters_dict in subsets.items():
307
- filters = entities.Filters(custom_filter=filters_dict)
308
- data_subset_base_path = os.path.join(data_path, subset)
309
- if os.path.isdir(data_subset_base_path) and not overwrite:
310
- # existing and dont overwrite
311
- self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
312
- else:
313
- self.logger.debug("Downloading subset {!r} of {}".format(subset,
314
- self.model_entity.dataset.name))
315
-
316
- annotation_filters = None
317
- if self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
318
- annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
319
- if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION,
320
- entities.AnnotationType.POLYGON]:
321
- model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
322
- else:
323
- model_output_types = [self.model_entity.output_type]
324
-
325
- annotation_filters.add(
326
- field=entities.FiltersKnownFields.TYPE,
327
- values=model_output_types,
328
- operator=entities.FiltersOperations.IN
329
- )
330
-
331
- if not self.configuration.get("include_model_annotations", False):
332
- annotation_filters.add(
333
- field="metadata.system.model.name",
334
- values=False,
335
- operator=entities.FiltersOperations.EXISTS
336
- )
337
-
338
- ret_list = dataset.items.download(filters=filters,
339
- local_path=data_subset_base_path,
340
- annotation_options=annotation_options,
341
- annotation_filters=annotation_filters
342
- )
343
- if isinstance(ret_list, list) and len(ret_list) == 0:
344
- raise ValueError(f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
345
- f"Subset {subset} filters: {filters.prepare()}\nAnnotation filters: {annotation_filters.prepare()}")
346
-
347
- self.convert_from_dtlpy(data_path=data_path, **kwargs)
348
- return root_path, data_path, output_path
349
-
350
- def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
351
- """ Loads a model from given `dl.Model`.
352
- Reads configurations and instantiate self.model_entity
353
- Downloads the model_entity bucket (if available)
354
-
355
- :param model_entity: `str` dl.Model entity
356
- :param local_path: `str` directory path in local FileSystem to download the model_entity to
357
- :param overwrite: `bool` (default False) if False does not download files with same name else (True) download all
358
- """
359
- if model_entity is not None:
360
- self.model_entity = model_entity
361
- if local_path is None:
362
- local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
363
- # Load configuration
364
- self.configuration = self.model_entity.configuration
365
- # Update the adapter config with the model config to run over defaults if needed
366
- self.adapter_defaults.update(**self.configuration)
367
- # Download
368
- self.model_entity.artifacts.download(
369
- local_path=local_path,
370
- overwrite=overwrite
371
- )
372
- self.load(local_path, **kwargs)
373
-
374
- def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
375
- """
376
- Saves the model state to a new bucket and configuration
377
-
378
- Saves configuration and weights to artifacts
379
- Mark the model as `trained`
380
- loads only applies for remote buckets
381
-
382
- :param local_path: `str` directory path in local FileSystem to save the current model bucket (weights) (default will create a temp dir)
383
- :param replace: `bool` will clean the bucket's content before uploading new files
384
- :param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
385
- :return:
386
- """
387
-
388
- if local_path is None:
389
- local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
390
- self.logger.debug("Using temporary dir at {}".format(local_path))
391
-
392
- self.save(local_path=local_path, **kwargs)
393
-
394
- if self.model_entity is None:
395
- raise ValueError('Missing model entity on the adapter. '
396
- 'Please set before saving: "adapter.model_entity=model"')
397
-
398
- self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'),
399
- overwrite=True)
400
- if cleanup:
401
- shutil.rmtree(path=local_path, ignore_errors=True)
402
- self.logger.info("Clean-up. deleting {}".format(local_path))
403
-
404
- # ===============
405
- # SERVICE METHODS
406
- # ===============
407
-
408
- @entities.Package.decorators.function(display_name='Predict Items',
409
- inputs={'items': 'Item[]'},
410
- outputs={'items': 'Item[]', 'annotations': 'Annotation[]'})
411
- def predict_items(self, items: list, upload_annotations=None, clean_annotations=None, batch_size=None, **kwargs):
412
- """
413
- Run the predict function on the input list of items (or single) and return the items and the predictions.
414
- Each prediction is by the model output type (package.output_type) and model_info in the metadata
415
-
416
- :param items: `List[dl.Item]` list of items to predict
417
- :param upload_annotations: `bool` uploads the predictions on the given items
418
- :param clean_annotations: `bool` deletes previous model annotations (predictions) before uploading new ones
419
- :param batch_size: `int` size of batch to run a single inference
420
-
421
- :return: `List[dl.Item]`, `List[List[dl.Annotation]]`
422
- """
423
- if batch_size is None:
424
- batch_size = self.configuration.get('batch_size', 4)
425
- upload_annotations = self.adapter_defaults.resolve("upload_annotations", upload_annotations)
426
- clean_annotations = self.adapter_defaults.resolve("clean_annotations", clean_annotations)
427
- input_type = self.model_entity.input_type
428
- self.logger.debug(
429
- "Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
430
- pool = ThreadPoolExecutor(max_workers=16)
431
-
432
- annotations = list()
433
- for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None,
434
- file=sys.stdout):
435
- batch_items = items[i_batch: i_batch + batch_size]
436
- batch = list(pool.map(self.prepare_item_func, batch_items))
437
- batch_collections = self.predict(batch, **kwargs)
438
- _futures = list(pool.map(partial(self._update_predictions_metadata),
439
- batch_items,
440
- batch_collections))
441
- # Loop over the futures to make sure they are all done to avoid race conditions
442
- _ = [_f for _f in _futures]
443
- if upload_annotations is True:
444
- self.logger.debug(
445
- "Uploading items' annotation for model {!r}.".format(self.model_entity.name))
446
- try:
447
- batch_collections = list(pool.map(partial(self._upload_model_annotations,
448
- clean_annotations=clean_annotations),
449
- batch_items,
450
- batch_collections))
451
- except Exception as err:
452
- self.logger.exception("Failed to upload annotations items.")
453
-
454
- for collection in batch_collections:
455
- # function needs to return `List[List[dl.Annotation]]`
456
- # convert annotation collection to a list of dl.Annotation for each batch
457
- if isinstance(collection, entities.AnnotationCollection):
458
- annotations.extend([annotation for annotation in collection.annotations])
459
- else:
460
- logger.warning(f'RETURN TYPE MAY BE INVALID: {type(collection)}')
461
- annotations.extend(collection)
462
- # TODO call the callback
463
-
464
- pool.shutdown()
465
- return items, annotations
466
-
467
- @entities.Package.decorators.function(display_name='Embed Items',
468
- inputs={'items': 'Item[]'},
469
- outputs={'items': 'Item[]', 'features': 'Json[]'})
470
- def embed_items(self, items: list, upload_features=None, batch_size=None, progress:utilities.Progress=None, **kwargs):
471
- """
472
- Extract feature from an input list of items (or single) and return the items and the feature vector.
473
-
474
- :param items: `List[dl.Item]` list of items to embed
475
- :param upload_features: `bool` uploads the features on the given items
476
- :param batch_size: `int` size of batch to run a single embed
477
-
478
- :return: `List[dl.Item]`, `List[List[vector]]`
479
- """
480
- if batch_size is None:
481
- batch_size = self.configuration.get('batch_size', 4)
482
- upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
483
- input_type = self.model_entity.input_type
484
- self.logger.debug(
485
- "Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
486
-
487
- # Search for existing feature set for this model id
488
- feature_set = self.model_entity.feature_set
489
- if feature_set is None:
490
- logger.info('Feature Set not found. creating... ')
491
- try:
492
- self.model_entity.project.feature_sets.get(feature_set_name=self.model_entity.name)
493
- feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
494
- logger.warning(
495
- f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}")
496
- except exceptions.NotFound:
497
- feature_set_name = self.model_entity.name
498
- feature_set = self.model_entity.project.feature_sets.create(name=feature_set_name,
499
- entity_type=entities.FeatureEntityType.ITEM,
500
- model_id=self.model_entity.id,
501
- project_id=self.model_entity.project_id,
502
- set_type=self.model_entity.name,
503
- size=self.configuration.get('embeddings_size',
504
- 256))
505
- logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
506
- else:
507
- logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
508
-
509
- # upload the feature vectors
510
- pool = ThreadPoolExecutor(max_workers=16)
511
- vectors = list()
512
- for i_batch in tqdm.tqdm(range(0, len(items), batch_size),
513
- desc='embedding',
514
- unit='bt',
515
- leave=None,
516
- file=sys.stdout):
517
- batch_items = items[i_batch: i_batch + batch_size]
518
- batch = list(pool.map(self.prepare_item_func, batch_items))
519
- batch_vectors = self.embed(batch, **kwargs)
520
- vectors.extend(batch_vectors)
521
- if upload_features is True:
522
- self.logger.debug(
523
- "Uploading items' feature vectors for model {!r}.".format(self.model_entity.name))
524
- try:
525
- list(pool.map(partial(self._upload_model_features,
526
- progress.logger if progress is not None else self.logger,
527
- feature_set.id,
528
- self.model_entity.project_id),
529
- batch_items,
530
- batch_vectors))
531
- except Exception as err:
532
- self.logger.exception("Failed to upload feature vectors to items.")
533
-
534
- pool.shutdown()
535
- return items, vectors
536
-
537
- @entities.Package.decorators.function(display_name='Embed Dataset with DQL',
538
- inputs={'dataset': 'Dataset',
539
- 'filters': 'Json'})
540
- def embed_dataset(self,
541
- dataset: entities.Dataset,
542
- filters: entities.Filters = None,
543
- upload_features=None,
544
- batch_size=None,
545
- progress:utilities.Progress=None,
546
- **kwargs):
547
- """
548
- Extract feature from all items given
549
-
550
- :param dataset: Dataset entity to predict
551
- :param filters: Filters entity for a filtering before embedding
552
- :param upload_features: `bool` uploads the features back to the given items
553
- :param batch_size: `int` size of batch to run a single embed
554
-
555
- :return: `bool` indicating if the embedding process completed successfully
556
- """
557
- if batch_size is None:
558
- batch_size = self.configuration.get('batch_size', 4)
559
- upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
560
-
561
- self.logger.debug("Creating embeddings for dataset (name:{}, id:{}), using batch size {}".format(dataset.name,
562
- dataset.id,
563
- batch_size))
564
- if not filters:
565
- filters = entities.Filters()
566
- if filters is not None and isinstance(filters, dict):
567
- filters = entities.Filters(custom_filter=filters)
568
- pages = dataset.items.list(filters=filters, page_size=batch_size)
569
- # Item type is 'file' only, can be deleted if default filters are added to custom filters
570
- items = [item for page in pages for item in page if item.type == 'file']
571
- self.embed_items(items=items,
572
- upload_features=upload_features,
573
- batch_size=batch_size,
574
- progress=progress,
575
- **kwargs)
576
- return True
577
-
578
- @entities.Package.decorators.function(display_name='Predict Dataset with DQL',
579
- inputs={'dataset': 'Dataset',
580
- 'filters': 'Json'})
581
- def predict_dataset(self,
582
- dataset: entities.Dataset,
583
- filters: entities.Filters = None,
584
- upload_annotations=None,
585
- clean_annotations=None,
586
- batch_size=None,
587
- **kwargs):
588
- """
589
- Predicts all items given
590
-
591
- :param dataset: Dataset entity to predict
592
- :param filters: Filters entity for a filtering before predicting
593
- :param upload_annotations: `bool` uploads the predictions back to the given items
594
- :param clean_annotations: `bool` if set removes existing predictions with the same package-model name (default: False)
595
- :param batch_size: `int` size of batch to run a single inference
596
-
597
- :return: `bool` indicating if the prediction process completed successfully
598
- """
599
-
600
- if batch_size is None:
601
- batch_size = self.configuration.get('batch_size', 4)
602
-
603
- self.logger.debug("Predicting dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
604
- dataset.id,
605
- batch_size))
606
- if not filters:
607
- filters = entities.Filters()
608
- if filters is not None and isinstance(filters, dict):
609
- filters = entities.Filters(custom_filter=filters)
610
- pages = dataset.items.list(filters=filters, page_size=batch_size)
611
- # Item type is 'file' only, can be deleted if default filters are added to custom filters
612
- items = [item for page in pages for item in page if item.type == 'file']
613
- self.predict_items(items=items,
614
- upload_annotations=upload_annotations,
615
- clean_annotations=clean_annotations,
616
- batch_size=batch_size,
617
- **kwargs)
618
- return True
619
-
620
- @entities.Package.decorators.function(display_name='Train a Model',
621
- inputs={'model': entities.Model},
622
- outputs={'model': entities.Model})
623
- def train_model(self,
624
- model: entities.Model,
625
- cleanup=False,
626
- progress: utilities.Progress = None,
627
- context: utilities.Context = None):
628
- """
629
- Train on existing model.
630
- data will be taken from dl.Model.datasetId
631
- configuration is as defined in dl.Model.configuration
632
- upload the output the model's bucket (model.bucket)
633
- """
634
- if isinstance(model, dict):
635
- model = repositories.Models(client_api=self._client_api).get(model_id=model['id'])
636
- output_path = None
637
- try:
638
- logger.info("Received {s} for training".format(s=model.id))
639
- model = model.wait_for_model_ready()
640
- if model.status == 'failed':
641
- raise ValueError("Model is in failed state, cannot train.")
642
-
643
- ##############
644
- # Set status #
645
- ##############
646
- model.status = 'training'
647
- if context is not None:
648
- if 'system' not in model.metadata:
649
- model.metadata['system'] = dict()
650
- model.update()
651
-
652
- ##########################
653
- # load model and weights #
654
- ##########################
655
- logger.info("Loading Adapter with: {n} ({i!r})".format(n=model.name, i=model.id))
656
- self.load_from_model(model_entity=model)
657
-
658
- ################
659
- # prepare data #
660
- ################
661
- root_path, data_path, output_path = self.prepare_data(
662
- dataset=self.model_entity.dataset,
663
- root_path=os.path.join('tmp', model.id)
664
- )
665
- # Start the Train
666
- logger.info("Training {p_name!r} with model {m_name!r} on data {d_path!r}".
667
- format(p_name=self.package_name, m_name=model.id, d_path=data_path))
668
- if progress is not None:
669
- progress.update(message='starting training')
670
-
671
- def on_epoch_end_callback(i_epoch, n_epoch):
672
- if progress is not None:
673
- progress.update(progress=int(100 * (i_epoch + 1) / n_epoch),
674
- message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
675
-
676
- self.train(data_path=data_path,
677
- output_path=output_path,
678
- on_epoch_end_callback=on_epoch_end_callback)
679
- if progress is not None:
680
- progress.update(message='saving model',
681
- progress=99)
682
-
683
- self.save_to_model(local_path=output_path, replace=True)
684
- model.status = 'trained'
685
- model.update()
686
- ###########
687
- # cleanup #
688
- ###########
689
- if cleanup:
690
- shutil.rmtree(output_path, ignore_errors=True)
691
- except Exception:
692
- # save also on fail
693
- if output_path is not None:
694
- self.save_to_model(local_path=output_path, replace=True)
695
- logger.info('Execution failed. Setting model.status to failed')
696
- raise
697
- return model
698
-
699
- @entities.Package.decorators.function(display_name='Evaluate a Model',
700
- inputs={'model': entities.Model,
701
- 'dataset': entities.Dataset,
702
- 'filters': 'Json'},
703
- outputs={'model': entities.Model,
704
- 'dataset': entities.Dataset
705
- })
706
- def evaluate_model(self,
707
- model: entities.Model,
708
- dataset: entities.Dataset,
709
- filters: entities.Filters = None,
710
- #
711
- progress: utilities.Progress = None,
712
- context: utilities.Context = None):
713
- """
714
- Evaluate a model.
715
- data will be downloaded from the dataset and query
716
- configuration is as defined in dl.Model.configuration
717
- upload annotations and calculate metrics vs GT
718
-
719
- :param model: Model entity to run prediction
720
- :param dataset: Dataset to evaluate
721
- :param filters: Filter for specific items from dataset
722
- :param progress: dl.Progress for report FaaS progress
723
- :param context:
724
- :return:
725
- """
726
- logger.info(
727
- f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
728
- ##########################
729
- # load model and weights #
730
- ##########################
731
- logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
732
- self.load_from_model(dataset=dataset,
733
- model_entity=model)
734
-
735
- ##############
736
- # Predicting #
737
- ##############
738
- logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
739
- if not filters:
740
- filters = entities.Filters()
741
- self.predict_dataset(dataset=dataset,
742
- filters=filters,
743
- with_upload=True)
744
-
745
- ##############
746
- # Evaluating #
747
- ##############
748
- logger.info(f"Starting adapter.evaluate()")
749
- if progress is not None:
750
- progress.update(message='calculating metrics',
751
- progress=98)
752
- model = self.evaluate(model=model,
753
- dataset=dataset,
754
- filters=filters)
755
- #########
756
- # Done! #
757
- #########
758
- if progress is not None:
759
- progress.update(message='finishing evaluation',
760
- progress=99)
761
- return model, dataset
762
-
763
- # =============
764
- # INNER METHODS
765
- # =============
766
-
767
- @staticmethod
768
- def _upload_model_features(logger, feature_set_id, project_id, item: entities.Item, vector):
769
- try:
770
- if vector is not None:
771
- item.features.create(value=vector,
772
- project_id=project_id,
773
- feature_set_id=feature_set_id,
774
- entity=item)
775
- except Exception as e:
776
- logger.error(f'Failed to upload feature vector of length {len(vector)} to item {item.id}, Error: {e}')
777
-
778
- def _upload_model_annotations(self, item: entities.Item, predictions, clean_annotations):
779
- """
780
- Utility function that upload prediction to dlp platform based on the package.output_type
781
- :param predictions: `dl.AnnotationCollection`
782
- :param cleanup: `bool` if set removes existing predictions with the same package-model name
783
- """
784
- if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
785
- raise TypeError('predictions was expected to be of type {}, but instead it is {}'.
786
- format(entities.AnnotationCollection, type(predictions)))
787
- if clean_annotations:
788
- clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
789
- clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name)
790
- # clean_filter.add(field='type', values=self.model_entity.output_type,)
791
- item.annotations.delete(filters=clean_filter)
792
- annotations = item.annotations.upload(annotations=predictions)
793
- return annotations
794
-
795
- @staticmethod
796
- def _item_to_image(item):
797
- """
798
- Preprocess items before calling the `predict` functions.
799
- Convert item to numpy array
800
-
801
- :param item:
802
- :return:
803
- """
804
- buffer = item.download(save_locally=False)
805
- image = np.asarray(Image.open(buffer))
806
- return image
807
-
808
- @staticmethod
809
- def _item_to_item(item):
810
- """
811
- Default item to batch function.
812
- This function should prepare a single item for the predict function, e.g. for images, it loads the image as numpy array
813
- :param item:
814
- :return:
815
- """
816
- return item
817
-
818
- @staticmethod
819
- def _item_to_text(item):
820
- filename = item.download(overwrite=True)
821
- text = None
822
- if item.mimetype == 'text/plain' or item.mimetype == 'text/markdown':
823
- with open(filename, 'r') as f:
824
- text = f.read()
825
- text = text.replace('\n', ' ')
826
- else:
827
- logger.warning('Item is not text file. mimetype: {}'.format(item.mimetype))
828
- text = item
829
- if os.path.exists(filename):
830
- os.remove(filename)
831
- return text
832
-
833
- @staticmethod
834
- def _uri_to_image(data_uri):
835
- # data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAS4AAAEuCAYAAAAwQP9DAAAU80lEQVR4Xu2da+hnRRnHv0qZKV42LDOt1eyGULoSJBGpRBFprBJBQrBJBBWGSm8jld5WroHUCyEXKutNu2IJ1QtXetULL0uQFCu24WoRsV5KpYvGYzM4nv6X8zu/mTnznPkcWP6XPTPzzOf7/L7/OXPmzDlOHBCAAAScETjOWbyECwEIQEAYF0kAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxDwTeDTkr4s6UxJ/5F0QNK3JD3lu1tbR49xLVld+jYXgcskvSTpIkmnS/qgpJMk/Tv8bHHZ7+PXPw6M5kRJx0t6Ijkv9uUsSW+U9Iykczfp4K8lfXiuztdoF+OqQZk2vBEwUzFTsK9mQNFkotGkhvFeSc+G86NRtdDfd0h6tIVASsSAcZWgSp0eCJjJ7JR0SRgZ2SjHDMp+38Jho7PXTAzkBUmvn1jWRTGMy4VMBJmBgBnSpZLsMs7+paOodao3k/hLqCBe8j0cfj4Yvtp8k/1fPLaaf4pxxXPSS8r4/Vsl3SXp5EHgNjo8JukDkg6v06nWy2JcrSvUX3xmKjYSipdqF0h6V/jgp6Mh+2DHf0YpnSd6p6TTkjml7UZRL4bLPasnmo7VHb+PKsQ20rZTQ6ql1lclfXODxr4u6Ru1gpizHYxrTvq0beZkE9cfkXRxxcu0pyXZaMiMKX71dBfua5sY1Psk/baHtMK4elC5rT5eFS7Z7Otmd8VyRDwcRZkxmUlFo8rRxlx13Clpz6Dxn0r61FwB1W4X46pNvM/27PLPPmhmVhvNLUWTiaZil1/xEswMx/7fbv9bWfs5nfcxommdceQU55eWSNxGihcmHbMRZK45Oxe8MK75ZYofaku8MyQ9J+mQpKNJMqbzLfeHkIeTuPP35JUIbCSVToRvNrKyftqCSfs3nE9qqT+txWKT8OmxT9LnWguyZDwYV0m6m9dtH+SbJNlamw+tGIIl7Va6/VPS8xusP4rN2JojG8E8NrhUS+d4ht/bbfkTJP0umGk6ER7PtfkVmwR/wzaXgEck7Q1mNcfE9oq4mzx9aFxXB55NBlsiKIyrBNXt67xB0q3bn7aYM+xSxkZVNjez5Eu4GoLZ5fb+pCFb/mB/LLo6MK555LaRyUPzND251VUWRJpRxTt2cUJ8csMUfBUBG61en/ymu8tE6zvGNd+nwuao7N8PJO0Kz7JZNDbH9aSkv4fQ0su2RyS9VtKD4dJtOClt5+4Il4Fpz+KkdqzLnpuzdrY74vnppWG6ujx9xMXOsUWPjw8WW27XBv+/GgH7Q2Dzh/G4NoxkV6vF+dkYV1sCRoNpKyqiaYmA/TGxxbXxsD963d3YwLhaSkligcDWBIZTDHajo+RauGb1wLialYbAIPB/BO6Q9Pnkt7dJshs93R0YV3eS02HHBGz+8Owk/vN6nU/EuBxnMaF3RWC4DOJ7kr7UFYGksxhXr8rTb28Eho/5dDvaMuEwLm/pS7w9EhiOtu4Oz332yOLlPmNc3UpPx50QsCUytlg5vXvY5RKIVC+My0n2Ema3BG4Oz7VGAN2PthhxdftZoOOOCKQLTu1RKlvL1f3D6Yy4HGUwoXZHwLaq+X7S6xvDzhrdgRh2GOPqPgUA0DCB9LlE27tsu73zG+5K3tAwrrw8qQ0CuQjYZLztmRaP7vbc2gokxpUrzagHAnkJpNvXMNoasMW48iYbtUEgF4F0Up7RFsaVK6+oBwLFCKST8t3uAMGlYrH8omIIFCFg21zvDjV3uwMExlUkt6gUAkUIDCflu34mcTPCzHEVyT0qhcBkAumLVJiU3wQjxjU5vygIgSIE0l0gutxPfgxVjGsMJc6BQB0C9kC1vW4sHvbik/RlKXWicNAKxuVAJELshkC6fY29sdzecs6xAQGMi7SAQDsE7IW5e0I4PJe4hS4YVztJSyQQsF0fdgYM3E3EuPhEQKB5Aumrx7ibuI1cjLiaz2cC7IRAugyCy0SMq5O0p5veCaSr5blMxLi85zPxd0LgGUmnSOIycYTgXCqOgMQpEChMwJY93MfdxPGUMa7xrDgTAqUIxGUQ7Ck/kjDGNRIUp0GgIIG49xaXiSMhY1wjQXEaBAoRSFfLczdxJGSMayQoToNAIQLpannuJo6EjHGNBMVpEChEgMvECWAxrgnQKAKBTAS4TJwIEuOaCI5iEMhAgMvEiRAxrongKAaBDAS4TJwIEeOaCI5iEFiTQPpQNXcTV4SJca0IjNMhkIlA+sJX7iauCBXjWhEYp0MgE4G49xaLTicAxbgmQKMIBNYkkL6CjPcmToCJcU2ARhEIrEkgfVP1Lkn2Zh+OFQhgXCvA4lQIZCIQl0EckWSjL44VCWBcKwLjdAhkIHBY0vmS9kmy0RfHigQwrhWBcToE1iSQLoO4QtK9a9bXZXGMq0vZ6fSMBOLe8rb3ll0m8sLXCWJgXBOgUQQCaxA4KOlStmheg6AkjGs9fpSGwKoEXgoFbpF086qFOf9/BDAuMgEC9Qike8tfLslGXxwTCGBcE6BRBAITCdgI66ZQls/eRIiMuNYAR1EITCAQ57ful2SjL46JBHD9ieAoBoEJBJjfmgBtoyIYVyaQVAOBbQik67eulmRvruaYSADjmgiOYhBYkUBcv2XFdrB+a0V6g9MxrvX4URoCYwnwfOJYUiPOw7hGQOIUCGQgEPff4vnEDDAxrgwQqQIC2xBI99+6VpKNvjjWIIBxrQGPohAYSSDdf4ttmkdC2+o0jCsDRKqAwDYEmN/KnCIYV2agVAeBDQgclfQW9t/KlxsYVz6W1ASBjQiw/1aBvMC4CkClSggkBOLziey/lTEtMK6MMKkKAhsQsBdhXMj+W3lzA+PKy5PaIJASOF3SsfAL3ladMTcwrowwqQoCAwK8hqxQSmBchcBSLQTCg9S7Jdn8lo2+ODIRwLgygaQaCGxAwF6EcRrLIPLnBsaVnyk1QsAIXCVpf0DBNjaZcwLjygyU6iAQCOyVdH34nm1sMqcFxpUZKNVBIBCIu0HcHUZfgMlIAOPKCJOqIBAIpKvl2Q2iQFpgXAWgUmX3BLhMLJwCGFdhwFTfJQEuEwvLjnEVBkz13RHgpRgVJMe4KkCmia4IpA9Vs+i0kPQYVyGwVNstgQcl7WLRaVn9Ma6yfKm9LwLsvVVJb4yrEmia6YJAvJvIs4mF5ca4CgOm+q4I8GxiJbkxrkqgaWbxBNJnE22OyzYQ5ChEAOMqBJZquyMQ124dkWTvUeQoSADjKgiXqrshcJmk+0Jv2em0guwYVwXINLF4Agck2YaBdvDC1wpyY1wVINPEognYZeHvJZ0g6RFJFyy6t410DuNqRAjCcEvgBkm3huhvl3Sd2544ChzjciQWoTZJIL5+zILjbmIliTCuSqBpZpEE0tePsei0osQYV0XYNLU4Aunrx/ZJsp85KhDAuCpAponFErhT0p7QO5ZBVJQZ46oIm6YWR4D5rZkkxbhmAk+z7gkwvzWjhBjXjPBp2jWBz0i6K/TgN5Iucd0bZ8FjXM4EI9xmCMSdTi2gn0gyI+OoRADjqgSaZhZHIH3Mh1eQVZYX46oMnOYWQyDuBmEdulzSwcX0zEFHMC4HIhFikwReSqLiwerKEmFclYHT3CIIpNvYWIf4HFWWFeCVgdPcIgh8R9JXQk/+KulNi+iVo05gXI7EItRmCPxS0kdDNLalzXuaiayTQDCuToSmm9kI2MJT25751FDjLZJsaQRHRQIYV0XYNLUIAvdIujLpCXcUZ5AV45oBOk26JvCMpFNCD+zO4vGue+M0eIzLqXCEPQuBdBsbC+BeSVfMEknnjWJcnScA3V+JwJOS3pyUuFqSraDnqEwA46oMnOZcE0gXnVpH+PzMJCfgZwJPsy4JYFyNyIZxNSIEYbggMDSuHZKechH5woLEuBYmKN0pSoARV1G84yvHuMaz4sy+CQzvKB6VdE7fSObrPcY1H3ta9kVgeEeRt/rMqB/GNSN8mnZFYHiZyIr5GeXDuGaET9NuCFwlaX8SLTtCzCwdxjWzADTvgkC6v7wFfJukG1xEvtAgMa6FCku3shL4s6QzkxpZMZ8V7+qVYVyrM6NEfwSel3Ri0m3Wb82cAxjXzALQfPMEhvNbf5D07uajXniAGNfCBaZ7axN4VNLbk1pulLR37VqpYC0CGNda+Ci8cAK22+mxQR95o08DomNcDYhACM0SGK6Wt3cpmnFxzEwA45pZAJpvmsBwtTyXiY3IhXE1IgRhNElguFqey8RGZMK4GhGCMJojMLybyGViQxJhXA2JQShNEbhT0p4kIlbLNyQPxtWQGITSFAH2l29KjlcHg3E1LA6hzUrgxcGe8nxWZpUD42oIP6E0SuAiSQ8NYtsl6eFG4+0uLP6KdCc5HR5BYKOFp+y/NQJcrVMwrlqkaccTgQckXTwI+DJJ93vqxJJjxbiWrC59m0LgfEmHBwX/JemEKZVRpgwBjKsMV2r1S8BGVvcNwv+spB/67dLyIse4lqcpPVqPwEbGxcaB6zHNXhrjyo6UCp0TuFLSPYM+XCPpx877tajwMa5FyUlnMhCwveRvHdTDjqcZwOasAuPKSZO6lkDggKTdSUeOSDp3CR1bUh8wriWpSV9yEPiHpJOSinhGMQfVzHVgXJmBUp17AsOtbFgx36CkGFeDohDSbASGj/r8TdIZs0VDw5sSwLhIDgi8QmC4VfPdkmxfLo7GCGBcjQlCOLMSGO7BxVbNs8qxeeMYV6PCENYsBGyX051JyzxYPYsM2zeKcW3PiDP6ITCcmGf9VqPaY1yNCkNY1QkMJ+YPSbLfcTRIAONqUBRCmoXA8BlF1m/NIsO4RjGucZw4a/kEhncUebC6Yc0xrobFIbSqBIbPKDK/VRX/ao1hXKvx4uzlEtgr6frQvUckXbDcrvrvGcblX0N6kIdAaly/kPTxPNVSSwkCGFcJqtTpkUC6+JSFp40riHE1LhDhVSNwUNKloTUm5qthn9YQxjWNG6WWRyA1LlbMN64vxtW4QIRXjcBTkk4LrWFc1bBPawjjmsaNUssjkD7ug3E1ri/G1bhAhFeNQGpcbB5YDfu0hjCuadwotTwCqXGdJ8l2iuBolADG1agwhFWdQGpcfC6q41+tQQRajRdnL5dANK6nJZ2+3G4uo2cY1zJ0pBfrEbDXjz0WquB1ZOuxrFIa46qCmUYaJ/AJST8PMf5K0scaj7f78DCu7lMAAJLSnSFul3QdVNomgHG1rQ/R1SGQPmDNGq46zNdqBeNaCx+FF0LgYUkXhr6wFMKBqBiXA5EIsTgB7igWR5y3AYwrL09q80cg3WueF8A60Q/jciIUYRYjcLOkm0Lt7MNVDHPeijGuvDypzR+BdH6LZxSd6IdxORGKMIsQsBXyx0LNLDwtgrhMpRhXGa7U6oNA+kqyfZLsZw4HBDAuByIRYjEC6T7zbNdcDHP+ijGu/Eyp0Q+BuOspD1b70ezlSDEuZ4IRbjYCF0l6KNTGZWI2rHUqwrjqcKaV9gikj/lwmdiePltGhHE5E4xwsxGIyyC4TMyGtF5FGFc91rTUFoEXJL1OEqvl29JlVDQY1yhMnLQwAuljPl+QdMfC+rf47mBci5eYDm5AIJ3fYjcIhymCcTkUjZDXJhDnt1gtvzbKeSrAuObhTqvzEUj3l78t7H46XzS0PIkAxjUJG4UcE0i3aWYZhFMhMS6nwhH2ZAIHJO0Opcn/yRjnLYhw8/Kn9foE4m6nhyTZ6nkOhwQwLoeiEfJkAryGbDK6tgpiXG3pQTRlCaS7nfJ8YlnWRWvHuIripfLGCLCNTWOCTA0H45pKjnIeCaTbNPP+RI8KclfFsWqEPpVAnJi38jsk2X5cHA4JMOJyKBohTyaQGhe5Pxnj/AURb34NiKAOgXTjQLayqcO8WCsYVzG0VNwYgXRHCNZwNSbOquFgXKsS43yvBOxlr98OwT8g6f1eO0Lc7DlPDvRD4LuSvhi6+zNJn+yn68vrKSOu5WlKjzYmkD6jaKMv25OLwykBjMupcIS9MoH4KjIryK4QK+NrqwDG1ZYeRFOGQDoxby2whqsM52q1YlzVUNPQjAR+JOma0P5zkk6eMRaazkAA48oAkSqaJ/CEpLNClM9KOrX5iAlwSwIYFwmydAJnS3p80MlzJB1deseX3D+Ma8nq0rdIwF6K8bbww58k7QSNbwIYl2/9iH4cAdtA0O4k2rFf0r3jinFWqwQwrlaVIS4IQGBTAhgXyQEBCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwR+C/doIhTZIi/uMAAAAASUVORK5CYII="
836
- image_b64 = data_uri.split(",")[1]
837
- binary = base64.b64decode(image_b64)
838
- image = np.asarray(Image.open(io.BytesIO(binary)))
839
- return image
840
-
841
- def _update_predictions_metadata(self, item: entities.Item, predictions: entities.AnnotationCollection):
842
- """
843
- add model_name and model_id to the metadata of the annotations.
844
- add model_info to the metadata of the system metadata of the annotation.
845
- Add item id to all the annotations in the AnnotationCollection
846
-
847
- :param item: Entity.Item
848
- :param predictions: item's AnnotationCollection
849
- :return:
850
- """
851
- for prediction in predictions:
852
- if prediction.type == entities.AnnotationType.SEGMENTATION:
853
- color = None
854
- try:
855
- color = item.dataset._get_ontology().color_map.get(prediction.label, None)
856
- except (exceptions.BadRequest, exceptions.NotFound):
857
- ...
858
- if color is None:
859
- if self.model_entity._dataset is not None:
860
- try:
861
- color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label,
862
- (255, 255, 255))
863
- except (exceptions.BadRequest, exceptions.NotFound):
864
- ...
865
- if color is None:
866
- logger.warning("Can't get annotation color from model's dataset, using default.")
867
- color = prediction.color
868
- prediction.color = color
869
-
870
- prediction.item_id = item.id
871
- if 'user' in prediction.metadata and 'model' in prediction.metadata['user']:
872
- prediction.metadata['user']['model']['model_id'] = self.model_entity.id
873
- prediction.metadata['user']['model']['name'] = self.model_entity.name
874
- if 'system' not in prediction.metadata:
875
- prediction.metadata['system'] = dict()
876
- if 'model' not in prediction.metadata['system']:
877
- prediction.metadata['system']['model'] = dict()
878
- confidence = prediction.metadata.get('user', dict()).get('model', dict()).get('confidence', None)
879
- prediction.metadata['system']['model'] = {
880
- 'model_id': self.model_entity.id,
881
- 'name': self.model_entity.name,
882
- 'confidence': confidence
883
- }
884
-
885
- ##############################
886
- # Callback Factory functions #
887
- ##############################
888
- @property
889
- def dataloop_keras_callback(self):
890
- """
891
- Returns the constructor for a keras api dump callback
892
- The callback is used for dlp platform to show train losses
893
-
894
- :return: DumpHistoryCallback constructor
895
- """
896
- try:
897
- import keras
898
- except (ImportError, ModuleNotFoundError) as err:
899
- raise RuntimeError(
900
- '{} depends on extenral package. Please install '.format(self.__class__.__name__)) from err
901
-
902
- import os
903
- import time
904
- import json
905
-
906
- class DumpHistoryCallback(keras.callbacks.Callback):
907
- def __init__(self, dump_path):
908
- super().__init__()
909
- if os.path.isdir(dump_path):
910
- dump_path = os.path.join(dump_path,
911
- '__view__training-history__{}.json'.format(time.strftime("%F-%X")))
912
- self.dump_file = dump_path
913
- self.data = dict()
914
-
915
- def on_epoch_end(self, epoch, logs=None):
916
- logs = logs or {}
917
- for name, val in logs.items():
918
- if name not in self.data:
919
- self.data[name] = {'x': list(), 'y': list()}
920
- self.data[name]['x'].append(float(epoch))
921
- self.data[name]['y'].append(float(val))
922
- self.dump_history()
923
-
924
- def dump_history(self):
925
- _json = {
926
- "query": {},
927
- "datasetId": "",
928
- "xlabel": "epoch",
929
- "title": "training loss",
930
- "ylabel": "val",
931
- "type": "metric",
932
- "data": [{"name": name,
933
- "x": values['x'],
934
- "y": values['y']} for name, values in self.data.items()]
935
- }
936
-
937
- with open(self.dump_file, 'w') as f:
938
- json.dump(_json, f, indent=2)
939
-
940
- return DumpHistoryCallback
1
+ import dataclasses
2
+ import tempfile
3
+ import datetime
4
+ import logging
5
+ import string
6
+ import shutil
7
+ import random
8
+ import base64
9
+ import tqdm
10
+ import sys
11
+ import io
12
+ import os
13
+ from PIL import Image
14
+ from functools import partial
15
+ import numpy as np
16
+ from concurrent.futures import ThreadPoolExecutor
17
+ import attr
18
+ from .. import entities, utilities, repositories, exceptions
19
+ from ..services import service_defaults
20
+ from ..services.api_client import ApiClient
21
+
22
+ logger = logging.getLogger('ModelAdapter')
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class AdapterDefaults(dict):
27
+ # for predict items, dataset, evaluate
28
+ upload_annotations: bool = dataclasses.field(default=True)
29
+ clean_annotations: bool = dataclasses.field(default=True)
30
+ # for embeddings
31
+ upload_features: bool = dataclasses.field(default=True)
32
+ # for training
33
+ root_path: str = dataclasses.field(default=None)
34
+ data_path: str = dataclasses.field(default=None)
35
+ output_path: str = dataclasses.field(default=None)
36
+
37
+ def __post_init__(self):
38
+ # Initialize the internal dictionary with the dataclass fields
39
+ self.update(**dataclasses.asdict(self))
40
+
41
+ def update(self, **kwargs):
42
+ for f in dataclasses.fields(AdapterDefaults):
43
+ if f.name in kwargs:
44
+ setattr(self, f.name, kwargs[f.name])
45
+ super().update(**kwargs)
46
+
47
+ def resolve(self, key, *args):
48
+
49
+ for arg in args:
50
+ if arg is not None:
51
+ return arg
52
+ return self.get(key, None)
53
+
54
+
55
+ class BaseModelAdapter(utilities.BaseServiceRunner):
56
+ _client_api = attr.ib(type=ApiClient, repr=False)
57
+
58
+ def __init__(self, model_entity: entities.Model = None):
59
+ self.adapter_defaults = AdapterDefaults()
60
+ self.logger = logger
61
+ # entities
62
+ self._model_entity = None
63
+ self._package = None
64
+ self._base_configuration = dict()
65
+ self.package_name = None
66
+ self.model = None
67
+ self.bucket_path = None
68
+ # funcs
69
+ self.item_to_batch_mapping = {'text': self._item_to_text,
70
+ 'image': self._item_to_image}
71
+ if model_entity is not None:
72
+ self.load_from_model(model_entity=model_entity)
73
+ logger.warning(
74
+ "in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'.")
75
+
76
+ ##################
77
+ # Configurations #
78
+ ##################
79
+
80
+ @property
81
+ def configuration(self) -> dict:
82
+ # load from model
83
+ if self._model_entity is not None:
84
+ configuration = self.model_entity.configuration
85
+ # else - load the default from the package
86
+ elif self._package is not None:
87
+ configuration = self.package.metadata.get('system', {}).get('ml', {}).get('defaultConfiguration', {})
88
+ else:
89
+ configuration = self._base_configuration
90
+ return configuration
91
+
92
+ @configuration.setter
93
+ def configuration(self, d):
94
+ assert isinstance(d, dict)
95
+ if self._model_entity is not None:
96
+ self._model_entity.configuration = d
97
+
98
+ ############
99
+ # Entities #
100
+ ############
101
+ @property
102
+ def model_entity(self):
103
+ if self._model_entity is None:
104
+ raise ValueError(
105
+ "No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
106
+ assert isinstance(self._model_entity, entities.Model)
107
+ return self._model_entity
108
+
109
+ @model_entity.setter
110
+ def model_entity(self, model_entity):
111
+ assert isinstance(model_entity, entities.Model)
112
+ if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
113
+ if self._model_entity.id != model_entity.id:
114
+ self.logger.warning(
115
+ 'Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
116
+ self._model_entity = model_entity
117
+ self.package = model_entity.package
118
+
119
+ @property
120
+ def package(self):
121
+ if self._model_entity is not None:
122
+ self.package = self.model_entity.package
123
+ if self._package is None:
124
+ raise ValueError('Missing Package entity on adapter. Please set: "adapter.package=package"')
125
+ assert isinstance(self._package, (entities.Package, entities.Dpk))
126
+ return self._package
127
+
128
+ @package.setter
129
+ def package(self, package):
130
+ assert isinstance(package, (entities.Package, entities.Dpk))
131
+ self.package_name = package.name
132
+ self._package = package
133
+
134
+ ###################################
135
+ # NEED TO IMPLEMENT THESE METHODS #
136
+ ###################################
137
+
138
+ def load(self, local_path, **kwargs):
139
+ """ Loads model and populates self.model with a `runnable` model
140
+
141
+ Virtual method - need to implement
142
+
143
+ This function is called by load_from_model (download to local and then loads)
144
+
145
+ :param local_path: `str` directory path in local FileSystem
146
+ """
147
+ raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
148
+
149
+ def save(self, local_path, **kwargs):
150
+ """ saves configuration and weights locally
151
+
152
+ Virtual method - need to implement
153
+
154
+ the function is called in save_to_model which first save locally and then uploads to model entity
155
+
156
+ :param local_path: `str` directory path in local FileSystem
157
+ """
158
+ raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
159
+
160
+ def train(self, data_path, output_path, **kwargs):
161
+ """
162
+ Virtual method - need to implement
163
+
164
+ Train the model according to data in data_paths and save the train outputs to output_path,
165
+ this include the weights and any other artifacts created during train
166
+
167
+ :param data_path: `str` local File System path to where the data was downloaded and converted at
168
+ :param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
169
+ """
170
+ raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
171
+
172
+ def predict(self, batch, **kwargs):
173
+ """ Model inference (predictions) on batch of items
174
+
175
+ Virtual method - need to implement
176
+
177
+ :param batch: output of the `prepare_item_func` func
178
+ :return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
179
+ """
180
+ raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
181
+
182
+ def embed(self, batch, **kwargs):
183
+ """ Extract model embeddings on batch of items
184
+
185
+ Virtual method - need to implement
186
+
187
+ :param batch: output of the `prepare_item_func` func
188
+ :return: `list[list]` a feature vector per each item in the batch
189
+ """
190
+ raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
191
+
192
+ def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
193
+ """
194
+ This function evaluates the model prediction on a dataset (with GT annotations).
195
+ The evaluation process will upload the scores and metrics to the platform.
196
+
197
+ :param model: The model to evaluate (annotation.metadata.system.model.name
198
+ :param dataset: Dataset where the model predicted and uploaded its annotations
199
+ :param filters: Filters to query items on the dataset
200
+ :return:
201
+ """
202
+ import dtlpymetrics
203
+ compare_types = model.output_type
204
+ if not filters:
205
+ filters = entities.Filters()
206
+ if filters is not None and isinstance(filters, dict):
207
+ filters = entities.Filters(custom_filter=filters)
208
+ model = dtlpymetrics.scoring.create_model_score(model=model,
209
+ dataset=dataset,
210
+ filters=filters,
211
+ compare_types=compare_types)
212
+ return model
213
+
214
+ def convert_from_dtlpy(self, data_path, **kwargs):
215
+ """ Convert Dataloop structure data to model structured
216
+
217
+ Virtual method - need to implement
218
+
219
+ e.g. take dlp dir structure and construct annotation file
220
+
221
+ :param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
222
+ :return:
223
+ """
224
+ raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
225
+
226
+ #################
227
+ # DTLPY METHODS #
228
+ ################
229
+ def prepare_item_func(self, item: entities.Item):
230
+ """
231
+ Prepare the Dataloop item before calling the `predict` function with a batch.
232
+ A user can override this function to load item differently
233
+ Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
234
+
235
+ :param item:
236
+ :return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
237
+ """
238
+ # Item to batch func
239
+ if isinstance(self.model_entity.input_type, list):
240
+ if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
241
+ processed = self._item_to_text(item)
242
+ elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
243
+ processed = self._item_to_image(item)
244
+ else:
245
+ processed = self._item_to_item(item)
246
+
247
+ elif self.model_entity.input_type in self.item_to_batch_mapping:
248
+ processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
249
+
250
+ else:
251
+ processed = self._item_to_item(item)
252
+
253
+ return processed
254
+
255
+ def prepare_data(self,
256
+ dataset: entities.Dataset,
257
+ # paths
258
+ root_path=None,
259
+ data_path=None,
260
+ output_path=None,
261
+ #
262
+ overwrite=False,
263
+ **kwargs):
264
+ """
265
+ Prepares dataset locally before training or evaluation.
266
+ download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
267
+
268
+ :param dataset: dl.Dataset
269
+ :param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
270
+ :param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
271
+ :param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
272
+
273
+ :param bool overwrite: overwrite the data path (download again). default is False
274
+ """
275
+ # define paths
276
+ dataloop_path = service_defaults.DATALOOP_PATH
277
+ root_path = self.adapter_defaults.resolve("root_path", root_path)
278
+ data_path = self.adapter_defaults.resolve("data_path", data_path)
279
+ output_path = self.adapter_defaults.resolve("output_path", output_path)
280
+
281
+ if root_path is None:
282
+ now = datetime.datetime.now()
283
+ root_path = os.path.join(dataloop_path,
284
+ 'model_data',
285
+ "{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
286
+ now.strftime('%Y-%m-%d-%H%M%S'),
287
+ )
288
+ if data_path is None:
289
+ data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
290
+ os.makedirs(data_path, exist_ok=True)
291
+ if output_path is None:
292
+ output_path = os.path.join(root_path, 'output')
293
+ os.makedirs(output_path, exist_ok=True)
294
+
295
+ if len(os.listdir(data_path)) > 0:
296
+ self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
297
+
298
+ annotation_options = entities.ViewAnnotationOptions.JSON
299
+ if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
300
+ annotation_options = entities.ViewAnnotationOptions.INSTANCE
301
+
302
+ # Download the subset items
303
+ subsets = self.model_entity.metadata.get("system", dict()).get("subsets", None)
304
+ if subsets is None:
305
+ raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
306
+ for subset, filters_dict in subsets.items():
307
+ filters = entities.Filters(custom_filter=filters_dict)
308
+ data_subset_base_path = os.path.join(data_path, subset)
309
+ if os.path.isdir(data_subset_base_path) and not overwrite:
310
+ # existing and dont overwrite
311
+ self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
312
+ else:
313
+ self.logger.debug("Downloading subset {!r} of {}".format(subset,
314
+ self.model_entity.dataset.name))
315
+
316
+ annotation_filters = None
317
+ if self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
318
+ annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
319
+ if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION,
320
+ entities.AnnotationType.POLYGON]:
321
+ model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
322
+ else:
323
+ model_output_types = [self.model_entity.output_type]
324
+
325
+ annotation_filters.add(
326
+ field=entities.FiltersKnownFields.TYPE,
327
+ values=model_output_types,
328
+ operator=entities.FiltersOperations.IN
329
+ )
330
+
331
+ if not self.configuration.get("include_model_annotations", False):
332
+ annotation_filters.add(
333
+ field="metadata.system.model.name",
334
+ values=False,
335
+ operator=entities.FiltersOperations.EXISTS
336
+ )
337
+
338
+ ret_list = dataset.items.download(filters=filters,
339
+ local_path=data_subset_base_path,
340
+ annotation_options=annotation_options,
341
+ annotation_filters=annotation_filters
342
+ )
343
+ if isinstance(ret_list, list) and len(ret_list) == 0:
344
+ if annotation_filters is not None:
345
+ annotation_filters_str = annotation_filters.prepare()
346
+ raise ValueError(f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
347
+ f"Subset {subset} filters: {filters.prepare()}\nAnnotation filters: {annotation_filters_str}")
348
+ else:
349
+ raise ValueError(f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
350
+ f"Subset {subset} filters: {filters.prepare()}")
351
+
352
+ self.convert_from_dtlpy(data_path=data_path, **kwargs)
353
+ return root_path, data_path, output_path
354
+
355
+ def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
356
+ """ Loads a model from given `dl.Model`.
357
+ Reads configurations and instantiate self.model_entity
358
+ Downloads the model_entity bucket (if available)
359
+
360
+ :param model_entity: `str` dl.Model entity
361
+ :param local_path: `str` directory path in local FileSystem to download the model_entity to
362
+ :param overwrite: `bool` (default False) if False does not download files with same name else (True) download all
363
+ """
364
+ if model_entity is not None:
365
+ self.model_entity = model_entity
366
+ if local_path is None:
367
+ local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
368
+ # Load configuration
369
+ self.configuration = self.model_entity.configuration
370
+ # Update the adapter config with the model config to run over defaults if needed
371
+ self.adapter_defaults.update(**self.configuration)
372
+ # Download
373
+ self.model_entity.artifacts.download(
374
+ local_path=local_path,
375
+ overwrite=overwrite
376
+ )
377
+ self.load(local_path, **kwargs)
378
+
379
+ def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
380
+ """
381
+ Saves the model state to a new bucket and configuration
382
+
383
+ Saves configuration and weights to artifacts
384
+ Mark the model as `trained`
385
+ loads only applies for remote buckets
386
+
387
+ :param local_path: `str` directory path in local FileSystem to save the current model bucket (weights) (default will create a temp dir)
388
+ :param replace: `bool` will clean the bucket's content before uploading new files
389
+ :param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
390
+ :return:
391
+ """
392
+
393
+ if local_path is None:
394
+ local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
395
+ self.logger.debug("Using temporary dir at {}".format(local_path))
396
+
397
+ self.save(local_path=local_path, **kwargs)
398
+
399
+ if self.model_entity is None:
400
+ raise ValueError('Missing model entity on the adapter. '
401
+ 'Please set before saving: "adapter.model_entity=model"')
402
+
403
+ self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'),
404
+ overwrite=True)
405
+ if cleanup:
406
+ shutil.rmtree(path=local_path, ignore_errors=True)
407
+ self.logger.info("Clean-up. deleting {}".format(local_path))
408
+
409
+ # ===============
410
+ # SERVICE METHODS
411
+ # ===============
412
+
413
+ @entities.Package.decorators.function(display_name='Predict Items',
414
+ inputs={'items': 'Item[]'},
415
+ outputs={'items': 'Item[]', 'annotations': 'Annotation[]'})
416
+ def predict_items(self, items: list, upload_annotations=None, clean_annotations=None, batch_size=None, **kwargs):
417
+ """
418
+ Run the predict function on the input list of items (or single) and return the items and the predictions.
419
+ Each prediction is by the model output type (package.output_type) and model_info in the metadata
420
+
421
+ :param items: `List[dl.Item]` list of items to predict
422
+ :param upload_annotations: `bool` uploads the predictions on the given items
423
+ :param clean_annotations: `bool` deletes previous model annotations (predictions) before uploading new ones
424
+ :param batch_size: `int` size of batch to run a single inference
425
+
426
+ :return: `List[dl.Item]`, `List[List[dl.Annotation]]`
427
+ """
428
+ if batch_size is None:
429
+ batch_size = self.configuration.get('batch_size', 4)
430
+ upload_annotations = self.adapter_defaults.resolve("upload_annotations", upload_annotations)
431
+ clean_annotations = self.adapter_defaults.resolve("clean_annotations", clean_annotations)
432
+ input_type = self.model_entity.input_type
433
+ self.logger.debug(
434
+ "Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
435
+ pool = ThreadPoolExecutor(max_workers=16)
436
+
437
+ annotations = list()
438
+ for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None,
439
+ file=sys.stdout):
440
+ batch_items = items[i_batch: i_batch + batch_size]
441
+ batch = list(pool.map(self.prepare_item_func, batch_items))
442
+ batch_collections = self.predict(batch, **kwargs)
443
+ _futures = list(pool.map(partial(self._update_predictions_metadata),
444
+ batch_items,
445
+ batch_collections))
446
+ # Loop over the futures to make sure they are all done to avoid race conditions
447
+ _ = [_f for _f in _futures]
448
+ if upload_annotations is True:
449
+ self.logger.debug(
450
+ "Uploading items' annotation for model {!r}.".format(self.model_entity.name))
451
+ try:
452
+ batch_collections = list(pool.map(partial(self._upload_model_annotations,
453
+ clean_annotations=clean_annotations),
454
+ batch_items,
455
+ batch_collections))
456
+ except Exception as err:
457
+ self.logger.exception("Failed to upload annotations items.")
458
+
459
+ for collection in batch_collections:
460
+ # function needs to return `List[List[dl.Annotation]]`
461
+ # convert annotation collection to a list of dl.Annotation for each batch
462
+ if isinstance(collection, entities.AnnotationCollection):
463
+ annotations.extend([annotation for annotation in collection.annotations])
464
+ else:
465
+ logger.warning(f'RETURN TYPE MAY BE INVALID: {type(collection)}')
466
+ annotations.extend(collection)
467
+ # TODO call the callback
468
+
469
+ pool.shutdown()
470
+ return items, annotations
471
+
472
+ @entities.Package.decorators.function(display_name='Embed Items',
473
+ inputs={'items': 'Item[]'},
474
+ outputs={'items': 'Item[]', 'features': 'Json[]'})
475
+ def embed_items(self, items: list, upload_features=None, batch_size=None, progress:utilities.Progress=None, **kwargs):
476
+ """
477
+ Extract feature from an input list of items (or single) and return the items and the feature vector.
478
+
479
+ :param items: `List[dl.Item]` list of items to embed
480
+ :param upload_features: `bool` uploads the features on the given items
481
+ :param batch_size: `int` size of batch to run a single embed
482
+
483
+ :return: `List[dl.Item]`, `List[List[vector]]`
484
+ """
485
+ if batch_size is None:
486
+ batch_size = self.configuration.get('batch_size', 4)
487
+ upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
488
+ input_type = self.model_entity.input_type
489
+ self.logger.debug(
490
+ "Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
491
+
492
+ # Search for existing feature set for this model id
493
+ feature_set = self.model_entity.feature_set
494
+ if feature_set is None:
495
+ logger.info('Feature Set not found. creating... ')
496
+ try:
497
+ self.model_entity.project.feature_sets.get(feature_set_name=self.model_entity.name)
498
+ feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
499
+ logger.warning(
500
+ f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}")
501
+ except exceptions.NotFound:
502
+ feature_set_name = self.model_entity.name
503
+ feature_set = self.model_entity.project.feature_sets.create(name=feature_set_name,
504
+ entity_type=entities.FeatureEntityType.ITEM,
505
+ model_id=self.model_entity.id,
506
+ project_id=self.model_entity.project_id,
507
+ set_type=self.model_entity.name,
508
+ size=self.configuration.get('embeddings_size',
509
+ 256))
510
+ logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
511
+ else:
512
+ logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
513
+
514
+ # upload the feature vectors
515
+ pool = ThreadPoolExecutor(max_workers=16)
516
+ vectors = list()
517
+ for i_batch in tqdm.tqdm(range(0, len(items), batch_size),
518
+ desc='embedding',
519
+ unit='bt',
520
+ leave=None,
521
+ file=sys.stdout):
522
+ batch_items = items[i_batch: i_batch + batch_size]
523
+ batch = list(pool.map(self.prepare_item_func, batch_items))
524
+ batch_vectors = self.embed(batch, **kwargs)
525
+ vectors.extend(batch_vectors)
526
+ if upload_features is True:
527
+ self.logger.debug(
528
+ "Uploading items' feature vectors for model {!r}.".format(self.model_entity.name))
529
+ try:
530
+ list(pool.map(partial(self._upload_model_features,
531
+ progress.logger if progress is not None else self.logger,
532
+ feature_set.id,
533
+ self.model_entity.project_id),
534
+ batch_items,
535
+ batch_vectors))
536
+ except Exception as err:
537
+ self.logger.exception("Failed to upload feature vectors to items.")
538
+
539
+ pool.shutdown()
540
+ return items, vectors
541
+
542
+ @entities.Package.decorators.function(display_name='Embed Dataset with DQL',
543
+ inputs={'dataset': 'Dataset',
544
+ 'filters': 'Json'})
545
+ def embed_dataset(self,
546
+ dataset: entities.Dataset,
547
+ filters: entities.Filters = None,
548
+ upload_features=None,
549
+ batch_size=None,
550
+ progress:utilities.Progress=None,
551
+ **kwargs):
552
+ """
553
+ Extract feature from all items given
554
+
555
+ :param dataset: Dataset entity to predict
556
+ :param filters: Filters entity for a filtering before embedding
557
+ :param upload_features: `bool` uploads the features back to the given items
558
+ :param batch_size: `int` size of batch to run a single embed
559
+
560
+ :return: `bool` indicating if the embedding process completed successfully
561
+ """
562
+ if batch_size is None:
563
+ batch_size = self.configuration.get('batch_size', 4)
564
+ upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
565
+
566
+ self.logger.debug("Creating embeddings for dataset (name:{}, id:{}), using batch size {}".format(dataset.name,
567
+ dataset.id,
568
+ batch_size))
569
+ if not filters:
570
+ filters = entities.Filters()
571
+ if filters is not None and isinstance(filters, dict):
572
+ filters = entities.Filters(custom_filter=filters)
573
+ pages = dataset.items.list(filters=filters, page_size=batch_size)
574
+ # Item type is 'file' only, can be deleted if default filters are added to custom filters
575
+ items = [item for page in pages for item in page if item.type == 'file']
576
+ self.embed_items(items=items,
577
+ upload_features=upload_features,
578
+ batch_size=batch_size,
579
+ progress=progress,
580
+ **kwargs)
581
+ return True
582
+
583
+ @entities.Package.decorators.function(display_name='Predict Dataset with DQL',
584
+ inputs={'dataset': 'Dataset',
585
+ 'filters': 'Json'})
586
+ def predict_dataset(self,
587
+ dataset: entities.Dataset,
588
+ filters: entities.Filters = None,
589
+ upload_annotations=None,
590
+ clean_annotations=None,
591
+ batch_size=None,
592
+ **kwargs):
593
+ """
594
+ Predicts all items given
595
+
596
+ :param dataset: Dataset entity to predict
597
+ :param filters: Filters entity for a filtering before predicting
598
+ :param upload_annotations: `bool` uploads the predictions back to the given items
599
+ :param clean_annotations: `bool` if set removes existing predictions with the same package-model name (default: False)
600
+ :param batch_size: `int` size of batch to run a single inference
601
+
602
+ :return: `bool` indicating if the prediction process completed successfully
603
+ """
604
+
605
+ if batch_size is None:
606
+ batch_size = self.configuration.get('batch_size', 4)
607
+
608
+ self.logger.debug("Predicting dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
609
+ dataset.id,
610
+ batch_size))
611
+ if not filters:
612
+ filters = entities.Filters()
613
+ if filters is not None and isinstance(filters, dict):
614
+ filters = entities.Filters(custom_filter=filters)
615
+ pages = dataset.items.list(filters=filters, page_size=batch_size)
616
+ # Item type is 'file' only, can be deleted if default filters are added to custom filters
617
+ items = [item for page in pages for item in page if item.type == 'file']
618
+ self.predict_items(items=items,
619
+ upload_annotations=upload_annotations,
620
+ clean_annotations=clean_annotations,
621
+ batch_size=batch_size,
622
+ **kwargs)
623
+ return True
624
+
625
+ @entities.Package.decorators.function(display_name='Train a Model',
626
+ inputs={'model': entities.Model},
627
+ outputs={'model': entities.Model})
628
+ def train_model(self,
629
+ model: entities.Model,
630
+ cleanup=False,
631
+ progress: utilities.Progress = None,
632
+ context: utilities.Context = None):
633
+ """
634
+ Train on existing model.
635
+ data will be taken from dl.Model.datasetId
636
+ configuration is as defined in dl.Model.configuration
637
+ upload the output the model's bucket (model.bucket)
638
+ """
639
+ if isinstance(model, dict):
640
+ model = repositories.Models(client_api=self._client_api).get(model_id=model['id'])
641
+ output_path = None
642
+ try:
643
+ logger.info("Received {s} for training".format(s=model.id))
644
+ model = model.wait_for_model_ready()
645
+ if model.status == 'failed':
646
+ raise ValueError("Model is in failed state, cannot train.")
647
+
648
+ ##############
649
+ # Set status #
650
+ ##############
651
+ model.status = 'training'
652
+ if context is not None:
653
+ if 'system' not in model.metadata:
654
+ model.metadata['system'] = dict()
655
+ model.update()
656
+
657
+ ##########################
658
+ # load model and weights #
659
+ ##########################
660
+ logger.info("Loading Adapter with: {n} ({i!r})".format(n=model.name, i=model.id))
661
+ self.load_from_model(model_entity=model)
662
+
663
+ ################
664
+ # prepare data #
665
+ ################
666
+ root_path, data_path, output_path = self.prepare_data(
667
+ dataset=self.model_entity.dataset,
668
+ root_path=os.path.join('tmp', model.id)
669
+ )
670
+ # Start the Train
671
+ logger.info("Training {p_name!r} with model {m_name!r} on data {d_path!r}".
672
+ format(p_name=self.package_name, m_name=model.id, d_path=data_path))
673
+ if progress is not None:
674
+ progress.update(message='starting training')
675
+
676
+ def on_epoch_end_callback(i_epoch, n_epoch):
677
+ if progress is not None:
678
+ progress.update(progress=int(100 * (i_epoch + 1) / n_epoch),
679
+ message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
680
+
681
+ self.train(data_path=data_path,
682
+ output_path=output_path,
683
+ on_epoch_end_callback=on_epoch_end_callback)
684
+ if progress is not None:
685
+ progress.update(message='saving model',
686
+ progress=99)
687
+
688
+ self.save_to_model(local_path=output_path, replace=True)
689
+ model.status = 'trained'
690
+ model.update()
691
+ ###########
692
+ # cleanup #
693
+ ###########
694
+ if cleanup:
695
+ shutil.rmtree(output_path, ignore_errors=True)
696
+ except Exception:
697
+ # save also on fail
698
+ if output_path is not None:
699
+ self.save_to_model(local_path=output_path, replace=True)
700
+ logger.info('Execution failed. Setting model.status to failed')
701
+ raise
702
+ return model
703
+
704
+ @entities.Package.decorators.function(display_name='Evaluate a Model',
705
+ inputs={'model': entities.Model,
706
+ 'dataset': entities.Dataset,
707
+ 'filters': 'Json'},
708
+ outputs={'model': entities.Model,
709
+ 'dataset': entities.Dataset
710
+ })
711
+ def evaluate_model(self,
712
+ model: entities.Model,
713
+ dataset: entities.Dataset,
714
+ filters: entities.Filters = None,
715
+ #
716
+ progress: utilities.Progress = None,
717
+ context: utilities.Context = None):
718
+ """
719
+ Evaluate a model.
720
+ data will be downloaded from the dataset and query
721
+ configuration is as defined in dl.Model.configuration
722
+ upload annotations and calculate metrics vs GT
723
+
724
+ :param model: Model entity to run prediction
725
+ :param dataset: Dataset to evaluate
726
+ :param filters: Filter for specific items from dataset
727
+ :param progress: dl.Progress for report FaaS progress
728
+ :param context:
729
+ :return:
730
+ """
731
+ logger.info(
732
+ f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
733
+ ##########################
734
+ # load model and weights #
735
+ ##########################
736
+ logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
737
+ self.load_from_model(dataset=dataset,
738
+ model_entity=model)
739
+
740
+ ##############
741
+ # Predicting #
742
+ ##############
743
+ logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
744
+ if not filters:
745
+ filters = entities.Filters()
746
+ self.predict_dataset(dataset=dataset,
747
+ filters=filters,
748
+ with_upload=True)
749
+
750
+ ##############
751
+ # Evaluating #
752
+ ##############
753
+ logger.info(f"Starting adapter.evaluate()")
754
+ if progress is not None:
755
+ progress.update(message='calculating metrics',
756
+ progress=98)
757
+ model = self.evaluate(model=model,
758
+ dataset=dataset,
759
+ filters=filters)
760
+ #########
761
+ # Done! #
762
+ #########
763
+ if progress is not None:
764
+ progress.update(message='finishing evaluation',
765
+ progress=99)
766
+ return model, dataset
767
+
768
+ # =============
769
+ # INNER METHODS
770
+ # =============
771
+
772
+ @staticmethod
773
+ def _upload_model_features(logger, feature_set_id, project_id, item: entities.Item, vector):
774
+ try:
775
+ if vector is not None:
776
+ item.features.create(value=vector,
777
+ project_id=project_id,
778
+ feature_set_id=feature_set_id,
779
+ entity=item)
780
+ except Exception as e:
781
+ logger.error(f'Failed to upload feature vector of length {len(vector)} to item {item.id}, Error: {e}')
782
+
783
+ def _upload_model_annotations(self, item: entities.Item, predictions, clean_annotations):
784
+ """
785
+ Utility function that upload prediction to dlp platform based on the package.output_type
786
+ :param predictions: `dl.AnnotationCollection`
787
+ :param cleanup: `bool` if set removes existing predictions with the same package-model name
788
+ """
789
+ if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
790
+ raise TypeError('predictions was expected to be of type {}, but instead it is {}'.
791
+ format(entities.AnnotationCollection, type(predictions)))
792
+ if clean_annotations:
793
+ clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
794
+ clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name)
795
+ # clean_filter.add(field='type', values=self.model_entity.output_type,)
796
+ item.annotations.delete(filters=clean_filter)
797
+ annotations = item.annotations.upload(annotations=predictions)
798
+ return annotations
799
+
800
+ @staticmethod
801
+ def _item_to_image(item):
802
+ """
803
+ Preprocess items before calling the `predict` functions.
804
+ Convert item to numpy array
805
+
806
+ :param item:
807
+ :return:
808
+ """
809
+ buffer = item.download(save_locally=False)
810
+ image = np.asarray(Image.open(buffer))
811
+ return image
812
+
813
+ @staticmethod
814
+ def _item_to_item(item):
815
+ """
816
+ Default item to batch function.
817
+ This function should prepare a single item for the predict function, e.g. for images, it loads the image as numpy array
818
+ :param item:
819
+ :return:
820
+ """
821
+ return item
822
+
823
+ @staticmethod
824
+ def _item_to_text(item):
825
+ filename = item.download(overwrite=True)
826
+ text = None
827
+ if item.mimetype == 'text/plain' or item.mimetype == 'text/markdown':
828
+ with open(filename, 'r') as f:
829
+ text = f.read()
830
+ text = text.replace('\n', ' ')
831
+ else:
832
+ logger.warning('Item is not text file. mimetype: {}'.format(item.mimetype))
833
+ text = item
834
+ if os.path.exists(filename):
835
+ os.remove(filename)
836
+ return text
837
+
838
+ @staticmethod
839
+ def _uri_to_image(data_uri):
840
+ # data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAS4AAAEuCAYAAAAwQP9DAAAU80lEQVR4Xu2da+hnRRnHv0qZKV42LDOt1eyGULoSJBGpRBFprBJBQrBJBBWGSm8jld5WroHUCyEXKutNu2IJ1QtXetULL0uQFCu24WoRsV5KpYvGYzM4nv6X8zu/mTnznPkcWP6XPTPzzOf7/L7/OXPmzDlOHBCAAAScETjOWbyECwEIQEAYF0kAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxDwTeDTkr4s6UxJ/5F0QNK3JD3lu1tbR49xLVld+jYXgcskvSTpIkmnS/qgpJMk/Tv8bHHZ7+PXPw6M5kRJx0t6Ijkv9uUsSW+U9Iykczfp4K8lfXiuztdoF+OqQZk2vBEwUzFTsK9mQNFkotGkhvFeSc+G86NRtdDfd0h6tIVASsSAcZWgSp0eCJjJ7JR0SRgZ2SjHDMp+38Jho7PXTAzkBUmvn1jWRTGMy4VMBJmBgBnSpZLsMs7+paOodao3k/hLqCBe8j0cfj4Yvtp8k/1fPLaaf4pxxXPSS8r4/Vsl3SXp5EHgNjo8JukDkg6v06nWy2JcrSvUX3xmKjYSipdqF0h6V/jgp6Mh+2DHf0YpnSd6p6TTkjml7UZRL4bLPasnmo7VHb+PKsQ20rZTQ6ql1lclfXODxr4u6Ru1gpizHYxrTvq0beZkE9cfkXRxxcu0pyXZaMiMKX71dBfua5sY1Psk/baHtMK4elC5rT5eFS7Z7Otmd8VyRDwcRZkxmUlFo8rRxlx13Clpz6Dxn0r61FwB1W4X46pNvM/27PLPPmhmVhvNLUWTiaZil1/xEswMx/7fbv9bWfs5nfcxommdceQU55eWSNxGihcmHbMRZK45Oxe8MK75ZYofaku8MyQ9J+mQpKNJMqbzLfeHkIeTuPP35JUIbCSVToRvNrKyftqCSfs3nE9qqT+txWKT8OmxT9LnWguyZDwYV0m6m9dtH+SbJNlamw+tGIIl7Va6/VPS8xusP4rN2JojG8E8NrhUS+d4ht/bbfkTJP0umGk6ER7PtfkVmwR/wzaXgEck7Q1mNcfE9oq4mzx9aFxXB55NBlsiKIyrBNXt67xB0q3bn7aYM+xSxkZVNjez5Eu4GoLZ5fb+pCFb/mB/LLo6MK555LaRyUPzND251VUWRJpRxTt2cUJ8csMUfBUBG61en/ymu8tE6zvGNd+nwuao7N8PJO0Kz7JZNDbH9aSkv4fQ0su2RyS9VtKD4dJtOClt5+4Il4Fpz+KkdqzLnpuzdrY74vnppWG6ujx9xMXOsUWPjw8WW27XBv+/GgH7Q2Dzh/G4NoxkV6vF+dkYV1sCRoNpKyqiaYmA/TGxxbXxsD963d3YwLhaSkligcDWBIZTDHajo+RauGb1wLialYbAIPB/BO6Q9Pnkt7dJshs93R0YV3eS02HHBGz+8Owk/vN6nU/EuBxnMaF3RWC4DOJ7kr7UFYGksxhXr8rTb28Eho/5dDvaMuEwLm/pS7w9EhiOtu4Oz332yOLlPmNc3UpPx50QsCUytlg5vXvY5RKIVC+My0n2Ema3BG4Oz7VGAN2PthhxdftZoOOOCKQLTu1RKlvL1f3D6Yy4HGUwoXZHwLaq+X7S6xvDzhrdgRh2GOPqPgUA0DCB9LlE27tsu73zG+5K3tAwrrw8qQ0CuQjYZLztmRaP7vbc2gokxpUrzagHAnkJpNvXMNoasMW48iYbtUEgF4F0Up7RFsaVK6+oBwLFCKST8t3uAMGlYrH8omIIFCFg21zvDjV3uwMExlUkt6gUAkUIDCflu34mcTPCzHEVyT0qhcBkAumLVJiU3wQjxjU5vygIgSIE0l0gutxPfgxVjGsMJc6BQB0C9kC1vW4sHvbik/RlKXWicNAKxuVAJELshkC6fY29sdzecs6xAQGMi7SAQDsE7IW5e0I4PJe4hS4YVztJSyQQsF0fdgYM3E3EuPhEQKB5Aumrx7ibuI1cjLiaz2cC7IRAugyCy0SMq5O0p5veCaSr5blMxLi85zPxd0LgGUmnSOIycYTgXCqOgMQpEChMwJY93MfdxPGUMa7xrDgTAqUIxGUQ7Ck/kjDGNRIUp0GgIIG49xaXiSMhY1wjQXEaBAoRSFfLczdxJGSMayQoToNAIQLpannuJo6EjHGNBMVpEChEgMvECWAxrgnQKAKBTAS4TJwIEuOaCI5iEMhAgMvEiRAxrongKAaBDAS4TJwIEeOaCI5iEFiTQPpQNXcTV4SJca0IjNMhkIlA+sJX7iauCBXjWhEYp0MgE4G49xaLTicAxbgmQKMIBNYkkL6CjPcmToCJcU2ARhEIrEkgfVP1Lkn2Zh+OFQhgXCvA4lQIZCIQl0EckWSjL44VCWBcKwLjdAhkIHBY0vmS9kmy0RfHigQwrhWBcToE1iSQLoO4QtK9a9bXZXGMq0vZ6fSMBOLe8rb3ll0m8sLXCWJgXBOgUQQCaxA4KOlStmheg6AkjGs9fpSGwKoEXgoFbpF086qFOf9/BDAuMgEC9Qike8tfLslGXxwTCGBcE6BRBAITCdgI66ZQls/eRIiMuNYAR1EITCAQ57ful2SjL46JBHD9ieAoBoEJBJjfmgBtoyIYVyaQVAOBbQik67eulmRvruaYSADjmgiOYhBYkUBcv2XFdrB+a0V6g9MxrvX4URoCYwnwfOJYUiPOw7hGQOIUCGQgEPff4vnEDDAxrgwQqQIC2xBI99+6VpKNvjjWIIBxrQGPohAYSSDdf4ttmkdC2+o0jCsDRKqAwDYEmN/KnCIYV2agVAeBDQgclfQW9t/KlxsYVz6W1ASBjQiw/1aBvMC4CkClSggkBOLziey/lTEtMK6MMKkKAhsQsBdhXMj+W3lzA+PKy5PaIJASOF3SsfAL3ladMTcwrowwqQoCAwK8hqxQSmBchcBSLQTCg9S7Jdn8lo2+ODIRwLgygaQaCGxAwF6EcRrLIPLnBsaVnyk1QsAIXCVpf0DBNjaZcwLjygyU6iAQCOyVdH34nm1sMqcFxpUZKNVBIBCIu0HcHUZfgMlIAOPKCJOqIBAIpKvl2Q2iQFpgXAWgUmX3BLhMLJwCGFdhwFTfJQEuEwvLjnEVBkz13RHgpRgVJMe4KkCmia4IpA9Vs+i0kPQYVyGwVNstgQcl7WLRaVn9Ma6yfKm9LwLsvVVJb4yrEmia6YJAvJvIs4mF5ca4CgOm+q4I8GxiJbkxrkqgaWbxBNJnE22OyzYQ5ChEAOMqBJZquyMQ124dkWTvUeQoSADjKgiXqrshcJmk+0Jv2em0guwYVwXINLF4Agck2YaBdvDC1wpyY1wVINPEognYZeHvJZ0g6RFJFyy6t410DuNqRAjCcEvgBkm3huhvl3Sd2544ChzjciQWoTZJIL5+zILjbmIliTCuSqBpZpEE0tePsei0osQYV0XYNLU4Aunrx/ZJsp85KhDAuCpAponFErhT0p7QO5ZBVJQZ46oIm6YWR4D5rZkkxbhmAk+z7gkwvzWjhBjXjPBp2jWBz0i6K/TgN5Iucd0bZ8FjXM4EI9xmCMSdTi2gn0gyI+OoRADjqgSaZhZHIH3Mh1eQVZYX46oMnOYWQyDuBmEdulzSwcX0zEFHMC4HIhFikwReSqLiwerKEmFclYHT3CIIpNvYWIf4HFWWFeCVgdPcIgh8R9JXQk/+KulNi+iVo05gXI7EItRmCPxS0kdDNLalzXuaiayTQDCuToSmm9kI2MJT25751FDjLZJsaQRHRQIYV0XYNLUIAvdIujLpCXcUZ5AV45oBOk26JvCMpFNCD+zO4vGue+M0eIzLqXCEPQuBdBsbC+BeSVfMEknnjWJcnScA3V+JwJOS3pyUuFqSraDnqEwA46oMnOZcE0gXnVpH+PzMJCfgZwJPsy4JYFyNyIZxNSIEYbggMDSuHZKechH5woLEuBYmKN0pSoARV1G84yvHuMaz4sy+CQzvKB6VdE7fSObrPcY1H3ta9kVgeEeRt/rMqB/GNSN8mnZFYHiZyIr5GeXDuGaET9NuCFwlaX8SLTtCzCwdxjWzADTvgkC6v7wFfJukG1xEvtAgMa6FCku3shL4s6QzkxpZMZ8V7+qVYVyrM6NEfwSel3Ri0m3Wb82cAxjXzALQfPMEhvNbf5D07uajXniAGNfCBaZ7axN4VNLbk1pulLR37VqpYC0CGNda+Ci8cAK22+mxQR95o08DomNcDYhACM0SGK6Wt3cpmnFxzEwA45pZAJpvmsBwtTyXiY3IhXE1IgRhNElguFqey8RGZMK4GhGCMJojMLybyGViQxJhXA2JQShNEbhT0p4kIlbLNyQPxtWQGITSFAH2l29KjlcHg3E1LA6hzUrgxcGe8nxWZpUD42oIP6E0SuAiSQ8NYtsl6eFG4+0uLP6KdCc5HR5BYKOFp+y/NQJcrVMwrlqkaccTgQckXTwI+DJJ93vqxJJjxbiWrC59m0LgfEmHBwX/JemEKZVRpgwBjKsMV2r1S8BGVvcNwv+spB/67dLyIse4lqcpPVqPwEbGxcaB6zHNXhrjyo6UCp0TuFLSPYM+XCPpx877tajwMa5FyUlnMhCwveRvHdTDjqcZwOasAuPKSZO6lkDggKTdSUeOSDp3CR1bUh8wriWpSV9yEPiHpJOSinhGMQfVzHVgXJmBUp17AsOtbFgx36CkGFeDohDSbASGj/r8TdIZs0VDw5sSwLhIDgi8QmC4VfPdkmxfLo7GCGBcjQlCOLMSGO7BxVbNs8qxeeMYV6PCENYsBGyX051JyzxYPYsM2zeKcW3PiDP6ITCcmGf9VqPaY1yNCkNY1QkMJ+YPSbLfcTRIAONqUBRCmoXA8BlF1m/NIsO4RjGucZw4a/kEhncUebC6Yc0xrobFIbSqBIbPKDK/VRX/ao1hXKvx4uzlEtgr6frQvUckXbDcrvrvGcblX0N6kIdAaly/kPTxPNVSSwkCGFcJqtTpkUC6+JSFp40riHE1LhDhVSNwUNKloTUm5qthn9YQxjWNG6WWRyA1LlbMN64vxtW4QIRXjcBTkk4LrWFc1bBPawjjmsaNUssjkD7ug3E1ri/G1bhAhFeNQGpcbB5YDfu0hjCuadwotTwCqXGdJ8l2iuBolADG1agwhFWdQGpcfC6q41+tQQRajRdnL5dANK6nJZ2+3G4uo2cY1zJ0pBfrEbDXjz0WquB1ZOuxrFIa46qCmUYaJ/AJST8PMf5K0scaj7f78DCu7lMAAJLSnSFul3QdVNomgHG1rQ/R1SGQPmDNGq46zNdqBeNaCx+FF0LgYUkXhr6wFMKBqBiXA5EIsTgB7igWR5y3AYwrL09q80cg3WueF8A60Q/jciIUYRYjcLOkm0Lt7MNVDHPeijGuvDypzR+BdH6LZxSd6IdxORGKMIsQsBXyx0LNLDwtgrhMpRhXGa7U6oNA+kqyfZLsZw4HBDAuByIRYjEC6T7zbNdcDHP+ijGu/Eyp0Q+BuOspD1b70ezlSDEuZ4IRbjYCF0l6KNTGZWI2rHUqwrjqcKaV9gikj/lwmdiePltGhHE5E4xwsxGIyyC4TMyGtF5FGFc91rTUFoEXJL1OEqvl29JlVDQY1yhMnLQwAuljPl+QdMfC+rf47mBci5eYDm5AIJ3fYjcIhymCcTkUjZDXJhDnt1gtvzbKeSrAuObhTqvzEUj3l78t7H46XzS0PIkAxjUJG4UcE0i3aWYZhFMhMS6nwhH2ZAIHJO0Opcn/yRjnLYhw8/Kn9foE4m6nhyTZ6nkOhwQwLoeiEfJkAryGbDK6tgpiXG3pQTRlCaS7nfJ8YlnWRWvHuIripfLGCLCNTWOCTA0H45pKjnIeCaTbNPP+RI8KclfFsWqEPpVAnJi38jsk2X5cHA4JMOJyKBohTyaQGhe5Pxnj/AURb34NiKAOgXTjQLayqcO8WCsYVzG0VNwYgXRHCNZwNSbOquFgXKsS43yvBOxlr98OwT8g6f1eO0Lc7DlPDvRD4LuSvhi6+zNJn+yn68vrKSOu5WlKjzYmkD6jaKMv25OLwykBjMupcIS9MoH4KjIryK4QK+NrqwDG1ZYeRFOGQDoxby2whqsM52q1YlzVUNPQjAR+JOma0P5zkk6eMRaazkAA48oAkSqaJ/CEpLNClM9KOrX5iAlwSwIYFwmydAJnS3p80MlzJB1deseX3D+Ma8nq0rdIwF6K8bbww58k7QSNbwIYl2/9iH4cAdtA0O4k2rFf0r3jinFWqwQwrlaVIS4IQGBTAhgXyQEBCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwR+C/doIhTZIi/uMAAAAASUVORK5CYII="
841
+ image_b64 = data_uri.split(",")[1]
842
+ binary = base64.b64decode(image_b64)
843
+ image = np.asarray(Image.open(io.BytesIO(binary)))
844
+ return image
845
+
846
+ def _update_predictions_metadata(self, item: entities.Item, predictions: entities.AnnotationCollection):
847
+ """
848
+ add model_name and model_id to the metadata of the annotations.
849
+ add model_info to the metadata of the system metadata of the annotation.
850
+ Add item id to all the annotations in the AnnotationCollection
851
+
852
+ :param item: Entity.Item
853
+ :param predictions: item's AnnotationCollection
854
+ :return:
855
+ """
856
+ for prediction in predictions:
857
+ if prediction.type == entities.AnnotationType.SEGMENTATION:
858
+ color = None
859
+ try:
860
+ color = item.dataset._get_ontology().color_map.get(prediction.label, None)
861
+ except (exceptions.BadRequest, exceptions.NotFound):
862
+ ...
863
+ if color is None:
864
+ if self.model_entity._dataset is not None:
865
+ try:
866
+ color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label,
867
+ (255, 255, 255))
868
+ except (exceptions.BadRequest, exceptions.NotFound):
869
+ ...
870
+ if color is None:
871
+ logger.warning("Can't get annotation color from model's dataset, using default.")
872
+ color = prediction.color
873
+ prediction.color = color
874
+
875
+ prediction.item_id = item.id
876
+ if 'user' in prediction.metadata and 'model' in prediction.metadata['user']:
877
+ prediction.metadata['user']['model']['model_id'] = self.model_entity.id
878
+ prediction.metadata['user']['model']['name'] = self.model_entity.name
879
+ if 'system' not in prediction.metadata:
880
+ prediction.metadata['system'] = dict()
881
+ if 'model' not in prediction.metadata['system']:
882
+ prediction.metadata['system']['model'] = dict()
883
+ confidence = prediction.metadata.get('user', dict()).get('model', dict()).get('confidence', None)
884
+ prediction.metadata['system']['model'] = {
885
+ 'model_id': self.model_entity.id,
886
+ 'name': self.model_entity.name,
887
+ 'confidence': confidence
888
+ }
889
+
890
+ ##############################
891
+ # Callback Factory functions #
892
+ ##############################
893
+ @property
894
+ def dataloop_keras_callback(self):
895
+ """
896
+ Returns the constructor for a keras api dump callback
897
+ The callback is used for dlp platform to show train losses
898
+
899
+ :return: DumpHistoryCallback constructor
900
+ """
901
+ try:
902
+ import keras
903
+ except (ImportError, ModuleNotFoundError) as err:
904
+ raise RuntimeError(
905
+ '{} depends on extenral package. Please install '.format(self.__class__.__name__)) from err
906
+
907
+ import os
908
+ import time
909
+ import json
910
+
911
+ class DumpHistoryCallback(keras.callbacks.Callback):
912
+ def __init__(self, dump_path):
913
+ super().__init__()
914
+ if os.path.isdir(dump_path):
915
+ dump_path = os.path.join(dump_path,
916
+ '__view__training-history__{}.json'.format(time.strftime("%F-%X")))
917
+ self.dump_file = dump_path
918
+ self.data = dict()
919
+
920
+ def on_epoch_end(self, epoch, logs=None):
921
+ logs = logs or {}
922
+ for name, val in logs.items():
923
+ if name not in self.data:
924
+ self.data[name] = {'x': list(), 'y': list()}
925
+ self.data[name]['x'].append(float(epoch))
926
+ self.data[name]['y'].append(float(val))
927
+ self.dump_history()
928
+
929
+ def dump_history(self):
930
+ _json = {
931
+ "query": {},
932
+ "datasetId": "",
933
+ "xlabel": "epoch",
934
+ "title": "training loss",
935
+ "ylabel": "val",
936
+ "type": "metric",
937
+ "data": [{"name": name,
938
+ "x": values['x'],
939
+ "y": values['y']} for name, values in self.data.items()]
940
+ }
941
+
942
+ with open(self.dump_file, 'w') as f:
943
+ json.dump(_json, f, indent=2)
944
+
945
+ return DumpHistoryCallback