orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +27 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/classification_model.py +439 -129
- orca_sdk/classification_model_test.py +334 -104
- orca_sdk/client.py +3747 -0
- orca_sdk/conftest.py +164 -19
- orca_sdk/credentials.py +120 -18
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +259 -68
- orca_sdk/datasource_test.py +242 -0
- orca_sdk/embedding_model.py +425 -82
- orca_sdk/embedding_model_test.py +39 -13
- orca_sdk/job.py +337 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1341 -305
- orca_sdk/memoryset_test.py +350 -111
- orca_sdk/regression_model.py +684 -0
- orca_sdk/regression_model_test.py +369 -0
- orca_sdk/telemetry.py +449 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
- orca_sdk-0.1.2.dist-info/RECORD +40 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/embedding_model.py
CHANGED
|
@@ -1,93 +1,324 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from abc import abstractmethod
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import TYPE_CHECKING, Sequence, cast, overload
|
|
6
|
-
|
|
7
|
-
from .
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
embed_with_pretrained_model_gpu,
|
|
12
|
-
get_finetuned_embedding_model,
|
|
13
|
-
get_pretrained_embedding_model,
|
|
14
|
-
list_finetuned_embedding_models,
|
|
15
|
-
list_pretrained_embedding_models,
|
|
16
|
-
)
|
|
17
|
-
from ._generated_api_client.models import (
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
|
|
6
|
+
|
|
7
|
+
from ._shared.metrics import ClassificationMetrics, RegressionMetrics
|
|
8
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
9
|
+
from .client import (
|
|
10
|
+
EmbeddingEvaluationRequest,
|
|
18
11
|
EmbeddingFinetuningMethod,
|
|
19
12
|
EmbedRequest,
|
|
20
13
|
FinetunedEmbeddingModelMetadata,
|
|
21
14
|
FinetuneEmbeddingModelRequest,
|
|
22
|
-
FinetuneEmbeddingModelRequestTrainingArgs,
|
|
23
15
|
PretrainedEmbeddingModelMetadata,
|
|
24
16
|
PretrainedEmbeddingModelName,
|
|
17
|
+
orca_api,
|
|
25
18
|
)
|
|
26
|
-
from ._utils.common import CreateMode, DropMode
|
|
27
|
-
from ._utils.task import TaskStatus, wait_for_task
|
|
28
19
|
from .datasource import Datasource
|
|
20
|
+
from .job import Job, Status
|
|
29
21
|
|
|
30
22
|
if TYPE_CHECKING:
|
|
31
23
|
from .memoryset import LabeledMemoryset
|
|
32
24
|
|
|
33
25
|
|
|
34
|
-
class
|
|
35
|
-
name: str
|
|
26
|
+
class EmbeddingModelBase(ABC):
|
|
36
27
|
embedding_dim: int
|
|
37
28
|
max_seq_length: int
|
|
38
29
|
uses_context: bool
|
|
30
|
+
supports_instructions: bool
|
|
39
31
|
|
|
40
|
-
def __init__(
|
|
41
|
-
self
|
|
32
|
+
def __init__(
|
|
33
|
+
self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool, supports_instructions: bool
|
|
34
|
+
):
|
|
42
35
|
self.embedding_dim = embedding_dim
|
|
43
36
|
self.max_seq_length = max_seq_length
|
|
44
37
|
self.uses_context = uses_context
|
|
38
|
+
self.supports_instructions = supports_instructions
|
|
45
39
|
|
|
46
40
|
@classmethod
|
|
47
41
|
@abstractmethod
|
|
48
|
-
def all(cls) -> Sequence[
|
|
42
|
+
def all(cls) -> Sequence[EmbeddingModelBase]:
|
|
49
43
|
pass
|
|
50
44
|
|
|
45
|
+
def _get_instruction_error_message(self) -> str:
|
|
46
|
+
"""Get error message for instruction not supported"""
|
|
47
|
+
if isinstance(self, FinetunedEmbeddingModel):
|
|
48
|
+
return f"Model {self.name} does not support instructions. Instruction-following is only supported by models based on instruction-supporting models."
|
|
49
|
+
elif isinstance(self, PretrainedEmbeddingModel):
|
|
50
|
+
return f"Model {self.name} does not support instructions. Instruction-following is only supported by instruction-supporting models."
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError("Invalid embedding model")
|
|
53
|
+
|
|
51
54
|
@overload
|
|
52
|
-
def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
|
|
55
|
+
def embed(self, value: str, max_seq_length: int | None = None, prompt: str | None = None) -> list[float]:
|
|
53
56
|
pass
|
|
54
57
|
|
|
55
58
|
@overload
|
|
56
|
-
def embed(
|
|
59
|
+
def embed(
|
|
60
|
+
self, value: list[str], max_seq_length: int | None = None, prompt: str | None = None
|
|
61
|
+
) -> list[list[float]]:
|
|
57
62
|
pass
|
|
58
63
|
|
|
59
|
-
def embed(
|
|
64
|
+
def embed(
|
|
65
|
+
self, value: str | list[str], max_seq_length: int | None = None, prompt: str | None = None
|
|
66
|
+
) -> list[float] | list[list[float]]:
|
|
60
67
|
"""
|
|
61
68
|
Generate embeddings for a value or list of values
|
|
62
69
|
|
|
63
70
|
Params:
|
|
64
71
|
value: The value or list of values to embed
|
|
65
72
|
max_seq_length: The maximum sequence length to truncate the input to
|
|
73
|
+
prompt: Optional prompt for prompt-following embedding models.
|
|
66
74
|
|
|
67
75
|
Returns:
|
|
68
76
|
A matrix of floats representing the embedding for each value if the input is a list of
|
|
69
77
|
values, or a list of floats representing the embedding for the single value if the
|
|
70
78
|
input is a single value
|
|
71
79
|
"""
|
|
72
|
-
|
|
80
|
+
payload: EmbedRequest = {
|
|
81
|
+
"values": value if isinstance(value, list) else [value],
|
|
82
|
+
"max_seq_length": max_seq_length,
|
|
83
|
+
"prompt": prompt,
|
|
84
|
+
}
|
|
73
85
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
74
|
-
embeddings =
|
|
86
|
+
embeddings = orca_api.POST(
|
|
87
|
+
"/gpu/pretrained_embedding_model/{model_name}/embedding",
|
|
88
|
+
params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
|
|
89
|
+
json=payload,
|
|
90
|
+
timeout=30, # may be slow in case of cold start
|
|
91
|
+
)
|
|
75
92
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
76
|
-
embeddings =
|
|
93
|
+
embeddings = orca_api.POST(
|
|
94
|
+
"/gpu/finetuned_embedding_model/{name_or_id}/embedding",
|
|
95
|
+
params={"name_or_id": self.id},
|
|
96
|
+
json=payload,
|
|
97
|
+
timeout=30, # may be slow in case of cold start
|
|
98
|
+
)
|
|
77
99
|
else:
|
|
78
100
|
raise ValueError("Invalid embedding model")
|
|
79
101
|
return embeddings if isinstance(value, list) else embeddings[0]
|
|
80
102
|
|
|
103
|
+
@overload
|
|
104
|
+
def evaluate(
|
|
105
|
+
self,
|
|
106
|
+
datasource: Datasource,
|
|
107
|
+
*,
|
|
108
|
+
value_column: str = "value",
|
|
109
|
+
label_column: str,
|
|
110
|
+
score_column: None = None,
|
|
111
|
+
eval_datasource: Datasource | None = None,
|
|
112
|
+
subsample: int | None = None,
|
|
113
|
+
neighbor_count: int = 5,
|
|
114
|
+
batch_size: int = 32,
|
|
115
|
+
weigh_memories: bool = True,
|
|
116
|
+
background: Literal[True],
|
|
117
|
+
) -> Job[ClassificationMetrics]:
|
|
118
|
+
pass
|
|
81
119
|
|
|
82
|
-
|
|
83
|
-
def
|
|
84
|
-
|
|
85
|
-
|
|
120
|
+
@overload
|
|
121
|
+
def evaluate(
|
|
122
|
+
self,
|
|
123
|
+
datasource: Datasource,
|
|
124
|
+
*,
|
|
125
|
+
value_column: str = "value",
|
|
126
|
+
label_column: str,
|
|
127
|
+
score_column: None = None,
|
|
128
|
+
eval_datasource: Datasource | None = None,
|
|
129
|
+
subsample: int | None = None,
|
|
130
|
+
neighbor_count: int = 5,
|
|
131
|
+
batch_size: int = 32,
|
|
132
|
+
weigh_memories: bool = True,
|
|
133
|
+
background: Literal[False] = False,
|
|
134
|
+
) -> ClassificationMetrics:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
@overload
|
|
138
|
+
def evaluate(
|
|
139
|
+
self,
|
|
140
|
+
datasource: Datasource,
|
|
141
|
+
*,
|
|
142
|
+
value_column: str = "value",
|
|
143
|
+
label_column: None = None,
|
|
144
|
+
score_column: str,
|
|
145
|
+
eval_datasource: Datasource | None = None,
|
|
146
|
+
subsample: int | None = None,
|
|
147
|
+
neighbor_count: int = 5,
|
|
148
|
+
batch_size: int = 32,
|
|
149
|
+
weigh_memories: bool = True,
|
|
150
|
+
background: Literal[True],
|
|
151
|
+
) -> Job[RegressionMetrics]:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
@overload
|
|
155
|
+
def evaluate(
|
|
156
|
+
self,
|
|
157
|
+
datasource: Datasource,
|
|
158
|
+
*,
|
|
159
|
+
value_column: str = "value",
|
|
160
|
+
label_column: None = None,
|
|
161
|
+
score_column: str,
|
|
162
|
+
eval_datasource: Datasource | None = None,
|
|
163
|
+
subsample: int | None = None,
|
|
164
|
+
neighbor_count: int = 5,
|
|
165
|
+
batch_size: int = 32,
|
|
166
|
+
weigh_memories: bool = True,
|
|
167
|
+
background: Literal[False] = False,
|
|
168
|
+
) -> RegressionMetrics:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
def evaluate(
|
|
172
|
+
self,
|
|
173
|
+
datasource: Datasource,
|
|
174
|
+
*,
|
|
175
|
+
value_column: str = "value",
|
|
176
|
+
label_column: str | None = None,
|
|
177
|
+
score_column: str | None = None,
|
|
178
|
+
eval_datasource: Datasource | None = None,
|
|
179
|
+
subsample: int | None = None,
|
|
180
|
+
neighbor_count: int = 5,
|
|
181
|
+
batch_size: int = 32,
|
|
182
|
+
weigh_memories: bool = True,
|
|
183
|
+
background: bool = False,
|
|
184
|
+
) -> (
|
|
185
|
+
ClassificationMetrics
|
|
186
|
+
| RegressionMetrics
|
|
187
|
+
| Job[ClassificationMetrics]
|
|
188
|
+
| Job[RegressionMetrics]
|
|
189
|
+
| Job[ClassificationMetrics | RegressionMetrics]
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Evaluate the finetuned embedding model
|
|
193
|
+
"""
|
|
194
|
+
payload: EmbeddingEvaluationRequest = {
|
|
195
|
+
"datasource_name_or_id": datasource.id,
|
|
196
|
+
"datasource_label_column": label_column,
|
|
197
|
+
"datasource_value_column": value_column,
|
|
198
|
+
"datasource_score_column": score_column,
|
|
199
|
+
"eval_datasource_name_or_id": eval_datasource.id if eval_datasource is not None else None,
|
|
200
|
+
"subsample": subsample,
|
|
201
|
+
"neighbor_count": neighbor_count,
|
|
202
|
+
"batch_size": batch_size,
|
|
203
|
+
"weigh_memories": weigh_memories,
|
|
204
|
+
}
|
|
205
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
206
|
+
response = orca_api.POST(
|
|
207
|
+
"/pretrained_embedding_model/{model_name}/evaluation",
|
|
208
|
+
params={"model_name": self.name},
|
|
209
|
+
json=payload,
|
|
210
|
+
)
|
|
211
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
212
|
+
response = orca_api.POST(
|
|
213
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation",
|
|
214
|
+
params={"name_or_id": self.id},
|
|
215
|
+
json=payload,
|
|
216
|
+
)
|
|
86
217
|
else:
|
|
87
|
-
raise
|
|
218
|
+
raise ValueError("Invalid embedding model")
|
|
88
219
|
|
|
220
|
+
def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
|
|
221
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
222
|
+
res = orca_api.GET(
|
|
223
|
+
"/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
|
|
224
|
+
params={"model_name": self.name, "task_id": task_id},
|
|
225
|
+
)["result"]
|
|
226
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
227
|
+
res = orca_api.GET(
|
|
228
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
|
|
229
|
+
params={"name_or_id": self.id, "task_id": task_id},
|
|
230
|
+
)["result"]
|
|
231
|
+
else:
|
|
232
|
+
raise ValueError("Invalid embedding model")
|
|
233
|
+
assert res is not None
|
|
234
|
+
return (
|
|
235
|
+
RegressionMetrics(
|
|
236
|
+
coverage=res.get("coverage"),
|
|
237
|
+
mse=res.get("mse"),
|
|
238
|
+
rmse=res.get("rmse"),
|
|
239
|
+
mae=res.get("mae"),
|
|
240
|
+
r2=res.get("r2"),
|
|
241
|
+
explained_variance=res.get("explained_variance"),
|
|
242
|
+
loss=res.get("loss"),
|
|
243
|
+
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
244
|
+
anomaly_score_median=res.get("anomaly_score_median"),
|
|
245
|
+
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
246
|
+
)
|
|
247
|
+
if "mse" in res
|
|
248
|
+
else ClassificationMetrics(
|
|
249
|
+
coverage=res.get("coverage"),
|
|
250
|
+
f1_score=res.get("f1_score"),
|
|
251
|
+
accuracy=res.get("accuracy"),
|
|
252
|
+
loss=res.get("loss"),
|
|
253
|
+
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
254
|
+
anomaly_score_median=res.get("anomaly_score_median"),
|
|
255
|
+
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
256
|
+
roc_auc=res.get("roc_auc"),
|
|
257
|
+
pr_auc=res.get("pr_auc"),
|
|
258
|
+
pr_curve=res.get("pr_curve"),
|
|
259
|
+
roc_curve=res.get("roc_curve"),
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
job = Job(response["task_id"], lambda: get_result(response["task_id"]))
|
|
264
|
+
return job if background else job.result()
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class _ModelDescriptor:
|
|
268
|
+
"""
|
|
269
|
+
Descriptor for lazily loading embedding models with IDE autocomplete support.
|
|
270
|
+
|
|
271
|
+
This class implements the descriptor protocol to provide lazy loading of embedding models
|
|
272
|
+
while maintaining IDE autocomplete functionality. It delays the actual loading of models
|
|
273
|
+
until they are accessed, which improves startup performance.
|
|
89
274
|
|
|
90
|
-
|
|
275
|
+
The descriptor pattern works by defining how attribute access is handled. When a class
|
|
276
|
+
attribute using this descriptor is accessed, the __get__ method is called, which then
|
|
277
|
+
retrieves or initializes the actual model on first access.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
def __init__(self, name: str):
|
|
281
|
+
"""
|
|
282
|
+
Initialize a model descriptor.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
name: The name of the embedding model in PretrainedEmbeddingModelName
|
|
286
|
+
"""
|
|
287
|
+
self.name = name
|
|
288
|
+
self.model = None # Model is loaded lazily on first access
|
|
289
|
+
|
|
290
|
+
def __get__(self, instance, owner_class):
|
|
291
|
+
"""
|
|
292
|
+
Descriptor protocol method called when the attribute is accessed.
|
|
293
|
+
|
|
294
|
+
This method implements lazy loading - the actual model is only initialized
|
|
295
|
+
the first time it's accessed. Subsequent accesses will use the cached model.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
instance: The instance the attribute was accessed from, or None if accessed from the class
|
|
299
|
+
owner_class: The class that owns the descriptor
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
The initialized embedding model
|
|
303
|
+
|
|
304
|
+
Raises:
|
|
305
|
+
AttributeError: If no model with the given name exists
|
|
306
|
+
"""
|
|
307
|
+
# When accessed from an instance, redirect to class access
|
|
308
|
+
if instance is not None:
|
|
309
|
+
return self.__get__(None, owner_class)
|
|
310
|
+
|
|
311
|
+
# Load the model on first access
|
|
312
|
+
if self.model is None:
|
|
313
|
+
try:
|
|
314
|
+
self.model = PretrainedEmbeddingModel._get(cast(PretrainedEmbeddingModelName, self.name))
|
|
315
|
+
except (KeyError, AttributeError):
|
|
316
|
+
raise AttributeError(f"No embedding model named {self.name}")
|
|
317
|
+
|
|
318
|
+
return self.model
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
91
322
|
"""
|
|
92
323
|
A pretrained embedding model
|
|
93
324
|
|
|
@@ -100,28 +331,60 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
100
331
|
- **`CDE_SMALL`**: Context-aware CDE small model from Hugging Face ([jxm/cde-small-v1](https://huggingface.co/jxm/cde-small-v1))
|
|
101
332
|
- **`CLIP_BASE`**: Multi-modal CLIP model from Hugging Face ([sentence-transformers/clip-ViT-L-14](https://huggingface.co/sentence-transformers/clip-ViT-L-14))
|
|
102
333
|
- **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
|
|
334
|
+
- **`DISTILBERT`**: DistilBERT embedding model from Hugging Face ([distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased))
|
|
335
|
+
- **`GTE_SMALL`**: GTE-Small embedding model from Hugging Face ([Supabase/gte-small](https://huggingface.co/Supabase/gte-small))
|
|
336
|
+
- **`E5_LARGE`**: E5-Large instruction-tuned embedding model from Hugging Face ([intfloat/multilingual-e5-large-instruct](https://huggingface.co/intfloat/multilingual-e5-large-instruct))
|
|
337
|
+
- **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
|
|
338
|
+
- **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
|
|
339
|
+
- **`QWEN2_1_5B`**: Alibaba's Qwen2-1.5B instruction-tuned embedding model from Hugging Face ([Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct))
|
|
340
|
+
- **`BGE_BASE`**: BAAI's BGE-Base instruction-tuned embedding model from Hugging Face ([BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5))
|
|
341
|
+
|
|
342
|
+
**Instruction Support:**
|
|
343
|
+
|
|
344
|
+
Some models support instruction-following for better task-specific embeddings. You can check if a model supports instructions
|
|
345
|
+
using the `supports_instructions` attribute.
|
|
103
346
|
|
|
104
347
|
Examples:
|
|
105
348
|
>>> PretrainedEmbeddingModel.CDE_SMALL
|
|
106
349
|
PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
|
|
107
350
|
|
|
351
|
+
>>> # Using instruction with an instruction-supporting model
|
|
352
|
+
>>> model = PretrainedEmbeddingModel.E5_LARGE
|
|
353
|
+
>>> embeddings = model.embed("Hello world", prompt="Represent this sentence for retrieval:")
|
|
354
|
+
|
|
108
355
|
Attributes:
|
|
109
356
|
name: Name of the pretrained embedding model
|
|
110
357
|
embedding_dim: Dimension of the embeddings that are generated by the model
|
|
111
358
|
max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
|
|
112
359
|
uses_context: Whether the pretrained embedding model uses context
|
|
360
|
+
supports_instructions: Whether this model supports instruction-following
|
|
113
361
|
"""
|
|
114
362
|
|
|
115
|
-
|
|
363
|
+
# Define descriptors for model access with IDE autocomplete
|
|
364
|
+
CDE_SMALL = _ModelDescriptor("CDE_SMALL")
|
|
365
|
+
CLIP_BASE = _ModelDescriptor("CLIP_BASE")
|
|
366
|
+
GTE_BASE = _ModelDescriptor("GTE_BASE")
|
|
367
|
+
DISTILBERT = _ModelDescriptor("DISTILBERT")
|
|
368
|
+
GTE_SMALL = _ModelDescriptor("GTE_SMALL")
|
|
369
|
+
E5_LARGE = _ModelDescriptor("E5_LARGE")
|
|
370
|
+
GIST_LARGE = _ModelDescriptor("GIST_LARGE")
|
|
371
|
+
MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
|
|
372
|
+
QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
|
|
373
|
+
BGE_BASE = _ModelDescriptor("BGE_BASE")
|
|
374
|
+
|
|
375
|
+
name: PretrainedEmbeddingModelName
|
|
116
376
|
|
|
117
377
|
def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
|
|
118
378
|
# for internal use only, do not document
|
|
119
|
-
self.
|
|
379
|
+
self.name = metadata["name"]
|
|
120
380
|
super().__init__(
|
|
121
|
-
name=metadata
|
|
122
|
-
embedding_dim=metadata
|
|
123
|
-
max_seq_length=metadata
|
|
124
|
-
uses_context=metadata
|
|
381
|
+
name=metadata["name"],
|
|
382
|
+
embedding_dim=metadata["embedding_dim"],
|
|
383
|
+
max_seq_length=metadata["max_seq_length"],
|
|
384
|
+
uses_context=metadata["uses_context"],
|
|
385
|
+
supports_instructions=(
|
|
386
|
+
bool(metadata["supports_instructions"]) if "supports_instructions" in metadata else False
|
|
387
|
+
),
|
|
125
388
|
)
|
|
126
389
|
|
|
127
390
|
def __eq__(self, other) -> bool:
|
|
@@ -138,16 +401,44 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
138
401
|
Returns:
|
|
139
402
|
A list of all pretrained embedding models available in the OrcaCloud
|
|
140
403
|
"""
|
|
141
|
-
return [cls(metadata) for metadata in
|
|
404
|
+
return [cls(metadata) for metadata in orca_api.GET("/pretrained_embedding_model")]
|
|
142
405
|
|
|
143
406
|
_instances: dict[str, PretrainedEmbeddingModel] = {}
|
|
144
407
|
|
|
145
408
|
@classmethod
|
|
146
|
-
def _get(cls, name: PretrainedEmbeddingModelName
|
|
409
|
+
def _get(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
|
|
147
410
|
# for internal use only, do not document - we want people to use dot notation to get the model
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
411
|
+
cache_key = str(name)
|
|
412
|
+
if cache_key not in cls._instances:
|
|
413
|
+
metadata = orca_api.GET(
|
|
414
|
+
"/pretrained_embedding_model/{model_name}",
|
|
415
|
+
params={"model_name": name},
|
|
416
|
+
)
|
|
417
|
+
cls._instances[cache_key] = cls(metadata)
|
|
418
|
+
return cls._instances[cache_key]
|
|
419
|
+
|
|
420
|
+
@classmethod
|
|
421
|
+
def open(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
|
|
422
|
+
"""
|
|
423
|
+
Open an embedding model by name.
|
|
424
|
+
|
|
425
|
+
This is an alternative method to access models for environments
|
|
426
|
+
where IDE autocomplete for model names is not available.
|
|
427
|
+
|
|
428
|
+
Params:
|
|
429
|
+
name: Name of the model to open (e.g., "GTE_BASE", "CLIP_BASE")
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
The embedding model instance
|
|
433
|
+
|
|
434
|
+
Examples:
|
|
435
|
+
>>> model = PretrainedEmbeddingModel.open("GTE_BASE")
|
|
436
|
+
"""
|
|
437
|
+
try:
|
|
438
|
+
# Always use the _get method which handles caching properly
|
|
439
|
+
return cls._get(name)
|
|
440
|
+
except (KeyError, AttributeError):
|
|
441
|
+
raise ValueError(f"Unknown model name: {name}")
|
|
151
442
|
|
|
152
443
|
@classmethod
|
|
153
444
|
def exists(cls, name: str) -> bool:
|
|
@@ -160,8 +451,25 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
160
451
|
Returns:
|
|
161
452
|
True if the pretrained embedding model exists, False otherwise
|
|
162
453
|
"""
|
|
163
|
-
return name in PretrainedEmbeddingModelName
|
|
454
|
+
return name in get_args(PretrainedEmbeddingModelName)
|
|
455
|
+
|
|
456
|
+
@overload
|
|
457
|
+
def finetune(
|
|
458
|
+
self,
|
|
459
|
+
name: str,
|
|
460
|
+
train_datasource: Datasource | LabeledMemoryset,
|
|
461
|
+
*,
|
|
462
|
+
eval_datasource: Datasource | None = None,
|
|
463
|
+
label_column: str = "label",
|
|
464
|
+
value_column: str = "value",
|
|
465
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
466
|
+
training_args: dict | None = None,
|
|
467
|
+
if_exists: CreateMode = "error",
|
|
468
|
+
background: Literal[True],
|
|
469
|
+
) -> Job[FinetunedEmbeddingModel]:
|
|
470
|
+
pass
|
|
164
471
|
|
|
472
|
+
@overload
|
|
165
473
|
def finetune(
|
|
166
474
|
self,
|
|
167
475
|
name: str,
|
|
@@ -170,10 +478,26 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
170
478
|
eval_datasource: Datasource | None = None,
|
|
171
479
|
label_column: str = "label",
|
|
172
480
|
value_column: str = "value",
|
|
173
|
-
training_method: EmbeddingFinetuningMethod
|
|
481
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
174
482
|
training_args: dict | None = None,
|
|
175
483
|
if_exists: CreateMode = "error",
|
|
484
|
+
background: Literal[False] = False,
|
|
176
485
|
) -> FinetunedEmbeddingModel:
|
|
486
|
+
pass
|
|
487
|
+
|
|
488
|
+
def finetune(
|
|
489
|
+
self,
|
|
490
|
+
name: str,
|
|
491
|
+
train_datasource: Datasource | LabeledMemoryset,
|
|
492
|
+
*,
|
|
493
|
+
eval_datasource: Datasource | None = None,
|
|
494
|
+
label_column: str = "label",
|
|
495
|
+
value_column: str = "value",
|
|
496
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
497
|
+
training_args: dict | None = None,
|
|
498
|
+
if_exists: CreateMode = "error",
|
|
499
|
+
background: bool = False,
|
|
500
|
+
) -> FinetunedEmbeddingModel | Job[FinetunedEmbeddingModel]:
|
|
177
501
|
"""
|
|
178
502
|
Finetune an embedding model
|
|
179
503
|
|
|
@@ -184,10 +508,11 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
184
508
|
label_column: Column name of the label
|
|
185
509
|
value_column: Column name of the value
|
|
186
510
|
training_method: Training method to use
|
|
187
|
-
training_args: Optional override for Hugging Face [`TrainingArguments`]
|
|
511
|
+
training_args: Optional override for Hugging Face [`TrainingArguments`][transformers.TrainingArguments].
|
|
188
512
|
If not provided, reasonable training arguments will be used for the specified training method
|
|
189
513
|
if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
|
|
190
514
|
`"error"`. Other option is `"open"` to open the existing finetuned embedding model.
|
|
515
|
+
background: Whether to run the operation in the background and return a job handle
|
|
191
516
|
|
|
192
517
|
Returns:
|
|
193
518
|
The finetuned embedding model
|
|
@@ -208,34 +533,40 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
208
533
|
elif exists and if_exists == "open":
|
|
209
534
|
existing = FinetunedEmbeddingModel.open(name)
|
|
210
535
|
|
|
211
|
-
if existing.base_model_name != self.
|
|
536
|
+
if existing.base_model_name != self.name:
|
|
212
537
|
raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
|
|
213
538
|
|
|
214
539
|
return existing
|
|
215
540
|
|
|
216
541
|
from .memoryset import LabeledMemoryset
|
|
217
542
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
543
|
+
payload: FinetuneEmbeddingModelRequest = {
|
|
544
|
+
"name": name,
|
|
545
|
+
"base_model": self.name,
|
|
546
|
+
"label_column": label_column,
|
|
547
|
+
"value_column": value_column,
|
|
548
|
+
"training_method": training_method,
|
|
549
|
+
"training_args": training_args or {},
|
|
550
|
+
}
|
|
551
|
+
if isinstance(train_datasource, Datasource):
|
|
552
|
+
payload["train_datasource_name_or_id"] = train_datasource.id
|
|
553
|
+
elif isinstance(train_datasource, LabeledMemoryset):
|
|
554
|
+
payload["train_memoryset_name_or_id"] = train_datasource.id
|
|
555
|
+
if eval_datasource is not None:
|
|
556
|
+
payload["eval_datasource_name_or_id"] = eval_datasource.id
|
|
557
|
+
|
|
558
|
+
res = orca_api.POST(
|
|
559
|
+
"/finetuned_embedding_model",
|
|
560
|
+
json=payload,
|
|
233
561
|
)
|
|
234
|
-
|
|
235
|
-
|
|
562
|
+
job = Job(
|
|
563
|
+
res["finetuning_task_id"],
|
|
564
|
+
lambda: FinetunedEmbeddingModel.open(res["id"]),
|
|
565
|
+
)
|
|
566
|
+
return job if background else job.result()
|
|
236
567
|
|
|
237
568
|
|
|
238
|
-
class FinetunedEmbeddingModel(
|
|
569
|
+
class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
239
570
|
"""
|
|
240
571
|
A finetuned embedding model in the OrcaCloud
|
|
241
572
|
|
|
@@ -250,22 +581,27 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
250
581
|
"""
|
|
251
582
|
|
|
252
583
|
id: str
|
|
584
|
+
name: str
|
|
253
585
|
created_at: datetime
|
|
254
586
|
updated_at: datetime
|
|
255
|
-
|
|
587
|
+
base_model_name: PretrainedEmbeddingModelName
|
|
588
|
+
_status: Status
|
|
256
589
|
|
|
257
590
|
def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
|
|
258
591
|
# for internal use only, do not document
|
|
259
|
-
self.id = metadata
|
|
260
|
-
self.
|
|
261
|
-
self.
|
|
262
|
-
self.
|
|
263
|
-
self.
|
|
592
|
+
self.id = metadata["id"]
|
|
593
|
+
self.name = metadata["name"]
|
|
594
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
595
|
+
self.updated_at = datetime.fromisoformat(metadata["updated_at"])
|
|
596
|
+
self.base_model_name = metadata["base_model"]
|
|
597
|
+
self._status = Status(metadata["finetuning_status"])
|
|
598
|
+
|
|
264
599
|
super().__init__(
|
|
265
|
-
name=metadata
|
|
266
|
-
embedding_dim=metadata
|
|
267
|
-
max_seq_length=metadata
|
|
268
|
-
uses_context=metadata
|
|
600
|
+
name=metadata["name"],
|
|
601
|
+
embedding_dim=metadata["embedding_dim"],
|
|
602
|
+
max_seq_length=metadata["max_seq_length"],
|
|
603
|
+
uses_context=metadata["uses_context"],
|
|
604
|
+
supports_instructions=self.base_model.supports_instructions,
|
|
269
605
|
)
|
|
270
606
|
|
|
271
607
|
def __eq__(self, other) -> bool:
|
|
@@ -277,7 +613,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
277
613
|
f" name: {self.name},\n"
|
|
278
614
|
f" embedding_dim: {self.embedding_dim},\n"
|
|
279
615
|
f" max_seq_length: {self.max_seq_length},\n"
|
|
280
|
-
f" base_model: PretrainedEmbeddingModel.{self.base_model_name
|
|
616
|
+
f" base_model: PretrainedEmbeddingModel.{self.base_model_name}\n"
|
|
281
617
|
"})"
|
|
282
618
|
)
|
|
283
619
|
|
|
@@ -294,7 +630,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
294
630
|
Returns:
|
|
295
631
|
A list of all finetuned embedding model handles in the OrcaCloud
|
|
296
632
|
"""
|
|
297
|
-
return [cls(metadata) for metadata in
|
|
633
|
+
return [cls(metadata) for metadata in orca_api.GET("/finetuned_embedding_model")]
|
|
298
634
|
|
|
299
635
|
@classmethod
|
|
300
636
|
def open(cls, name: str) -> FinetunedEmbeddingModel:
|
|
@@ -310,7 +646,11 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
310
646
|
Raises:
|
|
311
647
|
LookupError: If the finetuned embedding model does not exist
|
|
312
648
|
"""
|
|
313
|
-
|
|
649
|
+
metadata = orca_api.GET(
|
|
650
|
+
"/finetuned_embedding_model/{name_or_id}",
|
|
651
|
+
params={"name_or_id": name},
|
|
652
|
+
)
|
|
653
|
+
return cls(metadata)
|
|
314
654
|
|
|
315
655
|
@classmethod
|
|
316
656
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -341,7 +681,10 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
341
681
|
LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
|
|
342
682
|
"""
|
|
343
683
|
try:
|
|
344
|
-
|
|
345
|
-
|
|
684
|
+
orca_api.DELETE(
|
|
685
|
+
"/finetuned_embedding_model/{name_or_id}",
|
|
686
|
+
params={"name_or_id": name_or_id},
|
|
687
|
+
)
|
|
688
|
+
except (LookupError, RuntimeError):
|
|
346
689
|
if if_not_exists == "error":
|
|
347
690
|
raise
|