orca-sdk 0.0.95__py3-none-any.whl → 0.0.97__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +1 -5
- orca_sdk/_generated_api_client/api/__init__.py +22 -2
- orca_sdk/_generated_api_client/api/{datasource/create_datasource_datasource_post.py → auth/create_org_plan_auth_org_plan_post.py} +32 -31
- orca_sdk/_generated_api_client/api/auth/get_org_plan_auth_org_plan_get.py +122 -0
- orca_sdk/_generated_api_client/api/auth/update_org_plan_auth_org_plan_put.py +168 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_content_datasource_post.py +224 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_from_files_datasource_upload_post.py +229 -0
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +21 -26
- orca_sdk/_generated_api_client/api/telemetry/generate_memory_suggestions_telemetry_prediction_prediction_id_memory_suggestions_post.py +239 -0
- orca_sdk/_generated_api_client/api/telemetry/get_action_recommendation_telemetry_prediction_prediction_id_action_get.py +192 -0
- orca_sdk/_generated_api_client/models/__init__.py +54 -4
- orca_sdk/_generated_api_client/models/action_recommendation.py +82 -0
- orca_sdk/_generated_api_client/models/action_recommendation_action.py +11 -0
- orca_sdk/_generated_api_client/models/add_memory_recommendations.py +85 -0
- orca_sdk/_generated_api_client/models/add_memory_suggestion.py +79 -0
- orca_sdk/_generated_api_client/models/body_create_datasource_from_files_datasource_upload_post.py +145 -0
- orca_sdk/_generated_api_client/models/class_representatives.py +92 -0
- orca_sdk/_generated_api_client/models/classification_model_metadata.py +14 -0
- orca_sdk/_generated_api_client/models/clone_memoryset_request.py +40 -0
- orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/constraint_violation_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/create_classification_model_request.py +40 -0
- orca_sdk/_generated_api_client/models/create_datasource_from_content_request.py +101 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request.py +40 -0
- orca_sdk/_generated_api_client/models/create_org_plan_request.py +73 -0
- orca_sdk/_generated_api_client/models/create_org_plan_request_tier.py +11 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +20 -0
- orca_sdk/_generated_api_client/models/embed_request.py +20 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +28 -10
- orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +28 -10
- orca_sdk/_generated_api_client/models/embedding_model_result.py +9 -0
- orca_sdk/_generated_api_client/models/filter_item.py +31 -23
- orca_sdk/_generated_api_client/models/filter_item_field_type_1_item_type_0.py +8 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_0.py +8 -0
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +2 -0
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/internal_server_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/labeled_memory.py +5 -5
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +16 -16
- orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +5 -5
- orca_sdk/_generated_api_client/models/lookup_request.py +20 -0
- orca_sdk/_generated_api_client/models/memory_metrics.py +98 -0
- orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +33 -0
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_analysis_config.py +79 -0
- orca_sdk/_generated_api_client/models/memoryset_class_patterns_metrics.py +138 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata.py +42 -0
- orca_sdk/_generated_api_client/models/memoryset_metrics.py +33 -0
- orca_sdk/_generated_api_client/models/memoryset_update.py +20 -0
- orca_sdk/_generated_api_client/models/not_found_error_response.py +6 -7
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/not_found_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/org_plan.py +99 -0
- orca_sdk/_generated_api_client/models/org_plan_tier.py +11 -0
- orca_sdk/_generated_api_client/models/paginated_task.py +108 -0
- orca_sdk/_generated_api_client/models/predictive_model_update.py +20 -0
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +8 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +14 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +9 -9
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/service_unavailable_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_0.py +8 -0
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_0.py +8 -0
- orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_1.py +8 -0
- orca_sdk/_generated_api_client/models/telemetry_filter_item.py +42 -30
- orca_sdk/_generated_api_client/models/telemetry_sort_options.py +42 -30
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/unauthenticated_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +8 -7
- orca_sdk/_generated_api_client/models/unauthorized_error_response_status_code.py +8 -0
- orca_sdk/_generated_api_client/models/update_org_plan_request.py +73 -0
- orca_sdk/_generated_api_client/models/update_org_plan_request_tier.py +11 -0
- orca_sdk/_shared/metrics.py +1 -1
- orca_sdk/classification_model.py +4 -1
- orca_sdk/classification_model_test.py +53 -0
- orca_sdk/credentials.py +15 -1
- orca_sdk/datasource.py +180 -41
- orca_sdk/datasource_test.py +194 -0
- orca_sdk/embedding_model.py +51 -13
- orca_sdk/embedding_model_test.py +27 -0
- orca_sdk/job.py +15 -14
- orca_sdk/job_test.py +34 -0
- orca_sdk/memoryset.py +47 -7
- orca_sdk/regression_model_test.py +0 -1
- orca_sdk/telemetry.py +94 -3
- {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/METADATA +18 -1
- {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/RECORD +87 -56
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -207
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -246
- {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/WHEEL +0 -0
|
@@ -11,12 +11,15 @@ The main change is:
|
|
|
11
11
|
# flake8: noqa: C901
|
|
12
12
|
|
|
13
13
|
from enum import Enum
|
|
14
|
-
from typing import Any, List,
|
|
14
|
+
from typing import Any, List, Type, TypeVar, Union, cast
|
|
15
15
|
|
|
16
16
|
from attrs import define as _attrs_define
|
|
17
17
|
from attrs import field as _attrs_field
|
|
18
18
|
|
|
19
|
+
from ..models.telemetry_field_type_0_item_type_0 import TelemetryFieldType0ItemType0
|
|
19
20
|
from ..models.telemetry_field_type_0_item_type_2 import TelemetryFieldType0ItemType2
|
|
21
|
+
from ..models.telemetry_field_type_1_item_type_0 import TelemetryFieldType1ItemType0
|
|
22
|
+
from ..models.telemetry_field_type_1_item_type_1 import TelemetryFieldType1ItemType1
|
|
20
23
|
from ..models.telemetry_sort_options_direction import TelemetrySortOptionsDirection
|
|
21
24
|
|
|
22
25
|
T = TypeVar("T", bound="TelemetrySortOptions")
|
|
@@ -26,25 +29,31 @@ T = TypeVar("T", bound="TelemetrySortOptions")
|
|
|
26
29
|
class TelemetrySortOptions:
|
|
27
30
|
"""
|
|
28
31
|
Attributes:
|
|
29
|
-
field (Union[List[Union[
|
|
30
|
-
|
|
32
|
+
field (Union[List[Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str]],
|
|
33
|
+
List[Union[TelemetryFieldType1ItemType0, TelemetryFieldType1ItemType1]]]):
|
|
31
34
|
direction (TelemetrySortOptionsDirection):
|
|
32
35
|
"""
|
|
33
36
|
|
|
34
37
|
field: Union[
|
|
35
|
-
List[Union[
|
|
36
|
-
List[Union[
|
|
38
|
+
List[Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str]],
|
|
39
|
+
List[Union[TelemetryFieldType1ItemType0, TelemetryFieldType1ItemType1]],
|
|
37
40
|
]
|
|
38
41
|
direction: TelemetrySortOptionsDirection
|
|
39
42
|
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
40
43
|
|
|
41
44
|
def to_dict(self) -> dict[str, Any]:
|
|
42
|
-
field:
|
|
45
|
+
field: List[str]
|
|
43
46
|
if isinstance(self.field, list):
|
|
44
47
|
field = []
|
|
45
48
|
for componentsschemas_telemetry_field_type_0_item_data in self.field:
|
|
46
|
-
componentsschemas_telemetry_field_type_0_item:
|
|
47
|
-
if isinstance(componentsschemas_telemetry_field_type_0_item_data,
|
|
49
|
+
componentsschemas_telemetry_field_type_0_item: str
|
|
50
|
+
if isinstance(componentsschemas_telemetry_field_type_0_item_data, TelemetryFieldType0ItemType0):
|
|
51
|
+
componentsschemas_telemetry_field_type_0_item = (
|
|
52
|
+
componentsschemas_telemetry_field_type_0_item_data.value
|
|
53
|
+
if isinstance(componentsschemas_telemetry_field_type_0_item_data, Enum)
|
|
54
|
+
else componentsschemas_telemetry_field_type_0_item_data
|
|
55
|
+
)
|
|
56
|
+
elif isinstance(componentsschemas_telemetry_field_type_0_item_data, TelemetryFieldType0ItemType2):
|
|
48
57
|
componentsschemas_telemetry_field_type_0_item = (
|
|
49
58
|
componentsschemas_telemetry_field_type_0_item_data.value
|
|
50
59
|
if isinstance(componentsschemas_telemetry_field_type_0_item_data, Enum)
|
|
@@ -74,8 +83,8 @@ class TelemetrySortOptions:
|
|
|
74
83
|
def _parse_field(
|
|
75
84
|
data: object,
|
|
76
85
|
) -> Union[
|
|
77
|
-
List[Union[
|
|
78
|
-
List[Union[
|
|
86
|
+
List[Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str]],
|
|
87
|
+
List[Union[TelemetryFieldType1ItemType0, TelemetryFieldType1ItemType1]],
|
|
79
88
|
]:
|
|
80
89
|
try:
|
|
81
90
|
if not isinstance(data, list):
|
|
@@ -86,13 +95,15 @@ class TelemetrySortOptions:
|
|
|
86
95
|
|
|
87
96
|
def _parse_componentsschemas_telemetry_field_type_0_item(
|
|
88
97
|
data: object,
|
|
89
|
-
) -> Union[
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
98
|
+
) -> Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str]:
|
|
99
|
+
try:
|
|
100
|
+
if not isinstance(data, str):
|
|
101
|
+
raise TypeError()
|
|
102
|
+
componentsschemas_telemetry_field_type_0_item_type_0 = TelemetryFieldType0ItemType0(data)
|
|
103
|
+
|
|
104
|
+
return componentsschemas_telemetry_field_type_0_item_type_0
|
|
105
|
+
except: # noqa: E722
|
|
106
|
+
pass
|
|
96
107
|
try:
|
|
97
108
|
if not isinstance(data, str):
|
|
98
109
|
raise TypeError()
|
|
@@ -101,7 +112,7 @@ class TelemetrySortOptions:
|
|
|
101
112
|
return componentsschemas_telemetry_field_type_0_item_type_2
|
|
102
113
|
except: # noqa: E722
|
|
103
114
|
pass
|
|
104
|
-
return cast(Union[
|
|
115
|
+
return cast(Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str], data)
|
|
105
116
|
|
|
106
117
|
componentsschemas_telemetry_field_type_0_item = (
|
|
107
118
|
_parse_componentsschemas_telemetry_field_type_0_item(
|
|
@@ -122,18 +133,19 @@ class TelemetrySortOptions:
|
|
|
122
133
|
|
|
123
134
|
def _parse_componentsschemas_telemetry_field_type_1_item(
|
|
124
135
|
data: object,
|
|
125
|
-
) -> Union[
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
136
|
+
) -> Union[TelemetryFieldType1ItemType0, TelemetryFieldType1ItemType1]:
|
|
137
|
+
try:
|
|
138
|
+
if not isinstance(data, str):
|
|
139
|
+
raise TypeError()
|
|
140
|
+
componentsschemas_telemetry_field_type_1_item_type_0 = TelemetryFieldType1ItemType0(data)
|
|
141
|
+
|
|
142
|
+
return componentsschemas_telemetry_field_type_1_item_type_0
|
|
143
|
+
except: # noqa: E722
|
|
144
|
+
pass
|
|
145
|
+
if not isinstance(data, str):
|
|
146
|
+
raise TypeError()
|
|
147
|
+
componentsschemas_telemetry_field_type_1_item_type_1 = TelemetryFieldType1ItemType1(data)
|
|
148
|
+
|
|
137
149
|
return componentsschemas_telemetry_field_type_1_item_type_1
|
|
138
150
|
|
|
139
151
|
componentsschemas_telemetry_field_type_1_item = _parse_componentsschemas_telemetry_field_type_1_item(
|
|
@@ -10,11 +10,14 @@ The main change is:
|
|
|
10
10
|
|
|
11
11
|
# flake8: noqa: C901
|
|
12
12
|
|
|
13
|
-
from
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Any, Type, TypeVar
|
|
14
15
|
|
|
15
16
|
from attrs import define as _attrs_define
|
|
16
17
|
from attrs import field as _attrs_field
|
|
17
18
|
|
|
19
|
+
from ..models.unauthenticated_error_response_status_code import UnauthenticatedErrorResponseStatusCode
|
|
20
|
+
|
|
18
21
|
T = TypeVar("T", bound="UnauthenticatedErrorResponse")
|
|
19
22
|
|
|
20
23
|
|
|
@@ -22,14 +25,14 @@ T = TypeVar("T", bound="UnauthenticatedErrorResponse")
|
|
|
22
25
|
class UnauthenticatedErrorResponse:
|
|
23
26
|
"""
|
|
24
27
|
Attributes:
|
|
25
|
-
status_code (
|
|
28
|
+
status_code (UnauthenticatedErrorResponseStatusCode):
|
|
26
29
|
"""
|
|
27
30
|
|
|
28
|
-
status_code:
|
|
31
|
+
status_code: UnauthenticatedErrorResponseStatusCode
|
|
29
32
|
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
30
33
|
|
|
31
34
|
def to_dict(self) -> dict[str, Any]:
|
|
32
|
-
status_code = self.status_code
|
|
35
|
+
status_code = self.status_code.value if isinstance(self.status_code, Enum) else self.status_code
|
|
33
36
|
|
|
34
37
|
field_dict: dict[str, Any] = {}
|
|
35
38
|
field_dict.update(self.additional_properties)
|
|
@@ -44,9 +47,7 @@ class UnauthenticatedErrorResponse:
|
|
|
44
47
|
@classmethod
|
|
45
48
|
def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
|
|
46
49
|
d = src_dict.copy()
|
|
47
|
-
status_code =
|
|
48
|
-
if status_code != 401:
|
|
49
|
-
raise ValueError(f"status_code must match const 401, got '{status_code}'")
|
|
50
|
+
status_code = UnauthenticatedErrorResponseStatusCode(d.pop("status_code"))
|
|
50
51
|
|
|
51
52
|
unauthenticated_error_response = cls(
|
|
52
53
|
status_code=status_code,
|
|
@@ -10,11 +10,14 @@ The main change is:
|
|
|
10
10
|
|
|
11
11
|
# flake8: noqa: C901
|
|
12
12
|
|
|
13
|
-
from
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Any, Type, TypeVar
|
|
14
15
|
|
|
15
16
|
from attrs import define as _attrs_define
|
|
16
17
|
from attrs import field as _attrs_field
|
|
17
18
|
|
|
19
|
+
from ..models.unauthorized_error_response_status_code import UnauthorizedErrorResponseStatusCode
|
|
20
|
+
|
|
18
21
|
T = TypeVar("T", bound="UnauthorizedErrorResponse")
|
|
19
22
|
|
|
20
23
|
|
|
@@ -22,16 +25,16 @@ T = TypeVar("T", bound="UnauthorizedErrorResponse")
|
|
|
22
25
|
class UnauthorizedErrorResponse:
|
|
23
26
|
"""
|
|
24
27
|
Attributes:
|
|
25
|
-
status_code (
|
|
28
|
+
status_code (UnauthorizedErrorResponseStatusCode):
|
|
26
29
|
reason (str):
|
|
27
30
|
"""
|
|
28
31
|
|
|
29
|
-
status_code:
|
|
32
|
+
status_code: UnauthorizedErrorResponseStatusCode
|
|
30
33
|
reason: str
|
|
31
34
|
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
32
35
|
|
|
33
36
|
def to_dict(self) -> dict[str, Any]:
|
|
34
|
-
status_code = self.status_code
|
|
37
|
+
status_code = self.status_code.value if isinstance(self.status_code, Enum) else self.status_code
|
|
35
38
|
|
|
36
39
|
reason = self.reason
|
|
37
40
|
|
|
@@ -49,9 +52,7 @@ class UnauthorizedErrorResponse:
|
|
|
49
52
|
@classmethod
|
|
50
53
|
def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
|
|
51
54
|
d = src_dict.copy()
|
|
52
|
-
status_code =
|
|
53
|
-
if status_code != 403:
|
|
54
|
-
raise ValueError(f"status_code must match const 403, got '{status_code}'")
|
|
55
|
+
status_code = UnauthorizedErrorResponseStatusCode(d.pop("status_code"))
|
|
55
56
|
|
|
56
57
|
reason = d.pop("reason")
|
|
57
58
|
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This file is generated by the openapi-python-client tool via the generate_api_client.py script
|
|
3
|
+
|
|
4
|
+
It is a customized template from the openapi-python-client tool's default template:
|
|
5
|
+
https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
|
|
6
|
+
|
|
7
|
+
The main change is:
|
|
8
|
+
- Fix typing issues
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# flake8: noqa: C901
|
|
12
|
+
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Any, Type, TypeVar
|
|
15
|
+
|
|
16
|
+
from attrs import define as _attrs_define
|
|
17
|
+
from attrs import field as _attrs_field
|
|
18
|
+
|
|
19
|
+
from ..models.update_org_plan_request_tier import UpdateOrgPlanRequestTier
|
|
20
|
+
|
|
21
|
+
T = TypeVar("T", bound="UpdateOrgPlanRequest")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@_attrs_define
|
|
25
|
+
class UpdateOrgPlanRequest:
|
|
26
|
+
"""
|
|
27
|
+
Attributes:
|
|
28
|
+
tier (UpdateOrgPlanRequestTier):
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
tier: UpdateOrgPlanRequestTier
|
|
32
|
+
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
33
|
+
|
|
34
|
+
def to_dict(self) -> dict[str, Any]:
|
|
35
|
+
tier = self.tier.value if isinstance(self.tier, Enum) else self.tier
|
|
36
|
+
|
|
37
|
+
field_dict: dict[str, Any] = {}
|
|
38
|
+
field_dict.update(self.additional_properties)
|
|
39
|
+
field_dict.update(
|
|
40
|
+
{
|
|
41
|
+
"tier": tier,
|
|
42
|
+
}
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
return field_dict
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
|
|
49
|
+
d = src_dict.copy()
|
|
50
|
+
tier = UpdateOrgPlanRequestTier(d.pop("tier"))
|
|
51
|
+
|
|
52
|
+
update_org_plan_request = cls(
|
|
53
|
+
tier=tier,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
update_org_plan_request.additional_properties = d
|
|
57
|
+
return update_org_plan_request
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def additional_keys(self) -> list[str]:
|
|
61
|
+
return list(self.additional_properties.keys())
|
|
62
|
+
|
|
63
|
+
def __getitem__(self, key: str) -> Any:
|
|
64
|
+
return self.additional_properties[key]
|
|
65
|
+
|
|
66
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
67
|
+
self.additional_properties[key] = value
|
|
68
|
+
|
|
69
|
+
def __delitem__(self, key: str) -> None:
|
|
70
|
+
del self.additional_properties[key]
|
|
71
|
+
|
|
72
|
+
def __contains__(self, key: str) -> bool:
|
|
73
|
+
return key in self.additional_properties
|
orca_sdk/_shared/metrics.py
CHANGED
|
@@ -24,7 +24,7 @@ def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
# We don't want to depend on transformers just for the eval_pred type in orca_sdk
|
|
27
|
-
def transform_eval_pred(eval_pred: Any) -> tuple[NDArray
|
|
27
|
+
def transform_eval_pred(eval_pred: Any) -> tuple[NDArray, NDArray[np.float32]]:
|
|
28
28
|
# convert results from Trainer compute_metrics param for use in calculate_classification_metrics
|
|
29
29
|
logits, references = eval_pred # transformers.trainer_utils.EvalPrediction
|
|
30
30
|
if isinstance(logits, tuple):
|
orca_sdk/classification_model.py
CHANGED
|
@@ -35,7 +35,10 @@ from ._generated_api_client.models import (
|
|
|
35
35
|
from ._generated_api_client.models import (
|
|
36
36
|
PredictionSortItemItemType1 as PredictionSortDirection,
|
|
37
37
|
)
|
|
38
|
-
from ._generated_api_client.models import
|
|
38
|
+
from ._generated_api_client.models import (
|
|
39
|
+
PredictiveModelUpdate,
|
|
40
|
+
RACHeadType,
|
|
41
|
+
)
|
|
39
42
|
from ._generated_api_client.types import UNSET as CLIENT_UNSET
|
|
40
43
|
from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
|
|
41
44
|
from ._utils.common import UNSET, CreateMode, DropMode
|
|
@@ -374,3 +374,56 @@ def test_explain(writable_memoryset: LabeledMemoryset):
|
|
|
374
374
|
raise e
|
|
375
375
|
finally:
|
|
376
376
|
ClassificationModel.drop("test_model_for_explain")
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@skip_in_ci("We don't have Anthropic API key in CI")
|
|
380
|
+
def test_action_recommendation(writable_memoryset: LabeledMemoryset):
|
|
381
|
+
"""Test getting action recommendations for predictions"""
|
|
382
|
+
|
|
383
|
+
writable_memoryset.analyze(
|
|
384
|
+
{"name": "neighbor", "neighbor_counts": [1, 3]},
|
|
385
|
+
lookup_count=3,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
model = ClassificationModel.create(
|
|
389
|
+
"test_model_for_action",
|
|
390
|
+
writable_memoryset,
|
|
391
|
+
num_classes=2,
|
|
392
|
+
memory_lookup_count=3,
|
|
393
|
+
description="This is a test model for action recommendations",
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Make a prediction with expected label to simulate incorrect prediction
|
|
397
|
+
prediction = model.predict("Do you love soup?", expected_labels=1)
|
|
398
|
+
|
|
399
|
+
try:
|
|
400
|
+
# Get action recommendation
|
|
401
|
+
action, rationale = prediction.recommend_action()
|
|
402
|
+
|
|
403
|
+
assert action is not None
|
|
404
|
+
assert rationale is not None
|
|
405
|
+
assert action in ["remove_duplicates", "detect_mislabels", "add_memories", "finetuning"]
|
|
406
|
+
assert len(rationale) > 10
|
|
407
|
+
|
|
408
|
+
# Test memory suggestions
|
|
409
|
+
memory_suggestions = prediction.generate_memory_suggestions(num_memories=2)
|
|
410
|
+
|
|
411
|
+
assert memory_suggestions is not None
|
|
412
|
+
assert len(memory_suggestions) == 2
|
|
413
|
+
|
|
414
|
+
for suggestion in memory_suggestions:
|
|
415
|
+
assert isinstance(suggestion, dict)
|
|
416
|
+
assert "value" in suggestion
|
|
417
|
+
assert "label" in suggestion
|
|
418
|
+
assert isinstance(suggestion["value"], str)
|
|
419
|
+
assert len(suggestion["value"]) > 0
|
|
420
|
+
assert isinstance(suggestion["label"], int)
|
|
421
|
+
assert 0 <= suggestion["label"] < len(model.memoryset.label_names)
|
|
422
|
+
|
|
423
|
+
except Exception as e:
|
|
424
|
+
if "ANTHROPIC_API_KEY" in str(e):
|
|
425
|
+
logging.info("Skipping agent tests because ANTHROPIC_API_KEY is not set")
|
|
426
|
+
else:
|
|
427
|
+
raise e
|
|
428
|
+
finally:
|
|
429
|
+
ClassificationModel.drop("test_model_for_action")
|
orca_sdk/credentials.py
CHANGED
|
@@ -106,6 +106,20 @@ class OrcaCredentials:
|
|
|
106
106
|
"""
|
|
107
107
|
delete_api_key(name_or_id=name)
|
|
108
108
|
|
|
109
|
+
@staticmethod
|
|
110
|
+
def set_headers(headers: dict[str, str]):
|
|
111
|
+
"""
|
|
112
|
+
Add or override default HTTP headers for all Orca API requests.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
**kwargs: Header names with their string values
|
|
116
|
+
|
|
117
|
+
Notes:
|
|
118
|
+
New keys are merged into the existing headers, this will overwrite headers with the
|
|
119
|
+
same name, but leave other headers untouched.
|
|
120
|
+
"""
|
|
121
|
+
set_headers(get_headers() | headers)
|
|
122
|
+
|
|
109
123
|
@staticmethod
|
|
110
124
|
def set_api_key(api_key: str, check_validity: bool = True):
|
|
111
125
|
"""
|
|
@@ -121,6 +135,6 @@ class OrcaCredentials:
|
|
|
121
135
|
Raises:
|
|
122
136
|
ValueError: if the API key is invalid and `check_validity` is True
|
|
123
137
|
"""
|
|
124
|
-
set_headers(
|
|
138
|
+
OrcaCredentials.set_headers({"Api-Key": api_key})
|
|
125
139
|
if check_validity:
|
|
126
140
|
check_authentication()
|