orca-sdk 0.0.94__py3-none-any.whl → 0.0.96__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- orca_sdk/__init__.py +13 -4
- orca_sdk/_generated_api_client/api/__init__.py +80 -34
- orca_sdk/_generated_api_client/api/classification_model/create_classification_model_gpu_classification_model_post.py +170 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
- orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
- orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
- orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
- orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
- orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
- orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
- orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
- orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
- orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
- orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
- orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
- orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
- orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
- orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
- orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
- orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
- orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
- orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
- orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
- orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_gpu_regression_model_post.py} +27 -27
- orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
- orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
- orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
- orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
- orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
- orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
- orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
- orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
- orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
- orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
- orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
- orca_sdk/_generated_api_client/models/__init__.py +84 -24
- orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
- orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
- orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
- orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
- orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
- orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
- orca_sdk/_generated_api_client/models/column_info.py +31 -0
- orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
- orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
- orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
- orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
- orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
- orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
- orca_sdk/_generated_api_client/models/memory_type.py +9 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
- orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
- orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
- orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
- orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
- orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
- orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
- orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
- orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
- orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
- orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
- orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
- orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
- orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
- orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
- orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
- orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
- orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
- orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
- orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
- orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
- orca_sdk/_shared/__init__.py +9 -1
- orca_sdk/_shared/metrics.py +257 -87
- orca_sdk/_shared/metrics_test.py +136 -77
- orca_sdk/_utils/data_parsing.py +0 -3
- orca_sdk/_utils/data_parsing_test.py +0 -3
- orca_sdk/_utils/prediction_result_ui.py +55 -23
- orca_sdk/classification_model.py +183 -172
- orca_sdk/classification_model_test.py +147 -157
- orca_sdk/conftest.py +76 -26
- orca_sdk/datasource_test.py +0 -1
- orca_sdk/embedding_model.py +136 -14
- orca_sdk/embedding_model_test.py +10 -6
- orca_sdk/job.py +329 -0
- orca_sdk/job_test.py +48 -0
- orca_sdk/memoryset.py +882 -161
- orca_sdk/memoryset_test.py +56 -23
- orca_sdk/regression_model.py +647 -0
- orca_sdk/regression_model_test.py +337 -0
- orca_sdk/telemetry.py +223 -106
- orca_sdk/telemetry_test.py +34 -30
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/METADATA +2 -4
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/RECORD +115 -69
- orca_sdk/_utils/task.py +0 -73
- {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.96.dist-info}/WHEEL +0 -0
orca_sdk/_shared/metrics_test.py
CHANGED
|
@@ -9,13 +9,13 @@ from typing import Literal
|
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import pytest
|
|
12
|
+
import sklearn.metrics
|
|
12
13
|
|
|
13
14
|
from .metrics import (
|
|
14
|
-
|
|
15
|
+
calculate_classification_metrics,
|
|
15
16
|
calculate_pr_curve,
|
|
17
|
+
calculate_regression_metrics,
|
|
16
18
|
calculate_roc_curve,
|
|
17
|
-
classification_scores,
|
|
18
|
-
compute_classifier_metrics,
|
|
19
19
|
softmax,
|
|
20
20
|
)
|
|
21
21
|
|
|
@@ -24,36 +24,36 @@ def test_binary_metrics():
|
|
|
24
24
|
y_true = np.array([0, 1, 1, 0, 1])
|
|
25
25
|
y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
|
|
26
26
|
|
|
27
|
-
metrics =
|
|
27
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
28
28
|
|
|
29
|
-
assert metrics
|
|
30
|
-
assert metrics
|
|
31
|
-
assert metrics
|
|
32
|
-
assert metrics
|
|
33
|
-
assert metrics
|
|
34
|
-
assert metrics
|
|
35
|
-
assert metrics
|
|
36
|
-
assert metrics
|
|
37
|
-
assert metrics
|
|
38
|
-
assert metrics
|
|
29
|
+
assert metrics.accuracy == 0.8
|
|
30
|
+
assert metrics.f1_score == 0.8
|
|
31
|
+
assert metrics.roc_auc is not None
|
|
32
|
+
assert metrics.roc_auc > 0.8
|
|
33
|
+
assert metrics.roc_auc < 1.0
|
|
34
|
+
assert metrics.pr_auc is not None
|
|
35
|
+
assert metrics.pr_auc > 0.8
|
|
36
|
+
assert metrics.pr_auc < 1.0
|
|
37
|
+
assert metrics.loss is not None
|
|
38
|
+
assert metrics.loss > 0.0
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def test_multiclass_metrics_with_2_classes():
|
|
42
42
|
y_true = np.array([0, 1, 1, 0, 1])
|
|
43
43
|
y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
44
44
|
|
|
45
|
-
metrics =
|
|
45
|
+
metrics = calculate_classification_metrics(y_true, y_score)
|
|
46
46
|
|
|
47
|
-
assert metrics
|
|
48
|
-
assert metrics
|
|
49
|
-
assert metrics
|
|
50
|
-
assert metrics
|
|
51
|
-
assert metrics
|
|
52
|
-
assert metrics
|
|
53
|
-
assert metrics
|
|
54
|
-
assert metrics
|
|
55
|
-
assert metrics
|
|
56
|
-
assert metrics
|
|
47
|
+
assert metrics.accuracy == 0.8
|
|
48
|
+
assert metrics.f1_score == 0.8
|
|
49
|
+
assert metrics.roc_auc is not None
|
|
50
|
+
assert metrics.roc_auc > 0.8
|
|
51
|
+
assert metrics.roc_auc < 1.0
|
|
52
|
+
assert metrics.pr_auc is not None
|
|
53
|
+
assert metrics.pr_auc > 0.8
|
|
54
|
+
assert metrics.pr_auc < 1.0
|
|
55
|
+
assert metrics.loss is not None
|
|
56
|
+
assert metrics.loss > 0.0
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
@pytest.mark.parametrize(
|
|
@@ -66,104 +66,163 @@ def test_multiclass_metrics_with_3_classes(
|
|
|
66
66
|
y_true = np.array([0, 1, 1, 0, 2])
|
|
67
67
|
y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
|
|
68
68
|
|
|
69
|
-
metrics =
|
|
69
|
+
metrics = calculate_classification_metrics(y_true, y_score, average=average, multi_class=multiclass)
|
|
70
70
|
|
|
71
|
-
assert metrics
|
|
72
|
-
assert metrics
|
|
73
|
-
assert metrics
|
|
74
|
-
assert metrics
|
|
75
|
-
assert metrics
|
|
76
|
-
assert metrics
|
|
77
|
-
assert metrics
|
|
71
|
+
assert metrics.accuracy == 1.0
|
|
72
|
+
assert metrics.f1_score == 1.0
|
|
73
|
+
assert metrics.roc_auc is not None
|
|
74
|
+
assert metrics.roc_auc > 0.8
|
|
75
|
+
assert metrics.pr_auc is None
|
|
76
|
+
assert metrics.loss is not None
|
|
77
|
+
assert metrics.loss > 0.0
|
|
78
78
|
|
|
79
79
|
|
|
80
80
|
def test_does_not_modify_logits_unless_necessary():
|
|
81
81
|
logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
expected_labels = [0, 1, 0, 1]
|
|
83
|
+
assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
|
|
84
|
+
expected_labels, logits
|
|
85
|
+
)
|
|
85
86
|
|
|
86
87
|
|
|
87
88
|
def test_normalizes_logits_if_necessary():
|
|
88
89
|
logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
|
|
90
|
+
expected_labels = [0, 1, 0, 1]
|
|
91
|
+
assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
|
|
92
|
+
expected_labels, logits / logits.sum(axis=1, keepdims=True)
|
|
93
93
|
)
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
def test_softmaxes_logits_if_necessary():
|
|
97
97
|
logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
expected_labels = [0, 1, 0, 1]
|
|
99
|
+
assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
|
|
100
|
+
expected_labels, softmax(logits)
|
|
101
|
+
)
|
|
101
102
|
|
|
102
103
|
|
|
103
104
|
def test_precision_recall_curve():
|
|
104
105
|
y_true = np.array([0, 1, 1, 0, 1])
|
|
105
106
|
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
106
107
|
|
|
107
|
-
|
|
108
|
-
assert precision is not None
|
|
109
|
-
assert recall is not None
|
|
110
|
-
assert thresholds is not None
|
|
108
|
+
pr_curve = calculate_pr_curve(y_true, y_score)
|
|
111
109
|
|
|
112
|
-
assert len(
|
|
113
|
-
assert
|
|
114
|
-
assert
|
|
115
|
-
assert
|
|
116
|
-
assert
|
|
110
|
+
assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 6
|
|
111
|
+
assert np.allclose(pr_curve["precisions"][0], 0.6)
|
|
112
|
+
assert np.allclose(pr_curve["recalls"][0], 1.0)
|
|
113
|
+
assert np.allclose(pr_curve["precisions"][-1], 1.0)
|
|
114
|
+
assert np.allclose(pr_curve["recalls"][-1], 0.0)
|
|
117
115
|
|
|
118
116
|
# test that thresholds are sorted
|
|
119
|
-
assert np.all(np.diff(thresholds) >= 0)
|
|
117
|
+
assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
|
|
120
118
|
|
|
121
119
|
|
|
122
120
|
def test_roc_curve():
|
|
123
121
|
y_true = np.array([0, 1, 1, 0, 1])
|
|
124
122
|
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
125
123
|
|
|
126
|
-
|
|
127
|
-
assert fpr is not None
|
|
128
|
-
assert tpr is not None
|
|
129
|
-
assert thresholds is not None
|
|
124
|
+
roc_curve = calculate_roc_curve(y_true, y_score)
|
|
130
125
|
|
|
131
|
-
assert
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
126
|
+
assert (
|
|
127
|
+
len(roc_curve["false_positive_rates"])
|
|
128
|
+
== len(roc_curve["true_positive_rates"])
|
|
129
|
+
== len(roc_curve["thresholds"])
|
|
130
|
+
== 6
|
|
131
|
+
)
|
|
132
|
+
assert roc_curve["false_positive_rates"][0] == 1.0
|
|
133
|
+
assert roc_curve["true_positive_rates"][0] == 1.0
|
|
134
|
+
assert roc_curve["false_positive_rates"][-1] == 0.0
|
|
135
|
+
assert roc_curve["true_positive_rates"][-1] == 0.0
|
|
136
136
|
|
|
137
137
|
# test that thresholds are sorted
|
|
138
|
-
assert np.all(np.diff(thresholds) >= 0)
|
|
138
|
+
assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
|
|
139
139
|
|
|
140
140
|
|
|
141
141
|
def test_precision_recall_curve_max_length():
|
|
142
142
|
y_true = np.array([0, 1, 1, 0, 1])
|
|
143
143
|
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
144
144
|
|
|
145
|
-
|
|
146
|
-
assert len(
|
|
145
|
+
pr_curve = calculate_pr_curve(y_true, y_score, max_length=5)
|
|
146
|
+
assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 5
|
|
147
147
|
|
|
148
|
-
assert
|
|
149
|
-
assert
|
|
150
|
-
assert
|
|
151
|
-
assert
|
|
148
|
+
assert np.allclose(pr_curve["precisions"][0], 0.6)
|
|
149
|
+
assert np.allclose(pr_curve["recalls"][0], 1.0)
|
|
150
|
+
assert np.allclose(pr_curve["precisions"][-1], 1.0)
|
|
151
|
+
assert np.allclose(pr_curve["recalls"][-1], 0.0)
|
|
152
152
|
|
|
153
153
|
# test that thresholds are sorted
|
|
154
|
-
assert np.all(np.diff(thresholds) >= 0)
|
|
154
|
+
assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
|
|
155
155
|
|
|
156
156
|
|
|
157
157
|
def test_roc_curve_max_length():
|
|
158
158
|
y_true = np.array([0, 1, 1, 0, 1])
|
|
159
159
|
y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
|
|
160
160
|
|
|
161
|
-
|
|
162
|
-
assert
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
161
|
+
roc_curve = calculate_roc_curve(y_true, y_score, max_length=5)
|
|
162
|
+
assert (
|
|
163
|
+
len(roc_curve["false_positive_rates"])
|
|
164
|
+
== len(roc_curve["true_positive_rates"])
|
|
165
|
+
== len(roc_curve["thresholds"])
|
|
166
|
+
== 5
|
|
167
|
+
)
|
|
168
|
+
assert np.allclose(roc_curve["false_positive_rates"][0], 1.0)
|
|
169
|
+
assert np.allclose(roc_curve["true_positive_rates"][0], 1.0)
|
|
170
|
+
assert np.allclose(roc_curve["false_positive_rates"][-1], 0.0)
|
|
171
|
+
assert np.allclose(roc_curve["true_positive_rates"][-1], 0.0)
|
|
167
172
|
|
|
168
173
|
# test that thresholds are sorted
|
|
169
|
-
assert np.all(np.diff(thresholds) >= 0)
|
|
174
|
+
assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# Regression Metrics Tests
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def test_perfect_regression_predictions():
|
|
181
|
+
y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
|
182
|
+
y_pred = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
|
183
|
+
|
|
184
|
+
metrics = calculate_regression_metrics(y_true, y_pred)
|
|
185
|
+
|
|
186
|
+
assert metrics.mse == 0.0
|
|
187
|
+
assert metrics.rmse == 0.0
|
|
188
|
+
assert metrics.mae == 0.0
|
|
189
|
+
assert metrics.r2 == 1.0
|
|
190
|
+
assert metrics.explained_variance == 1.0
|
|
191
|
+
assert metrics.loss == 0.0
|
|
192
|
+
assert metrics.anomaly_score_mean is None
|
|
193
|
+
assert metrics.anomaly_score_median is None
|
|
194
|
+
assert metrics.anomaly_score_variance is None
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def test_basic_regression_metrics():
|
|
198
|
+
y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
|
199
|
+
y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
|
|
200
|
+
|
|
201
|
+
metrics = calculate_regression_metrics(y_true, y_pred)
|
|
202
|
+
|
|
203
|
+
# Check that all metrics are reasonable
|
|
204
|
+
assert metrics.mse > 0.0
|
|
205
|
+
assert metrics.rmse == pytest.approx(np.sqrt(metrics.mse))
|
|
206
|
+
assert metrics.mae > 0.0
|
|
207
|
+
assert 0.0 <= metrics.r2 <= 1.0
|
|
208
|
+
assert 0.0 <= metrics.explained_variance <= 1.0
|
|
209
|
+
assert metrics.loss == metrics.mse
|
|
210
|
+
|
|
211
|
+
# Check specific values based on the data
|
|
212
|
+
expected_mse = np.mean((y_true - y_pred) ** 2)
|
|
213
|
+
assert metrics.mse == pytest.approx(expected_mse)
|
|
214
|
+
|
|
215
|
+
expected_mae = np.mean(np.abs(y_true - y_pred))
|
|
216
|
+
assert metrics.mae == pytest.approx(expected_mae)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def test_regression_metrics_with_anomaly_scores():
|
|
220
|
+
y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
|
|
221
|
+
y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
|
|
222
|
+
anomaly_scores = [0.1, 0.2, 0.15, 0.3, 0.25]
|
|
223
|
+
|
|
224
|
+
metrics = calculate_regression_metrics(y_true, y_pred, anomaly_scores)
|
|
225
|
+
|
|
226
|
+
assert metrics.anomaly_score_mean == pytest.approx(np.mean(anomaly_scores))
|
|
227
|
+
assert metrics.anomaly_score_median == pytest.approx(np.median(anomaly_scores))
|
|
228
|
+
assert metrics.anomaly_score_variance == pytest.approx(np.var(anomaly_scores))
|
orca_sdk/_utils/data_parsing.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
import pickle
|
|
3
2
|
from dataclasses import asdict, is_dataclass
|
|
4
3
|
from os import PathLike
|
|
@@ -9,8 +8,6 @@ from datasets import Dataset
|
|
|
9
8
|
from torch.utils.data import DataLoader as TorchDataLoader
|
|
10
9
|
from torch.utils.data import Dataset as TorchDataset
|
|
11
10
|
|
|
12
|
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
13
|
-
|
|
14
11
|
|
|
15
12
|
def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
|
|
16
13
|
if isinstance(item, dict):
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import logging
|
|
3
2
|
import pickle
|
|
4
3
|
import tempfile
|
|
5
4
|
from collections import namedtuple
|
|
@@ -15,8 +14,6 @@ from torch.utils.data import Dataset as TorchDataset
|
|
|
15
14
|
from ..conftest import SAMPLE_DATA
|
|
16
15
|
from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
|
|
17
16
|
|
|
18
|
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
19
|
-
|
|
20
17
|
|
|
21
18
|
class PytorchDictDataset(TorchDataset):
|
|
22
19
|
def __init__(self):
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import re
|
|
3
5
|
from pathlib import Path
|
|
@@ -5,14 +7,13 @@ from typing import TYPE_CHECKING
|
|
|
5
7
|
|
|
6
8
|
import gradio as gr
|
|
7
9
|
|
|
8
|
-
from ..memoryset import LabeledMemoryLookup
|
|
10
|
+
from ..memoryset import LabeledMemoryLookup, ScoredMemoryLookup, LabeledMemoryset
|
|
9
11
|
|
|
10
12
|
if TYPE_CHECKING:
|
|
11
|
-
from ..telemetry import
|
|
13
|
+
from ..telemetry import _Prediction
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
def inspect_prediction_result(prediction_result:
|
|
15
|
-
label_names = prediction_result.memoryset.label_names
|
|
16
|
+
def inspect_prediction_result(prediction_result: _Prediction):
|
|
16
17
|
|
|
17
18
|
def update_label(val: str, memory: LabeledMemoryLookup, progress=gr.Progress(track_tqdm=True)):
|
|
18
19
|
progress(0)
|
|
@@ -26,6 +27,12 @@ def inspect_prediction_result(prediction_result: "LabelPrediction"):
|
|
|
26
27
|
else:
|
|
27
28
|
logging.error(f"Invalid label format: {val}")
|
|
28
29
|
|
|
30
|
+
def update_score(val: float, memory: ScoredMemoryLookup, progress=gr.Progress(track_tqdm=True)):
|
|
31
|
+
progress(0)
|
|
32
|
+
memory.update(score=val)
|
|
33
|
+
progress(1)
|
|
34
|
+
return "✅ Changes saved"
|
|
35
|
+
|
|
29
36
|
with gr.Blocks(
|
|
30
37
|
fill_width=True,
|
|
31
38
|
title="Prediction Results",
|
|
@@ -33,14 +40,21 @@ def inspect_prediction_result(prediction_result: "LabelPrediction"):
|
|
|
33
40
|
) as prediction_result_ui:
|
|
34
41
|
gr.Markdown("# Prediction Results")
|
|
35
42
|
gr.Markdown(f"**Input:** {prediction_result.input_value}")
|
|
36
|
-
|
|
43
|
+
|
|
44
|
+
if isinstance(prediction_result.memoryset, LabeledMemoryset) and prediction_result.label is not None:
|
|
45
|
+
label_names = prediction_result.memoryset.label_names
|
|
46
|
+
gr.Markdown(f"**Prediction:** {label_names[prediction_result.label]} ({prediction_result.label})")
|
|
47
|
+
else:
|
|
48
|
+
gr.Markdown(f"**Prediction:** {prediction_result.score:.2f}")
|
|
49
|
+
|
|
37
50
|
gr.Markdown("### Memory Lookups")
|
|
38
51
|
|
|
39
52
|
with gr.Row(equal_height=True, variant="panel"):
|
|
40
53
|
with gr.Column(scale=7):
|
|
41
54
|
gr.Markdown("**Value**")
|
|
42
55
|
with gr.Column(scale=3, min_width=150):
|
|
43
|
-
gr.Markdown("**Label**")
|
|
56
|
+
gr.Markdown("**Label**" if prediction_result.label is not None else "**Score**")
|
|
57
|
+
|
|
44
58
|
for i, mem_lookup in enumerate(prediction_result.memory_lookups):
|
|
45
59
|
with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
|
|
46
60
|
with gr.Column(scale=7):
|
|
@@ -48,27 +62,45 @@ def inspect_prediction_result(prediction_result: "LabelPrediction"):
|
|
|
48
62
|
(
|
|
49
63
|
mem_lookup.value
|
|
50
64
|
if isinstance(mem_lookup.value, str)
|
|
51
|
-
else "Time series data"
|
|
52
|
-
if isinstance(mem_lookup.value, list)
|
|
53
|
-
else "Image data"
|
|
65
|
+
else "Time series data" if isinstance(mem_lookup.value, list) else "Image data"
|
|
54
66
|
),
|
|
55
67
|
label="Value",
|
|
56
68
|
height=50,
|
|
57
69
|
)
|
|
58
70
|
with gr.Column(scale=3, min_width=150):
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
label
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
71
|
+
if (
|
|
72
|
+
isinstance(prediction_result.memoryset, LabeledMemoryset)
|
|
73
|
+
and prediction_result.label is not None
|
|
74
|
+
and isinstance(mem_lookup, LabeledMemoryLookup)
|
|
75
|
+
):
|
|
76
|
+
label_names = prediction_result.memoryset.label_names
|
|
77
|
+
dropdown = gr.Dropdown(
|
|
78
|
+
choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
|
|
79
|
+
label="Label",
|
|
80
|
+
value=f"{label_names[mem_lookup.label]} ({mem_lookup.label})",
|
|
81
|
+
interactive=True,
|
|
82
|
+
container=False,
|
|
83
|
+
)
|
|
84
|
+
changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
|
|
85
|
+
dropdown.change(
|
|
86
|
+
lambda val, mem=mem_lookup: update_label(val, mem),
|
|
87
|
+
inputs=[dropdown],
|
|
88
|
+
outputs=[changes_saved],
|
|
89
|
+
show_progress="full",
|
|
90
|
+
)
|
|
91
|
+
elif prediction_result.score is not None and isinstance(mem_lookup, ScoredMemoryLookup):
|
|
92
|
+
input = gr.Number(
|
|
93
|
+
value=mem_lookup.score,
|
|
94
|
+
label="Score",
|
|
95
|
+
interactive=True,
|
|
96
|
+
container=False,
|
|
97
|
+
)
|
|
98
|
+
changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
|
|
99
|
+
input.change(
|
|
100
|
+
lambda val, mem=mem_lookup: update_score(val, mem),
|
|
101
|
+
inputs=[input],
|
|
102
|
+
outputs=[changes_saved],
|
|
103
|
+
show_progress="full",
|
|
104
|
+
)
|
|
73
105
|
|
|
74
106
|
prediction_result_ui.launch()
|