orca-sdk 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -1
- orca_sdk/_utils/auth.py +12 -8
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +176 -14
- orca_sdk/classification_model_test.py +96 -28
- orca_sdk/client.py +515 -475
- orca_sdk/conftest.py +37 -36
- orca_sdk/credentials.py +54 -14
- orca_sdk/credentials_test.py +92 -28
- orca_sdk/datasource.py +19 -10
- orca_sdk/datasource_test.py +24 -18
- orca_sdk/embedding_model.py +22 -13
- orca_sdk/embedding_model_test.py +27 -20
- orca_sdk/job.py +14 -8
- orca_sdk/memoryset.py +513 -183
- orca_sdk/memoryset_test.py +130 -32
- orca_sdk/regression_model.py +21 -11
- orca_sdk/regression_model_test.py +35 -26
- orca_sdk/telemetry.py +24 -13
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +1 -1
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- orca_sdk-0.1.2.dist-info/RECORD +0 -40
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +0 -0
orca_sdk/embedding_model.py
CHANGED
|
@@ -12,9 +12,9 @@ from .client import (
|
|
|
12
12
|
EmbedRequest,
|
|
13
13
|
FinetunedEmbeddingModelMetadata,
|
|
14
14
|
FinetuneEmbeddingModelRequest,
|
|
15
|
+
OrcaClient,
|
|
15
16
|
PretrainedEmbeddingModelMetadata,
|
|
16
17
|
PretrainedEmbeddingModelName,
|
|
17
|
-
orca_api,
|
|
18
18
|
)
|
|
19
19
|
from .datasource import Datasource
|
|
20
20
|
from .job import Job, Status
|
|
@@ -82,15 +82,16 @@ class EmbeddingModelBase(ABC):
|
|
|
82
82
|
"max_seq_length": max_seq_length,
|
|
83
83
|
"prompt": prompt,
|
|
84
84
|
}
|
|
85
|
+
client = OrcaClient._resolve_client()
|
|
85
86
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
86
|
-
embeddings =
|
|
87
|
+
embeddings = client.POST(
|
|
87
88
|
"/gpu/pretrained_embedding_model/{model_name}/embedding",
|
|
88
89
|
params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
|
|
89
90
|
json=payload,
|
|
90
91
|
timeout=30, # may be slow in case of cold start
|
|
91
92
|
)
|
|
92
93
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
93
|
-
embeddings =
|
|
94
|
+
embeddings = client.POST(
|
|
94
95
|
"/gpu/finetuned_embedding_model/{name_or_id}/embedding",
|
|
95
96
|
params={"name_or_id": self.id},
|
|
96
97
|
json=payload,
|
|
@@ -202,14 +203,15 @@ class EmbeddingModelBase(ABC):
|
|
|
202
203
|
"batch_size": batch_size,
|
|
203
204
|
"weigh_memories": weigh_memories,
|
|
204
205
|
}
|
|
206
|
+
client = OrcaClient._resolve_client()
|
|
205
207
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
206
|
-
response =
|
|
208
|
+
response = client.POST(
|
|
207
209
|
"/pretrained_embedding_model/{model_name}/evaluation",
|
|
208
210
|
params={"model_name": self.name},
|
|
209
211
|
json=payload,
|
|
210
212
|
)
|
|
211
213
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
212
|
-
response =
|
|
214
|
+
response = client.POST(
|
|
213
215
|
"/finetuned_embedding_model/{name_or_id}/evaluation",
|
|
214
216
|
params={"name_or_id": self.id},
|
|
215
217
|
json=payload,
|
|
@@ -218,13 +220,14 @@ class EmbeddingModelBase(ABC):
|
|
|
218
220
|
raise ValueError("Invalid embedding model")
|
|
219
221
|
|
|
220
222
|
def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
|
|
223
|
+
client = OrcaClient._resolve_client()
|
|
221
224
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
222
|
-
res =
|
|
225
|
+
res = client.GET(
|
|
223
226
|
"/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
|
|
224
227
|
params={"model_name": self.name, "task_id": task_id},
|
|
225
228
|
)["result"]
|
|
226
229
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
227
|
-
res =
|
|
230
|
+
res = client.GET(
|
|
228
231
|
"/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
|
|
229
232
|
params={"name_or_id": self.id, "task_id": task_id},
|
|
230
233
|
)["result"]
|
|
@@ -401,7 +404,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
401
404
|
Returns:
|
|
402
405
|
A list of all pretrained embedding models available in the OrcaCloud
|
|
403
406
|
"""
|
|
404
|
-
|
|
407
|
+
client = OrcaClient._resolve_client()
|
|
408
|
+
return [cls(metadata) for metadata in client.GET("/pretrained_embedding_model")]
|
|
405
409
|
|
|
406
410
|
_instances: dict[str, PretrainedEmbeddingModel] = {}
|
|
407
411
|
|
|
@@ -410,7 +414,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
410
414
|
# for internal use only, do not document - we want people to use dot notation to get the model
|
|
411
415
|
cache_key = str(name)
|
|
412
416
|
if cache_key not in cls._instances:
|
|
413
|
-
|
|
417
|
+
client = OrcaClient._resolve_client()
|
|
418
|
+
metadata = client.GET(
|
|
414
419
|
"/pretrained_embedding_model/{model_name}",
|
|
415
420
|
params={"model_name": name},
|
|
416
421
|
)
|
|
@@ -555,7 +560,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
|
555
560
|
if eval_datasource is not None:
|
|
556
561
|
payload["eval_datasource_name_or_id"] = eval_datasource.id
|
|
557
562
|
|
|
558
|
-
|
|
563
|
+
client = OrcaClient._resolve_client()
|
|
564
|
+
res = client.POST(
|
|
559
565
|
"/finetuned_embedding_model",
|
|
560
566
|
json=payload,
|
|
561
567
|
)
|
|
@@ -630,7 +636,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
|
630
636
|
Returns:
|
|
631
637
|
A list of all finetuned embedding model handles in the OrcaCloud
|
|
632
638
|
"""
|
|
633
|
-
|
|
639
|
+
client = OrcaClient._resolve_client()
|
|
640
|
+
return [cls(metadata) for metadata in client.GET("/finetuned_embedding_model")]
|
|
634
641
|
|
|
635
642
|
@classmethod
|
|
636
643
|
def open(cls, name: str) -> FinetunedEmbeddingModel:
|
|
@@ -646,7 +653,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
|
646
653
|
Raises:
|
|
647
654
|
LookupError: If the finetuned embedding model does not exist
|
|
648
655
|
"""
|
|
649
|
-
|
|
656
|
+
client = OrcaClient._resolve_client()
|
|
657
|
+
metadata = client.GET(
|
|
650
658
|
"/finetuned_embedding_model/{name_or_id}",
|
|
651
659
|
params={"name_or_id": name},
|
|
652
660
|
)
|
|
@@ -681,7 +689,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
|
681
689
|
LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
|
|
682
690
|
"""
|
|
683
691
|
try:
|
|
684
|
-
|
|
692
|
+
client = OrcaClient._resolve_client()
|
|
693
|
+
client.DELETE(
|
|
685
694
|
"/finetuned_embedding_model/{name_or_id}",
|
|
686
695
|
params={"name_or_id": name_or_id},
|
|
687
696
|
)
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -25,9 +25,10 @@ def test_open_pretrained_model():
|
|
|
25
25
|
assert model is PretrainedEmbeddingModel.GTE_BASE
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def test_open_pretrained_model_unauthenticated(
|
|
29
|
-
with
|
|
30
|
-
|
|
28
|
+
def test_open_pretrained_model_unauthenticated(unauthenticated_client):
|
|
29
|
+
with unauthenticated_client.use():
|
|
30
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
31
|
+
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def test_open_pretrained_model_not_found():
|
|
@@ -52,9 +53,10 @@ def test_embed_text():
|
|
|
52
53
|
assert isinstance(embedding[0], float)
|
|
53
54
|
|
|
54
55
|
|
|
55
|
-
def test_embed_text_unauthenticated(
|
|
56
|
-
with
|
|
57
|
-
|
|
56
|
+
def test_embed_text_unauthenticated(unauthenticated_client):
|
|
57
|
+
with unauthenticated_client.use():
|
|
58
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
59
|
+
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
def test_evaluate_pretrained_model(datasource: Datasource):
|
|
@@ -108,9 +110,10 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
|
|
|
108
110
|
assert new_model._status == Status.COMPLETED
|
|
109
111
|
|
|
110
112
|
|
|
111
|
-
def test_finetune_model_unauthenticated(
|
|
112
|
-
with
|
|
113
|
-
|
|
113
|
+
def test_finetune_model_unauthenticated(unauthenticated_client, datasource: Datasource):
|
|
114
|
+
with unauthenticated_client.use():
|
|
115
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
116
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
|
|
114
117
|
|
|
115
118
|
|
|
116
119
|
def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
|
|
@@ -150,13 +153,15 @@ def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
|
|
|
150
153
|
assert any(model.name == finetuned_model.name for model in models)
|
|
151
154
|
|
|
152
155
|
|
|
153
|
-
def test_all_finetuned_models_unauthenticated(
|
|
154
|
-
with
|
|
155
|
-
|
|
156
|
+
def test_all_finetuned_models_unauthenticated(unauthenticated_client):
|
|
157
|
+
with unauthenticated_client.use():
|
|
158
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
159
|
+
FinetunedEmbeddingModel.all()
|
|
156
160
|
|
|
157
161
|
|
|
158
|
-
def test_all_finetuned_models_unauthorized(
|
|
159
|
-
|
|
162
|
+
def test_all_finetuned_models_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
|
|
163
|
+
with unauthorized_client.use():
|
|
164
|
+
assert finetuned_model not in FinetunedEmbeddingModel.all()
|
|
160
165
|
|
|
161
166
|
|
|
162
167
|
def test_drop_finetuned_model(datasource: Datasource):
|
|
@@ -167,9 +172,10 @@ def test_drop_finetuned_model(datasource: Datasource):
|
|
|
167
172
|
FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
168
173
|
|
|
169
174
|
|
|
170
|
-
def test_drop_finetuned_model_unauthenticated(
|
|
171
|
-
with
|
|
172
|
-
|
|
175
|
+
def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
|
|
176
|
+
with unauthenticated_client.use():
|
|
177
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
178
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
173
179
|
|
|
174
180
|
|
|
175
181
|
def test_drop_finetuned_model_not_found():
|
|
@@ -179,9 +185,10 @@ def test_drop_finetuned_model_not_found():
|
|
|
179
185
|
FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
180
186
|
|
|
181
187
|
|
|
182
|
-
def test_drop_finetuned_model_unauthorized(
|
|
183
|
-
with
|
|
184
|
-
|
|
188
|
+
def test_drop_finetuned_model_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
|
|
189
|
+
with unauthorized_client.use():
|
|
190
|
+
with pytest.raises(LookupError):
|
|
191
|
+
FinetunedEmbeddingModel.drop(finetuned_model.id)
|
|
185
192
|
|
|
186
193
|
|
|
187
194
|
def test_supports_instructions():
|
orca_sdk/job.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Callable, Generic, TypedDict, TypeVar, cast
|
|
|
7
7
|
|
|
8
8
|
from tqdm.auto import tqdm
|
|
9
9
|
|
|
10
|
-
from .client import
|
|
10
|
+
from .client import OrcaClient
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class JobConfig(TypedDict):
|
|
@@ -140,7 +140,8 @@ class Job(Generic[TResult]):
|
|
|
140
140
|
Returns:
|
|
141
141
|
List of jobs matching the given filters
|
|
142
142
|
"""
|
|
143
|
-
|
|
143
|
+
client = OrcaClient._resolve_client()
|
|
144
|
+
paginated_tasks = client.GET(
|
|
144
145
|
"/task",
|
|
145
146
|
params={
|
|
146
147
|
"status": (
|
|
@@ -186,11 +187,14 @@ class Job(Generic[TResult]):
|
|
|
186
187
|
get_value: Optional function to customize how the value is resolved, if not provided the result will be a dict
|
|
187
188
|
"""
|
|
188
189
|
self.id = id
|
|
189
|
-
|
|
190
|
+
client = OrcaClient._resolve_client()
|
|
191
|
+
task = client.GET("/task/{task_id}", params={"task_id": id})
|
|
190
192
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
193
|
+
def default_get_value():
|
|
194
|
+
client = OrcaClient._resolve_client()
|
|
195
|
+
return cast(TResult | None, client.GET("/task/{task_id}", params={"task_id": id})["result"])
|
|
196
|
+
|
|
197
|
+
self._get_value = get_value or default_get_value
|
|
194
198
|
self.type = task["type"]
|
|
195
199
|
self.status = Status(task["status"])
|
|
196
200
|
self.steps_total = task["steps_total"]
|
|
@@ -222,7 +226,8 @@ class Job(Generic[TResult]):
|
|
|
222
226
|
return
|
|
223
227
|
self.refreshed_at = current_time
|
|
224
228
|
|
|
225
|
-
|
|
229
|
+
client = OrcaClient._resolve_client()
|
|
230
|
+
status_info = client.GET("/task/{task_id}/status", params={"task_id": self.id})
|
|
226
231
|
self.status = Status(status_info["status"])
|
|
227
232
|
if status_info["steps_total"] is not None:
|
|
228
233
|
self.steps_total = status_info["steps_total"]
|
|
@@ -333,5 +338,6 @@ def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait
|
|
|
333
338
|
refresh_interval: Polling interval in seconds while waiting for the job to abort
|
|
334
339
|
max_wait: Maximum time to wait for the job to abort in seconds
|
|
335
340
|
"""
|
|
336
|
-
|
|
341
|
+
client = OrcaClient._resolve_client()
|
|
342
|
+
client.DELETE("/task/{task_id}/abort", params={"task_id": self.id})
|
|
337
343
|
self.wait(show_progress, refresh_interval, max_wait)
|