orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +27 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/classification_model.py +439 -129
- orca_sdk/classification_model_test.py +334 -104
- orca_sdk/client.py +3747 -0
- orca_sdk/conftest.py +164 -19
- orca_sdk/credentials.py +120 -18
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +259 -68
- orca_sdk/datasource_test.py +242 -0
- orca_sdk/embedding_model.py +425 -82
- orca_sdk/embedding_model_test.py +39 -13
- orca_sdk/job.py +337 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1341 -305
- orca_sdk/memoryset_test.py +350 -111
- orca_sdk/regression_model.py +684 -0
- orca_sdk/regression_model_test.py +369 -0
- orca_sdk/telemetry.py +449 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
- orca_sdk-0.1.2.dist-info/RECORD +40 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
orca_sdk/datasource_test.py
CHANGED
|
@@ -1,6 +1,14 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import tempfile
|
|
4
|
+
from typing import cast
|
|
1
5
|
from uuid import uuid4
|
|
2
6
|
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import pyarrow as pa
|
|
3
10
|
import pytest
|
|
11
|
+
from datasets import Dataset
|
|
4
12
|
|
|
5
13
|
from .datasource import Datasource
|
|
6
14
|
|
|
@@ -93,3 +101,237 @@ def test_drop_datasource_unauthorized(datasource, unauthorized):
|
|
|
93
101
|
def test_drop_datasource_invalid_input():
|
|
94
102
|
with pytest.raises(ValueError, match=r"Invalid input:.*"):
|
|
95
103
|
Datasource.drop("not valid id")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_from_list():
|
|
107
|
+
# Test creating datasource from list of dictionaries
|
|
108
|
+
data = [
|
|
109
|
+
{"column1": 1, "column2": "a"},
|
|
110
|
+
{"column1": 2, "column2": "b"},
|
|
111
|
+
{"column1": 3, "column2": "c"},
|
|
112
|
+
]
|
|
113
|
+
datasource = Datasource.from_list(f"test_list_{uuid4()}", data)
|
|
114
|
+
assert datasource.name.startswith("test_list_")
|
|
115
|
+
assert datasource.length == 3
|
|
116
|
+
assert "column1" in datasource.columns
|
|
117
|
+
assert "column2" in datasource.columns
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def test_from_dict():
|
|
121
|
+
# Test creating datasource from dictionary of columns
|
|
122
|
+
data = {
|
|
123
|
+
"column1": [1, 2, 3],
|
|
124
|
+
"column2": ["a", "b", "c"],
|
|
125
|
+
}
|
|
126
|
+
datasource = Datasource.from_dict(f"test_dict_{uuid4()}", data)
|
|
127
|
+
assert datasource.name.startswith("test_dict_")
|
|
128
|
+
assert datasource.length == 3
|
|
129
|
+
assert "column1" in datasource.columns
|
|
130
|
+
assert "column2" in datasource.columns
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_from_pandas():
|
|
134
|
+
# Test creating datasource from pandas DataFrame
|
|
135
|
+
df = pd.DataFrame(
|
|
136
|
+
{
|
|
137
|
+
"column1": [1, 2, 3],
|
|
138
|
+
"column2": ["a", "b", "c"],
|
|
139
|
+
}
|
|
140
|
+
)
|
|
141
|
+
datasource = Datasource.from_pandas(f"test_pandas_{uuid4()}", df)
|
|
142
|
+
assert datasource.name.startswith("test_pandas_")
|
|
143
|
+
assert datasource.length == 3
|
|
144
|
+
assert "column1" in datasource.columns
|
|
145
|
+
assert "column2" in datasource.columns
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def test_from_arrow():
|
|
149
|
+
# Test creating datasource from pyarrow Table
|
|
150
|
+
table = pa.table(
|
|
151
|
+
{
|
|
152
|
+
"column1": [1, 2, 3],
|
|
153
|
+
"column2": ["a", "b", "c"],
|
|
154
|
+
}
|
|
155
|
+
)
|
|
156
|
+
datasource = Datasource.from_arrow(f"test_arrow_{uuid4()}", table)
|
|
157
|
+
assert datasource.name.startswith("test_arrow_")
|
|
158
|
+
assert datasource.length == 3
|
|
159
|
+
assert "column1" in datasource.columns
|
|
160
|
+
assert "column2" in datasource.columns
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def test_from_list_already_exists():
|
|
164
|
+
# Test the if_exists parameter with from_list
|
|
165
|
+
data = [{"column1": 1, "column2": "a"}]
|
|
166
|
+
name = f"test_list_exists_{uuid4()}"
|
|
167
|
+
|
|
168
|
+
# Create the first datasource
|
|
169
|
+
datasource1 = Datasource.from_list(name, data)
|
|
170
|
+
assert datasource1.length == 1
|
|
171
|
+
|
|
172
|
+
# Try to create again with if_exists="error" (should raise)
|
|
173
|
+
with pytest.raises(ValueError):
|
|
174
|
+
Datasource.from_list(name, data, if_exists="error")
|
|
175
|
+
|
|
176
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
177
|
+
datasource2 = Datasource.from_list(name, data, if_exists="open")
|
|
178
|
+
assert datasource2.id == datasource1.id
|
|
179
|
+
assert datasource2.name == datasource1.name
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def test_from_dict_already_exists():
|
|
183
|
+
# Test the if_exists parameter with from_dict
|
|
184
|
+
data = {"column1": [1], "column2": ["a"]}
|
|
185
|
+
name = f"test_dict_exists_{uuid4()}"
|
|
186
|
+
|
|
187
|
+
# Create the first datasource
|
|
188
|
+
datasource1 = Datasource.from_dict(name, data)
|
|
189
|
+
assert datasource1.length == 1
|
|
190
|
+
|
|
191
|
+
# Try to create again with if_exists="error" (should raise)
|
|
192
|
+
with pytest.raises(ValueError):
|
|
193
|
+
Datasource.from_dict(name, data, if_exists="error")
|
|
194
|
+
|
|
195
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
196
|
+
datasource2 = Datasource.from_dict(name, data, if_exists="open")
|
|
197
|
+
assert datasource2.id == datasource1.id
|
|
198
|
+
assert datasource2.name == datasource1.name
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def test_from_pandas_already_exists():
|
|
202
|
+
# Test the if_exists parameter with from_pandas
|
|
203
|
+
df = pd.DataFrame({"column1": [1], "column2": ["a"]})
|
|
204
|
+
name = f"test_pandas_exists_{uuid4()}"
|
|
205
|
+
|
|
206
|
+
# Create the first datasource
|
|
207
|
+
datasource1 = Datasource.from_pandas(name, df)
|
|
208
|
+
assert datasource1.length == 1
|
|
209
|
+
|
|
210
|
+
# Try to create again with if_exists="error" (should raise)
|
|
211
|
+
with pytest.raises(ValueError):
|
|
212
|
+
Datasource.from_pandas(name, df, if_exists="error")
|
|
213
|
+
|
|
214
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
215
|
+
datasource2 = Datasource.from_pandas(name, df, if_exists="open")
|
|
216
|
+
assert datasource2.id == datasource1.id
|
|
217
|
+
assert datasource2.name == datasource1.name
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def test_from_arrow_already_exists():
|
|
221
|
+
# Test the if_exists parameter with from_arrow
|
|
222
|
+
table = pa.table({"column1": [1], "column2": ["a"]})
|
|
223
|
+
name = f"test_arrow_exists_{uuid4()}"
|
|
224
|
+
|
|
225
|
+
# Create the first datasource
|
|
226
|
+
datasource1 = Datasource.from_arrow(name, table)
|
|
227
|
+
assert datasource1.length == 1
|
|
228
|
+
|
|
229
|
+
# Try to create again with if_exists="error" (should raise)
|
|
230
|
+
with pytest.raises(ValueError):
|
|
231
|
+
Datasource.from_arrow(name, table, if_exists="error")
|
|
232
|
+
|
|
233
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
234
|
+
datasource2 = Datasource.from_arrow(name, table, if_exists="open")
|
|
235
|
+
assert datasource2.id == datasource1.id
|
|
236
|
+
assert datasource2.name == datasource1.name
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def test_from_disk_csv():
|
|
240
|
+
# Test creating datasource from CSV file
|
|
241
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
242
|
+
f.write("column1,column2\n1,a\n2,b\n3,c")
|
|
243
|
+
f.flush()
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
datasource = Datasource.from_disk(f"test_csv_{uuid4()}", f.name)
|
|
247
|
+
assert datasource.length == 3
|
|
248
|
+
assert "column1" in datasource.columns
|
|
249
|
+
assert "column2" in datasource.columns
|
|
250
|
+
finally:
|
|
251
|
+
os.unlink(f.name)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def test_from_disk_json():
|
|
255
|
+
# Test creating datasource from JSON file
|
|
256
|
+
import json
|
|
257
|
+
|
|
258
|
+
data = [{"column1": 1, "column2": "a"}, {"column1": 2, "column2": "b"}]
|
|
259
|
+
|
|
260
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
261
|
+
json.dump(data, f)
|
|
262
|
+
f.flush()
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
datasource = Datasource.from_disk(f"test_json_{uuid4()}", f.name)
|
|
266
|
+
assert datasource.length == 2
|
|
267
|
+
assert "column1" in datasource.columns
|
|
268
|
+
assert "column2" in datasource.columns
|
|
269
|
+
finally:
|
|
270
|
+
os.unlink(f.name)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def test_from_disk_already_exists():
|
|
274
|
+
# Test the if_exists parameter with from_disk
|
|
275
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
|
276
|
+
f.write("column1,column2\n1,a")
|
|
277
|
+
f.flush()
|
|
278
|
+
|
|
279
|
+
try:
|
|
280
|
+
name = f"test_disk_exists_{uuid4()}"
|
|
281
|
+
|
|
282
|
+
# Create the first datasource
|
|
283
|
+
datasource1 = Datasource.from_disk(name, f.name)
|
|
284
|
+
assert datasource1.length == 1
|
|
285
|
+
|
|
286
|
+
# Try to create again with if_exists="error" (should raise)
|
|
287
|
+
with pytest.raises(ValueError):
|
|
288
|
+
Datasource.from_disk(name, f.name, if_exists="error")
|
|
289
|
+
|
|
290
|
+
# Try to create again with if_exists="open" (should return existing)
|
|
291
|
+
datasource2 = Datasource.from_disk(name, f.name, if_exists="open")
|
|
292
|
+
assert datasource2.id == datasource1.id
|
|
293
|
+
assert datasource2.name == datasource1.name
|
|
294
|
+
finally:
|
|
295
|
+
os.unlink(f.name)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def test_to_list(hf_dataset, datasource):
|
|
299
|
+
assert datasource.to_list() == hf_dataset.to_list()
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def test_download_datasource(hf_dataset, datasource):
|
|
303
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
304
|
+
# Dataset download
|
|
305
|
+
datasource.download(temp_dir)
|
|
306
|
+
downloaded_hf_dataset_dir = f"{temp_dir}/{datasource.name}"
|
|
307
|
+
assert os.path.exists(downloaded_hf_dataset_dir)
|
|
308
|
+
assert os.path.isdir(downloaded_hf_dataset_dir)
|
|
309
|
+
assert not os.path.exists(f"{downloaded_hf_dataset_dir}.zip")
|
|
310
|
+
dataset_from_downloaded_hf_dataset = Dataset.load_from_disk(downloaded_hf_dataset_dir)
|
|
311
|
+
assert dataset_from_downloaded_hf_dataset.column_names == hf_dataset.column_names
|
|
312
|
+
assert dataset_from_downloaded_hf_dataset.to_dict() == hf_dataset.to_dict()
|
|
313
|
+
|
|
314
|
+
# JSON download
|
|
315
|
+
datasource.download(temp_dir, file_type="json")
|
|
316
|
+
downloaded_json_file = f"{temp_dir}/{datasource.name}.json"
|
|
317
|
+
assert os.path.exists(downloaded_json_file)
|
|
318
|
+
with open(downloaded_json_file, "r") as f:
|
|
319
|
+
content = json.load(f)
|
|
320
|
+
assert content == hf_dataset.to_list()
|
|
321
|
+
|
|
322
|
+
# CSV download
|
|
323
|
+
datasource.download(temp_dir, file_type="csv")
|
|
324
|
+
downloaded_csv_file = f"{temp_dir}/{datasource.name}.csv"
|
|
325
|
+
assert os.path.exists(downloaded_csv_file)
|
|
326
|
+
dataset_from_downloaded_csv = cast(Dataset, Dataset.from_csv(downloaded_csv_file))
|
|
327
|
+
assert dataset_from_downloaded_csv.column_names == hf_dataset.column_names
|
|
328
|
+
assert (
|
|
329
|
+
dataset_from_downloaded_csv.remove_columns("score").to_dict()
|
|
330
|
+
== hf_dataset.remove_columns("score").to_dict()
|
|
331
|
+
)
|
|
332
|
+
# Replace None with NaN for comparison
|
|
333
|
+
assert np.allclose(
|
|
334
|
+
np.array([np.nan if v is None else float(v) for v in dataset_from_downloaded_csv["score"]], dtype=float),
|
|
335
|
+
np.array([np.nan if v is None else float(v) for v in hf_dataset["score"]], dtype=float),
|
|
336
|
+
equal_nan=True,
|
|
337
|
+
)
|