orca-sdk 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. orca_sdk/__init__.py +30 -0
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +634 -0
  4. orca_sdk/_shared/metrics_test.py +570 -0
  5. orca_sdk/_utils/__init__.py +0 -0
  6. orca_sdk/_utils/analysis_ui.py +196 -0
  7. orca_sdk/_utils/analysis_ui_style.css +51 -0
  8. orca_sdk/_utils/auth.py +65 -0
  9. orca_sdk/_utils/auth_test.py +31 -0
  10. orca_sdk/_utils/common.py +37 -0
  11. orca_sdk/_utils/data_parsing.py +129 -0
  12. orca_sdk/_utils/data_parsing_test.py +244 -0
  13. orca_sdk/_utils/pagination.py +126 -0
  14. orca_sdk/_utils/pagination_test.py +132 -0
  15. orca_sdk/_utils/prediction_result_ui.css +18 -0
  16. orca_sdk/_utils/prediction_result_ui.py +110 -0
  17. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  18. orca_sdk/_utils/value_parser.py +45 -0
  19. orca_sdk/_utils/value_parser_test.py +39 -0
  20. orca_sdk/async_client.py +4104 -0
  21. orca_sdk/classification_model.py +1165 -0
  22. orca_sdk/classification_model_test.py +887 -0
  23. orca_sdk/client.py +4096 -0
  24. orca_sdk/conftest.py +382 -0
  25. orca_sdk/credentials.py +217 -0
  26. orca_sdk/credentials_test.py +121 -0
  27. orca_sdk/datasource.py +576 -0
  28. orca_sdk/datasource_test.py +463 -0
  29. orca_sdk/embedding_model.py +712 -0
  30. orca_sdk/embedding_model_test.py +206 -0
  31. orca_sdk/job.py +343 -0
  32. orca_sdk/job_test.py +108 -0
  33. orca_sdk/memoryset.py +3811 -0
  34. orca_sdk/memoryset_test.py +1150 -0
  35. orca_sdk/regression_model.py +841 -0
  36. orca_sdk/regression_model_test.py +595 -0
  37. orca_sdk/telemetry.py +742 -0
  38. orca_sdk/telemetry_test.py +119 -0
  39. orca_sdk-0.1.9.dist-info/METADATA +98 -0
  40. orca_sdk-0.1.9.dist-info/RECORD +41 -0
  41. orca_sdk-0.1.9.dist-info/WHEEL +4 -0
