orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. orca_sdk/__init__.py +10 -4
  2. orca_sdk/_shared/__init__.py +10 -0
  3. orca_sdk/_shared/metrics.py +393 -0
  4. orca_sdk/_shared/metrics_test.py +273 -0
  5. orca_sdk/_utils/analysis_ui.py +12 -10
  6. orca_sdk/_utils/analysis_ui_style.css +0 -3
  7. orca_sdk/_utils/auth.py +27 -29
  8. orca_sdk/_utils/data_parsing.py +28 -2
  9. orca_sdk/_utils/data_parsing_test.py +15 -15
  10. orca_sdk/_utils/pagination.py +126 -0
  11. orca_sdk/_utils/pagination_test.py +132 -0
  12. orca_sdk/_utils/prediction_result_ui.py +67 -21
  13. orca_sdk/_utils/tqdm_file_reader.py +12 -0
  14. orca_sdk/_utils/value_parser.py +45 -0
  15. orca_sdk/_utils/value_parser_test.py +39 -0
  16. orca_sdk/classification_model.py +439 -129
  17. orca_sdk/classification_model_test.py +334 -104
  18. orca_sdk/client.py +3747 -0
  19. orca_sdk/conftest.py +164 -19
  20. orca_sdk/credentials.py +120 -18
  21. orca_sdk/credentials_test.py +20 -0
  22. orca_sdk/datasource.py +259 -68
  23. orca_sdk/datasource_test.py +242 -0
  24. orca_sdk/embedding_model.py +425 -82
  25. orca_sdk/embedding_model_test.py +39 -13
  26. orca_sdk/job.py +337 -0
  27. orca_sdk/job_test.py +108 -0
  28. orca_sdk/memoryset.py +1341 -305
  29. orca_sdk/memoryset_test.py +350 -111
  30. orca_sdk/regression_model.py +684 -0
  31. orca_sdk/regression_model_test.py +369 -0
  32. orca_sdk/telemetry.py +449 -143
  33. orca_sdk/telemetry_test.py +43 -24
  34. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
  35. orca_sdk-0.1.2.dist-info/RECORD +40 -0
  36. {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
  37. orca_sdk/_generated_api_client/__init__.py +0 -3
  38. orca_sdk/_generated_api_client/api/__init__.py +0 -193
  39. orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
  40. orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
  41. orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
  42. orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
  43. orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
  44. orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
  45. orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
  46. orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
  47. orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
  48. orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
  49. orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
  50. orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
  51. orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
  52. orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
  53. orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
  54. orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
  55. orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
  56. orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
  57. orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
  58. orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
  59. orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
  60. orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
  61. orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
  62. orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
  63. orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
  64. orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
  65. orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
  66. orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
  67. orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
  68. orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
  69. orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
  70. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
  71. orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
  72. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
  73. orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
  74. orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
  75. orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
  76. orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
  77. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
  78. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
  79. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
  80. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
  81. orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
  82. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
  83. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
  84. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
  85. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
  86. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
  87. orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
  88. orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
  89. orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
  90. orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
  91. orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
  92. orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
  93. orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
  94. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
  95. orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
  96. orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
  97. orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
  98. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
  99. orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
  100. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
  101. orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
  102. orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
  103. orca_sdk/_generated_api_client/client.py +0 -216
  104. orca_sdk/_generated_api_client/errors.py +0 -38
  105. orca_sdk/_generated_api_client/models/__init__.py +0 -159
  106. orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
  107. orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
  108. orca_sdk/_generated_api_client/models/base_model.py +0 -55
  109. orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
  110. orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
  111. orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
  112. orca_sdk/_generated_api_client/models/column_info.py +0 -114
  113. orca_sdk/_generated_api_client/models/column_type.py +0 -14
  114. orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
  115. orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
  116. orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
  117. orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
  118. orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
  119. orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
  120. orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
  121. orca_sdk/_generated_api_client/models/embed_request.py +0 -127
  122. orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
  123. orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
  124. orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
  125. orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
  126. orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
  127. orca_sdk/_generated_api_client/models/filter_item.py +0 -231
  128. orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
  129. orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
  130. orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
  131. orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
  132. orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
  133. orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
  134. orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
  135. orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
  136. orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
  137. orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
  138. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
  139. orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
  140. orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
  141. orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
  142. orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
  143. orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
  144. orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
  145. orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
  146. orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
  147. orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
  148. orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
  149. orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
  150. orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
  151. orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
  152. orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
  153. orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
  154. orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
  155. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
  156. orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
  157. orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
  158. orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
  159. orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
  160. orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
  161. orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
  162. orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
  163. orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
  164. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
  165. orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
  166. orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
  167. orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
  168. orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
  169. orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
  170. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
  171. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
  172. orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
  173. orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
  174. orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
  175. orca_sdk/_generated_api_client/models/task.py +0 -198
  176. orca_sdk/_generated_api_client/models/task_status.py +0 -14
  177. orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
  178. orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
  179. orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
  180. orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
  181. orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
  182. orca_sdk/_generated_api_client/py.typed +0 -1
  183. orca_sdk/_generated_api_client/types.py +0 -56
  184. orca_sdk/_utils/task.py +0 -73
  185. orca_sdk-0.1.1.dist-info/RECORD +0 -175
@@ -0,0 +1,273 @@
1
+ """
2
+ IMPORTANT:
3
+ - This is a shared file between OrcaLib and the OrcaSDK.
4
+ - Please ensure that it does not have any dependencies on the OrcaLib code.
5
+ - Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
6
+ """
7
+
8
+ from typing import Literal
9
+
10
+ import numpy as np
11
+ import pytest
12
+ import sklearn.metrics
13
+
14
+ from .metrics import (
15
+ calculate_classification_metrics,
16
+ calculate_pr_curve,
17
+ calculate_regression_metrics,
18
+ calculate_roc_curve,
19
+ softmax,
20
+ )
21
+
22
+
23
+ def test_binary_metrics():
24
+ y_true = np.array([0, 1, 1, 0, 1])
25
+ y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
26
+
27
+ metrics = calculate_classification_metrics(y_true, y_score)
28
+
29
+ assert metrics.accuracy == 0.8
30
+ assert metrics.f1_score == 0.8
31
+ assert metrics.roc_auc is not None
32
+ assert metrics.roc_auc > 0.8
33
+ assert metrics.roc_auc < 1.0
34
+ assert metrics.pr_auc is not None
35
+ assert metrics.pr_auc > 0.8
36
+ assert metrics.pr_auc < 1.0
37
+ assert metrics.loss is not None
38
+ assert metrics.loss > 0.0
39
+
40
+
41
+ def test_multiclass_metrics_with_2_classes():
42
+ y_true = np.array([0, 1, 1, 0, 1])
43
+ y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
44
+
45
+ metrics = calculate_classification_metrics(y_true, y_score)
46
+
47
+ assert metrics.accuracy == 0.8
48
+ assert metrics.f1_score == 0.8
49
+ assert metrics.roc_auc is not None
50
+ assert metrics.roc_auc > 0.8
51
+ assert metrics.roc_auc < 1.0
52
+ assert metrics.pr_auc is not None
53
+ assert metrics.pr_auc > 0.8
54
+ assert metrics.pr_auc < 1.0
55
+ assert metrics.loss is not None
56
+ assert metrics.loss > 0.0
57
+
58
+
59
+ @pytest.mark.parametrize(
60
+ "average, multiclass",
61
+ [("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
62
+ )
63
+ def test_multiclass_metrics_with_3_classes(
64
+ average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
65
+ ):
66
+ y_true = np.array([0, 1, 1, 0, 2])
67
+ y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
68
+
69
+ metrics = calculate_classification_metrics(y_true, y_score, average=average, multi_class=multiclass)
70
+
71
+ assert metrics.accuracy == 1.0
72
+ assert metrics.f1_score == 1.0
73
+ assert metrics.roc_auc is not None
74
+ assert metrics.roc_auc > 0.8
75
+ assert metrics.pr_auc is None
76
+ assert metrics.loss is not None
77
+ assert metrics.loss > 0.0
78
+
79
+
80
+ def test_does_not_modify_logits_unless_necessary():
81
+ logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
82
+ expected_labels = [0, 1, 0, 1]
83
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
84
+ expected_labels, logits
85
+ )
86
+
87
+
88
+ def test_normalizes_logits_if_necessary():
89
+ logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
90
+ expected_labels = [0, 1, 0, 1]
91
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
92
+ expected_labels, logits / logits.sum(axis=1, keepdims=True)
93
+ )
94
+
95
+
96
+ def test_softmaxes_logits_if_necessary():
97
+ logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
98
+ expected_labels = [0, 1, 0, 1]
99
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
100
+ expected_labels, softmax(logits)
101
+ )
102
+
103
+
104
+ def test_handles_nan_logits():
105
+ logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
106
+ expected_labels = [0, 1, 0, 1]
107
+ metrics = calculate_classification_metrics(expected_labels, logits)
108
+ assert metrics.loss is None
109
+ assert metrics.accuracy == 0.25
110
+ assert metrics.f1_score == 0.25
111
+ assert metrics.roc_auc is None
112
+ assert metrics.pr_auc is None
113
+ assert metrics.pr_curve is None
114
+ assert metrics.roc_curve is None
115
+ assert metrics.coverage == 0.5
116
+
117
+
118
+ def test_precision_recall_curve():
119
+ y_true = np.array([0, 1, 1, 0, 1])
120
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
121
+
122
+ pr_curve = calculate_pr_curve(y_true, y_score)
123
+
124
+ assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 6
125
+ assert np.allclose(pr_curve["precisions"][0], 0.6)
126
+ assert np.allclose(pr_curve["recalls"][0], 1.0)
127
+ assert np.allclose(pr_curve["precisions"][-1], 1.0)
128
+ assert np.allclose(pr_curve["recalls"][-1], 0.0)
129
+
130
+ # test that thresholds are sorted
131
+ assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
132
+
133
+
134
+ def test_roc_curve():
135
+ y_true = np.array([0, 1, 1, 0, 1])
136
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
137
+
138
+ roc_curve = calculate_roc_curve(y_true, y_score)
139
+
140
+ assert (
141
+ len(roc_curve["false_positive_rates"])
142
+ == len(roc_curve["true_positive_rates"])
143
+ == len(roc_curve["thresholds"])
144
+ == 6
145
+ )
146
+ assert roc_curve["false_positive_rates"][0] == 1.0
147
+ assert roc_curve["true_positive_rates"][0] == 1.0
148
+ assert roc_curve["false_positive_rates"][-1] == 0.0
149
+ assert roc_curve["true_positive_rates"][-1] == 0.0
150
+
151
+ # test that thresholds are sorted
152
+ assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
153
+
154
+
155
+ def test_log_loss_handles_missing_classes_in_y_true():
156
+ # y_true contains only a subset of classes, but predictions include an extra class column
157
+ y_true = np.array([0, 1, 0, 1])
158
+ y_score = np.array(
159
+ [
160
+ [0.7, 0.2, 0.1],
161
+ [0.1, 0.8, 0.1],
162
+ [0.6, 0.3, 0.1],
163
+ [0.2, 0.7, 0.1],
164
+ ]
165
+ )
166
+
167
+ metrics = calculate_classification_metrics(y_true, y_score)
168
+ expected_loss = sklearn.metrics.log_loss(y_true, y_score, labels=[0, 1, 2])
169
+
170
+ assert metrics.loss is not None
171
+ assert np.allclose(metrics.loss, expected_loss)
172
+
173
+
174
+ def test_precision_recall_curve_max_length():
175
+ y_true = np.array([0, 1, 1, 0, 1])
176
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
177
+
178
+ pr_curve = calculate_pr_curve(y_true, y_score, max_length=5)
179
+ assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 5
180
+
181
+ assert np.allclose(pr_curve["precisions"][0], 0.6)
182
+ assert np.allclose(pr_curve["recalls"][0], 1.0)
183
+ assert np.allclose(pr_curve["precisions"][-1], 1.0)
184
+ assert np.allclose(pr_curve["recalls"][-1], 0.0)
185
+
186
+ # test that thresholds are sorted
187
+ assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
188
+
189
+
190
+ def test_roc_curve_max_length():
191
+ y_true = np.array([0, 1, 1, 0, 1])
192
+ y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
193
+
194
+ roc_curve = calculate_roc_curve(y_true, y_score, max_length=5)
195
+ assert (
196
+ len(roc_curve["false_positive_rates"])
197
+ == len(roc_curve["true_positive_rates"])
198
+ == len(roc_curve["thresholds"])
199
+ == 5
200
+ )
201
+ assert np.allclose(roc_curve["false_positive_rates"][0], 1.0)
202
+ assert np.allclose(roc_curve["true_positive_rates"][0], 1.0)
203
+ assert np.allclose(roc_curve["false_positive_rates"][-1], 0.0)
204
+ assert np.allclose(roc_curve["true_positive_rates"][-1], 0.0)
205
+
206
+ # test that thresholds are sorted
207
+ assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
208
+
209
+
210
+ # Regression Metrics Tests
211
+ def test_perfect_regression_predictions():
212
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
213
+ y_pred = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
214
+
215
+ metrics = calculate_regression_metrics(y_true, y_pred)
216
+
217
+ assert metrics.mse == 0.0
218
+ assert metrics.rmse == 0.0
219
+ assert metrics.mae == 0.0
220
+ assert metrics.r2 == 1.0
221
+ assert metrics.explained_variance == 1.0
222
+ assert metrics.loss == 0.0
223
+ assert metrics.anomaly_score_mean is None
224
+ assert metrics.anomaly_score_median is None
225
+ assert metrics.anomaly_score_variance is None
226
+
227
+
228
+ def test_basic_regression_metrics():
229
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
230
+ y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
231
+
232
+ metrics = calculate_regression_metrics(y_true, y_pred)
233
+
234
+ # Check that all metrics are reasonable
235
+ assert metrics.mse > 0.0
236
+ assert metrics.rmse == pytest.approx(np.sqrt(metrics.mse))
237
+ assert metrics.mae > 0.0
238
+ assert 0.0 <= metrics.r2 <= 1.0
239
+ assert 0.0 <= metrics.explained_variance <= 1.0
240
+ assert metrics.loss == metrics.mse
241
+
242
+ # Check specific values based on the data
243
+ expected_mse = np.mean((y_true - y_pred) ** 2)
244
+ assert metrics.mse == pytest.approx(expected_mse)
245
+
246
+ expected_mae = np.mean(np.abs(y_true - y_pred))
247
+ assert metrics.mae == pytest.approx(expected_mae)
248
+
249
+
250
+ def test_regression_metrics_with_anomaly_scores():
251
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
252
+ y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
253
+ anomaly_scores = [0.1, 0.2, 0.15, 0.3, 0.25]
254
+
255
+ metrics = calculate_regression_metrics(y_true, y_pred, anomaly_scores)
256
+
257
+ assert metrics.anomaly_score_mean == pytest.approx(np.mean(anomaly_scores))
258
+ assert metrics.anomaly_score_median == pytest.approx(np.median(anomaly_scores))
259
+ assert metrics.anomaly_score_variance == pytest.approx(np.var(anomaly_scores))
260
+
261
+
262
+ def test_regression_metrics_handles_nans():
263
+ y_true = np.array([1.0, 2.0, 3.0], dtype=np.float32)
264
+ y_pred = np.array([1.1, 1.9, np.nan], dtype=np.float32)
265
+
266
+ metrics = calculate_regression_metrics(y_true, y_pred)
267
+
268
+ assert np.allclose(metrics.coverage, 0.6666666666666666)
269
+ assert metrics.mse > 0.0
270
+ assert metrics.rmse > 0.0
271
+ assert metrics.mae > 0.0
272
+ assert 0.0 <= metrics.r2 <= 1.0
273
+ assert 0.0 <= metrics.explained_variance <= 1.0
@@ -28,9 +28,7 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
28
28
  filters=[("metrics.neighbor_predicted_label_matches_current_label", "==", False)]
