orca-sdk 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orca_sdk/memoryset.py CHANGED
@@ -16,11 +16,7 @@ from typing import (
16
16
  overload,
17
17
  )
18
18
 
19
- import pandas as pd
20
- import pyarrow as pa
21
19
  from datasets import Dataset
22
- from torch.utils.data import DataLoader as TorchDataLoader
23
- from torch.utils.data import Dataset as TorchDataset
24
20
 
25
21
  from ._utils.common import UNSET, CreateMode, DropMode
26
22
  from .async_client import OrcaAsyncClient
@@ -30,6 +26,7 @@ from .client import (
30
26
  CreateMemorysetFromDatasourceRequest,
31
27
  CreateMemorysetRequest,
32
28
  FilterItem,
29
+ LabeledBatchMemoryUpdatePatch,
33
30
  )
34
31
  from .client import LabeledMemory as LabeledMemoryResponse
35
32
  from .client import (
@@ -49,6 +46,7 @@ from .client import (
49
46
  MemorysetUpdate,
50
47
  MemoryType,
51
48
  OrcaClient,
49
+ ScoredBatchMemoryUpdatePatch,
52
50
  )
53
51
  from .client import ScoredMemory as ScoredMemoryResponse
54
52
  from .client import (
@@ -74,6 +72,12 @@ from .job import Job, Status
74
72
  from .telemetry import ClassificationPrediction, RegressionPrediction
75
73
 
76
74
  if TYPE_CHECKING:
75
+ # peer dependencies that are used for types only
76
+ from pandas import DataFrame as PandasDataFrame # type: ignore
77
+ from pyarrow import Table as PyArrowTable # type: ignore
78
+ from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
79
+ from torch.utils.data import Dataset as TorchDataset # type: ignore
80
+
77
81
  from .classification_model import ClassificationModel
78
82
  from .regression_model import RegressionModel
79
83
 
@@ -94,7 +98,21 @@ FilterOperation = Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in", "lik
94
98
  Operations that can be used in a filter expression.
95
99
  """
96
100
 
97
- FilterValue = str | int | float | bool | datetime | None | list[str | None] | list[int] | list[float] | list[bool]
101
+ FilterValue = (
102
+ str
103
+ | int
104
+ | float
105
+ | bool
106
+ | datetime
107
+ | list[None]
108
+ | list[str]
109
+ | list[str | None]
110
+ | list[int]
111
+ | list[int | None]
112
+ | list[float]
113
+ | list[bool]
114
+ | None
115
+ )
98
116
  """
99
117
  Values that can be used in a filter expression.
100
118
  """
@@ -134,7 +152,21 @@ def _is_metric_column(column: str):
134
152
  return column in ["feedback_metrics", "lookup"]
135
153
 
136
154
 
137
- def _parse_filter_item_from_tuple(input: FilterItemTuple) -> FilterItem | TelemetryFilterItem:
155
+ @overload
156
+ def _parse_filter_item_from_tuple(input: FilterItemTuple, allow_metric_fields: Literal[False]) -> FilterItem:
157
+ pass
158
+
159
+
160
+ @overload
161
+ def _parse_filter_item_from_tuple(
162
+ input: FilterItemTuple, allow_metric_fields: Literal[True] = True
163
+ ) -> FilterItem | TelemetryFilterItem:
164
+ pass
165
+
166
+
167
+ def _parse_filter_item_from_tuple(
168
+ input: FilterItemTuple, allow_metric_fields: bool = True
169
+ ) -> FilterItem | TelemetryFilterItem:
138
170
  field = input[0].split(".")
139
171
  if (
140
172
  len(field) == 1
@@ -146,6 +178,8 @@ def _parse_filter_item_from_tuple(input: FilterItemTuple) -> FilterItem | Teleme
146
178
  if isinstance(value, datetime):
147
179
  value = value.isoformat()
148
180
  if _is_metric_column(field[0]):
181
+ if not allow_metric_fields:
182
+ raise ValueError(f"Cannot filter on {field[0]} - metric fields are not supported")
149
183
  if not (
150
184
  (isinstance(value, list) and all(isinstance(v, float) or isinstance(v, int) for v in value))
151
185
  or isinstance(value, float)
@@ -165,7 +199,7 @@ def _parse_filter_item_from_tuple(input: FilterItemTuple) -> FilterItem | Teleme
165
199
  return TelemetryFilterItem(field=cast(TelemetryField, tuple(field)), op=op, value=value)
166
200
 
167
201
  # Convert list to tuple for FilterItem field type
168
- return FilterItem(field=tuple(field), op=op, value=value) # type: ignore[assignment]
202
+ return FilterItem(field=tuple[Any, ...](field), op=op, value=value)
169
203
 
170
204
 
171
205
  def _parse_sort_item_from_tuple(
@@ -238,17 +272,29 @@ def _parse_memory_insert(memory: dict[str, Any], type: MemoryType) -> LabeledMem
238
272
  }
239
273
 
240
274
 
241
- def _parse_memory_update(update: dict[str, Any], type: MemoryType) -> LabeledMemoryUpdate | ScoredMemoryUpdate:
242
- if "memory_id" not in update:
243
- raise ValueError("memory_id must be specified in the update dictionary")
244
- memory_id = update["memory_id"]
245
- if not isinstance(memory_id, str):
246
- raise ValueError("memory_id must be a string")
247
- payload: LabeledMemoryUpdate | ScoredMemoryUpdate = {"memory_id": memory_id}
248
- if "value" in update:
249
- if not isinstance(update["value"], str):
250
- raise ValueError("value must be a string or unset")
251
- payload["value"] = update["value"]
275
+ def _extract_metadata_for_patch(update: dict[str, Any], exclude_keys: set[str]) -> dict[str, Any] | None:
276
+ """Extract metadata from update dict for patch operations.
277
+
278
+ Returns the metadata dict to include in the payload, or None if metadata should be omitted
279
+ (to preserve existing metadata on the server).
280
+ """
281
+ if "metadata" in update and update["metadata"] is not None:
282
+ # User explicitly provided metadata dict (could be {} to clear all metadata)
283
+ metadata = update["metadata"]
284
+ if not isinstance(metadata, dict):
285
+ raise ValueError("metadata must be a dict")
286
+ return metadata
287
+ # Extract metadata from top-level keys, only include if non-empty
288
+ metadata = {k: v for k, v in update.items() if k not in DEFAULT_COLUMN_NAMES | exclude_keys}
289
+ if any(k in metadata for k in FORBIDDEN_METADATA_COLUMN_NAMES):
290
+ raise ValueError(f"Cannot update the following metadata keys: {', '.join(FORBIDDEN_METADATA_COLUMN_NAMES)}")
291
+ return metadata if metadata else None
292
+
293
+
294
+ def _parse_memory_update_patch(
295
+ update: dict[str, Any], type: MemoryType
296
+ ) -> LabeledBatchMemoryUpdatePatch | ScoredBatchMemoryUpdatePatch:
297
+ payload: LabeledBatchMemoryUpdatePatch | ScoredBatchMemoryUpdatePatch = {}
252
298
  if "source_id" in update:
253
299
  source_id = update["source_id"]
254
300
  if source_id is not None and not isinstance(source_id, str):
@@ -261,31 +307,41 @@ def _parse_memory_update(update: dict[str, Any], type: MemoryType) -> LabeledMem
261
307
  payload["partition_id"] = partition_id
262
308
  match type:
263
309
  case "LABELED":
264
- payload = cast(LabeledMemoryUpdate, payload)
310
+ payload = cast(LabeledBatchMemoryUpdatePatch, payload)
265
311
  if "label" in update:
266
312
  if not isinstance(update["label"], int):
267
313
  raise ValueError("label must be an integer or unset")
268
314
  payload["label"] = update["label"]
269
- metadata = {k: v for k, v in update.items() if k not in DEFAULT_COLUMN_NAMES | {"memory_id", "label"}}
270
- if any(k in metadata for k in FORBIDDEN_METADATA_COLUMN_NAMES):
271
- raise ValueError(
272
- f"Cannot update the following metadata keys: {', '.join(FORBIDDEN_METADATA_COLUMN_NAMES)}"
273
- )
274
- payload["metadata"] = metadata
315
+ metadata = _extract_metadata_for_patch(update, {"memory_id", "label", "metadata"})
316
+ if metadata is not None:
317
+ payload["metadata"] = metadata
275
318
  return payload
276
319
  case "SCORED":
277
- payload = cast(ScoredMemoryUpdate, payload)
320
+ payload = cast(ScoredBatchMemoryUpdatePatch, payload)
278
321
  if "score" in update:
279
322
  if not isinstance(update["score"], (int, float)):
280
323
  raise ValueError("score must be a number or unset")
281
324
  payload["score"] = update["score"]
282
- metadata = {k: v for k, v in update.items() if k not in DEFAULT_COLUMN_NAMES | {"memory_id", "score"}}
283
- if any(k in metadata for k in FORBIDDEN_METADATA_COLUMN_NAMES):
284
- raise ValueError(
285
- f"Cannot update the following metadata keys: {', '.join(FORBIDDEN_METADATA_COLUMN_NAMES)}"
286
- )
287
- payload["metadata"] = metadata
288
- return cast(ScoredMemoryUpdate, payload)
325
+ metadata = _extract_metadata_for_patch(update, {"memory_id", "score", "metadata"})
326
+ if metadata is not None:
327
+ payload["metadata"] = metadata
328
+ return payload
329
+
330
+
331
+ def _parse_memory_update(update: dict[str, Any], type: MemoryType) -> LabeledMemoryUpdate | ScoredMemoryUpdate:
332
+ if "memory_id" not in update:
333
+ raise ValueError("memory_id must be specified in the update dictionary")
334
+ memory_id = update["memory_id"]
335
+ if not isinstance(memory_id, str):
336
+ raise ValueError("memory_id must be a string")
337
+ payload: LabeledMemoryUpdate | ScoredMemoryUpdate = {"memory_id": memory_id}
338
+ if "value" in update:
339
+ if not isinstance(update["value"], str):
340
+ raise ValueError("value must be a string or unset")
341
+ payload["value"] = update["value"]
342
+ for key, value in _parse_memory_update_patch(update, type).items():
343
+ payload[key] = value
344
+ return payload
289
345
 
290
346
 
291
347
  class MemoryBase(ABC):
@@ -1817,7 +1873,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1817
1873
  def from_pandas(
1818
1874
  cls,
1819
1875
  name: str,
1820
- dataframe: pd.DataFrame,
1876
+ dataframe: PandasDataFrame,
1821
1877
  *,
1822
1878
  background: Literal[True],
1823
1879
  **kwargs: Any,
@@ -1829,7 +1885,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1829
1885
  def from_pandas(
1830
1886
  cls,
1831
1887
  name: str,
1832
- dataframe: pd.DataFrame,
1888
+ dataframe: PandasDataFrame,
1833
1889
  *,
1834
1890
  background: Literal[False] = False,
1835
1891
  **kwargs: Any,
@@ -1840,7 +1896,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1840
1896
  def from_pandas(
1841
1897
  cls,
1842
1898
  name: str,
1843
- dataframe: pd.DataFrame,
1899
+ dataframe: PandasDataFrame,
1844
1900
  *,
1845
1901
  background: bool = False,
1846
1902
  **kwargs: Any,
@@ -1883,7 +1939,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1883
1939
  def from_arrow(
1884
1940
  cls,
1885
1941
  name: str,
1886
- pyarrow_table: pa.Table,
1942
+ pyarrow_table: PyArrowTable,
1887
1943
  *,
1888
1944
  background: Literal[True],
1889
1945
  **kwargs: Any,
@@ -1895,7 +1951,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1895
1951
  def from_arrow(
1896
1952
  cls,
1897
1953
  name: str,
1898
- pyarrow_table: pa.Table,
1954
+ pyarrow_table: PyArrowTable,
1899
1955
  *,
1900
1956
  background: Literal[False] = False,
1901
1957
  **kwargs: Any,
@@ -1906,7 +1962,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
1906
1962
  def from_arrow(
1907
1963
  cls,
1908
1964
  name: str,
1909
- pyarrow_table: pa.Table,
1965
+ pyarrow_table: PyArrowTable,
1910
1966
  *,
1911
1967
  background: bool = False,
1912
1968
  **kwargs: Any,
@@ -2090,7 +2146,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2090
2146
  ]
2091
2147
 
2092
2148
  @classmethod
2093
- def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
2149
+ def drop(cls, name_or_id: str, if_not_exists: DropMode = "error", cascade: bool = False):
2094
2150
  """
2095
2151
  Delete a memoryset from the OrcaCloud
2096
2152
 
@@ -2098,13 +2154,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2098
2154
  name_or_id: Name or id of the memoryset
2099
2155
  if_not_exists: What to do if the memoryset does not exist, defaults to `"error"`.
2100
2156
  Other options are `"ignore"` to do nothing if the memoryset does not exist.
2157
+ cascade: If True, also delete all associated predictive models and predictions.
2158
+ Defaults to False.
2101
2159
 
2102
2160
  Raises:
2103
2161
  LookupError: If the memoryset does not exist and if_not_exists is `"error"`
2162
+ RuntimeError: If the memoryset has associated models and cascade is False
2104
2163
  """
2105
2164
  try:
2106
2165
  client = OrcaClient._resolve_client()
2107
- client.DELETE("/memoryset/{name_or_id}", params={"name_or_id": name_or_id})
2166
+ client.DELETE("/memoryset/{name_or_id}", params={"name_or_id": name_or_id, "cascade": cascade})
2108
2167
  logging.info(f"Deleted memoryset {name_or_id}")
2109
2168
  except LookupError:
2110
2169
  if if_not_exists == "error":
@@ -2436,10 +2495,6 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2436
2495
  filters: list[FilterItemTuple] = [],
2437
2496
  with_feedback_metrics: bool = False,
2438
2497
  sort: list[TelemetrySortItem] | None = None,
2439
- partition_id: str | None = None,
2440
- partition_filter_mode: Literal[
2441
- "ignore_partitions", "include_global", "exclude_global", "only_global"
2442
- ] = "include_global",
2443
2498
  ) -> list[MemoryT]:
2444
2499
  """
2445
2500
  Query the memoryset for memories that match the filters
@@ -2460,26 +2515,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2460
2515
  LabeledMemory({ label: <negative: 0>, value: "I am sad" }),
2461
2516
  ]
2462
2517
  """
2463
- parsed_filters = [
2464
- _parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
2465
- ]
2466
2518
 
2519
+ client = OrcaClient._resolve_client()
2467
2520
  if with_feedback_metrics:
2468
- if partition_id:
2469
- raise ValueError("Partition ID is not supported when with_feedback_metrics is True")
2470
- if partition_filter_mode != "include_global":
2471
- raise ValueError(
2472
- f"Partition filter mode {partition_filter_mode} is not supported when with_feedback_metrics is True. Only 'include_global' is supported."
2473
- )
2474
-
2475
- client = OrcaClient._resolve_client()
2476
2521
  response = client.POST(
2477
2522
  "/telemetry/memories",
2478
2523
  json={
2479
2524
  "memoryset_id": self.id,
2480
2525
  "offset": offset,
2481
2526
  "limit": limit,
2482
- "filters": parsed_filters,
2527
+ "filters": [_parse_filter_item_from_tuple(filter) for filter in filters],
2483
2528
  "sort": [_parse_sort_item_from_tuple(item) for item in sort] if sort else None,
2484
2529
  },
2485
2530
  )
@@ -2497,16 +2542,13 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2497
2542
  if sort:
2498
2543
  logging.warning("Sorting is not supported when with_feedback_metrics is False. Sort value will be ignored.")
2499
2544
 
2500
- client = OrcaClient._resolve_client()
2501
2545
  response = client.POST(
2502
2546
  "/memoryset/{name_or_id}/memories",
2503
2547
  params={"name_or_id": self.id},
2504
2548
  json={
2505
2549
  "offset": offset,
2506
2550
  "limit": limit,
2507
- "filters": cast(list[FilterItem], parsed_filters),
2508
- "partition_id": partition_id,
2509
- "partition_filter_mode": partition_filter_mode,
2551
+ "filters": [_parse_filter_item_from_tuple(filter, allow_metric_fields=False) for filter in filters],
2510
2552
  },
2511
2553
  )
2512
2554
  return [
@@ -2524,11 +2566,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2524
2566
  filters: list[FilterItemTuple] = [],
2525
2567
  with_feedback_metrics: bool = False,
2526
2568
  sort: list[TelemetrySortItem] | None = None,
2527
- ) -> pd.DataFrame:
2569
+ ) -> PandasDataFrame:
2528
2570
  """
2529
2571
  Convert the memoryset to a pandas DataFrame
2530
2572
  """
2531
- return pd.DataFrame(
2573
+ try:
2574
+ from pandas import DataFrame as PandasDataFrame # type: ignore
2575
+ except ImportError:
2576
+ raise ImportError("Install pandas to use this method")
2577
+
2578
+ return PandasDataFrame(
2532
2579
  [
2533
2580
  memory.to_dict()
2534
2581
  for memory in self.query(
@@ -2699,18 +2746,28 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2699
2746
  ]
2700
2747
 
2701
2748
  @overload
2702
- def update(self, updates: dict[str, Any], *, batch_size: int = 32) -> MemoryT:
2749
+ def update(self, updates: dict[str, Any] | Iterable[dict[str, Any]], *, batch_size: int = 32) -> int:
2703
2750
  pass
2704
2751
 
2705
2752
  @overload
2706
- def update(self, updates: Iterable[dict[str, Any]], *, batch_size: int = 32) -> list[MemoryT]:
2753
+ def update(
2754
+ self,
2755
+ *,
2756
+ filters: list[FilterItemTuple],
2757
+ patch: dict[str, Any],
2758
+ ) -> int:
2707
2759
  pass
2708
2760
 
2709
2761
  def update(
2710
- self, updates: dict[str, Any] | Iterable[dict[str, Any]], *, batch_size: int = 32
2711
- ) -> MemoryT | list[MemoryT]:
2762
+ self,
2763
+ updates: dict[str, Any] | Iterable[dict[str, Any]] | None = None,
2764
+ *,
2765
+ batch_size: int = 32,
2766
+ filters: list[FilterItemTuple] | None = None,
2767
+ patch: dict[str, Any] | None = None,
2768
+ ) -> int:
2712
2769
  """
2713
- Update one or multiple memories in the memoryset
2770
+ Update one or multiple memories in the memoryset.
2714
2771
 
2715
2772
  Params:
2716
2773
  updates: List of updates to apply to the memories. Each update should be a dictionary
@@ -2723,10 +2780,12 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2723
2780
  - `partition_id`: Optional new partition ID of the memory
2724
2781
  - `...`: Optional new values for metadata properties
2725
2782
 
2726
- batch_size: Number of memories to update in a single API call
2783
+ filters: Filters to match memories against. Each filter is a tuple of (field, operation, value).
2784
+ patch: Patch to apply to matching memories (only used with filters).
2785
+ batch_size: Number of memories to update in a single API call (only used with updates)
2727
2786
 
2728
2787
  Returns:
2729
- Updated memory or list of updated memories
2788
+ The number of memories updated.
2730
2789
 
2731
2790
  Examples:
2732
2791
  Update a single memory:
@@ -2742,32 +2801,57 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2742
2801
  ... {"memory_id": m.memory_id, "label": 2}
2743
2802
  ... for m in memoryset.query(filters=[("tag", "==", "happy")])
2744
2803
  ... )
