orca-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. orca_sdk/__init__.py +19 -0
  2. orca_sdk/_generated_api_client/__init__.py +3 -0
  3. orca_sdk/_generated_api_client/api/__init__.py +193 -0
  4. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  5. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +128 -0
  6. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +170 -0
  7. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +156 -0
  8. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +130 -0
  9. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +127 -0
  10. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  11. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +183 -0
  12. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +170 -0
  13. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  14. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +154 -0
  15. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  16. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +156 -0
  17. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +161 -0
  18. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +127 -0
  19. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +190 -0
  20. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  21. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +167 -0
  22. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +156 -0
  23. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +156 -0
  24. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +127 -0
  25. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  26. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +118 -0
  27. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +118 -0
  28. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  29. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +168 -0
  30. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +156 -0
  31. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +189 -0
  32. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +156 -0
  33. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +127 -0
  34. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  35. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +181 -0
  36. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +183 -0
  37. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +168 -0
  38. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +181 -0
  39. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +167 -0
  40. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +156 -0
  41. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +169 -0
  42. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +188 -0
  43. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +169 -0
  44. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +156 -0
  45. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +184 -0
  46. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +260 -0
  47. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +127 -0
  48. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +193 -0
  49. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +188 -0
  50. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +191 -0
  51. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +187 -0
  52. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  53. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +188 -0
  54. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +157 -0
  55. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +127 -0
  56. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +154 -0
  58. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +156 -0
  59. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +243 -0
  60. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +162 -0
  62. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +156 -0
  63. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +157 -0
  64. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +127 -0
  65. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +175 -0
  66. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +171 -0
  67. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +181 -0
  68. orca_sdk/_generated_api_client/client.py +216 -0
  69. orca_sdk/_generated_api_client/errors.py +38 -0
  70. orca_sdk/_generated_api_client/models/__init__.py +159 -0
  71. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +84 -0
  72. orca_sdk/_generated_api_client/models/api_key_metadata.py +118 -0
  73. orca_sdk/_generated_api_client/models/base_model.py +55 -0
  74. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
  75. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +114 -0
  76. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
  77. orca_sdk/_generated_api_client/models/column_info.py +114 -0
  78. orca_sdk/_generated_api_client/models/column_type.py +14 -0
  79. orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
  80. orca_sdk/_generated_api_client/models/create_api_key_request.py +99 -0
  81. orca_sdk/_generated_api_client/models/create_api_key_response.py +126 -0
  82. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +259 -0
  83. orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
  84. orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
  85. orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
  86. orca_sdk/_generated_api_client/models/embed_request.py +127 -0
  87. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
  88. orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
  89. orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
  90. orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
  91. orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
  92. orca_sdk/_generated_api_client/models/filter_item.py +231 -0
  93. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
  94. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +16 -0
  95. orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
  96. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
  97. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
  98. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
  99. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
  100. orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
  101. orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
  102. orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
  103. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
  104. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
  105. orca_sdk/_generated_api_client/models/label_prediction_result.py +101 -0
  106. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +232 -0
  107. orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
  108. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +108 -0
  109. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
  110. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
  111. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
  112. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
  113. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +277 -0
  114. orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
  115. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
  116. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
  117. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
  118. orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
  119. orca_sdk/_generated_api_client/models/list_predictions_request.py +234 -0
  120. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +9 -0
  121. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +9 -0
  122. orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
  123. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
  124. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
  125. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
  126. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
  127. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
  128. orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
  129. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +20 -0
  130. orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
  131. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
  132. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
  133. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
  134. orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
  135. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
  136. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +11 -0
  137. orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
  138. orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
  139. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
  140. orca_sdk/_generated_api_client/models/task.py +198 -0
  141. orca_sdk/_generated_api_client/models/task_status.py +14 -0
  142. orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
  143. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
  144. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
  145. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
  146. orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
  147. orca_sdk/_generated_api_client/py.typed +1 -0
  148. orca_sdk/_generated_api_client/types.py +56 -0
  149. orca_sdk/_utils/__init__.py +0 -0
  150. orca_sdk/_utils/analysis_ui.py +194 -0
  151. orca_sdk/_utils/analysis_ui_style.css +54 -0
  152. orca_sdk/_utils/auth.py +63 -0
  153. orca_sdk/_utils/auth_test.py +31 -0
  154. orca_sdk/_utils/common.py +37 -0
  155. orca_sdk/_utils/data_parsing.py +99 -0
  156. orca_sdk/_utils/data_parsing_test.py +244 -0
  157. orca_sdk/_utils/prediction_result_ui.css +18 -0
  158. orca_sdk/_utils/prediction_result_ui.py +64 -0
  159. orca_sdk/_utils/task.py +73 -0
  160. orca_sdk/classification_model.py +499 -0
  161. orca_sdk/classification_model_test.py +266 -0
  162. orca_sdk/conftest.py +117 -0
  163. orca_sdk/datasource.py +333 -0
  164. orca_sdk/datasource_test.py +95 -0
  165. orca_sdk/embedding_model.py +336 -0
  166. orca_sdk/embedding_model_test.py +173 -0
  167. orca_sdk/labeled_memoryset.py +1154 -0
  168. orca_sdk/labeled_memoryset_test.py +271 -0
  169. orca_sdk/orca_credentials.py +75 -0
  170. orca_sdk/orca_credentials_test.py +37 -0
  171. orca_sdk/telemetry.py +386 -0
  172. orca_sdk/telemetry_test.py +100 -0
  173. orca_sdk-0.1.0.dist-info/METADATA +39 -0
  174. orca_sdk-0.1.0.dist-info/RECORD +175 -0
  175. orca_sdk-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,336 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import abstractmethod
