orca-sdk 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +30 -0
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +634 -0
- orca_sdk/_shared/metrics_test.py +570 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +196 -0
- orca_sdk/_utils/analysis_ui_style.css +51 -0
- orca_sdk/_utils/auth.py +65 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +129 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +110 -0
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +4104 -0
- orca_sdk/classification_model.py +1165 -0
- orca_sdk/classification_model_test.py +887 -0
- orca_sdk/client.py +4096 -0
- orca_sdk/conftest.py +382 -0
- orca_sdk/credentials.py +217 -0
- orca_sdk/credentials_test.py +121 -0
- orca_sdk/datasource.py +576 -0
- orca_sdk/datasource_test.py +463 -0
- orca_sdk/embedding_model.py +712 -0
- orca_sdk/embedding_model_test.py +206 -0
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +3811 -0
- orca_sdk/memoryset_test.py +1150 -0
- orca_sdk/regression_model.py +841 -0
- orca_sdk/regression_model_test.py +595 -0
- orca_sdk/telemetry.py +742 -0
- orca_sdk/telemetry_test.py +119 -0
- orca_sdk-0.1.9.dist-info/METADATA +98 -0
- orca_sdk-0.1.9.dist-info/RECORD +41 -0
- orca_sdk-0.1.9.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TypedDict, cast
|
|
7
|
+
|
|
8
|
+
import gradio as gr
|
|
9
|
+
|
|
10
|
+
from ..memoryset import LabeledMemory, LabeledMemoryset
|
|
11
|
+
|
|
12
|
+
# Suppress all httpx logs
|
|
13
|
+
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
14
|
+
|
|
15
|
+
# Optionally suppress other libraries Gradio might use
|
|
16
|
+
logging.getLogger("gradio").setLevel(logging.CRITICAL)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RelabelStatus(TypedDict):
|
|
20
|
+
memory_id: str
|
|
21
|
+
approved: bool
|
|
22
|
+
new_label: int | None
|
|
23
|
+
full_memory: LabeledMemory
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
|
|
27
|
+
suggested_relabels = memoryset.query(
|
|
28
|
+
filters=[("metrics.neighbor_predicted_label_matches_current_label", "==", False)]
|
|
29
|
+
)
|
|
30
|
+
# Sort memories by confidence score (higher confidence first)
|
|
31
|
+
suggested_relabels.sort(key=lambda x: (x.metrics.get("neighbor_predicted_label_confidence", 0.0)), reverse=True)
|
|
32
|
+
|
|
33
|
+
def update_approved(memory_id: str, selected: bool, current_memory_relabel_map: dict[str, RelabelStatus]):
|
|
34
|
+
current_memory_relabel_map[memory_id]["approved"] = selected
|
|
35
|
+
return current_memory_relabel_map
|
|
36
|
+
|
|
37
|
+
def approve_all(current_all_approved, selected: bool):
|
|
38
|
+
for mem_id in current_all_approved:
|
|
39
|
+
current_all_approved[mem_id]["approved"] = selected
|
|
40
|
+
return current_all_approved, selected
|
|
41
|
+
|
|
42
|
+
def apply_selected(current_memory_relabel_map: dict[str, RelabelStatus], progress=gr.Progress(track_tqdm=True)):
|
|
43
|
+
progress(0, desc="Processing label updates...")
|
|
44
|
+
to_be_deleted = []
|
|
45
|
+
approved_relabels = [mem for mem in current_memory_relabel_map.values() if mem["approved"]]
|
|
46
|
+
for memory in progress.tqdm(approved_relabels, desc="Applying label updates..."):
|
|
47
|
+
memory = cast(RelabelStatus, memory)
|
|
48
|
+
new_label = memory["new_label"]
|
|
49
|
+
assert isinstance(new_label, int)
|
|
50
|
+
memoryset.update(
|
|
51
|
+
{
|
|
52
|
+
"memory_id": memory["memory_id"],
|
|
53
|
+
"label": new_label,
|
|
54
|
+
}
|
|
55
|
+
)
|
|
56
|
+
to_be_deleted.append(memory["memory_id"])
|
|
57
|
+
for mem_id in to_be_deleted:
|
|
58
|
+
del current_memory_relabel_map[mem_id]
|
|
59
|
+
return (
|
|
60
|
+
current_memory_relabel_map,
|
|
61
|
+
gr.HTML(
|
|
62
|
+
f"<h1 style='display: inline-block; position: fixed; z-index: 1000; left: 36px; top: 14px;'>Suggested Label Updates: {len(current_memory_relabel_map)}</h1>",
|
|
63
|
+
),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def update_label(mem_id: str, label: str, current_memory_relabel_map: dict[str, RelabelStatus]):
|
|
67
|
+
match = re.search(r".*\((\d+)\)$", label)
|
|
68
|
+
if match:
|
|
69
|
+
new_label = int(match.group(1))
|
|
70
|
+
current_memory_relabel_map[mem_id]["new_label"] = new_label
|
|
71
|
+
confidence = "--"
|
|
72
|
+
current_metrics = current_memory_relabel_map[mem_id]["full_memory"].metrics
|
|
73
|
+
if current_metrics and new_label == current_metrics.get("neighbor_predicted_label"):
|
|
74
|
+
confidence = (
|
|
75
|
+
round(current_metrics.get("neighbor_predicted_label_confidence", 0.0), 2) if current_metrics else 0
|
|
76
|
+
)
|
|
77
|
+
return (
|
|
78
|
+
gr.HTML(
|
|
79
|
+
f"<p style='font-size: 10px; color: #888;'>Confidence: {confidence}</p>",
|
|
80
|
+
elem_classes="no-padding",
|
|
81
|
+
),
|
|
82
|
+
current_memory_relabel_map,
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
logging.error(f"Invalid label format: {label}")
|
|
86
|
+
|
|
87
|
+
with gr.Blocks(
|
|
88
|
+
fill_width=True,
|
|
89
|
+
title="Suggested Label Updates",
|
|
90
|
+
css_paths=str(Path(__file__).parent / "analysis_ui_style.css"),
|
|
91
|
+
) as demo:
|
|
92
|
+
label_names = memoryset.label_names
|
|
93
|
+
|
|
94
|
+
refresh = gr.State(False)
|
|
95
|
+
all_approved = gr.State(False)
|
|
96
|
+
memory_relabel_map = gr.State(
|
|
97
|
+
{
|
|
98
|
+
mem.memory_id: RelabelStatus(
|
|
99
|
+
memory_id=mem.memory_id,
|
|
100
|
+
approved=False,
|
|
101
|
+
new_label=(
|
|
102
|
+
mem.metrics.get("neighbor_predicted_label")
|
|
103
|
+
if (mem.metrics and isinstance(mem.metrics.get("neighbor_predicted_label"), int))
|
|
104
|
+
else None
|
|
105
|
+
),
|
|
106
|
+
full_memory=mem,
|
|
107
|
+
)
|
|
108
|
+
for mem in suggested_relabels
|
|
109
|
+
}
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@gr.render(
|
|
113
|
+
inputs=[memory_relabel_map, all_approved],
|
|
114
|
+
triggers=[demo.load, refresh.change, all_approved.change, memory_relabel_map.change], # type: ignore[arg-type]
|
|
115
|
+
)
|
|
116
|
+
def render_table(current_memory_relabel_map, current_all_approved):
|
|
117
|
+
if len(current_memory_relabel_map):
|
|
118
|
+
with gr.Group(elem_classes="header"):
|
|
119
|
+
title = gr.HTML(
|
|
120
|
+
f"<h1 style='display: inline-block; position: fixed; z-index: 1000; left: 36px; top: 14px;'>Suggested Label Updates: {len(current_memory_relabel_map)}</h1>"
|
|
121
|
+
)
|
|
122
|
+
apply_selected_button = gr.Button("Apply Selected", elem_classes="button")
|
|
123
|
+
apply_selected_button.click(
|
|
124
|
+
apply_selected,
|
|
125
|
+
inputs=[memory_relabel_map],
|
|
126
|
+
outputs=[memory_relabel_map, title],
|
|
127
|
+
show_progress="full",
|
|
128
|
+
)
|
|
129
|
+
with gr.Row(equal_height=True, variant="panel", elem_classes="margin-top"):
|
|
130
|
+
with gr.Column(scale=9):
|
|
131
|
+
gr.Markdown("**Value**")
|
|
132
|
+
with gr.Column(scale=2, min_width=90):
|
|
133
|
+
gr.Markdown("**Current Label**")
|
|
134
|
+
with gr.Column(scale=3, min_width=150):
|
|
135
|
+
gr.Markdown("**Suggested Label**", elem_classes="centered")
|
|
136
|
+
with gr.Column(scale=2, min_width=50):
|
|
137
|
+
approve_all_checkbox = gr.Checkbox(
|
|
138
|
+
show_label=False,
|
|
139
|
+
value=current_all_approved,
|
|
140
|
+
label="",
|
|
141
|
+
container=False,
|
|
142
|
+
elem_classes="centered",
|
|
143
|
+
)
|
|
144
|
+
approve_all_checkbox.change(
|
|
145
|
+
approve_all,
|
|
146
|
+
inputs=[memory_relabel_map, approve_all_checkbox],
|
|
147
|
+
outputs=[memory_relabel_map, all_approved],
|
|
148
|
+
)
|
|
149
|
+
for i, memory_relabel in enumerate(current_memory_relabel_map.values()):
|
|
150
|
+
mem = memory_relabel["full_memory"]
|
|
151
|
+
predicted_label = mem.metrics["neighbor_predicted_label"]
|
|
152
|
+
predicted_label_name = label_names[predicted_label]
|
|
153
|
+
predicted_label_confidence = mem.metrics.get("neighbor_predicted_label_confidence", 0)
|
|
154
|
+
|
|
155
|
+
with gr.Row(equal_height=True, variant="panel"):
|
|
156
|
+
with gr.Column(scale=9):
|
|
157
|
+
assert isinstance(mem.value, str)
|
|
158
|
+
gr.Markdown(mem.value, label="Value", height=50)
|
|
159
|
+
with gr.Column(scale=2, min_width=90):
|
|
160
|
+
gr.Markdown(f"{mem.label_name} ({mem.label})", label="Current Label", height=50)
|
|
161
|
+
with gr.Column(scale=3, min_width=150):
|
|
162
|
+
dropdown = gr.Dropdown(
|
|
163
|
+
choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
|
|
164
|
+
label="SuggestedLabel",
|
|
165
|
+
value=f"{predicted_label_name} ({predicted_label})",
|
|
166
|
+
interactive=True,
|
|
167
|
+
container=False,
|
|
168
|
+
)
|
|
169
|
+
confidence = gr.HTML(
|
|
170
|
+
f"<p style='font-size: 10px; color: #888;'>Confidence: {predicted_label_confidence:.2f}</p>",
|
|
171
|
+
elem_classes="no-padding",
|
|
172
|
+
)
|
|
173
|
+
dropdown.change(
|
|
174
|
+
lambda val, map, mem_id=mem.memory_id: update_label(mem_id, val, map),
|
|
175
|
+
inputs=[dropdown, memory_relabel_map],
|
|
176
|
+
outputs=[confidence, memory_relabel_map],
|
|
177
|
+
)
|
|
178
|
+
with gr.Column(scale=2, min_width=50):
|
|
179
|
+
checkbox = gr.Checkbox(
|
|
180
|
+
show_label=False,
|
|
181
|
+
label="",
|
|
182
|
+
value=current_memory_relabel_map[mem.memory_id]["approved"],
|
|
183
|
+
container=False,
|
|
184
|
+
elem_classes="centered",
|
|
185
|
+
interactive=True,
|
|
186
|
+
)
|
|
187
|
+
checkbox.input(
|
|
188
|
+
lambda selected, map, mem_id=mem.memory_id: update_approved(mem_id, selected, map),
|
|
189
|
+
inputs=[checkbox, memory_relabel_map],
|
|
190
|
+
outputs=[memory_relabel_map],
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
gr.HTML("<h1>No suggested label updates</h1>")
|
|
195
|
+
|
|
196
|
+
demo.launch()
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
.centered input {
|
|
2
|
+
margin: auto;
|
|
3
|
+
}
|
|
4
|
+
.centered p {
|
|
5
|
+
text-align: center;
|
|
6
|
+
}
|
|
7
|
+
.button {
|
|
8
|
+
display: inline-block;
|
|
9
|
+
max-width: 250px;
|
|
10
|
+
background-color: #2b9a66;
|
|
11
|
+
color: white;
|
|
12
|
+
position: fixed;
|
|
13
|
+
z-index: 1000;
|
|
14
|
+
right: 36px;
|
|
15
|
+
border-radius: 8px;
|
|
16
|
+
top: 12px;
|
|
17
|
+
}
|
|
18
|
+
.margin-top {
|
|
19
|
+
margin-top: 60px;
|
|
20
|
+
}
|
|
21
|
+
.header {
|
|
22
|
+
position: fixed;
|
|
23
|
+
z-index: 1000;
|
|
24
|
+
height: 64px;
|
|
25
|
+
left: 0;
|
|
26
|
+
top: 0;
|
|
27
|
+
border-radius: 0;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
input[type='checkbox']:checked,
|
|
31
|
+
input[type='checkbox']:checked:hover,
|
|
32
|
+
input[type='checkbox']:checked:focus {
|
|
33
|
+
background-color: #2b9a66;
|
|
34
|
+
border-color: #2b9a66;
|
|
35
|
+
}
|
|
36
|
+
input[type='checkbox']:focus {
|
|
37
|
+
border-color: #2b9a66;
|
|
38
|
+
}
|
|
39
|
+
.html-container:has(.no-padding) {
|
|
40
|
+
padding: 0;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
.progress-bar {
|
|
44
|
+
background-color: #2b9a66;
|
|
45
|
+
}
|
|
46
|
+
.header .full {
|
|
47
|
+
position: fixed !important;
|
|
48
|
+
z-index: 1100;
|
|
49
|
+
background-color: #e4e4e7;
|
|
50
|
+
height: 68px;
|
|
51
|
+
}
|
orca_sdk/_utils/auth.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""This module contains internal utils for managing api keys in tests"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from typing import List, Literal
|
|
6
|
+
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
|
|
9
|
+
from ..client import ApiKeyMetadata, OrcaClient
|
|
10
|
+
from .common import DropMode
|
|
11
|
+
|
|
12
|
+
load_dotenv() # this needs to be here to ensure env is populated before accessing it
|
|
13
|
+
|
|
14
|
+
# the defaults here must match nautilus and lighthouse config defaults
|
|
15
|
+
_ORCA_ROOT_ACCESS_API_KEY = os.environ.get("ORCA_ROOT_ACCESS_API_KEY", "00000000-0000-0000-0000-000000000000")
|
|
16
|
+
_DEFAULT_ORG_ID = os.environ.get("DEFAULT_ORG_ID", "10e50000-0000-4000-a000-a78dca14af3a")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _create_api_key(org_id: str, name: str, scopes: list[Literal["ADMINISTER", "PREDICT"]] = ["ADMINISTER"]) -> str:
|
|
20
|
+
"""Creates an API key for the given organization"""
|
|
21
|
+
client = OrcaClient._resolve_client()
|
|
22
|
+
response = client.POST(
|
|
23
|
+
"/auth/api_key",
|
|
24
|
+
json={"name": name, "scope": scopes},
|
|
25
|
+
headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
|
|
26
|
+
)
|
|
27
|
+
return response["api_key"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _list_api_keys(org_id: str) -> List[ApiKeyMetadata]:
|
|
31
|
+
"""Lists all API keys for the given organization"""
|
|
32
|
+
client = OrcaClient._resolve_client()
|
|
33
|
+
return client.GET("/auth/api_key", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _delete_api_key(org_id: str, name: str, if_not_exists: DropMode = "error") -> None:
|
|
37
|
+
"""Deletes the API key with the given name from the organization"""
|
|
38
|
+
try:
|
|
39
|
+
client = OrcaClient._resolve_client()
|
|
40
|
+
client.DELETE(
|
|
41
|
+
"/auth/api_key/{name_or_id}",
|
|
42
|
+
params={"name_or_id": name},
|
|
43
|
+
headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
|
|
44
|
+
)
|
|
45
|
+
except LookupError:
|
|
46
|
+
if if_not_exists == "error":
|
|
47
|
+
raise
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _delete_org(org_id: str) -> None:
|
|
51
|
+
"""Deletes the organization"""
|
|
52
|
+
client = OrcaClient._resolve_client()
|
|
53
|
+
client.DELETE("/auth/org", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _authenticate_local_api(org_id: str = _DEFAULT_ORG_ID, api_key_name: str = "local") -> None:
|
|
57
|
+
"""Connect to the local API at http://localhost:1584/ and authenticate with a new API key"""
|
|
58
|
+
_delete_api_key(org_id, api_key_name, if_not_exists="ignore")
|
|
59
|
+
client = OrcaClient._resolve_client()
|
|
60
|
+
client.base_url = "http://localhost:1584"
|
|
61
|
+
client.headers.update({"Api-Key": _create_api_key(org_id, api_key_name)})
|
|
62
|
+
logging.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
__all__ = ["_create_api_key", "_delete_api_key", "_delete_org", "_list_api_keys", "_authenticate_local_api"]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from uuid import uuid4
|
|
2
|
+
|
|
3
|
+
from ..credentials import OrcaCredentials
|
|
4
|
+
from .auth import _create_api_key, _delete_api_key, _delete_org, _list_api_keys
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_list_api_keys(org_id):
|
|
8
|
+
assert len(_list_api_keys(org_id)) >= 1
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def test_create_api_key(org_id):
|
|
12
|
+
name = f"test-{uuid4().hex[:8]}"
|
|
13
|
+
api_key = _create_api_key(org_id=org_id, name=name)
|
|
14
|
+
assert api_key is not None
|
|
15
|
+
assert name in [api_key.name for api_key in OrcaCredentials.list_api_keys()]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_delete_api_key(org_id):
|
|
19
|
+
name = f"test-{uuid4().hex[:8]}"
|
|
20
|
+
api_key = _create_api_key(org_id=org_id, name=name)
|
|
21
|
+
assert api_key is not None
|
|
22
|
+
assert name in [api_key.name for api_key in OrcaCredentials.list_api_keys()]
|
|
23
|
+
_delete_api_key(org_id=org_id, name=name)
|
|
24
|
+
assert name not in [api_key.name for api_key in OrcaCredentials.list_api_keys()]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_delete_org(other_org_id):
|
|
28
|
+
_create_api_key(org_id=other_org_id, name="test")
|
|
29
|
+
assert len(_list_api_keys(other_org_id)) >= 1
|
|
30
|
+
_delete_org(other_org_id)
|
|
31
|
+
assert len(_list_api_keys(other_org_id)) == 0
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
CreateMode = Literal["error", "open"]
|
|
4
|
+
"""
|
|
5
|
+
Mode for creating a resource.
|
|
6
|
+
|
|
7
|
+
**Options:**
|
|
8
|
+
|
|
9
|
+
- `"error"`: raise an error if a resource with the same name already exists
|
|
10
|
+
- `"open"`: open the resource with the same name if it exists
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
DropMode = Literal["error", "ignore"]
|
|
14
|
+
"""
|
|
15
|
+
Mode for deleting a resource.
|
|
16
|
+
|
|
17
|
+
**Options:**
|
|
18
|
+
|
|
19
|
+
- `"error"`: raise an error if the resource does not exist
|
|
20
|
+
- `"ignore"`: do nothing if the resource does not exist
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _UnsetSentinel:
|
|
25
|
+
"""See corresponding class in orcalib.pydantic_utils"""
|
|
26
|
+
|
|
27
|
+
def __bool__(self) -> bool:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
def __repr__(self) -> str:
|
|
31
|
+
return "UNSET"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
UNSET: Any = _UnsetSentinel()
|
|
35
|
+
"""
|
|
36
|
+
Default value to indicate that no update should be applied to a field and it should not be set to None
|
|
37
|
+
"""
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import pickle
|
|
2
|
+
from dataclasses import asdict, is_dataclass
|
|
3
|
+
from os import PathLike
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset
|
|
7
|
+
from datasets.exceptions import DatasetGenerationError
|
|
8
|
+
from torch.utils.data import DataLoader as TorchDataLoader
|
|
9
|
+
from torch.utils.data import Dataset as TorchDataset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
|
|
13
|
+
if isinstance(item, dict):
|
|
14
|
+
return item
|
|
15
|
+
|
|
16
|
+
if isinstance(item, tuple):
|
|
17
|
+
if column_names is not None:
|
|
18
|
+
if len(item) != len(column_names):
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"Tuple length ({len(item)}) does not match number of column names ({len(column_names)})"
|
|
21
|
+
)
|
|
22
|
+
return {column_names[i]: item[i] for i in range(len(item))}
|
|
23
|
+
elif hasattr(item, "_fields") and all(isinstance(field, str) for field in item._fields): # type: ignore
|
|
24
|
+
return {field: getattr(item, field) for field in item._fields} # type: ignore
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError("For datasets that return unnamed tuples, please provide column_names argument")
|
|
27
|
+
|
|
28
|
+
if is_dataclass(item) and not isinstance(item, type):
|
|
29
|
+
return asdict(item)
|
|
30
|
+
|
|
31
|
+
raise ValueError(f"Cannot parse {type(item)}")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]:
|
|
35
|
+
if isinstance(batch, list):
|
|
36
|
+
return [parse_dict_like(item, column_names) for item in batch]
|
|
37
|
+
|
|
38
|
+
batch = parse_dict_like(batch, column_names)
|
|
39
|
+
keys = list(batch.keys())
|
|
40
|
+
batch_size = len(batch[keys[0]])
|
|
41
|
+
for key in keys:
|
|
42
|
+
if not len(batch[key]) == batch_size:
|
|
43
|
+
raise ValueError(f"Batch must consist of values of the same length, but {key} has length {len(batch[key])}")
|
|
44
|
+
return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def hf_dataset_from_torch(
|
|
48
|
+
torch_data: TorchDataLoader | TorchDataset,
|
|
49
|
+
column_names: list[str] | None = None,
|
|
50
|
+
) -> Dataset:
|
|
51
|
+
"""
|
|
52
|
+
Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
|
|
53
|
+
|
|
54
|
+
NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
|
|
55
|
+
cached results can ignore changes you've made to tests. This can make a test appear to succeed
|
|
56
|
+
when it's actually broken or vice versa.
|
|
57
|
+
|
|
58
|
+
Params:
|
|
59
|
+
torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
|
|
60
|
+
column_names: Optional list of column names to use for the dataset. If not provided,
|
|
61
|
+
the column names will be inferred from the data.
|
|
62
|
+
Returns:
|
|
63
|
+
A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
|
|
64
|
+
"""
|
|
65
|
+
if isinstance(torch_data, TorchDataLoader):
|
|
66
|
+
dataloader = torch_data
|
|
67
|
+
else:
|
|
68
|
+
dataloader = TorchDataLoader(torch_data, batch_size=1, collate_fn=lambda x: x)
|
|
69
|
+
|
|
70
|
+
# Collect data from the dataloader into a list to avoid serialization issues
|
|
71
|
+
# with Dataset.from_generator in Python 3.14 (see datasets issue #7839)
|
|
72
|
+
data_list = []
|
|
73
|
+
try:
|
|
74
|
+
for batch in dataloader:
|
|
75
|
+
data_list.extend(parse_batch(batch, column_names=column_names))
|
|
76
|
+
except ValueError as e:
|
|
77
|
+
raise DatasetGenerationError(str(e)) from e
|
|
78
|
+
|
|
79
|
+
ds = Dataset.from_list(data_list)
|
|
80
|
+
|
|
81
|
+
if not isinstance(ds, Dataset):
|
|
82
|
+
raise ValueError(f"Failed to create dataset from list: {type(ds)}")
|
|
83
|
+
return ds
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
|
|
87
|
+
"""
|
|
88
|
+
Load a dataset from disk into a HuggingFace Dataset object.
|
|
89
|
+
|
|
90
|
+
Params:
|
|
91
|
+
file_path: Path to the file on disk to create the memoryset from. The file type will
|
|
92
|
+
be inferred from the file extension. The following file types are supported:
|
|
93
|
+
|
|
94
|
+
- .pkl: [`Pickle`][pickle] files containing lists of dictionaries or dictionaries of columns
|
|
95
|
+
- .json/.jsonl: [`JSON`][json] and [`JSON`] Lines files
|
|
96
|
+
- .csv: [`CSV`][csv] files
|
|
97
|
+
- .parquet: [`Parquet`](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html#pyarrow.parquet.ParquetFile) files
|
|
98
|
+
- dataset directory: Directory containing a saved HuggingFace [`Dataset`][datasets.Dataset]
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
A HuggingFace Dataset object containing the loaded data.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
[`ValueError`][ValueError]: If the pickle file contains unsupported data types or if
|
|
105
|
+
loading the dataset fails for any reason.
|
|
106
|
+
"""
|
|
107
|
+
if str(file_path).endswith(".pkl"):
|
|
108
|
+
data = pickle.load(open(file_path, "rb"))
|
|
109
|
+
if isinstance(data, list):
|
|
110
|
+
return Dataset.from_list(data)
|
|
111
|
+
elif isinstance(data, dict):
|
|
112
|
+
return Dataset.from_dict(data)
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError(f"Unsupported pickle file: {file_path}")
|
|
115
|
+
elif str(file_path).endswith(".json"):
|
|
116
|
+
hf_dataset = Dataset.from_json(file_path)
|
|
117
|
+
elif str(file_path).endswith(".jsonl"):
|
|
118
|
+
hf_dataset = Dataset.from_json(file_path)
|
|
119
|
+
elif str(file_path).endswith(".csv"):
|
|
120
|
+
hf_dataset = Dataset.from_csv(file_path)
|
|
121
|
+
elif str(file_path).endswith(".parquet"):
|
|
122
|
+
hf_dataset = Dataset.from_parquet(file_path)
|
|
123
|
+
else:
|
|
124
|
+
try:
|
|
125
|
+
hf_dataset = Dataset.load_from_disk(file_path)
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise ValueError(f"Failed to load dataset from disk: {e}")
|
|
128
|
+
|
|
129
|
+
return cast(Dataset, hf_dataset)
|