orca-sdk 0.0.90__py3-none-any.whl → 0.0.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. orca_sdk/_generated_api_client/api/__init__.py +12 -0
  2. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +12 -12
  3. orca_sdk/_generated_api_client/api/classification_model/update_model_classification_model_name_or_id_patch.py +183 -0
  4. orca_sdk/_generated_api_client/api/memoryset/batch_delete_memoryset_batch_delete_memoryset_post.py +168 -0
  5. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +183 -0
  6. orca_sdk/_generated_api_client/models/__init__.py +8 -2
  7. orca_sdk/_generated_api_client/models/{label_prediction_result.py → base_label_prediction_result.py} +15 -8
  8. orca_sdk/_generated_api_client/models/delete_memorysets_request.py +70 -0
  9. orca_sdk/_generated_api_client/models/labeled_memoryset_update.py +113 -0
  10. orca_sdk/_generated_api_client/models/prediction_request.py +9 -0
  11. orca_sdk/_generated_api_client/models/rac_model_update.py +82 -0
  12. orca_sdk/_utils/analysis_ui.py +1 -1
  13. orca_sdk/_utils/analysis_ui_style.css +0 -3
  14. orca_sdk/classification_model.py +47 -4
  15. orca_sdk/classification_model_test.py +26 -0
  16. orca_sdk/conftest.py +13 -1
  17. orca_sdk/embedding_model.py +2 -0
  18. orca_sdk/memoryset.py +13 -0
  19. orca_sdk/memoryset_test.py +27 -6
  20. orca_sdk/telemetry.py +10 -2
  21. orca_sdk/telemetry_test.py +6 -0
  22. {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.91.dist-info}/METADATA +1 -1
  23. {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.91.dist-info}/RECORD +24 -18
  24. {orca_sdk-0.0.90.dist-info → orca_sdk-0.0.91.dist-info}/WHEEL +0 -0
@@ -3,6 +3,7 @@
3
3
  from .analyze_neighbor_labels_result import AnalyzeNeighborLabelsResult
4
4
  from .api_key_metadata import ApiKeyMetadata
5
5
  from .api_key_metadata_scope_item import ApiKeyMetadataScopeItem
6
+ from .base_label_prediction_result import BaseLabelPredictionResult
6
7
  from .base_model import BaseModel
7
8
  from .body_create_datasource_datasource_post import BodyCreateDatasourceDatasourcePost
8
9
  from .classification_evaluation_result import ClassificationEvaluationResult
@@ -19,6 +20,7 @@ from .create_labeled_memoryset_request import CreateLabeledMemorysetRequest
19
20
  from .create_rac_model_request import CreateRACModelRequest
20
21
  from .datasource_metadata import DatasourceMetadata
21
22
  from .delete_memories_request import DeleteMemoriesRequest
23
+ from .delete_memorysets_request import DeleteMemorysetsRequest
22
24
  from .embed_request import EmbedRequest
23
25
  from .embedding_evaluation_request import EmbeddingEvaluationRequest
24
26
  from .embedding_evaluation_response import EmbeddingEvaluationResponse
@@ -41,7 +43,6 @@ from .internal_server_error_response import InternalServerErrorResponse
41
43
  from .label_class_metrics import LabelClassMetrics
42
44
  from .label_prediction_memory_lookup import LabelPredictionMemoryLookup
43
45
  from .label_prediction_memory_lookup_metadata import LabelPredictionMemoryLookupMetadata
44
- from .label_prediction_result import LabelPredictionResult
45
46
  from .label_prediction_with_memories_and_feedback import LabelPredictionWithMemoriesAndFeedback
46
47
  from .labeled_memory import LabeledMemory
47
48
  from .labeled_memory_insert import LabeledMemoryInsert
@@ -56,6 +57,7 @@ from .labeled_memory_with_feedback_metrics import LabeledMemoryWithFeedbackMetri
56
57
  from .labeled_memory_with_feedback_metrics_feedback_metrics import LabeledMemoryWithFeedbackMetricsFeedbackMetrics
57
58
  from .labeled_memory_with_feedback_metrics_metadata import LabeledMemoryWithFeedbackMetricsMetadata
58
59
  from .labeled_memoryset_metadata import LabeledMemorysetMetadata
