orca-sdk 0.0.93__py3-none-any.whl → 0.0.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,86 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
14
+
15
+ from attrs import define as _attrs_define
16
+ from attrs import field as _attrs_field
17
+
18
+ from ..types import UNSET, Unset
19
+
20
+ if TYPE_CHECKING:
21
+ from ..models.validation_error import ValidationError
22
+
23
+
24
+ T = TypeVar("T", bound="HTTPValidationError")
25
+
26
+
27
+ @_attrs_define
28
+ class HTTPValidationError:
29
+ """
30
+ Attributes:
31
+ detail (Union[Unset, List['ValidationError']]):
32
+ """
33
+
34
+ detail: Union[Unset, List["ValidationError"]] = UNSET
35
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
36
+
37
+ def to_dict(self) -> dict[str, Any]:
38
+ detail: Union[Unset, List[Dict[str, Any]]] = UNSET
39
+ if not isinstance(self.detail, Unset):
40
+ detail = []
41
+ for detail_item_data in self.detail:
42
+ detail_item = detail_item_data.to_dict()
43
+ detail.append(detail_item)
44
+
45
+ field_dict: dict[str, Any] = {}
46
+ field_dict.update(self.additional_properties)
47
+ field_dict.update({})
48
+ if detail is not UNSET:
49
+ field_dict["detail"] = detail
50
+
51
+ return field_dict
52
+
53
+ @classmethod
54
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
55
+ from ..models.validation_error import ValidationError
56
+
57
+ d = src_dict.copy()
58
+ detail = []
59
+ _detail = d.pop("detail", UNSET)
60
+ for detail_item_data in _detail or []:
61
+ detail_item = ValidationError.from_dict(detail_item_data)
62
+
63
+ detail.append(detail_item)
64
+
65
+ http_validation_error = cls(
66
+ detail=detail,
67
+ )
68
+
69
+ http_validation_error.additional_properties = d
70
+ return http_validation_error
71
+
72
+ @property
73
+ def additional_keys(self) -> list[str]:
74
+ return list(self.additional_properties.keys())
75
+
76
+ def __getitem__(self, key: str) -> Any:
77
+ return self.additional_properties[key]
78
+
79
+ def __setitem__(self, key: str, value: Any) -> None:
80
+ self.additional_properties[key] = value
81
+
82
+ def __delitem__(self, key: str) -> None:
83
+ del self.additional_properties[key]
84
+
85
+ def __contains__(self, key: str) -> bool:
86
+ return key in self.additional_properties
@@ -10,11 +10,13 @@ The main change is:
10
10
 
11
11
  # flake8: noqa: C901
12
12
 
13
+ import datetime
13
14
  from enum import Enum
14
15
  from typing import Any, List, Type, TypeVar, Union, cast
15
16
 
16
17
  from attrs import define as _attrs_define
17
18
  from attrs import field as _attrs_field
19
+ from dateutil.parser import isoparse
18
20
 
19
21
  from ..models.prediction_sort_item_item_type_0 import PredictionSortItemItemType0
20
22
  from ..models.prediction_sort_item_item_type_1 import PredictionSortItemItemType1
@@ -30,6 +32,8 @@ class ListPredictionsRequest:
30
32
  model_id (Union[None, Unset, str]):
31
33
  tag (Union[None, Unset, str]):
32
34
  prediction_ids (Union[List[str], None, Unset]):
35
+ start_timestamp (Union[None, Unset, datetime.datetime]):
36
+ end_timestamp (Union[None, Unset, datetime.datetime]):
33
37
  limit (Union[None, Unset, int]):
34
38
  offset (Union[None, Unset, int]): Default: 0.
35
39
  sort (Union[Unset, List[List[Union[PredictionSortItemItemType0, PredictionSortItemItemType1]]]]):
@@ -39,6 +43,8 @@ class ListPredictionsRequest:
39
43
  model_id: Union[None, Unset, str] = UNSET
40
44
  tag: Union[None, Unset, str] = UNSET
41
45
  prediction_ids: Union[List[str], None, Unset] = UNSET
46
+ start_timestamp: Union[None, Unset, datetime.datetime] = UNSET
47
+ end_timestamp: Union[None, Unset, datetime.datetime] = UNSET
42
48
  limit: Union[None, Unset, int] = UNSET
43
49
  offset: Union[None, Unset, int] = 0
44
50
  sort: Union[Unset, List[List[Union[PredictionSortItemItemType0, PredictionSortItemItemType1]]]] = UNSET
@@ -67,6 +73,22 @@ class ListPredictionsRequest:
67
73
  else:
