orca-sdk 0.0.93__py3-none-any.whl → 0.0.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/_generated_api_client/api/__init__.py +4 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +148 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
- orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
- orca_sdk/_generated_api_client/models/__init__.py +6 -0
- orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
- orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
- orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
- orca_sdk/_generated_api_client/models/validation_error.py +99 -0
- orca_sdk/classification_model.py +5 -3
- orca_sdk/classification_model_test.py +46 -0
- orca_sdk/conftest.py +1 -0
- orca_sdk/datasource.py +34 -0
- orca_sdk/datasource_test.py +9 -0
- orca_sdk/memoryset_test.py +2 -0
- orca_sdk/telemetry.py +5 -3
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.94.dist-info}/METADATA +1 -1
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.94.dist-info}/RECORD +21 -17
- {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.94.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is generated by the openapi-python-client tool via the generate_api_client.py script
|
|
3
|
+
|
|
4
|
+
It is a customized template from the openapi-python-client tool's default template:
|
|
5
|
+
https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
|
|
6
|
+
|
|
7
|
+
The main change is:
|
|
8
|
+
- Fix typing issues
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# flake8: noqa: C901
|
|
12
|
+
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
|
|
14
|
+
|
|
15
|
+
from attrs import define as _attrs_define
|
|
16
|
+
from attrs import field as _attrs_field
|
|
17
|
+
|
|
18
|
+
from ..types import UNSET, Unset
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from ..models.validation_error import ValidationError
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
T = TypeVar("T", bound="HTTPValidationError")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@_attrs_define
|
|
28
|
+
class HTTPValidationError:
|
|
29
|
+
"""
|
|
30
|
+
Attributes:
|
|
31
|
+
detail (Union[Unset, List['ValidationError']]):
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
detail: Union[Unset, List["ValidationError"]] = UNSET
|
|
35
|
+
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
36
|
+
|
|
37
|
+
def to_dict(self) -> dict[str, Any]:
|
|
38
|
+
detail: Union[Unset, List[Dict[str, Any]]] = UNSET
|
|
39
|
+
if not isinstance(self.detail, Unset):
|
|
40
|
+
detail = []
|
|
41
|
+
for detail_item_data in self.detail:
|
|
42
|
+
detail_item = detail_item_data.to_dict()
|
|
43
|
+
detail.append(detail_item)
|
|
44
|
+
|
|
45
|
+
field_dict: dict[str, Any] = {}
|
|
46
|
+
field_dict.update(self.additional_properties)
|
|
47
|
+
field_dict.update({})
|
|
48
|
+
if detail is not UNSET:
|
|
49
|
+
field_dict["detail"] = detail
|
|
50
|
+
|
|
51
|
+
return field_dict
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
|
|
55
|
+
from ..models.validation_error import ValidationError
|
|
56
|
+
|
|
57
|
+
d = src_dict.copy()
|
|
58
|
+
detail = []
|
|
59
|
+
_detail = d.pop("detail", UNSET)
|
|
60
|
+
for detail_item_data in _detail or []:
|
|
61
|
+
detail_item = ValidationError.from_dict(detail_item_data)
|
|
62
|
+
|
|
63
|
+
detail.append(detail_item)
|
|
64
|
+
|
|
65
|
+
http_validation_error = cls(
|
|
66
|
+
detail=detail,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
http_validation_error.additional_properties = d
|
|
70
|
+
return http_validation_error
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def additional_keys(self) -> list[str]:
|
|
74
|
+
return list(self.additional_properties.keys())
|
|
75
|
+
|
|
76
|
+
def __getitem__(self, key: str) -> Any:
|
|
77
|
+
return self.additional_properties[key]
|
|
78
|
+
|
|
79
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
80
|
+
self.additional_properties[key] = value
|
|
81
|
+
|
|
82
|
+
def __delitem__(self, key: str) -> None:
|
|
83
|
+
del self.additional_properties[key]
|
|
84
|
+
|
|
85
|
+
def __contains__(self, key: str) -> bool:
|
|
86
|
+
return key in self.additional_properties
|
|
@@ -10,11 +10,13 @@ The main change is:
|
|
|
10
10
|
|
|
11
11
|
# flake8: noqa: C901
|
|
12
12
|
|
|
13
|
+
import datetime
|
|
13
14
|
from enum import Enum
|
|
14
15
|
from typing import Any, List, Type, TypeVar, Union, cast
|
|
15
16
|
|
|
16
17
|
from attrs import define as _attrs_define
|
|
17
18
|
from attrs import field as _attrs_field
|
|
19
|
+
from dateutil.parser import isoparse
|
|
18
20
|
|
|
19
21
|
from ..models.prediction_sort_item_item_type_0 import PredictionSortItemItemType0
|
|
20
22
|
from ..models.prediction_sort_item_item_type_1 import PredictionSortItemItemType1
|
|
@@ -30,6 +32,8 @@ class ListPredictionsRequest:
|
|
|
30
32
|
model_id (Union[None, Unset, str]):
|
|
31
33
|
tag (Union[None, Unset, str]):
|
|
32
34
|
prediction_ids (Union[List[str], None, Unset]):
|
|
35
|
+
start_timestamp (Union[None, Unset, datetime.datetime]):
|
|
36
|
+
end_timestamp (Union[None, Unset, datetime.datetime]):
|
|
33
37
|
limit (Union[None, Unset, int]):
|
|
34
38
|
offset (Union[None, Unset, int]): Default: 0.
|
|
35
39
|
sort (Union[Unset, List[List[Union[PredictionSortItemItemType0, PredictionSortItemItemType1]]]]):
|
|
@@ -39,6 +43,8 @@ class ListPredictionsRequest:
|
|
|
39
43
|
model_id: Union[None, Unset, str] = UNSET
|
|
40
44
|
tag: Union[None, Unset, str] = UNSET
|
|
41
45
|
prediction_ids: Union[List[str], None, Unset] = UNSET
|
|
46
|
+
start_timestamp: Union[None, Unset, datetime.datetime] = UNSET
|
|
47
|
+
end_timestamp: Union[None, Unset, datetime.datetime] = UNSET
|
|
42
48
|
limit: Union[None, Unset, int] = UNSET
|
|
43
49
|
offset: Union[None, Unset, int] = 0
|
|
44
50
|
sort: Union[Unset, List[List[Union[PredictionSortItemItemType0, PredictionSortItemItemType1]]]] = UNSET
|
|
@@ -67,6 +73,22 @@ class ListPredictionsRequest:
|
|
|
67
73
|
else:
|
|
68
74
|
prediction_ids = self.prediction_ids
|
|
69
75
|
|
|
76
|
+
start_timestamp: Union[None, Unset, str]
|
|
77
|
+
if isinstance(self.start_timestamp, Unset):
|
|
78
|
+
start_timestamp = UNSET
|
|
79
|
+
elif isinstance(self.start_timestamp, datetime.datetime):
|
|
80
|
+
start_timestamp = self.start_timestamp.isoformat()
|
|
81
|
+
else:
|
|
82
|
+
start_timestamp = self.start_timestamp
|
|
83
|
+
|
|
84
|
+
end_timestamp: Union[None, Unset, str]
|
|
85
|
+
if isinstance(self.end_timestamp, Unset):
|
|
86
|
+
end_timestamp = UNSET
|
|
87
|
+
elif isinstance(self.end_timestamp, datetime.datetime):
|
|
88
|
+
end_timestamp = self.end_timestamp.isoformat()
|
|
89
|
+
else:
|
|
90
|
+
end_timestamp = self.end_timestamp
|
|
91
|
+
|
|
70
92
|
limit: Union[None, Unset, int]
|
|
71
93
|
if isinstance(self.limit, Unset):
|
|
72
94
|
limit = UNSET
|
|
@@ -118,6 +140,10 @@ class ListPredictionsRequest:
|
|
|
118
140
|
field_dict["tag"] = tag
|
|
119
141
|
if prediction_ids is not UNSET:
|
|
120
142
|
field_dict["prediction_ids"] = prediction_ids
|
|
143
|
+
if start_timestamp is not UNSET:
|
|
144
|
+
field_dict["start_timestamp"] = start_timestamp
|
|
145
|
+
if end_timestamp is not UNSET:
|
|
146
|
+
field_dict["end_timestamp"] = end_timestamp
|
|
121
147
|
if limit is not UNSET:
|
|
122
148
|
field_dict["limit"] = limit
|
|
123
149
|
if offset is not UNSET:
|
|
@@ -168,6 +194,40 @@ class ListPredictionsRequest:
|
|
|
168
194
|
|
|
169
195
|
prediction_ids = _parse_prediction_ids(d.pop("prediction_ids", UNSET))
|
|
170
196
|
|
|
197
|
+
def _parse_start_timestamp(data: object) -> Union[None, Unset, datetime.datetime]:
|
|
198
|
+
if data is None:
|
|
199
|
+
return data
|
|
200
|
+
if isinstance(data, Unset):
|
|
201
|
+
return data
|
|
202
|
+
try:
|
|
203
|
+
if not isinstance(data, str):
|
|
204
|
+
raise TypeError()
|
|
205
|
+
start_timestamp_type_0 = isoparse(data)
|
|
206
|
+
|
|
207
|
+
return start_timestamp_type_0
|
|
208
|
+
except: # noqa: E722
|
|
209
|
+
pass
|
|
210
|
+
return cast(Union[None, Unset, datetime.datetime], data)
|
|
211
|
+
|
|
212
|
+
start_timestamp = _parse_start_timestamp(d.pop("start_timestamp", UNSET))
|
|
213
|
+
|
|
214
|
+
def _parse_end_timestamp(data: object) -> Union[None, Unset, datetime.datetime]:
|
|
215
|
+
if data is None:
|
|
216
|
+
return data
|
|
217
|
+
if isinstance(data, Unset):
|
|
218
|
+
return data
|
|
219
|
+
try:
|
|
220
|
+
if not isinstance(data, str):
|
|
221
|
+
raise TypeError()
|
|
222
|
+
end_timestamp_type_0 = isoparse(data)
|
|
223
|
+
|
|
224
|
+
return end_timestamp_type_0
|
|
225
|
+
except: # noqa: E722
|
|
226
|
+
pass
|
|
227
|
+
return cast(Union[None, Unset, datetime.datetime], data)
|
|
228
|
+
|
|
229
|
+
end_timestamp = _parse_end_timestamp(d.pop("end_timestamp", UNSET))
|
|
230
|
+
|
|
171
231
|
def _parse_limit(data: object) -> Union[None, Unset, int]:
|
|
172
232
|
if data is None:
|
|
173
233
|
return data
|
|
@@ -231,6 +291,8 @@ class ListPredictionsRequest:
|
|
|
231
291
|
model_id=model_id,
|
|
232
292
|
tag=tag,
|
|
233
293
|
prediction_ids=prediction_ids,
|
|
294
|
+
start_timestamp=start_timestamp,
|
|
295
|
+
end_timestamp=end_timestamp,
|
|
234
296
|
limit=limit,
|
|
235
297
|
offset=offset,
|
|
236
298
|
sort=sort,
|
|
@@ -13,7 +13,6 @@ The main change is:
|
|
|
13
13
|
from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast
|
|
14
14
|
|
|
15
15
|
from attrs import define as _attrs_define
|
|
16
|
-
from attrs import field as _attrs_field
|
|
17
16
|
|
|
18
17
|
from ..types import UNSET, Unset
|
|
19
18
|
|
|
@@ -44,7 +43,6 @@ class MemorysetAnalysisConfigs:
|
|
|
44
43
|
duplicate: Union["MemorysetDuplicateAnalysisConfig", None, Unset] = UNSET
|
|
45
44
|
projection: Union["MemorysetProjectionAnalysisConfig", None, Unset] = UNSET
|
|
46
45
|
cluster: Union["MemorysetClusterAnalysisConfig", None, Unset] = UNSET
|
|
47
|
-
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
48
46
|
|
|
49
47
|
def to_dict(self) -> dict[str, Any]:
|
|
50
48
|
from ..models.memoryset_cluster_analysis_config import MemorysetClusterAnalysisConfig
|
|
@@ -94,7 +92,6 @@ class MemorysetAnalysisConfigs:
|
|
|
94
92
|
cluster = self.cluster
|
|
95
93
|
|
|
96
94
|
field_dict: dict[str, Any] = {}
|
|
97
|
-
field_dict.update(self.additional_properties)
|
|
98
95
|
field_dict.update({})
|
|
99
96
|
if neighbor is not UNSET:
|
|
100
97
|
field_dict["neighbor"] = neighbor
|
|
@@ -212,21 +209,4 @@ class MemorysetAnalysisConfigs:
|
|
|
212
209
|
cluster=cluster,
|
|
213
210
|
)
|
|
214
211
|
|
|
215
|
-
memoryset_analysis_configs.additional_properties = d
|
|
216
212
|
return memoryset_analysis_configs
|
|
217
|
-
|
|
218
|
-
@property
|
|
219
|
-
def additional_keys(self) -> list[str]:
|
|
220
|
-
return list(self.additional_properties.keys())
|
|
221
|
-
|
|
222
|
-
def __getitem__(self, key: str) -> Any:
|
|
223
|
-
return self.additional_properties[key]
|
|
224
|
-
|
|
225
|
-
def __setitem__(self, key: str, value: Any) -> None:
|
|
226
|
-
self.additional_properties[key] = value
|
|
227
|
-
|
|
228
|
-
def __delitem__(self, key: str) -> None:
|
|
229
|
-
del self.additional_properties[key]
|
|
230
|
-
|
|
231
|
-
def __contains__(self, key: str) -> bool:
|
|
232
|
-
return key in self.additional_properties
|
|
@@ -2,11 +2,16 @@ from enum import Enum
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class PretrainedEmbeddingModelName(str, Enum):
|
|
5
|
+
BGE_BASE = "BGE_BASE"
|
|
5
6
|
CDE_SMALL = "CDE_SMALL"
|
|
6
7
|
CLIP_BASE = "CLIP_BASE"
|
|
7
8
|
DISTILBERT = "DISTILBERT"
|
|
9
|
+
E5_LARGE = "E5_LARGE"
|
|
10
|
+
GIST_LARGE = "GIST_LARGE"
|
|
8
11
|
GTE_BASE = "GTE_BASE"
|
|
9
12
|
GTE_SMALL = "GTE_SMALL"
|
|
13
|
+
MXBAI_LARGE = "MXBAI_LARGE"
|
|
14
|
+
QWEN2_1_5B = "QWEN2_1_5B"
|
|
10
15
|
|
|
11
16
|
def __str__(self) -> str:
|
|
12
17
|
return str(self.value)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is generated by the openapi-python-client tool via the generate_api_client.py script
|
|
3
|
+
|
|
4
|
+
It is a customized template from the openapi-python-client tool's default template:
|
|
5
|
+
https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
|
|
6
|
+
|
|
7
|
+
The main change is:
|
|
8
|
+
- Fix typing issues
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# flake8: noqa: C901
|
|
12
|
+
|
|
13
|
+
from typing import Any, List, Type, TypeVar, Union, cast
|
|
14
|
+
|
|
15
|
+
from attrs import define as _attrs_define
|
|
16
|
+
from attrs import field as _attrs_field
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T", bound="ValidationError")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@_attrs_define
|
|
22
|
+
class ValidationError:
|
|
23
|
+
"""
|
|
24
|
+
Attributes:
|
|
25
|
+
loc (List[Union[int, str]]):
|
|
26
|
+
msg (str):
|
|
27
|
+
type (str):
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
loc: List[Union[int, str]]
|
|
31
|
+
msg: str
|
|
32
|
+
type: str
|
|
33
|
+
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
34
|
+
|
|
35
|
+
def to_dict(self) -> dict[str, Any]:
|
|
36
|
+
loc = []
|
|
37
|
+
for loc_item_data in self.loc:
|
|
38
|
+
loc_item: Union[int, str]
|
|
39
|
+
loc_item = loc_item_data
|
|
40
|
+
loc.append(loc_item)
|
|
41
|
+
|
|
42
|
+
msg = self.msg
|
|
43
|
+
|
|
44
|
+
type = self.type
|
|
45
|
+
|
|
46
|
+
field_dict: dict[str, Any] = {}
|
|
47
|
+
field_dict.update(self.additional_properties)
|
|
48
|
+
field_dict.update(
|
|
49
|
+
{
|
|
50
|
+
"loc": loc,
|
|
51
|
+
"msg": msg,
|
|
52
|
+
"type": type,
|
|
53
|
+
}
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return field_dict
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
|
|
60
|
+
d = src_dict.copy()
|
|
61
|
+
loc = []
|
|
62
|
+
_loc = d.pop("loc")
|
|
63
|
+
for loc_item_data in _loc:
|
|
64
|
+
|
|
65
|
+
def _parse_loc_item(data: object) -> Union[int, str]:
|
|
66
|
+
return cast(Union[int, str], data)
|
|
67
|
+
|
|
68
|
+
loc_item = _parse_loc_item(loc_item_data)
|
|
69
|
+
|
|
70
|
+
loc.append(loc_item)
|
|
71
|
+
|
|
72
|
+
msg = d.pop("msg")
|
|
73
|
+
|
|
74
|
+
type = d.pop("type")
|
|
75
|
+
|
|
76
|
+
validation_error = cls(
|
|
77
|
+
loc=loc,
|
|
78
|
+
msg=msg,
|
|
79
|
+
type=type,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
validation_error.additional_properties = d
|
|
83
|
+
return validation_error
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def additional_keys(self) -> list[str]:
|
|
87
|
+
return list(self.additional_properties.keys())
|
|
88
|
+
|
|
89
|
+
def __getitem__(self, key: str) -> Any:
|
|
90
|
+
return self.additional_properties[key]
|
|
91
|
+
|
|
92
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
93
|
+
self.additional_properties[key] = value
|
|
94
|
+
|
|
95
|
+
def __delitem__(self, key: str) -> None:
|
|
96
|
+
del self.additional_properties[key]
|
|
97
|
+
|
|
98
|
+
def __contains__(self, key: str) -> bool:
|
|
99
|
+
return key in self.additional_properties
|
orca_sdk/classification_model.py
CHANGED
|
@@ -382,7 +382,9 @@ class ClassificationModel:
|
|
|
382
382
|
expected_labels=(
|
|
383
383
|
expected_labels
|
|
384
384
|
if isinstance(expected_labels, list)
|
|
385
|
-
else [expected_labels]
|
|
385
|
+
else [expected_labels]
|
|
386
|
+
if expected_labels is not None
|
|
387
|
+
else None
|
|
386
388
|
),
|
|
387
389
|
tags=list(tags),
|
|
388
390
|
save_telemetry=save_telemetry,
|
|
@@ -403,8 +405,9 @@ class ClassificationModel:
|
|
|
403
405
|
memoryset=self.memoryset,
|
|
404
406
|
model=self,
|
|
405
407
|
logits=prediction.logits,
|
|
408
|
+
input_value=input_value,
|
|
406
409
|
)
|
|
407
|
-
for prediction in response
|
|
410
|
+
for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
|
|
408
411
|
]
|
|
409
412
|
self._last_prediction_was_batch = isinstance(value, list)
|
|
410
413
|
self._last_prediction = predictions[-1]
|
|
@@ -480,7 +483,6 @@ class ClassificationModel:
|
|
|
480
483
|
predictions: list[LabelPrediction],
|
|
481
484
|
expected_labels: list[int],
|
|
482
485
|
) -> ClassificationEvaluationResult:
|
|
483
|
-
|
|
484
486
|
targets_array = np.array(expected_labels)
|
|
485
487
|
predictions_array = np.array([p.label for p in predictions])
|
|
486
488
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
1
3
|
from uuid import uuid4
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -9,6 +11,11 @@ from .datasource import Datasource
|
|
|
9
11
|
from .embedding_model import PretrainedEmbeddingModel
|
|
10
12
|
from .memoryset import LabeledMemoryset
|
|
11
13
|
|
|
14
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
SKIP_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
|
|
18
|
+
|
|
12
19
|
|
|
13
20
|
def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
|
|
14
21
|
assert model is not None
|
|
@@ -338,3 +345,42 @@ def test_last_prediction_with_single(model: ClassificationModel):
|
|
|
338
345
|
assert model.last_prediction.prediction_id == prediction.prediction_id
|
|
339
346
|
assert model.last_prediction.input_value == "Do you love soup?"
|
|
340
347
|
assert model._last_prediction_was_batch is False
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@pytest.mark.skipif(
|
|
351
|
+
SKIP_IN_GITHUB_ACTIONS, reason="Skipping explanation test because in CI we don't have Anthropic API key"
|
|
352
|
+
)
|
|
353
|
+
def test_explain(writable_memoryset: LabeledMemoryset):
|
|
354
|
+
|
|
355
|
+
writable_memoryset.analyze(
|
|
356
|
+
{"name": "neighbor", "neighbor_counts": [1, 3]},
|
|
357
|
+
lookup_count=3,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
model = ClassificationModel.create(
|
|
361
|
+
"test_model_for_explain",
|
|
362
|
+
writable_memoryset,
|
|
363
|
+
num_classes=2,
|
|
364
|
+
memory_lookup_count=3,
|
|
365
|
+
description="This is a test model for explain",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
predictions = model.predict(["Do you love soup?", "Are cats cute?"])
|
|
369
|
+
assert len(predictions) == 2
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
explanation = predictions[0].explanation
|
|
373
|
+
print(explanation)
|
|
374
|
+
assert explanation is not None
|
|
375
|
+
assert len(explanation) > 10
|
|
376
|
+
assert "soup" in explanation.lower()
|
|
377
|
+
except Exception as e:
|
|
378
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
379
|
+
logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set on server")
|
|
380
|
+
else:
|
|
381
|
+
raise e
|
|
382
|
+
finally:
|
|
383
|
+
try:
|
|
384
|
+
ClassificationModel.drop("test_model_for_explain")
|
|
385
|
+
except Exception as e:
|
|
386
|
+
logging.info(f"Failed to drop test model for explain: {e}")
|
orca_sdk/conftest.py
CHANGED
|
@@ -176,6 +176,7 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
|
|
|
176
176
|
|
|
177
177
|
if memory_ids:
|
|
178
178
|
memoryset.delete(memory_ids)
|
|
179
|
+
memoryset.refresh()
|
|
179
180
|
assert len(memoryset) == 0
|
|
180
181
|
memoryset.insert(SAMPLE_DATA)
|
|
181
182
|
# If the test dropped the memoryset, do nothing — it will be recreated on the next use.
|
orca_sdk/datasource.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import tempfile
|
|
5
|
+
import zipfile
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from os import PathLike
|
|
7
8
|
from pathlib import Path
|
|
@@ -84,6 +85,39 @@ class Datasource:
|
|
|
84
85
|
+ "})"
|
|
85
86
|
)
|
|
86
87
|
|
|
88
|
+
def download(self, output_path: str | PathLike) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Download the datasource as a ZIP and extract them to a specified path.
|
|
91
|
+
|
|
92
|
+
Params:
|
|
93
|
+
output_path: The local file path or directory where the downloaded files will be saved.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
None
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
RuntimeError: If the download fails.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
output_path = Path(output_path)
|
|
103
|
+
client = get_client().get_httpx_client()
|
|
104
|
+
url = f"/datasource/{self.id}/download"
|
|
105
|
+
response = client.get(url)
|
|
106
|
+
if response.status_code == 404:
|
|
107
|
+
raise LookupError(f"Datasource {self.id} not found")
|
|
108
|
+
if response.status_code != 200:
|
|
109
|
+
raise RuntimeError(f"Failed to download datasource: {response.status_code} {response.text}")
|
|
110
|
+
|
|
111
|
+
with tempfile.NamedTemporaryFile(suffix=".zip") as tmp_zip:
|
|
112
|
+
tmp_zip.write(response.content)
|
|
113
|
+
tmp_zip.flush()
|
|
114
|
+
with zipfile.ZipFile(tmp_zip.name, "r") as zf:
|
|
115
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
116
|
+
for file in zf.namelist():
|
|
117
|
+
out_file = output_path / Path(file).name
|
|
118
|
+
with zf.open(file) as af:
|
|
119
|
+
out_file.write_bytes(af.read())
|
|
120
|
+
|
|
87
121
|
@classmethod
|
|
88
122
|
def from_hf_dataset(
|
|
89
123
|
cls, name: str, dataset: Dataset, if_exists: CreateMode = "error", description: str | None = None
|
orca_sdk/datasource_test.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
1
3
|
from uuid import uuid4
|
|
2
4
|
|
|
3
5
|
import pytest
|
|
@@ -94,3 +96,10 @@ def test_drop_datasource_unauthorized(datasource, unauthorized):
|
|
|
94
96
|
def test_drop_datasource_invalid_input():
|
|
95
97
|
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
96
98
|
Datasource.drop("not valid id")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_download_datasource(datasource):
|
|
102
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
103
|
+
output_path = os.path.join(temp_dir, "datasource.zip")
|
|
104
|
+
datasource.download(output_path)
|
|
105
|
+
assert os.path.exists(output_path)
|
orca_sdk/memoryset_test.py
CHANGED
|
@@ -281,8 +281,10 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
|
|
|
281
281
|
dict(value="cats are fun to play with", label=1),
|
|
282
282
|
]
|
|
283
283
|
)
|
|
284
|
+
writable_memoryset.refresh()
|
|
284
285
|
assert writable_memoryset.length == prev_length + 2
|
|
285
286
|
writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
|
|
287
|
+
writable_memoryset.refresh()
|
|
286
288
|
assert writable_memoryset.length == prev_length + 3
|
|
287
289
|
last_memory = writable_memoryset[-1]
|
|
288
290
|
assert last_memory.value == "tomato soup is my favorite"
|
orca_sdk/telemetry.py
CHANGED
|
@@ -149,6 +149,7 @@ class LabelPrediction:
|
|
|
149
149
|
model: ClassificationModel | str,
|
|
150
150
|
telemetry: LabelPredictionWithMemoriesAndFeedback | None = None,
|
|
151
151
|
logits: list[float] | None = None,
|
|
152
|
+
input_value: str | list[list[float]] | None = None,
|
|
152
153
|
):
|
|
153
154
|
# for internal use only, do not document
|
|
154
155
|
from .classification_model import ClassificationModel
|
|
@@ -162,15 +163,14 @@ class LabelPrediction:
|
|
|
162
163
|
self.model = ClassificationModel.open(model) if isinstance(model, str) else model
|
|
163
164
|
self.__telemetry = telemetry if telemetry else None
|
|
164
165
|
self.logits = logits
|
|
166
|
+
self._input_value = input_value
|
|
165
167
|
|
|
166
168
|
def __repr__(self):
|
|
167
169
|
return (
|
|
168
170
|
"LabelPrediction({"
|
|
169
171
|
+ f"label: <{self.label_name}: {self.label}>, "
|
|
170
172
|
+ f"confidence: {self.confidence:.2f}, "
|
|
171
|
-
+ f"anomaly_score: {self.anomaly_score:.2f}, "
|
|
172
|
-
if self.anomaly_score is not None
|
|
173
|
-
else ""
|
|
173
|
+
+ (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
|
|
174
174
|
+ f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
|
|
175
175
|
+ "})"
|
|
176
176
|
)
|
|
@@ -188,6 +188,8 @@ class LabelPrediction:
|
|
|
188
188
|
|
|
189
189
|
@property
|
|
190
190
|
def input_value(self) -> str | list[list[float]] | None:
|
|
191
|
+
if self._input_value is not None:
|
|
192
|
+
return self._input_value
|
|
191
193
|
return self._telemetry.input_value
|
|
192
194
|
|
|
193
195
|
@property
|