orca-sdk 0.0.96__py3-none-any.whl → 0.0.97__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. orca_sdk/__init__.py +1 -5
  2. orca_sdk/_generated_api_client/api/__init__.py +28 -8
  3. orca_sdk/_generated_api_client/api/{datasource/create_datasource_datasource_post.py → auth/create_org_plan_auth_org_plan_post.py} +32 -31
  4. orca_sdk/_generated_api_client/api/auth/get_org_plan_auth_org_plan_get.py +122 -0
  5. orca_sdk/_generated_api_client/api/auth/update_org_plan_auth_org_plan_put.py +168 -0
  6. orca_sdk/_generated_api_client/api/classification_model/{create_classification_model_gpu_classification_model_post.py → create_classification_model_classification_model_post.py} +1 -1
  7. orca_sdk/_generated_api_client/api/datasource/create_datasource_from_content_datasource_post.py +224 -0
  8. orca_sdk/_generated_api_client/api/datasource/create_datasource_from_files_datasource_upload_post.py +229 -0
  9. orca_sdk/_generated_api_client/api/regression_model/{create_regression_model_gpu_regression_model_post.py → create_regression_model_regression_model_post.py} +1 -1
  10. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +21 -26
  11. orca_sdk/_generated_api_client/api/telemetry/generate_memory_suggestions_telemetry_prediction_prediction_id_memory_suggestions_post.py +239 -0
  12. orca_sdk/_generated_api_client/api/telemetry/get_action_recommendation_telemetry_prediction_prediction_id_action_get.py +192 -0
  13. orca_sdk/_generated_api_client/models/__init__.py +54 -4
  14. orca_sdk/_generated_api_client/models/action_recommendation.py +82 -0
  15. orca_sdk/_generated_api_client/models/action_recommendation_action.py +11 -0
  16. orca_sdk/_generated_api_client/models/add_memory_recommendations.py +85 -0
  17. orca_sdk/_generated_api_client/models/add_memory_suggestion.py +79 -0
  18. orca_sdk/_generated_api_client/models/body_create_datasource_from_files_datasource_upload_post.py +145 -0
  19. orca_sdk/_generated_api_client/models/class_representatives.py +92 -0
  20. orca_sdk/_generated_api_client/models/classification_model_metadata.py +14 -0
  21. orca_sdk/_generated_api_client/models/clone_memoryset_request.py +40 -0
  22. orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +8 -7
  23. orca_sdk/_generated_api_client/models/constraint_violation_error_response_status_code.py +8 -0
  24. orca_sdk/_generated_api_client/models/create_classification_model_request.py +40 -0
  25. orca_sdk/_generated_api_client/models/create_datasource_from_content_request.py +101 -0
  26. orca_sdk/_generated_api_client/models/create_memoryset_request.py +40 -0
  27. orca_sdk/_generated_api_client/models/create_org_plan_request.py +73 -0
  28. orca_sdk/_generated_api_client/models/create_org_plan_request_tier.py +11 -0
  29. orca_sdk/_generated_api_client/models/create_regression_model_request.py +20 -0
  30. orca_sdk/_generated_api_client/models/embed_request.py +20 -0
  31. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +28 -10
  32. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +28 -10
  33. orca_sdk/_generated_api_client/models/embedding_model_result.py +9 -0
  34. orca_sdk/_generated_api_client/models/filter_item.py +31 -23
  35. orca_sdk/_generated_api_client/models/filter_item_field_type_1_item_type_0.py +8 -0
  36. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_0.py +8 -0
  37. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +2 -0
  38. orca_sdk/_generated_api_client/models/internal_server_error_response.py +8 -7
  39. orca_sdk/_generated_api_client/models/internal_server_error_response_status_code.py +8 -0
  40. orca_sdk/_generated_api_client/models/labeled_memory.py +5 -5
  41. orca_sdk/_generated_api_client/models/labeled_memory_update.py +16 -16
  42. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +5 -5
  43. orca_sdk/_generated_api_client/models/lookup_request.py +20 -0
  44. orca_sdk/_generated_api_client/models/memory_metrics.py +98 -0
  45. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +33 -0
  46. orca_sdk/_generated_api_client/models/memoryset_class_patterns_analysis_config.py +79 -0
  47. orca_sdk/_generated_api_client/models/memoryset_class_patterns_metrics.py +138 -0
  48. orca_sdk/_generated_api_client/models/memoryset_metadata.py +42 -0
  49. orca_sdk/_generated_api_client/models/memoryset_metrics.py +33 -0
  50. orca_sdk/_generated_api_client/models/memoryset_update.py +20 -0
  51. orca_sdk/_generated_api_client/models/not_found_error_response.py +6 -7
  52. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  53. orca_sdk/_generated_api_client/models/not_found_error_response_status_code.py +8 -0
  54. orca_sdk/_generated_api_client/models/org_plan.py +99 -0
  55. orca_sdk/_generated_api_client/models/org_plan_tier.py +11 -0
  56. orca_sdk/_generated_api_client/models/paginated_task.py +108 -0
  57. orca_sdk/_generated_api_client/models/predictive_model_update.py +20 -0
  58. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +8 -0
  59. orca_sdk/_generated_api_client/models/regression_model_metadata.py +14 -0
  60. orca_sdk/_generated_api_client/models/scored_memory_update.py +9 -9
  61. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +8 -7
  62. orca_sdk/_generated_api_client/models/service_unavailable_error_response_status_code.py +8 -0
  63. orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_0.py +8 -0
  64. orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_0.py +8 -0
  65. orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_1.py +8 -0
  66. orca_sdk/_generated_api_client/models/telemetry_filter_item.py +42 -30
  67. orca_sdk/_generated_api_client/models/telemetry_sort_options.py +42 -30
  68. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +8 -7
  69. orca_sdk/_generated_api_client/models/unauthenticated_error_response_status_code.py +8 -0
  70. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +8 -7
  71. orca_sdk/_generated_api_client/models/unauthorized_error_response_status_code.py +8 -0
  72. orca_sdk/_generated_api_client/models/update_org_plan_request.py +73 -0
  73. orca_sdk/_generated_api_client/models/update_org_plan_request_tier.py +11 -0
  74. orca_sdk/_shared/metrics.py +1 -1
  75. orca_sdk/classification_model.py +2 -2
  76. orca_sdk/classification_model_test.py +53 -0
  77. orca_sdk/credentials.py +15 -1
  78. orca_sdk/datasource.py +180 -41
  79. orca_sdk/datasource_test.py +194 -0
  80. orca_sdk/embedding_model.py +51 -13
  81. orca_sdk/embedding_model_test.py +27 -0
  82. orca_sdk/job.py +15 -14
  83. orca_sdk/job_test.py +34 -0
  84. orca_sdk/memoryset.py +47 -7
  85. orca_sdk/regression_model.py +2 -2
  86. orca_sdk/telemetry.py +94 -3
  87. {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.97.dist-info}/METADATA +18 -1
  88. {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.97.dist-info}/RECORD +89 -58
  89. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -207
  90. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -246
  91. {orca_sdk-0.0.96.dist-info → orca_sdk-0.0.97.dist-info}/WHEEL +0 -0