2804
+
2805
+ Update all memories matching a filter:
2806
+ >>> memoryset.update(filters=[("label", "==", 0)], patch={"label": 1})
2745
2807
  """
2746
2808
  if batch_size <= 0 or batch_size > 500:
2747
2809
  raise ValueError("batch_size must be between 1 and 500")
2810
+
2748
2811
  client = OrcaClient._resolve_client()
2749
- updates_list = cast(list[dict[str, Any]], [updates]) if isinstance(updates, dict) else list(updates)
2750
- # update memories in batches to avoid API timeouts
2751
- updated_memories: list[MemoryT] = []
2752
- for i in range(0, len(updates_list), batch_size):
2753
- batch = updates_list[i : i + batch_size]
2754
- response = client.PATCH(
2755
- "/gpu/memoryset/{name_or_id}/memories",
2756
- params={"name_or_id": self.id},
2757
- json=cast(
2758
- list[LabeledMemoryUpdate] | list[ScoredMemoryUpdate],
2759
- [_parse_memory_update(update, type=self.memory_type) for update in batch],
2760
- ),
2761
- )
2762
- updated_memories.extend(
2763
- cast(
2764
- MemoryT,
2765
- (LabeledMemory(self.id, memory) if "label" in memory else ScoredMemory(self.id, memory)),
2812
+
2813
+ # Convert updates to list
2814
+ single_update = isinstance(updates, dict)
2815
+ updates_list: list[dict[str, Any]] | None
2816
+ if single_update:
2817
+ updates_list = [updates] # type: ignore[list-item]
2818
+ elif updates is not None:
2819
+ updates_list = [u for u in updates] # type: ignore[misc]
2820
+ else:
2821
+ updates_list = None
2822
+
2823
+ # Batch updates to avoid API timeouts
2824
+ if updates_list and len(updates_list) > batch_size:
2825
+ updated_count: int = 0
2826
+ for i in range(0, len(updates_list), batch_size):
2827
+ batch = updates_list[i : i + batch_size]
2828
+ response = client.PATCH(
2829
+ "/gpu/memoryset/{name_or_id}/memories",
2830
+ params={"name_or_id": self.id},
2831
+ json={"updates": [_parse_memory_update(update, type=self.memory_type) for update in batch]},
2766
2832
  )
2767
- for memory in response
2768
- )
2833
+ updated_count += response["updated_count"]
2834
+ return updated_count
2769
2835
 
2770
- return updated_memories[0] if isinstance(updates, dict) else updated_memories
2836
+ # Single request for all other cases
2837
+ response = client.PATCH(
2838
+ "/gpu/memoryset/{name_or_id}/memories",
2839
+ params={"name_or_id": self.id},
2840
+ json={
2841
+ "updates": (
2842
+ [_parse_memory_update(update, type=self.memory_type) for update in updates_list]
2843
+ if updates_list is not None
2844
+ else None
2845
+ ),
2846
+ "filters": (
2847
+ [_parse_filter_item_from_tuple(filter, allow_metric_fields=False) for filter in filters]
2848
+ if filters is not None
2849
+ else None
2850
+ ),
2851
+ "patch": _parse_memory_update_patch(patch, type=self.memory_type) if patch is not None else None,
2852
+ },
2853
+ )
2854
+ return response["updated_count"]
2771
2855
 
2772
2856
  def get_cascading_edits_suggestions(
2773
2857
  self,
@@ -2826,37 +2910,128 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2826
2910
  },
2827
2911
  )
2828
2912
 
2829
- def delete(self, memory_id: str | Iterable[str], *, batch_size: int = 32) -> None:
2913
+ @overload
2914
+ def delete(self, memory_id: str | Iterable[str], *, batch_size: int = 32) -> int:
2915
+ pass
2916
+
2917
+ @overload
2918
+ def delete(
2919
+ self,
2920
+ *,
2921
+ filters: list[FilterItemTuple],
2922
+ ) -> int:
2923
+ pass
2924
+
2925
+ def delete(
2926
+ self,
2927
+ memory_id: str | Iterable[str] | None = None,
2928
+ *,
2929
+ batch_size: int = 32,
2930
+ filters: list[FilterItemTuple] | None = None,
2931
+ ) -> int:
2830
2932
  """
