orca-sdk 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orca_sdk/datasource.py CHANGED
@@ -1,14 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
4
  import tempfile
6
5
  import zipfile
7
6
  from datetime import datetime
8
7
  from io import BytesIO
9
8
  from os import PathLike
10
9
  from pathlib import Path
11
- from typing import Literal, Union, cast
10
+ from typing import Any, Literal, Union, cast
12
11
 
13
12
  import pandas as pd
14
13
  import pyarrow as pa
@@ -488,6 +487,50 @@ class Datasource:
488
487
  def __len__(self) -> int:
489
488
  return self.length
490
489
 
490
+ def query(
491
+ self,
492
+ offset: int = 0,
493
+ limit: int = 100,
494
+ shuffle: bool = False,
495
+ shuffle_seed: int | None = None,
496
+ filters: list[tuple[str, Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in", "like"], Any]] = [],
497
+ ) -> list[dict[str, Any]]:
498
+ """
499
+ Query the datasource for rows with pagination and filtering support.
500
+
501
+ Params:
502
+ offset: Number of rows to skip
503
+ limit: Maximum number of rows to return
504
+ shuffle: Whether to shuffle the dataset before pagination
505
+ shuffle_seed: Seed for shuffling (for reproducible results)
506
+ filters: List of filter tuples. Each tuple contains:
507
+ - field (str): Column name to filter on
508
+ - op (str): Operator ("==", "!=", ">", ">=", "<", "<=", "in", "not in", "like")
509
+ - value: Value to compare against
510
+
511
+ Returns:
512
+ List of rows from the datasource
513
+
514
+ Examples:
515
+ >>> datasource.query(filters=[("age", ">", 25)])
516
+ >>> datasource.query(filters=[("city", "in", ["NYC", "LA"])])
517
+ >>> datasource.query(filters=[("name", "like", "John")])
518
+ """
519
+
520
+ client = OrcaClient._resolve_client()
521
+ response = client.POST(
522
+ "/datasource/{name_or_id}/rows",
523
+ params={"name_or_id": self.id},
524
+ json={
525
+ "limit": limit,
526
+ "offset": offset,
527
+ "shuffle": shuffle,
528
+ "shuffle_seed": shuffle_seed,
529
+ "filters": [{"field": field, "op": op, "value": value} for field, op, value in filters],
530
+ },
531
+ )
532
+ return response
533
+
491
534
  def download(
492
535
  self, output_dir: str | PathLike, file_type: Literal["hf_dataset", "json", "csv"] = "hf_dataset"
493
536
  ) -> None:
@@ -301,6 +301,126 @@ def test_from_disk_already_exists():
301
301
  os.unlink(f.name)
302
302
 
303
303
 
304
+ def test_query_datasource_rows():
305
+ """Test querying rows from a datasource with pagination and shuffle."""
306
+ # Create a new dataset with 5 entries for testing
307
+ test_data = [{"id": i, "name": f"item_{i}"} for i in range(5)]
308
+ datasource = Datasource.from_list(name="test_query_datasource", data=test_data)
309
+
310
+ # Test basic query
311
+ rows = datasource.query(limit=3)
312
+ assert len(rows) == 3
313
+ assert all(isinstance(row, dict) for row in rows)
314
+
315
+ # Test offset
316
+ offset_rows = datasource.query(offset=2, limit=2)
317
+ assert len(offset_rows) == 2
318
+ assert offset_rows[0]["id"] == 2
319
+
320
+ # Test shuffle
321
+ shuffled_rows = datasource.query(limit=5, shuffle=True)
322
+ assert len(shuffled_rows) == 5
323
+ assert not all(row["id"] == i for i, row in enumerate(shuffled_rows))
324
+
325
+ # Test shuffle with seed
326
+ assert datasource.query(limit=5, shuffle=True, shuffle_seed=42) == datasource.query(
327
+ limit=5, shuffle=True, shuffle_seed=42
328
+ )
329
+
330
+
331
+ def test_query_datasource_with_filters():
332
+ """Test querying datasource rows with various filter operators."""
333
+ # Create a datasource with test data
334
+ test_data = [
335
+ {"name": "Alice", "age": 25, "city": "New York", "score": 85.5},
336
+ {"name": "Bob", "age": 30, "city": "San Francisco", "score": 90.0},
337
+ {"name": "Charlie", "age": 35, "city": "Chicago", "score": 75.5},
338
+ {"name": "Diana", "age": 28, "city": "Boston", "score": 88.0},
339
+ {"name": "Eve", "age": 32, "city": "New York", "score": 92.0},
340
+ ]
341
+ datasource = Datasource.from_list(name=f"test_filter_datasource_{uuid4()}", data=test_data)
342
+
343
+ # Test == operator
344
+ rows = datasource.query(filters=[("city", "==", "New York")])
345
+ assert len(rows) == 2
346
+ assert all(row["city"] == "New York" for row in rows)
347
+
348
+ # Test > operator
349
+ rows = datasource.query(filters=[("age", ">", 30)])
350
+ assert len(rows) == 2
351
+ assert all(row["age"] > 30 for row in rows)
352
+
353
+ # Test >= operator
354
+ rows = datasource.query(filters=[("score", ">=", 88.0)])
355
+ assert len(rows) == 3
356
+ assert all(row["score"] >= 88.0 for row in rows)
357
+
358
+ # Test < operator
359
+ rows = datasource.query(filters=[("age", "<", 30)])
360
+ assert len(rows) == 2
361
+ assert all(row["age"] < 30 for row in rows)
362
+
363
+ # Test in operator
364
+ rows = datasource.query(filters=[("city", "in", ["New York", "Boston"])])
365
+ assert len(rows) == 3
366
+ assert all(row["city"] in ["New York", "Boston"] for row in rows)
367
+
368
+ # Test not in operator
369
+ rows = datasource.query(filters=[("city", "not in", ["New York", "Boston"])])
370
+ assert len(rows) == 2
371
+ assert all(row["city"] not in ["New York", "Boston"] for row in rows)
372
+
373
+ # Test like operator
374
+ rows = datasource.query(filters=[("name", "like", "li")])
375
+ assert len(rows) == 2
376
+ assert all("li" in row["name"].lower() for row in rows)
377
+
378
+ # Test multiple filters (AND logic)
379
+ rows = datasource.query(filters=[("city", "==", "New York"), ("age", ">", 26)])
380
+ assert len(rows) == 1
381
+ assert rows[0]["name"] == "Eve"
382
+
383
+ # Test filter with pagination
384
+ rows = datasource.query(filters=[("age", ">=", 28)], limit=2, offset=1)
385
+ assert len(rows) == 2
386
+
387
+
388
+ def test_query_datasource_with_none_filters():
389
+ """Test filtering for None values."""
390
+ test_data = [
391
+ {"name": "Alice", "age": 25, "label": "A"},
392
+ {"name": "Bob", "age": 30, "label": None},
393
+ {"name": "Charlie", "age": 35, "label": "C"},
394
+ {"name": "Diana", "age": None, "label": "D"},
395
+ {"name": "Eve", "age": 32, "label": None},
396
+ ]
397
+ datasource = Datasource.from_list(name=f"test_none_filter_{uuid4()}", data=test_data)
398
+
399
+ # Test == None
400
+ rows = datasource.query(filters=[("label", "==", None)])
401
+ assert len(rows) == 2
402
+ assert all(row["label"] is None for row in rows)
403
+
404
+ # Test != None
405
+ rows = datasource.query(filters=[("label", "!=", None)])
406
+ assert len(rows) == 3
407
+ assert all(row["label"] is not None for row in rows)
408
+
409
+ # Test that None values are excluded from comparison operators
410
+ rows = datasource.query(filters=[("age", ">", 25)])
411
+ assert len(rows) == 3
412
+ assert all(row["age"] is not None and row["age"] > 25 for row in rows)
413
+
414
+
415
+ def test_query_datasource_filter_invalid_column():
416
+ """Test that querying with an invalid column raises an error."""
417
+ test_data = [{"name": "Alice", "age": 25}]
418
+ datasource = Datasource.from_list(name=f"test_invalid_filter_{uuid4()}", data=test_data)
419
+
420
+ with pytest.raises(ValueError):
421
+ datasource.query(filters=[("invalid_column", "==", "test")])
422
+
423
+
304
424
  def test_to_list(hf_dataset, datasource):
305
425
  assert datasource.to_list() == hf_dataset.to_list()
306
426
 
@@ -20,7 +20,7 @@ from .datasource import Datasource
20
20
  from .job import Job, Status
21
21
 
22
22
  if TYPE_CHECKING:
23
- from .memoryset import LabeledMemoryset
23
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
24
24
 
25
25
 
26
26
  class EmbeddingModelBase(ABC):
@@ -110,7 +110,7 @@ class EmbeddingModelBase(ABC):
110
110
  label_column: str,
111
111
  score_column: None = None,
112
112
  eval_datasource: Datasource | None = None,
113
- subsample: int | None = None,
113
+ subsample: int | float | None = None,
114
114
  neighbor_count: int = 5,
115
115
  batch_size: int = 32,
116
116
  weigh_memories: bool = True,
@@ -127,7 +127,7 @@ class EmbeddingModelBase(ABC):
127
127
  label_column: str,
128
128
  score_column: None = None,
129
129
  eval_datasource: Datasource | None = None,
130
- subsample: int | None = None,
130
+ subsample: int | float | None = None,
131
131
  neighbor_count: int = 5,
132
132
  batch_size: int = 32,
133
133
  weigh_memories: bool = True,
@@ -144,7 +144,7 @@ class EmbeddingModelBase(ABC):
144
144
  label_column: None = None,
145
145
  score_column: str,
146
146
  eval_datasource: Datasource | None = None,
147
- subsample: int | None = None,
147
+ subsample: int | float | None = None,
148
148
  neighbor_count: int = 5,
149
149
  batch_size: int = 32,
150
150
  weigh_memories: bool = True,
@@ -161,7 +161,7 @@ class EmbeddingModelBase(ABC):
161
161
  label_column: None = None,
162
162
  score_column: str,
163
163
  eval_datasource: Datasource | None = None,
164
- subsample: int | None = None,
164
+ subsample: int | float | None = None,
165
165
  neighbor_count: int = 5,
166
166
  batch_size: int = 32,
167
167
  weigh_memories: bool = True,
@@ -177,7 +177,7 @@ class EmbeddingModelBase(ABC):
177
177
  label_column: str | None = None,
178
178
  score_column: str | None = None,
179
179
  eval_datasource: Datasource | None = None,
180
- subsample: int | None = None,
180
+ subsample: int | float | None = None,
181
181
  neighbor_count: int = 5,
182
182
  batch_size: int = 32,
183
183
  weigh_memories: bool = True,
@@ -192,6 +192,7 @@ class EmbeddingModelBase(ABC):
192
192
  """
193
193
  Evaluate the finetuned embedding model
194
194
  """
195
+
195
196
  payload: EmbeddingEvaluationRequest = {
196
197
  "datasource_name_or_id": datasource.id,
197
198
  "datasource_label_column": label_column,
@@ -219,17 +220,17 @@ class EmbeddingModelBase(ABC):
219
220
  else:
220
221
  raise ValueError("Invalid embedding model")
221
222
 
222
- def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
223
+ def get_result(job_id: str) -> ClassificationMetrics | RegressionMetrics:
223
224
  client = OrcaClient._resolve_client()
224
225
  if isinstance(self, PretrainedEmbeddingModel):
225
226
  res = client.GET(
226
- "/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
227
- params={"model_name": self.name, "task_id": task_id},
227
+ "/pretrained_embedding_model/{model_name}/evaluation/{job_id}",
228
+ params={"model_name": self.name, "job_id": job_id},
228
229
  )["result"]
229
230
  elif isinstance(self, FinetunedEmbeddingModel):
230
231
  res = client.GET(
231
- "/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
232
- params={"name_or_id": self.id, "task_id": task_id},
232
+ "/finetuned_embedding_model/{name_or_id}/evaluation/{job_id}",
233
+ params={"name_or_id": self.id, "job_id": job_id},
233
234
  )["result"]
234
235
  else:
235
236
  raise ValueError("Invalid embedding model")
@@ -263,7 +264,7 @@ class EmbeddingModelBase(ABC):
263
264
  )
264
265
  )
265
266
 
266
- job = Job(response["task_id"], lambda: get_result(response["task_id"]))
267
+ job = Job(response["job_id"], lambda: get_result(response["job_id"]))
267
268
  return job if background else job.result()
268
269
 
269
270
 
@@ -462,12 +463,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
462
463
  def finetune(
463
464
  self,
464
465
  name: str,
465
- train_datasource: Datasource | LabeledMemoryset,
466
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
466
467
  *,
467
468
  eval_datasource: Datasource | None = None,
468
469
  label_column: str = "label",
470
+ score_column: str = "score",
469
471
  value_column: str = "value",
470
- training_method: EmbeddingFinetuningMethod = "classification",
472
+ training_method: EmbeddingFinetuningMethod | None = None,
471
473
  training_args: dict | None = None,
472
474
  if_exists: CreateMode = "error",
473
475
  background: Literal[True],
@@ -478,12 +480,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
478
480
  def finetune(
479
481
  self,
480
482
  name: str,
481
- train_datasource: Datasource | LabeledMemoryset,
483
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
482
484
  *,
483
485
  eval_datasource: Datasource | None = None,
484
486
  label_column: str = "label",
487
+ score_column: str = "score",
485
488
  value_column: str = "value",
486
- training_method: EmbeddingFinetuningMethod = "classification",
489
+ training_method: EmbeddingFinetuningMethod | None = None,
487
490
  training_args: dict | None = None,
488
491
  if_exists: CreateMode = "error",
489
492
  background: Literal[False] = False,
@@ -493,12 +496,13 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
493
496
  def finetune(
494
497
  self,
495
498
  name: str,
496
- train_datasource: Datasource | LabeledMemoryset,
499
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
497
500
  *,
498
501
  eval_datasource: Datasource | None = None,
499
502
  label_column: str = "label",
503
+ score_column: str = "score",
500
504
  value_column: str = "value",
501
- training_method: EmbeddingFinetuningMethod = "classification",
505
+ training_method: EmbeddingFinetuningMethod | None = None,
502
506
  training_args: dict | None = None,
503
507
  if_exists: CreateMode = "error",
504
508
  background: bool = False,
@@ -510,9 +514,10 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
510
514
  name: Name of the finetuned embedding model
511
515
  train_datasource: Data to train on
512
516
  eval_datasource: Optionally provide data to evaluate on
513
- label_column: Column name of the label
517
+ label_column: Column name of the label.
518
+ score_column: Column name of the score (for regression when training on scored data).
514
519
  value_column: Column name of the value
515
- training_method: Training method to use
520
+ training_method: Optional training method override. If omitted, Lighthouse defaults apply.
516
521
  training_args: Optional override for Hugging Face [`TrainingArguments`][transformers.TrainingArguments].
517
522
  If not provided, reasonable training arguments will be used for the specified training method
518
523
  if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
@@ -543,19 +548,22 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
543
548
 
544
549
  return existing
545
550
 
546
- from .memoryset import LabeledMemoryset
551
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
547
552
 
548
553
  payload: FinetuneEmbeddingModelRequest = {
549
554
  "name": name,
550
555
  "base_model": self.name,
551
556
  "label_column": label_column,
557
+ "score_column": score_column,
552
558
  "value_column": value_column,
553
- "training_method": training_method,
554
559
  "training_args": training_args or {},
555
560
  }
561
+ if training_method is not None:
562
+ payload["training_method"] = training_method
563
+
556
564
  if isinstance(train_datasource, Datasource):
557
565
  payload["train_datasource_name_or_id"] = train_datasource.id
558
- elif isinstance(train_datasource, LabeledMemoryset):
566
+ elif isinstance(train_datasource, (LabeledMemoryset, ScoredMemoryset)):
559
567
  payload["train_memoryset_name_or_id"] = train_datasource.id
560
568
  if eval_datasource is not None:
561
569
  payload["eval_datasource_name_or_id"] = eval_datasource.id
@@ -566,7 +574,7 @@ class PretrainedEmbeddingModel(EmbeddingModelBase):
566
574
  json=payload,
567
575
  )
568
576
  job = Job(
569
- res["finetuning_task_id"],
577
+ res["finetuning_job_id"],
570
578
  lambda: FinetunedEmbeddingModel.open(res["id"]),
571
579
  )
572
580
  return job if background else job.result()
orca_sdk/job.py CHANGED
@@ -17,7 +17,7 @@ class JobConfig(TypedDict):
17
17
 
18
18
 
19
19
  class Status(Enum):
20
- """Status of a cloud job in the task queue"""
20
+ """Status of a cloud job in the job queue"""
21
21
 
22
22
  # the INITIALIZED state should never be returned by the API
23
23
  INITIALIZED = "INITIALIZED"
@@ -141,8 +141,8 @@ class Job(Generic[TResult]):
141
141
  List of jobs matching the given filters
142
142
  """
143
143
  client = OrcaClient._resolve_client()
144
- paginated_tasks = client.GET(
145
- "/task",
144
+ paginated_jobs = client.GET(
145
+ "/job",
146
146
  params={
147
147
  "status": (
148
148
  [s.value for s in status]
@@ -175,7 +175,7 @@ class Job(Generic[TResult]):
175
175
  obj,
176
176
  )[-1]
177
177
  )(t)
178
- for t in paginated_tasks["items"]
178
+ for t in paginated_jobs["items"]
179
179
  ]
180
180
 
181
181
  def __init__(self, id: str, get_value: Callable[[], TResult | None] | None = None):
@@ -188,29 +188,29 @@ class Job(Generic[TResult]):
188
188
  """
189
189
  self.id = id
190
190
  client = OrcaClient._resolve_client()
191
- task = client.GET("/task/{task_id}", params={"task_id": id})
191
+ job = client.GET("/job/{job_id}", params={"job_id": id})
192
192
 
193
193
  def default_get_value():
194
194
  client = OrcaClient._resolve_client()
195
- return cast(TResult | None, client.GET("/task/{task_id}", params={"task_id": id})["result"])
195
+ return cast(TResult | None, client.GET("/job/{job_id}", params={"job_id": id})["result"])
196
196
 
197
197
  self._get_value = get_value or default_get_value
198
- self.type = task["type"]
199
- self.status = Status(task["status"])
200
- self.steps_total = task["steps_total"]
201
- self.steps_completed = task["steps_completed"]
202
- self.exception = task["exception"]
198
+ self.type = job["type"]
199
+ self.status = Status(job["status"])
200
+ self.steps_total = job["steps_total"]
201
+ self.steps_completed = job["steps_completed"]
202
+ self.exception = job["exception"]
203
203
  self.value = (
204
204
  None
205
- if task["status"] != "COMPLETED"
205
+ if job["status"] != "COMPLETED"
206
206
  else (
207
207
  get_value()
208
208
  if get_value is not None
209
- else cast(TResult, task["result"]) if task["result"] is not None else None
209
+ else cast(TResult, job["result"]) if job["result"] is not None else None
210
210
  )
211
211
  )
212
- self.updated_at = datetime.fromisoformat(task["updated_at"])
213
- self.created_at = datetime.fromisoformat(task["created_at"])
212
+ self.updated_at = datetime.fromisoformat(job["updated_at"])
213
+ self.created_at = datetime.fromisoformat(job["created_at"])
214
214
  self.refreshed_at = datetime.now()
215
215
 
216
216
  def refresh(self, throttle: float = 0):
@@ -227,7 +227,7 @@ class Job(Generic[TResult]):
227
227
  self.refreshed_at = current_time
228
228
 
229
229
  client = OrcaClient._resolve_client()
230
- status_info = client.GET("/task/{task_id}/status", params={"task_id": self.id})
230
+ status_info = client.GET("/job/{job_id}/status", params={"job_id": self.id})
231
231
  self.status = Status(status_info["status"])
232
232
  if status_info["steps_total"] is not None:
233
233
  self.steps_total = status_info["steps_total"]
@@ -339,5 +339,5 @@ def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait
339
339
  max_wait: Maximum time to wait for the job to abort in seconds
340
340
  """
341
341
  client = OrcaClient._resolve_client()
342
- client.DELETE("/task/{task_id}/abort", params={"task_id": self.id})
342
+ client.DELETE("/job/{job_id}/abort", params={"job_id": self.id})
343
343
  self.wait(show_progress, refresh_interval, max_wait)