orca-sdk 0.0.95__py3-none-any.whl → 0.0.97__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. orca_sdk/__init__.py +1 -5
  2. orca_sdk/_generated_api_client/api/__init__.py +22 -2
  3. orca_sdk/_generated_api_client/api/{datasource/create_datasource_datasource_post.py → auth/create_org_plan_auth_org_plan_post.py} +32 -31
  4. orca_sdk/_generated_api_client/api/auth/get_org_plan_auth_org_plan_get.py +122 -0
  5. orca_sdk/_generated_api_client/api/auth/update_org_plan_auth_org_plan_put.py +168 -0
  6. orca_sdk/_generated_api_client/api/datasource/create_datasource_from_content_datasource_post.py +224 -0
  7. orca_sdk/_generated_api_client/api/datasource/create_datasource_from_files_datasource_upload_post.py +229 -0
  8. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +21 -26
  9. orca_sdk/_generated_api_client/api/telemetry/generate_memory_suggestions_telemetry_prediction_prediction_id_memory_suggestions_post.py +239 -0
  10. orca_sdk/_generated_api_client/api/telemetry/get_action_recommendation_telemetry_prediction_prediction_id_action_get.py +192 -0
  11. orca_sdk/_generated_api_client/models/__init__.py +54 -4
  12. orca_sdk/_generated_api_client/models/action_recommendation.py +82 -0
  13. orca_sdk/_generated_api_client/models/action_recommendation_action.py +11 -0
  14. orca_sdk/_generated_api_client/models/add_memory_recommendations.py +85 -0
  15. orca_sdk/_generated_api_client/models/add_memory_suggestion.py +79 -0
  16. orca_sdk/_generated_api_client/models/body_create_datasource_from_files_datasource_upload_post.py +145 -0
  17. orca_sdk/_generated_api_client/models/class_representatives.py +92 -0
  18. orca_sdk/_generated_api_client/models/classification_model_metadata.py +14 -0
  19. orca_sdk/_generated_api_client/models/clone_memoryset_request.py +40 -0
  20. orca_sdk/_generated_api_client/models/constraint_violation_error_response.py +8 -7
  21. orca_sdk/_generated_api_client/models/constraint_violation_error_response_status_code.py +8 -0
  22. orca_sdk/_generated_api_client/models/create_classification_model_request.py +40 -0
  23. orca_sdk/_generated_api_client/models/create_datasource_from_content_request.py +101 -0
  24. orca_sdk/_generated_api_client/models/create_memoryset_request.py +40 -0
  25. orca_sdk/_generated_api_client/models/create_org_plan_request.py +73 -0
  26. orca_sdk/_generated_api_client/models/create_org_plan_request_tier.py +11 -0
  27. orca_sdk/_generated_api_client/models/create_regression_model_request.py +20 -0
  28. orca_sdk/_generated_api_client/models/embed_request.py +20 -0
  29. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +28 -10
  30. orca_sdk/_generated_api_client/models/embedding_evaluation_request.py +28 -10
  31. orca_sdk/_generated_api_client/models/embedding_model_result.py +9 -0
  32. orca_sdk/_generated_api_client/models/filter_item.py +31 -23
  33. orca_sdk/_generated_api_client/models/filter_item_field_type_1_item_type_0.py +8 -0
  34. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_0.py +8 -0
  35. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +2 -0
  36. orca_sdk/_generated_api_client/models/internal_server_error_response.py +8 -7
  37. orca_sdk/_generated_api_client/models/internal_server_error_response_status_code.py +8 -0
  38. orca_sdk/_generated_api_client/models/labeled_memory.py +5 -5
  39. orca_sdk/_generated_api_client/models/labeled_memory_update.py +16 -16
  40. orca_sdk/_generated_api_client/models/labeled_memory_with_feedback_metrics.py +5 -5
  41. orca_sdk/_generated_api_client/models/lookup_request.py +20 -0
  42. orca_sdk/_generated_api_client/models/memory_metrics.py +98 -0
  43. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +33 -0
  44. orca_sdk/_generated_api_client/models/memoryset_class_patterns_analysis_config.py +79 -0
  45. orca_sdk/_generated_api_client/models/memoryset_class_patterns_metrics.py +138 -0
  46. orca_sdk/_generated_api_client/models/memoryset_metadata.py +42 -0
  47. orca_sdk/_generated_api_client/models/memoryset_metrics.py +33 -0
  48. orca_sdk/_generated_api_client/models/memoryset_update.py +20 -0
  49. orca_sdk/_generated_api_client/models/not_found_error_response.py +6 -7
  50. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  51. orca_sdk/_generated_api_client/models/not_found_error_response_status_code.py +8 -0
  52. orca_sdk/_generated_api_client/models/org_plan.py +99 -0
  53. orca_sdk/_generated_api_client/models/org_plan_tier.py +11 -0
  54. orca_sdk/_generated_api_client/models/paginated_task.py +108 -0
  55. orca_sdk/_generated_api_client/models/predictive_model_update.py +20 -0
  56. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +8 -0
  57. orca_sdk/_generated_api_client/models/regression_model_metadata.py +14 -0
  58. orca_sdk/_generated_api_client/models/scored_memory_update.py +9 -9
  59. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +8 -7
  60. orca_sdk/_generated_api_client/models/service_unavailable_error_response_status_code.py +8 -0
  61. orca_sdk/_generated_api_client/models/telemetry_field_type_0_item_type_0.py +8 -0
  62. orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_0.py +8 -0
  63. orca_sdk/_generated_api_client/models/telemetry_field_type_1_item_type_1.py +8 -0
  64. orca_sdk/_generated_api_client/models/telemetry_filter_item.py +42 -30
  65. orca_sdk/_generated_api_client/models/telemetry_sort_options.py +42 -30
  66. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +8 -7
  67. orca_sdk/_generated_api_client/models/unauthenticated_error_response_status_code.py +8 -0
  68. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +8 -7
  69. orca_sdk/_generated_api_client/models/unauthorized_error_response_status_code.py +8 -0
  70. orca_sdk/_generated_api_client/models/update_org_plan_request.py +73 -0
  71. orca_sdk/_generated_api_client/models/update_org_plan_request_tier.py +11 -0
  72. orca_sdk/_shared/metrics.py +1 -1
  73. orca_sdk/classification_model.py +4 -1
  74. orca_sdk/classification_model_test.py +53 -0
  75. orca_sdk/credentials.py +15 -1
  76. orca_sdk/datasource.py +180 -41
  77. orca_sdk/datasource_test.py +194 -0
  78. orca_sdk/embedding_model.py +51 -13
  79. orca_sdk/embedding_model_test.py +27 -0
  80. orca_sdk/job.py +15 -14
  81. orca_sdk/job_test.py +34 -0
  82. orca_sdk/memoryset.py +47 -7
  83. orca_sdk/regression_model_test.py +0 -1
  84. orca_sdk/telemetry.py +94 -3
  85. {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/METADATA +18 -1
  86. {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/RECORD +87 -56
  87. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -207
  88. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -246
  89. {orca_sdk-0.0.95.dist-info → orca_sdk-0.0.97.dist-info}/WHEEL +0 -0
@@ -11,12 +11,15 @@ The main change is:
11
11
  # flake8: noqa: C901
12
12
 
13
13
  from enum import Enum
14
- from typing import Any, List, Literal, Type, TypeVar, Union, cast
14
+ from typing import Any, List, Type, TypeVar, Union, cast
15
15
 
16
16
  from attrs import define as _attrs_define
17
17
  from attrs import field as _attrs_field
18
18
 
19
+ from ..models.telemetry_field_type_0_item_type_0 import TelemetryFieldType0ItemType0
19
20
  from ..models.telemetry_field_type_0_item_type_2 import TelemetryFieldType0ItemType2
21
+ from ..models.telemetry_field_type_1_item_type_0 import TelemetryFieldType1ItemType0
22
+ from ..models.telemetry_field_type_1_item_type_1 import TelemetryFieldType1ItemType1
20
23
  from ..models.telemetry_sort_options_direction import TelemetrySortOptionsDirection
21
24
 
22
25
  T = TypeVar("T", bound="TelemetrySortOptions")
@@ -26,25 +29,31 @@ T = TypeVar("T", bound="TelemetrySortOptions")
26
29
  class TelemetrySortOptions:
27
30
  """
28
31
  Attributes:
29
- field (Union[List[Union[Literal['count'], Literal['lookup']]], List[Union[Literal['feedback_metrics'],
30
- TelemetryFieldType0ItemType2, str]]]):
32
+ field (Union[List[Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str]],
33
+ List[Union[TelemetryFieldType1ItemType0, TelemetryFieldType1ItemType1]]]):
31
34
  direction (TelemetrySortOptionsDirection):
32
35
  """
33
36
 
34
37
  field: Union[
35
- List[Union[Literal["count"], Literal["lookup"]]],
36
- List[Union[Literal["feedback_metrics"], TelemetryFieldType0ItemType2, str]],
38
+ List[Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str]],
39
+ List[Union[TelemetryFieldType1ItemType0, TelemetryFieldType1ItemType1]],
37
40
  ]
38
41
  direction: TelemetrySortOptionsDirection
39
42
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
40
43
 
41
44
  def to_dict(self) -> dict[str, Any]:
42
- field: Union[List[Union[Literal["count"], Literal["lookup"]]], List[Union[Literal["feedback_metrics"], str]]]
45
+ field: List[str]
43
46
  if isinstance(self.field, list):
44
47
  field = []
45
48
  for componentsschemas_telemetry_field_type_0_item_data in self.field:
46
- componentsschemas_telemetry_field_type_0_item: Union[Literal["feedback_metrics"], str]
47
- if isinstance(componentsschemas_telemetry_field_type_0_item_data, TelemetryFieldType0ItemType2):
49
+ componentsschemas_telemetry_field_type_0_item: str
50
+ if isinstance(componentsschemas_telemetry_field_type_0_item_data, TelemetryFieldType0ItemType0):
51
+ componentsschemas_telemetry_field_type_0_item = (
52
+ componentsschemas_telemetry_field_type_0_item_data.value
53
+ if isinstance(componentsschemas_telemetry_field_type_0_item_data, Enum)
54
+ else componentsschemas_telemetry_field_type_0_item_data
55
+ )
56
+ elif isinstance(componentsschemas_telemetry_field_type_0_item_data, TelemetryFieldType0ItemType2):
48
57
  componentsschemas_telemetry_field_type_0_item = (
49
58
  componentsschemas_telemetry_field_type_0_item_data.value
50
59
  if isinstance(componentsschemas_telemetry_field_type_0_item_data, Enum)
@@ -74,8 +83,8 @@ class TelemetrySortOptions:
74
83
  def _parse_field(
75
84
  data: object,
76
85
  ) -> Union[
77
- List[Union[Literal["count"], Literal["lookup"]]],
78
- List[Union[Literal["feedback_metrics"], TelemetryFieldType0ItemType2, str]],
86
+ List[Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str]],
87
+ List[Union[TelemetryFieldType1ItemType0, TelemetryFieldType1ItemType1]],
79
88
  ]:
80
89
  try:
81
90
  if not isinstance(data, list):
@@ -86,13 +95,15 @@ class TelemetrySortOptions:
86
95
 
87
96
  def _parse_componentsschemas_telemetry_field_type_0_item(
88
97
  data: object,
89
- ) -> Union[Literal["feedback_metrics"], TelemetryFieldType0ItemType2, str]:
90
- componentsschemas_telemetry_field_type_0_item_type_0 = cast(Literal["feedback_metrics"], data)
91
- if componentsschemas_telemetry_field_type_0_item_type_0 != "feedback_metrics":
92
- raise ValueError(
93
- f"/components/schemas/TelemetryField_type_0_item_type_0 must match const 'feedback_metrics', got '{componentsschemas_telemetry_field_type_0_item_type_0}'"
94
- )
95
- return componentsschemas_telemetry_field_type_0_item_type_0
98
+ ) -> Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str]:
99
+ try:
100
+ if not isinstance(data, str):
101
+ raise TypeError()
102
+ componentsschemas_telemetry_field_type_0_item_type_0 = TelemetryFieldType0ItemType0(data)
103
+
104
+ return componentsschemas_telemetry_field_type_0_item_type_0
105
+ except: # noqa: E722
106
+ pass
96
107
  try:
