orca-sdk 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_utils/analysis_ui.py +1 -1
- orca_sdk/_utils/data_parsing.py +16 -12
- orca_sdk/_utils/data_parsing_test.py +8 -8
- orca_sdk/async_client.py +96 -28
- orca_sdk/classification_model.py +184 -104
- orca_sdk/classification_model_test.py +8 -4
- orca_sdk/client.py +96 -28
- orca_sdk/credentials.py +8 -10
- orca_sdk/datasource.py +3 -3
- orca_sdk/memoryset.py +64 -38
- orca_sdk/memoryset_test.py +5 -3
- orca_sdk/regression_model.py +124 -67
- orca_sdk/regression_model_test.py +8 -4
- {orca_sdk-0.1.8.dist-info → orca_sdk-0.1.10.dist-info}/METADATA +4 -4
- {orca_sdk-0.1.8.dist-info → orca_sdk-0.1.10.dist-info}/RECORD +16 -16
- {orca_sdk-0.1.8.dist-info → orca_sdk-0.1.10.dist-info}/WHEEL +0 -0
orca_sdk/client.py
CHANGED
|
@@ -17,6 +17,7 @@ from typing import (
|
|
|
17
17
|
Mapping,
|
|
18
18
|
NotRequired,
|
|
19
19
|
Self,
|
|
20
|
+
TypeAlias,
|
|
20
21
|
TypedDict,
|
|
21
22
|
cast,
|
|
22
23
|
overload,
|
|
@@ -153,11 +154,14 @@ class ClusterMetrics(TypedDict):
|
|
|
153
154
|
memory_count: int
|
|
154
155
|
|
|
155
156
|
|
|
156
|
-
ColumnType = Literal["STRING", "FLOAT", "INT", "BOOL", "ENUM", "IMAGE", "OTHER"]
|
|
157
|
+
ColumnType: TypeAlias = Literal["STRING", "FLOAT", "INT", "BOOL", "ENUM", "IMAGE", "OTHER"]
|
|
158
|
+
"""
|
|
159
|
+
The type of a column in a datasource
|
|
160
|
+
"""
|
|
157
161
|
|
|
158
162
|
|
|
159
163
|
class ConstraintViolationErrorResponse(TypedDict):
|
|
160
|
-
status_code:
|
|
164
|
+
status_code: Literal[409]
|
|
161
165
|
constraint: str
|
|
162
166
|
|
|
163
167
|
|
|
@@ -168,6 +172,7 @@ class CountPredictionsRequest(TypedDict):
|
|
|
168
172
|
start_timestamp: NotRequired[str | None]
|
|
169
173
|
end_timestamp: NotRequired[str | None]
|
|
170
174
|
memory_id: NotRequired[str | None]
|
|
175
|
+
expected_label_match: NotRequired[bool | None]
|
|
171
176
|
|
|
172
177
|
|
|
173
178
|
class CreateApiKeyRequest(TypedDict):
|
|
@@ -230,7 +235,7 @@ class EmbeddingEvaluationRequest(TypedDict):
|
|
|
230
235
|
weigh_memories: NotRequired[bool]
|
|
231
236
|
|
|
232
237
|
|
|
233
|
-
EmbeddingFinetuningMethod = Literal["classification", "regression", "batch_triplet_loss"]
|
|
238
|
+
EmbeddingFinetuningMethod: TypeAlias = Literal["classification", "regression", "batch_triplet_loss"]
|
|
234
239
|
|
|
235
240
|
|
|
236
241
|
class FeedbackMetrics(TypedDict):
|
|
@@ -238,13 +243,55 @@ class FeedbackMetrics(TypedDict):
|
|
|
238
243
|
count: int
|
|
239
244
|
|
|
240
245
|
|
|
241
|
-
FeedbackType = Literal["CONTINUOUS", "BINARY"]
|
|
246
|
+
FeedbackType: TypeAlias = Literal["CONTINUOUS", "BINARY"]
|
|
242
247
|
|
|
243
248
|
|
|
244
249
|
class FilterItem(TypedDict):
|
|
245
|
-
field:
|
|
246
|
-
|
|
250
|
+
field: (
|
|
251
|
+
tuple[
|
|
252
|
+
Literal[
|
|
253
|
+
"memory_id",
|
|
254
|
+
"value",
|
|
255
|
+
"label",
|
|
256
|
+
"metadata",
|
|
257
|
+
"source_id",
|
|
258
|
+
"partition_id",
|
|
259
|
+
"created_at",
|
|
260
|
+
"updated_at",
|
|
261
|
+
"edited_at",
|
|
262
|
+
"metrics",
|
|
263
|
+
"score",
|
|
264
|
+
"labels",
|
|
265
|
+
]
|
|
266
|
+
]
|
|
267
|
+
| tuple[Literal["metadata"], str]
|
|
268
|
+
| tuple[
|
|
269
|
+
Literal["metrics"],
|
|
270
|
+
Literal[
|
|
271
|
+
"cluster",
|
|
272
|
+
"embedding_2d",
|
|
273
|
+
"is_duplicate",
|
|
274
|
+
"duplicate_memory_ids",
|
|
275
|
+
"has_potential_duplicates",
|
|
276
|
+
"potential_duplicate_memory_ids",
|
|
277
|
+
"anomaly_score",
|
|
278
|
+
"neighbor_label_logits",
|
|
279
|
+
"neighbor_predicted_label",
|
|
280
|
+
"neighbor_predicted_label_ambiguity",
|
|
281
|
+
"neighbor_predicted_label_confidence",
|
|
282
|
+
"current_label_neighbor_confidence",
|
|
283
|
+
"normalized_neighbor_label_entropy",
|
|
284
|
+
"neighbor_predicted_label_matches_current_label",
|
|
285
|
+
"spread",
|
|
286
|
+
"uniformity",
|
|
287
|
+
"concept_id",
|
|
288
|
+
"subconcept_id",
|
|
289
|
+
],
|
|
290
|
+
]
|
|
291
|
+
)
|
|
292
|
+
op: Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in", "like", "contains all", "contains any"]
|
|
247
293
|
value: str | int | float | bool | list[str | None] | list[int] | list[float] | list[bool] | None
|
|
294
|
+
transform: NotRequired[Literal["length"]]
|
|
248
295
|
|
|
249
296
|
|
|
250
297
|
class GetDatasourceRowCountRequest(TypedDict):
|
|
@@ -272,12 +319,17 @@ class HealthyResponse(TypedDict):
|
|
|
272
319
|
|
|
273
320
|
|
|
274
321
|
class InternalServerErrorResponse(TypedDict):
|
|
275
|
-
status_code:
|
|
322
|
+
status_code: Literal[500]
|
|
276
323
|
message: str
|
|
277
324
|
request_id: str
|
|
278
325
|
|
|
279
326
|
|
|
280
|
-
JobStatus = Literal[
|
|
327
|
+
JobStatus: TypeAlias = Literal[
|
|
328
|
+
"INITIALIZED", "DISPATCHED", "WAITING", "PROCESSING", "COMPLETED", "FAILED", "ABORTING", "ABORTED"
|
|
329
|
+
]
|
|
330
|
+
"""
|
|
331
|
+
Status of job in the job queue
|
|
332
|
+
"""
|
|
281
333
|
|
|
282
334
|
|
|
283
335
|
class JobStatusInfo(TypedDict):
|
|
@@ -344,7 +396,7 @@ class MemoryMetrics(TypedDict):
|
|
|
344
396
|
has_potential_duplicates: NotRequired[bool]
|
|
345
397
|
potential_duplicate_memory_ids: NotRequired[list[str] | None]
|
|
346
398
|
cluster: NotRequired[int]
|
|
347
|
-
embedding_2d: NotRequired[
|
|
399
|
+
embedding_2d: NotRequired[tuple[float, float]]
|
|
348
400
|
anomaly_score: NotRequired[float]
|
|
349
401
|
neighbor_label_logits: NotRequired[list[float] | None]
|
|
350
402
|
neighbor_predicted_label: NotRequired[int | None]
|
|
@@ -359,7 +411,7 @@ class MemoryMetrics(TypedDict):
|
|
|
359
411
|
subconcept_id: NotRequired[int | None]
|
|
360
412
|
|
|
361
413
|
|
|
362
|
-
MemoryType = Literal["LABELED", "SCORED"]
|
|
414
|
+
MemoryType: TypeAlias = Literal["LABELED", "SCORED"]
|
|
363
415
|
|
|
364
416
|
|
|
365
417
|
class MemorysetClassPatternsAnalysisConfig(TypedDict):
|
|
@@ -465,7 +517,7 @@ class MemorysetUpdate(TypedDict):
|
|
|
465
517
|
|
|
466
518
|
|
|
467
519
|
class NotFoundErrorResponse(TypedDict):
|
|
468
|
-
status_code:
|
|
520
|
+
status_code: Literal[404]
|
|
469
521
|
resource: (
|
|
470
522
|
Literal[
|
|
471
523
|
"org",
|
|
@@ -545,7 +597,7 @@ class PredictionFeedbackResult(TypedDict):
|
|
|
545
597
|
new_category_ids: list[str]
|
|
546
598
|
|
|
547
599
|
|
|
548
|
-
PredictionSort = list[
|
|
600
|
+
PredictionSort: TypeAlias = list[tuple[Literal["timestamp", "confidence", "anomaly_score"], Literal["asc", "desc"]]]
|
|
549
601
|
|
|
550
602
|
|
|
551
603
|
class PredictiveModelUpdate(TypedDict):
|
|
@@ -554,15 +606,18 @@ class PredictiveModelUpdate(TypedDict):
|
|
|
554
606
|
locked: NotRequired[bool]
|
|
555
607
|
|
|
556
608
|
|
|
557
|
-
PretrainedEmbeddingModelName = Literal[
|
|
609
|
+
PretrainedEmbeddingModelName: TypeAlias = Literal[
|
|
558
610
|
"CLIP_BASE", "GTE_BASE", "CDE_SMALL", "DISTILBERT", "GTE_SMALL", "MXBAI_LARGE", "E5_LARGE", "BGE_BASE", "GIST_LARGE"
|
|
559
611
|
]
|
|
612
|
+
"""
|
|
613
|
+
Names of pretrained embedding models that are supported by OrcaCloud
|
|
614
|
+
"""
|
|
560
615
|
|
|
561
616
|
|
|
562
|
-
RACHeadType = Literal["KNN", "MMOE", "FF", "BMMOE"]
|
|
617
|
+
RACHeadType: TypeAlias = Literal["KNN", "MMOE", "FF", "BMMOE"]
|
|
563
618
|
|
|
564
619
|
|
|
565
|
-
RARHeadType = Literal["MMOE", "KNN"]
|
|
620
|
+
RARHeadType: TypeAlias = Literal["MMOE", "KNN"]
|
|
566
621
|
|
|
567
622
|
|
|
568
623
|
class ROCCurve(TypedDict):
|
|
@@ -669,6 +724,7 @@ class ScorePredictionWithMemoriesAndFeedback(TypedDict):
|
|
|
669
724
|
tags: list[str]
|
|
670
725
|
explanation: str | None
|
|
671
726
|
memory_id: str | None
|
|
727
|
+
is_in_dense_neighborhood: NotRequired[bool | None]
|
|
672
728
|
feedbacks: list[PredictionFeedback]
|
|
673
729
|
|
|
674
730
|
|
|
@@ -740,7 +796,7 @@ class ScoredMemoryWithFeedbackMetrics(TypedDict):
|
|
|
740
796
|
|
|
741
797
|
|
|
742
798
|
class ServiceUnavailableErrorResponse(TypedDict):
|
|
743
|
-
status_code:
|
|
799
|
+
status_code: Literal[503]
|
|
744
800
|
service: str
|
|
745
801
|
|
|
746
802
|
|
|
@@ -752,7 +808,9 @@ class SubConceptMetrics(TypedDict):
|
|
|
752
808
|
memory_count: int
|
|
753
809
|
|
|
754
810
|
|
|
755
|
-
TelemetryField =
|
|
811
|
+
TelemetryField: TypeAlias = (
|
|
812
|
+
tuple[Literal["feedback_metrics"], str, Literal["avg", "count"]] | tuple[Literal["lookup"], Literal["count"]]
|
|
813
|
+
)
|
|
756
814
|
|
|
757
815
|
|
|
758
816
|
class TelemetryFilterItem(TypedDict):
|
|
@@ -767,11 +825,11 @@ class TelemetrySortOptions(TypedDict):
|
|
|
767
825
|
|
|
768
826
|
|
|
769
827
|
class UnauthenticatedErrorResponse(TypedDict):
|
|
770
|
-
status_code:
|
|
828
|
+
status_code: Literal[401]
|
|
771
829
|
|
|
772
830
|
|
|
773
831
|
class UnauthorizedErrorResponse(TypedDict):
|
|
774
|
-
status_code:
|
|
832
|
+
status_code: Literal[403]
|
|
775
833
|
reason: str
|
|
776
834
|
|
|
777
835
|
|
|
@@ -792,7 +850,10 @@ class ValidationError(TypedDict):
|
|
|
792
850
|
type: str
|
|
793
851
|
|
|
794
852
|
|
|
795
|
-
WorkerStatus = Literal["IDLE", "BUSY", "DRAINING", "SHUTDOWN", "CRASHED"]
|
|
853
|
+
WorkerStatus: TypeAlias = Literal["IDLE", "BUSY", "DRAINING", "SHUTDOWN", "CRASHED"]
|
|
854
|
+
"""
|
|
855
|
+
Status of worker in the worker pool
|
|
856
|
+
"""
|
|
796
857
|
|
|
797
858
|
|
|
798
859
|
class GetTestErrorByStatusCodeParams(TypedDict):
|
|
@@ -868,7 +929,7 @@ class PostGpuMemorysetByNameOrIdMemoryParams(TypedDict):
|
|
|
868
929
|
name_or_id: str
|
|
869
930
|
|
|
870
931
|
|
|
871
|
-
PostGpuMemorysetByNameOrIdMemoryRequest = list[LabeledMemoryInsert] | list[ScoredMemoryInsert]
|
|
932
|
+
PostGpuMemorysetByNameOrIdMemoryRequest: TypeAlias = list[LabeledMemoryInsert] | list[ScoredMemoryInsert]
|
|
872
933
|
|
|
873
934
|
|
|
874
935
|
class PatchGpuMemorysetByNameOrIdMemoriesParams(TypedDict):
|
|
@@ -1104,6 +1165,10 @@ class GetWorkerByWorkerIdParams(TypedDict):
|
|
|
1104
1165
|
|
|
1105
1166
|
class GetTelemetryPredictionByPredictionIdParams(TypedDict):
|
|
1106
1167
|
prediction_id: str
|
|
1168
|
+
calc_neighborhood_density: NotRequired[bool]
|
|
1169
|
+
"""
|
|
1170
|
+
Calculate neighborhood density
|
|
1171
|
+
"""
|
|
1107
1172
|
|
|
1108
1173
|
|
|
1109
1174
|
class PatchTelemetryPredictionByPredictionIdParams(TypedDict):
|
|
@@ -1142,7 +1207,7 @@ class DeleteTelemetryFeedbackCategoryByNameOrIdParams(TypedDict):
|
|
|
1142
1207
|
name_or_id: str
|
|
1143
1208
|
|
|
1144
1209
|
|
|
1145
|
-
PutTelemetryPredictionFeedbackRequest = list[PredictionFeedbackRequest]
|
|
1210
|
+
PutTelemetryPredictionFeedbackRequest: TypeAlias = list[PredictionFeedbackRequest]
|
|
1146
1211
|
|
|
1147
1212
|
|
|
1148
1213
|
class GetAgentsBootstrapClassificationModelByJobIdParams(TypedDict):
|
|
@@ -1195,6 +1260,8 @@ class ClassificationMetrics(TypedDict):
|
|
|
1195
1260
|
pr_auc: NotRequired[float | None]
|
|
1196
1261
|
pr_curve: NotRequired[PRCurve | None]
|
|
1197
1262
|
roc_curve: NotRequired[ROCCurve | None]
|
|
1263
|
+
confusion_matrix: NotRequired[list[list[int]] | None]
|
|
1264
|
+
warnings: NotRequired[list[str]]
|
|
1198
1265
|
|
|
1199
1266
|
|
|
1200
1267
|
class ClassificationModelMetadata(TypedDict):
|
|
@@ -1418,7 +1485,7 @@ class HTTPValidationError(TypedDict):
|
|
|
1418
1485
|
|
|
1419
1486
|
|
|
1420
1487
|
class InvalidInputErrorResponse(TypedDict):
|
|
1421
|
-
status_code:
|
|
1488
|
+
status_code: Literal[422]
|
|
1422
1489
|
validation_issues: list[FieldValidationError]
|
|
1423
1490
|
|
|
1424
1491
|
|
|
@@ -1478,6 +1545,7 @@ class LabelPredictionWithMemoriesAndFeedback(TypedDict):
|
|
|
1478
1545
|
tags: list[str]
|
|
1479
1546
|
explanation: str | None
|
|
1480
1547
|
memory_id: str | None
|
|
1548
|
+
is_in_dense_neighborhood: NotRequired[bool | None]
|
|
1481
1549
|
feedbacks: list[PredictionFeedback]
|
|
1482
1550
|
|
|
1483
1551
|
|
|
@@ -1549,10 +1617,10 @@ class ListPredictionsRequest(TypedDict):
|
|
|
1549
1617
|
start_timestamp: NotRequired[str | None]
|
|
1550
1618
|
end_timestamp: NotRequired[str | None]
|
|
1551
1619
|
memory_id: NotRequired[str | None]
|
|
1620
|
+
expected_label_match: NotRequired[bool | None]
|
|
1552
1621
|
limit: NotRequired[int]
|
|
1553
1622
|
offset: NotRequired[int | None]
|
|
1554
1623
|
sort: NotRequired[PredictionSort]
|
|
1555
|
-
expected_label_match: NotRequired[bool | None]
|
|
1556
1624
|
|
|
1557
1625
|
|
|
1558
1626
|
class MemorysetAnalysisConfigs(TypedDict):
|
|
@@ -1631,10 +1699,10 @@ class WorkerInfo(TypedDict):
|
|
|
1631
1699
|
config: dict[str, str | float | int | bool | dict[str, str] | None]
|
|
1632
1700
|
|
|
1633
1701
|
|
|
1634
|
-
PatchGpuMemorysetByNameOrIdMemoryRequest = LabeledMemoryUpdate | ScoredMemoryUpdate
|
|
1702
|
+
PatchGpuMemorysetByNameOrIdMemoryRequest: TypeAlias = LabeledMemoryUpdate | ScoredMemoryUpdate
|
|
1635
1703
|
|
|
1636
1704
|
|
|
1637
|
-
PatchGpuMemorysetByNameOrIdMemoriesRequest = list[LabeledMemoryUpdate] | list[ScoredMemoryUpdate]
|
|
1705
|
+
PatchGpuMemorysetByNameOrIdMemoriesRequest: TypeAlias = list[LabeledMemoryUpdate] | list[ScoredMemoryUpdate]
|
|
1638
1706
|
|
|
1639
1707
|
|
|
1640
1708
|
class CascadingEditSuggestion(TypedDict):
|
|
@@ -1862,7 +1930,7 @@ class OrcaClient(Client):
|
|
|
1862
1930
|
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
1863
1931
|
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
1864
1932
|
extensions: RequestExtensions | None = None,
|
|
1865
|
-
) ->
|
|
1933
|
+
) -> Literal[True]:
|
|
1866
1934
|
"""Return true only when called with a valid root API key; otherwise 401 Unauthenticated."""
|
|
1867
1935
|
pass
|
|
1868
1936
|
|
|
@@ -1896,7 +1964,7 @@ class OrcaClient(Client):
|
|
|
1896
1964
|
follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
1897
1965
|
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
|
|
1898
1966
|
extensions: RequestExtensions | None = None,
|
|
1899
|
-
) ->
|
|
1967
|
+
) -> Literal[True]:
|
|
1900
1968
|
"""Returns true if the api key header is valid for the org (will be false for admin api key)"""
|
|
1901
1969
|
pass
|
|
1902
1970
|
|
orca_sdk/credentials.py
CHANGED
|
@@ -1,10 +1,8 @@
|
|
|
1
|
-
import os
|
|
2
1
|
from datetime import datetime
|
|
3
|
-
from typing import Literal
|
|
2
|
+
from typing import Literal
|
|
4
3
|
|
|
5
4
|
import httpx
|
|
6
|
-
from httpx import ConnectError, Headers
|
|
7
|
-
from typing_extensions import deprecated
|
|
5
|
+
from httpx import ConnectError, Headers
|
|
8
6
|
|
|
9
7
|
from .async_client import OrcaAsyncClient
|
|
10
8
|
from .client import OrcaClient
|
|
@@ -132,9 +130,6 @@ class OrcaCredentials:
|
|
|
132
130
|
client = OrcaClient._resolve_client()
|
|
133
131
|
client.DELETE("/auth/api_key/{name_or_id}", params={"name_or_id": name})
|
|
134
132
|
|
|
135
|
-
# TODO: remove deprecated methods after 2026-01-01
|
|
136
|
-
|
|
137
|
-
@deprecated("Use `OrcaClient.api_key` instead")
|
|
138
133
|
@staticmethod
|
|
139
134
|
def set_api_key(api_key: str, check_validity: bool = True):
|
|
140
135
|
"""
|
|
@@ -158,21 +153,25 @@ class OrcaCredentials:
|
|
|
158
153
|
async_client = OrcaAsyncClient._resolve_client()
|
|
159
154
|
async_client.api_key = api_key
|
|
160
155
|
|
|
161
|
-
@deprecated("Use `OrcaClient.base_url` instead")
|
|
162
156
|
@staticmethod
|
|
163
157
|
def get_api_url() -> str:
|
|
164
158
|
"""
|
|
165
159
|
Get the base URL of the Orca API that is currently being used
|
|
166
160
|
"""
|
|
167
161
|
client = OrcaClient._resolve_client()
|
|
162
|
+
async_client = OrcaAsyncClient._resolve_client()
|
|
163
|
+
if client.base_url != async_client.base_url:
|
|
164
|
+
raise RuntimeError("The base URL of the sync and async clients do not match")
|
|
168
165
|
return str(client.base_url)
|
|
169
166
|
|
|
170
|
-
@deprecated("Use `OrcaClient.base_url` instead")
|
|
171
167
|
@staticmethod
|
|
172
168
|
def set_api_url(url: str, check_validity: bool = True):
|
|
173
169
|
"""
|
|
174
170
|
Set the base URL for the Orca API
|
|
175
171
|
|
|
172
|
+
Note:
|
|
173
|
+
The base URL can also be provided by setting the `ORCA_API_URL` environment variable
|
|
174
|
+
|
|
176
175
|
Args:
|
|
177
176
|
url: The base URL to set
|
|
178
177
|
check_validity: Whether to check if there is an API running at the given base URL
|
|
@@ -197,7 +196,6 @@ class OrcaCredentials:
|
|
|
197
196
|
if check_validity:
|
|
198
197
|
OrcaCredentials.is_healthy()
|
|
199
198
|
|
|
200
|
-
@deprecated("Use `OrcaClient.headers` instead")
|
|
201
199
|
@staticmethod
|
|
202
200
|
def set_api_headers(headers: dict[str, str]):
|
|
203
201
|
"""
|
orca_sdk/datasource.py
CHANGED
|
@@ -202,10 +202,10 @@ class Datasource:
|
|
|
202
202
|
ValueError: If a datasource already exists and if_exists is `"error"`
|
|
203
203
|
"""
|
|
204
204
|
if description is None or isinstance(description, str):
|
|
205
|
-
description = {dataset_name: description for dataset_name in dataset_dict.keys()}
|
|
205
|
+
description = {str(dataset_name): description for dataset_name in dataset_dict.keys()}
|
|
206
206
|
return {
|
|
207
|
-
dataset_name: cls.from_hf_dataset(
|
|
208
|
-
f"{name}_{dataset_name}", dataset, if_exists=if_exists, description=description[dataset_name]
|
|
207
|
+
str(dataset_name): cls.from_hf_dataset(
|
|
208
|
+
f"{name}_{dataset_name}", dataset, if_exists=if_exists, description=description[str(dataset_name)]
|
|
209
209
|
)
|
|
210
210
|
for dataset_name, dataset in dataset_dict.items()
|
|
211
211
|
}
|
orca_sdk/memoryset.py
CHANGED
|
@@ -32,13 +32,16 @@ from .client import (
|
|
|
32
32
|
FilterItem,
|
|
33
33
|
)
|
|
34
34
|
from .client import LabeledMemory as LabeledMemoryResponse
|
|
35
|
-
from .client import
|
|
35
|
+
from .client import (
|
|
36
|
+
LabeledMemoryInsert,
|
|
37
|
+
)
|
|
36
38
|
from .client import LabeledMemoryLookup as LabeledMemoryLookupResponse
|
|
37
39
|
from .client import (
|
|
38
40
|
LabeledMemoryUpdate,
|
|
39
41
|
LabeledMemoryWithFeedbackMetrics,
|
|
40
42
|
LabelPredictionMemoryLookup,
|
|
41
43
|
LabelPredictionWithMemoriesAndFeedback,
|
|
44
|
+
ListPredictionsRequest,
|
|
42
45
|
MemoryMetrics,
|
|
43
46
|
MemorysetAnalysisConfigs,
|
|
44
47
|
MemorysetMetadata,
|
|
@@ -46,16 +49,18 @@ from .client import (
|
|
|
46
49
|
MemorysetUpdate,
|
|
47
50
|
MemoryType,
|
|
48
51
|
OrcaClient,
|
|
49
|
-
PredictionFeedback,
|
|
50
52
|
)
|
|
51
53
|
from .client import ScoredMemory as ScoredMemoryResponse
|
|
52
|
-
from .client import
|
|
54
|
+
from .client import (
|
|
55
|
+
ScoredMemoryInsert,
|
|
56
|
+
)
|
|
53
57
|
from .client import ScoredMemoryLookup as ScoredMemoryLookupResponse
|
|
54
58
|
from .client import (
|
|
55
59
|
ScoredMemoryUpdate,
|
|
56
60
|
ScoredMemoryWithFeedbackMetrics,
|
|
57
61
|
ScorePredictionMemoryLookup,
|
|
58
62
|
ScorePredictionWithMemoriesAndFeedback,
|
|
63
|
+
TelemetryField,
|
|
59
64
|
TelemetryFilterItem,
|
|
60
65
|
TelemetrySortOptions,
|
|
61
66
|
)
|
|
@@ -157,9 +162,10 @@ def _parse_filter_item_from_tuple(input: FilterItemTuple) -> FilterItem | Teleme
|
|
|
157
162
|
raise ValueError("Like filters are not supported on metric columns")
|
|
158
163
|
op = cast(Literal["==", "!=", ">", ">=", "<", "<=", "in", "not in"], op)
|
|
159
164
|
value = cast(float | int | list[float] | list[int], value)
|
|
160
|
-
return TelemetryFilterItem(field=field, op=op, value=value)
|
|
165
|
+
return TelemetryFilterItem(field=cast(TelemetryField, tuple(field)), op=op, value=value)
|
|
161
166
|
|
|
162
|
-
|
|
167
|
+
# Convert list to tuple for FilterItem field type
|
|
168
|
+
return FilterItem(field=tuple(field), op=op, value=value) # type: ignore[assignment]
|
|
163
169
|
|
|
164
170
|
|
|
165
171
|
def _parse_sort_item_from_tuple(
|
|
@@ -183,7 +189,8 @@ def _parse_sort_item_from_tuple(
|
|
|
183
189
|
raise ValueError("Lookup must follow the format `lookup.count`")
|
|
184
190
|
if field[1] != "count":
|
|
185
191
|
raise ValueError("Lookup can only be sorted on count")
|
|
186
|
-
|
|
192
|
+
# Convert list to tuple for TelemetryField type
|
|
193
|
+
return TelemetrySortOptions(field=cast(TelemetryField, tuple(field)), direction=input[1])
|
|
187
194
|
|
|
188
195
|
|
|
189
196
|
def _parse_memory_insert(memory: dict[str, Any], type: MemoryType) -> LabeledMemoryInsert | ScoredMemoryInsert:
|
|
@@ -593,16 +600,18 @@ class LabeledMemory(MemoryBase):
|
|
|
593
600
|
"""
|
|
594
601
|
|
|
595
602
|
client = OrcaClient._resolve_client()
|
|
603
|
+
request_json: ListPredictionsRequest = {
|
|
604
|
+
"memory_id": self.memory_id,
|
|
605
|
+
"limit": limit,
|
|
606
|
+
"offset": offset,
|
|
607
|
+
"tag": tag,
|
|
608
|
+
"expected_label_match": expected_label_match,
|
|
609
|
+
}
|
|
610
|
+
if sort:
|
|
611
|
+
request_json["sort"] = sort
|
|
596
612
|
predictions_data = client.POST(
|
|
597
613
|
"/telemetry/prediction",
|
|
598
|
-
json=
|
|
599
|
-
"memory_id": self.memory_id,
|
|
600
|
-
"limit": limit,
|
|
601
|
-
"offset": offset,
|
|
602
|
-
"sort": [list(sort_item) for sort_item in sort],
|
|
603
|
-
"tag": tag,
|
|
604
|
-
"expected_label_match": expected_label_match,
|
|
605
|
-
},
|
|
614
|
+
json=request_json,
|
|
606
615
|
)
|
|
607
616
|
|
|
608
617
|
# Filter to only classification predictions and convert to ClassificationPrediction objects
|
|
@@ -808,16 +817,18 @@ class ScoredMemory(MemoryBase):
|
|
|
808
817
|
List of RegressionPrediction objects that used this memory
|
|
809
818
|
"""
|
|
810
819
|
client = OrcaClient._resolve_client()
|
|
820
|
+
request_json: ListPredictionsRequest = {
|
|
821
|
+
"memory_id": self.memory_id,
|
|
822
|
+
"limit": limit,
|
|
823
|
+
"offset": offset,
|
|
824
|
+
"tag": tag,
|
|
825
|
+
"expected_label_match": expected_label_match,
|
|
826
|
+
}
|
|
827
|
+
if sort:
|
|
828
|
+
request_json["sort"] = sort
|
|
811
829
|
predictions_data = client.POST(
|
|
812
830
|
"/telemetry/prediction",
|
|
813
|
-
json=
|
|
814
|
-
"memory_id": self.memory_id,
|
|
815
|
-
"limit": limit,
|
|
816
|
-
"offset": offset,
|
|
817
|
-
"sort": [list(sort_item) for sort_item in sort],
|
|
818
|
-
"tag": tag,
|
|
819
|
-
"expected_label_match": expected_label_match,
|
|
820
|
-
},
|
|
831
|
+
json=request_json,
|
|
821
832
|
)
|
|
822
833
|
|
|
823
834
|
# Filter to only regression predictions and convert to RegressionPrediction objects
|
|
@@ -940,8 +951,6 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
940
951
|
index_params: dict[str, Any]
|
|
941
952
|
hidden: bool
|
|
942
953
|
|
|
943
|
-
_batch_size = 32 # max number of memories to insert/update/delete in a single API call
|
|
944
|
-
|
|
945
954
|
def __init__(self, metadata: MemorysetMetadata):
|
|
946
955
|
# for internal use only, do not document
|
|
947
956
|
if metadata["pretrained_embedding_model_name"]:
|
|
@@ -2532,7 +2541,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2532
2541
|
]
|
|
2533
2542
|
)
|
|
2534
2543
|
|
|
2535
|
-
def insert(self, items: Iterable[dict[str, Any]] | dict[str, Any]) -> None:
|
|
2544
|
+
def insert(self, items: Iterable[dict[str, Any]] | dict[str, Any], *, batch_size: int = 32) -> None:
|
|
2536
2545
|
"""
|
|
2537
2546
|
Insert memories into the memoryset
|
|
2538
2547
|
|
|
@@ -2546,17 +2555,21 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2546
2555
|
- `source_id`: Optional unique ID of the memory in a system of reference
|
|
2547
2556
|
- `...`: Any other metadata to store for the memory
|
|
2548
2557
|
|
|
2558
|
+
batch_size: Number of memories to insert in a single API call
|
|
2559
|
+
|
|
2549
2560
|
Examples:
|
|
2550
2561
|
>>> memoryset.insert([
|
|
2551
2562
|
... {"value": "I am happy", "label": 1, "source_id": "data_123", "partition_id": "user_1", "tag": "happy"},
|
|
2552
2563
|
... {"value": "I am sad", "label": 0, "source_id": "data_124", "partition_id": "user_1", "tag": "sad"},
|
|
2553
2564
|
... ])
|
|
2554
2565
|
"""
|
|
2566
|
+
if batch_size <= 0 or batch_size > 500:
|
|
2567
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
2555
2568
|
client = OrcaClient._resolve_client()
|
|
2556
2569
|
items = cast(list[dict[str, Any]], [items]) if isinstance(items, dict) else list(items)
|
|
2557
2570
|
# insert memories in batches to avoid API timeouts
|
|
2558
|
-
for i in range(0, len(items),
|
|
2559
|
-
batch = items[i : i +
|
|
2571
|
+
for i in range(0, len(items), batch_size):
|
|
2572
|
+
batch = items[i : i + batch_size]
|
|
2560
2573
|
client.POST(
|
|
2561
2574
|
"/gpu/memoryset/{name_or_id}/memory",
|
|
2562
2575
|
params={"name_or_id": self.id},
|
|
@@ -2568,7 +2581,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2568
2581
|
|
|
2569
2582
|
self.refresh()
|
|
2570
2583
|
|
|
2571
|
-
async def ainsert(self, items: Iterable[dict[str, Any]] | dict[str, Any]) -> None:
|
|
2584
|
+
async def ainsert(self, items: Iterable[dict[str, Any]] | dict[str, Any], *, batch_size: int = 32) -> None:
|
|
2572
2585
|
"""
|
|
2573
2586
|
Asynchronously insert memories into the memoryset
|
|
2574
2587
|
|
|
@@ -2583,17 +2596,21 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2583
2596
|
- `partition_id`: Optional partition ID of the memory
|
|
2584
2597
|
- `...`: Any other metadata to store for the memory
|
|
2585
2598
|
|
|
2599
|
+
batch_size: Number of memories to insert in a single API call
|
|
2600
|
+
|
|
2586
2601
|
Examples:
|
|
2587
2602
|
>>> await memoryset.ainsert([
|
|
2588
2603
|
... {"value": "I am happy", "label": 1, "source_id": "data_123", "partition_id": "user_1", "tag": "happy"},
|
|
2589
2604
|
... {"value": "I am sad", "label": 0, "source_id": "data_124", "partition_id": "user_1", "tag": "sad"},
|
|
2590
2605
|
... ])
|
|
2591
2606
|
"""
|
|
2607
|
+
if batch_size <= 0 or batch_size > 500:
|
|
2608
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
2592
2609
|
client = OrcaAsyncClient._resolve_client()
|
|
2593
2610
|
items = cast(list[dict[str, Any]], [items]) if isinstance(items, dict) else list(items)
|
|
2594
2611
|
# insert memories in batches to avoid API timeouts
|
|
2595
|
-
for i in range(0, len(items),
|
|
2596
|
-
batch = items[i : i +
|
|
2612
|
+
for i in range(0, len(items), batch_size):
|
|
2613
|
+
batch = items[i : i + batch_size]
|
|
2597
2614
|
await client.POST(
|
|
2598
2615
|
"/gpu/memoryset/{name_or_id}/memory",
|
|
2599
2616
|
params={"name_or_id": self.id},
|
|
@@ -2682,14 +2699,16 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2682
2699
|
]
|
|
2683
2700
|
|
|
2684
2701
|
@overload
|
|
2685
|
-
def update(self, updates: dict[str, Any]) -> MemoryT:
|
|
2702
|
+
def update(self, updates: dict[str, Any], *, batch_size: int = 32) -> MemoryT:
|
|
2686
2703
|
pass
|
|
2687
2704
|
|
|
2688
2705
|
@overload
|
|
2689
|
-
def update(self, updates: Iterable[dict[str, Any]]) -> list[MemoryT]:
|
|
2706
|
+
def update(self, updates: Iterable[dict[str, Any]], *, batch_size: int = 32) -> list[MemoryT]:
|
|
2690
2707
|
pass
|
|
2691
2708
|
|
|
2692
|
-
def update(
|
|
2709
|
+
def update(
|
|
2710
|
+
self, updates: dict[str, Any] | Iterable[dict[str, Any]], *, batch_size: int = 32
|
|
2711
|
+
) -> MemoryT | list[MemoryT]:
|
|
2693
2712
|
"""
|
|
2694
2713
|
Update one or multiple memories in the memoryset
|
|
2695
2714
|
|
|
@@ -2704,6 +2723,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2704
2723
|
- `partition_id`: Optional new partition ID of the memory
|
|
2705
2724
|
- `...`: Optional new values for metadata properties
|
|
2706
2725
|
|
|
2726
|
+
batch_size: Number of memories to update in a single API call
|
|
2727
|
+
|
|
2707
2728
|
Returns:
|
|
2708
2729
|
Updated memory or list of updated memories
|
|
2709
2730
|
|
|
@@ -2722,12 +2743,14 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2722
2743
|
... for m in memoryset.query(filters=[("tag", "==", "happy")])
|
|
2723
2744
|
... )
|
|
2724
2745
|
"""
|
|
2746
|
+
if batch_size <= 0 or batch_size > 500:
|
|
2747
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
2725
2748
|
client = OrcaClient._resolve_client()
|
|
2726
2749
|
updates_list = cast(list[dict[str, Any]], [updates]) if isinstance(updates, dict) else list(updates)
|
|
2727
2750
|
# update memories in batches to avoid API timeouts
|
|
2728
2751
|
updated_memories: list[MemoryT] = []
|
|
2729
|
-
for i in range(0, len(updates_list),
|
|
2730
|
-
batch = updates_list[i : i +
|
|
2752
|
+
for i in range(0, len(updates_list), batch_size):
|
|
2753
|
+
batch = updates_list[i : i + batch_size]
|
|
2731
2754
|
response = client.PATCH(
|
|
2732
2755
|
"/gpu/memoryset/{name_or_id}/memories",
|
|
2733
2756
|
params={"name_or_id": self.id},
|
|
@@ -2803,12 +2826,13 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2803
2826
|
},
|
|
2804
2827
|
)
|
|
2805
2828
|
|
|
2806
|
-
def delete(self, memory_id: str | Iterable[str]) -> None:
|
|
2829
|
+
def delete(self, memory_id: str | Iterable[str], *, batch_size: int = 32) -> None:
|
|
2807
2830
|
"""
|
|
2808
2831
|
Delete memories from the memoryset
|
|
2809
2832
|
|
|
2810
2833
|
Params:
|
|
2811
2834
|
memory_id: unique identifiers of the memories to delete
|
|
2835
|
+
batch_size: Number of memories to delete in a single API call
|
|
2812
2836
|
|
|
2813
2837
|
Examples:
|
|
2814
2838
|
Delete a single memory:
|
|
@@ -2821,11 +2845,13 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
2821
2845
|
... )
|
|
2822
2846
|
|
|
2823
2847
|
"""
|
|
2848
|
+
if batch_size <= 0 or batch_size > 500:
|
|
2849
|
+
raise ValueError("batch_size must be between 1 and 500")
|
|
2824
2850
|
client = OrcaClient._resolve_client()
|
|
2825
2851
|
memory_ids = [memory_id] if isinstance(memory_id, str) else list(memory_id)
|
|
2826
2852
|
# delete memories in batches to avoid API timeouts
|
|
2827
|
-
for i in range(0, len(memory_ids),
|
|
2828
|
-
batch = memory_ids[i : i +
|
|
2853
|
+
for i in range(0, len(memory_ids), batch_size):
|
|
2854
|
+
batch = memory_ids[i : i + batch_size]
|
|
2829
2855
|
client.POST(
|
|
2830
2856
|
"/memoryset/{name_or_id}/memories/delete", params={"name_or_id": self.id}, json={"memory_ids": batch}
|
|
2831
2857
|
)
|