orca-sdk 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,9 +12,9 @@ from .client import (
12
12
  EmbedRequest,
13
13
  FinetunedEmbeddingModelMetadata,
14
14
  FinetuneEmbeddingModelRequest,
15
+ OrcaClient,
15
16
  PretrainedEmbeddingModelMetadata,
16
17
  PretrainedEmbeddingModelName,
17
- orca_api,
18
18
  )
19
19
  from .datasource import Datasource
20
20
  from .job import Job, Status
@@ -82,15 +82,16 @@ class EmbeddingModelBase(ABC):
82
82
  "max_seq_length": max_seq_length,
83
83
  "prompt": prompt,
84
84
  }
85
+ client = OrcaClient._resolve_client()
85
86
  if isinstance(self, PretrainedEmbeddingModel):
86
- embeddings = orca_api.POST(
87
+ embeddings = client.POST(
87
88
  "/gpu/pretrained_embedding_model/{model_name}/embedding",
88
89
  params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
89
90
  json=payload,
90
91
  timeout=30, # may be slow in case of cold start
91
92
  )
92
93
  elif isinstance(self, FinetunedEmbeddingModel):
93
- embeddings = orca_api.POST(
94
+ embeddings = client.POST(
94
95
  "/gpu/finetuned_embedding_model/{name_or_id}/embedding",
95
96
  params={"name_or_id": self.id},
96
97
  json=payload,
@@ -202,14 +203,15 @@ class EmbeddingModelBase(ABC):
202
203
  "batch_size": batch_size,
203
204
  "weigh_memories": weigh_memories,
204
205
  }
206
+ client = OrcaClient._resolve_client()
205
207
  if isinstance(self, PretrainedEmbeddingModel):
206
- response = orca_api.POST(
208
+ response = client.POST(
207
209
  "/pretrained_embedding_model/{model_name}/evaluation",
208
210
  params={"model_name": self.name},
209
211
  json=payload,
210
212
  )
211
213
  elif isinstance(self, FinetunedEmbeddingModel):
212
- response = orca_api.POST(
214
+ response = client.POST(
213
215
  "/finetuned_embedding_model/{name_or_id}/evaluation",
214
216
  params={"name_or_id": self.id},
215
217
  json=payload,
@@ -218,13 +220,14 @@ class EmbeddingModelBase(ABC):
218
220
  raise ValueError("Invalid embedding model")
219
221
 
220
222
  def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
223
+ client = OrcaClient._resolve_client()
221
224
  if isinstance(self, PretrainedEmbeddingModel):
222
- res = orca_api.GET(
225
+ res = client.GET(
223
226
  "/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
224
227
  params={"model_name": self.name, "task_id": task_id},
225
228
  )["result"]
226
229
  elif isinstance(self, FinetunedEmbeddingModel):
227
- res = orca_api.GET(
230
+ res = client.GET(
228
231
  "/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
229
232
  params={"name_or_id": self.id, "task_id": task_id},
230
233
  )["result"]
@@ -401,7 +404,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
401
404
  Returns:
402
405
  A list of all pretrained embedding models available in the OrcaCloud
403
406
  """
404
- return [cls(metadata) for metadata in orca_api.GET("/pretrained_embedding_model")]
407
+ client = OrcaClient._resolve_client()
408
+ return [cls(metadata) for metadata in client.GET("/pretrained_embedding_model")]
405
409
 
406
410
  _instances: dict[str, PretrainedEmbeddingModel] = {}
407
411
 
@@ -410,7 +414,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
410
414
  # for internal use only, do not document - we want people to use dot notation to get the model
411
415
  cache_key = str(name)
412
416
  if cache_key not in cls._instances:
413
- metadata = orca_api.GET(
417
+ client = OrcaClient._resolve_client()
418
+ metadata = client.GET(
414
419
  "/pretrained_embedding_model/{model_name}",
415
420
  params={"model_name": name},
416
421
  )
@@ -555,7 +560,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
555
560
  if eval_datasource is not None:
556
561
  payload["eval_datasource_name_or_id"] = eval_datasource.id
557
562
 
558
- res = orca_api.POST(
563
+ client = OrcaClient._resolve_client()
564
+ res = client.POST(
559
565
  "/finetuned_embedding_model",
560
566
  json=payload,
561
567
  )
@@ -630,7 +636,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
630
636
  Returns:
631
637
  A list of all finetuned embedding model handles in the OrcaCloud
632
638
  """
633
- return [cls(metadata) for metadata in orca_api.GET("/finetuned_embedding_model")]
639
+ client = OrcaClient._resolve_client()
640
+ return [cls(metadata) for metadata in client.GET("/finetuned_embedding_model")]
634
641
 
635
642
  @classmethod
636
643
  def open(cls, name: str) -> FinetunedEmbeddingModel:
@@ -646,7 +653,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
646
653
  Raises:
647
654
  LookupError: If the finetuned embedding model does not exist
648
655
  """
649
- metadata = orca_api.GET(
656
+ client = OrcaClient._resolve_client()
657
+ metadata = client.GET(
650
658
  "/finetuned_embedding_model/{name_or_id}",
651
659
  params={"name_or_id": name},
652
660
  )
@@ -681,7 +689,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
681
689
  LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
682
690
  """
683
691
  try:
684
- orca_api.DELETE(
692
+ client = OrcaClient._resolve_client()
693
+ client.DELETE(
685
694
  "/finetuned_embedding_model/{name_or_id}",
686
695
  params={"name_or_id": name_or_id},
687
696
  )
@@ -25,9 +25,10 @@ def test_open_pretrained_model():
25
25
  assert model is PretrainedEmbeddingModel.GTE_BASE
26
26
 
27
27
 
28
- def test_open_pretrained_model_unauthenticated(unauthenticated):
29
- with pytest.raises(ValueError, match="Invalid API key"):
30
- PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
28
+ def test_open_pretrained_model_unauthenticated(unauthenticated_client):
29
+ with unauthenticated_client.use():
30
+ with pytest.raises(ValueError, match="Invalid API key"):
31
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
31
32
 
32
33
 
33
34
  def test_open_pretrained_model_not_found():
@@ -52,9 +53,10 @@ def test_embed_text():
52
53
  assert isinstance(embedding[0], float)
53
54
 
54
55
 
55
- def test_embed_text_unauthenticated(unauthenticated):
56
- with pytest.raises(ValueError, match="Invalid API key"):
57
- PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
56
+ def test_embed_text_unauthenticated(unauthenticated_client):
57
+ with unauthenticated_client.use():
58
+ with pytest.raises(ValueError, match="Invalid API key"):
59
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
58
60
 
59
61
 
60
62
  def test_evaluate_pretrained_model(datasource: Datasource):
@@ -108,9 +110,10 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
108
110
  assert new_model._status == Status.COMPLETED
109
111
 
110
112
 
111
- def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
112
- with pytest.raises(ValueError, match="Invalid API key"):
113
- PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
113
+ def test_finetune_model_unauthenticated(unauthenticated_client, datasource: Datasource):
114
+ with unauthenticated_client.use():
115
+ with pytest.raises(ValueError, match="Invalid API key"):
116
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
114
117
 
115
118
 
116
119
  def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
@@ -150,13 +153,15 @@ def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
150
153
  assert any(model.name == finetuned_model.name for model in models)
151
154
 
152
155
 
153
- def test_all_finetuned_models_unauthenticated(unauthenticated):
154
- with pytest.raises(ValueError, match="Invalid API key"):
155
- FinetunedEmbeddingModel.all()
156
+ def test_all_finetuned_models_unauthenticated(unauthenticated_client):
157
+ with unauthenticated_client.use():
158
+ with pytest.raises(ValueError, match="Invalid API key"):
159
+ FinetunedEmbeddingModel.all()
156
160
 
157
161
 
158
- def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
159
- assert finetuned_model not in FinetunedEmbeddingModel.all()
162
+ def test_all_finetuned_models_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
163
+ with unauthorized_client.use():
164
+ assert finetuned_model not in FinetunedEmbeddingModel.all()
160
165
 
161
166
 
162
167
  def test_drop_finetuned_model(datasource: Datasource):
@@ -167,9 +172,10 @@ def test_drop_finetuned_model(datasource: Datasource):
167
172
  FinetunedEmbeddingModel.open("finetuned_model_to_delete")
168
173
 
169
174
 
170
- def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
171
- with pytest.raises(ValueError, match="Invalid API key"):
172
- PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
175
+ def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
176
+ with unauthenticated_client.use():
177
+ with pytest.raises(ValueError, match="Invalid API key"):
178
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
173
179
 
174
180
 
175
181
  def test_drop_finetuned_model_not_found():
@@ -179,9 +185,10 @@ def test_drop_finetuned_model_not_found():
179
185
  FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
180
186
 
181
187
 
182
- def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
183
- with pytest.raises(LookupError):
184
- FinetunedEmbeddingModel.drop(finetuned_model.id)
188
+ def test_drop_finetuned_model_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
189
+ with unauthorized_client.use():
190
+ with pytest.raises(LookupError):
191
+ FinetunedEmbeddingModel.drop(finetuned_model.id)
185
192
 
186
193
 
187
194
  def test_supports_instructions():
orca_sdk/job.py CHANGED
@@ -7,7 +7,7 @@ from typing import Callable, Generic, TypedDict, TypeVar, cast
7
7
 
8
8
  from tqdm.auto import tqdm
9
9
 
10
- from .client import orca_api
10
+ from .client import OrcaClient
11
11
 
12
12
 
13
13
  class JobConfig(TypedDict):
@@ -140,7 +140,8 @@ class Job(Generic[TResult]):
140
140
  Returns:
141
141
  List of jobs matching the given filters
142
142
  """
143
- paginated_tasks = orca_api.GET(
143
+ client = OrcaClient._resolve_client()
144
+ paginated_tasks = client.GET(
144
145
  "/task",
145
146
  params={
146
147
  "status": (
@@ -186,11 +187,14 @@ class Job(Generic[TResult]):
186
187
  get_value: Optional function to customize how the value is resolved, if not provided the result will be a dict
187
188
  """
188
189
  self.id = id
189
- task = orca_api.GET("/task/{task_id}", params={"task_id": id})
190
+ client = OrcaClient._resolve_client()
191
+ task = client.GET("/task/{task_id}", params={"task_id": id})
190
192
 
191
- self._get_value = get_value or (
192
- lambda: cast(TResult | None, orca_api.GET("/task/{task_id}", params={"task_id": id})["result"])
193
- )
193
+ def default_get_value():
194
+ client = OrcaClient._resolve_client()
195
+ return cast(TResult | None, client.GET("/task/{task_id}", params={"task_id": id})["result"])
196
+
197
+ self._get_value = get_value or default_get_value
194
198
  self.type = task["type"]
195
199
  self.status = Status(task["status"])
196
200
  self.steps_total = task["steps_total"]
@@ -222,7 +226,8 @@ class Job(Generic[TResult]):
222
226
  return
223
227
  self.refreshed_at = current_time
224
228
 
225
- status_info = orca_api.GET("/task/{task_id}/status", params={"task_id": self.id})
229
+ client = OrcaClient._resolve_client()
230
+ status_info = client.GET("/task/{task_id}/status", params={"task_id": self.id})
226
231
  self.status = Status(status_info["status"])
227
232
  if status_info["steps_total"] is not None:
228
233
  self.steps_total = status_info["steps_total"]
@@ -333,5 +338,6 @@ def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait
333
338
  refresh_interval: Polling interval in seconds while waiting for the job to abort
334
339
  max_wait: Maximum time to wait for the job to abort in seconds
335
340
  """
336
- orca_api.DELETE("/task/{task_id}/abort", params={"task_id": self.id})
341
+ client = OrcaClient._resolve_client()
342
+ client.DELETE("/task/{task_id}/abort", params={"task_id": self.id})
337
343
  self.wait(show_progress, refresh_interval, max_wait)