orca-sdk 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +30 -0
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +634 -0
- orca_sdk/_shared/metrics_test.py +570 -0
- orca_sdk/_utils/__init__.py +0 -0
- orca_sdk/_utils/analysis_ui.py +196 -0
- orca_sdk/_utils/analysis_ui_style.css +51 -0
- orca_sdk/_utils/auth.py +65 -0
- orca_sdk/_utils/auth_test.py +31 -0
- orca_sdk/_utils/common.py +37 -0
- orca_sdk/_utils/data_parsing.py +129 -0
- orca_sdk/_utils/data_parsing_test.py +244 -0
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.css +18 -0
- orca_sdk/_utils/prediction_result_ui.py +110 -0
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/async_client.py +4104 -0
- orca_sdk/classification_model.py +1165 -0
- orca_sdk/classification_model_test.py +887 -0
- orca_sdk/client.py +4096 -0
- orca_sdk/conftest.py +382 -0
- orca_sdk/credentials.py +217 -0
- orca_sdk/credentials_test.py +121 -0
- orca_sdk/datasource.py +576 -0
- orca_sdk/datasource_test.py +463 -0
- orca_sdk/embedding_model.py +712 -0
- orca_sdk/embedding_model_test.py +206 -0
- orca_sdk/job.py +343 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +3811 -0
- orca_sdk/memoryset_test.py +1150 -0
- orca_sdk/regression_model.py +841 -0
- orca_sdk/regression_model_test.py +595 -0
- orca_sdk/telemetry.py +742 -0
- orca_sdk/telemetry_test.py +119 -0
- orca_sdk-0.1.9.dist-info/METADATA +98 -0
- orca_sdk-0.1.9.dist-info/RECORD +41 -0
- orca_sdk-0.1.9.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import tempfile
|
|
4
|
+
from typing import cast
|
|
5
|
+
from uuid import uuid4
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import pyarrow as pa
|
|
10
|
+
import pytest
|
|
11
|
+
from datasets import Dataset
|
|
12
|
+
|
|
13
|
+
from .datasource import Datasource
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_create_datasource(datasource, hf_dataset):
|
|
17
|
+
assert datasource is not None
|
|
18
|
+
assert datasource.name == "test_datasource"
|
|
19
|
+
assert datasource.length == len(hf_dataset)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_create_datasource_unauthenticated(unauthenticated_client, hf_dataset):
|
|
23
|
+
with unauthenticated_client.use():
|
|
24
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
25
|
+
Datasource.from_hf_dataset("test_datasource", hf_dataset)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_create_datasource_already_exists_error(hf_dataset, datasource):
|
|
29
|
+
with pytest.raises(ValueError):
|
|
30
|
+
Datasource.from_hf_dataset("test_datasource", hf_dataset, if_exists="error")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_create_datasource_already_exists_return(hf_dataset, datasource):
|
|
34
|
+
returned_dataset = Datasource.from_hf_dataset("test_datasource", hf_dataset, if_exists="open")
|
|
35
|
+
assert returned_dataset is not None
|
|
36
|
+
assert returned_dataset.name == "test_datasource"
|
|
37
|
+
assert returned_dataset.length == len(hf_dataset)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_open_datasource(datasource):
|
|
41
|
+
fetched_datasource = Datasource.open(datasource.name)
|
|
42
|
+
assert fetched_datasource is not None
|
|
43
|
+
assert fetched_datasource.name == datasource.name
|
|
44
|
+
assert fetched_datasource.length == len(datasource)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_open_datasource_unauthenticated(unauthenticated_client, datasource):
|
|
48
|
+
with unauthenticated_client.use():
|
|
49
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
50
|
+
Datasource.open("test_datasource")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_open_datasource_invalid_input():
|
|
54
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
55
|
+
Datasource.open("not valid id")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_open_datasource_not_found():
|
|
59
|
+
with pytest.raises(LookupError):
|
|
60
|
+
Datasource.open(str(uuid4()))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_open_datasource_unauthorized(unauthorized_client, datasource):
|
|
64
|
+
with unauthorized_client.use():
|
|
65
|
+
with pytest.raises(LookupError):
|
|
66
|
+
Datasource.open(datasource.id)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_all_datasources(datasource):
|
|
70
|
+
datasources = Datasource.all()
|
|
71
|
+
assert len(datasources) > 0
|
|
72
|
+
assert any(datasource.name == datasource.name for datasource in datasources)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def test_all_datasources_unauthenticated(unauthenticated_client):
|
|
76
|
+
with unauthenticated_client.use():
|
|
77
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
78
|
+
Datasource.all()
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_drop_datasource(hf_dataset):
|
|
82
|
+
Datasource.from_hf_dataset("datasource_to_delete", hf_dataset)
|
|
83
|
+
assert Datasource.exists("datasource_to_delete")
|
|
84
|
+
Datasource.drop("datasource_to_delete")
|
|
85
|
+
assert not Datasource.exists("datasource_to_delete")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_drop_datasource_unauthenticated(datasource, unauthenticated_client):
|
|
89
|
+
with unauthenticated_client.use():
|
|
90
|
+
with pytest.raises(ValueError, match="Invalid API key"):
|
|
91
|
+
Datasource.drop(datasource.id)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_drop_datasource_not_found():
|
|
95
|
+
with pytest.raises(LookupError):
|
|
96
|
+
Datasource.drop(str(uuid4()))
|
|
97
|
+
# ignores error if specified
|
|
98
|
+
Datasource.drop(str(uuid4()), if_not_exists="ignore")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_drop_datasource_unauthorized(datasource, unauthorized_client):
|
|
102
|
+
with unauthorized_client.use():
|
|
103
|
+
with pytest.raises(LookupError):
|
|
104
|
+
Datasource.drop(datasource.id)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_drop_datasource_invalid_input():
|
|
108
|
+
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
109
|
+
Datasource.drop("not valid id")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_from_list():
|
|
113
|
+
# Test creating datasource from list of dictionaries
|
|
114
|
+
data = [
|
|
115
|
+
{"column1": 1, "column2": "a"},
|
|
116
|
+
{"column1": 2, "column2": "b"},
|
|
117
|
+
{"column1": 3, "column2": "c"},
|
|
118
|
+
]
|
|
119
|
+
datasource = Datasource.from_list(f"test_list_{uuid4()}", data)
|
|
120
|
+
assert datasource.name.startswith("test_list_")
|
|
121
|
+
assert datasource.length == 3
|
|
122
|
+
assert "column1" in datasource.columns
|
|
123
|
+
assert "column2" in datasource.columns
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_from_dict():
|
|
127
|
+
# Test creating datasource from dictionary of columns
|
|
128
|
+
data = {
|
|
129
|
+
"column1": [1, 2, 3],
|
|
130
|
+
"column2": ["a", "b", "c"],
|
|
131
|
+
}
|
|
132
|
+
datasource = Datasource.from_dict(f"test_dict_{uuid4()}", data)
|
|
133
|
+
assert datasource.name.startswith("test_dict_")
|
|
134
|
+
assert datasource.length == 3
|
|
135
|
+
assert "column1" in datasource.columns
|
|
136
|
+
assert "column2" in datasource.columns
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def test_from_pandas():
|
|
140
|
+
# Test creating datasource from pandas DataFrame
|
|
141
|
+
df = pd.DataFrame(
|
|
142
|
+
{
|
|
143
|
+
"column1": [1, 2, 3],
|
|
144
|
+
"column2": ["a", "b", "c"],
|
|
145
|
+
}
|
|
146
|
+
)
|
|
147
|
+
datasource = Datasource.from_pandas(f"test_pandas_{uuid4()}", df)
|
|
148
|
+
assert datasource.name.startswith("test_pandas_")
|
|
149
|
+
assert datasource.length == 3
|
|
150
|
+
assert "column1" in datasource.columns
|
|
151
|
+
assert "column2" in datasource.columns
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_from_arrow():
|
|
155
|
+
# Test creating datasource from pyarrow Table
|
|
156
|
+
table = pa.table(
|
|
157
|
+
{
|
|
158
|
+
"column1": [1, 2, 3],
|
|
159
|
+
"column2": ["a", "b", "c"],
|
|
160
|
+
}
|
|
161
|
+
)
|
|
162
|
+
datasource = Datasource.from_arrow(f"test_arrow_{uuid4()}", table)
|
|
163
|
+
assert datasource.name.startswith("test_arrow_")
|
|
164
|
+
assert datasource.length == 3
|
|
165
|
+
assert "column1" in datasource.columns
|
|
166
|
+
assert "column2" in datasource.columns
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def test_from_list_already_exists():
|
|
170
|
+
# Test the if_exists parameter with from_list
|
|
171
|
+
data = [{"column1": 1, "column2": "a"}]
|
|
172
|
+
name = f"test_list_exists_{uuid4()}"
|
|
173
|
+
|
|
174
|
+
# Create the first datasource
|
|
175
|
+
datasource1 = Datasource.from_list(name, data)
|
|
176
|
+
assert datasource1.length == 1
|
|
177
|
+
|
|
178
|
+
# Try to create again with if_exists="error" (should raise)
|
|
179
|
+
with pytest.raises(ValueError):
|
|
180
|
+
Datasource.from_list(name, data, if_exists="error")
|
|
181
|
+
|
|
182
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
183
|
+
datasource2 = Datasource.from_list(name, data, if_exists="open")
|
|
184
|
+
assert datasource2.id == datasource1.id
|
|
185
|
+
assert datasource2.name == datasource1.name
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_from_dict_already_exists():
|
|
189
|
+
# Test the if_exists parameter with from_dict
|
|
190
|
+
data = {"column1": [1], "column2": ["a"]}
|
|
191
|
+
name = f"test_dict_exists_{uuid4()}"
|
|
192
|
+
|
|
193
|
+
# Create the first datasource
|
|
194
|
+
datasource1 = Datasource.from_dict(name, data)
|
|
195
|
+
assert datasource1.length == 1
|
|
196
|
+
|
|
197
|
+
# Try to create again with if_exists="error" (should raise)
|
|
198
|
+
with pytest.raises(ValueError):
|
|
199
|
+
Datasource.from_dict(name, data, if_exists="error")
|
|
200
|
+
|
|
201
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
202
|
+
datasource2 = Datasource.from_dict(name, data, if_exists="open")
|
|
203
|
+
assert datasource2.id == datasource1.id
|
|
204
|
+
assert datasource2.name == datasource1.name
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def test_from_pandas_already_exists():
|
|
208
|
+
# Test the if_exists parameter with from_pandas
|
|
209
|
+
df = pd.DataFrame({"column1": [1], "column2": ["a"]})
|
|
210
|
+
name = f"test_pandas_exists_{uuid4()}"
|
|
211
|
+
|
|
212
|
+
# Create the first datasource
|
|
213
|
+
datasource1 = Datasource.from_pandas(name, df)
|
|
214
|
+
assert datasource1.length == 1
|
|
215
|
+
|
|
216
|
+
# Try to create again with if_exists="error" (should raise)
|
|
217
|
+
with pytest.raises(ValueError):
|
|
218
|
+
Datasource.from_pandas(name, df, if_exists="error")
|
|
219
|
+
|
|
220
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
221
|
+
datasource2 = Datasource.from_pandas(name, df, if_exists="open")
|
|
222
|
+
assert datasource2.id == datasource1.id
|
|
223
|
+
assert datasource2.name == datasource1.name
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def test_from_arrow_already_exists():
|
|
227
|
+
# Test the if_exists parameter with from_arrow
|
|
228
|
+
table = pa.table({"column1": [1], "column2": ["a"]})
|
|
229
|
+
name = f"test_arrow_exists_{uuid4()}"
|
|
230
|
+
|
|
231
|
+
# Create the first datasource
|
|
232
|
+
datasource1 = Datasource.from_arrow(name, table)
|
|
233
|
+
assert datasource1.length == 1
|
|
234
|
+
|
|
235
|
+
# Try to create again with if_exists="error" (should raise)
|
|
236
|
+
with pytest.raises(ValueError):
|
|
237
|
+
Datasource.from_arrow(name, table, if_exists="error")
|
|
238
|
+
|
|
239
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
240
|
+
datasource2 = Datasource.from_arrow(name, table, if_exists="open")
|
|
241
|
+
assert datasource2.id == datasource1.id
|
|
242
|
+
assert datasource2.name == datasource1.name
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def test_from_disk_csv():
|
|
246
|
+
# Test creating datasource from CSV file
|
|
247
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
248
|
+
f.write("column1,column2\n1,a\n2,b\n3,c")
|
|
249
|
+
f.flush()
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
datasource = Datasource.from_disk(f"test_csv_{uuid4()}", f.name)
|
|
253
|
+
assert datasource.length == 3
|
|
254
|
+
assert "column1" in datasource.columns
|
|
255
|
+
assert "column2" in datasource.columns
|
|
256
|
+
finally:
|
|
257
|
+
os.unlink(f.name)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def test_from_disk_json():
|
|
261
|
+
# Test creating datasource from JSON file
|
|
262
|
+
import json
|
|
263
|
+
|
|
264
|
+
data = [{"column1": 1, "column2": "a"}, {"column1": 2, "column2": "b"}]
|
|
265
|
+
|
|
266
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
267
|
+
json.dump(data, f)
|
|
268
|
+
f.flush()
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
datasource = Datasource.from_disk(f"test_json_{uuid4()}", f.name)
|
|
272
|
+
assert datasource.length == 2
|
|
273
|
+
assert "column1" in datasource.columns
|
|
274
|
+
assert "column2" in datasource.columns
|
|
275
|
+
finally:
|
|
276
|
+
os.unlink(f.name)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def test_from_disk_already_exists():
|
|
280
|
+
# Test the if_exists parameter with from_disk
|
|
281
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
282
|
+
f.write("column1,column2\n1,a")
|
|
283
|
+
f.flush()
|
|
284
|
+
|
|
285
|
+
try:
|
|
286
|
+
name = f"test_disk_exists_{uuid4()}"
|
|
287
|
+
|
|
288
|
+
# Create the first datasource
|
|
289
|
+
datasource1 = Datasource.from_disk(name, f.name)
|
|
290
|
+
assert datasource1.length == 1
|
|
291
|
+
|
|
292
|
+
# Try to create again with if_exists="error" (should raise)
|
|
293
|
+
with pytest.raises(ValueError):
|
|
294
|
+
Datasource.from_disk(name, f.name, if_exists="error")
|
|
295
|
+
|
|
296
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
297
|
+
datasource2 = Datasource.from_disk(name, f.name, if_exists="open")
|
|
298
|
+
assert datasource2.id == datasource1.id
|
|
299
|
+
assert datasource2.name == datasource1.name
|
|
300
|
+
finally:
|
|
301
|
+
os.unlink(f.name)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def test_query_datasource_rows():
|
|
305
|
+
"""Test querying rows from a datasource with pagination and shuffle."""
|
|
306
|
+
# Create a new dataset with 5 entries for testing
|
|
307
|
+
test_data = [{"id": i, "name": f"item_{i}"} for i in range(5)]
|
|
308
|
+
datasource = Datasource.from_list(name="test_query_datasource", data=test_data)
|
|
309
|
+
|
|
310
|
+
# Test basic query
|
|
311
|
+
rows = datasource.query(limit=3)
|
|
312
|
+
assert len(rows) == 3
|
|
313
|
+
assert all(isinstance(row, dict) for row in rows)
|
|
314
|
+
|
|
315
|
+
# Test offset
|
|
316
|
+
offset_rows = datasource.query(offset=2, limit=2)
|
|
317
|
+
assert len(offset_rows) == 2
|
|
318
|
+
assert offset_rows[0]["id"] == 2
|
|
319
|
+
|
|
320
|
+
# Test shuffle
|
|
321
|
+
shuffled_rows = datasource.query(limit=5, shuffle=True)
|
|
322
|
+
assert len(shuffled_rows) == 5
|
|
323
|
+
assert not all(row["id"] == i for i, row in enumerate(shuffled_rows))
|
|
324
|
+
|
|
325
|
+
# Test shuffle with seed
|
|
326
|
+
assert datasource.query(limit=5, shuffle=True, shuffle_seed=42) == datasource.query(
|
|
327
|
+
limit=5, shuffle=True, shuffle_seed=42
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def test_query_datasource_with_filters():
|
|
332
|
+
"""Test querying datasource rows with various filter operators."""
|
|
333
|
+
# Create a datasource with test data
|
|
334
|
+
test_data = [
|
|
335
|
+
{"name": "Alice", "age": 25, "city": "New York", "score": 85.5},
|
|
336
|
+
{"name": "Bob", "age": 30, "city": "San Francisco", "score": 90.0},
|
|
337
|
+
{"name": "Charlie", "age": 35, "city": "Chicago", "score": 75.5},
|
|
338
|
+
{"name": "Diana", "age": 28, "city": "Boston", "score": 88.0},
|
|
339
|
+
{"name": "Eve", "age": 32, "city": "New York", "score": 92.0},
|
|
340
|
+
]
|
|
341
|
+
datasource = Datasource.from_list(name=f"test_filter_datasource_{uuid4()}", data=test_data)
|
|
342
|
+
|
|
343
|
+
# Test == operator
|
|
344
|
+
rows = datasource.query(filters=[("city", "==", "New York")])
|
|
345
|
+
assert len(rows) == 2
|
|
346
|
+
assert all(row["city"] == "New York" for row in rows)
|
|
347
|
+
|
|
348
|
+
# Test > operator
|
|
349
|
+
rows = datasource.query(filters=[("age", ">", 30)])
|
|
350
|
+
assert len(rows) == 2
|
|
351
|
+
assert all(row["age"] > 30 for row in rows)
|
|
352
|
+
|
|
353
|
+
# Test >= operator
|
|
354
|
+
rows = datasource.query(filters=[("score", ">=", 88.0)])
|
|
355
|
+
assert len(rows) == 3
|
|
356
|
+
assert all(row["score"] >= 88.0 for row in rows)
|
|
357
|
+
|
|
358
|
+
# Test < operator
|
|
359
|
+
rows = datasource.query(filters=[("age", "<", 30)])
|
|
360
|
+
assert len(rows) == 2
|
|
361
|
+
assert all(row["age"] < 30 for row in rows)
|
|
362
|
+
|
|
363
|
+
# Test in operator
|
|
364
|
+
rows = datasource.query(filters=[("city", "in", ["New York", "Boston"])])
|
|
365
|
+
assert len(rows) == 3
|
|
366
|
+
assert all(row["city"] in ["New York", "Boston"] for row in rows)
|
|
367
|
+
|
|
368
|
+
# Test not in operator
|
|
369
|
+
rows = datasource.query(filters=[("city", "not in", ["New York", "Boston"])])
|
|
370
|
+
assert len(rows) == 2
|
|
371
|
+
assert all(row["city"] not in ["New York", "Boston"] for row in rows)
|
|
372
|
+
|
|
373
|
+
# Test like operator
|
|
374
|
+
rows = datasource.query(filters=[("name", "like", "li")])
|
|
375
|
+
assert len(rows) == 2
|
|
376
|
+
assert all("li" in row["name"].lower() for row in rows)
|
|
377
|
+
|
|
378
|
+
# Test multiple filters (AND logic)
|
|
379
|
+
rows = datasource.query(filters=[("city", "==", "New York"), ("age", ">", 26)])
|
|
380
|
+
assert len(rows) == 1
|
|
381
|
+
assert rows[0]["name"] == "Eve"
|
|
382
|
+
|
|
383
|
+
# Test filter with pagination
|
|
384
|
+
rows = datasource.query(filters=[("age", ">=", 28)], limit=2, offset=1)
|
|
385
|
+
assert len(rows) == 2
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def test_query_datasource_with_none_filters():
|
|
389
|
+
"""Test filtering for None values."""
|
|
390
|
+
test_data = [
|
|
391
|
+
{"name": "Alice", "age": 25, "label": "A"},
|
|
392
|
+
{"name": "Bob", "age": 30, "label": None},
|
|
393
|
+
{"name": "Charlie", "age": 35, "label": "C"},
|
|
394
|
+
{"name": "Diana", "age": None, "label": "D"},
|
|
395
|
+
{"name": "Eve", "age": 32, "label": None},
|
|
396
|
+
]
|
|
397
|
+
datasource = Datasource.from_list(name=f"test_none_filter_{uuid4()}", data=test_data)
|
|
398
|
+
|
|
399
|
+
# Test == None
|
|
400
|
+
rows = datasource.query(filters=[("label", "==", None)])
|
|
401
|
+
assert len(rows) == 2
|
|
402
|
+
assert all(row["label"] is None for row in rows)
|
|
403
|
+
|
|
404
|
+
# Test != None
|
|
405
|
+
rows = datasource.query(filters=[("label", "!=", None)])
|
|
406
|
+
assert len(rows) == 3
|
|
407
|
+
assert all(row["label"] is not None for row in rows)
|
|
408
|
+
|
|
409
|
+
# Test that None values are excluded from comparison operators
|
|
410
|
+
rows = datasource.query(filters=[("age", ">", 25)])
|
|
411
|
+
assert len(rows) == 3
|
|
412
|
+
assert all(row["age"] is not None and row["age"] > 25 for row in rows)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def test_query_datasource_filter_invalid_column():
|
|
416
|
+
"""Test that querying with an invalid column raises an error."""
|
|
417
|
+
test_data = [{"name": "Alice", "age": 25}]
|
|
418
|
+
datasource = Datasource.from_list(name=f"test_invalid_filter_{uuid4()}", data=test_data)
|
|
419
|
+
|
|
420
|
+
with pytest.raises(ValueError):
|
|
421
|
+
datasource.query(filters=[("invalid_column", "==", "test")])
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def test_to_list(hf_dataset, datasource):
|
|
425
|
+
assert datasource.to_list() == hf_dataset.to_list()
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def test_download_datasource(hf_dataset, datasource):
|
|
429
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
430
|
+
# Dataset download
|
|
431
|
+
datasource.download(temp_dir)
|
|
432
|
+
downloaded_hf_dataset_dir = f"{temp_dir}/{datasource.name}"
|
|
433
|
+
assert os.path.exists(downloaded_hf_dataset_dir)
|
|
434
|
+
assert os.path.isdir(downloaded_hf_dataset_dir)
|
|
435
|
+
assert not os.path.exists(f"{downloaded_hf_dataset_dir}.zip")
|
|
436
|
+
dataset_from_downloaded_hf_dataset = Dataset.load_from_disk(downloaded_hf_dataset_dir)
|
|
437
|
+
assert dataset_from_downloaded_hf_dataset.column_names == hf_dataset.column_names
|
|
438
|
+
assert dataset_from_downloaded_hf_dataset.to_dict() == hf_dataset.to_dict()
|
|
439
|
+
|
|
440
|
+
# JSON download
|
|
441
|
+
datasource.download(temp_dir, file_type="json")
|
|
442
|
+
downloaded_json_file = f"{temp_dir}/{datasource.name}.json"
|
|
443
|
+
assert os.path.exists(downloaded_json_file)
|
|
444
|
+
with open(downloaded_json_file, "r") as f:
|
|
445
|
+
content = json.load(f)
|
|
446
|
+
assert content == hf_dataset.to_list()
|
|
447
|
+
|
|
448
|
+
# CSV download
|
|
449
|
+
datasource.download(temp_dir, file_type="csv")
|
|
450
|
+
downloaded_csv_file = f"{temp_dir}/{datasource.name}.csv"
|
|
451
|
+
assert os.path.exists(downloaded_csv_file)
|
|
452
|
+
dataset_from_downloaded_csv = cast(Dataset, Dataset.from_csv(downloaded_csv_file))
|
|
453
|
+
assert dataset_from_downloaded_csv.column_names == hf_dataset.column_names
|
|
454
|
+
assert (
|
|
455
|
+
dataset_from_downloaded_csv.remove_columns("score").to_dict()
|
|
456
|
+
== hf_dataset.remove_columns("score").to_dict()
|
|
457
|
+
)
|
|
458
|
+
# Replace None with NaN for comparison
|
|
459
|
+
assert np.allclose(
|
|
460
|
+
np.array([np.nan if v is None else float(v) for v in dataset_from_downloaded_csv["score"]], dtype=float),
|
|
461
|
+
np.array([np.nan if v is None else float(v) for v in hf_dataset["score"]], dtype=float),
|
|
462
|
+
equal_nan=True,
|
|
463
|
+
)
|