orca-sdk 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. orca_sdk/__init__.py +30 -0
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +634 -0
  4. orca_sdk/_shared/metrics_test.py +570 -0
  5. orca_sdk/_utils/__init__.py +0 -0
  6. orca_sdk/_utils/analysis_ui.py +196 -0
  7. orca_sdk/_utils/analysis_ui_style.css +51 -0
  8. orca_sdk/_utils/auth.py +65 -0
  9. orca_sdk/_utils/auth_test.py +31 -0
  10. orca_sdk/_utils/common.py +37 -0
  11. orca_sdk/_utils/data_parsing.py +129 -0
  12. orca_sdk/_utils/data_parsing_test.py +244 -0
  13. orca_sdk/_utils/pagination.py +126 -0
  14. orca_sdk/_utils/pagination_test.py +132 -0
  15. orca_sdk/_utils/prediction_result_ui.css +18 -0
  16. orca_sdk/_utils/prediction_result_ui.py +110 -0
  17. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  18. orca_sdk/_utils/value_parser.py +45 -0
  19. orca_sdk/_utils/value_parser_test.py +39 -0
  20. orca_sdk/async_client.py +4104 -0
  21. orca_sdk/classification_model.py +1165 -0
  22. orca_sdk/classification_model_test.py +887 -0
  23. orca_sdk/client.py +4096 -0
  24. orca_sdk/conftest.py +382 -0
  25. orca_sdk/credentials.py +217 -0
  26. orca_sdk/credentials_test.py +121 -0
  27. orca_sdk/datasource.py +576 -0
  28. orca_sdk/datasource_test.py +463 -0
  29. orca_sdk/embedding_model.py +712 -0
  30. orca_sdk/embedding_model_test.py +206 -0
  31. orca_sdk/job.py +343 -0
  32. orca_sdk/job_test.py +108 -0
  33. orca_sdk/memoryset.py +3811 -0
  34. orca_sdk/memoryset_test.py +1150 -0
  35. orca_sdk/regression_model.py +841 -0
  36. orca_sdk/regression_model_test.py +595 -0
  37. orca_sdk/telemetry.py +742 -0
  38. orca_sdk/telemetry_test.py +119 -0
  39. orca_sdk-0.1.9.dist-info/METADATA +98 -0
  40. orca_sdk-0.1.9.dist-info/RECORD +41 -0
  41. orca_sdk-0.1.9.dist-info/WHEEL +4 -0