4
+ from datetime import datetime
5
+ from typing import TYPE_CHECKING, Sequence, cast, overload
6
+
7
+ from ._generated_api_client.api import (
8
+ create_finetuned_embedding_model,
9
+ delete_finetuned_embedding_model,
10
+ embed_with_finetuned_model_gpu,
11
+ embed_with_pretrained_model_gpu,
12
+ get_finetuned_embedding_model,
13
+ get_pretrained_embedding_model,
14
+ list_finetuned_embedding_models,
15
+ list_pretrained_embedding_models,
16
+ )
17
+ from ._generated_api_client.models import (
18
+ EmbeddingFinetuningMethod,
19
+ EmbedRequest,
20
+ FinetunedEmbeddingModelMetadata,
21
+ FinetuneEmbeddingModelRequest,
22
+ FinetuneEmbeddingModelRequestTrainingArgs,
23
+ PretrainedEmbeddingModelMetadata,
24
+ PretrainedEmbeddingModelName,
25
+ )
26
+ from ._utils.common import CreateMode, DropMode
27
+ from ._utils.task import TaskStatus, wait_for_task
28
+ from .datasource import Datasource
29
+
30
+ if TYPE_CHECKING:
31
+ from .labeled_memoryset import LabeledMemoryset
32
+
33
+
34
+ class _EmbeddingModel:
35
+ name: str
36
+ embedding_dim: int
37
+ max_seq_length: int
38
+ uses_context: bool
39
+
40
+ def __init__(self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool):
41
+ self.name = name
42
+ self.embedding_dim = embedding_dim
43
+ self.max_seq_length = max_seq_length
44
+ self.uses_context = uses_context
45
+
46
+ @classmethod
47
+ @abstractmethod
48
+ def all(cls) -> Sequence[_EmbeddingModel]:
49
+ pass
50
+
51
+ @overload
52
+ def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
53
+ pass
54
+
55
+ @overload
56
+ def embed(self, value: list[str], max_seq_length: int | None = None) -> list[list[float]]:
57
+ pass
58
+
59
+ def embed(self, value: str | list[str], max_seq_length: int | None = None) -> list[float] | list[list[float]]:
60
+ request = EmbedRequest(values=value if isinstance(value, list) else [value], max_seq_length=max_seq_length)
61
+ if isinstance(self, PretrainedEmbeddingModel):
62
+ embeddings = embed_with_pretrained_model_gpu(self._model_name, body=request)
63
+ elif isinstance(self, FinetunedEmbeddingModel):
64
+ embeddings = embed_with_finetuned_model_gpu(self.id, body=request)
65
+ else:
66
+ raise ValueError("Invalid embedding model")
67
+ return embeddings if isinstance(value, list) else embeddings[0]
68
+
69
+
70
+ class _PretrainedEmbeddingModelMeta(type):
71
+ def __getattr__(cls, name: str) -> PretrainedEmbeddingModel:
72
+ if cls != FinetunedEmbeddingModel and name in PretrainedEmbeddingModelName.__members__:
73
+ return PretrainedEmbeddingModel._get(name)
74
+ else:
75
+ raise AttributeError(f"'{cls.__name__}' object has no attribute '{name}'")
76
+
77
+
78
+ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingModelMeta):
79
+ """
80
+ A pretrained embedding model
81
+
82
+ **Models:**
83
+
84
+ OrcaCloud supports a select number of small to medium sized embedding models that perform well on the
85
+ [Hugging Face MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
86
+ These can be accessed as class attributes. We currently support:
87
+
88
+ - **`CDE_SMALL`**: Context-aware CDE small model from Hugging Face ([jxm/cde-small-v1](https://huggingface.co/jxm/cde-small-v1))
89
+ - **`CLIP_BASE`**: Multi-modal CLIP model from Hugging Face ([sentence-transformers/clip-ViT-L-14](https://huggingface.co/sentence-transformers/clip-ViT-L-14))
90
+ - **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
91
+
92
+ Examples:
93
+ >>> PretrainedEmbeddingModel.CDE_SMALL
94
+ PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
95
+
96
+ Attributes:
97
+ name: Name of the pretrained embedding model
98
+ embedding_dim: Dimension of the embeddings that are generated by the model
99
+ max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
100
+ uses_context: Whether the pretrained embedding model uses context
101
+ """
102
+
103
+ _model_name: PretrainedEmbeddingModelName
104
+
105
+ def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
106
+ # for internal use only, do not document
107
+ self._model_name = metadata.name
108
+ super().__init__(
109
+ name=metadata.name.value,
110
+ embedding_dim=metadata.embedding_dim,
111
+ max_seq_length=metadata.max_seq_length,
112
+ uses_context=metadata.uses_context,
113
+ )
114
+
115
+ def __eq__(self, other) -> bool:
116
+ return isinstance(other, PretrainedEmbeddingModel) and self.name == other.name
117
+
118
+ def __repr__(self) -> str:
119
+ return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}}})"
120
+
121
+ @classmethod
122
+ def all(cls) -> list[PretrainedEmbeddingModel]:
123
+ """
124
+ List all pretrained embedding models in the OrcaCloud
125
+
126
+ Returns:
127
+ A list of all pretrained embedding models available in the OrcaCloud
128
+ """
129
+ return [cls(metadata) for metadata in list_pretrained_embedding_models()]
130
+
131
+ _instances: dict[str, PretrainedEmbeddingModel] = {}
132
+
133
+ @classmethod
134
+ def _get(cls, name: PretrainedEmbeddingModelName | str) -> PretrainedEmbeddingModel:
135
+ # for internal use only, do not document - we want people to use dot notation to get the model
136
+ if str(name) not in cls._instances:
137
+ cls._instances[str(name)] = cls(get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name)))
138
+ return cls._instances[str(name)]
139
+
140
+ @classmethod
141
+ def exists(cls, name: str) -> bool:
142
+ """
143
+ Check if a pretrained embedding model exists by name
144
+
145
+ Params:
146
+ name: The name of the pretrained embedding model
147
+
148
+ Returns:
149
+ True if the pretrained embedding model exists, False otherwise
150
+ """
151
+ return name in PretrainedEmbeddingModelName
152
+
153
+ def finetune(
154
+ self,
155
+ name: str,
156
+ train_datasource: Datasource | LabeledMemoryset,
157
+ *,
158
+ eval_datasource: Datasource | None = None,
159
+ label_column: str = "label",
160
+ value_column: str = "value",
161
+ training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
162
+ training_args: dict | None = None,
163
+ if_exists: CreateMode = "error",
164
+ ) -> FinetunedEmbeddingModel:
165
+ """
166
+ Finetune an embedding model
167
+
168
+ Params:
169
+ name: Name of the finetuned embedding model
170
+ train_datasource: Data to train on
171
+ eval_datasource: Optionally provide data to evaluate on
172
+ label_column: Column name of the label
173
+ value_column: Column name of the value
174
+ training_method: Training method to use
175
+ training_args: Optional override for Hugging Face [`TrainingArguments`](transformers.TrainingArguments).
176
+ If not provided, reasonable training arguments will be used for the specified training method
177
+ if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
178
+ `"error"`. Other option is `"open"` to open the existing finetuned embedding model.
179
+
180
+ Returns:
181
+ The finetuned embedding model
182
+
183
+ Raises:
184
+ ValueError: If the finetuned embedding model already exists and `if_exists` is `"error"` or if it is `"open"`
185
+ but the base model param does not match the existing model
186
+
187
+ Examples:
188
+ >>> datasource = Datasource.open("my_datasource")
189
+ >>> model = PretrainedEmbeddingModel.CLIP_BASE
190
+ >>> model.finetune("my_finetuned_model", datasource)
191
+ """
192
+ exists = FinetunedEmbeddingModel.exists(name)
193
+
194
+ if exists and if_exists == "error":
195
+ raise ValueError(f"Finetuned embedding model '{name}' already exists")
196
+ elif exists and if_exists == "open":
197
+ existing = FinetunedEmbeddingModel.open(name)
198
+
199
+ if existing.base_model_name != self._model_name:
200
+ raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
201
+
202
+ return existing
203
+
204
+ from .labeled_memoryset import LabeledMemoryset
205
+
206
+ train_datasource_id = train_datasource.id if isinstance(train_datasource, Datasource) else None
207
+ train_memoryset_id = train_datasource.id if isinstance(train_datasource, LabeledMemoryset) else None
208
+ assert train_datasource_id is not None or train_memoryset_id is not None
209
+ res = create_finetuned_embedding_model(
210
+ body=FinetuneEmbeddingModelRequest(
211
+ name=name,
212
+ base_model=self._model_name,
213
+ train_memoryset_id=train_memoryset_id,
214
+ train_datasource_id=train_datasource_id,
215
+ eval_datasource_id=eval_datasource.id if eval_datasource is not None else None,
216
+ label_column=label_column,
217
+ value_column=value_column,
218
+ training_method=EmbeddingFinetuningMethod(training_method),
219
+ training_args=(FinetuneEmbeddingModelRequestTrainingArgs.from_dict(training_args or {})),
220
+ ),
221
+ )
222
+ wait_for_task(res.finetuning_task_id, description="Finetuning embedding model")
223
+ return FinetunedEmbeddingModel.open(res.id)
224
+
225
+
226
+ class FinetunedEmbeddingModel(_EmbeddingModel):
227
+ """
228
+ A finetuned embedding model in the OrcaCloud
229
+
230
+ Attributes:
231
+ name: Name of the finetuned embedding model
232
+ embedding_dim: Dimension of the embeddings that are generated by the model
233
+ max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
234
+ uses_context: Whether the model uses the memoryset to contextualize embeddings (acts akin to inverse document frequency in TFIDF features)
235
+ id: Unique identifier of the finetuned embedding model
236
+ base_model: Base model the finetuned embedding model was trained on
237
+ created_at: When the model was finetuned
238
+ """
239
+
240
+ id: str
241
+ created_at: datetime
242
+ updated_at: datetime
243
+ _status: TaskStatus
244
+
245
+ def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
246
+ # for internal use only, do not document
247
+ self.id = metadata.id
248
+ self.created_at = metadata.created_at
249
+ self.updated_at = metadata.updated_at
250
+ self.base_model_name = metadata.base_model
251
+ self._status = metadata.finetuning_status
252
+ super().__init__(
253
+ name=metadata.name,
254
+ embedding_dim=metadata.embedding_dim,
255
+ max_seq_length=metadata.max_seq_length,
256
+ uses_context=metadata.uses_context,
257
+ )
258
+
259
+ def __eq__(self, other) -> bool:
260
+ return isinstance(other, FinetunedEmbeddingModel) and self.id == other.id
261
+
262
+ def __repr__(self) -> str:
263
+ return (
264
+ "FinetunedEmbeddingModel({\n"
265
+ f" name: {self.name},\n"
266
+ f" embedding_dim: {self.embedding_dim},\n"
267
+ f" max_seq_length: {self.max_seq_length},\n"
268
+ f" status: {self._status}\n"
269
+ f" base_model: PretrainedEmbeddingModel.{self.base_model_name.value}\n"
270
+ "})"
271
+ )
272
+
273
+ @property
274
+ def base_model(self) -> PretrainedEmbeddingModel:
275
+ """Pretrained model the finetuned embedding model was based on"""
276
+ return PretrainedEmbeddingModel._get(self.base_model_name)
277
+
278
+ @classmethod
279
+ def all(cls) -> list[FinetunedEmbeddingModel]:
280
+ """
281
+ List all finetuned embedding model handles in the OrcaCloud
282
+
283
+ Returns:
284
+ A list of all finetuned embedding model handles in the OrcaCloud
285
+ """
286
+ return [cls(metadata) for metadata in list_finetuned_embedding_models()]
287
+
288
+ @classmethod
289
+ def open(cls, name: str) -> FinetunedEmbeddingModel:
290
+ """
291
+ Get a handle to a finetuned embedding model in the OrcaCloud
292
+
293
+ Params:
294
+ name: The name or unique identifier of a finetuned embedding model
295
+
296
+ Returns:
297
+ A handle to the finetuned embedding model in the OrcaCloud
298
+
299
+ Raises:
300
+ LookupError: If the finetuned embedding model does not exist
301
+ """
302
+ return cls(get_finetuned_embedding_model(name))
303
+
304
+ @classmethod
305
+ def exists(cls, name_or_id: str) -> bool:
306
+ """
307
+ Check if a finetuned embedding model with the given name or id exists.
308
+
309
+ Params:
310
+ name_or_id: The name or id of the finetuned embedding model
311
+
312
+ Returns:
313
+ True if the finetuned embedding model exists, False otherwise
314
+ """
315
+ try:
316
+ cls.open(name_or_id)
317
+ return True
318
+ except LookupError:
319
+ return False
320
+
321
+ @classmethod
322
+ def drop(cls, name_or_id: str, *, if_not_exists: DropMode = "error"):
323
+ """
324
+ Delete the finetuned embedding model from the OrcaCloud
325
+
326
+ Params:
327
+ name_or_id: The name or id of the finetuned embedding model
328
+
329
+ Raises:
330
+ LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
331
+ """
332
+ try:
333
+ delete_finetuned_embedding_model(name_or_id)
334
+ except LookupError:
335
+ if if_not_exists == "error":
336
+ raise
@@ -0,0 +1,173 @@
1
+ from uuid import uuid4
2
+
3
+ import pytest
4
+
5
+ from .datasource import Datasource
6
+ from .embedding_model import (
7
+ FinetunedEmbeddingModel,
8
+ PretrainedEmbeddingModel,
9
+ PretrainedEmbeddingModelName,
10
+ TaskStatus,
11
+ )
12
+ from .labeled_memoryset import LabeledMemoryset
13
+
14
+
15
+ def test_open_pretrained_model():
16
+ model = PretrainedEmbeddingModel.GTE_BASE
17
+ assert model is not None
18
+ assert isinstance(model, PretrainedEmbeddingModel)
19
+ assert model.name == "GTE_BASE"
20
+ assert model.embedding_dim == 768
21
+ assert model.max_seq_length == 8192
22
+ assert model is PretrainedEmbeddingModel.GTE_BASE
23
+
24
+
25
+ def test_open_pretrained_model_unauthenticated(unauthenticated):
26
+ with pytest.raises(ValueError, match="Invalid API key"):
27
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
28
+
29
+
30
+ def test_open_pretrained_model_not_found():
31
+ with pytest.raises(LookupError):
32
+ PretrainedEmbeddingModel._get("INVALID_MODEL")
33
+
34
+
35
+ def test_all_pretrained_models():
36
+ models = PretrainedEmbeddingModel.all()
37
+ assert len(models) == len(PretrainedEmbeddingModelName)
38
+ assert all(m.name in PretrainedEmbeddingModelName.__members__ for m in models)
39
+
40
+
41
+ def test_embed_text():
42
+ embedding = PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
43
+ assert embedding is not None
44
+ assert isinstance(embedding, list)
45
+ assert len(embedding) == 768
46
+ assert isinstance(embedding[0], float)
47
+
48
+
49
+ def test_embed_text_unauthenticated(unauthenticated):
50
+ with pytest.raises(ValueError, match="Invalid API key"):
51
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
52
+
53
+
54
+ @pytest.fixture(scope="session")
55
+ def finetuned_model(datasource) -> FinetunedEmbeddingModel:
56
+ return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
57
+
58
+
59
+ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
60
+ assert finetuned_model is not None
61
+ assert finetuned_model.name == "test_finetuned_model"
62
+ assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
63
+ assert finetuned_model.embedding_dim == 768
64
+ assert finetuned_model.max_seq_length == 512
65
+ assert finetuned_model._status == TaskStatus.COMPLETED
66
+
67
+
68
+ def test_finetune_model_with_memoryset(memoryset: LabeledMemoryset):
69
+ finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_from_memoryset", memoryset)
70
+ assert finetuned_model is not None
71
+ assert finetuned_model.name == "test_finetuned_model_from_memoryset"
72
+ assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
73
+ assert finetuned_model.embedding_dim == 768
74
+ assert finetuned_model.max_seq_length == 512
75
+ assert finetuned_model._status == TaskStatus.COMPLETED
76
+
77
+
78
+ def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
79
+ with pytest.raises(ValueError):
80
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource)
81
+
82
+ with pytest.raises(ValueError):
83
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="error")
84
+
85
+
86
+ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
87
+ with pytest.raises(ValueError):
88
+ PretrainedEmbeddingModel.GTE_BASE.finetune("test_finetuned_model", datasource, if_exists="open")
89
+
90
+ new_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, if_exists="open")
91
+ assert new_model is not None
92
+ assert new_model.name == "test_finetuned_model"
93
+ assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
94
+ assert new_model.embedding_dim == 768
95
+ assert new_model.max_seq_length == 512
96
+ assert new_model._status == TaskStatus.COMPLETED
97
+
98
+
99
+ def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
100
+ with pytest.raises(ValueError, match="Invalid API key"):
101
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_unauthenticated", datasource)
102
+
103
+
104
+ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
105
+ memoryset = LabeledMemoryset.create(
106
+ "test_memoryset_finetuned_model",
107
+ datasource,
108
+ embedding_model=finetuned_model,
109
+ value_column="text",
110
+ )
111
+ assert memoryset is not None
112
+ assert memoryset.name == "test_memoryset_finetuned_model"
113
+ assert memoryset.embedding_model == finetuned_model
114
+ assert memoryset.length == datasource.length
115
+
116
+
117
+ def test_open_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
118
+ model = FinetunedEmbeddingModel.open(finetuned_model.name)
119
+ assert isinstance(model, FinetunedEmbeddingModel)
120
+ assert model.id == finetuned_model.id
121
+ assert model.name == finetuned_model.name
122
+ assert model.base_model == PretrainedEmbeddingModel.DISTILBERT
123
+ assert model.embedding_dim == 768
124
+ assert model.max_seq_length == 512
125
+ assert model == finetuned_model
126
+
127
+
128
+ def test_embed_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
129
+ embedding = finetuned_model.embed("I love this airline")
130
+ assert embedding is not None
131
+ assert isinstance(embedding, list)
132
+ assert len(embedding) == 768
133
+ assert isinstance(embedding[0], float)
134
+
135
+
136
+ def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
137
+ models = FinetunedEmbeddingModel.all()
138
+ assert len(models) > 0
139
+ assert any(model.name == finetuned_model.name for model in models)
140
+
141
+
142
+ def test_all_finetuned_models_unauthenticated(unauthenticated):
143
+ with pytest.raises(ValueError, match="Invalid API key"):
144
+ FinetunedEmbeddingModel.all()
145
+
146
+
147
+ def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
148
+ assert finetuned_model not in FinetunedEmbeddingModel.all()
149
+
150
+
151
+ def test_drop_finetuned_model(datasource: Datasource):
152
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
153
+ assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
154
+ FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
155
+ with pytest.raises(LookupError):
156
+ FinetunedEmbeddingModel.open("finetuned_model_to_delete")
157
+
158
+
159
+ def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
160
+ with pytest.raises(ValueError, match="Invalid API key"):
161
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource)
162
+
163
+
164
+ def test_drop_finetuned_model_not_found():
165
+ with pytest.raises(LookupError):
166
+ FinetunedEmbeddingModel.drop(str(uuid4()))
167
+ # ignores error if specified
168
+ FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
169
+
170
+
171
+ def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
172
+ with pytest.raises(LookupError):
173
+ FinetunedEmbeddingModel.drop(finetuned_model.id)