orca-sdk 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,9 +19,10 @@ def test_create_datasource(datasource, hf_dataset):
19
19
  assert datasource.length == len(hf_dataset)
20
20
 
21
21
 
22
- def test_create_datasource_unauthenticated(unauthenticated, hf_dataset):
23
- with pytest.raises(ValueError, match="Invalid API key"):
24
- Datasource.from_hf_dataset("test_datasource", hf_dataset)
22
+ def test_create_datasource_unauthenticated(unauthenticated_client, hf_dataset):
23
+ with unauthenticated_client.use():
24
+ with pytest.raises(ValueError, match="Invalid API key"):
25
+ Datasource.from_hf_dataset("test_datasource", hf_dataset)
25
26
 
26
27
 
27
28
  def test_create_datasource_already_exists_error(hf_dataset, datasource):
@@ -43,9 +44,10 @@ def test_open_datasource(datasource):
43
44
  assert fetched_datasource.length == len(datasource)
44
45
 
45
46
 
46
- def test_open_datasource_unauthenticated(datasource, unauthenticated):
47
- with pytest.raises(ValueError, match="Invalid API key"):
48
- Datasource.open("test_datasource")
47
+ def test_open_datasource_unauthenticated(unauthenticated_client, datasource):
48
+ with unauthenticated_client.use():
49
+ with pytest.raises(ValueError, match="Invalid API key"):
50
+ Datasource.open("test_datasource")
49
51
 
50
52
 
51
53
  def test_open_datasource_invalid_input():
@@ -58,9 +60,10 @@ def test_open_datasource_not_found():
58
60
  Datasource.open(str(uuid4()))
59
61
 
60
62
 
61
- def test_open_datasource_unauthorized(datasource, unauthorized):
62
- with pytest.raises(LookupError):
63
- Datasource.open(datasource.id)
63
+ def test_open_datasource_unauthorized(unauthorized_client, datasource):
64
+ with unauthorized_client.use():
65
+ with pytest.raises(LookupError):
66
+ Datasource.open(datasource.id)
64
67
 
65
68
 
66
69
  def test_all_datasources(datasource):
@@ -69,9 +72,10 @@ def test_all_datasources(datasource):
69
72
  assert any(datasource.name == datasource.name for datasource in datasources)
70
73
 
71
74
 
72
- def test_all_datasources_unauthenticated(unauthenticated):
73
- with pytest.raises(ValueError, match="Invalid API key"):
74
- Datasource.all()
75
+ def test_all_datasources_unauthenticated(unauthenticated_client):
76
+ with unauthenticated_client.use():
77
+ with pytest.raises(ValueError, match="Invalid API key"):
78
+ Datasource.all()
75
79
 
76
80
 
77
81
  def test_drop_datasource(hf_dataset):
@@ -81,9 +85,10 @@ def test_drop_datasource(hf_dataset):
81
85
  assert not Datasource.exists("datasource_to_delete")
82
86
 
83
87
 
84
- def test_drop_datasource_unauthenticated(datasource, unauthenticated):
85
- with pytest.raises(ValueError, match="Invalid API key"):
86
- Datasource.drop(datasource.id)
88
+ def test_drop_datasource_unauthenticated(datasource, unauthenticated_client):
89
+ with unauthenticated_client.use():
90
+ with pytest.raises(ValueError, match="Invalid API key"):
91
+ Datasource.drop(datasource.id)
87
92
 
88
93
 
89
94
  def test_drop_datasource_not_found():
@@ -93,9 +98,10 @@ def test_drop_datasource_not_found():
93
98
  Datasource.drop(str(uuid4()), if_not_exists="ignore")
94
99
 
95
100
 
96
- def test_drop_datasource_unauthorized(datasource, unauthorized):
97
- with pytest.raises(LookupError):
98
- Datasource.drop(datasource.id)
101
+ def test_drop_datasource_unauthorized(datasource, unauthorized_client):
102
+ with unauthorized_client.use():
103
+ with pytest.raises(LookupError):
104
+ Datasource.drop(datasource.id)
99
105
 
100
106
 
101
107
  def test_drop_datasource_invalid_input():
@@ -295,6 +301,126 @@ def test_from_disk_already_exists():
295
301
  os.unlink(f.name)
296
302
 
297
303
 