68
74
  prediction_ids = self.prediction_ids
69
75
 
76
+ start_timestamp: Union[None, Unset, str]
77
+ if isinstance(self.start_timestamp, Unset):
78
+ start_timestamp = UNSET
79
+ elif isinstance(self.start_timestamp, datetime.datetime):
80
+ start_timestamp = self.start_timestamp.isoformat()
81
+ else:
82
+ start_timestamp = self.start_timestamp
83
+
84
+ end_timestamp: Union[None, Unset, str]
85
+ if isinstance(self.end_timestamp, Unset):
86
+ end_timestamp = UNSET
87
+ elif isinstance(self.end_timestamp, datetime.datetime):
88
+ end_timestamp = self.end_timestamp.isoformat()
89
+ else:
90
+ end_timestamp = self.end_timestamp
91
+
70
92
  limit: Union[None, Unset, int]
71
93
  if isinstance(self.limit, Unset):
72
94
  limit = UNSET
@@ -118,6 +140,10 @@ class ListPredictionsRequest:
118
140
  field_dict["tag"] = tag
119
141
  if prediction_ids is not UNSET:
120
142
  field_dict["prediction_ids"] = prediction_ids
143
+ if start_timestamp is not UNSET:
144
+ field_dict["start_timestamp"] = start_timestamp
145
+ if end_timestamp is not UNSET:
146
+ field_dict["end_timestamp"] = end_timestamp
121
147
  if limit is not UNSET:
122
148
  field_dict["limit"] = limit
123
149
  if offset is not UNSET:
@@ -168,6 +194,40 @@ class ListPredictionsRequest:
168
194
 
169
195
  prediction_ids = _parse_prediction_ids(d.pop("prediction_ids", UNSET))
170
196
 
197
+ def _parse_start_timestamp(data: object) -> Union[None, Unset, datetime.datetime]:
198
+ if data is None:
199
+ return data
200
+ if isinstance(data, Unset):
201
+ return data
202
+ try:
203
+ if not isinstance(data, str):
204
+ raise TypeError()
205
+ start_timestamp_type_0 = isoparse(data)
206
+
207
+ return start_timestamp_type_0
208
+ except: # noqa: E722
209
+ pass
210
+ return cast(Union[None, Unset, datetime.datetime], data)
211
+
212
+ start_timestamp = _parse_start_timestamp(d.pop("start_timestamp", UNSET))
213
+
214
+ def _parse_end_timestamp(data: object) -> Union[None, Unset, datetime.datetime]:
215
+ if data is None:
216
+ return data
217
+ if isinstance(data, Unset):
218
+ return data
219
+ try:
220
+ if not isinstance(data, str):
221
+ raise TypeError()
222
+ end_timestamp_type_0 = isoparse(data)
223
+
224
+ return end_timestamp_type_0
225
+ except: # noqa: E722
226
+ pass
227
+ return cast(Union[None, Unset, datetime.datetime], data)
228
+
229
+ end_timestamp = _parse_end_timestamp(d.pop("end_timestamp", UNSET))
230
+
171
231
  def _parse_limit(data: object) -> Union[None, Unset, int]:
172
232
  if data is None:
173
233
  return data
@@ -231,6 +291,8 @@ class ListPredictionsRequest:
231
291
  model_id=model_id,
232
292
  tag=tag,
233
293
  prediction_ids=prediction_ids,
294
+ start_timestamp=start_timestamp,
295
+ end_timestamp=end_timestamp,
234
296
  limit=limit,
235
297
  offset=offset,
236
298
  sort=sort,
@@ -13,7 +13,6 @@ The main change is:
13
13
  from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar, Union, cast
14
14
 
15
15
  from attrs import define as _attrs_define
16
- from attrs import field as _attrs_field
17
16
 
18
17
  from ..types import UNSET, Unset
19
18
 
@@ -44,7 +43,6 @@ class MemorysetAnalysisConfigs:
44
43
  duplicate: Union["MemorysetDuplicateAnalysisConfig", None, Unset] = UNSET
45
44
  projection: Union["MemorysetProjectionAnalysisConfig", None, Unset] = UNSET
46
45
  cluster: Union["MemorysetClusterAnalysisConfig", None, Unset] = UNSET
47
- additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
48
46
 
49
47
  def to_dict(self) -> dict[str, Any]:
50
48
  from ..models.memoryset_cluster_analysis_config import MemorysetClusterAnalysisConfig
@@ -94,7 +92,6 @@ class MemorysetAnalysisConfigs:
94
92
  cluster = self.cluster
95
93
 
96
94
  field_dict: dict[str, Any] = {}
