orca-sdk 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orca_sdk/memoryset.py CHANGED
@@ -4,7 +4,17 @@ import logging
4
4
  from abc import ABC
5
5
  from datetime import datetime, timedelta
6
6
  from os import PathLike
7
- from typing import Any, Generic, Iterable, Literal, Self, TypeVar, cast, overload
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ Generic,
11
+ Iterable,
12
+ Literal,
13
+ Self,
14
+ TypeVar,
15
+ cast,
16
+ overload,
17
+ )
8
18
 
9
19
  import pandas as pd
10
20
  import pyarrow as pa
@@ -13,11 +23,11 @@ from torch.utils.data import DataLoader as TorchDataLoader
13
23
  from torch.utils.data import Dataset as TorchDataset
14
24
 
15
25
  from ._utils.common import UNSET, CreateMode, DropMode
26
+ from .async_client import OrcaAsyncClient
16
27
  from .client import (
17
28
  CascadingEditSuggestion,
18
29
  CloneMemorysetRequest,
19
30
  CreateMemorysetRequest,
20
- EmbeddingModelResult,
21
31
  FilterItem,
22
32
  )
23
33
  from .client import LabeledMemory as LabeledMemoryResponse
@@ -29,12 +39,15 @@ from .client import (
29
39
  LabeledMemoryUpdate,
30
40
  LabeledMemoryWithFeedbackMetrics,
31
41
  LabelPredictionMemoryLookup,
42
+ LabelPredictionWithMemoriesAndFeedback,
32
43
  MemoryMetrics,
33
44
  MemorysetAnalysisConfigs,
34
45
  MemorysetMetadata,
35
46
  MemorysetMetrics,
36
47
  MemorysetUpdate,
37
48
  MemoryType,
49
+ OrcaClient,
50
+ PredictionFeedback,
38
51
  )
39
52
  from .client import ScoredMemory as ScoredMemoryResponse