orca_sdk/datasource.py CHANGED
@@ -4,32 +4,105 @@ import logging
4
4
  import tempfile
5
5
  import zipfile
6
6
  from datetime import datetime
7
+ from io import BytesIO
7
8
  from os import PathLike
8
9
  from pathlib import Path
9
- from typing import cast
10
+ from typing import Union
10
11
 
11
12
  import pandas as pd
12
13
  import pyarrow as pa
13
- from datasets import Dataset
14
+ from datasets import Dataset, DatasetDict
15
+ from pyarrow import parquet
14
16
  from torch.utils.data import DataLoader as TorchDataLoader
15
17
  from torch.utils.data import Dataset as TorchDataset
16
18
  from tqdm.auto import tqdm
17
19
 
18
20
  from ._generated_api_client.api import (
21
+ create_datasource_from_content,
19
22
  delete_datasource,
20
23
  get_datasource,
21
24
  list_datasources,
22
25
  )
23
- from ._generated_api_client.api.datasource.create_datasource_datasource_post import (
26
+ from ._generated_api_client.api.datasource.create_datasource_from_files_datasource_upload_post import (
24
27
  _parse_response as parse_create_response,
25
28
  )
26
29
  from ._generated_api_client.client import get_client
27
- from ._generated_api_client.models import ColumnType, DatasourceMetadata
30
+ from ._generated_api_client.models import (
31
+ ColumnType,
32
+ CreateDatasourceFromContentRequest,
33
+ DatasourceMetadata,
34
+ )
28
35
  from ._utils.common import CreateMode, DropMode
29
- from ._utils.data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
36
+ from ._utils.data_parsing import hf_dataset_from_torch
30
37
  from ._utils.tqdm_file_reader import TqdmFileReader
31
38
 
32
39
 
40
+ def _upload_files_to_datasource(
41
+ name: str,
42
+ file_paths: list[Path],
43
+ description: str | None = None,
44
+ ) -> DatasourceMetadata:
45
+ """
46
+ Helper function to upload files to create a datasource using manual HTTP requests.
47
+
48
+ This bypasses the generated client because it doesn't handle file uploads properly.
49
+
50
+ Params:
51
+ name: Name for the datasource
52
+ file_paths: List of file paths to upload
53
+ description: Optional description for the datasource
54
+
55
+ Returns:
56
+ Metadata for the created datasource
57
+ """
58
+ client = get_client()
59
+ files = []
60
+
61
+ # Calculate total size for all files
62
+ total_size = sum(file_path.stat().st_size for file_path in file_paths)
63
+
64
+ with tqdm(total=total_size, unit="B", unit_scale=True, desc="Uploading") as pbar:
65
+ for file_path in file_paths:
66
+ buffered_reader = open(file_path, "rb")
67
+ tqdm_reader = TqdmFileReader(buffered_reader, pbar)
68
+ files.append(("files", (file_path.name, tqdm_reader)))
69
+
70
+ # Use manual HTTP request for file uploads
71
+ metadata = parse_create_response(
72
+ response=client.get_httpx_client().request(
73
+ method="post",
74
+ url="/datasource/upload",
75
+ files=files,
76
+ data={"name": name, "description": description},
77
+ )
78
+ )
79
+
80
+ return metadata
81
+
82
+
83
+ def _handle_existing_datasource(name: str, if_exists: CreateMode) -> Union["Datasource", None]:
84
+ """
85
+ Helper function to handle the common pattern of checking if a datasource exists
86
+ and taking action based on the if_exists parameter.
87
+
88
+ Params:
89
+ name: Name of the datasource to check
90
+ if_exists: What to do if a datasource with the same name already exists
91
+
92
+ Returns:
93
+ Datasource instance if opening existing, None if should proceed with creation
94
+
95
+ Raises:
96
+ ValueError: If the datasource already exists and if_exists is "error"
97
+ """
98
+ if Datasource.exists(name):
99
+ if if_exists == "error":
100
+ raise ValueError(f"Dataset with name {name} already exists")
101
+ elif if_exists == "open":
102
+ return Datasource.open(name)
103
+ return None
104
+
105
+
33
106
  class Datasource:
34
107
  """
35
108
  A Handle to a datasource in the OrcaCloud
@@ -138,40 +211,54 @@ class Datasource:
138
211
  Raises:
139
212
  ValueError: If the datasource already exists and if_exists is `"error"`
140
213
  """
141
- client = get_client()
142
-
143
- if cls.exists(name):
144
- if if_exists == "error":
145
- raise ValueError(f"Dataset with name {name} already exists")
146
- elif if_exists == "open":
147
- return cls.open(name)
214
+ # Check if datasource already exists and handle accordingly
215
+ existing = _handle_existing_datasource(name, if_exists)
216
+ if existing is not None:
217
+ return existing
148
218
 
149
219
  with tempfile.TemporaryDirectory() as tmp_dir:
150
220
  dataset.save_to_disk(tmp_dir)
151
- files = []
152
221
 
153
- # Calculate total size for all files
222
+ # Get all file paths in the directory
154
223
  file_paths = list(Path(tmp_dir).iterdir())
155
- total_size = sum(file_path.stat().st_size for file_path in file_paths)
156
-
157
- with tqdm(total=total_size, unit="B", unit_scale=True, desc="Uploading") as pbar:
158
- for file_path in file_paths:
159
- buffered_reader = open(file_path, "rb")
160
- tqdm_reader = TqdmFileReader(buffered_reader, pbar)
161
- files.append(("files", (file_path.name, tqdm_reader)))
162
-
163
- # Do not use Generated client for this endpoint b/c it does not handle files properly
164
- metadata = parse_create_response(
165
- response=client.get_httpx_client().request(
166
- method="post",
167
- url="/datasource/",
168
- files=files,
169
- data={"name": name, "description": description},
170
- )
171
- )
172
224
 
225
+ # Use the helper function to upload files
226
+ metadata = _upload_files_to_datasource(name, file_paths, description)
173
227
  return cls(metadata=metadata)
174
228
 
229
+ @classmethod
230
+ def from_hf_dataset_dict(
231
+ cls,
232
+ name: str,
233
+ dataset_dict: DatasetDict,
234
+ if_exists: CreateMode = "error",
235
+ description: dict[str, str | None] | str | None = None,
236
+ ) -> dict[str, Datasource]:
237
+ """
238
+ Create datasources from a Hugging Face DatasetDict
239
+
240
+ Params:
241
+ name: Name prefix for the new datasources, will be suffixed with the dataset name
242
+ dataset_dict: The Hugging Face DatasetDict to create the datasources from
243
+ if_exists: What to do if a datasource with the same name already exists, defaults to
244
+ `"error"`. Other option is `"open"` to open the existing datasource.
245
+ description: Optional description for the datasources, can be a string or a dictionary of dataset names to descriptions
246
+
247
+ Returns:
248
+ A dictionary of datasource handles, keyed by the dataset name
249
+
250
+ Raises:
251
+ ValueError: If a datasource already exists and if_exists is `"error"`
252
+ """
253
+ if description is None or isinstance(description, str):
254
+ description = {dataset_name: description for dataset_name in dataset_dict.keys()}
255
+ return {
256
+ dataset_name: cls.from_hf_dataset(
257
+ f"{name}_{dataset_name}", dataset, if_exists=if_exists, description=description[dataset_name]
258
+ )
259
+ for dataset_name, dataset in dataset_dict.items()
260
+ }
261
+
175
262
  @classmethod
176
263
  def from_pytorch(
177
264
  cls,
@@ -225,8 +312,16 @@ class Datasource:
225
312
  Examples:
226
313
  >>> Datasource.from_list("my_datasource", [{"text": "Hello, world!", "label": 1}, {"text": "Goodbye", "label": 0}])
227
314
  """
228
- hf_dataset = Dataset.from_list(data)
229
- return cls.from_hf_dataset(name, hf_dataset, if_exists=if_exists, description=description)
315
+ # Check if datasource already exists and handle accordingly
316
+ existing = _handle_existing_datasource(name, if_exists)
317
+ if existing is not None:
318
+ return existing
319
+
320
+ # Use the generated API client function for content creation
321
+ body = CreateDatasourceFromContentRequest(name=name, description=description, content=data)
322
+
323
+ metadata = create_datasource_from_content(body=body)
324
+ return cls(metadata=metadata)
230
325
 
231
326
  @classmethod
232
327
  def from_dict(
@@ -251,8 +346,16 @@ class Datasource:
251
346
  Examples:
252
347
  >>> Datasource.from_dict("my_datasource", {"text": ["Hello, world!", "Goodbye"], "label": [1, 0]})
253
348
  """
254
- hf_dataset = Dataset.from_dict(data)
255
- return cls.from_hf_dataset(name, hf_dataset, if_exists=if_exists, description=description)
349
+ # Check if datasource already exists and handle accordingly
350
+ existing = _handle_existing_datasource(name, if_exists)
351
+ if existing is not None:
352
+ return existing
353
+
354
+ # Use the generated API client function for content creation
355
+ body = CreateDatasourceFromContentRequest(name=name, description=description, content=data)
356
+
357
+ metadata = create_datasource_from_content(body=body)
358
+ return cls(metadata=metadata)
256
359
 
257
360
  @classmethod
258
361
  def from_pandas(
@@ -274,8 +377,8 @@ class Datasource:
274
377
  Raises:
275
378
  ValueError: If the datasource already exists and if_exists is `"error"`
276
379
  """
277
- hf_dataset = Dataset.from_pandas(dataframe)
278
- return cls.from_hf_dataset(name, hf_dataset, if_exists=if_exists, description=description)
380
+ dataset = Dataset.from_pandas(dataframe)
381
+ return cls.from_hf_dataset(name, dataset, if_exists=if_exists, description=description)
279
382
 
280
383
  @classmethod
281
384
  def from_arrow(
@@ -297,8 +400,29 @@ class Datasource:
297
400
  Raises:
298
401
  ValueError: If the datasource already exists and if_exists is `"error"`
299
402
  """
300
- hf_dataset = Dataset(pyarrow_table)
301
- return cls.from_hf_dataset(name, hf_dataset, if_exists=if_exists, description=description)
403
+ # Check if datasource already exists and handle accordingly
404
+ existing = _handle_existing_datasource(name, if_exists)
405
+ if existing is not None:
406
+ return existing
407
+
408
+ # Write to bytes buffer
409
+ buffer = BytesIO()
410
+ parquet.write_table(pyarrow_table, buffer)
411
+ parquet_bytes = buffer.getvalue()
412
+
413
+ client = get_client()
414
+
415
+ # Use manual HTTP request for file uploads
416
+ metadata = parse_create_response(
417
+ response=client.get_httpx_client().request(
418
+ method="post",
419
+ url="/datasource/upload",
420
+ files=[("files", ("data.parquet", parquet_bytes))],
421
+ data={"name": name, "description": description},
422
+ )
423
+ )
424
+
425
+ return cls(metadata=metadata)
302
426
 
303
427
  @classmethod
304
428
  def from_disk(
@@ -328,8 +452,23 @@ class Datasource:
328
452
  Raises:
329
453
  ValueError: If the datasource already exists and if_exists is `"error"`
330
454
  """
331
- hf_dataset = hf_dataset_from_disk(file_path)
332
- return cls.from_hf_dataset(name, cast(Dataset, hf_dataset), if_exists=if_exists, description=description)
455
+ # Check if datasource already exists and handle accordingly
456
+ existing = _handle_existing_datasource(name, if_exists)
457
+ if existing is not None:
458
+ return existing
459
+
460
+ file_path = Path(file_path)
461
+
462
+ # For dataset directories, use the upload endpoint with multiple files
463
+ if file_path.is_dir():
464
+ return cls.from_hf_dataset(
465
+ name, Dataset.load_from_disk(file_path), if_exists=if_exists, description=description
466
+ )
467
+
468
+ # For single files, use the helper function to upload files
469
+ metadata = _upload_files_to_datasource(name, [file_path], description)
470
+
471
+ return cls(metadata=metadata)
333
472
 
334
473
  @classmethod
335
474
  def open(cls, name: str) -> Datasource:
@@ -2,6 +2,8 @@ import os
2
2
  import tempfile
3
3
  from uuid import uuid4
4
4
 
5
+ import pandas as pd
6
+ import pyarrow as pa
5
7
  import pytest
6
8
 
7
9
  from .datasource import Datasource
@@ -102,3 +104,195 @@ def test_download_datasource(datasource):
102
104
  output_path = os.path.join(temp_dir, "datasource.zip")
103
105
  datasource.download(output_path)
104
106
  assert os.path.exists(output_path)
107
+
108
+
109
+ def test_from_list():
110
+ # Test creating datasource from list of dictionaries
111
+ data = [
112
+ {"column1": 1, "column2": "a"},
113
+ {"column1": 2, "column2": "b"},
114
+ {"column1": 3, "column2": "c"},
115
+ ]
116
+ datasource = Datasource.from_list(f"test_list_{uuid4()}", data)
117
+ assert datasource.name.startswith("test_list_")
118
+ assert datasource.length == 3
119
+ assert "column1" in datasource.columns
120
+ assert "column2" in datasource.columns
121
+
122
+
123
+ def test_from_dict():
124
+ # Test creating datasource from dictionary of columns
125
+ data = {
126
+ "column1": [1, 2, 3],
127
+ "column2": ["a", "b", "c"],
128
+ }
129
+ datasource = Datasource.from_dict(f"test_dict_{uuid4()}", data)
130
+ assert datasource.name.startswith("test_dict_")
131
+ assert datasource.length == 3
132
+ assert "column1" in datasource.columns
133
+ assert "column2" in datasource.columns
134
+
135
+
136
+ def test_from_pandas():
137
+ # Test creating datasource from pandas DataFrame
138
+ df = pd.DataFrame(
139
+ {
140
+ "column1": [1, 2, 3],
141
+ "column2": ["a", "b", "c"],
142
+ }
143
+ )
144
+ datasource = Datasource.from_pandas(f"test_pandas_{uuid4()}", df)
145
+ assert datasource.name.startswith("test_pandas_")
146
+ assert datasource.length == 3
147
+ assert "column1" in datasource.columns
148
+ assert "column2" in datasource.columns
149
+
150
+
151
+ def test_from_arrow():
152
+ # Test creating datasource from pyarrow Table
153
+ table = pa.table(
154
+ {
155
+ "column1": [1, 2, 3],
156
+ "column2": ["a", "b", "c"],
157
+ }
158
+ )
159
+ datasource = Datasource.from_arrow(f"test_arrow_{uuid4()}", table)
160
+ assert datasource.name.startswith("test_arrow_")
161
+ assert datasource.length == 3
162
+ assert "column1" in datasource.columns
163
+ assert "column2" in datasource.columns
164
+
165
+
166
+ def test_from_list_already_exists():
167
+ # Test the if_exists parameter with from_list
168
+ data = [{"column1": 1, "column2": "a"}]
169
+ name = f"test_list_exists_{uuid4()}"
170
+
171
+ # Create the first datasource
172
+ datasource1 = Datasource.from_list(name, data)
173
+ assert datasource1.length == 1
174
+
175
+ # Try to create again with if_exists="error" (should raise)
176
+ with pytest.raises(ValueError):
177
+ Datasource.from_list(name, data, if_exists="error")
178
+
179
+ # Try to create again with if_exists="open" (should return existing)
180
+ datasource2 = Datasource.from_list(name, data, if_exists="open")
181
+ assert datasource2.id == datasource1.id
182
+ assert datasource2.name == datasource1.name
183
+
184
+
185
+ def test_from_dict_already_exists():
186
+ # Test the if_exists parameter with from_dict
187
+ data = {"column1": [1], "column2": ["a"]}
188
+ name = f"test_dict_exists_{uuid4()}"
189
+
190
+ # Create the first datasource
191
+ datasource1 = Datasource.from_dict(name, data)
192
+ assert datasource1.length == 1
193
+
194
+ # Try to create again with if_exists="error" (should raise)
195
+ with pytest.raises(ValueError):
196
+ Datasource.from_dict(name, data, if_exists="error")
197
+
198
+ # Try to create again with if_exists="open" (should return existing)
199
+ datasource2 = Datasource.from_dict(name, data, if_exists="open")
200
+ assert datasource2.id == datasource1.id
201
+ assert datasource2.name == datasource1.name
202
+
203
+
204
+ def test_from_pandas_already_exists():
205
+ # Test the if_exists parameter with from_pandas
206
+ df = pd.DataFrame({"column1": [1], "column2": ["a"]})
207
+ name = f"test_pandas_exists_{uuid4()}"
208
+
209
+ # Create the first datasource
210
+ datasource1 = Datasource.from_pandas(name, df)
211
+ assert datasource1.length == 1
212
+
213
+ # Try to create again with if_exists="error" (should raise)
214
+ with pytest.raises(ValueError):
215
+ Datasource.from_pandas(name, df, if_exists="error")
216
+
217
+ # Try to create again with if_exists="open" (should return existing)
218
+ datasource2 = Datasource.from_pandas(name, df, if_exists="open")
219
+ assert datasource2.id == datasource1.id
220
+ assert datasource2.name == datasource1.name
221
+
222
+
223
+ def test_from_arrow_already_exists():
224
+ # Test the if_exists parameter with from_arrow
225
+ table = pa.table({"column1": [1], "column2": ["a"]})
226
+ name = f"test_arrow_exists_{uuid4()}"
227
+
228
+ # Create the first datasource
229
+ datasource1 = Datasource.from_arrow(name, table)
230
+ assert datasource1.length == 1
231
+
232
+ # Try to create again with if_exists="error" (should raise)
233
+ with pytest.raises(ValueError):
234
+ Datasource.from_arrow(name, table, if_exists="error")
235
+
236
+ # Try to create again with if_exists="open" (should return existing)
237
+ datasource2 = Datasource.from_arrow(name, table, if_exists="open")
238
+ assert datasource2.id == datasource1.id
239
+ assert datasource2.name == datasource1.name
240
+
241
+
242
+ def test_from_disk_csv():
243
+ # Test creating datasource from CSV file
244
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
245
+ f.write("column1,column2\n1,a\n2,b\n3,c")
246
+ f.flush()
247
+
248
+ try:
249
+ datasource = Datasource.from_disk(f"test_csv_{uuid4()}", f.name)
250
+ assert datasource.length == 3
251
+ assert "column1" in datasource.columns
252
+ assert "column2" in datasource.columns
253
+ finally:
254
+ os.unlink(f.name)
255
+
256
+
257
+ def test_from_disk_json():
258
+ # Test creating datasource from JSON file
259
+ import json
260
+
261
+ data = [{"column1": 1, "column2": "a"}, {"column1": 2, "column2": "b"}]
262
+
263
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
264
+ json.dump(data, f)
265
+ f.flush()
266
+
267
+ try:
268
+ datasource = Datasource.from_disk(f"test_json_{uuid4()}", f.name)
269
+ assert datasource.length == 2
270
+ assert "column1" in datasource.columns
271
+ assert "column2" in datasource.columns
272
+ finally:
273
+ os.unlink(f.name)
274
+
275
+
276
+ def test_from_disk_already_exists():
277
+ # Test the if_exists parameter with from_disk
278
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
279
+ f.write("column1,column2\n1,a")
280
+ f.flush()
281
+
282
+ try:
283
+ name = f"test_disk_exists_{uuid4()}"
284
+
285
+ # Create the first datasource
286
+ datasource1 = Datasource.from_disk(name, f.name)
287
+ assert datasource1.length == 1
288
+
289
+ # Try to create again with if_exists="error" (should raise)
290
+ with pytest.raises(ValueError):
291
+ Datasource.from_disk(name, f.name, if_exists="error")
292
+
293
+ # Try to create again with if_exists="open" (should return existing)
294
+ datasource2 = Datasource.from_disk(name, f.name, if_exists="open")
295
+ assert datasource2.id == datasource1.id
296
+ assert datasource2.name == datasource1.name
297
+ finally:
298
+ os.unlink(f.name)
@@ -23,7 +23,7 @@ from ._generated_api_client.models import (
23
23
  PretrainedEmbeddingModelMetadata,
24
24
  PretrainedEmbeddingModelName,
25
25
  )
26
- from ._utils.common import CreateMode, DropMode
26
+ from ._utils.common import UNSET, CreateMode, DropMode
27
27
  from .datasource import Datasource
28
28
  from .job import Job, Status
29
29
 
@@ -36,40 +36,58 @@ class _EmbeddingModel:
36
36
  embedding_dim: int
37
37
  max_seq_length: int
38
38
  uses_context: bool
39
+ supports_instructions: bool
39
40
 
40
- def __init__(self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool):
41
+ def __init__(
42
+ self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool, supports_instructions: bool
43
+ ):
41
44
  self.name = name
42
45
  self.embedding_dim = embedding_dim
43
46
  self.max_seq_length = max_seq_length
44
47
  self.uses_context = uses_context
48
+ self.supports_instructions = supports_instructions
45
49
 
46
50
  @classmethod
47
51
  @abstractmethod
48
52
  def all(cls) -> Sequence[_EmbeddingModel]:
49
53
  pass
50
54
 
55
+ def _get_instruction_error_message(self) -> str:
56
+ """Get error message for instruction not supported"""
57
+ if isinstance(self, FinetunedEmbeddingModel):
58
+ return f"Model {self.name} does not support instructions. Instruction-following is only supported by models based on instruction-supporting models."
59
+ else:
60
+ return f"Model {self.name} does not support instructions. Instruction-following is only supported by instruction-supporting models."
61
+
51
62
  @overload
52
- def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
63
+ def embed(self, value: str, max_seq_length: int | None = None, prompt: str | None = None) -> list[float]:
53
64
  pass
54
65
 
55
66
  @overload
56
- def embed(self, value: list[str], max_seq_length: int | None = None) -> list[list[float]]:
67
+ def embed(
68
+ self, value: list[str], max_seq_length: int | None = None, prompt: str | None = None
69
+ ) -> list[list[float]]:
57
70
  pass
58
71
 
59
- def embed(self, value: str | list[str], max_seq_length: int | None = None) -> list[float] | list[list[float]]:
72
+ def embed(
73
+ self, value: str | list[str], max_seq_length: int | None = None, prompt: str | None = None
74
+ ) -> list[float] | list[list[float]]:
60
75
  """
61
76
  Generate embeddings for a value or list of values
62
77
 
63
78
  Params:
64
79
  value: The value or list of values to embed
65
80
  max_seq_length: The maximum sequence length to truncate the input to
81
+ prompt: Optional prompt for prompt-following embedding models.
66
82
 
67
83
  Returns:
68
84
  A matrix of floats representing the embedding for each value if the input is a list of
69
85
  values, or a list of floats representing the embedding for the single value if the
70
86
  input is a single value
71
87
  """
72
- request = EmbedRequest(values=value if isinstance(value, list) else [value], max_seq_length=max_seq_length)
88
+ request = EmbedRequest(
89
+ values=value if isinstance(value, list) else [value], max_seq_length=max_seq_length, prompt=prompt
90
+ )
73
91
  if isinstance(self, PretrainedEmbeddingModel):
74
92
  embeddings = embed_with_pretrained_model_gpu(self._model_name, body=request)
75
93
  elif isinstance(self, FinetunedEmbeddingModel):
@@ -152,17 +170,27 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
152
170
  - **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
153
171
  - **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
154
172
  - **`QWEN2_1_5B`**: Alibaba's Qwen2-1.5B instruction-tuned embedding model from Hugging Face ([Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct))
173
+ - **`BGE_BASE`**: BAAI's BGE-Base instruction-tuned embedding model from Hugging Face ([BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5))
174
+
175
+ **Instruction Support:**
155
176
 
177
+ Some models support instruction-following for better task-specific embeddings. You can check if a model supports instructions
178
+ using the `supports_instructions` attribute.
156
179
 
157
180
  Examples:
158
181
  >>> PretrainedEmbeddingModel.CDE_SMALL
159
182
  PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
160
183
 
184
+ >>> # Using instruction with an instruction-supporting model
185
+ >>> model = PretrainedEmbeddingModel.E5_LARGE
186
+ >>> embeddings = model.embed("Hello world", prompt="Represent this sentence for retrieval:")
187
+
161
188
  Attributes:
162
189
  name: Name of the pretrained embedding model
163
190
  embedding_dim: Dimension of the embeddings that are generated by the model
164
191
  max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
165
192
  uses_context: Whether the pretrained embedding model uses context
193
+ supports_instructions: Whether this model supports instruction-following
166
194
  """
167
195
 
168
196
  # Define descriptors for model access with IDE autocomplete
@@ -175,17 +203,22 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
175
203
  GIST_LARGE = _ModelDescriptor("GIST_LARGE")
176
204
  MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
177
205
  QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
206
+ BGE_BASE = _ModelDescriptor("BGE_BASE")
178
207
 
179
208
  _model_name: PretrainedEmbeddingModelName
180
209
 
181
210
  def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
182
211
  # for internal use only, do not document
183
212
  self._model_name = metadata.name
213
+
184
214
  super().__init__(
185
215
  name=metadata.name.value,
186
216
  embedding_dim=metadata.embedding_dim,
187
217
  max_seq_length=metadata.max_seq_length,
188
218
  uses_context=metadata.uses_context,
219
+ supports_instructions=(
220
+ bool(metadata.supports_instructions) if metadata.supports_instructions is not UNSET else False
221
+ ),
189
222
  )
190
223
 
191
224
  def __eq__(self, other) -> bool:
@@ -209,9 +242,11 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
209
242
  @classmethod
210
243
  def _get(cls, name: PretrainedEmbeddingModelName | str) -> PretrainedEmbeddingModel:
211
244
  # for internal use only, do not document - we want people to use dot notation to get the model
212
- if str(name) not in cls._instances:
213
- cls._instances[str(name)] = cls(get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name)))
214
- return cls._instances[str(name)]
245
+ cache_key = str(name)
246
+ if cache_key not in cls._instances:
247
+ metadata = get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name))
248
+ cls._instances[cache_key] = cls(metadata)
249
+ return cls._instances[cache_key]
215
250
 