@@ -0,0 +1,206 @@
1
+ import logging
2
+ from typing import get_args
3
+ from uuid import uuid4
4
+
5
+ import pytest
6
+
7
+ from .datasource import Datasource
8
+ from .embedding_model import (
9
+ ClassificationMetrics,
10
+ FinetunedEmbeddingModel,
11
+ PretrainedEmbeddingModel,
12
+ PretrainedEmbeddingModelName,
13
+ )
14
+ from .job import Status
15
+ from .memoryset import LabeledMemoryset
16
+
17
+
18
+ def test_open_pretrained_model():
19
+ model = PretrainedEmbeddingModel.GTE_BASE
20
+ assert model is not None
21
+ assert isinstance(model, PretrainedEmbeddingModel)
22
+ assert model.name == "GTE_BASE"
23
+ assert model.embedding_dim == 768
24
+ assert model.max_seq_length == 8192
25
+ assert model is PretrainedEmbeddingModel.GTE_BASE
26
+
27
+
28
+ def test_open_pretrained_model_unauthenticated(unauthenticated_client):
29
+ with unauthenticated_client.use():
30
+ with pytest.raises(ValueError, match="Invalid API key"):
31
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
32
+
33
+
34
+ def test_open_pretrained_model_not_found():
35
+ with pytest.raises(LookupError):
36
+ PretrainedEmbeddingModel._get("INVALID_MODEL") # type: ignore
37
+
38
+
39
+ def test_all_pretrained_models():
40
+ models = PretrainedEmbeddingModel.all()
41
+ assert len(models) > 1
42
+ if len(models) != len(get_args(PretrainedEmbeddingModelName)):
43
+ logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
44
+ model_names = [m.name for m in models]
45
+ assert all(m in model_names for m in get_args(PretrainedEmbeddingModelName))
46
+
47
+
48
+ def test_embed_text():
49
+ embedding = PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
50
+ assert embedding is not None
51
+ assert isinstance(embedding, list)
52
+ assert len(embedding) == 768
53
+ assert isinstance(embedding[0], float)
54
+
55
+
56
+ def test_embed_text_unauthenticated(unauthenticated_client):
57
+ with unauthenticated_client.use():
58
+ with pytest.raises(ValueError, match="Invalid API key"):
59
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
60
+
61
+
62
+ def test_evaluate_pretrained_model(datasource: Datasource):
63
+ metrics = PretrainedEmbeddingModel.GTE_BASE.evaluate(datasource=datasource, label_column="label")
64
+ assert metrics is not None
65
+ assert isinstance(metrics, ClassificationMetrics)
66
+ assert metrics.accuracy > 0.5
67
+
68
+
69
+ @pytest.fixture(scope="session")
70
+ def finetuned_model(datasource) -> FinetunedEmbeddingModel:
71
+ return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
72
+
73
+
74
+ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
75
+ assert finetuned_model is not None
76
+ assert finetuned_model.name == "test_finetuned_model"
77
+ assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
78
+ assert finetuned_model.embedding_dim == 768
79
+ assert finetuned_model.max_seq_length == 512
80
+ assert finetuned_model._status == Status.COMPLETED
81
+
82
+
83
+ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
84
+ finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
85
+ "test_finetuned_model_from_memoryset", readonly_memoryset
86
+ )
87
+ assert finetuned_model is not None
88
+ assert finetuned_model.name == "test_finetuned_model_from_memoryset"
89
+ assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
90
+ assert finetuned_model.embedding_dim == 768
91
+ assert finetuned_model.max_seq_length == 512
92
+ assert finetuned_model._status == Status.COMPLETED
93
+
94
+
95
+ def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
96
+ with pytest.raises(ValueError):
97
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
98
+
99
+
100
+ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
101
+ with pytest.raises(ValueError):
102
+ PretrainedEmbeddingModel.GTE_BASE.finetune("test_finetuned_model", datasource, if_exists="open")
103
+
104
+ new_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="open")
105
+ assert new_model is not None
106
+ assert new_model.name == "test_finetuned_model"
107
+ assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
108
+ assert new_model.embedding_dim == 768
109
+ assert new_model.max_seq_length == 512
110
+ assert new_model._status == Status.COMPLETED
111
+
112
+
113
+ def test_finetune_model_unauthenticated(unauthenticated_client, datasource: Datasource):
114
+ with unauthenticated_client.use():
115
+ with pytest.raises(ValueError, match="Invalid API key"):
116
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
117
+
118
+
119
+ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
120
+ memoryset = LabeledMemoryset.create(
121
+ "test_memoryset_finetuned_model",
122
+ datasource=datasource,
123
+ embedding_model=finetuned_model,
124
+ )
125
+ assert memoryset is not None
126
+ assert memoryset.name == "test_memoryset_finetuned_model"
127
+ assert memoryset.embedding_model == finetuned_model
128
+ assert memoryset.length == datasource.length
129
+
130
+
131
+ def test_open_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
132
+ model = FinetunedEmbeddingModel.open(finetuned_model.name)
133
+ assert isinstance(model, FinetunedEmbeddingModel)
134
+ assert model.id == finetuned_model.id
135
+ assert model.name == finetuned_model.name
136
+ assert model.base_model == PretrainedEmbeddingModel.DISTILBERT
137
+ assert model.embedding_dim == 768
138
+ assert model.max_seq_length == 512
139
+ assert model == finetuned_model
140
+
141
+
142
+ def test_embed_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
143
+ embedding = finetuned_model.embed("I love this airline")
144
+ assert embedding is not None
145
+ assert isinstance(embedding, list)
146
+ assert len(embedding) == 768
147
+ assert isinstance(embedding[0], float)
148
+
149
+
150
+ def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
151
+ models = FinetunedEmbeddingModel.all()
152
+ assert len(models) > 0
153
+ assert any(model.name == finetuned_model.name for model in models)
154
+
155
+
156
+ def test_all_finetuned_models_unauthenticated(unauthenticated_client):
157
+ with unauthenticated_client.use():
158
+ with pytest.raises(ValueError, match="Invalid API key"):
159
+ FinetunedEmbeddingModel.all()
160
+
161
+
162
+ def test_all_finetuned_models_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
163
+ with unauthorized_client.use():
164
+ assert finetuned_model not in FinetunedEmbeddingModel.all()
165
+
166
+
167
+ def test_drop_finetuned_model(datasource: Datasource):
168
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
169
+ assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
170
+ FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
171
+ with pytest.raises(LookupError):
172
+ FinetunedEmbeddingModel.open("finetuned_model_to_delete")
173
+
174
+
175
+ def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
176
+ with unauthenticated_client.use():
177
+ with pytest.raises(ValueError, match="Invalid API key"):
178
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
179
+
180
+
181
+ def test_drop_finetuned_model_not_found():
182
+ with pytest.raises(LookupError):
183
+ FinetunedEmbeddingModel.drop(str(uuid4()))
184
+ # ignores error if specified
185
+ FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
186
+
187
+
188
+ def test_drop_finetuned_model_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
189
+ with unauthorized_client.use():
190
+ with pytest.raises(LookupError):
191
+ FinetunedEmbeddingModel.drop(finetuned_model.id)
192
+
193
+
194
+ def test_supports_instructions():
195
+ model = PretrainedEmbeddingModel.GTE_BASE
196
+ assert not model.supports_instructions
197
+
198
+ instruction_model = PretrainedEmbeddingModel.BGE_BASE
199
+ assert instruction_model.supports_instructions
200
+
201
+
202
+ def test_use_explicit_instruction_prompt():
203
+ model = PretrainedEmbeddingModel.BGE_BASE
204
+ assert model.supports_instructions
205
+ input = "Hello world"
206
+ assert model.embed(input, prompt="Represent this sentence for sentiment retrieval:") != model.embed(input)
orca_sdk/job.py ADDED
@@ -0,0 +1,343 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from datetime import datetime, timedelta
5
+ from enum import Enum
6
+ from typing import Callable, Generic, TypedDict, TypeVar, cast
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ from .client import OrcaClient
11
+
12
+
13
+ class JobConfig(TypedDict):
14
+ refresh_interval: int
15
+ show_progress: bool
16
+ max_wait: int
17
+
18
+
19
+ class Status(Enum):
20
+ """Status of a cloud job in the job queue"""
21
+
22
+ # the INITIALIZED state should never be returned by the API
23
+ INITIALIZED = "INITIALIZED"
24
+ """The job has been initialized"""
25
+
26
+ DISPATCHED = "DISPATCHED"
27
+ """The job has been queued and is waiting to be processed"""
28
+
29
+ WAITING = "WAITING"
30
+ """The job is waiting for dependencies to complete"""
31
+
32
+ PROCESSING = "PROCESSING"
33
+ """The job is being processed"""
34
+
35
+ COMPLETED = "COMPLETED"
36
+ """The job has been completed successfully"""
37
+
38
+ FAILED = "FAILED"
39
+ """The job has failed"""
40
+
41
+ ABORTING = "ABORTING"
42
+ """The job is being aborted"""
43
+
44
+ ABORTED = "ABORTED"
45
+ """The job has been aborted"""
46
+
47
+
48
+ TResult = TypeVar("TResult")
49
+
50
+
51
+ class Job(Generic[TResult]):
52
+ """
53
+ Handle to a job that is run in the OrcaCloud
54
+
55
+ Attributes:
56
+ id: Unique identifier for the job
57
+ type: Type of the job
58
+ status: Current status of the job
59
+ steps_total: Total number of steps in the job, present if the job started processing
60
+ steps_completed: Number of steps completed in the job, present if the job started processing
61
+ completion: Percentage of the job that has been completed, present if the job started processing
62
+ exception: Exception that occurred during the job, present if the status is `FAILED`
63
+ value: Value of the result of the job, present if the status is `COMPLETED`
64
+ created_at: When the job was queued for processing
65
+ updated_at: When the job was last updated
66
+ refreshed_at: When the job status was last refreshed
67
+
68
+ Note:
69
+ Accessing status and related attributes will refresh the job status in the background.
70
+ """
71
+
72
+ id: str
73
+ type: str
74
+ status: Status
75
+ steps_total: int | None
76
+ steps_completed: int | None
77
+ exception: str | None
78
+ value: TResult | None
79
+ updated_at: datetime
80
+ created_at: datetime
81
+ refreshed_at: datetime
82
+
83
+ @property
84
+ def completion(self) -> float:
85
+ """
86
+ Percentage of the job that has been completed, present if the job started processing
87
+ """
88
+ return (self.steps_completed or 0) / self.steps_total if self.steps_total is not None else 0
89
+
90
+ # Global configuration for all jobs
91
+ config: JobConfig = {
92
+ "refresh_interval": 3,
93
+ "show_progress": True,
94
+ "max_wait": 60 * 60,
95
+ }
96
+
97
+ def __repr__(self) -> str:
98
+ return "Job({" + f" type: {self.type}, status: {self.status}, completion: {self.completion:.0%} " + "})"
99
+
100
+ @classmethod
101
+ def set_config(
102
+ cls, *, refresh_interval: int | None = None, show_progress: bool | None = None, max_wait: int | None = None
103
+ ):
104
+ """
105
+ Set global configuration for running jobs
106
+
107
+ Args:
108
+ refresh_interval: Time to wait between polling the job status in seconds, default is 3
109
+ show_progress: Whether to show a progress bar when calling the wait method, default is True
110
+ max_wait: Maximum time to wait for the job to complete in seconds, default is 1 hour
111
+ """
112
+ if refresh_interval is not None:
113
+ cls.config["refresh_interval"] = refresh_interval
114
+ if show_progress is not None:
115
+ cls.config["show_progress"] = show_progress
116
+ if max_wait is not None:
117
+ cls.config["max_wait"] = max_wait
118
+
119
+ @classmethod
120
+ def query(
121
+ cls,
122
+ status: Status | list[Status] | None = None,
123
+ type: str | list[str] | None = None,
124
+ limit: int = 100,
125
+ offset: int = 0,
126
+ start: datetime | None = None,
127
+ end: datetime | None = None,
128
+ ) -> list[Job]:
129
+ """
130
+ Query the job queue for jobs matching the given filters
131
+
132
+ Args:
133
+ status: Optional status or list of statuses to filter by
134
+ type: Optional type or list of types to filter by
135
+ limit: Maximum number of jobs to return
136
+ offset: Offset into the list of jobs to return
137
+ start: Optional minimum creation time of the jobs to query for
138
+ end: Optional maximum creation time of the jobs to query for
139
+
140
+ Returns:
141
+ List of jobs matching the given filters
142
+ """
143
+ client = OrcaClient._resolve_client()
144
+ paginated_jobs = client.GET(
145
+ "/job",
146
+ params={
147
+ "status": (
148
+ [s.value for s in status]
149
+ if isinstance(status, list)
150
+ else status.value if isinstance(status, Status) else None
151
+ ),
152
+ "type": type,
153
+ "limit": limit,
154
+ "offset": offset,
155
+ "start_timestamp": start.isoformat() if start is not None else None,
156
+ "end_timestamp": end.isoformat() if end is not None else None,
157
+ },
158
+ )
159
+
160
+ # can't use constructor because it makes an API call, so we construct the objects manually
161
+ return [
162
+ (
163
+ lambda t: (
164
+ obj := cls.__new__(cls),
165
+ setattr(obj, "id", t["id"]),
166
+ setattr(obj, "type", t["type"]),
167
+ setattr(obj, "status", Status(t["status"])),
168
+ setattr(obj, "steps_total", t["steps_total"]),
169
+ setattr(obj, "steps_completed", t["steps_completed"]),
170
+ setattr(obj, "exception", t["exception"]),
171
+ setattr(obj, "value", cast(TResult, t["result"]) if t["result"] is not None else None),
172
+ setattr(obj, "updated_at", datetime.fromisoformat(t["updated_at"])),
173
+ setattr(obj, "created_at", datetime.fromisoformat(t["created_at"])),
174
+ setattr(obj, "refreshed_at", datetime.now()),
175
+ obj,
176
+ )[-1]
177
+ )(t)
178
+ for t in paginated_jobs["items"]
179
+ ]
180
+
181
+ def __init__(self, id: str, get_value: Callable[[], TResult | None] | None = None):
182
+ """
183
+ Create a handle to a job in the job queue
184
+
185
+ Args:
186
+ id: Unique identifier for the job
187
+ get_value: Optional function to customize how the value is resolved, if not provided the result will be a dict
188
+ """
189
+ self.id = id
190
+ client = OrcaClient._resolve_client()
191
+ job = client.GET("/job/{job_id}", params={"job_id": id})
192
+
193
+ def default_get_value():
194
+ client = OrcaClient._resolve_client()
195
+ return cast(TResult | None, client.GET("/job/{job_id}", params={"job_id": id})["result"])
196
+
197
+ self._get_value = get_value or default_get_value
198
+ self.type = job["type"]
199
+ self.status = Status(job["status"])
200
+ self.steps_total = job["steps_total"]
201
+ self.steps_completed = job["steps_completed"]
202
+ self.exception = job["exception"]
203
+ self.value = (
204
+ None
205
+ if job["status"] != "COMPLETED"
206
+ else (
207
+ get_value()
208
+ if get_value is not None
209
+ else cast(TResult, job["result"]) if job["result"] is not None else None
210
+ )
211
+ )
212
+ self.updated_at = datetime.fromisoformat(job["updated_at"])
213
+ self.created_at = datetime.fromisoformat(job["created_at"])
214
+ self.refreshed_at = datetime.now()
215
+
216
+ def refresh(self, throttle: float = 0):
217
+ """
218
+ Refresh the status and progress of the job
219
+
220
+ Params:
221
+ throttle: Minimum time in seconds between refreshes
222
+ """
223
+ current_time = datetime.now()
224
+ # Skip refresh if last refresh was too recent
225
+ if (current_time - self.refreshed_at) < timedelta(seconds=throttle):
226
+ return
227
+ self.refreshed_at = current_time
228
+
229
+ client = OrcaClient._resolve_client()
230
+ status_info = client.GET("/job/{job_id}/status", params={"job_id": self.id})
231
+ self.status = Status(status_info["status"])
232
+ if status_info["steps_total"] is not None:
233
+ self.steps_total = status_info["steps_total"]
234
+ if status_info["steps_completed"] is not None:
235
+ self.steps_completed = status_info["steps_completed"]
236
+
237
+ self.exception = status_info["exception"]
238
+ self.updated_at = datetime.fromisoformat(status_info["updated_at"])
239
+
240
+ if status_info["status"] == "COMPLETED":
241
+ self.value = self._get_value()
242
+
243
+ def __getattribute__(self, name: str):
244
+ # if the attribute is not immutable, refresh the job if it hasn't been refreshed recently
245
+ if name in ["status", "updated_at", "steps_total", "steps_completed", "exception", "value"]:
246
+ self.refresh(self.config["refresh_interval"])
247
+ return super().__getattribute__(name)
248
+
249
+ def wait(
250
+ self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
251
+ ) -> None:
252
+ """
253
+ Block until the job is complete
254
+
255
+ Params:
256
+ show_progress: Show a progress bar while waiting for the job to complete
257
+ refresh_interval: Polling interval in seconds while waiting for the job to complete
258
+ max_wait: Maximum time to wait for the job to complete in seconds
259
+
260
+ Note:
261
+ The defaults for the config parameters can be set globally using the
262
+ [`set_config`][orca_sdk.Job.set_config] method.
263
+
264
+ This method will not return the result or raise an exception if the job fails. Call
265
+ [`result`][orca_sdk.Job.result] instead if you want to get the result.
266
+
267
+ Raises:
268
+ RuntimeError: If the job times out
269
+ """
270
+ start_time = time.time()
271
+ show_progress = show_progress if show_progress is not None else self.config["show_progress"]
272
+ refresh_interval = refresh_interval if refresh_interval is not None else self.config["refresh_interval"]
273
+ max_wait = max_wait if max_wait is not None else self.config["max_wait"]
274
+ pbar = None
275
+ while True:
276
+ # setup progress bar if steps total is known
277
+ if not pbar and self.steps_total is not None and show_progress:
278
+ desc = " ".join(self.type.split("_")).lower()
279
+ pbar = tqdm(total=self.steps_total, desc=desc)
280
+
281
+ # return if job is complete
282
+ if self.status in [Status.COMPLETED, Status.FAILED, Status.ABORTED]:
283
+ if pbar:
284
+ pbar.update(self.steps_total - pbar.n)
285
+ pbar.close()
286
+ return
287
+
288
+ # raise error if job timed out
289
+ if (time.time() - start_time) > max_wait:
290
+ raise RuntimeError(f"Job {self.id} timed out after {max_wait}s")
291
+
292
+ # update progress bar
293
+ if pbar and self.steps_completed is not None:
294
+ pbar.update(self.steps_completed - pbar.n)
295
+
296
+ # sleep before retrying
297
+ time.sleep(refresh_interval)
298
+
299
+ def result(
300
+ self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
301
+ ) -> TResult:
302
+ """
303
+ Block until the job is complete and return the result value
304
+
305
+ Params:
306
+ show_progress: Show a progress bar while waiting for the job to complete
307
+ refresh_interval: Polling interval in seconds while waiting for the job to complete
308
+ max_wait: Maximum time to wait for the job to complete in seconds
309
+
310
+ Note:
311
+ The defaults for the config parameters can be set globally using the
312
+ [`set_config`][orca_sdk.Job.set_config] method.
313
+
314
+ This method will raise an exception if the job fails. Use [`wait`][orca_sdk.Job.wait]
315
+ if you just want to wait for the job to complete without raising errors on failure.
316
+
317
+ Returns:
318
+ The result value of the job
319
+
320
+ Raises:
321
+ RuntimeError: If the job fails or times out
322
+ """
323
+ if self.value is not None:
324
+ return self.value
325
+ self.wait(show_progress, refresh_interval, max_wait)
326
+ if self.status != Status.COMPLETED:
327
+ raise RuntimeError(f"Job failed with exception: {self.exception}")
328
+ assert self.value is not None
329
+ return self.value
330
+
331
+
332
+ def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait: int = 20) -> None:
333
+ """
334
+ Abort the job
335
+
336
+ Params:
337
+ show_progress: Optionally show a progress bar while waiting for the job to abort
338
+ refresh_interval: Polling interval in seconds while waiting for the job to abort
339
+ max_wait: Maximum time to wait for the job to abort in seconds
340
+ """
341
+ client = OrcaClient._resolve_client()
342
+ client.DELETE("/job/{job_id}/abort", params={"job_id": self.id})
343
+ self.wait(show_progress, refresh_interval, max_wait)
orca_sdk/job_test.py ADDED
@@ -0,0 +1,108 @@
1
+ import time
2
+
3
+ import pytest
4
+ from datasets import Dataset
5
+
6
+ from .classification_model import ClassificationModel
7
+ from .datasource import Datasource
8
+ from .job import Job, Status
9
+
10
+
11
+ @pytest.fixture(scope="session")
12
+ def datasource_without_nones(hf_dataset: Dataset):
13
+ return Datasource.from_hf_dataset(
14
+ "test_datasource_without_nones", hf_dataset.filter(lambda x: x["label"] is not None)
15
+ )
16
+
17
+
18
+ def wait_for_jobs_status(job_ids, expected_statuses, timeout=10, poll_interval=0.2):
19
+ """
20
+ Wait until all jobs reach one of the expected statuses or timeout is reached.
21
+ """
22
+ start = time.time()
23
+ while time.time() - start < timeout:
24
+ jobs = [Job(job_id) for job_id in job_ids]
25
+ if all(job.status in expected_statuses for job in jobs):
26
+ return
27
+ time.sleep(poll_interval)
28
+ raise TimeoutError(f"Jobs did not reach statuses {expected_statuses} within {timeout} seconds")
29
+
30
+
31
+ def test_job_creation(classification_model: ClassificationModel, datasource_without_nones: Datasource):
32
+ job = classification_model.evaluate(datasource_without_nones, background=True)
33
+ assert job.id is not None
34
+ assert job.type == "EVALUATE_MODEL"
35
+ assert job.status in [Status.DISPATCHED, Status.PROCESSING]
36
+ assert job.created_at is not None
37
+ assert job.updated_at is not None
38
+ assert job.refreshed_at is not None
39
+ assert len(Job.query(limit=5, type="EVALUATE_MODEL")) >= 1
40
+
41
+
42
+ def test_job_result(classification_model: ClassificationModel, datasource_without_nones: Datasource):
43
+ job = classification_model.evaluate(datasource_without_nones, background=True)
44
+ result = job.result(show_progress=False)
45
+ assert result is not None
46
+ assert job.status == Status.COMPLETED
47
+ assert job.steps_completed is not None
48
+ assert job.steps_completed == job.steps_total
49
+
50
+
51
+ def test_job_wait(classification_model: ClassificationModel, datasource_without_nones: Datasource):
52
+ job = classification_model.evaluate(datasource_without_nones, background=True)
53
+ job.wait(show_progress=False)
54
+ assert job.status == Status.COMPLETED
55
+ assert job.steps_completed is not None
56
+ assert job.steps_completed == job.steps_total
57
+ assert job.value is not None
58
+
59
+
60
+ def test_job_refresh(classification_model: ClassificationModel, datasource_without_nones: Datasource):
61
+ job = classification_model.evaluate(datasource_without_nones, background=True)
62
+ last_refreshed_at = job.refreshed_at
63
+ # accessing the status attribute should refresh the job after the refresh interval
64
+ Job.set_config(refresh_interval=1)
65
+ time.sleep(1)
66
+ job.status
67
+ assert job.refreshed_at > last_refreshed_at
68
+ last_refreshed_at = job.refreshed_at
69
+ # calling refresh() should immediately refresh the job
70
+ job.refresh()
71
+ assert job.refreshed_at > last_refreshed_at
72
+
73
+
74
+ def test_job_query_pagination(classification_model: ClassificationModel, datasource_without_nones: Datasource):
75
+ """Test pagination with Job.query() method"""
76
+ # Create multiple jobs to test pagination
77
+ jobs_created = []
78
+ for i in range(3):
79
+ job = classification_model.evaluate(datasource_without_nones, background=True)
80
+ jobs_created.append(job.id)
81
+
82
+ # Wait for jobs to be at least PROCESSING or COMPLETED
83
+ wait_for_jobs_status(jobs_created, expected_statuses=[Status.PROCESSING, Status.COMPLETED])
84
+
85
+ # Test basic pagination with limit
86
+ jobs_page1 = Job.query(type="EVALUATE_MODEL", limit=2)
87
+ assert len(jobs_page1) == 2
88
+
89
+ # Test pagination with offset
90
+ jobs_page2 = Job.query(type="EVALUATE_MODEL", limit=2, offset=1)
91
+ assert len(jobs_page2) == 2
92
+
93
+ # Verify different pages contain different jobs (allowing for some overlap due to timing)
94
+ page1_ids = {job.id for job in jobs_page1}
95
+ page2_ids = {job.id for job in jobs_page2}
96
+
97
+ # At least one job should be different between pages
98
+ assert len(page1_ids.symmetric_difference(page2_ids)) > 0
99
+
100
+ # Test filtering by status
101
+ jobs_by_status = Job.query(status=Status.PROCESSING, limit=10)
102
+ for job in jobs_by_status:
103
+ assert job.status == Status.PROCESSING
104
+
105
+ # Test filtering by multiple statuses
106
+ multi_status_jobs = Job.query(status=[Status.PROCESSING, Status.COMPLETED], limit=10)
107
+ for job in multi_status_jobs:
108
+ assert job.status in [Status.PROCESSING, Status.COMPLETED]