orca-sdk 0.0.93__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +84 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +172 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/task/list_tasks_task_get.py +60 -10
  42. orca_sdk/_generated_api_client/api/telemetry/count_predictions_telemetry_prediction_count_post.py +10 -10
  43. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  44. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  45. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  46. orca_sdk/_generated_api_client/models/__init__.py +90 -24
  47. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  48. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  49. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  50. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  51. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  52. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  53. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  54. orca_sdk/_generated_api_client/models/count_predictions_request.py +195 -0
  55. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  56. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  57. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  58. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  59. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  60. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  61. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  62. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  63. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  64. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  65. orca_sdk/_generated_api_client/models/http_validation_error.py +86 -0
  66. orca_sdk/_generated_api_client/models/list_predictions_request.py +62 -0
  67. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  68. orca_sdk/_generated_api_client/models/memoryset_analysis_configs.py +0 -20
  69. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  70. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  71. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  72. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  73. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  74. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  75. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  76. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  77. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  78. orca_sdk/_generated_api_client/models/pretrained_embedding_model_name.py +5 -0
  79. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  80. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  81. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  82. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  83. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  84. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  85. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  86. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  88. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  92. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  93. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  94. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  95. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  96. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  97. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  98. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  99. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  100. orca_sdk/_generated_api_client/models/validation_error.py +99 -0
  101. orca_sdk/_shared/__init__.py +9 -1
  102. orca_sdk/_shared/metrics.py +257 -87
  103. orca_sdk/_shared/metrics_test.py +136 -77
  104. orca_sdk/_utils/data_parsing.py +0 -3
  105. orca_sdk/_utils/data_parsing_test.py +0 -3
  106. orca_sdk/_utils/prediction_result_ui.py +55 -23
  107. orca_sdk/classification_model.py +184 -174
  108. orca_sdk/classification_model_test.py +178 -142
  109. orca_sdk/conftest.py +77 -26
  110. orca_sdk/datasource.py +34 -0
  111. orca_sdk/datasource_test.py +9 -1
  112. orca_sdk/embedding_model.py +136 -14
  113. orca_sdk/embedding_model_test.py +10 -6
  114. orca_sdk/job.py +329 -0
  115. orca_sdk/job_test.py +48 -0
  116. orca_sdk/memoryset.py +882 -161
  117. orca_sdk/memoryset_test.py +58 -23
  118. orca_sdk/regression_model.py +647 -0
  119. orca_sdk/regression_model_test.py +338 -0
  120. orca_sdk/telemetry.py +225 -106
  121. orca_sdk/telemetry_test.py +34 -30
  122. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
  123. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +124 -74
  124. orca_sdk/_utils/task.py +0 -73
  125. {orca_sdk-0.0.93.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
@@ -9,13 +9,13 @@ from typing import Literal
9
9
 
10
10
  import numpy as np
11
11
  import pytest
12
+ import sklearn.metrics
12
13
 
13
14
  from .metrics import (
14
- EvalPrediction,
15
+ calculate_classification_metrics,
15
16
  calculate_pr_curve,
17
+ calculate_regression_metrics,
16
18
  calculate_roc_curve,
17
- classification_scores,
18
- compute_classifier_metrics,
19
19
  softmax,
20
20
  )
21
21
 
@@ -24,36 +24,36 @@ def test_binary_metrics():
24
24
  y_true = np.array([0, 1, 1, 0, 1])
25
25
  y_score = np.array([0.1, 0.9, 0.8, 0.3, 0.2])
26
26
 
27
- metrics = classification_scores(y_true, y_score)
27
+ metrics = calculate_classification_metrics(y_true, y_score)
28
28
 
29
- assert metrics["accuracy"] == 0.8
30
- assert metrics["f1_score"] == 0.8
31
- assert metrics["roc_auc"] is not None
32
- assert metrics["roc_auc"] > 0.8
33
- assert metrics["roc_auc"] < 1.0
34
- assert metrics["pr_auc"] is not None
35
- assert metrics["pr_auc"] > 0.8
36
- assert metrics["pr_auc"] < 1.0
37
- assert metrics["log_loss"] is not None
38
- assert metrics["log_loss"] > 0.0
29
+ assert metrics.accuracy == 0.8
30
+ assert metrics.f1_score == 0.8
31
+ assert metrics.roc_auc is not None
32
+ assert metrics.roc_auc > 0.8
33
+ assert metrics.roc_auc < 1.0
34
+ assert metrics.pr_auc is not None
35
+ assert metrics.pr_auc > 0.8
36
+ assert metrics.pr_auc < 1.0
37
+ assert metrics.loss is not None
38
+ assert metrics.loss > 0.0
39
39
 
40
40
 
41
41
  def test_multiclass_metrics_with_2_classes():
42
42
  y_true = np.array([0, 1, 1, 0, 1])
43
43
  y_score = np.array([[0.9, 0.1], [0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
44
44
 
45
- metrics = classification_scores(y_true, y_score)
45
+ metrics = calculate_classification_metrics(y_true, y_score)
46
46
 
47
- assert metrics["accuracy"] == 0.8
48
- assert metrics["f1_score"] == 0.8
49
- assert metrics["roc_auc"] is not None
50
- assert metrics["roc_auc"] > 0.8
51
- assert metrics["roc_auc"] < 1.0
52
- assert metrics["pr_auc"] is not None
53
- assert metrics["pr_auc"] > 0.8
54
- assert metrics["pr_auc"] < 1.0
55
- assert metrics["log_loss"] is not None
56
- assert metrics["log_loss"] > 0.0
47
+ assert metrics.accuracy == 0.8
48
+ assert metrics.f1_score == 0.8
49
+ assert metrics.roc_auc is not None
50
+ assert metrics.roc_auc > 0.8
51
+ assert metrics.roc_auc < 1.0
52
+ assert metrics.pr_auc is not None
53
+ assert metrics.pr_auc > 0.8
54
+ assert metrics.pr_auc < 1.0
55
+ assert metrics.loss is not None
56
+ assert metrics.loss > 0.0
57
57
 
58
58
 
59
59
  @pytest.mark.parametrize(
@@ -66,104 +66,163 @@ def test_multiclass_metrics_with_3_classes(
66
66
  y_true = np.array([0, 1, 1, 0, 2])
67
67
  y_score = np.array([[0.9, 0.1, 0.0], [0.1, 0.9, 0.0], [0.2, 0.8, 0.0], [0.7, 0.3, 0.0], [0.0, 0.0, 1.0]])
68
68
 
69
- metrics = classification_scores(y_true, y_score, average=average, multi_class=multiclass)
69
+ metrics = calculate_classification_metrics(y_true, y_score, average=average, multi_class=multiclass)
70
70
 
71
- assert metrics["accuracy"] == 1.0
72
- assert metrics["f1_score"] == 1.0
73
- assert metrics["roc_auc"] is not None
74
- assert metrics["roc_auc"] > 0.8
75
- assert metrics["pr_auc"] is None
76
- assert metrics["log_loss"] is not None
77
- assert metrics["log_loss"] > 0.0
71
+ assert metrics.accuracy == 1.0
72
+ assert metrics.f1_score == 1.0
73
+ assert metrics.roc_auc is not None
74
+ assert metrics.roc_auc > 0.8
75
+ assert metrics.pr_auc is None
76
+ assert metrics.loss is not None
77
+ assert metrics.loss > 0.0
78
78
 
79
79
 
80
80
  def test_does_not_modify_logits_unless_necessary():
81
81
  logits = np.array([[0.1, 0.9], [0.2, 0.8], [0.7, 0.3], [0.8, 0.2]])
82
- references = np.array([0, 1, 0, 1])
83
- metrics = compute_classifier_metrics(EvalPrediction(logits, references))
84
- assert metrics["log_loss"] == classification_scores(references, logits)["log_loss"]
82
+ expected_labels = [0, 1, 0, 1]
83
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
84
+ expected_labels, logits
85
+ )
85
86
 
86
87
 
87
88
  def test_normalizes_logits_if_necessary():
88
89
  logits = np.array([[1.2, 3.9], [1.2, 5.8], [1.2, 2.7], [1.2, 1.3]])
89
- references = np.array([0, 1, 0, 1])
90
- metrics = compute_classifier_metrics(EvalPrediction(logits, references))
91
- assert (
92
- metrics["log_loss"] == classification_scores(references, logits / logits.sum(axis=1, keepdims=True))["log_loss"]
90
+ expected_labels = [0, 1, 0, 1]
91
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
92
+ expected_labels, logits / logits.sum(axis=1, keepdims=True)
93
93
  )
94
94
 
95
95
 
96
96
  def test_softmaxes_logits_if_necessary():
97
97
  logits = np.array([[-1.2, 3.9], [1.2, -5.8], [1.2, 2.7], [1.2, 1.3]])
98
- references = np.array([0, 1, 0, 1])
99
- metrics = compute_classifier_metrics(EvalPrediction(logits, references))
100
- assert metrics["log_loss"] == classification_scores(references, softmax(logits))["log_loss"]
98
+ expected_labels = [0, 1, 0, 1]
99
+ assert calculate_classification_metrics(expected_labels, logits).loss == sklearn.metrics.log_loss(
100
+ expected_labels, softmax(logits)
101
+ )
101
102
 
102
103
 
103
104
  def test_precision_recall_curve():
104
105
  y_true = np.array([0, 1, 1, 0, 1])
105
106
  y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
106
107
 
107
- precision, recall, thresholds = calculate_pr_curve(y_true, y_score)
108
- assert precision is not None
109
- assert recall is not None
110
- assert thresholds is not None
108
+ pr_curve = calculate_pr_curve(y_true, y_score)
111
109
 
112
- assert len(precision) == len(recall) == len(thresholds) == 6
113
- assert precision[0] == 0.6
114
- assert recall[0] == 1.0
115
- assert precision[-1] == 1.0
116
- assert recall[-1] == 0.0
110
+ assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 6
111
+ assert np.allclose(pr_curve["precisions"][0], 0.6)
112
+ assert np.allclose(pr_curve["recalls"][0], 1.0)
113
+ assert np.allclose(pr_curve["precisions"][-1], 1.0)
114
+ assert np.allclose(pr_curve["recalls"][-1], 0.0)
117
115
 
118
116
  # test that thresholds are sorted
119
- assert np.all(np.diff(thresholds) >= 0)
117
+ assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
120
118
 
121
119
 
122
120
  def test_roc_curve():
123
121
  y_true = np.array([0, 1, 1, 0, 1])
124
122
  y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
125
123
 
126
- fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score)
127
- assert fpr is not None
128
- assert tpr is not None
129
- assert thresholds is not None
124
+ roc_curve = calculate_roc_curve(y_true, y_score)
130
125
 
131
- assert len(fpr) == len(tpr) == len(thresholds) == 6
132
- assert fpr[0] == 1.0
133
- assert tpr[0] == 1.0
134
- assert fpr[-1] == 0.0
135
- assert tpr[-1] == 0.0
126
+ assert (
127
+ len(roc_curve["false_positive_rates"])
128
+ == len(roc_curve["true_positive_rates"])
129
+ == len(roc_curve["thresholds"])
130
+ == 6
131
+ )
132
+ assert roc_curve["false_positive_rates"][0] == 1.0
133
+ assert roc_curve["true_positive_rates"][0] == 1.0
134
+ assert roc_curve["false_positive_rates"][-1] == 0.0
135
+ assert roc_curve["true_positive_rates"][-1] == 0.0
136
136
 
137
137
  # test that thresholds are sorted
138
- assert np.all(np.diff(thresholds) >= 0)
138
+ assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
139
139
 
140
140
 
141
141
  def test_precision_recall_curve_max_length():
142
142
  y_true = np.array([0, 1, 1, 0, 1])
143
143
  y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
144
144
 
145
- precision, recall, thresholds = calculate_pr_curve(y_true, y_score, max_length=5)
146
- assert len(precision) == len(recall) == len(thresholds) == 5
145
+ pr_curve = calculate_pr_curve(y_true, y_score, max_length=5)
146
+ assert len(pr_curve["precisions"]) == len(pr_curve["recalls"]) == len(pr_curve["thresholds"]) == 5
147
147
 
148
- assert precision[0] == 0.6
149
- assert recall[0] == 1.0
150
- assert precision[-1] == 1.0
151
- assert recall[-1] == 0.0
148
+ assert np.allclose(pr_curve["precisions"][0], 0.6)
149
+ assert np.allclose(pr_curve["recalls"][0], 1.0)
150
+ assert np.allclose(pr_curve["precisions"][-1], 1.0)
151
+ assert np.allclose(pr_curve["recalls"][-1], 0.0)
152
152
 
153
153
  # test that thresholds are sorted
154
- assert np.all(np.diff(thresholds) >= 0)
154
+ assert np.all(np.diff(pr_curve["thresholds"]) >= 0)
155
155
 
156
156
 
157
157
  def test_roc_curve_max_length():
158
158
  y_true = np.array([0, 1, 1, 0, 1])
159
159
  y_score = np.array([0.1, 0.9, 0.8, 0.6, 0.2])
160
160
 
161
- fpr, tpr, thresholds = calculate_roc_curve(y_true, y_score, max_length=5)
162
- assert len(fpr) == len(tpr) == len(thresholds) == 5
163
- assert fpr[0] == 1.0
164
- assert tpr[0] == 1.0
165
- assert fpr[-1] == 0.0
166
- assert tpr[-1] == 0.0
161
+ roc_curve = calculate_roc_curve(y_true, y_score, max_length=5)
162
+ assert (
163
+ len(roc_curve["false_positive_rates"])
164
+ == len(roc_curve["true_positive_rates"])
165
+ == len(roc_curve["thresholds"])
166
+ == 5
167
+ )
168
+ assert np.allclose(roc_curve["false_positive_rates"][0], 1.0)
169
+ assert np.allclose(roc_curve["true_positive_rates"][0], 1.0)
170
+ assert np.allclose(roc_curve["false_positive_rates"][-1], 0.0)
171
+ assert np.allclose(roc_curve["true_positive_rates"][-1], 0.0)
167
172
 
168
173
  # test that thresholds are sorted
169
- assert np.all(np.diff(thresholds) >= 0)
174
+ assert np.all(np.diff(roc_curve["thresholds"]) >= 0)
175
+
176
+
177
+ # Regression Metrics Tests
178
+
179
+
180
+ def test_perfect_regression_predictions():
181
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
182
+ y_pred = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
183
+
184
+ metrics = calculate_regression_metrics(y_true, y_pred)
185
+
186
+ assert metrics.mse == 0.0
187
+ assert metrics.rmse == 0.0
188
+ assert metrics.mae == 0.0
189
+ assert metrics.r2 == 1.0
190
+ assert metrics.explained_variance == 1.0
191
+ assert metrics.loss == 0.0
192
+ assert metrics.anomaly_score_mean is None
193
+ assert metrics.anomaly_score_median is None
194
+ assert metrics.anomaly_score_variance is None
195
+
196
+
197
+ def test_basic_regression_metrics():
198
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
199
+ y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
200
+
201
+ metrics = calculate_regression_metrics(y_true, y_pred)
202
+
203
+ # Check that all metrics are reasonable
204
+ assert metrics.mse > 0.0
205
+ assert metrics.rmse == pytest.approx(np.sqrt(metrics.mse))
206
+ assert metrics.mae > 0.0
207
+ assert 0.0 <= metrics.r2 <= 1.0
208
+ assert 0.0 <= metrics.explained_variance <= 1.0
209
+ assert metrics.loss == metrics.mse
210
+
211
+ # Check specific values based on the data
212
+ expected_mse = np.mean((y_true - y_pred) ** 2)
213
+ assert metrics.mse == pytest.approx(expected_mse)
214
+
215
+ expected_mae = np.mean(np.abs(y_true - y_pred))
216
+ assert metrics.mae == pytest.approx(expected_mae)
217
+
218
+
219
+ def test_regression_metrics_with_anomaly_scores():
220
+ y_true = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
221
+ y_pred = np.array([1.1, 1.9, 3.2, 3.8, 5.1], dtype=np.float32)
222
+ anomaly_scores = [0.1, 0.2, 0.15, 0.3, 0.25]
223
+
224
+ metrics = calculate_regression_metrics(y_true, y_pred, anomaly_scores)
225
+
226
+ assert metrics.anomaly_score_mean == pytest.approx(np.mean(anomaly_scores))
227
+ assert metrics.anomaly_score_median == pytest.approx(np.median(anomaly_scores))
228
+ assert metrics.anomaly_score_variance == pytest.approx(np.var(anomaly_scores))
@@ -1,4 +1,3 @@
1
- import logging
2
1
  import pickle
3
2
  from dataclasses import asdict, is_dataclass
4
3
  from os import PathLike
@@ -9,8 +8,6 @@ from datasets import Dataset
9
8
  from torch.utils.data import DataLoader as TorchDataLoader
10
9
  from torch.utils.data import Dataset as TorchDataset
11
10
 
12
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
13
-
14
11
 
15
12
  def parse_dict_like(item: Any, column_names: list[str] | None = None) -> dict:
16
13
  if isinstance(item, dict):
@@ -1,5 +1,4 @@
1
1
  import json
2
- import logging
3
2
  import pickle
4
3
  import tempfile
5
4
  from collections import namedtuple
@@ -15,8 +14,6 @@ from torch.utils.data import Dataset as TorchDataset
15
14
  from ..conftest import SAMPLE_DATA
16
15
  from .data_parsing import hf_dataset_from_disk, hf_dataset_from_torch
17
16
 
18
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
19
-
20
17
 
21
18
  class PytorchDictDataset(TorchDataset):
22
19
  def __init__(self):
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import re
3
5
  from pathlib import Path
@@ -5,14 +7,13 @@ from typing import TYPE_CHECKING
5
7
 
6
8
  import gradio as gr
7
9
 
8
- from ..memoryset import LabeledMemoryLookup
10
+ from ..memoryset import LabeledMemoryLookup, ScoredMemoryLookup, LabeledMemoryset
9
11
 
10
12
  if TYPE_CHECKING:
11
- from ..telemetry import LabelPrediction
13
+ from ..telemetry import _Prediction
12
14
 
13
15
 
14
- def inspect_prediction_result(prediction_result: "LabelPrediction"):
15
- label_names = prediction_result.memoryset.label_names
16
+ def inspect_prediction_result(prediction_result: _Prediction):
16
17
 
17
18
  def update_label(val: str, memory: LabeledMemoryLookup, progress=gr.Progress(track_tqdm=True)):
18
19
  progress(0)
@@ -26,6 +27,12 @@ def inspect_prediction_result(prediction_result: "LabelPrediction"):
26
27
  else:
27
28
  logging.error(f"Invalid label format: {val}")
28
29
 
30
+ def update_score(val: float, memory: ScoredMemoryLookup, progress=gr.Progress(track_tqdm=True)):
31
+ progress(0)
32
+ memory.update(score=val)
33
+ progress(1)
34
+ return "&#9989; Changes saved"
35
+
29
36
  with gr.Blocks(
30
37
  fill_width=True,
31
38
  title="Prediction Results",
@@ -33,14 +40,21 @@ def inspect_prediction_result(prediction_result: "LabelPrediction"):
33
40
  ) as prediction_result_ui:
34
41
  gr.Markdown("# Prediction Results")
35
42
  gr.Markdown(f"**Input:** {prediction_result.input_value}")
36
- gr.Markdown(f"**Prediction:** {label_names[prediction_result.label]} ({prediction_result.label})")
43
+
44
+ if isinstance(prediction_result.memoryset, LabeledMemoryset) and prediction_result.label is not None:
45
+ label_names = prediction_result.memoryset.label_names
46
+ gr.Markdown(f"**Prediction:** {label_names[prediction_result.label]} ({prediction_result.label})")
47
+ else:
48
+ gr.Markdown(f"**Prediction:** {prediction_result.score:.2f}")
49
+
37
50
  gr.Markdown("### Memory Lookups")
38
51
 
39
52
  with gr.Row(equal_height=True, variant="panel"):
40
53
  with gr.Column(scale=7):
41
54
  gr.Markdown("**Value**")
42
55
  with gr.Column(scale=3, min_width=150):
43
- gr.Markdown("**Label**")
56
+ gr.Markdown("**Label**" if prediction_result.label is not None else "**Score**")
57
+
44
58
  for i, mem_lookup in enumerate(prediction_result.memory_lookups):
45
59
  with gr.Row(equal_height=True, variant="panel", elem_classes="white" if i % 2 == 0 else None):
46
60
  with gr.Column(scale=7):
@@ -48,27 +62,45 @@ def inspect_prediction_result(prediction_result: "LabelPrediction"):
48
62
  (
49
63
  mem_lookup.value
50
64
  if isinstance(mem_lookup.value, str)
51
- else "Time series data"
52
- if isinstance(mem_lookup.value, list)
53
- else "Image data"
65
+ else "Time series data" if isinstance(mem_lookup.value, list) else "Image data"
54
66
  ),
55
67
  label="Value",
56
68
  height=50,
57
69
  )
58
70
  with gr.Column(scale=3, min_width=150):
59
- dropdown = gr.Dropdown(
60
- choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
61
- label="Label",
62
- value=f"{label_names[mem_lookup.label]} ({mem_lookup.label})",
63
- interactive=True,
64
- container=False,
65
- )
66
- changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
67
- dropdown.change(
68
- lambda val, mem_lookup=mem_lookup: update_label(val, mem_lookup),
69
- inputs=[dropdown],
70
- outputs=[changes_saved],
71
- show_progress="full",
72
- )
71
+ if (
72
+ isinstance(prediction_result.memoryset, LabeledMemoryset)
73
+ and prediction_result.label is not None
74
+ and isinstance(mem_lookup, LabeledMemoryLookup)
75
+ ):
76
+ label_names = prediction_result.memoryset.label_names
77
+ dropdown = gr.Dropdown(
78
+ choices=[f"{label_name} ({i})" for i, label_name in enumerate(label_names)],
79
+ label="Label",
80
+ value=f"{label_names[mem_lookup.label]} ({mem_lookup.label})",
81
+ interactive=True,
82
+ container=False,
83
+ )
84
+ changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
85
+ dropdown.change(
86
+ lambda val, mem=mem_lookup: update_label(val, mem),
87
+ inputs=[dropdown],
88
+ outputs=[changes_saved],
89
+ show_progress="full",
90
+ )
91
+ elif prediction_result.score is not None and isinstance(mem_lookup, ScoredMemoryLookup):
92
+ input = gr.Number(
93
+ value=mem_lookup.score,
94
+ label="Score",
95
+ interactive=True,
96
+ container=False,
97
+ )
98
+ changes_saved = gr.HTML(lambda: "", elem_classes="success no-padding", every=15)
99
+ input.change(
100
+ lambda val, mem=mem_lookup: update_score(val, mem),
101
+ inputs=[input],
102
+ outputs=[changes_saved],
103
+ show_progress="full",
104
+ )
73
105
 
74
106
  prediction_result_ui.launch()