orca-sdk 0.0.94__py3-none-any.whl → 0.0.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. orca_sdk/__init__.py +13 -4
  2. orca_sdk/_generated_api_client/api/__init__.py +80 -34
  3. orca_sdk/_generated_api_client/api/classification_model/create_classification_model_classification_model_post.py +170 -0
  4. orca_sdk/_generated_api_client/api/classification_model/{get_model_classification_model_name_or_id_get.py → delete_classification_model_classification_model_name_or_id_delete.py} +20 -20
  5. orca_sdk/_generated_api_client/api/classification_model/{delete_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py → delete_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_delete.py} +4 -4
  6. orca_sdk/_generated_api_client/api/classification_model/{create_evaluation_classification_model_model_name_or_id_evaluation_post.py → evaluate_classification_model_classification_model_model_name_or_id_evaluation_post.py} +14 -14
  7. orca_sdk/_generated_api_client/api/classification_model/get_classification_model_classification_model_name_or_id_get.py +156 -0
  8. orca_sdk/_generated_api_client/api/classification_model/{get_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py → get_classification_model_evaluation_classification_model_model_name_or_id_evaluation_task_id_get.py} +16 -16
  9. orca_sdk/_generated_api_client/api/classification_model/{list_evaluations_classification_model_model_name_or_id_evaluation_get.py → list_classification_model_evaluations_classification_model_model_name_or_id_evaluation_get.py} +16 -16
  10. orca_sdk/_generated_api_client/api/classification_model/list_classification_models_classification_model_get.py +127 -0
  11. orca_sdk/_generated_api_client/api/classification_model/{predict_gpu_classification_model_name_or_id_prediction_post.py → predict_label_gpu_classification_model_name_or_id_prediction_post.py} +14 -14
  12. orca_sdk/_generated_api_client/api/classification_model/update_classification_model_classification_model_name_or_id_patch.py +183 -0
  13. orca_sdk/_generated_api_client/api/datasource/download_datasource_datasource_name_or_id_download_get.py +24 -0
  14. orca_sdk/_generated_api_client/api/memoryset/clone_memoryset_memoryset_name_or_id_clone_post.py +22 -22
  15. orca_sdk/_generated_api_client/api/memoryset/create_memoryset_memoryset_post.py +22 -22
  16. orca_sdk/_generated_api_client/api/memoryset/get_memories_memoryset_name_or_id_memories_get_post.py +38 -16
  17. orca_sdk/_generated_api_client/api/memoryset/get_memory_memoryset_name_or_id_memory_memory_id_get.py +29 -12
  18. orca_sdk/_generated_api_client/api/memoryset/get_memoryset_memoryset_name_or_id_get.py +12 -12
  19. orca_sdk/_generated_api_client/api/memoryset/insert_memories_gpu_memoryset_name_or_id_memory_post.py +17 -14
  20. orca_sdk/_generated_api_client/api/memoryset/list_memorysets_memoryset_get.py +72 -19
  21. orca_sdk/_generated_api_client/api/memoryset/memoryset_lookup_gpu_memoryset_name_or_id_lookup_post.py +31 -12
  22. orca_sdk/_generated_api_client/api/memoryset/potential_duplicate_groups_memoryset_name_or_id_potential_duplicate_groups_get.py +49 -20
  23. orca_sdk/_generated_api_client/api/memoryset/query_memoryset_memoryset_name_or_id_memories_post.py +38 -16
  24. orca_sdk/_generated_api_client/api/memoryset/update_memories_gpu_memoryset_name_or_id_memories_patch.py +54 -29
  25. orca_sdk/_generated_api_client/api/memoryset/update_memory_gpu_memoryset_name_or_id_memory_patch.py +44 -26
  26. orca_sdk/_generated_api_client/api/memoryset/update_memoryset_memoryset_name_or_id_patch.py +22 -22
  27. orca_sdk/_generated_api_client/api/predictive_model/__init__.py +0 -0
  28. orca_sdk/_generated_api_client/api/predictive_model/list_predictive_models_predictive_model_get.py +150 -0
  29. orca_sdk/_generated_api_client/api/regression_model/__init__.py +0 -0
  30. orca_sdk/_generated_api_client/api/{classification_model/create_model_classification_model_post.py → regression_model/create_regression_model_regression_model_post.py} +27 -27
  31. orca_sdk/_generated_api_client/api/regression_model/delete_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_delete.py +168 -0
  32. orca_sdk/_generated_api_client/api/{classification_model/delete_model_classification_model_name_or_id_delete.py → regression_model/delete_regression_model_regression_model_name_or_id_delete.py} +5 -5
  33. orca_sdk/_generated_api_client/api/regression_model/evaluate_regression_model_regression_model_model_name_or_id_evaluation_post.py +183 -0
  34. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_evaluation_regression_model_model_name_or_id_evaluation_task_id_get.py +170 -0
  35. orca_sdk/_generated_api_client/api/regression_model/get_regression_model_regression_model_name_or_id_get.py +156 -0
  36. orca_sdk/_generated_api_client/api/regression_model/list_regression_model_evaluations_regression_model_model_name_or_id_evaluation_get.py +161 -0
  37. orca_sdk/_generated_api_client/api/{classification_model/list_models_classification_model_get.py → regression_model/list_regression_models_regression_model_get.py} +17 -17
  38. orca_sdk/_generated_api_client/api/regression_model/predict_score_gpu_regression_model_name_or_id_prediction_post.py +190 -0
  39. orca_sdk/_generated_api_client/api/{classification_model/update_model_classification_model_name_or_id_patch.py → regression_model/update_regression_model_regression_model_name_or_id_patch.py} +27 -27
  40. orca_sdk/_generated_api_client/api/task/get_task_task_task_id_get.py +156 -0
  41. orca_sdk/_generated_api_client/api/telemetry/get_prediction_telemetry_prediction_prediction_id_get.py +35 -12
  42. orca_sdk/_generated_api_client/api/telemetry/list_memories_with_feedback_telemetry_memories_post.py +20 -12
  43. orca_sdk/_generated_api_client/api/telemetry/list_predictions_telemetry_prediction_post.py +35 -12
  44. orca_sdk/_generated_api_client/models/__init__.py +84 -24
  45. orca_sdk/_generated_api_client/models/base_score_prediction_result.py +108 -0
  46. orca_sdk/_generated_api_client/models/{evaluation_request.py → classification_evaluation_request.py} +13 -45
  47. orca_sdk/_generated_api_client/models/{classification_evaluation_result.py → classification_metrics.py} +106 -56
  48. orca_sdk/_generated_api_client/models/{rac_model_metadata.py → classification_model_metadata.py} +51 -43
  49. orca_sdk/_generated_api_client/models/{prediction_request.py → classification_prediction_request.py} +31 -6
  50. orca_sdk/_generated_api_client/models/{clone_labeled_memoryset_request.py → clone_memoryset_request.py} +5 -5
  51. orca_sdk/_generated_api_client/models/column_info.py +31 -0
  52. orca_sdk/_generated_api_client/models/{create_rac_model_request.py → create_classification_model_request.py} +25 -57
  53. orca_sdk/_generated_api_client/models/{create_labeled_memoryset_request.py → create_memoryset_request.py} +73 -56
  54. orca_sdk/_generated_api_client/models/create_memoryset_request_index_params.py +66 -0
  55. orca_sdk/_generated_api_client/models/create_memoryset_request_index_type.py +13 -0
  56. orca_sdk/_generated_api_client/models/create_regression_model_request.py +137 -0
  57. orca_sdk/_generated_api_client/models/embedding_evaluation_payload.py +187 -0
  58. orca_sdk/_generated_api_client/models/embedding_evaluation_response.py +10 -0
  59. orca_sdk/_generated_api_client/models/evaluation_response.py +22 -9
  60. orca_sdk/_generated_api_client/models/evaluation_response_classification_metrics.py +140 -0
  61. orca_sdk/_generated_api_client/models/evaluation_response_regression_metrics.py +140 -0
  62. orca_sdk/_generated_api_client/models/memory_type.py +9 -0
  63. orca_sdk/_generated_api_client/models/{labeled_memoryset_metadata.py → memoryset_metadata.py} +73 -13
  64. orca_sdk/_generated_api_client/models/memoryset_metadata_index_params.py +55 -0
  65. orca_sdk/_generated_api_client/models/memoryset_metadata_index_type.py +13 -0
  66. orca_sdk/_generated_api_client/models/{labeled_memoryset_update.py → memoryset_update.py} +19 -31
  67. orca_sdk/_generated_api_client/models/not_found_error_response_resource_type_0.py +1 -0
  68. orca_sdk/_generated_api_client/models/{paginated_labeled_memory_with_feedback_metrics.py → paginated_union_labeled_memory_with_feedback_metrics_scored_memory_with_feedback_metrics.py} +37 -10
  69. orca_sdk/_generated_api_client/models/{precision_recall_curve.py → pr_curve.py} +5 -13
  70. orca_sdk/_generated_api_client/models/{rac_model_update.py → predictive_model_update.py} +14 -5
  71. orca_sdk/_generated_api_client/models/pretrained_embedding_model_metadata.py +11 -1
  72. orca_sdk/_generated_api_client/models/rar_head_type.py +8 -0
  73. orca_sdk/_generated_api_client/models/regression_evaluation_request.py +148 -0
  74. orca_sdk/_generated_api_client/models/regression_metrics.py +172 -0
  75. orca_sdk/_generated_api_client/models/regression_model_metadata.py +177 -0
  76. orca_sdk/_generated_api_client/models/regression_prediction_request.py +195 -0
  77. orca_sdk/_generated_api_client/models/roc_curve.py +0 -8
  78. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup.py +196 -0
  79. orca_sdk/_generated_api_client/models/score_prediction_memory_lookup_metadata.py +68 -0
  80. orca_sdk/_generated_api_client/models/score_prediction_with_memories_and_feedback.py +252 -0
  81. orca_sdk/_generated_api_client/models/scored_memory.py +172 -0
  82. orca_sdk/_generated_api_client/models/scored_memory_insert.py +128 -0
  83. orca_sdk/_generated_api_client/models/scored_memory_insert_metadata.py +68 -0
  84. orca_sdk/_generated_api_client/models/scored_memory_lookup.py +180 -0
  85. orca_sdk/_generated_api_client/models/scored_memory_lookup_metadata.py +68 -0
  86. orca_sdk/_generated_api_client/models/scored_memory_metadata.py +68 -0
  87. orca_sdk/_generated_api_client/models/scored_memory_update.py +171 -0
  88. orca_sdk/_generated_api_client/models/scored_memory_update_metadata_type_0.py +68 -0
  89. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics.py +193 -0
  90. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_feedback_metrics.py +68 -0
  91. orca_sdk/_generated_api_client/models/scored_memory_with_feedback_metrics_metadata.py +68 -0
  92. orca_sdk/_generated_api_client/models/update_prediction_request.py +20 -0
  93. orca_sdk/_shared/__init__.py +9 -1
  94. orca_sdk/_shared/metrics.py +257 -87
  95. orca_sdk/_shared/metrics_test.py +136 -77
  96. orca_sdk/_utils/data_parsing.py +0 -3
  97. orca_sdk/_utils/data_parsing_test.py +0 -3
  98. orca_sdk/_utils/prediction_result_ui.py +55 -23
  99. orca_sdk/classification_model.py +183 -175
  100. orca_sdk/classification_model_test.py +147 -157
  101. orca_sdk/conftest.py +76 -26
  102. orca_sdk/datasource_test.py +0 -1
  103. orca_sdk/embedding_model.py +136 -14
  104. orca_sdk/embedding_model_test.py +10 -6
  105. orca_sdk/job.py +329 -0
  106. orca_sdk/job_test.py +48 -0
  107. orca_sdk/memoryset.py +882 -161
  108. orca_sdk/memoryset_test.py +56 -23
  109. orca_sdk/regression_model.py +647 -0
  110. orca_sdk/regression_model_test.py +338 -0
  111. orca_sdk/telemetry.py +223 -106
  112. orca_sdk/telemetry_test.py +34 -30
  113. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/METADATA +2 -4
  114. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/RECORD +115 -69
  115. orca_sdk/_utils/task.py +0 -73
  116. {orca_sdk-0.0.94.dist-info → orca_sdk-0.0.95.dist-info}/WHEEL +0 -0