304
+ def test_query_datasource_rows():
305
+ """Test querying rows from a datasource with pagination and shuffle."""
306
+ # Create a new dataset with 5 entries for testing
307
+ test_data = [{"id": i, "name": f"item_{i}"} for i in range(5)]
308
+ datasource = Datasource.from_list(name="test_query_datasource", data=test_data)
309
+
310
+ # Test basic query
311
+ rows = datasource.query(limit=3)
312
+ assert len(rows) == 3
313
+ assert all(isinstance(row, dict) for row in rows)
314
+
315
+ # Test offset
316
+ offset_rows = datasource.query(offset=2, limit=2)
317
+ assert len(offset_rows) == 2
318
+ assert offset_rows[0]["id"] == 2
319
+
320
+ # Test shuffle
321
+ shuffled_rows = datasource.query(limit=5, shuffle=True)
322
+ assert len(shuffled_rows) == 5
323
+ assert not all(row["id"] == i for i, row in enumerate(shuffled_rows))
324
+
325
+ # Test shuffle with seed
326
+ assert datasource.query(limit=5, shuffle=True, shuffle_seed=42) == datasource.query(
327
+ limit=5, shuffle=True, shuffle_seed=42
328
+ )
329
+
330
+
331
+ def test_query_datasource_with_filters():
332
+ """Test querying datasource rows with various filter operators."""
333
+ # Create a datasource with test data
334
+ test_data = [
335
+ {"name": "Alice", "age": 25, "city": "New York", "score": 85.5},
336
+ {"name": "Bob", "age": 30, "city": "San Francisco", "score": 90.0},
337
+ {"name": "Charlie", "age": 35, "city": "Chicago", "score": 75.5},
338
+ {"name": "Diana", "age": 28, "city": "Boston", "score": 88.0},
339
+ {"name": "Eve", "age": 32, "city": "New York", "score": 92.0},
340
+ ]
341
+ datasource = Datasource.from_list(name=f"test_filter_datasource_{uuid4()}", data=test_data)
342
+
343
+ # Test == operator
344
+ rows = datasource.query(filters=[("city", "==", "New York")])
345
+ assert len(rows) == 2
346
+ assert all(row["city"] == "New York" for row in rows)
347
+
348
+ # Test > operator
349
+ rows = datasource.query(filters=[("age", ">", 30)])
350
+ assert len(rows) == 2
351
+ assert all(row["age"] > 30 for row in rows)
352
+
353
+ # Test >= operator
354
+ rows = datasource.query(filters=[("score", ">=", 88.0)])
355
+ assert len(rows) == 3
356
+ assert all(row["score"] >= 88.0 for row in rows)
357
+
358
+ # Test < operator
359
+ rows = datasource.query(filters=[("age", "<", 30)])
360
+ assert len(rows) == 2
361
+ assert all(row["age"] < 30 for row in rows)
362
+
363
+ # Test in operator
364
+ rows = datasource.query(filters=[("city", "in", ["New York", "Boston"])])
365
+ assert len(rows) == 3
366
+ assert all(row["city"] in ["New York", "Boston"] for row in rows)
367
+
368
+ # Test not in operator
369
+ rows = datasource.query(filters=[("city", "not in", ["New York", "Boston"])])
370
+ assert len(rows) == 2
371
+ assert all(row["city"] not in ["New York", "Boston"] for row in rows)
372
+
373
+ # Test like operator
374
+ rows = datasource.query(filters=[("name", "like", "li")])
375
+ assert len(rows) == 2
376
+ assert all("li" in row["name"].lower() for row in rows)
377
+
378
+ # Test multiple filters (AND logic)
379
+ rows = datasource.query(filters=[("city", "==", "New York"), ("age", ">", 26)])
380
+ assert len(rows) == 1
381
+ assert rows[0]["name"] == "Eve"
382
+
383
+ # Test filter with pagination
384
+ rows = datasource.query(filters=[("age", ">=", 28)], limit=2, offset=1)
385
+ assert len(rows) == 2
386
+
387
+
388
+ def test_query_datasource_with_none_filters():
389
+ """Test filtering for None values."""
390
+ test_data = [
391
+ {"name": "Alice", "age": 25, "label": "A"},
392
+ {"name": "Bob", "age": 30, "label": None},
393
+ {"name": "Charlie", "age": 35, "label": "C"},
394
+ {"name": "Diana", "age": None, "label": "D"},
395
+ {"name": "Eve", "age": 32, "label": None},
396
+ ]
397
+ datasource = Datasource.from_list(name=f"test_none_filter_{uuid4()}", data=test_data)
398
+
399
+ # Test == None
400
+ rows = datasource.query(filters=[("label", "==", None)])
401
+ assert len(rows) == 2
402
+ assert all(row["label"] is None for row in rows)
403
+
404
+ # Test != None
405
+ rows = datasource.query(filters=[("label", "!=", None)])
406
+ assert len(rows) == 3
407
+ assert all(row["label"] is not None for row in rows)
408
+
409
+ # Test that None values are excluded from comparison operators
410
+ rows = datasource.query(filters=[("age", ">", 25)])
411
+ assert len(rows) == 3
412
+ assert all(row["age"] is not None and row["age"] > 25 for row in rows)
413
+
414
+
415
+ def test_query_datasource_filter_invalid_column():
416
+ """Test that querying with an invalid column raises an error."""
417
+ test_data = [{"name": "Alice", "age": 25}]
418
+ datasource = Datasource.from_list(name=f"test_invalid_filter_{uuid4()}", data=test_data)
419
+
420
+ with pytest.raises(ValueError):
421
+ datasource.query(filters=[("invalid_column", "==", "test")])
422
+
423
+
298
424
  def test_to_list(hf_dataset, datasource):
