orca-sdk 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orca_sdk/client.py CHANGED
@@ -17,6 +17,7 @@ from typing import (
17
17
  Mapping,
18
18
  NotRequired,
19
19
  Self,
20
+ TypeAlias,
20
21
  TypedDict,
21
22
  cast,
22
23
  overload,
@@ -153,11 +154,14 @@ class ClusterMetrics(TypedDict):
153
154
  memory_count: int
154
155
 
155
156
 
156
- ColumnType = Literal["STRING", "FLOAT", "INT", "BOOL", "ENUM", "IMAGE", "OTHER"]
157
+ ColumnType: TypeAlias = Literal["STRING", "FLOAT", "INT", "BOOL", "ENUM", "IMAGE", "OTHER"]
158
+ """
159
+ The type of a column in a datasource
160
+ """
157
161
 
158
162
 
159
163
  class ConstraintViolationErrorResponse(TypedDict):
160
- status_code: NotRequired[int]
164
+ status_code: Literal[409]
161
165
  constraint: str
162
166
 
163
167
 
@@ -168,6 +172,7 @@ class CountPredictionsRequest(TypedDict):
168
172
  start_timestamp: NotRequired[str | None]
169
173
  end_timestamp: NotRequired[str | None]
170
174
  memory_id: NotRequired[str | None]
175
+ expected_label_match: NotRequired[bool | None]
171
176
 
172
177
 
173
178
  class CreateApiKeyRequest(TypedDict):
@@ -230,7 +235,7 @@ class EmbeddingEvaluationRequest(TypedDict):
230
235
  weigh_memories: NotRequired[bool]
231
236
 
232
237
 
233
- EmbeddingFinetuningMethod = Literal["classification", "regression", "batch_triplet_loss"]
238
+ EmbeddingFinetuningMethod: TypeAlias = Literal["classification", "regression", "batch_triplet_loss"]
234
239
 
235
240
 
236
241
  class FeedbackMetrics(TypedDict):
@@ -238,13 +243,55 @@ class FeedbackMetrics(TypedDict):
238
243
  count: int
239
244
 
240
245
 
241
- FeedbackType = Literal["CONTINUOUS", "BINARY"]
246
+ FeedbackType: TypeAlias = Literal["CONTINUOUS", "BINARY"]
242
247
 
243
248
 
244
249
  class FilterItem(TypedDict):
245
- field: list
246
- op: Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in", "like"]
250
+ field: (
251
+ tuple[
252
+ Literal[
253
+ "memory_id",
254
+ "value",
255
+ "label",
256
+ "metadata",
257
+ "source_id",
258
+ "partition_id",
259
+ "created_at",
260
+ "updated_at",
261
+ "edited_at",
262
+ "metrics",
263
+ "score",
264
+ "labels",
265
+ ]
266
+ ]
267
+ | tuple[Literal["metadata"], str]
268
+ | tuple[
269
+ Literal["metrics"],
270
+ Literal[
271
+ "cluster",
272
+ "embedding_2d",
273
+ "is_duplicate",
274
+ "duplicate_memory_ids",
275
+ "has_potential_duplicates",
276
+ "potential_duplicate_memory_ids",
277
+ "anomaly_score",
278
+ "neighbor_label_logits",
279
+ "neighbor_predicted_label",
280
+ "neighbor_predicted_label_ambiguity",
281
+ "neighbor_predicted_label_confidence",
282
+ "current_label_neighbor_confidence",
283
+ "normalized_neighbor_label_entropy",
284
+ "neighbor_predicted_label_matches_current_label",
285
+ "spread",
286
+ "uniformity",
287
+ "concept_id",
288
+ "subconcept_id",
289
+ ],
290
+ ]
291
+ )
292
+ op: Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in", "like", "contains all", "contains any"]
247
293
  value: str | int | float | bool | list[str | None] | list[int] | list[float] | list[bool] | None
294
+ transform: NotRequired[Literal["length"]]
248
295
 
249
296
 
250
297
  class GetDatasourceRowCountRequest(TypedDict):
@@ -272,12 +319,17 @@ class HealthyResponse(TypedDict):
272
319
 
273
320
 
274
321
  class InternalServerErrorResponse(TypedDict):
275
- status_code: NotRequired[int]
322
+ status_code: Literal[500]
276
323
  message: str
277
324
  request_id: str
278
325
 
279
326
 
280
- JobStatus = Literal["INITIALIZED", "DISPATCHED", "WAITING", "PROCESSING", "COMPLETED", "FAILED", "ABORTING", "ABORTED"]
327
+ JobStatus: TypeAlias = Literal[
328
+ "INITIALIZED", "DISPATCHED", "WAITING", "PROCESSING", "COMPLETED", "FAILED", "ABORTING", "ABORTED"
329
+ ]
330
+ """
331
+ Status of job in the job queue
332
+ """
281
333
 
282
334
 
283
335
  class JobStatusInfo(TypedDict):
@@ -344,7 +396,7 @@ class MemoryMetrics(TypedDict):
344
396
  has_potential_duplicates: NotRequired[bool]
345
397
  potential_duplicate_memory_ids: NotRequired[list[str] | None]
346
398
  cluster: NotRequired[int]
347
- embedding_2d: NotRequired[list]
399
+ embedding_2d: NotRequired[tuple[float, float]]
348
400
  anomaly_score: NotRequired[float]
349
401
  neighbor_label_logits: NotRequired[list[float] | None]
350
402
  neighbor_predicted_label: NotRequired[int | None]
@@ -359,7 +411,7 @@ class MemoryMetrics(TypedDict):
359
411
  subconcept_id: NotRequired[int | None]
360
412
 
361
413
 
362
- MemoryType = Literal["LABELED", "SCORED"]
414
+ MemoryType: TypeAlias = Literal["LABELED", "SCORED"]
363
415
 
364
416
 
365
417
  class MemorysetClassPatternsAnalysisConfig(TypedDict):
@@ -465,7 +517,7 @@ class MemorysetUpdate(TypedDict):
465
517
 
466
518
 
467
519
  class NotFoundErrorResponse(TypedDict):
468
- status_code: NotRequired[int]
520
+ status_code: Literal[404]
469
521
  resource: (
470
522
  Literal[
471
523
  "org",
@@ -545,7 +597,7 @@ class PredictionFeedbackResult(TypedDict):
545
597
  new_category_ids: list[str]
546
598
 
547
599
 
548
- PredictionSort = list[list]
600
+ PredictionSort: TypeAlias = list[tuple[Literal["timestamp", "confidence", "anomaly_score"], Literal["asc", "desc"]]]
549
601
 
550
602
 
551
603
  class PredictiveModelUpdate(TypedDict):
@@ -554,15 +606,18 @@ class PredictiveModelUpdate(TypedDict):
554
606
  locked: NotRequired[bool]
555
607
 
556
608
 
557
- PretrainedEmbeddingModelName = Literal[
609
+ PretrainedEmbeddingModelName: TypeAlias = Literal[
558
610
  "CLIP_BASE", "GTE_BASE", "CDE_SMALL", "DISTILBERT", "GTE_SMALL", "MXBAI_LARGE", "E5_LARGE", "BGE_BASE", "GIST_LARGE"
559
611
  ]
612
+ """
613
+ Names of pretrained embedding models that are supported by OrcaCloud
614
+ """
560
615
 
561
616
 
562
- RACHeadType = Literal["KNN", "MMOE", "FF", "BMMOE"]
617
+ RACHeadType: TypeAlias = Literal["KNN", "MMOE", "FF", "BMMOE"]
563
618
 
564
619
 
565
- RARHeadType = Literal["MMOE", "KNN"]
620
+ RARHeadType: TypeAlias = Literal["MMOE", "KNN"]
566
621
 
567
622
 
568
623
  class ROCCurve(TypedDict):
@@ -669,6 +724,7 @@ class ScorePredictionWithMemoriesAndFeedback(TypedDict):
669
724
  tags: list[str]
670
725
  explanation: str | None
671
726
  memory_id: str | None
727
+ is_in_dense_neighborhood: NotRequired[bool | None]
672
728
  feedbacks: list[PredictionFeedback]
673
729
 
674
730
 
@@ -740,7 +796,7 @@ class ScoredMemoryWithFeedbackMetrics(TypedDict):
740
796
 
741
797
 
742
798
  class ServiceUnavailableErrorResponse(TypedDict):
743
- status_code: NotRequired[int]
799
+ status_code: Literal[503]
744
800
  service: str
745
801
 
746
802
 
@@ -752,7 +808,9 @@ class SubConceptMetrics(TypedDict):
752
808
  memory_count: int
753
809
 
754
810
 
755
- TelemetryField = list
811
+ TelemetryField: TypeAlias = (
812
+ tuple[Literal["feedback_metrics"], str, Literal["avg", "count"]] | tuple[Literal["lookup"], Literal["count"]]
813
+ )
756
814
 
757
815
 
758
816
  class TelemetryFilterItem(TypedDict):
@@ -767,11 +825,11 @@ class TelemetrySortOptions(TypedDict):
767
825
 
768
826
 
769
827
  class UnauthenticatedErrorResponse(TypedDict):
770
- status_code: NotRequired[int]
828
+ status_code: Literal[401]
771
829
 
772
830
 
773
831
  class UnauthorizedErrorResponse(TypedDict):
774
- status_code: NotRequired[int]
832
+ status_code: Literal[403]
775
833
  reason: str
776
834
 
777
835
 
@@ -792,7 +850,10 @@ class ValidationError(TypedDict):
792
850
  type: str
793
851
 
794
852
 
795
- WorkerStatus = Literal["IDLE", "BUSY", "DRAINING", "SHUTDOWN", "CRASHED"]
853
+ WorkerStatus: TypeAlias = Literal["IDLE", "BUSY", "DRAINING", "SHUTDOWN", "CRASHED"]
854
+ """
855
+ Status of worker in the worker pool
856
+ """
796
857
 
797
858
 
798
859
  class GetTestErrorByStatusCodeParams(TypedDict):
@@ -868,7 +929,7 @@ class PostGpuMemorysetByNameOrIdMemoryParams(TypedDict):
868
929
  name_or_id: str
869
930
 
870
931
 
871
- PostGpuMemorysetByNameOrIdMemoryRequest = list[LabeledMemoryInsert] | list[ScoredMemoryInsert]
932
+ PostGpuMemorysetByNameOrIdMemoryRequest: TypeAlias = list[LabeledMemoryInsert] | list[ScoredMemoryInsert]
872
933
 
873
934
 
874
935
  class PatchGpuMemorysetByNameOrIdMemoriesParams(TypedDict):
@@ -1104,6 +1165,10 @@ class GetWorkerByWorkerIdParams(TypedDict):
1104
1165
 
1105
1166
  class GetTelemetryPredictionByPredictionIdParams(TypedDict):
1106
1167
  prediction_id: str
1168
+ calc_neighborhood_density: NotRequired[bool]
1169
+ """
1170
+ Calculate neighborhood density
1171
+ """
1107
1172
 
1108
1173
 
1109
1174
  class PatchTelemetryPredictionByPredictionIdParams(TypedDict):
@@ -1142,7 +1207,7 @@ class DeleteTelemetryFeedbackCategoryByNameOrIdParams(TypedDict):
1142
1207
  name_or_id: str
1143
1208
 
1144
1209
 
1145
- PutTelemetryPredictionFeedbackRequest = list[PredictionFeedbackRequest]
1210
+ PutTelemetryPredictionFeedbackRequest: TypeAlias = list[PredictionFeedbackRequest]
1146
1211
 
1147
1212
 
1148
1213
  class GetAgentsBootstrapClassificationModelByJobIdParams(TypedDict):
@@ -1195,6 +1260,8 @@ class ClassificationMetrics(TypedDict):
1195
1260
  pr_auc: NotRequired[float | None]
1196
1261
  pr_curve: NotRequired[PRCurve | None]
1197
1262
  roc_curve: NotRequired[ROCCurve | None]
1263
+ confusion_matrix: NotRequired[list[list[int]] | None]
1264
+ warnings: NotRequired[list[str]]
1198
1265
 
1199
1266
 
1200
1267
  class ClassificationModelMetadata(TypedDict):
@@ -1418,7 +1485,7 @@ class HTTPValidationError(TypedDict):
1418
1485
 
1419
1486
 
1420
1487
  class InvalidInputErrorResponse(TypedDict):
1421
- status_code: NotRequired[int]
1488
+ status_code: Literal[422]
1422
1489
  validation_issues: list[FieldValidationError]
1423
1490
 
1424
1491
 
@@ -1478,6 +1545,7 @@ class LabelPredictionWithMemoriesAndFeedback(TypedDict):
1478
1545
  tags: list[str]
1479
1546
  explanation: str | None
1480
1547
  memory_id: str | None
1548
+ is_in_dense_neighborhood: NotRequired[bool | None]
1481
1549
  feedbacks: list[PredictionFeedback]
1482
1550
 
1483
1551
 
@@ -1549,10 +1617,10 @@ class ListPredictionsRequest(TypedDict):
1549
1617
  start_timestamp: NotRequired[str | None]
1550
1618
  end_timestamp: NotRequired[str | None]
1551
1619
  memory_id: NotRequired[str | None]
1620
+ expected_label_match: NotRequired[bool | None]
1552
1621
  limit: NotRequired[int]
1553
1622
  offset: NotRequired[int | None]
1554
1623
  sort: NotRequired[PredictionSort]
1555
- expected_label_match: NotRequired[bool | None]
1556
1624
 
1557
1625
 
1558
1626
  class MemorysetAnalysisConfigs(TypedDict):
@@ -1631,10 +1699,10 @@ class WorkerInfo(TypedDict):
1631
1699
  config: dict[str, str | float | int | bool | dict[str, str] | None]
1632
1700
 
1633
1701
 
1634
- PatchGpuMemorysetByNameOrIdMemoryRequest = LabeledMemoryUpdate | ScoredMemoryUpdate
1702
+ PatchGpuMemorysetByNameOrIdMemoryRequest: TypeAlias = LabeledMemoryUpdate | ScoredMemoryUpdate
1635
1703
 
1636
1704
 
1637
- PatchGpuMemorysetByNameOrIdMemoriesRequest = list[LabeledMemoryUpdate] | list[ScoredMemoryUpdate]
1705
+ PatchGpuMemorysetByNameOrIdMemoriesRequest: TypeAlias = list[LabeledMemoryUpdate] | list[ScoredMemoryUpdate]
1638
1706
 
1639
1707
 
1640
1708
  class CascadingEditSuggestion(TypedDict):
@@ -1862,7 +1930,7 @@ class OrcaClient(Client):
1862
1930
  follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
1863
1931
  timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
1864
1932
  extensions: RequestExtensions | None = None,
1865
- ) -> bool:
1933
+ ) -> Literal[True]:
1866
1934
  """Return true only when called with a valid root API key; otherwise 401 Unauthenticated."""
1867
1935
  pass
1868
1936
 
@@ -1896,7 +1964,7 @@ class OrcaClient(Client):
1896
1964
  follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
1897
1965
  timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
1898
1966
  extensions: RequestExtensions | None = None,
1899
- ) -> bool:
1967
+ ) -> Literal[True]:
1900
1968
  """Returns true if the api key header is valid for the org (will be false for admin api key)"""