97
- field_dict.update(self.additional_properties)
98
95
  field_dict.update({})
99
96
  if neighbor is not UNSET:
100
97
  field_dict["neighbor"] = neighbor
@@ -212,21 +209,4 @@ class MemorysetAnalysisConfigs:
212
209
  cluster=cluster,
213
210
  )
214
211
 
215
- memoryset_analysis_configs.additional_properties = d
216
212
  return memoryset_analysis_configs
217
-
218
- @property
219
- def additional_keys(self) -> list[str]:
220
- return list(self.additional_properties.keys())
221
-
222
- def __getitem__(self, key: str) -> Any:
223
- return self.additional_properties[key]
224
-
225
- def __setitem__(self, key: str, value: Any) -> None:
226
- self.additional_properties[key] = value
227
-
228
- def __delitem__(self, key: str) -> None:
229
- del self.additional_properties[key]
230
-
231
- def __contains__(self, key: str) -> bool:
232
- return key in self.additional_properties
@@ -2,11 +2,16 @@ from enum import Enum
2
2
 
3
3
 
4
4
  class PretrainedEmbeddingModelName(str, Enum):
5
+ BGE_BASE = "BGE_BASE"
5
6
  CDE_SMALL = "CDE_SMALL"
6
7
  CLIP_BASE = "CLIP_BASE"
7
8
  DISTILBERT = "DISTILBERT"
9
+ E5_LARGE = "E5_LARGE"
10
+ GIST_LARGE = "GIST_LARGE"
8
11
  GTE_BASE = "GTE_BASE"
9
12
  GTE_SMALL = "GTE_SMALL"
13
+ MXBAI_LARGE = "MXBAI_LARGE"
14
+ QWEN2_1_5B = "QWEN2_1_5B"
10
15
 
11
16
  def __str__(self) -> str:
12
17
  return str(self.value)
@@ -0,0 +1,99 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from typing import Any, List, Type, TypeVar, Union, cast
14
+
15
+ from attrs import define as _attrs_define
16
+ from attrs import field as _attrs_field
17
+
18
+ T = TypeVar("T", bound="ValidationError")
19
+
20
+
21
+ @_attrs_define
22
+ class ValidationError:
23
+ """
24
+ Attributes:
25
+ loc (List[Union[int, str]]):
26
+ msg (str):
27
+ type (str):
28
+ """
29
+
30
+ loc: List[Union[int, str]]
31
+ msg: str
32
+ type: str
33
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
34
+
35
+ def to_dict(self) -> dict[str, Any]:
36
+ loc = []
37
+ for loc_item_data in self.loc:
38
+ loc_item: Union[int, str]
39
+ loc_item = loc_item_data
40
+ loc.append(loc_item)
41
+
42
+ msg = self.msg
43
+
44
+ type = self.type
45
+
46
+ field_dict: dict[str, Any] = {}
47
+ field_dict.update(self.additional_properties)
48
+ field_dict.update(
49
+ {
50
+ "loc": loc,
51
+ "msg": msg,
52
+ "type": type,
53
+ }
54
+ )
55
+
56
+ return field_dict
57
+
58
+ @classmethod
59
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
60
+ d = src_dict.copy()
61
+ loc = []
62
+ _loc = d.pop("loc")
63
+ for loc_item_data in _loc:
64
+
65
+ def _parse_loc_item(data: object) -> Union[int, str]:
66
+ return cast(Union[int, str], data)
67
+
68
+ loc_item = _parse_loc_item(loc_item_data)
69
+
70
+ loc.append(loc_item)
71
+
72
+ msg = d.pop("msg")
73
+
74
+ type = d.pop("type")
75
+
76
+ validation_error = cls(
77
+ loc=loc,
78
+ msg=msg,
79
+ type=type,
80
+ )
81
+
82
+ validation_error.additional_properties = d
83
+ return validation_error
84
+
85
+ @property
86
+ def additional_keys(self) -> list[str]:
87
+ return list(self.additional_properties.keys())
88
+
89
+ def __getitem__(self, key: str) -> Any:
90
+ return self.additional_properties[key]
91
+
92
+ def __setitem__(self, key: str, value: Any) -> None:
93
+ self.additional_properties[key] = value
94
+
95
+ def __delitem__(self, key: str) -> None:
96
+ del self.additional_properties[key]
97
+
98
+ def __contains__(self, key: str) -> bool:
99
+ return key in self.additional_properties
@@ -382,7 +382,9 @@ class ClassificationModel:
382
382
  expected_labels=(
383
383
  expected_labels
384
384
  if isinstance(expected_labels, list)
385
- else [expected_labels] if expected_labels is not None else None
385
+ else [expected_labels]
386
+ if expected_labels is not None
387
+ else None
386
388
  ),
