orca-sdk 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +30 -0
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +634 -0
- orca_sdk/_shared/metrics_test.py +570 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +196 -0
- orca_sdk/_utils/analysis_ui_style.css +51 -0
- orca_sdk/_utils/auth.py +65 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +129 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +110 -0
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +4104 -0
- orca_sdk/classification_model.py +1165 -0
- orca_sdk/classification_model_test.py +887 -0
- orca_sdk/client.py +4096 -0
- orca_sdk/conftest.py +382 -0
- orca_sdk/credentials.py +217 -0
- orca_sdk/credentials_test.py +121 -0
- orca_sdk/datasource.py +576 -0
- orca_sdk/datasource_test.py +463 -0
- orca_sdk/embedding_model.py +712 -0
- orca_sdk/embedding_model_test.py +206 -0
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +3811 -0
- orca_sdk/memoryset_test.py +1150 -0
- orca_sdk/regression_model.py +841 -0
- orca_sdk/regression_model_test.py +595 -0
- orca_sdk/telemetry.py +742 -0
- orca_sdk/telemetry_test.py +119 -0
- orca_sdk-0.1.9.dist-info/METADATA +98 -0
- orca_sdk-0.1.9.dist-info/RECORD +41 -0
- orca_sdk-0.1.9.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import get_args
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from .datasource import Datasource
|
|
8
|
+
from .embedding_model import (
|
|
9
|
+
ClassificationMetrics,
|
|
10
|
+
FinetunedEmbeddingModel,
|
|
11
|
+
PretrainedEmbeddingModel,
|
|
12
|
+
PretrainedEmbeddingModelName,
|
|
13
|
+
)
|
|
14
|
+
from .job import Status
|
|
15
|
+
from .memoryset import LabeledMemoryset
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_open_pretrained_model():
|
|
19
|
+
model = PretrainedEmbeddingModel.GTE_BASE
|
|
20
|
+
assert model is not None
|
|
21
|
+
assert isinstance(model, PretrainedEmbeddingModel)
|
|
22
|
+
assert model.name == "GTE_BASE"
|
|
23
|
+
assert model.embedding_dim == 768
|
|
24
|
+
assert model.max_seq_length == 8192
|
|
25
|
+
assert model is PretrainedEmbeddingModel.GTE_BASE
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_open_pretrained_model_unauthenticated(unauthenticated_client):
|
|
29
|
+
with unauthenticated_client.use():
|
|
30
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
31
|
+
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_open_pretrained_model_not_found():
|
|
35
|
+
with pytest.raises(LookupError):
|
|
36
|
+
PretrainedEmbeddingModel._get("INVALID_MODEL") # type: ignore
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_all_pretrained_models():
|
|
40
|
+
models = PretrainedEmbeddingModel.all()
|
|
41
|
+
assert len(models) > 1
|
|
42
|
+
if len(models) != len(get_args(PretrainedEmbeddingModelName)):
|
|
43
|
+
logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
|
|
44
|
+
model_names = [m.name for m in models]
|
|
45
|
+
assert all(m in model_names for m in get_args(PretrainedEmbeddingModelName))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_embed_text():
|
|
49
|
+
embedding = PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
|
|
50
|
+
assert embedding is not None
|
|
51
|
+
assert isinstance(embedding, list)
|
|
52
|
+
assert len(embedding) == 768
|
|
53
|
+
assert isinstance(embedding[0], float)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def test_embed_text_unauthenticated(unauthenticated_client):
|
|
57
|
+
with unauthenticated_client.use():
|
|
58
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
59
|
+
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_evaluate_pretrained_model(datasource: Datasource):
|
|
63
|
+
metrics = PretrainedEmbeddingModel.GTE_BASE.evaluate(datasource=datasource, label_column="label")
|
|
64
|
+
assert metrics is not None
|
|
65
|
+
assert isinstance(metrics, ClassificationMetrics)
|
|
66
|
+
assert metrics.accuracy > 0.5
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.fixture(scope="session")
|
|
70
|
+
def finetuned_model(datasource) -> FinetunedEmbeddingModel:
|
|
71
|
+
return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
|
|
75
|
+
assert finetuned_model is not None
|
|
76
|
+
assert finetuned_model.name == "test_finetuned_model"
|
|
77
|
+
assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
78
|
+
assert finetuned_model.embedding_dim == 768
|
|
79
|
+
assert finetuned_model.max_seq_length == 512
|
|
80
|
+
assert finetuned_model._status == Status.COMPLETED
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
84
|
+
finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
|
|
85
|
+
"test_finetuned_model_from_memoryset", readonly_memoryset
|
|
86
|
+
)
|
|
87
|
+
assert finetuned_model is not None
|
|
88
|
+
assert finetuned_model.name == "test_finetuned_model_from_memoryset"
|
|
89
|
+
assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
90
|
+
assert finetuned_model.embedding_dim == 768
|
|
91
|
+
assert finetuned_model.max_seq_length == 512
|
|
92
|
+
assert finetuned_model._status == Status.COMPLETED
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
|
|
96
|
+
with pytest.raises(ValueError):
|
|
97
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
|
|
101
|
+
with pytest.raises(ValueError):
|
|
102
|
+
PretrainedEmbeddingModel.GTE_BASE.finetune("test_finetuned_model", datasource, if_exists="open")
|
|
103
|
+
|
|
104
|
+
new_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="open")
|
|
105
|
+
assert new_model is not None
|
|
106
|
+
assert new_model.name == "test_finetuned_model"
|
|
107
|
+
assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
108
|
+
assert new_model.embedding_dim == 768
|
|
109
|
+
assert new_model.max_seq_length == 512
|
|
110
|
+
assert new_model._status == Status.COMPLETED
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_finetune_model_unauthenticated(unauthenticated_client, datasource: Datasource):
|
|
114
|
+
with unauthenticated_client.use():
|
|
115
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
116
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
|
|
120
|
+
memoryset = LabeledMemoryset.create(
|
|
121
|
+
"test_memoryset_finetuned_model",
|
|
122
|
+
datasource=datasource,
|
|
123
|
+
embedding_model=finetuned_model,
|
|
124
|
+
)
|
|
125
|
+
assert memoryset is not None
|
|
126
|
+
assert memoryset.name == "test_memoryset_finetuned_model"
|
|
127
|
+
assert memoryset.embedding_model == finetuned_model
|
|
128
|
+
assert memoryset.length == datasource.length
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_open_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
|
|
132
|
+
model = FinetunedEmbeddingModel.open(finetuned_model.name)
|
|
133
|
+
assert isinstance(model, FinetunedEmbeddingModel)
|
|
134
|
+
assert model.id == finetuned_model.id
|
|
135
|
+
assert model.name == finetuned_model.name
|
|
136
|
+
assert model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
137
|
+
assert model.embedding_dim == 768
|
|
138
|
+
assert model.max_seq_length == 512
|
|
139
|
+
assert model == finetuned_model
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def test_embed_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
|
|
143
|
+
embedding = finetuned_model.embed("I love this airline")
|
|
144
|
+
assert embedding is not None
|
|
145
|
+
assert isinstance(embedding, list)
|
|
146
|
+
assert len(embedding) == 768
|
|
147
|
+
assert isinstance(embedding[0], float)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
|
|
151
|
+
models = FinetunedEmbeddingModel.all()
|
|
152
|
+
assert len(models) > 0
|
|
153
|
+
assert any(model.name == finetuned_model.name for model in models)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def test_all_finetuned_models_unauthenticated(unauthenticated_client):
|
|
157
|
+
with unauthenticated_client.use():
|
|
158
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
159
|
+
FinetunedEmbeddingModel.all()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_all_finetuned_models_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
|
|
163
|
+
with unauthorized_client.use():
|
|
164
|
+
assert finetuned_model not in FinetunedEmbeddingModel.all()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def test_drop_finetuned_model(datasource: Datasource):
|
|
168
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
169
|
+
assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
170
|
+
FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
|
|
171
|
+
with pytest.raises(LookupError):
|
|
172
|
+
FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def test_drop_finetuned_model_unauthenticated(unauthenticated_client, datasource: Datasource):
|
|
176
|
+
with unauthenticated_client.use():
|
|
177
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
178
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def test_drop_finetuned_model_not_found():
|
|
182
|
+
with pytest.raises(LookupError):
|
|
183
|
+
FinetunedEmbeddingModel.drop(str(uuid4()))
|
|
184
|
+
# ignores error if specified
|
|
185
|
+
FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_drop_finetuned_model_unauthorized(unauthorized_client, finetuned_model: FinetunedEmbeddingModel):
|
|
189
|
+
with unauthorized_client.use():
|
|
190
|
+
with pytest.raises(LookupError):
|
|
191
|
+
FinetunedEmbeddingModel.drop(finetuned_model.id)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def test_supports_instructions():
|
|
195
|
+
model = PretrainedEmbeddingModel.GTE_BASE
|
|
196
|
+
assert not model.supports_instructions
|
|
197
|
+
|
|
198
|
+
instruction_model = PretrainedEmbeddingModel.BGE_BASE
|
|
199
|
+
assert instruction_model.supports_instructions
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def test_use_explicit_instruction_prompt():
|
|
203
|
+
model = PretrainedEmbeddingModel.BGE_BASE
|
|
204
|
+
assert model.supports_instructions
|
|
205
|
+
input = "Hello world"
|
|
206
|
+
assert model.embed(input, prompt="Represent this sentence for sentiment retrieval:") != model.embed(input)
|
orca_sdk/job.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from datetime import datetime, timedelta
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Callable, Generic, TypedDict, TypeVar, cast
|
|
7
|
+
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
|
|
10
|
+
from .client import OrcaClient
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class JobConfig(TypedDict):
|
|
14
|
+
refresh_interval: int
|
|
15
|
+
show_progress: bool
|
|
16
|
+
max_wait: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Status(Enum):
|
|
20
|
+
"""Status of a cloud job in the job queue"""
|
|
21
|
+
|
|
22
|
+
# the INITIALIZED state should never be returned by the API
|
|
23
|
+
INITIALIZED = "INITIALIZED"
|
|
24
|
+
"""The job has been initialized"""
|
|
25
|
+
|
|
26
|
+
DISPATCHED = "DISPATCHED"
|
|
27
|
+
"""The job has been queued and is waiting to be processed"""
|
|
28
|
+
|
|
29
|
+
WAITING = "WAITING"
|
|
30
|
+
"""The job is waiting for dependencies to complete"""
|
|
31
|
+
|
|
32
|
+
PROCESSING = "PROCESSING"
|
|
33
|
+
"""The job is being processed"""
|
|
34
|
+
|
|
35
|
+
COMPLETED = "COMPLETED"
|
|
36
|
+
"""The job has been completed successfully"""
|
|
37
|
+
|
|
38
|
+
FAILED = "FAILED"
|
|
39
|
+
"""The job has failed"""
|
|
40
|
+
|
|
41
|
+
ABORTING = "ABORTING"
|
|
42
|
+
"""The job is being aborted"""
|
|
43
|
+
|
|
44
|
+
ABORTED = "ABORTED"
|
|
45
|
+
"""The job has been aborted"""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
TResult = TypeVar("TResult")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Job(Generic[TResult]):
|
|
52
|
+
"""
|
|
53
|
+
Handle to a job that is run in the OrcaCloud
|
|
54
|
+
|
|
55
|
+
Attributes:
|
|
56
|
+
id: Unique identifier for the job
|
|
57
|
+
type: Type of the job
|
|
58
|
+
status: Current status of the job
|
|
59
|
+
steps_total: Total number of steps in the job, present if the job started processing
|
|
60
|
+
steps_completed: Number of steps completed in the job, present if the job started processing
|
|
61
|
+
completion: Percentage of the job that has been completed, present if the job started processing
|
|
62
|
+
exception: Exception that occurred during the job, present if the status is `FAILED`
|
|
63
|
+
value: Value of the result of the job, present if the status is `COMPLETED`
|
|
64
|
+
created_at: When the job was queued for processing
|
|
65
|
+
updated_at: When the job was last updated
|
|
66
|
+
refreshed_at: When the job status was last refreshed
|
|
67
|
+
|
|
68
|
+
Note:
|
|
69
|
+
Accessing status and related attributes will refresh the job status in the background.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
id: str
|
|
73
|
+
type: str
|
|
74
|
+
status: Status
|
|
75
|
+
steps_total: int | None
|
|
76
|
+
steps_completed: int | None
|
|
77
|
+
exception: str | None
|
|
78
|
+
value: TResult | None
|
|
79
|
+
updated_at: datetime
|
|
80
|
+
created_at: datetime
|
|
81
|
+
refreshed_at: datetime
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def completion(self) -> float:
|
|
85
|
+
"""
|
|
86
|
+
Percentage of the job that has been completed, present if the job started processing
|
|
87
|
+
"""
|
|
88
|
+
return (self.steps_completed or 0) / self.steps_total if self.steps_total is not None else 0
|
|
89
|
+
|
|
90
|
+
# Global configuration for all jobs
|
|
91
|
+
config: JobConfig = {
|
|
92
|
+
"refresh_interval": 3,
|
|
93
|
+
"show_progress": True,
|
|
94
|
+
"max_wait": 60 * 60,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def __repr__(self) -> str:
|
|
98
|
+
return "Job({" + f" type: {self.type}, status: {self.status}, completion: {self.completion:.0%} " + "})"
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def set_config(
|
|
102
|
+
cls, *, refresh_interval: int | None = None, show_progress: bool | None = None, max_wait: int | None = None
|
|
103
|
+
):
|
|
104
|
+
"""
|
|
105
|
+
Set global configuration for running jobs
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
refresh_interval: Time to wait between polling the job status in seconds, default is 3
|
|
109
|
+
show_progress: Whether to show a progress bar when calling the wait method, default is True
|
|
110
|
+
max_wait: Maximum time to wait for the job to complete in seconds, default is 1 hour
|
|
111
|
+
"""
|
|
112
|
+
if refresh_interval is not None:
|
|
113
|
+
cls.config["refresh_interval"] = refresh_interval
|
|
114
|
+
if show_progress is not None:
|
|
115
|
+
cls.config["show_progress"] = show_progress
|
|
116
|
+
if max_wait is not None:
|
|
117
|
+
cls.config["max_wait"] = max_wait
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def query(
|
|
121
|
+
cls,
|
|
122
|
+
status: Status | list[Status] | None = None,
|
|
123
|
+
type: str | list[str] | None = None,
|
|
124
|
+
limit: int = 100,
|
|
125
|
+
offset: int = 0,
|
|
126
|
+
start: datetime | None = None,
|
|
127
|
+
end: datetime | None = None,
|
|
128
|
+
) -> list[Job]:
|
|
129
|
+
"""
|
|
130
|
+
Query the job queue for jobs matching the given filters
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
status: Optional status or list of statuses to filter by
|
|
134
|
+
type: Optional type or list of types to filter by
|
|
135
|
+
limit: Maximum number of jobs to return
|
|
136
|
+
offset: Offset into the list of jobs to return
|
|
137
|
+
start: Optional minimum creation time of the jobs to query for
|
|
138
|
+
end: Optional maximum creation time of the jobs to query for
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
List of jobs matching the given filters
|
|
142
|
+
"""
|
|
143
|
+
client = OrcaClient._resolve_client()
|
|
144
|
+
paginated_jobs = client.GET(
|
|
145
|
+
"/job",
|
|
146
|
+
params={
|
|
147
|
+
"status": (
|
|
148
|
+
[s.value for s in status]
|
|
149
|
+
if isinstance(status, list)
|
|
150
|
+
else status.value if isinstance(status, Status) else None
|
|
151
|
+
),
|
|
152
|
+
"type": type,
|
|
153
|
+
"limit": limit,
|
|
154
|
+
"offset": offset,
|
|
155
|
+
"start_timestamp": start.isoformat() if start is not None else None,
|
|
156
|
+
"end_timestamp": end.isoformat() if end is not None else None,
|
|
157
|
+
},
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# can't use constructor because it makes an API call, so we construct the objects manually
|
|
161
|
+
return [
|
|
162
|
+
(
|
|
163
|
+
lambda t: (
|
|
164
|
+
obj := cls.__new__(cls),
|
|
165
|
+
setattr(obj, "id", t["id"]),
|
|
166
|
+
setattr(obj, "type", t["type"]),
|
|
167
|
+
setattr(obj, "status", Status(t["status"])),
|
|
168
|
+
setattr(obj, "steps_total", t["steps_total"]),
|
|
169
|
+
setattr(obj, "steps_completed", t["steps_completed"]),
|
|
170
|
+
setattr(obj, "exception", t["exception"]),
|
|
171
|
+
setattr(obj, "value", cast(TResult, t["result"]) if t["result"] is not None else None),
|
|
172
|
+
setattr(obj, "updated_at", datetime.fromisoformat(t["updated_at"])),
|
|
173
|
+
setattr(obj, "created_at", datetime.fromisoformat(t["created_at"])),
|
|
174
|
+
setattr(obj, "refreshed_at", datetime.now()),
|
|
175
|
+
obj,
|
|
176
|
+
)[-1]
|
|
177
|
+
)(t)
|
|
178
|
+
for t in paginated_jobs["items"]
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
def __init__(self, id: str, get_value: Callable[[], TResult | None] | None = None):
|
|
182
|
+
"""
|
|
183
|
+
Create a handle to a job in the job queue
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
id: Unique identifier for the job
|
|
187
|
+
get_value: Optional function to customize how the value is resolved, if not provided the result will be a dict
|
|
188
|
+
"""
|
|
189
|
+
self.id = id
|
|
190
|
+
client = OrcaClient._resolve_client()
|
|
191
|
+
job = client.GET("/job/{job_id}", params={"job_id": id})
|
|
192
|
+
|
|
193
|
+
def default_get_value():
|
|
194
|
+
client = OrcaClient._resolve_client()
|
|
195
|
+
return cast(TResult | None, client.GET("/job/{job_id}", params={"job_id": id})["result"])
|
|
196
|
+
|
|
197
|
+
self._get_value = get_value or default_get_value
|
|
198
|
+
self.type = job["type"]
|
|
199
|
+
self.status = Status(job["status"])
|
|
200
|
+
self.steps_total = job["steps_total"]
|
|
201
|
+
self.steps_completed = job["steps_completed"]
|
|
202
|
+
self.exception = job["exception"]
|
|
203
|
+
self.value = (
|
|
204
|
+
None
|
|
205
|
+
if job["status"] != "COMPLETED"
|
|
206
|
+
else (
|
|
207
|
+
get_value()
|
|
208
|
+
if get_value is not None
|
|
209
|
+
else cast(TResult, job["result"]) if job["result"] is not None else None
|
|
210
|
+
)
|
|
211
|
+
)
|
|
212
|
+
self.updated_at = datetime.fromisoformat(job["updated_at"])
|
|
213
|
+
self.created_at = datetime.fromisoformat(job["created_at"])
|
|
214
|
+
self.refreshed_at = datetime.now()
|
|
215
|
+
|
|
216
|
+
def refresh(self, throttle: float = 0):
|
|
217
|
+
"""
|
|
218
|
+
Refresh the status and progress of the job
|
|
219
|
+
|
|
220
|
+
Params:
|
|
221
|
+
throttle: Minimum time in seconds between refreshes
|
|
222
|
+
"""
|
|
223
|
+
current_time = datetime.now()
|
|
224
|
+
# Skip refresh if last refresh was too recent
|
|
225
|
+
if (current_time - self.refreshed_at) < timedelta(seconds=throttle):
|
|
226
|
+
return
|
|
227
|
+
self.refreshed_at = current_time
|
|
228
|
+
|
|
229
|
+
client = OrcaClient._resolve_client()
|
|
230
|
+
status_info = client.GET("/job/{job_id}/status", params={"job_id": self.id})
|
|
231
|
+
self.status = Status(status_info["status"])
|
|
232
|
+
if status_info["steps_total"] is not None:
|
|
233
|
+
self.steps_total = status_info["steps_total"]
|
|
234
|
+
if status_info["steps_completed"] is not None:
|
|
235
|
+
self.steps_completed = status_info["steps_completed"]
|
|
236
|
+
|
|
237
|
+
self.exception = status_info["exception"]
|
|
238
|
+
self.updated_at = datetime.fromisoformat(status_info["updated_at"])
|
|
239
|
+
|
|
240
|
+
if status_info["status"] == "COMPLETED":
|
|
241
|
+
self.value = self._get_value()
|
|
242
|
+
|
|
243
|
+
def __getattribute__(self, name: str):
|
|
244
|
+
# if the attribute is not immutable, refresh the job if it hasn't been refreshed recently
|
|
245
|
+
if name in ["status", "updated_at", "steps_total", "steps_completed", "exception", "value"]:
|
|
246
|
+
self.refresh(self.config["refresh_interval"])
|
|
247
|
+
return super().__getattribute__(name)
|
|
248
|
+
|
|
249
|
+
def wait(
|
|
250
|
+
self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
|
|
251
|
+
) -> None:
|
|
252
|
+
"""
|
|
253
|
+
Block until the job is complete
|
|
254
|
+
|
|
255
|
+
Params:
|
|
256
|
+
show_progress: Show a progress bar while waiting for the job to complete
|
|
257
|
+
refresh_interval: Polling interval in seconds while waiting for the job to complete
|
|
258
|
+
max_wait: Maximum time to wait for the job to complete in seconds
|
|
259
|
+
|
|
260
|
+
Note:
|
|
261
|
+
The defaults for the config parameters can be set globally using the
|
|
262
|
+
[`set_config`][orca_sdk.Job.set_config] method.
|
|
263
|
+
|
|
264
|
+
This method will not return the result or raise an exception if the job fails. Call
|
|
265
|
+
[`result`][orca_sdk.Job.result] instead if you want to get the result.
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
RuntimeError: If the job times out
|
|
269
|
+
"""
|
|
270
|
+
start_time = time.time()
|
|
271
|
+
show_progress = show_progress if show_progress is not None else self.config["show_progress"]
|
|
272
|
+
refresh_interval = refresh_interval if refresh_interval is not None else self.config["refresh_interval"]
|
|
273
|
+
max_wait = max_wait if max_wait is not None else self.config["max_wait"]
|
|
274
|
+
pbar = None
|
|
275
|
+
while True:
|
|
276
|
+
# setup progress bar if steps total is known
|
|
277
|
+
if not pbar and self.steps_total is not None and show_progress:
|
|
278
|
+
desc = " ".join(self.type.split("_")).lower()
|
|
279
|
+
pbar = tqdm(total=self.steps_total, desc=desc)
|
|
280
|
+
|
|
281
|
+
# return if job is complete
|
|
282
|
+
if self.status in [Status.COMPLETED, Status.FAILED, Status.ABORTED]:
|
|
283
|
+
if pbar:
|
|
284
|
+
pbar.update(self.steps_total - pbar.n)
|
|
285
|
+
pbar.close()
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
# raise error if job timed out
|
|
289
|
+
if (time.time() - start_time) > max_wait:
|
|
290
|
+
raise RuntimeError(f"Job {self.id} timed out after {max_wait}s")
|
|
291
|
+
|
|
292
|
+
# update progress bar
|
|
293
|
+
if pbar and self.steps_completed is not None:
|
|
294
|
+
pbar.update(self.steps_completed - pbar.n)
|
|
295
|
+
|
|
296
|
+
# sleep before retrying
|
|
297
|
+
time.sleep(refresh_interval)
|
|
298
|
+
|
|
299
|
+
def result(
|
|
300
|
+
self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
|
|
301
|
+
) -> TResult:
|
|
302
|
+
"""
|
|
303
|
+
Block until the job is complete and return the result value
|
|
304
|
+
|
|
305
|
+
Params:
|
|
306
|
+
show_progress: Show a progress bar while waiting for the job to complete
|
|
307
|
+
refresh_interval: Polling interval in seconds while waiting for the job to complete
|
|
308
|
+
max_wait: Maximum time to wait for the job to complete in seconds
|
|
309
|
+
|
|
310
|
+
Note:
|
|
311
|
+
The defaults for the config parameters can be set globally using the
|
|
312
|
+
[`set_config`][orca_sdk.Job.set_config] method.
|
|
313
|
+
|
|
314
|
+
This method will raise an exception if the job fails. Use [`wait`][orca_sdk.Job.wait]
|
|
315
|
+
if you just want to wait for the job to complete without raising errors on failure.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
The result value of the job
|
|
319
|
+
|
|
320
|
+
Raises:
|
|
321
|
+
RuntimeError: If the job fails or times out
|
|
322
|
+
"""
|
|
323
|
+
if self.value is not None:
|
|
324
|
+
return self.value
|
|
325
|
+
self.wait(show_progress, refresh_interval, max_wait)
|
|
326
|
+
if self.status != Status.COMPLETED:
|
|
327
|
+
raise RuntimeError(f"Job failed with exception: {self.exception}")
|
|
328
|
+
assert self.value is not None
|
|
329
|
+
return self.value
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait: int = 20) -> None:
|
|
333
|
+
"""
|
|
334
|
+
Abort the job
|
|
335
|
+
|
|
336
|
+
Params:
|
|
337
|
+
show_progress: Optionally show a progress bar while waiting for the job to abort
|
|
338
|
+
refresh_interval: Polling interval in seconds while waiting for the job to abort
|
|
339
|
+
max_wait: Maximum time to wait for the job to abort in seconds
|
|
340
|
+
"""
|
|
341
|
+
client = OrcaClient._resolve_client()
|
|
342
|
+
client.DELETE("/job/{job_id}/abort", params={"job_id": self.id})
|
|
343
|
+
self.wait(show_progress, refresh_interval, max_wait)
|
orca_sdk/job_test.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
|
|
6
|
+
from .classification_model import ClassificationModel
|
|
7
|
+
from .datasource import Datasource
|
|
8
|
+
from .job import Job, Status
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture(scope="session")
|
|
12
|
+
def datasource_without_nones(hf_dataset: Dataset):
|
|
13
|
+
return Datasource.from_hf_dataset(
|
|
14
|
+
"test_datasource_without_nones", hf_dataset.filter(lambda x: x["label"] is not None)
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def wait_for_jobs_status(job_ids, expected_statuses, timeout=10, poll_interval=0.2):
|
|
19
|
+
"""
|
|
20
|
+
Wait until all jobs reach one of the expected statuses or timeout is reached.
|
|
21
|
+
"""
|
|
22
|
+
start = time.time()
|
|
23
|
+
while time.time() - start < timeout:
|
|
24
|
+
jobs = [Job(job_id) for job_id in job_ids]
|
|
25
|
+
if all(job.status in expected_statuses for job in jobs):
|
|
26
|
+
return
|
|
27
|
+
time.sleep(poll_interval)
|
|
28
|
+
raise TimeoutError(f"Jobs did not reach statuses {expected_statuses} within {timeout} seconds")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_job_creation(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
32
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
33
|
+
assert job.id is not None
|
|
34
|
+
assert job.type == "EVALUATE_MODEL"
|
|
35
|
+
assert job.status in [Status.DISPATCHED, Status.PROCESSING]
|
|
36
|
+
assert job.created_at is not None
|
|
37
|
+
assert job.updated_at is not None
|
|
38
|
+
assert job.refreshed_at is not None
|
|
39
|
+
assert len(Job.query(limit=5, type="EVALUATE_MODEL")) >= 1
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_job_result(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
43
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
44
|
+
result = job.result(show_progress=False)
|
|
45
|
+
assert result is not None
|
|
46
|
+
assert job.status == Status.COMPLETED
|
|
47
|
+
assert job.steps_completed is not None
|
|
48
|
+
assert job.steps_completed == job.steps_total
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_job_wait(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
52
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
53
|
+
job.wait(show_progress=False)
|
|
54
|
+
assert job.status == Status.COMPLETED
|
|
55
|
+
assert job.steps_completed is not None
|
|
56
|
+
assert job.steps_completed == job.steps_total
|
|
57
|
+
assert job.value is not None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_job_refresh(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
61
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
62
|
+
last_refreshed_at = job.refreshed_at
|
|
63
|
+
# accessing the status attribute should refresh the job after the refresh interval
|
|
64
|
+
Job.set_config(refresh_interval=1)
|
|
65
|
+
time.sleep(1)
|
|
66
|
+
job.status
|
|
67
|
+
assert job.refreshed_at > last_refreshed_at
|
|
68
|
+
last_refreshed_at = job.refreshed_at
|
|
69
|
+
# calling refresh() should immediately refresh the job
|
|
70
|
+
job.refresh()
|
|
71
|
+
assert job.refreshed_at > last_refreshed_at
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_job_query_pagination(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
75
|
+
"""Test pagination with Job.query() method"""
|
|
76
|
+
# Create multiple jobs to test pagination
|
|
77
|
+
jobs_created = []
|
|
78
|
+
for i in range(3):
|
|
79
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
80
|
+
jobs_created.append(job.id)
|
|
81
|
+
|
|
82
|
+
# Wait for jobs to be at least PROCESSING or COMPLETED
|
|
83
|
+
wait_for_jobs_status(jobs_created, expected_statuses=[Status.PROCESSING, Status.COMPLETED])
|
|
84
|
+
|
|
85
|
+
# Test basic pagination with limit
|
|
86
|
+
jobs_page1 = Job.query(type="EVALUATE_MODEL", limit=2)
|
|
87
|
+
assert len(jobs_page1) == 2
|
|
88
|
+
|
|
89
|
+
# Test pagination with offset
|
|
90
|
+
jobs_page2 = Job.query(type="EVALUATE_MODEL", limit=2, offset=1)
|
|
91
|
+
assert len(jobs_page2) == 2
|
|
92
|
+
|
|
93
|
+
# Verify different pages contain different jobs (allowing for some overlap due to timing)
|
|
94
|
+
page1_ids = {job.id for job in jobs_page1}
|
|
95
|
+
page2_ids = {job.id for job in jobs_page2}
|
|
96
|
+
|
|
97
|
+
# At least one job should be different between pages
|
|
98
|
+
assert len(page1_ids.symmetric_difference(page2_ids)) > 0
|
|
99
|
+
|
|
100
|
+
# Test filtering by status
|
|
101
|
+
jobs_by_status = Job.query(status=Status.PROCESSING, limit=10)
|
|
102
|
+
for job in jobs_by_status:
|
|
103
|
+
assert job.status == Status.PROCESSING
|
|
104
|
+
|
|
105
|
+
# Test filtering by multiple statuses
|
|
106
|
+
multi_status_jobs = Job.query(status=[Status.PROCESSING, Status.COMPLETED], limit=10)
|
|
107
|
+
for job in multi_status_jobs:
|
|
108
|
+
assert job.status in [Status.PROCESSING, Status.COMPLETED]
|