@@ -0,0 +1,244 @@
1
+ import json
2
+ import pickle
3
+ import tempfile
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass
6
+
7
+ import pandas as pd
8
+ import pytest
9
+ from datasets import Dataset
10
+ from datasets.exceptions import DatasetGenerationError
11
+ from torch.utils.data import DataLoader as TorchDataLoader
12
+ from torch.utils.data import Dataset as TorchDataset
13
+
14
+ from ..conftest import SAMPLE_DATA
15
+ from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
16
+
17
+
18
+ class PytorchDictDataset(TorchDataset):
19
+ def __init__(self):
20
+ self.data = SAMPLE_DATA
21
+
22
+ def __getitem__(self, i):
23
+ return self.data[i]
24
+
25
+ def __len__(self):
26
+ return len(self.data)
27
+
28
+
29
+ def test_hf_dataset_from_torch_dict():
30
+ # Given a Pytorch dataset that returns a dictionary for each item
31
+ dataset = PytorchDictDataset()
32
+ hf_dataset = hf_dataset_from_torch(dataset)
33
+ # Then the HF dataset should be created successfully
34
+ assert isinstance(hf_dataset, Dataset)
35
+ assert len(hf_dataset) == len(dataset)
36
+ assert set(hf_dataset.column_names) == {"value", "label", "key", "score", "source_id", "partition_id"}
37
+
38
+
39
+ class PytorchTupleDataset(TorchDataset):
40
+ def __init__(self):
41
+ self.data = SAMPLE_DATA
42
+
43
+ def __getitem__(self, i):
44
+ return self.data[i]["value"], self.data[i]["label"]
45
+
46
+ def __len__(self):
47
+ return len(self.data)
48
+
49
+
50
+ def test_hf_dataset_from_torch_tuple():
51
+ # Given a Pytorch dataset that returns a tuple for each item
52
+ dataset = PytorchTupleDataset()
53
+ # And the correct number of column names passed in
54
+ hf_dataset = hf_dataset_from_torch(dataset, column_names=["value", "label"])
55
+ # Then the HF dataset should be created successfully
56
+ assert isinstance(hf_dataset, Dataset)
57
+ assert len(hf_dataset) == len(dataset)
58
+ assert hf_dataset.column_names == ["value", "label"]
59
+
60
+
61
+ def test_hf_dataset_from_torch_tuple_error():
62
+ # Given a Pytorch dataset that returns a tuple for each item
63
+ dataset = PytorchTupleDataset()
64
+ # Then the HF dataset should raise an error if no column names are passed in
65
+ with pytest.raises(DatasetGenerationError):
66
+ hf_dataset_from_torch(dataset)
67
+
68
+
69
+ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
70
+ # Given a Pytorch dataset that returns a tuple for each item
71
+ dataset = PytorchTupleDataset()
72
+ # Then the HF dataset should raise an error if not enough column names are passed in
73
+ with pytest.raises(DatasetGenerationError):
74
+ hf_dataset_from_torch(dataset, column_names=["value"])
75
+
76
+
77
+ DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
78
+
79
+
80
+ class PytorchNamedTupleDataset(TorchDataset):
81
+ def __init__(self):
82
+ self.data = SAMPLE_DATA
83
+
84
+ def __getitem__(self, i):
85
+ return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
86
+
87
+ def __len__(self):
88
+ return len(self.data)
89
+
90
+
91
+ def test_hf_dataset_from_torch_named_tuple():
92
+ # Given a Pytorch dataset that returns a namedtuple for each item
93
+ dataset = PytorchNamedTupleDataset()
94
+ # And no column names are passed in
95
+ hf_dataset = hf_dataset_from_torch(dataset)
96
+ # Then the HF dataset should be created successfully
97
+ assert isinstance(hf_dataset, Dataset)
98
+ assert len(hf_dataset) == len(dataset)
99
+ assert hf_dataset.column_names == ["value", "label"]
100
+
101
+
102
+ @dataclass
103
+ class DatasetItem:
104
+ text: str
105
+ label: int
106
+
107
+
108
+ class PytorchDataclassDataset(TorchDataset):
109
+ def __init__(self):
110
+ self.data = SAMPLE_DATA
111
+
112
+ def __getitem__(self, i):
113
+ return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
114
+
115
+ def __len__(self):
116
+ return len(self.data)
117
+
118
+
119
+ def test_hf_dataset_from_torch_dataclass():
120
+ # Given a Pytorch dataset that returns a dataclass for each item
121
+ dataset = PytorchDataclassDataset()
122
+ hf_dataset = hf_dataset_from_torch(dataset)
123
+ # Then the HF dataset should be created successfully
124
+ assert isinstance(hf_dataset, Dataset)
125
+ assert len(hf_dataset) == len(dataset)
126
+ assert hf_dataset.column_names == ["text", "label"]
127
+
128
+
129
+ class PytorchInvalidDataset(TorchDataset):
130
+ def __init__(self):
131
+ self.data = SAMPLE_DATA
132
+
133
+ def __getitem__(self, i):
134
+ return [self.data[i]["value"], self.data[i]["label"]]
135
+
136
+ def __len__(self):
137
+ return len(self.data)
138
+
139
+
140
+ def test_hf_dataset_from_torch_invalid_dataset():
141
+ # Given a Pytorch dataset that returns a list for each item
142
+ dataset = PytorchInvalidDataset()
143
+ # Then the HF dataset should raise an error
144
+ with pytest.raises(DatasetGenerationError):
145
+ hf_dataset_from_torch(dataset)
146
+
147
+
148
+ def test_hf_dataset_from_torchdataloader():
149
+ # Given a Pytorch dataloader that returns a column-oriented batch of items
150
+ dataset = PytorchDictDataset()
151
+
152
+ def collate_fn(x: list[dict]):
153
+ return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
154
+
155
+ dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
156
+ hf_dataset = hf_dataset_from_torch(dataloader)
157
+ # Then the HF dataset should be created successfully
158
+ assert isinstance(hf_dataset, Dataset)
159
+ assert len(hf_dataset) == len(dataset)
160
+ assert hf_dataset.column_names == ["value", "label"]
161
+
162
+
163
+ def test_hf_dataset_from_disk_pickle_list():
164
+ with tempfile.NamedTemporaryFile(suffix=".pkl") as temp_file:
165
+ # Given a pickle file with test data that is a list
166
+ test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
167
+ with open(temp_file.name, "wb") as f:
168
+ pickle.dump(test_data, f)
169
+ dataset = hf_dataset_from_disk(temp_file.name)
170
+ # Then the HF dataset should be created successfully
171
+ assert isinstance(dataset, Dataset)
172
+ assert len(dataset) == 30
173
+ assert dataset.column_names == ["value", "label"]
174
+
175
+
176
+ def test_hf_dataset_from_disk_pickle_dict():
177
+ with tempfile.NamedTemporaryFile(suffix=".pkl") as temp_file:
178
+ # Given a pickle file with test data that is a dict
179
+ test_data = {"value": [f"test_{i}" for i in range(30)], "label": [i % 2 for i in range(30)]}
180
+ with open(temp_file.name, "wb") as f:
181
+ pickle.dump(test_data, f)
182
+ dataset = hf_dataset_from_disk(temp_file.name)
183
+ # Then the HF dataset should be created successfully
184
+ assert isinstance(dataset, Dataset)
185
+ assert len(dataset) == 30
186
+ assert dataset.column_names == ["value", "label"]
187
+
188
+
189
+ def test_hf_dataset_from_disk_json():
190
+ with tempfile.NamedTemporaryFile(suffix=".json") as temp_file:
191
+ # Given a JSON file with test data
192
+ test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
193
+ with open(temp_file.name, "w") as f:
194
+ json.dump(test_data, f)
195
+ dataset = hf_dataset_from_disk(temp_file.name)
196
+ # Then the HF dataset should be created successfully
197
+ assert isinstance(dataset, Dataset)
198
+ assert len(dataset) == 30
199
+ assert dataset.column_names == ["value", "label"]
200
+
201
+
202
+ def test_hf_dataset_from_disk_jsonl():
203
+ with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file:
204
+ # Given a JSONL file with test data
205
+ test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
206
+ with open(temp_file.name, "w") as f:
207
+ for item in test_data:
208
+ f.write(json.dumps(item) + "\n")
209
+ dataset = hf_dataset_from_disk(temp_file.name)
210
+ # Then the HF dataset should be created successfully
211
+ assert isinstance(dataset, Dataset)
212
+ assert len(dataset) == 30
213
+ assert dataset.column_names == ["value", "label"]
214
+
215
+
216
+ def test_hf_dataset_from_disk_csv():
217
+ with tempfile.NamedTemporaryFile(suffix=".csv") as temp_file:
218
+ # Given a CSV file with test data
219
+ test_data = [{"value": f"test_{i}", "label": i % 2} for i in range(30)]
220
+ with open(temp_file.name, "w") as f:
221
+ f.write("value,label\n")
222
+ for item in test_data:
223
+ f.write(f"{item['value']},{item['label']}\n")
224
+ dataset = hf_dataset_from_disk(temp_file.name)
225
+ # Then the HF dataset should be created successfully
226
+ assert isinstance(dataset, Dataset)
227
+ assert len(dataset) == 30
228
+ assert dataset.column_names == ["value", "label"]
229
+
230
+
231
+ def test_hf_dataset_from_disk_parquet():
232
+ with tempfile.NamedTemporaryFile(suffix=".parquet") as temp_file:
233
+ # Given a Parquet file with test data
234
+ data = {
235
+ "value": [f"test_{i}" for i in range(30)],
236
+ "label": [i % 2 for i in range(30)],
237
+ }
238
+ df = pd.DataFrame(data)
239
+ df.to_parquet(temp_file.name)
240
+ dataset = hf_dataset_from_disk(temp_file.name)
241
+ # Then the HF dataset should be created successfully
242
+ assert isinstance(dataset, Dataset)
243
+ assert len(dataset) == 30
244
+ assert dataset.column_names == ["value", "label"]
@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable, Generic, Iterator, TypedDict, TypeVar, cast, overload
4
+
5
+ T = TypeVar("T")
6
+ R = TypeVar("R")
7
+
8
+
9
+ class Page(TypedDict, Generic[T]):
10
+ items: list[T]
11
+ count: int
12
+
13
+
14
+ class _PagedIterable(Generic[T, R]):
15
+ def __init__(
16
+ self,
17
+ fetch: Callable[[int, int], Page[T]],
18
+ *,
19
+ transform: Callable[[T], R] | None = None,
20
+ page_size: int = 100,
21
+ ) -> None:
22
+ """
23
+ Iterate over a paginated endpoint.
24
+
25
+ Parameters:
26
+ fetch: function to fetch a page from the endpoint `(limit: int, offset: int) -> TypedDict[{items: list[T], count: int}]`
27
+ transform: Optional function to transforms item types `(item: T) -> R`, defaults to identity
28
+ limit: maximum number of items to fetch per page
29
+ """
30
+ self.fetch = fetch
31
+ self.transform = transform or (lambda x: cast(R, x))
32
+ self.page_size = page_size
33
+ self.offset = 0 # tracks how much has been yielded, not fetched
34
+ self.page = fetch(self.page_size, self.offset) # fetch first page to populate count
35
+ self.count = self.page["count"]
36
+
37
+ def __iter__(self) -> Iterator[R]:
38
+ if self.offset >= self.count:
39
+ self.offset = 0
40
+ if len(self.page["items"]) < self.count:
41
+ # refetch first page unless we are still on the first page
42
+ self.page = self.fetch(self.page_size, self.offset)
43
+
44
+ # yield prefetched first page
45
+ if self.offset == 0:
46
+ yield from map(self.transform, self.page["items"])
47
+ self.offset += len(self.page["items"])
48
+
49
+ # yield remaining pages one by one
50
+ while self.offset < self.count:
51
+ self.page = self.fetch(self.page_size, self.offset)
52
+ yield from map(self.transform, self.page["items"])
53
+ self.offset += len(self.page["items"])
54
+
55
+ @overload
56
+ def __getitem__(self, key: int) -> R:
57
+ pass
58
+
59
+ @overload
60
+ def __getitem__(self, key: slice) -> list[R]:
61
+ pass
62
+
63
+ def __getitem__(self, key: int | slice) -> R | list[R]:
64
+ if isinstance(key, int):
65
+ effective_key = key
66
+ if effective_key < 0:
67
+ effective_key += self.count
68
+ if not 0 <= effective_key < self.count:
69
+ raise IndexError(f"Index {key} out of range")
70
+ # if key is on current page, return item
71
+ if self.offset <= effective_key < self.offset + len(self.page["items"]):
72
+ return self.transform(self.page["items"][effective_key - self.offset])
73
+ # otherwise, fetch and return the single item
74
+ return self.transform(self.fetch(1, effective_key)["items"][0])
75
+
76
+ elif isinstance(key, slice):
77
+ start, stop, step = key.indices(self.count)
78
+ if step != 1:
79
+ raise ValueError("Stepped slicing is not supported")
80
+ start = start + self.count if start < 0 else start or 0
81
+ stop = stop + self.count if stop < 0 else stop or self.count
82
+ if start >= self.count or stop > self.count:
83
+ raise IndexError(f"Slice {key} out of range")
84
+ limit = min(self.page_size, stop - start)
85
+ if limit <= 0:
86
+ return []
87
+ items = []
88
+ for i in range(start, stop, limit):
89
+ page = self.fetch(limit, i)
90
+ items.extend(map(self.transform, page["items"]))
91
+ return items
92
+
93
+ def __len__(self) -> int:
94
+ return self.count
95
+
96
+
97
+ # type checking workaround until python 3.13 allows declaring the class as PagedIterable[T, R = T]
98
+
99
+
100
+ @overload
101
+ def PagedIterable(
102
+ fetch: Callable[[int, int], Page[T]],
103
+ *,
104
+ transform: None = None,
105
+ page_size: int = 100,
106
+ ) -> _PagedIterable[T, T]:
107
+ pass
108
+
109
+
110
+ @overload
111
+ def PagedIterable(
112
+ fetch: Callable[[int, int], Page[T]],
113
+ *,
114
+ transform: Callable[[T], R],
115
+ page_size: int = 100,
116
+ ) -> _PagedIterable[T, R]:
117
+ pass
118
+
119
+
120
+ def PagedIterable(
121
+ fetch: Callable[[int, int], Page[T]],
122
+ *,
123
+ transform: Callable[[T], R] | None = None,
124
+ page_size: int = 100,
125
+ ) -> _PagedIterable[T, R]:
126
+ return _PagedIterable(fetch, transform=transform, page_size=page_size)
@@ -0,0 +1,132 @@
1
+ import pytest
2
+
3
+ from .pagination import Page, PagedIterable
4
+
5
+
6
+ class MockEndpoint:
7
+ """Mock paginated endpoint for testing"""
8
+
9
+ def __init__(self, total_items: int):
10
+ self.items = list(range(total_items))
11
+ self.fetch_count = 0
12
+
13
+ def fetch(self, limit: int, offset: int) -> Page[int]:
14
+ self.fetch_count += 1
15
+ end_index = min(offset + limit, len(self.items))
16
+ items = self.items[offset:end_index]
17
+ return {"items": items, "count": len(self.items)}
18
+
19
+
20
+ def test_basic_pagination():
21
+ # Given a mock endpoint with 5 items
22
+ endpoint = MockEndpoint(5)
23
+ # When doing a paginated iteration
24
+ paginated = PagedIterable(endpoint.fetch, page_size=2)
25
+ # Then we should be able to iterate through all items
26
+ assert list(paginated) == [0, 1, 2, 3, 4]
27
+ # And the length should be correct
28
+ assert len(paginated) == 5
29
+ # And 3 requests: [0,1], [2,3], [4] should have been made, one for each page
30
+ assert endpoint.fetch_count == 3
31
+
32
+
33
+ def test_empty_results():
34
+ # Given an empty mock endpoint
35
+ endpoint = MockEndpoint(0)
36
+ # When doing a paginated iteration
37
+ paginated = PagedIterable(endpoint.fetch, page_size=5)
38
+ # Then we should get an empty list
39
+ assert list(paginated) == []
40
+ # And the length should be 0
41
+ assert len(paginated) == 0
42
+ # And only one request should have been made, for the first page
43
+ assert endpoint.fetch_count == 1
44
+
45
+
46
+ def test_transform_function():
47
+ # Given a mock endpoint with 4 items
48
+ endpoint = MockEndpoint(4)
49
+ # And a transform function that doubles the items
50
+ transform = lambda x: f"2x={2*x}"
51
+ # When doing a paginated iteration with a transform function
52
+ paginated = PagedIterable(endpoint.fetch, transform=transform, page_size=2)
53
+ # Then we should get the transformed items
54
+ assert list(paginated) == ["2x=0", "2x=2", "2x=4", "2x=6"]
55
+
56
+
57
+ def test_multiple_iterations():
58
+ # Given a mock endpoint with 5 items
59
+ endpoint = MockEndpoint(5)
60
+ # When we do 2 paginated iterations
61
+ paginated = PagedIterable(endpoint.fetch, page_size=2)
62
+ result1 = list(paginated)
63
+ result2 = list(paginated)
64
+ # Then we should get the same items twice
65
+ assert result1 == result2 == [0, 1, 2, 3, 4]
66
+ # And 6 requests should have been made, 3 for each iteration
67
+ assert endpoint.fetch_count == 6
68
+
69
+
70
+ def test_single_page_optimization():
71
+ # Given a mock endpoint with 5 items
72
+ endpoint = MockEndpoint(5)
73
+ # When doing a paginated iteration with a limit that is greater than the number of items
74
+ paginated = PagedIterable(endpoint.fetch, page_size=10)
75
+ # Then we should get all items
76
+ assert list(paginated) == [0, 1, 2, 3, 4]
77
+ # And the length should be 5
78
+ assert len(paginated) == 5
79
+ # And only one request should have been made
80
+ assert endpoint.fetch_count == 1
81
+ # And a second iteration should not make any additional requests
82
+ assert list(paginated) == [0, 1, 2, 3, 4]
83
+ assert endpoint.fetch_count == 1
84
+
85
+
86
+ def test_indexing():
87
+ # Given a mock endpoint with 7 items
88
+ endpoint = MockEndpoint(7)
89
+ # When creating a paginated iterable with page size 3
90
+ paginated = PagedIterable(endpoint.fetch, page_size=3)
91
+ # Then we should be able to access items by index
92
+ assert paginated[0] == 0
93
+ assert paginated[2] == 2
94
+ assert paginated[6] == 6
95
+ # And negative indices should work
96
+ assert paginated[-1] == 6
97
+ # And accessing out of bounds should raise IndexError
98
+ with pytest.raises(IndexError):
99
+ paginated[7]
100
+ with pytest.raises(IndexError):
101
+ paginated[-8]
102
+ # And transforms are applied
103
+ assert PagedIterable(endpoint.fetch, transform=lambda x: x * 10, page_size=3)[1] == 10
104
+
105
+
106
+ def test_slicing():
107
+ # Given a mock endpoint with 10 items
108
+ endpoint = MockEndpoint(10)
109
+ # When creating a paginated iterable
110
+ paginated = PagedIterable(endpoint.fetch, page_size=3)
111
+ # Then we should be able to slice it
112
+ assert list(paginated[2:5]) == [2, 3, 4]
113
+ assert list(paginated[:3]) == [0, 1, 2]
114
+ assert list(paginated[7:]) == [7, 8, 9]
115
+ # And negative indices should work
116
+ assert list(paginated[:-5]) == [0, 1, 2, 3, 4, 5]
117
+ assert list(paginated[-3:]) == [7, 8, 9]
118
+ assert list(paginated[-5:-2]) == [5, 6, 7]
119
+ # And empty slices should work
120
+ assert list(paginated[5:5]) == []
121
+ # And slicing with a start and stop that are out of bounds should raise IndexError
122
+ with pytest.raises(IndexError):
123
+ list(paginated[20:25])
124
+ # And slicing with a step other than 1 should raise ValueError
125
+ with pytest.raises(ValueError):
126
+ list(paginated[::2])
127
+ with pytest.raises(ValueError):
128
+ list(paginated[1:8:3])
129
+ with pytest.raises(ValueError):
130
+ list(paginated[::-1])
131
+ # And transforms are applied
132
+ assert list(PagedIterable(endpoint.fetch, transform=lambda x: x * 10, page_size=3)[1:3]) == [10, 20]
@@ -0,0 +1,18 @@
1
+ .white {
2
+ background-color: white;
3
+ }
4
+ .success {
5
+ color: gray;
6
+ font-size: 12px;
7
+ height: 24px;
8
+ }
9
+ .html-container:has(.no-padding) {
10
+ padding: 0;
11
+ height: 24px;
12
+ }
13
+ .progress-bar {
14
+ background-color: #2b9a66;
15
+ }
16
+ .progress-level-inner {
17
+ display: none;
18
+ }
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import re
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ import gradio as gr
9
+
10
+ from ..memoryset import LabeledMemoryLookup, LabeledMemoryset, ScoredMemoryLookup
11
+
12
+ if TYPE_CHECKING:
13
+ from ..telemetry import PredictionBase
14
+
15
+
16
+ def inspect_prediction_result(prediction_result: PredictionBase):
17
+
18
+ def update_label(val: str, memory: LabeledMemoryLookup, progress=gr.Progress(track_tqdm=True)):
19
+ progress(0)
20
+ match = re.search(r".*\((\d+)\)$", val)
21
+ if match:
22
+ progress(0.5)
23
+ new_label = int(match.group(1))
24
+ memory.update(label=new_label)
25
+ progress(1)
26
+ return "&#9989; Changes saved"
27
+ else:
28
+ logging.error(f"Invalid label format: {val}")
29
+
30
+ def update_score(val: float, memory: ScoredMemoryLookup, progress=gr.Progress(track_tqdm=True)):
31
+ progress(0)
32
+ memory.update(score=val)
33
+ progress(1)
34
+ return "&#9989; Changes saved"
35
+
36
+ with gr.Blocks(
37
+ fill_width=True,
38
+ title="Prediction Results",
39
+ css_paths=str(Path(__file__).parent / "prediction_result_ui.css"),
40
+ ) as prediction_result_ui:
41
+ gr.Markdown("# Prediction Results")
42
+ gr.Markdown(f"**Input:** {prediction_result.input_value}")
43
+
44
+ if isinstance(prediction_result.memoryset, LabeledMemoryset) and prediction_result.label is not None:
45
+ label_names = prediction_result.memoryset.label_names
46
+ gr.Markdown(f"**Prediction:** {label_names[prediction_result.label]} ({prediction_result.label})")
47
+ else:
48
+ gr.Markdown(f"**Prediction:** {prediction_result.score:.2f}")
49
+
50
+ gr.Markdown("### Memory Lookups")
51
+
52
+ with gr.Row(equal_height=True, variant="panel"):
53
+ with gr.Column(scale=7):
54
+ gr.Markdown("**Value**")
55
+ with gr.Column(scale=3, min_width=150):
56
+ gr.Markdown("**Label**" if prediction_result.label is not None else "**Score**")
57
+
58
+ for i, mem_lookup in enumerate(prediction_result.memory_lookups):
59
+ with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
60
+ with gr.Column(scale=7):
61
+ gr.Markdown(
62
+ (
63
+ mem_lookup.value
64
+ if isinstance(mem_lookup.value, str)
65
+ else "Time series data" if isinstance(mem_lookup.value, list) else "Image data"
66
+ ),
67
+ label="Value",
68
+ height=50,
69
+ )
70
+ with gr.Column(scale=3, min_width=150):
71
+ if (
72
+ isinstance(prediction_result.memoryset, LabeledMemoryset)
73
+ and prediction_result.label is not None
74
+ and isinstance(mem_lookup, LabeledMemoryLookup)
75
+ ):
76
+ label_names = prediction_result.memoryset.label_names
77
+ dropdown = gr.Dropdown(
78
+ choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
79
+ label="Label",
80
+ value=(
81
+ f"{label_names[mem_lookup.label]} ({mem_lookup.label})"
82
+ if mem_lookup.label is not None
83
+ else "None"
84
+ ),
85
+ interactive=True,
86
+ container=False,
87
+ )
88
+ changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
89
+ dropdown.change(
90
+ lambda val, mem=mem_lookup: update_label(val, mem),
91
+ inputs=[dropdown],
92
+ outputs=[changes_saved],
93
+ show_progress="full",
94
+ )
95
+ elif prediction_result.score is not None and isinstance(mem_lookup, ScoredMemoryLookup):
96
+ input = gr.Number(
97
+ value=mem_lookup.score,
98
+ label="Score",
99
+ interactive=True,
100
+ container=False,
101
+ )
102
+ changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
103
+ input.change(
104
+ lambda val, mem=mem_lookup: update_score(val, mem),
105
+ inputs=[input],
106
+ outputs=[changes_saved],
107
+ show_progress="full",
108
+ )
109
+
110
+ prediction_result_ui.launch()
@@ -0,0 +1,12 @@
1
+ class TqdmFileReader:
2
+ def __init__(self, file_obj, pbar):
3
+ self.file_obj = file_obj
4
+ self.pbar = pbar
5
+
6
+ def read(self, size=-1):
7
+ data = self.file_obj.read(size)
8
+ self.pbar.update(len(data))
9
+ return data
10
+
11
+ def __getattr__(self, attr):
12
+ return getattr(self.file_obj, attr)
@@ -0,0 +1,45 @@
1
+ import base64
2
+ import io
3
+ from typing import cast
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+ from PIL import Image as pil
8
+
9
+ ValueType = str | pil.Image | NDArray[np.float32]
10
+ """
11
+ The type of a value in a memoryset
12
+
13
+ - `str`: string
14
+ - `pil.Image`: image
15
+ - `NDArray[np.float32]`: univariate or multivariate timeseries
16
+ """
17
+
18
+
19
+ def decode_value(value: str) -> ValueType:
20
+ if value.startswith("data:image"):
21
+ header, data = value.split(",", 1)
22
+ return pil.open(io.BytesIO(base64.b64decode(data)))
23
+
24
+ if value.startswith("data:numpy"):
25
+ header, data = value.split(",", 1)
26
+ return np.load(io.BytesIO(base64.b64decode(data)))
27
+
28
+ return value
29
+
30
+
31
+ def encode_value(value: ValueType) -> str:
32
+ if isinstance(value, pil.Image):
33
+ header = f"data:image/{value.format.lower()};base64," if value.format else "data:image;base64,"
34
+ buffer = io.BytesIO()
35
+ value.save(buffer, format=value.format)
36
+ bytes = buffer.getvalue()
37
+ return header + base64.b64encode(bytes).decode("utf-8")
38
+
39
+ if isinstance(value, np.ndarray):
40
+ header = f"data:numpy/{value.dtype.name};base64,"
41
+ buffer = io.BytesIO()
42
+ np.save(buffer, value)
43
+ return header + base64.b64encode(buffer.getvalue()).decode("utf-8")
44
+
45
+ return value