@@ -0,0 +1,108 @@
1
+ """
2
+ This file is generated by the openapi-python-client tool via the generate_api_client.py script
3
+
4
+ It is a customized template from the openapi-python-client tool's default template:
5
+ https://github.com/openapi-generators/openapi-python-client/blob/861ef5622f10fc96d240dc9becb0edf94e61446c/openapi_python_client/templates/model.py.jinja
6
+
7
+ The main change is:
8
+ - Fix typing issues
9
+ """
10
+
11
+ # flake8: noqa: C901
12
+
13
+ from typing import Any, Type, TypeVar, Union, cast
14
+
15
+ from attrs import define as _attrs_define
16
+ from attrs import field as _attrs_field
17
+
18
+ T = TypeVar("T", bound="BaseScorePredictionResult")
19
+
20
+
21
+ @_attrs_define
22
+ class BaseScorePredictionResult:
23
+ """Predicted score and confidence for a single input.
24
+
25
+ Attributes:
26
+ prediction_id (Union[None, str]):
27
+ confidence (float):
28
+ anomaly_score (Union[None, float]):
29
+ score (float):
30
+ """
31
+
32
+ prediction_id: Union[None, str]
33
+ confidence: float
34
+ anomaly_score: Union[None, float]
35
+ score: float
36
+ additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
37
+
38
+ def to_dict(self) -> dict[str, Any]:
39
+ prediction_id: Union[None, str]
40
+ prediction_id = self.prediction_id
41
+
42
+ confidence = self.confidence
43
+
44
+ anomaly_score: Union[None, float]
45
+ anomaly_score = self.anomaly_score
46
+
47
+ score = self.score
48
+
49
+ field_dict: dict[str, Any] = {}
50
+ field_dict.update(self.additional_properties)
51
+ field_dict.update(
52
+ {
53
+ "prediction_id": prediction_id,
54
+ "confidence": confidence,
55
+ "anomaly_score": anomaly_score,
56
+ "score": score,
57
+ }
58
+ )
59
+
60
+ return field_dict
61
+
62
+ @classmethod
63
+ def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
64
+ d = src_dict.copy()
65
+
66
+ def _parse_prediction_id(data: object) -> Union[None, str]:
67
+ if data is None:
68
+ return data
69
+ return cast(Union[None, str], data)
70
+
71
+ prediction_id = _parse_prediction_id(d.pop("prediction_id"))
72
+
73
+ confidence = d.pop("confidence")
74
+
75
+ def _parse_anomaly_score(data: object) -> Union[None, float]:
76
+ if data is None:
77
+ return data
78
+ return cast(Union[None, float], data)
79
+
80
+ anomaly_score = _parse_anomaly_score(d.pop("anomaly_score"))
81
+
82
+ score = d.pop("score")
83
+
84
+ base_score_prediction_result = cls(
85
+ prediction_id=prediction_id,
86
+ confidence=confidence,
87
+ anomaly_score=anomaly_score,
88
+ score=score,
89
+ )
90
+
91
+ base_score_prediction_result.additional_properties = d
92
+ return base_score_prediction_result
93
+
94
+ @property
95
+ def additional_keys(self) -> list[str]:
96
+ return list(self.additional_properties.keys())
97
+
98
+ def __getitem__(self, key: str) -> Any:
99
+ return self.additional_properties[key]
100
+
101
+ def __setitem__(self, key: str, value: Any) -> None:
102
+ self.additional_properties[key] = value
103
+
104
+ def __delitem__(self, key: str) -> None:
105
+ del self.additional_properties[key]
106
+
107
+ def __contains__(self, key: str) -> bool:
108
+ return key in self.additional_properties
@@ -17,48 +17,36 @@ from attrs import field as _attrs_field
17
17
 
