orca-sdk 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -1
- orca_sdk/_utils/auth.py +12 -8
- orca_sdk/async_client.py +3795 -0
- orca_sdk/classification_model.py +176 -14
- orca_sdk/classification_model_test.py +96 -28
- orca_sdk/client.py +515 -475
- orca_sdk/conftest.py +37 -36
- orca_sdk/credentials.py +54 -14
- orca_sdk/credentials_test.py +92 -28
- orca_sdk/datasource.py +19 -10
- orca_sdk/datasource_test.py +24 -18
- orca_sdk/embedding_model.py +22 -13
- orca_sdk/embedding_model_test.py +27 -20
- orca_sdk/job.py +14 -8
- orca_sdk/memoryset.py +513 -183
- orca_sdk/memoryset_test.py +130 -32
- orca_sdk/regression_model.py +21 -11
- orca_sdk/regression_model_test.py +35 -26
- orca_sdk/telemetry.py +24 -13
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +1 -1
- orca_sdk-0.1.3.dist-info/RECORD +41 -0
- orca_sdk-0.1.2.dist-info/RECORD +0 -40
- {orca_sdk-0.1.2.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +0 -0
orca_sdk/memoryset_test.py
CHANGED
|
@@ -5,7 +5,7 @@ import pytest
|
|
|
5
5
|
from datasets.arrow_dataset import Dataset
|
|
6
6
|
|
|
7
7
|
from .classification_model import ClassificationModel
|
|
8
|
-
from .conftest import skip_in_prod
|
|
8
|
+
from .conftest import skip_in_ci, skip_in_prod
|
|
9
9
|
from .datasource import Datasource
|
|
10
10
|
from .embedding_model import PretrainedEmbeddingModel
|
|
11
11
|
from .memoryset import LabeledMemoryset, ScoredMemory, ScoredMemoryset, Status
|
|
@@ -39,9 +39,10 @@ def test_create_memoryset(readonly_memoryset: LabeledMemoryset, hf_dataset: Data
|
|
|
39
39
|
assert readonly_memoryset.index_params == {"n_lists": 100}
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
def test_create_memoryset_unauthenticated(
|
|
43
|
-
with
|
|
44
|
-
|
|
42
|
+
def test_create_memoryset_unauthenticated(unauthenticated_client, datasource):
|
|
43
|
+
with unauthenticated_client.use():
|
|
44
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
45
|
+
LabeledMemoryset.create("test_memoryset", datasource)
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
def test_create_memoryset_invalid_input(datasource):
|
|
@@ -87,6 +88,30 @@ def test_create_memoryset_already_exists_open(hf_dataset, label_names, readonly_
|
|
|
87
88
|
assert opened_memoryset.length == len(hf_dataset)
|
|
88
89
|
|
|
89
90
|
|
|
91
|
+
def test_if_exists_error_no_datasource_creation(
|
|
92
|
+
readonly_memoryset: LabeledMemoryset,
|
|
93
|
+
):
|
|
94
|
+
memoryset_name = readonly_memoryset.name
|
|
95
|
+
datasource_name = f"{memoryset_name}_datasource"
|
|
96
|
+
Datasource.drop(datasource_name, if_not_exists="ignore")
|
|
97
|
+
assert not Datasource.exists(datasource_name)
|
|
98
|
+
with pytest.raises(ValueError):
|
|
99
|
+
LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="error")
|
|
100
|
+
assert not Datasource.exists(datasource_name)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_if_exists_open_reuses_existing_datasource(
|
|
104
|
+
readonly_memoryset: LabeledMemoryset,
|
|
105
|
+
):
|
|
106
|
+
memoryset_name = readonly_memoryset.name
|
|
107
|
+
datasource_name = f"{memoryset_name}_datasource"
|
|
108
|
+
Datasource.drop(datasource_name, if_not_exists="ignore")
|
|
109
|
+
assert not Datasource.exists(datasource_name)
|
|
110
|
+
reopened = LabeledMemoryset.from_list(memoryset_name, [{"value": "new value", "label": 0}], if_exists="open")
|
|
111
|
+
assert reopened.id == readonly_memoryset.id
|
|
112
|
+
assert not Datasource.exists(datasource_name)
|
|
113
|
+
|
|
114
|
+
|
|
90
115
|
def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
91
116
|
fetched_memoryset = LabeledMemoryset.open(readonly_memoryset.name)
|
|
92
117
|
assert fetched_memoryset is not None
|
|
@@ -96,9 +121,10 @@ def test_open_memoryset(readonly_memoryset, hf_dataset):
|
|
|
96
121
|
assert fetched_memoryset.index_params == {"n_lists": 100}
|
|
97
122
|
|
|
98
123
|
|
|
99
|
-
def test_open_memoryset_unauthenticated(
|
|
100
|
-
with
|
|
101
|
-
|
|
124
|
+
def test_open_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
|
|
125
|
+
with unauthenticated_client.use():
|
|
126
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
127
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
102
128
|
|
|
103
129
|
|
|
104
130
|
def test_open_memoryset_not_found():
|
|
@@ -111,9 +137,10 @@ def test_open_memoryset_invalid_input():
|
|
|
111
137
|
LabeledMemoryset.open("not valid id")
|
|
112
138
|
|
|
113
139
|
|
|
114
|
-
def test_open_memoryset_unauthorized(
|
|
115
|
-
with
|
|
116
|
-
|
|
140
|
+
def test_open_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
|
|
141
|
+
with unauthorized_client.use():
|
|
142
|
+
with pytest.raises(LookupError):
|
|
143
|
+
LabeledMemoryset.open(readonly_memoryset.name)
|
|
117
144
|
|
|
118
145
|
|
|
119
146
|
def test_all_memorysets(readonly_memoryset: LabeledMemoryset):
|
|
@@ -142,18 +169,21 @@ def test_all_memorysets_hidden(
|
|
|
142
169
|
assert hidden_memoryset in all_memorysets
|
|
143
170
|
|
|
144
171
|
|
|
145
|
-
def test_all_memorysets_unauthenticated(
|
|
146
|
-
with
|
|
147
|
-
|
|
172
|
+
def test_all_memorysets_unauthenticated(unauthenticated_client):
|
|
173
|
+
with unauthenticated_client.use():
|
|
174
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
175
|
+
LabeledMemoryset.all()
|
|
148
176
|
|
|
149
177
|
|
|
150
|
-
def test_all_memorysets_unauthorized(
|
|
151
|
-
|
|
178
|
+
def test_all_memorysets_unauthorized(unauthorized_client, readonly_memoryset):
|
|
179
|
+
with unauthorized_client.use():
|
|
180
|
+
assert readonly_memoryset not in LabeledMemoryset.all()
|
|
152
181
|
|
|
153
182
|
|
|
154
|
-
def test_drop_memoryset_unauthenticated(
|
|
155
|
-
with
|
|
156
|
-
|
|
183
|
+
def test_drop_memoryset_unauthenticated(unauthenticated_client, readonly_memoryset):
|
|
184
|
+
with unauthenticated_client.use():
|
|
185
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
186
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
157
187
|
|
|
158
188
|
|
|
159
189
|
def test_drop_memoryset_not_found():
|
|
@@ -163,9 +193,10 @@ def test_drop_memoryset_not_found():
|
|
|
163
193
|
LabeledMemoryset.drop(str(uuid4()), if_not_exists="ignore")
|
|
164
194
|
|
|
165
195
|
|
|
166
|
-
def test_drop_memoryset_unauthorized(
|
|
167
|
-
with
|
|
168
|
-
|
|
196
|
+
def test_drop_memoryset_unauthorized(unauthorized_client, readonly_memoryset):
|
|
197
|
+
with unauthorized_client.use():
|
|
198
|
+
with pytest.raises(LookupError):
|
|
199
|
+
LabeledMemoryset.drop(readonly_memoryset.name)
|
|
169
200
|
|
|
170
201
|
|
|
171
202
|
def test_update_memoryset_attributes(writable_memoryset: LabeledMemoryset):
|
|
@@ -327,6 +358,7 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
|
327
358
|
|
|
328
359
|
|
|
329
360
|
@skip_in_prod("Production memorysets do not have session consistency guarantees")
|
|
361
|
+
@skip_in_ci("CI environment may not have session consistency guarantees")
|
|
330
362
|
def test_update_memories(writable_memoryset: LabeledMemoryset, hf_dataset: Dataset):
|
|
331
363
|
# We've combined the update tests into one to avoid multiple expensive requests for a writable_memoryset
|
|
332
364
|
|
|
@@ -385,17 +417,6 @@ def test_clone_memoryset(readonly_memoryset: LabeledMemoryset):
|
|
|
385
417
|
assert cloned_memoryset.insertion_status == Status.COMPLETED
|
|
386
418
|
|
|
387
419
|
|
|
388
|
-
def test_embedding_evaluation(eval_datasource: Datasource):
|
|
389
|
-
results = LabeledMemoryset.run_embedding_evaluation(
|
|
390
|
-
eval_datasource, embedding_models=["CDE_SMALL"], neighbor_count=3
|
|
391
|
-
)
|
|
392
|
-
assert isinstance(results, list)
|
|
393
|
-
assert len(results) == 1
|
|
394
|
-
assert results[0] is not None
|
|
395
|
-
assert results[0]["embedding_model_name"] == "CDE_SMALL"
|
|
396
|
-
assert results[0]["embedding_model_path"] == "OrcaDB/cde-small-v1"
|
|
397
|
-
|
|
398
|
-
|
|
399
420
|
@pytest.fixture(scope="function")
|
|
400
421
|
async def test_group_potential_duplicates(writable_memoryset: LabeledMemoryset):
|
|
401
422
|
writable_memoryset.insert(
|
|
@@ -508,3 +529,80 @@ def test_update_scored_memory(scored_memoryset: ScoredMemoryset):
|
|
|
508
529
|
assert scored_memoryset[0].label == 3
|
|
509
530
|
memory.update(label=4)
|
|
510
531
|
assert scored_memoryset[0].label == 4
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
@pytest.mark.asyncio
|
|
535
|
+
async def test_insert_memories_async_single(writable_memoryset: LabeledMemoryset):
|
|
536
|
+
"""Test async insertion of a single memory"""
|
|
537
|
+
await writable_memoryset.arefresh()
|
|
538
|
+
prev_length = writable_memoryset.length
|
|
539
|
+
|
|
540
|
+
await writable_memoryset.ainsert(dict(value="async tomato soup is my favorite", label=0, key="async_test"))
|
|
541
|
+
|
|
542
|
+
await writable_memoryset.arefresh()
|
|
543
|
+
assert writable_memoryset.length == prev_length + 1
|
|
544
|
+
last_memory = writable_memoryset[-1]
|
|
545
|
+
assert last_memory.value == "async tomato soup is my favorite"
|
|
546
|
+
assert last_memory.label == 0
|
|
547
|
+
assert last_memory.metadata["key"] == "async_test"
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@pytest.mark.asyncio
|
|
551
|
+
async def test_insert_memories_async_batch(writable_memoryset: LabeledMemoryset):
|
|
552
|
+
"""Test async insertion of multiple memories"""
|
|
553
|
+
await writable_memoryset.arefresh()
|
|
554
|
+
prev_length = writable_memoryset.length
|
|
555
|
+
|
|
556
|
+
await writable_memoryset.ainsert(
|
|
557
|
+
[
|
|
558
|
+
dict(value="async batch soup is delicious", label=0, key="batch_test_1"),
|
|
559
|
+
dict(value="async batch cats are adorable", label=1, key="batch_test_2"),
|
|
560
|
+
]
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
await writable_memoryset.arefresh()
|
|
564
|
+
assert writable_memoryset.length == prev_length + 2
|
|
565
|
+
|
|
566
|
+
# Check the inserted memories
|
|
567
|
+
last_two_memories = writable_memoryset[-2:]
|
|
568
|
+
values = [memory.value for memory in last_two_memories]
|
|
569
|
+
labels = [memory.label for memory in last_two_memories]
|
|
570
|
+
keys = [memory.metadata.get("key") for memory in last_two_memories]
|
|
571
|
+
|
|
572
|
+
assert "async batch soup is delicious" in values
|
|
573
|
+
assert "async batch cats are adorable" in values
|
|
574
|
+
assert 0 in labels
|
|
575
|
+
assert 1 in labels
|
|
576
|
+
assert "batch_test_1" in keys
|
|
577
|
+
assert "batch_test_2" in keys
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
@pytest.mark.asyncio
|
|
581
|
+
async def test_insert_memories_async_with_source_id(writable_memoryset: LabeledMemoryset):
|
|
582
|
+
"""Test async insertion with source_id and metadata"""
|
|
583
|
+
await writable_memoryset.arefresh()
|
|
584
|
+
prev_length = writable_memoryset.length
|
|
585
|
+
|
|
586
|
+
await writable_memoryset.ainsert(
|
|
587
|
+
dict(
|
|
588
|
+
value="async soup with source id", label=0, source_id="async_source_123", custom_field="async_custom_value"
|
|
589
|
+
)
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
await writable_memoryset.arefresh()
|
|
593
|
+
assert writable_memoryset.length == prev_length + 1
|
|
594
|
+
last_memory = writable_memoryset[-1]
|
|
595
|
+
assert last_memory.value == "async soup with source id"
|
|
596
|
+
assert last_memory.label == 0
|
|
597
|
+
assert last_memory.source_id == "async_source_123"
|
|
598
|
+
assert last_memory.metadata["custom_field"] == "async_custom_value"
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
@pytest.mark.asyncio
|
|
602
|
+
async def test_insert_memories_async_unauthenticated(
|
|
603
|
+
unauthenticated_async_client, writable_memoryset: LabeledMemoryset
|
|
604
|
+
):
|
|
605
|
+
"""Test async insertion with invalid authentication"""
|
|
606
|
+
with unauthenticated_async_client.use():
|
|
607
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
608
|
+
await writable_memoryset.ainsert(dict(value="this should fail", label=0))
|
orca_sdk/regression_model.py
CHANGED
|
@@ -10,10 +10,10 @@ from datasets import Dataset
|
|
|
10
10
|
from ._shared.metrics import RegressionMetrics, calculate_regression_metrics
|
|
11
11
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
12
12
|
from .client import (
|
|
13
|
+
OrcaClient,
|
|
13
14
|
PredictiveModelUpdate,
|
|
14
15
|
RARHeadType,
|
|
15
16
|
RegressionModelMetadata,
|
|
16
|
-
orca_api,
|
|
17
17
|
)
|
|
18
18
|
from .datasource import Datasource
|
|
19
19
|
from .job import Job
|
|
@@ -154,7 +154,8 @@ class RegressionModel:
|
|
|
154
154
|
|
|
155
155
|
return existing
|
|
156
156
|
|
|
157
|
-
|
|
157
|
+
client = OrcaClient._resolve_client()
|
|
158
|
+
metadata = client.POST(
|
|
158
159
|
"/regression_model",
|
|
159
160
|
json={
|
|
160
161
|
"name": name,
|
|
@@ -179,7 +180,8 @@ class RegressionModel:
|
|
|
179
180
|
Raises:
|
|
180
181
|
LookupError: If the regression model does not exist
|
|
181
182
|
"""
|
|
182
|
-
|
|
183
|
+
client = OrcaClient._resolve_client()
|
|
184
|
+
return cls(client.GET("/regression_model/{name_or_id}", params={"name_or_id": name}))
|
|
183
185
|
|
|
184
186
|
@classmethod
|
|
185
187
|
def exists(cls, name_or_id: str) -> bool:
|
|
@@ -206,7 +208,8 @@ class RegressionModel:
|
|
|
206
208
|
Returns:
|
|
207
209
|
List of handles to all regression models in the OrcaCloud
|
|
208
210
|
"""
|
|
209
|
-
|
|
211
|
+
client = OrcaClient._resolve_client()
|
|
212
|
+
return [cls(metadata) for metadata in client.GET("/regression_model")]
|
|
210
213
|
|
|
211
214
|
@classmethod
|
|
212
215
|
def drop(cls, name_or_id: str, if_not_exists: DropMode = "error"):
|
|
@@ -225,7 +228,8 @@ class RegressionModel:
|
|
|
225
228
|
LookupError: If the regression model does not exist and if_not_exists is `"error"`
|
|
226
229
|
"""
|
|
227
230
|
try:
|
|
228
|
-
|
|
231
|
+
client = OrcaClient._resolve_client()
|
|
232
|
+
client.DELETE("/regression_model/{name_or_id}", params={"name_or_id": name_or_id})
|
|
229
233
|
logging.info(f"Deleted model {name_or_id}")
|
|
230
234
|
except LookupError:
|
|
231
235
|
if if_not_exists == "error":
|
|
@@ -261,7 +265,8 @@ class RegressionModel:
|
|
|
261
265
|
update["description"] = description
|
|
262
266
|
if locked is not UNSET:
|
|
263
267
|
update["locked"] = locked
|
|
264
|
-
|
|
268
|
+
client = OrcaClient._resolve_client()
|
|
269
|
+
client.PATCH("/regression_model/{name_or_id}", params={"name_or_id": self.id}, json=update)
|
|
265
270
|
self.refresh()
|
|
266
271
|
|
|
267
272
|
def lock(self) -> None:
|
|
@@ -334,7 +339,8 @@ class RegressionModel:
|
|
|
334
339
|
raise ValueError("timeout_seconds must be a positive integer")
|
|
335
340
|
|
|
336
341
|
telemetry_on, telemetry_sync = _get_telemetry_config(save_telemetry)
|
|
337
|
-
|
|
342
|
+
client = OrcaClient._resolve_client()
|
|
343
|
+
response = client.POST(
|
|
338
344
|
"/gpu/regression_model/{name_or_id}/prediction",
|
|
339
345
|
params={"name_or_id": self.id},
|
|
340
346
|
json={
|
|
@@ -409,7 +415,8 @@ class RegressionModel:
|
|
|
409
415
|
>>> predictions = model.predictions(sort=[("confidence", "desc")], offset=1, limit=1)
|
|
410
416
|
[RegressionPrediction({score: 4.2, confidence: 0.90, anomaly_score: 0.1, input_value: 'Good service'})]
|
|
411
417
|
"""
|
|
412
|
-
|
|
418
|
+
client = OrcaClient._resolve_client()
|
|
419
|
+
predictions = client.POST(
|
|
413
420
|
"/telemetry/prediction",
|
|
414
421
|
json={
|
|
415
422
|
"model_id": self.id,
|
|
@@ -446,7 +453,8 @@ class RegressionModel:
|
|
|
446
453
|
tags: set[str] | None,
|
|
447
454
|
background: bool = False,
|
|
448
455
|
) -> RegressionMetrics | Job[RegressionMetrics]:
|
|
449
|
-
|
|
456
|
+
client = OrcaClient._resolve_client()
|
|
457
|
+
response = client.POST(
|
|
450
458
|
"/regression_model/{model_name_or_id}/evaluation",
|
|
451
459
|
params={"model_name_or_id": self.id},
|
|
452
460
|
json={
|
|
@@ -460,7 +468,8 @@ class RegressionModel:
|
|
|
460
468
|
)
|
|
461
469
|
|
|
462
470
|
def get_value():
|
|
463
|
-
|
|
471
|
+
client = OrcaClient._resolve_client()
|
|
472
|
+
res = client.GET(
|
|
464
473
|
"/regression_model/{model_name_or_id}/evaluation/{task_id}",
|
|
465
474
|
params={"model_name_or_id": self.id, "task_id": response["task_id"]},
|
|
466
475
|
)
|
|
@@ -676,7 +685,8 @@ class RegressionModel:
|
|
|
676
685
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
677
686
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
678
687
|
"""
|
|
679
|
-
|
|
688
|
+
client = OrcaClient._resolve_client()
|
|
689
|
+
client.PUT(
|
|
680
690
|
"/telemetry/prediction/feedback",
|
|
681
691
|
json=[
|
|
682
692
|
_parse_feedback(f) for f in (cast(list[dict], [feedback]) if isinstance(feedback, dict) else feedback)
|
|
@@ -36,9 +36,10 @@ def test_create_model_already_exists_return(scored_memoryset, regression_model:
|
|
|
36
36
|
assert new_model.memory_lookup_count == 3
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def test_create_model_unauthenticated(
|
|
40
|
-
with
|
|
41
|
-
|
|
39
|
+
def test_create_model_unauthenticated(unauthenticated_client, scored_memoryset: ScoredMemoryset):
|
|
40
|
+
with unauthenticated_client.use():
|
|
41
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
42
|
+
RegressionModel.create("test_regression_model", scored_memoryset)
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
def test_get_model(regression_model: RegressionModel):
|
|
@@ -50,9 +51,10 @@ def test_get_model(regression_model: RegressionModel):
|
|
|
50
51
|
assert fetched_model == regression_model
|
|
51
52
|
|
|
52
53
|
|
|
53
|
-
def test_get_model_unauthenticated(
|
|
54
|
-
with
|
|
55
|
-
|
|
54
|
+
def test_get_model_unauthenticated(unauthenticated_client):
|
|
55
|
+
with unauthenticated_client.use():
|
|
56
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
57
|
+
RegressionModel.open("test_regression_model")
|
|
56
58
|
|
|
57
59
|
|
|
58
60
|
def test_get_model_invalid_input():
|
|
@@ -65,9 +67,10 @@ def test_get_model_not_found():
|
|
|
65
67
|
RegressionModel.open(str(uuid4()))
|
|
66
68
|
|
|
67
69
|
|
|
68
|
-
def test_get_model_unauthorized(
|
|
69
|
-
with
|
|
70
|
-
|
|
70
|
+
def test_get_model_unauthorized(unauthorized_client, regression_model: RegressionModel):
|
|
71
|
+
with unauthorized_client.use():
|
|
72
|
+
with pytest.raises(LookupError):
|
|
73
|
+
RegressionModel.open(regression_model.name)
|
|
71
74
|
|
|
72
75
|
|
|
73
76
|
def test_list_models(regression_model: RegressionModel):
|
|
@@ -76,13 +79,15 @@ def test_list_models(regression_model: RegressionModel):
|
|
|
76
79
|
assert any(model.name == regression_model.name for model in models)
|
|
77
80
|
|
|
78
81
|
|
|
79
|
-
def test_list_models_unauthenticated(
|
|
80
|
-
with
|
|
81
|
-
|
|
82
|
+
def test_list_models_unauthenticated(unauthenticated_client):
|
|
83
|
+
with unauthenticated_client.use():
|
|
84
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
85
|
+
RegressionModel.all()
|
|
82
86
|
|
|
83
87
|
|
|
84
|
-
def test_list_models_unauthorized(
|
|
85
|
-
|
|
88
|
+
def test_list_models_unauthorized(unauthorized_client, regression_model: RegressionModel):
|
|
89
|
+
with unauthorized_client.use():
|
|
90
|
+
assert RegressionModel.all() == []
|
|
86
91
|
|
|
87
92
|
|
|
88
93
|
def test_update_model_attributes(regression_model: RegressionModel):
|
|
@@ -113,9 +118,10 @@ def test_delete_model(scored_memoryset: ScoredMemoryset):
|
|
|
113
118
|
RegressionModel.open("regression_model_to_delete")
|
|
114
119
|
|
|
115
120
|
|
|
116
|
-
def test_delete_model_unauthenticated(
|
|
117
|
-
with
|
|
118
|
-
|
|
121
|
+
def test_delete_model_unauthenticated(unauthenticated_client, regression_model: RegressionModel):
|
|
122
|
+
with unauthenticated_client.use():
|
|
123
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
124
|
+
RegressionModel.drop(regression_model.name)
|
|
119
125
|
|
|
120
126
|
|
|
121
127
|
def test_delete_model_not_found():
|
|
@@ -125,9 +131,10 @@ def test_delete_model_not_found():
|
|
|
125
131
|
RegressionModel.drop(str(uuid4()), if_not_exists="ignore")
|
|
126
132
|
|
|
127
133
|
|
|
128
|
-
def test_delete_model_unauthorized(
|
|
129
|
-
with
|
|
130
|
-
|
|
134
|
+
def test_delete_model_unauthorized(unauthorized_client, regression_model: RegressionModel):
|
|
135
|
+
with unauthorized_client.use():
|
|
136
|
+
with pytest.raises(LookupError):
|
|
137
|
+
RegressionModel.drop(regression_model.name)
|
|
131
138
|
|
|
132
139
|
|
|
133
140
|
def test_delete_memoryset_before_model_constraint_violation(hf_dataset):
|
|
@@ -204,14 +211,16 @@ def test_regression_prediction_has_no_score(regression_model: RegressionModel):
|
|
|
204
211
|
assert prediction.score is None
|
|
205
212
|
|
|
206
213
|
|
|
207
|
-
def test_predict_unauthenticated(
|
|
208
|
-
with
|
|
209
|
-
|
|
214
|
+
def test_predict_unauthenticated(unauthenticated_client, regression_model: RegressionModel):
|
|
215
|
+
with unauthenticated_client.use():
|
|
216
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
217
|
+
regression_model.predict(["This is excellent!", "This is terrible!"])
|
|
210
218
|
|
|
211
219
|
|
|
212
|
-
def test_predict_unauthorized(
|
|
213
|
-
with
|
|
214
|
-
|
|
220
|
+
def test_predict_unauthorized(unauthorized_client, regression_model: RegressionModel):
|
|
221
|
+
with unauthorized_client.use():
|
|
222
|
+
with pytest.raises(LookupError):
|
|
223
|
+
regression_model.predict(["This is excellent!", "This is terrible!"])
|
|
215
224
|
|
|
216
225
|
|
|
217
226
|
def test_predict_constraint_violation(scored_memoryset: ScoredMemoryset):
|
orca_sdk/telemetry.py
CHANGED
|
@@ -11,11 +11,11 @@ from httpx import Timeout
|
|
|
11
11
|
from ._utils.common import UNSET
|
|
12
12
|
from .client import (
|
|
13
13
|
LabelPredictionWithMemoriesAndFeedback,
|
|
14
|
+
OrcaClient,
|
|
14
15
|
PredictionFeedbackCategory,
|
|
15
16
|
PredictionFeedbackRequest,
|
|
16
17
|
ScorePredictionWithMemoriesAndFeedback,
|
|
17
18
|
UpdatePredictionRequest,
|
|
18
|
-
orca_api,
|
|
19
19
|
)
|
|
20
20
|
from .memoryset import (
|
|
21
21
|
LabeledMemoryLookup,
|
|
@@ -98,7 +98,8 @@ class FeedbackCategory:
|
|
|
98
98
|
Returns:
|
|
99
99
|
List with information about all existing feedback categories.
|
|
100
100
|
"""
|
|
101
|
-
|
|
101
|
+
client = OrcaClient._resolve_client()
|
|
102
|
+
return [FeedbackCategory(category) for category in client.GET("/telemetry/feedback_category")]
|
|
102
103
|
|
|
103
104
|
@classmethod
|
|
104
105
|
def drop(cls, name: str) -> None:
|
|
@@ -115,7 +116,8 @@ class FeedbackCategory:
|
|
|
115
116
|
Raises:
|
|
116
117
|
LookupError: If the category is not found.
|
|
117
118
|
"""
|
|
118
|
-
|
|
119
|
+
client = OrcaClient._resolve_client()
|
|
120
|
+
client.DELETE("/telemetry/feedback_category/{name_or_id}", params={"name_or_id": name})
|
|
119
121
|
logging.info(f"Deleted feedback category {name} with all associated feedback")
|
|
120
122
|
|
|
121
123
|
def __repr__(self):
|
|
@@ -190,7 +192,8 @@ class PredictionBase(ABC):
|
|
|
190
192
|
if self.__telemetry is None:
|
|
191
193
|
if self.prediction_id is None:
|
|
192
194
|
raise ValueError("Cannot fetch telemetry with no prediction ID")
|
|
193
|
-
|
|
195
|
+
client = OrcaClient._resolve_client()
|
|
196
|
+
self.__telemetry = client.GET(
|
|
194
197
|
"/telemetry/prediction/{prediction_id}", params={"prediction_id": self.prediction_id}
|
|
195
198
|
)
|
|
196
199
|
return self.__telemetry
|
|
@@ -229,7 +232,8 @@ class PredictionBase(ABC):
|
|
|
229
232
|
@property
|
|
230
233
|
def explanation(self) -> str:
|
|
231
234
|
if self._telemetry["explanation"] is None:
|
|
232
|
-
|
|
235
|
+
client = OrcaClient._resolve_client()
|
|
236
|
+
self._telemetry["explanation"] = client.GET(
|
|
233
237
|
"/telemetry/prediction/{prediction_id}/explanation",
|
|
234
238
|
params={"prediction_id": self._telemetry["prediction_id"]},
|
|
235
239
|
parse_as="text",
|
|
@@ -247,7 +251,8 @@ class PredictionBase(ABC):
|
|
|
247
251
|
if not refresh and self._telemetry["explanation"] is not None:
|
|
248
252
|
print(self._telemetry["explanation"])
|
|
249
253
|
else:
|
|
250
|
-
|
|
254
|
+
client = OrcaClient._resolve_client()
|
|
255
|
+
with client.stream(
|
|
251
256
|
"GET",
|
|
252
257
|
f"/telemetry/prediction/{self.prediction_id}/explanation?refresh={refresh}",
|
|
253
258
|
timeout=Timeout(connect=3, read=None),
|
|
@@ -341,14 +346,15 @@ class PredictionBase(ABC):
|
|
|
341
346
|
telemetry=prediction,
|
|
342
347
|
)
|
|
343
348
|
|
|
349
|
+
client = OrcaClient._resolve_client()
|
|
344
350
|
if isinstance(prediction_id, str):
|
|
345
351
|
return create_prediction(
|
|
346
|
-
|
|
352
|
+
client.GET("/telemetry/prediction/{prediction_id}", params={"prediction_id": prediction_id})
|
|
347
353
|
)
|
|
348
354
|
else:
|
|
349
355
|
return [
|
|
350
356
|
create_prediction(prediction)
|
|
351
|
-
for prediction in
|
|
357
|
+
for prediction in client.POST("/telemetry/prediction", json={"prediction_ids": list(prediction_id)})
|
|
352
358
|
]
|
|
353
359
|
|
|
354
360
|
def refresh(self):
|
|
@@ -374,7 +380,8 @@ class PredictionBase(ABC):
|
|
|
374
380
|
payload["expected_label"] = expected_label
|
|
375
381
|
if expected_score is not UNSET:
|
|
376
382
|
payload["expected_score"] = expected_score
|
|
377
|
-
|
|
383
|
+
client = OrcaClient._resolve_client()
|
|
384
|
+
client.PATCH(
|
|
378
385
|
"/telemetry/prediction/{prediction_id}", params={"prediction_id": self.prediction_id}, json=payload
|
|
379
386
|
)
|
|
380
387
|
self.refresh()
|
|
@@ -431,7 +438,8 @@ class PredictionBase(ABC):
|
|
|
431
438
|
ValueError: If the value does not match previous value types for the category, or is a
|
|
432
439
|
[`float`][float] that is not between `-1.0` and `+1.0`.
|
|
433
440
|
"""
|
|
434
|
-
|
|
441
|
+
client = OrcaClient._resolve_client()
|
|
442
|
+
client.PUT(
|
|
435
443
|
"/telemetry/prediction/feedback",
|
|
436
444
|
json=[
|
|
437
445
|
_parse_feedback(
|
|
@@ -454,7 +462,8 @@ class PredictionBase(ABC):
|
|
|
454
462
|
if self.prediction_id is None:
|
|
455
463
|
raise ValueError("Cannot delete feedback with no prediction ID")
|
|
456
464
|
|
|
457
|
-
|
|
465
|
+
client = OrcaClient._resolve_client()
|
|
466
|
+
client.PUT(
|
|
458
467
|
"/telemetry/prediction/feedback",
|
|
459
468
|
json=[PredictionFeedbackRequest(prediction_id=self.prediction_id, category_name=category, value=None)],
|
|
460
469
|
)
|
|
@@ -571,7 +580,8 @@ class ClassificationPrediction(PredictionBase):
|
|
|
571
580
|
if self.prediction_id is None:
|
|
572
581
|
raise ValueError("Cannot get action recommendation with no prediction ID")
|
|
573
582
|
|
|
574
|
-
|
|
583
|
+
client = OrcaClient._resolve_client()
|
|
584
|
+
response = client.GET(
|
|
575
585
|
"/telemetry/prediction/{prediction_id}/action",
|
|
576
586
|
params={"prediction_id": self.prediction_id},
|
|
577
587
|
timeout=30,
|
|
@@ -611,7 +621,8 @@ class ClassificationPrediction(PredictionBase):
|
|
|
611
621
|
if self.prediction_id is None:
|
|
612
622
|
raise ValueError("Cannot generate memory suggestions with no prediction ID")
|
|
613
623
|
|
|
614
|
-
|
|
624
|
+
client = OrcaClient._resolve_client()
|
|
625
|
+
response = client.GET(
|
|
615
626
|
"/telemetry/prediction/{prediction_id}/memory_suggestions",
|
|
616
627
|
params={"prediction_id": self.prediction_id, "num_memories": num_memories},
|
|
617
628
|
timeout=30,
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
orca_sdk/__init__.py,sha256=xyjNwkLQXaX8A-UYgGwYDjv2btOXArT_yiMTfmW7KA8,1003
|
|
2
|
+
orca_sdk/_shared/__init__.py,sha256=3Kt0Hu3QLI5FEp9nqGTxqAm3hAoBJKcagfaGQZ-lbJQ,223
|
|
3
|
+
orca_sdk/_shared/metrics.py,sha256=LEZfAUWUtUWv_WWy9F_yjGLlUQHQpmR9WxG2fbKxa7U,14419
|
|
4
|
+
orca_sdk/_shared/metrics_test.py,sha256=Rw1MaH37FppNsMnW8Ir9vMd8xxnZt3eo2Iypx1igtBI,9440
|
|
5
|
+
orca_sdk/_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
+
orca_sdk/_utils/analysis_ui.py,sha256=nT-M_YcNRCVPQzvuqYNFKnNHhYkADYBvq1GlIUePrWw,9232
|
|
7
|
+
orca_sdk/_utils/analysis_ui_style.css,sha256=q_ba_-_KtgztepHg829zLzypaxKayl7ySC1-oYDzV3k,836
|
|
8
|
+
orca_sdk/_utils/auth.py,sha256=nC252O171_3_wn4KBAN7kg8GNvoZFiQ5Xtzkrm5dWDo,2645
|
|
9
|
+
orca_sdk/_utils/auth_test.py,sha256=ygVWv1Ex53LaxIP7p2hzPHl8l9qYyBD5IGmEFJMps6s,1056
|
|
10
|
+
orca_sdk/_utils/common.py,sha256=wUm2pNDWytEecC5WiDWd02-yCZw3Akx0bIutG4lHsFA,805
|
|
11
|
+
orca_sdk/_utils/data_parsing.py,sha256=gkAwWEC8qRt3vRUObe7n7Pr0azOayNwc2yFY04WFp7E,5220
|
|
12
|
+
orca_sdk/_utils/data_parsing_test.py,sha256=fNEYzPzE1jt3KWE2Kj91KqIeuv-L5REHFAa98zkNGSQ,8962
|
|
13
|
+
orca_sdk/_utils/pagination.py,sha256=986z0QPZixrZeurJWorF6eMgnTRdDF84AagEA6qNbMw,4245
|
|
14
|
+
orca_sdk/_utils/pagination_test.py,sha256=BUylCrcHnwoKEBmMUzVr0lwLpA35ivcCwdBK4rMw9y8,4887
|
|
15
|
+
orca_sdk/_utils/prediction_result_ui.css,sha256=sqBlkRLnovb5X5EcUDdB6iGpH63nVRlTW4uAmXuD0WM,258
|
|
16
|
+
orca_sdk/_utils/prediction_result_ui.py,sha256=Ur_FY7dz3oWNmtPiP3Wl3yRlEMgK8q9UfT-SDu9UPxA,4805
|
|
17
|
+
orca_sdk/_utils/tqdm_file_reader.py,sha256=Lw7Cg1UgNuRUoN6jjqZb-IlV00H-kbRcrZLdudr1GxE,324
|
|
18
|
+
orca_sdk/_utils/value_parser.py,sha256=c3qMABCCDQcIjn9N1orYYnlRwDW9JWdGwW_2TDZPLdI,1286
|
|
19
|
+
orca_sdk/_utils/value_parser_test.py,sha256=OybsiC-Obi32RRi9NIuwrVBRAnlyPMV1xVAaevSrb7M,1079
|
|
20
|
+
orca_sdk/async_client.py,sha256=HK52VxltotpDdq-aTgsCHQPsDAYzOSZDxdlbOnal99c,125459
|
|
21
|
+
orca_sdk/classification_model.py,sha256=WJM6oLBuGrxleTWakc-ZgSRfNyiZxb6-GIMH-S7k12w,39700
|
|
22
|
+
orca_sdk/classification_model_test.py,sha256=_gaDg8QB0h0ByN4UwTk2fIIDXE4UzahuJBjz7NSPK28,23605
|
|
23
|
+
orca_sdk/client.py,sha256=dcGBnzIwaU74CMzUh1ObKJbVmZekF5n57gQY6YcQwHE,124550
|
|
24
|
+
orca_sdk/conftest.py,sha256=RtINF1xea2iMycMkpMXIOOqRbfWeIZsceSAemhBmgNE,9761
|
|
25
|
+
orca_sdk/credentials.py,sha256=80_1r8n5jruEvN_E629SaRrRhKvF_NhWUEZyZzPXkqQ,6620
|
|
26
|
+
orca_sdk/credentials_test.py,sha256=TLbXJMz3IlThvtSrHeLM7jRsKnrncA_ahOTpHg15Ei4,4089
|
|
27
|
+
orca_sdk/datasource.py,sha256=DJt1Hr8iwaTFbtFD1aqbUPytpmjPr39qISeqSumoraM,20668
|
|
28
|
+
orca_sdk/datasource_test.py,sha256=yBR0NbsAzChV97pSOU0IvlfF5_WbMe49wZeWNXxwNl4,12128
|
|
29
|
+
orca_sdk/embedding_model.py,sha256=IQCpGUUlKHtz33Ld1-Ag8eLMk72qT7K-cHDjBJGJqhQ,27689
|
|
30
|
+
orca_sdk/embedding_model_test.py,sha256=-NItbNb3tTVj5jAvSi3WjV3FP448q08lmT5iObg9vwA,8133
|
|
31
|
+
orca_sdk/job.py,sha256=BOHg9ksVcN26VtAmuA2cNjGed_Gsx2zbdCO6FBZjuqI,13119
|
|
32
|
+
orca_sdk/job_test.py,sha256=nRSWxd_1UIfrj9oMVvrXjt6OBkBpddYAjb2y6P-DTUg,4327
|
|
33
|
+
orca_sdk/memoryset.py,sha256=angGB6OUJRDoBa2xzl4WsYhNsRRj99dEKEaAKUqgxO8,100113
|
|
34
|
+
orca_sdk/memoryset_test.py,sha256=wqoHXP60CBvtsReCunQNUxj6_ZDT67TTTluguQapigs,25368
|
|
35
|
+
orca_sdk/regression_model.py,sha256=kIT2i4XMrTBZXXVqDENoYILLF7Zqa8o2ndhraXHUPbY,26437
|
|
36
|
+
orca_sdk/regression_model_test.py,sha256=slwxbty_vL9d24OCn5xN61eKyri5GS7Jv2YmpEOMTrM,15856
|
|
37
|
+
orca_sdk/telemetry.py,sha256=C0rTudfAV3_t_uADATrl06d7vk-Sgop24FiSSqYhqmc,26209
|
|
38
|
+
orca_sdk/telemetry_test.py,sha256=eT66C5lFdNg-pQdo2I__BP7Tn5fTc9aTkVo9ZhWwhU0,5519
|
|
39
|
+
orca_sdk-0.1.3.dist-info/METADATA,sha256=tbUzJDcZGUOkwPeToA74JXtkJErdpmtx7UtL1mhzm_M,3659
|
|
40
|
+
orca_sdk-0.1.3.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
41
|
+
orca_sdk-0.1.3.dist-info/RECORD,,
|