60
+ from .labeled_memoryset_update import LabeledMemorysetUpdate
59
61
  from .list_memories_request import ListMemoriesRequest
60
62
  from .list_predictions_request import ListPredictionsRequest
61
63
  from .lookup_request import LookupRequest
@@ -93,6 +95,7 @@ from .pretrained_embedding_model_metadata import PretrainedEmbeddingModelMetadat
93
95
  from .pretrained_embedding_model_name import PretrainedEmbeddingModelName
94
96
  from .rac_head_type import RACHeadType
95
97
  from .rac_model_metadata import RACModelMetadata
98
+ from .rac_model_update import RACModelUpdate
96
99
  from .roc_curve import ROCCurve
97
100
  from .service_unavailable_error_response import ServiceUnavailableErrorResponse
98
101
  from .task import Task
@@ -112,6 +115,7 @@ __all__ = (
112
115
  "AnalyzeNeighborLabelsResult",
113
116
  "ApiKeyMetadata",
114
117
  "ApiKeyMetadataScopeItem",
118
+ "BaseLabelPredictionResult",
115
119
  "BaseModel",
116
120
  "BodyCreateDatasourceDatasourcePost",
117
121
  "ClassificationEvaluationResult",
@@ -128,6 +132,7 @@ __all__ = (
128
132
  "CreateRACModelRequest",
129
133
  "DatasourceMetadata",
130
134
  "DeleteMemoriesRequest",
135
+ "DeleteMemorysetsRequest",
131
136
  "EmbeddingEvaluationRequest",
132
137
  "EmbeddingEvaluationResponse",
133
138
  "EmbeddingEvaluationResult",
@@ -156,6 +161,7 @@ __all__ = (
156
161
  "LabeledMemoryMetadata",
157
162
  "LabeledMemoryMetrics",
158
163
  "LabeledMemorysetMetadata",
164
+ "LabeledMemorysetUpdate",
159
165
  "LabeledMemoryUpdate",
160
166
  "LabeledMemoryUpdateMetadataType0",
161
167
  "LabeledMemoryWithFeedbackMetrics",
@@ -163,7 +169,6 @@ __all__ = (
163
169
  "LabeledMemoryWithFeedbackMetricsMetadata",
164
170
  "LabelPredictionMemoryLookup",
165
171
  "LabelPredictionMemoryLookupMetadata",
166
- "LabelPredictionResult",
167
172
  "LabelPredictionWithMemoriesAndFeedback",
168
173
  "ListMemoriesRequest",
169
174
  "ListPredictionsRequest",
@@ -202,6 +207,7 @@ __all__ = (
202
207
  "PretrainedEmbeddingModelName",
203
208
  "RACHeadType",
204
209
  "RACModelMetadata",
210
+ "RACModelUpdate",
205
211
  "ROCCurve",
206
212
  "ServiceUnavailableErrorResponse",
207
213
  "Task",
@@ -15,22 +15,22 @@ from typing import Any, Type, TypeVar, Union, cast
15
15
  from attrs import define as _attrs_define
16
16
  from attrs import field as _attrs_field
17
17
 
18
- T = TypeVar("T", bound="LabelPredictionResult")
18
+ T = TypeVar("T", bound="BaseLabelPredictionResult")
19
19
 
20
20
 
21
21
  @_attrs_define
22
- class LabelPredictionResult:
22
+ class BaseLabelPredictionResult:
23
23
  """Predicted label and confidence for a single input.
24
24
 
25
25
  Attributes:
26
- prediction_id (str):
26
+ prediction_id (Union[None, str]):
27
27
  confidence (float):
28
28
  anomaly_score (Union[None, float]):
29
29
  label (int):
30
30
  label_name (Union[None, str]):
31
31
  """
32
32
 
33
- prediction_id: str
33
+ prediction_id: Union[None, str]
34
34
  confidence: float
35
35
  anomaly_score: Union[None, float]
36
36
  label: int
@@ -38,6 +38,7 @@ class LabelPredictionResult:
38
38
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
39
39
 
40
40
  def to_dict(self) -> dict[str, Any]:
41
+ prediction_id: Union[None, str]
41
42
  prediction_id = self.prediction_id
42
43
 
43
44
  confidence = self.confidence
@@ -67,7 +68,13 @@ class LabelPredictionResult:
67
68
  @classmethod
68
69
  def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
69
70
  d = src_dict.copy()
70
- prediction_id = d.pop("prediction_id")
71
+
72
+ def _parse_prediction_id(data: object) -> Union[None, str]:
73
+ if data is None:
74
+ return data
75
+ return cast(Union[None, str], data)
76
+
77
+ prediction_id = _parse_prediction_id(d.pop("prediction_id"))
71
78
 
72
79
  confidence = d.pop("confidence")
73
80
 
@@ -87,7 +94,7 @@ class LabelPredictionResult:
87
94
 
88
95
  label_name = _parse_label_name(d.pop("label_name"))
89
96
 
90
- label_prediction_result = cls(
97
+ base_label_prediction_result = cls(
91
98
  prediction_id=prediction_id,
92
99
  confidence=confidence,
93
100
  anomaly_score=anomaly_score,
@@ -95,8 +102,8 @@ class LabelPredictionResult:
95
102
  label_name=label_name,
96
103
  )
97
104
 
98
- label_prediction_result.additional_properties = d
99
- return label_prediction_result
105
+ base_label_prediction_result.additional_properties = d
106
+ return base_label_prediction_result
100
107
 
101
108
  @property
102
109
  def additional_keys(self) -> list[str]:
@@ -0,0 +1,70 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from typing import Any, List, Type, TypeVar, cast
14
+
15
+ from attrs import define as _attrs_define
16
+ from attrs import field as _attrs_field
17
+
18
+ T = TypeVar("T", bound="DeleteMemorysetsRequest")
19
+
20
+
21
+ @_attrs_define
22
+ class DeleteMemorysetsRequest:
23
+ """
24
+ Attributes:
25
+ memoryset_ids (List[str]):
26
+ """
27
+
28
+ memoryset_ids: List[str]
29
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
30
+
31
+ def to_dict(self) -> dict[str, Any]:
32
+ memoryset_ids = self.memoryset_ids
33
+
34
+ field_dict: dict[str, Any] = {}
35
+ field_dict.update(self.additional_properties)
36
+ field_dict.update(
37
+ {
38
+ "memoryset_ids": memoryset_ids,
39
+ }
40
+ )
41
+
42
+ return field_dict
43
+
44
+ @classmethod
45
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
46
+ d = src_dict.copy()
47
+ memoryset_ids = cast(List[str], d.pop("memoryset_ids"))
48
+
49
+ delete_memorysets_request = cls(
50
+ memoryset_ids=memoryset_ids,
51
+ )
52
+
53
+ delete_memorysets_request.additional_properties = d
54
+ return delete_memorysets_request
55
+
56
+ @property
57
+ def additional_keys(self) -> list[str]:
58
+ return list(self.additional_properties.keys())
59
+
60
+ def __getitem__(self, key: str) -> Any:
61
+ return self.additional_properties[key]
62
+
63
+ def __setitem__(self, key: str, value: Any) -> None:
64
+ self.additional_properties[key] = value
65
+
66
+ def __delitem__(self, key: str) -> None:
67
+ del self.additional_properties[key]
68
+
69
+ def __contains__(self, key: str) -> bool:
70
+ return key in self.additional_properties
@@ -0,0 +1,113 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from typing import Any, List, Type, TypeVar, Union, cast
14
+
15
+ from attrs import define as _attrs_define
16
+ from attrs import field as _attrs_field
17
+
18
+ from ..types import UNSET, Unset
19
+
20
+ T = TypeVar("T", bound="LabeledMemorysetUpdate")
21
+
22
+
23
+ @_attrs_define
24
+ class LabeledMemorysetUpdate:
25
+ """
26
+ Attributes:
27
+ label_names (Union[List[str], None, Unset]):
28
+ description (Union[None, Unset, str]):
29
+ """
30
+
31
+ label_names: Union[List[str], None, Unset] = UNSET
32
+ description: Union[None, Unset, str] = UNSET
33
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
34
+
35
+ def to_dict(self) -> dict[str, Any]:
36
+ label_names: Union[List[str], None, Unset]
37
+ if isinstance(self.label_names, Unset):
38
+ label_names = UNSET
39
+ elif isinstance(self.label_names, list):
40
+ label_names = self.label_names
41
+
42
+ else:
43
+ label_names = self.label_names
44
+
45
+ description: Union[None, Unset, str]
46
+ if isinstance(self.description, Unset):
47
+ description = UNSET
48
+ else:
49
+ description = self.description
50
+
51
+ field_dict: dict[str, Any] = {}
52
+ field_dict.update(self.additional_properties)
53
+ field_dict.update({})
54
+ if label_names is not UNSET:
55
+ field_dict["label_names"] = label_names
56
+ if description is not UNSET:
57
+ field_dict["description"] = description
58
+
59
+ return field_dict
60
+
61
+ @classmethod
62
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
63
+ d = src_dict.copy()
64
+
65
+ def _parse_label_names(data: object) -> Union[List[str], None, Unset]:
66
+ if data is None:
67
+ return data
68
+ if isinstance(data, Unset):
69
+ return data
70
+ try:
71
+ if not isinstance(data, list):
72
+ raise TypeError()
73
+ label_names_type_0 = cast(List[str], data)
74
+
75
+ return label_names_type_0
76
+ except: # noqa: E722
77
+ pass
78
+ return cast(Union[List[str], None, Unset], data)
79
+
80
+ label_names = _parse_label_names(d.pop("label_names", UNSET))
81
+
82
+ def _parse_description(data: object) -> Union[None, Unset, str]:
83
+ if data is None:
84
+ return data
85
+ if isinstance(data, Unset):
86
+ return data
87
+ return cast(Union[None, Unset, str], data)
88
+
89
+ description = _parse_description(d.pop("description", UNSET))
90
+
91
+ labeled_memoryset_update = cls(
92
+ label_names=label_names,
93
+ description=description,
94
+ )
95
+
96
+ labeled_memoryset_update.additional_properties = d
97
+ return labeled_memoryset_update
98
+
99
+ @property
100
+ def additional_keys(self) -> list[str]:
101
+ return list(self.additional_properties.keys())
102
+
103
+ def __getitem__(self, key: str) -> Any:
104
+ return self.additional_properties[key]
105
+
106
+ def __setitem__(self, key: str, value: Any) -> None:
107
+ self.additional_properties[key] = value
108
+
109
+ def __delitem__(self, key: str) -> None:
110
+ del self.additional_properties[key]
111
+
112
+ def __contains__(self, key: str) -> bool:
113
+ return key in self.additional_properties
@@ -28,12 +28,14 @@ class PredictionRequest:
28
28
  expected_labels (Union[List[int], None, Unset]):
29
29
  tags (Union[Unset, List[str]]):
30
30
  memoryset_override_id (Union[None, Unset, str]):
31
+ disable_telemetry (Union[Unset, bool]): Default: False.
31
32
  """
32
33
 
33
34
  input_values: List[str]
34
35
  expected_labels: Union[List[int], None, Unset] = UNSET
35
36
  tags: Union[Unset, List[str]] = UNSET
36
37
  memoryset_override_id: Union[None, Unset, str] = UNSET
38
+ disable_telemetry: Union[Unset, bool] = False
37
39
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
38
40
 
39
41
  def to_dict(self) -> dict[str, Any]:
@@ -60,6 +62,8 @@ class PredictionRequest:
60
62
  else:
61
63
  memoryset_override_id = self.memoryset_override_id
62
64
 
65
+ disable_telemetry = self.disable_telemetry
66
+
63
67
  field_dict: dict[str, Any] = {}
64
68
  field_dict.update(self.additional_properties)
65
69
  field_dict.update(
@@ -73,6 +77,8 @@ class PredictionRequest:
73
77
  field_dict["tags"] = tags
74
78
  if memoryset_override_id is not UNSET:
75
79
  field_dict["memoryset_override_id"] = memoryset_override_id
80
+ if disable_telemetry is not UNSET:
81
+ field_dict["disable_telemetry"] = disable_telemetry
76
82
 
77
83
  return field_dict
78
84
 
@@ -150,11 +156,14 @@ class PredictionRequest:
150
156
 
151
157
  memoryset_override_id = _parse_memoryset_override_id(d.pop("memoryset_override_id", UNSET))
152
158
 
159
+ disable_telemetry = d.pop("disable_telemetry", UNSET)
160
+
153
161
  prediction_request = cls(
154
162
  input_values=input_values,
155
163
  expected_labels=expected_labels,
156
164
  tags=tags,
157
165
  memoryset_override_id=memoryset_override_id,
166
+ disable_telemetry=disable_telemetry,
158
167
  )
159
168
 
160
169
  prediction_request.additional_properties = d
@@ -0,0 +1,82 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from typing import Any, Type, TypeVar, Union, cast
14
+
15
+ from attrs import define as _attrs_define
16
+ from attrs import field as _attrs_field
17
+
18
+ from ..types import UNSET, Unset
19
+
20
+ T = TypeVar("T", bound="RACModelUpdate")
21
+
22
+
23
+ @_attrs_define
24
+ class RACModelUpdate:
25
+ """
26
+ Attributes:
27
+ description (Union[None, Unset, str]):
28
+ """
29
+
30
+ description: Union[None, Unset, str] = UNSET
31
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
32
+
33
+ def to_dict(self) -> dict[str, Any]:
34
+ description: Union[None, Unset, str]
35
+ if isinstance(self.description, Unset):
36
+ description = UNSET
37
+ else:
38
+ description = self.description
39
+
40
+ field_dict: dict[str, Any] = {}
41
+ field_dict.update(self.additional_properties)
42
+ field_dict.update({})
43
+ if description is not UNSET:
44
+ field_dict["description"] = description
45
+
46
+ return field_dict
47
+
48
+ @classmethod
49
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
50
+ d = src_dict.copy()
51
+
52
+ def _parse_description(data: object) -> Union[None, Unset, str]:
53
+ if data is None:
54
+ return data
55
+ if isinstance(data, Unset):
56
+ return data
57
+ return cast(Union[None, Unset, str], data)
58
+
59
+ description = _parse_description(d.pop("description", UNSET))
60
+
61
+ rac_model_update = cls(
62
+ description=description,
63
+ )
64
+
65
+ rac_model_update.additional_properties = d
66
+ return rac_model_update
67
+
68
+ @property
69
+ def additional_keys(self) -> list[str]:
70
+ return list(self.additional_properties.keys())
71
+
72
+ def __getitem__(self, key: str) -> Any:
73
+ return self.additional_properties[key]
74
+
75
+ def __setitem__(self, key: str, value: Any) -> None:
76
+ self.additional_properties[key] = value
77
+
78
+ def __delitem__(self, key: str) -> None:
79
+ del self.additional_properties[key]
80
+
81
+ def __contains__(self, key: str) -> bool:
82
+ return key in self.additional_properties
@@ -152,7 +152,7 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
152
152
  predicted_label_name = label_names[predicted_label]
153
153
  predicted_label_confidence = mem.metrics.get("neighbor_predicted_label_confidence", 0)
154
154
 
155
- with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
155
+ with gr.Row(equal_height=True, variant="panel"):
156
156
  with gr.Column(scale=9):
157
157
  assert isinstance(mem.value, str)
158
158
  gr.Markdown(mem.value, label="Value", height=50)
@@ -1,6 +1,3 @@
1
- .white {
2
- background-color: white;
3
- }
4
1
  .centered input {
5
2
  margin: auto;
6
3
  }
@@ -16,6 +16,7 @@ from ._generated_api_client.api import (
16
16
  list_predictions,
17
17
  predict_gpu,
18
18
  record_prediction_feedback,
19
+ update_model,
19
20
  )
20
21
  from ._generated_api_client.models import (
21
22
  CreateRACModelRequest,
@@ -31,9 +32,10 @@ from ._generated_api_client.models import (
31
32
  from ._generated_api_client.models import (
32
33
  RACHeadType,
33
34
  RACModelMetadata,
35
+ RACModelUpdate,
34
36
  )
35
37
  from ._generated_api_client.models.prediction_request import PredictionRequest
36
- from ._utils.common import CreateMode, DropMode
38
+ from ._utils.common import UNSET, CreateMode, DropMode
37
39
  from ._utils.task import wait_for_task
38
40
  from .datasource import Datasource
39
41
  from .memoryset import LabeledMemoryset
@@ -270,18 +272,53 @@ class ClassificationModel:
270
272
  if if_not_exists == "error":
271
273
  raise
272
274
 
275
+ def refresh(self):
276
+ """Refresh the model data from the OrcaCloud"""
277
+ self.__dict__.update(ClassificationModel.open(self.name).__dict__)
278
+
279
+ def update_metadata(self, *, description: str | None = UNSET) -> None:
280
+ """
281
+ Update editable classification model metadata properties.
282
+
283
+ Params:
284
+ description: Value to set for the description, defaults to `[UNSET]` if not provided.
285
+
286
+ Examples:
287
+ Update the description:
288
+ >>> model.update(description="New description")
289
+
290
+ Remove description:
291
+ >>> model.update(description=None)
292
+ """
293
+ update_model(self.id, body=RACModelUpdate(description=description))
294
+ self.refresh()
295
+
273
296
  @overload
274
297
  def predict(
275
- self, value: list[str], expected_labels: list[int] | None = None, tags: set[str] = set()
298
+ self,
299
+ value: list[str],
300
+ expected_labels: list[int] | None = None,
301
+ tags: set[str] = set(),
302
+ disable_telemetry: bool = False,
276
303
  ) -> list[LabelPrediction]:
277
304
  pass
278
305
 
279
306
  @overload
280
- def predict(self, value: str, expected_labels: int | None = None, tags: set[str] = set()) -> LabelPrediction:
307
+ def predict(
308
+ self,
309
+ value: str,
310
+ expected_labels: int | None = None,
311
+ tags: set[str] = set(),
312
+ disable_telemetry: bool = False,
313
+ ) -> LabelPrediction:
281
314
  pass
282
315
 
283
316
  def predict(
284
- self, value: list[str] | str, expected_labels: list[int] | int | None = None, tags: set[str] = set()
317
+ self,
318
+ value: list[str] | str,
319
+ expected_labels: list[int] | int | None = None,
320
+ tags: set[str] = set(),
321
+ disable_telemetry: bool = False,
285
322
  ) -> list[LabelPrediction] | LabelPrediction:
286
323
  """
287
324
  Predict label(s) for the given input value(s) grounded in similar memories
@@ -290,6 +327,7 @@ class ClassificationModel:
290
327
  value: Value(s) to get predict the labels of
291
328
  expected_labels: Expected label(s) for the given input to record for model evaluation
292
329
  tags: Tags to add to the prediction(s)
330
+ disable_telemetry: Whether to disable telemetry for the prediction(s)
293
331
 
294
332
  Returns:
295
333
  Label prediction or list of label predictions
@@ -318,8 +356,13 @@ class ClassificationModel:
318
356
  else [expected_labels] if expected_labels is not None else None
319
357
  ),
320
358
  tags=list(tags),
359
+ disable_telemetry=disable_telemetry,
321
360
  ),
322
361
  )
362
+
363
+ if not disable_telemetry and any(p.prediction_id is None for p in response):
364
+ raise RuntimeError("Failed to save prediction to database.")
365
+
323
366
  predictions = [
324
367
  LabelPrediction(
325
368
  prediction_id=prediction.prediction_id,
@@ -95,6 +95,17 @@ def test_list_models_unauthorized(unauthorized, model: ClassificationModel):
95
95
  assert ClassificationModel.all() == []
96
96
 
97
97
 
98
+ def test_update_model(model: ClassificationModel):
99
+ model.update_metadata(description="New description")
100
+ assert model.description == "New description"
101
+
102
+
103
+ def test_update_model_no_description(model: ClassificationModel):
104
+ assert model.description is not None
105
+ model.update_metadata(description=None)
106
+ assert model.description is None
107
+
108
+
98
109
  def test_delete_model(memoryset: LabeledMemoryset):
99
110
  ClassificationModel.create("model_to_delete", LabeledMemoryset.open(memoryset.name))
100
111
  assert ClassificationModel.open("model_to_delete")
@@ -168,6 +179,21 @@ def test_evaluate_with_telemetry(model):
168
179
  def test_predict(model: ClassificationModel, label_names: list[str]):
169
180
  predictions = model.predict(["Do you love soup?", "Are cats cute?"])
170
181
  assert len(predictions) == 2
182
+ assert predictions[0].prediction_id is not None
183
+ assert predictions[1].prediction_id is not None
184
+ assert predictions[0].label == 0
185
+ assert predictions[0].label_name == label_names[0]
186
+ assert 0 <= predictions[0].confidence <= 1
187
+ assert predictions[1].label == 1
188
+ assert predictions[1].label_name == label_names[1]
189
+ assert 0 <= predictions[1].confidence <= 1
190
+
191
+
192
+ def test_predict_disable_telemetry(model: ClassificationModel, label_names: list[str]):
193
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"], disable_telemetry=True)
194
+ assert len(predictions) == 2
195
+ assert predictions[0].prediction_id is None
196
+ assert predictions[1].prediction_id is None
171
197
  assert predictions[0].label == 0
172
198
  assert predictions[0].label_name == label_names[0]
173
199
  assert 0 <= predictions[0].confidence <= 1
orca_sdk/conftest.py CHANGED
@@ -75,6 +75,16 @@ SAMPLE_DATA = [
75
75
  {"text": "i love cats", "label": 1, "key": "val4", "score": 0.4, "source_id": "s4"},
76
76
  {"text": "everyone loves cats", "label": 1, "key": "val5", "score": 0.5, "source_id": "s5"},
77
77
  {"text": "soup is great for the winter", "label": 0, "key": "val6", "score": 0.6, "source_id": "s6"},
78
+ {"text": "hot soup on a rainy day!", "label": 0, "key": "val7", "score": 0.7, "source_id": "s7"},
79
+ {"text": "cats sleep all day", "label": 1, "key": "val8", "score": 0.8, "source_id": "s8"},
80
+ {"text": "homemade soup recipes", "label": 0, "key": "val9", "score": 0.9, "source_id": "s9"},
81
+ {"text": "cats purr when happy", "label": 1, "key": "val10", "score": 1.0, "source_id": "s10"},
82
+ {"text": "chicken noodle soup is classic", "label": 0, "key": "val11", "score": 1.1, "source_id": "s11"},
83
+ {"text": "kittens are baby cats", "label": 1, "key": "val12", "score": 1.2, "source_id": "s12"},
84
+ {"text": "soup can be served cold too", "label": 0, "key": "val13", "score": 1.3, "source_id": "s13"},
85
+ {"text": "cats have nine lives", "label": 1, "key": "val14", "score": 1.4, "source_id": "s14"},
86
+ {"text": "tomato soup with grilled cheese", "label": 0, "key": "val15", "score": 1.5, "source_id": "s15"},
87
+ {"text": "cats are independent animals", "label": 1, "key": "val16", "score": 1.6, "source_id": "s16"},
78
88
  ]
79
89
 
80
90
 
@@ -113,4 +123,6 @@ def memoryset(datasource) -> LabeledMemoryset:
113
123
 
114
124
  @pytest.fixture(scope="session")
115
125
  def model(memoryset) -> ClassificationModel:
116
- return ClassificationModel.create("test_model", memoryset, num_classes=2, memory_lookup_count=3)
126
+ return ClassificationModel.create(
127
+ "test_model", memoryset, num_classes=2, memory_lookup_count=3, description="test_description"
128
+ )
@@ -100,6 +100,8 @@ class PretrainedEmbeddingModel(_EmbeddingModel, metaclass=_PretrainedEmbeddingMo
100
100
  - **`CDE_SMALL`**: Context-aware CDE small model from Hugging Face ([jxm/cde-small-v1](https://huggingface.co/jxm/cde-small-v1))
101
101
  - **`CLIP_BASE`**: Multi-modal CLIP model from Hugging Face ([sentence-transformers/clip-ViT-L-14](https://huggingface.co/sentence-transformers/clip-ViT-L-14))
102
102
  - **`GTE_BASE`**: Alibaba's GTE model from Hugging Face ([Alibaba-NLP/gte-base-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5))
103
+ - **`DISTILBERT`**: DistilBERT embedding model from Hugging Face ([distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased))
104
+ - **`GTE_SMALL`**: GTE-Small embedding model from Hugging Face ([Supabase/gte-small](https://huggingface.co/Supabase/gte-small))
103
105
 
104
106
  Examples:
105
107
  >>> PretrainedEmbeddingModel.CDE_SMALL