orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -1,6 +1,14 @@
1
+ import json
2
+ import os
3
+ import tempfile
4
+ from typing import cast
1
5
  from uuid import uuid4
2
6
 
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pyarrow as pa
3
10
  import pytest
11
+ from datasets import Dataset
4
12
 
5
13
  from .datasource import Datasource
6
14
 
@@ -11,9 +19,10 @@ def test_create_datasource(datasource, hf_dataset):
11
19
  assert datasource.length == len(hf_dataset)
12
20
 
13
21
 
14
- def test_create_datasource_unauthenticated(unauthenticated, hf_dataset):
15
- with pytest.raises(ValueError, match="Invalid API key"):
16
- Datasource.from_hf_dataset("test_datasource", hf_dataset)
22
+ def test_create_datasource_unauthenticated(unauthenticated_client, hf_dataset):
23
+ with unauthenticated_client.use():
24
+ with pytest.raises(ValueError, match="Invalid API key"):
25
+ Datasource.from_hf_dataset("test_datasource", hf_dataset)
17
26
 
18
27
 
19
28
  def test_create_datasource_already_exists_error(hf_dataset, datasource):
@@ -35,9 +44,10 @@ def test_open_datasource(datasource):
35
44
  assert fetched_datasource.length == len(datasource)
36
45
 
37
46
 
38
- def test_open_datasource_unauthenticated(datasource, unauthenticated):
39
- with pytest.raises(ValueError, match="Invalid API key"):
40
- Datasource.open("test_datasource")
47
+ def test_open_datasource_unauthenticated(unauthenticated_client, datasource):
48
+ with unauthenticated_client.use():
49
+ with pytest.raises(ValueError, match="Invalid API key"):
50
+ Datasource.open("test_datasource")
41
51
 
42
52
 
43
53
  def test_open_datasource_invalid_input():
@@ -50,9 +60,10 @@ def test_open_datasource_not_found():
50
60
  Datasource.open(str(uuid4()))
51
61
 
52
62
 
53
- def test_open_datasource_unauthorized(datasource, unauthorized):
54
- with pytest.raises(LookupError):
55
- Datasource.open(datasource.id)
63
+ def test_open_datasource_unauthorized(unauthorized_client, datasource):
64
+ with unauthorized_client.use():
65
+ with pytest.raises(LookupError):
66
+ Datasource.open(datasource.id)
56
67
 
57
68
 
58
69
  def test_all_datasources(datasource):
@@ -61,9 +72,10 @@ def test_all_datasources(datasource):
61
72
  assert any(datasource.name == datasource.name for datasource in datasources)
62
73
 
63
74
 
64
- def test_all_datasources_unauthenticated(unauthenticated):
65
- with pytest.raises(ValueError, match="Invalid API key"):
66
- Datasource.all()
75
+ def test_all_datasources_unauthenticated(unauthenticated_client):
76
+ with unauthenticated_client.use():
77
+ with pytest.raises(ValueError, match="Invalid API key"):
78
+ Datasource.all()
67
79
 
68
80
 
69
81
  def test_drop_datasource(hf_dataset):
@@ -73,9 +85,10 @@ def test_drop_datasource(hf_dataset):
73
85
  assert not Datasource.exists("datasource_to_delete")
74
86
 
75
87
 
76
- def test_drop_datasource_unauthenticated(datasource, unauthenticated):
77
- with pytest.raises(ValueError, match="Invalid API key"):
78
- Datasource.drop(datasource.id)
88
+ def test_drop_datasource_unauthenticated(datasource, unauthenticated_client):
89
+ with unauthenticated_client.use():
90
+ with pytest.raises(ValueError, match="Invalid API key"):
91
+ Datasource.drop(datasource.id)
79
92
 
80
93
 
81
94
  def test_drop_datasource_not_found():
@@ -85,11 +98,246 @@ def test_drop_datasource_not_found():
85
98
  Datasource.drop(str(uuid4()), if_not_exists="ignore")
86
99
 
87
100
 
88
- def test_drop_datasource_unauthorized(datasource, unauthorized):
89
- with pytest.raises(LookupError):
90
- Datasource.drop(datasource.id)
101
+ def test_drop_datasource_unauthorized(datasource, unauthorized_client):
102
+ with unauthorized_client.use():
103
+ with pytest.raises(LookupError):
104
+ Datasource.drop(datasource.id)
91
105
 
92
106
 
93
107
  def test_drop_datasource_invalid_input():
94
108
  with pytest.raises(ValueError, match=r"Invalid input:.*"):
95
109
  Datasource.drop("not valid id")