18
18
  from ..types import UNSET, Unset
19
19
 
20
- T = TypeVar("T", bound="EvaluationRequest")
20
+ T = TypeVar("T", bound="ClassificationEvaluationRequest")
21
21
 
22
22
 
23
23
  @_attrs_define
24
- class EvaluationRequest:
24
+ class ClassificationEvaluationRequest:
25
25
  """
26
26
  Attributes:
27
+ datasource_id (str):
27
28
  datasource_label_column (str):
28
29
  datasource_value_column (str):
29
- datasource_id (Union[None, Unset, str]):
30
- datasource_name (Union[None, Unset, str]):
31
30
  memoryset_override_id (Union[None, Unset, str]):
32
31
  record_telemetry (Union[Unset, bool]): Default: False.
33
32
  telemetry_tags (Union[List[str], None, Unset]):
34
33
  """
35
34
 
35
+ datasource_id: str
36
36
  datasource_label_column: str
37
37
  datasource_value_column: str
38
- datasource_id: Union[None, Unset, str] = UNSET
39
- datasource_name: Union[None, Unset, str] = UNSET
40
38
  memoryset_override_id: Union[None, Unset, str] = UNSET
41
39
  record_telemetry: Union[Unset, bool] = False
42
40
  telemetry_tags: Union[List[str], None, Unset] = UNSET
