dtlpy 1.92.19__py3-none-any.whl → 1.93.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtlpy/__init__.py +8 -2
- dtlpy/__version__.py +1 -1
- dtlpy/entities/__init__.py +4 -1
- dtlpy/entities/app.py +5 -1
- dtlpy/entities/compute.py +374 -0
- dtlpy/entities/dpk.py +1 -0
- dtlpy/entities/filters.py +1 -0
- dtlpy/entities/item.py +5 -6
- dtlpy/entities/model.py +3 -3
- dtlpy/entities/prompt_item.py +257 -107
- dtlpy/ml/base_model_adapter.py +57 -14
- dtlpy/repositories/__init__.py +1 -0
- dtlpy/repositories/apps.py +5 -2
- dtlpy/repositories/computes.py +228 -0
- dtlpy/repositories/models.py +2 -5
- {dtlpy-1.92.19.dist-info → dtlpy-1.93.11.dist-info}/METADATA +1 -1
- {dtlpy-1.92.19.dist-info → dtlpy-1.93.11.dist-info}/RECORD +24 -22
- {dtlpy-1.92.19.data → dtlpy-1.93.11.data}/scripts/dlp +0 -0
- {dtlpy-1.92.19.data → dtlpy-1.93.11.data}/scripts/dlp.bat +0 -0
- {dtlpy-1.92.19.data → dtlpy-1.93.11.data}/scripts/dlp.py +0 -0
- {dtlpy-1.92.19.dist-info → dtlpy-1.93.11.dist-info}/LICENSE +0 -0
- {dtlpy-1.92.19.dist-info → dtlpy-1.93.11.dist-info}/WHEEL +0 -0
- {dtlpy-1.92.19.dist-info → dtlpy-1.93.11.dist-info}/entry_points.txt +0 -0
- {dtlpy-1.92.19.dist-info → dtlpy-1.93.11.dist-info}/top_level.txt +0 -0
dtlpy/entities/prompt_item.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
|
+
import requests
|
|
1
2
|
import logging
|
|
3
|
+
import base64
|
|
2
4
|
import enum
|
|
3
5
|
import json
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
+
import io
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
from .. import entities, repositories
|
|
6
11
|
from dtlpy.services.api_client import client as client_api
|
|
7
|
-
import base64
|
|
8
|
-
import requests
|
|
9
12
|
|
|
10
13
|
logger = logging.getLogger(name='dtlpy')
|
|
11
14
|
|
|
@@ -26,6 +29,8 @@ class Prompt:
|
|
|
26
29
|
"""
|
|
27
30
|
self.key = key
|
|
28
31
|
self.elements = list()
|
|
32
|
+
# to avoid broken stream of json files - DAT-75653
|
|
33
|
+
client_api.default_headers['x-dl-sanitize'] = '0'
|
|
29
34
|
self._items = repositories.Items(client_api=client_api)
|
|
30
35
|
self.metadata = {'role': role}
|
|
31
36
|
|
|
@@ -39,8 +44,6 @@ class Prompt:
|
|
|
39
44
|
allowed_prompt_types = [prompt_type for prompt_type in PromptType]
|
|
40
45
|
if mimetype not in allowed_prompt_types:
|
|
41
46
|
raise ValueError(f'Invalid mimetype: {mimetype}. Allowed values: {allowed_prompt_types}')
|
|
42
|
-
if not isinstance(value, str) and mimetype != PromptType.METADATA:
|
|
43
|
-
raise ValueError(f'Expected str for Prompt element value, got {type(value)} instead')
|
|
44
47
|
if mimetype == PromptType.METADATA and isinstance(value, dict):
|
|
45
48
|
self.metadata.update(value)
|
|
46
49
|
else:
|
|
@@ -57,7 +60,7 @@ class Prompt:
|
|
|
57
60
|
{
|
|
58
61
|
"mimetype": e['mimetype'],
|
|
59
62
|
"value": e['value'],
|
|
60
|
-
} for e in self.elements
|
|
63
|
+
} for e in self.elements if not e['mimetype'] == PromptType.METADATA
|
|
61
64
|
]
|
|
62
65
|
elements_json.append({
|
|
63
66
|
"mimetype": PromptType.METADATA,
|
|
@@ -135,15 +138,34 @@ class Prompt:
|
|
|
135
138
|
|
|
136
139
|
|
|
137
140
|
class PromptItem:
|
|
138
|
-
def __init__(self, name, item: entities.Item = None):
|
|
141
|
+
def __init__(self, name, item: entities.Item = None, role_mapping=None):
|
|
142
|
+
if role_mapping is None:
|
|
143
|
+
role_mapping = {'user': 'item',
|
|
144
|
+
'assistant': 'annotation'}
|
|
145
|
+
if not isinstance(role_mapping, dict):
|
|
146
|
+
raise ValueError(f'input role_mapping must be dict. type: {type(role_mapping)}')
|
|
147
|
+
self.role_mapping = role_mapping
|
|
139
148
|
# prompt item name
|
|
140
149
|
self.name = name
|
|
141
150
|
# list of user prompts in the prompt item
|
|
142
151
|
self.prompts = list()
|
|
152
|
+
self.assistant_prompts = list()
|
|
143
153
|
# list of assistant (annotations) prompts in the prompt item
|
|
144
|
-
self.assistant_prompts = dict()
|
|
145
154
|
# Dataloop Item
|
|
146
|
-
self.
|
|
155
|
+
self._messages = []
|
|
156
|
+
self._item: entities.Item = item
|
|
157
|
+
self._annotations: entities.AnnotationCollection = None
|
|
158
|
+
if item is not None:
|
|
159
|
+
self._items = item.items
|
|
160
|
+
self.fetch()
|
|
161
|
+
else:
|
|
162
|
+
self._items = repositories.Items(client_api=client_api)
|
|
163
|
+
# to avoid broken stream of json files - DAT-75653
|
|
164
|
+
self._items._client_api.default_headers['x-dl-sanitize'] = '0'
|
|
165
|
+
|
|
166
|
+
@classmethod
|
|
167
|
+
def from_messages(cls, messages: list):
|
|
168
|
+
...
|
|
147
169
|
|
|
148
170
|
@classmethod
|
|
149
171
|
def from_item(cls, item: entities.Item):
|
|
@@ -154,68 +176,67 @@ class PromptItem:
|
|
|
154
176
|
"""
|
|
155
177
|
if 'json' not in item.mimetype or item.system.get('shebang', dict()).get('dltype') != 'prompt':
|
|
156
178
|
raise ValueError('Expecting a json item with system.shebang.dltype = prompt')
|
|
157
|
-
|
|
158
|
-
item_file_path = item.download()
|
|
159
|
-
prompt_item = cls.from_local_file(file_path=item_file_path)
|
|
160
|
-
if os.path.exists(item_file_path):
|
|
161
|
-
os.remove(item_file_path)
|
|
162
|
-
prompt_item._item = item
|
|
163
|
-
return prompt_item
|
|
179
|
+
return cls(name=item.name, item=item)
|
|
164
180
|
|
|
165
181
|
@classmethod
|
|
166
|
-
def from_local_file(cls,
|
|
182
|
+
def from_local_file(cls, filepath):
|
|
167
183
|
"""
|
|
168
184
|
Create a new prompt item from a file
|
|
169
|
-
:param
|
|
185
|
+
:param filepath: path to the file
|
|
170
186
|
:return: PromptItem object
|
|
171
187
|
"""
|
|
172
|
-
if os.path.exists(
|
|
173
|
-
raise FileNotFoundError(f'File does not exists: {
|
|
174
|
-
if 'json' not in os.path.splitext(
|
|
175
|
-
raise ValueError(f'Expected path to json item, got {os.path.splitext(
|
|
176
|
-
prompt_item = cls(name=
|
|
177
|
-
with open(
|
|
188
|
+
if os.path.exists(filepath) is False:
|
|
189
|
+
raise FileNotFoundError(f'File does not exists: {filepath}')
|
|
190
|
+
if 'json' not in os.path.splitext(filepath)[-1]:
|
|
191
|
+
raise ValueError(f'Expected path to json item, got {os.path.splitext(filepath)[-1]}')
|
|
192
|
+
prompt_item = cls(name=filepath)
|
|
193
|
+
with open(filepath, 'r', encoding='utf-8') as f:
|
|
178
194
|
data = json.load(f)
|
|
179
|
-
|
|
180
|
-
prompt = Prompt(key=prompt_key)
|
|
181
|
-
for val in prompt_values:
|
|
182
|
-
if val['mimetype'] == PromptType.METADATA:
|
|
183
|
-
_ = val.pop('mimetype')
|
|
184
|
-
prompt.add_element(value=val, mimetype=PromptType.METADATA)
|
|
185
|
-
else:
|
|
186
|
-
prompt.add_element(mimetype=val['mimetype'], value=val['value'])
|
|
187
|
-
prompt_item.add_prompt(prompt=prompt, update_item=False)
|
|
195
|
+
prompt_item.prompts = prompt_item._load_item_prompts(data=data)
|
|
188
196
|
return prompt_item
|
|
189
197
|
|
|
190
|
-
|
|
198
|
+
@staticmethod
|
|
199
|
+
def _load_item_prompts(data):
|
|
200
|
+
prompts = list()
|
|
201
|
+
for prompt_key, prompt_elements in data.get('prompts', dict()).items():
|
|
202
|
+
content = list()
|
|
203
|
+
for element in prompt_elements:
|
|
204
|
+
content.append({'value': element.get('value', dict()),
|
|
205
|
+
'mimetype': element['mimetype']})
|
|
206
|
+
prompt = Prompt(key=prompt_key, role="user")
|
|
207
|
+
for element in content:
|
|
208
|
+
prompt.add_element(value=element.get('value', ''),
|
|
209
|
+
mimetype=element.get('mimetype', PromptType.TEXT))
|
|
210
|
+
prompts.append(prompt)
|
|
211
|
+
return prompts
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def _load_annotations_prompts(annotations: entities.AnnotationCollection):
|
|
191
215
|
"""
|
|
192
216
|
Get all the annotations in the item for the assistant messages
|
|
193
217
|
"""
|
|
194
218
|
# clearing the assistant prompts from previous annotations that might not belong
|
|
195
|
-
|
|
219
|
+
assistant_prompts = list()
|
|
196
220
|
for annotation in annotations:
|
|
197
221
|
prompt_id = annotation.metadata.get('system', dict()).get('promptId', None)
|
|
222
|
+
model_info = annotation.metadata.get('user', dict()).get('model', dict())
|
|
223
|
+
annotation_id = annotation.id
|
|
198
224
|
if annotation.type == 'ref_image':
|
|
199
|
-
prompt = Prompt(key=prompt_id)
|
|
200
|
-
prompt.add_element(value=annotation.coordinates.get('ref'),
|
|
201
|
-
|
|
225
|
+
prompt = Prompt(key=prompt_id, role='assistant')
|
|
226
|
+
prompt.add_element(value=annotation.annotation_definition.coordinates.get('ref'),
|
|
227
|
+
mimetype=PromptType.IMAGE)
|
|
202
228
|
elif annotation.type == 'text':
|
|
203
|
-
prompt = Prompt(key=prompt_id)
|
|
204
|
-
prompt.add_element(value=annotation.coordinates,
|
|
205
|
-
|
|
229
|
+
prompt = Prompt(key=prompt_id, role='assistant')
|
|
230
|
+
prompt.add_element(value=annotation.annotation_definition.coordinates,
|
|
231
|
+
mimetype=PromptType.TEXT)
|
|
232
|
+
else:
|
|
233
|
+
raise ValueError(f"Unsupported annotation type: {annotation.type}")
|
|
206
234
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
if self._item is None:
|
|
213
|
-
logger.warning('Item is not loaded, skipping annotations context')
|
|
214
|
-
return
|
|
215
|
-
filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
|
|
216
|
-
filters.add(field='metadata.user.model.name', values=model_name)
|
|
217
|
-
annotations = self._item.annotations.list(filters=filters)
|
|
218
|
-
self.get_assistant_messages(annotations=annotations)
|
|
235
|
+
prompt.add_element(value={'id': annotation_id,
|
|
236
|
+
'model_info': model_info},
|
|
237
|
+
mimetype=PromptType.METADATA)
|
|
238
|
+
assistant_prompts.append(prompt)
|
|
239
|
+
return assistant_prompts
|
|
219
240
|
|
|
220
241
|
def to_json(self):
|
|
221
242
|
"""
|
|
@@ -233,33 +254,9 @@ class PromptItem:
|
|
|
233
254
|
for prompt in self.prompts:
|
|
234
255
|
for prompt_key, prompt_values in prompt.to_json().items():
|
|
235
256
|
prompts_json["prompts"][prompt_key] = prompt_values
|
|
236
|
-
prompts_json["prompts"][prompt_key].append({'metadata'})
|
|
237
|
-
|
|
238
257
|
return prompts_json
|
|
239
258
|
|
|
240
|
-
def
|
|
241
|
-
"""
|
|
242
|
-
add a prompt to the prompt item
|
|
243
|
-
prompt: a dictionary. keys are prompt message id, values are prompt messages
|
|
244
|
-
responses: a list of annotations representing responses to the prompt
|
|
245
|
-
"""
|
|
246
|
-
self.prompts.append(prompt)
|
|
247
|
-
if update_item is True:
|
|
248
|
-
if self._item is not None:
|
|
249
|
-
self._item._Item__update_item_binary(_json=self.to_json())
|
|
250
|
-
else:
|
|
251
|
-
logger.warning('Item is not loaded, skipping upload')
|
|
252
|
-
|
|
253
|
-
def messages(self, model_name=None):
|
|
254
|
-
"""
|
|
255
|
-
return a list of messages in the prompt item
|
|
256
|
-
messages are returned following the openai SDK format
|
|
257
|
-
"""
|
|
258
|
-
if model_name is not None:
|
|
259
|
-
self.get_assistant_prompts(model_name=model_name)
|
|
260
|
-
else:
|
|
261
|
-
logger.warning('Model name is not provided, skipping assistant prompts')
|
|
262
|
-
|
|
259
|
+
def to_messages(self, model_name=None, include_assistant=True):
|
|
263
260
|
all_prompts_messages = dict()
|
|
264
261
|
for prompt in self.prompts:
|
|
265
262
|
if prompt.key not in all_prompts_messages:
|
|
@@ -270,38 +267,191 @@ class PromptItem:
|
|
|
270
267
|
'content': prompt_messages
|
|
271
268
|
}
|
|
272
269
|
all_prompts_messages[prompt.key].append(messages)
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
270
|
+
if include_assistant is True:
|
|
271
|
+
# reload to filer model annotations
|
|
272
|
+
for prompt in self.assistant_prompts:
|
|
273
|
+
prompt_model_name = prompt.metadata.get('model_info', dict()).get('name')
|
|
274
|
+
if model_name is not None and prompt_model_name != model_name:
|
|
275
|
+
continue
|
|
276
|
+
if prompt.key not in all_prompts_messages:
|
|
277
|
+
logger.warning(
|
|
278
|
+
f'Prompt key {prompt.key} is not found in the user prompts, skipping Assistant prompt')
|
|
279
|
+
continue
|
|
280
|
+
prompt_messages, prompt_key = prompt.messages()
|
|
281
|
+
assistant_messages = {
|
|
282
|
+
'role': 'assistant',
|
|
283
|
+
'content': prompt_messages
|
|
284
|
+
}
|
|
285
|
+
all_prompts_messages[prompt.key].append(assistant_messages)
|
|
284
286
|
res = list()
|
|
285
287
|
for prompts in all_prompts_messages.values():
|
|
286
288
|
for prompt in prompts:
|
|
287
289
|
res.append(prompt)
|
|
288
|
-
|
|
290
|
+
self._messages = res
|
|
291
|
+
return self._messages
|
|
292
|
+
|
|
293
|
+
def to_bytes_io(self):
|
|
294
|
+
# Used for item upload, do not delete
|
|
295
|
+
byte_io = io.BytesIO()
|
|
296
|
+
byte_io.name = self.name
|
|
297
|
+
byte_io.write(json.dumps(self.to_json()).encode())
|
|
298
|
+
byte_io.seek(0)
|
|
299
|
+
return byte_io
|
|
289
300
|
|
|
290
|
-
def
|
|
301
|
+
def fetch(self):
|
|
302
|
+
if self._item is None:
|
|
303
|
+
raise ValueError('Missing item, nothing to fetch..')
|
|
304
|
+
self._item = self._items.get(item_id=self._item.id)
|
|
305
|
+
self._annotations = self._item.annotations.list()
|
|
306
|
+
self.prompts = self._load_item_prompts(data=json.load(self._item.download(save_locally=False)))
|
|
307
|
+
self.assistant_prompts = self._load_annotations_prompts(self._annotations)
|
|
308
|
+
|
|
309
|
+
def build_context(self, nearest_items, add_metadata=None) -> str:
|
|
310
|
+
"""
|
|
311
|
+
Create a context stream from nearest items list.
|
|
312
|
+
add_metadata is a list of location in the item.metadata to add to the context, for instance ['system.document.source']
|
|
313
|
+
:param nearest_items: list of item ids
|
|
314
|
+
:param add_metadata: list of metadata location to add metadata to context
|
|
315
|
+
:return:
|
|
291
316
|
"""
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
317
|
+
if add_metadata is None:
|
|
318
|
+
add_metadata = list()
|
|
319
|
+
|
|
320
|
+
def stream_single(w_id):
|
|
321
|
+
context_item = self._items.get(item_id=w_id)
|
|
322
|
+
buf = context_item.download(save_locally=False)
|
|
323
|
+
text = buf.read().decode(encoding='utf-8')
|
|
324
|
+
m = ""
|
|
325
|
+
for path in add_metadata:
|
|
326
|
+
parts = path.split('.')
|
|
327
|
+
value = context_item.metadata
|
|
328
|
+
part = ""
|
|
329
|
+
for part in parts:
|
|
330
|
+
if isinstance(value, dict):
|
|
331
|
+
value = value.get(part)
|
|
332
|
+
else:
|
|
333
|
+
value = ""
|
|
334
|
+
|
|
335
|
+
m += f"{part}:{value}\n"
|
|
336
|
+
return text, m
|
|
337
|
+
|
|
338
|
+
pool = ThreadPoolExecutor(max_workers=32)
|
|
339
|
+
context = ""
|
|
340
|
+
if len(nearest_items) > 0:
|
|
341
|
+
# build context
|
|
342
|
+
results = pool.map(stream_single, nearest_items)
|
|
343
|
+
for res in results:
|
|
344
|
+
context += f"\n<source>\n{res[1]}\n</source>\n<text>\n{res[0]}\n</text>"
|
|
345
|
+
return context
|
|
346
|
+
|
|
347
|
+
def add(self,
|
|
348
|
+
message: dict,
|
|
349
|
+
prompt_key: str = None,
|
|
350
|
+
stream: bool = True,
|
|
351
|
+
model_info: dict = None):
|
|
295
352
|
"""
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
353
|
+
add a prompt to the prompt item
|
|
354
|
+
prompt: a dictionary. keys are prompt message id, values are prompt messages
|
|
355
|
+
responses: a list of annotations representing responses to the prompt
|
|
356
|
+
|
|
357
|
+
:param message:
|
|
358
|
+
:param prompt_key:
|
|
359
|
+
:param stream:
|
|
360
|
+
:param model_info:
|
|
361
|
+
:return:
|
|
362
|
+
"""
|
|
363
|
+
if prompt_key is None:
|
|
364
|
+
prompt_key = len(self.prompts) + 1
|
|
365
|
+
role = message.get('role', 'user')
|
|
366
|
+
content = message.get('content', list())
|
|
367
|
+
|
|
368
|
+
if self.role_mapping.get(role, 'item') == 'item':
|
|
369
|
+
# for new prompt we need a new key
|
|
370
|
+
prompt = Prompt(key=str(prompt_key), role=role)
|
|
371
|
+
for element in content:
|
|
372
|
+
prompt.add_element(value=element.get('value', ''),
|
|
373
|
+
mimetype=element.get('mimetype', PromptType.TEXT))
|
|
374
|
+
|
|
375
|
+
# create new prompt and add to prompts
|
|
376
|
+
self.prompts.append(prompt)
|
|
377
|
+
if self._item is not None and stream is True:
|
|
378
|
+
self._item._Item__update_item_binary(_json=self.to_json())
|
|
379
|
+
else:
|
|
380
|
+
# for response - we need to assign to previous key
|
|
381
|
+
prompt_key = str(prompt_key - 1)
|
|
382
|
+
assistant_message = content[0]
|
|
383
|
+
assistant_mimetype = assistant_message.get('mimetype', PromptType.TEXT)
|
|
384
|
+
uploaded_annotation = None
|
|
385
|
+
|
|
386
|
+
# find if prompt
|
|
387
|
+
if model_info is None:
|
|
388
|
+
# dont search for existing if there's no model information
|
|
389
|
+
existing_prompt = None
|
|
390
|
+
else:
|
|
391
|
+
existing_prompts = list()
|
|
392
|
+
for prompt in self.assistant_prompts:
|
|
393
|
+
prompt_id = prompt.key
|
|
394
|
+
model_name = prompt.metadata.get('model_info', dict()).get('name')
|
|
395
|
+
if prompt_id == prompt_key and model_name == model_info.get('name'):
|
|
396
|
+
# TODO how to handle multiple annotations
|
|
397
|
+
existing_prompts.append(prompt)
|
|
398
|
+
if len(existing_prompts) > 1:
|
|
399
|
+
assert False, "shouldn't be here! more than 1 annotation for a single model"
|
|
400
|
+
elif len(existing_prompts) == 1:
|
|
401
|
+
# found model annotation to upload
|
|
402
|
+
existing_prompt = existing_prompts[0]
|
|
403
|
+
else:
|
|
404
|
+
# no annotation found
|
|
405
|
+
existing_prompt = None
|
|
406
|
+
|
|
407
|
+
if existing_prompt is None:
|
|
408
|
+
prompt = Prompt(key=prompt_key)
|
|
409
|
+
if assistant_mimetype == PromptType.TEXT:
|
|
410
|
+
annotation_definition = entities.FreeText(text=assistant_message.get('value'))
|
|
411
|
+
prompt.add_element(value=annotation_definition.to_coordinates(None),
|
|
412
|
+
mimetype=PromptType.TEXT)
|
|
413
|
+
elif assistant_mimetype == PromptType.IMAGE:
|
|
414
|
+
annotation_definition = entities.RefImage(ref=assistant_message.get('value'))
|
|
415
|
+
prompt.add_element(value=annotation_definition.to_coordinates(None).get('ref'),
|
|
416
|
+
mimetype=PromptType.IMAGE)
|
|
417
|
+
else:
|
|
418
|
+
raise NotImplementedError('Only images of mimetype image and text are supported')
|
|
419
|
+
metadata = {'system': {'promptId': prompt_key},
|
|
420
|
+
'user': {'model': model_info}}
|
|
421
|
+
prompt.add_element(mimetype=PromptType.METADATA,
|
|
422
|
+
value={"model_info": model_info})
|
|
423
|
+
|
|
424
|
+
if stream:
|
|
425
|
+
existing_annotation = entities.Annotation.new(item=self._item,
|
|
426
|
+
metadata=metadata,
|
|
427
|
+
annotation_definition=annotation_definition)
|
|
428
|
+
uploaded_annotation = existing_annotation.upload()
|
|
429
|
+
prompt.add_element(mimetype=PromptType.METADATA,
|
|
430
|
+
value={"id": uploaded_annotation.id})
|
|
431
|
+
existing_prompt = prompt
|
|
432
|
+
self.assistant_prompts.append(prompt)
|
|
433
|
+
|
|
434
|
+
# TODO Shadi fix
|
|
435
|
+
existing_prompt_element = [element for element in existing_prompt.elements if
|
|
436
|
+
element['mimetype'] != PromptType.METADATA][-1]
|
|
437
|
+
existing_prompt_element['value'] = assistant_message.get('value')
|
|
438
|
+
if stream is True and uploaded_annotation is None:
|
|
439
|
+
# Creating annotation with old dict to match platform dict
|
|
440
|
+
annotation_definition = entities.FreeText(text='')
|
|
441
|
+
metadata = {'system': {'promptId': prompt_key},
|
|
442
|
+
'user': {'model': existing_prompt.metadata.get('model_info')}}
|
|
443
|
+
annotation = entities.Annotation.new(item=self._item,
|
|
444
|
+
metadata=metadata,
|
|
445
|
+
annotation_definition=annotation_definition
|
|
446
|
+
)
|
|
447
|
+
annotation.id = existing_prompt.metadata['id']
|
|
448
|
+
# set the platform dict to match the old annotation for the dict difference check, otherwise it won't
|
|
449
|
+
# update
|
|
450
|
+
annotation._platform_dict = annotation.to_json()
|
|
451
|
+
# update the annotation with the new text
|
|
452
|
+
annotation.annotation_definition.text = existing_prompt_element['value']
|
|
453
|
+
self._item.annotations.update(annotation)
|
|
454
|
+
|
|
455
|
+
def update(self):
|
|
456
|
+
if self._item is not None:
|
|
457
|
+
self._item._Item__update_item_binary(_json=self.to_json())
|
dtlpy/ml/base_model_adapter.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import dataclasses
|
|
1
2
|
import tempfile
|
|
2
3
|
import datetime
|
|
3
4
|
import logging
|
|
@@ -19,10 +20,41 @@ from ..services.api_client import ApiClient
|
|
|
19
20
|
logger = logging.getLogger('ModelAdapter')
|
|
20
21
|
|
|
21
22
|
|
|
23
|
+
@dataclasses.dataclass
|
|
24
|
+
class AdapterDefaults(dict):
|
|
25
|
+
# for predict items, dataset, evaluate
|
|
26
|
+
upload_annotations: bool = dataclasses.field(default=True)
|
|
27
|
+
clean_annotations: bool = dataclasses.field(default=True)
|
|
28
|
+
# for embeddings
|
|
29
|
+
upload_features: bool = dataclasses.field(default=True)
|
|
30
|
+
# for training
|
|
31
|
+
root_path: str = dataclasses.field(default=None)
|
|
32
|
+
data_path: str = dataclasses.field(default=None)
|
|
33
|
+
output_path: str = dataclasses.field(default=None)
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
# Initialize the internal dictionary with the dataclass fields
|
|
37
|
+
self.update(**dataclasses.asdict(self))
|
|
38
|
+
|
|
39
|
+
def update(self, **kwargs):
|
|
40
|
+
for f in dataclasses.fields(AdapterDefaults):
|
|
41
|
+
if f.name in kwargs:
|
|
42
|
+
setattr(self, f.name, kwargs[f.name])
|
|
43
|
+
super().update(**kwargs)
|
|
44
|
+
|
|
45
|
+
def resolve(self, key, *args):
|
|
46
|
+
|
|
47
|
+
for arg in args:
|
|
48
|
+
if arg is not None:
|
|
49
|
+
return arg
|
|
50
|
+
return self.get(key, None)
|
|
51
|
+
|
|
52
|
+
|
|
22
53
|
class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
23
54
|
_client_api = attr.ib(type=ApiClient, repr=False)
|
|
24
55
|
|
|
25
56
|
def __init__(self, model_entity: entities.Model = None):
|
|
57
|
+
self.adapter_defaults = AdapterDefaults()
|
|
26
58
|
self.logger = logger
|
|
27
59
|
# entities
|
|
28
60
|
self._model_entity = None
|
|
@@ -222,14 +254,18 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
222
254
|
download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
|
|
223
255
|
|
|
224
256
|
:param dataset: dl.Dataset
|
|
225
|
-
:param root_path: `str` root directory for training. default is "tmp"
|
|
226
|
-
:param data_path: `str` dataset directory. default <root_path>/"data"
|
|
227
|
-
:param output_path: `str` save everything to this folder. default <root_path>/"output"
|
|
257
|
+
:param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
|
|
258
|
+
:param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
|
|
259
|
+
:param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
|
|
228
260
|
|
|
229
261
|
:param bool overwrite: overwrite the data path (download again). default is False
|
|
230
262
|
"""
|
|
231
263
|
# define paths
|
|
232
264
|
dataloop_path = os.path.join(os.path.expanduser('~'), '.dataloop')
|
|
265
|
+
root_path = self.adapter_defaults.resolve("root_path", root_path)
|
|
266
|
+
data_path = self.adapter_defaults.resolve("data_path", data_path)
|
|
267
|
+
output_path = self.adapter_defaults.resolve("output_path", output_path)
|
|
268
|
+
|
|
233
269
|
if root_path is None:
|
|
234
270
|
now = datetime.datetime.now()
|
|
235
271
|
root_path = os.path.join(dataloop_path,
|
|
@@ -311,6 +347,8 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
311
347
|
local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
|
|
312
348
|
# Load configuration
|
|
313
349
|
self.configuration = self.model_entity.configuration
|
|
350
|
+
# Update the adapter config with the model config to run over defaults if needed
|
|
351
|
+
self.adapter_defaults.update(**self.configuration)
|
|
314
352
|
# Download
|
|
315
353
|
self.model_entity.artifacts.download(
|
|
316
354
|
local_path=local_path,
|
|
@@ -331,6 +369,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
331
369
|
:param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
|
|
332
370
|
:return:
|
|
333
371
|
"""
|
|
372
|
+
|
|
334
373
|
if local_path is None:
|
|
335
374
|
local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
|
|
336
375
|
self.logger.debug("Using temporary dir at {}".format(local_path))
|
|
@@ -354,7 +393,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
354
393
|
@entities.Package.decorators.function(display_name='Predict Items',
|
|
355
394
|
inputs={'items': 'Item[]'},
|
|
356
395
|
outputs={'items': 'Item[]', 'annotations': 'Annotation[]'})
|
|
357
|
-
def predict_items(self, items: list, upload_annotations=
|
|
396
|
+
def predict_items(self, items: list, upload_annotations=None, clean_annotations=None, batch_size=None, **kwargs):
|
|
358
397
|
"""
|
|
359
398
|
Run the predict function on the input list of items (or single) and return the items and the predictions.
|
|
360
399
|
Each prediction is by the model output type (package.output_type) and model_info in the metadata
|
|
@@ -368,6 +407,8 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
368
407
|
"""
|
|
369
408
|
if batch_size is None:
|
|
370
409
|
batch_size = self.configuration.get('batch_size', 4)
|
|
410
|
+
upload_annotations = self.adapter_defaults.resolve("upload_annotations", upload_annotations)
|
|
411
|
+
clean_annotations = self.adapter_defaults.resolve("clean_annotations", clean_annotations)
|
|
371
412
|
input_type = self.model_entity.input_type
|
|
372
413
|
self.logger.debug(
|
|
373
414
|
"Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
@@ -410,8 +451,8 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
410
451
|
|
|
411
452
|
@entities.Package.decorators.function(display_name='Embed Items',
|
|
412
453
|
inputs={'items': 'Item[]'},
|
|
413
|
-
outputs={'items': 'Item[]', 'features': '[]'})
|
|
414
|
-
def embed_items(self, items: list, upload_features=
|
|
454
|
+
outputs={'items': 'Item[]', 'features': 'Json[]'})
|
|
455
|
+
def embed_items(self, items: list, upload_features=None, batch_size=None, **kwargs):
|
|
415
456
|
"""
|
|
416
457
|
Extract feature from an input list of items (or single) and return the items and the feature vector.
|
|
417
458
|
|
|
@@ -423,6 +464,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
423
464
|
"""
|
|
424
465
|
if batch_size is None:
|
|
425
466
|
batch_size = self.configuration.get('batch_size', 4)
|
|
467
|
+
upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
|
|
426
468
|
input_type = self.model_entity.input_type
|
|
427
469
|
self.logger.debug(
|
|
428
470
|
"Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
@@ -486,7 +528,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
486
528
|
def embed_dataset(self,
|
|
487
529
|
dataset: entities.Dataset,
|
|
488
530
|
filters: entities.Filters = None,
|
|
489
|
-
upload_features=
|
|
531
|
+
upload_features=None,
|
|
490
532
|
batch_size=None,
|
|
491
533
|
**kwargs):
|
|
492
534
|
"""
|
|
@@ -501,8 +543,9 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
501
543
|
"""
|
|
502
544
|
if batch_size is None:
|
|
503
545
|
batch_size = self.configuration.get('batch_size', 4)
|
|
546
|
+
upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
|
|
504
547
|
|
|
505
|
-
self.logger.debug("Creating
|
|
548
|
+
self.logger.debug("Creating embeddings for dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
|
|
506
549
|
dataset.id,
|
|
507
550
|
batch_size))
|
|
508
551
|
if not filters:
|
|
@@ -524,8 +567,8 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
524
567
|
def predict_dataset(self,
|
|
525
568
|
dataset: entities.Dataset,
|
|
526
569
|
filters: entities.Filters = None,
|
|
527
|
-
|
|
528
|
-
|
|
570
|
+
upload_annotations=None,
|
|
571
|
+
clean_annotations=None,
|
|
529
572
|
batch_size=None,
|
|
530
573
|
**kwargs):
|
|
531
574
|
"""
|
|
@@ -533,8 +576,8 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
533
576
|
|
|
534
577
|
:param dataset: Dataset entity to predict
|
|
535
578
|
:param filters: Filters entity for a filtering before predicting
|
|
536
|
-
:param
|
|
537
|
-
:param
|
|
579
|
+
:param upload_annotations: `bool` uploads the predictions back to the given items
|
|
580
|
+
:param clean_annotations: `bool` if set removes existing predictions with the same package-model name (default: False)
|
|
538
581
|
:param batch_size: `int` size of batch to run a single inference
|
|
539
582
|
|
|
540
583
|
:return: `bool` indicating if the prediction process completed successfully
|
|
@@ -554,8 +597,8 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
|
554
597
|
# Item type is 'file' only, can be deleted if default filters are added to custom filters
|
|
555
598
|
items = [item for page in pages for item in page if item.type == 'file']
|
|
556
599
|
self.predict_items(items=items,
|
|
557
|
-
upload_annotations=
|
|
558
|
-
|
|
600
|
+
upload_annotations=upload_annotations,
|
|
601
|
+
clean_annotations=clean_annotations,
|
|
559
602
|
batch_size=batch_size,
|
|
560
603
|
**kwargs)
|
|
561
604
|
return True
|
dtlpy/repositories/__init__.py
CHANGED
dtlpy/repositories/apps.py
CHANGED
|
@@ -210,7 +210,8 @@ class Apps:
|
|
|
210
210
|
organization_id: str = None,
|
|
211
211
|
custom_installation: dict = None,
|
|
212
212
|
scope: entities.AppScope = None,
|
|
213
|
-
wait: bool = True
|
|
213
|
+
wait: bool = True,
|
|
214
|
+
integrations: list = None
|
|
214
215
|
) -> entities.App:
|
|
215
216
|
"""
|
|
216
217
|
Install the specified app in the project.
|
|
@@ -222,6 +223,7 @@ class Apps:
|
|
|
222
223
|
:param dict custom_installation: partial installation.
|
|
223
224
|
:param str scope: the scope of the app. default is project.
|
|
224
225
|
:param bool wait: wait for the operation to finish.
|
|
226
|
+
:param list integrations: list of integrations to install with the app.
|
|
225
227
|
|
|
226
228
|
:return the installed app.
|
|
227
229
|
:rtype entities.App
|
|
@@ -243,7 +245,8 @@ class Apps:
|
|
|
243
245
|
'dpkName': dpk.name,
|
|
244
246
|
"customInstallation": custom_installation,
|
|
245
247
|
'dpkVersion': dpk.version,
|
|
246
|
-
'scope': scope
|
|
248
|
+
'scope': scope,
|
|
249
|
+
'integrations': integrations
|
|
247
250
|
},
|
|
248
251
|
client_api=self._client_api,
|
|
249
252
|
project=self.project)
|