orca-sdk 0.0.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. orca_sdk/__init__.py +24 -0
  2. orca_sdk/_generated_api_client/__init__.py +3 -0
  3. orca_sdk/_generated_api_client/api/__init__.py +205 -0
  4. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  5. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +130 -0
  6. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +172 -0
  7. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +158 -0
  8. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +132 -0
  9. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +129 -0
  10. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  11. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +185 -0
  12. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +172 -0
  13. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +170 -0
  14. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +156 -0
  15. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +172 -0
  16. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +158 -0
  17. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +163 -0
  18. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +129 -0
  19. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +192 -0
  20. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  21. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +169 -0
  22. orca_sdk/_generated_api_client/api/datasource/create_embedding_evaluation_datasource_name_or_id_embedding_evaluation_post.py +185 -0
  23. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +158 -0
  24. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +158 -0
  25. orca_sdk/_generated_api_client/api/datasource/get_embedding_evaluation_datasource_name_or_id_embedding_evaluation_task_id_get.py +171 -0
  26. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +129 -0
  27. orca_sdk/_generated_api_client/api/datasource/list_embedding_evaluations_datasource_name_or_id_embedding_evaluation_get.py +237 -0
  28. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  29. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +120 -0
  30. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +120 -0
  31. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  32. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +170 -0
  33. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +158 -0
  34. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +191 -0
  35. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +158 -0
  36. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +129 -0
  37. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  38. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +183 -0
  39. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +185 -0
  40. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +170 -0
  41. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +183 -0
  42. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +169 -0
  43. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +158 -0
  44. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +171 -0
  45. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +190 -0
  46. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +171 -0
  47. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +158 -0
  48. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +186 -0
  49. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +262 -0
  50. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +129 -0
  51. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +195 -0
  52. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +190 -0
  53. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +193 -0
  54. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +189 -0
  55. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +194 -0
  57. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +163 -0
  58. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +129 -0
  59. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  60. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +156 -0
  61. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +158 -0
  62. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +245 -0
  63. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +164 -0
  65. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +158 -0
  66. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +159 -0
  67. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +129 -0
  68. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +177 -0
  69. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +173 -0
  70. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +183 -0
  71. orca_sdk/_generated_api_client/client.py +216 -0
  72. orca_sdk/_generated_api_client/errors.py +38 -0
  73. orca_sdk/_generated_api_client/models/__init__.py +179 -0
  74. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +116 -0
  75. orca_sdk/_generated_api_client/models/api_key_metadata.py +137 -0
  76. orca_sdk/_generated_api_client/models/api_key_metadata_scope_item.py +9 -0
  77. orca_sdk/_generated_api_client/models/base_model.py +55 -0
  78. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +176 -0
  79. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +147 -0
  80. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +150 -0
  81. orca_sdk/_generated_api_client/models/column_info.py +114 -0
  82. orca_sdk/_generated_api_client/models/column_type.py +14 -0
  83. orca_sdk/_generated_api_client/models/conflict_error_response.py +80 -0
  84. orca_sdk/_generated_api_client/models/create_api_key_request.py +120 -0
  85. orca_sdk/_generated_api_client/models/create_api_key_request_scope_item.py +9 -0
  86. orca_sdk/_generated_api_client/models/create_api_key_response.py +145 -0
  87. orca_sdk/_generated_api_client/models/create_api_key_response_scope_item.py +9 -0
  88. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +279 -0
  89. orca_sdk/_generated_api_client/models/create_rac_model_request.py +209 -0
  90. orca_sdk/_generated_api_client/models/datasource_metadata.py +142 -0
  91. orca_sdk/_generated_api_client/models/delete_memories_request.py +70 -0
  92. orca_sdk/_generated_api_client/models/embed_request.py +127 -0
  93. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +179 -0
  94. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +148 -0
  95. orca_sdk/_generated_api_client/models/embedding_evaluation_result.py +86 -0
  96. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +9 -0
  97. orca_sdk/_generated_api_client/models/embedding_model_result.py +114 -0
  98. orca_sdk/_generated_api_client/models/evaluation_request.py +180 -0
  99. orca_sdk/_generated_api_client/models/evaluation_response.py +140 -0
  100. orca_sdk/_generated_api_client/models/feedback_type.py +9 -0
  101. orca_sdk/_generated_api_client/models/field_validation_error.py +103 -0
  102. orca_sdk/_generated_api_client/models/filter_item.py +231 -0
  103. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +15 -0
  104. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +20 -0
  105. orca_sdk/_generated_api_client/models/filter_item_op.py +16 -0
  106. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +70 -0
  107. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +259 -0
  108. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +66 -0
  109. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +166 -0
  110. orca_sdk/_generated_api_client/models/get_memories_request.py +70 -0
  111. orca_sdk/_generated_api_client/models/internal_server_error_response.py +80 -0
  112. orca_sdk/_generated_api_client/models/label_class_metrics.py +108 -0
  113. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +274 -0
  114. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +68 -0
  115. orca_sdk/_generated_api_client/models/label_prediction_result.py +115 -0
  116. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +246 -0
  117. orca_sdk/_generated_api_client/models/labeled_memory.py +197 -0
  118. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +128 -0
  119. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +68 -0
  120. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +258 -0
  121. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +68 -0
  122. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +68 -0
  123. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +237 -0
  124. orca_sdk/_generated_api_client/models/labeled_memory_update.py +171 -0
  125. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +68 -0
  126. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +195 -0
  127. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +9 -0
  128. orca_sdk/_generated_api_client/models/list_memories_request.py +104 -0
  129. orca_sdk/_generated_api_client/models/list_predictions_request.py +257 -0
  130. orca_sdk/_generated_api_client/models/lookup_request.py +81 -0
  131. orca_sdk/_generated_api_client/models/memory_metrics.py +156 -0
  132. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +83 -0
  133. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +9 -0
  134. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +180 -0
  135. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +66 -0
  136. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +9 -0
  137. orca_sdk/_generated_api_client/models/not_found_error_response.py +100 -0
  138. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +21 -0
  139. orca_sdk/_generated_api_client/models/precision_recall_curve.py +94 -0
  140. orca_sdk/_generated_api_client/models/prediction_feedback.py +157 -0
  141. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +115 -0
  142. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +122 -0
  143. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +102 -0
  144. orca_sdk/_generated_api_client/models/prediction_request.py +169 -0
  145. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_0.py +10 -0
  146. orca_sdk/_generated_api_client/models/prediction_sort_item_item_type_1.py +9 -0
  147. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +97 -0
  148. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +12 -0
  149. orca_sdk/_generated_api_client/models/rac_head_type.py +11 -0
  150. orca_sdk/_generated_api_client/models/rac_model_metadata.py +191 -0
  151. orca_sdk/_generated_api_client/models/roc_curve.py +94 -0
  152. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +80 -0
  153. orca_sdk/_generated_api_client/models/task.py +198 -0
  154. orca_sdk/_generated_api_client/models/task_status.py +14 -0
  155. orca_sdk/_generated_api_client/models/task_status_info.py +133 -0
  156. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +72 -0
  157. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +80 -0
  158. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +94 -0
  159. orca_sdk/_generated_api_client/models/update_prediction_request.py +93 -0
  160. orca_sdk/_generated_api_client/py.typed +1 -0
  161. orca_sdk/_generated_api_client/types.py +56 -0
  162. orca_sdk/_utils/__init__.py +0 -0
  163. orca_sdk/_utils/analysis_ui.py +192 -0
  164. orca_sdk/_utils/analysis_ui_style.css +54 -0
  165. orca_sdk/_utils/auth.py +68 -0
  166. orca_sdk/_utils/auth_test.py +31 -0
  167. orca_sdk/_utils/common.py +37 -0
  168. orca_sdk/_utils/data_parsing.py +99 -0
  169. orca_sdk/_utils/data_parsing_test.py +244 -0
  170. orca_sdk/_utils/prediction_result_ui.css +18 -0
  171. orca_sdk/_utils/prediction_result_ui.py +64 -0
  172. orca_sdk/_utils/task.py +73 -0
  173. orca_sdk/classification_model.py +508 -0
  174. orca_sdk/classification_model_test.py +272 -0
  175. orca_sdk/conftest.py +116 -0
  176. orca_sdk/credentials.py +126 -0
  177. orca_sdk/credentials_test.py +37 -0
  178. orca_sdk/datasource.py +333 -0
  179. orca_sdk/datasource_test.py +96 -0
  180. orca_sdk/embedding_model.py +347 -0
  181. orca_sdk/embedding_model_test.py +176 -0
  182. orca_sdk/memoryset.py +1209 -0
  183. orca_sdk/memoryset_test.py +287 -0
  184. orca_sdk/telemetry.py +398 -0
  185. orca_sdk/telemetry_test.py +109 -0
  186. orca_sdk-0.0.78.dist-info/METADATA +79 -0
  187. orca_sdk-0.0.78.dist-info/RECORD +188 -0
  188. orca_sdk-0.0.78.dist-info/WHEEL +4 -0