43
41
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
44
42
 
45
43
  def to_dict(self) -> dict[str, Any]:
44
+ datasource_id = self.datasource_id
45
+
46
46
  datasource_label_column = self.datasource_label_column
47
47
 
48
48
  datasource_value_column = self.datasource_value_column
49
49
 
50
- datasource_id: Union[None, Unset, str]
51
- if isinstance(self.datasource_id, Unset):
52
- datasource_id = UNSET
53
- else:
54
- datasource_id = self.datasource_id
55
-
56
- datasource_name: Union[None, Unset, str]
57
- if isinstance(self.datasource_name, Unset):
58
- datasource_name = UNSET
59
- else:
60
- datasource_name = self.datasource_name
61
-
62
50
  memoryset_override_id: Union[None, Unset, str]
63
51
  if isinstance(self.memoryset_override_id, Unset):
64
52
  memoryset_override_id = UNSET
@@ -80,14 +68,11 @@ class EvaluationRequest:
80
68
  field_dict.update(self.additional_properties)
81
69
  field_dict.update(
82
70
  {
71
+ "datasource_id": datasource_id,
83
72
  "datasource_label_column": datasource_label_column,
84
73
  "datasource_value_column": datasource_value_column,
85
74
  }
86
75
  )
87
- if datasource_id is not UNSET:
88
- field_dict["datasource_id"] = datasource_id
89
- if datasource_name is not UNSET:
90
- field_dict["datasource_name"] = datasource_name
91
76
  if memoryset_override_id is not UNSET:
92
77
  field_dict["memoryset_override_id"] = memoryset_override_id
93
78
  if record_telemetry is not UNSET:
@@ -100,28 +85,12 @@ class EvaluationRequest:
100
85
  @classmethod
101
86
  def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
102
87
  d = src_dict.copy()
88
+ datasource_id = d.pop("datasource_id")
89
+
103
90
  datasource_label_column = d.pop("datasource_label_column")
104
91
 
105
92
  datasource_value_column = d.pop("datasource_value_column")
106
93
 
107
- def _parse_datasource_id(data: object) -> Union[None, Unset, str]:
108
- if data is None:
109
- return data
110
- if isinstance(data, Unset):
111
- return data
112
- return cast(Union[None, Unset, str], data)
113
-
114
- datasource_id = _parse_datasource_id(d.pop("datasource_id", UNSET))
115
-
116
- def _parse_datasource_name(data: object) -> Union[None, Unset, str]:
117
- if data is None:
118
- return data
119
- if isinstance(data, Unset):
120
- return data
121
- return cast(Union[None, Unset, str], data)
122
-
123
- datasource_name = _parse_datasource_name(d.pop("datasource_name", UNSET))
124
-
125
94
  def _parse_memoryset_override_id(data: object) -> Union[None, Unset, str]:
126
95
  if data is None:
127
96
  return data
@@ -150,18 +119,17 @@ class EvaluationRequest:
150
119
 