216
251
  @classmethod
217
252
  def open(cls, name: str) -> PretrainedEmbeddingModel:
@@ -231,9 +266,9 @@ class PretrainedEmbeddingModel(_EmbeddingModel):
231
266
  >>> model = PretrainedEmbeddingModel.open("GTE_BASE")
232
267
  """
233
268
  try:
234
- # Use getattr to access the descriptor which will initialize the model
235
- return getattr(cls, name)
236
- except AttributeError:
269
+ # Always use the _get method which handles caching properly
270
+ return cls._get(name)
271
+ except (KeyError, AttributeError):
237
272
  raise ValueError(f"Unknown model name: {name}")
238
273
 
239
274
  @classmethod
@@ -385,11 +420,13 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
385
420
  self.updated_at = metadata.updated_at
386
421
  self.base_model_name = metadata.base_model
387
422
  self._status = Status(metadata.finetuning_status.value)
423
+
388
424
  super().__init__(
389
425
  name=metadata.name,
390
426
  embedding_dim=metadata.embedding_dim,
391
427
  max_seq_length=metadata.max_seq_length,
392
428
  uses_context=metadata.uses_context,
429
+ supports_instructions=self.base_model.supports_instructions,
393
430
  )
394
431
 
395
432
  def __eq__(self, other) -> bool:
@@ -434,7 +471,8 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
434
471
  Raises:
435
472
  LookupError: If the finetuned embedding model does not exist
436
473
  """
437
- return cls(get_finetuned_embedding_model(name))
474
+ metadata = get_finetuned_embedding_model(name)
475
+ return cls(metadata)
438
476
 
439
477
  @classmethod
440
478
  def exists(cls, name_or_id: str) -> bool: