orca-sdk 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
orca_sdk/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
  OrcaSDK is a Python library for building and using retrieval augmented models in the OrcaCloud.
3
3
  """
4
4
 
5
- from ._utils.common import UNSET, CreateMode, DropMode
5
+ from ._utils.common import UNSET, CreateMode, DropMode, logger
6
6
  from .classification_model import ClassificationMetrics, ClassificationModel
7
7
  from .client import OrcaClient
8
8
  from .credentials import OrcaCredentials
@@ -23,8 +23,8 @@ from .memoryset import (
23
23
  ScoredMemoryLookup,
24
24
  ScoredMemoryset,
25
25
  )
26
- from .regression_model import RegressionModel
26
+ from .regression_model import RegressionMetrics, RegressionModel
27
27
  from .telemetry import ClassificationPrediction, FeedbackCategory, RegressionPrediction
28
28
 
29
29
  # only specify things that should show up on the root page of the reference docs because they are in private modules
30
- __all__ = ["UNSET", "CreateMode", "DropMode"]
30
+ __all__ = ["UNSET", "CreateMode", "DropMode", "logger"]
@@ -5,7 +5,10 @@ import re
5
5
  from pathlib import Path
6
6
  from typing import TypedDict, cast
7
7
 
8
- import gradio as gr
8
+ try:
9
+ import gradio as gr # type: ignore
10
+ except ImportError as e:
11
+ raise ImportError("gradio is required for UI features. Install it with: pip install orca_sdk[ui]") from e
9
12
 
10
13
  from ..memoryset import LabeledMemory, LabeledMemoryset
11
14
 
orca_sdk/_utils/auth.py CHANGED
@@ -1,13 +1,12 @@
1
1
  """This module contains internal utils for managing api keys in tests"""
2
2
 
3
- import logging
4
3
  import os
5
4
  from typing import List, Literal
6
5
 
7
6
  from dotenv import load_dotenv
8
7
 
9
8
  from ..client import ApiKeyMetadata, OrcaClient
10
- from .common import DropMode
9
+ from .common import DropMode, logger
11
10
 
12
11
  load_dotenv() # this needs to be here to ensure env is populated before accessing it
13
12
 
@@ -59,7 +58,7 @@ def _authenticate_local_api(org_id: str = _DEFAULT_ORG_ID, api_key_name: str = "
59
58
  client = OrcaClient._resolve_client()
60
59
  client.base_url = "http://localhost:1584"
61
60
  client.headers.update({"Api-Key": _create_api_key(org_id, api_key_name)})
62
- logging.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
61
+ logger.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
63
62
 
64
63
 
65
64
  __all__ = ["_create_api_key", "_delete_api_key", "_delete_org", "_list_api_keys", "_authenticate_local_api"]
orca_sdk/_utils/common.py CHANGED
@@ -1,4 +1,21 @@
1
- from typing import Any, Literal
1
+ import logging
2
+ from typing import Any, Iterable, Iterator, Literal, TypeVar
3
+
4
+ try:
5
+ from itertools import batched
6
+ except ImportError:
7
+ # Polyfill for Python <3.12
8
+
9
+ from itertools import islice
10
+
11
+ _BatchT = TypeVar("_BatchT")
12
+
13
+ def batched(iterable: Iterable[_BatchT], n: int) -> Iterator[tuple[_BatchT, ...]]:
14
+ """Batch an iterable into chunks of size n (backfill for Python <3.12)."""
15
+ it = iter(iterable)
16
+ while batch := tuple(islice(it, n)):
17
+ yield batch
18
+
2
19
 
3
20
  CreateMode = Literal["error", "open"]
4
21
  """
@@ -35,3 +52,9 @@ UNSET: Any = _UnsetSentinel()
35
52
  """
36
53
  Default value to indicate that no update should be applied to a field and it should not be set to None
37
54
  """
55
+
56
+ logger = logging.getLogger("orca_sdk")
57
+ """
58
+ Logger for the Orca SDK.
59
+ """
60
+ logger.addHandler(logging.NullHandler())
@@ -5,7 +5,10 @@ import re
5
5
  from pathlib import Path
6
6
  from typing import TYPE_CHECKING
7
7
 
8
- import gradio as gr
8
+ try:
9
+ import gradio as gr # type: ignore
10
+ except ImportError as e:
11
+ raise ImportError("gradio is required for UI features. Install it with: pip install orca_sdk[ui]") from e
9
12
 
10
13
  from ..memoryset import LabeledMemoryLookup, LabeledMemoryset, ScoredMemoryLookup
11
14
 
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, is_dataclass
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ if TYPE_CHECKING:
7
+ # peer dependencies that are used for types only
8
+ from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
9
+ from torch.utils.data import Dataset as TorchDataset # type: ignore
10
+
11
+
12
+ def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
13
+ if isinstance(item, dict):
14
+ return item
15
+
16
+ if isinstance(item, tuple):
17
+ if column_names is not None:
18
+ if len(item) != len(column_names):
19
+ raise ValueError(
20
+ f"Tuple length ({len(item)}) does not match number of column names ({len(column_names)})"
21
+ )
22
+ return {column_names[i]: item[i] for i in range(len(item))}
23
+ elif hasattr(item, "_fields") and all(isinstance(field, str) for field in item._fields): # type: ignore
24
+ return {field: getattr(item, field) for field in item._fields} # type: ignore
25
+ else:
26
+ raise ValueError("For datasets that return unnamed tuples, please provide column_names argument")
27
+
28
+ if is_dataclass(item) and not isinstance(item, type):
29
+ return asdict(item)
30
+
31
+ raise ValueError(f"Cannot parse {type(item)}")
32
+
33
+
34
+ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]:
35
+ if isinstance(batch, list):
36
+ return [parse_dict_like(item, column_names) for item in batch]
37
+
38
+ batch = parse_dict_like(batch, column_names)
39
+ keys = list(batch.keys())
40
+ batch_size = len(batch[keys[0]])
41
+ for key in keys:
42
+ if not len(batch[key]) == batch_size:
43
+ raise ValueError(f"Batch must consist of values of the same length, but {key} has length {len(batch[key])}")
44
+ return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
45
+
46
+
47
+ def list_from_torch(
48
+ torch_data: TorchDataLoader | TorchDataset,
49
+ column_names: list[str] | None = None,
50
+ ) -> list[dict]:
51
+ """
52
+ Convert a PyTorch DataLoader or Dataset to a list of dictionaries.
53
+
54
+ Params:
55
+ torch_data: A PyTorch DataLoader or Dataset object to convert.
56
+ column_names: Optional list of column names to use for the data. If not provided,
57
+ the column names will be inferred from the data.
58
+ Returns:
59
+ A list of dictionaries containing the data from the PyTorch DataLoader or Dataset.
60
+ """
61
+ # peer dependency that is guaranteed to exist if the user provided a torch dataset
62
+ from torch.utils.data import DataLoader as TorchDataLoader # type: ignore
63
+
64
+ if isinstance(torch_data, TorchDataLoader):
65
+ dataloader = torch_data
66
+ else:
67
+ dataloader = TorchDataLoader(torch_data, batch_size=1, collate_fn=lambda x: x)
68
+
69
+ # Collect data from the dataloader into a list
70
+ data_list = []
71
+ try:
72
+ for batch in dataloader:
73
+ data_list.extend(parse_batch(batch, column_names=column_names))
74
+ except ValueError as e:
75
+ raise ValueError(str(e)) from e
76
+
77
+ return data_list
@@ -0,0 +1,142 @@
1
+ from collections import namedtuple
2
+ from dataclasses import dataclass
3
+
4
+ import pytest
5
+
6
+ from .torch_parsing import list_from_torch
7
+
8
+ pytest.importorskip("torch")
9
+
10
+ from torch.utils.data import DataLoader as TorchDataLoader # noqa: E402
11
+ from torch.utils.data import Dataset as TorchDataset # noqa: E402
12
+
13
+
14
+ def test_list_from_torch_dict_dataset(data: list[dict]):
15
+ class PytorchDictDataset(TorchDataset):
16
+ def __init__(self):
17
+ self.data = data
18
+
19
+ def __getitem__(self, i):
20
+ return self.data[i]
21
+
22
+ def __len__(self):
23
+ return len(self.data)
24
+
25
+ dataset = PytorchDictDataset()
26
+ data_list = list_from_torch(dataset)
27
+
28
+ assert isinstance(data_list, list)
29
+ assert len(data_list) == len(dataset)
30
+ assert set(list(data_list[0].keys())) == {"value", "label", "key", "score", "source_id", "partition_id"}
31
+
32
+
33
+ def test_list_from_torch_dataloader(data: list[dict]):
34
+ class PytorchDictDataset(TorchDataset):
35
+ def __init__(self):
36
+ self.data = data
37
+
38
+ def __getitem__(self, i):
39
+ return self.data[i]
40
+
41
+ def __len__(self):
42
+ return len(self.data)
43
+
44
+ dataset = PytorchDictDataset()
45
+
46
+ def collate_fn(x: list[dict]):
47
+ return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
48
+
49
+ dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
50
+ data_list = list_from_torch(dataloader)
51
+
52
+ assert isinstance(data_list, list)
53
+ assert len(data_list) == len(dataset)
54
+ assert list(data_list[0].keys()) == ["value", "label"]
55
+
56
+
57
+ def test_list_from_torch_tuple_dataset(data: list[dict]):
58
+ class PytorchTupleDataset(TorchDataset):
59
+ def __init__(self):
60
+ self.data = data
61
+
62
+ def __getitem__(self, i):
63
+ return self.data[i]["value"], self.data[i]["label"]
64
+
65
+ def __len__(self):
66
+ return len(self.data)
67
+
68
+ dataset = PytorchTupleDataset()
69
+
70
+ # raises error if no column names are passed in
71
+ with pytest.raises(ValueError):
72
+ list_from_torch(dataset)
73
+
74
+ # raises error if not enough column names are passed in
75
+ with pytest.raises(ValueError):
76
+ list_from_torch(dataset, column_names=["value"])
77
+
78
+ # creates list if correct number of column names are passed in
79
+ data_list = list_from_torch(dataset, column_names=["value", "label"])
80
+ assert isinstance(data_list, list)
81
+ assert len(data_list) == len(dataset)
82
+ assert list(data_list[0].keys()) == ["value", "label"]
83
+
84
+
85
+ def test_list_from_torch_named_tuple_dataset(data: list[dict]):
86
+ # Given a Pytorch dataset that returns a namedtuple for each item
87
+ DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
88
+
89
+ class PytorchNamedTupleDataset(TorchDataset):
90
+ def __init__(self):
91
+ self.data = data
92
+
93
+ def __getitem__(self, i):
94
+ return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
95
+
96
+ def __len__(self):
97
+ return len(self.data)
98
+
99
+ dataset = PytorchNamedTupleDataset()
100
+ data_list = list_from_torch(dataset)
101
+ assert isinstance(data_list, list)
102
+ assert len(data_list) == len(dataset)
103
+ assert list(data_list[0].keys()) == ["value", "label"]
104
+
105
+
106
+ def test_list_from_torch_dataclass_dataset(data: list[dict]):
107
+ @dataclass
108
+ class DatasetItem:
109
+ text: str
110
+ label: int
111
+
112
+ class PytorchDataclassDataset(TorchDataset):
113
+ def __init__(self):
114
+ self.data = data
115
+
116
+ def __getitem__(self, i):
117
+ return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
118
+
119
+ def __len__(self):
120
+ return len(self.data)
121
+
122
+ dataset = PytorchDataclassDataset()
123
+ data_list = list_from_torch(dataset)
124
+ assert isinstance(data_list, list)
125
+ assert len(data_list) == len(dataset)
126
+ assert list(data_list[0].keys()) == ["text", "label"]
127
+
128
+
129
+ def test_list_from_torch_invalid_dataset(data: list[dict]):
130
+ class PytorchInvalidDataset(TorchDataset):
131
+ def __init__(self):
132
+ self.data = data
133
+
134
+ def __getitem__(self, i):
135
+ return [self.data[i]["value"], self.data[i]["label"]]
136
+
137
+ def __len__(self):
138
+ return len(self.data)
139
+
140
+ dataset = PytorchInvalidDataset()
141
+ with pytest.raises(ValueError):
142
+ list_from_torch(dataset)
@@ -1,27 +1,43 @@
1
+ from __future__ import annotations
2
+
1
3
  import base64
2
4
  import io
3
- from typing import cast
5
+ from typing import TYPE_CHECKING, Any
4
6
 
5
- import numpy as np
6
- from numpy.typing import NDArray
7
- from PIL import Image as pil
7
+ if TYPE_CHECKING:
8
+ # peer dependencies that are used for types only
9
+ import numpy as np # type: ignore
10
+ from numpy.typing import NDArray # type: ignore
11
+ from PIL import Image as pil # type: ignore
8
12
 
9
- ValueType = str | pil.Image | NDArray[np.float32]
10
- """
11
- The type of a value in a memoryset
13
+ ValueType = str | pil.Image | NDArray[np.float32]
14
+ """
15
+ The type of a value in a memoryset
12
16
 
13
- - `str`: string
14
- - `pil.Image`: image
15
- - `NDArray[np.float32]`: univariate or multivariate timeseries
16
- """
17
+ - `str`: string
18
+ - `pil.Image`: image
19
+ - `NDArray[np.float32]`: univariate or multivariate timeseries
20
+ """
21
+ else:
22
+ ValueType = Any
17
23
 
18
24
 
19
25
  def decode_value(value: str) -> ValueType:
20
26
  if value.startswith("data:image"):
27
+ try:
28
+ from PIL import Image as pil # type: ignore
29
+ except ImportError as e:
30
+ raise ImportError("Install Pillow to use image values") from e
31
+
21
32
  header, data = value.split(",", 1)