151
120
  telemetry_tags = _parse_telemetry_tags(d.pop("telemetry_tags", UNSET))
152
121
 
153
- evaluation_request = cls(
122
+ classification_evaluation_request = cls(
123
+ datasource_id=datasource_id,
154
124
  datasource_label_column=datasource_label_column,
155
125
  datasource_value_column=datasource_value_column,
156
- datasource_id=datasource_id,
157
- datasource_name=datasource_name,
158
126
  memoryset_override_id=memoryset_override_id,
159
127
  record_telemetry=record_telemetry,
160
128
  telemetry_tags=telemetry_tags,
161
129
  )
162
130
 
163
- evaluation_request.additional_properties = d
164
- return evaluation_request
131
+ classification_evaluation_request.additional_properties = d
132
+ return classification_evaluation_request
165
133
 
166
134
  @property
167
135
  def additional_keys(self) -> list[str]:
@@ -18,39 +18,43 @@ from attrs import field as _attrs_field
18
18
  from ..types import UNSET, Unset
19
19
 
20
20
  if TYPE_CHECKING:
21
- from ..models.precision_recall_curve import PrecisionRecallCurve
21
+ from ..models.pr_curve import PRCurve
22
22
  from ..models.roc_curve import ROCCurve
23
23
 
24
24
 
25
- T = TypeVar("T", bound="ClassificationEvaluationResult")
25
+ T = TypeVar("T", bound="ClassificationMetrics")
26
26
 
27
27
 
28
28
  @_attrs_define
29
- class ClassificationEvaluationResult:
29
+ class ClassificationMetrics:
30
30
  """
31
31
  Attributes:
32
32
  f1_score (float):
33
33
  accuracy (float):
34
34
  loss (float):
35
- precision_recall_curve (Union['PrecisionRecallCurve', None]):
36
- roc_curve (Union['ROCCurve', None]):
37
35
  anomaly_score_mean (Union[None, Unset, float]):
38
36
  anomaly_score_median (Union[None, Unset, float]):
39
37
  anomaly_score_variance (Union[None, Unset, float]):
38
+ roc_auc (Union[None, Unset, float]):
39
+ pr_auc (Union[None, Unset, float]):
40
+ pr_curve (Union['PRCurve', None, Unset]):
41
+ roc_curve (Union['ROCCurve', None, Unset]):
40
42
  """
41
43
 
42
44
  f1_score: float
43
45
  accuracy: float
44
46
  loss: float
45
- precision_recall_curve: Union["PrecisionRecallCurve", None]
46
- roc_curve: Union["ROCCurve", None]
47
47
  anomaly_score_mean: Union[None, Unset, float] = UNSET
48
48
  anomaly_score_median: Union[None, Unset, float] = UNSET
49
49
  anomaly_score_variance: Union[None, Unset, float] = UNSET
50
+ roc_auc: Union[None, Unset, float] = UNSET
51
+ pr_auc: Union[None, Unset, float] = UNSET
52
+ pr_curve: Union["PRCurve", None, Unset] = UNSET
53
+ roc_curve: Union["ROCCurve", None, Unset] = UNSET
50
54
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
51
55
 
52
56
  def to_dict(self) -> dict[str, Any]:
53
- from ..models.precision_recall_curve import PrecisionRecallCurve
57
+ from ..models.pr_curve import PRCurve
54
58
  from ..models.roc_curve import ROCCurve
55
59
 
56
60
  f1_score = self.f1_score
@@ -59,18 +63,6 @@ class ClassificationEvaluationResult:
59
63
 
60
64
  loss = self.loss
61
65
 
62
- precision_recall_curve: Union[Dict[str, Any], None]
63
- if isinstance(self.precision_recall_curve, PrecisionRecallCurve):
64
- precision_recall_curve = self.precision_recall_curve.to_dict()
65
- else:
66
- precision_recall_curve = self.precision_recall_curve
67
-
68
- roc_curve: Union[Dict[str, Any], None]
69
- if isinstance(self.roc_curve, ROCCurve):
70
- roc_curve = self.roc_curve.to_dict()
71
- else:
72
- roc_curve = self.roc_curve
73
-
74
66
  anomaly_score_mean: Union[None, Unset, float]
75
67
  if isinstance(self.anomaly_score_mean, Unset):
76
68
  anomaly_score_mean = UNSET
@@ -89,6 +81,34 @@ class ClassificationEvaluationResult:
89
81
  else:
90
82
  anomaly_score_variance = self.anomaly_score_variance
91
83
 
84
+ roc_auc: Union[None, Unset, float]
85
+ if isinstance(self.roc_auc, Unset):
86
+ roc_auc = UNSET
87
+ else:
88
+ roc_auc = self.roc_auc
89
+
90
+ pr_auc: Union[None, Unset, float]
91
+ if isinstance(self.pr_auc, Unset):
92
+ pr_auc = UNSET
93
+ else:
94
+ pr_auc = self.pr_auc
95
+
96
+ pr_curve: Union[Dict[str, Any], None, Unset]
97
+ if isinstance(self.pr_curve, Unset):
98
+ pr_curve = UNSET
99
+ elif isinstance(self.pr_curve, PRCurve):
100
+ pr_curve = self.pr_curve.to_dict()
101
+ else:
102
+ pr_curve = self.pr_curve
103
+
104
+ roc_curve: Union[Dict[str, Any], None, Unset]
105
+ if isinstance(self.roc_curve, Unset):
106
+ roc_curve = UNSET
107
+ elif isinstance(self.roc_curve, ROCCurve):
108
+ roc_curve = self.roc_curve.to_dict()
109
+ else:
110
+ roc_curve = self.roc_curve
111
+
92
112
  field_dict: dict[str, Any] = {}
93
113
  field_dict.update(self.additional_properties)
94
114
  field_dict.update(
@@ -96,8 +116,6 @@ class ClassificationEvaluationResult:
96
116
  "f1_score": f1_score,
97
117
  "accuracy": accuracy,
98
118
  "loss": loss,
99
- "precision_recall_curve": precision_recall_curve,
100
- "roc_curve": roc_curve,
101
119
  }
102
120
  )