@@ -0,0 +1,347 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import abstractmethod
4
+ from datetime import datetime
5
+ from typing import TYPE_CHECKING, Sequence, cast, overload
6
+
7
+ from ._generated_api_client.api import (
8
+ create_finetuned_embedding_model,
9
+ delete_finetuned_embedding_model,
10
+ embed_with_finetuned_model_gpu,
11
+ embed_with_pretrained_model_gpu,
12
+ get_finetuned_embedding_model,
13
+ get_pretrained_embedding_model,
14
+ list_finetuned_embedding_models,
15
+ list_pretrained_embedding_models,
16
+ )
17
+ from ._generated_api_client.models import (
18
+ EmbeddingFinetuningMethod,
19
+ EmbedRequest,
20
+ FinetunedEmbeddingModelMetadata,
21
+ FinetuneEmbeddingModelRequest,
22
+ FinetuneEmbeddingModelRequestTrainingArgs,
23
+ PretrainedEmbeddingModelMetadata,
24
+ PretrainedEmbeddingModelName,
25
+ )
26
+ from ._utils.common import CreateMode, DropMode
27
+ from ._utils.task import TaskStatus, wait_for_task
28
+ from .datasource import Datasource
29
+
30
+ if TYPE_CHECKING:
31
+ from .memoryset import LabeledMemoryset
32
+
33
+
34
+ class _EmbeddingModel:
35
+ name: str
36
+ embedding_dim: int
37
+ max_seq_length: int
38
+ uses_context: bool
39
+
40
+ def __init__(self, *, name: str, embedding_dim: int, max_seq_length: int, uses_context: bool):
41
+ self.name = name
42
+ self.embedding_dim = embedding_dim
43
+ self.max_seq_length = max_seq_length
44
+ self.uses_context = uses_context
45
+
46
+ @classmethod
47
+ @abstractmethod
48
+ def all(cls) -> Sequence[_EmbeddingModel]:
49
+ pass
50
+
51
+ @overload
52
+ def embed(self, value: str, max_seq_length: int | None = None) -> list[float]:
53
+ pass
54
+
55
+ @overload
56
+ def embed(self, value: list[str], max_seq_length: int | None = None) -> list[list[float]]:
57
+ pass
58
+
59
+ def embed(self, value: str | list[str], max_seq_length: int | None = None) -> list[float] | list[list[float]]:
60
+ """
61
+ Generate embeddings for a value or list of values
62
+
63
+ Params:
64
+ value: The value or list of values to embed
65
+ max_seq_length: The maximum sequence length to truncate the input to
66
+
67
+ Returns:
68
+ A matrix of floats representing the embedding for each value if the input is a list of
69
+ values, or a list of floats representing the embedding for the single value if the
70
+ input is a single value
71
+ """
72
+ request = EmbedRequest(values=value if isinstance(value, list) else [value], max_seq_length=max_seq_length)
73
+ if isinstance(self, PretrainedEmbeddingModel):
74
+ embeddings = embed_with_pretrained_model_gpu(self._model_name, body=request)
75
+ elif isinstance(self, FinetunedEmbeddingModel):
76
+ embeddings = embed_with_finetuned_model_gpu(self.id, body=request)
77
+ else:
78
+ raise ValueError("Invalid embedding model")
79
+ return embeddings if isinstance(value, list) else embeddings[0]
80
+
81
+
82
+ class _PretrainedEmbeddingModelMeta(type):
83
+ def __getattr__(cls, name: str) -> PretrainedEmbeddingModel:
84
+ if cls != FinetunedEmbeddingModel and name in PretrainedEmbeddingModelName.__members__:
85
+ return PretrainedEmbeddingModel._get(name)
86
+ else:
87
+ raise AttributeError(f"'{cls.__name__}' object has no attribute '{name}'")
88
+
89
+
90
+ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingModelMeta):
91
+ """
92
+ A pretrained embedding model
93
+
94
+ **Models:**
95
+
96
+ OrcaCloud supports a select number of small to medium sized embedding models that perform well on the
97
+ [Hugging Face MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
98
+ These can be accessed as class attributes. We currently support:
99
+
100
+ - **`CDE_SMALL`**: Context-aware CDE small model from Hugging Face ([jxm/cde-small-v1](https://huggingface.co/jxm/cde-small-v1))
101
+ - **`CLIP_BASE`**: Multi-modal CLIP model from Hugging Face ([sentence-transformers/clip-ViT-L-14](https://huggingface.co/sentence-transformers/clip-ViT-L-14))
102
+ - **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
103
+
104
+ Examples:
105
+ >>> PretrainedEmbeddingModel.CDE_SMALL
106
+ PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
107
+
108
+ Attributes:
109
+ name: Name of the pretrained embedding model
110
+ embedding_dim: Dimension of the embeddings that are generated by the model
111
+ max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
112
+ uses_context: Whether the pretrained embedding model uses context
113
+ """
114
+
115
+ _model_name: PretrainedEmbeddingModelName
116
+
117
+ def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
118
+ # for internal use only, do not document
119
+ self._model_name = metadata.name
120
+ super().__init__(
121
+ name=metadata.name.value,
122
+ embedding_dim=metadata.embedding_dim,
123
+ max_seq_length=metadata.max_seq_length,
124
+ uses_context=metadata.uses_context,
125
+ )
126
+
127
+ def __eq__(self, other) -> bool:
128
+ return isinstance(other, PretrainedEmbeddingModel) and self.name == other.name
129
+
130
+ def __repr__(self) -> str:
131
+ return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}}})"
132
+
133
+ @classmethod
134
+ def all(cls) -> list[PretrainedEmbeddingModel]:
135
+ """
136
+ List all pretrained embedding models in the OrcaCloud
137
+
138
+ Returns:
139
+ A list of all pretrained embedding models available in the OrcaCloud
140
+ """
141
+ return [cls(metadata) for metadata in list_pretrained_embedding_models()]
142
+
143
+ _instances: dict[str, PretrainedEmbeddingModel] = {}
144
+
145
+ @classmethod
146
+ def _get(cls, name: PretrainedEmbeddingModelName | str) -> PretrainedEmbeddingModel:
147
+ # for internal use only, do not document - we want people to use dot notation to get the model
148
+ if str(name) not in cls._instances:
149
+ cls._instances[str(name)] = cls(get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name)))
150
+ return cls._instances[str(name)]
151
+
152
+ @classmethod
153
+ def exists(cls, name: str) -> bool:
154
+ """
155
+ Check if a pretrained embedding model exists by name
156
+
157
+ Params:
158
+ name: The name of the pretrained embedding model
159
+
160
+ Returns:
161
+ True if the pretrained embedding model exists, False otherwise
162
+ """
163
+ return name in PretrainedEmbeddingModelName
164
+
165
+ def finetune(
166
+ self,
167
+ name: str,
168
+ train_datasource: Datasource | LabeledMemoryset,
169
+ *,
170
+ eval_datasource: Datasource | None = None,
171
+ label_column: str = "label",
172
+ value_column: str = "value",
173
+ training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
174
+ training_args: dict | None = None,
175
+ if_exists: CreateMode = "error",
176
+ ) -> FinetunedEmbeddingModel:
177
+ """
178
+ Finetune an embedding model
179
+
180
+ Params:
181
+ name: Name of the finetuned embedding model
182
+ train_datasource: Data to train on
183
+ eval_datasource: Optionally provide data to evaluate on
184
+ label_column: Column name of the label
185
+ value_column: Column name of the value
186
+ training_method: Training method to use
187
+ training_args: Optional override for Hugging Face [`TrainingArguments`](transformers.TrainingArguments).
188
+ If not provided, reasonable training arguments will be used for the specified training method
189
+ if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
190
+ `"error"`. Other option is `"open"` to open the existing finetuned embedding model.
191
+
192
+ Returns:
193
+ The finetuned embedding model
194
+
195
+ Raises:
196
+ ValueError: If the finetuned embedding model already exists and `if_exists` is `"error"` or if it is `"open"`
197
+ but the base model param does not match the existing model
198
+
199
+ Examples:
200
+ >>> datasource = Datasource.open("my_datasource")
201
+ >>> model = PretrainedEmbeddingModel.CLIP_BASE
202
+ >>> model.finetune("my_finetuned_model", datasource)
203
+ """
204
+ exists = FinetunedEmbeddingModel.exists(name)
205
+
206
+ if exists and if_exists == "error":
207
+ raise ValueError(f"Finetuned embedding model '{name}' already exists")
208
+ elif exists and if_exists == "open":
209
+ existing = FinetunedEmbeddingModel.open(name)
210
+
211
+ if existing.base_model_name != self._model_name:
212
+ raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
213
+
214
+ return existing
215
+
216
+ from .memoryset import LabeledMemoryset
217
+
218
+ train_datasource_id = train_datasource.id if isinstance(train_datasource, Datasource) else None
219
+ train_memoryset_id = train_datasource.id if isinstance(train_datasource, LabeledMemoryset) else None
220
+ assert train_datasource_id is not None or train_memoryset_id is not None
221
+ res = create_finetuned_embedding_model(
222
+ body=FinetuneEmbeddingModelRequest(
223
+ name=name,
224
+ base_model=self._model_name,
225
+ train_memoryset_id=train_memoryset_id,
226
+ train_datasource_id=train_datasource_id,
227
+ eval_datasource_id=eval_datasource.id if eval_datasource is not None else None,
228
+ label_column=label_column,
229
+ value_column=value_column,
230
+ training_method=EmbeddingFinetuningMethod(training_method),
231
+ training_args=(FinetuneEmbeddingModelRequestTrainingArgs.from_dict(training_args or {})),
232
+ ),
233
+ )
234
+ wait_for_task(res.finetuning_task_id, description="Finetuning embedding model")
235
+ return FinetunedEmbeddingModel.open(res.id)
236
+
237
+
238
+ class FinetunedEmbeddingModel(_EmbeddingModel):
239
+ """
240
+ A finetuned embedding model in the OrcaCloud
241
+
242
+ Attributes:
243
+ name: Name of the finetuned embedding model
244
+ embedding_dim: Dimension of the embeddings that are generated by the model
245
+ max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
246
+ uses_context: Whether the model uses the memoryset to contextualize embeddings (acts akin to inverse document frequency in TFIDF features)
247
+ id: Unique identifier of the finetuned embedding model
248
+ base_model: Base model the finetuned embedding model was trained on
249
+ created_at: When the model was finetuned
250
+ """
251
+
252
+ id: str
253
+ created_at: datetime
254
+ updated_at: datetime
255
+ _status: TaskStatus
256
+
257
+ def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
258
+ # for internal use only, do not document
259
+ self.id = metadata.id
260
+ self.created_at = metadata.created_at
261
+ self.updated_at = metadata.updated_at
262
+ self.base_model_name = metadata.base_model
263
+ self._status = metadata.finetuning_status
264
+ super().__init__(
265
+ name=metadata.name,
266
+ embedding_dim=metadata.embedding_dim,
267
+ max_seq_length=metadata.max_seq_length,
268
+ uses_context=metadata.uses_context,
269
+ )
270
+
271
+ def __eq__(self, other) -> bool:
272
+ return isinstance(other, FinetunedEmbeddingModel) and self.id == other.id
273
+
274
+ def __repr__(self) -> str:
275
+ return (
276
+ "FinetunedEmbeddingModel({\n"
277
+ f" name: {self.name},\n"
278
+ f" embedding_dim: {self.embedding_dim},\n"
279
+ f" max_seq_length: {self.max_seq_length},\n"
280
+ f" base_model: PretrainedEmbeddingModel.{self.base_model_name.value}\n"
281
+ "})"
282
+ )
283
+
284
+ @property
285
+ def base_model(self) -> PretrainedEmbeddingModel:
286
+ """Pretrained model the finetuned embedding model was based on"""
287
+ return PretrainedEmbeddingModel._get(self.base_model_name)
288
+
289
+ @classmethod
290
+ def all(cls) -> list[FinetunedEmbeddingModel]:
291
+ """
292
+ List all finetuned embedding model handles in the OrcaCloud
293
+
294
+ Returns:
295
+ A list of all finetuned embedding model handles in the OrcaCloud
296
+ """
297
+ return [cls(metadata) for metadata in list_finetuned_embedding_models()]
298
+
299
+ @classmethod
300
+ def open(cls, name: str) -> FinetunedEmbeddingModel:
301
+ """
302
+ Get a handle to a finetuned embedding model in the OrcaCloud
303
+
304
+ Params:
305
+ name: The name or unique identifier of a finetuned embedding model
306
+
307
+ Returns:
308
+ A handle to the finetuned embedding model in the OrcaCloud
309
+
310
+ Raises:
311
+ LookupError: If the finetuned embedding model does not exist
312
+ """
313
+ return cls(get_finetuned_embedding_model(name))
314
+
315
+ @classmethod
316
+ def exists(cls, name_or_id: str) -> bool:
317
+ """
318
+ Check if a finetuned embedding model with the given name or id exists.
319
+
320
+ Params:
321
+ name_or_id: The name or id of the finetuned embedding model
322
+
323
+ Returns:
324
+ True if the finetuned embedding model exists, False otherwise
325
+ """
326
+ try:
327
+ cls.open(name_or_id)
328
+ return True
329
+ except LookupError:
330
+ return False
331
+
332
+ @classmethod
333
+ def drop(cls, name_or_id: str, *, if_not_exists: DropMode = "error"):
334
+ """
335
+ Delete the finetuned embedding model from the OrcaCloud
336
+
337
+ Params:
338
+ name_or_id: The name or id of the finetuned embedding model
339
+
340
+ Raises:
341
+ LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
342
+ """
343
+ try:
344
+ delete_finetuned_embedding_model(name_or_id)
345
+ except LookupError:
346
+ if if_not_exists == "error":
347
+ raise
@@ -0,0 +1,176 @@
1
+ from uuid import uuid4
2
+
3
+ import pytest
4
+
5
+ from .datasource import Datasource
6
+ from .embedding_model import (
7
+ FinetunedEmbeddingModel,
8
+ PretrainedEmbeddingModel,
9
+ PretrainedEmbeddingModelName,
10
+ TaskStatus,
11
+ )
12
+ from .memoryset import LabeledMemoryset
13
+
14
+
15
+ def test_open_pretrained_model():
16
+ model = PretrainedEmbeddingModel.GTE_BASE
17
+ assert model is not None
18
+ assert isinstance(model, PretrainedEmbeddingModel)
19
+ assert model.name == "GTE_BASE"
20
+ assert model.embedding_dim == 768
21
+ assert model.max_seq_length == 8192
22
+ assert model is PretrainedEmbeddingModel.GTE_BASE
23
+
24
+
25
+ def test_open_pretrained_model_unauthenticated(unauthenticated):
26
+ with pytest.raises(ValueError, match="Invalid API key"):
27
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline")
28
+
29
+
30
+ def test_open_pretrained_model_not_found():
31
+ with pytest.raises(LookupError):
32
+ PretrainedEmbeddingModel._get("INVALID_MODEL")
33
+
34
+
35
+ def test_all_pretrained_models():
36
+ models = PretrainedEmbeddingModel.all()
37
+ assert len(models) == len(PretrainedEmbeddingModelName)
38
+ assert all(m.name in PretrainedEmbeddingModelName.__members__ for m in models)
39
+
40
+
41
+ def test_embed_text():
42
+ embedding = PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
43
+ assert embedding is not None
44
+ assert isinstance(embedding, list)
45
+ assert len(embedding) == 768
46
+ assert isinstance(embedding[0], float)
47
+
48
+
49
+ def test_embed_text_unauthenticated(unauthenticated):
50
+ with pytest.raises(ValueError, match="Invalid API key"):
51
+ PretrainedEmbeddingModel.GTE_BASE.embed("I love this airline", max_seq_length=32)
52
+
53
+
54
+ @pytest.fixture(scope="session")
55
+ def finetuned_model(datasource) -> FinetunedEmbeddingModel:
56
+ return PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, value_column="text")
57
+
58
+
59
+ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel):
60
+ assert finetuned_model is not None
61
+ assert finetuned_model.name == "test_finetuned_model"
62
+ assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
63
+ assert finetuned_model.embedding_dim == 768
64
+ assert finetuned_model.max_seq_length == 512
65
+ assert finetuned_model._status == TaskStatus.COMPLETED
66
+
67
+
68
+ def test_finetune_model_with_memoryset(memoryset: LabeledMemoryset):
69
+ finetuned_model = PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model_from_memoryset", memoryset)
70
+ assert finetuned_model is not None
71
+ assert finetuned_model.name == "test_finetuned_model_from_memoryset"
72
+ assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
73
+ assert finetuned_model.embedding_dim == 768
74
+ assert finetuned_model.max_seq_length == 512
75
+ assert finetuned_model._status == TaskStatus.COMPLETED
76
+
77
+
78
+ def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
79
+ with pytest.raises(ValueError):
80
+ PretrainedEmbeddingModel.DISTILBERT.finetune("test_finetuned_model", datasource, value_column="text")
81
+
82
+
83
+ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_model):
84
+ with pytest.raises(ValueError):
85
+ PretrainedEmbeddingModel.GTE_BASE.finetune(
86
+ "test_finetuned_model", datasource, if_exists="open", value_column="text"
87
+ )
88
+
89
+ new_model = PretrainedEmbeddingModel.DISTILBERT.finetune(
90
+ "test_finetuned_model", datasource, if_exists="open", value_column="text"
91
+ )
92
+ assert new_model is not None
93
+ assert new_model.name == "test_finetuned_model"
94
+ assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
95
+ assert new_model.embedding_dim == 768
96
+ assert new_model.max_seq_length == 512
97
+ assert new_model._status == TaskStatus.COMPLETED
98
+
99
+
100
+ def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
101
+ with pytest.raises(ValueError, match="Invalid API key"):
102
+ PretrainedEmbeddingModel.DISTILBERT.finetune(
103
+ "test_finetuned_model_unauthenticated", datasource, value_column="text"
104
+ )
105
+
106
+
107
+ def test_use_finetuned_model_in_memoryset(datasource: Datasource, finetuned_model: FinetunedEmbeddingModel):
108
+ memoryset = LabeledMemoryset.create(
109
+ "test_memoryset_finetuned_model",
110
+ datasource,
111
+ embedding_model=finetuned_model,
112
+ value_column="text",
113
+ )
114
+ assert memoryset is not None
115
+ assert memoryset.name == "test_memoryset_finetuned_model"
116
+ assert memoryset.embedding_model == finetuned_model
117
+ assert memoryset.length == datasource.length
118
+
119
+
120
+ def test_open_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
121
+ model = FinetunedEmbeddingModel.open(finetuned_model.name)
122
+ assert isinstance(model, FinetunedEmbeddingModel)
123
+ assert model.id == finetuned_model.id
124
+ assert model.name == finetuned_model.name
125
+ assert model.base_model == PretrainedEmbeddingModel.DISTILBERT
126
+ assert model.embedding_dim == 768
127
+ assert model.max_seq_length == 512
128
+ assert model == finetuned_model
129
+
130
+
131
+ def test_embed_finetuned_model(finetuned_model: FinetunedEmbeddingModel):
132
+ embedding = finetuned_model.embed("I love this airline")
133
+ assert embedding is not None
134
+ assert isinstance(embedding, list)
135
+ assert len(embedding) == 768
136
+ assert isinstance(embedding[0], float)
137
+
138
+
139
+ def test_all_finetuned_models(finetuned_model: FinetunedEmbeddingModel):
140
+ models = FinetunedEmbeddingModel.all()
141
+ assert len(models) > 0
142
+ assert any(model.name == finetuned_model.name for model in models)
143
+
144
+
145
+ def test_all_finetuned_models_unauthenticated(unauthenticated):
146
+ with pytest.raises(ValueError, match="Invalid API key"):
147
+ FinetunedEmbeddingModel.all()
148
+
149
+
150
+ def test_all_finetuned_models_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
151
+ assert finetuned_model not in FinetunedEmbeddingModel.all()
152
+
153
+
154
+ def test_drop_finetuned_model(datasource: Datasource):
155
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource, value_column="text")
156
+ assert FinetunedEmbeddingModel.open("finetuned_model_to_delete")
157
+ FinetunedEmbeddingModel.drop("finetuned_model_to_delete")
158
+ with pytest.raises(LookupError):
159
+ FinetunedEmbeddingModel.open("finetuned_model_to_delete")
160
+
161
+
162
+ def test_drop_finetuned_model_unauthenticated(unauthenticated, datasource: Datasource):
163
+ with pytest.raises(ValueError, match="Invalid API key"):
164
+ PretrainedEmbeddingModel.DISTILBERT.finetune("finetuned_model_to_delete", datasource, value_column="text")
165
+
166
+
167
+ def test_drop_finetuned_model_not_found():
168
+ with pytest.raises(LookupError):
169
+ FinetunedEmbeddingModel.drop(str(uuid4()))
170
+ # ignores error if specified
171
+ FinetunedEmbeddingModel.drop(str(uuid4()), if_not_exists="ignore")
172
+
173
+
174
+ def test_drop_finetuned_model_unauthorized(unauthorized, finetuned_model: FinetunedEmbeddingModel):
175
+ with pytest.raises(LookupError):
176
+ FinetunedEmbeddingModel.drop(finetuned_model.id)