orca-sdk 0.0.94__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -175
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +338 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from abc import abstractmethod
4
4
  from datetime import datetime
5
- from typing import TYPE_CHECKING, Sequence, cast, overload
5
+ from typing import TYPE_CHECKING, Literal, Sequence, cast, overload
6
6
 
7
7
  from ._generated_api_client.api import (
8
8
  create_finetuned_embedding_model,
@@ -24,8 +24,8 @@ from ._generated_api_client.models import (
24
24
  PretrainedEmbeddingModelName,
25
25
  )
26
26
  from ._utils.common import CreateMode, DropMode
27
- from ._utils.task import TaskStatus, wait_for_task
28
27
  from .datasource import Datasource
28
+ from .job import Job, Status
29
29
 
30
30
  if TYPE_CHECKING:
31
31
  from .memoryset import LabeledMemoryset
@@ -79,15 +79,61 @@ class _EmbeddingModel:
79
79
  return embeddings if isinstance(value, list) else embeddings[0]
80
80
 
81
81
 
82
- class _PretrainedEmbeddingModelMeta(type):
83
- def __getattr__(cls, name: str) -> PretrainedEmbeddingModel:
84
- if cls != FinetunedEmbeddingModel and name in PretrainedEmbeddingModelName.__members__:
85
- return PretrainedEmbeddingModel._get(name)
86
- else:
87
- raise AttributeError(f"'{cls.__name__}' object has no attribute '{name}'")
82
+ class _ModelDescriptor:
83
+ """
84
+ Descriptor for lazily loading embedding models with IDE autocomplete support.
88
85
 
86
+ This class implements the descriptor protocol to provide lazy loading of embedding models
87
+ while maintaining IDE autocomplete functionality. It delays the actual loading of models
88
+ until they are accessed, which improves startup performance.
89
89
 
90
- class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingModelMeta):
90
+ The descriptor pattern works by defining how attribute access is handled. When a class
91
+ attribute using this descriptor is accessed, the __get__ method is called, which then
92
+ retrieves or initializes the actual model on first access.
93
+ """
94
+
95
+ def __init__(self, name: str):
96
+ """
97
+ Initialize a model descriptor.
98
+
99
+ Args:
100
+ name: The name of the embedding model in PretrainedEmbeddingModelName
101
+ """
102
+ self.name = name
103
+ self.model = None # Model is loaded lazily on first access
104
+
105
+ def __get__(self, instance, owner_class):
106
+ """
107
+ Descriptor protocol method called when the attribute is accessed.
108
+
109
+ This method implements lazy loading - the actual model is only initialized
110
+ the first time it's accessed. Subsequent accesses will use the cached model.
111
+
112
+ Args:
113
+ instance: The instance the attribute was accessed from, or None if accessed from the class
114
+ owner_class: The class that owns the descriptor
115
+
116
+ Returns:
117
+ The initialized embedding model
118
+
119
+ Raises:
120
+ AttributeError: If no model with the given name exists
121
+ """
122
+ # When accessed from an instance, redirect to class access
123
+ if instance is not None:
124
+ return self.__get__(None, owner_class)
125
+
126
+ # Load the model on first access
127
+ if self.model is None:
128
+ try:
129
+ self.model = PretrainedEmbeddingModel._get(self.name)
130
+ except (KeyError, AttributeError):
131
+ raise AttributeError(f"No embedding model named {self.name}")
132
+
133
+ return self.model
134
+
135
+
136
+ class PretrainedEmbeddingModel(_EmbeddingModel):
91
137
  """
92
138
  A pretrained embedding model
93
139
 
@@ -102,6 +148,11 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
102
148
  - **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
103
149
  - **`DISTILBERT`**: DistilBERT embedding model from Hugging Face ([distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased))
104
150
  - **`GTE_SMALL`**: GTE-Small embedding model from Hugging Face ([Supabase/gte-small](https://huggingface.co/Supabase/gte-small))
151
+ - **`E5_LARGE`**: E5-Large instruction-tuned embedding model from Hugging Face ([intfloat/multilingual-e5-large-instruct](https://huggingface.co/intfloat/multilingual-e5-large-instruct))
152
+ - **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
153
+ - **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
154
+ - **`QWEN2_1_5B`**: Alibaba's Qwen2-1.5B instruction-tuned embedding model from Hugging Face ([Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct))
155
+
105
156
 
106
157
  Examples:
107
158
  >>> PretrainedEmbeddingModel.CDE_SMALL
@@ -114,6 +165,17 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
114
165
  uses_context: Whether the pretrained embedding model uses context
115
166
  """
116
167
 
168
+ # Define descriptors for model access with IDE autocomplete
169
+ CDE_SMALL = _ModelDescriptor("CDE_SMALL")
170
+ CLIP_BASE = _ModelDescriptor("CLIP_BASE")
171
+ GTE_BASE = _ModelDescriptor("GTE_BASE")
172
+ DISTILBERT = _ModelDescriptor("DISTILBERT")
173
+ GTE_SMALL = _ModelDescriptor("GTE_SMALL")
174
+ E5_LARGE = _ModelDescriptor("E5_LARGE")
175
+ GIST_LARGE = _ModelDescriptor("GIST_LARGE")
176
+ MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
177
+ QWEN2_1_5B = _ModelDescriptor("QWEN2_1_5B")
178
+
117
179
  _model_name: PretrainedEmbeddingModelName
118
180
 
119
181
  def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
@@ -151,6 +213,29 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
151
213
  cls._instances[str(name)] = cls(get_pretrained_embedding_model(cast(PretrainedEmbeddingModelName, name)))
152
214
  return cls._instances[str(name)]
153
215
 
216
+ @classmethod
217
+ def open(cls, name: str) -> PretrainedEmbeddingModel:
218
+ """
219
+ Open an embedding model by name.
220
+
221
+ This is an alternative method to access models for environments
222
+ where IDE autocomplete for model names is not available.
223
+
224
+ Params:
225
+ name: Name of the model to open (e.g., "GTE_BASE", "CLIP_BASE")
226
+
227
+ Returns:
228
+ The embedding model instance
229
+
230
+ Examples:
231
+ >>> model = PretrainedEmbeddingModel.open("GTE_BASE")
232
+ """
233
+ try:
234
+ # Use getattr to access the descriptor which will initialize the model
235
+ return getattr(cls, name)
236
+ except AttributeError:
237
+ raise ValueError(f"Unknown model name: {name}")
238
+
154
239
  @classmethod
155
240
  def exists(cls, name: str) -> bool:
156
241
  """
@@ -164,6 +249,23 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
164
249
  """
165
250
  return name in PretrainedEmbeddingModelName
166
251
 
252
+ @overload
253
+ def finetune(
254
+ self,
255
+ name: str,
256
+ train_datasource: Datasource | LabeledMemoryset,
257
+ *,
258
+ eval_datasource: Datasource | None = None,
259
+ label_column: str = "label",
260
+ value_column: str = "value",
261
+ training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
262
+ training_args: dict | None = None,
263
+ if_exists: CreateMode = "error",
264
+ background: Literal[True],
265
+ ) -> Job[FinetunedEmbeddingModel]:
266
+ pass
267
+
268
+ @overload
167
269
  def finetune(
168
270
  self,
169
271
  name: str,
@@ -175,7 +277,23 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
175
277
  training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
176
278
  training_args: dict | None = None,
177
279
  if_exists: CreateMode = "error",
280
+ background: Literal[False] = False,
178
281
  ) -> FinetunedEmbeddingModel:
282
+ pass
283
+
284
+ def finetune(
285
+ self,
286
+ name: str,
287
+ train_datasource: Datasource | LabeledMemoryset,
288
+ *,
289
+ eval_datasource: Datasource | None = None,
290
+ label_column: str = "label",
291
+ value_column: str = "value",
292
+ training_method: EmbeddingFinetuningMethod | str = EmbeddingFinetuningMethod.CLASSIFICATION,
293
+ training_args: dict | None = None,
294
+ if_exists: CreateMode = "error",
295
+ background: bool = False,
296
+ ) -> FinetunedEmbeddingModel | Job[FinetunedEmbeddingModel]:
179
297
  """
180
298
  Finetune an embedding model
181
299
 
@@ -190,6 +308,7 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
190
308
  If not provided, reasonable training arguments will be used for the specified training method
191
309
  if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
192
310
  `"error"`. Other option is `"open"` to open the existing finetuned embedding model.
311
+ background: Whether to run the operation in the background and return a job handle
193
312
 
194
313
  Returns:
195
314
  The finetuned embedding model
@@ -233,8 +352,11 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
233
352
  training_args=(FinetuneEmbeddingModelRequestTrainingArgs.from_dict(training_args or {})),
234
353
  ),
235
354
  )
236
- wait_for_task(res.finetuning_task_id, description="Finetuning embedding model")
237
- return FinetunedEmbeddingModel.open(res.id)
355
+ job = Job(
356
+ res.finetuning_task_id,
357
+ lambda: FinetunedEmbeddingModel.open(res.id),
358
+ )
359
+ return job if background else job.result()
238
360
 
239
361
 
240
362
  class FinetunedEmbeddingModel(_EmbeddingModel):
@@ -254,7 +376,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
254
376
  id: str
255
377
  created_at: datetime
256
378
  updated_at: datetime
257
- _status: TaskStatus
379
+ _status: Status
258
380
 
259
381
  def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
260
382
  # for internal use only, do not document
@@ -262,7 +384,7 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
262
384
  self.created_at = metadata.created_at
263
385
  self.updated_at = metadata.updated_at
264
386
  self.base_model_name = metadata.base_model
265
- self._status = metadata.finetuning_status
387
+ self._status = Status(metadata.finetuning_status.value)
266
388
  super().__init__(
267
389
  name=metadata.name,
268
390
  embedding_dim=metadata.embedding_dim,
@@ -344,6 +466,6 @@ class FinetunedEmbeddingModel(_EmbeddingModel):
344
466
  """
345
467
  try:
346
468
  delete_finetuned_embedding_model(name_or_id)
347
- except LookupError:
469
+ except (LookupError, RuntimeError):
348
470
  if if_not_exists == "error":
349
471
  raise
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from uuid import uuid4
2
3
 
3
4
  import pytest
@@ -7,8 +8,8 @@ from .embedding_model import (
7
8
  FinetunedEmbeddingModel,
8
9
  PretrainedEmbeddingModel,
9
10
  PretrainedEmbeddingModelName,
10
- TaskStatus,
11
11
  )
12
+ from .job import Status
12
13
  from .memoryset import LabeledMemoryset
13
14
 
14
15
 
@@ -34,8 +35,11 @@ def test_open_pretrained_model_not_found():
34
35
 
35
36
  def test_all_pretrained_models():
36
37
  models = PretrainedEmbeddingModel.all()
37
- assert len(models) == len(PretrainedEmbeddingModelName)
38
- assert all(m.name in PretrainedEmbeddingModelName.__members__ for m in models)
38
+ assert len(models) > 1
39
+ if len(models) != len(PretrainedEmbeddingModelName):
40
+ logging.warning("Please regenerate the SDK client! Some pretrained model names are not exposed yet.")
41
+ model_names = [m.name for m in models]
42
+ assert all(enum_member in model_names for enum_member in PretrainedEmbeddingModelName.__members__)
39
43
 
40
44
 
41
45
  def test_embed_text():
@@ -62,7 +66,7 @@ def test_finetune_model_with_datasource(finetuned_model: FinetunedEmbeddingModel
62
66
  assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
63
67
  assert finetuned_model.embedding_dim == 768
64
68
  assert finetuned_model.max_seq_length == 512
65
- assert finetuned_model._status == TaskStatus.COMPLETED
69
+ assert finetuned_model._status == Status.COMPLETED
66
70
 
67
71
 
68
72
  def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
@@ -74,7 +78,7 @@ def test_finetune_model_with_memoryset(readonly_memoryset: LabeledMemoryset):
74
78
  assert finetuned_model.base_model == PretrainedEmbeddingModel.DISTILBERT
75
79
  assert finetuned_model.embedding_dim == 768
76
80
  assert finetuned_model.max_seq_length == 512
77
- assert finetuned_model._status == TaskStatus.COMPLETED
81
+ assert finetuned_model._status == Status.COMPLETED
78
82
 
79
83
 
80
84
  def test_finetune_model_already_exists_error(datasource: Datasource, finetuned_model):
@@ -96,7 +100,7 @@ def test_finetune_model_already_exists_return(datasource: Datasource, finetuned_
96
100
  assert new_model.base_model == PretrainedEmbeddingModel.DISTILBERT
97
101
  assert new_model.embedding_dim == 768
98
102
  assert new_model.max_seq_length == 512
99
- assert new_model._status == TaskStatus.COMPLETED
103
+ assert new_model._status == Status.COMPLETED
100
104
 
101
105
 
102
106
  def test_finetune_model_unauthenticated(unauthenticated, datasource: Datasource):
orca_sdk/job.py ADDED
@@ -0,0 +1,329 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from datetime import datetime, timedelta
5
+ from enum import Enum
6
+ from typing import Callable, Generic, TypedDict, TypeVar, cast
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ from ._generated_api_client.api import abort_task, get_task, get_task_status, list_tasks
11
+ from ._generated_api_client.models import TaskStatus
12
+
13
+
14
+ class JobConfig(TypedDict):
15
+ refresh_interval: int
16
+ show_progress: bool
17
+ max_wait: int
18
+
19
+
20
+ class Status(Enum):
21
+ """Status of a cloud job in the task queue"""
22
+
23
+ # the INITIALIZED state should never be returned by the API
24
+
25
+ DISPATCHED = "DISPATCHED"
26
+ """The job has been queued and is waiting to be processed"""
27
+
28
+ PROCESSING = "PROCESSING"
29
+ """The job is being processed"""
30
+
31
+ COMPLETED = "COMPLETED"
32
+ """The job has been completed successfully"""
33
+
34
+ FAILED = "FAILED"
35
+ """The job has failed"""
36
+
37
+ ABORTING = "ABORTING"
38
+ """The job is being aborted"""
39
+
40
+ ABORTED = "ABORTED"
41
+ """The job has been aborted"""
42
+
43
+
44
+ TResult = TypeVar("TResult")
45
+
46
+
47
+ class Job(Generic[TResult]):
48
+ """
49
+ Handle to a job that is run in the OrcaCloud
50
+
51
+ Attributes:
52
+ id: Unique identifier for the job
53
+ type: Type of the job
54
+ status: Current status of the job
55
+ steps_total: Total number of steps in the job, present if the job started processing
56
+ steps_completed: Number of steps completed in the job, present if the job started processing
57
+ completion: Percentage of the job that has been completed, present if the job started processing
58
+ exception: Exception that occurred during the job, present if the status is `FAILED`
59
+ value: Value of the result of the job, present if the status is `COMPLETED`
60
+ created_at: When the job was queued for processing
61
+ updated_at: When the job was last updated
62
+ refreshed_at: When the job status was last refreshed
63
+
64
+ Note:
65
+ Accessing status and related attributes will refresh the job status in the background.
66
+ """
67
+
68
+ id: str
69
+ type: str
70
+ status: Status
71
+ steps_total: int | None
72
+ steps_completed: int | None
73
+ exception: str | None
74
+ value: TResult | None
75
+ updated_at: datetime
76
+ created_at: datetime
77
+ refreshed_at: datetime
78
+
79
+ @property
80
+ def completion(self) -> float:
81
+ """
82
+ Percentage of the job that has been completed, present if the job started processing
83
+ """
84
+ return (self.steps_completed or 0) / self.steps_total if self.steps_total is not None else 0
85
+
86
+ # Global configuration for all jobs
87
+ config: JobConfig = {
88
+ "refresh_interval": 3,
89
+ "show_progress": True,
90
+ "max_wait": 60 * 60,
91
+ }
92
+
93
+ def __repr__(self) -> str:
94
+ return "Job({" + f" type: {self.type}, status: {self.status}, completion: {self.completion:.0%} " + "})"
95
+
96
+ @classmethod
97
+ def set_config(
98
+ cls, *, refresh_interval: int | None = None, show_progress: bool | None = None, max_wait: int | None = None
99
+ ):
100
+ """
101
+ Set global configuration for running jobs
102
+
103
+ Args:
104
+ refresh_interval: Time to wait between polling the job status in seconds, default is 3
105
+ show_progress: Whether to show a progress bar when calling the wait method, default is True
106
+ max_wait: Maximum time to wait for the job to complete in seconds, default is 1 hour
107
+ """
108
+ if refresh_interval is not None:
109
+ cls.config["refresh_interval"] = refresh_interval
110
+ if show_progress is not None:
111
+ cls.config["show_progress"] = show_progress
112
+ if max_wait is not None:
113
+ cls.config["max_wait"] = max_wait
114
+
115
+ @classmethod
116
+ def query(
117
+ cls,
118
+ status: Status | list[Status] | None = None,
119
+ type: str | list[str] | None = None,
120
+ limit: int | None = None,
121
+ offset: int = 0,
122
+ start: datetime | None = None,
123
+ end: datetime | None = None,
124
+ ) -> list[Job]:
125
+ """
126
+ Query the job queue for jobs matching the given filters
127
+
128
+ Args:
129
+ status: Optional status or list of statuses to filter by
130
+ type: Optional type or list of types to filter by
131
+ limit: Maximum number of jobs to return
132
+ offset: Offset into the list of jobs to return
133
+ start: Optional minimum creation time of the jobs to query for
134
+ end: Optional maximum creation time of the jobs to query for
135
+
136
+ Returns:
137
+ List of jobs matching the given filters
138
+ """
139
+ tasks = list_tasks(
140
+ status=(
141
+ [TaskStatus(s.value) for s in status]
142
+ if isinstance(status, list)
143
+ else TaskStatus(status.value) if isinstance(status, Status) else None
144
+ ),
145
+ type=type,
146
+ limit=limit,
147
+ offset=offset,
148
+ start_timestamp=start,
149
+ end_timestamp=end,
150
+ )
151
+
152
+ # can't use constructor because it makes an API call, so we construct the objects manually
153
+ return [
154
+ (
155
+ lambda t: (
156
+ obj := cls.__new__(cls),
157
+ setattr(obj, "id", t.id),
158
+ setattr(obj, "type", t.type),
159
+ setattr(obj, "status", Status(t.status.value)),
160
+ setattr(obj, "steps_total", t.steps_total),
161
+ setattr(obj, "steps_completed", t.steps_completed),
162
+ setattr(obj, "exception", t.exception),
163
+ setattr(obj, "value", cast(TResult, t.result.to_dict()) if t.result is not None else None),
164
+ setattr(obj, "updated_at", t.updated_at),
165
+ setattr(obj, "created_at", t.created_at),
166
+ setattr(obj, "refreshed_at", datetime.now()),
167
+ obj,
168
+ )[-1]
169
+ )(t)
170
+ for t in tasks
171
+ ]
172
+
173
+ def __init__(self, id: str, get_value: Callable[[], TResult | None] | None = None):
174
+ """
175
+ Create a handle to a job in the job queue
176
+
177
+ Args:
178
+ id: Unique identifier for the job
179
+ get_value: Optional function to customize how the value is resolved, if not provided the result will be a dict
180
+ """
181
+ self.id = id
182
+ task = get_task(self.id)
183
+
184
+ self._get_value = get_value or (
185
+ lambda: (r := get_task(id).result) and (cast(TResult, r.to_dict()) if r else None)
186
+ )
187
+ self.type = task.type
188
+ self.status = Status(task.status.value)
189
+ self.steps_total = task.steps_total
190
+ self.steps_completed = task.steps_completed
191
+ self.exception = task.exception
192
+ self.value = (
193
+ None
194
+ if task.status != TaskStatus.COMPLETED
195
+ else (
196
+ get_value()
197
+ if get_value is not None
198
+ else cast(TResult, task.result.to_dict()) if task.result is not None else None
199
+ )
200
+ )
201
+ self.updated_at = task.updated_at
202
+ self.created_at = task.created_at
203
+ self.refreshed_at = datetime.now()
204
+
205
+ def refresh(self, throttle: float = 0):
206
+ """
207
+ Refresh the status and progress of the job
208
+
209
+ Params:
210
+ throttle: Minimum time in seconds between refreshes
211
+ """
212
+ current_time = datetime.now()
213
+ # Skip refresh if last refresh was too recent
214
+ if (current_time - self.refreshed_at) < timedelta(seconds=throttle):
215
+ return
216
+ self.refreshed_at = current_time
217
+
218
+ status_info = get_task_status(self.id)
219
+ self.status = Status(status_info.status.value)
220
+ if status_info.steps_total is not None:
221
+ self.steps_total = status_info.steps_total
222
+ if status_info.steps_completed is not None:
223
+ self.steps_completed = status_info.steps_completed
224
+
225
+ self.exception = status_info.exception
226
+ self.updated_at = status_info.updated_at
227
+
228
+ if status_info.status == TaskStatus.COMPLETED:
229
+ self.value = self._get_value()
230
+
231
+ def __getattribute__(self, name: str):
232
+ # if the attribute is not immutable, refresh the job if it hasn't been refreshed recently
233
+ if name in ["status", "updated_at", "steps_total", "steps_completed", "exception", "value"]:
234
+ self.refresh(self.config["refresh_interval"])
235
+ return super().__getattribute__(name)
236
+
237
+ def wait(
238
+ self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
239
+ ) -> None:
240
+ """
241
+ Block until the job is complete
242
+
243
+ Params:
244
+ show_progress: Show a progress bar while waiting for the job to complete
245
+ refresh_interval: Polling interval in seconds while waiting for the job to complete
246
+ max_wait: Maximum time to wait for the job to complete in seconds
247
+
248
+ Note:
249
+ The defaults for the config parameters can be set globally using the
250
+ [`set_config`][orca_sdk.Job.set_config] method.
251
+
252
+ This method will not return the result or raise an exception if the job fails. Call
253
+ [`result`][orca_sdk.Job.result] instead if you want to get the result.
254
+
255
+ Raises:
256
+ RuntimeError: If the job times out
257
+ """
258
+ start_time = time.time()
259
+ show_progress = show_progress if show_progress is not None else self.config["show_progress"]
260
+ refresh_interval = refresh_interval if refresh_interval is not None else self.config["refresh_interval"]
261
+ max_wait = max_wait if max_wait is not None else self.config["max_wait"]
262
+ pbar = None
263
+ while True:
264
+ # setup progress bar if steps total is known
265
+ if not pbar and self.steps_total is not None and show_progress:
266
+ desc = " ".join(self.type.split("_")).lower()
267
+ pbar = tqdm(total=self.steps_total, desc=desc)
268
+
269
+ # return if job is complete
270
+ if self.status in [Status.COMPLETED, Status.FAILED, Status.ABORTED]:
271
+ if pbar:
272
+ pbar.update(self.steps_total - pbar.n)
273
+ pbar.close()
274
+ return
275
+
276
+ # raise error if job timed out
277
+ if (time.time() - start_time) > max_wait:
278
+ raise RuntimeError(f"Job {self.id} timed out after {max_wait}s")
279
+
280
+ # update progress bar
281
+ if pbar and self.steps_completed is not None:
282
+ pbar.update(self.steps_completed - pbar.n)
283
+
284
+ # sleep before retrying
285
+ time.sleep(refresh_interval)
286
+
287
+ def result(
288
+ self, show_progress: bool | None = None, refresh_interval: int | None = None, max_wait: int | None = None
289
+ ) -> TResult:
290
+ """
291
+ Block until the job is complete and return the result value
292
+
293
+ Params:
294
+ show_progress: Show a progress bar while waiting for the job to complete
295
+ refresh_interval: Polling interval in seconds while waiting for the job to complete
296
+ max_wait: Maximum time to wait for the job to complete in seconds
297
+
298
+ Note:
299
+ The defaults for the config parameters can be set globally using the
300
+ [`set_config`][orca_sdk.Job.set_config] method.
301
+
302
+ This method will raise an exception if the job fails. Use [`wait`][orca_sdk.Job.wait]
303
+ if you just want to wait for the job to complete without raising errors on failure.
304
+
305
+ Returns:
306
+ The result value of the job
307
+
308
+ Raises:
309
+ RuntimeError: If the job fails or times out
310
+ """
311
+ if self.value is not None:
312
+ return self.value
313
+ self.wait(show_progress, refresh_interval, max_wait)
314
+ if self.status != Status.COMPLETED:
315
+ raise RuntimeError(f"Job failed with exception: {self.exception}")
316
+ assert self.value is not None
317
+ return self.value
318
+
319
+ def abort(self, show_progress: bool = False, refresh_interval: int = 1, max_wait: int = 20) -> None:
320
+ """
321
+ Abort the job
322
+
323
+ Params:
324
+ show_progress: Optionally show a progress bar while waiting for the job to abort
325
+ refresh_interval: Polling interval in seconds while waiting for the job to abort
326
+ max_wait: Maximum time to wait for the job to abort in seconds
327
+ """
328
+ abort_task(self.id)
329
+ self.wait(show_progress, refresh_interval, max_wait)
orca_sdk/job_test.py ADDED
@@ -0,0 +1,48 @@
1
+ import time
2
+
3
+ from .classification_model import ClassificationModel
4
+ from .datasource import Datasource
5
+ from .job import Job, Status
6
+
7
+
8
+ def test_job_creation(classification_model: ClassificationModel, datasource: Datasource):
9
+ job = classification_model.evaluate(datasource, background=True)
10
+ assert job.id is not None
11
+ assert job.type == "EVALUATE_MODEL"
12
+ assert job.status in [Status.DISPATCHED, Status.PROCESSING]
13
+ assert job.created_at is not None
14
+ assert job.updated_at is not None
15
+ assert job.refreshed_at is not None
16
+ assert len(Job.query(limit=5, type="EVALUATE_MODEL")) >= 1
17
+
18
+
19
+ def test_job_result(classification_model: ClassificationModel, datasource: Datasource):
20
+ job = classification_model.evaluate(datasource, background=True)
21
+ result = job.result(show_progress=False)
22
+ assert result is not None
23
+ assert job.status == Status.COMPLETED
24
+ assert job.steps_completed is not None
25
+ assert job.steps_completed == job.steps_total
26
+
27
+
28
+ def test_job_wait(classification_model: ClassificationModel, datasource: Datasource):
29
+ job = classification_model.evaluate(datasource, background=True)
30
+ job.wait(show_progress=False)
31
+ assert job.status == Status.COMPLETED
32
+ assert job.steps_completed is not None
33
+ assert job.steps_completed == job.steps_total
34
+ assert job.value is not None
35
+
36
+
37
+ def test_job_refresh(classification_model: ClassificationModel, datasource: Datasource):
38
+ job = classification_model.evaluate(datasource, background=True)
39
+ last_refreshed_at = job.refreshed_at
40
+ # accessing the status attribute should refresh the job after the refresh interval
41
+ Job.set_config(refresh_interval=1)
42
+ time.sleep(1)
43
+ job.status
44
+ assert job.refreshed_at > last_refreshed_at
45
+ last_refreshed_at = job.refreshed_at
46
+ # calling refresh() should immediately refresh the job
47
+ job.refresh()
48
+ assert job.refreshed_at > last_refreshed_at