103
121
  if anomaly_score_mean is not UNSET:
@@ -106,12 +124,20 @@ class ClassificationEvaluationResult:
106
124
  field_dict["anomaly_score_median"] = anomaly_score_median
107
125
  if anomaly_score_variance is not UNSET:
108
126
  field_dict["anomaly_score_variance"] = anomaly_score_variance
127
+ if roc_auc is not UNSET:
128
+ field_dict["roc_auc"] = roc_auc
129
+ if pr_auc is not UNSET:
130
+ field_dict["pr_auc"] = pr_auc
131
+ if pr_curve is not UNSET:
132
+ field_dict["pr_curve"] = pr_curve
133
+ if roc_curve is not UNSET:
134
+ field_dict["roc_curve"] = roc_curve
109
135
 
110
136
  return field_dict
111
137
 
112
138
  @classmethod
113
139
  def from_dict(cls: Type[T], src_dict: dict[str, Any]) -> T:
114
- from ..models.precision_recall_curve import PrecisionRecallCurve
140
+ from ..models.pr_curve import PRCurve
115
141
  from ..models.roc_curve import ROCCurve
116
142
 
117
143
  d = src_dict.copy()
@@ -121,76 +147,100 @@ class ClassificationEvaluationResult:
121
147
 
122
148
  loss = d.pop("loss")
123
149
 
124
- def _parse_precision_recall_curve(data: object) -> Union["PrecisionRecallCurve", None]:
150
+ def _parse_anomaly_score_mean(data: object) -> Union[None, Unset, float]:
125
151
  if data is None:
126
152
  return data
127
- try:
128
- if not isinstance(data, dict):
129
- raise TypeError()
130
- precision_recall_curve_type_0 = PrecisionRecallCurve.from_dict(data)
131
-
132
- return precision_recall_curve_type_0
133
- except: # noqa: E722
134
- pass
135
- return cast(Union["PrecisionRecallCurve", None], data)
153
+ if isinstance(data, Unset):
154
+ return data
155
+ return cast(Union[None, Unset, float], data)
136
156
 
137
- precision_recall_curve = _parse_precision_recall_curve(d.pop("precision_recall_curve"))
157
+ anomaly_score_mean = _parse_anomaly_score_mean(d.pop("anomaly_score_mean", UNSET))
138
158
 
139
- def _parse_roc_curve(data: object) -> Union["ROCCurve", None]:
159
+ def _parse_anomaly_score_median(data: object) -> Union[None, Unset, float]:
140
160
  if data is None:
141
161
  return data
142
- try:
143
- if not isinstance(data, dict):
144
- raise TypeError()
145
- roc_curve_type_0 = ROCCurve.from_dict(data)
146
-
147
- return roc_curve_type_0
148
- except: # noqa: E722
149
- pass
150
- return cast(Union["ROCCurve", None], data)
162
+ if isinstance(data, Unset):
163
+ return data
164
+ return cast(Union[None, Unset, float], data)
151
165
 
152
- roc_curve = _parse_roc_curve(d.pop("roc_curve"))
166
+ anomaly_score_median = _parse_anomaly_score_median(d.pop("anomaly_score_median", UNSET))
153
167
 
154
- def _parse_anomaly_score_mean(data: object) -> Union[None, Unset, float]:
168
+ def _parse_anomaly_score_variance(data: object) -> Union[None, Unset, float]:
155
169
  if data is None:
156
170
  return data
157
171
  if isinstance(data, Unset):
158
172
  return data
159
173
  return cast(Union[None, Unset, float], data)
160
174
 
161
- anomaly_score_mean = _parse_anomaly_score_mean(d.pop("anomaly_score_mean", UNSET))
175
+ anomaly_score_variance = _parse_anomaly_score_variance(d.pop("anomaly_score_variance", UNSET))
162
176
 
163
- def _parse_anomaly_score_median(data: object) -> Union[None, Unset, float]:
177
+ def _parse_roc_auc(data: object) -> Union[None, Unset, float]:
164
178
  if data is None:
165
179
  return data
166
180
  if isinstance(data, Unset):
167
181
  return data
168
182
  return cast(Union[None, Unset, float], data)