1901
1969
  pass
1902
1970
 
orca_sdk/credentials.py CHANGED
@@ -1,10 +1,8 @@
1
- import os
2
1
  from datetime import datetime
3
- from typing import Literal, NamedTuple
2
+ from typing import Literal
4
3
 
5
4
  import httpx
6
- from httpx import ConnectError, Headers, HTTPTransport
7
- from typing_extensions import deprecated
5
+ from httpx import ConnectError, Headers
8
6
 
9
7
  from .async_client import OrcaAsyncClient
10
8
  from .client import OrcaClient
@@ -132,9 +130,6 @@ class OrcaCredentials:
132
130
  client = OrcaClient._resolve_client()
133
131
  client.DELETE("/auth/api_key/{name_or_id}", params={"name_or_id": name})
134
132
 
135
- # TODO: remove deprecated methods after 2026-01-01
136
-
137
- @deprecated("Use `OrcaClient.api_key` instead")
138
133
  @staticmethod
139
134
  def set_api_key(api_key: str, check_validity: bool = True):
140
135
  """
@@ -158,21 +153,25 @@ class OrcaCredentials:
158
153
  async_client = OrcaAsyncClient._resolve_client()
159
154
  async_client.api_key = api_key
160
155
 
161
- @deprecated("Use `OrcaClient.base_url` instead")
162
156
  @staticmethod
163
157
  def get_api_url() -> str:
164
158
  """