97
108
  if not isinstance(data, str):
98
109
  raise TypeError()
@@ -101,7 +112,7 @@ class TelemetrySortOptions:
101
112
  return componentsschemas_telemetry_field_type_0_item_type_2
102
113
  except: # noqa: E722
103
114
  pass
104
- return cast(Union[Literal["feedback_metrics"], TelemetryFieldType0ItemType2, str], data)
115
+ return cast(Union[TelemetryFieldType0ItemType0, TelemetryFieldType0ItemType2, str], data)
105
116
 
106
117
  componentsschemas_telemetry_field_type_0_item = (
107
118
  _parse_componentsschemas_telemetry_field_type_0_item(
@@ -122,18 +133,19 @@ class TelemetrySortOptions:
122
133
 
123
134
  def _parse_componentsschemas_telemetry_field_type_1_item(
124
135
  data: object,
125
- ) -> Union[Literal["count"], Literal["lookup"]]:
126
- componentsschemas_telemetry_field_type_1_item_type_0 = cast(Literal["lookup"], data)
127
- if componentsschemas_telemetry_field_type_1_item_type_0 != "lookup":
128
- raise ValueError(
129
- f"/components/schemas/TelemetryField_type_1_item_type_0 must match const 'lookup', got '{componentsschemas_telemetry_field_type_1_item_type_0}'"
130
- )
131
- return componentsschemas_telemetry_field_type_1_item_type_0
132
- componentsschemas_telemetry_field_type_1_item_type_1 = cast(Literal["count"], data)
133
- if componentsschemas_telemetry_field_type_1_item_type_1 != "count":
134
- raise ValueError(
135
- f"/components/schemas/TelemetryField_type_1_item_type_1 must match const 'count', got '{componentsschemas_telemetry_field_type_1_item_type_1}'"
136
- )
136
+ ) -> Union[TelemetryFieldType1ItemType0, TelemetryFieldType1ItemType1]:
137
+ try:
138
+ if not isinstance(data, str):
139
+ raise TypeError()
140
+ componentsschemas_telemetry_field_type_1_item_type_0 = TelemetryFieldType1ItemType0(data)
141
+
142
+ return componentsschemas_telemetry_field_type_1_item_type_0
143
+ except: # noqa: E722
144
+ pass
145
+ if not isinstance(data, str):
146
+ raise TypeError()
147
+ componentsschemas_telemetry_field_type_1_item_type_1 = TelemetryFieldType1ItemType1(data)
148
+
137
149
  return componentsschemas_telemetry_field_type_1_item_type_1
138
150
 
139
151
  componentsschemas_telemetry_field_type_1_item = _parse_componentsschemas_telemetry_field_type_1_item(
@@ -10,11 +10,14 @@ The main change is:
10
10
 
11
11
  # flake8: noqa: C901
12
12
 
13
- from typing import Any, Literal, Type, TypeVar, cast
13
+ from enum import Enum
14
+ from typing import Any, Type, TypeVar
14
15
 
15
16
  from attrs import define as _attrs_define
16
17
  from attrs import field as _attrs_field
17
18
 
19
+ from ..models.unauthenticated_error_response_status_code import UnauthenticatedErrorResponseStatusCode
20
+
18
21
  T = TypeVar("T", bound="UnauthenticatedErrorResponse")
19
22
 
20
23
 
@@ -22,14 +25,14 @@ T = TypeVar("T", bound="UnauthenticatedErrorResponse")
22
25
  class UnauthenticatedErrorResponse:
23
26
  """
24
27
  Attributes:
25
- status_code (Literal[401]):
28
+ status_code (UnauthenticatedErrorResponseStatusCode):
26
29
  """
27
30
 
28
- status_code: Literal[401]
31
+ status_code: UnauthenticatedErrorResponseStatusCode
29
32
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
30
33
 
31
34
  def to_dict(self) -> dict[str, Any]:
32
- status_code = self.status_code
35
+ status_code = self.status_code.value if isinstance(self.status_code, Enum) else self.status_code
33
36
 
34
37
  field_dict: dict[str, Any] = {}
35
38
  field_dict.update(self.additional_properties)
@@ -44,9 +47,7 @@ class UnauthenticatedErrorResponse:
44
47
  @classmethod
45
48
  def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
46
49
  d = src_dict.copy()
47
- status_code = cast(Literal[401], d.pop("status_code"))
48
- if status_code != 401:
49
- raise ValueError(f"status_code must match const 401, got '{status_code}'")
50
+ status_code = UnauthenticatedErrorResponseStatusCode(d.pop("status_code"))
50
51
 
51
52
  unauthenticated_error_response = cls(
52
53
  status_code=status_code,
@@ -0,0 +1,8 @@
1
+ from enum import IntEnum
2
+
3
+
4
+ class UnauthenticatedErrorResponseStatusCode(IntEnum):
5
+ VALUE_401 = 401
6
+
7
+ def __str__(self) -> str:
8
+ return str(self.value)
@@ -10,11 +10,14 @@ The main change is:
10
10
 
11
11
  # flake8: noqa: C901
12
12
 
13
- from typing import Any, Literal, Type, TypeVar, cast
13
+ from enum import Enum
14
+ from typing import Any, Type, TypeVar
14
15
 
15
16
  from attrs import define as _attrs_define
16
17
  from attrs import field as _attrs_field
17
18
 
19
+ from ..models.unauthorized_error_response_status_code import UnauthorizedErrorResponseStatusCode
20
+
18
21
  T = TypeVar("T", bound="UnauthorizedErrorResponse")
19
22
 
20
23
 
@@ -22,16 +25,16 @@ T = TypeVar("T", bound="UnauthorizedErrorResponse")
22
25
  class UnauthorizedErrorResponse:
23
26
  """
24
27
  Attributes:
25
- status_code (Literal[403]):
28
+ status_code (UnauthorizedErrorResponseStatusCode):
26
29
  reason (str):
27
30
  """
28
31
 
29
- status_code: Literal[403]
32
+ status_code: UnauthorizedErrorResponseStatusCode
30
33
  reason: str
31
34
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
32
35
 
33
36
  def to_dict(self) -> dict[str, Any]:
34
- status_code = self.status_code
37
+ status_code = self.status_code.value if isinstance(self.status_code, Enum) else self.status_code
35
38
 
36
39
  reason = self.reason
37
40
 
@@ -49,9 +52,7 @@ class UnauthorizedErrorResponse:
49
52
  @classmethod
50
53
  def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
51
54
  d = src_dict.copy()
52
- status_code = cast(Literal[403], d.pop("status_code"))
53
- if status_code != 403:
54
- raise ValueError(f"status_code must match const 403, got '{status_code}'")
55
+ status_code = UnauthorizedErrorResponseStatusCode(d.pop("status_code"))
55
56
 
56
57
  reason = d.pop("reason")
57
58
 
@@ -0,0 +1,8 @@
1
+ from enum import IntEnum
2
+
3
+
4
+ class UnauthorizedErrorResponseStatusCode(IntEnum):
5
+ VALUE_403 = 403
6
+
7
+ def __str__(self) -> str:
8
+ return str(self.value)
@@ -0,0 +1,73 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from enum import Enum
14
+ from typing import Any, Type, TypeVar
15
+
16
+ from attrs import define as _attrs_define
17
+ from attrs import field as _attrs_field
18
+
19
+ from ..models.update_org_plan_request_tier import UpdateOrgPlanRequestTier
20
+
21
+ T = TypeVar("T", bound="UpdateOrgPlanRequest")
22
+
23
+
24
+ @_attrs_define
25
+ class UpdateOrgPlanRequest:
26
+ """
27
+ Attributes:
28
+ tier (UpdateOrgPlanRequestTier):
29
+ """
30
+
31
+ tier: UpdateOrgPlanRequestTier
32
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
33
+
34
+ def to_dict(self) -> dict[str, Any]:
35
+ tier = self.tier.value if isinstance(self.tier, Enum) else self.tier
36
+
37
+ field_dict: dict[str, Any] = {}
38
+ field_dict.update(self.additional_properties)
39
+ field_dict.update(
40
+ {
41
+ "tier": tier,
42
+ }
43
+ )
44
+
45
+ return field_dict
46
+
47
+ @classmethod
48
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
49
+ d = src_dict.copy()
50
+ tier = UpdateOrgPlanRequestTier(d.pop("tier"))
51
+
52
+ update_org_plan_request = cls(
53
+ tier=tier,
54
+ )
55
+
56
+ update_org_plan_request.additional_properties = d
57
+ return update_org_plan_request
58
+
59
+ @property
60
+ def additional_keys(self) -> list[str]:
61
+ return list(self.additional_properties.keys())
62
+
63
+ def __getitem__(self, key: str) -> Any:
64
+ return self.additional_properties[key]
65
+
66
+ def __setitem__(self, key: str, value: Any) -> None:
67
+ self.additional_properties[key] = value
68
+
69
+ def __delitem__(self, key: str) -> None:
70
+ del self.additional_properties[key]
71
+
72
+ def __contains__(self, key: str) -> bool:
73
+ return key in self.additional_properties
@@ -0,0 +1,11 @@
1
+ from enum import Enum
2
+
3
+
4
+ class UpdateOrgPlanRequestTier(str, Enum):
5
+ CANCELLED = "CANCELLED"
6
+ ENTERPRISE = "ENTERPRISE"
7
+ FREE = "FREE"
8
+ PRO = "PRO"
9
+
10
+ def __str__(self) -> str:
11
+ return str(self.value)
@@ -24,7 +24,7 @@ def softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
24
24
 
25
25
 
26
26
  # We don't want to depend on transformers just for the eval_pred type in orca_sdk
27
- def transform_eval_pred(eval_pred: Any) -> tuple[NDArray[np.int64], NDArray[np.float32]]:
27
+ def transform_eval_pred(eval_pred: Any) -> tuple[NDArray, NDArray[np.float32]]:
28
28
  # convert results from Trainer compute_metrics param for use in calculate_classification_metrics
29
29
  logits, references = eval_pred # transformers.trainer_utils.EvalPrediction
30
30
  if isinstance(logits, tuple):
@@ -35,7 +35,10 @@ from ._generated_api_client.models import (
35
35
  from ._generated_api_client.models import (
36
36
  PredictionSortItemItemType1 as PredictionSortDirection,
37
37
  )
38
- from ._generated_api_client.models import PredictiveModelUpdate, RACHeadType
38
+ from ._generated_api_client.models import (
39
+ PredictiveModelUpdate,
40
+ RACHeadType,
41
+ )
39
42
  from ._generated_api_client.types import UNSET as CLIENT_UNSET
40
43
  from ._shared.metrics import ClassificationMetrics, calculate_classification_metrics
41
44
  from ._utils.common import UNSET, CreateMode, DropMode
@@ -374,3 +374,56 @@ def test_explain(writable_memoryset: LabeledMemoryset):
374
374
  raise e
375
375
  finally:
376
376
  ClassificationModel.drop("test_model_for_explain")
377
+
378
+
379
+ @skip_in_ci("We don't have Anthropic API key in CI")
380
+ def test_action_recommendation(writable_memoryset: LabeledMemoryset):
381
+ """Test getting action recommendations for predictions"""
382
+
383
+ writable_memoryset.analyze(
384
+ {"name": "neighbor", "neighbor_counts": [1, 3]},
385
+ lookup_count=3,
386
+ )
387
+
388
+ model = ClassificationModel.create(
389
+ "test_model_for_action",
390
+ writable_memoryset,
391
+ num_classes=2,
392
+ memory_lookup_count=3,
393
+ description="This is a test model for action recommendations",
394
+ )
395
+
396
+ # Make a prediction with expected label to simulate incorrect prediction
397
+ prediction = model.predict("Do you love soup?", expected_labels=1)
398
+
399
+ try:
400
+ # Get action recommendation
401
+ action, rationale = prediction.recommend_action()
402
+
403
+ assert action is not None
404
+ assert rationale is not None
405
+ assert action in ["remove_duplicates", "detect_mislabels", "add_memories", "finetuning"]
406
+ assert len(rationale) > 10
407
+
408
+ # Test memory suggestions
409
+ memory_suggestions = prediction.generate_memory_suggestions(num_memories=2)
410
+
411
+ assert memory_suggestions is not None
412
+ assert len(memory_suggestions) == 2
413
+
414
+ for suggestion in memory_suggestions:
415
+ assert isinstance(suggestion, dict)
416
+ assert "value" in suggestion
417
+ assert "label" in suggestion
418
+ assert isinstance(suggestion["value"], str)
419
+ assert len(suggestion["value"]) > 0
420
+ assert isinstance(suggestion["label"], int)
421
+ assert 0 <= suggestion["label"] < len(model.memoryset.label_names)
422
+
423
+ except Exception as e:
424
+ if "ANTHROPIC_API_KEY" in str(e):
425
+ logging.info("Skipping agent tests because ANTHROPIC_API_KEY is not set")
426
+ else:
427
+ raise e
428
+ finally:
429
+ ClassificationModel.drop("test_model_for_action")
orca_sdk/credentials.py CHANGED
@@ -106,6 +106,20 @@ class OrcaCredentials:
106
106
  """
107
107
  delete_api_key(name_or_id=name)
108
108
 
109
+ @staticmethod
110
+ def set_headers(headers: dict[str, str]):
111
+ """
112
+ Add or override default HTTP headers for all Orca API requests.
113
+
114
+ Args:
115
+ **kwargs: Header names with their string values
116
+
117
+ Notes:
118
+ New keys are merged into the existing headers, this will overwrite headers with the
119
+ same name, but leave other headers untouched.
120
+ """
121
+ set_headers(get_headers() | headers)
122
+
109
123
  @staticmethod
110
124
  def set_api_key(api_key: str, check_validity: bool = True):
111
125
  """
@@ -121,6 +135,6 @@ class OrcaCredentials:
121
135
  Raises:
122
136
  ValueError: if the API key is invalid and `check_validity` is True
123
137
  """
124
- set_headers(get_headers() | {"Api-Key": api_key})
138
+ OrcaCredentials.set_headers({"Api-Key": api_key})
125
139
  if check_validity:
126
140
  check_authentication()