22
33
  return pil.open(io.BytesIO(base64.b64decode(data)))
23
34
 
24
35
  if value.startswith("data:numpy"):
36
+ try:
37
+ import numpy as np # type: ignore
38
+ except ImportError as e:
39
+ raise ImportError("Install numpy to use timeseries values") from e
40
+
25
41
  header, data = value.split(",", 1)
26
42
  return np.load(io.BytesIO(base64.b64decode(data)))
27
43
 
@@ -29,17 +45,28 @@ def decode_value(value: str) -> ValueType:
29
45
 
30
46
 
31
47
  def encode_value(value: ValueType) -> str:
32
- if isinstance(value, pil.Image):
33
- header = f"data:image/{value.format.lower()};base64," if value.format else "data:image;base64,"
48
+ try:
49
+ from PIL import Image as pil # type: ignore
50
+ except ImportError:
51
+ pil = None # type: ignore[assignment]
52
+
53
+ try:
54
+ import numpy as np # type: ignore
55
+ except ImportError:
56
+ np = None # type: ignore[assignment]
57
+
58
+ if pil is not None and isinstance(value, pil.Image):
59
+ header = f"data:image/{value.format.lower()};base64," if value.format else "data:image;base64," # type: ignore[union-attr]
34
60
  buffer = io.BytesIO()
35
- value.save(buffer, format=value.format)
61
+ value.save(buffer, format=value.format) # type: ignore[union-attr]
36
62
  bytes = buffer.getvalue()
37
63
  return header + base64.b64encode(bytes).decode("utf-8")
38
64
 
39
- if isinstance(value, np.ndarray):
40
- header = f"data:numpy/{value.dtype.name};base64,"
65
+ if np is not None and isinstance(value, np.ndarray):
66
+ header = f"data:numpy/{value.dtype.name};base64," # type: ignore[union-attr]
41
67
  buffer = io.BytesIO()
42
68
  np.save(buffer, value)
43
69
  return header + base64.b64encode(buffer.getvalue()).decode("utf-8")
44
70
 
45
- return value
71
+ # Value is already a string, or an unhandled type (fall back to str conversion)
72
+ return value if isinstance(value, str) else str(value)
@@ -1,5 +1,4 @@
1
- import numpy as np
2
- from PIL import Image as pil
1
+ import pytest
3
2
 
4
3
  from .value_parser import decode_value, encode_value
5
4
 
@@ -13,6 +12,7 @@ def test_string_parsing():
13
12
 
14
13
 
15
14
  def test_image_parsing():
15
+ pil = pytest.importorskip("PIL.Image")
16
16
  img = pil.new("RGB", (10, 10), color="red")
17
17
  img.format = "PNG"
18
18
 
@@ -22,10 +22,11 @@ def test_image_parsing():
22
22
 
23
23
  decoded = decode_value(encoded)
24
24
  assert isinstance(decoded, pil.Image)
25
- assert decoded.size == img.size
25
+ assert decoded.size == img.size # type: ignore[union-attr]
26
26
 
27
27
 
28
28
  def test_timeseries_parsing():
29
+ np = pytest.importorskip("numpy")
29
30
  timeseries = np.random.rand(20, 3).astype(np.float32)
30
31
 
31
32
  encoded = encode_value(timeseries)
@@ -34,6 +35,6 @@ def test_timeseries_parsing():
34
35
 
35
36
  decoded = decode_value(encoded)
36
37
  assert isinstance(decoded, np.ndarray)
37
- assert decoded.shape == timeseries.shape
38
- assert decoded.dtype == timeseries.dtype
38
+ assert decoded.shape == timeseries.shape # type: ignore[union-attr]
39
+ assert decoded.dtype == timeseries.dtype # type: ignore[union-attr]
39
40
  assert np.allclose(decoded, timeseries)