orca-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +19 -0
- orca_sdk/_generated_api_client/__init__.py +3 -0
- orca_sdk/_generated_api_client/api/__init__.py +193 -0
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
- orca_sdk/_generated_api_client/client.py +216 -0
- orca_sdk/_generated_api_client/errors.py +38 -0
- orca_sdk/_generated_api_client/models/__init__.py +159 -0
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
- orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
- orca_sdk/_generated_api_client/models/base_model.py +55 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
- orca_sdk/_generated_api_client/models/column_info.py +114 -0
- orca_sdk/_generated_api_client/models/column_type.py +14 -0
- orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
- orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
- orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
- orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/embed_request.py +127 -0
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
- orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
- orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
- orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
- orca_sdk/_generated_api_client/models/filter_item.py +231 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
- orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
- orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
- orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
- orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
- orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
- orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/task.py +198 -0
- orca_sdk/_generated_api_client/models/task_status.py +14 -0
- orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
- orca_sdk/_generated_api_client/py.typed +1 -0
- orca_sdk/_generated_api_client/types.py +56 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +194 -0
- orca_sdk/_utils/analysis_ui_style.css +54 -0
- orca_sdk/_utils/auth.py +63 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +99 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +64 -0
- orca_sdk/_utils/task.py +73 -0
- orca_sdk/classification_model.py +499 -0
- orca_sdk/classification_model_test.py +266 -0
- orca_sdk/conftest.py +117 -0
- orca_sdk/datasource.py +333 -0
- orca_sdk/datasource_test.py +95 -0
- orca_sdk/embedding_model.py +336 -0
- orca_sdk/embedding_model_test.py +173 -0
- orca_sdk/labeled_memoryset.py +1154 -0
- orca_sdk/labeled_memoryset_test.py +271 -0
- orca_sdk/orca_credentials.py +75 -0
- orca_sdk/orca_credentials_test.py +37 -0
- orca_sdk/telemetry.py +386 -0
- orca_sdk/telemetry_test.py +100 -0
- orca_sdk-0.1.0.dist-info/METADATA +39 -0
- orca_sdk-0.1.0.dist-info/RECORD +175 -0
- orca_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import TYPE_CHECKING, Sequence, cast, overload
|
|
6
|
+
|
|
7
|
+
from ._generated_api_client.api import (
|
|
8
|
+
create_finetuned_embedding_model,
|
|
9
|
+
delete_finetuned_embedding_model,
|
|
10
|
+
embed_with_finetuned_model_gpu,
|
|
11
|
+
embed_with_pretrained_model_gpu,
|
|
12
|
+
get_finetuned_embedding_model,
|
|
13
|
+
get_pretrained_embedding_model,
|
|
14
|
+
list_finetuned_embedding_models,
|
|
15
|
+
list_pretrained_embedding_models,
|
|
16
|
+
)
|
|
17
|
+
from ._generated_api_client.models import (
|
|
18
|
+
EmbeddingFinetuningMethod,
|
|
19
|
+
EmbedRequest,
|
|
20
|
+
FinetunedEmbeddingModelMetadata,
|
|
21
|
+
FinetuneEmbeddingModelRequest,
|
|
22
|
+
FinetuneEmbeddingModelRequestTrainingArgs,
|
|
23
|
+
PretrainedEmbeddingModelMetadata,
|
|
24
|
+
PretrainedEmbeddingModelName,
|
|
25
|
+
)
|
|
26
|
+
from ._utils.common import CreateMode, DropMode
|
|
27
|
+
from ._utils.task import TaskStatus, wait_for_task
|
|
28
|
+
from .datasource import Datasource
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from .labeled_memoryset import LabeledMemoryset
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class _EmbeddingModel:
|
|
35
|
+
name: str
|
|
36
|
+
embedding_dim: int
|
|
37
|
+
max_seq_length: int
|
|
38
|
+
uses_context: bool
|
|
39
|
+
|
|
40
|
+
def __init__(self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool):
|
|
41
|
+
self.name = name
|
|
42
|
+
self.embedding_dim = embedding_dim
|
|
43
|
+
self.max_seq_length = max_seq_length
|
|
44
|
+
self.uses_context = uses_context
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def all(cls) -> Sequence[_EmbeddingModel]:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
@overload
|
|
52
|
+
def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@overload
|
|
56
|
+
def embed(self, value: list[str], max_seq_length: int | None = None) -> list[list[float]]:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def embed(self, value: str | list[str], max_seq_length: int | None = None) -> list[float] | list[list[float]]:
|
|
60
|
+
request = EmbedRequest(values=value if isinstance(value, list) else [value], max_seq_length=max_seq_length)
|
|
61
|
+
if isinstance(self, PretrainedEmbeddingModel):
|
|
62
|
+
embeddings = embed_with_pretrained_model_gpu(self._model_name, body=request)
|
|
63
|
+
elif isinstance(self, FinetunedEmbeddingModel):
|
|
64
|
+
embeddings = embed_with_finetuned_model_gpu(self.id, body=request)
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError("Invalid embedding model")
|
|
67
|
+
return embeddings if isinstance(value, list) else embeddings[0]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class _PretrainedEmbeddingModelMeta(type):
|
|
71
|
+
def __getattr__(cls, name: str) -> PretrainedEmbeddingModel:
|
|
72
|
+
if cls != FinetunedEmbeddingModel and name in PretrainedEmbeddingModelName.__members__:
|
|
73
|
+
return PretrainedEmbeddingModel._get(name)
|
|
74
|
+
else:
|
|
75
|
+
raise AttributeError(f"'{cls.__name__}' object has no attribute '{name}'")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingModelMeta):
|
|
79
|
+
"""
|
|
80
|
+
A pretrained embedding model
|
|
81
|
+
|
|
82
|
+
**Models:**
|
|
83
|
+
|
|
84
|
+
OrcaCloud supports a select number of small to medium sized embedding models that perform well on the
|
|
85
|
+
[Hugging Face MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
|
|
86
|
+
These can be accessed as class attributes. We currently support:
|
|
87
|
+
|
|
88
|
+
- **`CDE_SMALL`**: Context-aware CDE small model from Hugging Face ([jxm/cde-small-v1](https://huggingface.co/jxm/cde-small-v1))
|
|
89
|
+
- **`CLIP_BASE`**: Multi-modal CLIP model from Hugging Face ([sentence-transformers/clip-ViT-L-14](https://huggingface.co/sentence-transformers/clip-ViT-L-14))
|
|
90
|
+
- **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
|
|
91
|
+
|
|
92
|
+
Examples:
|
|
93
|
+
>>> PretrainedEmbeddingModel.CDE_SMALL
|
|
94
|
+
PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
|
|
95
|
+
|
|
96
|
+
Attributes:
|
|
97
|
+
name: Name of the pretrained embedding model
|
|
98
|
+
embedding_dim: Dimension of the embeddings that are generated by the model
|
|
99
|
+
max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
|
|
100
|
+
uses_context: Whether the pretrained embedding model uses context
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
_model_name: PretrainedEmbeddingModelName
|
|
104
|
+
|
|
105
|
+
def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
|
|
106
|
+
# for internal use only, do not document
|
|
107
|
+
self._model_name = metadata.name
|
|
108
|
+
super().__init__(
|
|
109
|
+
name=metadata.name.value,
|
|
110
|
+
embedding_dim=metadata.embedding_dim,
|
|
111
|
+
max_seq_length=metadata.max_seq_length,
|
|
112
|
+
uses_context=metadata.uses_context,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def __eq__(self, other) -> bool:
|
|
116
|
+
return isinstance(other, PretrainedEmbeddingModel) and self.name == other.name
|
|
117
|
+
|
|
118
|
+
def __repr__(self) -> str:
|
|
119
|
+
return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}}})"
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def all(cls) -> list[PretrainedEmbeddingModel]:
|
|
123
|
+
"""
|
|
124
|
+
List all pretrained embedding models in the OrcaCloud
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
A list of all pretrained embedding models available in the OrcaCloud
|
|
128
|
+
"""
|
|
129
|
+
return [cls(metadata) for metadata in list_pretrained_embedding_models()]
|
|
130
|
+
|
|
131
|
+
_instances: dict[str, PretrainedEmbeddingModel] = {}
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def _get(cls, name: PretrainedEmbeddingModelName | str) -> PretrainedEmbeddingModel:
|
|
135
|
+
# for internal use only, do not document - we want people to use dot notation to get the model
|
|
136
|
+
if str(name) not in cls._instances:
|
|
137
|
+
cls._instances[str(name)] = cls(get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name)))
|
|
138
|
+
return cls._instances[str(name)]
|
|
139
|
+
|
|
140
|
+
@classmethod
|
|
141
|
+
def exists(cls, name: str) -> bool:
|
|
142
|
+
"""
|
|
143
|
+
Check if a pretrained embedding model exists by name
|
|
144
|
+
|
|
145
|
+
Params:
|
|
146
|
+
name: The name of the pretrained embedding model
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
True if the pretrained embedding model exists, False otherwise
|
|
150
|
+
"""
|
|
151
|
+
return name in PretrainedEmbeddingModelName
|
|
152
|
+
|
|
153
|
+
def finetune(
|
|
154
|
+
self,
|
|
155
|
+
name: str,
|
|
156
|
+
train_datasource: Datasource | LabeledMemoryset,
|
|
157
|
+
*,
|
|
158
|
+
eval_datasource: Datasource | None = None,
|
|
159
|
+
label_column: str = "label",
|
|
160
|
+
value_column: str = "value",
|
|
161
|
+
training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
|
|
162
|
+
training_args: dict | None = None,
|
|
163
|
+
if_exists: CreateMode = "error",
|
|
164
|
+
) -> FinetunedEmbeddingModel:
|
|
165
|
+
"""
|
|
166
|
+
Finetune an embedding model
|
|
167
|
+
|
|
168
|
+
Params:
|
|
169
|
+
name: Name of the finetuned embedding model
|
|
170
|
+
train_datasource: Data to train on
|
|
171
|
+
eval_datasource: Optionally provide data to evaluate on
|
|
172
|
+
label_column: Column name of the label
|
|
173
|
+
value_column: Column name of the value
|
|
174
|
+
training_method: Training method to use
|
|
175
|
+
training_args: Optional override for Hugging Face [`TrainingArguments`](transformers.TrainingArguments).
|
|
176
|
+
If not provided, reasonable training arguments will be used for the specified training method
|
|
177
|
+
if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
|
|
178
|
+
`"error"`. Other option is `"open"` to open the existing finetuned embedding model.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
The finetuned embedding model
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ValueError: If the finetuned embedding model already exists and `if_exists` is `"error"` or if it is `"open"`
|
|
185
|
+
but the base model param does not match the existing model
|
|
186
|
+
|
|
187
|
+
Examples:
|
|
188
|
+
>>> datasource = Datasource.open("my_datasource")
|
|
189
|
+
>>> model = PretrainedEmbeddingModel.CLIP_BASE
|
|
190
|
+
>>> model.finetune("my_finetuned_model", datasource)
|
|
191
|
+
"""
|
|
192
|
+
exists = FinetunedEmbeddingModel.exists(name)
|
|
193
|
+
|
|
194
|
+
if exists and if_exists == "error":
|
|
195
|
+
raise ValueError(f"Finetuned embedding model '{name}' already exists")
|
|
196
|
+
elif exists and if_exists == "open":
|
|
197
|
+
existing = FinetunedEmbeddingModel.open(name)
|
|
198
|
+
|
|
199
|
+
if existing.base_model_name != self._model_name:
|
|
200
|
+
raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
|
|
201
|
+
|
|
202
|
+
return existing
|
|
203
|
+
|
|
204
|
+
from .labeled_memoryset import LabeledMemoryset
|
|
205
|
+
|
|
206
|
+
train_datasource_id = train_datasource.id if isinstance(train_datasource, Datasource) else None
|
|
207
|
+
train_memoryset_id = train_datasource.id if isinstance(train_datasource, LabeledMemoryset) else None
|
|
208
|
+
assert train_datasource_id is not None or train_memoryset_id is not None
|
|
209
|
+
res = create_finetuned_embedding_model(
|
|
210
|
+
body=FinetuneEmbeddingModelRequest(
|
|
211
|
+
name=name,
|
|
212
|
+
base_model=self._model_name,
|
|
213
|
+
train_memoryset_id=train_memoryset_id,
|
|
214
|
+
train_datasource_id=train_datasource_id,
|
|
215
|
+
eval_datasource_id=eval_datasource.id if eval_datasource is not None else None,
|
|
216
|
+
label_column=label_column,
|
|
217
|
+
value_column=value_column,
|
|
218
|
+
training_method=EmbeddingFinetuningMethod(training_method),
|
|
219
|
+
training_args=(FinetuneEmbeddingModelRequestTrainingArgs.from_dict(training_args or {})),
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
wait_for_task(res.finetuning_task_id, description="Finetuning embedding model")
|
|
223
|
+
return FinetunedEmbeddingModel.open(res.id)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class FinetunedEmbeddingModel(_EmbeddingModel):
|
|
227
|
+
"""
|
|
228
|
+
A finetuned embedding model in the OrcaCloud
|
|
229
|
+
|
|
230
|
+
Attributes:
|
|
231
|
+
name: Name of the finetuned embedding model
|
|
232
|
+
embedding_dim: Dimension of the embeddings that are generated by the model
|
|
233
|
+
max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
|
|
234
|
+
uses_context: Whether the model uses the memoryset to contextualize embeddings (acts akin to inverse document frequency in TFIDF features)
|
|
235
|
+
id: Unique identifier of the finetuned embedding model
|
|
236
|
+
base_model: Base model the finetuned embedding model was trained on
|
|
237
|
+
created_at: When the model was finetuned
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
id: str
|
|
241
|
+
created_at: datetime
|
|
242
|
+
updated_at: datetime
|
|
243
|
+
_status: TaskStatus
|
|
244
|
+
|
|
245
|
+
def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
|
|
246
|
+
# for internal use only, do not document
|
|
247
|
+
self.id = metadata.id
|
|
248
|
+
self.created_at = metadata.created_at
|
|
249
|
+
self.updated_at = metadata.updated_at
|
|
250
|
+
self.base_model_name = metadata.base_model
|
|
251
|
+
self._status = metadata.finetuning_status
|
|
252
|
+
super().__init__(
|
|
253
|
+
name=metadata.name,
|
|
254
|
+
embedding_dim=metadata.embedding_dim,
|
|
255
|
+
max_seq_length=metadata.max_seq_length,
|
|
256
|
+
uses_context=metadata.uses_context,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def __eq__(self, other) -> bool:
|
|
260
|
+
return isinstance(other, FinetunedEmbeddingModel) and self.id == other.id
|
|
261
|
+
|
|
262
|
+
def __repr__(self) -> str:
|
|
263
|
+
return (
|
|
264
|
+
"FinetunedEmbeddingModel({\n"
|
|
265
|
+
f" name: {self.name},\n"
|
|
266
|
+
f" embedding_dim: {self.embedding_dim},\n"
|
|
267
|
+
f" max_seq_length: {self.max_seq_length},\n"
|
|
268
|
+
f" status: {self._status}\n"
|
|
269
|
+
f" base_model: PretrainedEmbeddingModel.{self.base_model_name.value}\n"
|
|
270
|
+
"})"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
@property
|
|
274
|
+
def base_model(self) -> PretrainedEmbeddingModel:
|
|
275
|
+
"""Pretrained model the finetuned embedding model was based on"""
|
|
276
|
+
return PretrainedEmbeddingModel._get(self.base_model_name)
|
|
277
|
+
|
|
278
|
+
@classmethod
|
|
279
|
+
def all(cls) -> list[FinetunedEmbeddingModel]:
|
|
280
|
+
"""
|
|
281
|
+
List all finetuned embedding model handles in the OrcaCloud
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
A list of all finetuned embedding model handles in the OrcaCloud
|
|
285
|
+
"""
|
|
286
|
+
return [cls(metadata) for metadata in list_finetuned_embedding_models()]
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def open(cls, name: str) -> FinetunedEmbeddingModel:
|
|
290
|
+
"""
|
|
291
|
+
Get a handle to a finetuned embedding model in the OrcaCloud
|
|
292
|
+
|
|
293
|
+
Params:
|
|
294
|
+
name: The name or unique identifier of a finetuned embedding model
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
A handle to the finetuned embedding model in the OrcaCloud
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
LookupError: If the finetuned embedding model does not exist
|
|
301
|
+
"""
|
|
302
|
+
return cls(get_finetuned_embedding_model(name))
|
|
303
|
+
|
|
304
|
+
@classmethod
|
|
305
|
+
def exists(cls, name_or_id: str) -> bool:
|
|
306
|
+
"""
|
|
307
|
+
Check if a finetuned embedding model with the given name or id exists.
|
|
308
|
+
|
|
309
|
+
Params:
|
|
310
|
+
name_or_id: The name or id of the finetuned embedding model
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
True if the finetuned embedding model exists, False otherwise
|
|
314
|
+
"""
|
|
315
|
+
try:
|
|
316
|
+
cls.open(name_or_id)
|
|
317
|
+
return True
|
|
318
|
+
except LookupError:
|
|
319
|
+
return False
|
|
320
|
+
|
|
321
|
+
@classmethod
|
|
322
|
+
def drop(cls, name_or_id: str, *, if_not_exists: DropMode = "error"):
|
|
323
|
+
"""
|
|
324
|
+
Delete the finetuned embedding model from the OrcaCloud
|
|
325
|
+
|
|
326
|
+
Params:
|
|
327
|
+
name_or_id: The name or id of the finetuned embedding model
|
|
328
|
+
|
|
329
|
+
Raises:
|
|
330
|
+
LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
|
|
331
|
+
"""
|
|
332
|
+
try:
|
|
333
|
+
delete_finetuned_embedding_model(name_or_id)
|
|
334
|
+
except LookupError:
|
|
335
|
+
if if_not_exists == "error":
|
|
336
|
+
raise
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from .datasource import Datasource
|
|
6
|
+
from .embedding_model import (
|
|
7
|
+
FinetunedEmbeddingModel,
|
|
8
|
+
PretrainedEmbeddingModel,
|
|
9
|
+
PretrainedEmbeddingModelName,
|
|
10
|
+
TaskStatus,
|
|
11
|
+
)
|
|
12
|
+
from .labeled_memoryset import LabeledMemoryset
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_open_pretrained_model():
|
|
16
|
+
model = PretrainedEmbeddingModel.GTE_BASE
|
|
17
|
+
assert model is not None
|
|
18
|
+
assert isinstance(model, PretrainedEmbeddingModel)
|
|
19
|
+
assert model.name == "GTE_BASE"
|
|
20
|
+
assert model.embedding_dim == 768
|
|
21
|
+
assert model.max_seq_length == 8192
|
|
22
|
+
assert model is PretrainedEmbeddingModel.GTE_BASE
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def test_open_pretrained_model_unauthenticated(unauthenticated):
|
|
26
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
27
|
+
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_open_pretrained_model_not_found():
|
|
31
|
+
with pytest.raises(LookupError):
|
|
32
|
+
PretrainedEmbeddingModel._get("INVALID_MODEL")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_all_pretrained_models():
|
|
36
|
+
models = PretrainedEmbeddingModel.all()
|
|
37
|
+
assert len(models) == len(PretrainedEmbeddingModelName)
|
|
38
|
+
assert all(m.name in PretrainedEmbeddingModelName.__members__ for m in models)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_embed_text():
|
|
42
|
+
embedding = PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
|
|
43
|
+
assert embedding is not None
|
|
44
|
+
assert isinstance(embedding, list)
|
|
45
|
+
assert len(embedding) == 768
|
|
46
|
+
assert isinstance(embedding[0], float)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_embed_text_unauthenticated(unauthenticated):
|
|
50
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
51
|
+
PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@pytest.fixture(scope="session")
|
|
55
|
+
def finetuned_model(datasource) -> FinetunedEmbeddingModel:
|
|
56
|
+
return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
|
|
60
|
+
assert finetuned_model is not None
|
|
61
|
+
assert finetuned_model.name == "test_finetuned_model"
|
|
62
|
+
assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
63
|
+
assert finetuned_model.embedding_dim == 768
|
|
64
|
+
assert finetuned_model.max_seq_length == 512
|
|
65
|
+
assert finetuned_model._status == TaskStatus.COMPLETED
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_finetune_model_with_memoryset(memoryset: LabeledMemoryset):
|
|
69
|
+
finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_from_memoryset", memoryset)
|
|
70
|
+
assert finetuned_model is not None
|
|
71
|
+
assert finetuned_model.name == "test_finetuned_model_from_memoryset"
|
|
72
|
+
assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
73
|
+
assert finetuned_model.embedding_dim == 768
|
|
74
|
+
assert finetuned_model.max_seq_length == 512
|
|
75
|
+
assert finetuned_model._status == TaskStatus.COMPLETED
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
|
|
79
|
+
with pytest.raises(ValueError):
|
|
80
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
|
|
81
|
+
|
|
82
|
+
with pytest.raises(ValueError):
|
|
83
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="error")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
|
|
87
|
+
with pytest.raises(ValueError):
|
|
88
|
+
PretrainedEmbeddingModel.GTE_BASE.finetune("test_finetuned_model", datasource, if_exists="open")
|
|
89
|
+
|
|
90
|
+
new_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="open")
|
|
91
|
+
assert new_model is not None
|
|
92
|
+
assert new_model.name == "test_finetuned_model"
|
|
93
|
+
assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
94
|
+
assert new_model.embedding_dim == 768
|
|
95
|
+
assert new_model.max_seq_length == 512
|
|
96
|
+
assert new_model._status == TaskStatus.COMPLETED
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
|
|
100
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
101
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
|
|
105
|
+
memoryset = LabeledMemoryset.create(
|
|
106
|
+
"test_memoryset_finetuned_model",
|
|
107
|
+
datasource,
|
|
108
|
+
embedding_model=finetuned_model,
|
|
109
|
+
value_column="text",
|
|
110
|
+
)
|
|
111
|
+
assert memoryset is not None
|
|
112
|
+
assert memoryset.name == "test_memoryset_finetuned_model"
|
|
113
|
+
assert memoryset.embedding_model == finetuned_model
|
|
114
|
+
assert memoryset.length == datasource.length
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def test_open_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
|
|
118
|
+
model = FinetunedEmbeddingModel.open(finetuned_model.name)
|
|
119
|
+
assert isinstance(model, FinetunedEmbeddingModel)
|
|
120
|
+
assert model.id == finetuned_model.id
|
|
121
|
+
assert model.name == finetuned_model.name
|
|
122
|
+
assert model.base_model == PretrainedEmbeddingModel.DISTILBERT
|
|
123
|
+
assert model.embedding_dim == 768
|
|
124
|
+
assert model.max_seq_length == 512
|
|
125
|
+
assert model == finetuned_model
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def test_embed_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
|
|
129
|
+
embedding = finetuned_model.embed("I love this airline")
|
|
130
|
+
assert embedding is not None
|
|
131
|
+
assert isinstance(embedding, list)
|
|
132
|
+
assert len(embedding) == 768
|
|
133
|
+
assert isinstance(embedding[0], float)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
|
|
137
|
+
models = FinetunedEmbeddingModel.all()
|
|
138
|
+
assert len(models) > 0
|
|
139
|
+
assert any(model.name == finetuned_model.name for model in models)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def test_all_finetuned_models_unauthenticated(unauthenticated):
|
|
143
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
144
|
+
FinetunedEmbeddingModel.all()
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
|
|
148
|
+
assert finetuned_model not in FinetunedEmbeddingModel.all()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_drop_finetuned_model(datasource: Datasource):
|
|
152
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
153
|
+
assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
154
|
+
FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
|
|
155
|
+
with pytest.raises(LookupError):
|
|
156
|
+
FinetunedEmbeddingModel.open("finetuned_model_to_delete")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
|
|
160
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
161
|
+
PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_drop_finetuned_model_not_found():
|
|
165
|
+
with pytest.raises(LookupError):
|
|
166
|
+
FinetunedEmbeddingModel.drop(str(uuid4()))
|
|
167
|
+
# ignores error if specified
|
|
168
|
+
FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
|
|
172
|
+
with pytest.raises(LookupError):
|
|
173
|
+
FinetunedEmbeddingModel.drop(finetuned_model.id)
|