165
159
  Get the base URL of the Orca API that is currently being used
166
160
  """
167
161
  client = OrcaClient._resolve_client()
162
+ async_client = OrcaAsyncClient._resolve_client()
163
+ if client.base_url != async_client.base_url:
164
+ raise RuntimeError("The base URL of the sync and async clients do not match")
168
165
  return str(client.base_url)
169
166
 
170
- @deprecated("Use `OrcaClient.base_url` instead")
171
167
  @staticmethod
172
168
  def set_api_url(url: str, check_validity: bool = True):
173
169
  """
174
170
  Set the base URL for the Orca API
175
171
 
172
+ Note:
173
+ The base URL can also be provided by setting the `ORCA_API_URL` environment variable
174
+
176
175
  Args:
177
176
  url: The base URL to set
178
177
  check_validity: Whether to check if there is an API running at the given base URL
@@ -197,7 +196,6 @@ class OrcaCredentials:
197
196
  if check_validity:
198
197
  OrcaCredentials.is_healthy()
199
198
 
200
- @deprecated("Use `OrcaClient.headers` instead")
201
199
  @staticmethod
202
200
  def set_api_headers(headers: dict[str, str]):
203
201
  """
orca_sdk/datasource.py CHANGED
@@ -202,10 +202,10 @@ class Datasource:
202
202
  ValueError: If a datasource already exists and if_exists is `"error"`
203
203
  """
