orca-sdk 0.0.95__py3-none-any.whl → 0.0.97__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -5
- orca_sdk/_generated_api_client/api/__init__.py +22 -2
- orca_sdk/_generated_api_client/api/{datasource/create_datasource_datasource_post.py → auth/create_org_plan_auth_org_plan_post.py} +32 -31
- orca_sdk/_generated_api_client/api/auth/get_org_plan_auth_org_plan_get.py +122 -0
- orca_sdk/_generated_api_client/api/auth/update_org_plan_auth_org_plan_put.py +168 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_content_datasource_post.py +224 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_files_datasource_upload_post.py +229 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +21 -26
- orca_sdk/_generated_api_client/api/telemetry/generate_memory_suggestions_telemetry_prediction_prediction_id_memory_suggestions_post.py +239 -0
- orca_sdk/_generated_api_client/api/telemetry/get_action_recommendation_telemetry_prediction_prediction_id_action_get.py +192 -0
- orca_sdk/_generated_api_client/models/__init__.py +54 -4
- orca_sdk/_generated_api_client/models/action_recommendation.py +82 -0
- orca_sdk/_generated_api_client/models/action_recommendation_action.py +11 -0
- orca_sdk/_generated_api_client/models/add_memory_recommendations.py +85 -0
- orca_sdk/_generated_api_client/models/add_memory_suggestion.py +79 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_from_files_datasource_upload_post.py +145 -0
- orca_sdk/_generated_api_client/models/class_representatives.py +92 -0
- orca_sdk/_generated_api_client/models/classification_model_metadata.py +14 -0
- orca_sdk/_generated_api_client/models/clone_memoryset_request.py +40 -0
- orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/constraint_violation_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/create_classification_model_request.py +40 -0
- orca_sdk/_generated_api_client/models/create_datasource_from_content_request.py +101 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request.py +40 -0
- orca_sdk/_generated_api_client/models/create_org_plan_request.py +73 -0
- orca_sdk/_generated_api_client/models/create_org_plan_request_tier.py +11 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +20 -0
- orca_sdk/_generated_api_client/models/embed_request.py +20 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +28 -10
- orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +28 -10
- orca_sdk/_generated_api_client/models/embedding_model_result.py +9 -0
- orca_sdk/_generated_api_client/models/filter_item.py +31 -23
- orca_sdk/_generated_api_client/models/filter_item_field_type_1_item_type_0.py +8 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_0.py +8 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +2 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/internal_server_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +5 -5
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +16 -16
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +5 -5
- orca_sdk/_generated_api_client/models/lookup_request.py +20 -0
- orca_sdk/_generated_api_client/models/memory_metrics.py +98 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +33 -0
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_analysis_config.py +79 -0
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_metrics.py +138 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata.py +42 -0
- orca_sdk/_generated_api_client/models/memoryset_metrics.py +33 -0
- orca_sdk/_generated_api_client/models/memoryset_update.py +20 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +6 -7
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/org_plan.py +99 -0
- orca_sdk/_generated_api_client/models/org_plan_tier.py +11 -0
- orca_sdk/_generated_api_client/models/paginated_task.py +108 -0
- orca_sdk/_generated_api_client/models/predictive_model_update.py +20 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +8 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +14 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +9 -9
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/service_unavailable_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_0.py +8 -0
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_0.py +8 -0
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_1.py +8 -0
- orca_sdk/_generated_api_client/models/telemetry_filter_item.py +42 -30
- orca_sdk/_generated_api_client/models/telemetry_sort_options.py +42 -30
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/unauthenticated_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/unauthorized_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/update_org_plan_request.py +73 -0
- orca_sdk/_generated_api_client/models/update_org_plan_request_tier.py +11 -0
- orca_sdk/_shared/metrics.py +1 -1
- orca_sdk/classification_model.py +4 -1
- orca_sdk/classification_model_test.py +53 -0
- orca_sdk/credentials.py +15 -1
- orca_sdk/datasource.py +180 -41
- orca_sdk/datasource_test.py +194 -0
- orca_sdk/embedding_model.py +51 -13
- orca_sdk/embedding_model_test.py +27 -0
- orca_sdk/job.py +15 -14
- orca_sdk/job_test.py +34 -0
- orca_sdk/memoryset.py +47 -7
- orca_sdk/regression_model_test.py +0 -1
- orca_sdk/telemetry.py +94 -3
- {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/METADATA +18 -1
- {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/RECORD +87 -56
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -207
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -246
- {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/WHEEL +0 -0
orca_sdk/datasource.py
CHANGED
|
@@ -4,32 +4,105 @@ import logging
|
|
|
4
4
|
import tempfile
|
|
5
5
|
import zipfile
|
|
6
6
|
from datetime import datetime
|
|
7
|
+
from io import BytesIO
|
|
7
8
|
from os import PathLike
|
|
8
9
|
from pathlib import Path
|
|
9
|
-
from typing import
|
|
10
|
+
from typing import Union
|
|
10
11
|
|
|
11
12
|
import pandas as pd
|
|
12
13
|
import pyarrow as pa
|
|
13
|
-
from datasets import Dataset
|
|
14
|
+
from datasets import Dataset, DatasetDict
|
|
15
|
+
from pyarrow import parquet
|
|
14
16
|
from torch.utils.data import DataLoader as TorchDataLoader
|
|
15
17
|
from torch.utils.data import Dataset as TorchDataset
|
|
16
18
|
from tqdm.auto import tqdm
|
|
17
19
|
|
|
18
20
|
from ._generated_api_client.api import (
|
|
21
|
+
create_datasource_from_content,
|
|
19
22
|
delete_datasource,
|
|
20
23
|
get_datasource,
|
|
21
24
|
list_datasources,
|
|
22
25
|
)
|
|
23
|
-
from ._generated_api_client.api.datasource.
|
|
26
|
+
from ._generated_api_client.api.datasource.create_datasource_from_files_datasource_upload_post import (
|
|
24
27
|
_parse_response as parse_create_response,
|
|
25
28
|
)
|
|
26
29
|
from ._generated_api_client.client import get_client
|
|
27
|
-
from ._generated_api_client.models import
|
|
30
|
+
from ._generated_api_client.models import (
|
|
31
|
+
ColumnType,
|
|
32
|
+
CreateDatasourceFromContentRequest,
|
|
33
|
+
DatasourceMetadata,
|
|
34
|
+
)
|
|
28
35
|
from ._utils.common import CreateMode, DropMode
|
|
29
|
-
from ._utils.data_parsing import
|
|
36
|
+
from ._utils.data_parsing import hf_dataset_from_torch
|
|
30
37
|
from ._utils.tqdm_file_reader import TqdmFileReader
|
|
31
38
|
|
|
32
39
|
|
|
40
|
+
def _upload_files_to_datasource(
|
|
41
|
+
name: str,
|
|
42
|
+
file_paths: list[Path],
|
|
43
|
+
description: str | None = None,
|
|
44
|
+
) -> DatasourceMetadata:
|
|
45
|
+
"""
|
|
46
|
+
Helper function to upload files to create a datasource using manual HTTP requests.
|
|
47
|
+
|
|
48
|
+
This bypasses the generated client because it doesn't handle file uploads properly.
|
|
49
|
+
|
|
50
|
+
Params:
|
|
51
|
+
name: Name for the datasource
|
|
52
|
+
file_paths: List of file paths to upload
|
|
53
|
+
description: Optional description for the datasource
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Metadata for the created datasource
|
|
57
|
+
"""
|
|
58
|
+
client = get_client()
|
|
59
|
+
files = []
|
|
60
|
+
|
|
61
|
+
# Calculate total size for all files
|
|
62
|
+
total_size = sum(file_path.stat().st_size for file_path in file_paths)
|
|
63
|
+
|
|
64
|
+
with tqdm(total=total_size, unit="B", unit_scale=True, desc="Uploading") as pbar:
|
|
65
|
+
for file_path in file_paths:
|
|
66
|
+
buffered_reader = open(file_path, "rb")
|
|
67
|
+
tqdm_reader = TqdmFileReader(buffered_reader, pbar)
|
|
68
|
+
files.append(("files", (file_path.name, tqdm_reader)))
|
|
69
|
+
|
|
70
|
+
# Use manual HTTP request for file uploads
|
|
71
|
+
metadata = parse_create_response(
|
|
72
|
+
response=client.get_httpx_client().request(
|
|
73
|
+
method="post",
|
|
74
|
+
url="/datasource/upload",
|
|
75
|
+
files=files,
|
|
76
|
+
data={"name": name, "description": description},
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return metadata
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _handle_existing_datasource(name: str, if_exists: CreateMode) -> Union["Datasource", None]:
|
|
84
|
+
"""
|
|
85
|
+
Helper function to handle the common pattern of checking if a datasource exists
|
|
86
|
+
and taking action based on the if_exists parameter.
|
|
87
|
+
|
|
88
|
+
Params:
|
|
89
|
+
name: Name of the datasource to check
|
|
90
|
+
if_exists: What to do if a datasource with the same name already exists
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Datasource instance if opening existing, None if should proceed with creation
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
ValueError: If the datasource already exists and if_exists is "error"
|
|
97
|
+
"""
|
|
98
|
+
if Datasource.exists(name):
|
|
99
|
+
if if_exists == "error":
|
|
100
|
+
raise ValueError(f"Dataset with name {name} already exists")
|
|
101
|
+
elif if_exists == "open":
|
|
102
|
+
return Datasource.open(name)
|
|
103
|
+
return None
|
|
104
|
+
|
|
105
|
+
|
|
33
106
|
class Datasource:
|
|
34
107
|
"""
|
|
35
108
|
A Handle to a datasource in the OrcaCloud
|
|
@@ -138,40 +211,54 @@ class Datasource:
|
|
|
138
211
|
Raises:
|
|
139
212
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
140
213
|
"""
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
if
|
|
144
|
-
|
|
145
|
-
raise ValueError(f"Dataset with name {name} already exists")
|
|
146
|
-
elif if_exists == "open":
|
|
147
|
-
return cls.open(name)
|
|
214
|
+
# Check if datasource already exists and handle accordingly
|
|
215
|
+
existing = _handle_existing_datasource(name, if_exists)
|
|
216
|
+
if existing is not None:
|
|
217
|
+
return existing
|
|
148
218
|
|
|
149
219
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
150
220
|
dataset.save_to_disk(tmp_dir)
|
|
151
|
-
files = []
|
|
152
221
|
|
|
153
|
-
#
|
|
222
|
+
# Get all file paths in the directory
|
|
154
223
|
file_paths = list(Path(tmp_dir).iterdir())
|
|
155
|
-
total_size = sum(file_path.stat().st_size for file_path in file_paths)
|
|
156
|
-
|
|
157
|
-
with tqdm(total=total_size, unit="B", unit_scale=True, desc="Uploading") as pbar:
|
|
158
|
-
for file_path in file_paths:
|
|
159
|
-
buffered_reader = open(file_path, "rb")
|
|
160
|
-
tqdm_reader = TqdmFileReader(buffered_reader, pbar)
|
|
161
|
-
files.append(("files", (file_path.name, tqdm_reader)))
|
|
162
|
-
|
|
163
|
-
# Do not use Generated client for this endpoint b/c it does not handle files properly
|
|
164
|
-
metadata = parse_create_response(
|
|
165
|
-
response=client.get_httpx_client().request(
|
|
166
|
-
method="post",
|
|
167
|
-
url="/datasource/",
|
|
168
|
-
files=files,
|
|
169
|
-
data={"name": name, "description": description},
|
|
170
|
-
)
|
|
171
|
-
)
|
|
172
224
|
|
|
225
|
+
# Use the helper function to upload files
|
|
226
|
+
metadata = _upload_files_to_datasource(name, file_paths, description)
|
|
173
227
|
return cls(metadata=metadata)
|
|
174
228
|
|
|
229
|
+
@classmethod
|
|
230
|
+
def from_hf_dataset_dict(
|
|
231
|
+
cls,
|
|
232
|
+
name: str,
|
|
233
|
+
dataset_dict: DatasetDict,
|
|
234
|
+
if_exists: CreateMode = "error",
|
|
235
|
+
description: dict[str, str | None] | str | None = None,
|
|
236
|
+
) -> dict[str, Datasource]:
|
|
237
|
+
"""
|
|
238
|
+
Create datasources from a Hugging Face DatasetDict
|
|
239
|
+
|
|
240
|
+
Params:
|
|
241
|
+
name: Name prefix for the new datasources, will be suffixed with the dataset name
|
|
242
|
+
dataset_dict: The Hugging Face DatasetDict to create the datasources from
|
|
243
|
+
if_exists: What to do if a datasource with the same name already exists, defaults to
|
|
244
|
+
`"error"`. Other option is `"open"` to open the existing datasource.
|
|
245
|
+
description: Optional description for the datasources, can be a string or a dictionary of dataset names to descriptions
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
A dictionary of datasource handles, keyed by the dataset name
|
|
249
|
+
|
|
250
|
+
Raises:
|
|
251
|
+
ValueError: If a datasource already exists and if_exists is `"error"`
|
|
252
|
+
"""
|
|
253
|
+
if description is None or isinstance(description, str):
|
|
254
|
+
description = {dataset_name: description for dataset_name in dataset_dict.keys()}
|
|
255
|
+
return {
|
|
256
|
+
dataset_name: cls.from_hf_dataset(
|
|
257
|
+
f"{name}_{dataset_name}", dataset, if_exists=if_exists, description=description[dataset_name]
|
|
258
|
+
)
|
|
259
|
+
for dataset_name, dataset in dataset_dict.items()
|
|
260
|
+
}
|
|
261
|
+
|
|
175
262
|
@classmethod
|
|
176
263
|
def from_pytorch(
|
|
177
264
|
cls,
|
|
@@ -225,8 +312,16 @@ class Datasource:
|
|
|
225
312
|
Examples:
|
|
226
313
|
>>> Datasource.from_list("my_datasource", [{"text": "Hello, world!", "label": 1}, {"text": "Goodbye", "label": 0}])
|
|
227
314
|
"""
|
|
228
|
-
|
|
229
|
-
|
|
315
|
+
# Check if datasource already exists and handle accordingly
|
|
316
|
+
existing = _handle_existing_datasource(name, if_exists)
|
|
317
|
+
if existing is not None:
|
|
318
|
+
return existing
|
|
319
|
+
|
|
320
|
+
# Use the generated API client function for content creation
|
|
321
|
+
body = CreateDatasourceFromContentRequest(name=name, description=description, content=data)
|
|
322
|
+
|
|
323
|
+
metadata = create_datasource_from_content(body=body)
|
|
324
|
+
return cls(metadata=metadata)
|
|
230
325
|
|
|
231
326
|
@classmethod
|
|
232
327
|
def from_dict(
|
|
@@ -251,8 +346,16 @@ class Datasource:
|
|
|
251
346
|
Examples:
|
|
252
347
|
>>> Datasource.from_dict("my_datasource", {"text": ["Hello, world!", "Goodbye"], "label": [1, 0]})
|
|
253
348
|
"""
|
|
254
|
-
|
|
255
|
-
|
|
349
|
+
# Check if datasource already exists and handle accordingly
|
|
350
|
+
existing = _handle_existing_datasource(name, if_exists)
|
|
351
|
+
if existing is not None:
|
|
352
|
+
return existing
|
|
353
|
+
|
|
354
|
+
# Use the generated API client function for content creation
|
|
355
|
+
body = CreateDatasourceFromContentRequest(name=name, description=description, content=data)
|
|
356
|
+
|
|
357
|
+
metadata = create_datasource_from_content(body=body)
|
|
358
|
+
return cls(metadata=metadata)
|
|
256
359
|
|
|
257
360
|
@classmethod
|
|
258
361
|
def from_pandas(
|
|
@@ -274,8 +377,8 @@ class Datasource:
|
|
|
274
377
|
Raises:
|
|
275
378
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
276
379
|
"""
|
|
277
|
-
|
|
278
|
-
return cls.from_hf_dataset(name,
|
|
380
|
+
dataset = Dataset.from_pandas(dataframe)
|
|
381
|
+
return cls.from_hf_dataset(name, dataset, if_exists=if_exists, description=description)
|
|
279
382
|
|
|
280
383
|
@classmethod
|
|
281
384
|
def from_arrow(
|
|
@@ -297,8 +400,29 @@ class Datasource:
|
|
|
297
400
|
Raises:
|
|
298
401
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
299
402
|
"""
|
|
300
|
-
|
|
301
|
-
|
|
403
|
+
# Check if datasource already exists and handle accordingly
|
|
404
|
+
existing = _handle_existing_datasource(name, if_exists)
|
|
405
|
+
if existing is not None:
|
|
406
|
+
return existing
|
|
407
|
+
|
|
408
|
+
# Write to bytes buffer
|
|
409
|
+
buffer = BytesIO()
|
|
410
|
+
parquet.write_table(pyarrow_table, buffer)
|
|
411
|
+
parquet_bytes = buffer.getvalue()
|
|
412
|
+
|
|
413
|
+
client = get_client()
|
|
414
|
+
|
|
415
|
+
# Use manual HTTP request for file uploads
|
|
416
|
+
metadata = parse_create_response(
|
|
417
|
+
response=client.get_httpx_client().request(
|
|
418
|
+
method="post",
|
|
419
|
+
url="/datasource/upload",
|
|
420
|
+
files=[("files", ("data.parquet", parquet_bytes))],
|
|
421
|
+
data={"name": name, "description": description},
|
|
422
|
+
)
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
return cls(metadata=metadata)
|
|
302
426
|
|
|
303
427
|
@classmethod
|
|
304
428
|
def from_disk(
|
|
@@ -328,8 +452,23 @@ class Datasource:
|
|
|
328
452
|
Raises:
|
|
329
453
|
ValueError: If the datasource already exists and if_exists is `"error"`
|
|
330
454
|
"""
|
|
331
|
-
|
|
332
|
-
|
|
455
|
+
# Check if datasource already exists and handle accordingly
|
|
456
|
+
existing = _handle_existing_datasource(name, if_exists)
|
|
457
|
+
if existing is not None:
|
|
458
|
+
return existing
|
|
459
|
+
|
|
460
|
+
file_path = Path(file_path)
|
|
461
|
+
|
|
462
|
+
# For dataset directories, use the upload endpoint with multiple files
|
|
463
|
+
if file_path.is_dir():
|
|
464
|
+
return cls.from_hf_dataset(
|
|
465
|
+
name, Dataset.load_from_disk(file_path), if_exists=if_exists, description=description
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# For single files, use the helper function to upload files
|
|
469
|
+
metadata = _upload_files_to_datasource(name, [file_path], description)
|
|
470
|
+
|
|
471
|
+
return cls(metadata=metadata)
|
|
333
472
|
|
|
334
473
|
@classmethod
|
|
335
474
|
def open(cls, name: str) -> Datasource:
|
orca_sdk/datasource_test.py
CHANGED
|
@@ -2,6 +2,8 @@ import os
|
|
|
2
2
|
import tempfile
|
|
3
3
|
from uuid import uuid4
|
|
4
4
|
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import pyarrow as pa
|
|
5
7
|
import pytest
|
|
6
8
|
|
|
7
9
|
from .datasource import Datasource
|
|
@@ -102,3 +104,195 @@ def test_download_datasource(datasource):
|
|
|
102
104
|
output_path = os.path.join(temp_dir, "datasource.zip")
|
|
103
105
|
datasource.download(output_path)
|
|
104
106
|
assert os.path.exists(output_path)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_from_list():
|
|
110
|
+
# Test creating datasource from list of dictionaries
|
|
111
|
+
data = [
|
|
112
|
+
{"column1": 1, "column2": "a"},
|
|
113
|
+
{"column1": 2, "column2": "b"},
|
|
114
|
+
{"column1": 3, "column2": "c"},
|
|
115
|
+
]
|
|
116
|
+
datasource = Datasource.from_list(f"test_list_{uuid4()}", data)
|
|
117
|
+
assert datasource.name.startswith("test_list_")
|
|
118
|
+
assert datasource.length == 3
|
|
119
|
+
assert "column1" in datasource.columns
|
|
120
|
+
assert "column2" in datasource.columns
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_from_dict():
|
|
124
|
+
# Test creating datasource from dictionary of columns
|
|
125
|
+
data = {
|
|
126
|
+
"column1": [1, 2, 3],
|
|
127
|
+
"column2": ["a", "b", "c"],
|
|
128
|
+
}
|
|
129
|
+
datasource = Datasource.from_dict(f"test_dict_{uuid4()}", data)
|
|
130
|
+
assert datasource.name.startswith("test_dict_")
|
|
131
|
+
assert datasource.length == 3
|
|
132
|
+
assert "column1" in datasource.columns
|
|
133
|
+
assert "column2" in datasource.columns
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def test_from_pandas():
|
|
137
|
+
# Test creating datasource from pandas DataFrame
|
|
138
|
+
df = pd.DataFrame(
|
|
139
|
+
{
|
|
140
|
+
"column1": [1, 2, 3],
|
|
141
|
+
"column2": ["a", "b", "c"],
|
|
142
|
+
}
|
|
143
|
+
)
|
|
144
|
+
datasource = Datasource.from_pandas(f"test_pandas_{uuid4()}", df)
|
|
145
|
+
assert datasource.name.startswith("test_pandas_")
|
|
146
|
+
assert datasource.length == 3
|
|
147
|
+
assert "column1" in datasource.columns
|
|
148
|
+
assert "column2" in datasource.columns
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_from_arrow():
|
|
152
|
+
# Test creating datasource from pyarrow Table
|
|
153
|
+
table = pa.table(
|
|
154
|
+
{
|
|
155
|
+
"column1": [1, 2, 3],
|
|
156
|
+
"column2": ["a", "b", "c"],
|
|
157
|
+
}
|
|
158
|
+
)
|
|
159
|
+
datasource = Datasource.from_arrow(f"test_arrow_{uuid4()}", table)
|
|
160
|
+
assert datasource.name.startswith("test_arrow_")
|
|
161
|
+
assert datasource.length == 3
|
|
162
|
+
assert "column1" in datasource.columns
|
|
163
|
+
assert "column2" in datasource.columns
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def test_from_list_already_exists():
|
|
167
|
+
# Test the if_exists parameter with from_list
|
|
168
|
+
data = [{"column1": 1, "column2": "a"}]
|
|
169
|
+
name = f"test_list_exists_{uuid4()}"
|
|
170
|
+
|
|
171
|
+
# Create the first datasource
|
|
172
|
+
datasource1 = Datasource.from_list(name, data)
|
|
173
|
+
assert datasource1.length == 1
|
|
174
|
+
|
|
175
|
+
# Try to create again with if_exists="error" (should raise)
|
|
176
|
+
with pytest.raises(ValueError):
|
|
177
|
+
Datasource.from_list(name, data, if_exists="error")
|
|
178
|
+
|
|
179
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
180
|
+
datasource2 = Datasource.from_list(name, data, if_exists="open")
|
|
181
|
+
assert datasource2.id == datasource1.id
|
|
182
|
+
assert datasource2.name == datasource1.name
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def test_from_dict_already_exists():
|
|
186
|
+
# Test the if_exists parameter with from_dict
|
|
187
|
+
data = {"column1": [1], "column2": ["a"]}
|
|
188
|
+
name = f"test_dict_exists_{uuid4()}"
|
|
189
|
+
|
|
190
|
+
# Create the first datasource
|
|
191
|
+
datasource1 = Datasource.from_dict(name, data)
|
|
192
|
+
assert datasource1.length == 1
|
|
193
|
+
|
|
194
|
+
# Try to create again with if_exists="error" (should raise)
|
|
195
|
+
with pytest.raises(ValueError):
|
|
196
|
+
Datasource.from_dict(name, data, if_exists="error")
|
|
197
|
+
|
|
198
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
199
|
+
datasource2 = Datasource.from_dict(name, data, if_exists="open")
|
|
200
|
+
assert datasource2.id == datasource1.id
|
|
201
|
+
assert datasource2.name == datasource1.name
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_from_pandas_already_exists():
|
|
205
|
+
# Test the if_exists parameter with from_pandas
|
|
206
|
+
df = pd.DataFrame({"column1": [1], "column2": ["a"]})
|
|
207
|
+
name = f"test_pandas_exists_{uuid4()}"
|
|
208
|
+
|
|
209
|
+
# Create the first datasource
|
|
210
|
+
datasource1 = Datasource.from_pandas(name, df)
|
|
211
|
+
assert datasource1.length == 1
|
|
212
|
+
|
|
213
|
+
# Try to create again with if_exists="error" (should raise)
|
|
214
|
+
with pytest.raises(ValueError):
|
|
215
|
+
Datasource.from_pandas(name, df, if_exists="error")
|
|
216
|
+
|
|
217
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
218
|
+
datasource2 = Datasource.from_pandas(name, df, if_exists="open")
|
|
219
|
+
assert datasource2.id == datasource1.id
|
|
220
|
+
assert datasource2.name == datasource1.name
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def test_from_arrow_already_exists():
|
|
224
|
+
# Test the if_exists parameter with from_arrow
|
|
225
|
+
table = pa.table({"column1": [1], "column2": ["a"]})
|
|
226
|
+
name = f"test_arrow_exists_{uuid4()}"
|
|
227
|
+
|
|
228
|
+
# Create the first datasource
|
|
229
|
+
datasource1 = Datasource.from_arrow(name, table)
|
|
230
|
+
assert datasource1.length == 1
|
|
231
|
+
|
|
232
|
+
# Try to create again with if_exists="error" (should raise)
|
|
233
|
+
with pytest.raises(ValueError):
|
|
234
|
+
Datasource.from_arrow(name, table, if_exists="error")
|
|
235
|
+
|
|
236
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
237
|
+
datasource2 = Datasource.from_arrow(name, table, if_exists="open")
|
|
238
|
+
assert datasource2.id == datasource1.id
|
|
239
|
+
assert datasource2.name == datasource1.name
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def test_from_disk_csv():
|
|
243
|
+
# Test creating datasource from CSV file
|
|
244
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
245
|
+
f.write("column1,column2\n1,a\n2,b\n3,c")
|
|
246
|
+
f.flush()
|
|
247
|
+
|
|
248
|
+
try:
|
|
249
|
+
datasource = Datasource.from_disk(f"test_csv_{uuid4()}", f.name)
|
|
250
|
+
assert datasource.length == 3
|
|
251
|
+
assert "column1" in datasource.columns
|
|
252
|
+
assert "column2" in datasource.columns
|
|
253
|
+
finally:
|
|
254
|
+
os.unlink(f.name)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_from_disk_json():
|
|
258
|
+
# Test creating datasource from JSON file
|
|
259
|
+
import json
|
|
260
|
+
|
|
261
|
+
data = [{"column1": 1, "column2": "a"}, {"column1": 2, "column2": "b"}]
|
|
262
|
+
|
|
263
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
264
|
+
json.dump(data, f)
|
|
265
|
+
f.flush()
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
datasource = Datasource.from_disk(f"test_json_{uuid4()}", f.name)
|
|
269
|
+
assert datasource.length == 2
|
|
270
|
+
assert "column1" in datasource.columns
|
|
271
|
+
assert "column2" in datasource.columns
|
|
272
|
+
finally:
|
|
273
|
+
os.unlink(f.name)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def test_from_disk_already_exists():
|
|
277
|
+
# Test the if_exists parameter with from_disk
|
|
278
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
279
|
+
f.write("column1,column2\n1,a")
|
|
280
|
+
f.flush()
|
|
281
|
+
|
|
282
|
+
try:
|
|
283
|
+
name = f"test_disk_exists_{uuid4()}"
|
|
284
|
+
|
|
285
|
+
# Create the first datasource
|
|
286
|
+
datasource1 = Datasource.from_disk(name, f.name)
|
|
287
|
+
assert datasource1.length == 1
|
|
288
|
+
|
|
289
|
+
# Try to create again with if_exists="error" (should raise)
|
|
290
|
+
with pytest.raises(ValueError):
|
|
291
|
+
Datasource.from_disk(name, f.name, if_exists="error")
|
|
292
|
+
|
|
293
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
294
|
+
datasource2 = Datasource.from_disk(name, f.name, if_exists="open")
|
|
295
|
+
assert datasource2.id == datasource1.id
|
|
296
|
+
assert datasource2.name == datasource1.name
|
|
297
|
+
finally:
|
|
298
|
+
os.unlink(f.name)
|
orca_sdk/embedding_model.py
CHANGED
|
@@ -23,7 +23,7 @@ from ._generated_api_client.models import (
|
|
|
23
23
|
PretrainedEmbeddingModelMetadata,
|
|
24
24
|
PretrainedEmbeddingModelName,
|
|
25
25
|
)
|
|
26
|
-
from ._utils.common import CreateMode, DropMode
|
|
26
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
27
27
|
from .datasource import Datasource
|
|
28
28
|
from .job import Job, Status
|
|
29
29
|
|
|
@@ -36,40 +36,58 @@ class _EmbeddingModel:
|
|
|
36
36
|
embedding_dim: int
|
|
37
37
|
max_seq_length: int
|
|
38
38
|
uses_context: bool
|
|
39
|
+
supports_instructions: bool
|
|
39
40
|
|
|
40
|
-
def __init__(
|
|
41
|
+
def __init__(
|
|
42
|
+
self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool, supports_instructions: bool
|
|
43
|
+
):
|
|
41
44
|
self.name = name
|
|
42
45
|
self.embedding_dim = embedding_dim
|
|
43
46
|
self.max_seq_length = max_seq_length
|
|
44
47
|
self.uses_context = uses_context
|
|
48
|
+
self.supports_instructions = supports_instructions
|
|
45
49
|
|
|
46
50
|
@classmethod
|
|
47
51
|
@abstractmethod
|
|
48
52
|
def all(cls) -> Sequence[_EmbeddingModel]:
|
|
49
53
|
pass
|
|
50
54
|
|
|
55
|
+
def _get_instruction_error_message(self) -> str:
|
|
56
|
+
"""Get error message for instruction not supported"""
|
|
57
|
+
if isinstance(self, FinetunedEmbeddingModel):
|
|
58
|
+
return f"Model {self.name} does not support instructions. Instruction-following is only supported by models based on instruction-supporting models."
|
|
59
|
+
else:
|
|
60
|
+
return f"Model {self.name} does not support instructions. Instruction-following is only supported by instruction-supporting models."
|
|
61
|
+
|
|
51
62
|
@overload
|
|
52
|
-
def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
|
|
63
|
+
def embed(self, value: str, max_seq_length: int | None = None, prompt: str | None = None) -> list[float]:
|
|
53
64
|
pass
|
|
54
65
|
|
|
55
66
|
@overload
|
|
56
|
-
def embed(
|
|
67
|
+
def embed(
|
|
68
|
+
self, value: list[str], max_seq_length: int | None = None, prompt: str | None = None
|
|
69
|
+
) -> list[list[float]]:
|
|
57
70
|
pass
|
|
58
71
|
|
|
59
|
-
def embed(
|
|
72
|
+
def embed(
|
|
73
|
+
self, value: str | list[str], max_seq_length: int | None = None, prompt: str | None = None
|
|
74
|
+
) -> list[float] | list[list[float]]:
|
|
60
75
|
"""
|
|
61
76
|
Generate embeddings for a value or list of values
|
|
62
77
|
|
|
63
78
|
Params:
|
|
64
79
|
value: The value or list of values to embed
|
|
65
80
|
max_seq_length: The maximum sequence length to truncate the input to
|
|
81
|
+
prompt: Optional prompt for prompt-following embedding models.
|
|
66
82
|
|
|
67
83
|
Returns:
|
|
68
84
|
A matrix of floats representing the embedding for each value if the input is a list of
|
|
69
85
|
values, or a list of floats representing the embedding for the single value if the
|
|
70
86
|
input is a single value
|
|
71
87
|
"""
|
|
72
|
-
request = EmbedRequest(
|
|
88
|
+
request = EmbedRequest(
|
|
89
|
+
values=value if isinstance(value, list) else [value], max_seq_length=max_seq_length, prompt=prompt
|
|
90
|
+
)
|
|
73
91
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
74
92
|
embeddings = embed_with_pretrained_model_gpu(self._model_name, body=request)
|
|
75
93
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
@@ -152,17 +170,27 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
152
170
|
- **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
|
|
153
171
|
- **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
|
|
154
172
|
- **`QWEN2_1_5B`**: Alibaba's Qwen2-1.5B instruction-tuned embedding model from Hugging Face ([Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct))
|
|
173
|
+
- **`BGE_BASE`**: BAAI's BGE-Base instruction-tuned embedding model from Hugging Face ([BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5))
|
|
174
|
+
|
|
175
|
+
**Instruction Support:**
|
|
155
176
|
|
|
177
|
+
Some models support instruction-following for better task-specific embeddings. You can check if a model supports instructions
|
|
178
|
+
using the `supports_instructions` attribute.
|
|
156
179
|
|
|
157
180
|
Examples:
|
|
158
181
|
>>> PretrainedEmbeddingModel.CDE_SMALL
|
|
159
182
|
PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
|
|
160
183
|
|
|
184
|
+
>>> # Using instruction with an instruction-supporting model
|
|
185
|
+
>>> model = PretrainedEmbeddingModel.E5_LARGE
|
|
186
|
+
>>> embeddings = model.embed("Hello world", prompt="Represent this sentence for retrieval:")
|
|
187
|
+
|
|
161
188
|
Attributes:
|
|
162
189
|
name: Name of the pretrained embedding model
|
|
163
190
|
embedding_dim: Dimension of the embeddings that are generated by the model
|
|
164
191
|
max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
|
|
165
192
|
uses_context: Whether the pretrained embedding model uses context
|
|
193
|
+
supports_instructions: Whether this model supports instruction-following
|
|
166
194
|
"""
|
|
167
195
|
|
|
168
196
|
# Define descriptors for model access with IDE autocomplete
|
|
@@ -175,17 +203,22 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
175
203
|
GIST_LARGE = _ModelDescriptor("GIST_LARGE")
|
|
176
204
|
MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
|
|
177
205
|
QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
|
|
206
|
+
BGE_BASE = _ModelDescriptor("BGE_BASE")
|
|
178
207
|
|
|
179
208
|
_model_name: PretrainedEmbeddingModelName
|
|
180
209
|
|
|
181
210
|
def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
|
|
182
211
|
# for internal use only, do not document
|
|
183
212
|
self._model_name = metadata.name
|
|
213
|
+
|
|
184
214
|
super().__init__(
|
|
185
215
|
name=metadata.name.value,
|
|
186
216
|
embedding_dim=metadata.embedding_dim,
|
|
187
217
|
max_seq_length=metadata.max_seq_length,
|
|
188
218
|
uses_context=metadata.uses_context,
|
|
219
|
+
supports_instructions=(
|
|
220
|
+
bool(metadata.supports_instructions) if metadata.supports_instructions is not UNSET else False
|
|
221
|
+
),
|
|
189
222
|
)
|
|
190
223
|
|
|
191
224
|
def __eq__(self, other) -> bool:
|
|
@@ -209,9 +242,11 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
209
242
|
@classmethod
|
|
210
243
|
def _get(cls, name: PretrainedEmbeddingModelName | str) -> PretrainedEmbeddingModel:
|
|
211
244
|
# for internal use only, do not document - we want people to use dot notation to get the model
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
245
|
+
cache_key = str(name)
|
|
246
|
+
if cache_key not in cls._instances:
|
|
247
|
+
metadata = get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name))
|
|
248
|
+
cls._instances[cache_key] = cls(metadata)
|
|
249
|
+
return cls._instances[cache_key]
|
|
215
250
|
|
|
216
251
|
@classmethod
|
|
217
252
|
def open(cls, name: str) -> PretrainedEmbeddingModel:
|
|
@@ -231,9 +266,9 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
|
231
266
|
>>> model = PretrainedEmbeddingModel.open("GTE_BASE")
|
|
232
267
|
"""
|
|
233
268
|
try:
|
|
234
|
-
#
|
|
235
|
-
return
|
|
236
|
-
except AttributeError:
|
|
269
|
+
# Always use the _get method which handles caching properly
|
|
270
|
+
return cls._get(name)
|
|
271
|
+
except (KeyError, AttributeError):
|
|
237
272
|
raise ValueError(f"Unknown model name: {name}")
|
|
238
273
|
|
|
239
274
|
@classmethod
|
|
@@ -385,11 +420,13 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
385
420
|
self.updated_at = metadata.updated_at
|
|
386
421
|
self.base_model_name = metadata.base_model
|
|
387
422
|
self._status = Status(metadata.finetuning_status.value)
|
|
423
|
+
|
|
388
424
|
super().__init__(
|
|
389
425
|
name=metadata.name,
|
|
390
426
|
embedding_dim=metadata.embedding_dim,
|
|
391
427
|
max_seq_length=metadata.max_seq_length,
|
|
392
428
|
uses_context=metadata.uses_context,
|
|
429
|
+
supports_instructions=self.base_model.supports_instructions,
|
|
393
430
|
)
|
|
394
431
|
|
|
395
432
|
def __eq__(self, other) -> bool:
|
|
@@ -434,7 +471,8 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
434
471
|
Raises:
|
|
435
472
|
LookupError: If the finetuned embedding model does not exist
|
|
436
473
|
"""
|
|
437
|
-
|
|
474
|
+
metadata = get_finetuned_embedding_model(name)
|
|
475
|
+
return cls(metadata)
|
|
438
476
|
|
|
439
477
|
@classmethod
|
|
440
478
|
def exists(cls, name_or_id: str) -> bool:
|