orca-sdk 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. orca_sdk/__init__.py +30 -0
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +634 -0
  4. orca_sdk/_shared/metrics_test.py +570 -0
  5. orca_sdk/_utils/__init__.py +0 -0
  6. orca_sdk/_utils/analysis_ui.py +196 -0
  7. orca_sdk/_utils/analysis_ui_style.css +51 -0
  8. orca_sdk/_utils/auth.py +65 -0
  9. orca_sdk/_utils/auth_test.py +31 -0
  10. orca_sdk/_utils/common.py +37 -0
  11. orca_sdk/_utils/data_parsing.py +129 -0
  12. orca_sdk/_utils/data_parsing_test.py +244 -0
  13. orca_sdk/_utils/pagination.py +126 -0
  14. orca_sdk/_utils/pagination_test.py +132 -0
  15. orca_sdk/_utils/prediction_result_ui.css +18 -0
  16. orca_sdk/_utils/prediction_result_ui.py +110 -0
  17. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  18. orca_sdk/_utils/value_parser.py +45 -0
  19. orca_sdk/_utils/value_parser_test.py +39 -0
  20. orca_sdk/async_client.py +4104 -0
  21. orca_sdk/classification_model.py +1165 -0
  22. orca_sdk/classification_model_test.py +887 -0
  23. orca_sdk/client.py +4096 -0
  24. orca_sdk/conftest.py +382 -0
  25. orca_sdk/credentials.py +217 -0
  26. orca_sdk/credentials_test.py +121 -0
  27. orca_sdk/datasource.py +576 -0
  28. orca_sdk/datasource_test.py +463 -0
  29. orca_sdk/embedding_model.py +712 -0
  30. orca_sdk/embedding_model_test.py +206 -0
  31. orca_sdk/job.py +343 -0
  32. orca_sdk/job_test.py +108 -0
  33. orca_sdk/memoryset.py +3811 -0
  34. orca_sdk/memoryset_test.py +1150 -0
  35. orca_sdk/regression_model.py +841 -0
  36. orca_sdk/regression_model_test.py +595 -0
  37. orca_sdk/telemetry.py +742 -0
  38. orca_sdk/telemetry_test.py +119 -0
  39. orca_sdk-0.1.9.dist-info/METADATA +98 -0
  40. orca_sdk-0.1.9.dist-info/RECORD +41 -0
  41. orca_sdk-0.1.9.dist-info/WHEEL +4 -0