299
425
  assert datasource.to_list() == hf_dataset.to_list()
300
426
 
@@ -12,15 +12,15 @@ from .client import (
12
12
  EmbedRequest,
13
13
  FinetunedEmbeddingModelMetadata,
14
14
  FinetuneEmbeddingModelRequest,
15
+ OrcaClient,
15
16
  PretrainedEmbeddingModelMetadata,
16
17
  PretrainedEmbeddingModelName,
17
- orca_api,
18
18
  )
19
19
  from .datasource import Datasource
20
20
  from .job import Job, Status
21
21
 
22
22
  if TYPE_CHECKING:
23
- from .memoryset import LabeledMemoryset
23
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
24
24
 
25
25
 
26
26
  class EmbeddingModelBase(ABC):
@@ -82,15 +82,16 @@ class EmbeddingModelBase(ABC):
82
82
  "max_seq_length": max_seq_length,
83
83
  "prompt": prompt,
84
84
  }
85
+ client = OrcaClient._resolve_client()
85
86
  if isinstance(self, PretrainedEmbeddingModel):
86
- embeddings = orca_api.POST(
87
+ embeddings = client.POST(
87
88
  "/gpu/pretrained_embedding_model/{model_name}/embedding",
88
89
  params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
89
90
  json=payload,
90
91
  timeout=30, # may be slow in case of cold start
91
92
  )
92
93
  elif isinstance(self, FinetunedEmbeddingModel):
93
- embeddings = orca_api.POST(
94
+ embeddings = client.POST(
94
95
  "/gpu/finetuned_embedding_model/{name_or_id}/embedding",
95
96
  params={"name_or_id": self.id},
96
97
  json=payload,
@@ -109,7 +110,7 @@ class EmbeddingModelBase(ABC):
109
110
  label_column: str,
110
111
  score_column: None = None,
111
112
  eval_datasource: Datasource | None = None,
112
- subsample: int | None = None,
113
+ subsample: int | float | None = None,
113
114
  neighbor_count: int = 5,
114
115
  batch_size: int = 32,
115
116
  weigh_memories: bool = True,
@@ -126,7 +127,7 @@ class EmbeddingModelBase(ABC):
126
127
  label_column: str,
127
128
  score_column: None = None,
128
129
  eval_datasource: Datasource | None = None,
129
- subsample: int | None = None,
130
+ subsample: int | float | None = None,
130
131
  neighbor_count: int = 5,
131
132
  batch_size: int = 32,
132
133
  weigh_memories: bool = True,
@@ -143,7 +144,7 @@ class EmbeddingModelBase(ABC):
143
144
  label_column: None = None,
144
145
  score_column: str,
145
146
  eval_datasource: Datasource | None = None,
146
- subsample: int | None = None,
147
+ subsample: int | float | None = None,
147
148
  neighbor_count: int = 5,
148
149
  batch_size: int = 32,
149
150
  weigh_memories: bool = True,
@@ -160,7 +161,7 @@ class EmbeddingModelBase(ABC):
160
161
  label_column: None = None,
161
162
  score_column: str,
162
163
  eval_datasource: Datasource | None = None,
163
- subsample: int | None = None,
164
+ subsample: int | float | None = None,
164
165
  neighbor_count: int = 5,
165
166
  batch_size: int = 32,
166
167
  weigh_memories: bool = True,
@@ -176,7 +177,7 @@ class EmbeddingModelBase(ABC):
176
177
  label_column: str | None = None,
177
178
  score_column: str | None = None,
178
179
  eval_datasource: Datasource | None = None,
179
- subsample: int | None = None,
180
+ subsample: int | float | None = None,
180
181
  neighbor_count: int = 5,
181
182
  batch_size: int = 32,
182
183
  weigh_memories: bool = True,
@@ -191,6 +192,7 @@ class EmbeddingModelBase(ABC):
191
192
  """
192
193
  Evaluate the finetuned embedding model
193
194
  """
195
+
194
196
  payload: EmbeddingEvaluationRequest = {
195
197
  "datasource_name_or_id": datasource.id,
196
198
  "datasource_label_column": label_column,
@@ -202,14 +204,15 @@ class EmbeddingModelBase(ABC):
202
204
  "batch_size": batch_size,
203
205
  "weigh_memories": weigh_memories,
204
206
  }
207
+ client = OrcaClient._resolve_client()
205
208
  if isinstance(self, PretrainedEmbeddingModel):
206
- response = orca_api.POST(
209
+ response = client.POST(
207
210
  "/pretrained_embedding_model/{model_name}/evaluation",
208
211
  params={"model_name": self.name},
209
212
  json=payload,
210
213
  )
211
214
  elif isinstance(self, FinetunedEmbeddingModel):
212
- response = orca_api.POST(
215
+ response = client.POST(
213
216
  "/finetuned_embedding_model/{name_or_id}/evaluation",
214
217
  params={"name_or_id": self.id},
215
218
  json=payload,
@@ -217,16 +220,17 @@ class EmbeddingModelBase(ABC):
217
220
  else:
218
221
  raise ValueError("Invalid embedding model")
219
222
 
220
- def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
223
+ def get_result(job_id: str) -> ClassificationMetrics | RegressionMetrics:
224
+ client = OrcaClient._resolve_client()
221
225
  if isinstance(self, PretrainedEmbeddingModel):
222
- res = orca_api.GET(
223
- "/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
224
- params={"model_name": self.name, "task_id": task_id},
226
+ res = client.GET(
227
+ "/pretrained_embedding_model/{model_name}/evaluation/{job_id}",
228
+ params={"model_name": self.name, "job_id": job_id},
225
229
  )["result"]
226
230
  elif isinstance(self, FinetunedEmbeddingModel):
227
- res = orca_api.GET(
228
- "/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
229
- params={"name_or_id": self.id, "task_id": task_id},
231
+ res = client.GET(
232
+ "/finetuned_embedding_model/{name_or_id}/evaluation/{job_id}",
233
+ params={"name_or_id": self.id, "job_id": job_id},
230
234
  )["result"]
231
235
  else:
232
236
  raise ValueError("Invalid embedding model")
@@ -260,7 +264,7 @@ class EmbeddingModelBase(ABC):
260
264
  )
261
265
  )
262
266
 
263
- job = Job(response["task_id"], lambda: get_result(response["task_id"]))
267
+ job = Job(response["job_id"], lambda: get_result(response["job_id"]))
264
268
  return job if background else job.result()
265
269
 
266
270
 
@@ -401,7 +405,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
401
405
  Returns:
402
406
  A list of all pretrained embedding models available in the OrcaCloud
403
407
  """
404
- return [cls(metadata) for metadata in orca_api.GET("/pretrained_embedding_model")]
408
+ client = OrcaClient._resolve_client()
409
+ return [cls(metadata) for metadata in client.GET("/pretrained_embedding_model")]
405
410
 
406
411
  _instances: dict[str, PretrainedEmbeddingModel] = {}
407
412
 
@@ -410,7 +415,8 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
410
415
  # for internal use only, do not document - we want people to use dot notation to get the model
411
416
  cache_key = str(name)
412
417
  if cache_key not in cls._instances:
413
- metadata = orca_api.GET(
418
+ client = OrcaClient._resolve_client()
419
+ metadata = client.GET(
414
420
  "/pretrained_embedding_model/{model_name}",
415
421
  params={"model_name": name},
416
422
  )
@@ -457,12 +463,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
457
463
  def finetune(
458
464
  self,
459
465
  name: str,
460
- train_datasource: Datasource | LabeledMemoryset,
466
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
461
467
  *,
462
468
  eval_datasource: Datasource | None = None,
463
469
  label_column: str = "label",
470
+ score_column: str = "score",
464
471
  value_column: str = "value",
465
- training_method: EmbeddingFinetuningMethod = "classification",
472
+ training_method: EmbeddingFinetuningMethod | None = None,
466
473
  training_args: dict | None = None,
467
474
  if_exists: CreateMode = "error",
468
475
  background: Literal[True],
@@ -473,12 +480,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
473
480
  def finetune(
474
481
  self,
475
482
  name: str,
476
- train_datasource: Datasource | LabeledMemoryset,
483
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
477
484
  *,
478
485
  eval_datasource: Datasource | None = None,
479
486
  label_column: str = "label",
487
+ score_column: str = "score",
480
488
  value_column: str = "value",
481
- training_method: EmbeddingFinetuningMethod = "classification",
489
+ training_method: EmbeddingFinetuningMethod | None = None,
482
490
  training_args: dict | None = None,
483
491
  if_exists: CreateMode = "error",
484
492
  background: Literal[False] = False,
@@ -488,12 +496,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
488
496
  def finetune(
489
497
  self,
490
498
  name: str,
491
- train_datasource: Datasource | LabeledMemoryset,
499
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
492
500
  *,
493
501
  eval_datasource: Datasource | None = None,
494
502
  label_column: str = "label",
503
+ score_column: str = "score",
495
504
  value_column: str = "value",
496
- training_method: EmbeddingFinetuningMethod = "classification",
505
+ training_method: EmbeddingFinetuningMethod | None = None,
497
506
  training_args: dict | None = None,
498
507
  if_exists: CreateMode = "error",
499
508
  background: bool = False,
@@ -505,9 +514,10 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
505
514
  name: Name of the finetuned embedding model
506
515
  train_datasource: Data to train on
507
516
  eval_datasource: Optionally provide data to evaluate on
508
- label_column: Column name of the label
517
+ label_column: Column name of the label.
518
+ score_column: Column name of the score (for regression when training on scored data).
509
519
  value_column: Column name of the value
510
- training_method: Training method to use
520
+ training_method: Optional training method override. If omitted, Lighthouse defaults apply.
511
521
  training_args: Optional override for Hugging Face [`TrainingArguments`][transformers.TrainingArguments].
512
522
  If not provided, reasonable training arguments will be used for the specified training method
513
523
  if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
@@ -538,29 +548,33 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
538
548
 
539
549
  return existing
540
550
 
541
- from .memoryset import LabeledMemoryset
551
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
542
552
 
543
553
  payload: FinetuneEmbeddingModelRequest = {
544
554
  "name": name,
545
555
  "base_model": self.name,
546
556
  "label_column": label_column,
557
+ "score_column": score_column,
547
558
  "value_column": value_column,
548
- "training_method": training_method,
549
559
  "training_args": training_args or {},
550
560
  }
561
+ if training_method is not None:
562
+ payload["training_method"] = training_method
563
+
551
564
  if isinstance(train_datasource, Datasource):
552
565
  payload["train_datasource_name_or_id"] = train_datasource.id
553
- elif isinstance(train_datasource, LabeledMemoryset):
566
+ elif isinstance(train_datasource, (LabeledMemoryset, ScoredMemoryset)):
554
567
  payload["train_memoryset_name_or_id"] = train_datasource.id
555
568
  if eval_datasource is not None:
556
569
  payload["eval_datasource_name_or_id"] = eval_datasource.id
557
570
 
558
- res = orca_api.POST(
571
+ client = OrcaClient._resolve_client()
572
+ res = client.POST(
559
573
  "/finetuned_embedding_model",
560
574
  json=payload,
561
575
  )
562
576
  job = Job(
563
- res["finetuning_task_id"],
577
+ res["finetuning_job_id"],
564
578
  lambda: FinetunedEmbeddingModel.open(res["id"]),
565
579
  )
566
580
  return job if background else job.result()
@@ -630,7 +644,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
630
644
  Returns:
631
645
  A list of all finetuned embedding model handles in the OrcaCloud
632
646
  """
633
- return [cls(metadata) for metadata in orca_api.GET("/finetuned_embedding_model")]
647
+ client = OrcaClient._resolve_client()
648
+ return [cls(metadata) for metadata in client.GET("/finetuned_embedding_model")]
634
649
 
635
650
  @classmethod
636
651
  def open(cls, name: str) -> FinetunedEmbeddingModel:
@@ -646,7 +661,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
646
661
  Raises:
647
662
  LookupError: If the finetuned embedding model does not exist
648
663
  """
649
- metadata = orca_api.GET(
664
+ client = OrcaClient._resolve_client()
665
+ metadata = client.GET(
650
666
  "/finetuned_embedding_model/{name_or_id}",
651
667
  params={"name_or_id": name},
652
668
  )
@@ -681,7 +697,8 @@ class FinetunedEmbeddingModel(EmbeddingModelBase):
681
697
  LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
682
698
  """
683
699
  try:
684
- orca_api.DELETE(
700
+ client = OrcaClient._resolve_client()
701
+ client.DELETE(
685
702
  "/finetuned_embedding_model/{name_or_id}",
686
703
  params={"name_or_id": name_or_id},
687
704
  )
@@ -25,9 +25,10 @@ def test_open_pretrained_model():
25
25
  assert model is PretrainedEmbeddingModel.GTE_BASE
26
26
 
27
27
 
28
- def test_open_pretrained_model_unauthenticated(unauthenticated):
29
- with pytest.raises(ValueError, match="Invalid API key"):
30
- PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
28
+ def test_open_pretrained_model_unauthenticated(unauthenticated_client):
29
+ with unauthenticated_client.use():
30
+ with pytest.raises(ValueError, match="Invalid API key"):
31
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
31
32
 
32
33
 
33
34
  def test_open_pretrained_model_not_found():
@@ -52,9 +53,10 @@ def test_embed_text():
52
53
  assert isinstance(embedding[0], float)
53
54
 
54
55
 
55
- def test_embed_text_unauthenticated(unauthenticated):
56
- with pytest.raises(ValueError, match="Invalid API key"):
57
- PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
56
+ def test_embed_text_unauthenticated(unauthenticated_client):
57
+ with unauthenticated_client.use():
58
+ with pytest.raises(ValueError, match="Invalid API key"):
59
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
58
60
 
59
61
 
60
62
  def test_evaluate_pretrained_model(datasource: Datasource):
@@ -108,9 +110,10 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
108
110
  assert new_model._status == Status.COMPLETED
109
111
 
110
112
 
111
- def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
112
- with pytest.raises(ValueError, match="Invalid API key"):
113
- PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
113
+ def test_finetune_model_unauthenticated(unauthenticated_client, datasource: Datasource):
114
+ with unauthenticated_client.use():
115
+ with pytest.raises(ValueError, match="Invalid API key"):
116
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
114
117
 
115
118
 
116
119
  def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
@@ -150,13 +153,15 @@ def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
150
153
  assert any(model.name == finetuned_model.name for model in models)
151
154
 
152
155
 
153
- def test_all_finetuned_models_unauthenticated(unauthenticated):
154
- with pytest.raises(ValueError, match="Invalid API key"):
155
- FinetunedEmbeddingModel.all()
156
+ def test_all_finetuned_models_unauthenticated(unauthenticated_client):
157
+ with unauthenticated_client.use():
158
+ with pytest.raises(ValueError, match="Invalid API key"):
159
+ FinetunedEmbeddingModel.all()
156
160
 
157
161
 
158
- def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
159
- assert finetuned_model not in FinetunedEmbeddingModel.all()
162
+ def test_all_finetuned_models_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
163
+ with unauthorized_client.use():
164
+ assert finetuned_model not in FinetunedEmbeddingModel.all()
160
165
 
161
166
 
162
167
  def test_drop_finetuned_model(datasource: Datasource):
@@ -167,9 +172,10 @@ def test_drop_finetuned_model(datasource: Datasource):
167
172
  FinetunedEmbeddingModel.open("finetuned_model_to_delete")
168
173
 
169
174
 
170
- def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
171
- with pytest.raises(ValueError, match="Invalid API key"):
172
- PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
175
+ def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
176
+ with unauthenticated_client.use():
177
+ with pytest.raises(ValueError, match="Invalid API key"):
178
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
173
179
 
174
180
 
175
181
  def test_drop_finetuned_model_not_found():
@@ -179,9 +185,10 @@ def test_drop_finetuned_model_not_found():
179
185
  FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
180
186
 
181
187
 
182
- def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
183
- with pytest.raises(LookupError):
184
- FinetunedEmbeddingModel.drop(finetuned_model.id)
188
+ def test_drop_finetuned_model_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
189
+ with unauthorized_client.use():
190
+ with pytest.raises(LookupError):
191
+ FinetunedEmbeddingModel.drop(finetuned_model.id)
185
192
 
186
193
 
187
194
  def test_supports_instructions():