orca-sdk 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +31 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/async_client.py +3795 -0
  17. orca_sdk/classification_model.py +601 -129
  18. orca_sdk/classification_model_test.py +415 -117
  19. orca_sdk/client.py +3787 -0
  20. orca_sdk/conftest.py +184 -38
  21. orca_sdk/credentials.py +162 -20
  22. orca_sdk/credentials_test.py +100 -16
  23. orca_sdk/datasource.py +268 -68
  24. orca_sdk/datasource_test.py +266 -18
  25. orca_sdk/embedding_model.py +434 -82
  26. orca_sdk/embedding_model_test.py +66 -33
  27. orca_sdk/job.py +343 -0
  28. orca_sdk/job_test.py +108 -0
  29. orca_sdk/memoryset.py +1690 -324
  30. orca_sdk/memoryset_test.py +456 -119
  31. orca_sdk/regression_model.py +694 -0
  32. orca_sdk/regression_model_test.py +378 -0
  33. orca_sdk/telemetry.py +460 -143
  34. orca_sdk/telemetry_test.py +43 -24
  35. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/METADATA +34 -16
  36. orca_sdk-0.1.3.dist-info/RECORD +41 -0
  37. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.3.dist-info}/WHEEL +1 -1
  38. orca_sdk/_generated_api_client/__init__.py +0 -3
  39. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  40. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  41. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  42. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  43. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  44. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  45. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  46. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  47. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  48. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  49. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  50. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  51. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  52. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  53. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  54. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  55. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  56. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  57. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  58. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  60. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  61. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  62. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  69. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  70. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  71. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  72. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  73. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  76. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  77. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  78. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  79. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  80. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  81. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  82. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  83. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  84. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  85. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  86. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  87. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  91. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  92. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  93. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  94. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  95. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  96. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  97. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  98. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  99. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  100. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  101. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  102. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  103. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  104. orca_sdk/_generated_api_client/client.py +0 -216
  105. orca_sdk/_generated_api_client/errors.py +0 -38
  106. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  107. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  108. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  109. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  110. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  111. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  112. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  113. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  114. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  115. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  116. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  117. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  118. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  119. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  120. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  121. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  122. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  123. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  124. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  125. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  126. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  127. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  128. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  130. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  131. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  132. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  134. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  135. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  136. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  137. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  138. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  140. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  141. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  142. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  143. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  145. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  147. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  149. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  150. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  151. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  152. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  153. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  154. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  155. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  157. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  158. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  163. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  164. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  165. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  166. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  167. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  168. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  169. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  170. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  172. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  174. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  175. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  176. orca_sdk/_generated_api_client/models/task.py +0 -198
  177. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  178. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  179. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  180. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  181. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  182. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  183. orca_sdk/_generated_api_client/py.typed +0 -1
  184. orca_sdk/_generated_api_client/types.py +0 -56
  185. orca_sdk/_utils/task.py +0 -73
  186. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -0,0 +1,273 @@