@@ -0,0 +1,196 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import re
5
+ from pathlib import Path
6
+ from typing import TypedDict, cast
7
+
8
+ import gradio as gr
9
+
10
+ from ..memoryset import LabeledMemory, LabeledMemoryset
11
+
12
+ # Suppress all httpx logs
13
+ logging.getLogger("httpx").setLevel(logging.CRITICAL)
14
+
15
+ # Optionally suppress other libraries Gradio might use
16
+ logging.getLogger("gradio").setLevel(logging.CRITICAL)
17
+
18
+
19
+ class RelabelStatus(TypedDict):
20
+ memory_id: str
21
+ approved: bool
22
+ new_label: int | None
23
+ full_memory: LabeledMemory
24
+
25
+
26
+ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
27
+ suggested_relabels = memoryset.query(
28
+ filters=[("metrics.neighbor_predicted_label_matches_current_label", "==", False)]
29
+ )
30
+ # Sort memories by confidence score (higher confidence first)
31
+ suggested_relabels.sort(key=lambda x: (x.metrics.get("neighbor_predicted_label_confidence", 0.0)), reverse=True)
32
+
33
+ def update_approved(memory_id: str, selected: bool, current_memory_relabel_map: dict[str, RelabelStatus]):
34
+ current_memory_relabel_map[memory_id]["approved"] = selected
35
+ return current_memory_relabel_map
36
+
37
+ def approve_all(current_all_approved, selected: bool):
38
+ for mem_id in current_all_approved:
39
+ current_all_approved[mem_id]["approved"] = selected
40
+ return current_all_approved, selected
41
+
42
+ def apply_selected(current_memory_relabel_map: dict[str, RelabelStatus], progress=gr.Progress(track_tqdm=True)):
43
+ progress(0, desc="Processing label updates...")
44
+ to_be_deleted = []
45
+ approved_relabels = [mem for mem in current_memory_relabel_map.values() if mem["approved"]]
46
+ for memory in progress.tqdm(approved_relabels, desc="Applying label updates..."):
47
+ memory = cast(RelabelStatus, memory)
48
+ new_label = memory["new_label"]
49
+ assert isinstance(new_label, int)
50
+ memoryset.update(
51
+ {
52
+ "memory_id": memory["memory_id"],
53
+ "label": new_label,
54
+ }
55
+ )
56
+ to_be_deleted.append(memory["memory_id"])
57
+ for mem_id in to_be_deleted:
58
+ del current_memory_relabel_map[mem_id]
59
+ return (
60
+ current_memory_relabel_map,
61
+ gr.HTML(
62
+ f"<h1 style='display: inline-block; position: fixed; z-index: 1000; left: 36px; top: 14px;'>Suggested Label Updates: {len(current_memory_relabel_map)}</h1>",
63
+ ),
64
+ )
65
+
66
+ def update_label(mem_id: str, label: str, current_memory_relabel_map: dict[str, RelabelStatus]):
67
+ match = re.search(r".*\((\d+)\)$", label)
68
+ if match:
69
+ new_label = int(match.group(1))
70
+ current_memory_relabel_map[mem_id]["new_label"] = new_label
71
+ confidence = "--"
72
+ current_metrics = current_memory_relabel_map[mem_id]["full_memory"].metrics
73
+ if current_metrics and new_label == current_metrics.get("neighbor_predicted_label"):
74
+ confidence = (
75
+ round(current_metrics.get("neighbor_predicted_label_confidence", 0.0), 2) if current_metrics else 0
76
+ )
77
+ return (
78
+ gr.HTML(
79
+ f"<p style='font-size: 10px; color: #888;'>Confidence: {confidence}</p>",
80
+ elem_classes="no-padding",
81
+ ),
82
+ current_memory_relabel_map,
83
+ )
84
+ else:
85
+ logging.error(f"Invalid label format: {label}")
86
+
87
+ with gr.Blocks(
88
+ fill_width=True,
89
+ title="Suggested Label Updates",
90
+ css_paths=str(Path(__file__).parent / "analysis_ui_style.css"),
91
+ ) as demo:
92
+ label_names = memoryset.label_names
93
+
94
+ refresh = gr.State(False)
95
+ all_approved = gr.State(False)
96
+ memory_relabel_map = gr.State(
97
+ {
98
+ mem.memory_id: RelabelStatus(
99
+ memory_id=mem.memory_id,
100
+ approved=False,
101
+ new_label=(
102
+ mem.metrics.get("neighbor_predicted_label")
103
+ if (mem.metrics and isinstance(mem.metrics.get("neighbor_predicted_label"), int))
104
+ else None
105
+ ),
106
+ full_memory=mem,
107
+ )
108
+ for mem in suggested_relabels
109
+ }
110
+ )
111
+
112
+ @gr.render(
113
+ inputs=[memory_relabel_map, all_approved],
114
+ triggers=[demo.load, refresh.change, all_approved.change, memory_relabel_map.change], # type: ignore[arg-type]
115
+ )
116
+ def render_table(current_memory_relabel_map, current_all_approved):
117
+ if len(current_memory_relabel_map):
118
+ with gr.Group(elem_classes="header"):
119
+ title = gr.HTML(
120
+ f"<h1 style='display: inline-block; position: fixed; z-index: 1000; left: 36px; top: 14px;'>Suggested Label Updates: {len(current_memory_relabel_map)}</h1>"
121
+ )
122
+ apply_selected_button = gr.Button("Apply Selected", elem_classes="button")
123
+ apply_selected_button.click(
124
+ apply_selected,
125
+ inputs=[memory_relabel_map],
126
+ outputs=[memory_relabel_map, title],
127
+ show_progress="full",
128
+ )
129
+ with gr.Row(equal_height=True, variant="panel", elem_classes="margin-top"):
130
+ with gr.Column(scale=9):
131
+ gr.Markdown("**Value**")
132
+ with gr.Column(scale=2, min_width=90):
133
+ gr.Markdown("**Current Label**")
134
+ with gr.Column(scale=3, min_width=150):
135
+ gr.Markdown("**Suggested Label**", elem_classes="centered")
136
+ with gr.Column(scale=2, min_width=50):
137
+ approve_all_checkbox = gr.Checkbox(
138
+ show_label=False,
139
+ value=current_all_approved,
140
+ label="",
141
+ container=False,
142
+ elem_classes="centered",
143
+ )
144
+ approve_all_checkbox.change(
145
+ approve_all,
146
+ inputs=[memory_relabel_map, approve_all_checkbox],
147
+ outputs=[memory_relabel_map, all_approved],
148
+ )
149
+ for i, memory_relabel in enumerate(current_memory_relabel_map.values()):
150
+ mem = memory_relabel["full_memory"]
151
+ predicted_label = mem.metrics["neighbor_predicted_label"]
152
+ predicted_label_name = label_names[predicted_label]
153
+ predicted_label_confidence = mem.metrics.get("neighbor_predicted_label_confidence", 0)
154
+
155
+ with gr.Row(equal_height=True, variant="panel"):
156
+ with gr.Column(scale=9):
157
+ assert isinstance(mem.value, str)
158
+ gr.Markdown(mem.value, label="Value", height=50)
159
+ with gr.Column(scale=2, min_width=90):
160
+ gr.Markdown(f"{mem.label_name} ({mem.label})", label="Current Label", height=50)
161
+ with gr.Column(scale=3, min_width=150):
162
+ dropdown = gr.Dropdown(
163
+ choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
164
+ label="SuggestedLabel",
165
+ value=f"{predicted_label_name} ({predicted_label})",
166
+ interactive=True,
167
+ container=False,
168
+ )
169
+ confidence = gr.HTML(
170
+ f"<p style='font-size: 10px; color: #888;'>Confidence: {predicted_label_confidence:.2f}</p>",
171
+ elem_classes="no-padding",
172
+ )
173
+ dropdown.change(
174
+ lambda val, map, mem_id=mem.memory_id: update_label(mem_id, val, map),
175
+ inputs=[dropdown, memory_relabel_map],
176
+ outputs=[confidence, memory_relabel_map],
177
+ )
178
+ with gr.Column(scale=2, min_width=50):
179
+ checkbox = gr.Checkbox(
180
+ show_label=False,
181
+ label="",
182
+ value=current_memory_relabel_map[mem.memory_id]["approved"],
183
+ container=False,
184
+ elem_classes="centered",
185
+ interactive=True,
186
+ )
187
+ checkbox.input(
188
+ lambda selected, map, mem_id=mem.memory_id: update_approved(mem_id, selected, map),
189
+ inputs=[checkbox, memory_relabel_map],
190
+ outputs=[memory_relabel_map],
191
+ )
192
+
193
+ else:
194
+ gr.HTML("<h1>No suggested label updates</h1>")
195
+
196
+ demo.launch()
@@ -0,0 +1,51 @@
1
+ .centered input {
2
+ margin: auto;
3
+ }
4
+ .centered p {
5
+ text-align: center;
6
+ }
7
+ .button {
8
+ display: inline-block;
9
+ max-width: 250px;
10
+ background-color: #2b9a66;
11
+ color: white;
12
+ position: fixed;
13
+ z-index: 1000;
14
+ right: 36px;
15
+ border-radius: 8px;
16
+ top: 12px;
17
+ }
18
+ .margin-top {
19
+ margin-top: 60px;
20
+ }
21
+ .header {
22
+ position: fixed;
23
+ z-index: 1000;
24
+ height: 64px;
25
+ left: 0;
26
+ top: 0;
27
+ border-radius: 0;
28
+ }
29
+
30
+ input[type='checkbox']:checked,
31
+ input[type='checkbox']:checked:hover,
32
+ input[type='checkbox']:checked:focus {
33
+ background-color: #2b9a66;
34
+ border-color: #2b9a66;
35
+ }
36
+ input[type='checkbox']:focus {
37
+ border-color: #2b9a66;
38
+ }
39
+ .html-container:has(.no-padding) {
40
+ padding: 0;
41
+ }
42
+
43
+ .progress-bar {
44
+ background-color: #2b9a66;
45
+ }
46
+ .header .full {
47
+ position: fixed !important;
48
+ z-index: 1100;
49
+ background-color: #e4e4e7;
50
+ height: 68px;
51
+ }
@@ -0,0 +1,65 @@
1
+ """This module contains internal utils for managing api keys in tests"""
2
+
3
+ import logging
4
+ import os
5
+ from typing import List, Literal
6
+
7
+ from dotenv import load_dotenv
8
+
9
+ from ..client import ApiKeyMetadata, OrcaClient
10
+ from .common import DropMode
11
+
12
+ load_dotenv() # this needs to be here to ensure env is populated before accessing it
13
+
14
+ # the defaults here must match nautilus and lighthouse config defaults
15
+ _ORCA_ROOT_ACCESS_API_KEY = os.environ.get("ORCA_ROOT_ACCESS_API_KEY", "00000000-0000-0000-0000-000000000000")
16
+ _DEFAULT_ORG_ID = os.environ.get("DEFAULT_ORG_ID", "10e50000-0000-4000-a000-a78dca14af3a")
17
+
18
+
19
+ def _create_api_key(org_id: str, name: str, scopes: list[Literal["ADMINISTER", "PREDICT"]] = ["ADMINISTER"]) -> str:
20
+ """Creates an API key for the given organization"""
21
+ client = OrcaClient._resolve_client()
22
+ response = client.POST(
23
+ "/auth/api_key",
24
+ json={"name": name, "scope": scopes},
25
+ headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
26
+ )
27
+ return response["api_key"]
28
+
29
+
30
+ def _list_api_keys(org_id: str) -> List[ApiKeyMetadata]:
31
+ """Lists all API keys for the given organization"""
32
+ client = OrcaClient._resolve_client()
33
+ return client.GET("/auth/api_key", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
34
+
35
+
36
+ def _delete_api_key(org_id: str, name: str, if_not_exists: DropMode = "error") -> None:
37
+ """Deletes the API key with the given name from the organization"""
38
+ try:
39
+ client = OrcaClient._resolve_client()
40
+ client.DELETE(
41
+ "/auth/api_key/{name_or_id}",
42
+ params={"name_or_id": name},
43
+ headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
44
+ )
45
+ except LookupError:
46
+ if if_not_exists == "error":
47
+ raise
48
+
49
+
50
+ def _delete_org(org_id: str) -> None:
51
+ """Deletes the organization"""
52
+ client = OrcaClient._resolve_client()
53
+ client.DELETE("/auth/org", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
54
+
55
+
56
+ def _authenticate_local_api(org_id: str = _DEFAULT_ORG_ID, api_key_name: str = "local") -> None:
57
+ """Connect to the local API at http://localhost:1584/ and authenticate with a new API key"""
58
+ _delete_api_key(org_id, api_key_name, if_not_exists="ignore")
59
+ client = OrcaClient._resolve_client()
60
+ client.base_url = "http://localhost:1584"
61
+ client.headers.update({"Api-Key": _create_api_key(org_id, api_key_name)})
62
+ logging.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
63
+
64
+
65
+ __all__ = ["_create_api_key", "_delete_api_key", "_delete_org", "_list_api_keys", "_authenticate_local_api"]
@@ -0,0 +1,31 @@
1
+ from uuid import uuid4
2
+
3
+ from ..credentials import OrcaCredentials
4
+ from .auth import _create_api_key, _delete_api_key, _delete_org, _list_api_keys
5
+
6
+
7
+ def test_list_api_keys(org_id):
8
+ assert len(_list_api_keys(org_id)) >= 1
9
+
10
+
11
+ def test_create_api_key(org_id):
12
+ name = f"test-{uuid4().hex[:8]}"
13
+ api_key = _create_api_key(org_id=org_id, name=name)
14
+ assert api_key is not None
15
+ assert name in [api_key.name for api_key in OrcaCredentials.list_api_keys()]
16
+
17
+
18
+ def test_delete_api_key(org_id):
19
+ name = f"test-{uuid4().hex[:8]}"
20
+ api_key = _create_api_key(org_id=org_id, name=name)
21
+ assert api_key is not None
22
+ assert name in [api_key.name for api_key in OrcaCredentials.list_api_keys()]
23
+ _delete_api_key(org_id=org_id, name=name)
24
+ assert name not in [api_key.name for api_key in OrcaCredentials.list_api_keys()]
25
+
26
+
27
+ def test_delete_org(other_org_id):
28
+ _create_api_key(org_id=other_org_id, name="test")
29
+ assert len(_list_api_keys(other_org_id)) >= 1
30
+ _delete_org(other_org_id)
31
+ assert len(_list_api_keys(other_org_id)) == 0
@@ -0,0 +1,37 @@
1
+ from typing import Any, Literal
2
+
3
+ CreateMode = Literal["error", "open"]
4
+ """
5
+ Mode for creating a resource.
6
+
7
+ **Options:**
8
+
9
+ - `"error"`: raise an error if a resource with the same name already exists
10
+ - `"open"`: open the resource with the same name if it exists
11
+ """
12
+
13
+ DropMode = Literal["error", "ignore"]
14
+ """
15
+ Mode for deleting a resource.
16
+
17
+ **Options:**
18
+
19
+ - `"error"`: raise an error if the resource does not exist
20
+ - `"ignore"`: do nothing if the resource does not exist
21
+ """
22
+
23
+
24
+ class _UnsetSentinel:
25
+ """See corresponding class in orcalib.pydantic_utils"""
26
+
27
+ def __bool__(self) -> bool:
28
+ return False
29
+
30
+ def __repr__(self) -> str:
31
+ return "UNSET"
32
+
33
+
34
+ UNSET: Any = _UnsetSentinel()
35
+ """
36
+ Default value to indicate that no update should be applied to a field and it should not be set to None
37
+ """
@@ -0,0 +1,129 @@
1
+ import pickle
2
+ from dataclasses import asdict, is_dataclass
3
+ from os import PathLike
4
+ from typing import Any, cast
5
+
6
+ from datasets import Dataset
7
+ from datasets.exceptions import DatasetGenerationError
8
+ from torch.utils.data import DataLoader as TorchDataLoader
9
+ from torch.utils.data import Dataset as TorchDataset
10
+
11
+
12
+ def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
13
+ if isinstance(item, dict):
14
+ return item
15
+
16
+ if isinstance(item, tuple):
17
+ if column_names is not None:
18
+ if len(item) != len(column_names):
19
+ raise ValueError(
20
+ f"Tuple length ({len(item)}) does not match number of column names ({len(column_names)})"
21
+ )
22
+ return {column_names[i]: item[i] for i in range(len(item))}
23
+ elif hasattr(item, "_fields") and all(isinstance(field, str) for field in item._fields): # type: ignore
24
+ return {field: getattr(item, field) for field in item._fields} # type: ignore
25
+ else:
26
+ raise ValueError("For datasets that return unnamed tuples, please provide column_names argument")
27
+
28
+ if is_dataclass(item) and not isinstance(item, type):
29
+ return asdict(item)
30
+
31
+ raise ValueError(f"Cannot parse {type(item)}")
32
+
33
+
34
+ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]:
35
+ if isinstance(batch, list):
36
+ return [parse_dict_like(item, column_names) for item in batch]
37
+
38
+ batch = parse_dict_like(batch, column_names)
39
+ keys = list(batch.keys())
40
+ batch_size = len(batch[keys[0]])
41
+ for key in keys:
42
+ if not len(batch[key]) == batch_size:
43
+ raise ValueError(f"Batch must consist of values of the same length, but {key} has length {len(batch[key])}")
44
+ return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
45
+
46
+
47
+ def hf_dataset_from_torch(
48
+ torch_data: TorchDataLoader | TorchDataset,
49
+ column_names: list[str] | None = None,
50
+ ) -> Dataset:
51
+ """
52
+ Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
53
+
54
+ NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
55
+ cached results can ignore changes you've made to tests. This can make a test appear to succeed
56
+ when it's actually broken or vice versa.
57
+
58
+ Params:
59
+ torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
60
+ column_names: Optional list of column names to use for the dataset. If not provided,
61
+ the column names will be inferred from the data.
62
+ Returns:
63
+ A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
64
+ """
65
+ if isinstance(torch_data, TorchDataLoader):
66
+ dataloader = torch_data
67
+ else:
68
+ dataloader = TorchDataLoader(torch_data, batch_size=1, collate_fn=lambda x: x)
69
+
70
+ # Collect data from the dataloader into a list to avoid serialization issues
71
+ # with Dataset.from_generator in Python 3.14 (see datasets issue #7839)
72
+ data_list = []
73
+ try:
74
+ for batch in dataloader:
75
+ data_list.extend(parse_batch(batch, column_names=column_names))
76
+ except ValueError as e:
77
+ raise DatasetGenerationError(str(e)) from e
78
+
79
+ ds = Dataset.from_list(data_list)
80
+
81
+ if not isinstance(ds, Dataset):
82
+ raise ValueError(f"Failed to create dataset from list: {type(ds)}")
83
+ return ds
84
+
85
+
86
+ def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
87
+ """
88
+ Load a dataset from disk into a HuggingFace Dataset object.
89
+
90
+ Params:
91
+ file_path: Path to the file on disk to create the memoryset from. The file type will
92
+ be inferred from the file extension. The following file types are supported:
93
+
94
+ - .pkl: [`Pickle`][pickle] files containing lists of dictionaries or dictionaries of columns
95
+ - .json/.jsonl: [`JSON`][json] and [`JSON`] Lines files
96
+ - .csv: [`CSV`][csv] files
97
+ - .parquet: [`Parquet`](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html#pyarrow.parquet.ParquetFile) files
98
+ - dataset directory: Directory containing a saved HuggingFace [`Dataset`][datasets.Dataset]
99
+
100
+ Returns:
101
+ A HuggingFace Dataset object containing the loaded data.
102
+
103
+ Raises:
104
+ [`ValueError`][ValueError]: If the pickle file contains unsupported data types or if
105
+ loading the dataset fails for any reason.
106
+ """
107
+ if str(file_path).endswith(".pkl"):
108
+ data = pickle.load(open(file_path, "rb"))
109
+ if isinstance(data, list):
110
+ return Dataset.from_list(data)
111
+ elif isinstance(data, dict):
112
+ return Dataset.from_dict(data)
113
+ else:
114
+ raise ValueError(f"Unsupported pickle file: {file_path}")
115
+ elif str(file_path).endswith(".json"):
116
+ hf_dataset = Dataset.from_json(file_path)
117
+ elif str(file_path).endswith(".jsonl"):
118
+ hf_dataset = Dataset.from_json(file_path)
119
+ elif str(file_path).endswith(".csv"):
120
+ hf_dataset = Dataset.from_csv(file_path)
121
+ elif str(file_path).endswith(".parquet"):
122
+ hf_dataset = Dataset.from_parquet(file_path)
123
+ else:
124
+ try:
125
+ hf_dataset = Dataset.load_from_disk(file_path)
126
+ except Exception as e:
127
+ raise ValueError(f"Failed to load dataset from disk: {e}")
128
+
129
+ return cast(Dataset, hf_dataset)