orca-sdk 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_utils/analysis_ui.py +4 -1
- orca_sdk/_utils/data_parsing.py +11 -3
- orca_sdk/_utils/data_parsing_disk_test.py +91 -0
- orca_sdk/_utils/{data_parsing_test.py → data_parsing_torch_test.py} +58 -143
- orca_sdk/_utils/prediction_result_ui.py +4 -1
- orca_sdk/_utils/value_parser.py +44 -17
- orca_sdk/_utils/value_parser_test.py +6 -5
- orca_sdk/async_client.py +78 -18
- orca_sdk/classification_model.py +1 -1
- orca_sdk/classification_model_test.py +69 -22
- orca_sdk/client.py +78 -16
- orca_sdk/conftest.py +87 -7
- orca_sdk/credentials.py +8 -10
- orca_sdk/credentials_test.py +5 -8
- orca_sdk/datasource.py +13 -8
- orca_sdk/datasource_test.py +8 -2
- orca_sdk/embedding_model.py +7 -2
- orca_sdk/embedding_model_test.py +29 -0
- orca_sdk/memoryset.py +325 -107
- orca_sdk/memoryset_test.py +87 -178
- orca_sdk/regression_model.py +1 -1
- orca_sdk/regression_model_test.py +44 -0
- orca_sdk/telemetry.py +1 -1
- {orca_sdk-0.1.9.dist-info → orca_sdk-0.1.11.dist-info}/METADATA +3 -5
- orca_sdk-0.1.11.dist-info/RECORD +42 -0
- orca_sdk-0.1.9.dist-info/RECORD +0 -41
- {orca_sdk-0.1.9.dist-info → orca_sdk-0.1.11.dist-info}/WHEEL +0 -0
orca_sdk/memoryset.py
CHANGED
|
@@ -16,11 +16,7 @@ from typing import (
|
|
|
16
16
|
overload,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
-
import pandas as pd
|
|
20
|
-
import pyarrow as pa
|
|
21
19
|
from datasets import Dataset
|
|
22
|
-
from torch.utils.data import DataLoader as TorchDataLoader
|
|
23
|
-
from torch.utils.data import Dataset as TorchDataset
|
|
24
20
|
|
|
25
21
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
26
22
|
from .async_client import OrcaAsyncClient
|
|
@@ -30,6 +26,7 @@ from .client import (
|
|
|
30
26
|
CreateMemorysetFromDatasourceRequest,
|
|
31
27
|
CreateMemorysetRequest,
|
|
32
28
|
FilterItem,
|
|
29
|
+
LabeledBatchMemoryUpdatePatch,
|
|
33
30
|
)
|
|
34
31
|
from .client import LabeledMemory as LabeledMemoryResponse
|
|
35
32
|
from .client import (
|
|
@@ -49,6 +46,7 @@ from .client import (
|
|
|
49
46
|
MemorysetUpdate,
|
|
50
47
|
MemoryType,
|
|
51
48
|
OrcaClient,
|
|
49
|
+
ScoredBatchMemoryUpdatePatch,
|
|
52
50
|
)
|
|
53
51
|
from .client import ScoredMemory as ScoredMemoryResponse
|
|
54
52
|
from .client import (
|
|
@@ -74,6 +72,12 @@ from .job import Job, Status
|
|
|
74
72
|
from .telemetry import ClassificationPrediction, RegressionPrediction
|
|
75
73
|
|
|
76
74
|
if TYPE_CHECKING:
|
|
75
|
+
# peer dependencies that are used for types only
|
|
76
|
+
from pandas import DataFrame as PandasDataFrame # type: ignore
|
|
77
|
+
from pyarrow import Table as PyArrowTable # type: ignore
|
|
78
|
+
from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
|
|
79
|
+
from torch.utils.data import Dataset as TorchDataset # type: ignore
|
|
80
|
+
|
|
77
81
|
from .classification_model import ClassificationModel
|
|
78
82
|
from .regression_model import RegressionModel
|
|
79
83
|
|
|
@@ -94,7 +98,21 @@ FilterOperation = Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in", "lik
|
|
|
94
98
|
Operations that can be used in a filter expression.
|
|
95
99
|
"""
|
|
96
100
|
|
|
97
|
-
FilterValue =
|
|
101
|
+
FilterValue = (
|
|
102
|
+
str
|
|
103
|
+
| int
|
|
104
|
+
| float
|
|
105
|
+
| bool
|
|
106
|
+
| datetime
|
|
107
|
+
| list[None]
|
|
108
|
+
| list[str]
|
|
109
|
+
| list[str | None]
|
|
110
|
+
| list[int]
|
|
111
|
+
| list[int | None]
|
|
112
|
+
| list[float]
|
|
113
|
+
| list[bool]
|
|
114
|
+
| None
|
|
115
|
+
)
|
|
98
116
|
"""
|
|
99
117
|
Values that can be used in a filter expression.
|
|
100
118
|
"""
|
|
@@ -134,7 +152,21 @@ def _is_metric_column(column: str):
|
|
|
134
152
|
return column in ["feedback_metrics", "lookup"]
|
|
135
153
|
|
|
136
154
|
|
|
137
|
-
|
|
155
|
+
@overload
|
|
156
|
+
def _parse_filter_item_from_tuple(input: FilterItemTuple, allow_metric_fields: Literal[False]) -> FilterItem:
|
|
157
|
+
pass
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@overload
|
|
161
|
+
def _parse_filter_item_from_tuple(
|
|
162
|
+
input: FilterItemTuple, allow_metric_fields: Literal[True] = True
|
|
163
|
+
) -> FilterItem | TelemetryFilterItem:
|
|
164
|
+
pass
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _parse_filter_item_from_tuple(
|
|
168
|
+
input: FilterItemTuple, allow_metric_fields: bool = True
|
|
169
|
+
) -> FilterItem | TelemetryFilterItem:
|
|
138
170
|
field = input[0].split(".")
|
|
139
171
|
if (
|
|
140
172
|
len(field) == 1
|
|
@@ -146,6 +178,8 @@ def _parse_filter_item_from_tuple(input: FilterItemTuple) -> FilterItem | Teleme
|
|
|
146
178
|
if isinstance(value, datetime):
|
|
147
179
|
value = value.isoformat()
|
|
148
180
|
if _is_metric_column(field[0]):
|
|
181
|
+
if not allow_metric_fields:
|
|
182
|
+
raise ValueError(f"Cannot filter on {field[0]} - metric fields are not supported")
|
|
149
183
|
if not (
|
|
150
184
|
(isinstance(value, list) and all(isinstance(v, float) or isinstance(v, int) for v in value))
|
|
151
185
|
or isinstance(value, float)
|
|
@@ -165,7 +199,7 @@ def _parse_filter_item_from_tuple(input: FilterItemTuple) -> FilterItem | Teleme
|
|
|
165
199
|
return TelemetryFilterItem(field=cast(TelemetryField, tuple(field)), op=op, value=value)
|
|
166
200
|
|
|
167
201
|
# Convert list to tuple for FilterItem field type
|
|
168
|
-
return FilterItem(field=tuple(field), op=op, value=value)
|
|
202
|
+
return FilterItem(field=tuple[Any, ...](field), op=op, value=value)
|
|
169
203
|
|
|
170
204
|
|
|
171
205
|
def _parse_sort_item_from_tuple(
|
|
@@ -238,17 +272,29 @@ def _parse_memory_insert(memory: dict[str, Any], type: MemoryType) -> LabeledMem
|
|
|
238
272
|
}
|
|
239
273
|
|
|
240
274
|
|
|
241
|
-
def
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
275
|
+
def _extract_metadata_for_patch(update: dict[str, Any], exclude_keys: set[str]) -> dict[str, Any] | None:
|
|
276
|
+
"""Extract metadata from update dict for patch operations.
|
|
277
|
+
|
|
278
|
+
Returns the metadata dict to include in the payload, or None if metadata should be omitted
|
|
279
|
+
(to preserve existing metadata on the server).
|
|
280
|
+
"""
|
|
281
|
+
if "metadata" in update and update["metadata"] is not None:
|
|
282
|
+
# User explicitly provided metadata dict (could be {} to clear all metadata)
|
|
283
|
+
metadata = update["metadata"]
|
|
284
|
+
if not isinstance(metadata, dict):
|
|
285
|
+
raise ValueError("metadata must be a dict")
|
|
286
|
+
return metadata
|
|
287
|
+
# Extract metadata from top-level keys, only include if non-empty
|
|
288
|
+
metadata = {k: v for k, v in update.items() if k not in DEFAULT_COLUMN_NAMES | exclude_keys}
|
|
289
|
+
if any(k in metadata for k in FORBIDDEN_METADATA_COLUMN_NAMES):
|
|
290
|
+
raise ValueError(f"Cannot update the following metadata keys: {', '.join(FORBIDDEN_METADATA_COLUMN_NAMES)}")
|
|
291
|
+
return metadata if metadata else None
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _parse_memory_update_patch(
|
|
295
|
+
update: dict[str, Any], type: MemoryType
|
|
296
|
+
) -> LabeledBatchMemoryUpdatePatch | ScoredBatchMemoryUpdatePatch:
|
|
297
|
+
payload: LabeledBatchMemoryUpdatePatch | ScoredBatchMemoryUpdatePatch = {}
|
|
252
298
|
if "source_id" in update:
|
|
253
299
|
source_id = update["source_id"]
|
|
254
300
|
if source_id is not None and not isinstance(source_id, str):
|
|
@@ -261,31 +307,41 @@ def _parse_memory_update(update: dict[str, Any], type: MemoryType) -> LabeledMem
|
|
|
261
307
|
payload["partition_id"] = partition_id
|
|
262
308
|
match type:
|
|
263
309
|
case "LABELED":
|
|
264
|
-
payload = cast(
|
|
310
|
+
payload = cast(LabeledBatchMemoryUpdatePatch, payload)
|
|
265
311
|
if "label" in update:
|
|
266
312
|
if not isinstance(update["label"], int):
|
|
267
313
|
raise ValueError("label must be an integer or unset")
|
|
268
314
|
payload["label"] = update["label"]
|
|
269
|
-
metadata =
|
|
270
|
-
if
|
|
271
|
-
|
|
272
|
-
f"Cannot update the following metadata keys: {', '.join(FORBIDDEN_METADATA_COLUMN_NAMES)}"
|
|
273
|
-
)
|
|
274
|
-
payload["metadata"] = metadata
|
|
315
|
+
metadata = _extract_metadata_for_patch(update, {"memory_id", "label", "metadata"})
|
|
316
|
+
if metadata is not None:
|
|
317
|
+
payload["metadata"] = metadata
|
|
275
318
|
return payload
|
|
276
319
|
case "SCORED":
|
|
277
|
-
payload = cast(
|
|
320
|
+
payload = cast(ScoredBatchMemoryUpdatePatch, payload)
|
|
278
321
|
if "score" in update:
|
|
279
322
|
if not isinstance(update["score"], (int, float)):
|
|
280
323
|
raise ValueError("score must be a number or unset")
|
|
281
324
|
payload["score"] = update["score"]
|
|
282
|
-
metadata =
|
|
283
|
-
if
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
325
|
+
metadata = _extract_metadata_for_patch(update, {"memory_id", "score", "metadata"})
|
|
326
|
+
if metadata is not None:
|
|
327
|
+
payload["metadata"] = metadata
|
|
328
|
+
return payload
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _parse_memory_update(update: dict[str, Any], type: MemoryType) -> LabeledMemoryUpdate | ScoredMemoryUpdate:
|
|
332
|
+
if "memory_id" not in update:
|
|
333
|
+
raise ValueError("memory_id must be specified in the update dictionary")
|
|
334
|
+
memory_id = update["memory_id"]
|
|
335
|
+
if not isinstance(memory_id, str):
|
|
336
|
+
raise ValueError("memory_id must be a string")
|
|
337
|
+
payload: LabeledMemoryUpdate | ScoredMemoryUpdate = {"memory_id": memory_id}
|
|
338
|
+
if "value" in update:
|
|
339
|
+
if not isinstance(update["value"], str):
|
|
340
|
+
raise ValueError("value must be a string or unset")
|
|
341
|
+
payload["value"] = update["value"]
|
|
342
|
+
for key, value in _parse_memory_update_patch(update, type).items():
|
|
343
|
+
payload[key] = value
|
|
344
|
+
return payload
|
|
289
345
|
|
|
290
346
|
|
|
291
347
|
class MemoryBase(ABC):
|
|
@@ -1817,7 +1873,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1817
1873
|
def from_pandas(
|
|
1818
1874
|
cls,
|
|
1819
1875
|
name: str,
|
|
1820
|
-
dataframe:
|
|
1876
|
+
dataframe: PandasDataFrame,
|
|
1821
1877
|
*,
|
|
1822
1878
|
background: Literal[True],
|
|
1823
1879
|
**kwargs: Any,
|
|
@@ -1829,7 +1885,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1829
1885
|
def from_pandas(
|
|
1830
1886
|
cls,
|
|
1831
1887
|
name: str,
|
|
1832
|
-
dataframe:
|
|
1888
|
+
dataframe: PandasDataFrame,
|
|
1833
1889
|
*,
|
|
1834
1890
|
background: Literal[False] = False,
|
|
1835
1891
|
**kwargs: Any,
|
|
@@ -1840,7 +1896,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1840
1896
|
def from_pandas(
|
|
1841
1897
|
cls,
|
|
1842
1898
|
name: str,
|
|
1843
|
-
dataframe:
|
|
1899
|
+
dataframe: PandasDataFrame,
|
|
1844
1900
|
*,
|
|
1845
1901
|
background: bool = False,
|
|
1846
1902
|
**kwargs: Any,
|
|
@@ -1883,7 +1939,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1883
1939
|
def from_arrow(
|
|
1884
1940
|
cls,
|
|
1885
1941
|
name: str,
|
|
1886
|
-
pyarrow_table:
|
|
1942
|
+
pyarrow_table: PyArrowTable,
|
|
1887
1943
|
*,
|
|
1888
1944
|
background: Literal[True],
|
|
1889
1945
|
**kwargs: Any,
|
|
@@ -1895,7 +1951,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1895
1951
|
def from_arrow(
|
|
1896
1952
|
cls,
|
|
1897
1953
|
name: str,
|
|
1898
|
-
pyarrow_table:
|
|
1954
|
+
pyarrow_table: PyArrowTable,
|
|
1899
1955
|
*,
|
|
1900
1956
|
background: Literal[False] = False,
|
|
1901
1957
|
**kwargs: Any,
|
|
@@ -1906,7 +1962,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1906
1962
|
def from_arrow(
|
|
1907
1963
|
cls,
|
|
1908
1964
|
name: str,
|
|
1909
|
-
pyarrow_table:
|
|
1965
|
+
pyarrow_table: PyArrowTable,
|
|
1910
1966
|
*,
|
|
1911
1967
|
background: bool = False,
|
|
1912
1968
|
**kwargs: Any,
|
|
@@ -2090,7 +2146,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2090
2146
|
]
|
|
2091
2147
|
|
|
2092
2148
|
@classmethod
|
|
2093
|
-
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
2149
|
+
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error", cascade: bool = False):
|
|
2094
2150
|
"""
|
|
2095
2151
|
Delete a memoryset from the OrcaCloud
|
|
2096
2152
|
|
|
@@ -2098,13 +2154,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2098
2154
|
name_or_id: Name or id of the memoryset
|
|
2099
2155
|
if_not_exists: What to do if the memoryset does not exist, defaults to `"error"`.
|
|
2100
2156
|
Other options are `"ignore"` to do nothing if the memoryset does not exist.
|
|
2157
|
+
cascade: If True, also delete all associated predictive models and predictions.
|
|
2158
|
+
Defaults to False.
|
|
2101
2159
|
|
|
2102
2160
|
Raises:
|
|
2103
2161
|
LookupError: If the memoryset does not exist and if_not_exists is `"error"`
|
|
2162
|
+
RuntimeError: If the memoryset has associated models and cascade is False
|
|
2104
2163
|
"""
|
|
2105
2164
|
try:
|
|
2106
2165
|
client = OrcaClient._resolve_client()
|
|
2107
|
-
client.DELETE("/memoryset/{name_or_id}", params={"name_or_id": name_or_id})
|
|
2166
|
+
client.DELETE("/memoryset/{name_or_id}", params={"name_or_id": name_or_id, "cascade": cascade})
|
|
2108
2167
|
logging.info(f"Deleted memoryset {name_or_id}")
|
|
2109
2168
|
except LookupError:
|
|
2110
2169
|
if if_not_exists == "error":
|
|
@@ -2436,10 +2495,6 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2436
2495
|
filters: list[FilterItemTuple] = [],
|
|
2437
2496
|
with_feedback_metrics: bool = False,
|
|
2438
2497
|
sort: list[TelemetrySortItem] | None = None,
|
|
2439
|
-
partition_id: str | None = None,
|
|
2440
|
-
partition_filter_mode: Literal[
|
|
2441
|
-
"ignore_partitions", "include_global", "exclude_global", "only_global"
|
|
2442
|
-
] = "include_global",
|
|
2443
2498
|
) -> list[MemoryT]:
|
|
2444
2499
|
"""
|
|
2445
2500
|
Query the memoryset for memories that match the filters
|
|
@@ -2460,26 +2515,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2460
2515
|
LabeledMemory({ label: <negative: 0>, value: "I am sad" }),
|
|
2461
2516
|
]
|
|
2462
2517
|
"""
|
|
2463
|
-
parsed_filters = [
|
|
2464
|
-
_parse_filter_item_from_tuple(filter) if isinstance(filter, tuple) else filter for filter in filters
|
|
2465
|
-
]
|
|
2466
2518
|
|
|
2519
|
+
client = OrcaClient._resolve_client()
|
|
2467
2520
|
if with_feedback_metrics:
|
|
2468
|
-
if partition_id:
|
|
2469
|
-
raise ValueError("Partition ID is not supported when with_feedback_metrics is True")
|
|
2470
|
-
if partition_filter_mode != "include_global":
|
|
2471
|
-
raise ValueError(
|
|
2472
|
-
f"Partition filter mode {partition_filter_mode} is not supported when with_feedback_metrics is True. Only 'include_global' is supported."
|
|
2473
|
-
)
|
|
2474
|
-
|
|
2475
|
-
client = OrcaClient._resolve_client()
|
|
2476
2521
|
response = client.POST(
|
|
2477
2522
|
"/telemetry/memories",
|
|
2478
2523
|
json={
|
|
2479
2524
|
"memoryset_id": self.id,
|
|
2480
2525
|
"offset": offset,
|
|
2481
2526
|
"limit": limit,
|
|
2482
|
-
"filters":
|
|
2527
|
+
"filters": [_parse_filter_item_from_tuple(filter) for filter in filters],
|
|
2483
2528
|
"sort": [_parse_sort_item_from_tuple(item) for item in sort] if sort else None,
|
|
2484
2529
|
},
|
|
2485
2530
|
)
|
|
@@ -2497,16 +2542,13 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2497
2542
|
if sort:
|
|
2498
2543
|
logging.warning("Sorting is not supported when with_feedback_metrics is False. Sort value will be ignored.")
|
|
2499
2544
|
|
|
2500
|
-
client = OrcaClient._resolve_client()
|
|
2501
2545
|
response = client.POST(
|
|
2502
2546
|
"/memoryset/{name_or_id}/memories",
|
|
2503
2547
|
params={"name_or_id": self.id},
|
|
2504
2548
|
json={
|
|
2505
2549
|
"offset": offset,
|
|
2506
2550
|
"limit": limit,
|
|
2507
|
-
"filters":
|
|
2508
|
-
"partition_id": partition_id,
|
|
2509
|
-
"partition_filter_mode": partition_filter_mode,
|
|
2551
|
+
"filters": [_parse_filter_item_from_tuple(filter, allow_metric_fields=False) for filter in filters],
|
|
2510
2552
|
},
|
|
2511
2553
|
)
|
|
2512
2554
|
return [
|
|
@@ -2524,11 +2566,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2524
2566
|
filters: list[FilterItemTuple] = [],
|
|
2525
2567
|
with_feedback_metrics: bool = False,
|
|
2526
2568
|
sort: list[TelemetrySortItem] | None = None,
|
|
2527
|
-
) ->
|
|
2569
|
+
) -> PandasDataFrame:
|
|
2528
2570
|
"""
|
|
2529
2571
|
Convert the memoryset to a pandas DataFrame
|
|
2530
2572
|
"""
|
|
2531
|
-
|
|
2573
|
+
try:
|
|
2574
|
+
from pandas import DataFrame as PandasDataFrame # type: ignore
|
|
2575
|
+
except ImportError:
|
|
2576
|
+
raise ImportError("Install pandas to use this method")
|
|
2577
|
+
|
|
2578
|
+
return PandasDataFrame(
|
|
2532
2579
|
[
|
|
2533
2580
|
memory.to_dict()
|
|
2534
2581
|
for memory in self.query(
|
|
@@ -2699,18 +2746,28 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2699
2746
|
]
|
|
2700
2747
|
|
|
2701
2748
|
@overload
|
|
2702
|
-
def update(self, updates: dict[str, Any], *, batch_size: int = 32) ->
|
|
2749
|
+
def update(self, updates: dict[str, Any] | Iterable[dict[str, Any]], *, batch_size: int = 32) -> int:
|
|
2703
2750
|
pass
|
|
2704
2751
|
|
|
2705
2752
|
@overload
|
|
2706
|
-
def update(
|
|
2753
|
+
def update(
|
|
2754
|
+
self,
|
|
2755
|
+
*,
|
|
2756
|
+
filters: list[FilterItemTuple],
|
|
2757
|
+
patch: dict[str, Any],
|
|
2758
|
+
) -> int:
|
|
2707
2759
|
pass
|
|
2708
2760
|
|
|
2709
2761
|
def update(
|
|
2710
|
-
self,
|
|
2711
|
-
|
|
2762
|
+
self,
|
|
2763
|
+
updates: dict[str, Any] | Iterable[dict[str, Any]] | None = None,
|
|
2764
|
+
*,
|
|
2765
|
+
batch_size: int = 32,
|
|
2766
|
+
filters: list[FilterItemTuple] | None = None,
|
|
2767
|
+
patch: dict[str, Any] | None = None,
|
|
2768
|
+
) -> int:
|
|
2712
2769
|
"""
|
|
2713
|
-
Update one or multiple memories in the memoryset
|
|
2770
|
+
Update one or multiple memories in the memoryset.
|
|
2714
2771
|
|
|
2715
2772
|
Params:
|
|
2716
2773
|
updates: List of updates to apply to the memories. Each update should be a dictionary
|
|
@@ -2723,10 +2780,12 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2723
2780
|
- `partition_id`: Optional new partition ID of the memory
|
|
2724
2781
|
- `...`: Optional new values for metadata properties
|
|
2725
2782
|
|
|
2726
|
-
|
|
2783
|
+
filters: Filters to match memories against. Each filter is a tuple of (field, operation, value).
|
|
2784
|
+
patch: Patch to apply to matching memories (only used with filters).
|
|
2785
|
+
batch_size: Number of memories to update in a single API call (only used with updates)
|
|
2727
2786
|
|
|
2728
2787
|
Returns:
|
|
2729
|
-
|
|
2788
|
+
The number of memories updated.
|
|
2730
2789
|
|
|
2731
2790
|
Examples:
|
|
2732
2791
|
Update a single memory:
|
|
@@ -2742,32 +2801,57 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2742
2801
|
... {"memory_id": m.memory_id, "label": 2}
|
|
2743
2802
|
... for m in memoryset.query(filters=[("tag", "==", "happy")])
|
|
2744
2803
|
... )
|
|
2804
|
+
|
|
2805
|
+
Update all memories matching a filter:
|
|
2806
|
+
>>> memoryset.update(filters=[("label", "==", 0)], patch={"label": 1})
|
|
2745
2807
|
"""
|
|
2746
2808
|
if batch_size <= 0 or batch_size > 500:
|
|
2747
2809
|
raise ValueError("batch_size must be between 1 and 500")
|
|
2810
|
+
|
|
2748
2811
|
client = OrcaClient._resolve_client()
|
|
2749
|
-
|
|
2750
|
-
#
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
|
|
2755
|
-
|
|
2756
|
-
|
|
2757
|
-
|
|
2758
|
-
|
|
2759
|
-
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
|
|
2812
|
+
|
|
2813
|
+
# Convert updates to list
|
|
2814
|
+
single_update = isinstance(updates, dict)
|
|
2815
|
+
updates_list: list[dict[str, Any]] | None
|
|
2816
|
+
if single_update:
|
|
2817
|
+
updates_list = [updates] # type: ignore[list-item]
|
|
2818
|
+
elif updates is not None:
|
|
2819
|
+
updates_list = [u for u in updates] # type: ignore[misc]
|
|
2820
|
+
else:
|
|
2821
|
+
updates_list = None
|
|
2822
|
+
|
|
2823
|
+
# Batch updates to avoid API timeouts
|
|
2824
|
+
if updates_list and len(updates_list) > batch_size:
|
|
2825
|
+
updated_count: int = 0
|
|
2826
|
+
for i in range(0, len(updates_list), batch_size):
|
|
2827
|
+
batch = updates_list[i : i + batch_size]
|
|
2828
|
+
response = client.PATCH(
|
|
2829
|
+
"/gpu/memoryset/{name_or_id}/memories",
|
|
2830
|
+
params={"name_or_id": self.id},
|
|
2831
|
+
json={"updates": [_parse_memory_update(update, type=self.memory_type) for update in batch]},
|
|
2766
2832
|
)
|
|
2767
|
-
|
|
2768
|
-
|
|
2833
|
+
updated_count += response["updated_count"]
|
|
2834
|
+
return updated_count
|
|
2769
2835
|
|
|
2770
|
-
|
|
2836
|
+
# Single request for all other cases
|
|
2837
|
+
response = client.PATCH(
|
|
2838
|
+
"/gpu/memoryset/{name_or_id}/memories",
|
|
2839
|
+
params={"name_or_id": self.id},
|
|
2840
|
+
json={
|
|
2841
|
+
"updates": (
|
|
2842
|
+
[_parse_memory_update(update, type=self.memory_type) for update in updates_list]
|
|
2843
|
+
if updates_list is not None
|
|
2844
|
+
else None
|
|
2845
|
+
),
|
|
2846
|
+
"filters": (
|
|
2847
|
+
[_parse_filter_item_from_tuple(filter, allow_metric_fields=False) for filter in filters]
|
|
2848
|
+
if filters is not None
|
|
2849
|
+
else None
|
|
2850
|
+
),
|
|
2851
|
+
"patch": _parse_memory_update_patch(patch, type=self.memory_type) if patch is not None else None,
|
|
2852
|
+
},
|
|
2853
|
+
)
|
|
2854
|
+
return response["updated_count"]
|
|
2771
2855
|
|
|
2772
2856
|
def get_cascading_edits_suggestions(
|
|
2773
2857
|
self,
|
|
@@ -2826,37 +2910,128 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2826
2910
|
},
|
|
2827
2911
|
)
|
|
2828
2912
|
|
|
2829
|
-
|
|
2913
|
+
@overload
|
|
2914
|
+
def delete(self, memory_id: str | Iterable[str], *, batch_size: int = 32) -> int:
|
|
2915
|
+
pass
|
|
2916
|
+
|
|
2917
|
+
@overload
|
|
2918
|
+
def delete(
|
|
2919
|
+
self,
|
|
2920
|
+
*,
|
|
2921
|
+
filters: list[FilterItemTuple],
|
|
2922
|
+
) -> int:
|
|
2923
|
+
pass
|
|
2924
|
+
|
|
2925
|
+
def delete(
|
|
2926
|
+
self,
|
|
2927
|
+
memory_id: str | Iterable[str] | None = None,
|
|
2928
|
+
*,
|
|
2929
|
+
batch_size: int = 32,
|
|
2930
|
+
filters: list[FilterItemTuple] | None = None,
|
|
2931
|
+
) -> int:
|
|
2830
2932
|
"""
|
|
2831
|
-
Delete memories from the memoryset
|
|
2933
|
+
Delete memories from the memoryset.
|
|
2934
|
+
|
|
2832
2935
|
|
|
2833
2936
|
Params:
|
|
2834
2937
|
memory_id: unique identifiers of the memories to delete
|
|
2835
|
-
|
|
2938
|
+
filters: Filters to match memories against. Each filter is a tuple of (field, operation, value).
|
|
2939
|
+
batch_size: Number of memories to delete in a single API call (only used with memory_id)
|
|
2940
|
+
|
|
2941
|
+
Returns:
|
|
2942
|
+
The number of memories deleted.
|
|
2836
2943
|
|
|
2837
2944
|
Examples:
|
|
2838
|
-
Delete a single memory:
|
|
2945
|
+
Delete a single memory by ID:
|
|
2839
2946
|
>>> memoryset.delete("0195019a-5bc7-7afb-b902-5945ee1fb766")
|
|
2840
2947
|
|
|
2841
|
-
Delete multiple memories:
|
|
2948
|
+
Delete multiple memories by ID:
|
|
2842
2949
|
>>> memoryset.delete([
|
|
2843
2950
|
... "0195019a-5bc7-7afb-b902-5945ee1fb766",
|
|
2844
2951
|
... "019501a1-ea08-76b2-9f62-95e4800b4841",
|
|
2845
|
-
... )
|
|
2952
|
+
... ])
|
|
2953
|
+
|
|
2954
|
+
Delete all memories matching a filter:
|
|
2955
|
+
>>> deleted_count = memoryset.delete(filters=[("label", "==", 0)])
|
|
2846
2956
|
|
|
2847
2957
|
"""
|
|
2848
2958
|
if batch_size <= 0 or batch_size > 500:
|
|
2849
2959
|
raise ValueError("batch_size must be between 1 and 500")
|
|
2960
|
+
if memory_id is not None and filters is not None:
|
|
2961
|
+
raise ValueError("Cannot specify memory_ids together with filters")
|
|
2962
|
+
|
|
2850
2963
|
client = OrcaClient._resolve_client()
|
|
2851
|
-
|
|
2852
|
-
#
|
|
2853
|
-
|
|
2854
|
-
|
|
2855
|
-
|
|
2856
|
-
|
|
2857
|
-
|
|
2858
|
-
|
|
2859
|
-
|
|
2964
|
+
|
|
2965
|
+
# Convert memory_id to list
|
|
2966
|
+
if isinstance(memory_id, str):
|
|
2967
|
+
memory_ids = [memory_id]
|
|
2968
|
+
elif memory_id is not None:
|
|
2969
|
+
memory_ids = list(memory_id)
|
|
2970
|
+
else:
|
|
2971
|
+
memory_ids = None
|
|
2972
|
+
|
|
2973
|
+
# Batch memory_id deletions to avoid API timeouts
|
|
2974
|
+
if memory_ids and len(memory_ids) > batch_size:
|
|
2975
|
+
total_deleted = 0
|
|
2976
|
+
for i in range(0, len(memory_ids), batch_size):
|
|
2977
|
+
batch = memory_ids[i : i + batch_size]
|
|
2978
|
+
response = client.POST(
|
|
2979
|
+
"/memoryset/{name_or_id}/memories/delete",
|
|
2980
|
+
params={"name_or_id": self.id},
|
|
2981
|
+
json={"memory_ids": batch},
|
|
2982
|
+
)
|
|
2983
|
+
total_deleted += response.get("deleted_count", 0)
|
|
2984
|
+
if total_deleted > 0:
|
|
2985
|
+
logging.info(f"Deleted {total_deleted} memories from memoryset.")
|
|
2986
|
+
self.refresh()
|
|
2987
|
+
return total_deleted
|
|
2988
|
+
|
|
2989
|
+
# Single request for all other cases
|
|
2990
|
+
response = client.POST(
|
|
2991
|
+
"/memoryset/{name_or_id}/memories/delete",
|
|
2992
|
+
params={"name_or_id": self.id},
|
|
2993
|
+
json={
|
|
2994
|
+
"memory_ids": memory_ids,
|
|
2995
|
+
"filters": (
|
|
2996
|
+
[_parse_filter_item_from_tuple(filter, allow_metric_fields=False) for filter in filters]
|
|
2997
|
+
if filters is not None
|
|
2998
|
+
else None
|
|
2999
|
+
),
|
|
3000
|
+
},
|
|
3001
|
+
)
|
|
3002
|
+
deleted_count = response["deleted_count"]
|
|
3003
|
+
logging.info(f"Deleted {deleted_count} memories from memoryset.")
|
|
3004
|
+
if deleted_count > 0:
|
|
3005
|
+
self.refresh()
|
|
3006
|
+
return deleted_count
|
|
3007
|
+
|
|
3008
|
+
def truncate(self, *, partition_id: str | None = UNSET) -> int:
|
|
3009
|
+
"""
|
|
3010
|
+
Delete all memories from the memoryset or a specified partition.
|
|
3011
|
+
|
|
3012
|
+
Params:
|
|
3013
|
+
partition_id: Optional partition ID to truncate, `None` refers to the global partition.
|
|
3014
|
+
|
|
3015
|
+
Returns:
|
|
3016
|
+
The number of deleted memories.
|
|
3017
|
+
"""
|
|
3018
|
+
client = OrcaClient._resolve_client()
|
|
3019
|
+
response = client.POST(
|
|
3020
|
+
"/memoryset/{name_or_id}/memories/delete",
|
|
3021
|
+
params={"name_or_id": self.id},
|
|
3022
|
+
json={
|
|
3023
|
+
"filters": (
|
|
3024
|
+
[FilterItem(field=("partition_id",), op="==", value=partition_id)]
|
|
3025
|
+
if partition_id is not UNSET
|
|
3026
|
+
else [FilterItem(field=("memory_id",), op="!=", value=None)] # match all
|
|
3027
|
+
),
|
|
3028
|
+
},
|
|
3029
|
+
)
|
|
3030
|
+
deleted_count = response["deleted_count"]
|
|
3031
|
+
logging.info(f"Deleted {deleted_count} memories from memoryset.")
|
|
3032
|
+
if deleted_count > 0:
|
|
3033
|
+
self.refresh()
|
|
3034
|
+
return deleted_count
|
|
2860
3035
|
|
|
2861
3036
|
@overload
|
|
2862
3037
|
def analyze(
|
|
@@ -3003,10 +3178,21 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
3003
3178
|
job = Job(analysis["job_id"], get_analysis_result)
|
|
3004
3179
|
return job if background else job.result()
|
|
3005
3180
|
|
|
3006
|
-
def get_potential_duplicate_groups(self) -> list[list[MemoryT]]:
|
|
3007
|
-
"""
|
|
3181
|
+
def get_potential_duplicate_groups(self) -> list[list[MemoryT]] | None:
|
|
3182
|
+
"""
|
|
3183
|
+
Group potential duplicates in the memoryset.
|
|
3184
|
+
|
|
3185
|
+
Returns:
|
|
3186
|
+
List of groups of potentially duplicate memories, where each group is a list of memories.
|
|
3187
|
+
Returns None if duplicate analysis has not been run on this memoryset yet.
|
|
3188
|
+
|
|
3189
|
+
Raises:
|
|
3190
|
+
LookupError: If the memoryset does not exist.
|
|
3191
|
+
"""
|
|
3008
3192
|
client = OrcaClient._resolve_client()
|
|
3009
3193
|
response = client.GET("/memoryset/{name_or_id}/potential_duplicate_groups", params={"name_or_id": self.id})
|
|
3194
|
+
if response is None:
|
|
3195
|
+
return None
|
|
3010
3196
|
return [
|
|
3011
3197
|
[cast(MemoryT, LabeledMemory(self.id, m) if "label" in m else ScoredMemory(self.id, m)) for m in ms]
|
|
3012
3198
|
for ms in response
|
|
@@ -3434,6 +3620,22 @@ class LabeledMemoryset(MemorysetBase[LabeledMemory, LabeledMemoryLookup]):
|
|
|
3434
3620
|
|
|
3435
3621
|
display_suggested_memory_relabels(self)
|
|
3436
3622
|
|
|
3623
|
+
@property
|
|
3624
|
+
def classification_models(self) -> list[ClassificationModel]:
|
|
3625
|
+
"""
|
|
3626
|
+
List all classification models that use this memoryset
|
|
3627
|
+
|
|
3628
|
+
Returns:
|
|
3629
|
+
List of classification models associated with this memoryset
|
|
3630
|
+
"""
|
|
3631
|
+
from .classification_model import ClassificationModel
|
|
3632
|
+
|
|
3633
|
+
client = OrcaClient._resolve_client()
|
|
3634
|
+
return [
|
|
3635
|
+
ClassificationModel(metadata)
|
|
3636
|
+
for metadata in client.GET("/classification_model", params={"memoryset_name_or_id": str(self.id)})
|
|
3637
|
+
]
|
|
3638
|
+
|
|
3437
3639
|
|
|
3438
3640
|
class ScoredMemoryset(MemorysetBase[ScoredMemory, ScoredMemoryLookup]):
|
|
3439
3641
|
"""
|
|
@@ -3809,3 +4011,19 @@ class ScoredMemoryset(MemorysetBase[ScoredMemory, ScoredMemoryLookup]):
|
|
|
3809
4011
|
subsample=subsample,
|
|
3810
4012
|
memory_type="SCORED",
|
|
3811
4013
|
)
|
|
4014
|
+
|
|
4015
|
+
@property
|
|
4016
|
+
def regression_models(self) -> list[RegressionModel]:
|
|
4017
|
+
"""
|
|
4018
|
+
List all regression models that use this memoryset
|
|
4019
|
+
|
|
4020
|
+
Returns:
|
|
4021
|
+
List of regression models associated with this memoryset
|
|
4022
|
+
"""
|
|
4023
|
+
from .regression_model import RegressionModel
|
|
4024
|
+
|
|
4025
|
+
client = OrcaClient._resolve_client()
|
|
4026
|
+
return [
|
|
4027
|
+
RegressionModel(metadata)
|
|
4028
|
+
for metadata in client.GET("/regression_model", params={"memoryset_name_or_id": str(self.id)})
|
|
4029
|
+
]
|