dtlpy 1.91.37__py3-none-any.whl → 1.92.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. dtlpy/__init__.py +5 -2
  2. dtlpy/__version__.py +1 -1
  3. dtlpy/entities/__init__.py +1 -1
  4. dtlpy/entities/command.py +3 -2
  5. dtlpy/entities/dataset.py +52 -2
  6. dtlpy/entities/feature_set.py +3 -0
  7. dtlpy/entities/filters.py +2 -2
  8. dtlpy/entities/item.py +15 -1
  9. dtlpy/entities/node.py +11 -1
  10. dtlpy/entities/ontology.py +36 -40
  11. dtlpy/entities/pipeline.py +20 -1
  12. dtlpy/entities/pipeline_execution.py +23 -0
  13. dtlpy/entities/prompt_item.py +240 -37
  14. dtlpy/entities/service.py +5 -5
  15. dtlpy/ml/base_model_adapter.py +101 -41
  16. dtlpy/new_instance.py +80 -9
  17. dtlpy/repositories/apps.py +56 -10
  18. dtlpy/repositories/commands.py +10 -2
  19. dtlpy/repositories/datasets.py +142 -12
  20. dtlpy/repositories/dpks.py +5 -1
  21. dtlpy/repositories/feature_sets.py +23 -3
  22. dtlpy/repositories/models.py +1 -1
  23. dtlpy/repositories/pipeline_executions.py +53 -0
  24. dtlpy/repositories/uploader.py +3 -0
  25. dtlpy/services/api_client.py +59 -3
  26. {dtlpy-1.91.37.dist-info → dtlpy-1.92.19.dist-info}/METADATA +1 -1
  27. {dtlpy-1.91.37.dist-info → dtlpy-1.92.19.dist-info}/RECORD +35 -38
  28. tests/features/environment.py +29 -0
  29. dtlpy/callbacks/__init__.py +0 -16
  30. dtlpy/callbacks/piper_progress_reporter.py +0 -29
  31. dtlpy/callbacks/progress_viewer.py +0 -54
  32. {dtlpy-1.91.37.data → dtlpy-1.92.19.data}/scripts/dlp +0 -0
  33. {dtlpy-1.91.37.data → dtlpy-1.92.19.data}/scripts/dlp.bat +0 -0
  34. {dtlpy-1.91.37.data → dtlpy-1.92.19.data}/scripts/dlp.py +0 -0
  35. {dtlpy-1.91.37.dist-info → dtlpy-1.92.19.dist-info}/LICENSE +0 -0
  36. {dtlpy-1.91.37.dist-info → dtlpy-1.92.19.dist-info}/WHEEL +0 -0
  37. {dtlpy-1.91.37.dist-info → dtlpy-1.92.19.dist-info}/entry_points.txt +0 -0
  38. {dtlpy-1.91.37.dist-info → dtlpy-1.92.19.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,11 @@
1
1
  import logging
2
- import io
3
2
  import enum
4
3
  import json
4
+ import os.path
5
+ from dtlpy import entities, repositories
6
+ from dtlpy.services.api_client import client as client_api
7
+ import base64
8
+ import requests
5
9
 
6
10
  logger = logging.getLogger(name='dtlpy')
7
11
 
@@ -11,27 +15,37 @@ class PromptType(str, enum.Enum):
11
15
  IMAGE = 'image/*'
12
16
  AUDIO = 'audio/*'
13
17
  VIDEO = 'video/*'
18
+ METADATA = 'metadata'
14
19
 
15
20
 
16
21
  class Prompt:
17
- def __init__(self, key):
22
+ def __init__(self, key, role='user'):
18
23
  """
19
24
  Create a single Prompt. Prompt can contain multiple mimetype elements, e.g. text sentence and an image.
20
-
21
25
  :param key: unique identifier of the prompt in the item
22
26
  """
23
27
  self.key = key
24
28
  self.elements = list()
29
+ self._items = repositories.Items(client_api=client_api)
30
+ self.metadata = {'role': role}
25
31
 
26
- def add(self, value, mimetype='text'):
32
+ def add_element(self, value, mimetype='application/text'):
27
33
  """
28
34
 
29
35
  :param value: url or string of the input
30
36
  :param mimetype: mimetype of the input. options: `text`, `image/*`, `video/*`, `audio/*`
31
37
  :return:
32
38
  """
33
- self.elements.append({'mimetype': mimetype,
34
- 'value': value})
39
+ allowed_prompt_types = [prompt_type for prompt_type in PromptType]
40
+ if mimetype not in allowed_prompt_types:
41
+ raise ValueError(f'Invalid mimetype: {mimetype}. Allowed values: {allowed_prompt_types}')
42
+ if not isinstance(value, str) and mimetype != PromptType.METADATA:
43
+ raise ValueError(f'Expected str for Prompt element value, got {type(value)} instead')
44
+ if mimetype == PromptType.METADATA and isinstance(value, dict):
45
+ self.metadata.update(value)
46
+ else:
47
+ self.elements.append({'mimetype': mimetype,
48
+ 'value': value})
35
49
 
36
50
  def to_json(self):
37
51
  """
@@ -39,26 +53,169 @@ class Prompt:
39
53
 
40
54
  :return:
41
55
  """
56
+ elements_json = [
57
+ {
58
+ "mimetype": e['mimetype'],
59
+ "value": e['value'],
60
+ } for e in self.elements
61
+ ]
62
+ elements_json.append({
63
+ "mimetype": PromptType.METADATA,
64
+ "value": self.metadata
65
+ })
42
66
  return {
43
- self.key: [
44
- {
45
- "mimetype": e['mimetype'],
46
- "value": e['value']
47
- } for e in self.elements
48
- ]
67
+ self.key: elements_json
49
68
  }
50
69
 
51
-
52
- class PromptItem:
53
- def __init__(self, name):
70
+ def _convert_stream_to_binary(self, image_url: str):
54
71
  """
55
- Create a new Prompt Item. Single item can have multiple prompt, e.g. a conversation.
72
+ Convert a stream to binary
73
+ :param image_url: dataloop image stream url
74
+ :return: binary object
75
+ """
76
+ image_buffer = None
77
+ if '.' in image_url and 'dataloop.ai' not in image_url:
78
+ # URL and not DL item stream
79
+ try:
80
+ response = requests.get(image_url, stream=True)
81
+ response.raise_for_status() # Raise an exception for bad status codes
82
+
83
+ # Check for valid image content type
84
+ if response.headers["Content-Type"].startswith("image/"):
85
+ # Read the image data in chunks to avoid loading large images in memory
86
+ image_buffer = b"".join(chunk for chunk in response.iter_content(1024))
87
+ except requests.exceptions.RequestException as e:
88
+ logger.error(f"Failed to download image from URL: {image_url}, error: {e}")
89
+
90
+ elif '.' in image_url and 'stream' in image_url:
91
+ # DL Stream URL
92
+ item_id = image_url.split("/stream")[0].split("/items/")[-1]
93
+ image_buffer = self._items.get(item_id=item_id).download(save_locally=False).getvalue()
94
+ else:
95
+ # DL item ID
96
+ image_buffer = self._items.get(item_id=image_url).download(save_locally=False).getvalue()
56
97
 
57
- :param name: name of the item (filename)
98
+ if image_buffer is not None:
99
+ encoded_image = base64.b64encode(image_buffer).decode()
100
+ else:
101
+ logger.error(f'Invalid image url: {image_url}')
102
+ return None
103
+
104
+ return f'data:image/jpeg;base64,{encoded_image}'
105
+
106
+ def messages(self):
107
+ """
108
+ return a list of messages in the prompt item,
109
+ messages are returned following the openai SDK format https://platform.openai.com/docs/guides/vision
58
110
  """
111
+ messages = []
112
+ for element in self.elements:
113
+ if element['mimetype'] == PromptType.TEXT:
114
+ data = {
115
+ "type": "text",
116
+ "text": element['value']
117
+ }
118
+ messages.append(data)
119
+ elif element['mimetype'] == PromptType.IMAGE:
120
+ image_url = self._convert_stream_to_binary(element['value'])
121
+ data = {
122
+ "type": "image_url",
123
+ "image_url": {
124
+ "url": image_url
125
+ }
126
+ }
127
+ messages.append(data)
128
+ elif element['mimetype'] == PromptType.AUDIO:
129
+ raise NotImplementedError('Audio prompt is not supported yet')
130
+ elif element['mimetype'] == PromptType.VIDEO:
131
+ raise NotImplementedError('Video prompt is not supported yet')
132
+ else:
133
+ raise ValueError(f'Invalid mimetype: {element["mimetype"]}')
134
+ return messages, self.key
135
+
136
+
137
+ class PromptItem:
138
+ def __init__(self, name, item: entities.Item = None):
139
+ # prompt item name
59
140
  self.name = name
60
- self.type = "prompt"
141
+ # list of user prompts in the prompt item
61
142
  self.prompts = list()
143
+ # list of assistant (annotations) prompts in the prompt item
144
+ self.assistant_prompts = dict()
145
+ # Dataloop Item
146
+ self._item = None
147
+
148
+ @classmethod
149
+ def from_item(cls, item: entities.Item):
150
+ """
151
+ Load a prompt item from the platform
152
+ :param item : Item object
153
+ :return: PromptItem object
154
+ """
155
+ if 'json' not in item.mimetype or item.system.get('shebang', dict()).get('dltype') != 'prompt':
156
+ raise ValueError('Expecting a json item with system.shebang.dltype = prompt')
157
+ # Not using `save_locally=False` to use the from_local_file method
158
+ item_file_path = item.download()
159
+ prompt_item = cls.from_local_file(file_path=item_file_path)
160
+ if os.path.exists(item_file_path):
161
+ os.remove(item_file_path)
162
+ prompt_item._item = item
163
+ return prompt_item
164
+
165
+ @classmethod
166
+ def from_local_file(cls, file_path):
167
+ """
168
+ Create a new prompt item from a file
169
+ :param file_path: path to the file
170
+ :return: PromptItem object
171
+ """
172
+ if os.path.exists(file_path) is False:
173
+ raise FileNotFoundError(f'File does not exists: {file_path}')
174
+ if 'json' not in os.path.splitext(file_path)[-1]:
175
+ raise ValueError(f'Expected path to json item, got {os.path.splitext(file_path)[-1]}')
176
+ prompt_item = cls(name=file_path)
177
+ with open(file_path, 'r') as f:
178
+ data = json.load(f)
179
+ for prompt_key, prompt_values in data.get('prompts', dict()).items():
180
+ prompt = Prompt(key=prompt_key)
181
+ for val in prompt_values:
182
+ if val['mimetype'] == PromptType.METADATA:
183
+ _ = val.pop('mimetype')
184
+ prompt.add_element(value=val, mimetype=PromptType.METADATA)
185
+ else:
186
+ prompt.add_element(mimetype=val['mimetype'], value=val['value'])
187
+ prompt_item.add_prompt(prompt=prompt, update_item=False)
188
+ return prompt_item
189
+
190
+ def get_assistant_messages(self, annotations: entities.AnnotationCollection):
191
+ """
192
+ Get all the annotations in the item for the assistant messages
193
+ """
194
+ # clearing the assistant prompts from previous annotations that might not belong
195
+ self.assistant_prompts = dict()
196
+ for annotation in annotations:
197
+ prompt_id = annotation.metadata.get('system', dict()).get('promptId', None)
198
+ if annotation.type == 'ref_image':
199
+ prompt = Prompt(key=prompt_id)
200
+ prompt.add_element(value=annotation.coordinates.get('ref'), mimetype=PromptType.IMAGE)
201
+ self.assistant_prompts[annotation.id] = prompt
202
+ elif annotation.type == 'text':
203
+ prompt = Prompt(key=prompt_id)
204
+ prompt.add_element(value=annotation.coordinates, mimetype=PromptType.TEXT)
205
+ self.assistant_prompts[annotation.id] = prompt
206
+
207
+ def get_assistant_prompts(self, model_name):
208
+ """
209
+ Get assistant prompts
210
+ :return:
211
+ """
212
+ if self._item is None:
213
+ logger.warning('Item is not loaded, skipping annotations context')
214
+ return
215
+ filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
216
+ filters.add(field='metadata.user.model.name', values=model_name)
217
+ annotations = self._item.annotations.list(filters=filters)
218
+ self.get_assistant_messages(annotations=annotations)
62
219
 
63
220
  def to_json(self):
64
221
  """
@@ -69,36 +226,82 @@ class PromptItem:
69
226
  prompts_json = {
70
227
  "shebang": "dataloop",
71
228
  "metadata": {
72
- "dltype": self.type
229
+ "dltype": 'prompt'
73
230
  },
74
231
  "prompts": {}
75
232
  }
76
233
  for prompt in self.prompts:
77
234
  for prompt_key, prompt_values in prompt.to_json().items():
78
235
  prompts_json["prompts"][prompt_key] = prompt_values
79
- return prompts_json
80
-
81
- @classmethod
82
- def from_json(cls, _json):
83
- inst = cls(name='dummy')
84
- for prompt_key, prompt_values in _json["prompts"].items():
85
- prompt = Prompt(key=prompt_key)
86
- for val in prompt_values:
87
- prompt.add(mimetype=val['mimetype'], value=val['value'])
88
- inst.prompts.append(prompt)
89
- return inst
236
+ prompts_json["prompts"][prompt_key].append({'metadata'})
90
237
 
91
- def to_bytes_io(self):
92
- byte_io = io.BytesIO()
93
- byte_io.name = self.name
94
- byte_io.write(json.dumps(self.to_json()).encode())
95
- byte_io.seek(0)
96
- return byte_io
238
+ return prompts_json
97
239
 
98
- def add(self, prompt):
240
+ def add_prompt(self, prompt: Prompt, update_item=True):
99
241
  """
100
242
  add a prompt to the prompt item
101
243
  prompt: a dictionary. keys are prompt message id, values are prompt messages
102
244
  responses: a list of annotations representing responses to the prompt
103
245
  """
104
246
  self.prompts.append(prompt)
247
+ if update_item is True:
248
+ if self._item is not None:
249
+ self._item._Item__update_item_binary(_json=self.to_json())
250
+ else:
251
+ logger.warning('Item is not loaded, skipping upload')
252
+
253
+ def messages(self, model_name=None):
254
+ """
255
+ return a list of messages in the prompt item
256
+ messages are returned following the openai SDK format
257
+ """
258
+ if model_name is not None:
259
+ self.get_assistant_prompts(model_name=model_name)
260
+ else:
261
+ logger.warning('Model name is not provided, skipping assistant prompts')
262
+
263
+ all_prompts_messages = dict()
264
+ for prompt in self.prompts:
265
+ if prompt.key not in all_prompts_messages:
266
+ all_prompts_messages[prompt.key] = list()
267
+ prompt_messages, prompt_key = prompt.messages()
268
+ messages = {
269
+ 'role': prompt.metadata.get('role', 'user'),
270
+ 'content': prompt_messages
271
+ }
272
+ all_prompts_messages[prompt.key].append(messages)
273
+
274
+ for ann_id, prompt in self.assistant_prompts.items():
275
+ if prompt.key not in all_prompts_messages:
276
+ logger.warning(f'Prompt key {prompt.key} is not found in the user prompts, skipping Assistant prompt')
277
+ continue
278
+ prompt_messages, prompt_key = prompt.messages()
279
+ assistant_messages = {
280
+ 'role': 'assistant',
281
+ 'content': prompt_messages
282
+ }
283
+ all_prompts_messages[prompt.key].append(assistant_messages)
284
+ res = list()
285
+ for prompts in all_prompts_messages.values():
286
+ for prompt in prompts:
287
+ res.append(prompt)
288
+ return res
289
+
290
+ def add_responses(self, annotation: entities.BaseAnnotationDefinition, model: entities.Model):
291
+ """
292
+ Add an annotation to the prompt item
293
+ :param annotation: Annotation object
294
+ :param model: Model object
295
+ """
296
+ if self._item is None:
297
+ raise ValueError('Item is not loaded, cannot add annotation')
298
+ annotation_collection = entities.AnnotationCollection()
299
+ annotation_collection.add(annotation_definition=annotation,
300
+ prompt_id=self.prompts[-1].key,
301
+ model_info={
302
+ 'name': model.name,
303
+ 'model_id': model.id,
304
+ 'confidence': 1.0
305
+ })
306
+ annotations = self._item.annotations.upload(annotations=annotation_collection)
307
+ self.get_assistant_messages(annotations=annotations)
dtlpy/entities/service.py CHANGED
@@ -370,11 +370,6 @@ class Service(entities.BaseEntity):
370
370
  def package(self):
371
371
  if self._package is None:
372
372
  try:
373
- self._package = repositories.Packages(client_api=self._client_api).get(package_id=self.package_id,
374
- fetch=None,
375
- log_error=False)
376
- assert isinstance(self._package, entities.Package)
377
- except:
378
373
  dpk_id = None
379
374
  dpk_version = None
380
375
  if self.app and isinstance(self.app, dict):
@@ -389,6 +384,11 @@ class Service(entities.BaseEntity):
389
384
  version=dpk_version)
390
385
 
391
386
  assert isinstance(self._package, entities.Dpk)
387
+ except:
388
+ self._package = repositories.Packages(client_api=self._client_api).get(package_id=self.package_id,
389
+ fetch=None,
390
+ log_error=False)
391
+ assert isinstance(self._package, entities.Package)
392
392
  return self._package
393
393
 
394
394
  @property
@@ -110,7 +110,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
110
110
 
111
111
  :param local_path: `str` directory path in local FileSystem
112
112
  """
113
- raise NotImplementedError("Please implement 'load' method in {}".format(self.__class__.__name__))
113
+ raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
114
114
 
115
115
  def save(self, local_path, **kwargs):
116
116
  """ saves configuration and weights locally
@@ -121,7 +121,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
121
121
 
122
122
  :param local_path: `str` directory path in local FileSystem
123
123
  """
124
- raise NotImplementedError("Please implement 'save' method in {}".format(self.__class__.__name__))
124
+ raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
125
125
 
126
126
  def train(self, data_path, output_path, **kwargs):
127
127
  """
@@ -133,27 +133,27 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
133
133
  :param data_path: `str` local File System path to where the data was downloaded and converted at
134
134
  :param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
135
135
  """
136
- raise NotImplementedError("Please implement 'train' method in {}".format(self.__class__.__name__))
136
+ raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
137
137
 
138
138
  def predict(self, batch, **kwargs):
139
- """ Model inference (predictions) on batch of images
139
+ """ Model inference (predictions) on batch of items
140
140
 
141
141
  Virtual method - need to implement
142
142
 
143
- :param batch: `np.ndarray`
143
+ :param batch: output of the `prepare_item_func` func
144
144
  :return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
145
145
  """
146
- raise NotImplementedError("Please implement 'predict' method in {}".format(self.__class__.__name__))
146
+ raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
147
147
 
148
- def extract_features(self, batch, **kwargs):
149
- """ Extract model features on batch of images
148
+ def embed(self, batch, **kwargs):
149
+ """ Extract model embeddings on batch of items
150
150
 
151
151
  Virtual method - need to implement
152
152
 
153
- :param batch: `np.ndarray`
154
- :return: `list[list]` each feature is per each image / item in the batch
153
+ :param batch: output of the `prepare_item_func` func
154
+ :return: `list[list]` a feature vector per each item in the batch
155
155
  """
156
- raise NotImplementedError("Please implement 'extract_features' method in {}".format(self.__class__.__name__))
156
+ raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
157
157
 
158
158
  def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
159
159
  """
@@ -187,7 +187,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
187
187
  :param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
188
188
  :return:
189
189
  """
190
- raise NotImplementedError("Please implement 'convert_from_dtlpy' method in {}".format(self.__class__.__name__))
190
+ raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
191
191
 
192
192
  #################
193
193
  # DTLPY METHODS #
@@ -265,14 +265,26 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
265
265
  self.logger.debug("Downloading subset {!r} of {}".format(subset,
266
266
  self.model_entity.dataset.name))
267
267
 
268
- if self.configuration.get("include_model_annotations", False):
269
- annotation_filters = None
270
- else:
268
+ if self.model_entity.output_type is not None:
269
+ if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION,
270
+ entities.AnnotationType.POLYGON]:
271
+ model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
272
+ else:
273
+ model_output_types = [self.model_entity.output_type]
271
274
  annotation_filters = entities.Filters(
275
+ field=entities.FiltersKnownFields.TYPE,
276
+ values=model_output_types,
277
+ resource=entities.FiltersResource.ANNOTATION,
278
+ operator=entities.FiltersOperations.IN
279
+ )
280
+ else:
281
+ annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
282
+
283
+ if not self.configuration.get("include_model_annotations", False):
284
+ annotation_filters.add(
272
285
  field="metadata.system.model.name",
273
286
  values=False,
274
- operator=entities.FiltersOperations.EXISTS,
275
- resource=entities.FiltersResource.ANNOTATION
287
+ operator=entities.FiltersOperations.EXISTS
276
288
  )