387
389
  tags=list(tags),
388
390
  save_telemetry=save_telemetry,
@@ -403,8 +405,9 @@ class ClassificationModel:
403
405
  memoryset=self.memoryset,
404
406
  model=self,
405
407
  logits=prediction.logits,
408
+ input_value=input_value,
406
409
  )
407
- for prediction in response
410
+ for prediction, input_value in zip(response, value if isinstance(value, list) else [value])
408
411
  ]
409
412
  self._last_prediction_was_batch = isinstance(value, list)
410
413
  self._last_prediction = predictions[-1]
@@ -480,7 +483,6 @@ class ClassificationModel:
480
483
  predictions: list[LabelPrediction],
481
484
  expected_labels: list[int],
482
485
  ) -> ClassificationEvaluationResult:
483
-
484
486
  targets_array = np.array(expected_labels)
485
487
  predictions_array = np.array([p.label for p in predictions])
486
488
 
@@ -1,3 +1,5 @@
1
+ import logging
2
+ import os
1
3
  from uuid import uuid4
2
4
 
3
5
  import numpy as np
@@ -9,6 +11,11 @@ from .datasource import Datasource
9
11
  from .embedding_model import PretrainedEmbeddingModel
10
12
  from .memoryset import LabeledMemoryset
11
13
 
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
15
+
16
+
17
+ SKIP_IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
18
+
12
19
 
13
20
  def test_create_model(model: ClassificationModel, readonly_memoryset: LabeledMemoryset):
14
21
  assert model is not None
@@ -338,3 +345,42 @@ def test_last_prediction_with_single(model: ClassificationModel):
338
345
  assert model.last_prediction.prediction_id == prediction.prediction_id
339
346
  assert model.last_prediction.input_value == "Do you love soup?"
340
347
  assert model._last_prediction_was_batch is False
348
+
349
+
350
+ @pytest.mark.skipif(
351
+ SKIP_IN_GITHUB_ACTIONS, reason="Skipping explanation test because in CI we don't have Anthropic API key"
352
+ )
353
+ def test_explain(writable_memoryset: LabeledMemoryset):
354
+
355
+ writable_memoryset.analyze(
356
+ {"name": "neighbor", "neighbor_counts": [1, 3]},
357
+ lookup_count=3,
358
+ )
359
+
360
+ model = ClassificationModel.create(
361
+ "test_model_for_explain",
362
+ writable_memoryset,
363
+ num_classes=2,
364
+ memory_lookup_count=3,
365
+ description="This is a test model for explain",
366
+ )
367
+
368
+ predictions = model.predict(["Do you love soup?", "Are cats cute?"])
369
+ assert len(predictions) == 2
370
+
371
+ try:
372
+ explanation = predictions[0].explanation
373
+ print(explanation)
374
+ assert explanation is not None
375
+ assert len(explanation) > 10
376
+ assert "soup" in explanation.lower()
377
+ except Exception as e:
378
+ if "ANTHROPIC_API_KEY" in str(e):
379
+ logging.info("Skipping explanation test because ANTHROPIC_API_KEY is not set on server")
380
+ else:
381
+ raise e
382
+ finally:
383
+ try:
384
+ ClassificationModel.drop("test_model_for_explain")
385
+ except Exception as e:
386
+ logging.info(f"Failed to drop test model for explain: {e}")
orca_sdk/conftest.py CHANGED
@@ -176,6 +176,7 @@ def writable_memoryset(datasource: Datasource, api_key: str) -> Generator[Labele
176
176
 
177
177
  if memory_ids:
178
178
  memoryset.delete(memory_ids)
179
+ memoryset.refresh()
179
180
  assert len(memoryset) == 0
180
181
  memoryset.insert(SAMPLE_DATA)
181
182
  # If the test dropped the memoryset, do nothing — it will be recreated on the next use.
orca_sdk/datasource.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import tempfile
5
+ import zipfile
5
6
  from datetime import datetime
6
7
  from os import PathLike
7
8
  from pathlib import Path
@@ -84,6 +85,39 @@ class Datasource:
84
85
  + "})"
85
86
  )
86
87
 
