rasa-pro 3.14.1__py3-none-any.whl → 3.15.0.dev20251027__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/builder/config.py +4 -0
- rasa/builder/copilot/copilot.py +28 -9
- rasa/builder/copilot/models.py +251 -32
- rasa/builder/document_retrieval/inkeep_document_retrieval.py +2 -0
- rasa/builder/download.py +1 -1
- rasa/builder/evaluator/__init__.py +0 -0
- rasa/builder/evaluator/constants.py +15 -0
- rasa/builder/evaluator/copilot_executor.py +89 -0
- rasa/builder/evaluator/dataset/models.py +173 -0
- rasa/builder/evaluator/exceptions.py +4 -0
- rasa/builder/evaluator/response_classification/__init__.py +0 -0
- rasa/builder/evaluator/response_classification/constants.py +66 -0
- rasa/builder/evaluator/response_classification/evaluator.py +346 -0
- rasa/builder/evaluator/response_classification/langfuse_runner.py +463 -0
- rasa/builder/evaluator/response_classification/models.py +61 -0
- rasa/builder/evaluator/scripts/__init__.py +0 -0
- rasa/builder/evaluator/scripts/run_response_classification_evaluator.py +152 -0
- rasa/builder/service.py +101 -24
- rasa/builder/telemetry/__init__.py +0 -0
- rasa/builder/telemetry/copilot_langfuse_telemetry.py +384 -0
- rasa/builder/{copilot/telemetry.py → telemetry/copilot_segment_telemetry.py} +21 -3
- rasa/constants.py +1 -0
- rasa/core/policies/flows/flow_executor.py +20 -6
- rasa/core/run.py +15 -4
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
- rasa/e2e_test/e2e_config.py +4 -3
- rasa/engine/recipes/default_components.py +16 -6
- rasa/graph_components/validators/default_recipe_validator.py +10 -4
- rasa/nlu/classifiers/diet_classifier.py +2 -0
- rasa/shared/core/slots.py +55 -24
- rasa/shared/utils/common.py +9 -1
- rasa/utils/common.py +9 -0
- rasa/utils/endpoints.py +2 -0
- rasa/utils/installation_utils.py +111 -0
- rasa/utils/tensorflow/callback.py +2 -0
- rasa/utils/train_utils.py +2 -0
- rasa/version.py +1 -1
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0.dev20251027.dist-info}/METADATA +4 -2
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0.dev20251027.dist-info}/RECORD +43 -28
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0.dev20251027.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0.dev20251027.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0.dev20251027.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
from pydantic import BaseModel, Field, field_validator
|
|
5
|
+
|
|
6
|
+
from rasa.builder.copilot.models import (
|
|
7
|
+
ChatMessage,
|
|
8
|
+
CopilotContext,
|
|
9
|
+
EventContent,
|
|
10
|
+
ReferenceEntry,
|
|
11
|
+
ResponseCategory,
|
|
12
|
+
create_chat_message_from_dict,
|
|
13
|
+
)
|
|
14
|
+
from rasa.builder.document_retrieval.models import Document
|
|
15
|
+
from rasa.builder.shared.tracker_context import TrackerContext
|
|
16
|
+
|
|
17
|
+
structlogger = structlog.get_logger()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DatasetInput(BaseModel):
|
|
21
|
+
"""Model for the input field of a dataset entry."""
|
|
22
|
+
|
|
23
|
+
message: Optional[str] = None
|
|
24
|
+
tracker_event_attachments: List[EventContent] = Field(default_factory=list)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DatasetExpectedOutput(BaseModel):
|
|
28
|
+
"""Model for the expected_output field of a dataset entry."""
|
|
29
|
+
|
|
30
|
+
answer: str
|
|
31
|
+
response_category: ResponseCategory
|
|
32
|
+
references: list[ReferenceEntry]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class DatasetMetadataCopilotAdditionalContext(BaseModel):
|
|
36
|
+
"""Model for the copilot_additional_context in metadata."""
|
|
37
|
+
|
|
38
|
+
relevant_documents: List[Document] = Field(default_factory=list)
|
|
39
|
+
relevant_assistant_files: Dict[str, str] = Field(default_factory=dict)
|
|
40
|
+
assistant_tracker_context: Optional[Dict[str, Any]] = None
|
|
41
|
+
assistant_logs: str = Field(default="")
|
|
42
|
+
copilot_chat_history: List[ChatMessage] = Field(default_factory=list)
|
|
43
|
+
|
|
44
|
+
@field_validator("copilot_chat_history", mode="before")
|
|
45
|
+
@classmethod
|
|
46
|
+
def parse_chat_history(
|
|
47
|
+
cls, v: Union[List[Dict[str, Any]], List[ChatMessage]]
|
|
48
|
+
) -> List[ChatMessage]:
|
|
49
|
+
"""Manually parse chat history messages based on role field."""
|
|
50
|
+
# If already parsed ChatMessage objects, return them as-is
|
|
51
|
+
if (
|
|
52
|
+
v
|
|
53
|
+
and isinstance(v, list)
|
|
54
|
+
and all(isinstance(item, ChatMessage) for item in v)
|
|
55
|
+
):
|
|
56
|
+
return v # type: ignore[return-value]
|
|
57
|
+
|
|
58
|
+
# Check for mixed types (some ChatMessage, some not)
|
|
59
|
+
if (
|
|
60
|
+
v
|
|
61
|
+
and isinstance(v, list)
|
|
62
|
+
and any(isinstance(item, ChatMessage) for item in v)
|
|
63
|
+
):
|
|
64
|
+
message = (
|
|
65
|
+
"Mixed types in copilot_chat_history: cannot mix ChatMessage objects "
|
|
66
|
+
"with other types."
|
|
67
|
+
)
|
|
68
|
+
structlogger.error(
|
|
69
|
+
"dataset_entry.parse_chat_history.mixed_types",
|
|
70
|
+
event_info=message,
|
|
71
|
+
chat_history_types=[type(item) for item in v],
|
|
72
|
+
)
|
|
73
|
+
raise ValueError(message)
|
|
74
|
+
|
|
75
|
+
# Otherwise, parse from dictionaries
|
|
76
|
+
parsed_messages: List[ChatMessage] = []
|
|
77
|
+
for message_data in v:
|
|
78
|
+
chat_message = create_chat_message_from_dict(message_data)
|
|
79
|
+
parsed_messages.append(chat_message)
|
|
80
|
+
return parsed_messages
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class DatasetMetadata(BaseModel):
|
|
84
|
+
"""Model for the metadata field of a dataset entry."""
|
|
85
|
+
|
|
86
|
+
ids: Dict[str, str] = Field(default_factory=dict)
|
|
87
|
+
copilot_additional_context: DatasetMetadataCopilotAdditionalContext = Field(
|
|
88
|
+
default_factory=DatasetMetadataCopilotAdditionalContext
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class DatasetEntry(BaseModel):
|
|
93
|
+
"""Pydantic model for dataset entries from Langfuse ExperimentItem."""
|
|
94
|
+
|
|
95
|
+
# Basic fields from ExperimentItem
|
|
96
|
+
id: str
|
|
97
|
+
input: DatasetInput
|
|
98
|
+
expected_output: DatasetExpectedOutput
|
|
99
|
+
metadata: DatasetMetadata
|
|
100
|
+
|
|
101
|
+
def to_copilot_context(self) -> CopilotContext:
|
|
102
|
+
"""Create a CopilotContext from the dataset entry.
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ValueError: If the metadata is None, as it's required for creating a valid
|
|
106
|
+
CopilotContext.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
CopilotContext with all the context information.
|
|
110
|
+
"""
|
|
111
|
+
if self.metadata is None:
|
|
112
|
+
message = (
|
|
113
|
+
f"Cannot create CopilotContext from dataset item with id: {self.id}. "
|
|
114
|
+
f"Metadata is required but was None."
|
|
115
|
+
)
|
|
116
|
+
structlogger.error(
|
|
117
|
+
"dataset_entry.to_copilot_context.metadata_is_none",
|
|
118
|
+
event_info=message,
|
|
119
|
+
item_id=self.id,
|
|
120
|
+
item_metadata=self.metadata,
|
|
121
|
+
)
|
|
122
|
+
raise ValueError(message)
|
|
123
|
+
|
|
124
|
+
# Parse tracker context if available
|
|
125
|
+
tracker_context = None
|
|
126
|
+
if (
|
|
127
|
+
self.metadata.copilot_additional_context.assistant_tracker_context
|
|
128
|
+
is not None
|
|
129
|
+
):
|
|
130
|
+
tracker_context = TrackerContext(
|
|
131
|
+
**self.metadata.copilot_additional_context.assistant_tracker_context
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return CopilotContext(
|
|
135
|
+
tracker_context=tracker_context,
|
|
136
|
+
assistant_logs=self.metadata.copilot_additional_context.assistant_logs,
|
|
137
|
+
assistant_files=self.metadata.copilot_additional_context.relevant_assistant_files,
|
|
138
|
+
copilot_chat_history=self.metadata.copilot_additional_context.copilot_chat_history,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def from_raw_data(
|
|
143
|
+
cls,
|
|
144
|
+
id: str,
|
|
145
|
+
input_data: Dict[str, Any],
|
|
146
|
+
expected_output_data: Dict[str, Any],
|
|
147
|
+
metadata_data: Dict[str, Any],
|
|
148
|
+
) -> "DatasetEntry":
|
|
149
|
+
"""Create a DatasetEntry from raw dictionary data.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
id: The dataset entry ID.
|
|
153
|
+
input_data: Raw input dictionary.
|
|
154
|
+
expected_output_data: Raw expected output dictionary.
|
|
155
|
+
metadata_data: Raw metadata dictionary with all the additional context
|
|
156
|
+
used to generate the Copilot response.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
DatasetEntry with parsed data.
|
|
160
|
+
"""
|
|
161
|
+
# Use Pydantic's model_validate to parse nested structures
|
|
162
|
+
dataset_input = DatasetInput.model_validate(input_data)
|
|
163
|
+
dataset_expected_output = DatasetExpectedOutput.model_validate(
|
|
164
|
+
expected_output_data
|
|
165
|
+
)
|
|
166
|
+
dataset_metadata = DatasetMetadata.model_validate(metadata_data)
|
|
167
|
+
|
|
168
|
+
return cls(
|
|
169
|
+
id=id,
|
|
170
|
+
input=dataset_input,
|
|
171
|
+
expected_output=dataset_expected_output,
|
|
172
|
+
metadata=dataset_metadata,
|
|
173
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Constants for the response classification evaluator."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Literal
|
|
4
|
+
|
|
5
|
+
# Options for averaging methods for Response Classification Evaluator
|
|
6
|
+
MICRO_AVERAGING_METHOD: Literal["micro"] = "micro"
|
|
7
|
+
MACRO_AVERAGING_METHOD: Literal["macro"] = "macro"
|
|
8
|
+
WEIGHTED_AVERAGING_METHOD: Literal["weighted"] = "weighted"
|
|
9
|
+
|
|
10
|
+
AVERAGING_METHODS: List[Literal["micro", "macro", "weighted"]] = [
|
|
11
|
+
MICRO_AVERAGING_METHOD,
|
|
12
|
+
MACRO_AVERAGING_METHOD,
|
|
13
|
+
WEIGHTED_AVERAGING_METHOD,
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
# Overall evaluation metric names
|
|
17
|
+
MICRO_PRECISION_METRIC = "micro_precision"
|
|
18
|
+
MACRO_PRECISION_METRIC = "macro_precision"
|
|
19
|
+
WEIGHTED_PRECISION_METRIC = "weighted_precision"
|
|
20
|
+
|
|
21
|
+
MICRO_RECALL_METRIC = "micro_recall"
|
|
22
|
+
MACRO_RECALL_METRIC = "macro_recall"
|
|
23
|
+
WEIGHTED_RECALL_METRIC = "weighted_recall"
|
|
24
|
+
|
|
25
|
+
MICRO_F1_METRIC = "micro_f1"
|
|
26
|
+
MACRO_F1_METRIC = "macro_f1"
|
|
27
|
+
WEIGHTED_F1_METRIC = "weighted_f1"
|
|
28
|
+
|
|
29
|
+
# Skip count metric name due to invalid data
|
|
30
|
+
SKIP_COUNT_METRIC = "skipped_items"
|
|
31
|
+
|
|
32
|
+
# Per-class evaluation metric name templates
|
|
33
|
+
PER_CLASS_PRECISION_METRIC_TEMPLATE = "{category}_precision"
|
|
34
|
+
PER_CLASS_RECALL_METRIC_TEMPLATE = "{category}_recall"
|
|
35
|
+
PER_CLASS_F1_METRIC_TEMPLATE = "{category}_f1"
|
|
36
|
+
PER_CLASS_SUPPORT_METRIC_TEMPLATE = "{category}_support"
|
|
37
|
+
|
|
38
|
+
# Description templates for evaluation metrics
|
|
39
|
+
MICRO_PRECISION_DESCRIPTION = "Micro Precision: {value:.3f}"
|
|
40
|
+
MACRO_PRECISION_DESCRIPTION = "Macro Precision: {value:.3f}"
|
|
41
|
+
WEIGHTED_PRECISION_DESCRIPTION = "Weighted Precision: {value:.3f}"
|
|
42
|
+
|
|
43
|
+
MICRO_RECALL_DESCRIPTION = "Micro Recall: {value:.3f}"
|
|
44
|
+
MACRO_RECALL_DESCRIPTION = "Macro Recall: {value:.3f}"
|
|
45
|
+
WEIGHTED_RECALL_DESCRIPTION = "Weighted Recall: {value:.3f}"
|
|
46
|
+
|
|
47
|
+
MICRO_F1_DESCRIPTION = "Micro F1: {value:.3f}"
|
|
48
|
+
MACRO_F1_DESCRIPTION = "Macro F1: {value:.3f}"
|
|
49
|
+
WEIGHTED_F1_DESCRIPTION = "Weighted F1: {value:.3f}"
|
|
50
|
+
|
|
51
|
+
# Skip count metric description
|
|
52
|
+
SKIP_COUNT_DESCRIPTION = "Skipped {value} items due to invalid data"
|
|
53
|
+
|
|
54
|
+
# Per-class description templates
|
|
55
|
+
PER_CLASS_PRECISION_DESCRIPTION = "[{category}] Precision: {value:.3f}"
|
|
56
|
+
PER_CLASS_RECALL_DESCRIPTION = "[{category}] Recall: {value:.3f}"
|
|
57
|
+
PER_CLASS_F1_DESCRIPTION = "[{category}] F1: {value:.3f}"
|
|
58
|
+
PER_CLASS_SUPPORT_DESCRIPTION = "[{category}] Support: {value}"
|
|
59
|
+
|
|
60
|
+
# Experiment configuration
|
|
61
|
+
EXPERIMENT_NAME = "Copilot Response Classification Evaluation"
|
|
62
|
+
EXPERIMENT_DESCRIPTION = (
|
|
63
|
+
"Evaluating Copilot response classification performance with per-class metrics "
|
|
64
|
+
"and overall averages (micro, macro, weighted). The metric that are reported are: "
|
|
65
|
+
"precision, recall, F1, support."
|
|
66
|
+
)
|
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
from typing import Dict, List, Literal, Optional
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from rasa.builder.copilot.models import ResponseCategory
|
|
6
|
+
from rasa.builder.evaluator.response_classification.constants import (
|
|
7
|
+
MACRO_AVERAGING_METHOD,
|
|
8
|
+
MICRO_AVERAGING_METHOD,
|
|
9
|
+
WEIGHTED_AVERAGING_METHOD,
|
|
10
|
+
)
|
|
11
|
+
from rasa.builder.evaluator.response_classification.models import (
|
|
12
|
+
ClassificationResult,
|
|
13
|
+
MetricsSummary,
|
|
14
|
+
OverallClassificationMetrics,
|
|
15
|
+
PerClassMetrics,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
structlogger = structlog.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ResponseClassificationEvaluator:
|
|
22
|
+
def __init__(self): # type: ignore[no-untyped-def]
|
|
23
|
+
self._classes: List[ResponseCategory] = [
|
|
24
|
+
ResponseCategory.COPILOT,
|
|
25
|
+
ResponseCategory.OUT_OF_SCOPE_DETECTION,
|
|
26
|
+
ResponseCategory.ROLEPLAY_DETECTION,
|
|
27
|
+
ResponseCategory.KNOWLEDGE_BASE_ACCESS_REQUESTED,
|
|
28
|
+
ResponseCategory.ERROR_FALLBACK,
|
|
29
|
+
# TODO: Add the greetings and goodbyes as support once the orchestrator
|
|
30
|
+
# aproach is implemented
|
|
31
|
+
]
|
|
32
|
+
self._true_positives_per_class: Dict[ResponseCategory, int] = {
|
|
33
|
+
clazz: 0 for clazz in self._classes
|
|
34
|
+
}
|
|
35
|
+
self._false_positives_per_class: Dict[ResponseCategory, int] = {
|
|
36
|
+
clazz: 0 for clazz in self._classes
|
|
37
|
+
}
|
|
38
|
+
self._false_negatives_per_class: Dict[ResponseCategory, int] = {
|
|
39
|
+
clazz: 0 for clazz in self._classes
|
|
40
|
+
}
|
|
41
|
+
self._support_per_class: Dict[ResponseCategory, int] = {
|
|
42
|
+
clazz: 0 for clazz in self._classes
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
self._evaluated = False
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def metrics_summary(self) -> Optional[MetricsSummary]:
|
|
49
|
+
"""Get the metrics summary.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
MetricsSummary with structured per-class and overall metrics if
|
|
53
|
+
the evaluator has been run on the data, otherwise None.
|
|
54
|
+
"""
|
|
55
|
+
if not self._evaluated:
|
|
56
|
+
structlogger.warning(
|
|
57
|
+
"evaluator.response_classification_evaluator"
|
|
58
|
+
".metrics_summary.not_evaluated",
|
|
59
|
+
event_info="Evaluator not evaluated. Returning empty metrics summary.",
|
|
60
|
+
)
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
return self._get_metrics_summary()
|
|
64
|
+
|
|
65
|
+
def reset(self) -> None:
|
|
66
|
+
self._true_positives_per_class = {clazz: 0 for clazz in self._classes}
|
|
67
|
+
self._false_positives_per_class = {clazz: 0 for clazz in self._classes}
|
|
68
|
+
self._false_negatives_per_class = {clazz: 0 for clazz in self._classes}
|
|
69
|
+
self._support_per_class = {clazz: 0 for clazz in self._classes}
|
|
70
|
+
self._evaluated = False
|
|
71
|
+
|
|
72
|
+
def evaluate(self, item_results: List[ClassificationResult]) -> MetricsSummary:
|
|
73
|
+
"""Evaluate the classifier on the given item results."""
|
|
74
|
+
if self._evaluated:
|
|
75
|
+
structlogger.warning(
|
|
76
|
+
"evaluator.response_classification_evaluator.evaluate.already_evaluated",
|
|
77
|
+
event_info="Evaluator already evaluated. Resetting evaluator.",
|
|
78
|
+
)
|
|
79
|
+
self.reset()
|
|
80
|
+
|
|
81
|
+
for result in item_results:
|
|
82
|
+
# Skip and raise a warning if the class is not in the list of classes
|
|
83
|
+
if result.expected not in self._classes:
|
|
84
|
+
structlogger.warning(
|
|
85
|
+
"evaluator.response_classification_evaluator"
|
|
86
|
+
".evaluate.class_not_recognized",
|
|
87
|
+
event_info=(
|
|
88
|
+
f"Class '{result.expected}' is not recognized. "
|
|
89
|
+
f"Skipping evaluation for this class."
|
|
90
|
+
),
|
|
91
|
+
expected_class=result.expected,
|
|
92
|
+
classes=self._classes,
|
|
93
|
+
)
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
# Update support for the expected class
|
|
97
|
+
if result.expected in self._support_per_class:
|
|
98
|
+
self._support_per_class[result.expected] += 1
|
|
99
|
+
|
|
100
|
+
# Calculate TP, FP, FN per class
|
|
101
|
+
for clazz in self._classes:
|
|
102
|
+
if result.prediction == clazz and result.expected == clazz:
|
|
103
|
+
self._true_positives_per_class[clazz] += 1
|
|
104
|
+
|
|
105
|
+
elif result.prediction == clazz and result.expected != clazz:
|
|
106
|
+
self._false_positives_per_class[clazz] += 1
|
|
107
|
+
|
|
108
|
+
elif result.prediction != clazz and result.expected == clazz:
|
|
109
|
+
self._false_negatives_per_class[clazz] += 1
|
|
110
|
+
|
|
111
|
+
self._evaluated = True
|
|
112
|
+
return self._get_metrics_summary()
|
|
113
|
+
|
|
114
|
+
def calculate_precision_per_class(self, clazz: ResponseCategory) -> float:
|
|
115
|
+
"""Calculate precision for a specific response category."""
|
|
116
|
+
tp = self._true_positives_per_class.get(clazz, 0)
|
|
117
|
+
fp = self._false_positives_per_class.get(clazz, 0)
|
|
118
|
+
|
|
119
|
+
if tp + fp == 0:
|
|
120
|
+
return 0.0
|
|
121
|
+
|
|
122
|
+
return tp / (tp + fp)
|
|
123
|
+
|
|
124
|
+
def calculate_recall_per_class(self, clazz: ResponseCategory) -> float:
|
|
125
|
+
"""Calculate recall for a specific response category."""
|
|
126
|
+
tp = self._true_positives_per_class.get(clazz, 0)
|
|
127
|
+
fn = self._false_negatives_per_class.get(clazz, 0)
|
|
128
|
+
|
|
129
|
+
if tp + fn == 0:
|
|
130
|
+
return 0.0
|
|
131
|
+
|
|
132
|
+
return tp / (tp + fn)
|
|
133
|
+
|
|
134
|
+
def calculate_f1_per_class(self, clazz: ResponseCategory) -> float:
|
|
135
|
+
"""Calculate F1 score for a specific response category."""
|
|
136
|
+
precision = self.calculate_precision_per_class(clazz)
|
|
137
|
+
recall = self.calculate_recall_per_class(clazz)
|
|
138
|
+
|
|
139
|
+
if precision + recall == 0:
|
|
140
|
+
return 0.0
|
|
141
|
+
|
|
142
|
+
return 2 * (precision * recall) / (precision + recall)
|
|
143
|
+
|
|
144
|
+
def calculate_precision(
|
|
145
|
+
self, average: Literal["micro", "macro", "weighted"] = MICRO_AVERAGING_METHOD
|
|
146
|
+
) -> float:
|
|
147
|
+
"""Calculate precision with specified averaging method."""
|
|
148
|
+
if average == MICRO_AVERAGING_METHOD:
|
|
149
|
+
return self._calculate_micro_precision()
|
|
150
|
+
elif average == MACRO_AVERAGING_METHOD:
|
|
151
|
+
return self._calculate_macro_precision()
|
|
152
|
+
elif average == WEIGHTED_AVERAGING_METHOD:
|
|
153
|
+
return self._calculate_weighted_avg_precision()
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError(f"Invalid averaging method: {average}")
|
|
156
|
+
|
|
157
|
+
def _calculate_micro_precision(self) -> float:
|
|
158
|
+
"""Calculate overall precision with specified averaging method.
|
|
159
|
+
|
|
160
|
+
Calculates the metric globally by aggregating the total true positives, false
|
|
161
|
+
positives, across all classes. Each sample contributes equally to the final
|
|
162
|
+
score.
|
|
163
|
+
"""
|
|
164
|
+
total_tp = sum(self._true_positives_per_class.values())
|
|
165
|
+
total_fp = sum(self._false_positives_per_class.values())
|
|
166
|
+
|
|
167
|
+
if total_tp + total_fp == 0:
|
|
168
|
+
return 0.0
|
|
169
|
+
|
|
170
|
+
return total_tp / (total_tp + total_fp)
|
|
171
|
+
|
|
172
|
+
def _calculate_macro_precision(self) -> float:
|
|
173
|
+
"""Calculate macro-averaged precision.
|
|
174
|
+
|
|
175
|
+
Calculates the metric independently for each class and then takes the
|
|
176
|
+
unweighted average. Each class contributes equally.
|
|
177
|
+
"""
|
|
178
|
+
precisions = [
|
|
179
|
+
self.calculate_precision_per_class(clazz) for clazz in self._classes
|
|
180
|
+
]
|
|
181
|
+
return sum(precisions) / len(precisions) if precisions else 0.0
|
|
182
|
+
|
|
183
|
+
def _calculate_weighted_avg_precision(self) -> float:
|
|
184
|
+
"""Calculate weighted-averaged precision.
|
|
185
|
+
|
|
186
|
+
Calculates the metric independently for each class and then takes the average
|
|
187
|
+
weighted by the class support (number of true samples per class).
|
|
188
|
+
"""
|
|
189
|
+
total_support = sum(self._support_per_class.values())
|
|
190
|
+
if total_support == 0:
|
|
191
|
+
return 0.0
|
|
192
|
+
|
|
193
|
+
weighted_sum = 0.0
|
|
194
|
+
for clazz in self._classes:
|
|
195
|
+
precision = self.calculate_precision_per_class(clazz)
|
|
196
|
+
support = self._support_per_class.get(clazz, 0)
|
|
197
|
+
weighted_sum += precision * support
|
|
198
|
+
|
|
199
|
+
return weighted_sum / total_support
|
|
200
|
+
|
|
201
|
+
def calculate_recall(
|
|
202
|
+
self, average: Literal["micro", "macro", "weighted"] = MICRO_AVERAGING_METHOD
|
|
203
|
+
) -> float:
|
|
204
|
+
"""Calculate recall with specified averaging method."""
|
|
205
|
+
if average == MICRO_AVERAGING_METHOD:
|
|
206
|
+
return self._calculate_micro_recall()
|
|
207
|
+
elif average == MACRO_AVERAGING_METHOD:
|
|
208
|
+
return self._calculate_macro_recall()
|
|
209
|
+
elif average == WEIGHTED_AVERAGING_METHOD:
|
|
210
|
+
return self._calculate_weighted_avg_recall()
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(f"Invalid averaging method: {average}")
|
|
213
|
+
|
|
214
|
+
def _calculate_micro_recall(self) -> float:
|
|
215
|
+
"""Calculate micro-averaged recall.
|
|
216
|
+
|
|
217
|
+
Calculates the metric globally by aggregating the total true positives, false
|
|
218
|
+
negatives, across all classes. Each sample contributes equally to the final
|
|
219
|
+
score.
|
|
220
|
+
"""
|
|
221
|
+
total_tp = sum(self._true_positives_per_class.values())
|
|
222
|
+
total_fn = sum(self._false_negatives_per_class.values())
|
|
223
|
+
|
|
224
|
+
if total_tp + total_fn == 0:
|
|
225
|
+
return 0.0
|
|
226
|
+
|
|
227
|
+
return total_tp / (total_tp + total_fn)
|
|
228
|
+
|
|
229
|
+
def _calculate_macro_recall(self) -> float:
|
|
230
|
+
"""Calculate macro-averaged recall.
|
|
231
|
+
|
|
232
|
+
Calculates the metric independently for each class and then takes the
|
|
233
|
+
unweighted average. Each class contributes equally.
|
|
234
|
+
"""
|
|
235
|
+
recalls = [self.calculate_recall_per_class(clazz) for clazz in self._classes]
|
|
236
|
+
return sum(recalls) / len(recalls) if recalls else 0.0
|
|
237
|
+
|
|
238
|
+
def _calculate_weighted_avg_recall(self) -> float:
|
|
239
|
+
"""Calculate weighted-averaged recall.
|
|
240
|
+
|
|
241
|
+
Calculates the metric independently for each class and then takes the average
|
|
242
|
+
weighted by the class support (number of true samples per class).
|
|
243
|
+
"""
|
|
244
|
+
total_support = sum(self._support_per_class.values())
|
|
245
|
+
if total_support == 0:
|
|
246
|
+
return 0.0
|
|
247
|
+
|
|
248
|
+
weighted_sum = 0.0
|
|
249
|
+
for clazz in self._classes:
|
|
250
|
+
recall = self.calculate_recall_per_class(clazz)
|
|
251
|
+
support = self._support_per_class.get(clazz, 0)
|
|
252
|
+
weighted_sum += recall * support
|
|
253
|
+
|
|
254
|
+
return weighted_sum / total_support
|
|
255
|
+
|
|
256
|
+
def calculate_f1(
|
|
257
|
+
self, average: Literal["micro", "macro", "weighted"] = MICRO_AVERAGING_METHOD
|
|
258
|
+
) -> float:
|
|
259
|
+
"""Calculate F1 score with specified averaging method."""
|
|
260
|
+
if average == MICRO_AVERAGING_METHOD:
|
|
261
|
+
return self._calculate_micro_f1()
|
|
262
|
+
elif average == MACRO_AVERAGING_METHOD:
|
|
263
|
+
return self._calculate_macro_f1()
|
|
264
|
+
elif average == WEIGHTED_AVERAGING_METHOD:
|
|
265
|
+
return self._calculate_weighted_avg_f1()
|
|
266
|
+
else:
|
|
267
|
+
raise ValueError(f"Invalid averaging method: {average}")
|
|
268
|
+
|
|
269
|
+
def _calculate_micro_f1(self) -> float:
|
|
270
|
+
"""Calculate micro-averaged F1 score.
|
|
271
|
+
|
|
272
|
+
Calculates the metric globally by aggregating the total true positives, false
|
|
273
|
+
positives, and false negatives across all classes. Each sample contributes
|
|
274
|
+
equally to the final score.
|
|
275
|
+
"""
|
|
276
|
+
micro_precision = self._calculate_micro_precision()
|
|
277
|
+
micro_recall = self._calculate_micro_recall()
|
|
278
|
+
|
|
279
|
+
if micro_precision + micro_recall == 0:
|
|
280
|
+
return 0.0
|
|
281
|
+
|
|
282
|
+
return 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall)
|
|
283
|
+
|
|
284
|
+
def _calculate_macro_f1(self) -> float:
|
|
285
|
+
"""Calculate macro-averaged F1 score.
|
|
286
|
+
|
|
287
|
+
Calculates the metric independently for each class and then takes the
|
|
288
|
+
unweighted average. Each class contributes equally.
|
|
289
|
+
"""
|
|
290
|
+
f1_scores = [self.calculate_f1_per_class(clazz) for clazz in self._classes]
|
|
291
|
+
return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
|
|
292
|
+
|
|
293
|
+
def _calculate_weighted_avg_f1(self) -> float:
|
|
294
|
+
"""Calculate weighted F1 score.
|
|
295
|
+
|
|
296
|
+
Calculates the metric independently for each class and then takes the average
|
|
297
|
+
weighted by the class support (number of true samples per class).
|
|
298
|
+
"""
|
|
299
|
+
total_support = sum(self._support_per_class.values())
|
|
300
|
+
if total_support == 0:
|
|
301
|
+
return 0.0
|
|
302
|
+
|
|
303
|
+
weighted_sum = 0.0
|
|
304
|
+
for clazz in self._classes:
|
|
305
|
+
f1 = self.calculate_f1_per_class(clazz)
|
|
306
|
+
support = self._support_per_class.get(clazz, 0)
|
|
307
|
+
weighted_sum += f1 * support
|
|
308
|
+
|
|
309
|
+
return weighted_sum / total_support
|
|
310
|
+
|
|
311
|
+
def _get_metrics_summary(self) -> MetricsSummary:
|
|
312
|
+
"""Get the metrics summary without Optional wrapper.
|
|
313
|
+
|
|
314
|
+
This method assumes the evaluator has been evaluated and will always
|
|
315
|
+
return a MetricsSummary.
|
|
316
|
+
"""
|
|
317
|
+
# Build per-class metrics
|
|
318
|
+
per_class_metrics: Dict[ResponseCategory, PerClassMetrics] = {}
|
|
319
|
+
for clazz in self._classes:
|
|
320
|
+
per_class_metrics[clazz] = PerClassMetrics(
|
|
321
|
+
precision=self.calculate_precision_per_class(clazz),
|
|
322
|
+
recall=self.calculate_recall_per_class(clazz),
|
|
323
|
+
f1=self.calculate_f1_per_class(clazz),
|
|
324
|
+
support=self._support_per_class.get(clazz, 0),
|
|
325
|
+
true_positives=self._true_positives_per_class.get(clazz, 0),
|
|
326
|
+
false_positives=self._false_positives_per_class.get(clazz, 0),
|
|
327
|
+
false_negatives=self._false_negatives_per_class.get(clazz, 0),
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Build overall metrics
|
|
331
|
+
overall_metrics = OverallClassificationMetrics(
|
|
332
|
+
micro_precision=self.calculate_precision(MICRO_AVERAGING_METHOD),
|
|
333
|
+
macro_precision=self.calculate_precision(MACRO_AVERAGING_METHOD),
|
|
334
|
+
weighted_avg_precision=self.calculate_precision(WEIGHTED_AVERAGING_METHOD),
|
|
335
|
+
micro_recall=self.calculate_recall(MICRO_AVERAGING_METHOD),
|
|
336
|
+
macro_recall=self.calculate_recall(MACRO_AVERAGING_METHOD),
|
|
337
|
+
weighted_avg_recall=self.calculate_recall(WEIGHTED_AVERAGING_METHOD),
|
|
338
|
+
micro_f1=self.calculate_f1(MICRO_AVERAGING_METHOD),
|
|
339
|
+
macro_f1=self.calculate_f1(MACRO_AVERAGING_METHOD),
|
|
340
|
+
weighted_avg_f1=self.calculate_f1(WEIGHTED_AVERAGING_METHOD),
|
|
341
|
+
support=sum(self._support_per_class.values()),
|
|
342
|
+
true_positives=sum(self._true_positives_per_class.values()),
|
|
343
|
+
false_positives=sum(self._false_positives_per_class.values()),
|
|
344
|
+
false_negatives=sum(self._false_negatives_per_class.values()),
|
|
345
|
+
)
|
|
346
|
+
return MetricsSummary(per_class=per_class_metrics, overall=overall_metrics)
|