dtlpy 1.114.17__py3-none-any.whl → 1.115.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtlpy/__init__.py +1 -1
- dtlpy/__version__.py +1 -1
- dtlpy/entities/__init__.py +1 -1
- dtlpy/entities/analytic.py +42 -6
- dtlpy/entities/codebase.py +1 -5
- dtlpy/entities/compute.py +12 -5
- dtlpy/entities/dataset.py +19 -5
- dtlpy/entities/driver.py +14 -2
- dtlpy/entities/filters.py +156 -3
- dtlpy/entities/item.py +9 -3
- dtlpy/entities/prompt_item.py +7 -1
- dtlpy/entities/service.py +5 -0
- dtlpy/ml/base_model_adapter.py +407 -263
- dtlpy/repositories/commands.py +1 -7
- dtlpy/repositories/computes.py +17 -13
- dtlpy/repositories/datasets.py +287 -74
- dtlpy/repositories/downloader.py +23 -3
- dtlpy/repositories/drivers.py +12 -0
- dtlpy/repositories/executions.py +1 -3
- dtlpy/repositories/features.py +31 -14
- dtlpy/repositories/items.py +5 -2
- dtlpy/repositories/models.py +16 -4
- dtlpy/repositories/uploader.py +22 -12
- dtlpy/services/api_client.py +6 -3
- dtlpy/services/reporter.py +1 -1
- {dtlpy-1.114.17.dist-info → dtlpy-1.115.44.dist-info}/METADATA +15 -12
- {dtlpy-1.114.17.dist-info → dtlpy-1.115.44.dist-info}/RECORD +34 -34
- {dtlpy-1.114.17.data → dtlpy-1.115.44.data}/scripts/dlp +0 -0
- {dtlpy-1.114.17.data → dtlpy-1.115.44.data}/scripts/dlp.bat +0 -0
- {dtlpy-1.114.17.data → dtlpy-1.115.44.data}/scripts/dlp.py +0 -0
- {dtlpy-1.114.17.dist-info → dtlpy-1.115.44.dist-info}/WHEEL +0 -0
- {dtlpy-1.114.17.dist-info → dtlpy-1.115.44.dist-info}/entry_points.txt +0 -0
- {dtlpy-1.114.17.dist-info → dtlpy-1.115.44.dist-info}/licenses/LICENSE +0 -0
- {dtlpy-1.114.17.dist-info → dtlpy-1.115.44.dist-info}/top_level.txt +0 -0
dtlpy/ml/base_model_adapter.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import dataclasses
|
|
2
|
+
import threading
|
|
2
3
|
import tempfile
|
|
3
4
|
import datetime
|
|
4
5
|
import logging
|
|
@@ -6,22 +7,31 @@ import string
|
|
|
6
7
|
import shutil
|
|
7
8
|
import random
|
|
8
9
|
import base64
|
|
10
|
+
import copy
|
|
11
|
+
import time
|
|
9
12
|
import tqdm
|
|
13
|
+
import traceback
|
|
10
14
|
import sys
|
|
11
15
|
import io
|
|
12
16
|
import os
|
|
17
|
+
from itertools import chain
|
|
13
18
|
from PIL import Image
|
|
14
19
|
from functools import partial
|
|
15
20
|
import numpy as np
|
|
16
21
|
from concurrent.futures import ThreadPoolExecutor
|
|
17
22
|
import attr
|
|
18
23
|
from collections.abc import MutableMapping
|
|
24
|
+
from typing import Optional
|
|
19
25
|
from .. import entities, utilities, repositories, exceptions
|
|
20
26
|
from ..services import service_defaults
|
|
21
27
|
from ..services.api_client import ApiClient
|
|
22
28
|
|
|
23
29
|
logger = logging.getLogger('ModelAdapter')
|
|
24
30
|
|
|
31
|
+
# Constants
|
|
32
|
+
PREDICT_EMBED_DEFAULT_SUBSET_LIMIT = 1000
|
|
33
|
+
PREDICT_EMBED_DEFAULT_TIMEOUT = 3600 * 24
|
|
34
|
+
|
|
25
35
|
|
|
26
36
|
class ModelConfigurations(MutableMapping):
|
|
27
37
|
"""
|
|
@@ -37,11 +47,7 @@ class ModelConfigurations(MutableMapping):
|
|
|
37
47
|
# Store reference to base_model_adapter dictionary
|
|
38
48
|
self._backing_dict = {}
|
|
39
49
|
|
|
40
|
-
if
|
|
41
|
-
base_model_adapter is not None
|
|
42
|
-
and base_model_adapter.model_entity is not None
|
|
43
|
-
and base_model_adapter.model_entity.configuration is not None
|
|
44
|
-
):
|
|
50
|
+
if base_model_adapter is not None and base_model_adapter.model_entity is not None and base_model_adapter.model_entity.configuration is not None:
|
|
45
51
|
self._backing_dict = base_model_adapter.model_entity.configuration
|
|
46
52
|
if 'include_background' not in self._backing_dict:
|
|
47
53
|
self._backing_dict['include_background'] = False
|
|
@@ -148,6 +154,7 @@ class AdapterDefaults(ModelConfigurations):
|
|
|
148
154
|
|
|
149
155
|
class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
150
156
|
_client_api = attr.ib(type=ApiClient, repr=False)
|
|
157
|
+
_feature_set_lock = threading.Lock()
|
|
151
158
|
|
|
152
159
|
def __init__(self, model_entity: entities.Model = None):
|
|
153
160
|
self.logger = logger
|
|
@@ -161,12 +168,12 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
161
168
|
self.model = None
|
|
162
169
|
self.bucket_path = None
|
|
163
170
|
# funcs
|
|
164
|
-
self.item_to_batch_mapping = {'text': self._item_to_text,
|
|
165
|
-
'image': self._item_to_image}
|
|
171
|
+
self.item_to_batch_mapping = {'text': self._item_to_text, 'image': self._item_to_image}
|
|
166
172
|
if model_entity is not None:
|
|
167
173
|
self.load_from_model(model_entity=model_entity)
|
|
168
174
|
logger.warning(
|
|
169
|
-
"in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'."
|
|
175
|
+
"in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'."
|
|
176
|
+
)
|
|
170
177
|
|
|
171
178
|
##################
|
|
172
179
|
# Configurations #
|
|
@@ -199,8 +206,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
199
206
|
@property
|
|
200
207
|
def model_entity(self):
|
|
201
208
|
if self._model_entity is None:
|
|
202
|
-
raise ValueError(
|
|
203
|
-
"No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
|
|
209
|
+
raise ValueError("No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
|
|
204
210
|
assert isinstance(self._model_entity, entities.Model)
|
|
205
211
|
return self._model_entity
|
|
206
212
|
|
|
@@ -209,8 +215,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
209
215
|
assert isinstance(model_entity, entities.Model)
|
|
210
216
|
if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
|
|
211
217
|
if self._model_entity.id != model_entity.id:
|
|
212
|
-
self.logger.warning(
|
|
213
|
-
'Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
|
|
218
|
+
self.logger.warning('Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
|
|
214
219
|
self._model_entity = model_entity
|
|
215
220
|
self.package = model_entity.package
|
|
216
221
|
self.adapter_defaults = AdapterDefaults(self)
|
|
@@ -236,22 +241,24 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
236
241
|
###################################
|
|
237
242
|
|
|
238
243
|
def load(self, local_path, **kwargs):
|
|
239
|
-
"""
|
|
244
|
+
"""
|
|
245
|
+
Loads model and populates self.model with a `runnable` model
|
|
240
246
|
|
|
241
|
-
|
|
247
|
+
Virtual method - need to implement
|
|
242
248
|
|
|
243
|
-
|
|
249
|
+
This function is called by load_from_model (download to local and then loads)
|
|
244
250
|
|
|
245
251
|
:param local_path: `str` directory path in local FileSystem
|
|
246
252
|
"""
|
|
247
253
|
raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
|
|
248
254
|
|
|
249
255
|
def save(self, local_path, **kwargs):
|
|
250
|
-
"""
|
|
256
|
+
"""
|
|
257
|
+
Saves configuration and weights locally
|
|
251
258
|
|
|
252
|
-
|
|
259
|
+
Virtual method - need to implement
|
|
253
260
|
|
|
254
|
-
|
|
261
|
+
the function is called in save_to_model which first save locally and then uploads to model entity
|
|
255
262
|
|
|
256
263
|
:param local_path: `str` directory path in local FileSystem
|
|
257
264
|
"""
|
|
@@ -270,9 +277,10 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
270
277
|
raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
|
|
271
278
|
|
|
272
279
|
def predict(self, batch, **kwargs):
|
|
273
|
-
"""
|
|
280
|
+
"""
|
|
281
|
+
Model inference (predictions) on batch of items
|
|
274
282
|
|
|
275
|
-
|
|
283
|
+
Virtual method - need to implement
|
|
276
284
|
|
|
277
285
|
:param batch: output of the `prepare_item_func` func
|
|
278
286
|
:return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
|
|
@@ -280,9 +288,10 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
280
288
|
raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
|
|
281
289
|
|
|
282
290
|
def embed(self, batch, **kwargs):
|
|
283
|
-
"""
|
|
291
|
+
"""
|
|
292
|
+
Extract model embeddings on batch of items
|
|
284
293
|
|
|
285
|
-
|
|
294
|
+
Virtual method - need to implement
|
|
286
295
|
|
|
287
296
|
:param batch: output of the `prepare_item_func` func
|
|
288
297
|
:return: `list[list]` a feature vector per each item in the batch
|
|
@@ -300,19 +309,22 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
300
309
|
:return:
|
|
301
310
|
"""
|
|
302
311
|
import dtlpymetrics
|
|
312
|
+
|
|
303
313
|
compare_types = model.output_type
|
|
304
314
|
if not filters:
|
|
305
315
|
filters = entities.Filters()
|
|
306
316
|
if filters is not None and isinstance(filters, dict):
|
|
307
317
|
filters = entities.Filters(custom_filter=filters)
|
|
308
|
-
model = dtlpymetrics.scoring.create_model_score(
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
318
|
+
model = dtlpymetrics.scoring.create_model_score(
|
|
319
|
+
model=model,
|
|
320
|
+
dataset=dataset,
|
|
321
|
+
filters=filters,
|
|
322
|
+
compare_types=compare_types,
|
|
323
|
+
)
|
|
312
324
|
return model
|
|
313
325
|
|
|
314
326
|
def convert_from_dtlpy(self, data_path, **kwargs):
|
|
315
|
-
"""
|
|
327
|
+
"""Convert Dataloop structure data to model structured
|
|
316
328
|
|
|
317
329
|
Virtual method - need to implement
|
|
318
330
|
|
|
@@ -356,9 +368,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
356
368
|
include_model_annotations = self.model_entity.configuration.get("include_model_annotations", False)
|
|
357
369
|
if include_model_annotations is False:
|
|
358
370
|
if annotation_filters.custom_filter is None:
|
|
359
|
-
annotation_filters.add(
|
|
360
|
-
field="metadata.system.model.name", values=False, operator=entities.FiltersOperations.EXISTS
|
|
361
|
-
)
|
|
371
|
+
annotation_filters.add(field="metadata.system.model.name", values=False, operator=entities.FiltersOperations.EXISTS)
|
|
362
372
|
else:
|
|
363
373
|
annotation_filters.custom_filter['filter']['$and'].append({'metadata.system.model.name': {'$exists': False}})
|
|
364
374
|
return annotation_filters
|
|
@@ -403,11 +413,12 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
403
413
|
output_path = self.adapter_defaults.resolve("output_path", output_path)
|
|
404
414
|
if root_path is None:
|
|
405
415
|
now = datetime.datetime.now()
|
|
406
|
-
root_path = os.path.join(
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
416
|
+
root_path = os.path.join(
|
|
417
|
+
dataloop_path,
|
|
418
|
+
'model_data',
|
|
419
|
+
"{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
|
|
420
|
+
now.strftime('%Y-%m-%d-%H%M%S'),
|
|
421
|
+
)
|
|
411
422
|
if data_path is None:
|
|
412
423
|
data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
|
|
413
424
|
os.makedirs(data_path, exist_ok=True)
|
|
@@ -442,7 +453,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
442
453
|
annotation_filters = entities.Filters(
|
|
443
454
|
use_defaults=False,
|
|
444
455
|
resource=entities.FiltersResource.ANNOTATION,
|
|
445
|
-
custom_filter=annotations_subsets[subset]
|
|
456
|
+
custom_filter=annotations_subsets[subset],
|
|
446
457
|
)
|
|
447
458
|
# if user provided annotation_filters, skip the default filters
|
|
448
459
|
elif self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
|
|
@@ -472,7 +483,9 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
472
483
|
)
|
|
473
484
|
filters = entities.Filters(custom_filter=subsets[subset])
|
|
474
485
|
background_ret_list = self.__download_background_images(
|
|
475
|
-
filters=filters,
|
|
486
|
+
filters=filters,
|
|
487
|
+
data_subset_base_path=data_subset_base_path,
|
|
488
|
+
annotation_options=annotation_options,
|
|
476
489
|
)
|
|
477
490
|
ret_list = list(ret_list)
|
|
478
491
|
background_ret_list = list(background_ret_list)
|
|
@@ -494,7 +507,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
494
507
|
return root_path, data_path, output_path
|
|
495
508
|
|
|
496
509
|
def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
|
|
497
|
-
"""
|
|
510
|
+
"""Loads a model from given `dl.Model`.
|
|
498
511
|
Reads configurations and instantiate self.model_entity
|
|
499
512
|
Downloads the model_entity bucket (if available)
|
|
500
513
|
|
|
@@ -511,10 +524,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
511
524
|
# Point _configuration to the same object since AdapterDefaults inherits from ModelConfigurations
|
|
512
525
|
self._configuration = self.adapter_defaults
|
|
513
526
|
# Download
|
|
514
|
-
self.model_entity.artifacts.download(
|
|
515
|
-
local_path=local_path,
|
|
516
|
-
overwrite=overwrite
|
|
517
|
-
)
|
|
527
|
+
self.model_entity.artifacts.download(local_path=local_path, overwrite=overwrite)
|
|
518
528
|
self.load(local_path, **kwargs)
|
|
519
529
|
|
|
520
530
|
def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
|
|
@@ -538,11 +548,9 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
538
548
|
self.save(local_path=local_path, **kwargs)
|
|
539
549
|
|
|
540
550
|
if self.model_entity is None:
|
|
541
|
-
raise ValueError('Missing model entity on the adapter. '
|
|
542
|
-
'Please set before saving: "adapter.model_entity=model"')
|
|
551
|
+
raise ValueError('Missing model entity on the adapter. ' 'Please set before saving: "adapter.model_entity=model"')
|
|
543
552
|
|
|
544
|
-
self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'),
|
|
545
|
-
overwrite=True)
|
|
553
|
+
self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'), overwrite=True)
|
|
546
554
|
if cleanup:
|
|
547
555
|
shutil.rmtree(path=local_path, ignore_errors=True)
|
|
548
556
|
self.logger.info("Clean-up. deleting {}".format(local_path))
|
|
@@ -551,9 +559,11 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
551
559
|
# SERVICE METHODS
|
|
552
560
|
# ===============
|
|
553
561
|
|
|
554
|
-
@entities.Package.decorators.function(
|
|
555
|
-
|
|
556
|
-
|
|
562
|
+
@entities.Package.decorators.function(
|
|
563
|
+
display_name='Predict Items',
|
|
564
|
+
inputs={'items': 'Item[]'},
|
|
565
|
+
outputs={'items': 'Item[]', 'annotations': 'Annotation[]'},
|
|
566
|
+
)
|
|
557
567
|
def predict_items(self, items: list, upload_annotations=None, clean_annotations=None, batch_size=None, **kwargs):
|
|
558
568
|
"""
|
|
559
569
|
Run the predict function on the input list of items (or single) and return the items and the predictions.
|
|
@@ -571,31 +581,36 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
571
581
|
upload_annotations = self.adapter_defaults.resolve("upload_annotations", upload_annotations)
|
|
572
582
|
clean_annotations = self.adapter_defaults.resolve("clean_annotations", clean_annotations)
|
|
573
583
|
input_type = self.model_entity.input_type
|
|
574
|
-
self.logger.debug(
|
|
575
|
-
"Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
584
|
+
self.logger.debug("Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
576
585
|
pool = ThreadPoolExecutor(max_workers=16)
|
|
577
|
-
|
|
586
|
+
error_counter = 0
|
|
587
|
+
fail_ids = list()
|
|
578
588
|
annotations = list()
|
|
579
|
-
for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None,
|
|
580
|
-
|
|
581
|
-
batch_items = items[i_batch: i_batch + batch_size]
|
|
589
|
+
for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None, file=sys.stdout):
|
|
590
|
+
batch_items = items[i_batch : i_batch + batch_size]
|
|
582
591
|
batch = list(pool.map(self.prepare_item_func, batch_items))
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
592
|
+
try:
|
|
593
|
+
batch_collections = self.predict(batch, **kwargs)
|
|
594
|
+
except Exception as e:
|
|
595
|
+
item_ids = [item.id for item in batch_items]
|
|
596
|
+
self.logger.error(f"Failed to predict batch {i_batch} for items {item_ids}. Error: {e}\n{traceback.format_exc()}")
|
|
597
|
+
error_counter += 1
|
|
598
|
+
fail_ids.extend(item_ids)
|
|
599
|
+
continue
|
|
600
|
+
_futures = list(pool.map(partial(self._update_predictions_metadata), batch_items, batch_collections))
|
|
587
601
|
# Loop over the futures to make sure they are all done to avoid race conditions
|
|
588
602
|
_ = [_f for _f in _futures]
|
|
589
603
|
if upload_annotations is True:
|
|
590
|
-
self.logger.debug(
|
|
591
|
-
"Uploading items' annotation for model {!r}.".format(self.model_entity.name))
|
|
604
|
+
self.logger.debug("Uploading items' annotation for model {!r}.".format(self.model_entity.name))
|
|
592
605
|
try:
|
|
593
|
-
batch_collections = list(
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
batch_collections))
|
|
606
|
+
batch_collections = list(
|
|
607
|
+
pool.map(partial(self._upload_model_annotations, clean_annotations=clean_annotations), batch_items, batch_collections)
|
|
608
|
+
)
|
|
597
609
|
except Exception as err:
|
|
598
|
-
|
|
610
|
+
item_ids = [item.id for item in batch_items]
|
|
611
|
+
self.logger.error(f"Failed to upload annotations for items {item_ids}. Error: {err}\n{traceback.format_exc()}")
|
|
612
|
+
error_counter += 1
|
|
613
|
+
fail_ids.extend(item_ids)
|
|
599
614
|
|
|
600
615
|
for collection in batch_collections:
|
|
601
616
|
# function needs to return `List[List[dl.Annotation]]`
|
|
@@ -608,12 +623,16 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
608
623
|
# TODO call the callback
|
|
609
624
|
|
|
610
625
|
pool.shutdown()
|
|
626
|
+
if error_counter > 0:
|
|
627
|
+
raise Exception(f"Failed to predict all items. Failed IDs: {fail_ids}, See logs for more details")
|
|
611
628
|
return items, annotations
|
|
612
629
|
|
|
613
|
-
@entities.Package.decorators.function(
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
630
|
+
@entities.Package.decorators.function(
|
|
631
|
+
display_name='Embed Items',
|
|
632
|
+
inputs={'items': 'Item[]'},
|
|
633
|
+
outputs={'items': 'Item[]', 'features': 'Json[]'},
|
|
634
|
+
)
|
|
635
|
+
def embed_items(self, items: list, upload_features=None, batch_size=None, progress: utilities.Progress = None, **kwargs):
|
|
617
636
|
"""
|
|
618
637
|
Extract feature from an input list of items (or single) and return the items and the feature vector.
|
|
619
638
|
|
|
@@ -627,150 +646,141 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
627
646
|
batch_size = self.configuration.get('batch_size', 4)
|
|
628
647
|
upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
|
|
629
648
|
input_type = self.model_entity.input_type
|
|
630
|
-
self.logger.debug(
|
|
631
|
-
|
|
649
|
+
self.logger.debug("Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
650
|
+
error_counter = 0
|
|
651
|
+
fail_ids = list()
|
|
632
652
|
|
|
633
|
-
|
|
634
|
-
feature_set = self.model_entity.feature_set
|
|
635
|
-
if feature_set is None:
|
|
636
|
-
logger.info('Feature Set not found. creating... ')
|
|
637
|
-
try:
|
|
638
|
-
self.model_entity.project.feature_sets.get(feature_set_name=self.model_entity.name)
|
|
639
|
-
feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
|
|
640
|
-
logger.warning(
|
|
641
|
-
f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}")
|
|
642
|
-
except exceptions.NotFound:
|
|
643
|
-
feature_set_name = self.model_entity.name
|
|
644
|
-
feature_set = self.model_entity.project.feature_sets.create(name=feature_set_name,
|
|
645
|
-
entity_type=entities.FeatureEntityType.ITEM,
|
|
646
|
-
model_id=self.model_entity.id,
|
|
647
|
-
project_id=self.model_entity.project_id,
|
|
648
|
-
set_type=self.model_entity.name,
|
|
649
|
-
size=self.configuration.get('embeddings_size',
|
|
650
|
-
256))
|
|
651
|
-
logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
|
|
652
|
-
else:
|
|
653
|
-
logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
|
|
653
|
+
feature_set = self._get_feature_set()
|
|
654
654
|
|
|
655
655
|
# upload the feature vectors
|
|
656
656
|
pool = ThreadPoolExecutor(max_workers=16)
|
|
657
657
|
vectors = list()
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
658
|
+
_items = list()
|
|
659
|
+
for i_batch in tqdm.tqdm(
|
|
660
|
+
range(0, len(items), batch_size),
|
|
661
|
+
desc='embedding',
|
|
662
|
+
unit='bt',
|
|
663
|
+
leave=None,
|
|
664
|
+
file=sys.stdout,
|
|
665
|
+
):
|
|
666
|
+
batch_items = items[i_batch : i_batch + batch_size]
|
|
664
667
|
batch = list(pool.map(self.prepare_item_func, batch_items))
|
|
665
|
-
|
|
668
|
+
try:
|
|
669
|
+
batch_vectors = self.embed(batch, **kwargs)
|
|
670
|
+
except Exception as err:
|
|
671
|
+
item_ids = [item.id for item in batch_items]
|
|
672
|
+
self.logger.error(f"Failed to embed batch {i_batch} for items {item_ids}. Error: {err}\n{traceback.format_exc()}")
|
|
673
|
+
error_counter += 1
|
|
674
|
+
fail_ids.extend(item_ids)
|
|
675
|
+
continue
|
|
666
676
|
vectors.extend(batch_vectors)
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
"Uploading items' feature vectors for model {!r}.".format(self.model_entity.name))
|
|
670
|
-
try:
|
|
671
|
-
list(pool.map(partial(self._upload_model_features,
|
|
672
|
-
progress.logger if progress is not None else self.logger,
|
|
673
|
-
feature_set.id,
|
|
674
|
-
self.model_entity.project_id),
|
|
675
|
-
batch_items,
|
|
676
|
-
batch_vectors))
|
|
677
|
-
except Exception as err:
|
|
678
|
-
self.logger.exception("Failed to upload feature vectors to items.")
|
|
679
|
-
|
|
677
|
+
# Save the items in the order of the vectors
|
|
678
|
+
_items.extend(batch_items)
|
|
680
679
|
pool.shutdown()
|
|
681
|
-
return items, vectors
|
|
682
|
-
|
|
683
|
-
@entities.Package.decorators.function(display_name='Embed Dataset with DQL',
|
|
684
|
-
inputs={'dataset': 'Dataset',
|
|
685
|
-
'filters': 'Json'})
|
|
686
|
-
def embed_dataset(self,
|
|
687
|
-
dataset: entities.Dataset,
|
|
688
|
-
filters: entities.Filters = None,
|
|
689
|
-
upload_features=None,
|
|
690
|
-
batch_size=None,
|
|
691
|
-
progress:utilities.Progress=None,
|
|
692
|
-
**kwargs):
|
|
693
|
-
"""
|
|
694
|
-
Extract feature from all items given
|
|
695
|
-
|
|
696
|
-
:param dataset: Dataset entity to predict
|
|
697
|
-
:param filters: Filters entity for a filtering before embedding
|
|
698
|
-
:param upload_features: `bool` uploads the features back to the given items
|
|
699
|
-
:param batch_size: `int` size of batch to run a single embed
|
|
700
680
|
|
|
701
|
-
|
|
681
|
+
if upload_features is True:
|
|
682
|
+
_indicies_to_remove = list()
|
|
683
|
+
embeddings_size = self.configuration.get('embeddings_size', 256)
|
|
684
|
+
for idx, vector in enumerate(vectors):
|
|
685
|
+
if vector is None or len(vector) != embeddings_size:
|
|
686
|
+
self.logger.warning(f"Vector generated for item {_items[idx].id} is None or has wrong size. Skipping...")
|
|
687
|
+
_indicies_to_remove.append(idx)
|
|
688
|
+
|
|
689
|
+
# Remove indices in descending order to avoid IndexError
|
|
690
|
+
# When removing items, indices shift down, so we must remove from highest to lowest
|
|
691
|
+
for index in sorted(_indicies_to_remove, reverse=True):
|
|
692
|
+
_items.pop(index)
|
|
693
|
+
vectors.pop(index)
|
|
694
|
+
|
|
695
|
+
if len(_items) != len(vectors):
|
|
696
|
+
raise ValueError(f"The number of items ({len(_items)}) is not equal to the number of vectors ({len(vectors)}).")
|
|
697
|
+
self.logger.debug(f"Uploading {len(_items)} items' feature vectors for model {self.model_entity.name}.")
|
|
698
|
+
try:
|
|
699
|
+
start_time = time.time()
|
|
700
|
+
feature_set.features.create(entity=_items, value=vectors, feature_set_id=feature_set.id, project_id=self.model_entity.project_id)
|
|
701
|
+
self.logger.debug(f"Uploaded {len(_items)} items' feature vectors for model {self.model_entity.name} in {time.time() - start_time} seconds.")
|
|
702
|
+
except Exception as err:
|
|
703
|
+
self.logger.error(f"Failed to upload feature vectors. Error: {err}\n{traceback.format_exc()}")
|
|
704
|
+
error_counter += 1
|
|
705
|
+
if error_counter > 0:
|
|
706
|
+
raise Exception(f"Failed to embed all items. Failed IDs: {fail_ids}, See logs for more details")
|
|
707
|
+
return _items, vectors
|
|
708
|
+
|
|
709
|
+
@entities.Package.decorators.function(
|
|
710
|
+
display_name='Embed Dataset with DQL',
|
|
711
|
+
inputs={'dataset': 'Dataset', 'filters': 'Json'},
|
|
712
|
+
)
|
|
713
|
+
def embed_dataset(
|
|
714
|
+
self,
|
|
715
|
+
dataset: entities.Dataset,
|
|
716
|
+
filters: Optional[entities.Filters] = None,
|
|
717
|
+
upload_features: Optional[bool] = None,
|
|
718
|
+
batch_size: Optional[int] = None,
|
|
719
|
+
progress: Optional[utilities.Progress] = None,
|
|
720
|
+
**kwargs,
|
|
721
|
+
):
|
|
722
|
+
"""
|
|
723
|
+
Run model embedding on all items in a dataset
|
|
724
|
+
|
|
725
|
+
:param dataset: Dataset entity to embed
|
|
726
|
+
:param filters: Filters entity for filtering before embedding
|
|
727
|
+
:param upload_features: bool whether to upload features back to platform
|
|
728
|
+
:param batch_size: int size of batch to run a single embedding
|
|
729
|
+
:param progress: dl.Progress object to track progress
|
|
730
|
+
:return: bool indicating if embedding completed successfully
|
|
702
731
|
"""
|
|
703
|
-
if batch_size is None:
|
|
704
|
-
batch_size = self.configuration.get('batch_size', 4)
|
|
705
|
-
upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
|
|
706
|
-
|
|
707
|
-
self.logger.debug("Creating embeddings for dataset (name:{}, id:{}), using batch size {}".format(dataset.name,
|
|
708
|
-
dataset.id,
|
|
709
|
-
batch_size))
|
|
710
|
-
if not filters:
|
|
711
|
-
filters = entities.Filters()
|
|
712
|
-
if filters is not None and isinstance(filters, dict):
|
|
713
|
-
filters = entities.Filters(custom_filter=filters)
|
|
714
|
-
pages = dataset.items.list(filters=filters, page_size=batch_size)
|
|
715
|
-
# Item type is 'file' only, can be deleted if default filters are added to custom filters
|
|
716
|
-
items = [item for page in pages for item in page if item.type == 'file']
|
|
717
|
-
self.embed_items(items=items,
|
|
718
|
-
upload_features=upload_features,
|
|
719
|
-
batch_size=batch_size,
|
|
720
|
-
progress=progress,
|
|
721
|
-
**kwargs)
|
|
722
|
-
return True
|
|
723
732
|
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
733
|
+
status = self._execute_dataset_operation(
|
|
734
|
+
dataset=dataset,
|
|
735
|
+
operation_type='embed',
|
|
736
|
+
filters=filters,
|
|
737
|
+
progress=progress,
|
|
738
|
+
batch_size=batch_size,
|
|
739
|
+
)
|
|
740
|
+
if status is False:
|
|
741
|
+
raise ValueError(f"Failed to embed entire dataset, please check the logs for more details")
|
|
742
|
+
|
|
743
|
+
@entities.Package.decorators.function(
|
|
744
|
+
display_name='Predict Dataset with DQL',
|
|
745
|
+
inputs={'dataset': 'Dataset', 'filters': 'Json'},
|
|
746
|
+
)
|
|
747
|
+
def predict_dataset(
|
|
748
|
+
self,
|
|
749
|
+
dataset: entities.Dataset,
|
|
750
|
+
filters: Optional[entities.Filters] = None,
|
|
751
|
+
upload_annotations: Optional[bool] = None,
|
|
752
|
+
clean_annotations: Optional[bool] = None,
|
|
753
|
+
batch_size: Optional[int] = None,
|
|
754
|
+
progress: Optional[utilities.Progress] = None,
|
|
755
|
+
**kwargs,
|
|
756
|
+
):
|
|
734
757
|
"""
|
|
735
|
-
|
|
758
|
+
Run model prediction on all items in a dataset
|
|
736
759
|
|
|
737
760
|
:param dataset: Dataset entity to predict
|
|
738
|
-
:param filters: Filters entity for
|
|
739
|
-
:param upload_annotations:
|
|
740
|
-
:param clean_annotations:
|
|
741
|
-
:param batch_size:
|
|
742
|
-
|
|
743
|
-
:return:
|
|
761
|
+
:param filters: Filters entity for filtering before prediction
|
|
762
|
+
:param upload_annotations: bool whether to upload annotations back to platform
|
|
763
|
+
:param clean_annotations: bool whether to clean existing annotations
|
|
764
|
+
:param batch_size: int size of batch to run a single prediction
|
|
765
|
+
:param progress: dl.Progress object to track progress
|
|
766
|
+
:return: bool indicating if prediction completed successfully
|
|
744
767
|
"""
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
if
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
clean_annotations=clean_annotations,
|
|
762
|
-
batch_size=batch_size,
|
|
763
|
-
**kwargs)
|
|
764
|
-
return True
|
|
765
|
-
|
|
766
|
-
@entities.Package.decorators.function(display_name='Train a Model',
|
|
767
|
-
inputs={'model': entities.Model},
|
|
768
|
-
outputs={'model': entities.Model})
|
|
769
|
-
def train_model(self,
|
|
770
|
-
model: entities.Model,
|
|
771
|
-
cleanup=False,
|
|
772
|
-
progress: utilities.Progress = None,
|
|
773
|
-
context: utilities.Context = None):
|
|
768
|
+
status = self._execute_dataset_operation(
|
|
769
|
+
dataset=dataset,
|
|
770
|
+
operation_type='predict',
|
|
771
|
+
filters=filters,
|
|
772
|
+
progress=progress,
|
|
773
|
+
batch_size=batch_size,
|
|
774
|
+
)
|
|
775
|
+
if status is False:
|
|
776
|
+
raise ValueError(f"Failed to predict entire dataset, please check the logs for more details")
|
|
777
|
+
|
|
778
|
+
@entities.Package.decorators.function(
|
|
779
|
+
display_name='Train a Model',
|
|
780
|
+
inputs={'model': entities.Model},
|
|
781
|
+
outputs={'model': entities.Model},
|
|
782
|
+
)
|
|
783
|
+
def train_model(self, model: entities.Model, cleanup=False, progress: utilities.Progress = None, context: utilities.Context = None):
|
|
774
784
|
"""
|
|
775
785
|
Train on existing model.
|
|
776
786
|
data will be taken from dl.Model.datasetId
|
|
@@ -804,27 +814,19 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
804
814
|
################
|
|
805
815
|
# prepare data #
|
|
806
816
|
################
|
|
807
|
-
root_path, data_path, output_path = self.prepare_data(
|
|
808
|
-
dataset=self.model_entity.dataset,
|
|
809
|
-
root_path=os.path.join('tmp', model.id)
|
|
810
|
-
)
|
|
817
|
+
root_path, data_path, output_path = self.prepare_data(dataset=self.model_entity.dataset, root_path=os.path.join('tmp', model.id))
|
|
811
818
|
# Start the Train
|
|
812
|
-
logger.info("Training {p_name!r} with model {m_name!r} on data {d_path!r}".
|
|
813
|
-
format(p_name=self.package_name, m_name=model.id, d_path=data_path))
|
|
819
|
+
logger.info("Training {p_name!r} with model {m_name!r} on data {d_path!r}".format(p_name=self.package_name, m_name=model.id, d_path=data_path))
|
|
814
820
|
if progress is not None:
|
|
815
821
|
progress.update(message='starting training')
|
|
816
822
|
|
|
817
823
|
def on_epoch_end_callback(i_epoch, n_epoch):
|
|
818
824
|
if progress is not None:
|
|
819
|
-
progress.update(progress=int(100 * (i_epoch + 1) / n_epoch),
|
|
820
|
-
message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
|
|
825
|
+
progress.update(progress=int(100 * (i_epoch + 1) / n_epoch), message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
|
|
821
826
|
|
|
822
|
-
self.train(data_path=data_path,
|
|
823
|
-
output_path=output_path,
|
|
824
|
-
on_epoch_end_callback=on_epoch_end_callback)
|
|
827
|
+
self.train(data_path=data_path, output_path=output_path, on_epoch_end_callback=on_epoch_end_callback)
|
|
825
828
|
if progress is not None:
|
|
826
|
-
progress.update(message='saving model',
|
|
827
|
-
progress=99)
|
|
829
|
+
progress.update(message='saving model', progress=99)
|
|
828
830
|
|
|
829
831
|
self.save_to_model(local_path=output_path, replace=True)
|
|
830
832
|
model.status = 'trained'
|
|
@@ -842,20 +844,20 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
842
844
|
raise
|
|
843
845
|
return model
|
|
844
846
|
|
|
845
|
-
@entities.Package.decorators.function(
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
847
|
+
@entities.Package.decorators.function(
|
|
848
|
+
display_name='Evaluate a Model',
|
|
849
|
+
inputs={'model': entities.Model, 'dataset': entities.Dataset, 'filters': 'Json'},
|
|
850
|
+
outputs={'model': entities.Model, 'dataset': entities.Dataset},
|
|
851
|
+
)
|
|
852
|
+
def evaluate_model(
|
|
853
|
+
self,
|
|
854
|
+
model: entities.Model,
|
|
855
|
+
dataset: entities.Dataset,
|
|
856
|
+
filters: entities.Filters = None,
|
|
857
|
+
#
|
|
858
|
+
progress: utilities.Progress = None,
|
|
859
|
+
context: utilities.Context = None,
|
|
860
|
+
):
|
|
859
861
|
"""
|
|
860
862
|
Evaluate a model.
|
|
861
863
|
data will be downloaded from the dataset and query
|
|
@@ -869,14 +871,12 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
869
871
|
:param context:
|
|
870
872
|
:return:
|
|
871
873
|
"""
|
|
872
|
-
logger.info(
|
|
873
|
-
f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
|
|
874
|
+
logger.info(f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
|
|
874
875
|
##########################
|
|
875
876
|
# load model and weights #
|
|
876
877
|
##########################
|
|
877
878
|
logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
|
|
878
|
-
self.load_from_model(dataset=dataset,
|
|
879
|
-
model_entity=model)
|
|
879
|
+
self.load_from_model(dataset=dataset, model_entity=model)
|
|
880
880
|
|
|
881
881
|
##############
|
|
882
882
|
# Predicting #
|
|
@@ -884,42 +884,181 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
884
884
|
logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
|
|
885
885
|
if not filters:
|
|
886
886
|
filters = entities.Filters()
|
|
887
|
-
self.predict_dataset(dataset=dataset,
|
|
888
|
-
filters=filters,
|
|
889
|
-
with_upload=True)
|
|
887
|
+
self.predict_dataset(dataset=dataset, filters=filters, with_upload=True)
|
|
890
888
|
|
|
891
889
|
##############
|
|
892
890
|
# Evaluating #
|
|
893
891
|
##############
|
|
894
892
|
logger.info(f"Starting adapter.evaluate()")
|
|
895
893
|
if progress is not None:
|
|
896
|
-
progress.update(message='calculating metrics',
|
|
897
|
-
|
|
898
|
-
model = self.evaluate(model=model,
|
|
899
|
-
dataset=dataset,
|
|
900
|
-
filters=filters)
|
|
894
|
+
progress.update(message='calculating metrics', progress=98)
|
|
895
|
+
model = self.evaluate(model=model, dataset=dataset, filters=filters)
|
|
901
896
|
#########
|
|
902
897
|
# Done! #
|
|
903
898
|
#########
|
|
904
899
|
if progress is not None:
|
|
905
|
-
progress.update(message='finishing evaluation',
|
|
906
|
-
progress=99)
|
|
900
|
+
progress.update(message='finishing evaluation', progress=99)
|
|
907
901
|
return model, dataset
|
|
908
902
|
|
|
909
903
|
# =============
|
|
910
904
|
# INNER METHODS
|
|
911
905
|
# =============
|
|
906
|
+
def _get_feature_set(self):
|
|
907
|
+
# Ensure feature set creation/retrieval is thread-safe across the class
|
|
908
|
+
with self.__class__._feature_set_lock:
|
|
909
|
+
# Search for existing feature set for this model id
|
|
910
|
+
feature_set = self.model_entity.feature_set
|
|
911
|
+
if feature_set is None:
|
|
912
|
+
logger.info('Feature Set not found. creating... ')
|
|
913
|
+
try:
|
|
914
|
+
self.model_entity.project.feature_sets.get(feature_set_name=self.model_entity.name)
|
|
915
|
+
feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
|
|
916
|
+
logger.warning(
|
|
917
|
+
f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}"
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
except exceptions.NotFound:
|
|
921
|
+
feature_set_name = self.model_entity.name
|
|
922
|
+
feature_set = self.model_entity.project.feature_sets.create(
|
|
923
|
+
name=feature_set_name,
|
|
924
|
+
entity_type=entities.FeatureEntityType.ITEM,
|
|
925
|
+
model_id=self.model_entity.id,
|
|
926
|
+
project_id=self.model_entity.project_id,
|
|
927
|
+
set_type=self.model_entity.name,
|
|
928
|
+
size=self.configuration.get('embeddings_size', 256),
|
|
929
|
+
)
|
|
930
|
+
logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
|
|
931
|
+
else:
|
|
932
|
+
logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
|
|
933
|
+
return feature_set
|
|
912
934
|
|
|
913
|
-
|
|
914
|
-
|
|
935
|
+
def _execute_dataset_operation(
|
|
936
|
+
self,
|
|
937
|
+
dataset: entities.Dataset,
|
|
938
|
+
operation_type: str,
|
|
939
|
+
filters: Optional[entities.Filters] = None,
|
|
940
|
+
progress: Optional[utilities.Progress] = None,
|
|
941
|
+
batch_size: Optional[int] = None,
|
|
942
|
+
) -> bool:
|
|
943
|
+
"""
|
|
944
|
+
Execute dataset operation (predict/embed) with batching and filtering support.
|
|
945
|
+
|
|
946
|
+
:param dataset: Dataset entity to run operation on
|
|
947
|
+
:param operation_type: Type of operation to execute ('predict' or 'embed')
|
|
948
|
+
:param filters: Filters entity to filter items, default None
|
|
949
|
+
:param progress: Progress object for tracking progress, default None
|
|
950
|
+
:param batch_size: Size of batches to process items, default None (uses model config)
|
|
951
|
+
:return: True if operation completes successfully
|
|
952
|
+
:raises ValueError: If operation_type is not 'predict' or 'embed'
|
|
953
|
+
"""
|
|
954
|
+
self.logger.debug(f"Running {operation_type} for dataset (name:{dataset.name}, id:{dataset.id})")
|
|
955
|
+
|
|
956
|
+
if not filters:
|
|
957
|
+
self.logger.debug("No filters provided, using default filters")
|
|
958
|
+
filters = entities.Filters()
|
|
959
|
+
if filters is not None and isinstance(filters, dict):
|
|
960
|
+
self.logger.debug(f"Received custom filters {filters}")
|
|
961
|
+
filters = entities.Filters(custom_filter=filters)
|
|
962
|
+
|
|
963
|
+
if operation_type == 'embed':
|
|
964
|
+
feature_set = self._get_feature_set()
|
|
965
|
+
logger.info(f"Feature set found! name: {feature_set.name}, id: {feature_set.id}")
|
|
966
|
+
|
|
967
|
+
predict_embed_subset_limit = self.configuration.get('predict_embed_subset_limit', PREDICT_EMBED_DEFAULT_SUBSET_LIMIT)
|
|
968
|
+
predict_embed_timeout = self.configuration.get('predict_embed_timeout', PREDICT_EMBED_DEFAULT_TIMEOUT)
|
|
969
|
+
self.logger.debug(f"Inputs: predict_embed_subset_limit: {predict_embed_subset_limit}, predict_embed_timeout: {predict_embed_timeout}")
|
|
970
|
+
tmp_filters = copy.deepcopy(filters.prepare())
|
|
971
|
+
tmp_filters['pageSize'] = 0
|
|
972
|
+
num_items = dataset.items.list(filters=entities.Filters(custom_filter=tmp_filters)).items_count
|
|
973
|
+
self.logger.debug(f"Number of items for current filters: {num_items}")
|
|
974
|
+
|
|
975
|
+
# One-item lookahead on generator: if only one subset, run locally; else create executions for all
|
|
976
|
+
gen = entities.Filters._get_split_filters(dataset, filters, predict_embed_subset_limit)
|
|
915
977
|
try:
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
978
|
+
first_filter = next(gen)
|
|
979
|
+
except StopIteration:
|
|
980
|
+
self.logger.info("Filters is empty, nothing to run")
|
|
981
|
+
return True
|
|
982
|
+
|
|
983
|
+
try:
|
|
984
|
+
second_filter = next(gen)
|
|
985
|
+
multiple = True
|
|
986
|
+
except StopIteration:
|
|
987
|
+
multiple = False
|
|
988
|
+
|
|
989
|
+
if not multiple:
|
|
990
|
+
self.logger.info("Split filters has only one subset, running locally")
|
|
991
|
+
if batch_size is None:
|
|
992
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
993
|
+
first_filter["pageSize"] = 1000
|
|
994
|
+
single_filters = entities.Filters(custom_filter=first_filter)
|
|
995
|
+
pages = dataset.items.list(filters=single_filters)
|
|
996
|
+
self.logger.info(f"Single run pages on: {pages.items_count} items")
|
|
997
|
+
items = [item for page in pages for item in page if item.type == 'file']
|
|
998
|
+
self.logger.debug(f"Single run items length: {len(items)}")
|
|
999
|
+
if operation_type == 'embed':
|
|
1000
|
+
self.embed_items(items=items, batch_size=batch_size, progress=progress)
|
|
1001
|
+
elif operation_type == 'predict':
|
|
1002
|
+
self.predict_items(items=items, batch_size=batch_size, progress=progress)
|
|
1003
|
+
else:
|
|
1004
|
+
raise ValueError(f"Unsupported operation type: {operation_type}")
|
|
1005
|
+
return True
|
|
1006
|
+
|
|
1007
|
+
executions = []
|
|
1008
|
+
for filter_dict in chain([first_filter, second_filter], gen):
|
|
1009
|
+
self.logger.debug(f"Creating execution for models {operation_type} with dataset id {dataset.id} and filter_dict {filter_dict}")
|
|
1010
|
+
if operation_type == 'embed':
|
|
1011
|
+
execution = self.model_entity.models.embed(
|
|
1012
|
+
model=self.model_entity,
|
|
1013
|
+
dataset_id=dataset.id,
|
|
1014
|
+
filters=entities.Filters(custom_filter=filter_dict),
|
|
1015
|
+
)
|
|
1016
|
+
elif operation_type == 'predict':
|
|
1017
|
+
execution = self.model_entity.models.predict(
|
|
1018
|
+
model=self.model_entity, dataset_id=dataset.id, filters=entities.Filters(custom_filter=filter_dict)
|
|
1019
|
+
)
|
|
1020
|
+
else:
|
|
1021
|
+
raise ValueError(f"Unsupported operation type: {operation_type}")
|
|
1022
|
+
executions.append(execution)
|
|
1023
|
+
|
|
1024
|
+
if executions:
|
|
1025
|
+
self.logger.info(f'Created {len(executions)} executions for {operation_type}, ' f'execution ids: {[ex.id for ex in executions]}')
|
|
1026
|
+
|
|
1027
|
+
wait_time = 5
|
|
1028
|
+
start_time = time.time()
|
|
1029
|
+
last_perc = 0
|
|
1030
|
+
self.logger.debug(f"Waiting for executions with timeout {predict_embed_timeout}")
|
|
1031
|
+
while time.time() - start_time < predict_embed_timeout:
|
|
1032
|
+
continue_loop = False
|
|
1033
|
+
total_perc = 0
|
|
1034
|
+
|
|
1035
|
+
for ex in executions:
|
|
1036
|
+
execution = dataset.project.executions.get(execution_id=ex.id)
|
|
1037
|
+
perc = execution.latest_status.get('percentComplete', 0)
|
|
1038
|
+
total_perc += perc
|
|
1039
|
+
if execution.in_progress():
|
|
1040
|
+
continue_loop = True
|
|
1041
|
+
|
|
1042
|
+
avg_perc = round(total_perc / len(executions), 0)
|
|
1043
|
+
if progress is not None and last_perc != avg_perc:
|
|
1044
|
+
last_perc = avg_perc
|
|
1045
|
+
progress.update(progress=last_perc, message=f'running {operation_type}')
|
|
1046
|
+
|
|
1047
|
+
if not continue_loop:
|
|
1048
|
+
break
|
|
1049
|
+
|
|
1050
|
+
time.sleep(wait_time)
|
|
1051
|
+
self.logger.debug("End waiting for executions")
|
|
1052
|
+
# Check if any execution failed
|
|
1053
|
+
executions_filter = entities.Filters(resource=entities.FiltersResource.EXECUTION)
|
|
1054
|
+
executions_filter.add(field="id", values=[ex.id for ex in executions], operator=entities.FiltersOperations.IN)
|
|
1055
|
+
executions_filter.add(field='latestStatus.status', values='failed')
|
|
1056
|
+
executions_filter.page_size = 0
|
|
1057
|
+
failed_executions_count = dataset.project.executions.list(filters=executions_filter).items_count
|
|
1058
|
+
if failed_executions_count > 0:
|
|
1059
|
+
self.logger.error(f"Failed to {operation_type} for {failed_executions_count} executions")
|
|
1060
|
+
return False
|
|
1061
|
+
return True
|
|
923
1062
|
|
|
924
1063
|
def _upload_model_annotations(self, item: entities.Item, predictions, clean_annotations):
|
|
925
1064
|
"""
|
|
@@ -928,8 +1067,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
928
1067
|
:param cleanup: `bool` if set removes existing predictions with the same package-model name
|
|
929
1068
|
"""
|
|
930
1069
|
if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
|
|
931
|
-
raise TypeError('predictions was expected to be of type {}, but instead it is {}'
|
|
932
|
-
format(entities.AnnotationCollection, type(predictions)))
|
|
1070
|
+
raise TypeError(f'predictions was expected to be of type {entities.AnnotationCollection}, but instead it is {type(predictions)}')
|
|
933
1071
|
if clean_annotations:
|
|
934
1072
|
clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
|
|
935
1073
|
clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name)
|
|
@@ -947,8 +1085,12 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
947
1085
|
:param item:
|
|
948
1086
|
:return:
|
|
949
1087
|
"""
|
|
950
|
-
|
|
951
|
-
|
|
1088
|
+
try:
|
|
1089
|
+
buffer = item.download(save_locally=False)
|
|
1090
|
+
image = np.asarray(Image.open(buffer))
|
|
1091
|
+
except Exception as e:
|
|
1092
|
+
logger.error(f"Failed to convert image to np.array, Error: {e}\n{traceback.format_exc()}")
|
|
1093
|
+
image = None
|
|
952
1094
|
return image
|
|
953
1095
|
|
|
954
1096
|
@staticmethod
|
|
@@ -1004,8 +1146,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
1004
1146
|
if color is None:
|
|
1005
1147
|
if self.model_entity._dataset is not None:
|
|
1006
1148
|
try:
|
|
1007
|
-
color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label,
|
|
1008
|
-
(255, 255, 255))
|
|
1149
|
+
color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label, (255, 255, 255))
|
|
1009
1150
|
except (exceptions.BadRequest, exceptions.NotFound):
|
|
1010
1151
|
...
|
|
1011
1152
|
if color is None:
|
|
@@ -1025,7 +1166,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
1025
1166
|
prediction.metadata['system']['model'] = {
|
|
1026
1167
|
'model_id': self.model_entity.id,
|
|
1027
1168
|
'name': self.model_entity.name,
|
|
1028
|
-
'confidence': confidence
|
|
1169
|
+
'confidence': confidence,
|
|
1029
1170
|
}
|
|
1030
1171
|
|
|
1031
1172
|
##############################
|
|
@@ -1042,8 +1183,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
1042
1183
|
try:
|
|
1043
1184
|
import keras
|
|
1044
1185
|
except (ImportError, ModuleNotFoundError) as err:
|
|
1045
|
-
raise RuntimeError(
|
|
1046
|
-
'{} depends on extenral package. Please install '.format(self.__class__.__name__)) from err
|
|
1186
|
+
raise RuntimeError(f'{self.__class__.__name__} depends on extenral package. Please install ') from err
|
|
1047
1187
|
|
|
1048
1188
|
import os
|
|
1049
1189
|
import time
|
|
@@ -1053,8 +1193,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
1053
1193
|
def __init__(self, dump_path):
|
|
1054
1194
|
super().__init__()
|
|
1055
1195
|
if os.path.isdir(dump_path):
|
|
1056
|
-
dump_path = os.path.join(dump_path,
|
|
1057
|
-
'__view__training-history__{}.json'.format(time.strftime("%F-%X")))
|
|
1196
|
+
dump_path = os.path.join(dump_path, f'__view__training-history__{time.strftime("%F-%X")}.json')
|
|
1058
1197
|
self.dump_file = dump_path
|
|
1059
1198
|
self.data = dict()
|
|
1060
1199
|
|
|
@@ -1075,9 +1214,14 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
1075
1214
|
"title": "training loss",
|
|
1076
1215
|
"ylabel": "val",
|
|
1077
1216
|
"type": "metric",
|
|
1078
|
-
"data": [
|
|
1079
|
-
|
|
1080
|
-
|
|
1217
|
+
"data": [
|
|
1218
|
+
{
|
|
1219
|
+
"name": name,
|
|
1220
|
+
"x": values['x'],
|
|
1221
|
+
"y": values['y'],
|
|
1222
|
+
}
|
|
1223
|
+
for name, values in self.data.items()
|
|
1224
|
+
],
|
|
1081
1225
|
}
|
|
1082
1226
|
|
|
1083
1227
|
with open(self.dump_file, 'w') as f:
|