2831
- Delete memories from the memoryset
2933
+ Delete memories from the memoryset.
2934
+
2832
2935
 
2833
2936
  Params:
2834
2937
  memory_id: unique identifiers of the memories to delete
2835
- batch_size: Number of memories to delete in a single API call
2938
+ filters: Filters to match memories against. Each filter is a tuple of (field, operation, value).
2939
+ batch_size: Number of memories to delete in a single API call (only used with memory_id)
2940
+
2941
+ Returns:
2942
+ The number of memories deleted.
2836
2943
 
2837
2944
  Examples:
2838
- Delete a single memory:
2945
+ Delete a single memory by ID:
2839
2946
  >>> memoryset.delete("0195019a-5bc7-7afb-b902-5945ee1fb766")
2840
2947
 
2841
- Delete multiple memories:
2948
+ Delete multiple memories by ID:
2842
2949
  >>> memoryset.delete([
2843
2950
  ... "0195019a-5bc7-7afb-b902-5945ee1fb766",
2844
2951
  ... "019501a1-ea08-76b2-9f62-95e4800b4841",
2845
- ... )
2952
+ ... ])
2953
+
2954
+ Delete all memories matching a filter:
2955
+ >>> deleted_count = memoryset.delete(filters=[("label", "==", 0)])
2846
2956
 