169
183
 
170
- anomaly_score_median = _parse_anomaly_score_median(d.pop("anomaly_score_median", UNSET))
184
+ roc_auc = _parse_roc_auc(d.pop("roc_auc", UNSET))
171
185
 
172
- def _parse_anomaly_score_variance(data: object) -> Union[None, Unset, float]:
186
+ def _parse_pr_auc(data: object) -> Union[None, Unset, float]:
173
187
  if data is None:
174
188
  return data
175
189
  if isinstance(data, Unset):
176
190
  return data
177
191
  return cast(Union[None, Unset, float], data)
178
192
 
179
- anomaly_score_variance = _parse_anomaly_score_variance(d.pop("anomaly_score_variance", UNSET))
193
+ pr_auc = _parse_pr_auc(d.pop("pr_auc", UNSET))
194
+
195
+ def _parse_pr_curve(data: object) -> Union["PRCurve", None, Unset]:
196
+ if data is None:
197
+ return data
198
+ if isinstance(data, Unset):
199
+ return data
200
+ try:
201
+ if not isinstance(data, dict):
202
+ raise TypeError()
203
+ pr_curve_type_0 = PRCurve.from_dict(data)
204
+
205
+ return pr_curve_type_0
206
+ except: # noqa: E722
207
+ pass
208
+ return cast(Union["PRCurve", None, Unset], data)
209
+
210
+ pr_curve = _parse_pr_curve(d.pop("pr_curve", UNSET))
180
211
 
181
- classification_evaluation_result = cls(
212
+ def _parse_roc_curve(data: object) -> Union["ROCCurve", None, Unset]:
213
+ if data is None:
214
+ return data
215
+ if isinstance(data, Unset):
216
+ return data
217
+ try:
218
+ if not isinstance(data, dict):
219
+ raise TypeError()
220
+ roc_curve_type_0 = ROCCurve.from_dict(data)
221
+
222
+ return roc_curve_type_0
223
+ except: # noqa: E722
224
+ pass
225
+ return cast(Union["ROCCurve", None, Unset], data)
226
+
227
+ roc_curve = _parse_roc_curve(d.pop("roc_curve", UNSET))
228
+
229
+ classification_metrics = cls(
182
230
  f1_score=f1_score,
183
231
  accuracy=accuracy,
184
232
  loss=loss,
185
- precision_recall_curve=precision_recall_curve,
186
- roc_curve=roc_curve,
187
233
  anomaly_score_mean=anomaly_score_mean,
188
234
  anomaly_score_median=anomaly_score_median,
189
235
  anomaly_score_variance=anomaly_score_variance,
236
+ roc_auc=roc_auc,
237
+ pr_auc=pr_auc,
238
+ pr_curve=pr_curve,
239
+ roc_curve=roc_curve,
190
240
  )
191
241
 
192
- classification_evaluation_result.additional_properties = d
193
- return classification_evaluation_result
242
+ classification_metrics.additional_properties = d
243
+ return classification_metrics
194
244
 
195
245
  @property
196
246
  def additional_keys(self) -> list[str]:
@@ -20,11 +20,11 @@ from dateutil.parser import isoparse
20
20
 
21
21
  from ..models.rac_head_type import RACHeadType
22
22
 
23
- T = TypeVar("T", bound="RACModelMetadata")
23
+ T = TypeVar("T", bound="ClassificationModelMetadata")
24
24
 
25
25
 
26
26
  @_attrs_define
27
- class RACModelMetadata:
27
+ class ClassificationModelMetadata:
28
28
  """
29
29
  Attributes:
30
30
  id (str):
@@ -32,16 +32,17 @@ class RACModelMetadata:
32
32
  name (str):
33
33
  description (Union[None, str]):
34
34
  version (int):
35
- num_classes (int):
36
- head_type (RACHeadType):
37
35
  memoryset_id (str):
38
36
  memory_lookup_count (int):
39
- weigh_memories (Union[None, bool]):
40
- min_memory_weight (Union[None, float]):
41
37
  storage_path (str):
42
38
  memoryset_collection_name (str):
43
39
  created_at (datetime.datetime):
44
40
  updated_at (datetime.datetime):
41
+ locked (bool):
42
+ num_classes (int):
43
+ head_type (RACHeadType):
44
+ weigh_memories (Union[None, bool]):
45
+ min_memory_weight (Union[None, float]):
45
46
  """
46
47
 
47
48
  id: str
@@ -49,16 +50,17 @@ class RACModelMetadata:
49
50
  name: str
50
51
  description: Union[None, str]
51
52
  version: int
52
- num_classes: int
53
- head_type: RACHeadType
54
53
  memoryset_id: str
55
54
  memory_lookup_count: int
56
- weigh_memories: Union[None, bool]
57
- min_memory_weight: Union[None, float]
58
55
  storage_path: str
59
56
  memoryset_collection_name: str
60
57
  created_at: datetime.datetime
61
58
  updated_at: datetime.datetime
59
+ locked: bool
60
+ num_classes: int
61
+ head_type: RACHeadType
62
+ weigh_memories: Union[None, bool]
63
+ min_memory_weight: Union[None, float]
62
64
  additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
63
65
 
64
66
  def to_dict(self) -> dict[str, Any]:
@@ -73,20 +75,10 @@ class RACModelMetadata:
73
75
 
74
76
  version = self.version
75
77
 
76
- num_classes = self.num_classes
77
-
78
- head_type = self.head_type.value if isinstance(self.head_type, Enum) else self.head_type
79
-
80
78
  memoryset_id = self.memoryset_id
81
79
 
82
80
  memory_lookup_count = self.memory_lookup_count
83
81
 
84
- weigh_memories: Union[None, bool]
85
- weigh_memories = self.weigh_memories
86
-
87
- min_memory_weight: Union[None, float]
88
- min_memory_weight = self.min_memory_weight
89
-
90
82
  storage_path = self.storage_path
91
83
 
92
84
  memoryset_collection_name = self.memoryset_collection_name
@@ -95,6 +87,18 @@ class RACModelMetadata:
95
87
 
96
88
  updated_at = self.updated_at.isoformat()
97
89
 
90
+ locked = self.locked
91
+
92
+ num_classes = self.num_classes
93
+
94
+ head_type = self.head_type.value if isinstance(self.head_type, Enum) else self.head_type
95
+
96
+ weigh_memories: Union[None, bool]
97
+ weigh_memories = self.weigh_memories
98
+
99
+ min_memory_weight: Union[None, float]
100
+ min_memory_weight = self.min_memory_weight
101
+
98
102
  field_dict: dict[str, Any] = {}
99
103
  field_dict.update(self.additional_properties)
100
104
  field_dict.update(
@@ -104,16 +108,17 @@ class RACModelMetadata:
104
108
  "name": name,
105
109
  "description": description,
106
110
  "version": version,
107
- "num_classes": num_classes,
108
- "head_type": head_type,
109
111
  "memoryset_id": memoryset_id,
110
112
  "memory_lookup_count": memory_lookup_count,
111
- "weigh_memories": weigh_memories,
112
- "min_memory_weight": min_memory_weight,
113
113
  "storage_path": storage_path,
114
114
  "memoryset_collection_name": memoryset_collection_name,
115
115
  "created_at": created_at,
116
116
  "updated_at": updated_at,
117
+ "locked": locked,
118
+ "num_classes": num_classes,
119
+ "head_type": head_type,
120
+ "weigh_memories": weigh_memories,
121
+ "min_memory_weight": min_memory_weight,
117
122
  }
118
123
  )
