orca-sdk 0.0.102__py3-none-any.whl → 0.0.104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_shared/metrics.py +31 -9
- orca_sdk/_shared/metrics_test.py +30 -4
- orca_sdk/_utils/auth.py +1 -1
- orca_sdk/_utils/prediction_result_ui.py +5 -1
- orca_sdk/classification_model.py +32 -1
- orca_sdk/classification_model_test.py +19 -1
- orca_sdk/client.py +541 -442
- orca_sdk/conftest.py +14 -2
- orca_sdk/credentials.py +48 -49
- orca_sdk/credentials_test.py +5 -5
- orca_sdk/datasource.py +1 -1
- orca_sdk/datasource_test.py +6 -1
- orca_sdk/embedding_model.py +28 -1
- orca_sdk/job.py +4 -1
- orca_sdk/job_test.py +20 -10
- orca_sdk/memoryset.py +40 -28
- orca_sdk/memoryset_test.py +26 -2
- orca_sdk/regression_model.py +29 -1
- orca_sdk/regression_model_test.py +18 -1
- {orca_sdk-0.0.102.dist-info → orca_sdk-0.0.104.dist-info}/METADATA +15 -14
- orca_sdk-0.0.104.dist-info/RECORD +40 -0
- {orca_sdk-0.0.102.dist-info → orca_sdk-0.0.104.dist-info}/WHEEL +1 -1
- orca_sdk-0.0.102.dist-info/RECORD +0 -40
orca_sdk/conftest.py
CHANGED
|
@@ -24,7 +24,7 @@ os.environ["ORCA_SAVE_TELEMETRY_SYNCHRONOUSLY"] = "true"
|
|
|
24
24
|
|
|
25
25
|
def skip_in_prod(reason: str):
|
|
26
26
|
"""Custom decorator to skip tests when running against production API"""
|
|
27
|
-
PROD_API_URLs = ["https://api.orcadb.ai", "https://api.
|
|
27
|
+
PROD_API_URLs = ["https://api.orcadb.ai", "https://api.staging.orcadb.ai"]
|
|
28
28
|
return pytest.mark.skipif(
|
|
29
29
|
os.environ["ORCA_API_URL"] in PROD_API_URLs,
|
|
30
30
|
reason=reason,
|
|
@@ -45,7 +45,7 @@ def _create_org_id():
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
@pytest.fixture()
|
|
48
|
-
def
|
|
48
|
+
def api_url_reset():
|
|
49
49
|
original_base_url = orca_api.base_url
|
|
50
50
|
yield
|
|
51
51
|
orca_api.base_url = original_base_url
|
|
@@ -113,6 +113,18 @@ SAMPLE_DATA = [
|
|
|
113
113
|
{"value": "cats have nine lives", "label": 1, "key": "g2", "score": 0.9, "source_id": "s14"},
|
|
114
114
|
{"value": "tomato soup with grilled cheese", "label": 0, "key": "g1", "score": 0.1, "source_id": "s15"},
|
|
115
115
|
{"value": "cats are independent animals", "label": 1, "key": "g2", "score": 0.9, "source_id": "s16"},
|
|
116
|
+
{"value": "the beach is always fun", "label": None, "key": "g3", "score": None, "source_id": "s17"},
|
|
117
|
+
{"value": "i love the beach", "label": None, "key": "g3", "score": None, "source_id": "s18"},
|
|
118
|
+
{"value": "the ocean is healing", "label": None, "key": "g3", "score": None, "source_id": "s19"},
|
|
119
|
+
{
|
|
120
|
+
"value": "sandy feet, sand between my toes at the beach",
|
|
121
|
+
"label": None,
|
|
122
|
+
"key": "g3",
|
|
123
|
+
"score": None,
|
|
124
|
+
"source_id": "s20",
|
|
125
|
+
},
|
|
126
|
+
{"value": "i am such a beach bum", "label": None, "key": "g3", "score": None, "source_id": "s21"},
|
|
127
|
+
{"value": "i will always want to be at the beach", "label": None, "key": "g3", "score": None, "source_id": "s22"},
|
|
116
128
|
]
|
|
117
129
|
|
|
118
130
|
|
orca_sdk/credentials.py
CHANGED
|
@@ -35,11 +35,33 @@ class OrcaCredentials:
|
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
37
|
@staticmethod
|
|
38
|
-
def
|
|
38
|
+
def is_authenticated() -> bool:
|
|
39
39
|
"""
|
|
40
|
-
|
|
40
|
+
Check if you are authenticated to interact with the Orca API
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
True if you are authenticated, False otherwise
|
|
41
44
|
"""
|
|
42
|
-
|
|
45
|
+
try:
|
|
46
|
+
return orca_api.GET("/auth")
|
|
47
|
+
except ValueError as e:
|
|
48
|
+
if "Invalid API key" in str(e):
|
|
49
|
+
return False
|
|
50
|
+
raise e
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def is_healthy() -> bool:
|
|
54
|
+
"""
|
|
55
|
+
Check whether the API is healthy
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
True if the API is healthy, False otherwise
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
orca_api.GET("/check/healthy")
|
|
62
|
+
except Exception:
|
|
63
|
+
return False
|
|
64
|
+
return True
|
|
43
65
|
|
|
44
66
|
@staticmethod
|
|
45
67
|
def list_api_keys() -> list[ApiKeyInfo]:
|
|
@@ -58,21 +80,6 @@ class OrcaCredentials:
|
|
|
58
80
|
for api_key in orca_api.GET("/auth/api_key")
|
|
59
81
|
]
|
|
60
82
|
|
|
61
|
-
@staticmethod
|
|
62
|
-
def is_authenticated() -> bool:
|
|
63
|
-
"""
|
|
64
|
-
Check if you are authenticated to interact with the Orca API
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
True if you are authenticated, False otherwise
|
|
68
|
-
"""
|
|
69
|
-
try:
|
|
70
|
-
return orca_api.GET("/auth")
|
|
71
|
-
except ValueError as e:
|
|
72
|
-
if "Invalid API key" in str(e):
|
|
73
|
-
return False
|
|
74
|
-
raise e
|
|
75
|
-
|
|
76
83
|
@staticmethod
|
|
77
84
|
def create_api_key(name: str, scopes: set[Scope] = {"ADMINISTER"}) -> str:
|
|
78
85
|
"""
|
|
@@ -104,20 +111,6 @@ class OrcaCredentials:
|
|
|
104
111
|
"""
|
|
105
112
|
orca_api.DELETE("/auth/api_key/{name_or_id}", params={"name_or_id": name})
|
|
106
113
|
|
|
107
|
-
@staticmethod
|
|
108
|
-
def set_headers(headers: dict[str, str]):
|
|
109
|
-
"""
|
|
110
|
-
Add or override default HTTP headers for all Orca API requests.
|
|
111
|
-
|
|
112
|
-
Params:
|
|
113
|
-
headers: Mapping of header names to their string values
|
|
114
|
-
|
|
115
|
-
Notes:
|
|
116
|
-
New keys are merged into the existing headers, this will overwrite headers with the
|
|
117
|
-
same name, but leave other headers untouched.
|
|
118
|
-
"""
|
|
119
|
-
orca_api.headers.update(Headers(headers))
|
|
120
|
-
|
|
121
114
|
@staticmethod
|
|
122
115
|
def set_api_key(api_key: str, check_validity: bool = True):
|
|
123
116
|
"""
|
|
@@ -133,17 +126,24 @@ class OrcaCredentials:
|
|
|
133
126
|
Raises:
|
|
134
127
|
ValueError: if the API key is invalid and `check_validity` is True
|
|
135
128
|
"""
|
|
136
|
-
OrcaCredentials.
|
|
129
|
+
OrcaCredentials.set_api_headers({"Api-Key": api_key})
|
|
137
130
|
if check_validity:
|
|
138
131
|
orca_api.GET("/auth")
|
|
139
132
|
|
|
140
133
|
@staticmethod
|
|
141
|
-
def
|
|
134
|
+
def get_api_url() -> str:
|
|
135
|
+
"""
|
|
136
|
+
Get the base URL of the Orca API that is currently being used
|
|
137
|
+
"""
|
|
138
|
+
return str(orca_api.base_url)
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def set_api_url(url: str, check_validity: bool = True):
|
|
142
142
|
"""
|
|
143
143
|
Set the base URL for the Orca API
|
|
144
144
|
|
|
145
145
|
Args:
|
|
146
|
-
|
|
146
|
+
url: The base URL to set
|
|
147
147
|
check_validity: Whether to check if there is an API running at the given base URL
|
|
148
148
|
|
|
149
149
|
Raises:
|
|
@@ -152,27 +152,26 @@ class OrcaCredentials:
|
|
|
152
152
|
# check if the base url is reachable before setting it
|
|
153
153
|
if check_validity:
|
|
154
154
|
try:
|
|
155
|
-
httpx.get(
|
|
155
|
+
httpx.get(url, timeout=1)
|
|
156
156
|
except ConnectError as e:
|
|
157
|
-
raise ValueError(f"No API found at {
|
|
157
|
+
raise ValueError(f"No API found at {url}") from e
|
|
158
158
|
|
|
159
|
-
orca_api.base_url =
|
|
159
|
+
orca_api.base_url = url
|
|
160
160
|
|
|
161
161
|
# check if the api passes the health check
|
|
162
162
|
if check_validity:
|
|
163
|
-
|
|
163
|
+
OrcaCredentials.is_healthy()
|
|
164
164
|
|
|
165
165
|
@staticmethod
|
|
166
|
-
def
|
|
166
|
+
def set_api_headers(headers: dict[str, str]):
|
|
167
167
|
"""
|
|
168
|
-
|
|
168
|
+
Add or override default HTTP headers for all Orca API requests.
|
|
169
169
|
|
|
170
|
-
|
|
171
|
-
|
|
170
|
+
Params:
|
|
171
|
+
headers: Mapping of header names to their string values
|
|
172
|
+
|
|
173
|
+
Notes:
|
|
174
|
+
New keys are merged into the existing headers, this will overwrite headers with the
|
|
175
|
+
same name, but leave other headers untouched.
|
|
172
176
|
"""
|
|
173
|
-
|
|
174
|
-
orca_api.GET("/")
|
|
175
|
-
orca_api.GET("/gpu/")
|
|
176
|
-
except Exception:
|
|
177
|
-
return False
|
|
178
|
-
return True
|
|
177
|
+
orca_api.headers.update(Headers(headers))
|
orca_sdk/credentials_test.py
CHANGED
|
@@ -38,20 +38,20 @@ def test_set_invalid_api_key(api_key):
|
|
|
38
38
|
assert not OrcaCredentials.is_authenticated()
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def
|
|
42
|
-
OrcaCredentials.
|
|
41
|
+
def test_set_api_url(api_url_reset):
|
|
42
|
+
OrcaCredentials.set_api_url("http://api.orcadb.ai")
|
|
43
43
|
assert str(orca_api.base_url) == "http://api.orcadb.ai"
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
def test_set_invalid_base_url():
|
|
47
47
|
with pytest.raises(ValueError, match="No API found at http://localhost:1582"):
|
|
48
|
-
OrcaCredentials.
|
|
48
|
+
OrcaCredentials.set_api_url("http://localhost:1582")
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
def test_is_healthy():
|
|
52
52
|
assert OrcaCredentials.is_healthy()
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
def test_is_healthy_false(
|
|
56
|
-
OrcaCredentials.
|
|
55
|
+
def test_is_healthy_false(api_url_reset):
|
|
56
|
+
OrcaCredentials.set_api_url("http://localhost:1582", check_validity=False)
|
|
57
57
|
assert not OrcaCredentials.is_healthy()
|
orca_sdk/datasource.py
CHANGED
|
@@ -499,7 +499,7 @@ class Datasource:
|
|
|
499
499
|
with open(output_path, "wb") as download_file:
|
|
500
500
|
with orca_api.stream("GET", f"/datasource/{self.id}/download", params={"file_type": file_type}) as response:
|
|
501
501
|
total_chunks = int(response.headers["X-Total-Chunks"]) if "X-Total-Chunks" in response.headers else None
|
|
502
|
-
with tqdm(desc=
|
|
502
|
+
with tqdm(desc="Downloading", total=total_chunks, disable=total_chunks is None) as progress:
|
|
503
503
|
for chunk in response.iter_bytes():
|
|
504
504
|
download_file.write(chunk)
|
|
505
505
|
progress.update(1)
|
orca_sdk/datasource_test.py
CHANGED
|
@@ -329,4 +329,9 @@ def test_download_datasource(hf_dataset, datasource):
|
|
|
329
329
|
dataset_from_downloaded_csv.remove_columns("score").to_dict()
|
|
330
330
|
== hf_dataset.remove_columns("score").to_dict()
|
|
331
331
|
)
|
|
332
|
-
|
|
332
|
+
# Replace None with NaN for comparison
|
|
333
|
+
assert np.allclose(
|
|
334
|
+
np.array([np.nan if v is None else float(v) for v in dataset_from_downloaded_csv["score"]], dtype=float),
|
|
335
|
+
np.array([np.nan if v is None else float(v) for v in hf_dataset["score"]], dtype=float),
|
|
336
|
+
equal_nan=True,
|
|
337
|
+
)
|
orca_sdk/embedding_model.py
CHANGED
|
@@ -231,7 +231,34 @@ class EmbeddingModelBase(ABC):
|
|
|
231
231
|
else:
|
|
232
232
|
raise ValueError("Invalid embedding model")
|
|
233
233
|
assert res is not None
|
|
234
|
-
return
|
|
234
|
+
return (
|
|
235
|
+
RegressionMetrics(
|
|
236
|
+
coverage=res.get("coverage"),
|
|
237
|
+
mse=res.get("mse"),
|
|
238
|
+
rmse=res.get("rmse"),
|
|
239
|
+
mae=res.get("mae"),
|
|
240
|
+
r2=res.get("r2"),
|
|
241
|
+
explained_variance=res.get("explained_variance"),
|
|
242
|
+
loss=res.get("loss"),
|
|
243
|
+
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
244
|
+
anomaly_score_median=res.get("anomaly_score_median"),
|
|
245
|
+
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
246
|
+
)
|
|
247
|
+
if "mse" in res
|
|
248
|
+
else ClassificationMetrics(
|
|
249
|
+
coverage=res.get("coverage"),
|
|
250
|
+
f1_score=res.get("f1_score"),
|
|
251
|
+
accuracy=res.get("accuracy"),
|
|
252
|
+
loss=res.get("loss"),
|
|
253
|
+
anomaly_score_mean=res.get("anomaly_score_mean"),
|
|
254
|
+
anomaly_score_median=res.get("anomaly_score_median"),
|
|
255
|
+
anomaly_score_variance=res.get("anomaly_score_variance"),
|
|
256
|
+
roc_auc=res.get("roc_auc"),
|
|
257
|
+
pr_auc=res.get("pr_auc"),
|
|
258
|
+
pr_curve=res.get("pr_curve"),
|
|
259
|
+
roc_curve=res.get("roc_curve"),
|
|
260
|
+
)
|
|
261
|
+
)
|
|
235
262
|
|
|
236
263
|
job = Job(response["task_id"], lambda: get_result(response["task_id"]))
|
|
237
264
|
return job if background else job.result()
|
orca_sdk/job.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Callable, Generic, TypedDict, TypeVar, cast
|
|
|
7
7
|
|
|
8
8
|
from tqdm.auto import tqdm
|
|
9
9
|
|
|
10
|
-
from .client import
|
|
10
|
+
from .client import orca_api
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class JobConfig(TypedDict):
|
|
@@ -26,6 +26,9 @@ class Status(Enum):
|
|
|
26
26
|
DISPATCHED = "DISPATCHED"
|
|
27
27
|
"""The job has been queued and is waiting to be processed"""
|
|
28
28
|
|
|
29
|
+
WAITING = "WAITING"
|
|
30
|
+
"""The job is waiting for dependencies to complete"""
|
|
31
|
+
|
|
29
32
|
PROCESSING = "PROCESSING"
|
|
30
33
|
"""The job is being processed"""
|
|
31
34
|
|
orca_sdk/job_test.py
CHANGED
|
@@ -1,10 +1,20 @@
|
|
|
1
1
|
import time
|
|
2
2
|
|
|
3
|
+
import pytest
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
|
|
3
6
|
from .classification_model import ClassificationModel
|
|
4
7
|
from .datasource import Datasource
|
|
5
8
|
from .job import Job, Status
|
|
6
9
|
|
|
7
10
|
|
|
11
|
+
@pytest.fixture(scope="session")
|
|
12
|
+
def datasource_without_nones(hf_dataset: Dataset):
|
|
13
|
+
return Datasource.from_hf_dataset(
|
|
14
|
+
"test_datasource_without_nones", hf_dataset.filter(lambda x: x["label"] is not None)
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
8
18
|
def wait_for_jobs_status(job_ids, expected_statuses, timeout=10, poll_interval=0.2):
|
|
9
19
|
"""
|
|
10
20
|
Wait until all jobs reach one of the expected statuses or timeout is reached.
|
|
@@ -18,8 +28,8 @@ def wait_for_jobs_status(job_ids, expected_statuses, timeout=10, poll_interval=0
|
|
|
18
28
|
raise TimeoutError(f"Jobs did not reach statuses {expected_statuses} within {timeout} seconds")
|
|
19
29
|
|
|
20
30
|
|
|
21
|
-
def test_job_creation(classification_model: ClassificationModel,
|
|
22
|
-
job = classification_model.evaluate(
|
|
31
|
+
def test_job_creation(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
32
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
23
33
|
assert job.id is not None
|
|
24
34
|
assert job.type == "EVALUATE_MODEL"
|
|
25
35
|
assert job.status in [Status.DISPATCHED, Status.PROCESSING]
|
|
@@ -29,8 +39,8 @@ def test_job_creation(classification_model: ClassificationModel, datasource: Dat
|
|
|
29
39
|
assert len(Job.query(limit=5, type="EVALUATE_MODEL")) >= 1
|
|
30
40
|
|
|
31
41
|
|
|
32
|
-
def test_job_result(classification_model: ClassificationModel,
|
|
33
|
-
job = classification_model.evaluate(
|
|
42
|
+
def test_job_result(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
43
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
34
44
|
result = job.result(show_progress=False)
|
|
35
45
|
assert result is not None
|
|
36
46
|
assert job.status == Status.COMPLETED
|
|
@@ -38,8 +48,8 @@ def test_job_result(classification_model: ClassificationModel, datasource: Datas
|
|
|
38
48
|
assert job.steps_completed == job.steps_total
|
|
39
49
|
|
|
40
50
|
|
|
41
|
-
def test_job_wait(classification_model: ClassificationModel,
|
|
42
|
-
job = classification_model.evaluate(
|
|
51
|
+
def test_job_wait(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
52
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
43
53
|
job.wait(show_progress=False)
|
|
44
54
|
assert job.status == Status.COMPLETED
|
|
45
55
|
assert job.steps_completed is not None
|
|
@@ -47,8 +57,8 @@ def test_job_wait(classification_model: ClassificationModel, datasource: Datasou
|
|
|
47
57
|
assert job.value is not None
|
|
48
58
|
|
|
49
59
|
|
|
50
|
-
def test_job_refresh(classification_model: ClassificationModel,
|
|
51
|
-
job = classification_model.evaluate(
|
|
60
|
+
def test_job_refresh(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
61
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
52
62
|
last_refreshed_at = job.refreshed_at
|
|
53
63
|
# accessing the status attribute should refresh the job after the refresh interval
|
|
54
64
|
Job.set_config(refresh_interval=1)
|
|
@@ -61,12 +71,12 @@ def test_job_refresh(classification_model: ClassificationModel, datasource: Data
|
|
|
61
71
|
assert job.refreshed_at > last_refreshed_at
|
|
62
72
|
|
|
63
73
|
|
|
64
|
-
def test_job_query_pagination(classification_model: ClassificationModel,
|
|
74
|
+
def test_job_query_pagination(classification_model: ClassificationModel, datasource_without_nones: Datasource):
|
|
65
75
|
"""Test pagination with Job.query() method"""
|
|
66
76
|
# Create multiple jobs to test pagination
|
|
67
77
|
jobs_created = []
|
|
68
78
|
for i in range(3):
|
|
69
|
-
job = classification_model.evaluate(
|
|
79
|
+
job = classification_model.evaluate(datasource_without_nones, background=True)
|
|
70
80
|
jobs_created.append(job.id)
|
|
71
81
|
|
|
72
82
|
# Wait for jobs to be at least PROCESSING or COMPLETED
|
orca_sdk/memoryset.py
CHANGED
|
@@ -21,7 +21,9 @@ from .client import (
|
|
|
21
21
|
FilterItem,
|
|
22
22
|
)
|
|
23
23
|
from .client import LabeledMemory as LabeledMemoryResponse
|
|
24
|
-
from .client import
|
|
24
|
+
from .client import (
|
|
25
|
+
LabeledMemoryInsert,
|
|
26
|
+
)
|
|
25
27
|
from .client import LabeledMemoryLookup as LabeledMemoryLookupResponse
|
|
26
28
|
from .client import (
|
|
27
29
|
LabeledMemoryUpdate,
|
|
@@ -35,7 +37,9 @@ from .client import (
|
|
|
35
37
|
MemoryType,
|
|
36
38
|
)
|
|
37
39
|
from .client import ScoredMemory as ScoredMemoryResponse
|
|
38
|
-
from .client import
|
|
40
|
+
from .client import (
|
|
41
|
+
ScoredMemoryInsert,
|
|
42
|
+
)
|
|
39
43
|
from .client import ScoredMemoryLookup as ScoredMemoryLookupResponse
|
|
40
44
|
from .client import (
|
|
41
45
|
ScoredMemoryUpdate,
|
|
@@ -177,7 +181,7 @@ def _parse_memory_insert(memory: dict[str, Any], type: MemoryType) -> LabeledMem
|
|
|
177
181
|
match type:
|
|
178
182
|
case "LABELED":
|
|
179
183
|
label = memory.get("label")
|
|
180
|
-
if not isinstance(label, int):
|
|
184
|
+
if label is not None and not isinstance(label, int):
|
|
181
185
|
raise ValueError("Memory label must be an integer")
|
|
182
186
|
metadata = {k: v for k, v in memory.items() if k not in DEFAULT_COLUMN_NAMES | {"label"}}
|
|
183
187
|
if any(k in metadata for k in FORBIDDEN_METADATA_COLUMN_NAMES):
|
|
@@ -187,7 +191,7 @@ def _parse_memory_insert(memory: dict[str, Any], type: MemoryType) -> LabeledMem
|
|
|
187
191
|
return {"value": value, "label": label, "source_id": source_id, "metadata": metadata}
|
|
188
192
|
case "SCORED":
|
|
189
193
|
score = memory.get("score")
|
|
190
|
-
if not isinstance(score, (int, float)):
|
|
194
|
+
if score is not None and not isinstance(score, (int, float)):
|
|
191
195
|
raise ValueError("Memory score must be a number")
|
|
192
196
|
metadata = {k: v for k, v in memory.items() if k not in DEFAULT_COLUMN_NAMES | {"score"}}
|
|
193
197
|
if any(k in metadata for k in FORBIDDEN_METADATA_COLUMN_NAMES):
|
|
@@ -288,27 +292,13 @@ class MemoryBase(ABC):
|
|
|
288
292
|
raise AttributeError(f"{key} is not a valid attribute")
|
|
289
293
|
return self.metadata[key]
|
|
290
294
|
|
|
291
|
-
def
|
|
295
|
+
def _update(
|
|
292
296
|
self,
|
|
293
297
|
*,
|
|
294
298
|
value: str = UNSET,
|
|
295
299
|
source_id: str | None = UNSET,
|
|
296
300
|
**metadata: None | bool | float | int | str,
|
|
297
301
|
) -> Self:
|
|
298
|
-
"""
|
|
299
|
-
Update the memory with new values
|
|
300
|
-
|
|
301
|
-
Note:
|
|
302
|
-
If a field is not provided, it will default to [UNSET][orca_sdk.UNSET] and not be updated.
|
|
303
|
-
|
|
304
|
-
Params:
|
|
305
|
-
value: New value of the memory
|
|
306
|
-
source_id: New source ID of the memory
|
|
307
|
-
**metadata: New values for metadata properties
|
|
308
|
-
|
|
309
|
-
Returns:
|
|
310
|
-
The updated memory
|
|
311
|
-
"""
|
|
312
302
|
response = orca_api.PATCH(
|
|
313
303
|
"/gpu/memoryset/{name_or_id}/memory",
|
|
314
304
|
params={"name_or_id": self.memoryset_id},
|
|
@@ -368,7 +358,7 @@ class LabeledMemory(MemoryBase):
|
|
|
368
358
|
* **`...`** (<code>[str][str] | [float][float] | [int][int] | [bool][bool] | None</code>): All metadata properties can be accessed as attributes
|
|
369
359
|
"""
|
|
370
360
|
|
|
371
|
-
label: int
|
|
361
|
+
label: int | None
|
|
372
362
|
label_name: str | None
|
|
373
363
|
memory_type = "LABELED"
|
|
374
364
|
|
|
@@ -403,7 +393,7 @@ class LabeledMemory(MemoryBase):
|
|
|
403
393
|
self,
|
|
404
394
|
*,
|
|
405
395
|
value: str = UNSET,
|
|
406
|
-
label: int = UNSET,
|
|
396
|
+
label: int | None = UNSET,
|
|
407
397
|
source_id: str | None = UNSET,
|
|
408
398
|
**metadata: None | bool | float | int | str,
|
|
409
399
|
) -> LabeledMemory:
|
|
@@ -422,7 +412,7 @@ class LabeledMemory(MemoryBase):
|
|
|
422
412
|
Returns:
|
|
423
413
|
The updated memory
|
|
424
414
|
"""
|
|
425
|
-
|
|
415
|
+
self._update(value=value, label=label, source_id=source_id, **metadata)
|
|
426
416
|
return self
|
|
427
417
|
|
|
428
418
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -507,7 +497,7 @@ class ScoredMemory(MemoryBase):
|
|
|
507
497
|
* **`...`** (<code>[str][str] | [float][float] | [int][int] | [bool][bool] | None</code>): All metadata properties can be accessed as attributes
|
|
508
498
|
"""
|
|
509
499
|
|
|
510
|
-
score: float
|
|
500
|
+
score: float | None
|
|
511
501
|
memory_type = "SCORED"
|
|
512
502
|
|
|
513
503
|
def __init__(
|
|
@@ -540,7 +530,7 @@ class ScoredMemory(MemoryBase):
|
|
|
540
530
|
self,
|
|
541
531
|
*,
|
|
542
532
|
value: str = UNSET,
|
|
543
|
-
score: float = UNSET,
|
|
533
|
+
score: float | None = UNSET,
|
|
544
534
|
source_id: str | None = UNSET,
|
|
545
535
|
**metadata: None | bool | float | int | str,
|
|
546
536
|
) -> ScoredMemory:
|
|
@@ -559,7 +549,7 @@ class ScoredMemory(MemoryBase):
|
|
|
559
549
|
Returns:
|
|
560
550
|
The updated memory
|
|
561
551
|
"""
|
|
562
|
-
|
|
552
|
+
self._update(value=value, score=score, source_id=source_id, **metadata)
|
|
563
553
|
return self
|
|
564
554
|
|
|
565
555
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -645,6 +635,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
645
635
|
embedding_model: EmbeddingModelBase
|
|
646
636
|
index_type: IndexType
|
|
647
637
|
index_params: dict[str, Any]
|
|
638
|
+
hidden: bool
|
|
648
639
|
|
|
649
640
|
def __init__(self, metadata: MemorysetMetadata):
|
|
650
641
|
# for internal use only, do not document
|
|
@@ -665,6 +656,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
665
656
|
self.index_type = metadata["index_type"]
|
|
666
657
|
self.index_params = metadata["index_params"]
|
|
667
658
|
self.memory_type = metadata["memory_type"]
|
|
659
|
+
self.hidden = metadata["hidden"]
|
|
668
660
|
|
|
669
661
|
def __eq__(self, other) -> bool:
|
|
670
662
|
return isinstance(other, MemorysetBase) and self.id == other.id
|
|
@@ -699,6 +691,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
699
691
|
index_params: dict[str, Any] = {},
|
|
700
692
|
if_exists: CreateMode = "error",
|
|
701
693
|
background: Literal[True],
|
|
694
|
+
hidden: bool = False,
|
|
702
695
|
) -> Job[Self]:
|
|
703
696
|
pass
|
|
704
697
|
|
|
@@ -723,6 +716,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
723
716
|
index_params: dict[str, Any] = {},
|
|
724
717
|
if_exists: CreateMode = "error",
|
|
725
718
|
background: Literal[False] = False,
|
|
719
|
+
hidden: bool = False,
|
|
726
720
|
) -> Self:
|
|
727
721
|
pass
|
|
728
722
|
|
|
@@ -746,6 +740,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
746
740
|
index_params: dict[str, Any] = {},
|
|
747
741
|
if_exists: CreateMode = "error",
|
|
748
742
|
background: bool = False,
|
|
743
|
+
hidden: bool = False,
|
|
749
744
|
) -> Self | Job[Self]:
|
|
750
745
|
"""
|
|
751
746
|
Create a new memoryset in the OrcaCloud
|
|
@@ -783,6 +778,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
783
778
|
if_exists: What to do if a memoryset with the same name already exists, defaults to
|
|
784
779
|
`"error"`. Other option is `"open"` to open the existing memoryset.
|
|
785
780
|
background: Whether to run the operation none blocking and return a job handle
|
|
781
|
+
hidden: Whether the memoryset should be hidden
|
|
786
782
|
|
|
787
783
|
Returns:
|
|
788
784
|
Handle to the new memoryset in the OrcaCloud
|
|
@@ -820,6 +816,7 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
820
816
|
"remove_duplicates": remove_duplicates,
|
|
821
817
|
"index_type": index_type,
|
|
822
818
|
"index_params": index_params,
|
|
819
|
+
"hidden": hidden,
|
|
823
820
|
}
|
|
824
821
|
if prompt is not None:
|
|
825
822
|
payload["prompt"] = prompt
|
|
@@ -1272,14 +1269,20 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1272
1269
|
return False
|
|
1273
1270
|
|
|
1274
1271
|
@classmethod
|
|
1275
|
-
def all(cls) -> list[Self]:
|
|
1272
|
+
def all(cls, show_hidden: bool = False) -> list[Self]:
|
|
1276
1273
|
"""
|
|
1277
1274
|
Get a list of handles to all memorysets in the OrcaCloud
|
|
1278
1275
|
|
|
1276
|
+
Params:
|
|
1277
|
+
show_hidden: Whether to include hidden memorysets in results, defaults to `False`
|
|
1278
|
+
|
|
1279
1279
|
Returns:
|
|
1280
1280
|
List of handles to all memorysets in the OrcaCloud
|
|
1281
1281
|
"""
|
|
1282
|
-
return [
|
|
1282
|
+
return [
|
|
1283
|
+
cls(metadata)
|
|
1284
|
+
for metadata in orca_api.GET("/memoryset", params={"type": cls.memory_type, "show_hidden": show_hidden})
|
|
1285
|
+
]
|
|
1283
1286
|
|
|
1284
1287
|
@classmethod
|
|
1285
1288
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -1301,7 +1304,14 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1301
1304
|
if if_not_exists == "error":
|
|
1302
1305
|
raise
|
|
1303
1306
|
|
|
1304
|
-
def set(
|
|
1307
|
+
def set(
|
|
1308
|
+
self,
|
|
1309
|
+
*,
|
|
1310
|
+
name: str = UNSET,
|
|
1311
|
+
description: str | None = UNSET,
|
|
1312
|
+
label_names: list[str] = UNSET,
|
|
1313
|
+
hidden: bool = UNSET,
|
|
1314
|
+
):
|
|
1305
1315
|
"""
|
|
1306
1316
|
Update editable attributes of the memoryset
|
|
1307
1317
|
|
|
@@ -1320,6 +1330,8 @@ class MemorysetBase(Generic[MemoryT, MemoryLookupT], ABC):
|
|
|
1320
1330
|
payload["description"] = description
|
|
1321
1331
|
if label_names is not UNSET:
|
|
1322
1332
|
payload["label_names"] = label_names
|
|
1333
|
+
if hidden is not UNSET:
|
|
1334
|
+
payload["hidden"] = hidden
|
|
1323
1335
|
|
|
1324
1336
|
orca_api.PATCH("/memoryset/{name_or_id}", params={"name_or_id": self.id}, json=payload)
|
|
1325
1337
|
self.refresh()
|
orca_sdk/memoryset_test.py
CHANGED
|
@@ -122,6 +122,26 @@ def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
|
|
|
122
122
|
assert any(memoryset.name == readonly_memoryset.name for memoryset in memorysets)
|
|
123
123
|
|
|
124
124
|
|
|
125
|
+
def test_all_memorysets_hidden(
|
|
126
|
+
readonly_memoryset: LabeledMemoryset,
|
|
127
|
+
):
|
|
128
|
+
# Create a hidden memoryset
|
|
129
|
+
hidden_memoryset = LabeledMemoryset.clone(readonly_memoryset, "test_hidden_memoryset")
|
|
130
|
+
hidden_memoryset.set(hidden=True)
|
|
131
|
+
|
|
132
|
+
# Test that show_hidden=False excludes hidden memorysets
|
|
133
|
+
visible_memorysets = LabeledMemoryset.all(show_hidden=False)
|
|
134
|
+
assert len(visible_memorysets) > 0
|
|
135
|
+
assert readonly_memoryset in visible_memorysets
|
|
136
|
+
assert hidden_memoryset not in visible_memorysets
|
|
137
|
+
|
|
138
|
+
# Test that show_hidden=True includes hidden memorysets
|
|
139
|
+
all_memorysets = LabeledMemoryset.all(show_hidden=True)
|
|
140
|
+
assert len(all_memorysets) == len(visible_memorysets) + 1
|
|
141
|
+
assert readonly_memoryset in all_memorysets
|
|
142
|
+
assert hidden_memoryset in all_memorysets
|
|
143
|
+
|
|
144
|
+
|
|
125
145
|
def test_all_memorysets_unauthenticated(unauthenticated):
|
|
126
146
|
with pytest.raises(ValueError, match="Invalid API key"):
|
|
127
147
|
LabeledMemoryset.all()
|
|
@@ -167,6 +187,9 @@ def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
|
|
|
167
187
|
writable_memoryset.set(label_names=["New label 1", "New label 2"])
|
|
168
188
|
assert writable_memoryset.label_names == ["New label 1", "New label 2"]
|
|
169
189
|
|
|
190
|
+
writable_memoryset.set(hidden=True)
|
|
191
|
+
assert writable_memoryset.hidden is True
|
|
192
|
+
|
|
170
193
|
|
|
171
194
|
def test_search(readonly_memoryset: LabeledMemoryset):
|
|
172
195
|
memory_lookups = readonly_memoryset.search(["i love soup", "cats are cute"])
|
|
@@ -364,7 +387,7 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
364
387
|
|
|
365
388
|
def test_embedding_evaluation(eval_datasource: Datasource):
|
|
366
389
|
results = LabeledMemoryset.run_embedding_evaluation(
|
|
367
|
-
eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=
|
|
390
|
+
eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=3
|
|
368
391
|
)
|
|
369
392
|
assert isinstance(results, list)
|
|
370
393
|
assert len(results) == 1
|
|
@@ -465,13 +488,14 @@ def test_drop_memoryset(writable_memoryset: LabeledMemoryset):
|
|
|
465
488
|
|
|
466
489
|
|
|
467
490
|
def test_scored_memoryset(scored_memoryset: ScoredMemoryset):
|
|
468
|
-
assert scored_memoryset.length ==
|
|
491
|
+
assert scored_memoryset.length == 22
|
|
469
492
|
assert isinstance(scored_memoryset[0], ScoredMemory)
|
|
470
493
|
assert scored_memoryset[0].value == "i love soup"
|
|
471
494
|
assert scored_memoryset[0].score is not None
|
|
472
495
|
assert scored_memoryset[0].metadata == {"key": "g1", "source_id": "s1", "label": 0}
|
|
473
496
|
lookup = scored_memoryset.search("i love soup", count=1)
|
|
474
497
|
assert len(lookup) == 1
|
|
498
|
+
assert lookup[0].score is not None
|
|
475
499
|
assert lookup[0].score < 0.11
|
|
476
500
|
|
|
477
501
|
|