2847
2957
  """
2848
2958
  if batch_size <= 0 or batch_size > 500:
2849
2959
  raise ValueError("batch_size must be between 1 and 500")
2960
+ if memory_id is not None and filters is not None:
2961
+ raise ValueError("Cannot specify memory_ids together with filters")
2962
+
2850
2963
  client = OrcaClient._resolve_client()
2851
- memory_ids = [memory_id] if isinstance(memory_id, str) else list(memory_id)
2852
- # delete memories in batches to avoid API timeouts
2853
- for i in range(0, len(memory_ids), batch_size):
2854
- batch = memory_ids[i : i + batch_size]
2855
- client.POST(
2856
- "/memoryset/{name_or_id}/memories/delete", params={"name_or_id": self.id}, json={"memory_ids": batch}
2857
- )
2858
- logging.info(f"Deleted {len(memory_ids)} memories from memoryset.")
2859
- self.refresh()
2964
+
2965
+ # Convert memory_id to list
2966
+ if isinstance(memory_id, str):
2967
+ memory_ids = [memory_id]
2968
+ elif memory_id is not None:
2969
+ memory_ids = list(memory_id)
2970
+ else:
2971
+ memory_ids = None
2972
+
2973
+ # Batch memory_id deletions to avoid API timeouts
2974
+ if memory_ids and len(memory_ids) > batch_size:
2975
+ total_deleted = 0
2976
+ for i in range(0, len(memory_ids), batch_size):
2977
+ batch = memory_ids[i : i + batch_size]
2978
+ response = client.POST(
2979
+ "/memoryset/{name_or_id}/memories/delete",
2980
+ params={"name_or_id": self.id},
2981
+ json={"memory_ids": batch},
2982
+ )
2983
+ total_deleted += response.get("deleted_count", 0)
2984
+ if total_deleted > 0:
2985
+ logging.info(f"Deleted {total_deleted} memories from memoryset.")
2986
+ self.refresh()
2987
+ return total_deleted
2988
+
2989
+ # Single request for all other cases
2990
+ response = client.POST(
2991
+ "/memoryset/{name_or_id}/memories/delete",
2992
+ params={"name_or_id": self.id},
2993
+ json={
2994
+ "memory_ids": memory_ids,
2995
+ "filters": (
2996
+ [_parse_filter_item_from_tuple(filter, allow_metric_fields=False) for filter in filters]
2997
+ if filters is not None
2998
+ else None
2999
+ ),
3000
+ },
3001
+ )
3002
+ deleted_count = response["deleted_count"]
3003
+ logging.info(f"Deleted {deleted_count} memories from memoryset.")
3004
+ if deleted_count > 0:
3005
+ self.refresh()
3006
+ return deleted_count
3007
+
3008
+ def truncate(self, *, partition_id: str | None = UNSET) -> int:
3009
+ """
3010
+ Delete all memories from the memoryset or a specified partition.
3011
+
3012
+ Params:
3013
+ partition_id: Optional partition ID to truncate, `None` refers to the global partition.
3014
+
3015
+ Returns:
3016
+ The number of deleted memories.
3017
+ """
3018
+ client = OrcaClient._resolve_client()
3019
+ response = client.POST(
3020
+ "/memoryset/{name_or_id}/memories/delete",
3021
+ params={"name_or_id": self.id},
3022
+ json={
3023
+ "filters": (
3024
+ [FilterItem(field=("partition_id",), op="==", value=partition_id)]
3025
+ if partition_id is not UNSET
3026
+ else [FilterItem(field=("memory_id",), op="!=", value=None)] # match all
3027
+ ),
3028
+ },
3029
+ )
3030
+ deleted_count = response["deleted_count"]
3031
+ logging.info(f"Deleted {deleted_count} memories from memoryset.")
3032
+ if deleted_count > 0:
3033
+ self.refresh()
3034
+ return deleted_count
2860
3035
 
2861
3036
  @overload
2862
3037
  def analyze(
@@ -3003,10 +3178,21 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
3003
3178
  job = Job(analysis["job_id"], get_analysis_result)
3004
3179
  return job if background else job.result()
3005
3180
 
3006
- def get_potential_duplicate_groups(self) -> list[list[MemoryT]]:
3007
- """Group potential duplicates in the memoryset"""
3181
+ def get_potential_duplicate_groups(self) -> list[list[MemoryT]] | None:
3182
+ """
3183
+ Group potential duplicates in the memoryset.
3184
+
3185
+ Returns:
3186
+ List of groups of potentially duplicate memories, where each group is a list of memories.
3187
+ Returns None if duplicate analysis has not been run on this memoryset yet.
3188
+
3189
+ Raises:
3190
+ LookupError: If the memoryset does not exist.
3191
+ """
3008
3192
  client = OrcaClient._resolve_client()
3009
3193
  response = client.GET("/memoryset/{name_or_id}/potential_duplicate_groups", params={"name_or_id": self.id})
3194
+ if response is None:
3195
+ return None
3010
3196
  return [
3011
3197
  [cast(MemoryT, LabeledMemory(self.id, m) if "label" in m else ScoredMemory(self.id, m)) for m in ms]
3012
3198
  for ms in response
@@ -3434,6 +3620,22 @@ class LabeledMemoryset(MemorysetBase[LabeledMemory, LabeledMemoryLookup]):
3434
3620
 
3435
3621
  display_suggested_memory_relabels(self)
3436
3622
 
3623
+ @property
3624
+ def classification_models(self) -> list[ClassificationModel]:
3625
+ """
3626
+ List all classification models that use this memoryset
3627
+
3628
+ Returns:
3629
+ List of classification models associated with this memoryset
3630
+ """
3631
+ from .classification_model import ClassificationModel
3632
+
3633
+ client = OrcaClient._resolve_client()
3634
+ return [
3635
+ ClassificationModel(metadata)
3636
+ for metadata in client.GET("/classification_model", params={"memoryset_name_or_id": str(self.id)})
3637
+ ]
3638
+
3437
3639
 
3438
3640
  class ScoredMemoryset(MemorysetBase[ScoredMemory, ScoredMemoryLookup]):
3439
3641
  """
@@ -3809,3 +4011,19 @@ class ScoredMemoryset(MemorysetBase[ScoredMemory, ScoredMemoryLookup]):
3809
4011
  subsample=subsample,
3810
4012
  memory_type="SCORED",
3811
4013
  )
4014
+
4015
+ @property
4016
+ def regression_models(self) -> list[RegressionModel]:
4017
+ """
4018
+ List all regression models that use this memoryset
4019
+
4020
+ Returns:
4021
+ List of regression models associated with this memoryset
4022
+ """
4023
+ from .regression_model import RegressionModel
4024
+
4025
+ client = OrcaClient._resolve_client()
4026
+ return [
4027
+ RegressionModel(metadata)
4028
+ for metadata in client.GET("/regression_model", params={"memoryset_name_or_id": str(self.id)})
4029
+ ]