orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +31 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +601 -129
- orca_sdk/classification_model_test.py +415 -117
- orca_sdk/client.py +3787 -0
- orca_sdk/conftest.py +184 -38
- orca_sdk/credentials.py +162 -20
- orca_sdk/credentials_test.py +100 -16
- orca_sdk/datasource.py +268 -68
- orca_sdk/datasource_test.py +266 -18
- orca_sdk/embedding_model.py +434 -82
- orca_sdk/embedding_model_test.py +66 -33
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1690 -324
- orca_sdk/memoryset_test.py +456 -119
- orca_sdk/regression_model.py +694 -0
- orca_sdk/regression_model_test.py +378 -0
- orca_sdk/telemetry.py +460 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/embedding_model.py
CHANGED
|
@@ -1,93 +1,327 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from abc import abstractmethod
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import TYPE_CHECKING, Sequence, cast, overload
|
|
6
|
-
|
|
7
|
-
from .
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
embed_with_pretrained_model_gpu,
|
|
12
|
-
get_finetuned_embedding_model,
|
|
13
|
-
get_pretrained_embedding_model,
|
|
14
|
-
list_finetuned_embedding_models,
|
|
15
|
-
list_pretrained_embedding_models,
|
|
16
|
-
)
|
|
17
|
-
from ._generated_api_client.models import (
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
|
|
6
|
+
|
|
7
|
+
from ._shared.metrics import ClassificationMetrics, RegressionMetrics
|
|
8
|
+
from ._utils.common import UNSET, CreateMode, DropMode
|
|
9
|
+
from .client import (
|
|
10
|
+
EmbeddingEvaluationRequest,
|
|
18
11
|
EmbeddingFinetuningMethod,
|
|
19
12
|
EmbedRequest,
|
|
20
13
|
FinetunedEmbeddingModelMetadata,
|
|
21
14
|
FinetuneEmbeddingModelRequest,
|
|
22
|
-
|
|
15
|
+
OrcaClient,
|
|
23
16
|
PretrainedEmbeddingModelMetadata,
|
|
24
17
|
PretrainedEmbeddingModelName,
|
|
25
18
|
)
|
|
26
|
-
from ._utils.common import CreateMode, DropMode
|
|
27
|
-
from ._utils.task import TaskStatus, wait_for_task
|
|
28
19
|
from .datasource import Datasource
|
|
20
|
+
from .job import Job, Status
|
|
29
21
|
|
|
30
22
|
if TYPE_CHECKING:
|
|
31
23
|
from .memoryset import LabeledMemoryset
|
|
32
24
|
|
|
33
25
|
|
|
34
|
-
class
|
|
35
|
-
name: str
|
|
26
|
+
class EmbeddingModelBase(ABC):
|
|
36
27
|
embedding_dim: int
|
|
37
28
|
max_seq_length: int
|
|
38
29
|
uses_context: bool
|
|
30
|
+
supports_instructions: bool
|
|
39
31
|
|
|
40
|
-
def __init__(
|
|
41
|
-
self
|
|
32
|
+
def __init__(
|
|
33
|
+
self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool, supports_instructions: bool
|
|
34
|
+
):
|
|
42
35
|
self.embedding_dim = embedding_dim
|
|
43
36
|
self.max_seq_length = max_seq_length
|
|
44
37
|
self.uses_context = uses_context
|
|
38
|
+
self.supports_instructions = supports_instructions
|
|
45
39
|
|
|
46
40
|
@classmethod
|
|
47
41
|
@abstractmethod
|
|
48
|
-
def all(cls) -> Sequence[
|
|
42
|
+
def all(cls) -> Sequence[EmbeddingModelBase]:
|
|
49
43
|
pass
|
|
50
44
|
|
|
45
|
+
def _get_instruction_error_message(self) -> str:
|
|
46
|
+
"""Get error message for instruction not supported"""
|
|
47
|
+
if isinstance(self, FinetunedEmbeddingModel):
|
|
48
|
+
return f"Model {self.name} does not support instructions. Instruction-following is only supported by models based on instruction-supporting models."
|
|
49
|
+
elif isinstance(self, PretrainedEmbeddingModel):
|
|
50
|
+
return f"Model {self.name} does not support instructions. Instruction-following is only supported by instruction-supporting models."
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError("Invalid embedding model")
|
|
53
|
+
|
|
51
54
|
@overload
|
|
52
|
-
def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
|
|
55
|
+
def embed(self, value: str, max_seq_length: int | None = None, prompt: str | None = None) -> list[float]:
|
|
53
56
|
pass
|
|
54
57
|
|
|
55
58
|
@overload
|
|
56
|
-
def embed(
|
|
59
|
+
def embed(
|
|
60
|
+
self, value: list[str], max_seq_length: int | None = None, prompt: str | None = None
|
|
61
|
+
) -> list[list[float]]:
|
|
57
62
|
pass
|
|
58
63
|
|
|
59
|
-
def embed(
|
|
64
|
+
def embed(
|
|
65
|
+
self, value: str | list[str], max_seq_length: int | None = None, prompt: str | None = None
|
|
66
|
+
) -> list[float] | list[list[float]]:
|
|
60
67
|
"""
|
|
61
68
|
Generate embeddings for a value or list of values
|
|
62
69
|
|
|
63
70
|
Params:
|
|
64
71
|
value: The value or list of values to embed
|
|
65
72
|
max_seq_length: The maximum sequence length to truncate the input to
|
|
73
|
+
prompt: Optional prompt for prompt-following embedding models.
|
|
66
74
|
|
|
67
75
|
Returns:
|
|
68
76
|
A matrix of floats representing the embedding for each value if the input is a list of
|
|
69
77
|
values, or a list of floats representing the embedding for the single value if the
|
|
70
78
|
input is a single value
|
|
71
79
|
"""
|
|
72
|
-
|
|
80
|
+
payload: EmbedRequest = {
|
|
81
|
+
"values": value if isinstance(value, list) else [value],
|
|
82
|
+
"max_seq_length": max_seq_length,
|
|
83
|
+
"prompt": prompt,
|
|
84
|
+
}
|
|
85
|
+
client = OrcaClient._resolve_client()
|
|
73
86
|
if isinstance(self, PretrainedEmbeddingModel):
|
|
74
|
-
embeddings =
|
|
87
|
+
embeddings = client.POST(
|
|
88
|
+
"/gpu/pretrained_embedding_model/{model_name}/embedding",
|
|
89
|
+
params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
|
|
90
|
+
json=payload,
|
|
91
|
+
timeout=30, # may be slow in case of cold start
|
|
92
|
+
)
|
|
75
93
|
elif isinstance(self, FinetunedEmbeddingModel):
|
|
76
|
-
embeddings =
|
|
94
|
+
embeddings = client.POST(
|
|
95
|
+
"/gpu/finetuned_embedding_model/{name_or_id}/embedding",
|
|
96
|
+
params={"name_or_id": self.id},
|
|
97
|
+
json=payload,
|
|
98
|
+
timeout=30, # may be slow in case of cold start
|
|
99
|
+
)
|
|
77
100
|
else:
|
|
78
101
|
raise ValueError("Invalid embedding model")
|
|
79
102
|
return embeddings if isinstance(value, list) else embeddings[0]
|
|
80
103
|
|
|
104
|
+
@overload
|
|
105
|
+
def evaluate(
|
|
106
|
+
self,
|
|
107
|
+
datasource: Datasource,
|
|
108
|
+
*,
|
|
109
|
+
value_column: str = "value",
|
|
110
|
+
label_column: str,
|
|
111
|
+
score_column: None = None,
|
|
112
|
+
eval_datasource: Datasource | None = None,
|
|
113
|
+
subsample: int | None = None,
|
|
114
|
+
neighbor_count: int = 5,
|
|
115
|
+
batch_size: int = 32,
|
|
116
|
+
weigh_memories: bool = True,
|
|
117
|
+
background: Literal[True],
|
|
118
|
+
) -> Job[ClassificationMetrics]:
|
|
119
|
+
pass
|
|
81
120
|
|
|
82
|
-
|
|
83
|
-
def
|
|
84
|
-
|
|
85
|
-
|
|
121
|
+
@overload
|
|
122
|
+
def evaluate(
|
|
123
|
+
self,
|
|
124
|
+
datasource: Datasource,
|
|
125
|
+
*,
|
|
126
|
+
value_column: str = "value",
|
|
127
|
+
label_column: str,
|
|
128
|
+
score_column: None = None,
|
|
129
|
+
eval_datasource: Datasource | None = None,
|
|
130
|
+
subsample: int | None = None,
|
|
131
|
+
neighbor_count: int = 5,
|
|
132
|
+
batch_size: int = 32,
|
|
133
|
+
weigh_memories: bool = True,
|
|
134
|
+
background: Literal[False] = False,
|
|
135
|
+
) -> ClassificationMetrics:
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
@overload
|
|
139
|
+
def evaluate(
|
|
140
|
+
self,
|
|
141
|
+
datasource: Datasource,
|
|
142
|
+
*,
|
|
143
|
+
value_column: str = "value",
|
|
144
|
+
label_column: None = None,
|
|
145
|
+
score_column: str,
|
|
146
|
+
eval_datasource: Datasource | None = None,
|
|
147
|
+
subsample: int | None = None,
|
|
148
|
+
neighbor_count: int = 5,
|
|
149
|
+
batch_size: int = 32,
|
|
150
|
+
weigh_memories: bool = True,
|
|
151
|
+
background: Literal[True],
|
|
152
|
+
) -> Job[RegressionMetrics]:
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
@overload
|
|
156
|
+
def evaluate(
|
|
157
|
+
self,
|
|
158
|
+
datasource: Datasource,
|
|
159
|
+
*,
|
|
160
|
+
value_column: str = "value",
|
|
161
|
+
label_column: None = None,
|
|
162
|
+
score_column: str,
|
|
163
|
+
eval_datasource: Datasource | None = None,
|
|
164
|
+
subsample: int | None = None,
|
|
165
|
+
neighbor_count: int = 5,
|
|
166
|
+
batch_size: int = 32,
|
|
167
|
+
weigh_memories: bool = True,
|
|
168
|
+
background: Literal[False] = False,
|
|
169
|
+
) -> RegressionMetrics:
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
def evaluate(
|
|
173
|
+
self,
|
|
174
|
+
datasource: Datasource,
|
|
175
|
+
*,
|
|
176
|
+
value_column: str = "value",
|
|
177
|
+
label_column: str | None = None,
|
|
178
|
+
score_column: str | None = None,
|
|
179
|
+
eval_datasource: Datasource | None = None,
|
|
180
|
+
subsample: int | None = None,
|
|
181
|
+
neighbor_count: int = 5,
|
|
182
|
+
batch_size: int = 32,
|
|
183
|
+
weigh_memories: bool = True,
|
|
184
|
+
background: bool = False,
|
|
185
|
+
) -> (
|
|
186
|
+
ClassificationMetrics
|
|
187
|
+
| RegressionMetrics
|
|
188
|
+
| Job[ClassificationMetrics]
|
|
189
|
+
| Job[RegressionMetrics]
|
|
190
|
+
| Job[ClassificationMetrics | RegressionMetrics]
|
|
191
|
+
):
|
|
192
|
+
"""
|
|
193
|
+
Evaluate the finetuned embedding model
|
|
194
|
+
"""
|
|
195
|
+
payload: EmbeddingEvaluationRequest = {
|
|
196
|
+
"datasource_name_or_id": datasource.id,
|
|
197
|
+
"datasource_label_column": label_column,
|
|
198
|
+
"datasource_value_column": value_column,
|
|
199
|
+
"datasource_score_column": score_column,
|
|
200
|
+
"eval_datasource_name_or_id": eval_datasource.id if eval_datasource is not None else None,
|
|
201
|
+
"subsample": subsample,
|
|
202
|
+
"neighbor_count": neighbor_count,
|
|
203
|
+
"batch_size": batch_size,
|
|
204
|
+
"weigh_memories": weigh_memories,
|
|
205
|
+
}
|
|
206
|
+
client = OrcaClient._resolve_client()
|
|
207
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
208
|
+
response = client.POST(
|
|
209
|
+
"/pretrained_embedding_model/{model_name}/evaluation",
|
|
210
|
+
params={"model_name": self.name},
|
|
211
|
+
json=payload,
|
|
212
|
+
)
|
|
213
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
214
|
+
response = client.POST(
|
|
215
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation",
|
|
216
|
+
params={"name_or_id": self.id},
|
|
217
|
+
json=payload,
|
|
218
|
+
)
|
|
86
219
|
else:
|
|
87
|
-
raise
|
|
220
|
+
raise ValueError("Invalid embedding model")
|
|
88
221
|
|
|
222
|
+
def get_result(task_id: str) -> ClassificationMetrics | RegressionMetrics:
|
|
223
|
+
client = OrcaClient._resolve_client()
|
|
224
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
225
|
+
res = client.GET(
|
|
226
|
+
"/pretrained_embedding_model/{model_name}/evaluation/{task_id}",
|
|
227
|
+
params={"model_name": self.name, "task_id": task_id},
|
|
228
|
+
)["result"]
|
|
229
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
230
|
+
res = client.GET(
|
|
231
|
+
"/finetuned_embedding_model/{name_or_id}/evaluation/{task_id}",
|
|
232
|
+
params={"name_or_id": self.id, "task_id": task_id},
|
|
233
|
+
)["result"]
|
|
234
|
+
else:
|
|
235
|
+
raise ValueError("Invalid embedding model")
|
|
236
|
+
assert res is not None
|
|
237
|
+
return (
|
|
238
|
+
RegressionMetrics(
|
|
239
|
+
coverage=res.get("coverage"),
|
|
240
|
+
mse=res.get("mse"),
|
|
241
|
+
rmse=res.get("rmse"),
|
|
242
|
+
mae=res.get("mae"),
|
|
243
|
+
r2=res.get("r2"),
|
|
244
|
+
explained_variance=res.get("explained_variance"),
|
|
245
|
+
loss=res.get("loss"),
|
|
246
|
+
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
247
|
+
anomaly_score_median=res.get("anomaly_score_median"),
|
|
248
|
+
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
249
|
+
)
|
|
250
|
+
if "mse" in res
|
|
251
|
+
else ClassificationMetrics(
|
|
252
|
+
coverage=res.get("coverage"),
|
|
253
|
+
f1_score=res.get("f1_score"),
|
|
254
|
+
accuracy=res.get("accuracy"),
|
|
255
|
+
loss=res.get("loss"),
|
|
256
|
+
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
257
|
+
anomaly_score_median=res.get("anomaly_score_median"),
|
|
258
|
+
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
259
|
+
roc_auc=res.get("roc_auc"),
|
|
260
|
+
pr_auc=res.get("pr_auc"),
|
|
261
|
+
pr_curve=res.get("pr_curve"),
|
|
262
|
+
roc_curve=res.get("roc_curve"),
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
job = Job(response["task_id"], lambda: get_result(response["task_id"]))
|
|
267
|
+
return job if background else job.result()
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class _ModelDescriptor:
|
|
271
|
+
"""
|
|
272
|
+
Descriptor for lazily loading embedding models with IDE autocomplete support.
|
|
273
|
+
|
|
274
|
+
This class implements the descriptor protocol to provide lazy loading of embedding models
|
|
275
|
+
while maintaining IDE autocomplete functionality. It delays the actual loading of models
|
|
276
|
+
until they are accessed, which improves startup performance.
|
|
89
277
|
|
|
90
|
-
|
|
278
|
+
The descriptor pattern works by defining how attribute access is handled. When a class
|
|
279
|
+
attribute using this descriptor is accessed, the __get__ method is called, which then
|
|
280
|
+
retrieves or initializes the actual model on first access.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
def __init__(self, name: str):
|
|
284
|
+
"""
|
|
285
|
+
Initialize a model descriptor.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
name: The name of the embedding model in PretrainedEmbeddingModelName
|
|
289
|
+
"""
|
|
290
|
+
self.name = name
|
|
291
|
+
self.model = None # Model is loaded lazily on first access
|
|
292
|
+
|
|
293
|
+
def __get__(self, instance, owner_class):
|
|
294
|
+
"""
|
|
295
|
+
Descriptor protocol method called when the attribute is accessed.
|
|
296
|
+
|
|
297
|
+
This method implements lazy loading - the actual model is only initialized
|
|
298
|
+
the first time it's accessed. Subsequent accesses will use the cached model.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
instance: The instance the attribute was accessed from, or None if accessed from the class
|
|
302
|
+
owner_class: The class that owns the descriptor
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
The initialized embedding model
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
AttributeError: If no model with the given name exists
|
|
309
|
+
"""
|
|
310
|
+
# When accessed from an instance, redirect to class access
|
|
311
|
+
if instance is not None:
|
|
312
|
+
return self.__get__(None, owner_class)
|
|
313
|
+
|
|
314
|
+
# Load the model on first access
|
|
315
|
+
if self.model is None:
|
|
316
|
+
try:
|
|
317
|
+
self.model = PretrainedEmbeddingModel._get(cast(PretrainedEmbeddingModelName, self.name))
|
|
318
|
+
except (KeyError, AttributeError):
|
|
319
|
+
raise AttributeError(f"No embedding model named {self.name}")
|
|
320
|
+
|
|
321
|
+
return self.model
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class PretrainedEmbeddingModel(EmbeddingModelBase):
|
|
91
325
|
"""
|
|
92
326
|
A pretrained embedding model
|
|
93
327
|
|
|
@@ -100,28 +334,60 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
100
334
|
- **`CDE_SMALL`**: Context-aware CDE small model from Hugging Face ([jxm/cde-small-v1](https://huggingface.co/jxm/cde-small-v1))
|
|
101
335
|
- **`CLIP_BASE`**: Multi-modal CLIP model from Hugging Face ([sentence-transformers/clip-ViT-L-14](https://huggingface.co/sentence-transformers/clip-ViT-L-14))
|
|
102
336
|
- **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
|
|
337
|
+
- **`DISTILBERT`**: DistilBERT embedding model from Hugging Face ([distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased))
|
|
338
|
+
- **`GTE_SMALL`**: GTE-Small embedding model from Hugging Face ([Supabase/gte-small](https://huggingface.co/Supabase/gte-small))
|
|
339
|
+
- **`E5_LARGE`**: E5-Large instruction-tuned embedding model from Hugging Face ([intfloat/multilingual-e5-large-instruct](https://huggingface.co/intfloat/multilingual-e5-large-instruct))
|
|
340
|
+
- **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
|
|
341
|
+
- **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
|
|
342
|
+
- **`QWEN2_1_5B`**: Alibaba's Qwen2-1.5B instruction-tuned embedding model from Hugging Face ([Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct))
|
|
343
|
+
- **`BGE_BASE`**: BAAI's BGE-Base instruction-tuned embedding model from Hugging Face ([BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5))
|
|
344
|
+
|
|
345
|
+
**Instruction Support:**
|
|
346
|
+
|
|
347
|
+
Some models support instruction-following for better task-specific embeddings. You can check if a model supports instructions
|
|
348
|
+
using the `supports_instructions` attribute.
|
|
103
349
|
|
|
104
350
|
Examples:
|
|
105
351
|
>>> PretrainedEmbeddingModel.CDE_SMALL
|
|
106
352
|
PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
|
|
107
353
|
|
|
354
|
+
>>> # Using instruction with an instruction-supporting model
|
|
355
|
+
>>> model = PretrainedEmbeddingModel.E5_LARGE
|
|
356
|
+
>>> embeddings = model.embed("Hello world", prompt="Represent this sentence for retrieval:")
|
|
357
|
+
|
|
108
358
|
Attributes:
|
|
109
359
|
name: Name of the pretrained embedding model
|
|
110
360
|
embedding_dim: Dimension of the embeddings that are generated by the model
|
|
111
361
|
max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
|
|
112
362
|
uses_context: Whether the pretrained embedding model uses context
|
|
363
|
+
supports_instructions: Whether this model supports instruction-following
|
|
113
364
|
"""
|
|
114
365
|
|
|
115
|
-
|
|
366
|
+
# Define descriptors for model access with IDE autocomplete
|
|
367
|
+
CDE_SMALL = _ModelDescriptor("CDE_SMALL")
|
|
368
|
+
CLIP_BASE = _ModelDescriptor("CLIP_BASE")
|
|
369
|
+
GTE_BASE = _ModelDescriptor("GTE_BASE")
|
|
370
|
+
DISTILBERT = _ModelDescriptor("DISTILBERT")
|
|
371
|
+
GTE_SMALL = _ModelDescriptor("GTE_SMALL")
|
|
372
|
+
E5_LARGE = _ModelDescriptor("E5_LARGE")
|
|
373
|
+
GIST_LARGE = _ModelDescriptor("GIST_LARGE")
|
|
374
|
+
MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
|
|
375
|
+
QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
|
|
376
|
+
BGE_BASE = _ModelDescriptor("BGE_BASE")
|
|
377
|
+
|
|
378
|
+
name: PretrainedEmbeddingModelName
|
|
116
379
|
|
|
117
380
|
def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
|
|
118
381
|
# for internal use only, do not document
|
|
119
|
-
self.
|
|
382
|
+
self.name = metadata["name"]
|
|
120
383
|
super().__init__(
|
|
121
|
-
name=metadata
|
|
122
|
-
embedding_dim=metadata
|
|
123
|
-
max_seq_length=metadata
|
|
124
|
-
uses_context=metadata
|
|
384
|
+
name=metadata["name"],
|
|
385
|
+
embedding_dim=metadata["embedding_dim"],
|
|
386
|
+
max_seq_length=metadata["max_seq_length"],
|
|
387
|
+
uses_context=metadata["uses_context"],
|
|
388
|
+
supports_instructions=(
|
|
389
|
+
bool(metadata["supports_instructions"]) if "supports_instructions" in metadata else False
|
|
390
|
+
),
|
|
125
391
|
)
|
|
126
392
|
|
|
127
393
|
def __eq__(self, other) -> bool:
|
|
@@ -138,16 +404,46 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
138
404
|
Returns:
|
|
139
405
|
A list of all pretrained embedding models available in the OrcaCloud
|
|
140
406
|
"""
|
|
141
|
-
|
|
407
|
+
client = OrcaClient._resolve_client()
|
|
408
|
+
return [cls(metadata) for metadata in client.GET("/pretrained_embedding_model")]
|
|
142
409
|
|
|
143
410
|
_instances: dict[str, PretrainedEmbeddingModel] = {}
|
|
144
411
|
|
|
145
412
|
@classmethod
|
|
146
|
-
def _get(cls, name: PretrainedEmbeddingModelName
|
|
413
|
+
def _get(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
|
|
147
414
|
# for internal use only, do not document - we want people to use dot notation to get the model
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
415
|
+
cache_key = str(name)
|
|
416
|
+
if cache_key not in cls._instances:
|
|
417
|
+
client = OrcaClient._resolve_client()
|
|
418
|
+
metadata = client.GET(
|
|
419
|
+
"/pretrained_embedding_model/{model_name}",
|
|
420
|
+
params={"model_name": name},
|
|
421
|
+
)
|
|
422
|
+
cls._instances[cache_key] = cls(metadata)
|
|
423
|
+
return cls._instances[cache_key]
|
|
424
|
+
|
|
425
|
+
@classmethod
|
|
426
|
+
def open(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
|
|
427
|
+
"""
|
|
428
|
+
Open an embedding model by name.
|
|
429
|
+
|
|
430
|
+
This is an alternative method to access models for environments
|
|
431
|
+
where IDE autocomplete for model names is not available.
|
|
432
|
+
|
|
433
|
+
Params:
|
|
434
|
+
name: Name of the model to open (e.g., "GTE_BASE", "CLIP_BASE")
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
The embedding model instance
|
|
438
|
+
|
|
439
|
+
Examples:
|
|
440
|
+
>>> model = PretrainedEmbeddingModel.open("GTE_BASE")
|
|
441
|
+
"""
|
|
442
|
+
try:
|
|
443
|
+
# Always use the _get method which handles caching properly
|
|
444
|
+
return cls._get(name)
|
|
445
|
+
except (KeyError, AttributeError):
|
|
446
|
+
raise ValueError(f"Unknown model name: {name}")
|
|
151
447
|
|
|
152
448
|
@classmethod
|
|
153
449
|
def exists(cls, name: str) -> bool:
|
|
@@ -160,8 +456,25 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
160
456
|
Returns:
|
|
161
457
|
True if the pretrained embedding model exists, False otherwise
|
|
162
458
|
"""
|
|
163
|
-
return name in PretrainedEmbeddingModelName
|
|
459
|
+
return name in get_args(PretrainedEmbeddingModelName)
|
|
460
|
+
|
|
461
|
+
@overload
|
|
462
|
+
def finetune(
|
|
463
|
+
self,
|
|
464
|
+
name: str,
|
|
465
|
+
train_datasource: Datasource | LabeledMemoryset,
|
|
466
|
+
*,
|
|
467
|
+
eval_datasource: Datasource | None = None,
|
|
468
|
+
label_column: str = "label",
|
|
469
|
+
value_column: str = "value",
|
|
470
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
471
|
+
training_args: dict | None = None,
|
|
472
|
+
if_exists: CreateMode = "error",
|
|
473
|
+
background: Literal[True],
|
|
474
|
+
) -> Job[FinetunedEmbeddingModel]:
|
|
475
|
+
pass
|
|
164
476
|
|
|
477
|
+
@overload
|
|
165
478
|
def finetune(
|
|
166
479
|
self,
|
|
167
480
|
name: str,
|
|
@@ -170,10 +483,26 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
170
483
|
eval_datasource: Datasource | None = None,
|
|
171
484
|
label_column: str = "label",
|
|
172
485
|
value_column: str = "value",
|
|
173
|
-
training_method: EmbeddingFinetuningMethod
|
|
486
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
174
487
|
training_args: dict | None = None,
|
|
175
488
|
if_exists: CreateMode = "error",
|
|
489
|
+
background: Literal[False] = False,
|
|
176
490
|
) -> FinetunedEmbeddingModel:
|
|
491
|
+
pass
|
|
492
|
+
|
|
493
|
+
def finetune(
|
|
494
|
+
self,
|
|
495
|
+
name: str,
|
|
496
|
+
train_datasource: Datasource | LabeledMemoryset,
|
|
497
|
+
*,
|
|
498
|
+
eval_datasource: Datasource | None = None,
|
|
499
|
+
label_column: str = "label",
|
|
500
|
+
value_column: str = "value",
|
|
501
|
+
training_method: EmbeddingFinetuningMethod = "classification",
|
|
502
|
+
training_args: dict | None = None,
|
|
503
|
+
if_exists: CreateMode = "error",
|
|
504
|
+
background: bool = False,
|
|
505
|
+
) -> FinetunedEmbeddingModel | Job[FinetunedEmbeddingModel]:
|
|
177
506
|
"""
|
|
178
507
|
Finetune an embedding model
|
|
179
508
|
|
|
@@ -184,10 +513,11 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
184
513
|
label_column: Column name of the label
|
|
185
514
|
value_column: Column name of the value
|
|
186
515
|
training_method: Training method to use
|
|
187
|
-
training_args: Optional override for Hugging Face [`TrainingArguments`]
|
|
516
|
+
training_args: Optional override for Hugging Face [`TrainingArguments`][transformers.TrainingArguments].
|
|
188
517
|
If not provided, reasonable training arguments will be used for the specified training method
|
|
189
518
|
if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
|
|
190
519
|
`"error"`. Other option is `"open"` to open the existing finetuned embedding model.
|
|
520
|
+
background: Whether to run the operation in the background and return a job handle
|
|
191
521
|
|
|
192
522
|
Returns:
|
|
193
523
|
The finetuned embedding model
|
|
@@ -208,34 +538,41 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
|
|
|
208
538
|
elif exists and if_exists == "open":
|
|
209
539
|
existing = FinetunedEmbeddingModel.open(name)
|
|
210
540
|
|
|
211
|
-
if existing.base_model_name != self.
|
|
541
|
+
if existing.base_model_name != self.name:
|
|
212
542
|
raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
|
|
213
543
|
|
|
214
544
|
return existing
|
|
215
545
|
|
|
216
546
|
from .memoryset import LabeledMemoryset
|
|
217
547
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
548
|
+
payload: FinetuneEmbeddingModelRequest = {
|
|
549
|
+
"name": name,
|
|
550
|
+
"base_model": self.name,
|
|
551
|
+
"label_column": label_column,
|
|
552
|
+
"value_column": value_column,
|
|
553
|
+
"training_method": training_method,
|
|
554
|
+
"training_args": training_args or {},
|
|
555
|
+
}
|
|
556
|
+
if isinstance(train_datasource, Datasource):
|
|
557
|
+
payload["train_datasource_name_or_id"] = train_datasource.id
|
|
558
|
+
elif isinstance(train_datasource, LabeledMemoryset):
|
|
559
|
+
payload["train_memoryset_name_or_id"] = train_datasource.id
|
|
560
|
+
if eval_datasource is not None:
|
|
561
|
+
payload["eval_datasource_name_or_id"] = eval_datasource.id
|
|
562
|
+
|
|
563
|
+
client = OrcaClient._resolve_client()
|
|
564
|
+
res = client.POST(
|
|
565
|
+
"/finetuned_embedding_model",
|
|
566
|
+
json=payload,
|
|
233
567
|
)
|
|
234
|
-
|
|
235
|
-
|
|
568
|
+
job = Job(
|
|
569
|
+
res["finetuning_task_id"],
|
|
570
|
+
lambda: FinetunedEmbeddingModel.open(res["id"]),
|
|
571
|
+
)
|
|
572
|
+
return job if background else job.result()
|
|
236
573
|
|
|
237
574
|
|
|
238
|
-
class FinetunedEmbeddingModel(
|
|
575
|
+
class FinetunedEmbeddingModel(EmbeddingModelBase):
|
|
239
576
|
"""
|
|
240
577
|
A finetuned embedding model in the OrcaCloud
|
|
241
578
|
|
|
@@ -250,22 +587,27 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
250
587
|
"""
|
|
251
588
|
|
|
252
589
|
id: str
|
|
590
|
+
name: str
|
|
253
591
|
created_at: datetime
|
|
254
592
|
updated_at: datetime
|
|
255
|
-
|
|
593
|
+
base_model_name: PretrainedEmbeddingModelName
|
|
594
|
+
_status: Status
|
|
256
595
|
|
|
257
596
|
def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
|
|
258
597
|
# for internal use only, do not document
|
|
259
|
-
self.id = metadata
|
|
260
|
-
self.
|
|
261
|
-
self.
|
|
262
|
-
self.
|
|
263
|
-
self.
|
|
598
|
+
self.id = metadata["id"]
|
|
599
|
+
self.name = metadata["name"]
|
|
600
|
+
self.created_at = datetime.fromisoformat(metadata["created_at"])
|
|
601
|
+
self.updated_at = datetime.fromisoformat(metadata["updated_at"])
|
|
602
|
+
self.base_model_name = metadata["base_model"]
|
|
603
|
+
self._status = Status(metadata["finetuning_status"])
|
|
604
|
+
|
|
264
605
|
super().__init__(
|
|
265
|
-
name=metadata
|
|
266
|
-
embedding_dim=metadata
|
|
267
|
-
max_seq_length=metadata
|
|
268
|
-
uses_context=metadata
|
|
606
|
+
name=metadata["name"],
|
|
607
|
+
embedding_dim=metadata["embedding_dim"],
|
|
608
|
+
max_seq_length=metadata["max_seq_length"],
|
|
609
|
+
uses_context=metadata["uses_context"],
|
|
610
|
+
supports_instructions=self.base_model.supports_instructions,
|
|
269
611
|
)
|
|
270
612
|
|
|
271
613
|
def __eq__(self, other) -> bool:
|
|
@@ -277,7 +619,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
277
619
|
f" name: {self.name},\n"
|
|
278
620
|
f" embedding_dim: {self.embedding_dim},\n"
|
|
279
621
|
f" max_seq_length: {self.max_seq_length},\n"
|
|
280
|
-
f" base_model: PretrainedEmbeddingModel.{self.base_model_name
|
|
622
|
+
f" base_model: PretrainedEmbeddingModel.{self.base_model_name}\n"
|
|
281
623
|
"})"
|
|
282
624
|
)
|
|
283
625
|
|
|
@@ -294,7 +636,8 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
294
636
|
Returns:
|
|
295
637
|
A list of all finetuned embedding model handles in the OrcaCloud
|
|
296
638
|
"""
|
|
297
|
-
|
|
639
|
+
client = OrcaClient._resolve_client()
|
|
640
|
+
return [cls(metadata) for metadata in client.GET("/finetuned_embedding_model")]
|
|
298
641
|
|
|
299
642
|
@classmethod
|
|
300
643
|
def open(cls, name: str) -> FinetunedEmbeddingModel:
|
|
@@ -310,7 +653,12 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
310
653
|
Raises:
|
|
311
654
|
LookupError: If the finetuned embedding model does not exist
|
|
312
655
|
"""
|
|
313
|
-
|
|
656
|
+
client = OrcaClient._resolve_client()
|
|
657
|
+
metadata = client.GET(
|
|
658
|
+
"/finetuned_embedding_model/{name_or_id}",
|
|
659
|
+
params={"name_or_id": name},
|
|
660
|
+
)
|
|
661
|
+
return cls(metadata)
|
|
314
662
|
|
|
315
663
|
@classmethod
|
|
316
664
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -341,7 +689,11 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
|
341
689
|
LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
|
|
342
690
|
"""
|
|
343
691
|
try:
|
|
344
|
-
|
|
345
|
-
|
|
692
|
+
client = OrcaClient._resolve_client()
|
|
693
|
+
client.DELETE(
|
|
694
|
+
"/finetuned_embedding_model/{name_or_id}",
|
|
695
|
+
params={"name_or_id": name_or_id},
|
|
696
|
+
)
|
|
697
|
+
except (LookupError, RuntimeError):
|
|
346
698
|
if if_not_exists == "error":
|
|
347
699
|
raise
|