orca-sdk 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. orca_sdk/__init__.py +30 -0
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +634 -0
  4. orca_sdk/_shared/metrics_test.py +570 -0
  5. orca_sdk/_utils/__init__.py +0 -0
  6. orca_sdk/_utils/analysis_ui.py +196 -0
  7. orca_sdk/_utils/analysis_ui_style.css +51 -0
  8. orca_sdk/_utils/auth.py +65 -0
  9. orca_sdk/_utils/auth_test.py +31 -0
  10. orca_sdk/_utils/common.py +37 -0
  11. orca_sdk/_utils/data_parsing.py +129 -0
  12. orca_sdk/_utils/data_parsing_test.py +244 -0
  13. orca_sdk/_utils/pagination.py +126 -0
  14. orca_sdk/_utils/pagination_test.py +132 -0
  15. orca_sdk/_utils/prediction_result_ui.css +18 -0
  16. orca_sdk/_utils/prediction_result_ui.py +110 -0
  17. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  18. orca_sdk/_utils/value_parser.py +45 -0
  19. orca_sdk/_utils/value_parser_test.py +39 -0
  20. orca_sdk/async_client.py +4104 -0
  21. orca_sdk/classification_model.py +1165 -0
  22. orca_sdk/classification_model_test.py +887 -0
  23. orca_sdk/client.py +4096 -0
  24. orca_sdk/conftest.py +382 -0
  25. orca_sdk/credentials.py +217 -0
  26. orca_sdk/credentials_test.py +121 -0
  27. orca_sdk/datasource.py +576 -0
  28. orca_sdk/datasource_test.py +463 -0
  29. orca_sdk/embedding_model.py +712 -0
  30. orca_sdk/embedding_model_test.py +206 -0
  31. orca_sdk/job.py +343 -0
  32. orca_sdk/job_test.py +108 -0
  33. orca_sdk/memoryset.py +3811 -0
  34. orca_sdk/memoryset_test.py +1150 -0
  35. orca_sdk/regression_model.py +841 -0
  36. orca_sdk/regression_model_test.py +595 -0
  37. orca_sdk/telemetry.py +742 -0
  38. orca_sdk/telemetry_test.py +119 -0
  39. orca_sdk-0.1.9.dist-info/METADATA +98 -0
  40. orca_sdk-0.1.9.dist-info/RECORD +41 -0
  41. orca_sdk-0.1.9.dist-info/WHEEL +4 -0