29
29
  )
30
30
  # Sort memories by confidence score (higher confidence first)
31
- suggested_relabels.sort(
32
- key=lambda x: (x.metrics and x.metrics.neighbor_predicted_label_confidence) or 0.0, reverse=True
33
- )
31
+ suggested_relabels.sort(key=lambda x: (x.metrics.get("neighbor_predicted_label_confidence", 0.0)), reverse=True)
34
32
 
35
33
  def update_approved(memory_id: str, selected: bool, current_memory_relabel_map: dict[str, RelabelStatus]):
36
34
  current_memory_relabel_map[memory_id]["approved"] = selected
@@ -72,9 +70,9 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
72
70
  current_memory_relabel_map[mem_id]["new_label"] = new_label
73
71
  confidence = "--"
74
72
  current_metrics = current_memory_relabel_map[mem_id]["full_memory"].metrics
75
- if current_metrics and new_label == current_metrics.neighbor_predicted_label:
73
+ if current_metrics and new_label == current_metrics.get("neighbor_predicted_label"):
76
74
  confidence = (
77
- round(current_metrics.neighbor_predicted_label_confidence or 0.0, 2) if current_metrics else 0
75
+ round(current_metrics.get("neighbor_predicted_label_confidence", 0.0), 2) if current_metrics else 0
78
76
  )
79
77
  return (
80
78
  gr.HTML(
@@ -101,8 +99,8 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
101
99
  memory_id=mem.memory_id,
102
100
  approved=False,
103
101
  new_label=(
104
- mem.metrics.neighbor_predicted_label
105
- if (mem.metrics and isinstance(mem.metrics.neighbor_predicted_label, int))
102
+ mem.metrics.get("neighbor_predicted_label")
103
+ if (mem.metrics and isinstance(mem.metrics.get("neighbor_predicted_label"), int))
106
104
  else None
107
105
  ),
108
106
  full_memory=mem,
@@ -150,7 +148,11 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
150
148
  )
151
149
  for i, memory_relabel in enumerate(current_memory_relabel_map.values()):
152
150
  mem = memory_relabel["full_memory"]
153
- with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
151
+ predicted_label = mem.metrics["neighbor_predicted_label"]
152
+ predicted_label_name = label_names[predicted_label]
153
+ predicted_label_confidence = mem.metrics.get("neighbor_predicted_label_confidence", 0)
154
+
155
+ with gr.Row(equal_height=True, variant="panel"):
154
156
  with gr.Column(scale=9):
155
157
  assert isinstance(mem.value, str)
156
158
  gr.Markdown(mem.value, label="Value", height=50)
@@ -160,12 +162,12 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
160
162
  dropdown = gr.Dropdown(
161
163
  choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
162
164
  label="SuggestedLabel",
163
- value=f"{label_names[mem.metrics.neighbor_predicted_label]} ({mem.metrics.neighbor_predicted_label})",
165
+ value=f"{predicted_label_name} ({predicted_label})",
164
166
  interactive=True,
165
167
  container=False,
166
168
  )
167
169
  confidence = gr.HTML(
168
- f"<p style='font-size: 10px; color: #888;'>Confidence: {mem.metrics.neighbor_predicted_label_confidence:.2f}</p>",
170
+ f"<p style='font-size: 10px; color: #888;'>Confidence: {predicted_label_confidence:.2f}</p>",
169
171
  elem_classes="no-padding",
170
172
  )
171
173
  dropdown.change(
@@ -1,6 +1,3 @@
1
- .white {
2
- background-color: white;
3
- }
4
1
  .centered input {
5
2
  margin: auto;
6
3
  }
orca_sdk/_utils/auth.py CHANGED
@@ -2,61 +2,59 @@
2
2
 
3
3
  import logging
4
4
  import os
5
- from typing import List
5
+ from typing import List, Literal
6
6
 
7
7
  from dotenv import load_dotenv
8
8
 
9
- from .._generated_api_client.api import (
10
- check_authentication,
11
- create_api_key,
12
- delete_api_key,
13
- delete_org,
14
- list_api_keys,
15
- )
16
- from .._generated_api_client.client import headers_context, set_base_url, set_headers
17
- from .._generated_api_client.models import CreateApiKeyRequest
18
- from .._generated_api_client.models.api_key_metadata import ApiKeyMetadata
9
+ from ..client import ApiKeyMetadata, orca_api
10
+ from ..credentials import OrcaCredentials
19
11
  from .common import DropMode
20
12
 
21
13
  load_dotenv() # this needs to be here to ensure env is populated before accessing it
14
+
15
+ # the defaults here must match nautilus and lighthouse config defaults
22
16
  _ORCA_ROOT_ACCESS_API_KEY = os.environ.get("ORCA_ROOT_ACCESS_API_KEY", "00000000-0000-0000-0000-000000000000")
17
+ _DEFAULT_ORG_ID = os.environ.get("DEFAULT_ORG_ID", "10e50000-0000-4000-a000-a78dca14af3a")
23
18
 
24
19
 
25
- def _create_api_key(org_id: str, name: str) -> str:
20
+ def _create_api_key(org_id: str, name: str, scopes: list[Literal["ADMINISTER", "PREDICT"]] = ["ADMINISTER"]) -> str:
26
21
  """Creates an API key for the given organization"""
27
- with headers_context({"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id}):
28
- res = create_api_key(body=CreateApiKeyRequest(name=name))
29
- return res.api_key
22
+ response = orca_api.POST(
23
+ "/auth/api_key",
24
+ json={"name": name, "scope": scopes},
25
+ headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
26
+ )
27
+ return response["api_key"]
30
28
 
31
29
 
32
30
  def _list_api_keys(org_id: str) -> List[ApiKeyMetadata]:
33
31
  """Lists all API keys for the given organization"""
34
- with headers_context({"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id}):
35
- return list_api_keys()
32
+ return orca_api.GET("/auth/api_key", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
36
33
 
37
34
 
38
35
  def _delete_api_key(org_id: str, name: str, if_not_exists: DropMode = "error") -> None:
39
36
  """Deletes the API key with the given name from the organization"""
40
- with headers_context({"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id}):
41
- try:
42
- delete_api_key(name_or_id=name)
43
- except LookupError:
44
- if if_not_exists == "error":
45
- raise
37
+ try:
38
+ orca_api.DELETE(
39
+ "/auth/api_key/{name_or_id}",
40
+ params={"name_or_id": name},
41
+ headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
42
+ )
43
+ except LookupError:
44
+ if if_not_exists == "error":
45
+ raise
46
46
 
47
47
 
48
48
  def _delete_org(org_id: str) -> None:
49
49
  """Deletes the organization"""
50
- with headers_context({"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id}):
51
- delete_org()
50
+ orca_api.DELETE("/auth/org", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
52
51
 
53
52
 
54
- def _authenticate_local_api(org_id: str = "10e50000-0000-4000-a000-a78dca14af3a", api_key_name: str = "local") -> None:
53
+ def _authenticate_local_api(org_id: str = _DEFAULT_ORG_ID, api_key_name: str = "local") -> None:
55
54
  """Connect to the local API at http://localhost:1584/ and authenticate with a new API key"""
56
- set_base_url("http://localhost:1584/")
57
55
  _delete_api_key(org_id, api_key_name, if_not_exists="ignore")
58
- set_headers({"Api-Key": _create_api_key(org_id, api_key_name)})
59
- check_authentication()
56
+ OrcaCredentials.set_api_url("http://localhost:1584")
57
+ OrcaCredentials.set_api_key(_create_api_key(org_id, api_key_name))
60
58
  logging.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
61
59
 
62
60
 
@@ -1,6 +1,7 @@
1
1
  import pickle
2
2
  from dataclasses import asdict, is_dataclass
3
3
  from os import PathLike
4
+ from tempfile import TemporaryDirectory
4
5
  from typing import Any, cast
5
6
 
6
7
  from datasets import Dataset
@@ -40,7 +41,24 @@ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]
40
41
  return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
41
42
 
42
43
 
43
- def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None) -> Dataset:
44
+ def hf_dataset_from_torch(
45
+ torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None, ignore_cache=False
46
+ ) -> Dataset:
47
+ """
48
+ Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
49
+
50
+ NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
51
+ cached results can ignore changes you've made to tests. This can make a test appear to succeed
52
+ when it's actually broken or vice versa.
53
+
54
+ Params:
55
+ torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
56
+ column_names: Optional list of column names to use for the dataset. If not provided,
57
+ the column names will be inferred from the data.
58
+ ignore_cache: If True, the dataset will not be cached on disk.
59
+ Returns:
60
+ A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
61
+ """
44
62
  if isinstance(torch_data, TorchDataLoader):
45
63
  dataloader = torch_data
46
64
  else:
@@ -50,7 +68,15 @@ def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_nam
50
68
  for batch in dataloader:
51
69
  yield from parse_batch(batch, column_names=column_names)
52
70
 
53
- return cast(Dataset, Dataset.from_generator(generator))
71
+ if ignore_cache:
72
+ with TemporaryDirectory() as temp_dir:
73
+ ds = Dataset.from_generator(generator, cache_dir=temp_dir)
74
+ else:
75
+ ds = Dataset.from_generator(generator)
76
+
77
+ if not isinstance(ds, Dataset):
78
+ raise ValueError(f"Failed to create dataset from generator: {type(ds)}")
79
+ return ds
54
80
 
55
81
 
56
82
  def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
@@ -29,11 +29,11 @@ class PytorchDictDataset(TorchDataset):
29
29
  def test_hf_dataset_from_torch_dict():
30
30
  # Given a Pytorch dataset that returns a dictionary for each item
31
31
  dataset = PytorchDictDataset()
32
- hf_dataset = hf_dataset_from_torch(dataset)
32
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
33
33
  # Then the HF dataset should be created successfully
34
34
  assert isinstance(hf_dataset, Dataset)
35
35
  assert len(hf_dataset) == len(dataset)
36
- assert set(hf_dataset.column_names) == {"text", "label", "key", "score", "source_id"}
36
+ assert set(hf_dataset.column_names) == {"value", "label", "key", "score", "source_id"}
37
37
 
38
38
 
39
39
  class PytorchTupleDataset(TorchDataset):
@@ -41,7 +41,7 @@ class PytorchTupleDataset(TorchDataset):
41
41
  self.data = SAMPLE_DATA
42
42
 
43
43
  def __getitem__(self, i):
44
- return self.data[i]["text"], self.data[i]["label"]
44
+ return self.data[i]["value"], self.data[i]["label"]
45
45
 
46
46
  def __len__(self):
47
47
  return len(self.data)
@@ -51,11 +51,11 @@ def test_hf_dataset_from_torch_tuple():
51
51
  # Given a Pytorch dataset that returns a tuple for each item
52
52
  dataset = PytorchTupleDataset()
53
53
  # And the correct number of column names passed in
54
- hf_dataset = hf_dataset_from_torch(dataset, column_names=["text", "label"])
54
+ hf_dataset = hf_dataset_from_torch(dataset, column_names=["value", "label"], ignore_cache=True)
55
55
  # Then the HF dataset should be created successfully
56
56
  assert isinstance(hf_dataset, Dataset)
57
57
  assert len(hf_dataset) == len(dataset)
58
- assert hf_dataset.column_names == ["text", "label"]
58
+ assert hf_dataset.column_names == ["value", "label"]
59
59
 
60
60
 
61
61
  def test_hf_dataset_from_torch_tuple_error():
@@ -63,7 +63,7 @@ def test_hf_dataset_from_torch_tuple_error():
63
63
  dataset = PytorchTupleDataset()
64
64
  # Then the HF dataset should raise an error if no column names are passed in
65
65
  with pytest.raises(DatasetGenerationError):
66
- hf_dataset_from_torch(dataset)
66
+ hf_dataset_from_torch(dataset, ignore_cache=True)
67
67
 
68
68
 
69
69
  def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
@@ -71,7 +71,7 @@ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
71
71
  dataset = PytorchTupleDataset()
72
72
  # Then the HF dataset should raise an error if not enough column names are passed in
73
73
  with pytest.raises(DatasetGenerationError):
74
- hf_dataset_from_torch(dataset, column_names=["value"])
74
+ hf_dataset_from_torch(dataset, column_names=["value"], ignore_cache=True)
75
75
 
76
76
 
77
77
  DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
@@ -82,7 +82,7 @@ class PytorchNamedTupleDataset(TorchDataset):
82
82
  self.data = SAMPLE_DATA
83
83
 
84
84
  def __getitem__(self, i):
85
- return DatasetTuple(self.data[i]["text"], self.data[i]["label"])
85
+ return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
86
86
 
87
87
  def __len__(self):
88
88
  return len(self.data)
@@ -92,7 +92,7 @@ def test_hf_dataset_from_torch_named_tuple():
92
92
  # Given a Pytorch dataset that returns a namedtuple for each item
93
93
  dataset = PytorchNamedTupleDataset()
94
94
  # And no column names are passed in
95
- hf_dataset = hf_dataset_from_torch(dataset)
95
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
96
96
  # Then the HF dataset should be created successfully
97
97
  assert isinstance(hf_dataset, Dataset)
98
98
  assert len(hf_dataset) == len(dataset)
@@ -110,7 +110,7 @@ class PytorchDataclassDataset(TorchDataset):
110
110
  self.data = SAMPLE_DATA
111
111
 
112
112
  def __getitem__(self, i):
113
- return DatasetItem(text=self.data[i]["text"], label=self.data[i]["label"])
113
+ return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
114
114
 
115
115
  def __len__(self):
116
116
  return len(self.data)
@@ -119,7 +119,7 @@ class PytorchDataclassDataset(TorchDataset):
119
119
  def test_hf_dataset_from_torch_dataclass():
120
120
  # Given a Pytorch dataset that returns a dataclass for each item
121
121
  dataset = PytorchDataclassDataset()
122
- hf_dataset = hf_dataset_from_torch(dataset)
122
+ hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
123
123
  # Then the HF dataset should be created successfully
124
124
  assert isinstance(hf_dataset, Dataset)
125
125
  assert len(hf_dataset) == len(dataset)
@@ -131,7 +131,7 @@ class PytorchInvalidDataset(TorchDataset):
131
131
  self.data = SAMPLE_DATA
132
132
 
133
133
  def __getitem__(self, i):
134
- return [self.data[i]["text"], self.data[i]["label"]]
134
+ return [self.data[i]["value"], self.data[i]["label"]]
135
135
 
136
136
  def __len__(self):
137
137
  return len(self.data)
@@ -142,7 +142,7 @@ def test_hf_dataset_from_torch_invalid_dataset():
142
142
  dataset = PytorchInvalidDataset()
143
143
  # Then the HF dataset should raise an error
144
144
  with pytest.raises(DatasetGenerationError):
145
- hf_dataset_from_torch(dataset)
145
+ hf_dataset_from_torch(dataset, ignore_cache=True)
146
146
 
147
147
 
148
148
  def test_hf_dataset_from_torchdataloader():
@@ -150,10 +150,10 @@ def test_hf_dataset_from_torchdataloader():
150
150
  dataset = PytorchDictDataset()
151
151
 
152
152
  def collate_fn(x: list[dict]):
153
- return {"value": [item["text"] for item in x], "label": [item["label"] for item in x]}
153
+ return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
154
154
 
155
155
  dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
156
- hf_dataset = hf_dataset_from_torch(dataloader)
156
+ hf_dataset = hf_dataset_from_torch(dataloader, ignore_cache=True)
157
157
  # Then the HF dataset should be created successfully
158
158
  assert isinstance(hf_dataset, Dataset)
159
159
  assert len(hf_dataset) == len(dataset)