rasa-pro 3.14.1__py3-none-any.whl → 3.15.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (69) hide show
  1. rasa/builder/config.py +4 -0
  2. rasa/builder/constants.py +5 -0
  3. rasa/builder/copilot/copilot.py +28 -9
  4. rasa/builder/copilot/models.py +251 -32
  5. rasa/builder/document_retrieval/inkeep_document_retrieval.py +2 -0
  6. rasa/builder/download.py +111 -1
  7. rasa/builder/evaluator/__init__.py +0 -0
  8. rasa/builder/evaluator/constants.py +15 -0
  9. rasa/builder/evaluator/copilot_executor.py +89 -0
  10. rasa/builder/evaluator/dataset/models.py +173 -0
  11. rasa/builder/evaluator/exceptions.py +4 -0
  12. rasa/builder/evaluator/response_classification/__init__.py +0 -0
  13. rasa/builder/evaluator/response_classification/constants.py +66 -0
  14. rasa/builder/evaluator/response_classification/evaluator.py +346 -0
  15. rasa/builder/evaluator/response_classification/langfuse_runner.py +463 -0
  16. rasa/builder/evaluator/response_classification/models.py +61 -0
  17. rasa/builder/evaluator/scripts/__init__.py +0 -0
  18. rasa/builder/evaluator/scripts/run_response_classification_evaluator.py +152 -0
  19. rasa/builder/jobs.py +208 -1
  20. rasa/builder/logging_utils.py +25 -24
  21. rasa/builder/main.py +6 -1
  22. rasa/builder/models.py +23 -0
  23. rasa/builder/project_generator.py +29 -10
  24. rasa/builder/service.py +205 -46
  25. rasa/builder/telemetry/__init__.py +0 -0
  26. rasa/builder/telemetry/copilot_langfuse_telemetry.py +384 -0
  27. rasa/builder/{copilot/telemetry.py → telemetry/copilot_segment_telemetry.py} +21 -3
  28. rasa/builder/training_service.py +13 -1
  29. rasa/builder/validation_service.py +2 -1
  30. rasa/constants.py +1 -0
  31. rasa/core/actions/action_clean_stack.py +32 -0
  32. rasa/core/actions/constants.py +4 -0
  33. rasa/core/actions/custom_action_executor.py +70 -12
  34. rasa/core/actions/grpc_custom_action_executor.py +41 -2
  35. rasa/core/actions/http_custom_action_executor.py +49 -25
  36. rasa/core/channels/voice_stream/voice_channel.py +14 -2
  37. rasa/core/policies/flows/flow_executor.py +20 -6
  38. rasa/core/run.py +15 -4
  39. rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
  40. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
  41. rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
  42. rasa/dialogue_understanding/processor/command_processor.py +49 -7
  43. rasa/e2e_test/e2e_config.py +4 -3
  44. rasa/engine/recipes/default_components.py +16 -6
  45. rasa/graph_components/validators/default_recipe_validator.py +10 -4
  46. rasa/nlu/classifiers/diet_classifier.py +2 -0
  47. rasa/shared/core/slots.py +55 -24
  48. rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
  49. rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
  50. rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
  51. rasa/shared/providers/_configs/openai_client_config.py +5 -7
  52. rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
  53. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
  54. rasa/shared/providers/llm/_base_litellm_client.py +42 -14
  55. rasa/shared/providers/llm/litellm_router_llm_client.py +38 -15
  56. rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
  57. rasa/shared/utils/common.py +9 -1
  58. rasa/shared/utils/configs.py +5 -8
  59. rasa/utils/common.py +9 -0
  60. rasa/utils/endpoints.py +8 -0
  61. rasa/utils/installation_utils.py +111 -0
  62. rasa/utils/tensorflow/callback.py +2 -0
  63. rasa/utils/train_utils.py +2 -0
  64. rasa/version.py +1 -1
  65. {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0a3.dist-info}/METADATA +15 -13
  66. {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0a3.dist-info}/RECORD +69 -53
  67. {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0a3.dist-info}/NOTICE +0 -0
  68. {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0a3.dist-info}/WHEEL +0 -0
  69. {rasa_pro-3.14.1.dist-info → rasa_pro-3.15.0a3.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,89 @@
1
+ """Copilot execution utilities for evaluators.
2
+
3
+ This module provides utilities for running copilot operations in evaluation contexts,
4
+ independent of specific evaluation frameworks like Langfuse.
5
+ """
6
+
7
+ from typing import List, Optional
8
+
9
+ import structlog
10
+ from pydantic import BaseModel
11
+
12
+ from rasa.builder.config import COPILOT_HANDLER_ROLLING_BUFFER_SIZE
13
+ from rasa.builder.copilot.models import (
14
+ CopilotContext,
15
+ CopilotGenerationContext,
16
+ GeneratedContent,
17
+ ReferenceSection,
18
+ ResponseCategory,
19
+ )
20
+ from rasa.builder.llm_service import llm_service
21
+
22
+ structlogger = structlog.get_logger()
23
+
24
+
25
+ class CopilotRunResult(BaseModel):
26
+ """Result from running the copilot with response handler."""
27
+
28
+ complete_response: Optional[str]
29
+ response_category: Optional[ResponseCategory]
30
+ reference_section: Optional[ReferenceSection]
31
+ generation_context: CopilotGenerationContext
32
+
33
+
34
+ async def run_copilot_with_response_handler(
35
+ context: CopilotContext,
36
+ ) -> Optional[CopilotRunResult]:
37
+ """Run the copilot with response handler on the given context.
38
+
39
+ This function encapsulates the core copilot execution logic. It handles:
40
+ - Instantiating the copilot and response handler
41
+ - Generating a response and extracting the reference section from the given context
42
+ - Returning structured results
43
+
44
+ Args:
45
+ context: The copilot context to process.
46
+
47
+ Returns:
48
+ CopilotRunResult containing the complete response, category, and generation
49
+ context, or None if execution fails.
50
+
51
+ Raises:
52
+ Any exceptions from the copilot or response handler execution.
53
+ """
54
+ # Instantiate the copilot and response handler
55
+ copilot = llm_service.instantiate_copilot()
56
+ copilot_response_handler = llm_service.instantiate_handler(
57
+ COPILOT_HANDLER_ROLLING_BUFFER_SIZE
58
+ )
59
+
60
+ # Call the copilot to generate a response and handle it with the response
61
+ # handler
62
+ (original_stream, generation_context) = await copilot.generate_response(context)
63
+ intercepted_stream = copilot_response_handler.handle_response(original_stream)
64
+
65
+ # Exhaust the stream to get the complete response for evaluation
66
+ response_chunks: List[str] = []
67
+ response_category = None
68
+ async for chunk in intercepted_stream:
69
+ if not isinstance(chunk, GeneratedContent):
70
+ continue
71
+ response_chunks.append(chunk.content)
72
+ response_category = chunk.response_category
73
+
74
+ complete_response = "".join(response_chunks) if response_chunks else None
75
+
76
+ # Extract the reference section from the response handler
77
+ if generation_context.relevant_documents:
78
+ reference_section = copilot_response_handler.extract_references(
79
+ generation_context.relevant_documents
80
+ )
81
+ else:
82
+ reference_section = None
83
+
84
+ return CopilotRunResult(
85
+ complete_response=complete_response,
86
+ response_category=response_category,
87
+ reference_section=reference_section,
88
+ generation_context=generation_context,
89
+ )
@@ -0,0 +1,173 @@
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import structlog
4
+ from pydantic import BaseModel, Field, field_validator
5
+
6
+ from rasa.builder.copilot.models import (
7
+ ChatMessage,
8
+ CopilotContext,
9
+ EventContent,
10
+ ReferenceEntry,
11
+ ResponseCategory,
12
+ create_chat_message_from_dict,
13
+ )
14
+ from rasa.builder.document_retrieval.models import Document
15
+ from rasa.builder.shared.tracker_context import TrackerContext
16
+
17
+ structlogger = structlog.get_logger()
18
+
19
+
20
+ class DatasetInput(BaseModel):
21
+ """Model for the input field of a dataset entry."""
22
+
23
+ message: Optional[str] = None
24
+ tracker_event_attachments: List[EventContent] = Field(default_factory=list)
25
+
26
+
27
+ class DatasetExpectedOutput(BaseModel):
28
+ """Model for the expected_output field of a dataset entry."""
29
+
30
+ answer: str
31
+ response_category: ResponseCategory
32
+ references: list[ReferenceEntry]
33
+
34
+
35
+ class DatasetMetadataCopilotAdditionalContext(BaseModel):
36
+ """Model for the copilot_additional_context in metadata."""
37
+
38
+ relevant_documents: List[Document] = Field(default_factory=list)
39
+ relevant_assistant_files: Dict[str, str] = Field(default_factory=dict)
40
+ assistant_tracker_context: Optional[Dict[str, Any]] = None
41
+ assistant_logs: str = Field(default="")
42
+ copilot_chat_history: List[ChatMessage] = Field(default_factory=list)
43
+
44
+ @field_validator("copilot_chat_history", mode="before")
45
+ @classmethod
46
+ def parse_chat_history(
47
+ cls, v: Union[List[Dict[str, Any]], List[ChatMessage]]
48
+ ) -> List[ChatMessage]:
49
+ """Manually parse chat history messages based on role field."""
50
+ # If already parsed ChatMessage objects, return them as-is
51
+ if (
52
+ v
53
+ and isinstance(v, list)
54
+ and all(isinstance(item, ChatMessage) for item in v)
55
+ ):
56
+ return v # type: ignore[return-value]
57
+
58
+ # Check for mixed types (some ChatMessage, some not)
59
+ if (
60
+ v
61
+ and isinstance(v, list)
62
+ and any(isinstance(item, ChatMessage) for item in v)
63
+ ):
64
+ message = (
65
+ "Mixed types in copilot_chat_history: cannot mix ChatMessage objects "
66
+ "with other types."
67
+ )
68
+ structlogger.error(
69
+ "dataset_entry.parse_chat_history.mixed_types",
70
+ event_info=message,
71
+ chat_history_types=[type(item) for item in v],
72
+ )
73
+ raise ValueError(message)
74
+
75
+ # Otherwise, parse from dictionaries
76
+ parsed_messages: List[ChatMessage] = []
77
+ for message_data in v:
78
+ chat_message = create_chat_message_from_dict(message_data)
79
+ parsed_messages.append(chat_message)
80
+ return parsed_messages
81
+
82
+
83
+ class DatasetMetadata(BaseModel):
84
+ """Model for the metadata field of a dataset entry."""
85
+
86
+ ids: Dict[str, str] = Field(default_factory=dict)
87
+ copilot_additional_context: DatasetMetadataCopilotAdditionalContext = Field(
88
+ default_factory=DatasetMetadataCopilotAdditionalContext
89
+ )
90
+
91
+
92
+ class DatasetEntry(BaseModel):
93
+ """Pydantic model for dataset entries from Langfuse ExperimentItem."""
94
+
95
+ # Basic fields from ExperimentItem
96
+ id: str
97
+ input: DatasetInput
98
+ expected_output: DatasetExpectedOutput
99
+ metadata: DatasetMetadata
100
+
101
+ def to_copilot_context(self) -> CopilotContext:
102
+ """Create a CopilotContext from the dataset entry.
103
+
104
+ Raises:
105
+ ValueError: If the metadata is None, as it's required for creating a valid
106
+ CopilotContext.
107
+
108
+ Returns:
109
+ CopilotContext with all the context information.
110
+ """
111
+ if self.metadata is None:
112
+ message = (
113
+ f"Cannot create CopilotContext from dataset item with id: {self.id}. "
114
+ f"Metadata is required but was None."
115
+ )
116
+ structlogger.error(
117
+ "dataset_entry.to_copilot_context.metadata_is_none",
118
+ event_info=message,
119
+ item_id=self.id,
120
+ item_metadata=self.metadata,
121
+ )
122
+ raise ValueError(message)
123
+
124
+ # Parse tracker context if available
125
+ tracker_context = None
126
+ if (
127
+ self.metadata.copilot_additional_context.assistant_tracker_context
128
+ is not None
129
+ ):
130
+ tracker_context = TrackerContext(
131
+ **self.metadata.copilot_additional_context.assistant_tracker_context
132
+ )
133
+
134
+ return CopilotContext(
135
+ tracker_context=tracker_context,
136
+ assistant_logs=self.metadata.copilot_additional_context.assistant_logs,
137
+ assistant_files=self.metadata.copilot_additional_context.relevant_assistant_files,
138
+ copilot_chat_history=self.metadata.copilot_additional_context.copilot_chat_history,
139
+ )
140
+
141
+ @classmethod
142
+ def from_raw_data(
143
+ cls,
144
+ id: str,
145
+ input_data: Dict[str, Any],
146
+ expected_output_data: Dict[str, Any],
147
+ metadata_data: Dict[str, Any],
148
+ ) -> "DatasetEntry":
149
+ """Create a DatasetEntry from raw dictionary data.
150
+
151
+ Args:
152
+ id: The dataset entry ID.
153
+ input_data: Raw input dictionary.
154
+ expected_output_data: Raw expected output dictionary.
155
+ metadata_data: Raw metadata dictionary with all the additional context
156
+ used to generate the Copilot response.
157
+
158
+ Returns:
159
+ DatasetEntry with parsed data.
160
+ """
161
+ # Use Pydantic's model_validate to parse nested structures
162
+ dataset_input = DatasetInput.model_validate(input_data)
163
+ dataset_expected_output = DatasetExpectedOutput.model_validate(
164
+ expected_output_data
165
+ )
166
+ dataset_metadata = DatasetMetadata.model_validate(metadata_data)
167
+
168
+ return cls(
169
+ id=id,
170
+ input=dataset_input,
171
+ expected_output=dataset_expected_output,
172
+ metadata=dataset_metadata,
173
+ )
@@ -0,0 +1,4 @@
1
+ class EvaluationError(Exception):
2
+ """Base exception for evaluation-related errors."""
3
+
4
+ pass
@@ -0,0 +1,66 @@
1
+ """Constants for the response classification evaluator."""
2
+
3
+ from typing import List, Literal
4
+
5
+ # Options for averaging methods for Response Classification Evaluator
6
+ MICRO_AVERAGING_METHOD: Literal["micro"] = "micro"
7
+ MACRO_AVERAGING_METHOD: Literal["macro"] = "macro"
8
+ WEIGHTED_AVERAGING_METHOD: Literal["weighted"] = "weighted"
9
+
10
+ AVERAGING_METHODS: List[Literal["micro", "macro", "weighted"]] = [
11
+ MICRO_AVERAGING_METHOD,
12
+ MACRO_AVERAGING_METHOD,
13
+ WEIGHTED_AVERAGING_METHOD,
14
+ ]
15
+
16
+ # Overall evaluation metric names
17
+ MICRO_PRECISION_METRIC = "micro_precision"
18
+ MACRO_PRECISION_METRIC = "macro_precision"
19
+ WEIGHTED_PRECISION_METRIC = "weighted_precision"
20
+
21
+ MICRO_RECALL_METRIC = "micro_recall"
22
+ MACRO_RECALL_METRIC = "macro_recall"
23
+ WEIGHTED_RECALL_METRIC = "weighted_recall"
24
+
25
+ MICRO_F1_METRIC = "micro_f1"
26
+ MACRO_F1_METRIC = "macro_f1"
27
+ WEIGHTED_F1_METRIC = "weighted_f1"
28
+
29
+ # Skip count metric name due to invalid data
30
+ SKIP_COUNT_METRIC = "skipped_items"
31
+
32
+ # Per-class evaluation metric name templates
33
+ PER_CLASS_PRECISION_METRIC_TEMPLATE = "{category}_precision"
34
+ PER_CLASS_RECALL_METRIC_TEMPLATE = "{category}_recall"
35
+ PER_CLASS_F1_METRIC_TEMPLATE = "{category}_f1"
36
+ PER_CLASS_SUPPORT_METRIC_TEMPLATE = "{category}_support"
37
+
38
+ # Description templates for evaluation metrics
39
+ MICRO_PRECISION_DESCRIPTION = "Micro Precision: {value:.3f}"
40
+ MACRO_PRECISION_DESCRIPTION = "Macro Precision: {value:.3f}"
41
+ WEIGHTED_PRECISION_DESCRIPTION = "Weighted Precision: {value:.3f}"
42
+
43
+ MICRO_RECALL_DESCRIPTION = "Micro Recall: {value:.3f}"
44
+ MACRO_RECALL_DESCRIPTION = "Macro Recall: {value:.3f}"
45
+ WEIGHTED_RECALL_DESCRIPTION = "Weighted Recall: {value:.3f}"
46
+
47
+ MICRO_F1_DESCRIPTION = "Micro F1: {value:.3f}"
48
+ MACRO_F1_DESCRIPTION = "Macro F1: {value:.3f}"
49
+ WEIGHTED_F1_DESCRIPTION = "Weighted F1: {value:.3f}"
50
+
51
+ # Skip count metric description
52
+ SKIP_COUNT_DESCRIPTION = "Skipped {value} items due to invalid data"
53
+
54
+ # Per-class description templates
55
+ PER_CLASS_PRECISION_DESCRIPTION = "[{category}] Precision: {value:.3f}"
56
+ PER_CLASS_RECALL_DESCRIPTION = "[{category}] Recall: {value:.3f}"
57
+ PER_CLASS_F1_DESCRIPTION = "[{category}] F1: {value:.3f}"
58
+ PER_CLASS_SUPPORT_DESCRIPTION = "[{category}] Support: {value}"
59
+
60
+ # Experiment configuration
61
+ EXPERIMENT_NAME = "Copilot Response Classification Evaluation"
62
+ EXPERIMENT_DESCRIPTION = (
63
+ "Evaluating Copilot response classification performance with per-class metrics "
64
+ "and overall averages (micro, macro, weighted). The metric that are reported are: "
65
+ "precision, recall, F1, support."
66
+ )
@@ -0,0 +1,346 @@
1
+ from typing import Dict, List, Literal, Optional
2
+
3
+ import structlog
4
+
5
+ from rasa.builder.copilot.models import ResponseCategory
6
+ from rasa.builder.evaluator.response_classification.constants import (
7
+ MACRO_AVERAGING_METHOD,
8
+ MICRO_AVERAGING_METHOD,
9
+ WEIGHTED_AVERAGING_METHOD,
10
+ )
11
+ from rasa.builder.evaluator.response_classification.models import (
12
+ ClassificationResult,
13
+ MetricsSummary,
14
+ OverallClassificationMetrics,
15
+ PerClassMetrics,
16
+ )
17
+
18
+ structlogger = structlog.get_logger()
19
+
20
+
21
+ class ResponseClassificationEvaluator:
22
+ def __init__(self): # type: ignore[no-untyped-def]
23
+ self._classes: List[ResponseCategory] = [
24
+ ResponseCategory.COPILOT,
25
+ ResponseCategory.OUT_OF_SCOPE_DETECTION,
26
+ ResponseCategory.ROLEPLAY_DETECTION,
27
+ ResponseCategory.KNOWLEDGE_BASE_ACCESS_REQUESTED,
28
+ ResponseCategory.ERROR_FALLBACK,
29
+ # TODO: Add the greetings and goodbyes as support once the orchestrator
30
+ # aproach is implemented
31
+ ]
32
+ self._true_positives_per_class: Dict[ResponseCategory, int] = {
33
+ clazz: 0 for clazz in self._classes
34
+ }
35
+ self._false_positives_per_class: Dict[ResponseCategory, int] = {
36
+ clazz: 0 for clazz in self._classes
37
+ }
38
+ self._false_negatives_per_class: Dict[ResponseCategory, int] = {
39
+ clazz: 0 for clazz in self._classes
40
+ }
41
+ self._support_per_class: Dict[ResponseCategory, int] = {
42
+ clazz: 0 for clazz in self._classes
43
+ }
44
+
45
+ self._evaluated = False
46
+
47
+ @property
48
+ def metrics_summary(self) -> Optional[MetricsSummary]:
49
+ """Get the metrics summary.
50
+
51
+ Returns:
52
+ MetricsSummary with structured per-class and overall metrics if
53
+ the evaluator has been run on the data, otherwise None.
54
+ """
55
+ if not self._evaluated:
56
+ structlogger.warning(
57
+ "evaluator.response_classification_evaluator"
58
+ ".metrics_summary.not_evaluated",
59
+ event_info="Evaluator not evaluated. Returning empty metrics summary.",
60
+ )
61
+ return None
62
+
63
+ return self._get_metrics_summary()
64
+
65
+ def reset(self) -> None:
66
+ self._true_positives_per_class = {clazz: 0 for clazz in self._classes}
67
+ self._false_positives_per_class = {clazz: 0 for clazz in self._classes}
68
+ self._false_negatives_per_class = {clazz: 0 for clazz in self._classes}
69
+ self._support_per_class = {clazz: 0 for clazz in self._classes}
70
+ self._evaluated = False
71
+
72
+ def evaluate(self, item_results: List[ClassificationResult]) -> MetricsSummary:
73
+ """Evaluate the classifier on the given item results."""
74
+ if self._evaluated:
75
+ structlogger.warning(
76
+ "evaluator.response_classification_evaluator.evaluate.already_evaluated",
77
+ event_info="Evaluator already evaluated. Resetting evaluator.",
78
+ )
79
+ self.reset()
80
+
81
+ for result in item_results:
82
+ # Skip and raise a warning if the class is not in the list of classes
83
+ if result.expected not in self._classes:
84
+ structlogger.warning(
85
+ "evaluator.response_classification_evaluator"
86
+ ".evaluate.class_not_recognized",
87
+ event_info=(
88
+ f"Class '{result.expected}' is not recognized. "
89
+ f"Skipping evaluation for this class."
90
+ ),
91
+ expected_class=result.expected,
92
+ classes=self._classes,
93
+ )
94
+ continue
95
+
96
+ # Update support for the expected class
97
+ if result.expected in self._support_per_class:
98
+ self._support_per_class[result.expected] += 1
99
+
100
+ # Calculate TP, FP, FN per class
101
+ for clazz in self._classes:
102
+ if result.prediction == clazz and result.expected == clazz:
103
+ self._true_positives_per_class[clazz] += 1
104
+
105
+ elif result.prediction == clazz and result.expected != clazz:
106
+ self._false_positives_per_class[clazz] += 1
107
+
108
+ elif result.prediction != clazz and result.expected == clazz:
109
+ self._false_negatives_per_class[clazz] += 1
110
+
111
+ self._evaluated = True
112
+ return self._get_metrics_summary()
113
+
114
+ def calculate_precision_per_class(self, clazz: ResponseCategory) -> float:
115
+ """Calculate precision for a specific response category."""
116
+ tp = self._true_positives_per_class.get(clazz, 0)
117
+ fp = self._false_positives_per_class.get(clazz, 0)
118
+
119
+ if tp + fp == 0:
120
+ return 0.0
121
+
122
+ return tp / (tp + fp)
123
+
124
+ def calculate_recall_per_class(self, clazz: ResponseCategory) -> float:
125
+ """Calculate recall for a specific response category."""
126
+ tp = self._true_positives_per_class.get(clazz, 0)
127
+ fn = self._false_negatives_per_class.get(clazz, 0)
128
+
129
+ if tp + fn == 0:
130
+ return 0.0
131
+
132
+ return tp / (tp + fn)
133
+
134
+ def calculate_f1_per_class(self, clazz: ResponseCategory) -> float:
135
+ """Calculate F1 score for a specific response category."""
136
+ precision = self.calculate_precision_per_class(clazz)
137
+ recall = self.calculate_recall_per_class(clazz)
138
+
139
+ if precision + recall == 0:
140
+ return 0.0
141
+
142
+ return 2 * (precision * recall) / (precision + recall)
143
+
144
+ def calculate_precision(
145
+ self, average: Literal["micro", "macro", "weighted"] = MICRO_AVERAGING_METHOD
146
+ ) -> float:
147
+ """Calculate precision with specified averaging method."""
148
+ if average == MICRO_AVERAGING_METHOD:
149
+ return self._calculate_micro_precision()
150
+ elif average == MACRO_AVERAGING_METHOD:
151
+ return self._calculate_macro_precision()
152
+ elif average == WEIGHTED_AVERAGING_METHOD:
153
+ return self._calculate_weighted_avg_precision()
154
+ else:
155
+ raise ValueError(f"Invalid averaging method: {average}")
156
+
157
+ def _calculate_micro_precision(self) -> float:
158
+ """Calculate overall precision with specified averaging method.
159
+
160
+ Calculates the metric globally by aggregating the total true positives, false
161
+ positives, across all classes. Each sample contributes equally to the final
162
+ score.
163
+ """
164
+ total_tp = sum(self._true_positives_per_class.values())
165
+ total_fp = sum(self._false_positives_per_class.values())
166
+
167
+ if total_tp + total_fp == 0:
168
+ return 0.0
169
+
170
+ return total_tp / (total_tp + total_fp)
171
+
172
+ def _calculate_macro_precision(self) -> float:
173
+ """Calculate macro-averaged precision.
174
+
175
+ Calculates the metric independently for each class and then takes the
176
+ unweighted average. Each class contributes equally.
177
+ """
178
+ precisions = [
179
+ self.calculate_precision_per_class(clazz) for clazz in self._classes
180
+ ]
181
+ return sum(precisions) / len(precisions) if precisions else 0.0
182
+
183
+ def _calculate_weighted_avg_precision(self) -> float:
184
+ """Calculate weighted-averaged precision.
185
+
186
+ Calculates the metric independently for each class and then takes the average
187
+ weighted by the class support (number of true samples per class).
188
+ """
189
+ total_support = sum(self._support_per_class.values())
190
+ if total_support == 0:
191
+ return 0.0
192
+
193
+ weighted_sum = 0.0
194
+ for clazz in self._classes:
195
+ precision = self.calculate_precision_per_class(clazz)
196
+ support = self._support_per_class.get(clazz, 0)
197
+ weighted_sum += precision * support
198
+
199
+ return weighted_sum / total_support
200
+
201
+ def calculate_recall(
202
+ self, average: Literal["micro", "macro", "weighted"] = MICRO_AVERAGING_METHOD
203
+ ) -> float:
204
+ """Calculate recall with specified averaging method."""
205
+ if average == MICRO_AVERAGING_METHOD:
206
+ return self._calculate_micro_recall()
207
+ elif average == MACRO_AVERAGING_METHOD:
208
+ return self._calculate_macro_recall()
209
+ elif average == WEIGHTED_AVERAGING_METHOD:
210
+ return self._calculate_weighted_avg_recall()
211
+ else:
212
+ raise ValueError(f"Invalid averaging method: {average}")
213
+
214
+ def _calculate_micro_recall(self) -> float:
215
+ """Calculate micro-averaged recall.
216
+
217
+ Calculates the metric globally by aggregating the total true positives, false
218
+ negatives, across all classes. Each sample contributes equally to the final
219
+ score.
220
+ """
221
+ total_tp = sum(self._true_positives_per_class.values())
222
+ total_fn = sum(self._false_negatives_per_class.values())
223
+
224
+ if total_tp + total_fn == 0:
225
+ return 0.0
226
+
227
+ return total_tp / (total_tp + total_fn)
228
+
229
+ def _calculate_macro_recall(self) -> float:
230
+ """Calculate macro-averaged recall.
231
+
232
+ Calculates the metric independently for each class and then takes the
233
+ unweighted average. Each class contributes equally.
234
+ """
235
+ recalls = [self.calculate_recall_per_class(clazz) for clazz in self._classes]
236
+ return sum(recalls) / len(recalls) if recalls else 0.0
237
+
238
+ def _calculate_weighted_avg_recall(self) -> float:
239
+ """Calculate weighted-averaged recall.
240
+
241
+ Calculates the metric independently for each class and then takes the average
242
+ weighted by the class support (number of true samples per class).
243
+ """
244
+ total_support = sum(self._support_per_class.values())
245
+ if total_support == 0:
246
+ return 0.0
247
+
248
+ weighted_sum = 0.0
249
+ for clazz in self._classes:
250
+ recall = self.calculate_recall_per_class(clazz)
251
+ support = self._support_per_class.get(clazz, 0)
252
+ weighted_sum += recall * support
253
+
254
+ return weighted_sum / total_support
255
+
256
+ def calculate_f1(
257
+ self, average: Literal["micro", "macro", "weighted"] = MICRO_AVERAGING_METHOD
258
+ ) -> float:
259
+ """Calculate F1 score with specified averaging method."""
260
+ if average == MICRO_AVERAGING_METHOD:
261
+ return self._calculate_micro_f1()
262
+ elif average == MACRO_AVERAGING_METHOD:
263
+ return self._calculate_macro_f1()
264
+ elif average == WEIGHTED_AVERAGING_METHOD:
265
+ return self._calculate_weighted_avg_f1()
266
+ else:
267
+ raise ValueError(f"Invalid averaging method: {average}")
268
+
269
+ def _calculate_micro_f1(self) -> float:
270
+ """Calculate micro-averaged F1 score.
271
+
272
+ Calculates the metric globally by aggregating the total true positives, false
273
+ positives, and false negatives across all classes. Each sample contributes
274
+ equally to the final score.
275
+ """
276
+ micro_precision = self._calculate_micro_precision()
277
+ micro_recall = self._calculate_micro_recall()
278
+
279
+ if micro_precision + micro_recall == 0:
280
+ return 0.0
281
+
282
+ return 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall)
283
+
284
+ def _calculate_macro_f1(self) -> float:
285
+ """Calculate macro-averaged F1 score.
286
+
287
+ Calculates the metric independently for each class and then takes the
288
+ unweighted average. Each class contributes equally.
289
+ """
290
+ f1_scores = [self.calculate_f1_per_class(clazz) for clazz in self._classes]
291
+ return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
292
+
293
+ def _calculate_weighted_avg_f1(self) -> float:
294
+ """Calculate weighted F1 score.
295
+
296
+ Calculates the metric independently for each class and then takes the average
297
+ weighted by the class support (number of true samples per class).
298
+ """
299
+ total_support = sum(self._support_per_class.values())
300
+ if total_support == 0:
301
+ return 0.0
302
+
303
+ weighted_sum = 0.0
304
+ for clazz in self._classes:
305
+ f1 = self.calculate_f1_per_class(clazz)
306
+ support = self._support_per_class.get(clazz, 0)
307
+ weighted_sum += f1 * support
308
+
309
+ return weighted_sum / total_support
310
+
311
+ def _get_metrics_summary(self) -> MetricsSummary:
312
+ """Get the metrics summary without Optional wrapper.
313
+
314
+ This method assumes the evaluator has been evaluated and will always
315
+ return a MetricsSummary.
316
+ """
317
+ # Build per-class metrics
318
+ per_class_metrics: Dict[ResponseCategory, PerClassMetrics] = {}
319
+ for clazz in self._classes:
320
+ per_class_metrics[clazz] = PerClassMetrics(
321
+ precision=self.calculate_precision_per_class(clazz),
322
+ recall=self.calculate_recall_per_class(clazz),
323
+ f1=self.calculate_f1_per_class(clazz),
324
+ support=self._support_per_class.get(clazz, 0),
325
+ true_positives=self._true_positives_per_class.get(clazz, 0),
326
+ false_positives=self._false_positives_per_class.get(clazz, 0),
327
+ false_negatives=self._false_negatives_per_class.get(clazz, 0),
328
+ )
329
+
330
+ # Build overall metrics
331
+ overall_metrics = OverallClassificationMetrics(
332
+ micro_precision=self.calculate_precision(MICRO_AVERAGING_METHOD),
333
+ macro_precision=self.calculate_precision(MACRO_AVERAGING_METHOD),
334
+ weighted_avg_precision=self.calculate_precision(WEIGHTED_AVERAGING_METHOD),
335
+ micro_recall=self.calculate_recall(MICRO_AVERAGING_METHOD),
336
+ macro_recall=self.calculate_recall(MACRO_AVERAGING_METHOD),
337
+ weighted_avg_recall=self.calculate_recall(WEIGHTED_AVERAGING_METHOD),
338
+ micro_f1=self.calculate_f1(MICRO_AVERAGING_METHOD),
339
+ macro_f1=self.calculate_f1(MACRO_AVERAGING_METHOD),
340
+ weighted_avg_f1=self.calculate_f1(WEIGHTED_AVERAGING_METHOD),
341
+ support=sum(self._support_per_class.values()),
342
+ true_positives=sum(self._true_positives_per_class.values()),
343
+ false_positives=sum(self._false_positives_per_class.values()),
344
+ false_negatives=sum(self._false_negatives_per_class.values()),
345
+ )
346
+ return MetricsSummary(per_class=per_class_metrics, overall=overall_metrics)