88
+ def download(self, output_path: str | PathLike) -> None:
89
+ """
90
+ Download the datasource as a ZIP and extract them to a specified path.
91
+
92
+ Params:
93
+ output_path: The local file path or directory where the downloaded files will be saved.
94
+
95
+ Returns:
96
+ None
97
+
98
+ Raises:
99
+ RuntimeError: If the download fails.
100
+ """
101
+
102
+ output_path = Path(output_path)
103
+ client = get_client().get_httpx_client()
104
+ url = f"/datasource/{self.id}/download"
105
+ response = client.get(url)
106
+ if response.status_code == 404:
107
+ raise LookupError(f"Datasource {self.id} not found")
108
+ if response.status_code != 200:
109
+ raise RuntimeError(f"Failed to download datasource: {response.status_code} {response.text}")
110
+
111
+ with tempfile.NamedTemporaryFile(suffix=".zip") as tmp_zip:
112
+ tmp_zip.write(response.content)
113
+ tmp_zip.flush()
114
+ with zipfile.ZipFile(tmp_zip.name, "r") as zf:
115
+ output_path.mkdir(parents=True, exist_ok=True)
116
+ for file in zf.namelist():
117
+ out_file = output_path / Path(file).name
118
+ with zf.open(file) as af:
119
+ out_file.write_bytes(af.read())
120
+
87
121
  @classmethod
88
122
  def from_hf_dataset(
89
123
  cls, name: str, dataset: Dataset, if_exists: CreateMode = "error", description: str | None = None
@@ -1,3 +1,5 @@
1
+ import os
2
+ import tempfile
1
3
  from uuid import uuid4
2
4
 
3
5
  import pytest
@@ -94,3 +96,10 @@ def test_drop_datasource_unauthorized(datasource, unauthorized):
94
96
  def test_drop_datasource_invalid_input():
95
97
  with pytest.raises(ValueError, match=r"Invalid input:.*"):
96
98
  Datasource.drop("not valid id")
99
+
100
+
101
+ def test_download_datasource(datasource):
102
+ with tempfile.TemporaryDirectory() as temp_dir:
103
+ output_path = os.path.join(temp_dir, "datasource.zip")
104
+ datasource.download(output_path)
105
+ assert os.path.exists(output_path)
@@ -281,8 +281,10 @@ def test_insert_memories(writable_memoryset: LabeledMemoryset):
281
281
  dict(value="cats are fun to play with", label=1),
282
282
  ]
283
283
  )
284
+ writable_memoryset.refresh()
284
285
  assert writable_memoryset.length == prev_length + 2
285
286
  writable_memoryset.insert(dict(value="tomato soup is my favorite", label=0, key="test", source_id="test"))
287
+ writable_memoryset.refresh()
286
288
  assert writable_memoryset.length == prev_length + 3
287
289
  last_memory = writable_memoryset[-1]
288
290
  assert last_memory.value == "tomato soup is my favorite"
orca_sdk/telemetry.py CHANGED
@@ -149,6 +149,7 @@ class LabelPrediction:
149
149
  model: ClassificationModel | str,
150
150
  telemetry: LabelPredictionWithMemoriesAndFeedback | None = None,
151
151
  logits: list[float] | None = None,
152
+ input_value: str | list[list[float]] | None = None,
152
153
  ):
153
154
  # for internal use only, do not document
154
155
  from .classification_model import ClassificationModel
@@ -162,15 +163,14 @@ class LabelPrediction:
162
163
  self.model = ClassificationModel.open(model) if isinstance(model, str) else model
163
164
  self.__telemetry = telemetry if telemetry else None
164
165
  self.logits = logits
166
+ self._input_value = input_value
165
167
 
166
168
  def __repr__(self):
167
169
  return (
168
170
  "LabelPrediction({"
169
171
  + f"label: <{self.label_name}: {self.label}>, "
170
172
  + f"confidence: {self.confidence:.2f}, "
171
- + f"anomaly_score: {self.anomaly_score:.2f}, "
172
- if self.anomaly_score is not None
173
- else ""
173
+ + (f"anomaly_score: {self.anomaly_score:.2f}, " if self.anomaly_score is not None else "")
174
174
  + f"input_value: '{str(self.input_value)[:100] + '...' if len(str(self.input_value)) > 100 else self.input_value}'"
175
175
  + "})"
176
176
  )
@@ -188,6 +188,8 @@ class LabelPrediction:
188
188
 
189
189
  @property
190
190
  def input_value(self) -> str | list[list[float]] | None:
191
+ if self._input_value is not None:
192
+ return self._input_value
191
193
  return self._telemetry.input_value
192
194
 
193
195
  @property
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: orca_sdk
3
- Version: 0.0.93
3
+ Version: 0.0.94
4
4
  Summary: SDK for interacting with Orca Services
5
5
  License: Apache-2.0
6
6
  Author: Orca DB Inc.