1
+ """
2
+ IMPORTANT:
3
+ - This is a shared file between OrcaLib and the OrcaSDK.
4
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
5
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
6
+ """
7
+
8
+ from typing import Literal
9
+
10
+ import numpy as np
11
+ import pytest
12
+ import sklearn.metrics
13
+
14
+ from .metrics import (
15
+ calculate_classification_metrics,
16
+ calculate_pr_curve,
17
+ calculate_regression_metrics,
18
+ calculate_roc_curve,
19
+ softmax,
20
+ )
21
+
22
+
23
+ def test_binary_metrics():
24
+ y_true = np.array([0, 1, 1, 0, 1])
25
+ y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
26
+
27
+ metrics = calculate_classification_metrics(y_true, y_score)
28
+
29
+ assert metrics.accuracy == 0.8
30
+ assert metrics.f1_score == 0.8
31
+ assert metrics.roc_auc is not None
32
+ assert metrics.roc_auc > 0.8
33
+ assert metrics.roc_auc < 1.0
34
+ assert metrics.pr_auc is not None
35
+ assert metrics.pr_auc > 0.8
36
+ assert metrics.pr_auc < 1.0
37
+ assert metrics.loss is not None
38
+ assert metrics.loss > 0.0
39
+
40
+
41
+ def test_multiclass_metrics_with_2_classes():
42
+ y_true = np.array([0, 1, 1, 0, 1])
43
+ y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
44
+
45
+ metrics = calculate_classification_metrics(y_true, y_score)
46
+
47
+ assert metrics.accuracy == 0.8
48
+ assert metrics.f1_score == 0.8
49
+ assert metrics.roc_auc is not None
50
+ assert metrics.roc_auc > 0.8
51
+ assert metrics.roc_auc < 1.0
52
+ assert metrics.pr_auc is not None
53
+ assert metrics.pr_auc > 0.8
54
+ assert metrics.pr_auc < 1.0
55
+ assert metrics.loss is not None
56
+ assert metrics.loss > 0.0
57
+
58
+
59
+ @pytest.mark.parametrize(
60
+ "average, multiclass",
61
+ [("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
62
+ )
63
+ def test_multiclass_metrics_with_3_classes(
64
+ average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
65
+ ):
66
+ y_true = np.array([0, 1, 1, 0, 2])
67
+ y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
68
+
69
+ metrics = calculate_classification_metrics(y_true, y_score, average=average, multi_class=multiclass)
70
+
71
+ assert metrics.accuracy == 1.0
72
+ assert metrics.f1_score == 1.0
73
+ assert metrics.roc_auc is not None
74
+ assert metrics.roc_auc > 0.8
75
+ assert metrics.pr_auc is None
76
+ assert metrics.loss is not None
77
+ assert metrics.loss > 0.0
78
+
79
+
80
+ def test_does_not_modify_logits_unless_necessary():
81
+ logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
82
+ expected_labels = [0, 1, 0, 1]
83
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
84
+ expected_labels, logits
85
+ )
86
+
87
+
88
+ def test_normalizes_logits_if_necessary():
89
+ logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
90
+ expected_labels = [0, 1, 0, 1]
91
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
92
+ expected_labels, logits / logits.sum(axis=1, keepdims=True)
93
+ )
94
+
95
+
96
+ def test_softmaxes_logits_if_necessary():
97
+ logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
98
+ expected_labels = [0, 1, 0, 1]
99
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
100
+ expected_labels, softmax(logits)
101
+ )
102
+
103
+
104
+ def test_handles_nan_logits():
105
+ logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
106
+ expected_labels = [0, 1, 0, 1]
107
+ metrics = calculate_classification_metrics(expected_labels, logits)
108
+ assert metrics.loss is None
109
+ assert metrics.accuracy == 0.25
110
+ assert metrics.f1_score == 0.25
111
+ assert metrics.roc_auc is None
112
+ assert metrics.pr_auc is None
113
+ assert metrics.pr_curve is None
114
+ assert metrics.roc_curve is None
115
+ assert metrics.coverage == 0.5
116
+
117
+
118
+ def test_precision_recall_curve():
119
+ y_true = np.array([0, 1, 1, 0, 1])
120
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
121
+
122
+ pr_curve = calculate_pr_curve(y_true, y_score)
123
+
124
+ assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 6
125
+ assert np.allclose(pr_curve["precisions"][0], 0.6)
126
+ assert np.allclose(pr_curve["recalls"][0], 1.0)
127
+ assert np.allclose(pr_curve["precisions"][-1], 1.0)
128
+ assert np.allclose(pr_curve["recalls"][-1], 0.0)
129
+
130
+ # test that thresholds are sorted
131
+ assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
132
+
133
+
134
+ def test_roc_curve():
135
+ y_true = np.array([0, 1, 1, 0, 1])
136
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
137
+
138
+ roc_curve = calculate_roc_curve(y_true, y_score)
139
+
140
+ assert (
141
+ len(roc_curve["false_positive_rates"])
142
+ == len(roc_curve["true_positive_rates"])
143
+ == len(roc_curve["thresholds"])
144
+ == 6
145
+ )
146
+ assert roc_curve["false_positive_rates"][0] == 1.0
147
+ assert roc_curve["true_positive_rates"][0] == 1.0
148
+ assert roc_curve["false_positive_rates"][-1] == 0.0
149
+ assert roc_curve["true_positive_rates"][-1] == 0.0
150
+
151
+ # test that thresholds are sorted
152
+ assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
153
+
154
+
155
+ def test_log_loss_handles_missing_classes_in_y_true():
156
+ # y_true contains only a subset of classes, but predictions include an extra class column
157
+ y_true = np.array([0, 1, 0, 1])
158
+ y_score = np.array(
159
+ [
160
+ [0.7, 0.2, 0.1],
161
+ [0.1, 0.8, 0.1],
162
+ [0.6, 0.3, 0.1],
163
+ [0.2, 0.7, 0.1],
164
+ ]
165
+ )
166
+
167
+ metrics = calculate_classification_metrics(y_true, y_score)
168
+ expected_loss = sklearn.metrics.log_loss(y_true, y_score, labels=[0, 1, 2])
169
+
170
+ assert metrics.loss is not None
171
+ assert np.allclose(metrics.loss, expected_loss)
172
+
173
+
174
+ def test_precision_recall_curve_max_length():
175
+ y_true = np.array([0, 1, 1, 0, 1])
176
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
177
+
178
+ pr_curve = calculate_pr_curve(y_true, y_score, max_length=5)
179
+ assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 5
180
+
181
+ assert np.allclose(pr_curve["precisions"][0], 0.6)
182
+ assert np.allclose(pr_curve["recalls"][0], 1.0)
183
+ assert np.allclose(pr_curve["precisions"][-1], 1.0)
184
+ assert np.allclose(pr_curve["recalls"][-1], 0.0)
185
+
186
+ # test that thresholds are sorted
187
+ assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
188
+
189
+
190
+ def test_roc_curve_max_length():
191
+ y_true = np.array([0, 1, 1, 0, 1])
192
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
193
+
194
+ roc_curve = calculate_roc_curve(y_true, y_score, max_length=5)
195
+ assert (
196
+ len(roc_curve["false_positive_rates"])
197
+ == len(roc_curve["true_positive_rates"])
198
+ == len(roc_curve["thresholds"])
199
+ == 5
200
+ )
201
+ assert np.allclose(roc_curve["false_positive_rates"][0], 1.0)
202
+ assert np.allclose(roc_curve["true_positive_rates"][0], 1.0)
203
+ assert np.allclose(roc_curve["false_positive_rates"][-1], 0.0)
204
+ assert np.allclose(roc_curve["true_positive_rates"][-1], 0.0)
205
+
206
+ # test that thresholds are sorted
207
+ assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
208
+
209
+
210
+ # Regression Metrics Tests
211
+ def test_perfect_regression_predictions():
212
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
213
+ y_pred = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
214
+
215
+ metrics = calculate_regression_metrics(y_true, y_pred)
216
+
217
+ assert metrics.mse == 0.0
218
+ assert metrics.rmse == 0.0
219
+ assert metrics.mae == 0.0
220
+ assert metrics.r2 == 1.0
221
+ assert metrics.explained_variance == 1.0
222
+ assert metrics.loss == 0.0
223
+ assert metrics.anomaly_score_mean is None
224
+ assert metrics.anomaly_score_median is None
225
+ assert metrics.anomaly_score_variance is None
226
+
227
+
228
+ def test_basic_regression_metrics():
229
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
230
+ y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
231
+
232
+ metrics = calculate_regression_metrics(y_true, y_pred)
233
+
234
+ # Check that all metrics are reasonable
235
+ assert metrics.mse > 0.0
236
+ assert metrics.rmse == pytest.approx(np.sqrt(metrics.mse))
237
+ assert metrics.mae > 0.0
238
+ assert 0.0 <= metrics.r2 <= 1.0
239
+ assert 0.0 <= metrics.explained_variance <= 1.0
240
+ assert metrics.loss == metrics.mse
241
+
242
+ # Check specific values based on the data
243
+ expected_mse = np.mean((y_true - y_pred) ** 2)
244
+ assert metrics.mse == pytest.approx(expected_mse)
245
+
246
+ expected_mae = np.mean(np.abs(y_true - y_pred))
247
+ assert metrics.mae == pytest.approx(expected_mae)
248
+
249
+
250
+ def test_regression_metrics_with_anomaly_scores():
251
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
252
+ y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
253
+ anomaly_scores = [0.1, 0.2, 0.15, 0.3, 0.25]
254
+
255
+ metrics = calculate_regression_metrics(y_true, y_pred, anomaly_scores)
256
+
257
+ assert metrics.anomaly_score_mean == pytest.approx(np.mean(anomaly_scores))
258
+ assert metrics.anomaly_score_median == pytest.approx(np.median(anomaly_scores))
259
+ assert metrics.anomaly_score_variance == pytest.approx(np.var(anomaly_scores))
260
+
261
+
262
+ def test_regression_metrics_handles_nans():
263
+ y_true = np.array([1.0, 2.0, 3.0], dtype=np.float32)
264
+ y_pred = np.array([1.1, 1.9, np.nan], dtype=np.float32)
265
+
266
+ metrics = calculate_regression_metrics(y_true, y_pred)
267
+
268
+ assert np.allclose(metrics.coverage, 0.6666666666666666)
269
+ assert metrics.mse > 0.0
270
+ assert metrics.rmse > 0.0
271
+ assert metrics.mae > 0.0
272
+ assert 0.0 <= metrics.r2 <= 1.0
273
+ assert 0.0 <= metrics.explained_variance <= 1.0
@@ -28,9 +28,7 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
28
28
  filters=[("metrics.neighbor_predicted_label_matches_current_label", "==", False)]