277
289
 
278
290
  ret_list = dataset.items.download(filters=filters,
@@ -396,10 +408,10 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
396
408
  pool.shutdown()
397
409
  return items, annotations
398
410
 
399
- @entities.Package.decorators.function(display_name='Extract Feature',
411
+ @entities.Package.decorators.function(display_name='Embed Items',
400
412
  inputs={'items': 'Item[]'},
401
413
  outputs={'items': 'Item[]', 'features': '[]'})
402
- def extract_item_features(self, items: list, upload_features=True, batch_size=None, **kwargs):
414
+ def embed_items(self, items: list, upload_features=True, batch_size=None, **kwargs):
403
415
  """
404
416
  Extract feature from an input list of items (or single) and return the items and the feature vector.
405
417
 
@@ -414,17 +426,18 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
414
426
  input_type = self.model_entity.input_type
415
427
  self.logger.debug(
416
428
  "Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
417
- pool = ThreadPoolExecutor(max_workers=16)
418
429
 
419
- vectors = list()
420
- feature_set_name = self.configuration.get('featureSetName', self.model_entity.name)
421
- try:
422
- feature_set = self.model_entity.project.feature_sets.get(feature_set_name)
423
- logger.info(f'Feature Set found! name: {feature_set_name}')
424
- except exceptions.NotFound as e:
430
+ # Search for existing feature set for this model id
431
+ filters = entities.Filters(field='modelId',
432
+ values=self.model_entity.id,
433
+ resource=entities.FiltersResource.FEATURE_SET)
434
+ pages = self.model_entity.project.feature_sets.list(filters)
435
+ if pages.items_count == 0:
436
+ feature_set_name = self.configuration.get('featureSetName', self.model_entity.name)
425
437
  logger.info('Feature Set not found. creating... ')
426
438
  feature_set = self.model_entity.project.feature_sets.create(name=feature_set_name,
427
439
  entity_type=entities.FeatureEntityType.ITEM,
440
+ model_id=self.model_entity.id,
428
441
  project_id=self.model_entity.project_id,
429
442
  set_type=self.model_entity.name,
430
443
  size=self.configuration.get('embeddings_size',
@@ -433,10 +446,16 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
433
446
  self.model_entity.configuration['featureSetName'] = feature_set_name
434
447
  self.model_entity.update()
435
448
  logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
449
+ elif pages.items_count > 1:
450
+ raise ValueError(
451
+ f'More than one feature set for model. model_id: {self.model_entity.id}, feature_sets_ids: {[f.id for f in pages.all()]}')
452
+ else:
453
+ feature_set = pages.items[0]
454
+ logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
436
455
 
437
- feature_set_id = feature_set.id
438
- project_id = self.model_entity.project_id
439
-
456
+ # upload the feature vectors
457
+ pool = ThreadPoolExecutor(max_workers=16)
458
+ vectors = list()
440
459
  for i_batch in tqdm.tqdm(range(0, len(items), batch_size),
441
460
  desc='predicting',
442
461
  unit='bt',
@@ -444,24 +463,61 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
444
463
  file=sys.stdout):
445
464
  batch_items = items[i_batch: i_batch + batch_size]
446
465
  batch = list(pool.map(self.prepare_item_func, batch_items))
447
- batch_vectors = self.extract_features(batch, **kwargs)
448
- batch_features = list()
466
+ batch_vectors = self.embed(batch, **kwargs)
467
+ vectors.extend(batch_vectors)
449
468
  if upload_features is True:
450
469
  self.logger.debug(
451
470
  "Uploading items' feature vectors for model {!r}.".format(self.model_entity.name))
452
471
  try:
453
- batch_features = list(pool.map(partial(self._upload_model_features,
454
- feature_set_id,
455
- project_id),
456
- batch_items,
457
- batch_vectors))
472
+ _ = list(pool.map(partial(self._upload_model_features,
473
+ feature_set.id,
474
+ self.model_entity.project_id),
475
+ batch_items,
476
+ batch_vectors))
458
477
  except Exception as err:
459
478
  self.logger.exception("Failed to upload feature vectors to items.")
460
479
 
461
- vectors.extend(batch_features)
462
480
  pool.shutdown()
463
481
  return items, vectors
464
482
 
483
+ @entities.Package.decorators.function(display_name='Embed Dataset with DQL',
484
+ inputs={'dataset': 'Dataset',
485
+ 'filters': 'Json'})
486
+ def embed_dataset(self,
487
+ dataset: entities.Dataset,
488
+ filters: entities.Filters = None,
489
+ upload_features=True,
490
+ batch_size=None,
491
+ **kwargs):
492
+ """
493
+ Extract feature from all items given
494
+
495
+ :param dataset: Dataset entity to predict
496
+ :param filters: Filters entity for a filtering before predicting
497
+ :param upload_features: `bool` uploads the features back to the given items
498
+ :param batch_size: `int` size of batch to run a single inference
499
+
500
+ :return: `bool` indicating if the prediction process completed successfully
501
+ """
502
+ if batch_size is None:
503
+ batch_size = self.configuration.get('batch_size', 4)
504
+
505
+ self.logger.debug("Creating embedings for dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
506
+ dataset.id,
507
+ batch_size))
508
+ if not filters:
509
+ filters = entities.Filters()
510
+ if filters is not None and isinstance(filters, dict):
511
+ filters = entities.Filters(custom_filter=filters)
512
+ pages = dataset.items.list(filters=filters, page_size=batch_size)
513
+ # Item type is 'file' only, can be deleted if default filters are added to custom filters
514
+ items = [item for page in pages for item in page if item.type == 'file']
515
+ self.embed_items(items=items,
516
+ upload_features=upload_features,
517
+ batch_size=batch_size,
518
+ **kwargs)
519
+ return True
520
+
465
521
  @entities.Package.decorators.function(display_name='Predict Dataset with DQL',
466
522
  inputs={'dataset': 'Dataset',
467
523
  'filters': 'Json'})
@@ -481,9 +537,12 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
481
537
  :param cleanup: `bool` if set removes existing predictions with the same package-model name (default: False)
482
538
  :param batch_size: `int` size of batch to run a single inference
483
539
 
484
- :return: `List[dl.AnnotationCollection]` where all annotation in the collection are of type package.output_type
485
- and has prediction fields (model_info)
540
+ :return: `bool` indicating if the prediction process completed successfully
486
541
  """
542
+
543
+ if batch_size is None:
544
+ batch_size = self.configuration.get('batch_size', 4)
545
+
487
546
  self.logger.debug("Predicting dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
488
547
  dataset.id,
489
548
  batch_size))
@@ -492,9 +551,10 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
492
551
  if filters is not None and isinstance(filters, dict):
493
552
  filters = entities.Filters(custom_filter=filters)
494
553
  pages = dataset.items.list(filters=filters, page_size=batch_size)
495
- items = [item for item in pages.all() if item.type == 'file']
554
+ # Item type is 'file' only, can be deleted if default filters are added to custom filters
555
+ items = [item for page in pages for item in page if item.type == 'file']
496
556
  self.predict_items(items=items,
497
- with_upload=with_upload,
557
+ upload_annotations=with_upload,
498
558
  cleanup=cleanup,
499
559
  batch_size=batch_size,
500
560
  **kwargs)