dtlpy 1.114.17__py3-none-any.whl → 1.116.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtlpy/__init__.py +491 -491
- dtlpy/__version__.py +1 -1
- dtlpy/assets/__init__.py +26 -26
- dtlpy/assets/code_server/config.yaml +2 -2
- dtlpy/assets/code_server/installation.sh +24 -24
- dtlpy/assets/code_server/launch.json +13 -13
- dtlpy/assets/code_server/settings.json +2 -2
- dtlpy/assets/main.py +53 -53
- dtlpy/assets/main_partial.py +18 -18
- dtlpy/assets/mock.json +11 -11
- dtlpy/assets/model_adapter.py +83 -83
- dtlpy/assets/package.json +61 -61
- dtlpy/assets/package_catalog.json +29 -29
- dtlpy/assets/package_gitignore +307 -307
- dtlpy/assets/service_runners/__init__.py +33 -33
- dtlpy/assets/service_runners/converter.py +96 -96
- dtlpy/assets/service_runners/multi_method.py +49 -49
- dtlpy/assets/service_runners/multi_method_annotation.py +54 -54
- dtlpy/assets/service_runners/multi_method_dataset.py +55 -55
- dtlpy/assets/service_runners/multi_method_item.py +52 -52
- dtlpy/assets/service_runners/multi_method_json.py +52 -52
- dtlpy/assets/service_runners/single_method.py +37 -37
- dtlpy/assets/service_runners/single_method_annotation.py +43 -43
- dtlpy/assets/service_runners/single_method_dataset.py +43 -43
- dtlpy/assets/service_runners/single_method_item.py +41 -41
- dtlpy/assets/service_runners/single_method_json.py +42 -42
- dtlpy/assets/service_runners/single_method_multi_input.py +45 -45
- dtlpy/assets/voc_annotation_template.xml +23 -23
- dtlpy/caches/base_cache.py +32 -32
- dtlpy/caches/cache.py +473 -473
- dtlpy/caches/dl_cache.py +201 -201
- dtlpy/caches/filesystem_cache.py +89 -89
- dtlpy/caches/redis_cache.py +84 -84
- dtlpy/dlp/__init__.py +20 -20
- dtlpy/dlp/cli_utilities.py +367 -367
- dtlpy/dlp/command_executor.py +764 -764
- dtlpy/dlp/dlp +1 -1
- dtlpy/dlp/dlp.bat +1 -1
- dtlpy/dlp/dlp.py +128 -128
- dtlpy/dlp/parser.py +651 -651
- dtlpy/entities/__init__.py +83 -83
- dtlpy/entities/analytic.py +347 -311
- dtlpy/entities/annotation.py +1879 -1879
- dtlpy/entities/annotation_collection.py +699 -699
- dtlpy/entities/annotation_definitions/__init__.py +20 -20
- dtlpy/entities/annotation_definitions/base_annotation_definition.py +100 -100
- dtlpy/entities/annotation_definitions/box.py +195 -195
- dtlpy/entities/annotation_definitions/classification.py +67 -67
- dtlpy/entities/annotation_definitions/comparison.py +72 -72
- dtlpy/entities/annotation_definitions/cube.py +204 -204
- dtlpy/entities/annotation_definitions/cube_3d.py +149 -149
- dtlpy/entities/annotation_definitions/description.py +32 -32
- dtlpy/entities/annotation_definitions/ellipse.py +124 -124
- dtlpy/entities/annotation_definitions/free_text.py +62 -62
- dtlpy/entities/annotation_definitions/gis.py +69 -69
- dtlpy/entities/annotation_definitions/note.py +139 -139
- dtlpy/entities/annotation_definitions/point.py +117 -117
- dtlpy/entities/annotation_definitions/polygon.py +182 -182
- dtlpy/entities/annotation_definitions/polyline.py +111 -111
- dtlpy/entities/annotation_definitions/pose.py +92 -92
- dtlpy/entities/annotation_definitions/ref_image.py +86 -86
- dtlpy/entities/annotation_definitions/segmentation.py +240 -240
- dtlpy/entities/annotation_definitions/subtitle.py +34 -34
- dtlpy/entities/annotation_definitions/text.py +85 -85
- dtlpy/entities/annotation_definitions/undefined_annotation.py +74 -74
- dtlpy/entities/app.py +220 -220
- dtlpy/entities/app_module.py +107 -107
- dtlpy/entities/artifact.py +174 -174
- dtlpy/entities/assignment.py +399 -399
- dtlpy/entities/base_entity.py +214 -214
- dtlpy/entities/bot.py +113 -113
- dtlpy/entities/codebase.py +292 -296
- dtlpy/entities/collection.py +38 -38
- dtlpy/entities/command.py +169 -169
- dtlpy/entities/compute.py +449 -442
- dtlpy/entities/dataset.py +1299 -1285
- dtlpy/entities/directory_tree.py +44 -44
- dtlpy/entities/dpk.py +470 -470
- dtlpy/entities/driver.py +235 -223
- dtlpy/entities/execution.py +397 -397
- dtlpy/entities/feature.py +124 -124
- dtlpy/entities/feature_set.py +145 -145
- dtlpy/entities/filters.py +798 -645
- dtlpy/entities/gis_item.py +107 -107
- dtlpy/entities/integration.py +184 -184
- dtlpy/entities/item.py +959 -953
- dtlpy/entities/label.py +123 -123
- dtlpy/entities/links.py +85 -85
- dtlpy/entities/message.py +175 -175
- dtlpy/entities/model.py +684 -684
- dtlpy/entities/node.py +1005 -1005
- dtlpy/entities/ontology.py +810 -803
- dtlpy/entities/organization.py +287 -287
- dtlpy/entities/package.py +657 -657
- dtlpy/entities/package_defaults.py +5 -5
- dtlpy/entities/package_function.py +185 -185
- dtlpy/entities/package_module.py +113 -113
- dtlpy/entities/package_slot.py +118 -118
- dtlpy/entities/paged_entities.py +299 -299
- dtlpy/entities/pipeline.py +624 -624
- dtlpy/entities/pipeline_execution.py +279 -279
- dtlpy/entities/project.py +394 -394
- dtlpy/entities/prompt_item.py +505 -499
- dtlpy/entities/recipe.py +301 -301
- dtlpy/entities/reflect_dict.py +102 -102
- dtlpy/entities/resource_execution.py +138 -138
- dtlpy/entities/service.py +963 -958
- dtlpy/entities/service_driver.py +117 -117
- dtlpy/entities/setting.py +294 -294
- dtlpy/entities/task.py +495 -495
- dtlpy/entities/time_series.py +143 -143
- dtlpy/entities/trigger.py +426 -426
- dtlpy/entities/user.py +118 -118
- dtlpy/entities/webhook.py +124 -124
- dtlpy/examples/__init__.py +19 -19
- dtlpy/examples/add_labels.py +135 -135
- dtlpy/examples/add_metadata_to_item.py +21 -21
- dtlpy/examples/annotate_items_using_model.py +65 -65
- dtlpy/examples/annotate_video_using_model_and_tracker.py +75 -75
- dtlpy/examples/annotations_convert_to_voc.py +9 -9
- dtlpy/examples/annotations_convert_to_yolo.py +9 -9
- dtlpy/examples/convert_annotation_types.py +51 -51
- dtlpy/examples/converter.py +143 -143
- dtlpy/examples/copy_annotations.py +22 -22
- dtlpy/examples/copy_folder.py +31 -31
- dtlpy/examples/create_annotations.py +51 -51
- dtlpy/examples/create_video_annotations.py +83 -83
- dtlpy/examples/delete_annotations.py +26 -26
- dtlpy/examples/filters.py +113 -113
- dtlpy/examples/move_item.py +23 -23
- dtlpy/examples/play_video_annotation.py +13 -13
- dtlpy/examples/show_item_and_mask.py +53 -53
- dtlpy/examples/triggers.py +49 -49
- dtlpy/examples/upload_batch_of_items.py +20 -20
- dtlpy/examples/upload_items_and_custom_format_annotations.py +55 -55
- dtlpy/examples/upload_items_with_modalities.py +43 -43
- dtlpy/examples/upload_segmentation_annotations_from_mask_image.py +44 -44
- dtlpy/examples/upload_yolo_format_annotations.py +70 -70
- dtlpy/exceptions.py +125 -125
- dtlpy/miscellaneous/__init__.py +20 -20
- dtlpy/miscellaneous/dict_differ.py +95 -95
- dtlpy/miscellaneous/git_utils.py +217 -217
- dtlpy/miscellaneous/json_utils.py +14 -14
- dtlpy/miscellaneous/list_print.py +105 -105
- dtlpy/miscellaneous/zipping.py +130 -130
- dtlpy/ml/__init__.py +20 -20
- dtlpy/ml/base_feature_extractor_adapter.py +27 -27
- dtlpy/ml/base_model_adapter.py +1257 -1086
- dtlpy/ml/metrics.py +461 -461
- dtlpy/ml/predictions_utils.py +274 -274
- dtlpy/ml/summary_writer.py +57 -57
- dtlpy/ml/train_utils.py +60 -60
- dtlpy/new_instance.py +252 -252
- dtlpy/repositories/__init__.py +56 -56
- dtlpy/repositories/analytics.py +85 -85
- dtlpy/repositories/annotations.py +916 -916
- dtlpy/repositories/apps.py +383 -383
- dtlpy/repositories/artifacts.py +452 -452
- dtlpy/repositories/assignments.py +599 -599
- dtlpy/repositories/bots.py +213 -213
- dtlpy/repositories/codebases.py +559 -559
- dtlpy/repositories/collections.py +332 -332
- dtlpy/repositories/commands.py +152 -158
- dtlpy/repositories/compositions.py +61 -61
- dtlpy/repositories/computes.py +439 -435
- dtlpy/repositories/datasets.py +1504 -1291
- dtlpy/repositories/downloader.py +976 -903
- dtlpy/repositories/dpks.py +433 -433
- dtlpy/repositories/drivers.py +482 -470
- dtlpy/repositories/executions.py +815 -817
- dtlpy/repositories/feature_sets.py +226 -226
- dtlpy/repositories/features.py +255 -238
- dtlpy/repositories/integrations.py +484 -484
- dtlpy/repositories/items.py +912 -909
- dtlpy/repositories/messages.py +94 -94
- dtlpy/repositories/models.py +1000 -988
- dtlpy/repositories/nodes.py +80 -80
- dtlpy/repositories/ontologies.py +511 -511
- dtlpy/repositories/organizations.py +525 -525
- dtlpy/repositories/packages.py +1941 -1941
- dtlpy/repositories/pipeline_executions.py +451 -451
- dtlpy/repositories/pipelines.py +640 -640
- dtlpy/repositories/projects.py +539 -539
- dtlpy/repositories/recipes.py +419 -399
- dtlpy/repositories/resource_executions.py +137 -137
- dtlpy/repositories/schema.py +120 -120
- dtlpy/repositories/service_drivers.py +213 -213
- dtlpy/repositories/services.py +1704 -1704
- dtlpy/repositories/settings.py +339 -339
- dtlpy/repositories/tasks.py +1477 -1477
- dtlpy/repositories/times_series.py +278 -278
- dtlpy/repositories/triggers.py +536 -536
- dtlpy/repositories/upload_element.py +257 -257
- dtlpy/repositories/uploader.py +661 -651
- dtlpy/repositories/webhooks.py +249 -249
- dtlpy/services/__init__.py +22 -22
- dtlpy/services/aihttp_retry.py +131 -131
- dtlpy/services/api_client.py +1785 -1782
- dtlpy/services/api_reference.py +40 -40
- dtlpy/services/async_utils.py +133 -133
- dtlpy/services/calls_counter.py +44 -44
- dtlpy/services/check_sdk.py +68 -68
- dtlpy/services/cookie.py +115 -115
- dtlpy/services/create_logger.py +156 -156
- dtlpy/services/events.py +84 -84
- dtlpy/services/logins.py +235 -235
- dtlpy/services/reporter.py +256 -256
- dtlpy/services/service_defaults.py +91 -91
- dtlpy/utilities/__init__.py +20 -20
- dtlpy/utilities/annotations/__init__.py +16 -16
- dtlpy/utilities/annotations/annotation_converters.py +269 -269
- dtlpy/utilities/base_package_runner.py +285 -264
- dtlpy/utilities/converter.py +1650 -1650
- dtlpy/utilities/dataset_generators/__init__.py +1 -1
- dtlpy/utilities/dataset_generators/dataset_generator.py +670 -670
- dtlpy/utilities/dataset_generators/dataset_generator_tensorflow.py +23 -23
- dtlpy/utilities/dataset_generators/dataset_generator_torch.py +21 -21
- dtlpy/utilities/local_development/__init__.py +1 -1
- dtlpy/utilities/local_development/local_session.py +179 -179
- dtlpy/utilities/reports/__init__.py +2 -2
- dtlpy/utilities/reports/figures.py +343 -343
- dtlpy/utilities/reports/report.py +71 -71
- dtlpy/utilities/videos/__init__.py +17 -17
- dtlpy/utilities/videos/video_player.py +598 -598
- dtlpy/utilities/videos/videos.py +470 -470
- {dtlpy-1.114.17.data → dtlpy-1.116.6.data}/scripts/dlp +1 -1
- dtlpy-1.116.6.data/scripts/dlp.bat +2 -0
- {dtlpy-1.114.17.data → dtlpy-1.116.6.data}/scripts/dlp.py +128 -128
- {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/METADATA +186 -183
- dtlpy-1.116.6.dist-info/RECORD +239 -0
- {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/WHEEL +1 -1
- {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/licenses/LICENSE +200 -200
- tests/features/environment.py +551 -551
- dtlpy/assets/__pycache__/__init__.cpython-310.pyc +0 -0
- dtlpy-1.114.17.data/scripts/dlp.bat +0 -2
- dtlpy-1.114.17.dist-info/RECORD +0 -240
- {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/entry_points.txt +0 -0
- {dtlpy-1.114.17.dist-info → dtlpy-1.116.6.dist-info}/top_level.txt +0 -0
dtlpy/ml/base_model_adapter.py
CHANGED
|
@@ -1,1086 +1,1257 @@
|
|
|
1
|
-
import dataclasses
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
import
|
|
7
|
-
import
|
|
8
|
-
import
|
|
9
|
-
import
|
|
10
|
-
import
|
|
11
|
-
import
|
|
12
|
-
import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
import
|
|
16
|
-
|
|
17
|
-
import
|
|
18
|
-
from
|
|
19
|
-
from
|
|
20
|
-
|
|
21
|
-
from
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
self.
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
if
|
|
53
|
-
self.
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
self.
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
self.
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
self.
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
self.
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
self.
|
|
215
|
-
self.
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
self.
|
|
232
|
-
self.
|
|
233
|
-
|
|
234
|
-
###################################
|
|
235
|
-
# NEED TO IMPLEMENT THESE METHODS #
|
|
236
|
-
###################################
|
|
237
|
-
|
|
238
|
-
def load(self, local_path, **kwargs):
|
|
239
|
-
"""
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
"
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
"""
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
"""
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
"""
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
:param
|
|
398
|
-
""
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
if
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
self.model_entity
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
self.
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
:
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
self.
|
|
575
|
-
|
|
576
|
-
pool = ThreadPoolExecutor(max_workers=16)
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
batch_items = items[i_batch: i_batch + batch_size]
|
|
582
|
-
batch = list(pool.map(self.prepare_item_func, batch_items))
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
:
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
def
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
:param
|
|
740
|
-
:param
|
|
741
|
-
:param
|
|
742
|
-
|
|
743
|
-
:
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
:
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
#
|
|
883
|
-
|
|
884
|
-
logger.info(f"
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
#
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
"""
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
:param
|
|
960
|
-
:
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
'
|
|
1027
|
-
|
|
1028
|
-
'
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1
|
+
import dataclasses
|
|
2
|
+
import threading
|
|
3
|
+
import tempfile
|
|
4
|
+
import datetime
|
|
5
|
+
import logging
|
|
6
|
+
import string
|
|
7
|
+
import shutil
|
|
8
|
+
import random
|
|
9
|
+
import base64
|
|
10
|
+
import copy
|
|
11
|
+
import time
|
|
12
|
+
import tqdm
|
|
13
|
+
import traceback
|
|
14
|
+
import sys
|
|
15
|
+
import io
|
|
16
|
+
import os
|
|
17
|
+
from itertools import chain
|
|
18
|
+
from PIL import Image
|
|
19
|
+
from functools import partial
|
|
20
|
+
import numpy as np
|
|
21
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
22
|
+
import attr
|
|
23
|
+
from collections.abc import MutableMapping
|
|
24
|
+
from typing import Optional
|
|
25
|
+
from .. import entities, utilities, repositories, exceptions
|
|
26
|
+
from ..services import service_defaults
|
|
27
|
+
from ..services.api_client import ApiClient
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger('ModelAdapter')
|
|
30
|
+
|
|
31
|
+
# Constants
|
|
32
|
+
PREDICT_EMBED_DEFAULT_SUBSET_LIMIT = 1000
|
|
33
|
+
PREDICT_EMBED_DEFAULT_TIMEOUT = 3600 * 24
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ModelConfigurations(MutableMapping):
|
|
37
|
+
"""
|
|
38
|
+
Manages model configuration using composition with a backing dict.
|
|
39
|
+
|
|
40
|
+
Uses MutableMapping to implement dict-like behavior without inheritance.
|
|
41
|
+
This avoids duplication: if we inherited from dict, we'd have two dicts
|
|
42
|
+
(one from inheritance, one from model_entity.configuration), leading to
|
|
43
|
+
data inconsistency and maintenance issues.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, base_model_adapter):
|
|
47
|
+
# Store reference to base_model_adapter dictionary
|
|
48
|
+
self._backing_dict = {}
|
|
49
|
+
|
|
50
|
+
if base_model_adapter is not None and base_model_adapter.model_entity is not None and base_model_adapter.model_entity.configuration is not None:
|
|
51
|
+
self._backing_dict = base_model_adapter.model_entity.configuration
|
|
52
|
+
if 'include_background' not in self._backing_dict:
|
|
53
|
+
self._backing_dict['include_background'] = False
|
|
54
|
+
self._base_model_adapter = base_model_adapter
|
|
55
|
+
# Don't call _update_model_entity during initialization to avoid premature updates
|
|
56
|
+
|
|
57
|
+
def _update_model_entity(self):
|
|
58
|
+
if self._base_model_adapter is not None and self._base_model_adapter.model_entity is not None:
|
|
59
|
+
self._base_model_adapter.model_entity.update(reload_services=False)
|
|
60
|
+
|
|
61
|
+
def __ior__(self, other):
|
|
62
|
+
self.update(other)
|
|
63
|
+
return self
|
|
64
|
+
|
|
65
|
+
# Required MutableMapping abstract methods
|
|
66
|
+
def __getitem__(self, key):
|
|
67
|
+
return self._backing_dict[key]
|
|
68
|
+
|
|
69
|
+
def __setitem__(self, key, value):
|
|
70
|
+
# Note: This method only updates the backing dict, not object attributes.
|
|
71
|
+
# If you need to also update object attributes, be careful to avoid
|
|
72
|
+
# infinite recursion by not calling __setattr__ from here.
|
|
73
|
+
update = False
|
|
74
|
+
if key not in self._backing_dict or self._backing_dict.get(key) != value:
|
|
75
|
+
update = True
|
|
76
|
+
self._backing_dict[key] = value
|
|
77
|
+
if update:
|
|
78
|
+
self._update_model_entity()
|
|
79
|
+
|
|
80
|
+
def __delitem__(self, key):
|
|
81
|
+
del self._backing_dict[key]
|
|
82
|
+
|
|
83
|
+
def __iter__(self):
|
|
84
|
+
return iter(self._backing_dict)
|
|
85
|
+
|
|
86
|
+
def __len__(self):
|
|
87
|
+
return len(self._backing_dict)
|
|
88
|
+
|
|
89
|
+
def get(self, key, default=None):
|
|
90
|
+
if key not in self._backing_dict:
|
|
91
|
+
self.__setitem__(key, default)
|
|
92
|
+
return self._backing_dict.get(key)
|
|
93
|
+
|
|
94
|
+
def update(self, *args, **kwargs):
|
|
95
|
+
# Check if there will be any modifications
|
|
96
|
+
update_dict = dict(*args, **kwargs)
|
|
97
|
+
has_changes = False
|
|
98
|
+
for key, value in update_dict.items():
|
|
99
|
+
if key not in self._backing_dict or self._backing_dict[key] != value:
|
|
100
|
+
has_changes = True
|
|
101
|
+
break
|
|
102
|
+
self._backing_dict.update(*args, **kwargs)
|
|
103
|
+
|
|
104
|
+
if has_changes:
|
|
105
|
+
self._update_model_entity()
|
|
106
|
+
|
|
107
|
+
def setdefault(self, key, default=None):
|
|
108
|
+
if key not in self._backing_dict:
|
|
109
|
+
self._backing_dict[key] = default
|
|
110
|
+
return self._backing_dict[key]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclasses.dataclass
|
|
114
|
+
class AdapterDefaults(ModelConfigurations):
|
|
115
|
+
# for predict items, dataset, evaluate
|
|
116
|
+
upload_annotations: bool = dataclasses.field(default=True)
|
|
117
|
+
clean_annotations: bool = dataclasses.field(default=True)
|
|
118
|
+
overwrite_annotations: bool = dataclasses.field(default=True)
|
|
119
|
+
# for embeddings
|
|
120
|
+
upload_features: bool = dataclasses.field(default=None)
|
|
121
|
+
# for training
|
|
122
|
+
root_path: str = dataclasses.field(default=None)
|
|
123
|
+
data_path: str = dataclasses.field(default=None)
|
|
124
|
+
output_path: str = dataclasses.field(default=None)
|
|
125
|
+
|
|
126
|
+
def __init__(self, base_model_adapter=None):
|
|
127
|
+
super().__init__(base_model_adapter)
|
|
128
|
+
for f in dataclasses.fields(AdapterDefaults):
|
|
129
|
+
# if the field exists in model_entity.configuration, use it
|
|
130
|
+
# else set it from the attribute default value
|
|
131
|
+
if super().get(f.name) is not None:
|
|
132
|
+
super().__setattr__(f.name, super().get(f.name))
|
|
133
|
+
else:
|
|
134
|
+
super().__setitem__(f.name, f.default)
|
|
135
|
+
|
|
136
|
+
def __setattr__(self, key, value):
|
|
137
|
+
# Dataclass-like fields behave as attributes, so map to dict
|
|
138
|
+
super().__setattr__(key, value)
|
|
139
|
+
if not key.startswith("_"):
|
|
140
|
+
super().__setitem__(key, value)
|
|
141
|
+
|
|
142
|
+
def update(self, *args, **kwargs):
|
|
143
|
+
for f in dataclasses.fields(AdapterDefaults):
|
|
144
|
+
if f.name in kwargs:
|
|
145
|
+
setattr(self, f.name, kwargs[f.name])
|
|
146
|
+
super().update(*args, **kwargs)
|
|
147
|
+
|
|
148
|
+
def resolve(self, key, *args):
|
|
149
|
+
for arg in args:
|
|
150
|
+
if arg is not None:
|
|
151
|
+
super().__setitem__(key, arg)
|
|
152
|
+
return arg
|
|
153
|
+
return self.get(key, None)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
157
|
+
_client_api = attr.ib(type=ApiClient, repr=False)
|
|
158
|
+
_feature_set_lock = threading.Lock()
|
|
159
|
+
|
|
160
|
+
def __init__(self, model_entity: entities.Model = None):
|
|
161
|
+
self.logger = logger
|
|
162
|
+
# entities
|
|
163
|
+
self._model_entity = None
|
|
164
|
+
self._configuration = None
|
|
165
|
+
self.adapter_defaults = None
|
|
166
|
+
self.model = None
|
|
167
|
+
self.bucket_path = None
|
|
168
|
+
self._project = None
|
|
169
|
+
self._feature_set = None
|
|
170
|
+
# funcs
|
|
171
|
+
self.item_to_batch_mapping = {'text': self._item_to_text, 'image': self._item_to_image}
|
|
172
|
+
if model_entity is not None:
|
|
173
|
+
self.load_from_model(model_entity=model_entity)
|
|
174
|
+
logger.warning(
|
|
175
|
+
"in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'."
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
##################
|
|
179
|
+
# Configurations #
|
|
180
|
+
##################
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def configuration(self) -> dict:
|
|
184
|
+
# load from model
|
|
185
|
+
if self._model_entity is not None:
|
|
186
|
+
configuration = self._configuration
|
|
187
|
+
else:
|
|
188
|
+
configuration = dict()
|
|
189
|
+
return configuration
|
|
190
|
+
|
|
191
|
+
@configuration.setter
|
|
192
|
+
def configuration(self, configuration: dict):
|
|
193
|
+
assert isinstance(configuration, dict)
|
|
194
|
+
if self._model_entity is not None:
|
|
195
|
+
# Update configuration with received dict
|
|
196
|
+
self._model_entity.configuration = configuration
|
|
197
|
+
self.adapter_defaults = AdapterDefaults(self)
|
|
198
|
+
self._configuration = self.adapter_defaults
|
|
199
|
+
|
|
200
|
+
############
|
|
201
|
+
# Entities #
|
|
202
|
+
############
|
|
203
|
+
@property
|
|
204
|
+
def project(self):
|
|
205
|
+
if self._project is None:
|
|
206
|
+
self._project = self.model_entity.project
|
|
207
|
+
assert isinstance(self._project, entities.Project)
|
|
208
|
+
return self._project
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def feature_set(self):
|
|
212
|
+
if self._feature_set is None:
|
|
213
|
+
self._feature_set = self._get_feature_set()
|
|
214
|
+
assert isinstance(self._feature_set, entities.FeatureSet)
|
|
215
|
+
return self._feature_set
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def model_entity(self):
|
|
219
|
+
if self._model_entity is None:
|
|
220
|
+
raise ValueError("No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
|
|
221
|
+
assert isinstance(self._model_entity, entities.Model)
|
|
222
|
+
return self._model_entity
|
|
223
|
+
|
|
224
|
+
@model_entity.setter
|
|
225
|
+
def model_entity(self, model_entity):
|
|
226
|
+
assert isinstance(model_entity, entities.Model)
|
|
227
|
+
if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
|
|
228
|
+
if self._model_entity.id != model_entity.id:
|
|
229
|
+
self.logger.warning('Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
|
|
230
|
+
self._model_entity = model_entity
|
|
231
|
+
self.adapter_defaults = AdapterDefaults(self)
|
|
232
|
+
self._configuration = self.adapter_defaults
|
|
233
|
+
|
|
234
|
+
###################################
|
|
235
|
+
# NEED TO IMPLEMENT THESE METHODS #
|
|
236
|
+
###################################
|
|
237
|
+
|
|
238
|
+
def load(self, local_path, **kwargs):
|
|
239
|
+
"""
|
|
240
|
+
Loads model and populates self.model with a `runnable` model
|
|
241
|
+
|
|
242
|
+
Virtual method - need to implement
|
|
243
|
+
|
|
244
|
+
This function is called by load_from_model (download to local and then loads)
|
|
245
|
+
|
|
246
|
+
:param local_path: `str` directory path in local FileSystem
|
|
247
|
+
"""
|
|
248
|
+
raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
|
|
249
|
+
|
|
250
|
+
def save(self, local_path, **kwargs):
|
|
251
|
+
"""
|
|
252
|
+
Saves configuration and weights locally
|
|
253
|
+
|
|
254
|
+
Virtual method - need to implement
|
|
255
|
+
|
|
256
|
+
the function is called in save_to_model which first save locally and then uploads to model entity
|
|
257
|
+
|
|
258
|
+
:param local_path: `str` directory path in local FileSystem
|
|
259
|
+
"""
|
|
260
|
+
raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
|
|
261
|
+
|
|
262
|
+
def train(self, data_path, output_path, **kwargs):
|
|
263
|
+
"""
|
|
264
|
+
Virtual method - need to implement
|
|
265
|
+
|
|
266
|
+
Train the model according to data in data_paths and save the train outputs to output_path,
|
|
267
|
+
this include the weights and any other artifacts created during train
|
|
268
|
+
|
|
269
|
+
:param data_path: `str` local File System path to where the data was downloaded and converted at
|
|
270
|
+
:param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
|
|
271
|
+
"""
|
|
272
|
+
raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
|
|
273
|
+
|
|
274
|
+
def predict(self, batch, **kwargs):
|
|
275
|
+
"""
|
|
276
|
+
Model inference (predictions) on batch of items
|
|
277
|
+
|
|
278
|
+
Virtual method - need to implement
|
|
279
|
+
|
|
280
|
+
:param batch: output of the `prepare_item_func` func
|
|
281
|
+
:return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
|
|
282
|
+
"""
|
|
283
|
+
raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
|
|
284
|
+
|
|
285
|
+
def embed(self, batch, **kwargs):
|
|
286
|
+
"""
|
|
287
|
+
Extract model embeddings on batch of items
|
|
288
|
+
|
|
289
|
+
Virtual method - need to implement
|
|
290
|
+
|
|
291
|
+
:param batch: output of the `prepare_item_func` func
|
|
292
|
+
:return: `list[list]` a feature vector per each item in the batch
|
|
293
|
+
"""
|
|
294
|
+
raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
|
|
295
|
+
|
|
296
|
+
def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
|
|
297
|
+
"""
|
|
298
|
+
This function evaluates the model prediction on a dataset (with GT annotations).
|
|
299
|
+
The evaluation process will upload the scores and metrics to the platform.
|
|
300
|
+
|
|
301
|
+
:param model: The model to evaluate (annotation.metadata.system.model.name
|
|
302
|
+
:param dataset: Dataset where the model predicted and uploaded its annotations
|
|
303
|
+
:param filters: Filters to query items on the dataset
|
|
304
|
+
:return:
|
|
305
|
+
"""
|
|
306
|
+
import dtlpymetrics
|
|
307
|
+
|
|
308
|
+
compare_types = model.output_type
|
|
309
|
+
if not filters:
|
|
310
|
+
filters = entities.Filters()
|
|
311
|
+
if filters is not None and isinstance(filters, dict):
|
|
312
|
+
filters = entities.Filters(custom_filter=filters)
|
|
313
|
+
model = dtlpymetrics.scoring.create_model_score(
|
|
314
|
+
model=model,
|
|
315
|
+
dataset=dataset,
|
|
316
|
+
filters=filters,
|
|
317
|
+
compare_types=compare_types,
|
|
318
|
+
)
|
|
319
|
+
return model
|
|
320
|
+
|
|
321
|
+
def convert_from_dtlpy(self, data_path, **kwargs):
|
|
322
|
+
"""Convert Dataloop structure data to model structured
|
|
323
|
+
|
|
324
|
+
Virtual method - need to implement
|
|
325
|
+
|
|
326
|
+
e.g. take dlp dir structure and construct annotation file
|
|
327
|
+
|
|
328
|
+
:param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
|
|
329
|
+
:return:
|
|
330
|
+
"""
|
|
331
|
+
raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
|
|
332
|
+
|
|
333
|
+
#################
|
|
334
|
+
# DTLPY METHODS #
|
|
335
|
+
################
|
|
336
|
+
def prepare_item_func(self, item: entities.Item):
|
|
337
|
+
"""
|
|
338
|
+
Prepare the Dataloop item before calling the `predict` function with a batch.
|
|
339
|
+
A user can override this function to load item differently
|
|
340
|
+
Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
|
|
341
|
+
|
|
342
|
+
:param item:
|
|
343
|
+
:return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
|
|
344
|
+
"""
|
|
345
|
+
# Item to batch func
|
|
346
|
+
if isinstance(self.model_entity.input_type, list):
|
|
347
|
+
if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
|
|
348
|
+
processed = self._item_to_text(item)
|
|
349
|
+
elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
|
|
350
|
+
processed = self._item_to_image(item)
|
|
351
|
+
else:
|
|
352
|
+
processed = self._item_to_item(item)
|
|
353
|
+
|
|
354
|
+
elif self.model_entity.input_type in self.item_to_batch_mapping:
|
|
355
|
+
processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
|
|
356
|
+
|
|
357
|
+
else:
|
|
358
|
+
processed = self._item_to_item(item)
|
|
359
|
+
|
|
360
|
+
return processed
|
|
361
|
+
|
|
362
|
+
def __include_model_annotations(self, annotation_filters):
|
|
363
|
+
include_model_annotations = self.model_entity.configuration.get("include_model_annotations", False)
|
|
364
|
+
if include_model_annotations is False:
|
|
365
|
+
if annotation_filters.custom_filter is None:
|
|
366
|
+
annotation_filters.add(field="metadata.system.model.name", values=False, operator=entities.FiltersOperations.EXISTS)
|
|
367
|
+
else:
|
|
368
|
+
annotation_filters.custom_filter['filter']['$and'].append({'metadata.system.model.name': {'$exists': False}})
|
|
369
|
+
return annotation_filters
|
|
370
|
+
|
|
371
|
+
def __download_background_images(self, filters, data_subset_base_path, annotation_options):
|
|
372
|
+
background_list = list()
|
|
373
|
+
if self.configuration.get('include_background', False) is True:
|
|
374
|
+
filters.custom_filter["filter"]["$and"].append({"annotated": False})
|
|
375
|
+
background_list = self.model_entity.dataset.items.download(
|
|
376
|
+
filters=filters,
|
|
377
|
+
local_path=data_subset_base_path,
|
|
378
|
+
annotation_options=annotation_options,
|
|
379
|
+
)
|
|
380
|
+
return background_list
|
|
381
|
+
|
|
382
|
+
def prepare_data(
|
|
383
|
+
self,
|
|
384
|
+
dataset: entities.Dataset,
|
|
385
|
+
# paths
|
|
386
|
+
root_path=None,
|
|
387
|
+
data_path=None,
|
|
388
|
+
output_path=None,
|
|
389
|
+
#
|
|
390
|
+
overwrite=False,
|
|
391
|
+
**kwargs,
|
|
392
|
+
):
|
|
393
|
+
"""
|
|
394
|
+
Prepares dataset locally before training or evaluation.
|
|
395
|
+
download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
|
|
396
|
+
|
|
397
|
+
:param dataset: dl.Dataset
|
|
398
|
+
:param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
|
|
399
|
+
:param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
|
|
400
|
+
:param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
|
|
401
|
+
|
|
402
|
+
:param bool overwrite: overwrite the data path (download again). default is False
|
|
403
|
+
"""
|
|
404
|
+
# define paths
|
|
405
|
+
dataloop_path = service_defaults.DATALOOP_PATH
|
|
406
|
+
root_path = self.adapter_defaults.resolve("root_path", root_path)
|
|
407
|
+
data_path = self.adapter_defaults.resolve("data_path", data_path)
|
|
408
|
+
output_path = self.adapter_defaults.resolve("output_path", output_path)
|
|
409
|
+
if root_path is None:
|
|
410
|
+
now = datetime.datetime.now()
|
|
411
|
+
root_path = os.path.join(
|
|
412
|
+
dataloop_path,
|
|
413
|
+
'model_data',
|
|
414
|
+
"{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
|
|
415
|
+
now.strftime('%Y-%m-%d-%H%M%S'),
|
|
416
|
+
)
|
|
417
|
+
if data_path is None:
|
|
418
|
+
data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
|
|
419
|
+
os.makedirs(data_path, exist_ok=True)
|
|
420
|
+
if output_path is None:
|
|
421
|
+
output_path = os.path.join(root_path, 'output')
|
|
422
|
+
os.makedirs(output_path, exist_ok=True)
|
|
423
|
+
|
|
424
|
+
if len(os.listdir(data_path)) > 0:
|
|
425
|
+
self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
|
|
426
|
+
|
|
427
|
+
annotation_options = entities.ViewAnnotationOptions.JSON
|
|
428
|
+
if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
|
|
429
|
+
annotation_options = entities.ViewAnnotationOptions.INSTANCE
|
|
430
|
+
|
|
431
|
+
# Download the subset items
|
|
432
|
+
subsets = self.model_entity.metadata.get("system", {}).get("subsets", None)
|
|
433
|
+
annotations_subsets = self.model_entity.metadata.get("system", {}).get("annotationsSubsets", {})
|
|
434
|
+
if subsets is None:
|
|
435
|
+
raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
|
|
436
|
+
for subset, filters_dict in subsets.items():
|
|
437
|
+
data_subset_base_path = os.path.join(data_path, subset)
|
|
438
|
+
if os.path.isdir(data_subset_base_path) and not overwrite:
|
|
439
|
+
# existing and dont overwrite
|
|
440
|
+
self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
|
|
441
|
+
continue
|
|
442
|
+
|
|
443
|
+
filters = entities.Filters(custom_filter=filters_dict)
|
|
444
|
+
self.logger.debug("Downloading subset {!r} of {}".format(subset, self.model_entity.dataset.name))
|
|
445
|
+
|
|
446
|
+
annotation_filters = None
|
|
447
|
+
if subset in annotations_subsets:
|
|
448
|
+
annotation_filters = entities.Filters(
|
|
449
|
+
use_defaults=False,
|
|
450
|
+
resource=entities.FiltersResource.ANNOTATION,
|
|
451
|
+
custom_filter=annotations_subsets[subset],
|
|
452
|
+
)
|
|
453
|
+
# if user provided annotation_filters, skip the default filters
|
|
454
|
+
elif self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
|
|
455
|
+
annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
|
|
456
|
+
if self.model_entity.output_type in [
|
|
457
|
+
entities.AnnotationType.SEGMENTATION,
|
|
458
|
+
entities.AnnotationType.POLYGON,
|
|
459
|
+
]:
|
|
460
|
+
model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
|
|
461
|
+
else:
|
|
462
|
+
model_output_types = [self.model_entity.output_type]
|
|
463
|
+
|
|
464
|
+
annotation_filters.add(
|
|
465
|
+
field=entities.FiltersKnownFields.TYPE,
|
|
466
|
+
values=model_output_types,
|
|
467
|
+
operator=entities.FiltersOperations.IN,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
annotation_filters = self.__include_model_annotations(annotation_filters)
|
|
471
|
+
annotations_subsets[subset] = annotation_filters.prepare()
|
|
472
|
+
|
|
473
|
+
ret_list = dataset.items.download(
|
|
474
|
+
filters=filters,
|
|
475
|
+
local_path=data_subset_base_path,
|
|
476
|
+
annotation_options=annotation_options,
|
|
477
|
+
annotation_filters=annotation_filters,
|
|
478
|
+
)
|
|
479
|
+
filters = entities.Filters(custom_filter=subsets[subset])
|
|
480
|
+
background_ret_list = self.__download_background_images(
|
|
481
|
+
filters=filters,
|
|
482
|
+
data_subset_base_path=data_subset_base_path,
|
|
483
|
+
annotation_options=annotation_options,
|
|
484
|
+
)
|
|
485
|
+
ret_list = list(ret_list)
|
|
486
|
+
background_ret_list = list(background_ret_list)
|
|
487
|
+
self.logger.debug(f"Subset '{subset}': ret_list length: {len(ret_list)}, background_ret_list length: {len(background_ret_list)}")
|
|
488
|
+
# Combine ret_list and background_ret_list generators into a single generator
|
|
489
|
+
ret_list = ret_list + background_ret_list
|
|
490
|
+
if isinstance(ret_list, list) and len(ret_list) == 0:
|
|
491
|
+
if annotation_filters is not None:
|
|
492
|
+
annotation_filters_str = annotation_filters.prepare()
|
|
493
|
+
else:
|
|
494
|
+
annotation_filters_str = None
|
|
495
|
+
raise ValueError(
|
|
496
|
+
f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
|
|
497
|
+
f"Subset {subset} filters: {filters.prepare()}\n"
|
|
498
|
+
f"Annotation filters: {annotation_filters_str}"
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
self.convert_from_dtlpy(data_path=data_path, **kwargs)
|
|
502
|
+
return root_path, data_path, output_path
|
|
503
|
+
|
|
504
|
+
def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
|
|
505
|
+
"""Loads a model from given `dl.Model`.
|
|
506
|
+
Reads configurations and instantiate self.model_entity
|
|
507
|
+
Downloads the model_entity bucket (if available)
|
|
508
|
+
|
|
509
|
+
:param model_entity: `str` dl.Model entity
|
|
510
|
+
:param local_path: `str` directory path in local FileSystem to download the model_entity to
|
|
511
|
+
:param overwrite: `bool` (default False) if False does not download files with same name else (True) download all
|
|
512
|
+
"""
|
|
513
|
+
if model_entity is not None:
|
|
514
|
+
self.model_entity = model_entity
|
|
515
|
+
if local_path is None:
|
|
516
|
+
local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
|
|
517
|
+
# Load configuration and adapter defaults
|
|
518
|
+
self.adapter_defaults = AdapterDefaults(self)
|
|
519
|
+
# Point _configuration to the same object since AdapterDefaults inherits from ModelConfigurations
|
|
520
|
+
self._configuration = self.adapter_defaults
|
|
521
|
+
# Download
|
|
522
|
+
self.model_entity.artifacts.download(local_path=local_path, overwrite=overwrite)
|
|
523
|
+
self.load(local_path, **kwargs)
|
|
524
|
+
|
|
525
|
+
def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
|
|
526
|
+
"""
|
|
527
|
+
Saves the model state to a new bucket and configuration
|
|
528
|
+
|
|
529
|
+
Saves configuration and weights to artifacts
|
|
530
|
+
Mark the model as `trained`
|
|
531
|
+
loads only applies for remote buckets
|
|
532
|
+
|
|
533
|
+
:param local_path: `str` directory path in local FileSystem to save the current model bucket (weights) (default will create a temp dir)
|
|
534
|
+
:param replace: `bool` will clean the bucket's content before uploading new files
|
|
535
|
+
:param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
|
|
536
|
+
:return:
|
|
537
|
+
"""
|
|
538
|
+
|
|
539
|
+
if local_path is None:
|
|
540
|
+
local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
|
|
541
|
+
self.logger.debug("Using temporary dir at {}".format(local_path))
|
|
542
|
+
|
|
543
|
+
self.save(local_path=local_path, **kwargs)
|
|
544
|
+
|
|
545
|
+
if self.model_entity is None:
|
|
546
|
+
raise ValueError('Missing model entity on the adapter. ' 'Please set before saving: "adapter.model_entity=model"')
|
|
547
|
+
|
|
548
|
+
self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'), overwrite=True)
|
|
549
|
+
if cleanup:
|
|
550
|
+
shutil.rmtree(path=local_path, ignore_errors=True)
|
|
551
|
+
self.logger.info("Clean-up. deleting {}".format(local_path))
|
|
552
|
+
|
|
553
|
+
# ===============
|
|
554
|
+
# SERVICE METHODS
|
|
555
|
+
# ===============
|
|
556
|
+
|
|
557
|
+
@entities.Package.decorators.function(
|
|
558
|
+
display_name='Predict Items',
|
|
559
|
+
inputs={'items': 'Item[]'},
|
|
560
|
+
outputs={'items': 'Item[]', 'annotations': 'Annotation[]'},
|
|
561
|
+
)
|
|
562
|
+
def predict_items(self, items: list, batch_size=None, **kwargs):
|
|
563
|
+
"""
|
|
564
|
+
Run the predict function on the input list of items (or single) and return the items and the predictions.
|
|
565
|
+
Each prediction is by the model output type (package.output_type) and model_info in the metadata
|
|
566
|
+
|
|
567
|
+
:param items: `List[dl.Item]` list of items to predict
|
|
568
|
+
:param batch_size: `int` size of batch to run a single inference
|
|
569
|
+
|
|
570
|
+
:return: `List[dl.Item]`, `List[List[dl.Annotation]]`
|
|
571
|
+
"""
|
|
572
|
+
if batch_size is None:
|
|
573
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
574
|
+
input_type = self.model_entity.input_type
|
|
575
|
+
self.logger.debug("Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
576
|
+
pool = ThreadPoolExecutor(max_workers=16)
|
|
577
|
+
error_counter = 0
|
|
578
|
+
fail_ids = list()
|
|
579
|
+
annotations = list()
|
|
580
|
+
for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None, file=sys.stdout):
|
|
581
|
+
batch_items = items[i_batch : i_batch + batch_size]
|
|
582
|
+
batch = list(pool.map(self.prepare_item_func, batch_items))
|
|
583
|
+
try:
|
|
584
|
+
batch_collections = self.predict(batch, **kwargs)
|
|
585
|
+
except Exception as e:
|
|
586
|
+
item_ids = [item.id for item in batch_items]
|
|
587
|
+
self.logger.error(f"Failed to predict batch {i_batch} for items {item_ids}. Error: {e}\n{traceback.format_exc()}")
|
|
588
|
+
error_counter += 1
|
|
589
|
+
fail_ids.extend(item_ids)
|
|
590
|
+
continue
|
|
591
|
+
_futures = list(pool.map(partial(self._update_predictions_metadata), batch_items, batch_collections))
|
|
592
|
+
# Loop over the futures to make sure they are all done to avoid race conditions
|
|
593
|
+
_ = [_f for _f in _futures]
|
|
594
|
+
self.logger.debug("Uploading items' annotation for model {!r}.".format(self.model_entity.name))
|
|
595
|
+
try:
|
|
596
|
+
batch_collections = list(
|
|
597
|
+
pool.map(partial(self._upload_model_annotations), batch_items, batch_collections)
|
|
598
|
+
)
|
|
599
|
+
except Exception as err:
|
|
600
|
+
item_ids = [item.id for item in batch_items]
|
|
601
|
+
self.logger.error(
|
|
602
|
+
f"Failed to upload annotations for items {item_ids}. Error: {err}\n{traceback.format_exc()}"
|
|
603
|
+
)
|
|
604
|
+
error_counter += 1
|
|
605
|
+
fail_ids.extend(item_ids)
|
|
606
|
+
|
|
607
|
+
for collection in batch_collections:
|
|
608
|
+
# function needs to return `List[List[dl.Annotation]]`
|
|
609
|
+
# convert annotation collection to a list of dl.Annotation for each batch
|
|
610
|
+
if isinstance(collection, entities.AnnotationCollection):
|
|
611
|
+
annotations.extend([annotation for annotation in collection.annotations])
|
|
612
|
+
else:
|
|
613
|
+
logger.warning(f'RETURN TYPE MAY BE INVALID: {type(collection)}')
|
|
614
|
+
annotations.extend(collection)
|
|
615
|
+
# TODO call the callback
|
|
616
|
+
|
|
617
|
+
pool.shutdown()
|
|
618
|
+
if error_counter > 0:
|
|
619
|
+
raise Exception(f"Failed to predict all items. Failed IDs: {fail_ids}, See logs for more details")
|
|
620
|
+
return items, annotations
|
|
621
|
+
|
|
622
|
+
@entities.Package.decorators.function(
|
|
623
|
+
display_name='Embed Items',
|
|
624
|
+
inputs={'items': 'Item[]'},
|
|
625
|
+
outputs={'items': 'Item[]', 'features': 'Json[]'},
|
|
626
|
+
)
|
|
627
|
+
def embed_items(self, items: list, upload_features=None, batch_size=None, progress: utilities.Progress = None, **kwargs):
|
|
628
|
+
"""
|
|
629
|
+
Extract feature from an input list of items (or single) and return the items and the feature vector.
|
|
630
|
+
|
|
631
|
+
:param items: `List[dl.Item]` list of items to embed
|
|
632
|
+
:param upload_features: `bool` uploads the features on the given items
|
|
633
|
+
:param batch_size: `int` size of batch to run a single embed
|
|
634
|
+
|
|
635
|
+
:return: `List[dl.Item]`, `List[List[vector]]`
|
|
636
|
+
"""
|
|
637
|
+
if batch_size is None:
|
|
638
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
639
|
+
upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
|
|
640
|
+
skip_default_items = upload_features is None
|
|
641
|
+
if upload_features is None:
|
|
642
|
+
upload_features = True
|
|
643
|
+
input_type = self.model_entity.input_type
|
|
644
|
+
self.logger.debug("Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
645
|
+
error_counter = 0
|
|
646
|
+
fail_ids = list()
|
|
647
|
+
|
|
648
|
+
feature_set = self.feature_set
|
|
649
|
+
|
|
650
|
+
# upload the feature vectors
|
|
651
|
+
pool = ThreadPoolExecutor(max_workers=16)
|
|
652
|
+
vectors = list()
|
|
653
|
+
_items = list()
|
|
654
|
+
for i_batch in tqdm.tqdm(
|
|
655
|
+
range(0, len(items), batch_size),
|
|
656
|
+
desc='embedding',
|
|
657
|
+
unit='bt',
|
|
658
|
+
leave=None,
|
|
659
|
+
file=sys.stdout,
|
|
660
|
+
):
|
|
661
|
+
batch_items = items[i_batch : i_batch + batch_size]
|
|
662
|
+
batch = list(pool.map(self.prepare_item_func, batch_items))
|
|
663
|
+
try:
|
|
664
|
+
batch_vectors = self.embed(batch, **kwargs)
|
|
665
|
+
except Exception as err:
|
|
666
|
+
item_ids = [item.id for item in batch_items]
|
|
667
|
+
self.logger.error(f"Failed to embed batch {i_batch} for items {item_ids}. Error: {err}\n{traceback.format_exc()}")
|
|
668
|
+
error_counter += 1
|
|
669
|
+
fail_ids.extend(item_ids)
|
|
670
|
+
continue
|
|
671
|
+
vectors.extend(batch_vectors)
|
|
672
|
+
# Save the items in the order of the vectors
|
|
673
|
+
_items.extend(batch_items)
|
|
674
|
+
pool.shutdown()
|
|
675
|
+
|
|
676
|
+
if upload_features is True:
|
|
677
|
+
embeddings_size = self.configuration.get('embeddings_size', 256)
|
|
678
|
+
valid_items = []
|
|
679
|
+
valid_vectors = []
|
|
680
|
+
items_to_upload = []
|
|
681
|
+
vectors_to_upload = []
|
|
682
|
+
|
|
683
|
+
for item, vector in zip(_items, vectors):
|
|
684
|
+
# Check if vector is valid
|
|
685
|
+
if vector is None or len(vector) != embeddings_size:
|
|
686
|
+
self.logger.warning(f"Vector generated for item {item.id} is None or has wrong size. Skipping...")
|
|
687
|
+
continue
|
|
688
|
+
|
|
689
|
+
# Item and vector are valid
|
|
690
|
+
valid_items.append(item)
|
|
691
|
+
valid_vectors.append(vector)
|
|
692
|
+
|
|
693
|
+
# Check if item should be skipped (prompt items)
|
|
694
|
+
_system_metadata = getattr(item, 'system', dict())
|
|
695
|
+
is_prompt = _system_metadata.get('shebang', dict()).get('dltype', '') == 'prompt'
|
|
696
|
+
if skip_default_items and is_prompt:
|
|
697
|
+
self.logger.debug(f"Skipping feature upload for prompt item {item.id}")
|
|
698
|
+
continue
|
|
699
|
+
|
|
700
|
+
# Items were not skipped - should be uploaded
|
|
701
|
+
items_to_upload.append(item)
|
|
702
|
+
vectors_to_upload.append(vector)
|
|
703
|
+
|
|
704
|
+
# Update the original lists with valid items only
|
|
705
|
+
_items[:] = valid_items
|
|
706
|
+
vectors[:] = valid_vectors
|
|
707
|
+
|
|
708
|
+
if len(_items) != len(vectors):
|
|
709
|
+
raise ValueError(f"The number of items ({len(_items)}) is not equal to the number of vectors ({len(vectors)}).")
|
|
710
|
+
|
|
711
|
+
self.logger.debug(f"Uploading {len(items_to_upload)} items' feature vectors for model {self.model_entity.name}.")
|
|
712
|
+
try:
|
|
713
|
+
start_time = time.time()
|
|
714
|
+
feature_set.features.create(entity=items_to_upload, value=vectors_to_upload, feature_set_id=feature_set.id, project_id=self.model_entity.project_id)
|
|
715
|
+
self.logger.debug(f"Uploaded {len(items_to_upload)} items' feature vectors for model {self.model_entity.name} in {time.time() - start_time} seconds.")
|
|
716
|
+
except Exception as err:
|
|
717
|
+
self.logger.error(f"Failed to upload feature vectors. Error: {err}\n{traceback.format_exc()}")
|
|
718
|
+
error_counter += 1
|
|
719
|
+
if error_counter > 0:
|
|
720
|
+
raise Exception(f"Failed to embed all items. Failed IDs: {fail_ids}, See logs for more details")
|
|
721
|
+
return _items, vectors
|
|
722
|
+
|
|
723
|
+
@entities.Package.decorators.function(
|
|
724
|
+
display_name='Embed Dataset with DQL',
|
|
725
|
+
inputs={'dataset': 'Dataset', 'filters': 'Json'},
|
|
726
|
+
)
|
|
727
|
+
def embed_dataset(
|
|
728
|
+
self,
|
|
729
|
+
dataset: entities.Dataset,
|
|
730
|
+
filters: Optional[entities.Filters] = None,
|
|
731
|
+
upload_features: Optional[bool] = None,
|
|
732
|
+
batch_size: Optional[int] = None,
|
|
733
|
+
progress: Optional[utilities.Progress] = None,
|
|
734
|
+
**kwargs,
|
|
735
|
+
):
|
|
736
|
+
"""
|
|
737
|
+
Run model embedding on all items in a dataset
|
|
738
|
+
|
|
739
|
+
:param dataset: Dataset entity to embed
|
|
740
|
+
:param filters: Filters entity for filtering before embedding
|
|
741
|
+
:param upload_features: bool whether to upload features back to platform
|
|
742
|
+
:param batch_size: int size of batch to run a single embedding
|
|
743
|
+
:param progress: dl.Progress object to track progress
|
|
744
|
+
:return: bool indicating if embedding completed successfully
|
|
745
|
+
"""
|
|
746
|
+
|
|
747
|
+
self._execute_dataset_operation(
|
|
748
|
+
dataset=dataset,
|
|
749
|
+
operation_type='embed',
|
|
750
|
+
filters=filters,
|
|
751
|
+
progress=progress,
|
|
752
|
+
batch_size=batch_size,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
@entities.Package.decorators.function(
|
|
756
|
+
display_name='Predict Dataset with DQL',
|
|
757
|
+
inputs={'dataset': 'Dataset', 'filters': 'Json'},
|
|
758
|
+
)
|
|
759
|
+
def predict_dataset(
|
|
760
|
+
self,
|
|
761
|
+
dataset: entities.Dataset,
|
|
762
|
+
filters: Optional[entities.Filters] = None,
|
|
763
|
+
batch_size: Optional[int] = None,
|
|
764
|
+
progress: Optional[utilities.Progress] = None,
|
|
765
|
+
**kwargs,
|
|
766
|
+
):
|
|
767
|
+
"""
|
|
768
|
+
Run model prediction on all items in a dataset
|
|
769
|
+
|
|
770
|
+
:param dataset: Dataset entity to predict
|
|
771
|
+
:param filters: Filters entity for filtering before prediction
|
|
772
|
+
:param batch_size: int size of batch to run a single prediction
|
|
773
|
+
:param progress: dl.Progress object to track progress
|
|
774
|
+
:return: bool indicating if prediction completed successfully
|
|
775
|
+
"""
|
|
776
|
+
self._execute_dataset_operation(
|
|
777
|
+
dataset=dataset,
|
|
778
|
+
operation_type='predict',
|
|
779
|
+
filters=filters,
|
|
780
|
+
progress=progress,
|
|
781
|
+
batch_size=batch_size,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
@entities.Package.decorators.function(
|
|
785
|
+
display_name='Train a Model',
|
|
786
|
+
inputs={'model': entities.Model},
|
|
787
|
+
outputs={'model': entities.Model},
|
|
788
|
+
)
|
|
789
|
+
def train_model(self, model: entities.Model, cleanup=False, progress: utilities.Progress = None, context: utilities.Context = None):
|
|
790
|
+
"""
|
|
791
|
+
Train on existing model.
|
|
792
|
+
data will be taken from dl.Model.datasetId
|
|
793
|
+
configuration is as defined in dl.Model.configuration
|
|
794
|
+
upload the output the model's bucket (model.bucket)
|
|
795
|
+
"""
|
|
796
|
+
if isinstance(model, dict):
|
|
797
|
+
model = repositories.Models(client_api=self._client_api).get(model_id=model['id'])
|
|
798
|
+
output_path = None
|
|
799
|
+
try:
|
|
800
|
+
logger.info("Received {s} for training".format(s=model.id))
|
|
801
|
+
model = model.wait_for_model_ready()
|
|
802
|
+
if model.status == 'failed':
|
|
803
|
+
raise ValueError("Model is in failed state, cannot train.")
|
|
804
|
+
|
|
805
|
+
##############
|
|
806
|
+
# Set status #
|
|
807
|
+
##############
|
|
808
|
+
model.status = 'training'
|
|
809
|
+
if context is not None:
|
|
810
|
+
if 'system' not in model.metadata:
|
|
811
|
+
model.metadata['system'] = dict()
|
|
812
|
+
model.update(reload_services=False)
|
|
813
|
+
|
|
814
|
+
##########################
|
|
815
|
+
# load model and weights #
|
|
816
|
+
##########################
|
|
817
|
+
logger.info("Loading Adapter with: {n} ({i!r})".format(n=model.name, i=model.id))
|
|
818
|
+
self.load_from_model(model_entity=model)
|
|
819
|
+
|
|
820
|
+
################
|
|
821
|
+
# prepare data #
|
|
822
|
+
################
|
|
823
|
+
root_path, data_path, output_path = self.prepare_data(dataset=self.model_entity.dataset, root_path=os.path.join('tmp', model.id))
|
|
824
|
+
# Start the Train
|
|
825
|
+
logger.info(f"Training model {model.name!r} ({model.id!r}) on data {data_path!r}")
|
|
826
|
+
if progress is not None:
|
|
827
|
+
progress.update(message='starting training')
|
|
828
|
+
|
|
829
|
+
def on_epoch_end_callback(i_epoch, n_epoch):
|
|
830
|
+
if progress is not None:
|
|
831
|
+
progress.update(progress=int(100 * (i_epoch + 1) / n_epoch), message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
|
|
832
|
+
|
|
833
|
+
self.train(data_path=data_path, output_path=output_path, on_epoch_end_callback=on_epoch_end_callback)
|
|
834
|
+
if progress is not None:
|
|
835
|
+
progress.update(message='saving model', progress=99)
|
|
836
|
+
|
|
837
|
+
self.save_to_model(local_path=output_path, replace=True)
|
|
838
|
+
model.status = 'trained'
|
|
839
|
+
model.update(reload_services=False)
|
|
840
|
+
###########
|
|
841
|
+
# cleanup #
|
|
842
|
+
###########
|
|
843
|
+
if cleanup:
|
|
844
|
+
shutil.rmtree(output_path, ignore_errors=True)
|
|
845
|
+
except Exception:
|
|
846
|
+
# save also on fail
|
|
847
|
+
if output_path is not None:
|
|
848
|
+
self.save_to_model(local_path=output_path, replace=True)
|
|
849
|
+
logger.info('Execution failed. Setting model.status to failed')
|
|
850
|
+
raise
|
|
851
|
+
return model
|
|
852
|
+
|
|
853
|
+
@entities.Package.decorators.function(
|
|
854
|
+
display_name='Evaluate a Model',
|
|
855
|
+
inputs={'model': entities.Model, 'dataset': entities.Dataset, 'filters': 'Json'},
|
|
856
|
+
outputs={'model': entities.Model, 'dataset': entities.Dataset},
|
|
857
|
+
)
|
|
858
|
+
def evaluate_model(
|
|
859
|
+
self,
|
|
860
|
+
model: entities.Model,
|
|
861
|
+
dataset: entities.Dataset,
|
|
862
|
+
filters: entities.Filters = None,
|
|
863
|
+
#
|
|
864
|
+
progress: utilities.Progress = None,
|
|
865
|
+
context: utilities.Context = None,
|
|
866
|
+
):
|
|
867
|
+
"""
|
|
868
|
+
Evaluate a model.
|
|
869
|
+
data will be downloaded from the dataset and query
|
|
870
|
+
configuration is as defined in dl.Model.configuration
|
|
871
|
+
upload annotations and calculate metrics vs GT
|
|
872
|
+
|
|
873
|
+
:param model: Model entity to run prediction
|
|
874
|
+
:param dataset: Dataset to evaluate
|
|
875
|
+
:param filters: Filter for specific items from dataset
|
|
876
|
+
:param progress: dl.Progress for report FaaS progress
|
|
877
|
+
:param context:
|
|
878
|
+
:return:
|
|
879
|
+
"""
|
|
880
|
+
logger.info(f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
|
|
881
|
+
##########################
|
|
882
|
+
# load model and weights #
|
|
883
|
+
##########################
|
|
884
|
+
logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
|
|
885
|
+
self.load_from_model(dataset=dataset, model_entity=model)
|
|
886
|
+
|
|
887
|
+
##############
|
|
888
|
+
# Predicting #
|
|
889
|
+
##############
|
|
890
|
+
logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
|
|
891
|
+
if not filters:
|
|
892
|
+
filters = entities.Filters()
|
|
893
|
+
if self.adapter_defaults.get("overwrite_annotations", True) is True:
|
|
894
|
+
self._execute_dataset_operation(
|
|
895
|
+
dataset=dataset,
|
|
896
|
+
operation_type='predict',
|
|
897
|
+
filters=filters,
|
|
898
|
+
multiple_executions=False,
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
##############
|
|
902
|
+
# Evaluating #
|
|
903
|
+
##############
|
|
904
|
+
logger.info(f"Starting adapter.evaluate()")
|
|
905
|
+
if progress is not None:
|
|
906
|
+
progress.update(message='calculating metrics', progress=98)
|
|
907
|
+
model = self.evaluate(model=model, dataset=dataset, filters=filters)
|
|
908
|
+
#########
|
|
909
|
+
# Done! #
|
|
910
|
+
#########
|
|
911
|
+
if progress is not None:
|
|
912
|
+
progress.update(message='finishing evaluation', progress=99)
|
|
913
|
+
return model, dataset
|
|
914
|
+
|
|
915
|
+
# =============
|
|
916
|
+
# INNER METHODS
|
|
917
|
+
# =============
|
|
918
|
+
def _get_feature_set(self):
|
|
919
|
+
# Ensure feature set creation/retrieval is thread-safe across the class
|
|
920
|
+
with self.__class__._feature_set_lock:
|
|
921
|
+
# Search for existing feature set for this model id
|
|
922
|
+
feature_set = self.model_entity.feature_set
|
|
923
|
+
if feature_set is None:
|
|
924
|
+
logger.info('Feature Set not found. creating... ')
|
|
925
|
+
try:
|
|
926
|
+
self.project.feature_sets.get(feature_set_name=self.model_entity.name)
|
|
927
|
+
feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
|
|
928
|
+
logger.warning(
|
|
929
|
+
f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}"
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
except exceptions.NotFound:
|
|
933
|
+
feature_set_name = self.model_entity.name
|
|
934
|
+
feature_set = self.project.feature_sets.create(
|
|
935
|
+
name=feature_set_name,
|
|
936
|
+
entity_type=entities.FeatureEntityType.ITEM,
|
|
937
|
+
model_id=self.model_entity.id,
|
|
938
|
+
project_id=self.project.id,
|
|
939
|
+
set_type=self.model_entity.name,
|
|
940
|
+
size=self.configuration.get('embeddings_size', 256),
|
|
941
|
+
)
|
|
942
|
+
logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
|
|
943
|
+
else:
|
|
944
|
+
logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
|
|
945
|
+
return feature_set
|
|
946
|
+
|
|
947
|
+
def _execute_dataset_operation(
|
|
948
|
+
self,
|
|
949
|
+
dataset: entities.Dataset,
|
|
950
|
+
operation_type: str,
|
|
951
|
+
filters: Optional[entities.Filters] = None,
|
|
952
|
+
progress: Optional[utilities.Progress] = None,
|
|
953
|
+
batch_size: Optional[int] = None,
|
|
954
|
+
multiple_executions: bool = True,
|
|
955
|
+
) -> bool:
|
|
956
|
+
"""
|
|
957
|
+
Execute dataset operation (predict/embed) with batching and filtering support.
|
|
958
|
+
|
|
959
|
+
:param dataset: Dataset entity to run operation on
|
|
960
|
+
:param operation_type: Type of operation to execute ('predict' or 'embed')
|
|
961
|
+
:param filters: Filters entity to filter items, default None
|
|
962
|
+
:param progress: Progress object for tracking progress, default None
|
|
963
|
+
:param batch_size: Size of batches to process items, default None (uses model config)
|
|
964
|
+
:param multiple_executions: Whether to use multiple executions when filters exceed subset limit, default True
|
|
965
|
+
:return: True if operation completes successfully
|
|
966
|
+
:raises ValueError: If operation_type is not 'predict' or 'embed'
|
|
967
|
+
"""
|
|
968
|
+
self.logger.debug(f"Running {operation_type} for dataset (name:{dataset.name}, id:{dataset.id})")
|
|
969
|
+
|
|
970
|
+
if not filters:
|
|
971
|
+
self.logger.debug("No filters provided, using default filters")
|
|
972
|
+
filters = entities.Filters()
|
|
973
|
+
if filters is not None and isinstance(filters, dict):
|
|
974
|
+
self.logger.debug(f"Received custom filters {filters}")
|
|
975
|
+
filters = entities.Filters(custom_filter=filters)
|
|
976
|
+
|
|
977
|
+
if operation_type == 'embed':
|
|
978
|
+
feature_set = self.feature_set
|
|
979
|
+
logger.info(f"Feature set found! name: {feature_set.name}, id: {feature_set.id}")
|
|
980
|
+
|
|
981
|
+
predict_embed_subset_limit = self.configuration.get('predict_embed_subset_limit', PREDICT_EMBED_DEFAULT_SUBSET_LIMIT)
|
|
982
|
+
predict_embed_timeout = self.configuration.get('predict_embed_timeout', PREDICT_EMBED_DEFAULT_TIMEOUT)
|
|
983
|
+
self.logger.debug(f"Inputs: predict_embed_subset_limit: {predict_embed_subset_limit}, predict_embed_timeout: {predict_embed_timeout}")
|
|
984
|
+
tmp_filters = copy.deepcopy(filters.prepare())
|
|
985
|
+
tmp_filters['pageSize'] = 0
|
|
986
|
+
num_items = dataset.items.list(filters=entities.Filters(custom_filter=tmp_filters)).items_count
|
|
987
|
+
self.logger.debug(f"Number of items for current filters: {num_items}")
|
|
988
|
+
|
|
989
|
+
# One-item lookahead on generator: if only one subset, run locally; else create executions for all
|
|
990
|
+
gen = entities.Filters._get_split_filters(dataset, filters, predict_embed_subset_limit)
|
|
991
|
+
try:
|
|
992
|
+
first_filter = next(gen)
|
|
993
|
+
except StopIteration:
|
|
994
|
+
self.logger.info("Filters is empty, nothing to run")
|
|
995
|
+
return True
|
|
996
|
+
|
|
997
|
+
try:
|
|
998
|
+
second_filter = next(gen)
|
|
999
|
+
multiple = True
|
|
1000
|
+
except StopIteration:
|
|
1001
|
+
multiple = False
|
|
1002
|
+
|
|
1003
|
+
# Create consistent iterable of all filters for reuse
|
|
1004
|
+
# Both paths use chain to ensure consistent type and iteration behavior
|
|
1005
|
+
if multiple:
|
|
1006
|
+
# Chain together the pre-consumed filters with the remaining generator
|
|
1007
|
+
all_filters = chain([first_filter, second_filter], gen)
|
|
1008
|
+
else:
|
|
1009
|
+
# Single filter - use chain with empty generator for consistency
|
|
1010
|
+
all_filters = chain([first_filter], [])
|
|
1011
|
+
|
|
1012
|
+
if not multiple or not multiple_executions:
|
|
1013
|
+
self.logger.info("Running locally")
|
|
1014
|
+
if batch_size is None:
|
|
1015
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
1016
|
+
|
|
1017
|
+
# Process each filter locally
|
|
1018
|
+
for filter_dict in all_filters:
|
|
1019
|
+
filter_dict["pageSize"] = 1000
|
|
1020
|
+
single_filters = entities.Filters(custom_filter=filter_dict)
|
|
1021
|
+
pages = dataset.items.list(filters=single_filters)
|
|
1022
|
+
self.logger.info(f"Processing filter on: {pages.items_count} items")
|
|
1023
|
+
items = [item for page in pages for item in page if item.type == 'file']
|
|
1024
|
+
self.logger.debug(f"Items length: {len(items)}")
|
|
1025
|
+
|
|
1026
|
+
if operation_type == 'embed':
|
|
1027
|
+
self.embed_items(items=items, batch_size=batch_size, progress=progress)
|
|
1028
|
+
elif operation_type == 'predict':
|
|
1029
|
+
self.predict_items(items=items, batch_size=batch_size, progress=progress)
|
|
1030
|
+
else:
|
|
1031
|
+
raise ValueError(f"Unsupported operation type: {operation_type}")
|
|
1032
|
+
return True
|
|
1033
|
+
|
|
1034
|
+
executions = []
|
|
1035
|
+
for filter_dict in all_filters:
|
|
1036
|
+
self.logger.debug(f"Creating execution for models {operation_type} with dataset id {dataset.id} and filter_dict {filter_dict}")
|
|
1037
|
+
if operation_type == 'embed':
|
|
1038
|
+
execution = self.model_entity.models.embed(
|
|
1039
|
+
model=self.model_entity,
|
|
1040
|
+
dataset_id=dataset.id,
|
|
1041
|
+
filters=entities.Filters(custom_filter=filter_dict),
|
|
1042
|
+
)
|
|
1043
|
+
elif operation_type == 'predict':
|
|
1044
|
+
execution = self.model_entity.models.predict(
|
|
1045
|
+
model=self.model_entity, dataset_id=dataset.id, filters=entities.Filters(custom_filter=filter_dict)
|
|
1046
|
+
)
|
|
1047
|
+
else:
|
|
1048
|
+
raise ValueError(f"Unsupported operation type: {operation_type}")
|
|
1049
|
+
executions.append(execution)
|
|
1050
|
+
|
|
1051
|
+
if executions:
|
|
1052
|
+
self.logger.info(f'Created {len(executions)} executions for {operation_type}, ' f'execution ids: {[ex.id for ex in executions]}')
|
|
1053
|
+
|
|
1054
|
+
wait_time = 5
|
|
1055
|
+
start_time = time.time()
|
|
1056
|
+
last_perc = 0
|
|
1057
|
+
self.logger.debug(f"Waiting for executions with timeout {predict_embed_timeout}")
|
|
1058
|
+
while time.time() - start_time < predict_embed_timeout:
|
|
1059
|
+
continue_loop = False
|
|
1060
|
+
total_perc = 0
|
|
1061
|
+
|
|
1062
|
+
for ex in executions:
|
|
1063
|
+
execution = self.project.executions.get(execution_id=ex.id)
|
|
1064
|
+
perc = execution.latest_status.get('percentComplete', 0)
|
|
1065
|
+
total_perc += perc
|
|
1066
|
+
if execution.in_progress():
|
|
1067
|
+
continue_loop = True
|
|
1068
|
+
|
|
1069
|
+
avg_perc = round(total_perc / len(executions), 0)
|
|
1070
|
+
if progress is not None and last_perc != avg_perc:
|
|
1071
|
+
last_perc = avg_perc
|
|
1072
|
+
progress.update(progress=last_perc, message=f'running {operation_type}')
|
|
1073
|
+
|
|
1074
|
+
if not continue_loop:
|
|
1075
|
+
break
|
|
1076
|
+
|
|
1077
|
+
time.sleep(wait_time)
|
|
1078
|
+
self.logger.debug("End waiting for executions")
|
|
1079
|
+
# Check if any execution failed
|
|
1080
|
+
executions_filter = entities.Filters(resource=entities.FiltersResource.EXECUTION)
|
|
1081
|
+
executions_filter.add(field="id", values=[ex.id for ex in executions], operator=entities.FiltersOperations.IN)
|
|
1082
|
+
executions_filter.add(field='latestStatus.status', values='failed')
|
|
1083
|
+
executions_filter.page_size = 0
|
|
1084
|
+
failed_executions_count = self.project.executions.list(filters=executions_filter).items_count
|
|
1085
|
+
if failed_executions_count > 0:
|
|
1086
|
+
self.logger.error(f"Failed to {operation_type} for {failed_executions_count} executions")
|
|
1087
|
+
raise ValueError(f"Failed to {operation_type} entire dataset, please check the logs for more details")
|
|
1088
|
+
return True
|
|
1089
|
+
|
|
1090
|
+
def _upload_model_annotations(self, item: entities.Item, predictions):
|
|
1091
|
+
"""
|
|
1092
|
+
Utility function that upload prediction to dlp platform based on the package.output_type
|
|
1093
|
+
:param predictions: `dl.AnnotationCollection`
|
|
1094
|
+
:param cleanup: `bool` if set removes existing predictions with the same package-model name
|
|
1095
|
+
"""
|
|
1096
|
+
if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
|
|
1097
|
+
raise TypeError(f'predictions was expected to be of type {entities.AnnotationCollection}, but instead it is {type(predictions)}')
|
|
1098
|
+
clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
|
|
1099
|
+
clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name, method=entities.FiltersMethod.OR)
|
|
1100
|
+
clean_filter.add(field='metadata.system.model.name', values=self.model_entity.name, method=entities.FiltersMethod.OR)
|
|
1101
|
+
# clean_filter.add(field='type', values=self.model_entity.output_type,)
|
|
1102
|
+
item.annotations.delete(filters=clean_filter)
|
|
1103
|
+
annotations = item.annotations.upload(annotations=predictions)
|
|
1104
|
+
return annotations
|
|
1105
|
+
|
|
1106
|
+
@staticmethod
|
|
1107
|
+
def _item_to_image(item):
|
|
1108
|
+
"""
|
|
1109
|
+
Preprocess items before calling the `predict` functions.
|
|
1110
|
+
Convert item to numpy array
|
|
1111
|
+
|
|
1112
|
+
:param item:
|
|
1113
|
+
:return:
|
|
1114
|
+
"""
|
|
1115
|
+
try:
|
|
1116
|
+
buffer = item.download(save_locally=False)
|
|
1117
|
+
image = np.asarray(Image.open(buffer))
|
|
1118
|
+
except Exception as e:
|
|
1119
|
+
logger.error(f"Failed to convert image to np.array, Error: {e}\n{traceback.format_exc()}")
|
|
1120
|
+
image = None
|
|
1121
|
+
return image
|
|
1122
|
+
|
|
1123
|
+
@staticmethod
|
|
1124
|
+
def _item_to_item(item):
|
|
1125
|
+
"""
|
|
1126
|
+
Default item to batch function.
|
|
1127
|
+
This function should prepare a single item for the predict function, e.g. for images, it loads the image as numpy array
|
|
1128
|
+
:param item:
|
|
1129
|
+
:return:
|
|
1130
|
+
"""
|
|
1131
|
+
return item
|
|
1132
|
+
|
|
1133
|
+
@staticmethod
|
|
1134
|
+
def _item_to_text(item):
|
|
1135
|
+
filename = item.download(overwrite=True)
|
|
1136
|
+
text = None
|
|
1137
|
+
if item.mimetype == 'text/plain' or item.mimetype == 'text/markdown':
|
|
1138
|
+
with open(filename, 'r') as f:
|
|
1139
|
+
text = f.read()
|
|
1140
|
+
text = text.replace('\n', ' ')
|
|
1141
|
+
else:
|
|
1142
|
+
logger.warning('Item is not text file. mimetype: {}'.format(item.mimetype))
|
|
1143
|
+
text = item
|
|
1144
|
+
if os.path.exists(filename):
|
|
1145
|
+
os.remove(filename)
|
|
1146
|
+
return text
|
|
1147
|
+
|
|
1148
|
+
@staticmethod
|
|
1149
|
+
def _uri_to_image(data_uri):
|
|
1150
|
+
# data_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAS4AAAEuCAYAAAAwQP9DAAAU80lEQVR4Xu2da+hnRRnHv0qZKV42LDOt1eyGULoSJBGpRBFprBJBQrBJBBWGSm8jld5WroHUCyEXKutNu2IJ1QtXetULL0uQFCu24WoRsV5KpYvGYzM4nv6X8zu/mTnznPkcWP6XPTPzzOf7/L7/OXPmzDlOHBCAAAScETjOWbyECwEIQEAYF0kAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxCAgDsCGJc7yQgYAhDAuMgBCEDAHQGMy51kBAwBCGBc5AAEIOCOAMblTjIChgAEMC5yAAIQcEcA43InGQFDAAIYFzkAAQi4I4BxuZOMgCEAAYyLHIAABNwRwLjcSUbAEIAAxkUOQAAC7ghgXO4kI2AIQADjIgcgAAF3BDAud5IRMAQggHGRAxDwTeDTkr4s6UxJ/5F0QNK3JD3lu1tbR49xLVld+jYXgcskvSTpIkmnS/qgpJMk/Tv8bHHZ7+PXPw6M5kRJx0t6Ijkv9uUsSW+U9Iykczfp4K8lfXiuztdoF+OqQZk2vBEwUzFTsK9mQNFkotGkhvFeSc+G86NRtdDfd0h6tIVASsSAcZWgSp0eCJjJ7JR0SRgZ2SjHDMp+38Jho7PXTAzkBUmvn1jWRTGMy4VMBJmBgBnSpZLsMs7+paOodao3k/hLqCBe8j0cfj4Yvtp8k/1fPLaaf4pxxXPSS8r4/Vsl3SXp5EHgNjo8JukDkg6v06nWy2JcrSvUX3xmKjYSipdqF0h6V/jgp6Mh+2DHf0YpnSd6p6TTkjml7UZRL4bLPasnmo7VHb+PKsQ20rZTQ6ql1lclfXODxr4u6Ru1gpizHYxrTvq0beZkE9cfkXRxxcu0pyXZaMiMKX71dBfua5sY1Psk/baHtMK4elC5rT5eFS7Z7Otmd8VyRDwcRZkxmUlFo8rRxlx13Clpz6Dxn0r61FwB1W4X46pNvM/27PLPPmhmVhvNLUWTiaZil1/xEswMx/7fbv9bWfs5nfcxommdceQU55eWSNxGihcmHbMRZK45Oxe8MK75ZYofaku8MyQ9J+mQpKNJMqbzLfeHkIeTuPP35JUIbCSVToRvNrKyftqCSfs3nE9qqT+txWKT8OmxT9LnWguyZDwYV0m6m9dtH+SbJNlamw+tGIIl7Va6/VPS8xusP4rN2JojG8E8NrhUS+d4ht/bbfkTJP0umGk6ER7PtfkVmwR/wzaXgEck7Q1mNcfE9oq4mzx9aFxXB55NBlsiKIyrBNXt67xB0q3bn7aYM+xSxkZVNjez5Eu4GoLZ5fb+pCFb/mB/LLo6MK555LaRyUPzND251VUWRJpRxTt2cUJ8csMUfBUBG61en/ymu8tE6zvGNd+nwuao7N8PJO0Kz7JZNDbH9aSkv4fQ0su2RyS9VtKD4dJtOClt5+4Il4Fpz+KkdqzLnpuzdrY74vnppWG6ujx9xMXOsUWPjw8WW27XBv+/GgH7Q2Dzh/G4NoxkV6vF+dkYV1sCRoNpKyqiaYmA/TGxxbXxsD963d3YwLhaSkligcDWBIZTDHajo+RauGb1wLialYbAIPB/BO6Q9Pnkt7dJshs93R0YV3eS02HHBGz+8Owk/vN6nU/EuBxnMaF3RWC4DOJ7kr7UFYGksxhXr8rTb28Eho/5dDvaMuEwLm/pS7w9EhiOtu4Oz332yOLlPmNc3UpPx50QsCUytlg5vXvY5RKIVC+My0n2Ema3BG4Oz7VGAN2PthhxdftZoOOOCKQLTu1RKlvL1f3D6Yy4HGUwoXZHwLaq+X7S6xvDzhrdgRh2GOPqPgUA0DCB9LlE27tsu73zG+5K3tAwrrw8qQ0CuQjYZLztmRaP7vbc2gokxpUrzagHAnkJpNvXMNoasMW48iYbtUEgF4F0Up7RFsaVK6+oBwLFCKST8t3uAMGlYrH8omIIFCFg21zvDjV3uwMExlUkt6gUAkUIDCflu34mcTPCzHEVyT0qhcBkAumLVJiU3wQjxjU5vygIgSIE0l0gutxPfgxVjGsMJc6BQB0C9kC1vW4sHvbik/RlKXWicNAKxuVAJELshkC6fY29sdzecs6xAQGMi7SAQDsE7IW5e0I4PJe4hS4YVztJSyQQsF0fdgYM3E3EuPhEQKB5Aumrx7ibuI1cjLiaz2cC7IRAugyCy0SMq5O0p5veCaSr5blMxLi85zPxd0LgGUmnSOIycYTgXCqOgMQpEChMwJY93MfdxPGUMa7xrDgTAqUIxGUQ7Ck/kjDGNRIUp0GgIIG49xaXiSMhY1wjQXEaBAoRSFfLczdxJGSMayQoToNAIQLpannuJo6EjHGNBMVpEChEgMvECWAxrgnQKAKBTAS4TJwIEuOaCI5iEMhAgMvEiRAxrongKAaBDAS4TJwIEeOaCI5iEFiTQPpQNXcTV4SJca0IjNMhkIlA+sJX7iauCBXjWhEYp0MgE4G49xaLTicAxbgmQKMIBNYkkL6CjPcmToCJcU2ARhEIrEkgfVP1Lkn2Zh+OFQhgXCvA4lQIZCIQl0EckWSjL44VCWBcKwLjdAhkIHBY0vmS9kmy0RfHigQwrhWBcToE1iSQLoO4QtK9a9bXZXGMq0vZ6fSMBOLe8rb3ll0m8sLXCWJgXBOgUQQCaxA4KOlStmheg6AkjGs9fpSGwKoEXgoFbpF086qFOf9/BDAuMgEC9Qike8tfLslGXxwTCGBcE6BRBAITCdgI66ZQls/eRIiMuNYAR1EITCAQ57ful2SjL46JBHD9ieAoBoEJBJjfmgBtoyIYVyaQVAOBbQik67eulmRvruaYSADjmgiOYhBYkUBcv2XFdrB+a0V6g9MxrvX4URoCYwnwfOJYUiPOw7hGQOIUCGQgEPff4vnEDDAxrgwQqQIC2xBI99+6VpKNvjjWIIBxrQGPohAYSSDdf4ttmkdC2+o0jCsDRKqAwDYEmN/KnCIYV2agVAeBDQgclfQW9t/KlxsYVz6W1ASBjQiw/1aBvMC4CkClSggkBOLziey/lTEtMK6MMKkKAhsQsBdhXMj+W3lzA+PKy5PaIJASOF3SsfAL3ladMTcwrowwqQoCAwK8hqxQSmBchcBSLQTCg9S7Jdn8lo2+ODIRwLgygaQaCGxAwF6EcRrLIPLnBsaVnyk1QsAIXCVpf0DBNjaZcwLjygyU6iAQCOyVdH34nm1sMqcFxpUZKNVBIBCIu0HcHUZfgMlIAOPKCJOqIBAIpKvl2Q2iQFpgXAWgUmX3BLhMLJwCGFdhwFTfJQEuEwvLjnEVBkz13RHgpRgVJMe4KkCmia4IpA9Vs+i0kPQYVyGwVNstgQcl7WLRaVn9Ma6yfKm9LwLsvVVJb4yrEmia6YJAvJvIs4mF5ca4CgOm+q4I8GxiJbkxrkqgaWbxBNJnE22OyzYQ5ChEAOMqBJZquyMQ124dkWTvUeQoSADjKgiXqrshcJmk+0Jv2em0guwYVwXINLF4Agck2YaBdvDC1wpyY1wVINPEognYZeHvJZ0g6RFJFyy6t410DuNqRAjCcEvgBkm3huhvl3Sd2544ChzjciQWoTZJIL5+zILjbmIliTCuSqBpZpEE0tePsei0osQYV0XYNLU4Aunrx/ZJsp85KhDAuCpAponFErhT0p7QO5ZBVJQZ46oIm6YWR4D5rZkkxbhmAk+z7gkwvzWjhBjXjPBp2jWBz0i6K/TgN5Iucd0bZ8FjXM4EI9xmCMSdTi2gn0gyI+OoRADjqgSaZhZHIH3Mh1eQVZYX46oMnOYWQyDuBmEdulzSwcX0zEFHMC4HIhFikwReSqLiwerKEmFclYHT3CIIpNvYWIf4HFWWFeCVgdPcIgh8R9JXQk/+KulNi+iVo05gXI7EItRmCPxS0kdDNLalzXuaiayTQDCuToSmm9kI2MJT25751FDjLZJsaQRHRQIYV0XYNLUIAvdIujLpCXcUZ5AV45oBOk26JvCMpFNCD+zO4vGue+M0eIzLqXCEPQuBdBsbC+BeSVfMEknnjWJcnScA3V+JwJOS3pyUuFqSraDnqEwA46oMnOZcE0gXnVpH+PzMJCfgZwJPsy4JYFyNyIZxNSIEYbggMDSuHZKechH5woLEuBYmKN0pSoARV1G84yvHuMaz4sy+CQzvKB6VdE7fSObrPcY1H3ta9kVgeEeRt/rMqB/GNSN8mnZFYHiZyIr5GeXDuGaET9NuCFwlaX8SLTtCzCwdxjWzADTvgkC6v7wFfJukG1xEvtAgMa6FCku3shL4s6QzkxpZMZ8V7+qVYVyrM6NEfwSel3Ri0m3Wb82cAxjXzALQfPMEhvNbf5D07uajXniAGNfCBaZ7axN4VNLbk1pulLR37VqpYC0CGNda+Ci8cAK22+mxQR95o08DomNcDYhACM0SGK6Wt3cpmnFxzEwA45pZAJpvmsBwtTyXiY3IhXE1IgRhNElguFqey8RGZMK4GhGCMJojMLybyGViQxJhXA2JQShNEbhT0p4kIlbLNyQPxtWQGITSFAH2l29KjlcHg3E1LA6hzUrgxcGe8nxWZpUD42oIP6E0SuAiSQ8NYtsl6eFG4+0uLP6KdCc5HR5BYKOFp+y/NQJcrVMwrlqkaccTgQckXTwI+DJJ93vqxJJjxbiWrC59m0LgfEmHBwX/JemEKZVRpgwBjKsMV2r1S8BGVvcNwv+spB/67dLyIse4lqcpPVqPwEbGxcaB6zHNXhrjyo6UCp0TuFLSPYM+XCPpx877tajwMa5FyUlnMhCwveRvHdTDjqcZwOasAuPKSZO6lkDggKTdSUeOSDp3CR1bUh8wriWpSV9yEPiHpJOSinhGMQfVzHVgXJmBUp17AsOtbFgx36CkGFeDohDSbASGj/r8TdIZs0VDw5sSwLhIDgi8QmC4VfPdkmxfLo7GCGBcjQlCOLMSGO7BxVbNs8qxeeMYV6PCENYsBGyX051JyzxYPYsM2zeKcW3PiDP6ITCcmGf9VqPaY1yNCkNY1QkMJ+YPSbLfcTRIAONqUBRCmoXA8BlF1m/NIsO4RjGucZw4a/kEhncUebC6Yc0xrobFIbSqBIbPKDK/VRX/ao1hXKvx4uzlEtgr6frQvUckXbDcrvrvGcblX0N6kIdAaly/kPTxPNVSSwkCGFcJqtTpkUC6+JSFp40riHE1LhDhVSNwUNKloTUm5qthn9YQxjWNG6WWRyA1LlbMN64vxtW4QIRXjcBTkk4LrWFc1bBPawjjmsaNUssjkD7ug3E1ri/G1bhAhFeNQGpcbB5YDfu0hjCuadwotTwCqXGdJ8l2iuBolADG1agwhFWdQGpcfC6q41+tQQRajRdnL5dANK6nJZ2+3G4uo2cY1zJ0pBfrEbDXjz0WquB1ZOuxrFIa46qCmUYaJ/AJST8PMf5K0scaj7f78DCu7lMAAJLSnSFul3QdVNomgHG1rQ/R1SGQPmDNGq46zNdqBeNaCx+FF0LgYUkXhr6wFMKBqBiXA5EIsTgB7igWR5y3AYwrL09q80cg3WueF8A60Q/jciIUYRYjcLOkm0Lt7MNVDHPeijGuvDypzR+BdH6LZxSd6IdxORGKMIsQsBXyx0LNLDwtgrhMpRhXGa7U6oNA+kqyfZLsZw4HBDAuByIRYjEC6T7zbNdcDHP+ijGu/Eyp0Q+BuOspD1b70ezlSDEuZ4IRbjYCF0l6KNTGZWI2rHUqwrjqcKaV9gikj/lwmdiePltGhHE5E4xwsxGIyyC4TMyGtF5FGFc91rTUFoEXJL1OEqvl29JlVDQY1yhMnLQwAuljPl+QdMfC+rf47mBci5eYDm5AIJ3fYjcIhymCcTkUjZDXJhDnt1gtvzbKeSrAuObhTqvzEUj3l78t7H46XzS0PIkAxjUJG4UcE0i3aWYZhFMhMS6nwhH2ZAIHJO0Opcn/yRjnLYhw8/Kn9foE4m6nhyTZ6nkOhwQwLoeiEfJkAryGbDK6tgpiXG3pQTRlCaS7nfJ8YlnWRWvHuIripfLGCLCNTWOCTA0H45pKjnIeCaTbNPP+RI8KclfFsWqEPpVAnJi38jsk2X5cHA4JMOJyKBohTyaQGhe5Pxnj/AURb34NiKAOgXTjQLayqcO8WCsYVzG0VNwYgXRHCNZwNSbOquFgXKsS43yvBOxlr98OwT8g6f1eO0Lc7DlPDvRD4LuSvhi6+zNJn+yn68vrKSOu5WlKjzYmkD6jaKMv25OLwykBjMupcIS9MoH4KjIryK4QK+NrqwDG1ZYeRFOGQDoxby2whqsM52q1YlzVUNPQjAR+JOma0P5zkk6eMRaazkAA48oAkSqaJ/CEpLNClM9KOrX5iAlwSwIYFwmydAJnS3p80MlzJB1deseX3D+Ma8nq0rdIwF6K8bbww58k7QSNbwIYl2/9iH4cAdtA0O4k2rFf0r3jinFWqwQwrlaVIS4IQGBTAhgXyQEBCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwRwDjcicZAUMAAhgXOQABCLgjgHG5k4yAIQABjIscgAAE3BHAuNxJRsAQgADGRQ5AAALuCGBc7iQjYAhAAOMiByAAAXcEMC53khEwBCCAcZEDEICAOwIYlzvJCBgCEMC4yAEIQMAdAYzLnWQEDAEIYFzkAAQg4I4AxuVOMgKGAAQwLnIAAhBwR+C/doIhTZIi/uMAAAAASUVORK5CYII="
|
|
1151
|
+
image_b64 = data_uri.split(",")[1]
|
|
1152
|
+
binary = base64.b64decode(image_b64)
|
|
1153
|
+
image = np.asarray(Image.open(io.BytesIO(binary)))
|
|
1154
|
+
return image
|
|
1155
|
+
|
|
1156
|
+
def _update_predictions_metadata(self, item: entities.Item, predictions: entities.AnnotationCollection):
|
|
1157
|
+
"""
|
|
1158
|
+
add model_name and model_id to the metadata of the annotations.
|
|
1159
|
+
add model_info to the metadata of the system metadata of the annotation.
|
|
1160
|
+
Add item id to all the annotations in the AnnotationCollection
|
|
1161
|
+
|
|
1162
|
+
:param item: Entity.Item
|
|
1163
|
+
:param predictions: item's AnnotationCollection
|
|
1164
|
+
:return:
|
|
1165
|
+
"""
|
|
1166
|
+
for prediction in predictions:
|
|
1167
|
+
if prediction.type == entities.AnnotationType.SEGMENTATION:
|
|
1168
|
+
color = None
|
|
1169
|
+
try:
|
|
1170
|
+
color = item.dataset._get_ontology().color_map.get(prediction.label, None)
|
|
1171
|
+
except (exceptions.BadRequest, exceptions.NotFound):
|
|
1172
|
+
...
|
|
1173
|
+
if color is None:
|
|
1174
|
+
if self.model_entity._dataset is not None:
|
|
1175
|
+
try:
|
|
1176
|
+
color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label, (255, 255, 255))
|
|
1177
|
+
except (exceptions.BadRequest, exceptions.NotFound):
|
|
1178
|
+
...
|
|
1179
|
+
if color is None:
|
|
1180
|
+
logger.warning("Can't get annotation color from model's dataset, using default.")
|
|
1181
|
+
color = prediction.color
|
|
1182
|
+
prediction.color = color
|
|
1183
|
+
|
|
1184
|
+
prediction.item_id = item.id
|
|
1185
|
+
if 'user' in prediction.metadata and 'model' in prediction.metadata['user']:
|
|
1186
|
+
prediction.metadata['user']['model']['model_id'] = self.model_entity.id
|
|
1187
|
+
prediction.metadata['user']['model']['name'] = self.model_entity.name
|
|
1188
|
+
if 'system' not in prediction.metadata:
|
|
1189
|
+
prediction.metadata['system'] = dict()
|
|
1190
|
+
if 'model' not in prediction.metadata['system']:
|
|
1191
|
+
prediction.metadata['system']['model'] = dict()
|
|
1192
|
+
confidence = prediction.metadata.get('user', dict()).get('model', dict()).get('confidence', None)
|
|
1193
|
+
prediction.metadata['system']['model'] = {
|
|
1194
|
+
'model_id': self.model_entity.id,
|
|
1195
|
+
'name': self.model_entity.name,
|
|
1196
|
+
'confidence': confidence,
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
##############################
|
|
1200
|
+
# Callback Factory functions #
|
|
1201
|
+
##############################
|
|
1202
|
+
@property
|
|
1203
|
+
def dataloop_keras_callback(self):
|
|
1204
|
+
"""
|
|
1205
|
+
Returns the constructor for a keras api dump callback
|
|
1206
|
+
The callback is used for dlp platform to show train losses
|
|
1207
|
+
|
|
1208
|
+
:return: DumpHistoryCallback constructor
|
|
1209
|
+
"""
|
|
1210
|
+
try:
|
|
1211
|
+
import keras
|
|
1212
|
+
except (ImportError, ModuleNotFoundError) as err:
|
|
1213
|
+
raise RuntimeError(f'{self.__class__.__name__} depends on extenral package. Please install ') from err
|
|
1214
|
+
|
|
1215
|
+
import os
|
|
1216
|
+
import time
|
|
1217
|
+
import json
|
|
1218
|
+
|
|
1219
|
+
class DumpHistoryCallback(keras.callbacks.Callback):
|
|
1220
|
+
def __init__(self, dump_path):
|
|
1221
|
+
super().__init__()
|
|
1222
|
+
if os.path.isdir(dump_path):
|
|
1223
|
+
dump_path = os.path.join(dump_path, f'__view__training-history__{time.strftime("%F-%X")}.json')
|
|
1224
|
+
self.dump_file = dump_path
|
|
1225
|
+
self.data = dict()
|
|
1226
|
+
|
|
1227
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
1228
|
+
logs = logs or {}
|
|
1229
|
+
for name, val in logs.items():
|
|
1230
|
+
if name not in self.data:
|
|
1231
|
+
self.data[name] = {'x': list(), 'y': list()}
|
|
1232
|
+
self.data[name]['x'].append(float(epoch))
|
|
1233
|
+
self.data[name]['y'].append(float(val))
|
|
1234
|
+
self.dump_history()
|
|
1235
|
+
|
|
1236
|
+
def dump_history(self):
|
|
1237
|
+
_json = {
|
|
1238
|
+
"query": {},
|
|
1239
|
+
"datasetId": "",
|
|
1240
|
+
"xlabel": "epoch",
|
|
1241
|
+
"title": "training loss",
|
|
1242
|
+
"ylabel": "val",
|
|
1243
|
+
"type": "metric",
|
|
1244
|
+
"data": [
|
|
1245
|
+
{
|
|
1246
|
+
"name": name,
|
|
1247
|
+
"x": values['x'],
|
|
1248
|
+
"y": values['y'],
|
|
1249
|
+
}
|
|
1250
|
+
for name, values in self.data.items()
|
|
1251
|
+
],
|
|
1252
|
+
}
|
|
1253
|
+
|
|
1254
|
+
with open(self.dump_file, 'w') as f:
|
|
1255
|
+
json.dump(_json, f, indent=2)
|
|
1256
|
+
|
|
1257
|
+
return DumpHistoryCallback
|