119
124
 
@@ -137,14 +142,24 @@ class RACModelMetadata:
137
142
 
138
143
  version = d.pop("version")
139
144
 
140
- num_classes = d.pop("num_classes")
141
-
142
- head_type = RACHeadType(d.pop("head_type"))
143
-
144
145
  memoryset_id = d.pop("memoryset_id")
145
146
 
146
147
  memory_lookup_count = d.pop("memory_lookup_count")
147
148
 
149
+ storage_path = d.pop("storage_path")
150
+
151
+ memoryset_collection_name = d.pop("memoryset_collection_name")
152
+
153
+ created_at = isoparse(d.pop("created_at"))
154
+
155
+ updated_at = isoparse(d.pop("updated_at"))
156
+
157
+ locked = d.pop("locked")
158
+
159
+ num_classes = d.pop("num_classes")
160
+
161
+ head_type = RACHeadType(d.pop("head_type"))
162
+
148
163
  def _parse_weigh_memories(data: object) -> Union[None, bool]:
149
164
  if data is None:
150
165
  return data
@@ -159,34 +174,27 @@ class RACModelMetadata:
159
174
 
160
175
  min_memory_weight = _parse_min_memory_weight(d.pop("min_memory_weight"))
161
176
 
162
- storage_path = d.pop("storage_path")
163
-
164
- memoryset_collection_name = d.pop("memoryset_collection_name")
165
-
166
- created_at = isoparse(d.pop("created_at"))
167
-
168
- updated_at = isoparse(d.pop("updated_at"))
169
-
170
- rac_model_metadata = cls(
177
+ classification_model_metadata = cls(
171
178
  id=id,
172
179
  org_id=org_id,
173
180
  name=name,
174
181
  description=description,
175
182
  version=version,
176
- num_classes=num_classes,
177
- head_type=head_type,
178
183
  memoryset_id=memoryset_id,
179
184
  memory_lookup_count=memory_lookup_count,
180
- weigh_memories=weigh_memories,
181
- min_memory_weight=min_memory_weight,
182
185
  storage_path=storage_path,
183
186
  memoryset_collection_name=memoryset_collection_name,
184
187
  created_at=created_at,
185
188
  updated_at=updated_at,
189
+ locked=locked,
190
+ num_classes=num_classes,
191
+ head_type=head_type,
192
+ weigh_memories=weigh_memories,
193
+ min_memory_weight=min_memory_weight,
186
194
  )
187
195
 
188
- rac_model_metadata.additional_properties = d
189
- return rac_model_metadata
196
+ classification_model_metadata.additional_properties = d
197
+ return classification_model_metadata
190
198
 
191
199
  @property
192
200
  def additional_keys(self) -> list[str]: