dtlpy 1.91.37__py3-none-any.whl → 1.92.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtlpy/__init__.py +5 -2
- dtlpy/__version__.py +1 -1
- dtlpy/entities/__init__.py +1 -1
- dtlpy/entities/command.py +3 -2
- dtlpy/entities/dataset.py +52 -2
- dtlpy/entities/feature_set.py +3 -0
- dtlpy/entities/filters.py +2 -2
- dtlpy/entities/item.py +15 -1
- dtlpy/entities/node.py +11 -1
- dtlpy/entities/ontology.py +36 -40
- dtlpy/entities/pipeline.py +20 -1
- dtlpy/entities/pipeline_execution.py +23 -0
- dtlpy/entities/prompt_item.py +240 -37
- dtlpy/entities/service.py +5 -5
- dtlpy/ml/base_model_adapter.py +99 -41
- dtlpy/new_instance.py +80 -9
- dtlpy/repositories/apps.py +56 -10
- dtlpy/repositories/commands.py +10 -2
- dtlpy/repositories/datasets.py +142 -12
- dtlpy/repositories/dpks.py +5 -1
- dtlpy/repositories/feature_sets.py +23 -3
- dtlpy/repositories/models.py +1 -1
- dtlpy/repositories/pipeline_executions.py +53 -0
- dtlpy/repositories/uploader.py +3 -0
- dtlpy/services/api_client.py +59 -3
- {dtlpy-1.91.37.dist-info → dtlpy-1.92.18.dist-info}/METADATA +1 -1
- {dtlpy-1.91.37.dist-info → dtlpy-1.92.18.dist-info}/RECORD +35 -38
- tests/features/environment.py +29 -0
- dtlpy/callbacks/__init__.py +0 -16
- dtlpy/callbacks/piper_progress_reporter.py +0 -29
- dtlpy/callbacks/progress_viewer.py +0 -54
- {dtlpy-1.91.37.data → dtlpy-1.92.18.data}/scripts/dlp +0 -0
- {dtlpy-1.91.37.data → dtlpy-1.92.18.data}/scripts/dlp.bat +0 -0
- {dtlpy-1.91.37.data → dtlpy-1.92.18.data}/scripts/dlp.py +0 -0
- {dtlpy-1.91.37.dist-info → dtlpy-1.92.18.dist-info}/LICENSE +0 -0
- {dtlpy-1.91.37.dist-info → dtlpy-1.92.18.dist-info}/WHEEL +0 -0
- {dtlpy-1.91.37.dist-info → dtlpy-1.92.18.dist-info}/entry_points.txt +0 -0
- {dtlpy-1.91.37.dist-info → dtlpy-1.92.18.dist-info}/top_level.txt +0 -0
dtlpy/entities/prompt_item.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
import io
|
|
3
2
|
import enum
|
|
4
3
|
import json
|
|
4
|
+
import os.path
|
|
5
|
+
from dtlpy import entities, repositories
|
|
6
|
+
from dtlpy.services.api_client import client as client_api
|
|
7
|
+
import base64
|
|
8
|
+
import requests
|
|
5
9
|
|
|
6
10
|
logger = logging.getLogger(name='dtlpy')
|
|
7
11
|
|
|
@@ -11,27 +15,37 @@ class PromptType(str, enum.Enum):
|
|
|
11
15
|
IMAGE = 'image/*'
|
|
12
16
|
AUDIO = 'audio/*'
|
|
13
17
|
VIDEO = 'video/*'
|
|
18
|
+
METADATA = 'metadata'
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
class Prompt:
|
|
17
|
-
def __init__(self, key):
|
|
22
|
+
def __init__(self, key, role='user'):
|
|
18
23
|
"""
|
|
19
24
|
Create a single Prompt. Prompt can contain multiple mimetype elements, e.g. text sentence and an image.
|
|
20
|
-
|
|
21
25
|
:param key: unique identifier of the prompt in the item
|
|
22
26
|
"""
|
|
23
27
|
self.key = key
|
|
24
28
|
self.elements = list()
|
|
29
|
+
self._items = repositories.Items(client_api=client_api)
|
|
30
|
+
self.metadata = {'role': role}
|
|
25
31
|
|
|
26
|
-
def
|
|
32
|
+
def add_element(self, value, mimetype='application/text'):
|
|
27
33
|
"""
|
|
28
34
|
|
|
29
35
|
:param value: url or string of the input
|
|
30
36
|
:param mimetype: mimetype of the input. options: `text`, `image/*`, `video/*`, `audio/*`
|
|
31
37
|
:return:
|
|
32
38
|
"""
|
|
33
|
-
|
|
34
|
-
|
|
39
|
+
allowed_prompt_types = [prompt_type for prompt_type in PromptType]
|
|
40
|
+
if mimetype not in allowed_prompt_types:
|
|
41
|
+
raise ValueError(f'Invalid mimetype: {mimetype}. Allowed values: {allowed_prompt_types}')
|
|
42
|
+
if not isinstance(value, str) and mimetype != PromptType.METADATA:
|
|
43
|
+
raise ValueError(f'Expected str for Prompt element value, got {type(value)} instead')
|
|
44
|
+
if mimetype == PromptType.METADATA and isinstance(value, dict):
|
|
45
|
+
self.metadata.update(value)
|
|
46
|
+
else:
|
|
47
|
+
self.elements.append({'mimetype': mimetype,
|
|
48
|
+
'value': value})
|
|
35
49
|
|
|
36
50
|
def to_json(self):
|
|
37
51
|
"""
|
|
@@ -39,26 +53,169 @@ class Prompt:
|
|
|
39
53
|
|
|
40
54
|
:return:
|
|
41
55
|
"""
|
|
56
|
+
elements_json = [
|
|
57
|
+
{
|
|
58
|
+
"mimetype": e['mimetype'],
|
|
59
|
+
"value": e['value'],
|
|
60
|
+
} for e in self.elements
|
|
61
|
+
]
|
|
62
|
+
elements_json.append({
|
|
63
|
+
"mimetype": PromptType.METADATA,
|
|
64
|
+
"value": self.metadata
|
|
65
|
+
})
|
|
42
66
|
return {
|
|
43
|
-
self.key:
|
|
44
|
-
{
|
|
45
|
-
"mimetype": e['mimetype'],
|
|
46
|
-
"value": e['value']
|
|
47
|
-
} for e in self.elements
|
|
48
|
-
]
|
|
67
|
+
self.key: elements_json
|
|
49
68
|
}
|
|
50
69
|
|
|
51
|
-
|
|
52
|
-
class PromptItem:
|
|
53
|
-
def __init__(self, name):
|
|
70
|
+
def _convert_stream_to_binary(self, image_url: str):
|
|
54
71
|
"""
|
|
55
|
-
|
|
72
|
+
Convert a stream to binary
|
|
73
|
+
:param image_url: dataloop image stream url
|
|
74
|
+
:return: binary object
|
|
75
|
+
"""
|
|
76
|
+
image_buffer = None
|
|
77
|
+
if '.' in image_url and 'dataloop.ai' not in image_url:
|
|
78
|
+
# URL and not DL item stream
|
|
79
|
+
try:
|
|
80
|
+
response = requests.get(image_url, stream=True)
|
|
81
|
+
response.raise_for_status() # Raise an exception for bad status codes
|
|
82
|
+
|
|
83
|
+
# Check for valid image content type
|
|
84
|
+
if response.headers["Content-Type"].startswith("image/"):
|
|
85
|
+
# Read the image data in chunks to avoid loading large images in memory
|
|
86
|
+
image_buffer = b"".join(chunk for chunk in response.iter_content(1024))
|
|
87
|
+
except requests.exceptions.RequestException as e:
|
|
88
|
+
logger.error(f"Failed to download image from URL: {image_url}, error: {e}")
|
|
89
|
+
|
|
90
|
+
elif '.' in image_url and 'stream' in image_url:
|
|
91
|
+
# DL Stream URL
|
|
92
|
+
item_id = image_url.split("/stream")[0].split("/items/")[-1]
|
|
93
|
+
image_buffer = self._items.get(item_id=item_id).download(save_locally=False).getvalue()
|
|
94
|
+
else:
|
|
95
|
+
# DL item ID
|
|
96
|
+
image_buffer = self._items.get(item_id=image_url).download(save_locally=False).getvalue()
|
|
56
97
|
|
|
57
|
-
|
|
98
|
+
if image_buffer is not None:
|
|
99
|
+
encoded_image = base64.b64encode(image_buffer).decode()
|
|
100
|
+
else:
|
|
101
|
+
logger.error(f'Invalid image url: {image_url}')
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
return f'data:image/jpeg;base64,{encoded_image}'
|
|
105
|
+
|
|
106
|
+
def messages(self):
|
|
107
|
+
"""
|
|
108
|
+
return a list of messages in the prompt item,
|
|
109
|
+
messages are returned following the openai SDK format https://platform.openai.com/docs/guides/vision
|
|
58
110
|
"""
|
|
111
|
+
messages = []
|
|
112
|
+
for element in self.elements:
|
|
113
|
+
if element['mimetype'] == PromptType.TEXT:
|
|
114
|
+
data = {
|
|
115
|
+
"type": "text",
|
|
116
|
+
"text": element['value']
|
|
117
|
+
}
|
|
118
|
+
messages.append(data)
|
|
119
|
+
elif element['mimetype'] == PromptType.IMAGE:
|
|
120
|
+
image_url = self._convert_stream_to_binary(element['value'])
|
|
121
|
+
data = {
|
|
122
|
+
"type": "image_url",
|
|
123
|
+
"image_url": {
|
|
124
|
+
"url": image_url
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
messages.append(data)
|
|
128
|
+
elif element['mimetype'] == PromptType.AUDIO:
|
|
129
|
+
raise NotImplementedError('Audio prompt is not supported yet')
|
|
130
|
+
elif element['mimetype'] == PromptType.VIDEO:
|
|
131
|
+
raise NotImplementedError('Video prompt is not supported yet')
|
|
132
|
+
else:
|
|
133
|
+
raise ValueError(f'Invalid mimetype: {element["mimetype"]}')
|
|
134
|
+
return messages, self.key
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class PromptItem:
|
|
138
|
+
def __init__(self, name, item: entities.Item = None):
|
|
139
|
+
# prompt item name
|
|
59
140
|
self.name = name
|
|
60
|
-
|
|
141
|
+
# list of user prompts in the prompt item
|
|
61
142
|
self.prompts = list()
|
|
143
|
+
# list of assistant (annotations) prompts in the prompt item
|
|
144
|
+
self.assistant_prompts = dict()
|
|
145
|
+
# Dataloop Item
|
|
146
|
+
self._item = None
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def from_item(cls, item: entities.Item):
|
|
150
|
+
"""
|
|
151
|
+
Load a prompt item from the platform
|
|
152
|
+
:param item : Item object
|
|
153
|
+
:return: PromptItem object
|
|
154
|
+
"""
|
|
155
|
+
if 'json' not in item.mimetype or item.system.get('shebang', dict()).get('dltype') != 'prompt':
|
|
156
|
+
raise ValueError('Expecting a json item with system.shebang.dltype = prompt')
|
|
157
|
+
# Not using `save_locally=False` to use the from_local_file method
|
|
158
|
+
item_file_path = item.download()
|
|
159
|
+
prompt_item = cls.from_local_file(file_path=item_file_path)
|
|
160
|
+
if os.path.exists(item_file_path):
|
|
161
|
+
os.remove(item_file_path)
|
|
162
|
+
prompt_item._item = item
|
|
163
|
+
return prompt_item
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
def from_local_file(cls, file_path):
|
|
167
|
+
"""
|
|
168
|
+
Create a new prompt item from a file
|
|
169
|
+
:param file_path: path to the file
|
|
170
|
+
:return: PromptItem object
|
|
171
|
+
"""
|
|
172
|
+
if os.path.exists(file_path) is False:
|
|
173
|
+
raise FileNotFoundError(f'File does not exists: {file_path}')
|
|
174
|
+
if 'json' not in os.path.splitext(file_path)[-1]:
|
|
175
|
+
raise ValueError(f'Expected path to json item, got {os.path.splitext(file_path)[-1]}')
|
|
176
|
+
prompt_item = cls(name=file_path)
|
|
177
|
+
with open(file_path, 'r') as f:
|
|
178
|
+
data = json.load(f)
|
|
179
|
+
for prompt_key, prompt_values in data.get('prompts', dict()).items():
|
|
180
|
+
prompt = Prompt(key=prompt_key)
|
|
181
|
+
for val in prompt_values:
|
|
182
|
+
if val['mimetype'] == PromptType.METADATA:
|
|
183
|
+
_ = val.pop('mimetype')
|
|
184
|
+
prompt.add_element(value=val, mimetype=PromptType.METADATA)
|
|
185
|
+
else:
|
|
186
|
+
prompt.add_element(mimetype=val['mimetype'], value=val['value'])
|
|
187
|
+
prompt_item.add_prompt(prompt=prompt, update_item=False)
|
|
188
|
+
return prompt_item
|
|
189
|
+
|
|
190
|
+
def get_assistant_messages(self, annotations: entities.AnnotationCollection):
|
|
191
|
+
"""
|
|
192
|
+
Get all the annotations in the item for the assistant messages
|
|
193
|
+
"""
|
|
194
|
+
# clearing the assistant prompts from previous annotations that might not belong
|
|
195
|
+
self.assistant_prompts = dict()
|
|
196
|
+
for annotation in annotations:
|
|
197
|
+
prompt_id = annotation.metadata.get('system', dict()).get('promptId', None)
|
|
198
|
+
if annotation.type == 'ref_image':
|
|
199
|
+
prompt = Prompt(key=prompt_id)
|
|
200
|
+
prompt.add_element(value=annotation.coordinates.get('ref'), mimetype=PromptType.IMAGE)
|
|
201
|
+
self.assistant_prompts[annotation.id] = prompt
|
|
202
|
+
elif annotation.type == 'text':
|
|
203
|
+
prompt = Prompt(key=prompt_id)
|
|
204
|
+
prompt.add_element(value=annotation.coordinates, mimetype=PromptType.TEXT)
|
|
205
|
+
self.assistant_prompts[annotation.id] = prompt
|
|
206
|
+
|
|
207
|
+
def get_assistant_prompts(self, model_name):
|
|
208
|
+
"""
|
|
209
|
+
Get assistant prompts
|
|
210
|
+
:return:
|
|
211
|
+
"""
|
|
212
|
+
if self._item is None:
|
|
213
|
+
logger.warning('Item is not loaded, skipping annotations context')
|
|
214
|
+
return
|
|
215
|
+
filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
|
|
216
|
+
filters.add(field='metadata.user.model.name', values=model_name)
|
|
217
|
+
annotations = self._item.annotations.list(filters=filters)
|
|
218
|
+
self.get_assistant_messages(annotations=annotations)
|
|
62
219
|
|
|
63
220
|
def to_json(self):
|
|
64
221
|
"""
|
|
@@ -69,36 +226,82 @@ class PromptItem:
|
|
|
69
226
|
prompts_json = {
|
|
70
227
|
"shebang": "dataloop",
|
|
71
228
|
"metadata": {
|
|
72
|
-
"dltype":
|
|
229
|
+
"dltype": 'prompt'
|
|
73
230
|
},
|
|
74
231
|
"prompts": {}
|
|
75
232
|
}
|
|
76
233
|
for prompt in self.prompts:
|
|
77
234
|
for prompt_key, prompt_values in prompt.to_json().items():
|
|
78
235
|
prompts_json["prompts"][prompt_key] = prompt_values
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
@classmethod
|
|
82
|
-
def from_json(cls, _json):
|
|
83
|
-
inst = cls(name='dummy')
|
|
84
|
-
for prompt_key, prompt_values in _json["prompts"].items():
|
|
85
|
-
prompt = Prompt(key=prompt_key)
|
|
86
|
-
for val in prompt_values:
|
|
87
|
-
prompt.add(mimetype=val['mimetype'], value=val['value'])
|
|
88
|
-
inst.prompts.append(prompt)
|
|
89
|
-
return inst
|
|
236
|
+
prompts_json["prompts"][prompt_key].append({'metadata'})
|
|
90
237
|
|
|
91
|
-
|
|
92
|
-
byte_io = io.BytesIO()
|
|
93
|
-
byte_io.name = self.name
|
|
94
|
-
byte_io.write(json.dumps(self.to_json()).encode())
|
|
95
|
-
byte_io.seek(0)
|
|
96
|
-
return byte_io
|
|
238
|
+
return prompts_json
|
|
97
239
|
|
|
98
|
-
def
|
|
240
|
+
def add_prompt(self, prompt: Prompt, update_item=True):
|
|
99
241
|
"""
|
|
100
242
|
add a prompt to the prompt item
|
|
101
243
|
prompt: a dictionary. keys are prompt message id, values are prompt messages
|
|
102
244
|
responses: a list of annotations representing responses to the prompt
|
|
103
245
|
"""
|
|
104
246
|
self.prompts.append(prompt)
|
|
247
|
+
if update_item is True:
|
|
248
|
+
if self._item is not None:
|
|
249
|
+
self._item._Item__update_item_binary(_json=self.to_json())
|
|
250
|
+
else:
|
|
251
|
+
logger.warning('Item is not loaded, skipping upload')
|
|
252
|
+
|
|
253
|
+
def messages(self, model_name=None):
|
|
254
|
+
"""
|
|
255
|
+
return a list of messages in the prompt item
|
|
256
|
+
messages are returned following the openai SDK format
|
|
257
|
+
"""
|
|
258
|
+
if model_name is not None:
|
|
259
|
+
self.get_assistant_prompts(model_name=model_name)
|
|
260
|
+
else:
|
|
261
|
+
logger.warning('Model name is not provided, skipping assistant prompts')
|
|
262
|
+
|
|
263
|
+
all_prompts_messages = dict()
|
|
264
|
+
for prompt in self.prompts:
|
|
265
|
+
if prompt.key not in all_prompts_messages:
|
|
266
|
+
all_prompts_messages[prompt.key] = list()
|
|
267
|
+
prompt_messages, prompt_key = prompt.messages()
|
|
268
|
+
messages = {
|
|
269
|
+
'role': prompt.metadata.get('role', 'user'),
|
|
270
|
+
'content': prompt_messages
|
|
271
|
+
}
|
|
272
|
+
all_prompts_messages[prompt.key].append(messages)
|
|
273
|
+
|
|
274
|
+
for ann_id, prompt in self.assistant_prompts.items():
|
|
275
|
+
if prompt.key not in all_prompts_messages:
|
|
276
|
+
logger.warning(f'Prompt key {prompt.key} is not found in the user prompts, skipping Assistant prompt')
|
|
277
|
+
continue
|
|
278
|
+
prompt_messages, prompt_key = prompt.messages()
|
|
279
|
+
assistant_messages = {
|
|
280
|
+
'role': 'assistant',
|
|
281
|
+
'content': prompt_messages
|
|
282
|
+
}
|
|
283
|
+
all_prompts_messages[prompt.key].append(assistant_messages)
|
|
284
|
+
res = list()
|
|
285
|
+
for prompts in all_prompts_messages.values():
|
|
286
|
+
for prompt in prompts:
|
|
287
|
+
res.append(prompt)
|
|
288
|
+
return res
|
|
289
|
+
|
|
290
|
+
def add_responses(self, annotation: entities.BaseAnnotationDefinition, model: entities.Model):
|
|
291
|
+
"""
|
|
292
|
+
Add an annotation to the prompt item
|
|
293
|
+
:param annotation: Annotation object
|
|
294
|
+
:param model: Model object
|
|
295
|
+
"""
|
|
296
|
+
if self._item is None:
|
|
297
|
+
raise ValueError('Item is not loaded, cannot add annotation')
|
|
298
|
+
annotation_collection = entities.AnnotationCollection()
|
|
299
|
+
annotation_collection.add(annotation_definition=annotation,
|
|
300
|
+
prompt_id=self.prompts[-1].key,
|
|
301
|
+
model_info={
|
|
302
|
+
'name': model.name,
|
|
303
|
+
'model_id': model.id,
|
|
304
|
+
'confidence': 1.0
|
|
305
|
+
})
|
|
306
|
+
annotations = self._item.annotations.upload(annotations=annotation_collection)
|
|
307
|
+
self.get_assistant_messages(annotations=annotations)
|
dtlpy/entities/service.py
CHANGED
|
@@ -370,11 +370,6 @@ class Service(entities.BaseEntity):
|
|
|
370
370
|
def package(self):
|
|
371
371
|
if self._package is None:
|
|
372
372
|
try:
|
|
373
|
-
self._package = repositories.Packages(client_api=self._client_api).get(package_id=self.package_id,
|
|
374
|
-
fetch=None,
|
|
375
|
-
log_error=False)
|
|
376
|
-
assert isinstance(self._package, entities.Package)
|
|
377
|
-
except:
|
|
378
373
|
dpk_id = None
|
|
379
374
|
dpk_version = None
|
|
380
375
|
if self.app and isinstance(self.app, dict):
|
|
@@ -389,6 +384,11 @@ class Service(entities.BaseEntity):
|
|
|
389
384
|
version=dpk_version)
|
|
390
385
|
|
|
391
386
|
assert isinstance(self._package, entities.Dpk)
|
|
387
|
+
except:
|
|
388
|
+
self._package = repositories.Packages(client_api=self._client_api).get(package_id=self.package_id,
|
|
389
|
+
fetch=None,
|
|
390
|
+
log_error=False)
|
|
391
|
+
assert isinstance(self._package, entities.Package)
|
|
392
392
|
return self._package
|
|
393
393
|
|
|
394
394
|
@property
|
dtlpy/ml/base_model_adapter.py
CHANGED
|
@@ -110,7 +110,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
110
110
|
|
|
111
111
|
:param local_path: `str` directory path in local FileSystem
|
|
112
112
|
"""
|
|
113
|
-
raise NotImplementedError("Please implement
|
|
113
|
+
raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
|
|
114
114
|
|
|
115
115
|
def save(self, local_path, **kwargs):
|
|
116
116
|
""" saves configuration and weights locally
|
|
@@ -121,7 +121,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
121
121
|
|
|
122
122
|
:param local_path: `str` directory path in local FileSystem
|
|
123
123
|
"""
|
|
124
|
-
raise NotImplementedError("Please implement
|
|
124
|
+
raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
|
|
125
125
|
|
|
126
126
|
def train(self, data_path, output_path, **kwargs):
|
|
127
127
|
"""
|
|
@@ -133,27 +133,27 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
133
133
|
:param data_path: `str` local File System path to where the data was downloaded and converted at
|
|
134
134
|
:param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
|
|
135
135
|
"""
|
|
136
|
-
raise NotImplementedError("Please implement
|
|
136
|
+
raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
|
|
137
137
|
|
|
138
138
|
def predict(self, batch, **kwargs):
|
|
139
|
-
""" Model inference (predictions) on batch of
|
|
139
|
+
""" Model inference (predictions) on batch of items
|
|
140
140
|
|
|
141
141
|
Virtual method - need to implement
|
|
142
142
|
|
|
143
|
-
:param batch: `
|
|
143
|
+
:param batch: output of the `prepare_item_func` func
|
|
144
144
|
:return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
|
|
145
145
|
"""
|
|
146
|
-
raise NotImplementedError("Please implement
|
|
146
|
+
raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
|
|
147
147
|
|
|
148
|
-
def
|
|
149
|
-
""" Extract model
|
|
148
|
+
def embed(self, batch, **kwargs):
|
|
149
|
+
""" Extract model embeddings on batch of items
|
|
150
150
|
|
|
151
151
|
Virtual method - need to implement
|
|
152
152
|
|
|
153
|
-
:param batch: `
|
|
154
|
-
:return: `list[list]`
|
|
153
|
+
:param batch: output of the `prepare_item_func` func
|
|
154
|
+
:return: `list[list]` a feature vector per each item in the batch
|
|
155
155
|
"""
|
|
156
|
-
raise NotImplementedError("Please implement
|
|
156
|
+
raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
|
|
157
157
|
|
|
158
158
|
def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
|
|
159
159
|
"""
|
|
@@ -187,7 +187,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
187
187
|
:param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
|
|
188
188
|
:return:
|
|
189
189
|
"""
|
|
190
|
-
raise NotImplementedError("Please implement
|
|
190
|
+
raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
|
|
191
191
|
|
|
192
192
|
#################
|
|
193
193
|
# DTLPY METHODS #
|
|
@@ -265,14 +265,26 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
265
265
|
self.logger.debug("Downloading subset {!r} of {}".format(subset,
|
|
266
266
|
self.model_entity.dataset.name))
|
|
267
267
|
|
|
268
|
-
if self.
|
|
269
|
-
|
|
270
|
-
|
|
268
|
+
if self.model_entity.output_type is not None:
|
|
269
|
+
if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION,
|
|
270
|
+
entities.AnnotationType.POLYGON]:
|
|
271
|
+
model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
|
|
272
|
+
else:
|
|
273
|
+
model_output_types = [self.model_entity.output_type]
|
|
271
274
|
annotation_filters = entities.Filters(
|
|
275
|
+
field=entities.FiltersKnownFields.TYPE,
|
|
276
|
+
values=model_output_types,
|
|
277
|
+
resource=entities.FiltersResource.ANNOTATION,
|
|
278
|
+
operator=entities.FiltersOperations.IN
|
|
279
|
+
)
|
|
280
|
+
else:
|
|
281
|
+
annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
|
|
282
|
+
|
|
283
|
+
if not self.configuration.get("include_model_annotations", False):
|
|
284
|
+
annotation_filters.add(
|
|
272
285
|
field="metadata.system.model.name",
|
|
273
286
|
values=False,
|
|
274
|
-
operator=entities.FiltersOperations.EXISTS
|
|
275
|
-
resource=entities.FiltersResource.ANNOTATION
|
|
287
|
+
operator=entities.FiltersOperations.EXISTS
|
|
276
288
|
)
|
|
277
289
|
|
|
278
290
|
ret_list = dataset.items.download(filters=filters,
|
|
@@ -396,10 +408,10 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
396
408
|
pool.shutdown()
|
|
397
409
|
return items, annotations
|
|
398
410
|
|
|
399
|
-
@entities.Package.decorators.function(display_name='
|
|
411
|
+
@entities.Package.decorators.function(display_name='Embed Items',
|
|
400
412
|
inputs={'items': 'Item[]'},
|
|
401
413
|
outputs={'items': 'Item[]', 'features': '[]'})
|
|
402
|
-
def
|
|
414
|
+
def embed_items(self, items: list, upload_features=True, batch_size=None, **kwargs):
|
|
403
415
|
"""
|
|
404
416
|
Extract feature from an input list of items (or single) and return the items and the feature vector.
|
|
405
417
|
|
|
@@ -414,17 +426,18 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
414
426
|
input_type = self.model_entity.input_type
|
|
415
427
|
self.logger.debug(
|
|
416
428
|
"Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
417
|
-
pool = ThreadPoolExecutor(max_workers=16)
|
|
418
429
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
430
|
+
# Search for existing feature set for this model id
|
|
431
|
+
filters = entities.Filters(field='modelId',
|
|
432
|
+
values=self.model_entity.id,
|
|
433
|
+
resource=entities.FiltersResource.FEATURE_SET)
|
|
434
|
+
pages = self.model_entity.project.feature_sets.list(filters)
|
|
435
|
+
if pages.items_count == 0:
|
|
436
|
+
feature_set_name = self.configuration.get('featureSetName', self.model_entity.name)
|
|
425
437
|
logger.info('Feature Set not found. creating... ')
|
|
426
438
|
feature_set = self.model_entity.project.feature_sets.create(name=feature_set_name,
|
|
427
439
|
entity_type=entities.FeatureEntityType.ITEM,
|
|
440
|
+
model_id=self.model_entity.id,
|
|
428
441
|
project_id=self.model_entity.project_id,
|
|
429
442
|
set_type=self.model_entity.name,
|
|
430
443
|
size=self.configuration.get('embeddings_size',
|
|
@@ -433,10 +446,16 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
433
446
|
self.model_entity.configuration['featureSetName'] = feature_set_name
|
|
434
447
|
self.model_entity.update()
|
|
435
448
|
logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
|
|
449
|
+
elif pages.items_count > 1:
|
|
450
|
+
raise ValueError(
|
|
451
|
+
f'More than one feature set for model. model_id: {self.model_entity.id}, feature_sets_ids: {[f.id for f in pages.all()]}')
|
|
452
|
+
else:
|
|
453
|
+
feature_set = pages.items[0]
|
|
454
|
+
logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
|
|
436
455
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
456
|
+
# upload the feature vectors
|
|
457
|
+
pool = ThreadPoolExecutor(max_workers=16)
|
|
458
|
+
vectors = list()
|
|
440
459
|
for i_batch in tqdm.tqdm(range(0, len(items), batch_size),
|
|
441
460
|
desc='predicting',
|
|
442
461
|
unit='bt',
|
|
@@ -444,24 +463,60 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
444
463
|
file=sys.stdout):
|
|
445
464
|
batch_items = items[i_batch: i_batch + batch_size]
|
|
446
465
|
batch = list(pool.map(self.prepare_item_func, batch_items))
|
|
447
|
-
batch_vectors = self.
|
|
448
|
-
|
|
466
|
+
batch_vectors = self.embed(batch, **kwargs)
|
|
467
|
+
vectors.extend(batch_vectors)
|
|
449
468
|
if upload_features is True:
|
|
450
469
|
self.logger.debug(
|
|
451
470
|
"Uploading items' feature vectors for model {!r}.".format(self.model_entity.name))
|
|
452
471
|
try:
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
472
|
+
_ = list(pool.map(partial(self._upload_model_features,
|
|
473
|
+
feature_set.id,
|
|
474
|
+
self.model_entity.project_id),
|
|
475
|
+
batch_items,
|
|
476
|
+
batch_vectors))
|
|
458
477
|
except Exception as err:
|
|
459
478
|
self.logger.exception("Failed to upload feature vectors to items.")
|
|
460
479
|
|
|
461
|
-
vectors.extend(batch_features)
|
|
462
480
|
pool.shutdown()
|
|
463
481
|
return items, vectors
|
|
464
482
|
|
|
483
|
+
@entities.Package.decorators.function(display_name='Embed Dataset with DQL',
|
|
484
|
+
inputs={'dataset': 'Dataset',
|
|
485
|
+
'filters': 'Json'})
|
|
486
|
+
def embed_dataset(self,
|
|
487
|
+
dataset: entities.Dataset,
|
|
488
|
+
filters: entities.Filters = None,
|
|
489
|
+
upload_features=True,
|
|
490
|
+
batch_size=None,
|
|
491
|
+
**kwargs):
|
|
492
|
+
"""
|
|
493
|
+
Extract feature from all items given
|
|
494
|
+
|
|
495
|
+
:param dataset: Dataset entity to predict
|
|
496
|
+
:param filters: Filters entity for a filtering before predicting
|
|
497
|
+
:param upload_features: `bool` uploads the features back to the given items
|
|
498
|
+
:param batch_size: `int` size of batch to run a single inference
|
|
499
|
+
|
|
500
|
+
:return: `bool` indicating if the prediction process completed successfully
|
|
501
|
+
"""
|
|
502
|
+
if batch_size is None:
|
|
503
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
504
|
+
|
|
505
|
+
self.logger.debug("Creating embedings for dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
|
|
506
|
+
dataset.id,
|
|
507
|
+
batch_size))
|
|
508
|
+
if not filters:
|
|
509
|
+
filters = entities.Filters()
|
|
510
|
+
if filters is not None and isinstance(filters, dict):
|
|
511
|
+
filters = entities.Filters(custom_filter=filters)
|
|
512
|
+
pages = dataset.items.list(filters=filters, page_size=batch_size)
|
|
513
|
+
items = [item for page in pages for item in page]
|
|
514
|
+
self.embed_items(items=items,
|
|
515
|
+
upload_features=upload_features,
|
|
516
|
+
batch_size=batch_size,
|
|
517
|
+
**kwargs)
|
|
518
|
+
return True
|
|
519
|
+
|
|
465
520
|
@entities.Package.decorators.function(display_name='Predict Dataset with DQL',
|
|
466
521
|
inputs={'dataset': 'Dataset',
|
|
467
522
|
'filters': 'Json'})
|
|
@@ -481,9 +536,12 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
481
536
|
:param cleanup: `bool` if set removes existing predictions with the same package-model name (default: False)
|
|
482
537
|
:param batch_size: `int` size of batch to run a single inference
|
|
483
538
|
|
|
484
|
-
:return: `
|
|
485
|
-
and has prediction fields (model_info)
|
|
539
|
+
:return: `bool` indicating if the prediction process completed successfully
|
|
486
540
|
"""
|
|
541
|
+
|
|
542
|
+
if batch_size is None:
|
|
543
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
544
|
+
|
|
487
545
|
self.logger.debug("Predicting dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
|
|
488
546
|
dataset.id,
|
|
489
547
|
batch_size))
|
|
@@ -492,9 +550,9 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
492
550
|
if filters is not None and isinstance(filters, dict):
|
|
493
551
|
filters = entities.Filters(custom_filter=filters)
|
|
494
552
|
pages = dataset.items.list(filters=filters, page_size=batch_size)
|
|
495
|
-
items = [item for
|
|
553
|
+
items = [item for page in pages for item in page]
|
|
496
554
|
self.predict_items(items=items,
|
|
497
|
-
|
|
555
|
+
upload_annotations=with_upload,
|
|
498
556
|
cleanup=cleanup,
|
|
499
557
|
batch_size=batch_size,
|
|
500
558
|
**kwargs)
|