204
204
  if description is None or isinstance(description, str):
205
- description = {dataset_name: description for dataset_name in dataset_dict.keys()}
205
+ description = {str(dataset_name): description for dataset_name in dataset_dict.keys()}
206
206
  return {
207
- dataset_name: cls.from_hf_dataset(
208
- f"{name}_{dataset_name}", dataset, if_exists=if_exists, description=description[dataset_name]
207
+ str(dataset_name): cls.from_hf_dataset(
208
+ f"{name}_{dataset_name}", dataset, if_exists=if_exists, description=description[str(dataset_name)]
209
209
  )
210
210
  for dataset_name, dataset in dataset_dict.items()
211
211
  }
orca_sdk/memoryset.py CHANGED
@@ -32,13 +32,16 @@ from .client import (
32
32
  FilterItem,
33
33
  )
34
34
  from .client import LabeledMemory as LabeledMemoryResponse
35
- from .client import LabeledMemoryInsert
35
+ from .client import (
36
+ LabeledMemoryInsert,
37
+ )
36
38
  from .client import LabeledMemoryLookup as LabeledMemoryLookupResponse
37
39
  from .client import (
38
40
  LabeledMemoryUpdate,
39
41
  LabeledMemoryWithFeedbackMetrics,
40
42
  LabelPredictionMemoryLookup,
41
43
  LabelPredictionWithMemoriesAndFeedback,
44
+ ListPredictionsRequest,
42
45
  MemoryMetrics,
43
46
  MemorysetAnalysisConfigs,
44
47
  MemorysetMetadata,
@@ -46,16 +49,18 @@ from .client import (
46
49
  MemorysetUpdate,
47
50
  MemoryType,
48
51
  OrcaClient,
49
- PredictionFeedback,
50
52
  )
51
53
  from .client import ScoredMemory as ScoredMemoryResponse
52
- from .client import ScoredMemoryInsert
54
+ from .client import (
55
+ ScoredMemoryInsert,
56
+ )
53
57
  from .client import ScoredMemoryLookup as ScoredMemoryLookupResponse
54
58
  from .client import (
55
59
  ScoredMemoryUpdate,
56
60
  ScoredMemoryWithFeedbackMetrics,
57
61
  ScorePredictionMemoryLookup,
58
62
  ScorePredictionWithMemoriesAndFeedback,
63
+ TelemetryField,
59
64
  TelemetryFilterItem,
60
65
  TelemetrySortOptions,
61
66
  )
@@ -157,9 +162,10 @@ def _parse_filter_item_from_tuple(input: FilterItemTuple) -> FilterItem | Teleme
157
162
  raise ValueError("Like filters are not supported on metric columns")
158
163
  op = cast(Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in"], op)
159
164
  value = cast(float | int | list[float] | list[int], value)
160
- return TelemetryFilterItem(field=field, op=op, value=value)
165
+ return TelemetryFilterItem(field=cast(TelemetryField, tuple(field)), op=op, value=value)
161
166
 
162
- return FilterItem(field=field, op=op, value=value)
167
+ # Convert list to tuple for FilterItem field type
168
+ return FilterItem(field=tuple(field), op=op, value=value) # type: ignore[assignment]
163
169
 
164
170
 
165
171
  def _parse_sort_item_from_tuple(
@@ -183,7 +189,8 @@ def _parse_sort_item_from_tuple(
183
189
  raise ValueError("Lookup must follow the format `lookup.count`")
184
190
  if field[1] != "count":
185
191
  raise ValueError("Lookup can only be sorted on count")
186
- return TelemetrySortOptions(field=field, direction=input[1])
192
+ # Convert list to tuple for TelemetryField type
193
+ return TelemetrySortOptions(field=cast(TelemetryField, tuple(field)), direction=input[1])
187
194
 
188
195
 
189
196
  def _parse_memory_insert(memory: dict[str, Any], type: MemoryType) -> LabeledMemoryInsert | ScoredMemoryInsert:
@@ -593,16 +600,18 @@ class LabeledMemory(MemoryBase):
593
600
  """
594
601
 
595
602
  client = OrcaClient._resolve_client()
603
+ request_json: ListPredictionsRequest = {
604
+ "memory_id": self.memory_id,
605
+ "limit": limit,
606
+ "offset": offset,
607
+ "tag": tag,
608
+ "expected_label_match": expected_label_match,
609
+ }
610
+ if sort:
611
+ request_json["sort"] = sort
596
612
  predictions_data = client.POST(
597
613
  "/telemetry/prediction",
598
- json={
599
- "memory_id": self.memory_id,
600
- "limit": limit,
601
- "offset": offset,
602
- "sort": [list(sort_item) for sort_item in sort],
603
- "tag": tag,
604
- "expected_label_match": expected_label_match,
605
- },
614
+ json=request_json,
606
615
  )
607
616
 
608
617
  # Filter to only classification predictions and convert to ClassificationPrediction objects
@@ -808,16 +817,18 @@ class ScoredMemory(MemoryBase):
808
817
  List of RegressionPrediction objects that used this memory
809
818
  """
810
819
  client = OrcaClient._resolve_client()
820
+ request_json: ListPredictionsRequest = {
821
+ "memory_id": self.memory_id,
822
+ "limit": limit,
823
+ "offset": offset,
824
+ "tag": tag,
825
+ "expected_label_match": expected_label_match,
826
+ }
827
+ if sort:
828
+ request_json["sort"] = sort
811
829
  predictions_data = client.POST(
812
830
  "/telemetry/prediction",
813
- json={
814
- "memory_id": self.memory_id,
815
- "limit": limit,
816
- "offset": offset,
817
- "sort": [list(sort_item) for sort_item in sort],
818
- "tag": tag,
819
- "expected_label_match": expected_label_match,
820
- },
831
+ json=request_json,
821
832
  )
822
833
 
823
834
  # Filter to only regression predictions and convert to RegressionPrediction objects
@@ -940,8 +951,6 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
940
951
  index_params: dict[str, Any]
941
952
  hidden: bool
942
953
 
943
- _batch_size = 32 # max number of memories to insert/update/delete in a single API call
944
-
945
954
  def __init__(self, metadata: MemorysetMetadata):
946
955
  # for internal use only, do not document
947
956
  if metadata["pretrained_embedding_model_name"]:
@@ -2532,7 +2541,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2532
2541
  ]
2533
2542
  )
2534
2543
 
2535
- def insert(self, items: Iterable[dict[str, Any]] | dict[str, Any]) -> None:
2544
+ def insert(self, items: Iterable[dict[str, Any]] | dict[str, Any], *, batch_size: int = 32) -> None:
2536
2545
  """
2537
2546
  Insert memories into the memoryset
2538
2547
 
@@ -2546,17 +2555,21 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2546
2555
  - `source_id`: Optional unique ID of the memory in a system of reference
2547
2556
  - `...`: Any other metadata to store for the memory
2548
2557
 
2558
+ batch_size: Number of memories to insert in a single API call
2559
+
2549
2560
  Examples:
2550
2561
  >>> memoryset.insert([
2551
2562
  ... {"value": "I am happy", "label": 1, "source_id": "data_123", "partition_id": "user_1", "tag": "happy"},
2552
2563
  ... {"value": "I am sad", "label": 0, "source_id": "data_124", "partition_id": "user_1", "tag": "sad"},
2553
2564
  ... ])
2554
2565
  """
2566
+ if batch_size <= 0 or batch_size > 500:
2567
+ raise ValueError("batch_size must be between 1 and 500")
2555
2568
  client = OrcaClient._resolve_client()
2556
2569
  items = cast(list[dict[str, Any]], [items]) if isinstance(items, dict) else list(items)
2557
2570
  # insert memories in batches to avoid API timeouts
2558
- for i in range(0, len(items), self._batch_size):
2559
- batch = items[i : i + self._batch_size]
2571
+ for i in range(0, len(items), batch_size):
2572
+ batch = items[i : i + batch_size]
2560
2573
  client.POST(
2561
2574
  "/gpu/memoryset/{name_or_id}/memory",
2562
2575
  params={"name_or_id": self.id},
@@ -2568,7 +2581,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2568
2581
 
2569
2582
  self.refresh()
2570
2583
 
2571
- async def ainsert(self, items: Iterable[dict[str, Any]] | dict[str, Any]) -> None:
2584
+ async def ainsert(self, items: Iterable[dict[str, Any]] | dict[str, Any], *, batch_size: int = 32) -> None:
2572
2585
  """
2573
2586
  Asynchronously insert memories into the memoryset
2574
2587
 
@@ -2583,17 +2596,21 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2583
2596
  - `partition_id`: Optional partition ID of the memory
2584
2597
  - `...`: Any other metadata to store for the memory
2585
2598
 
2599
+ batch_size: Number of memories to insert in a single API call
2600
+
2586
2601
  Examples:
2587
2602
  >>> await memoryset.ainsert([
2588
2603
  ... {"value": "I am happy", "label": 1, "source_id": "data_123", "partition_id": "user_1", "tag": "happy"},
2589
2604
  ... {"value": "I am sad", "label": 0, "source_id": "data_124", "partition_id": "user_1", "tag": "sad"},
2590
2605
  ... ])
2591
2606
  """
2607
+ if batch_size <= 0 or batch_size > 500:
2608
+ raise ValueError("batch_size must be between 1 and 500")
2592
2609
  client = OrcaAsyncClient._resolve_client()
2593
2610
  items = cast(list[dict[str, Any]], [items]) if isinstance(items, dict) else list(items)
2594
2611
  # insert memories in batches to avoid API timeouts
2595
- for i in range(0, len(items), self._batch_size):
2596
- batch = items[i : i + self._batch_size]
2612
+ for i in range(0, len(items), batch_size):
2613
+ batch = items[i : i + batch_size]
2597
2614
  await client.POST(
2598
2615
  "/gpu/memoryset/{name_or_id}/memory",
2599
2616
  params={"name_or_id": self.id},
@@ -2682,14 +2699,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2682
2699
  ]
2683
2700
 
2684
2701
  @overload
2685
- def update(self, updates: dict[str, Any]) -> MemoryT:
2702
+ def update(self, updates: dict[str, Any], *, batch_size: int = 32) -> MemoryT:
2686
2703
  pass
2687
2704
 
2688
2705
  @overload
2689
- def update(self, updates: Iterable[dict[str, Any]]) -> list[MemoryT]:
2706
+ def update(self, updates: Iterable[dict[str, Any]], *, batch_size: int = 32) -> list[MemoryT]:
2690
2707
  pass
2691
2708
 
2692
- def update(self, updates: dict[str, Any] | Iterable[dict[str, Any]]) -> MemoryT | list[MemoryT]:
2709
+ def update(
2710
+ self, updates: dict[str, Any] | Iterable[dict[str, Any]], *, batch_size: int = 32
2711
+ ) -> MemoryT | list[MemoryT]:
2693
2712
  """
2694
2713
  Update one or multiple memories in the memoryset
2695
2714
 
@@ -2704,6 +2723,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2704
2723
  - `partition_id`: Optional new partition ID of the memory
2705
2724
  - `...`: Optional new values for metadata properties
2706
2725
 
2726
+ batch_size: Number of memories to update in a single API call
2727
+
2707
2728
  Returns:
2708
2729
  Updated memory or list of updated memories
2709
2730
 
@@ -2722,12 +2743,14 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2722
2743
  ... for m in memoryset.query(filters=[("tag", "==", "happy")])
2723
2744
  ... )
2724
2745
  """
2746
+ if batch_size <= 0 or batch_size > 500:
2747
+ raise ValueError("batch_size must be between 1 and 500")
2725
2748
  client = OrcaClient._resolve_client()
2726
2749
  updates_list = cast(list[dict[str, Any]], [updates]) if isinstance(updates, dict) else list(updates)
2727
2750
  # update memories in batches to avoid API timeouts
2728
2751
  updated_memories: list[MemoryT] = []
2729
- for i in range(0, len(updates_list), self._batch_size):
2730
- batch = updates_list[i : i + self._batch_size]
2752
+ for i in range(0, len(updates_list), batch_size):
2753
+ batch = updates_list[i : i + batch_size]
2731
2754
  response = client.PATCH(
2732
2755
  "/gpu/memoryset/{name_or_id}/memories",
2733
2756
  params={"name_or_id": self.id},
@@ -2803,12 +2826,13 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2803
2826
  },
2804
2827
  )
2805
2828
 
2806
- def delete(self, memory_id: str | Iterable[str]) -> None:
2829
+ def delete(self, memory_id: str | Iterable[str], *, batch_size: int = 32) -> None:
2807
2830
  """
2808
2831
  Delete memories from the memoryset
2809
2832
 
2810
2833
  Params:
2811
2834
  memory_id: unique identifiers of the memories to delete
2835
+ batch_size: Number of memories to delete in a single API call
2812
2836
 
2813
2837
  Examples:
2814
2838
  Delete a single memory:
@@ -2821,11 +2845,13 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
2821
2845
  ... )
2822
2846
 
2823
2847
  """
2848
+ if batch_size <= 0 or batch_size > 500:
2849
+ raise ValueError("batch_size must be between 1 and 500")
2824
2850
  client = OrcaClient._resolve_client()
2825
2851
  memory_ids = [memory_id] if isinstance(memory_id, str) else list(memory_id)
2826
2852
  # delete memories in batches to avoid API timeouts
2827
- for i in range(0, len(memory_ids), self._batch_size):
2828
- batch = memory_ids[i : i + self._batch_size]
2853
+ for i in range(0, len(memory_ids), batch_size):
2854
+ batch = memory_ids[i : i + batch_size]
2829
2855
  client.POST(
2830
2856
  "/memoryset/{name_or_id}/memories/delete", params={"name_or_id": self.id}, json={"memory_ids": batch}
2831
2857
  )