@@ -0,0 +1,712 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from datetime import datetime
5
+ from typing import TYPE_CHECKING, Literal, Sequence, cast, get_args, overload
6
+
7
+ from ._shared.metrics import ClassificationMetrics, RegressionMetrics
8
+ from ._utils.common import UNSET, CreateMode, DropMode
9
+ from .client import (
10
+ EmbeddingEvaluationRequest,
11
+ EmbeddingFinetuningMethod,
12
+ EmbedRequest,
13
+ FinetunedEmbeddingModelMetadata,
14
+ FinetuneEmbeddingModelRequest,
15
+ OrcaClient,
16
+ PretrainedEmbeddingModelMetadata,
17
+ PretrainedEmbeddingModelName,
18
+ )
19
+ from .datasource import Datasource
20
+ from .job import Job, Status
21
+
22
+ if TYPE_CHECKING:
23
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
24
+
25
+
26
+ class EmbeddingModelBase(ABC):
27
+ embedding_dim: int
28
+ max_seq_length: int
29
+ num_params: int
30
+ uses_context: bool
31
+ supports_instructions: bool
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ name: str,
37
+ embedding_dim: int,
38
+ max_seq_length: int,
39
+ num_params: int,
40
+ uses_context: bool,
41
+ supports_instructions: bool,
42
+ ):
43
+ self.embedding_dim = embedding_dim
44
+ self.max_seq_length = max_seq_length
45
+ self.num_params = num_params
46
+ self.uses_context = uses_context
47
+ self.supports_instructions = supports_instructions
48
+
49
+ @classmethod
50
+ @abstractmethod
51
+ def all(cls) -> Sequence[EmbeddingModelBase]:
52
+ pass
53
+
54
+ def _get_instruction_error_message(self) -> str:
55
+ """Get error message for instruction not supported"""
56
+ if isinstance(self, FinetunedEmbeddingModel):
57
+ return f"Model {self.name} does not support instructions. Instruction-following is only supported by models based on instruction-supporting models."
58
+ elif isinstance(self, PretrainedEmbeddingModel):
59
+ return f"Model {self.name} does not support instructions. Instruction-following is only supported by instruction-supporting models."
60
+ else:
61
+ raise ValueError("Invalid embedding model")
62
+
63
+ @overload
64
+ def embed(self, value: str, max_seq_length: int | None = None, prompt: str | None = None) -> list[float]:
65
+ pass
66
+
67
+ @overload
68
+ def embed(
69
+ self, value: list[str], max_seq_length: int | None = None, prompt: str | None = None
70
+ ) -> list[list[float]]:
71
+ pass
72
+
73
+ def embed(
74
+ self, value: str | list[str], max_seq_length: int | None = None, prompt: str | None = None
75
+ ) -> list[float] | list[list[float]]:
76
+ """
77
+ Generate embeddings for a value or list of values
78
+
79
+ Params:
80
+ value: The value or list of values to embed
81
+ max_seq_length: The maximum sequence length to truncate the input to
82
+ prompt: Optional prompt for prompt-following embedding models.
83
+
84
+ Returns:
85
+ A matrix of floats representing the embedding for each value if the input is a list of
86
+ values, or a list of floats representing the embedding for the single value if the
87
+ input is a single value
88
+ """
89
+ payload: EmbedRequest = {
90
+ "values": value if isinstance(value, list) else [value],
91
+ "max_seq_length": max_seq_length,
92
+ "prompt": prompt,
93
+ }
94
+ client = OrcaClient._resolve_client()
95
+ if isinstance(self, PretrainedEmbeddingModel):
96
+ embeddings = client.POST(
97
+ "/gpu/pretrained_embedding_model/{model_name}/embedding",
98
+ params={"model_name": cast(PretrainedEmbeddingModelName, self.name)},
99
+ json=payload,
100
+ timeout=30, # may be slow in case of cold start
101
+ )
102
+ elif isinstance(self, FinetunedEmbeddingModel):
103
+ embeddings = client.POST(
104
+ "/gpu/finetuned_embedding_model/{name_or_id}/embedding",
105
+ params={"name_or_id": self.id},
106
+ json=payload,
107
+ timeout=30, # may be slow in case of cold start
108
+ )
109
+ else:
110
+ raise ValueError("Invalid embedding model")
111
+ return embeddings if isinstance(value, list) else embeddings[0]
112
+
113
+ @overload
114
+ def evaluate(
115
+ self,
116
+ datasource: Datasource,
117
+ *,
118
+ value_column: str = "value",
119
+ label_column: str,
120
+ score_column: None = None,
121
+ eval_datasource: Datasource | None = None,
122
+ subsample: int | float | None = None,
123
+ neighbor_count: int = 5,
124
+ batch_size: int = 32,
125
+ weigh_memories: bool = True,
126
+ background: Literal[True],
127
+ ) -> Job[ClassificationMetrics]:
128
+ pass
129
+
130
+ @overload
131
+ def evaluate(
132
+ self,
133
+ datasource: Datasource,
134
+ *,
135
+ value_column: str = "value",
136
+ label_column: str,
137
+ score_column: None = None,
138
+ eval_datasource: Datasource | None = None,
139
+ subsample: int | float | None = None,
140
+ neighbor_count: int = 5,
141
+ batch_size: int = 32,
142
+ weigh_memories: bool = True,
143
+ background: Literal[False] = False,
144
+ ) -> ClassificationMetrics:
145
+ pass
146
+
147
+ @overload
148
+ def evaluate(
149
+ self,
150
+ datasource: Datasource,
151
+ *,
152
+ value_column: str = "value",
153
+ label_column: None = None,
154
+ score_column: str,
155
+ eval_datasource: Datasource | None = None,
156
+ subsample: int | float | None = None,
157
+ neighbor_count: int = 5,
158
+ batch_size: int = 32,
159
+ weigh_memories: bool = True,
160
+ background: Literal[True],
161
+ ) -> Job[RegressionMetrics]:
162
+ pass
163
+
164
+ @overload
165
+ def evaluate(
166
+ self,
167
+ datasource: Datasource,
168
+ *,
169
+ value_column: str = "value",
170
+ label_column: None = None,
171
+ score_column: str,
172
+ eval_datasource: Datasource | None = None,
173
+ subsample: int | float | None = None,
174
+ neighbor_count: int = 5,
175
+ batch_size: int = 32,
176
+ weigh_memories: bool = True,
177
+ background: Literal[False] = False,
178
+ ) -> RegressionMetrics:
179
+ pass
180
+
181
+ def evaluate(
182
+ self,
183
+ datasource: Datasource,
184
+ *,
185
+ value_column: str = "value",
186
+ label_column: str | None = None,
187
+ score_column: str | None = None,
188
+ eval_datasource: Datasource | None = None,
189
+ subsample: int | float | None = None,
190
+ neighbor_count: int = 5,
191
+ batch_size: int = 32,
192
+ weigh_memories: bool = True,
193
+ background: bool = False,
194
+ ) -> (
195
+ ClassificationMetrics
196
+ | RegressionMetrics
197
+ | Job[ClassificationMetrics]
198
+ | Job[RegressionMetrics]
199
+ | Job[ClassificationMetrics | RegressionMetrics]
200
+ ):
201
+ """
202
+ Evaluate the finetuned embedding model
203
+ """
204
+
205
+ payload: EmbeddingEvaluationRequest = {
206
+ "datasource_name_or_id": datasource.id,
207
+ "datasource_label_column": label_column,
208
+ "datasource_value_column": value_column,
209
+ "datasource_score_column": score_column,
210
+ "eval_datasource_name_or_id": eval_datasource.id if eval_datasource is not None else None,
211
+ "subsample": subsample,
212
+ "neighbor_count": neighbor_count,
213
+ "batch_size": batch_size,
214
+ "weigh_memories": weigh_memories,
215
+ }
216
+ client = OrcaClient._resolve_client()
217
+ if isinstance(self, PretrainedEmbeddingModel):
218
+ response = client.POST(
219
+ "/pretrained_embedding_model/{model_name}/evaluation",
220
+ params={"model_name": self.name},
221
+ json=payload,
222
+ )
223
+ elif isinstance(self, FinetunedEmbeddingModel):
224
+ response = client.POST(
225
+ "/finetuned_embedding_model/{name_or_id}/evaluation",
226
+ params={"name_or_id": self.id},
227
+ json=payload,
228
+ )
229
+ else:
230
+ raise ValueError("Invalid embedding model")
231
+
232
+ def get_result(job_id: str) -> ClassificationMetrics | RegressionMetrics:
233
+ client = OrcaClient._resolve_client()
234
+ if isinstance(self, PretrainedEmbeddingModel):
235
+ res = client.GET(
236
+ "/pretrained_embedding_model/{model_name}/evaluation/{job_id}",
237
+ params={"model_name": self.name, "job_id": job_id},
238
+ )["result"]
239
+ elif isinstance(self, FinetunedEmbeddingModel):
240
+ res = client.GET(
241
+ "/finetuned_embedding_model/{name_or_id}/evaluation/{job_id}",
242
+ params={"name_or_id": self.id, "job_id": job_id},
243
+ )["result"]
244
+ else:
245
+ raise ValueError("Invalid embedding model")
246
+ assert res is not None
247
+ return (
248
+ RegressionMetrics(
249
+ coverage=res.get("coverage"),
250
+ mse=res.get("mse"),
251
+ rmse=res.get("rmse"),
252
+ mae=res.get("mae"),
253
+ r2=res.get("r2"),
254
+ explained_variance=res.get("explained_variance"),
255
+ loss=res.get("loss"),
256
+ anomaly_score_mean=res.get("anomaly_score_mean"),
257
+ anomaly_score_median=res.get("anomaly_score_median"),
258
+ anomaly_score_variance=res.get("anomaly_score_variance"),
259
+ )
260
+ if "mse" in res
261
+ else ClassificationMetrics(
262
+ coverage=res.get("coverage"),
263
+ f1_score=res.get("f1_score"),
264
+ accuracy=res.get("accuracy"),
265
+ loss=res.get("loss"),
266
+ anomaly_score_mean=res.get("anomaly_score_mean"),
267
+ anomaly_score_median=res.get("anomaly_score_median"),
268
+ anomaly_score_variance=res.get("anomaly_score_variance"),
269
+ roc_auc=res.get("roc_auc"),
270
+ pr_auc=res.get("pr_auc"),
271
+ pr_curve=res.get("pr_curve"),
272
+ roc_curve=res.get("roc_curve"),
273
+ )
274
+ )
275
+
276
+ job = Job(response["job_id"], lambda: get_result(response["job_id"]))
277
+ return job if background else job.result()
278
+
279
+
280
+ class _ModelDescriptor:
281
+ """
282
+ Descriptor for lazily loading embedding models with IDE autocomplete support.
283
+
284
+ This class implements the descriptor protocol to provide lazy loading of embedding models
285
+ while maintaining IDE autocomplete functionality. It delays the actual loading of models
286
+ until they are accessed, which improves startup performance.
287
+
288
+ The descriptor pattern works by defining how attribute access is handled. When a class
289
+ attribute using this descriptor is accessed, the __get__ method is called, which then
290
+ retrieves or initializes the actual model on first access.
291
+ """
292
+
293
+ def __init__(self, name: str):
294
+ """
295
+ Initialize a model descriptor.
296
+
297
+ Args:
298
+ name: The name of the embedding model in PretrainedEmbeddingModelName
299
+ """
300
+ self.name = name
301
+ self.model = None # Model is loaded lazily on first access
302
+
303
+ def __get__(self, instance, owner_class):
304
+ """
305
+ Descriptor protocol method called when the attribute is accessed.
306
+
307
+ This method implements lazy loading - the actual model is only initialized
308
+ the first time it's accessed. Subsequent accesses will use the cached model.
309
+
310
+ Args:
311
+ instance: The instance the attribute was accessed from, or None if accessed from the class
312
+ owner_class: The class that owns the descriptor
313
+
314
+ Returns:
315
+ The initialized embedding model
316
+
317
+ Raises:
318
+ AttributeError: If no model with the given name exists
319
+ """
320
+ # When accessed from an instance, redirect to class access
321
+ if instance is not None:
322
+ return self.__get__(None, owner_class)
323
+
324
+ # Load the model on first access
325
+ if self.model is None:
326
+ try:
327
+ self.model = PretrainedEmbeddingModel._get(cast(PretrainedEmbeddingModelName, self.name))
328
+ except (KeyError, AttributeError):
329
+ raise AttributeError(f"No embedding model named {self.name}")
330
+
331
+ return self.model
332
+
333
+
334
+ class PretrainedEmbeddingModel(EmbeddingModelBase):
335
+ """
336
+ A pretrained embedding model
337
+
338
+ **Models:**
339
+
340
+ OrcaCloud supports a select number of small to medium sized embedding models that perform well on the
341
+ [Hugging Face MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard).
342
+ These can be accessed as class attributes. We currently support:
343
+
344
+ - **`CDE_SMALL`**: Context-aware CDE small model from Hugging Face ([jxm/cde-small-v1](https://huggingface.co/jxm/cde-small-v1))
345
+ - **`CLIP_BASE`**: Multi-modal CLIP model from Hugging Face ([sentence-transformers/clip-ViT-L-14](https://huggingface.co/sentence-transformers/clip-ViT-L-14))
346
+ - **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
347
+ - **`DISTILBERT`**: DistilBERT embedding model from Hugging Face ([distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased))
348
+ - **`GTE_SMALL`**: GTE-Small embedding model from Hugging Face ([Supabase/gte-small](https://huggingface.co/Supabase/gte-small))
349
+ - **`E5_LARGE`**: E5-Large instruction-tuned embedding model from Hugging Face ([intfloat/multilingual-e5-large-instruct](https://huggingface.co/intfloat/multilingual-e5-large-instruct))
350
+ - **`GIST_LARGE`**: GIST-Large embedding model from Hugging Face ([avsolatorio/GIST-large-Embedding-v0](https://huggingface.co/avsolatorio/GIST-large-Embedding-v0))
351
+ - **`MXBAI_LARGE`**: Mixbreas's Large embedding model from Hugging Face ([mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1))
352
+ - **`BGE_BASE`**: BAAI's BGE-Base instruction-tuned embedding model from Hugging Face ([BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5))
353
+
354
+ **Instruction Support:**
355
+
356
+ Some models support instruction-following for better task-specific embeddings. You can check if a model supports instructions
357
+ using the `supports_instructions` attribute.
358
+
359
+ Examples:
360
+ >>> PretrainedEmbeddingModel.CDE_SMALL
361
+ PretrainedEmbeddingModel({name: CDE_SMALL, embedding_dim: 768, max_seq_length: 512})
362
+
363
+ >>> # Using instruction with an instruction-supporting model
364
+ >>> model = PretrainedEmbeddingModel.E5_LARGE
365
+ >>> embeddings = model.embed("Hello world", prompt="Represent this sentence for retrieval:")
366
+
367
+ Attributes:
368
+ name: Name of the pretrained embedding model
369
+ embedding_dim: Dimension of the embeddings that are generated by the model
370
+ max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
371
+ num_params: Number of parameters in the model
372
+ uses_context: Whether the pretrained embedding model uses context
373
+ supports_instructions: Whether this model supports instruction-following
374
+ """
375
+
376
+ # Define descriptors for model access with IDE autocomplete
377
+ CDE_SMALL = _ModelDescriptor("CDE_SMALL")
378
+ CLIP_BASE = _ModelDescriptor("CLIP_BASE")
379
+ GTE_BASE = _ModelDescriptor("GTE_BASE")
380
+ DISTILBERT = _ModelDescriptor("DISTILBERT")
381
+ GTE_SMALL = _ModelDescriptor("GTE_SMALL")
382
+ E5_LARGE = _ModelDescriptor("E5_LARGE")
383
+ GIST_LARGE = _ModelDescriptor("GIST_LARGE")
384
+ MXBAI_LARGE = _ModelDescriptor("MXBAI_LARGE")
385
+ BGE_BASE = _ModelDescriptor("BGE_BASE")
386
+
387
+ name: PretrainedEmbeddingModelName
388
+
389
+ def __init__(self, metadata: PretrainedEmbeddingModelMetadata):
390
+ # for internal use only, do not document
391
+ self.name = metadata["name"]
392
+ super().__init__(
393
+ name=metadata["name"],
394
+ embedding_dim=metadata["embedding_dim"],
395
+ max_seq_length=metadata["max_seq_length"],
396
+ num_params=metadata["num_params"],
397
+ uses_context=metadata["uses_context"],
398
+ supports_instructions=(
399
+ bool(metadata["supports_instructions"]) if "supports_instructions" in metadata else False
400
+ ),
401
+ )
402
+
403
+ def __eq__(self, other) -> bool:
404
+ return isinstance(other, PretrainedEmbeddingModel) and self.name == other.name
405
+
406
+ def __repr__(self) -> str:
407
+ return f"PretrainedEmbeddingModel({{name: {self.name}, embedding_dim: {self.embedding_dim}, max_seq_length: {self.max_seq_length}, num_params: {self.num_params/1000000:.0f}M}})"
408
+
409
+ @classmethod
410
+ def all(cls) -> list[PretrainedEmbeddingModel]:
411
+ """
412
+ List all pretrained embedding models in the OrcaCloud
413
+
414
+ Returns:
415
+ A list of all pretrained embedding models available in the OrcaCloud
416
+ """
417
+ client = OrcaClient._resolve_client()
418
+ return [cls(metadata) for metadata in client.GET("/pretrained_embedding_model")]
419
+
420
+ _instances: dict[str, PretrainedEmbeddingModel] = {}
421
+
422
+ @classmethod
423
+ def _get(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
424
+ # for internal use only, do not document - we want people to use dot notation to get the model
425
+ cache_key = str(name)
426
+ if cache_key not in cls._instances:
427
+ client = OrcaClient._resolve_client()
428
+ metadata = client.GET(
429
+ "/pretrained_embedding_model/{model_name}",
430
+ params={"model_name": name},
431
+ )
432
+ cls._instances[cache_key] = cls(metadata)
433
+ return cls._instances[cache_key]
434
+
435
+ @classmethod
436
+ def open(cls, name: PretrainedEmbeddingModelName) -> PretrainedEmbeddingModel:
437
+ """
438
+ Open an embedding model by name.
439
+
440
+ This is an alternative method to access models for environments
441
+ where IDE autocomplete for model names is not available.
442
+
443
+ Params:
444
+ name: Name of the model to open (e.g., "GTE_BASE", "CLIP_BASE")
445
+
446
+ Returns:
447
+ The embedding model instance
448
+
449
+ Examples:
450
+ >>> model = PretrainedEmbeddingModel.open("GTE_BASE")
451
+ """
452
+ try:
453
+ # Always use the _get method which handles caching properly
454
+ return cls._get(name)
455
+ except (KeyError, AttributeError):
456
+ raise ValueError(f"Unknown model name: {name}")
457
+
458
+ @classmethod
459
+ def exists(cls, name: str) -> bool:
460
+ """
461
+ Check if a pretrained embedding model exists by name
462
+
463
+ Params:
464
+ name: The name of the pretrained embedding model
465
+
466
+ Returns:
467
+ True if the pretrained embedding model exists, False otherwise
468
+ """
469
+ return name in get_args(PretrainedEmbeddingModelName)
470
+
471
+ @overload
472
+ def finetune(
473
+ self,
474
+ name: str,
475
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
476
+ *,
477
+ eval_datasource: Datasource | None = None,
478
+ label_column: str = "label",
479
+ score_column: str = "score",
480
+ value_column: str = "value",
481
+ training_method: EmbeddingFinetuningMethod | None = None,
482
+ training_args: dict | None = None,
483
+ if_exists: CreateMode = "error",
484
+ background: Literal[True],
485
+ ) -> Job[FinetunedEmbeddingModel]:
486
+ pass
487
+
488
+ @overload
489
+ def finetune(
490
+ self,
491
+ name: str,
492
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
493
+ *,
494
+ eval_datasource: Datasource | None = None,
495
+ label_column: str = "label",
496
+ score_column: str = "score",
497
+ value_column: str = "value",
498
+ training_method: EmbeddingFinetuningMethod | None = None,
499
+ training_args: dict | None = None,
500
+ if_exists: CreateMode = "error",
501
+ background: Literal[False] = False,
502
+ ) -> FinetunedEmbeddingModel:
503
+ pass
504
+
505
+ def finetune(
506
+ self,
507
+ name: str,
508
+ train_datasource: Datasource | LabeledMemoryset | ScoredMemoryset,
509
+ *,
510
+ eval_datasource: Datasource | None = None,
511
+ label_column: str = "label",
512
+ score_column: str = "score",
513
+ value_column: str = "value",
514
+ training_method: EmbeddingFinetuningMethod | None = None,
515
+ training_args: dict | None = None,
516
+ if_exists: CreateMode = "error",
517
+ background: bool = False,
518
+ ) -> FinetunedEmbeddingModel | Job[FinetunedEmbeddingModel]:
519
+ """
520
+ Finetune an embedding model
521
+
522
+ Params:
523
+ name: Name of the finetuned embedding model
524
+ train_datasource: Data to train on
525
+ eval_datasource: Optionally provide data to evaluate on
526
+ label_column: Column name of the label.
527
+ score_column: Column name of the score (for regression when training on scored data).
528
+ value_column: Column name of the value
529
+ training_method: Optional training method override. If omitted, Lighthouse defaults apply.
530
+ training_args: Optional override for Hugging Face [`TrainingArguments`][transformers.TrainingArguments].
531
+ If not provided, reasonable training arguments will be used for the specified training method
532
+ if_exists: What to do if a finetuned embedding model with the same name already exists, defaults to
533
+ `"error"`. Other option is `"open"` to open the existing finetuned embedding model.
534
+ background: Whether to run the operation in the background and return a job handle
535
+
536
+ Returns:
537
+ The finetuned embedding model
538
+
539
+ Raises:
540
+ ValueError: If the finetuned embedding model already exists and `if_exists` is `"error"` or if it is `"open"`
541
+ but the base model param does not match the existing model
542
+
543
+ Examples:
544
+ >>> datasource = Datasource.open("my_datasource")
545
+ >>> model = PretrainedEmbeddingModel.CLIP_BASE
546
+ >>> model.finetune("my_finetuned_model", datasource)
547
+ """
548
+ exists = FinetunedEmbeddingModel.exists(name)
549
+
550
+ if exists and if_exists == "error":
551
+ raise ValueError(f"Finetuned embedding model '{name}' already exists")
552
+ elif exists and if_exists == "open":
553
+ existing = FinetunedEmbeddingModel.open(name)
554
+
555
+ if existing.base_model.name != self.name:
556
+ raise ValueError(f"Finetuned embedding model '{name}' already exists, but with different base model")
557
+
558
+ return existing
559
+
560
+ from .memoryset import LabeledMemoryset, ScoredMemoryset
561
+
562
+ payload: FinetuneEmbeddingModelRequest = {
563
+ "name": name,
564
+ "base_model": self.name,
565
+ "label_column": label_column,
566
+ "score_column": score_column,
567
+ "value_column": value_column,
568
+ "training_args": training_args or {},
569
+ }
570
+ if training_method is not None:
571
+ payload["training_method"] = training_method
572
+
573
+ if isinstance(train_datasource, Datasource):
574
+ payload["train_datasource_name_or_id"] = train_datasource.id
575
+ elif isinstance(train_datasource, (LabeledMemoryset, ScoredMemoryset)):
576
+ payload["train_memoryset_name_or_id"] = train_datasource.id
577
+ if eval_datasource is not None:
578
+ payload["eval_datasource_name_or_id"] = eval_datasource.id
579
+
580
+ client = OrcaClient._resolve_client()
581
+ res = client.POST(
582
+ "/finetuned_embedding_model",
583
+ json=payload,
584
+ )
585
+ job = Job(
586
+ res["finetuning_job_id"],
587
+ lambda: FinetunedEmbeddingModel.open(res["id"]),
588
+ )
589
+ return job if background else job.result()
590
+
591
+
592
+ class FinetunedEmbeddingModel(EmbeddingModelBase):
593
+ """
594
+ A finetuned embedding model in the OrcaCloud
595
+
596
+ Attributes:
597
+ name: Name of the finetuned embedding model
598
+ embedding_dim: Dimension of the embeddings that are generated by the model
599
+ max_seq_length: Maximum input length (in tokens not characters) that this model can process. Inputs that are longer will be truncated during the embedding process
600
+ uses_context: Whether the model uses the memoryset to contextualize embeddings (acts akin to inverse document frequency in TFIDF features)
601
+ id: Unique identifier of the finetuned embedding model
602
+ base_model: Base model the finetuned embedding model was trained on
603
+ created_at: When the model was finetuned
604
+ """
605
+
606
+ id: str
607
+ name: str
608
+ created_at: datetime
609
+ updated_at: datetime
610
+ base_model: PretrainedEmbeddingModel
611
+ _status: Status
612
+
613
+ def __init__(self, metadata: FinetunedEmbeddingModelMetadata):
614
+ # for internal use only, do not document
615
+ self.id = metadata["id"]
616
+ self.name = metadata["name"]
617
+ self.created_at = datetime.fromisoformat(metadata["created_at"])
618
+ self.updated_at = datetime.fromisoformat(metadata["updated_at"])
619
+ self.base_model = PretrainedEmbeddingModel._get(metadata["base_model"])
620
+ self._status = Status(metadata["finetuning_status"])
621
+
622
+ super().__init__(
623
+ name=metadata["name"],
624
+ embedding_dim=metadata["embedding_dim"],
625
+ max_seq_length=metadata["max_seq_length"],
626
+ num_params=self.base_model.num_params,
627
+ uses_context=metadata["uses_context"],
628
+ supports_instructions=self.base_model.supports_instructions,
629
+ )
630
+
631
+ def __eq__(self, other) -> bool:
632
+ return isinstance(other, FinetunedEmbeddingModel) and self.id == other.id
633
+
634
+ def __repr__(self) -> str:
635
+ return (
636
+ "FinetunedEmbeddingModel({\n"
637
+ f" name: {self.name},\n"
638
+ f" embedding_dim: {self.embedding_dim},\n"
639
+ f" max_seq_length: {self.max_seq_length},\n"
640
+ f" base_model: PretrainedEmbeddingModel.{self.base_model.name}\n"
641
+ "})"
642
+ )
643
+
644
+ @classmethod
645
+ def all(cls) -> list[FinetunedEmbeddingModel]:
646
+ """
647
+ List all finetuned embedding model handles in the OrcaCloud
648
+
649
+ Returns:
650
+ A list of all finetuned embedding model handles in the OrcaCloud
651
+ """
652
+ client = OrcaClient._resolve_client()
653
+ return [cls(metadata) for metadata in client.GET("/finetuned_embedding_model")]
654
+
655
+ @classmethod
656
+ def open(cls, name: str) -> FinetunedEmbeddingModel:
657
+ """
658
+ Get a handle to a finetuned embedding model in the OrcaCloud
659
+
660
+ Params:
661
+ name: The name or unique identifier of a finetuned embedding model
662
+
663
+ Returns:
664
+ A handle to the finetuned embedding model in the OrcaCloud
665
+
666
+ Raises:
667
+ LookupError: If the finetuned embedding model does not exist
668
+ """
669
+ client = OrcaClient._resolve_client()
670
+ metadata = client.GET(
671
+ "/finetuned_embedding_model/{name_or_id}",
672
+ params={"name_or_id": name},
673
+ )
674
+ return cls(metadata)
675
+
676
+ @classmethod
677
+ def exists(cls, name_or_id: str) -> bool:
678
+ """
679
+ Check if a finetuned embedding model with the given name or id exists.
680
+
681
+ Params:
682
+ name_or_id: The name or id of the finetuned embedding model
683
+
684
+ Returns:
685
+ True if the finetuned embedding model exists, False otherwise
686
+ """
687
+ try:
688
+ cls.open(name_or_id)
689
+ return True
690
+ except LookupError:
691
+ return False
692
+
693
+ @classmethod
694
+ def drop(cls, name_or_id: str, *, if_not_exists: DropMode = "error"):
695
+ """
696
+ Delete the finetuned embedding model from the OrcaCloud
697
+
698
+ Params:
699
+ name_or_id: The name or id of the finetuned embedding model
700
+
701
+ Raises:
702
+ LookupError: If the finetuned embedding model does not exist and `if_not_exists` is `"error"`
703
+ """
704
+ try:
705
+ client = OrcaClient._resolve_client()
706
+ client.DELETE(
707
+ "/finetuned_embedding_model/{name_or_id}",
708
+ params={"name_or_id": name_or_id},
709
+ )
710
+ except LookupError:
711
+ if if_not_exists == "error":
712
+ raise