29
29
  )
30
30
  # Sort memories by confidence score (higher confidence first)
31
- suggested_relabels.sort(
32
- key=lambda x: (x.metrics and x.metrics.neighbor_predicted_label_confidence) or 0.0, reverse=True
33
- )
31
+ suggested_relabels.sort(key=lambda x: (x.metrics.get("neighbor_predicted_label_confidence", 0.0)), reverse=True)
34
32
 
35
33
  def update_approved(memory_id: str, selected: bool, current_memory_relabel_map: dict[str, RelabelStatus]):
36
34
  current_memory_relabel_map[memory_id]["approved"] = selected
@@ -72,9 +70,9 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
72
70
  current_memory_relabel_map[mem_id]["new_label"] = new_label
73
71
  confidence = "--"
74
72
  current_metrics = current_memory_relabel_map[mem_id]["full_memory"].metrics
75
- if current_metrics and new_label == current_metrics.neighbor_predicted_label:
73
+ if current_metrics and new_label == current_metrics.get("neighbor_predicted_label"):
76
74
  confidence = (
77
- round(current_metrics.neighbor_predicted_label_confidence or 0.0, 2) if current_metrics else 0
75
+ round(current_metrics.get("neighbor_predicted_label_confidence", 0.0), 2) if current_metrics else 0
78
76
  )