110
+
111
+
112
+ def test_from_list():
113
+ # Test creating datasource from list of dictionaries
114
+ data = [
115
+ {"column1": 1, "column2": "a"},
116
+ {"column1": 2, "column2": "b"},
117
+ {"column1": 3, "column2": "c"},
118
+ ]
119
+ datasource = Datasource.from_list(f"test_list_{uuid4()}", data)
120
+ assert datasource.name.startswith("test_list_")
121
+ assert datasource.length == 3
122
+ assert "column1" in datasource.columns
123
+ assert "column2" in datasource.columns
124
+
125
+
126
+ def test_from_dict():
127
+ # Test creating datasource from dictionary of columns
128
+ data = {
129
+ "column1": [1, 2, 3],
130
+ "column2": ["a", "b", "c"],
131
+ }
132
+ datasource = Datasource.from_dict(f"test_dict_{uuid4()}", data)
133
+ assert datasource.name.startswith("test_dict_")
134
+ assert datasource.length == 3
135
+ assert "column1" in datasource.columns
136
+ assert "column2" in datasource.columns
137
+
138
+
139
+ def test_from_pandas():
140
+ # Test creating datasource from pandas DataFrame
141
+ df = pd.DataFrame(
142
+ {
143
+ "column1": [1, 2, 3],
144
+ "column2": ["a", "b", "c"],
145
+ }
146
+ )
147
+ datasource = Datasource.from_pandas(f"test_pandas_{uuid4()}", df)
148
+ assert datasource.name.startswith("test_pandas_")
149
+ assert datasource.length == 3
150
+ assert "column1" in datasource.columns
151
+ assert "column2" in datasource.columns
152
+
153
+
154
+ def test_from_arrow():
155
+ # Test creating datasource from pyarrow Table
156
+ table = pa.table(
157
+ {
158
+ "column1": [1, 2, 3],
159
+ "column2": ["a", "b", "c"],
160
+ }
161
+ )
162
+ datasource = Datasource.from_arrow(f"test_arrow_{uuid4()}", table)
163
+ assert datasource.name.startswith("test_arrow_")
164
+ assert datasource.length == 3
165
+ assert "column1" in datasource.columns
166
+ assert "column2" in datasource.columns
167
+
168
+
169
+ def test_from_list_already_exists():
170
+ # Test the if_exists parameter with from_list
171
+ data = [{"column1": 1, "column2": "a"}]
172
+ name = f"test_list_exists_{uuid4()}"
173
+
174
+ # Create the first datasource
175
+ datasource1 = Datasource.from_list(name, data)
176
+ assert datasource1.length == 1
177
+
178
+ # Try to create again with if_exists="error" (should raise)
179
+ with pytest.raises(ValueError):
180
+ Datasource.from_list(name, data, if_exists="error")
181
+
182
+ # Try to create again with if_exists="open" (should return existing)
183
+ datasource2 = Datasource.from_list(name, data, if_exists="open")
184
+ assert datasource2.id == datasource1.id
185
+ assert datasource2.name == datasource1.name
186
+
187
+
188
+ def test_from_dict_already_exists():
189
+ # Test the if_exists parameter with from_dict
190
+ data = {"column1": [1], "column2": ["a"]}
191
+ name = f"test_dict_exists_{uuid4()}"
192
+
193
+ # Create the first datasource
194
+ datasource1 = Datasource.from_dict(name, data)
195
+ assert datasource1.length == 1
196
+
197
+ # Try to create again with if_exists="error" (should raise)
198
+ with pytest.raises(ValueError):
199
+ Datasource.from_dict(name, data, if_exists="error")
200
+
201
+ # Try to create again with if_exists="open" (should return existing)
202
+ datasource2 = Datasource.from_dict(name, data, if_exists="open")
203
+ assert datasource2.id == datasource1.id
204
+ assert datasource2.name == datasource1.name
205
+
206
+
207
+ def test_from_pandas_already_exists():
208
+ # Test the if_exists parameter with from_pandas
209
+ df = pd.DataFrame({"column1": [1], "column2": ["a"]})
210
+ name = f"test_pandas_exists_{uuid4()}"
211
+
212
+ # Create the first datasource
213
+ datasource1 = Datasource.from_pandas(name, df)
214
+ assert datasource1.length == 1
215
+
216
+ # Try to create again with if_exists="error" (should raise)
217
+ with pytest.raises(ValueError):
218
+ Datasource.from_pandas(name, df, if_exists="error")
219
+
220
+ # Try to create again with if_exists="open" (should return existing)
221
+ datasource2 = Datasource.from_pandas(name, df, if_exists="open")
222
+ assert datasource2.id == datasource1.id
223
+ assert datasource2.name == datasource1.name
224
+
225
+
226
+ def test_from_arrow_already_exists():
227
+ # Test the if_exists parameter with from_arrow
228
+ table = pa.table({"column1": [1], "column2": ["a"]})
229
+ name = f"test_arrow_exists_{uuid4()}"
230
+
231
+ # Create the first datasource
232
+ datasource1 = Datasource.from_arrow(name, table)
233
+ assert datasource1.length == 1
234
+
235
+ # Try to create again with if_exists="error" (should raise)
236
+ with pytest.raises(ValueError):
237
+ Datasource.from_arrow(name, table, if_exists="error")
238
+
239
+ # Try to create again with if_exists="open" (should return existing)
240
+ datasource2 = Datasource.from_arrow(name, table, if_exists="open")
241
+ assert datasource2.id == datasource1.id
242
+ assert datasource2.name == datasource1.name
243
+
244
+
245
+ def test_from_disk_csv():
246
+ # Test creating datasource from CSV file
247
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
248
+ f.write("column1,column2\n1,a\n2,b\n3,c")
249
+ f.flush()
250
+
251
+ try:
252
+ datasource = Datasource.from_disk(f"test_csv_{uuid4()}", f.name)
253
+ assert datasource.length == 3
254
+ assert "column1" in datasource.columns
255
+ assert "column2" in datasource.columns
256
+ finally:
257
+ os.unlink(f.name)
258
+
259
+
260
+ def test_from_disk_json():
261
+ # Test creating datasource from JSON file
262
+ import json
263
+
264
+ data = [{"column1": 1, "column2": "a"}, {"column1": 2, "column2": "b"}]
265
+
266
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
267
+ json.dump(data, f)
268
+ f.flush()
269
+
270
+ try:
271
+ datasource = Datasource.from_disk(f"test_json_{uuid4()}", f.name)
272
+ assert datasource.length == 2
273
+ assert "column1" in datasource.columns
274
+ assert "column2" in datasource.columns
275
+ finally:
276
+ os.unlink(f.name)
277
+
278
+
279
+ def test_from_disk_already_exists():
280
+ # Test the if_exists parameter with from_disk
281
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
282
+ f.write("column1,column2\n1,a")
283
+ f.flush()
284
+
285
+ try:
286
+ name = f"test_disk_exists_{uuid4()}"
287
+
288
+ # Create the first datasource
289
+ datasource1 = Datasource.from_disk(name, f.name)
290
+ assert datasource1.length == 1
291
+
292
+ # Try to create again with if_exists="error" (should raise)
293
+ with pytest.raises(ValueError):
294
+ Datasource.from_disk(name, f.name, if_exists="error")
295
+
296
+ # Try to create again with if_exists="open" (should return existing)
297
+ datasource2 = Datasource.from_disk(name, f.name, if_exists="open")
298
+ assert datasource2.id == datasource1.id
299
+ assert datasource2.name == datasource1.name
300
+ finally:
301
+ os.unlink(f.name)
302
+
303
+
304
+ def test_to_list(hf_dataset, datasource):
305
+ assert datasource.to_list() == hf_dataset.to_list()
306
+
307
+
308
+ def test_download_datasource(hf_dataset, datasource):
309
+ with tempfile.TemporaryDirectory() as temp_dir:
310
+ # Dataset download
311
+ datasource.download(temp_dir)
312
+ downloaded_hf_dataset_dir = f"{temp_dir}/{datasource.name}"
313
+ assert os.path.exists(downloaded_hf_dataset_dir)
314
+ assert os.path.isdir(downloaded_hf_dataset_dir)
315
+ assert not os.path.exists(f"{downloaded_hf_dataset_dir}.zip")
316
+ dataset_from_downloaded_hf_dataset = Dataset.load_from_disk(downloaded_hf_dataset_dir)
317
+ assert dataset_from_downloaded_hf_dataset.column_names == hf_dataset.column_names
318
+ assert dataset_from_downloaded_hf_dataset.to_dict() == hf_dataset.to_dict()
319
+
320
+ # JSON download
321
+ datasource.download(temp_dir, file_type="json")
322
+ downloaded_json_file = f"{temp_dir}/{datasource.name}.json"
323
+ assert os.path.exists(downloaded_json_file)
324
+ with open(downloaded_json_file, "r") as f:
325
+ content = json.load(f)
326
+ assert content == hf_dataset.to_list()
327
+
328
+ # CSV download
329
+ datasource.download(temp_dir, file_type="csv")
330
+ downloaded_csv_file = f"{temp_dir}/{datasource.name}.csv"
331
+ assert os.path.exists(downloaded_csv_file)
332
+ dataset_from_downloaded_csv = cast(Dataset, Dataset.from_csv(downloaded_csv_file))
333
+ assert dataset_from_downloaded_csv.column_names == hf_dataset.column_names
334
+ assert (
335
+ dataset_from_downloaded_csv.remove_columns("score").to_dict()
336
+ == hf_dataset.remove_columns("score").to_dict()
337
+ )
338
+ # Replace None with NaN for comparison
339
+ assert np.allclose(
340
+ np.array([np.nan if v is None else float(v) for v in dataset_from_downloaded_csv["score"]], dtype=float),
341
+ np.array([np.nan if v is None else float(v) for v in hf_dataset["score"]], dtype=float),
342
+ equal_nan=True,
343
+ )