orca-sdk 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -1
- orca_sdk/_utils/auth.py +12 -8
- orca_sdk/async_client.py +3942 -0
- orca_sdk/classification_model.py +218 -20
- orca_sdk/classification_model_test.py +96 -28
- orca_sdk/client.py +899 -712
- orca_sdk/conftest.py +37 -36
- orca_sdk/credentials.py +54 -14
- orca_sdk/credentials_test.py +92 -28
- orca_sdk/datasource.py +64 -12
- orca_sdk/datasource_test.py +144 -18
- orca_sdk/embedding_model.py +54 -37
- orca_sdk/embedding_model_test.py +27 -20
- orca_sdk/job.py +27 -21
- orca_sdk/memoryset.py +823 -205
- orca_sdk/memoryset_test.py +315 -33
- orca_sdk/regression_model.py +59 -15
- orca_sdk/regression_model_test.py +35 -26
- orca_sdk/telemetry.py +76 -26
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.4.dist-info}/METADATA +1 -1
- orca_sdk-0.1.4.dist-info/RECORD +41 -0
- orca_sdk-0.1.2.dist-info/RECORD +0 -40
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.4.dist-info}/WHEEL +0 -0
orca_sdk/datasource_test.py
CHANGED
|
@@ -19,9 +19,10 @@ def test_create_datasource(datasource, hf_dataset):
|
|
|
19
19
|
assert datasource.length == len(hf_dataset)
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def test_create_datasource_unauthenticated(
|
|
23
|
-
with
|
|
24
|
-
|
|
22
|
+
def test_create_datasource_unauthenticated(unauthenticated_client, hf_dataset):
|
|
23
|
+
with unauthenticated_client.use():
|
|
24
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
25
|
+
Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def test_create_datasource_already_exists_error(hf_dataset, datasource):
|
|
@@ -43,9 +44,10 @@ def test_open_datasource(datasource):
|
|
|
43
44
|
assert fetched_datasource.length == len(datasource)
|
|
44
45
|
|
|
45
46
|
|
|
46
|
-
def test_open_datasource_unauthenticated(
|
|
47
|
-
with
|
|
48
|
-
|
|
47
|
+
def test_open_datasource_unauthenticated(unauthenticated_client, datasource):
|
|
48
|
+
with unauthenticated_client.use():
|
|
49
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
50
|
+
Datasource.open("test_datasource")
|
|
49
51
|
|
|
50
52
|
|
|
51
53
|
def test_open_datasource_invalid_input():
|
|
@@ -58,9 +60,10 @@ def test_open_datasource_not_found():
|
|
|
58
60
|
Datasource.open(str(uuid4()))
|
|
59
61
|
|
|
60
62
|
|
|
61
|
-
def test_open_datasource_unauthorized(
|
|
62
|
-
with
|
|
63
|
-
|
|
63
|
+
def test_open_datasource_unauthorized(unauthorized_client, datasource):
|
|
64
|
+
with unauthorized_client.use():
|
|
65
|
+
with pytest.raises(LookupError):
|
|
66
|
+
Datasource.open(datasource.id)
|
|
64
67
|
|
|
65
68
|
|
|
66
69
|
def test_all_datasources(datasource):
|
|
@@ -69,9 +72,10 @@ def test_all_datasources(datasource):
|
|
|
69
72
|
assert any(datasource.name == datasource.name for datasource in datasources)
|
|
70
73
|
|
|
71
74
|
|
|
72
|
-
def test_all_datasources_unauthenticated(
|
|
73
|
-
with
|
|
74
|
-
|
|
75
|
+
def test_all_datasources_unauthenticated(unauthenticated_client):
|
|
76
|
+
with unauthenticated_client.use():
|
|
77
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
78
|
+
Datasource.all()
|
|
75
79
|
|
|
76
80
|
|
|
77
81
|
def test_drop_datasource(hf_dataset):
|
|
@@ -81,9 +85,10 @@ def test_drop_datasource(hf_dataset):
|
|
|
81
85
|
assert not Datasource.exists("datasource_to_delete")
|
|
82
86
|
|
|
83
87
|
|
|
84
|
-
def test_drop_datasource_unauthenticated(datasource,
|
|
85
|
-
with
|
|
86
|
-
|
|
88
|
+
def test_drop_datasource_unauthenticated(datasource, unauthenticated_client):
|
|
89
|
+
with unauthenticated_client.use():
|
|
90
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
91
|
+
Datasource.drop(datasource.id)
|
|
87
92
|
|
|
88
93
|
|
|
89
94
|
def test_drop_datasource_not_found():
|
|
@@ -93,9 +98,10 @@ def test_drop_datasource_not_found():
|
|
|
93
98
|
Datasource.drop(str(uuid4()), if_not_exists="ignore")
|
|
94
99
|
|
|
95
100
|
|
|
96
|
-
def test_drop_datasource_unauthorized(datasource,
|
|
97
|
-
with
|
|
98
|
-
|
|
101
|
+
def test_drop_datasource_unauthorized(datasource, unauthorized_client):
|
|
102
|
+
with unauthorized_client.use():
|
|
103
|
+
with pytest.raises(LookupError):
|
|
104
|
+
Datasource.drop(datasource.id)
|
|
99
105
|
|
|
100
106
|
|
|
101
107
|
def test_drop_datasource_invalid_input():
|
|
@@ -295,6 +301,126 @@ def test_from_disk_already_exists():
|
|
|
295
301
|
os.unlink(f.name)
|
|
296
302
|
|
|
297
303
|
|
|
304
|
+
def test_query_datasource_rows():
|
|
305
|
+
"""Test querying rows from a datasource with pagination and shuffle."""
|
|
306
|
+
# Create a new dataset with 5 entries for testing
|
|
307
|
+
test_data = [{"id": i, "name": f"item_{i}"} for i in range(5)]
|
|
308
|
+
datasource = Datasource.from_list(name="test_query_datasource", data=test_data)
|
|
309
|
+
|
|
310
|
+
# Test basic query
|
|
311
|
+
rows = datasource.query(limit=3)
|
|
312
|
+
assert len(rows) == 3
|
|
313
|
+
assert all(isinstance(row, dict) for row in rows)
|
|
314
|
+
|
|
315
|
+
# Test offset
|
|
316
|
+
offset_rows = datasource.query(offset=2, limit=2)
|
|
317
|
+
assert len(offset_rows) == 2
|
|
318
|
+
assert offset_rows[0]["id"] == 2
|
|
319
|
+
|
|
320
|
+
# Test shuffle
|
|
321
|
+
shuffled_rows = datasource.query(limit=5, shuffle=True)
|
|
322
|
+
assert len(shuffled_rows) == 5
|
|
323
|
+
assert not all(row["id"] == i for i, row in enumerate(shuffled_rows))
|
|
324
|
+
|
|
325
|
+
# Test shuffle with seed
|
|
326
|
+
assert datasource.query(limit=5, shuffle=True, shuffle_seed=42) == datasource.query(
|
|
327
|
+
limit=5, shuffle=True, shuffle_seed=42
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def test_query_datasource_with_filters():
|
|
332
|
+
"""Test querying datasource rows with various filter operators."""
|
|
333
|
+
# Create a datasource with test data
|
|
334
|
+
test_data = [
|
|
335
|
+
{"name": "Alice", "age": 25, "city": "New York", "score": 85.5},
|
|
336
|
+
{"name": "Bob", "age": 30, "city": "San Francisco", "score": 90.0},
|
|
337
|
+
{"name": "Charlie", "age": 35, "city": "Chicago", "score": 75.5},
|
|
338
|
+
{"name": "Diana", "age": 28, "city": "Boston", "score": 88.0},
|
|
339
|
+
{"name": "Eve", "age": 32, "city": "New York", "score": 92.0},
|
|
340
|
+
]
|
|
341
|
+
datasource = Datasource.from_list(name=f"test_filter_datasource_{uuid4()}", data=test_data)
|
|
342
|
+
|
|
343
|
+
# Test == operator
|
|
344
|
+
rows = datasource.query(filters=[("city", "==", "New York")])
|
|
345
|
+
assert len(rows) == 2
|
|
346
|
+
assert all(row["city"] == "New York" for row in rows)
|
|
347
|
+
|
|
348
|
+
# Test > operator
|
|
349
|
+
rows = datasource.query(filters=[("age", ">", 30)])
|
|
350
|
+
assert len(rows) == 2
|
|
351
|
+
assert all(row["age"] > 30 for row in rows)
|
|
352
|
+
|
|
353
|
+
# Test >= operator
|
|
354
|
+
rows = datasource.query(filters=[("score", ">=", 88.0)])
|
|
355
|
+
assert len(rows) == 3
|
|
356
|
+
assert all(row["score"] >= 88.0 for row in rows)
|
|
357
|
+
|
|
358
|
+
# Test < operator
|
|
359
|
+
rows = datasource.query(filters=[("age", "<", 30)])
|
|
360
|
+
assert len(rows) == 2
|
|
361
|
+
assert all(row["age"] < 30 for row in rows)
|
|
362
|
+
|
|
363
|
+
# Test in operator
|
|
364
|
+
rows = datasource.query(filters=[("city", "in", ["New York", "Boston"])])
|
|
365
|
+
assert len(rows) == 3
|
|
366
|
+
assert all(row["city"] in ["New York", "Boston"] for row in rows)
|
|
367
|
+
|
|
368
|
+
# Test not in operator
|
|
369
|
+
rows = datasource.query(filters=[("city", "not in", ["New York", "Boston"])])
|
|
370
|
+
assert len(rows) == 2
|
|
371
|
+
assert all(row["city"] not in ["New York", "Boston"] for row in rows)
|
|
372
|
+
|
|
373
|
+
# Test like operator
|
|
374
|
+
rows = datasource.query(filters=[("name", "like", "li")])
|
|
375
|
+
assert len(rows) == 2
|
|
376
|
+
assert all("li" in row["name"].lower() for row in rows)
|
|
377
|
+
|
|
378
|
+
# Test multiple filters (AND logic)
|
|
379
|
+
rows = datasource.query(filters=[("city", "==", "New York"), ("age", ">", 26)])
|
|
380
|
+
assert len(rows) == 1
|
|
381
|
+
assert rows[0]["name"] == "Eve"
|
|
382
|
+
|
|
383
|
+
# Test filter with pagination
|
|
384
|
+
rows = datasource.query(filters=[("age", ">=", 28)], limit=2, offset=1)
|
|
385
|
+
assert len(rows) == 2
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def test_query_datasource_with_none_filters():
|
|
389
|
+
"""Test filtering for None values."""
|
|
390
|
+
test_data = [
|
|
391
|
+
{"name": "Alice", "age": 25, "label": "A"},
|
|
392
|
+
{"name": "Bob", "age": 30, "label": None},
|
|
393
|
+
{"name": "Charlie", "age": 35, "label": "C"},
|
|
394
|
+
{"name": "Diana", "age": None, "label": "D"},
|
|
395
|
+
{"name": "Eve", "age": 32, "label": None},
|
|
396
|
+
]
|
|
397
|
+
datasource = Datasource.from_list(name=f"test_none_filter_{uuid4()}", data=test_data)
|
|
398
|
+
|
|
399
|
+
# Test == None
|
|
400
|
+
rows = datasource.query(filters=[("label", "==", None)])
|
|
401
|
+
assert len(rows) == 2
|
|
402
|
+
assert all(row["label"] is None for row in rows)
|
|
403
|
+
|
|
404
|
+
# Test != None
|
|
405
|
+
rows = datasource.query(filters=[("label", "!=", None)])
|
|
406
|
+
assert len(rows) == 3
|
|
407
|
+
assert all(row["label"] is not None for row in rows)
|
|
408
|
+
|
|
409
|
+
# Test that None values are excluded from comparison operators
|
|
410
|
+
rows = datasource.query(filters=[("age", ">", 25)])
|
|
411
|
+
assert len(rows) == 3
|
|
412
|
+
assert all(row["age"] is not None and row["age"] > 25 for row in rows)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def test_query_datasource_filter_invalid_column():
|
|
416
|
+
"""Test that querying with an invalid column raises an error."""
|
|
417
|
+
test_data = [{"name": "Alice", "age": 25}]
|
|
418
|
+
datasource = Datasource.from_list(name=f"test_invalid_filter_{uuid4()}", data=test_data)
|
|
419
|
+
|
|
420
|
+
with pytest.raises(ValueError):
|
|
421
|
+
datasource.query(filters=[("invalid_column", "==", "test")])
|
|
422
|
+
|
|
423
|
+
|
|
298
424
|
def test_to_list(hf_dataset, datasource):
|
|
299
425
|
assert datasource.to_list() == hf_dataset.to_list()
|
|
300
426
|
|
orca_sdk/embedding_model.py
CHANGED
|
@@ -12,15 +12,15 @@ from .client import (
|
|
|
12
12
|
EmbedRequest,
|
|
13
13
|
FinetunedEmbeddingModelMetadata,
|
|
14
14
|
FinetuneEmbeddingModelRequest,
|
|
15
|
+
OrcaClient,
|
|
15
16
|
PretrainedEmbeddingModelMetadata,
|
|
16
17
|
PretrainedEmbeddingModelName,
|
|
17
|
-
orca_api,
|
|
18
18
|
)
|
|
19
19
|
from .datasource import Datasource
|
|
20
20
|
from .job import Job, Status
|
|
21
21
|
|
|
22
22
|
if TYPE_CHECKING:
|
|
23
|
-
from .memoryset import LabeledMemoryset
|
|
23
|
+
from .memoryset import LabeledMemoryset, ScoredMemoryset
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class EmbeddingModelBase(ABC):
|
|
@@ -82,15 +82,16 @@ class EmbeddingModelBase(ABC):
|
|
|
82
82
|
"max_seq_length": max_seq_length,
|
|
83
83
|
"prompt": prompt,
|
|
84
84
|
}
|
|
85
|
+
client = OrcaClient._resolve_client()
|
|
85
86
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
86
|
-
embeddings =
|
|
87
|
+
embeddings = client.POST(
|
|
87
88
|
"/gpu/pretrained_embedding_model/{model_name}/embedding",
|
|
88
89
|
params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
|
|
89
90
|
json=payload,
|
|
90
91
|
timeout=30, # may be slow in case of cold start
|
|
91
92
|
)
|
|
92
93
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
93
|
-
embeddings =
|
|
94
|
+
embeddings = client.POST(
|
|
94
95
|
"/gpu/finetuned_embedding_model/{name_or_id}/embedding",
|
|
95
96
|
params={"name_or_id": self.id},
|
|
96
97
|
json=payload,
|
|
@@ -109,7 +110,7 @@ class EmbeddingModelBase(ABC):
|
|
|
109
110
|
label_column: str,
|
|
110
111
|
score_column: None = None,
|
|
111
112
|
eval_datasource: Datasource | None = None,
|
|
112
|
-
subsample: int | None = None,
|
|
113
|
+
subsample: int | float | None = None,
|
|
113
114
|
neighbor_count: int = 5,
|
|
114
115
|
batch_size: int = 32,
|
|
115
116
|
weigh_memories: bool = True,
|
|
@@ -126,7 +127,7 @@ class EmbeddingModelBase(ABC):
|
|
|
126
127
|
label_column: str,
|
|
127
128
|
score_column: None = None,
|
|
128
129
|
eval_datasource: Datasource | None = None,
|
|
129
|
-
subsample: int | None = None,
|
|
130
|
+
subsample: int | float | None = None,
|
|
130
131
|
neighbor_count: int = 5,
|
|
131
132
|
batch_size: int = 32,
|
|
132
133
|
weigh_memories: bool = True,
|
|
@@ -143,7 +144,7 @@ class EmbeddingModelBase(ABC):
|
|
|
143
144
|
label_column: None = None,
|
|
144
145
|
score_column: str,
|
|
145
146
|
eval_datasource: Datasource | None = None,
|
|
146
|
-
subsample: int | None = None,
|
|
147
|
+
subsample: int | float | None = None,
|
|
147
148
|
neighbor_count: int = 5,
|
|
148
149
|
batch_size: int = 32,
|
|
149
150
|
weigh_memories: bool = True,
|
|
@@ -160,7 +161,7 @@ class EmbeddingModelBase(ABC):
|
|
|
160
161
|
label_column: None = None,
|
|
161
162
|
score_column: str,
|
|
162
163
|
eval_datasource: Datasource | None = None,
|
|
163
|
-
subsample: int | None = None,
|
|
164
|
+
subsample: int | float | None = None,
|
|
164
165
|
neighbor_count: int = 5,
|
|
165
166
|
batch_size: int = 32,
|
|
166
167
|
weigh_memories: bool = True,
|
|
@@ -176,7 +177,7 @@ class EmbeddingModelBase(ABC):
|
|
|
176
177
|
label_column: str | None = None,
|
|
177
178
|
score_column: str | None = None,
|
|
178
179
|
eval_datasource: Datasource | None = None,
|
|
179
|
-
subsample: int | None = None,
|
|
180
|
+
subsample: int | float | None = None,
|
|
180
181
|
neighbor_count: int = 5,
|
|
181
182
|
batch_size: int = 32,
|
|
182
183
|
weigh_memories: bool = True,
|
|
@@ -191,6 +192,7 @@ class EmbeddingModelBase(ABC):
|
|
|
191
192
|
"""
|
|
192
193
|
Evaluate the finetuned embedding model
|
|
193
194
|
"""
|
|
195
|
+
|
|
194
196
|
payload: EmbeddingEvaluationRequest = {
|
|
195
197
|
"datasource_name_or_id": datasource.id,
|
|
196
198
|
"datasource_label_column": label_column,
|
|
@@ -202,14 +204,15 @@ class EmbeddingModelBase(ABC):
|
|
|
202
204
|
"batch_size": batch_size,
|
|
203
205
|
"weigh_memories": weigh_memories,
|
|
204
206
|
}
|
|
207
|
+
client = OrcaClient._resolve_client()
|
|
205
208
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
206
|
-
response =
|
|
209
|
+
response = client.POST(
|
|
207
210
|
"/pretrained_embedding_model/{model_name}/evaluation",
|
|
208
211
|
params={"model_name": self.name},
|
|
209
212
|
json=payload,
|
|
210
213
|
)
|
|
211
214
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
212
|
-
response =
|
|
215
|
+
response = client.POST(
|
|
213
216
|
"/finetuned_embedding_model/{name_or_id}/evaluation",
|
|
214
217
|
params={"name_or_id": self.id},
|
|
215
218
|
json=payload,
|
|
@@ -217,16 +220,17 @@ class EmbeddingModelBase(ABC):
|
|
|
217
220
|
else:
|
|
218
221
|
raise ValueError("Invalid embedding model")
|
|
219
222
|
|
|
220
|
-
def get_result(
|
|
223
|
+
def get_result(job_id: str) -> ClassificationMetrics | RegressionMetrics:
|
|
224
|
+
client = OrcaClient._resolve_client()
|
|
221
225
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
222
|
-
res =
|
|
223
|
-
"/pretrained_embedding_model/{model_name}/evaluation/{
|
|
224
|
-
params={"model_name": self.name, "
|
|
226
|
+
res = client.GET(
|
|
227
|
+
"/pretrained_embedding_model/{model_name}/evaluation/{job_id}",
|
|
228
|
+
params={"model_name": self.name, "job_id": job_id},
|
|
225
229
|
)["result"]
|
|
226
230
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
227
|
-
res =
|
|
228
|
-
"/finetuned_embedding_model/{name_or_id}/evaluation/{
|
|
229
|
-
params={"name_or_id": self.id, "
|
|
231
|
+
res = client.GET(
|
|
232
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation/{job_id}",
|
|
233
|
+
params={"name_or_id": self.id, "job_id": job_id},
|
|
230
234
|
)["result"]
|
|
231
235
|
else:
|
|
232
236
|
raise ValueError("Invalid embedding model")
|
|
@@ -260,7 +264,7 @@ class EmbeddingModelBase(ABC):
|
|
|
260
264
|
)
|
|
261
265
|
)
|
|
262
266
|
|
|
263
|
-
job = Job(response["
|
|
267
|
+
job = Job(response["job_id"], lambda: get_result(response["job_id"]))
|
|
264
268
|
return job if background else job.result()
|
|
265
269
|
|
|
266
270
|
|
|
@@ -401,7 +405,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
401
405
|
Returns:
|
|
402
406
|
A list of all pretrained embedding models available in the OrcaCloud
|
|
403
407
|
"""
|
|
404
|
-
|
|
408
|
+
client = OrcaClient._resolve_client()
|
|
409
|
+
return [cls(metadata) for metadata in client.GET("/pretrained_embedding_model")]
|
|
405
410
|
|
|
406
411
|
_instances: dict[str, PretrainedEmbeddingModel] = {}
|
|
407
412
|
|
|
@@ -410,7 +415,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
410
415
|
# for internal use only, do not document - we want people to use dot notation to get the model
|
|
411
416
|
cache_key = str(name)
|
|
412
417
|
if cache_key not in cls._instances:
|
|
413
|
-
|
|
418
|
+
client = OrcaClient._resolve_client()
|
|
419
|
+
metadata = client.GET(
|
|
414
420
|
"/pretrained_embedding_model/{model_name}",
|
|
415
421
|
params={"model_name": name},
|
|
416
422
|
)
|
|
@@ -457,12 +463,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
457
463
|
def finetune(
|
|
458
464
|
self,
|
|
459
465
|
name: str,
|
|
460
|
-
train_datasource: Datasource | LabeledMemoryset,
|
|
466
|
+
train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
|
|
461
467
|
*,
|
|
462
468
|
eval_datasource: Datasource | None = None,
|
|
463
469
|
label_column: str = "label",
|
|
470
|
+
score_column: str = "score",
|
|
464
471
|
value_column: str = "value",
|
|
465
|
-
training_method: EmbeddingFinetuningMethod =
|
|
472
|
+
training_method: EmbeddingFinetuningMethod | None = None,
|
|
466
473
|
training_args: dict | None = None,
|
|
467
474
|
if_exists: CreateMode = "error",
|
|
468
475
|
background: Literal[True],
|
|
@@ -473,12 +480,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
473
480
|
def finetune(
|
|
474
481
|
self,
|
|
475
482
|
name: str,
|
|
476
|
-
train_datasource: Datasource | LabeledMemoryset,
|
|
483
|
+
train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
|
|
477
484
|
*,
|
|
478
485
|
eval_datasource: Datasource | None = None,
|
|
479
486
|
label_column: str = "label",
|
|
487
|
+
score_column: str = "score",
|
|
480
488
|
value_column: str = "value",
|
|
481
|
-
training_method: EmbeddingFinetuningMethod =
|
|
489
|
+
training_method: EmbeddingFinetuningMethod | None = None,
|
|
482
490
|
training_args: dict | None = None,
|
|
483
491
|
if_exists: CreateMode = "error",
|
|
484
492
|
background: Literal[False] = False,
|
|
@@ -488,12 +496,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
488
496
|
def finetune(
|
|
489
497
|
self,
|
|
490
498
|
name: str,
|
|
491
|
-
train_datasource: Datasource | LabeledMemoryset,
|
|
499
|
+
train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
|
|
492
500
|
*,
|
|
493
501
|
eval_datasource: Datasource | None = None,
|
|
494
502
|
label_column: str = "label",
|
|
503
|
+
score_column: str = "score",
|
|
495
504
|
value_column: str = "value",
|
|
496
|
-
training_method: EmbeddingFinetuningMethod =
|
|
505
|
+
training_method: EmbeddingFinetuningMethod | None = None,
|
|
497
506
|
training_args: dict | None = None,
|
|
498
507
|
if_exists: CreateMode = "error",
|
|
499
508
|
background: bool = False,
|
|
@@ -505,9 +514,10 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
505
514
|
name: Name of the finetuned embedding model
|
|
506
515
|
train_datasource: Data to train on
|
|
507
516
|
eval_datasource: Optionally provide data to evaluate on
|
|
508
|
-
label_column: Column name of the label
|
|
517
|
+
label_column: Column name of the label.
|
|
518
|
+
score_column: Column name of the score (for regression when training on scored data).
|
|
509
519
|
value_column: Column name of the value
|
|
510
|
-
training_method:
|
|
520
|
+
training_method: Optional training method override. If omitted, Lighthouse defaults apply.
|
|
511
521
|
training_args: Optional override for Hugging Face [`TrainingArguments`][transformers.TrainingArguments].
|
|
512
522
|
If not provided, reasonable training arguments will be used for the specified training method
|
|
513
523
|
if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
|
|
@@ -538,29 +548,33 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
538
548
|
|
|
539
549
|
return existing
|
|
540
550
|
|
|
541
|
-
from .memoryset import LabeledMemoryset
|
|
551
|
+
from .memoryset import LabeledMemoryset, ScoredMemoryset
|
|
542
552
|
|
|
543
553
|
payload: FinetuneEmbeddingModelRequest = {
|
|
544
554
|
"name": name,
|
|
545
555
|
"base_model": self.name,
|
|
546
556
|
"label_column": label_column,
|
|
557
|
+
"score_column": score_column,
|
|
547
558
|
"value_column": value_column,
|
|
548
|
-
"training_method": training_method,
|
|
549
559
|
"training_args": training_args or {},
|
|
550
560
|
}
|
|
561
|
+
if training_method is not None:
|
|
562
|
+
payload["training_method"] = training_method
|
|
563
|
+
|
|
551
564
|
if isinstance(train_datasource, Datasource):
|
|
552
565
|
payload["train_datasource_name_or_id"] = train_datasource.id
|
|
553
|
-
elif isinstance(train_datasource, LabeledMemoryset):
|
|
566
|
+
elif isinstance(train_datasource, (LabeledMemoryset, ScoredMemoryset)):
|
|
554
567
|
payload["train_memoryset_name_or_id"] = train_datasource.id
|
|
555
568
|
if eval_datasource is not None:
|
|
556
569
|
payload["eval_datasource_name_or_id"] = eval_datasource.id
|
|
557
570
|
|
|
558
|
-
|
|
571
|
+
client = OrcaClient._resolve_client()
|
|
572
|
+
res = client.POST(
|
|
559
573
|
"/finetuned_embedding_model",
|
|
560
574
|
json=payload,
|
|
561
575
|
)
|
|
562
576
|
job = Job(
|
|
563
|
-
res["
|
|
577
|
+
res["finetuning_job_id"],
|
|
564
578
|
lambda: FinetunedEmbeddingModel.open(res["id"]),
|
|
565
579
|
)
|
|
566
580
|
return job if background else job.result()
|
|
@@ -630,7 +644,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
|
630
644
|
Returns:
|
|
631
645
|
A list of all finetuned embedding model handles in the OrcaCloud
|
|
632
646
|
"""
|
|
633
|
-
|
|
647
|
+
client = OrcaClient._resolve_client()
|
|
648
|
+
return [cls(metadata) for metadata in client.GET("/finetuned_embedding_model")]
|
|
634
649
|
|
|
635
650
|
@classmethod
|
|
636
651
|
def open(cls, name: str) -> FinetunedEmbeddingModel:
|
|
@@ -646,7 +661,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
|
646
661
|
Raises:
|
|
647
662
|
LookupError: If the finetuned embedding model does not exist
|
|
648
663
|
"""
|
|
649
|
-
|
|
664
|
+
client = OrcaClient._resolve_client()
|
|
665
|
+
metadata = client.GET(
|
|
650
666
|
"/finetuned_embedding_model/{name_or_id}",
|
|
651
667
|
params={"name_or_id": name},
|
|
652
668
|
)
|
|
@@ -681,7 +697,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
|
681
697
|
LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
|
|
682
698
|
"""
|
|
683
699
|
try:
|
|
684
|
-
|
|
700
|
+
client = OrcaClient._resolve_client()
|
|
701
|
+
client.DELETE(
|
|
685
702
|
"/finetuned_embedding_model/{name_or_id}",
|
|
686
703
|
params={"name_or_id": name_or_id},
|
|
687
704
|
)
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -25,9 +25,10 @@ def test_open_pretrained_model():
|
|
|
25
25
|
assert model is PretrainedEmbeddingModel.GTE_BASE
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def test_open_pretrained_model_unauthenticated(
|
|
29
|
-
with
|
|
30
|
-
|
|
28
|
+
def test_open_pretrained_model_unauthenticated(unauthenticated_client):
|
|
29
|
+
with unauthenticated_client.use():
|
|
30
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
31
|
+
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def test_open_pretrained_model_not_found():
|
|
@@ -52,9 +53,10 @@ def test_embed_text():
|
|
|
52
53
|
assert isinstance(embedding[0], float)
|
|
53
54
|
|
|
54
55
|
|
|
55
|
-
def test_embed_text_unauthenticated(
|
|
56
|
-
with
|
|
57
|
-
|
|
56
|
+
def test_embed_text_unauthenticated(unauthenticated_client):
|
|
57
|
+
with unauthenticated_client.use():
|
|
58
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
59
|
+
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
def test_evaluate_pretrained_model(datasource: Datasource):
|
|
@@ -108,9 +110,10 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
|
|
|
108
110
|
assert new_model._status == Status.COMPLETED
|
|
109
111
|
|
|
110
112
|
|
|
111
|
-
def test_finetune_model_unauthenticated(
|
|
112
|
-
with
|
|
113
|
-
|
|
113
|
+
def test_finetune_model_unauthenticated(unauthenticated_client, datasource: Datasource):
|
|
114
|
+
with unauthenticated_client.use():
|
|
115
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
116
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
|
|
114
117
|
|
|
115
118
|
|
|
116
119
|
def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
|
|
@@ -150,13 +153,15 @@ def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
|
|
|
150
153
|
assert any(model.name == finetuned_model.name for model in models)
|
|
151
154
|
|
|
152
155
|
|
|
153
|
-
def test_all_finetuned_models_unauthenticated(
|
|
154
|
-
with
|
|
155
|
-
|
|
156
|
+
def test_all_finetuned_models_unauthenticated(unauthenticated_client):
|
|
157
|
+
with unauthenticated_client.use():
|
|
158
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
159
|
+
FinetunedEmbeddingModel.all()
|
|
156
160
|
|
|
157
161
|
|
|
158
|
-
def test_all_finetuned_models_unauthorized(
|
|
159
|
-
|
|
162
|
+
def test_all_finetuned_models_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
|
|
163
|
+
with unauthorized_client.use():
|
|
164
|
+
assert finetuned_model not in FinetunedEmbeddingModel.all()
|
|
160
165
|
|
|
161
166
|
|
|
162
167
|
def test_drop_finetuned_model(datasource: Datasource):
|
|
@@ -167,9 +172,10 @@ def test_drop_finetuned_model(datasource: Datasource):
|
|
|
167
172
|
FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
168
173
|
|
|
169
174
|
|
|
170
|
-
def test_drop_finetuned_model_unauthenticated(
|
|
171
|
-
with
|
|
172
|
-
|
|
175
|
+
def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
|
|
176
|
+
with unauthenticated_client.use():
|
|
177
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
178
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
173
179
|
|
|
174
180
|
|
|
175
181
|
def test_drop_finetuned_model_not_found():
|
|
@@ -179,9 +185,10 @@ def test_drop_finetuned_model_not_found():
|
|
|
179
185
|
FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
180
186
|
|
|
181
187
|
|
|
182
|
-
def test_drop_finetuned_model_unauthorized(
|
|
183
|
-
with
|
|
184
|
-
|
|
188
|
+
def test_drop_finetuned_model_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
|
|
189
|
+
with unauthorized_client.use():
|
|
190
|
+
with pytest.raises(LookupError):
|
|
191
|
+
FinetunedEmbeddingModel.drop(finetuned_model.id)
|
|
185
192
|
|
|
186
193
|
|
|
187
194
|
def test_supports_instructions():
|