orca-sdk 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +10 -4
- orca_sdk/_shared/__init__.py +10 -0
- orca_sdk/_shared/metrics.py +393 -0
- orca_sdk/_shared/metrics_test.py +273 -0
- orca_sdk/_utils/analysis_ui.py +12 -10
- orca_sdk/_utils/analysis_ui_style.css +0 -3
- orca_sdk/_utils/auth.py +27 -29
- orca_sdk/_utils/data_parsing.py +28 -2
- orca_sdk/_utils/data_parsing_test.py +15 -15
- orca_sdk/_utils/pagination.py +126 -0
- orca_sdk/_utils/pagination_test.py +132 -0
- orca_sdk/_utils/prediction_result_ui.py +67 -21
- orca_sdk/_utils/tqdm_file_reader.py +12 -0
- orca_sdk/_utils/value_parser.py +45 -0
- orca_sdk/_utils/value_parser_test.py +39 -0
- orca_sdk/classification_model.py +439 -129
- orca_sdk/classification_model_test.py +334 -104
- orca_sdk/client.py +3747 -0
- orca_sdk/conftest.py +164 -19
- orca_sdk/credentials.py +120 -18
- orca_sdk/credentials_test.py +20 -0
- orca_sdk/datasource.py +259 -68
- orca_sdk/datasource_test.py +242 -0
- orca_sdk/embedding_model.py +425 -82
- orca_sdk/embedding_model_test.py +39 -13
- orca_sdk/job.py +337 -0
- orca_sdk/job_test.py +108 -0
- orca_sdk/memoryset.py +1341 -305
- orca_sdk/memoryset_test.py +350 -111
- orca_sdk/regression_model.py +684 -0
- orca_sdk/regression_model_test.py +369 -0
- orca_sdk/telemetry.py +449 -143
- orca_sdk/telemetry_test.py +43 -24
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/METADATA +34 -16
- orca_sdk-0.1.2.dist-info/RECORD +40 -0
- {orca_sdk-0.1.1.dist-info → orca_sdk-0.1.2.dist-info}/WHEEL +1 -1
- orca_sdk/_generated_api_client/__init__.py +0 -3
- orca_sdk/_generated_api_client/api/__init__.py +0 -193
- orca_sdk/_generated_api_client/api/auth/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/auth/check_authentication_auth_get.py +0 -128
- orca_sdk/_generated_api_client/api/auth/create_api_key_auth_api_key_post.py +0 -170
- orca_sdk/_generated_api_client/api/auth/delete_api_key_auth_api_key_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/auth/delete_org_auth_org_delete.py +0 -130
- orca_sdk/_generated_api_client/api/auth/list_api_keys_auth_api_key_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/classification_model/create_evaluation_classification_model_model_name_or_id_evaluation_post.py +0 -183
- orca_sdk/_generated_api_client/api/classification_model/create_model_classification_model_post.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py +0 -168
- orca_sdk/_generated_api_client/api/classification_model/delete_model_classification_model_name_or_id_delete.py +0 -154
- orca_sdk/_generated_api_client/api/classification_model/get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py +0 -170
- orca_sdk/_generated_api_client/api/classification_model/get_model_classification_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/classification_model/list_evaluations_classification_model_model_name_or_id_evaluation_get.py +0 -161
- orca_sdk/_generated_api_client/api/classification_model/list_models_classification_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/classification_model/predict_gpu_classification_model_name_or_id_prediction_post.py +0 -190
- orca_sdk/_generated_api_client/api/datasource/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/datasource/create_datasource_datasource_post.py +0 -167
- orca_sdk/_generated_api_client/api/datasource/delete_datasource_datasource_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/get_datasource_datasource_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/datasource/list_datasources_datasource_get.py +0 -127
- orca_sdk/_generated_api_client/api/default/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/default/healthcheck_get.py +0 -118
- orca_sdk/_generated_api_client/api/default/healthcheck_gpu_get.py +0 -118
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/create_finetuned_embedding_model_finetuned_embedding_model_post.py +0 -168
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/delete_finetuned_embedding_model_finetuned_embedding_model_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/embed_with_finetuned_model_gpu_finetuned_embedding_model_name_or_id_embedding_post.py +0 -189
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/get_finetuned_embedding_model_finetuned_embedding_model_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/finetuned_embedding_model/list_finetuned_embedding_models_finetuned_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/create_analysis_memoryset_name_or_id_analysis_post.py +0 -183
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +0 -168
- orca_sdk/_generated_api_client/api/memoryset/delete_memories_memoryset_name_or_id_memories_delete_post.py +0 -181
- orca_sdk/_generated_api_client/api/memoryset/delete_memory_memoryset_name_or_id_memory_memory_id_delete.py +0 -167
- orca_sdk/_generated_api_client/api/memoryset/delete_memoryset_memoryset_name_or_id_delete.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/get_analysis_memoryset_name_or_id_analysis_analysis_task_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +0 -169
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +0 -184
- orca_sdk/_generated_api_client/api/memoryset/list_analyses_memoryset_name_or_id_analysis_get.py +0 -260
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +0 -127
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +0 -193
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +0 -188
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +0 -191
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +0 -187
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/embed_with_pretrained_model_gpu_pretrained_embedding_model_model_name_embedding_post.py +0 -188
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/get_pretrained_embedding_model_pretrained_embedding_model_model_name_get.py +0 -157
- orca_sdk/_generated_api_client/api/pretrained_embedding_model/list_pretrained_embedding_models_pretrained_embedding_model_get.py +0 -127
- orca_sdk/_generated_api_client/api/task/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/task/abort_task_task_task_id_abort_delete.py +0 -154
- orca_sdk/_generated_api_client/api/task/get_task_status_task_task_id_status_get.py +0 -156
- orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +0 -243
- orca_sdk/_generated_api_client/api/telemetry/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/telemetry/drop_feedback_category_with_data_telemetry_feedback_category_name_or_id_delete.py +0 -162
- orca_sdk/_generated_api_client/api/telemetry/get_feedback_category_telemetry_feedback_category_name_or_id_get.py +0 -156
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +0 -157
- orca_sdk/_generated_api_client/api/telemetry/list_feedback_categories_telemetry_feedback_category_get.py +0 -127
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +0 -175
- orca_sdk/_generated_api_client/api/telemetry/record_prediction_feedback_telemetry_prediction_feedback_put.py +0 -171
- orca_sdk/_generated_api_client/api/telemetry/update_prediction_telemetry_prediction_prediction_id_patch.py +0 -181
- orca_sdk/_generated_api_client/client.py +0 -216
- orca_sdk/_generated_api_client/errors.py +0 -38
- orca_sdk/_generated_api_client/models/__init__.py +0 -159
- orca_sdk/_generated_api_client/models/analyze_neighbor_labels_result.py +0 -84
- orca_sdk/_generated_api_client/models/api_key_metadata.py +0 -118
- orca_sdk/_generated_api_client/models/base_model.py +0 -55
- orca_sdk/_generated_api_client/models/body_create_datasource_datasource_post.py +0 -176
- orca_sdk/_generated_api_client/models/classification_evaluation_result.py +0 -114
- orca_sdk/_generated_api_client/models/clone_labeled_memoryset_request.py +0 -150
- orca_sdk/_generated_api_client/models/column_info.py +0 -114
- orca_sdk/_generated_api_client/models/column_type.py +0 -14
- orca_sdk/_generated_api_client/models/conflict_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/create_api_key_request.py +0 -99
- orca_sdk/_generated_api_client/models/create_api_key_response.py +0 -126
- orca_sdk/_generated_api_client/models/create_labeled_memoryset_request.py +0 -259
- orca_sdk/_generated_api_client/models/create_rac_model_request.py +0 -209
- orca_sdk/_generated_api_client/models/datasource_metadata.py +0 -142
- orca_sdk/_generated_api_client/models/delete_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/embed_request.py +0 -127
- orca_sdk/_generated_api_client/models/embedding_finetuning_method.py +0 -9
- orca_sdk/_generated_api_client/models/evaluation_request.py +0 -180
- orca_sdk/_generated_api_client/models/evaluation_response.py +0 -140
- orca_sdk/_generated_api_client/models/feedback_type.py +0 -9
- orca_sdk/_generated_api_client/models/field_validation_error.py +0 -103
- orca_sdk/_generated_api_client/models/filter_item.py +0 -231
- orca_sdk/_generated_api_client/models/filter_item_field_type_0_item.py +0 -15
- orca_sdk/_generated_api_client/models/filter_item_field_type_2_item_type_1.py +0 -16
- orca_sdk/_generated_api_client/models/filter_item_op.py +0 -16
- orca_sdk/_generated_api_client/models/find_duplicates_analysis_result.py +0 -70
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request.py +0 -259
- orca_sdk/_generated_api_client/models/finetune_embedding_model_request_training_args.py +0 -66
- orca_sdk/_generated_api_client/models/finetuned_embedding_model_metadata.py +0 -166
- orca_sdk/_generated_api_client/models/get_memories_request.py +0 -70
- orca_sdk/_generated_api_client/models/internal_server_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/label_class_metrics.py +0 -108
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup.py +0 -274
- orca_sdk/_generated_api_client/models/label_prediction_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/label_prediction_result.py +0 -101
- orca_sdk/_generated_api_client/models/label_prediction_with_memories_and_feedback.py +0 -232
- orca_sdk/_generated_api_client/models/labeled_memory.py +0 -197
- orca_sdk/_generated_api_client/models/labeled_memory_insert.py +0 -108
- orca_sdk/_generated_api_client/models/labeled_memory_insert_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_lookup.py +0 -258
- orca_sdk/_generated_api_client/models/labeled_memory_lookup_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metadata.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memory_metrics.py +0 -277
- orca_sdk/_generated_api_client/models/labeled_memory_update.py +0 -171
- orca_sdk/_generated_api_client/models/labeled_memory_update_metadata_type_0.py +0 -68
- orca_sdk/_generated_api_client/models/labeled_memoryset_metadata.py +0 -195
- orca_sdk/_generated_api_client/models/list_analyses_memoryset_name_or_id_analysis_get_type_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_memories_request.py +0 -104
- orca_sdk/_generated_api_client/models/list_predictions_request.py +0 -234
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_0.py +0 -9
- orca_sdk/_generated_api_client/models/list_predictions_request_sort_item_item_type_1.py +0 -9
- orca_sdk/_generated_api_client/models/lookup_request.py +0 -81
- orca_sdk/_generated_api_client/models/memoryset_analysis_request.py +0 -83
- orca_sdk/_generated_api_client/models/memoryset_analysis_request_type.py +0 -9
- orca_sdk/_generated_api_client/models/memoryset_analysis_response.py +0 -180
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_config.py +0 -66
- orca_sdk/_generated_api_client/models/memoryset_analysis_response_type.py +0 -9
- orca_sdk/_generated_api_client/models/not_found_error_response.py +0 -100
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +0 -20
- orca_sdk/_generated_api_client/models/prediction_feedback.py +0 -157
- orca_sdk/_generated_api_client/models/prediction_feedback_category.py +0 -115
- orca_sdk/_generated_api_client/models/prediction_feedback_request.py +0 -122
- orca_sdk/_generated_api_client/models/prediction_feedback_result.py +0 -102
- orca_sdk/_generated_api_client/models/prediction_request.py +0 -169
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +0 -97
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +0 -11
- orca_sdk/_generated_api_client/models/rac_head_type.py +0 -11
- orca_sdk/_generated_api_client/models/rac_model_metadata.py +0 -191
- orca_sdk/_generated_api_client/models/service_unavailable_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/task.py +0 -198
- orca_sdk/_generated_api_client/models/task_status.py +0 -14
- orca_sdk/_generated_api_client/models/task_status_info.py +0 -133
- orca_sdk/_generated_api_client/models/unauthenticated_error_response.py +0 -72
- orca_sdk/_generated_api_client/models/unauthorized_error_response.py +0 -80
- orca_sdk/_generated_api_client/models/unprocessable_input_error_response.py +0 -94
- orca_sdk/_generated_api_client/models/update_prediction_request.py +0 -93
- orca_sdk/_generated_api_client/py.typed +0 -1
- orca_sdk/_generated_api_client/types.py +0 -56
- orca_sdk/_utils/task.py +0 -73
- orca_sdk-0.1.1.dist-info/RECORD +0 -175
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""
|
|
2
|
+
IMPORTANT:
|
|
3
|
+
- This is a shared file between OrcaLib and the OrcaSDK.
|
|
4
|
+
- Please ensure that it does not have any dependencies on the OrcaLib code.
|
|
5
|
+
- Make sure to edit this file in orcalib/shared and NOT in orca_sdk, since it will be overwritten there.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Literal
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pytest
|
|
12
|
+
import sklearn.metrics
|
|
13
|
+
|
|
14
|
+
from .metrics import (
|
|
15
|
+
calculate_classification_metrics,
|
|
16
|
+
calculate_pr_curve,
|
|
17
|
+
calculate_regression_metrics,
|
|
18
|
+
calculate_roc_curve,
|
|
19
|
+
softmax,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_binary_metrics():
|
|
24
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
25
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
|
|
26
|
+
|
|
27
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
28
|
+
|
|
29
|
+
assert metrics.accuracy == 0.8
|
|
30
|
+
assert metrics.f1_score == 0.8
|
|
31
|
+
assert metrics.roc_auc is not None
|
|
32
|
+
assert metrics.roc_auc > 0.8
|
|
33
|
+
assert metrics.roc_auc < 1.0
|
|
34
|
+
assert metrics.pr_auc is not None
|
|
35
|
+
assert metrics.pr_auc > 0.8
|
|
36
|
+
assert metrics.pr_auc < 1.0
|
|
37
|
+
assert metrics.loss is not None
|
|
38
|
+
assert metrics.loss > 0.0
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_multiclass_metrics_with_2_classes():
|
|
42
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
43
|
+
y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
44
|
+
|
|
45
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
46
|
+
|
|
47
|
+
assert metrics.accuracy == 0.8
|
|
48
|
+
assert metrics.f1_score == 0.8
|
|
49
|
+
assert metrics.roc_auc is not None
|
|
50
|
+
assert metrics.roc_auc > 0.8
|
|
51
|
+
assert metrics.roc_auc < 1.0
|
|
52
|
+
assert metrics.pr_auc is not None
|
|
53
|
+
assert metrics.pr_auc > 0.8
|
|
54
|
+
assert metrics.pr_auc < 1.0
|
|
55
|
+
assert metrics.loss is not None
|
|
56
|
+
assert metrics.loss > 0.0
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.mark.parametrize(
|
|
60
|
+
"average, multiclass",
|
|
61
|
+
[("micro", "ovr"), ("macro", "ovr"), ("weighted", "ovr"), ("micro", "ovo"), ("macro", "ovo"), ("weighted", "ovo")],
|
|
62
|
+
)
|
|
63
|
+
def test_multiclass_metrics_with_3_classes(
|
|
64
|
+
average: Literal["micro", "macro", "weighted"], multiclass: Literal["ovr", "ovo"]
|
|
65
|
+
):
|
|
66
|
+
y_true = np.array([0, 1, 1, 0, 2])
|
|
67
|
+
y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
|
|
68
|
+
|
|
69
|
+
metrics = calculate_classification_metrics(y_true, y_score, average=average, multi_class=multiclass)
|
|
70
|
+
|
|
71
|
+
assert metrics.accuracy == 1.0
|
|
72
|
+
assert metrics.f1_score == 1.0
|
|
73
|
+
assert metrics.roc_auc is not None
|
|
74
|
+
assert metrics.roc_auc > 0.8
|
|
75
|
+
assert metrics.pr_auc is None
|
|
76
|
+
assert metrics.loss is not None
|
|
77
|
+
assert metrics.loss > 0.0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_does_not_modify_logits_unless_necessary():
|
|
81
|
+
logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
82
|
+
expected_labels = [0, 1, 0, 1]
|
|
83
|
+
assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
|
|
84
|
+
expected_labels, logits
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_normalizes_logits_if_necessary():
|
|
89
|
+
logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
90
|
+
expected_labels = [0, 1, 0, 1]
|
|
91
|
+
assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
|
|
92
|
+
expected_labels, logits / logits.sum(axis=1, keepdims=True)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_softmaxes_logits_if_necessary():
|
|
97
|
+
logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
98
|
+
expected_labels = [0, 1, 0, 1]
|
|
99
|
+
assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
|
|
100
|
+
expected_labels, softmax(logits)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_handles_nan_logits():
|
|
105
|
+
logits = np.array([[np.nan, np.nan], [np.nan, np.nan], [0.1, 0.9], [0.2, 0.8]])
|
|
106
|
+
expected_labels = [0, 1, 0, 1]
|
|
107
|
+
metrics = calculate_classification_metrics(expected_labels, logits)
|
|
108
|
+
assert metrics.loss is None
|
|
109
|
+
assert metrics.accuracy == 0.25
|
|
110
|
+
assert metrics.f1_score == 0.25
|
|
111
|
+
assert metrics.roc_auc is None
|
|
112
|
+
assert metrics.pr_auc is None
|
|
113
|
+
assert metrics.pr_curve is None
|
|
114
|
+
assert metrics.roc_curve is None
|
|
115
|
+
assert metrics.coverage == 0.5
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_precision_recall_curve():
|
|
119
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
120
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
121
|
+
|
|
122
|
+
pr_curve = calculate_pr_curve(y_true, y_score)
|
|
123
|
+
|
|
124
|
+
assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 6
|
|
125
|
+
assert np.allclose(pr_curve["precisions"][0], 0.6)
|
|
126
|
+
assert np.allclose(pr_curve["recalls"][0], 1.0)
|
|
127
|
+
assert np.allclose(pr_curve["precisions"][-1], 1.0)
|
|
128
|
+
assert np.allclose(pr_curve["recalls"][-1], 0.0)
|
|
129
|
+
|
|
130
|
+
# test that thresholds are sorted
|
|
131
|
+
assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_roc_curve():
|
|
135
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
136
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
137
|
+
|
|
138
|
+
roc_curve = calculate_roc_curve(y_true, y_score)
|
|
139
|
+
|
|
140
|
+
assert (
|
|
141
|
+
len(roc_curve["false_positive_rates"])
|
|
142
|
+
== len(roc_curve["true_positive_rates"])
|
|
143
|
+
== len(roc_curve["thresholds"])
|
|
144
|
+
== 6
|
|
145
|
+
)
|
|
146
|
+
assert roc_curve["false_positive_rates"][0] == 1.0
|
|
147
|
+
assert roc_curve["true_positive_rates"][0] == 1.0
|
|
148
|
+
assert roc_curve["false_positive_rates"][-1] == 0.0
|
|
149
|
+
assert roc_curve["true_positive_rates"][-1] == 0.0
|
|
150
|
+
|
|
151
|
+
# test that thresholds are sorted
|
|
152
|
+
assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def test_log_loss_handles_missing_classes_in_y_true():
|
|
156
|
+
# y_true contains only a subset of classes, but predictions include an extra class column
|
|
157
|
+
y_true = np.array([0, 1, 0, 1])
|
|
158
|
+
y_score = np.array(
|
|
159
|
+
[
|
|
160
|
+
[0.7, 0.2, 0.1],
|
|
161
|
+
[0.1, 0.8, 0.1],
|
|
162
|
+
[0.6, 0.3, 0.1],
|
|
163
|
+
[0.2, 0.7, 0.1],
|
|
164
|
+
]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
168
|
+
expected_loss = sklearn.metrics.log_loss(y_true, y_score, labels=[0, 1, 2])
|
|
169
|
+
|
|
170
|
+
assert metrics.loss is not None
|
|
171
|
+
assert np.allclose(metrics.loss, expected_loss)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def test_precision_recall_curve_max_length():
|
|
175
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
176
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
177
|
+
|
|
178
|
+
pr_curve = calculate_pr_curve(y_true, y_score, max_length=5)
|
|
179
|
+
assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 5
|
|
180
|
+
|
|
181
|
+
assert np.allclose(pr_curve["precisions"][0], 0.6)
|
|
182
|
+
assert np.allclose(pr_curve["recalls"][0], 1.0)
|
|
183
|
+
assert np.allclose(pr_curve["precisions"][-1], 1.0)
|
|
184
|
+
assert np.allclose(pr_curve["recalls"][-1], 0.0)
|
|
185
|
+
|
|
186
|
+
# test that thresholds are sorted
|
|
187
|
+
assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def test_roc_curve_max_length():
|
|
191
|
+
y_true = np.array([0, 1, 1, 0, 1])
|
|
192
|
+
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
193
|
+
|
|
194
|
+
roc_curve = calculate_roc_curve(y_true, y_score, max_length=5)
|
|
195
|
+
assert (
|
|
196
|
+
len(roc_curve["false_positive_rates"])
|
|
197
|
+
== len(roc_curve["true_positive_rates"])
|
|
198
|
+
== len(roc_curve["thresholds"])
|
|
199
|
+
== 5
|
|
200
|
+
)
|
|
201
|
+
assert np.allclose(roc_curve["false_positive_rates"][0], 1.0)
|
|
202
|
+
assert np.allclose(roc_curve["true_positive_rates"][0], 1.0)
|
|
203
|
+
assert np.allclose(roc_curve["false_positive_rates"][-1], 0.0)
|
|
204
|
+
assert np.allclose(roc_curve["true_positive_rates"][-1], 0.0)
|
|
205
|
+
|
|
206
|
+
# test that thresholds are sorted
|
|
207
|
+
assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
# Regression Metrics Tests
|
|
211
|
+
def test_perfect_regression_predictions():
|
|
212
|
+
y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
|
213
|
+
y_pred = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
|
214
|
+
|
|
215
|
+
metrics = calculate_regression_metrics(y_true, y_pred)
|
|
216
|
+
|
|
217
|
+
assert metrics.mse == 0.0
|
|
218
|
+
assert metrics.rmse == 0.0
|
|
219
|
+
assert metrics.mae == 0.0
|
|
220
|
+
assert metrics.r2 == 1.0
|
|
221
|
+
assert metrics.explained_variance == 1.0
|
|
222
|
+
assert metrics.loss == 0.0
|
|
223
|
+
assert metrics.anomaly_score_mean is None
|
|
224
|
+
assert metrics.anomaly_score_median is None
|
|
225
|
+
assert metrics.anomaly_score_variance is None
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def test_basic_regression_metrics():
|
|
229
|
+
y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
|
230
|
+
y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
|
|
231
|
+
|
|
232
|
+
metrics = calculate_regression_metrics(y_true, y_pred)
|
|
233
|
+
|
|
234
|
+
# Check that all metrics are reasonable
|
|
235
|
+
assert metrics.mse > 0.0
|
|
236
|
+
assert metrics.rmse == pytest.approx(np.sqrt(metrics.mse))
|
|
237
|
+
assert metrics.mae > 0.0
|
|
238
|
+
assert 0.0 <= metrics.r2 <= 1.0
|
|
239
|
+
assert 0.0 <= metrics.explained_variance <= 1.0
|
|
240
|
+
assert metrics.loss == metrics.mse
|
|
241
|
+
|
|
242
|
+
# Check specific values based on the data
|
|
243
|
+
expected_mse = np.mean((y_true - y_pred) ** 2)
|
|
244
|
+
assert metrics.mse == pytest.approx(expected_mse)
|
|
245
|
+
|
|
246
|
+
expected_mae = np.mean(np.abs(y_true - y_pred))
|
|
247
|
+
assert metrics.mae == pytest.approx(expected_mae)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def test_regression_metrics_with_anomaly_scores():
|
|
251
|
+
y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
|
252
|
+
y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
|
|
253
|
+
anomaly_scores = [0.1, 0.2, 0.15, 0.3, 0.25]
|
|
254
|
+
|
|
255
|
+
metrics = calculate_regression_metrics(y_true, y_pred, anomaly_scores)
|
|
256
|
+
|
|
257
|
+
assert metrics.anomaly_score_mean == pytest.approx(np.mean(anomaly_scores))
|
|
258
|
+
assert metrics.anomaly_score_median == pytest.approx(np.median(anomaly_scores))
|
|
259
|
+
assert metrics.anomaly_score_variance == pytest.approx(np.var(anomaly_scores))
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def test_regression_metrics_handles_nans():
|
|
263
|
+
y_true = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
|
264
|
+
y_pred = np.array([1.1, 1.9, np.nan], dtype=np.float32)
|
|
265
|
+
|
|
266
|
+
metrics = calculate_regression_metrics(y_true, y_pred)
|
|
267
|
+
|
|
268
|
+
assert np.allclose(metrics.coverage, 0.6666666666666666)
|
|
269
|
+
assert metrics.mse > 0.0
|
|
270
|
+
assert metrics.rmse > 0.0
|
|
271
|
+
assert metrics.mae > 0.0
|
|
272
|
+
assert 0.0 <= metrics.r2 <= 1.0
|
|
273
|
+
assert 0.0 <= metrics.explained_variance <= 1.0
|
orca_sdk/_utils/analysis_ui.py
CHANGED
|
@@ -28,9 +28,7 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
|
|
|
28
28
|
filters=[("metrics.neighbor_predicted_label_matches_current_label", "==", False)]
|
|
29
29
|
)
|
|
30
30
|
# Sort memories by confidence score (higher confidence first)
|
|
31
|
-
suggested_relabels.sort(
|
|
32
|
-
key=lambda x: (x.metrics and x.metrics.neighbor_predicted_label_confidence) or 0.0, reverse=True
|
|
33
|
-
)
|
|
31
|
+
suggested_relabels.sort(key=lambda x: (x.metrics.get("neighbor_predicted_label_confidence", 0.0)), reverse=True)
|
|
34
32
|
|
|
35
33
|
def update_approved(memory_id: str, selected: bool, current_memory_relabel_map: dict[str, RelabelStatus]):
|
|
36
34
|
current_memory_relabel_map[memory_id]["approved"] = selected
|
|
@@ -72,9 +70,9 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
|
|
|
72
70
|
current_memory_relabel_map[mem_id]["new_label"] = new_label
|
|
73
71
|
confidence = "--"
|
|
74
72
|
current_metrics = current_memory_relabel_map[mem_id]["full_memory"].metrics
|
|
75
|
-
if current_metrics and new_label == current_metrics.neighbor_predicted_label:
|
|
73
|
+
if current_metrics and new_label == current_metrics.get("neighbor_predicted_label"):
|
|
76
74
|
confidence = (
|
|
77
|
-
round(current_metrics.neighbor_predicted_label_confidence
|
|
75
|
+
round(current_metrics.get("neighbor_predicted_label_confidence", 0.0), 2) if current_metrics else 0
|
|
78
76
|
)
|
|
79
77
|
return (
|
|
80
78
|
gr.HTML(
|
|
@@ -101,8 +99,8 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
|
|
|
101
99
|
memory_id=mem.memory_id,
|
|
102
100
|
approved=False,
|
|
103
101
|
new_label=(
|
|
104
|
-
mem.metrics.neighbor_predicted_label
|
|
105
|
-
if (mem.metrics and isinstance(mem.metrics.neighbor_predicted_label, int))
|
|
102
|
+
mem.metrics.get("neighbor_predicted_label")
|
|
103
|
+
if (mem.metrics and isinstance(mem.metrics.get("neighbor_predicted_label"), int))
|
|
106
104
|
else None
|
|
107
105
|
),
|
|
108
106
|
full_memory=mem,
|
|
@@ -150,7 +148,11 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
|
|
|
150
148
|
)
|
|
151
149
|
for i, memory_relabel in enumerate(current_memory_relabel_map.values()):
|
|
152
150
|
mem = memory_relabel["full_memory"]
|
|
153
|
-
|
|
151
|
+
predicted_label = mem.metrics["neighbor_predicted_label"]
|
|
152
|
+
predicted_label_name = label_names[predicted_label]
|
|
153
|
+
predicted_label_confidence = mem.metrics.get("neighbor_predicted_label_confidence", 0)
|
|
154
|
+
|
|
155
|
+
with gr.Row(equal_height=True, variant="panel"):
|
|
154
156
|
with gr.Column(scale=9):
|
|
155
157
|
assert isinstance(mem.value, str)
|
|
156
158
|
gr.Markdown(mem.value, label="Value", height=50)
|
|
@@ -160,12 +162,12 @@ def display_suggested_memory_relabels(memoryset: LabeledMemoryset):
|
|
|
160
162
|
dropdown = gr.Dropdown(
|
|
161
163
|
choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
|
|
162
164
|
label="SuggestedLabel",
|
|
163
|
-
value=f"{
|
|
165
|
+
value=f"{predicted_label_name} ({predicted_label})",
|
|
164
166
|
interactive=True,
|
|
165
167
|
container=False,
|
|
166
168
|
)
|
|
167
169
|
confidence = gr.HTML(
|
|
168
|
-
f"<p style='font-size: 10px; color: #888;'>Confidence: {
|
|
170
|
+
f"<p style='font-size: 10px; color: #888;'>Confidence: {predicted_label_confidence:.2f}</p>",
|
|
169
171
|
elem_classes="no-padding",
|
|
170
172
|
)
|
|
171
173
|
dropdown.change(
|
orca_sdk/_utils/auth.py
CHANGED
|
@@ -2,61 +2,59 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
|
-
from typing import List
|
|
5
|
+
from typing import List, Literal
|
|
6
6
|
|
|
7
7
|
from dotenv import load_dotenv
|
|
8
8
|
|
|
9
|
-
from ..
|
|
10
|
-
|
|
11
|
-
create_api_key,
|
|
12
|
-
delete_api_key,
|
|
13
|
-
delete_org,
|
|
14
|
-
list_api_keys,
|
|
15
|
-
)
|
|
16
|
-
from .._generated_api_client.client import headers_context, set_base_url, set_headers
|
|
17
|
-
from .._generated_api_client.models import CreateApiKeyRequest
|
|
18
|
-
from .._generated_api_client.models.api_key_metadata import ApiKeyMetadata
|
|
9
|
+
from ..client import ApiKeyMetadata, orca_api
|
|
10
|
+
from ..credentials import OrcaCredentials
|
|
19
11
|
from .common import DropMode
|
|
20
12
|
|
|
21
13
|
load_dotenv() # this needs to be here to ensure env is populated before accessing it
|
|
14
|
+
|
|
15
|
+
# the defaults here must match nautilus and lighthouse config defaults
|
|
22
16
|
_ORCA_ROOT_ACCESS_API_KEY = os.environ.get("ORCA_ROOT_ACCESS_API_KEY", "00000000-0000-0000-0000-000000000000")
|
|
17
|
+
_DEFAULT_ORG_ID = os.environ.get("DEFAULT_ORG_ID", "10e50000-0000-4000-a000-a78dca14af3a")
|
|
23
18
|
|
|
24
19
|
|
|
25
|
-
def _create_api_key(org_id: str, name: str) -> str:
|
|
20
|
+
def _create_api_key(org_id: str, name: str, scopes: list[Literal["ADMINISTER", "PREDICT"]] = ["ADMINISTER"]) -> str:
|
|
26
21
|
"""Creates an API key for the given organization"""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
22
|
+
response = orca_api.POST(
|
|
23
|
+
"/auth/api_key",
|
|
24
|
+
json={"name": name, "scope": scopes},
|
|
25
|
+
headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
|
|
26
|
+
)
|
|
27
|
+
return response["api_key"]
|
|
30
28
|
|
|
31
29
|
|
|
32
30
|
def _list_api_keys(org_id: str) -> List[ApiKeyMetadata]:
|
|
33
31
|
"""Lists all API keys for the given organization"""
|
|
34
|
-
|
|
35
|
-
return list_api_keys()
|
|
32
|
+
return orca_api.GET("/auth/api_key", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
|
|
36
33
|
|
|
37
34
|
|
|
38
35
|
def _delete_api_key(org_id: str, name: str, if_not_exists: DropMode = "error") -> None:
|
|
39
36
|
"""Deletes the API key with the given name from the organization"""
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
37
|
+
try:
|
|
38
|
+
orca_api.DELETE(
|
|
39
|
+
"/auth/api_key/{name_or_id}",
|
|
40
|
+
params={"name_or_id": name},
|
|
41
|
+
headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id},
|
|
42
|
+
)
|
|
43
|
+
except LookupError:
|
|
44
|
+
if if_not_exists == "error":
|
|
45
|
+
raise
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
def _delete_org(org_id: str) -> None:
|
|
49
49
|
"""Deletes the organization"""
|
|
50
|
-
|
|
51
|
-
delete_org()
|
|
50
|
+
orca_api.DELETE("/auth/org", headers={"Api-Key": _ORCA_ROOT_ACCESS_API_KEY, "Org-Id": org_id})
|
|
52
51
|
|
|
53
52
|
|
|
54
|
-
def _authenticate_local_api(org_id: str =
|
|
53
|
+
def _authenticate_local_api(org_id: str = _DEFAULT_ORG_ID, api_key_name: str = "local") -> None:
|
|
55
54
|
"""Connect to the local API at http://localhost:1584/ and authenticate with a new API key"""
|
|
56
|
-
set_base_url("http://localhost:1584/")
|
|
57
55
|
_delete_api_key(org_id, api_key_name, if_not_exists="ignore")
|
|
58
|
-
|
|
59
|
-
|
|
56
|
+
OrcaCredentials.set_api_url("http://localhost:1584")
|
|
57
|
+
OrcaCredentials.set_api_key(_create_api_key(org_id, api_key_name))
|
|
60
58
|
logging.info(f"Authenticated against local API at 'http://localhost:1584' with '{api_key_name}' API key")
|
|
61
59
|
|
|
62
60
|
|
orca_sdk/_utils/data_parsing.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import pickle
|
|
2
2
|
from dataclasses import asdict, is_dataclass
|
|
3
3
|
from os import PathLike
|
|
4
|
+
from tempfile import TemporaryDirectory
|
|
4
5
|
from typing import Any, cast
|
|
5
6
|
|
|
6
7
|
from datasets import Dataset
|
|
@@ -40,7 +41,24 @@ def parse_batch(batch: Any, column_names: list[str] | None = None) -> list[dict]
|
|
|
40
41
|
return [{key: batch[key][idx] for key in keys} for idx in range(batch_size)]
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
def hf_dataset_from_torch(
|
|
44
|
+
def hf_dataset_from_torch(
|
|
45
|
+
torch_data: TorchDataLoader | TorchDataset, column_names: list[str] | None = None, ignore_cache=False
|
|
46
|
+
) -> Dataset:
|
|
47
|
+
"""
|
|
48
|
+
Create a HuggingFace Dataset from a PyTorch DataLoader or Dataset.
|
|
49
|
+
|
|
50
|
+
NOTE: It's important to ignore the cached files when testing (i.e., ignore_cache=Ture), because
|
|
51
|
+
cached results can ignore changes you've made to tests. This can make a test appear to succeed
|
|
52
|
+
when it's actually broken or vice versa.
|
|
53
|
+
|
|
54
|
+
Params:
|
|
55
|
+
torch_data: A PyTorch DataLoader or Dataset object to create the HuggingFace Dataset from.
|
|
56
|
+
column_names: Optional list of column names to use for the dataset. If not provided,
|
|
57
|
+
the column names will be inferred from the data.
|
|
58
|
+
ignore_cache: If True, the dataset will not be cached on disk.
|
|
59
|
+
Returns:
|
|
60
|
+
A HuggingFace Dataset object containing the data from the PyTorch DataLoader or Dataset.
|
|
61
|
+
"""
|
|
44
62
|
if isinstance(torch_data, TorchDataLoader):
|
|
45
63
|
dataloader = torch_data
|
|
46
64
|
else:
|
|
@@ -50,7 +68,15 @@ def hf_dataset_from_torch(torch_data: TorchDataLoader | TorchDataset, column_nam
|
|
|
50
68
|
for batch in dataloader:
|
|
51
69
|
yield from parse_batch(batch, column_names=column_names)
|
|
52
70
|
|
|
53
|
-
|
|
71
|
+
if ignore_cache:
|
|
72
|
+
with TemporaryDirectory() as temp_dir:
|
|
73
|
+
ds = Dataset.from_generator(generator, cache_dir=temp_dir)
|
|
74
|
+
else:
|
|
75
|
+
ds = Dataset.from_generator(generator)
|
|
76
|
+
|
|
77
|
+
if not isinstance(ds, Dataset):
|
|
78
|
+
raise ValueError(f"Failed to create dataset from generator: {type(ds)}")
|
|
79
|
+
return ds
|
|
54
80
|
|
|
55
81
|
|
|
56
82
|
def hf_dataset_from_disk(file_path: str | PathLike) -> Dataset:
|
|
@@ -29,11 +29,11 @@ class PytorchDictDataset(TorchDataset):
|
|
|
29
29
|
def test_hf_dataset_from_torch_dict():
|
|
30
30
|
# Given a Pytorch dataset that returns a dictionary for each item
|
|
31
31
|
dataset = PytorchDictDataset()
|
|
32
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
32
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
33
33
|
# Then the HF dataset should be created successfully
|
|
34
34
|
assert isinstance(hf_dataset, Dataset)
|
|
35
35
|
assert len(hf_dataset) == len(dataset)
|
|
36
|
-
assert set(hf_dataset.column_names) == {"
|
|
36
|
+
assert set(hf_dataset.column_names) == {"value", "label", "key", "score", "source_id"}
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class PytorchTupleDataset(TorchDataset):
|
|
@@ -41,7 +41,7 @@ class PytorchTupleDataset(TorchDataset):
|
|
|
41
41
|
self.data = SAMPLE_DATA
|
|
42
42
|
|
|
43
43
|
def __getitem__(self, i):
|
|
44
|
-
return self.data[i]["
|
|
44
|
+
return self.data[i]["value"], self.data[i]["label"]
|
|
45
45
|
|
|
46
46
|
def __len__(self):
|
|
47
47
|
return len(self.data)
|
|
@@ -51,11 +51,11 @@ def test_hf_dataset_from_torch_tuple():
|
|
|
51
51
|
# Given a Pytorch dataset that returns a tuple for each item
|
|
52
52
|
dataset = PytorchTupleDataset()
|
|
53
53
|
# And the correct number of column names passed in
|
|
54
|
-
hf_dataset = hf_dataset_from_torch(dataset, column_names=["
|
|
54
|
+
hf_dataset = hf_dataset_from_torch(dataset, column_names=["value", "label"], ignore_cache=True)
|
|
55
55
|
# Then the HF dataset should be created successfully
|
|
56
56
|
assert isinstance(hf_dataset, Dataset)
|
|
57
57
|
assert len(hf_dataset) == len(dataset)
|
|
58
|
-
assert hf_dataset.column_names == ["
|
|
58
|
+
assert hf_dataset.column_names == ["value", "label"]
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def test_hf_dataset_from_torch_tuple_error():
|
|
@@ -63,7 +63,7 @@ def test_hf_dataset_from_torch_tuple_error():
|
|
|
63
63
|
dataset = PytorchTupleDataset()
|
|
64
64
|
# Then the HF dataset should raise an error if no column names are passed in
|
|
65
65
|
with pytest.raises(DatasetGenerationError):
|
|
66
|
-
hf_dataset_from_torch(dataset)
|
|
66
|
+
hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
67
67
|
|
|
68
68
|
|
|
69
69
|
def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
|
|
@@ -71,7 +71,7 @@ def test_hf_dataset_from_torch_tuple_error_not_enough_columns():
|
|
|
71
71
|
dataset = PytorchTupleDataset()
|
|
72
72
|
# Then the HF dataset should raise an error if not enough column names are passed in
|
|
73
73
|
with pytest.raises(DatasetGenerationError):
|
|
74
|
-
hf_dataset_from_torch(dataset, column_names=["value"])
|
|
74
|
+
hf_dataset_from_torch(dataset, column_names=["value"], ignore_cache=True)
|
|
75
75
|
|
|
76
76
|
|
|
77
77
|
DatasetTuple = namedtuple("DatasetTuple", ["value", "label"])
|
|
@@ -82,7 +82,7 @@ class PytorchNamedTupleDataset(TorchDataset):
|
|
|
82
82
|
self.data = SAMPLE_DATA
|
|
83
83
|
|
|
84
84
|
def __getitem__(self, i):
|
|
85
|
-
return DatasetTuple(self.data[i]["
|
|
85
|
+
return DatasetTuple(self.data[i]["value"], self.data[i]["label"])
|
|
86
86
|
|
|
87
87
|
def __len__(self):
|
|
88
88
|
return len(self.data)
|
|
@@ -92,7 +92,7 @@ def test_hf_dataset_from_torch_named_tuple():
|
|
|
92
92
|
# Given a Pytorch dataset that returns a namedtuple for each item
|
|
93
93
|
dataset = PytorchNamedTupleDataset()
|
|
94
94
|
# And no column names are passed in
|
|
95
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
95
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
96
96
|
# Then the HF dataset should be created successfully
|
|
97
97
|
assert isinstance(hf_dataset, Dataset)
|
|
98
98
|
assert len(hf_dataset) == len(dataset)
|
|
@@ -110,7 +110,7 @@ class PytorchDataclassDataset(TorchDataset):
|
|
|
110
110
|
self.data = SAMPLE_DATA
|
|
111
111
|
|
|
112
112
|
def __getitem__(self, i):
|
|
113
|
-
return DatasetItem(text=self.data[i]["
|
|
113
|
+
return DatasetItem(text=self.data[i]["value"], label=self.data[i]["label"])
|
|
114
114
|
|
|
115
115
|
def __len__(self):
|
|
116
116
|
return len(self.data)
|
|
@@ -119,7 +119,7 @@ class PytorchDataclassDataset(TorchDataset):
|
|
|
119
119
|
def test_hf_dataset_from_torch_dataclass():
|
|
120
120
|
# Given a Pytorch dataset that returns a dataclass for each item
|
|
121
121
|
dataset = PytorchDataclassDataset()
|
|
122
|
-
hf_dataset = hf_dataset_from_torch(dataset)
|
|
122
|
+
hf_dataset = hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
123
123
|
# Then the HF dataset should be created successfully
|
|
124
124
|
assert isinstance(hf_dataset, Dataset)
|
|
125
125
|
assert len(hf_dataset) == len(dataset)
|
|
@@ -131,7 +131,7 @@ class PytorchInvalidDataset(TorchDataset):
|
|
|
131
131
|
self.data = SAMPLE_DATA
|
|
132
132
|
|
|
133
133
|
def __getitem__(self, i):
|
|
134
|
-
return [self.data[i]["
|
|
134
|
+
return [self.data[i]["value"], self.data[i]["label"]]
|
|
135
135
|
|
|
136
136
|
def __len__(self):
|
|
137
137
|
return len(self.data)
|
|
@@ -142,7 +142,7 @@ def test_hf_dataset_from_torch_invalid_dataset():
|
|
|
142
142
|
dataset = PytorchInvalidDataset()
|
|
143
143
|
# Then the HF dataset should raise an error
|
|
144
144
|
with pytest.raises(DatasetGenerationError):
|
|
145
|
-
hf_dataset_from_torch(dataset)
|
|
145
|
+
hf_dataset_from_torch(dataset, ignore_cache=True)
|
|
146
146
|
|
|
147
147
|
|
|
148
148
|
def test_hf_dataset_from_torchdataloader():
|
|
@@ -150,10 +150,10 @@ def test_hf_dataset_from_torchdataloader():
|
|
|
150
150
|
dataset = PytorchDictDataset()
|
|
151
151
|
|
|
152
152
|
def collate_fn(x: list[dict]):
|
|
153
|
-
return {"value": [item["
|
|
153
|
+
return {"value": [item["value"] for item in x], "label": [item["label"] for item in x]}
|
|
154
154
|
|
|
155
155
|
dataloader = TorchDataLoader(dataset, batch_size=3, collate_fn=collate_fn)
|
|
156
|
-
hf_dataset = hf_dataset_from_torch(dataloader)
|
|
156
|
+
hf_dataset = hf_dataset_from_torch(dataloader, ignore_cache=True)
|
|
157
157
|
# Then the HF dataset should be created successfully
|
|
158
158
|
assert isinstance(hf_dataset, Dataset)
|
|
159
159
|
assert len(hf_dataset) == len(dataset)
|