orca-sdk 0.0.94__py3-none-any.whl → 0.0.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +80 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +84 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +183 -175
- orca_sdk/classification_model_test.py +147 -157
- orca_sdk/conftest.py +76 -26
- orca_sdk/datasource_test.py +0 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +56 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +338 -0
- orca_sdk/telemetry.py +223 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +115 -69
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
orca_sdk/embedding_model.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from abc import abstractmethod
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import TYPE_CHECKING, Sequence, cast, overload
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, Sequence, cast, overload
|
|
6
6
|
|
|
7
7
|
from ._generated_api_client.api import (
|
|
8
8
|
create_finetuned_embedding_model,
|
|
@@ -24,8 +24,8 @@ from ._generated_api_client.models import (
|
|
|
24
24
|
PretrainedEmbeddingModelName,
|
|
25
25
|
)
|
|
26
26
|
from ._utils.common import CreateMode, DropMode
|
|
27
|
-
from ._utils.task import TaskStatus, wait_for_task
|
|
28
27
|
from .datasource import Datasource
|
|
28
|
+
from .job import Job, Status
|
|
29
29
|
|
|
30
30
|
if TYPE_CHECKING:
|
|
31
31
|
from .memoryset import LabeledMemoryset
|
|
@@ -79,15 +79,61 @@ class _EmbeddingModel:
|
|
|
79
79
|
return embeddings if isinstance(value, list) else embeddings[0]
|
|
80
80
|
|
|
81
81
|
|
|
82
|
-
class
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
return PretrainedEmbeddingModel._get(name)
|
|
86
|
-
else:
|
|
87
|
-
raise AttributeError(f"'{cls.__name__}' object has no attribute '{name}'")
|
|
82
|
+
class _ModelDescriptor:
|
|
83
|
+
"""
|
|
84
|
+
Descriptor for lazily loading embedding models with IDE autocomplete support.
|
|
88
85
|
|
|
86
|
+
This class implements the descriptor protocol to provide lazy loading of embedding models
|
|
87
|
+
while maintaining IDE autocomplete functionality. It delays the actual loading of models
|
|
88
|
+
until they are accessed, which improves startup performance.
|
|
89
89
|
|
|
90
|
-
|
|
90
|
+
The descriptor pattern works by defining how attribute access is handled. When a class
|
|
91
|
+
attribute using this descriptor is accessed, the __get__ method is called, which then
|
|
92
|
+
retrieves or initializes the actual model on first access.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self, name: str):
|
|
96
|
+
"""
|
|
97
|
+
Initialize a model descriptor.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
name: The name of the embedding model in PretrainedEmbeddingModelName
|
|
101
|
+
"""
|
|
102
|
+
self.name = name
|
|
103
|
+
self.model = None # Model is loaded lazily on first access
|
|
104
|
+
|
|
105
|
+
def __get__(self, instance, owner_class):
|
|
106
|
+
"""
|
|
107
|
+
Descriptor protocol method called when the attribute is accessed.
|
|
108
|
+
|
|
109
|
+
This method implements lazy loading - the actual model is only initialized
|
|
110
|
+
the first time it's accessed. Subsequent accesses will use the cached model.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
instance: The instance the attribute was accessed from, or None if accessed from the class
|
|
114
|
+
owner_class: The class that owns the descriptor
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
The initialized embedding model
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
AttributeError: If no model with the given name exists
|
|
121
|
+
"""
|
|
122
|
+
# When accessed from an instance, redirect to class access
|
|
123
|
+
if instance is not None:
|
|
124
|
+
return self.__get__(None, owner_class)
|
|
125
|
+
|
|
126
|
+
# Load the model on first access
|
|
127
|
+
if self.model is None:
|
|
128
|
+
try:
|
|
129
|
+
self.model = PretrainedEmbeddingModel._get(self.name)
|
|
130
|
+
except (KeyError, AttributeError):
|
|
131
|
+
raise AttributeError(f"No embedding model named {self.name}")
|
|
132
|
+
|
|
133
|
+
return self.model
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class PretrainedEmbeddingModel(_EmbeddingModel):
|
|
91
137
|
"""
|
|
92
138
|
A pretrained embedding model
|
|
93
139
|
|
|
@@ -102,6 +148,11 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
102
148
|
- **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
|
|
103
149
|
- **`DISTILBERT`**: DistilBERT embedding model from Hugging Face ([distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased))
|
|
104
150
|
- **`GTE_SMALL`**: GTE-Small embedding model from Hugging Face ([Supabase/gte-small](https://huggingface.co/Supabase/gte-small))
|
|
151
|
+
- **`E5_LARGE`**: E5-Large instruction-tuned embedding model from Hugging Face ([intfloat/multilingual-e5-large-instruct](https://huggingface.co/intfloat/multilingual-e5-large-instruct))
|
|
152
|
+
- **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
|
|
153
|
+
- **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
|
|
154
|
+
- **`QWEN2_1_5B`**: Alibaba's Qwen2-1.5B instruction-tuned embedding model from Hugging Face ([Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct))
|
|
155
|
+
|
|
105
156
|
|
|
106
157
|
Examples:
|
|
107
158
|
>>> PretrainedEmbeddingModel.CDE_SMALL
|
|
@@ -114,6 +165,17 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
114
165
|
uses_context: Whether the pretrained embedding model uses context
|
|
115
166
|
"""
|
|
116
167
|
|
|
168
|
+
# Define descriptors for model access with IDE autocomplete
|
|
169
|
+
CDE_SMALL = _ModelDescriptor("CDE_SMALL")
|
|
170
|
+
CLIP_BASE = _ModelDescriptor("CLIP_BASE")
|
|
171
|
+
GTE_BASE = _ModelDescriptor("GTE_BASE")
|
|
172
|
+
DISTILBERT = _ModelDescriptor("DISTILBERT")
|
|
173
|
+
GTE_SMALL = _ModelDescriptor("GTE_SMALL")
|
|
174
|
+
E5_LARGE = _ModelDescriptor("E5_LARGE")
|
|
175
|
+
GIST_LARGE = _ModelDescriptor("GIST_LARGE")
|
|
176
|
+
MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
|
|
177
|
+
QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
|
|
178
|
+
|
|
117
179
|
_model_name: PretrainedEmbeddingModelName
|
|
118
180
|
|
|
119
181
|
def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
|
|
@@ -151,6 +213,29 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
151
213
|
cls._instances[str(name)] = cls(get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name)))
|
|
152
214
|
return cls._instances[str(name)]
|
|
153
215
|
|
|
216
|
+
@classmethod
|
|
217
|
+
def open(cls, name: str) -> PretrainedEmbeddingModel:
|
|
218
|
+
"""
|
|
219
|
+
Open an embedding model by name.
|
|
220
|
+
|
|
221
|
+
This is an alternative method to access models for environments
|
|
222
|
+
where IDE autocomplete for model names is not available.
|
|
223
|
+
|
|
224
|
+
Params:
|
|
225
|
+
name: Name of the model to open (e.g., "GTE_BASE", "CLIP_BASE")
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
The embedding model instance
|
|
229
|
+
|
|
230
|
+
Examples:
|
|
231
|
+
>>> model = PretrainedEmbeddingModel.open("GTE_BASE")
|
|
232
|
+
"""
|
|
233
|
+
try:
|
|
234
|
+
# Use getattr to access the descriptor which will initialize the model
|
|
235
|
+
return getattr(cls, name)
|
|
236
|
+
except AttributeError:
|
|
237
|
+
raise ValueError(f"Unknown model name: {name}")
|
|
238
|
+
|
|
154
239
|
@classmethod
|
|
155
240
|
def exists(cls, name: str) -> bool:
|
|
156
241
|
"""
|
|
@@ -164,6 +249,23 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
164
249
|
"""
|
|
165
250
|
return name in PretrainedEmbeddingModelName
|
|
166
251
|
|
|
252
|
+
@overload
|
|
253
|
+
def finetune(
|
|
254
|
+
self,
|
|
255
|
+
name: str,
|
|
256
|
+
train_datasource: Datasource | LabeledMemoryset,
|
|
257
|
+
*,
|
|
258
|
+
eval_datasource: Datasource | None = None,
|
|
259
|
+
label_column: str = "label",
|
|
260
|
+
value_column: str = "value",
|
|
261
|
+
training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
|
|
262
|
+
training_args: dict | None = None,
|
|
263
|
+
if_exists: CreateMode = "error",
|
|
264
|
+
background: Literal[True],
|
|
265
|
+
) -> Job[FinetunedEmbeddingModel]:
|
|
266
|
+
pass
|
|
267
|
+
|
|
268
|
+
@overload
|
|
167
269
|
def finetune(
|
|
168
270
|
self,
|
|
169
271
|
name: str,
|
|
@@ -175,7 +277,23 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
175
277
|
training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
|
|
176
278
|
training_args: dict | None = None,
|
|
177
279
|
if_exists: CreateMode = "error",
|
|
280
|
+
background: Literal[False] = False,
|
|
178
281
|
) -> FinetunedEmbeddingModel:
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
def finetune(
|
|
285
|
+
self,
|
|
286
|
+
name: str,
|
|
287
|
+
train_datasource: Datasource | LabeledMemoryset,
|
|
288
|
+
*,
|
|
289
|
+
eval_datasource: Datasource | None = None,
|
|
290
|
+
label_column: str = "label",
|
|
291
|
+
value_column: str = "value",
|
|
292
|
+
training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
|
|
293
|
+
training_args: dict | None = None,
|
|
294
|
+
if_exists: CreateMode = "error",
|
|
295
|
+
background: bool = False,
|
|
296
|
+
) -> FinetunedEmbeddingModel | Job[FinetunedEmbeddingModel]:
|
|
179
297
|
"""
|
|
180
298
|
Finetune an embedding model
|
|
181
299
|
|
|
@@ -190,6 +308,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
190
308
|
If not provided, reasonable training arguments will be used for the specified training method
|
|
191
309
|
if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
|
|
192
310
|
`"error"`. Other option is `"open"` to open the existing finetuned embedding model.
|
|
311
|
+
background: Whether to run the operation in the background and return a job handle
|
|
193
312
|
|
|
194
313
|
Returns:
|
|
195
314
|
The finetuned embedding model
|
|
@@ -233,8 +352,11 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
233
352
|
training_args=(FinetuneEmbeddingModelRequestTrainingArgs.from_dict(training_args or {})),
|
|
234
353
|
),
|
|
235
354
|
)
|
|
236
|
-
|
|
237
|
-
|
|
355
|
+
job = Job(
|
|
356
|
+
res.finetuning_task_id,
|
|
357
|
+
lambda: FinetunedEmbeddingModel.open(res.id),
|
|
358
|
+
)
|
|
359
|
+
return job if background else job.result()
|
|
238
360
|
|
|
239
361
|
|
|
240
362
|
class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
@@ -254,7 +376,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
254
376
|
id: str
|
|
255
377
|
created_at: datetime
|
|
256
378
|
updated_at: datetime
|
|
257
|
-
_status:
|
|
379
|
+
_status: Status
|
|
258
380
|
|
|
259
381
|
def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
|
|
260
382
|
# for internal use only, do not document
|
|
@@ -262,7 +384,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
262
384
|
self.created_at = metadata.created_at
|
|
263
385
|
self.updated_at = metadata.updated_at
|
|
264
386
|
self.base_model_name = metadata.base_model
|
|
265
|
-
self._status = metadata.finetuning_status
|
|
387
|
+
self._status = Status(metadata.finetuning_status.value)
|
|
266
388
|
super().__init__(
|
|
267
389
|
name=metadata.name,
|
|
268
390
|
embedding_dim=metadata.embedding_dim,
|
|
@@ -344,6 +466,6 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
344
466
|
"""
|
|
345
467
|
try:
|
|
346
468
|
delete_finetuned_embedding_model(name_or_id)
|
|
347
|
-
except LookupError:
|
|
469
|
+
except (LookupError, RuntimeError):
|
|
348
470
|
if if_not_exists == "error":
|
|
349
471
|
raise
|
orca_sdk/embedding_model_test.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from uuid import uuid4
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
@@ -7,8 +8,8 @@ from .embedding_model import (
|
|
|
7
8
|
FinetunedEmbeddingModel,
|
|
8
9
|
PretrainedEmbeddingModel,
|
|
9
10
|
PretrainedEmbeddingModelName,
|
|
10
|
-
TaskStatus,
|
|
11
11
|
)
|
|
12
|
+
from .job import Status
|
|
12
13
|
from .memoryset import LabeledMemoryset
|
|
13
14
|
|
|
14
15
|
|
|
@@ -34,8 +35,11 @@ def test_open_pretrained_model_not_found():
|
|
|
34
35
|
|
|
35
36
|
def test_all_pretrained_models():
|
|
36
37
|
models = PretrainedEmbeddingModel.all()
|
|
37
|
-
assert len(models)
|
|
38
|
-
|
|
38
|
+
assert len(models) > 1
|
|
39
|
+
if len(models) != len(PretrainedEmbeddingModelName):
|
|
40
|
+
logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
|
|
41
|
+
model_names = [m.name for m in models]
|
|
42
|
+
assert all(enum_member in model_names for enum_member in PretrainedEmbeddingModelName.__members__)
|
|
39
43
|
|
|
40
44
|
|
|
41
45
|
def test_embed_text():
|
|
@@ -62,7 +66,7 @@ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel
|
|
|
62
66
|
assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
63
67
|
assert finetuned_model.embedding_dim == 768
|
|
64
68
|
assert finetuned_model.max_seq_length == 512
|
|
65
|
-
assert finetuned_model._status ==
|
|
69
|
+
assert finetuned_model._status == Status.COMPLETED
|
|
66
70
|
|
|
67
71
|
|
|
68
72
|
def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
@@ -74,7 +78,7 @@ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
74
78
|
assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
75
79
|
assert finetuned_model.embedding_dim == 768
|
|
76
80
|
assert finetuned_model.max_seq_length == 512
|
|
77
|
-
assert finetuned_model._status ==
|
|
81
|
+
assert finetuned_model._status == Status.COMPLETED
|
|
78
82
|
|
|
79
83
|
|
|
80
84
|
def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
|
|
@@ -96,7 +100,7 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
|
|
|
96
100
|
assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
97
101
|
assert new_model.embedding_dim == 768
|
|
98
102
|
assert new_model.max_seq_length == 512
|
|
99
|
-
assert new_model._status ==
|
|
103
|
+
assert new_model._status == Status.COMPLETED
|
|
100
104
|
|
|
101
105
|
|
|
102
106
|
def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
|
orca_sdk/job.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from datetime import datetime, timedelta
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Callable, Generic, TypedDict, TypeVar, cast
|
|
7
|
+
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
|
|
10
|
+
from ._generated_api_client.api import abort_task, get_task, get_task_status, list_tasks
|
|
11
|
+
from ._generated_api_client.models import TaskStatus
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class JobConfig(TypedDict):
|
|
15
|
+
refresh_interval: int
|
|
16
|
+
show_progress: bool
|
|
17
|
+
max_wait: int
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Status(Enum):
|
|
21
|
+
"""Status of a cloud job in the task queue"""
|
|
22
|
+
|
|
23
|
+
# the INITIALIZED state should never be returned by the API
|
|
24
|
+
|
|
25
|
+
DISPATCHED = "DISPATCHED"
|
|
26
|
+
"""The job has been queued and is waiting to be processed"""
|
|
27
|
+
|
|
28
|
+
PROCESSING = "PROCESSING"
|
|
29
|
+
"""The job is being processed"""
|
|
30
|
+
|
|
31
|
+
COMPLETED = "COMPLETED"
|
|
32
|
+
"""The job has been completed successfully"""
|
|
33
|
+
|
|
34
|
+
FAILED = "FAILED"
|
|
35
|
+
"""The job has failed"""
|
|
36
|
+
|
|
37
|
+
ABORTING = "ABORTING"
|
|
38
|
+
"""The job is being aborted"""
|
|
39
|
+
|
|
40
|
+
ABORTED = "ABORTED"
|
|
41
|
+
"""The job has been aborted"""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
TResult = TypeVar("TResult")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Job(Generic[TResult]):
|
|
48
|
+
"""
|
|
49
|
+
Handle to a job that is run in the OrcaCloud
|
|
50
|
+
|
|
51
|
+
Attributes:
|
|
52
|
+
id: Unique identifier for the job
|
|
53
|
+
type: Type of the job
|
|
54
|
+
status: Current status of the job
|
|
55
|
+
steps_total: Total number of steps in the job, present if the job started processing
|
|
56
|
+
steps_completed: Number of steps completed in the job, present if the job started processing
|
|
57
|
+
completion: Percentage of the job that has been completed, present if the job started processing
|
|
58
|
+
exception: Exception that occurred during the job, present if the status is `FAILED`
|
|
59
|
+
value: Value of the result of the job, present if the status is `COMPLETED`
|
|
60
|
+
created_at: When the job was queued for processing
|
|
61
|
+
updated_at: When the job was last updated
|
|
62
|
+
refreshed_at: When the job status was last refreshed
|
|
63
|
+
|
|
64
|
+
Note:
|
|
65
|
+
Accessing status and related attributes will refresh the job status in the background.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
id: str
|
|
69
|
+
type: str
|
|
70
|
+
status: Status
|
|
71
|
+
steps_total: int | None
|
|
72
|
+
steps_completed: int | None
|
|
73
|
+
exception: str | None
|
|
74
|
+
value: TResult | None
|
|
75
|
+
updated_at: datetime
|
|
76
|
+
created_at: datetime
|
|
77
|
+
refreshed_at: datetime
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def completion(self) -> float:
|
|
81
|
+
"""
|
|
82
|
+
Percentage of the job that has been completed, present if the job started processing
|
|
83
|
+
"""
|
|
84
|
+
return (self.steps_completed or 0) / self.steps_total if self.steps_total is not None else 0
|
|
85
|
+
|
|
86
|
+
# Global configuration for all jobs
|
|
87
|
+
config: JobConfig = {
|
|
88
|
+
"refresh_interval": 3,
|
|
89
|
+
"show_progress": True,
|
|
90
|
+
"max_wait": 60 * 60,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
def __repr__(self) -> str:
|
|
94
|
+
return "Job({" + f" type: {self.type}, status: {self.status}, completion: {self.completion:.0%} " + "})"
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def set_config(
|
|
98
|
+
cls, *, refresh_interval: int | None = None, show_progress: bool | None = None, max_wait: int | None = None
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
Set global configuration for running jobs
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
refresh_interval: Time to wait between polling the job status in seconds, default is 3
|
|
105
|
+
show_progress: Whether to show a progress bar when calling the wait method, default is True
|
|
106
|
+
max_wait: Maximum time to wait for the job to complete in seconds, default is 1 hour
|
|
107
|
+
"""
|
|
108
|
+
if refresh_interval is not None:
|
|
109
|
+
cls.config["refresh_interval"] = refresh_interval
|
|
110
|
+
if show_progress is not None:
|
|
111
|
+
cls.config["show_progress"] = show_progress
|
|
112
|
+
if max_wait is not None:
|
|
113
|
+
cls.config["max_wait"] = max_wait
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def query(
|
|
117
|
+
cls,
|
|
118
|
+
status: Status | list[Status] | None = None,
|
|
119
|
+
type: str | list[str] | None = None,
|
|
120
|
+
limit: int | None = None,
|
|
121
|
+
offset: int = 0,
|
|
122
|
+
start: datetime | None = None,
|
|
123
|
+
end: datetime | None = None,
|
|
124
|
+
) -> list[Job]:
|
|
125
|
+
"""
|
|
126
|
+
Query the job queue for jobs matching the given filters
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
status: Optional status or list of statuses to filter by
|
|
130
|
+
type: Optional type or list of types to filter by
|
|
131
|
+
limit: Maximum number of jobs to return
|
|
132
|
+
offset: Offset into the list of jobs to return
|
|
133
|
+
start: Optional minimum creation time of the jobs to query for
|
|
134
|
+
end: Optional maximum creation time of the jobs to query for
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
List of jobs matching the given filters
|
|
138
|
+
"""
|
|
139
|
+
tasks = list_tasks(
|
|
140
|
+
status=(
|
|
141
|
+
[TaskStatus(s.value) for s in status]
|
|
142
|
+
if isinstance(status, list)
|
|
143
|
+
else TaskStatus(status.value) if isinstance(status, Status) else None
|
|
144
|
+
),
|
|
145
|
+
type=type,
|
|
146
|
+
limit=limit,
|
|
147
|
+
offset=offset,
|
|
148
|
+
start_timestamp=start,
|
|
149
|
+
end_timestamp=end,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# can't use constructor because it makes an API call, so we construct the objects manually
|
|
153
|
+
return [
|
|
154
|
+
(
|
|
155
|
+
lambda t: (
|
|
156
|
+
obj := cls.__new__(cls),
|
|
157
|
+
setattr(obj, "id", t.id),
|
|
158
|
+
setattr(obj, "type", t.type),
|
|
159
|
+
setattr(obj, "status", Status(t.status.value)),
|
|
160
|
+
setattr(obj, "steps_total", t.steps_total),
|
|
161
|
+
setattr(obj, "steps_completed", t.steps_completed),
|
|
162
|
+
setattr(obj, "exception", t.exception),
|
|
163
|
+
setattr(obj, "value", cast(TResult, t.result.to_dict()) if t.result is not None else None),
|
|
164
|
+
setattr(obj, "updated_at", t.updated_at),
|
|
165
|
+
setattr(obj, "created_at", t.created_at),
|
|
166
|
+
setattr(obj, "refreshed_at", datetime.now()),
|
|
167
|
+
obj,
|
|
168
|
+
)[-1]
|
|
169
|
+
)(t)
|
|
170
|
+
for t in tasks
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
def __init__(self, id: str, get_value: Callable[[], TResult | None] | None = None):
|
|
174
|
+
"""
|
|
175
|
+
Create a handle to a job in the job queue
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
id: Unique identifier for the job
|
|
179
|
+
get_value: Optional function to customize how the value is resolved, if not provided the result will be a dict
|
|
180
|
+
"""
|
|
181
|
+
self.id = id
|
|
182
|
+
task = get_task(self.id)
|
|
183
|
+
|
|
184
|
+
self._get_value = get_value or (
|
|
185
|
+
lambda: (r := get_task(id).result) and (cast(TResult, r.to_dict()) if r else None)
|
|
186
|
+
)
|
|
187
|
+
self.type = task.type
|
|
188
|
+
self.status = Status(task.status.value)
|
|
189
|
+
self.steps_total = task.steps_total
|
|
190
|
+
self.steps_completed = task.steps_completed
|
|
191
|
+
self.exception = task.exception
|
|
192
|
+
self.value = (
|
|
193
|
+
None
|
|
194
|
+
if task.status != TaskStatus.COMPLETED
|
|
195
|
+
else (
|
|
196
|
+
get_value()
|
|
197
|
+
if get_value is not None
|
|
198
|
+
else cast(TResult, task.result.to_dict()) if task.result is not None else None
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
self.updated_at = task.updated_at
|
|
202
|
+
self.created_at = task.created_at
|
|
203
|
+
self.refreshed_at = datetime.now()
|
|
204
|
+
|
|
205
|
+
def refresh(self, throttle: float = 0):
|
|
206
|
+
"""
|
|
207
|
+
Refresh the status and progress of the job
|
|
208
|
+
|
|
209
|
+
Params:
|
|
210
|
+
throttle: Minimum time in seconds between refreshes
|
|
211
|
+
"""
|
|
212
|
+
current_time = datetime.now()
|
|
213
|
+
# Skip refresh if last refresh was too recent
|
|
214
|
+
if (current_time - self.refreshed_at) < timedelta(seconds=throttle):
|
|
215
|
+
return
|
|
216
|
+
self.refreshed_at = current_time
|
|
217
|
+
|
|
218
|
+
status_info = get_task_status(self.id)
|
|
219
|
+
self.status = Status(status_info.status.value)
|
|
220
|
+
if status_info.steps_total is not None:
|
|
221
|
+
self.steps_total = status_info.steps_total
|
|
222
|
+
if status_info.steps_completed is not None:
|
|
223
|
+
self.steps_completed = status_info.steps_completed
|
|
224
|
+
|
|
225
|
+
self.exception = status_info.exception
|
|
226
|
+
self.updated_at = status_info.updated_at
|
|
227
|
+
|
|
228
|
+
if status_info.status == TaskStatus.COMPLETED:
|
|
229
|
+
self.value = self._get_value()
|
|
230
|
+
|
|
231
|
+
def __getattribute__(self, name: str):
|
|
232
|
+
# if the attribute is not immutable, refresh the job if it hasn't been refreshed recently
|
|
233
|
+
if name in ["status", "updated_at", "steps_total", "steps_completed", "exception", "value"]:
|
|
234
|
+
self.refresh(self.config["refresh_interval"])
|
|
235
|
+
return super().__getattribute__(name)
|
|
236
|
+
|
|
237
|
+
def wait(
|
|
238
|
+
self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
|
|
239
|
+
) -> None:
|
|
240
|
+
"""
|
|
241
|
+
Block until the job is complete
|
|
242
|
+
|
|
243
|
+
Params:
|
|
244
|
+
show_progress: Show a progress bar while waiting for the job to complete
|
|
245
|
+
refresh_interval: Polling interval in seconds while waiting for the job to complete
|
|
246
|
+
max_wait: Maximum time to wait for the job to complete in seconds
|
|
247
|
+
|
|
248
|
+
Note:
|
|
249
|
+
The defaults for the config parameters can be set globally using the
|
|
250
|
+
[`set_config`][orca_sdk.Job.set_config] method.
|
|
251
|
+
|
|
252
|
+
This method will not return the result or raise an exception if the job fails. Call
|
|
253
|
+
[`result`][orca_sdk.Job.result] instead if you want to get the result.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
RuntimeError: If the job times out
|
|
257
|
+
"""
|
|
258
|
+
start_time = time.time()
|
|
259
|
+
show_progress = show_progress if show_progress is not None else self.config["show_progress"]
|
|
260
|
+
refresh_interval = refresh_interval if refresh_interval is not None else self.config["refresh_interval"]
|
|
261
|
+
max_wait = max_wait if max_wait is not None else self.config["max_wait"]
|
|
262
|
+
pbar = None
|
|
263
|
+
while True:
|
|
264
|
+
# setup progress bar if steps total is known
|
|
265
|
+
if not pbar and self.steps_total is not None and show_progress:
|
|
266
|
+
desc = " ".join(self.type.split("_")).lower()
|
|
267
|
+
pbar = tqdm(total=self.steps_total, desc=desc)
|
|
268
|
+
|
|
269
|
+
# return if job is complete
|
|
270
|
+
if self.status in [Status.COMPLETED, Status.FAILED, Status.ABORTED]:
|
|
271
|
+
if pbar:
|
|
272
|
+
pbar.update(self.steps_total - pbar.n)
|
|
273
|
+
pbar.close()
|
|
274
|
+
return
|
|
275
|
+
|
|
276
|
+
# raise error if job timed out
|
|
277
|
+
if (time.time() - start_time) > max_wait:
|
|
278
|
+
raise RuntimeError(f"Job {self.id} timed out after {max_wait}s")
|
|
279
|
+
|
|
280
|
+
# update progress bar
|
|
281
|
+
if pbar and self.steps_completed is not None:
|
|
282
|
+
pbar.update(self.steps_completed - pbar.n)
|
|
283
|
+
|
|
284
|
+
# sleep before retrying
|
|
285
|
+
time.sleep(refresh_interval)
|
|
286
|
+
|
|
287
|
+
def result(
|
|
288
|
+
self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
|
|
289
|
+
) -> TResult:
|
|
290
|
+
"""
|
|
291
|
+
Block until the job is complete and return the result value
|
|
292
|
+
|
|
293
|
+
Params:
|
|
294
|
+
show_progress: Show a progress bar while waiting for the job to complete
|
|
295
|
+
refresh_interval: Polling interval in seconds while waiting for the job to complete
|
|
296
|
+
max_wait: Maximum time to wait for the job to complete in seconds
|
|
297
|
+
|
|
298
|
+
Note:
|
|
299
|
+
The defaults for the config parameters can be set globally using the
|
|
300
|
+
[`set_config`][orca_sdk.Job.set_config] method.
|
|
301
|
+
|
|
302
|
+
This method will raise an exception if the job fails. Use [`wait`][orca_sdk.Job.wait]
|
|
303
|
+
if you just want to wait for the job to complete without raising errors on failure.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
The result value of the job
|
|
307
|
+
|
|
308
|
+
Raises:
|
|
309
|
+
RuntimeError: If the job fails or times out
|
|
310
|
+
"""
|
|
311
|
+
if self.value is not None:
|
|
312
|
+
return self.value
|
|
313
|
+
self.wait(show_progress, refresh_interval, max_wait)
|
|
314
|
+
if self.status != Status.COMPLETED:
|
|
315
|
+
raise RuntimeError(f"Job failed with exception: {self.exception}")
|
|
316
|
+
assert self.value is not None
|
|
317
|
+
return self.value
|
|
318
|
+
|
|
319
|
+
def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait: int = 20) -> None:
|
|
320
|
+
"""
|
|
321
|
+
Abort the job
|
|
322
|
+
|
|
323
|
+
Params:
|
|
324
|
+
show_progress: Optionally show a progress bar while waiting for the job to abort
|
|
325
|
+
refresh_interval: Polling interval in seconds while waiting for the job to abort
|
|
326
|
+
max_wait: Maximum time to wait for the job to abort in seconds
|
|
327
|
+
"""
|
|
328
|
+
abort_task(self.id)
|
|
329
|
+
self.wait(show_progress, refresh_interval, max_wait)
|
orca_sdk/job_test.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
from .classification_model import ClassificationModel
|
|
4
|
+
from .datasource import Datasource
|
|
5
|
+
from .job import Job, Status
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_job_creation(classification_model: ClassificationModel, datasource: Datasource):
|
|
9
|
+
job = classification_model.evaluate(datasource, background=True)
|
|
10
|
+
assert job.id is not None
|
|
11
|
+
assert job.type == "EVALUATE_MODEL"
|
|
12
|
+
assert job.status in [Status.DISPATCHED, Status.PROCESSING]
|
|
13
|
+
assert job.created_at is not None
|
|
14
|
+
assert job.updated_at is not None
|
|
15
|
+
assert job.refreshed_at is not None
|
|
16
|
+
assert len(Job.query(limit=5, type="EVALUATE_MODEL")) >= 1
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_job_result(classification_model: ClassificationModel, datasource: Datasource):
|
|
20
|
+
job = classification_model.evaluate(datasource, background=True)
|
|
21
|
+
result = job.result(show_progress=False)
|
|
22
|
+
assert result is not None
|
|
23
|
+
assert job.status == Status.COMPLETED
|
|
24
|
+
assert job.steps_completed is not None
|
|
25
|
+
assert job.steps_completed == job.steps_total
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_job_wait(classification_model: ClassificationModel, datasource: Datasource):
|
|
29
|
+
job = classification_model.evaluate(datasource, background=True)
|
|
30
|
+
job.wait(show_progress=False)
|
|
31
|
+
assert job.status == Status.COMPLETED
|
|
32
|
+
assert job.steps_completed is not None
|
|
33
|
+
assert job.steps_completed == job.steps_total
|
|
34
|
+
assert job.value is not None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_job_refresh(classification_model: ClassificationModel, datasource: Datasource):
|
|
38
|
+
job = classification_model.evaluate(datasource, background=True)
|
|
39
|
+
last_refreshed_at = job.refreshed_at
|
|
40
|
+
# accessing the status attribute should refresh the job after the refresh interval
|
|
41
|
+
Job.set_config(refresh_interval=1)
|
|
42
|
+
time.sleep(1)
|
|
43
|
+
job.status
|
|
44
|
+
assert job.refreshed_at > last_refreshed_at
|
|
45
|
+
last_refreshed_at = job.refreshed_at
|
|
46
|
+
# calling refresh() should immediately refresh the job
|
|
47
|
+
job.refresh()
|
|
48
|
+
assert job.refreshed_at > last_refreshed_at
|