79
77
  return (
80
78
  gr.HTML(
@@ -101,8 +99,8 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
101
99
  memory_id=mem.memory_id,
102
100
  approved=False,
103
101
  new_label=(
104
- mem.metrics.neighbor_predicted_label
105
- if (mem.metrics and isinstance(mem.metrics.neighbor_predicted_label, int))
102
+ mem.metrics.get("neighbor_predicted_label")
103
+ if (mem.metrics and isinstance(mem.metrics.get("neighbor_predicted_label"), int))
106
104
  else None
107
105
  ),
108
106
  full_memory=mem,
@@ -150,7 +148,11 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
150
148
  )
151
149
  for i, memory_relabel in enumerate(current_memory_relabel_map.values()):
152
150
  mem = memory_relabel["full_memory"]
153
- with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
151
+ predicted_label = mem.metrics["neighbor_predicted_label"]
152
+ predicted_label_name = label_names[predicted_label]
153
+ predicted_label_confidence = mem.metrics.get("neighbor_predicted_label_confidence", 0)
154
+
155
+ with gr.Row(equal_height=True, variant="panel"):
154
156
  with gr.Column(scale=9):
155
157
  assert isinstance(mem.value, str)
156
158
  gr.Markdown(mem.value, label="Value", height=50)
@@ -160,12 +162,12 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
160
162
  dropdown = gr.Dropdown(
161
163
  choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
162
164
  label="SuggestedLabel",
163
- value=f"{label_names[mem.metrics.neighbor_predicted_label]} ({mem.metrics.neighbor_predicted_label})",
165
+ value=f"{predicted_label_name} ({predicted_label})",
164
166
  interactive=True,
165
167
  container=False,
166
168
  )
167
169
  confidence = gr.HTML(
168
- f"<p style='font-size: 10px; color: #888;'>Confidence: {mem.metrics.neighbor_predicted_label_confidence:.2f}</p>",
170
+ f"<p style='font-size: 10px; color: #888;'>Confidence: {predicted_label_confidence:.2f}</p>",
169
171
  elem_classes="no-padding",
170
172
  )
171
173
  dropdown.change(
@@ -1,6 +1,3 @@
1
- .white {
2
- background-color: white;
3
- }
4
1
  .centered input {
5
2
  margin: auto;
6
3
  }
orca_sdk/_utils/auth.py CHANGED
@@ -2,61 +2,63 @@
2
2
 
3
3
  import logging
4
4
  import os
5
- from typing import List
5
+ from typing import List, Literal
6
6
 
7
7
  from dotenv import load_dotenv
8
8
 
9
- from .._generated_api_client.api import (
10
- check_authentication,
11
- create_api_key,
12
- delete_api_key,
13
- delete_org,
14
- list_api_keys,
15
- )
16
- from .._generated_api_client.client import headers_context, set_base_url, set_headers
17
- from .._generated_api_client.models import CreateApiKeyRequest
18
- from .._generated_api_client.models.api_key_metadata import ApiKeyMetadata
9
+ from ..client import ApiKeyMetadata, OrcaClient
19
10
  from .common import DropMode
20
11
 
21
12
  load_dotenv() # this needs to be here to ensure env is populated before accessing it
13
+
14
+ # the defaults here must match nautilus and lighthouse config defaults
22
15
  _ORCA_ROOT_ACCESS_API_KEY = os.environ.get("ORCA_ROOT_ACCESS_API_KEY", "00000000-0000-0000-0000-000000000000")
16
+ _DEFAULT_ORG_ID = os.environ.get("DEFAULT_ORG_ID", "10e50000-0000-4000-a000-a78dca14af3a")
23
17
 
24
18
 
25
- def _create_api_key(org_id: str, name: str) -> str:
19
+ def _create_api_key(org_id: str, name: str, scopes: list[Literal["ADMINISTER", "PREDICT"]] = ["ADMINISTER"]) -> str:
26
20
  """Creates an API key for the given organization"""
27
- with headers_context({"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id}):
28
- res = create_api_key(body=CreateApiKeyRequest(name=name))
29
- return res.api_key
21
+ client = OrcaClient._resolve_client()
22
+ response = client.POST(
23
+ "/auth/api_key",
24
+ json={"name": name, "scope": scopes},
25
+ headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
26
+ )
27
+ return response["api_key"]
30
28
 
31
29
 
32
30
  def _list_api_keys(org_id: str) -> List[ApiKeyMetadata]:
33
31
  """Lists all API keys for the given organization"""
34
- with headers_context({"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id}):
35
- return list_api_keys()
32
+ client = OrcaClient._resolve_client()
33
+ return client.GET("/auth/api_key", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
36
34
 
37
35
 
38
36
  def _delete_api_key(org_id: str, name: str, if_not_exists: DropMode = "error") -> None:
39
37
  """Deletes the API key with the given name from the organization"""
40
- with headers_context({"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id}):
41
- try:
42
- delete_api_key(name_or_id=name)
43
- except LookupError:
44
- if if_not_exists == "error":
45
- raise
38
+ try:
39
+ client = OrcaClient._resolve_client()
40
+ client.DELETE(
41
+ "/auth/api_key/{name_or_id}",
42
+ params={"name_or_id": name},
43
+ headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
44
+ )
45
+ except LookupError:
46
+ if if_not_exists == "error":
47
+ raise
46
48
 
47
49
 
48
50
  def _delete_org(org_id: str) -> None:
49
51
  """Deletes the organization"""
50
- with headers_context({"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id}):
51
- delete_org()
52
+ client = OrcaClient._resolve_client()
53
+ client.DELETE("/auth/org", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
52
54
 
53
55
 
54
- def _authenticate_local_api(org_id: str = "10e50000-0000-4000-a000-a78dca14af3a", api_key_name: str = "local") -> None:
56
+ def _authenticate_local_api(org_id: str = _DEFAULT_ORG_ID, api_key_name: str = "local") -> None:
55
57
  """Connect to the local API at http://localhost:1584/ and authenticate with a new API key"""
56
- set_base_url("http://localhost:1584/")
57
58
  _delete_api_key(org_id, api_key_name, if_not_exists="ignore")
58
- set_headers({"Api-Key": _create_api_key(org_id, api_key_name)})
59
- check_authentication()
59
+ client = OrcaClient._resolve_client()
60
+ client.base_url = "http://localhost:1584"
61
+ client.headers.update({"Api-Key": _create_api_key(org_id, api_key_name)})
60
62
  logging.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
61
63
 
62
64
 
@@ -1,6 +1,7 @@
1
1
  import pickle
2
2
  from dataclasses import asdict, is_dataclass
3
3
  from os import PathLike
4
+ from tempfile import TemporaryDirectory
4
5
  from typing import Any, cast
5
6
 
6
7
  from datasets import Dataset
@@ -40,7 +41,24 @@ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]
40
41
  return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
41
42
 
42
43
 
43
- def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None) -> Dataset:
44
+ def hf_dataset_from_torch(
45
+ torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None, ignore_cache=False
46
+ ) -> Dataset:
47
+ """
48
+ Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
49
+
50
+ NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
51
+ cached results can ignore changes you've made to tests. This can make a test appear to succeed
52
+ when it's actually broken or vice versa.
53
+
54
+ Params:
55
+ torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
56
+ column_names: Optional list of column names to use for the dataset. If not provided,
57
+ the column names will be inferred from the data.
58
+ ignore_cache: If True, the dataset will not be cached on disk.
59
+ Returns:
60
+ A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
61
+ """
44
62
  if isinstance(torch_data, TorchDataLoader):
45
63
  dataloader = torch_data
46
64
  else:
@@ -50,7 +68,15 @@ def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_nam
50
68
  for batch in dataloader:
51
69
  yield from parse_batch(batch, column_names=column_names)
52
70
 
53
- return cast(Dataset, Dataset.from_generator(generator))
71
+ if ignore_cache:
72
+ with TemporaryDirectory() as temp_dir:
73
+ ds = Dataset.from_generator(generator, cache_dir=temp_dir)
74
+ else:
75
+ ds = Dataset.from_generator(generator)
76
+
77
+ if not isinstance(ds, Dataset):
78
+ raise ValueError(f"Failed to create dataset from generator: {type(ds)}")
79
+ return ds
54
80
 
55
81
 
56
82
  def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
@@ -29,11 +29,11 @@ class PytorchDictDataset(TorchDataset):
29
29
  def test_hf_dataset_from_torch_dict():
30
30
  # Given a Pytorch dataset that returns a dictionary for each item
31
31
  dataset = PytorchDictDataset()
32
- hf_dataset = hf_dataset_from_torch(dataset)
32
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
33
33
  # Then the HF dataset should be created successfully
34
34
  assert isinstance(hf_dataset, Dataset)
35
35
  assert len(hf_dataset) == len(dataset)
36
- assert set(hf_dataset.column_names) == {"text", "label", "key", "score", "source_id"}
36
+ assert set(hf_dataset.column_names) == {"value", "label", "key", "score", "source_id"}
37
37
 
38
38
 
39
39
  class PytorchTupleDataset(TorchDataset):
@@ -41,7 +41,7 @@ class PytorchTupleDataset(TorchDataset):
41
41
  self.data = SAMPLE_DATA
42
42
 
43
43
  def __getitem__(self, i):
44
- return self.data[i]["text"], self.data[i]["label"]
44
+ return self.data[i]["value"], self.data[i]["label"]
45
45
 
46
46
  def __len__(self):
47
47
  return len(self.data)
@@ -51,11 +51,11 @@ def test_hf_dataset_from_torch_tuple():
51
51
  # Given a Pytorch dataset that returns a tuple for each item
52
52
  dataset = PytorchTupleDataset()
53
53
  # And the correct number of column names passed in
54
- hf_dataset = hf_dataset_from_torch(dataset, column_names=["text", "label"])
54
+ hf_dataset = hf_dataset_from_torch(dataset, column_names=["value", "label"], ignore_cache=True)
55
55
  # Then the HF dataset should be created successfully
56
56
  assert isinstance(hf_dataset, Dataset)
57
57
  assert len(hf_dataset) == len(dataset)
58
- assert hf_dataset.column_names == ["text", "label"]
58
+ assert hf_dataset.column_names == ["value", "label"]
59
59
 
60
60
 
61
61
  def test_hf_dataset_from_torch_tuple_error():
@@ -63,7 +63,7 @@ def test_hf_dataset_from_torch_tuple_error():
63
63
  dataset = PytorchTupleDataset()
64
64
  # Then the HF dataset should raise an error if no column names are passed in
65
65
  with pytest.raises(DatasetGenerationError):
66
- hf_dataset_from_torch(dataset)
66
+ hf_dataset_from_torch(dataset, ignore_cache=True)
67
67
 
68
68
 
69
69
  def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
@@ -71,7 +71,7 @@ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
71
71
  dataset = PytorchTupleDataset()
72
72
  # Then the HF dataset should raise an error if not enough column names are passed in
73
73
  with pytest.raises(DatasetGenerationError):
74
- hf_dataset_from_torch(dataset, column_names=["value"])
74
+ hf_dataset_from_torch(dataset, column_names=["value"], ignore_cache=True)
75
75
 
76
76
 
77
77
  DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
@@ -82,7 +82,7 @@ class PytorchNamedTupleDataset(TorchDataset):
82
82
  self.data = SAMPLE_DATA
83
83
 
84
84
  def __getitem__(self, i):
85
- return DatasetTuple(self.data[i]["text"], self.data[i]["label"])
85
+ return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
86
86
 
87
87
  def __len__(self):
88
88
  return len(self.data)
@@ -92,7 +92,7 @@ def test_hf_dataset_from_torch_named_tuple():
92
92
  # Given a Pytorch dataset that returns a namedtuple for each item
93
93
  dataset = PytorchNamedTupleDataset()
94
94
  # And no column names are passed in
95
- hf_dataset = hf_dataset_from_torch(dataset)
95
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
96
96
  # Then the HF dataset should be created successfully
97
97
  assert isinstance(hf_dataset, Dataset)
98
98
  assert len(hf_dataset) == len(dataset)
@@ -110,7 +110,7 @@ class PytorchDataclassDataset(TorchDataset):
110
110
  self.data = SAMPLE_DATA
111
111
 
112
112
  def __getitem__(self, i):
113
- return DatasetItem(text=self.data[i]["text"], label=self.data[i]["label"])
113
+ return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
114
114
 
115
115
  def __len__(self):
116
116
  return len(self.data)
@@ -119,7 +119,7 @@ class PytorchDataclassDataset(TorchDataset):
119
119
  def test_hf_dataset_from_torch_dataclass():
120
120
  # Given a Pytorch dataset that returns a dataclass for each item
121
121
  dataset = PytorchDataclassDataset()
122
- hf_dataset = hf_dataset_from_torch(dataset)
122
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
123
123
  # Then the HF dataset should be created successfully
124
124
  assert isinstance(hf_dataset, Dataset)
125
125
  assert len(hf_dataset) == len(dataset)
@@ -131,7 +131,7 @@ class PytorchInvalidDataset(TorchDataset):
131
131
  self.data = SAMPLE_DATA
132
132
 
133
133
  def __getitem__(self, i):
134
- return [self.data[i]["text"], self.data[i]["label"]]
134
+ return [self.data[i]["value"], self.data[i]["label"]]
135
135
 
136
136
  def __len__(self):
137
137
  return len(self.data)
@@ -142,7 +142,7 @@ def test_hf_dataset_from_torch_invalid_dataset():
142
142
  dataset = PytorchInvalidDataset()
143
143
  # Then the HF dataset should raise an error
144
144
  with pytest.raises(DatasetGenerationError):
145
- hf_dataset_from_torch(dataset)
145
+ hf_dataset_from_torch(dataset, ignore_cache=True)
146
146
 
147
147
 
148
148
  def test_hf_dataset_from_torchdataloader():
@@ -150,10 +150,10 @@ def test_hf_dataset_from_torchdataloader():
150
150
  dataset = PytorchDictDataset()
151
151
 
152
152
  def collate_fn(x: list[dict]):
153
- return {"value": [item["text"] for item in x], "label": [item["label"] for item in x]}
153
+ return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
154
154
 
155
155
  dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
156
- hf_dataset = hf_dataset_from_torch(dataloader)
156
+ hf_dataset = hf_dataset_from_torch(dataloader, ignore_cache=True)
157
157
  # Then the HF dataset should be created successfully
158
158
  assert isinstance(hf_dataset, Dataset)
159
159
  assert len(hf_dataset) == len(dataset)