40
53
  from .client import (
@@ -45,9 +58,9 @@ from .client import (
45
58
  ScoredMemoryUpdate,
46
59
  ScoredMemoryWithFeedbackMetrics,
47
60
  ScorePredictionMemoryLookup,
61
+ ScorePredictionWithMemoriesAndFeedback,
48
62
  TelemetryFilterItem,
49
63
  TelemetrySortOptions,
50
- orca_api,
51
64
  )
52
65
  from .datasource import Datasource
53
66
  from .embedding_model import (
@@ -56,6 +69,11 @@ from .embedding_model import (
56
69
  PretrainedEmbeddingModel,
57
70
  )
58
71
  from .job import Job, Status
72
+ from .telemetry import ClassificationPrediction, RegressionPrediction
73
+
74
+ if TYPE_CHECKING:
75
+ from .classification_model import ClassificationModel
76
+ from .regression_model import RegressionModel
59
77
 
60
78
  TelemetrySortItem = tuple[str, Literal["asc", "desc"]]
61
79
  """
@@ -74,7 +92,7 @@ FilterOperation = Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in", "lik
74
92
  Operations that can be used in a filter expression.
75
93
  """
76
94
 
77
- FilterValue = str | int | float | bool | datetime | None | list[str] | list[int] | list[float] | list[bool]
95
+ FilterValue = str | int | float | bool | datetime | None | list[str | None] | list[int] | list[float] | list[bool]
78
96
  """
79
97
  Values that can be used in a filter expression.
80
98
  """
@@ -292,6 +310,110 @@ class MemoryBase(ABC):
292
310
  raise AttributeError(f"{key} is not a valid attribute")
293
311
  return self.metadata[key]
294
312
 
313
+ def _convert_to_classification_prediction(
314
+ self,
315
+ prediction: LabelPredictionWithMemoriesAndFeedback,
316
+ *,
317
+ memoryset: LabeledMemoryset,
318
+ model: ClassificationModel,
319
+ ) -> ClassificationPrediction:
320
+ """
321
+ Convert internal prediction TypedDict to ClassificationPrediction object.
322
+ """
323
+ input_value = prediction.get("input_value")
324
+ input_value_str: str | None = None
325
+ if input_value is not None:
326
+ input_value_str = input_value.decode("utf-8") if isinstance(input_value, bytes) else input_value
327
+
328
+ return ClassificationPrediction(
329
+ prediction_id=prediction["prediction_id"],
330
+ label=prediction.get("label"),
331
+ label_name=prediction.get("label_name"),
332
+ score=None,
333
+ confidence=prediction["confidence"],
334
+ anomaly_score=prediction["anomaly_score"],
335
+ memoryset=memoryset,
336
+ model=model,
337
+ telemetry=prediction,
338
+ logits=prediction.get("logits"),
339
+ input_value=input_value_str,
340
+ )
341
+
342
+ def _convert_to_regression_prediction(
343
+ self,
344
+ prediction: ScorePredictionWithMemoriesAndFeedback,
345
+ *,
346
+ memoryset: ScoredMemoryset,
347
+ model: RegressionModel,
348
+ ) -> RegressionPrediction:
349
+ """
350
+ Convert internal prediction TypedDict to RegressionPrediction object.
351
+ """
352
+ input_value = prediction.get("input_value")
353
+ input_value_str: str | None = None
354
+ if input_value is not None:
355
+ input_value_str = input_value.decode("utf-8") if isinstance(input_value, bytes) else input_value
356
+
357
+ return RegressionPrediction(
358
+ prediction_id=prediction["prediction_id"],
359
+ label=None,
360
+ label_name=None,
361
+ score=prediction.get("score"),
362
+ confidence=prediction["confidence"],
363
+ anomaly_score=prediction["anomaly_score"],
364
+ memoryset=memoryset,
365
+ model=model,
366
+ telemetry=prediction,
367
+ logits=None,
368
+ input_value=input_value_str,
369
+ )
370
+
371
+ def feedback(self) -> dict[str, list[bool] | list[float]]:
372
+ """
373
+ Get feedback metrics computed from predictions that used this memory.
374
+
375
+ Returns a dictionary where:
376
+ - Keys are feedback category names
377
+ - Values are lists of feedback values (you may want to look at mean on the raw data)
378
+ """
379
+ # Collect all feedbacks by category, paginating through all predictions
380
+ feedback_by_category: dict[str, list[bool] | list[float]] = {}
381
+ batch_size = 500
382
+ offset = 0
383
+
384
+ while True:
385
+ predictions_batch = self.predictions(limit=batch_size, offset=offset)
386
+
387
+ if not predictions_batch:
388
+ break
389
+
390
+ for prediction in predictions_batch:
391
+ telemetry = prediction._telemetry
392
+ if "feedbacks" not in telemetry:
393
+ continue
394
+
395
+ for fb in telemetry["feedbacks"]:
396
+ category_name = fb["category_name"]
397
+ value = fb["value"]
398
+ # Convert BINARY (1/0) to boolean, CONTINUOUS to float
399
+ if fb["category_type"] == "BINARY":
400
+ value = bool(value)
401
+ if category_name not in feedback_by_category:
402
+ feedback_by_category[category_name] = []
403
+ cast(list[bool], feedback_by_category[category_name]).append(value)
404
+ else:
405
+ value = float(value)
406
+ if category_name not in feedback_by_category:
407
+ feedback_by_category[category_name] = []
408
+ cast(list[float], feedback_by_category[category_name]).append(value)
409
+
410
+ if len(predictions_batch) < batch_size:
411
+ break
412
+
413
+ offset += batch_size
414
+
415
+ return feedback_by_category
416
+
295
417
  def _update(
296
418
  self,
297
419
  *,
@@ -299,7 +421,8 @@ class MemoryBase(ABC):
299
421
  source_id: str | None = UNSET,
300
422
  **metadata: None | bool | float | int | str,
301
423
  ) -> Self:
302
- response = orca_api.PATCH(
424
+ client = OrcaClient._resolve_client()
425
+ response = client.PATCH(
303
426
  "/gpu/memoryset/{name_or_id}/memory",
304
427
  params={"name_or_id": self.memoryset_id},
305
428
  json=_parse_memory_update(
@@ -415,6 +538,75 @@ class LabeledMemory(MemoryBase):
415
538
  self._update(value=value, label=label, source_id=source_id, **metadata)
416
539
  return self
417
540
 
541
+ def predictions(
542
+ self,
543
+ limit: int = 100,
544
+ offset: int = 0,
545
+ tag: str | None = None,
546
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
547
+ expected_label_match: bool | None = None,
548
+ ) -> list[ClassificationPrediction]:
549
+ """
550
+ Get classification predictions that used this memory.
551
+
552
+ Args:
553
+ limit: Maximum number of predictions to return (default: 100)
554
+ offset: Number of predictions to skip for pagination (default: 0)
555
+ tag: Optional tag filter to only include predictions with this tag
556
+ sort: List of (field, direction) tuples for sorting results.
557
+ Valid fields: "anomaly_score", "confidence", "timestamp".
558
+ Valid directions: "asc", "desc"
559
+ expected_label_match: Filter by prediction correctness:
560
+ - True: only return correct predictions (label == expected_label)
561
+ - False: only return incorrect predictions (label != expected_label)
562
+ - None: return all predictions (default)
563
+
564
+ Returns:
565
+ List of ClassificationPrediction objects that used this memory
566
+ """
567
+
568
+ client = OrcaClient._resolve_client()
569
+ predictions_data = client.POST(
570
+ "/telemetry/prediction",
571
+ json={
572
+ "memory_id": self.memory_id,
573
+ "limit": limit,
574
+ "offset": offset,
575
+ "sort": [list(sort_item) for sort_item in sort],
576
+ "tag": tag,
577
+ "expected_label_match": expected_label_match,
578
+ },
579
+ )
580
+
581
+ # Filter to only classification predictions and convert to ClassificationPrediction objects
582
+ classification_predictions = [
583
+ cast(LabelPredictionWithMemoriesAndFeedback, p) for p in predictions_data if "label" in p
584
+ ]
585
+
586
+ from .classification_model import ClassificationModel
587
+
588
+ memorysets: dict[str, LabeledMemoryset] = {}
589
+ models: dict[str, ClassificationModel] = {}
590
+
591
+ def resolve_memoryset(memoryset_id: str) -> LabeledMemoryset:
592
+ if memoryset_id not in memorysets:
593
+ memorysets[memoryset_id] = LabeledMemoryset.open(memoryset_id)
594
+ return memorysets[memoryset_id]
595
+
596
+ def resolve_model(model_id: str) -> ClassificationModel:
597
+ if model_id not in models:
598
+ models[model_id] = ClassificationModel.open(model_id)
599
+ return models[model_id]
600
+
601
+ return [
602
+ self._convert_to_classification_prediction(
603
+ p,
604
+ memoryset=resolve_memoryset(p["memoryset_id"]),
605
+ model=resolve_model(p["model_id"]),
606
+ )
607
+ for p in classification_predictions
608
+ ]
609
+
418
610
  def to_dict(self) -> dict[str, Any]:
419
611
  """
420
612
  Convert the memory to a dictionary
@@ -456,7 +648,11 @@ class LabeledMemoryLookup(LabeledMemory):
456
648
  lookup_score: float
457
649
  attention_weight: float | None
458
650
 
459
- def __init__(self, memoryset_id: str, memory_lookup: LabeledMemoryLookupResponse | LabelPredictionMemoryLookup):
651
+ def __init__(
652
+ self,
653
+ memoryset_id: str,
654
+ memory_lookup: LabeledMemoryLookupResponse | LabelPredictionMemoryLookup,
655
+ ):
460
656
  # for internal use only, do not document
461
657
  super().__init__(memoryset_id, memory_lookup)
462
658
  self.lookup_score = memory_lookup["lookup_score"]
@@ -552,6 +748,75 @@ class ScoredMemory(MemoryBase):
552
748
  self._update(value=value, score=score, source_id=source_id, **metadata)
553
749
  return self
554
750
 
751
+ def predictions(
752
+ self,
753
+ limit: int = 100,
754
+ offset: int = 0,
755
+ tag: str | None = None,
756
+ sort: list[tuple[Literal["anomaly_score", "confidence", "timestamp"], Literal["asc", "desc"]]] = [],
757
+ expected_label_match: bool | None = None,
758
+ ) -> list[RegressionPrediction]:
759
+ """
760
+ Get regression predictions that used this memory.
761
+
762
+ Args:
763
+ limit: Maximum number of predictions to return (default: 100)
764
+ offset: Number of predictions to skip for pagination (default: 0)
765
+ tag: Optional tag filter to only include predictions with this tag
766
+ sort: List of (field, direction) tuples for sorting results.
767
+ Valid fields: "anomaly_score", "confidence", "timestamp".
768
+ Valid directions: "asc", "desc"
769
+ expected_label_match: Filter by prediction correctness:
770
+ - True: only return correct predictions (score close to expected_score)
771
+ - False: only return incorrect predictions (score differs from expected_score)
772
+ - None: return all predictions (default)
773
+ Note: For regression, "correctness" is based on score proximity to expected_score.
774
+
775
+ Returns:
776
+ List of RegressionPrediction objects that used this memory
777
+ """
778
+ client = OrcaClient._resolve_client()
779
+ predictions_data = client.POST(
780
+ "/telemetry/prediction",
781
+ json={
782
+ "memory_id": self.memory_id,
783
+ "limit": limit,
784
+ "offset": offset,
785
+ "sort": [list(sort_item) for sort_item in sort],
786
+ "tag": tag,
787
+ "expected_label_match": expected_label_match,
788
+ },
789
+ )
790
+
791
+ # Filter to only regression predictions and convert to RegressionPrediction objects
792
+ regression_predictions = [
793
+ cast(ScorePredictionWithMemoriesAndFeedback, p) for p in predictions_data if "score" in p
794
+ ]
795
+
796
+ from .regression_model import RegressionModel
797
+
798
+ memorysets: dict[str, ScoredMemoryset] = {}
799
+ models: dict[str, RegressionModel] = {}
800
+
801
+ def resolve_memoryset(memoryset_id: str) -> ScoredMemoryset:
802
+ if memoryset_id not in memorysets:
803
+ memorysets[memoryset_id] = ScoredMemoryset.open(memoryset_id)
804
+ return memorysets[memoryset_id]
805
+
806
+ def resolve_model(model_id: str) -> RegressionModel:
807
+ if model_id not in models:
808
+ models[model_id] = RegressionModel.open(model_id)
809
+ return models[model_id]
810
+
811
+ return [
812
+ self._convert_to_regression_prediction(
813
+ p,
814
+ memoryset=resolve_memoryset(p["memoryset_id"]),
815
+ model=resolve_model(p["model_id"]),
816
+ )
817
+ for p in regression_predictions
818
+ ]
819
+
555
820
  def to_dict(self) -> dict[str, Any]:
556
821
  """
557
822
  Convert the memory to a dictionary
@@ -588,7 +853,11 @@ class ScoredMemoryLookup(ScoredMemory):
588
853
  lookup_score: float
589
854
  attention_weight: float | None
590
855
 
591
- def __init__(self, memoryset_id: str, memory_lookup: ScoredMemoryLookupResponse | ScorePredictionMemoryLookup):
856
+ def __init__(
857
+ self,
858
+ memoryset_id: str,
859
+ memory_lookup: ScoredMemoryLookupResponse | ScorePredictionMemoryLookup,
860
+ ):
592
861
  # for internal use only, do not document
593
862
  super().__init__(memoryset_id, memory_lookup)
594
863
  self.lookup_score = memory_lookup["lookup_score"]
@@ -637,6 +906,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
637
906
  index_params: dict[str, Any]
638
907
  hidden: bool
639
908
 
909
+ _batch_size = 32 # max number of memories to insert/update/delete in a single API call
910
+
640
911
  def __init__(self, metadata: MemorysetMetadata):
641
912
  # for internal use only, do not document
642
913
  if metadata["pretrained_embedding_model_name"]:
@@ -670,55 +941,48 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
670
941
  "})"
671
942
  )
672
943
 
673
- @overload
674
944
  @classmethod
675
- def create(
945
+ def _handle_if_exists(
676
946
  cls,
677
947
  name: str,
678
- datasource: Datasource,
679
948
  *,
680
- embedding_model: FinetunedEmbeddingModel | PretrainedEmbeddingModel | None = None,
681
- value_column: str = "value",
682
- label_column: str | None = None,
683
- score_column: str | None = None,
684
- source_id_column: str | None = None,
685
- description: str | None = None,
686
- label_names: list[str] | None = None,
687
- max_seq_length_override: int | None = None,
688
- prompt: str | None = None,
689
- remove_duplicates: bool = True,
690
- index_type: IndexType = "FLAT",
691
- index_params: dict[str, Any] = {},
692
- if_exists: CreateMode = "error",
693
- background: Literal[True],
694
- hidden: bool = False,
695
- ) -> Job[Self]:
696
- pass
949
+ if_exists: CreateMode,
950
+ label_names: list[str] | None,
951
+ embedding_model: PretrainedEmbeddingModel | FinetunedEmbeddingModel | None,
952
+ ) -> Self | None:
953
+ """
954
+ Handle common `if_exists` logic shared by all creator-style helpers.
697
955
 
698
- @overload
699
- @classmethod
700
- def create(
701
- cls,
702
- name: str,
703
- datasource: Datasource,
704
- *,
705
- embedding_model: FinetunedEmbeddingModel | PretrainedEmbeddingModel | None = None,
706
- value_column: str = "value",
707
- label_column: str | None = None,
708
- score_column: str | None = None,
709
- source_id_column: str | None = None,
710
- description: str | None = None,
711
- label_names: list[str] | None = None,
712
- max_seq_length_override: int | None = None,
713
- prompt: str | None = None,
714
- remove_duplicates: bool = True,
715
- index_type: IndexType = "FLAT",
716
- index_params: dict[str, Any] = {},
717
- if_exists: CreateMode = "error",
718
- background: Literal[False] = False,
719
- hidden: bool = False,
720
- ) -> Self:
721
- pass
956
+ Returns the already-existing memoryset when `if_exists == "open"`, raises for `"error"`,
957
+ and returns `None` when the memoryset does not yet exist.
958
+ """
959
+ if not cls.exists(name):
960
+ return None
961
+ if if_exists == "error":
962
+ raise ValueError(f"Memoryset with name {name} already exists")
963
+
964
+ existing = cls.open(name)
965
+
966
+ if label_names is not None and hasattr(existing, "label_names"):
967
+ existing_label_names = getattr(existing, "label_names")
968
+ if label_names != existing_label_names:
969
+ requested = ", ".join(label_names)
970
+ existing_joined = ", ".join(existing_label_names)
971
+ raise ValueError(
972
+ f"Memoryset {name} already exists with label names [{existing_joined}] "
973
+ f"(requested: [{requested}])."
974
+ )
975
+
976
+ if embedding_model is not None and embedding_model != existing.embedding_model:
977
+ existing_model = existing.embedding_model
978
+ existing_model_name = getattr(existing_model, "name", getattr(existing_model, "path", str(existing_model)))
979
+ requested_name = getattr(embedding_model, "name", getattr(embedding_model, "path", str(embedding_model)))
980
+ raise ValueError(
981
+ f"Memoryset {name} already exists with embedding_model {existing_model_name} "
982
+ f"(requested: {requested_name})."
983
+ )
984
+
985
+ return existing
722
986
 
723
987
  @classmethod
724
988
  def create(
@@ -741,6 +1005,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
741
1005
  if_exists: CreateMode = "error",
742
1006
  background: bool = False,
743
1007
  hidden: bool = False,
1008
+ subsample: int | float | None = None,
1009
+ memory_type: MemoryType | None = None,
744
1010
  ) -> Self | Job[Self]:
745
1011
  """
746
1012
  Create a new memoryset in the OrcaCloud
@@ -754,8 +1020,9 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
754
1020
  embedding_model: Embedding model to use for embedding memory values for semantic search.
755
1021
  If not provided, a default embedding model for the memoryset will be used.
756
1022
  value_column: Name of the column in the datasource that contains the memory values
757
- label_column: Name of the column in the datasource that contains the memory labels,
758
- these must be contiguous integers starting from 0
1023
+ label_column: Name of the column in the datasource that contains the memory labels.
1024
+ Must contain categorical values as integers or strings. String labels will be
1025
+ converted to integers with the unique strings extracted as `label_names`
759
1026
  score_column: Name of the column in the datasource that contains the memory scores
760
1027
  source_id_column: Optional name of the column in the datasource that contains the ids in
761
1028
  the system of reference
@@ -763,9 +1030,9 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
763
1030
  so make sure it is concise and describes the contents of your memoryset not the
764
1031
  datasource or the embedding model.
765
1032
  label_names: List of human-readable names for the labels in the memoryset, must match
766
- the number of labels in the `label_column`. Will be automatically inferred if a
767
- [Dataset][datasets.Dataset] with a [`ClassLabel`][datasets.ClassLabel] feature for
768
- labels is used as the datasource
1033
+ the number of labels in the `label_column`. Will be automatically inferred if string
1034
+ labels are provided or if a [Dataset][datasets.Dataset] with a
1035
+ [`ClassLabel`][datasets.ClassLabel] feature for labels is used as the datasource
769
1036
  max_seq_length_override: Maximum sequence length of values in the memoryset, if the
770
1037
  value is longer than this it will be truncated, will default to the model's max
771
1038
  sequence length if not provided
@@ -779,7 +1046,10 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
779
1046
  `"error"`. Other option is `"open"` to open the existing memoryset.
780
1047
  background: Whether to run the operation none blocking and return a job handle
781
1048
  hidden: Whether the memoryset should be hidden
782
-
1049
+ subsample: Optional number (int) of rows to insert or fraction (float in (0, 1]) of the
1050
+ datasource to insert. Use to limit the size of the initial memoryset.
1051
+ memory_type: Type of memoryset to create, defaults to `"LABELED"` if `label_column` is provided,
1052
+ and `"SCORED"` if `score_column` is provided, must be specified for other cases.
783
1053
  Returns:
784
1054
  Handle to the new memoryset in the OrcaCloud
785
1055
 
@@ -790,18 +1060,14 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
790
1060
  if embedding_model is None:
791
1061
  embedding_model = PretrainedEmbeddingModel.GTE_BASE
792
1062
 
793
- if label_column is None and score_column is None:
794
- raise ValueError("label_column or score_column must be provided")
795
-
796
- if cls.exists(name):
797
- if if_exists == "error":
798
- raise ValueError(f"Memoryset with name {name} already exists")
799
- elif if_exists == "open":
800
- existing = cls.open(name)
801
- for attribute in {"label_names", "embedding_model"}:
802
- if locals()[attribute] is not None and locals()[attribute] != getattr(existing, attribute):
803
- raise ValueError(f"Memoryset with name {name} already exists with a different {attribute}.")
804
- return existing
1063
+ existing = cls._handle_if_exists(
1064
+ name,
1065
+ if_exists=if_exists,
1066
+ label_names=label_names,
1067
+ embedding_model=embedding_model,
1068
+ )
1069
+ if existing is not None:
1070
+ return existing
805
1071
 
806
1072
  payload: CreateMemorysetRequest = {
807
1073
  "name": name,
@@ -818,6 +1084,10 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
818
1084
  "index_params": index_params,
819
1085
  "hidden": hidden,
820
1086
  }
1087
+ if memory_type is not None:
1088
+ payload["memory_type"] = memory_type
1089
+ if subsample is not None:
1090
+ payload["subsample"] = subsample
821
1091
  if prompt is not None:
822
1092
  payload["prompt"] = prompt
823
1093
  if isinstance(embedding_model, PretrainedEmbeddingModel):
@@ -826,8 +1096,9 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
826
1096
  payload["finetuned_embedding_model_name_or_id"] = embedding_model.id
827
1097
  else:
828
1098
  raise ValueError("Invalid embedding model")
829
- response = orca_api.POST("/memoryset", json=payload)
830
- job = Job(response["insertion_task_id"], lambda: cls.open(response["id"]))
1099
+ client = OrcaClient._resolve_client()
1100
+ response = client.POST("/memoryset", json=payload)
1101
+ job = Job(response["insertion_job_id"], lambda: cls.open(response["id"]))
831
1102
  return job if background else job.result()
832
1103
 
833
1104
  @overload
@@ -862,6 +1133,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
862
1133
  Returns:
863
1134
  Handle to the new memoryset in the OrcaCloud
864
1135
  """
1136
+ if_exists = kwargs.get("if_exists", "error")
1137
+ existing = cls._handle_if_exists(
1138
+ name,
1139
+ if_exists=if_exists,
1140
+ label_names=kwargs.get("label_names"),
1141
+ embedding_model=kwargs.get("embedding_model"),
1142
+ )
1143
+ if existing is not None:
1144
+ return existing
1145
+
865
1146
  datasource = Datasource.from_hf_dataset(
866
1147
  f"{name}_datasource", hf_dataset, if_exists=kwargs.get("if_exists", "error")
867
1148
  )
@@ -926,6 +1207,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
926
1207
  Returns:
927
1208
  Handle to the new memoryset in the OrcaCloud
928
1209
  """
1210
+ if_exists = kwargs.get("if_exists", "error")
1211
+ existing = cls._handle_if_exists(
1212
+ name,
1213
+ if_exists=if_exists,
1214
+ label_names=kwargs.get("label_names"),
1215
+ embedding_model=kwargs.get("embedding_model"),
1216
+ )
1217
+ if existing is not None:
1218
+ return existing
1219
+
929
1220
  datasource = Datasource.from_pytorch(
930
1221
  f"{name}_datasource", torch_data, column_names=column_names, if_exists=kwargs.get("if_exists", "error")
931
1222
  )
@@ -990,6 +1281,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
990
1281
  ... {"value": "world", "label": 1, "tag": "tag2"},
991
1282
  ... ])
992
1283
  """
1284
+ if_exists = kwargs.get("if_exists", "error")
1285
+ existing = cls._handle_if_exists(
1286
+ name,
1287
+ if_exists=if_exists,
1288
+ label_names=kwargs.get("label_names"),
1289
+ embedding_model=kwargs.get("embedding_model"),
1290
+ )
1291
+ if existing is not None:
1292
+ return existing
1293
+
993
1294
  datasource = Datasource.from_list(f"{name}_datasource", data, if_exists=kwargs.get("if_exists", "error"))
994
1295
  kwargs["background"] = background
995
1296
  return cls.create(name, datasource, **kwargs)
@@ -1053,6 +1354,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1053
1354
  ... "tag": ["tag1", "tag2"],
1054
1355
  ... })
1055
1356
  """
1357
+ if_exists = kwargs.get("if_exists", "error")
1358
+ existing = cls._handle_if_exists(
1359
+ name,
1360
+ if_exists=if_exists,
1361
+ label_names=kwargs.get("label_names"),
1362
+ embedding_model=kwargs.get("embedding_model"),
1363
+ )
1364
+ if existing is not None:
1365
+ return existing
1366
+
1056
1367
  datasource = Datasource.from_dict(f"{name}_datasource", data, if_exists=kwargs.get("if_exists", "error"))
1057
1368
  kwargs["background"] = background
1058
1369
  return cls.create(name, datasource, **kwargs)
@@ -1109,6 +1420,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1109
1420
  Returns:
1110
1421
  Handle to the new memoryset in the OrcaCloud
1111
1422
  """
1423
+ if_exists = kwargs.get("if_exists", "error")
1424
+ existing = cls._handle_if_exists(
1425
+ name,
1426
+ if_exists=if_exists,
1427
+ label_names=kwargs.get("label_names"),
1428
+ embedding_model=kwargs.get("embedding_model"),
1429
+ )
1430
+ if existing is not None:
1431
+ return existing
1432
+
1112
1433
  datasource = Datasource.from_pandas(f"{name}_datasource", dataframe, if_exists=kwargs.get("if_exists", "error"))
1113
1434
  kwargs["background"] = background
1114
1435
  return cls.create(name, datasource, **kwargs)
@@ -1165,6 +1486,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1165
1486
  Returns:
1166
1487
  Handle to the new memoryset in the OrcaCloud
1167
1488
  """
1489
+ if_exists = kwargs.get("if_exists", "error")
1490
+ existing = cls._handle_if_exists(
1491
+ name,
1492
+ if_exists=if_exists,
1493
+ label_names=kwargs.get("label_names"),
1494
+ embedding_model=kwargs.get("embedding_model"),
1495
+ )
1496
+ if existing is not None:
1497
+ return existing
1498
+
1168
1499
  datasource = Datasource.from_arrow(
1169
1500
  f"{name}_datasource", pyarrow_table, if_exists=kwargs.get("if_exists", "error")
1170
1501
  )
@@ -1230,6 +1561,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1230
1561
  Returns:
1231
1562
  Handle to the new memoryset in the OrcaCloud
1232
1563
  """
1564
+ if_exists = kwargs.get("if_exists", "error")
1565
+ existing = cls._handle_if_exists(
1566
+ name,
1567
+ if_exists=if_exists,
1568
+ label_names=kwargs.get("label_names"),
1569
+ embedding_model=kwargs.get("embedding_model"),
1570
+ )
1571
+ if existing is not None:
1572
+ return existing
1573
+
1233
1574
  datasource = Datasource.from_disk(f"{name}_datasource", file_path, if_exists=kwargs.get("if_exists", "error"))
1234
1575
  kwargs["background"] = background
1235
1576
  return cls.create(name, datasource, **kwargs)
@@ -1248,7 +1589,26 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1248
1589
  Raises:
1249
1590
  LookupError: If the memoryset does not exist
1250
1591
  """
1251
- metadata = orca_api.GET("/memoryset/{name_or_id}", params={"name_or_id": name})
1592
+ client = OrcaClient._resolve_client()
1593
+ metadata = client.GET("/memoryset/{name_or_id}", params={"name_or_id": name})
1594
+ return cls(metadata)
1595
+
1596
+ @classmethod
1597
+ async def aopen(cls, name: str) -> Self:
1598
+ """
1599
+ Asynchronously get a handle to a memoryset in the OrcaCloud
1600
+
1601
+ Params:
1602
+ name: Name or unique identifier of the memoryset
1603
+
1604
+ Returns:
1605
+ Handle to the existing memoryset in the OrcaCloud
1606
+
1607
+ Raises:
1608
+ LookupError: If the memoryset does not exist
1609
+ """
1610
+ client = OrcaAsyncClient._resolve_client()
1611
+ metadata = await client.GET("/memoryset/{name_or_id}", params={"name_or_id": name})
1252
1612
  return cls(metadata)
1253
1613
 
1254
1614
  @classmethod
@@ -1279,9 +1639,10 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1279
1639
  Returns:
1280
1640
  List of handles to all memorysets in the OrcaCloud
1281
1641
  """
1642
+ client = OrcaClient._resolve_client()
1282
1643
  return [
1283
1644
  cls(metadata)
1284
- for metadata in orca_api.GET("/memoryset", params={"type": cls.memory_type, "show_hidden": show_hidden})
1645
+ for metadata in client.GET("/memoryset", params={"type": cls.memory_type, "show_hidden": show_hidden})
1285
1646
  ]
1286
1647
 
1287
1648
  @classmethod
@@ -1298,7 +1659,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1298
1659
  LookupError: If the memoryset does not exist and if_not_exists is `"error"`
1299
1660
  """
1300
1661
  try:
1301
- orca_api.DELETE("/memoryset/{name_or_id}", params={"name_or_id": name_or_id})
1662
+ client = OrcaClient._resolve_client()
1663
+ client.DELETE("/memoryset/{name_or_id}", params={"name_or_id": name_or_id})
1302
1664
  logging.info(f"Deleted memoryset {name_or_id}")
1303
1665
  except LookupError:
1304
1666
  if if_not_exists == "error":
@@ -1333,7 +1695,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1333
1695
  if hidden is not UNSET:
1334
1696
  payload["hidden"] = hidden
1335
1697
 
1336
- orca_api.PATCH("/memoryset/{name_or_id}", params={"name_or_id": self.id}, json=payload)
1698
+ client = OrcaClient._resolve_client()
1699
+ client.PATCH("/memoryset/{name_or_id}", params={"name_or_id": self.id}, json=payload)
1337
1700
  self.refresh()
1338
1701
 
1339
1702
  @overload
@@ -1425,9 +1788,10 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1425
1788
  elif isinstance(embedding_model, FinetunedEmbeddingModel):
1426
1789
  payload["finetuned_embedding_model_name_or_id"] = embedding_model.id
1427
1790
 
1428
- metadata = orca_api.POST("/memoryset/{name_or_id}/clone", params={"name_or_id": self.id}, json=payload)
1791
+ client = OrcaClient._resolve_client()
1792
+ metadata = client.POST("/memoryset/{name_or_id}/clone", params={"name_or_id": self.id}, json=payload)
1429
1793
  job = Job(
1430
- metadata["insertion_task_id"],
1794
+ metadata["insertion_job_id"],
1431
1795
  lambda: self.open(metadata["id"]),
1432
1796
  )
1433
1797
  return job if background else job.result()
@@ -1556,7 +1920,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1556
1920
  ],
1557
1921
  ]
1558
1922
  """
1559
- response = orca_api.POST(
1923
+ client = OrcaClient._resolve_client()
1924
+ response = client.POST(
1560
1925
  "/gpu/memoryset/{name_or_id}/lookup",
1561
1926
  params={"name_or_id": self.id},
1562
1927
  json={
@@ -1613,7 +1978,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1613
1978
  ]
1614
1979
 
1615
1980
  if with_feedback_metrics:
1616
- response = orca_api.POST(
1981
+ client = OrcaClient._resolve_client()
1982
+ response = client.POST(
1617
1983
  "/telemetry/memories",
1618
1984
  json={
1619
1985
  "memoryset_id": self.id,
@@ -1637,7 +2003,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1637
2003
  if sort:
1638
2004
  logging.warning("Sorting is not supported when with_feedback_metrics is False. Sort value will be ignored.")
1639
2005
 
1640
- response = orca_api.POST(
2006
+ client = OrcaClient._resolve_client()
2007
+ response = client.POST(
1641
2008
  "/memoryset/{name_or_id}/memories",
1642
2009
  params={"name_or_id": self.id},
1643
2010
  json={
@@ -1698,19 +2065,74 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1698
2065
  ... {"value": "I am sad", "label": 0, "source_id": "user_124", "tag": "sad"},
1699
2066
  ... ])
1700
2067
  """
1701
- orca_api.POST(
1702
- "/gpu/memoryset/{name_or_id}/memory",
1703
- params={"name_or_id": self.id},
1704
- json=cast(
1705
- list[LabeledMemoryInsert] | list[ScoredMemoryInsert],
1706
- [
1707
- _parse_memory_insert(memory, type=self.memory_type)
1708
- for memory in (cast(list[dict[str, Any]], [items]) if isinstance(items, dict) else items)
1709
- ],
1710
- ),
1711
- )
2068
+ client = OrcaClient._resolve_client()
2069
+ items = cast(list[dict[str, Any]], [items]) if isinstance(items, dict) else list(items)
2070
+ # insert memories in batches to avoid API timeouts
2071
+ for i in range(0, len(items), self._batch_size):
2072
+ batch = items[i : i + self._batch_size]
2073
+ client.POST(
2074
+ "/gpu/memoryset/{name_or_id}/memory",
2075
+ params={"name_or_id": self.id},
2076
+ json=cast(
2077
+ list[LabeledMemoryInsert] | list[ScoredMemoryInsert],
2078
+ [_parse_memory_insert(item, type=self.memory_type) for item in batch],
2079
+ ),
2080
+ )
2081
+
1712
2082
  self.refresh()
1713
2083
 
2084
+ async def ainsert(self, items: Iterable[dict[str, Any]] | dict[str, Any]) -> None:
2085
+ """
2086
+ Asynchronously insert memories into the memoryset
2087
+
2088
+ Params:
2089
+ items: List of memories to insert into the memoryset. This should be a list of
2090
+ dictionaries with the following keys:
2091
+
2092
+ - `value`: Value of the memory
2093
+ - `label`: Label of the memory
2094
+ - `score`: Score of the memory
2095
+ - `source_id`: Optional unique ID of the memory in a system of reference
2096
+ - `...`: Any other metadata to store for the memory
2097
+
2098
+ Examples:
2099
+ >>> await memoryset.ainsert([
2100
+ ... {"value": "I am happy", "label": 1, "source_id": "user_123", "tag": "happy"},
2101
+ ... {"value": "I am sad", "label": 0, "source_id": "user_124", "tag": "sad"},
2102
+ ... ])
2103
+ """
2104
+ client = OrcaAsyncClient._resolve_client()
2105
+ items = cast(list[dict[str, Any]], [items]) if isinstance(items, dict) else list(items)
2106
+ # insert memories in batches to avoid API timeouts
2107
+ for i in range(0, len(items), self._batch_size):
2108
+ batch = items[i : i + self._batch_size]
2109
+ await client.POST(
2110
+ "/gpu/memoryset/{name_or_id}/memory",
2111
+ params={"name_or_id": self.id},
2112
+ json=cast(
2113
+ list[LabeledMemoryInsert] | list[ScoredMemoryInsert],
2114
+ [_parse_memory_insert(item, type=self.memory_type) for item in batch],
2115
+ ),
2116
+ )
2117
+
2118
+ await self.arefresh()
2119
+
2120
+ async def arefresh(self, throttle: float = 0):
2121
+ """
2122
+ Asynchronously refresh the information about the memoryset from the OrcaCloud
2123
+
2124
+ Params:
2125
+ throttle: Minimum time in seconds between refreshes
2126
+ """
2127
+ current_time = datetime.now()
2128
+ # Skip refresh if last refresh was too recent
2129
+ if (current_time - self._last_refresh) < timedelta(seconds=throttle):
2130
+ return
2131
+
2132
+ refreshed_memoryset = await type(self).aopen(self.id)
2133
+ self.__dict__.update(refreshed_memoryset.__dict__)
2134
+ self._last_refresh = current_time
2135
+
1714
2136
  @overload
1715
2137
  def get(self, memory_id: str) -> MemoryT: # type: ignore -- this takes precedence
1716
2138
  pass
@@ -1748,7 +2170,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1748
2170
  ]
1749
2171
  """
1750
2172
  if isinstance(memory_id, str):
1751
- response = orca_api.GET(
2173
+ client = OrcaClient._resolve_client()
2174
+ response = client.GET(
1752
2175
  "/memoryset/{name_or_id}/memory/{memory_id}", params={"name_or_id": self.id, "memory_id": memory_id}
1753
2176
  )
1754
2177
  return cast(
@@ -1756,7 +2179,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1756
2179
  (LabeledMemory(self.id, response) if "label" in response else ScoredMemory(self.id, response)),
1757
2180
  )
1758
2181
  else:
1759
- response = orca_api.POST(
2182
+ client = OrcaClient._resolve_client()
2183
+ response = client.POST(
1760
2184
  "/memoryset/{name_or_id}/memories/get",
1761
2185
  params={"name_or_id": self.id},
1762
2186
  json={"memory_ids": list(memory_id)},
@@ -1809,24 +2233,28 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1809
2233
  ... for m in memoryset.query(filters=[("tag", "==", "happy")])
1810
2234
  ... )
1811
2235
  """
1812
- response = orca_api.PATCH(
1813
- "/gpu/memoryset/{name_or_id}/memories",
1814
- params={"name_or_id": self.id},
1815
- json=cast(
1816
- list[LabeledMemoryUpdate] | list[ScoredMemoryUpdate],
1817
- [
1818
- _parse_memory_update(update, type=self.memory_type)
1819
- for update in (cast(list[dict[str, Any]], [updates]) if isinstance(updates, dict) else updates)
1820
- ],
1821
- ),
1822
- )
1823
- updated_memories = [
1824
- cast(
1825
- MemoryT,
1826
- (LabeledMemory(self.id, memory) if "label" in memory else ScoredMemory(self.id, memory)),
2236
+ client = OrcaClient._resolve_client()
2237
+ updates_list = cast(list[dict[str, Any]], [updates]) if isinstance(updates, dict) else list(updates)
2238
+ # update memories in batches to avoid API timeouts
2239
+ updated_memories: list[MemoryT] = []
2240
+ for i in range(0, len(updates_list), self._batch_size):
2241
+ batch = updates_list[i : i + self._batch_size]
2242
+ response = client.PATCH(
2243
+ "/gpu/memoryset/{name_or_id}/memories",
2244
+ params={"name_or_id": self.id},
2245
+ json=cast(
2246
+ list[LabeledMemoryUpdate] | list[ScoredMemoryUpdate],
2247
+ [_parse_memory_update(update, type=self.memory_type) for update in batch],
2248
+ ),
1827
2249
  )
1828
- for memory in response
1829
- ]
2250
+ updated_memories.extend(
2251
+ cast(
2252
+ MemoryT,
2253
+ (LabeledMemory(self.id, memory) if "label" in memory else ScoredMemory(self.id, memory)),
2254
+ )
2255
+ for memory in response
2256
+ )
2257
+
1830
2258
  return updated_memories[0] if isinstance(updates, dict) else updated_memories
1831
2259
 
1832
2260
  def get_cascading_edits_suggestions(
@@ -1869,7 +2297,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1869
2297
  A list of CascadingEditSuggestion objects, each containing a neighbor and the suggested new label.
1870
2298
  """
1871
2299
  # TODO: properly integrate this with memory edits and return something that can be applied
1872
- return orca_api.POST(
2300
+ client = OrcaClient._resolve_client()
2301
+ return client.POST(
1873
2302
  "/memoryset/{name_or_id}/memory/{memory_id}/cascading_edits",
1874
2303
  params={"name_or_id": self.id, "memory_id": memory.memory_id},
1875
2304
  json={
@@ -1903,10 +2332,14 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1903
2332
  ... )
1904
2333
 
1905
2334
  """
2335
+ client = OrcaClient._resolve_client()
1906
2336
  memory_ids = [memory_id] if isinstance(memory_id, str) else list(memory_id)
1907
- orca_api.POST(
1908
- "/memoryset/{name_or_id}/memories/delete", params={"name_or_id": self.id}, json={"memory_ids": memory_ids}
1909
- )
2337
+ # delete memories in batches to avoid API timeouts
2338
+ for i in range(0, len(memory_ids), self._batch_size):
2339
+ batch = memory_ids[i : i + self._batch_size]
2340
+ client.POST(
2341
+ "/memoryset/{name_or_id}/memories/delete", params={"name_or_id": self.id}, json={"memory_ids": batch}
2342
+ )
1910
2343
  logging.info(f"Deleted {len(memory_ids)} memories from memoryset.")
1911
2344
  self.refresh()
1912
2345
 
@@ -1951,7 +2384,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1951
2384
  - **`"duplicate"`**: Find potentially duplicate memories in the memoryset
1952
2385
  - **`"cluster"`**: Cluster the memories in the memoryset
1953
2386
  - **`"label"`**: Analyze the labels to find potential mislabelings
1954
- - **`"neighbor"`**: Analyze the neighbors to populate anomaly scores
2387
+ - **`"distribution"`**: Analyze the embedding distribution to populate
1955
2388
  - **`"projection"`**: Create a 2D projection of the embeddings for visualization
1956
2389
 
1957
2390
  lookup_count: Number of memories to lookup for each memory in the memoryset
@@ -2017,7 +2450,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2017
2450
  raise ValueError(error_msg)
2018
2451
  configs[name] = analysis
2019
2452
 
2020
- analysis = orca_api.POST(
2453
+ client = OrcaClient._resolve_client()
2454
+ analysis = client.POST(
2021
2455
  "/memoryset/{name_or_id}/analysis",
2022
2456
  params={"name_or_id": self.id},
2023
2457
  json={
@@ -2026,134 +2460,193 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2026
2460
  "clear_metrics": clear_metrics,
2027
2461
  },
2028
2462
  )
2029
- job = Job(
2030
- analysis["task_id"],
2031
- lambda: orca_api.GET(
2032
- "/memoryset/{name_or_id}/analysis/{analysis_task_id}",
2033
- params={"name_or_id": self.id, "analysis_task_id": analysis["task_id"]},
2034
- )["results"],
2035
- )
2463
+
2464
+ def get_analysis_result():
2465
+ client = OrcaClient._resolve_client()
2466
+ return client.GET(
2467
+ "/memoryset/{name_or_id}/analysis/{analysis_job_id}",
2468
+ params={"name_or_id": self.id, "analysis_job_id": analysis["job_id"]},
2469
+ )["results"]
2470
+
2471
+ job = Job(analysis["job_id"], get_analysis_result)
2036
2472
  return job if background else job.result()
2037
2473
 
2038
2474
  def get_potential_duplicate_groups(self) -> list[list[MemoryT]]:
2039
2475
  """Group potential duplicates in the memoryset"""
2040
- response = orca_api.GET("/memoryset/{name_or_id}/potential_duplicate_groups", params={"name_or_id": self.id})
2476
+ client = OrcaClient._resolve_client()
2477
+ response = client.GET("/memoryset/{name_or_id}/potential_duplicate_groups", params={"name_or_id": self.id})
2041
2478
  return [
2042
2479
  [cast(MemoryT, LabeledMemory(self.id, m) if "label" in m else ScoredMemory(self.id, m)) for m in ms]
2043
2480
  for ms in response
2044
2481
  ]
2045
2482
 
2483
+
2484
+ class LabeledMemoryset(MemorysetBase[LabeledMemory, LabeledMemoryLookup]):
2485
+ """
2486
+ A Handle to a collection of memories with labels in the OrcaCloud
2487
+
2488
+ Attributes:
2489
+ id: Unique identifier for the memoryset
2490
+ name: Unique name of the memoryset
2491
+ description: Description of the memoryset
2492
+ label_names: Names for the class labels in the memoryset
2493
+ length: Number of memories in the memoryset
2494
+ embedding_model: Embedding model used to embed the memory values for semantic search
2495
+ created_at: When the memoryset was created, automatically generated on create
2496
+ updated_at: When the memoryset was last updated, automatically updated on updates
2497
+ """
2498
+
2499
+ label_names: list[str]
2500
+ memory_type: MemoryType = "LABELED"
2501
+
2502
+ def __init__(self, metadata: MemorysetMetadata):
2503
+ super().__init__(metadata)
2504
+ assert metadata["label_names"] is not None
2505
+ self.label_names = metadata["label_names"]
2506
+
2507
+ def __eq__(self, other) -> bool:
2508
+ return isinstance(other, LabeledMemoryset) and self.id == other.id
2509
+
2046
2510
  @overload
2047
- @staticmethod
2048
- def run_embedding_evaluation(
2511
+ @classmethod
2512
+ def create(
2513
+ cls,
2514
+ name: str,
2049
2515
  datasource: Datasource,
2050
2516
  *,
2517
+ embedding_model: FinetunedEmbeddingModel | PretrainedEmbeddingModel | None = None,
2051
2518
  value_column: str = "value",
2052
- label_column: str = "label",
2519
+ label_column: str | None = "label",
2053
2520
  source_id_column: str | None = None,
2054
- neighbor_count: int = 5,
2055
- embedding_models: list[str] | None = None,
2521
+ description: str | None = None,
2522
+ label_names: list[str] | None = None,
2523
+ max_seq_length_override: int | None = None,
2524
+ prompt: str | None = None,
2525
+ remove_duplicates: bool = True,
2526
+ index_type: IndexType = "FLAT",
2527
+ index_params: dict[str, Any] = {},
2528
+ if_exists: CreateMode = "error",
2056
2529
  background: Literal[True],
2057
- ) -> Job[list[EmbeddingModelResult]]:
2530
+ hidden: bool = False,
2531
+ subsample: int | float | None = None,
2532
+ ) -> Job[Self]:
2058
2533
  pass
2059
2534
 
2060
2535
  @overload
2061
- @staticmethod
2062
- def run_embedding_evaluation(
2536
+ @classmethod
2537
+ def create(
2538
+ cls,
2539
+ name: str,
2063
2540
  datasource: Datasource,
2064
2541
  *,
2542
+ embedding_model: FinetunedEmbeddingModel | PretrainedEmbeddingModel | None = None,
2065
2543
  value_column: str = "value",
2066
- label_column: str = "label",
2544
+ label_column: str | None = "label",
2067
2545
  source_id_column: str | None = None,
2068
- neighbor_count: int = 5,
2069
- embedding_models: list[str] | None = None,
2546
+ description: str | None = None,
2547
+ label_names: list[str] | None = None,
2548
+ max_seq_length_override: int | None = None,
2549
+ prompt: str | None = None,
2550
+ remove_duplicates: bool = True,
2551
+ index_type: IndexType = "FLAT",
2552
+ index_params: dict[str, Any] = {},
2553
+ if_exists: CreateMode = "error",
2070
2554
  background: Literal[False] = False,
2071
- ) -> list[EmbeddingModelResult]:
2555
+ hidden: bool = False,
2556
+ subsample: int | float | None = None,
2557
+ ) -> Self:
2072
2558
  pass
2073
2559
 
2074
- @staticmethod
2075
- def run_embedding_evaluation(
2560
+ @classmethod
2561
+ def create( # type: ignore[override]
2562
+ cls,
2563
+ name: str,
2076
2564
  datasource: Datasource,
2077
2565
  *,
2566
+ embedding_model: FinetunedEmbeddingModel | PretrainedEmbeddingModel | None = None,
2078
2567
  value_column: str = "value",
2079
- label_column: str = "label",
2568
+ label_column: str | None = "label",
2080
2569
  source_id_column: str | None = None,
2081
- neighbor_count: int = 5,
2082
- embedding_models: list[str] | None = None,
2570
+ description: str | None = None,
2571
+ label_names: list[str] | None = None,
2572
+ max_seq_length_override: int | None = None,
2573
+ prompt: str | None = None,
2574
+ remove_duplicates: bool = True,
2575
+ index_type: IndexType = "FLAT",
2576
+ index_params: dict[str, Any] = {},
2577
+ if_exists: CreateMode = "error",
2083
2578
  background: bool = False,
2084
- ) -> Job[list[EmbeddingModelResult]] | list[EmbeddingModelResult]:
2579
+ hidden: bool = False,
2580
+ subsample: int | float | None = None,
2581
+ ) -> Self | Job[Self]:
2085
2582
  """
2086
- Test the quality of embeddings for the datasource by computing metrics such as prediction accuracy.
2583
+ Create a new labeled memoryset in the OrcaCloud
2584
+
2585
+ All columns from the datasource that are not specified in the `value_column`,
2586
+ `label_column`, or `source_id_column` will be stored as metadata in the memoryset.
2087
2587
 
2088
2588
  Params:
2089
- datasource: The datasource to run the embedding evaluation on
2589
+ name: Name for the new memoryset (must be unique)
2590
+ datasource: Source data to populate the memories in the memoryset
2591
+ embedding_model: Embedding model to use for embedding memory values for semantic search.
2592
+ If not provided, a default embedding model for the memoryset will be used.
2090
2593
  value_column: Name of the column in the datasource that contains the memory values
2091
- label_column: Name of the column in the datasource that contains the memory labels,
2092
- these must be contiguous integers starting from 0
2594
+ label_column: Name of the column in the datasource that contains the memory labels.
2595
+ Must contain categorical values as integers or strings. String labels will be
2596
+ converted to integers with the unique strings extracted as `label_names`. To create
2597
+ a memoryset with all none labels, set to `None`.
2093
2598
  source_id_column: Optional name of the column in the datasource that contains the ids in
2094
2599
  the system of reference
2095
- neighbor_count: The number of neighbors to select for prediction
2096
- embedding_models: Optional list of embedding model keys to evaluate, if not provided all
2097
- available embedding models will be used
2600
+ description: Optional description for the memoryset, this will be used in agentic flows,
2601
+ so make sure it is concise and describes the contents of your memoryset not the
2602
+ datasource or the embedding model.
2603
+ label_names: List of human-readable names for the labels in the memoryset, must match
2604
+ the number of labels in the `label_column`. Will be automatically inferred if string
2605
+ labels are provided or if a [Dataset][datasets.Dataset] with a
2606
+ [`ClassLabel`][datasets.ClassLabel] feature for labels is used as the datasource
2607
+ max_seq_length_override: Maximum sequence length of values in the memoryset, if the
2608
+ value is longer than this it will be truncated, will default to the model's max
2609
+ sequence length if not provided
2610
+ prompt: Optional prompt to use when embedding documents/memories for storage
2611
+ remove_duplicates: Whether to remove duplicates from the datasource before inserting
2612
+ into the memoryset
2613
+ index_type: Type of vector index to use for the memoryset, defaults to `"FLAT"`. Valid
2614
+ values are `"FLAT"`, `"IVF_FLAT"`, `"IVF_SQ8"`, `"IVF_PQ"`, `"HNSW"`, and `"DISKANN"`.
2615
+ index_params: Parameters for the vector index, defaults to `{}`
2616
+ if_exists: What to do if a memoryset with the same name already exists, defaults to
2617
+ `"error"`. Other option is `"open"` to open the existing memoryset.
2618
+ background: Whether to run the operation none blocking and return a job handle
2619
+ hidden: Whether the memoryset should be hidden
2098
2620
 
2099
2621
  Returns:
2100
- A dictionary containing the results of the embedding evaluation
2101
- """
2622
+ Handle to the new memoryset in the OrcaCloud
2102
2623
 
2103
- response = orca_api.POST(
2104
- "/datasource/{name_or_id}/embedding_evaluation",
2105
- params={"name_or_id": datasource.id},
2106
- json={
2107
- "value_column": value_column,
2108
- "label_column": label_column,
2109
- "source_id_column": source_id_column,
2110
- "neighbor_count": neighbor_count,
2111
- "embedding_models": embedding_models,
2112
- },
2624
+ Raises:
2625
+ ValueError: If the memoryset already exists and if_exists is `"error"` or if it is
2626
+ `"open"` and the params do not match those of the existing memoryset.
2627
+ """
2628
+ return super().create(
2629
+ name,
2630
+ datasource,
2631
+ label_column=label_column,
2632
+ score_column=None,
2633
+ embedding_model=embedding_model,
2634
+ value_column=value_column,
2635
+ source_id_column=source_id_column,
2636
+ description=description,
2637
+ label_names=label_names,
2638
+ max_seq_length_override=max_seq_length_override,
2639
+ prompt=prompt,
2640
+ remove_duplicates=remove_duplicates,
2641
+ index_type=index_type,
2642
+ index_params=index_params,
2643
+ if_exists=if_exists,
2644
+ background=background,
2645
+ hidden=hidden,
2646
+ subsample=subsample,
2647
+ memory_type="LABELED",
2113
2648
  )
2114
2649
 
2115
- def get_value() -> list[EmbeddingModelResult]:
2116
- res = orca_api.GET(
2117
- "/datasource/{name_or_id}/embedding_evaluation/{task_id}",
2118
- params={"name_or_id": datasource.id, "task_id": response["task_id"]},
2119
- )
2120
- assert res["result"] is not None
2121
- return res["result"]["evaluation_results"]
2122
-
2123
- job = Job(response["task_id"], get_value)
2124
- return job if background else job.result()
2125
-
2126
-
2127
- class LabeledMemoryset(MemorysetBase[LabeledMemory, LabeledMemoryLookup]):
2128
- """
2129
- A Handle to a collection of memories with labels in the OrcaCloud
2130
-
2131
- Attributes:
2132
- id: Unique identifier for the memoryset
2133
- name: Unique name of the memoryset
2134
- description: Description of the memoryset
2135
- label_names: Names for the class labels in the memoryset
2136
- length: Number of memories in the memoryset
2137
- embedding_model: Embedding model used to embed the memory values for semantic search
2138
- created_at: When the memoryset was created, automatically generated on create
2139
- updated_at: When the memoryset was last updated, automatically updated on updates
2140
- """
2141
-
2142
- label_names: list[str]
2143
- memory_type: MemoryType = "LABELED"
2144
-
2145
- def __init__(self, metadata: MemorysetMetadata):
2146
- super().__init__(metadata)
2147
- assert metadata["label_names"] is not None
2148
- self.label_names = metadata["label_names"]
2149
-
2150
- def __eq__(self, other) -> bool:
2151
- return isinstance(other, LabeledMemoryset) and self.id == other.id
2152
-
2153
- @classmethod
2154
- def create(cls, name: str, datasource: Datasource, *, label_column: str | None = "label", **kwargs):
2155
- return super().create(name, datasource, label_column=label_column, score_column=None, **kwargs)
2156
-
2157
2650
  def display_label_analysis(self):
2158
2651
  """
2159
2652
  Display an interactive UI to review and act upon the label analysis results
@@ -2185,6 +2678,131 @@ class ScoredMemoryset(MemorysetBase[ScoredMemory, ScoredMemoryLookup]):
2185
2678
  def __eq__(self, other) -> bool:
2186
2679
  return isinstance(other, ScoredMemoryset) and self.id == other.id
2187
2680
 
2681
+ @overload
2682
+ @classmethod
2683
+ def create(
2684
+ cls,
2685
+ name: str,
2686
+ datasource: Datasource,
2687
+ *,
2688
+ embedding_model: FinetunedEmbeddingModel | PretrainedEmbeddingModel | None = None,
2689
+ value_column: str = "value",
2690
+ score_column: str | None = "score",
2691
+ source_id_column: str | None = None,
2692
+ description: str | None = None,
2693
+ max_seq_length_override: int | None = None,
2694
+ prompt: str | None = None,
2695
+ remove_duplicates: bool = True,
2696
+ index_type: IndexType = "FLAT",
2697
+ index_params: dict[str, Any] = {},
2698
+ if_exists: CreateMode = "error",
2699
+ background: Literal[True],
2700
+ hidden: bool = False,
2701
+ subsample: int | float | None = None,
2702
+ ) -> Job[Self]:
2703
+ pass
2704
+
2705
+ @overload
2706
+ @classmethod
2707
+ def create(
2708
+ cls,
2709
+ name: str,
2710
+ datasource: Datasource,
2711
+ *,
2712
+ embedding_model: FinetunedEmbeddingModel | PretrainedEmbeddingModel | None = None,
2713
+ score_column: str | None = "score",
2714
+ value_column: str = "value",
2715
+ source_id_column: str | None = None,
2716
+ description: str | None = None,
2717
+ max_seq_length_override: int | None = None,
2718
+ prompt: str | None = None,
2719
+ remove_duplicates: bool = True,
2720
+ index_type: IndexType = "FLAT",
2721
+ index_params: dict[str, Any] = {},
2722
+ if_exists: CreateMode = "error",
2723
+ background: Literal[False] = False,
2724
+ hidden: bool = False,
2725
+ subsample: int | float | None = None,
2726
+ ) -> Self:
2727
+ pass
2728
+
2188
2729
  @classmethod
2189
- def create(cls, name: str, datasource: Datasource, *, score_column: str | None = "score", **kwargs):
2190
- return super().create(name, datasource, score_column=score_column, label_column=None, **kwargs)
2730
+ def create( # type: ignore[override]
2731
+ cls,
2732
+ name: str,
2733
+ datasource: Datasource,
2734
+ *,
2735
+ embedding_model: FinetunedEmbeddingModel | PretrainedEmbeddingModel | None = None,
2736
+ value_column: str = "value",
2737
+ score_column: str | None = "score",
2738
+ source_id_column: str | None = None,
2739
+ description: str | None = None,
2740
+ max_seq_length_override: int | None = None,
2741
+ prompt: str | None = None,
2742
+ remove_duplicates: bool = True,
2743
+ index_type: IndexType = "FLAT",
2744
+ index_params: dict[str, Any] = {},
2745
+ if_exists: CreateMode = "error",
2746
+ background: bool = False,
2747
+ hidden: bool = False,
2748
+ subsample: int | float | None = None,
2749
+ ) -> Self | Job[Self]:
2750
+ """
2751
+ Create a new scored memoryset in the OrcaCloud
2752
+
2753
+ All columns from the datasource that are not specified in the `value_column`,
2754
+ `score_column`, or `source_id_column` will be stored as metadata in the memoryset.
2755
+
2756
+ Params:
2757
+ name: Name for the new memoryset (must be unique)
2758
+ datasource: Source data to populate the memories in the memoryset
2759
+ embedding_model: Embedding model to use for embedding memory values for semantic search.
2760
+ If not provided, a default embedding model for the memoryset will be used.
2761
+ value_column: Name of the column in the datasource that contains the memory values
2762
+ score_column: Name of the column in the datasource that contains the memory scores. Must
2763
+ contain numerical values. To create a memoryset with all none scores, set to `None`.
2764
+ source_id_column: Optional name of the column in the datasource that contains the ids in
2765
+ the system of reference
2766
+ description: Optional description for the memoryset, this will be used in agentic flows,
2767
+ so make sure it is concise and describes the contents of your memoryset not the
2768
+ datasource or the embedding model.
2769
+ max_seq_length_override: Maximum sequence length of values in the memoryset, if the
2770
+ value is longer than this it will be truncated, will default to the model's max
2771
+ sequence length if not provided
2772
+ prompt: Optional prompt to use when embedding documents/memories for storage
2773
+ remove_duplicates: Whether to remove duplicates from the datasource before inserting
2774
+ into the memoryset
2775
+ index_type: Type of vector index to use for the memoryset, defaults to `"FLAT"`. Valid
2776
+ values are `"FLAT"`, `"IVF_FLAT"`, `"IVF_SQ8"`, `"IVF_PQ"`, `"HNSW"`, and `"DISKANN"`.
2777
+ index_params: Parameters for the vector index, defaults to `{}`
2778
+ if_exists: What to do if a memoryset with the same name already exists, defaults to
2779
+ `"error"`. Other option is `"open"` to open the existing memoryset.
2780
+ background: Whether to run the operation none blocking and return a job handle
2781
+ hidden: Whether the memoryset should be hidden
2782
+
2783
+ Returns:
2784
+ Handle to the new memoryset in the OrcaCloud
2785
+
2786
+ Raises:
2787
+ ValueError: If the memoryset already exists and if_exists is `"error"` or if it is
2788
+ `"open"` and the params do not match those of the existing memoryset.
2789
+ """
2790
+ return super().create(
2791
+ name,
2792
+ datasource,
2793
+ embedding_model=embedding_model,
2794
+ value_column=value_column,
2795
+ score_column=score_column,
2796
+ source_id_column=source_id_column,
2797
+ description=description,
2798
+ max_seq_length_override=max_seq_length_override,
2799
+ prompt=prompt,
2800
+ remove_duplicates=remove_duplicates,
2801
+ index_type=index_type,
2802
+ index_params=index_params,
2803
+ if_exists=if_exists,
2804
+ background=background,
2805
+ hidden=hidden,
2806
+ subsample=subsample,
2807
+ memory_type="SCORED",
2808
+ )