rasa-pro 3.14.0rc4__py3-none-any.whl → 3.15.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/agents/agent_manager.py +7 -5
- rasa/agents/protocol/a2a/a2a_agent.py +13 -11
- rasa/agents/protocol/mcp/mcp_base_agent.py +49 -11
- rasa/agents/validation.py +4 -2
- rasa/builder/config.py +4 -0
- rasa/builder/copilot/copilot.py +28 -9
- rasa/builder/copilot/copilot_templated_message_provider.py +1 -1
- rasa/builder/copilot/models.py +171 -4
- rasa/builder/document_retrieval/inkeep_document_retrieval.py +2 -0
- rasa/builder/download.py +1 -1
- rasa/builder/service.py +101 -24
- rasa/builder/telemetry/__init__.py +0 -0
- rasa/builder/telemetry/copilot_langfuse_telemetry.py +384 -0
- rasa/builder/{copilot/telemetry.py → telemetry/copilot_segment_telemetry.py} +21 -3
- rasa/builder/validation_service.py +4 -0
- rasa/cli/arguments/data.py +9 -0
- rasa/cli/data.py +72 -6
- rasa/cli/interactive.py +3 -0
- rasa/cli/llm_fine_tuning.py +1 -0
- rasa/cli/project_templates/defaults.py +1 -0
- rasa/cli/validation/bot_config.py +2 -0
- rasa/constants.py +2 -1
- rasa/core/actions/action_exceptions.py +1 -1
- rasa/core/agent.py +4 -1
- rasa/core/available_agents.py +1 -1
- rasa/core/exceptions.py +1 -1
- rasa/core/featurizers/tracker_featurizers.py +3 -2
- rasa/core/persistor.py +7 -7
- rasa/core/policies/flows/agent_executor.py +84 -4
- rasa/core/policies/flows/flow_exceptions.py +5 -2
- rasa/core/policies/flows/flow_executor.py +23 -8
- rasa/core/policies/flows/mcp_tool_executor.py +7 -1
- rasa/core/policies/rule_policy.py +1 -1
- rasa/core/run.py +15 -4
- rasa/dialogue_understanding/commands/cancel_flow_command.py +1 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +1 -1
- rasa/e2e_test/e2e_config.py +4 -3
- rasa/engine/recipes/default_components.py +16 -6
- rasa/graph_components/validators/default_recipe_validator.py +10 -4
- rasa/model_manager/runner_service.py +1 -1
- rasa/nlu/classifiers/diet_classifier.py +2 -0
- rasa/privacy/privacy_config.py +1 -1
- rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +4 -7
- rasa/shared/core/slots.py +55 -24
- rasa/shared/core/training_data/story_reader/story_reader.py +1 -1
- rasa/shared/exceptions.py +23 -2
- rasa/shared/providers/llm/litellm_router_llm_client.py +2 -2
- rasa/shared/utils/common.py +9 -1
- rasa/shared/utils/llm.py +21 -4
- rasa/shared/utils/mcp/server_connection.py +7 -4
- rasa/studio/download.py +3 -0
- rasa/studio/prompts.py +1 -0
- rasa/studio/upload.py +4 -0
- rasa/utils/common.py +9 -0
- rasa/utils/endpoints.py +2 -0
- rasa/utils/installation_utils.py +111 -0
- rasa/utils/log_utils.py +20 -1
- rasa/utils/tensorflow/callback.py +2 -0
- rasa/utils/train_utils.py +2 -0
- rasa/version.py +1 -1
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.15.0a1.dist-info}/METADATA +4 -2
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.15.0a1.dist-info}/RECORD +65 -62
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.15.0a1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.15.0a1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.0rc4.dist-info → rasa_pro-3.15.0a1.dist-info}/entry_points.txt +0 -0
rasa/agents/agent_manager.py
CHANGED
|
@@ -15,7 +15,7 @@ structlogger = structlog.get_logger()
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class AgentManager(metaclass=Singleton):
|
|
18
|
-
"""High-level agent management with protocol abstraction"""
|
|
18
|
+
"""High-level agent management with protocol abstraction."""
|
|
19
19
|
|
|
20
20
|
agents: ClassVar[Dict[AgentIdentifier, AgentProtocol]] = {}
|
|
21
21
|
|
|
@@ -66,9 +66,11 @@ class AgentManager(metaclass=Singleton):
|
|
|
66
66
|
async def connect_agent(
|
|
67
67
|
self, agent_name: str, protocol_type: ProtocolType, config: AgentConfig
|
|
68
68
|
) -> None:
|
|
69
|
-
"""Connect to agent using specified protocol
|
|
70
|
-
|
|
71
|
-
|
|
69
|
+
"""Connect to agent using specified protocol.
|
|
70
|
+
|
|
71
|
+
Also, load the default resources and persist the agent to the manager
|
|
72
|
+
in a ready-to-use state so that it can be used immediately
|
|
73
|
+
to send messages to the agent.
|
|
72
74
|
|
|
73
75
|
Args:
|
|
74
76
|
agent_name: The name of the agent.
|
|
@@ -109,7 +111,7 @@ class AgentManager(metaclass=Singleton):
|
|
|
109
111
|
agent_id=str(agent_identifier),
|
|
110
112
|
event_info=event_info,
|
|
111
113
|
)
|
|
112
|
-
raise AgentInitializationException(e) from e
|
|
114
|
+
raise AgentInitializationException(e, suppress_stack_trace=True) from e
|
|
113
115
|
|
|
114
116
|
async def run_agent(
|
|
115
117
|
self, agent_name: str, protocol_type: ProtocolType, context: AgentInput
|
|
@@ -65,7 +65,7 @@ structlogger = structlog.get_logger()
|
|
|
65
65
|
|
|
66
66
|
|
|
67
67
|
class A2AAgent(AgentProtocol):
|
|
68
|
-
"""A2A client implementation"""
|
|
68
|
+
"""A2A client implementation."""
|
|
69
69
|
|
|
70
70
|
__SUPPORTED_OUTPUT_MODES: ClassVar[list[str]] = [
|
|
71
71
|
"text",
|
|
@@ -169,7 +169,8 @@ class A2AAgent(AgentProtocol):
|
|
|
169
169
|
error=str(exception),
|
|
170
170
|
)
|
|
171
171
|
raise AgentInitializationException(
|
|
172
|
-
f"Failed to initialize A2A client
|
|
172
|
+
f"Failed to initialize A2A client "
|
|
173
|
+
f"for agent '{self._name}': {exception}",
|
|
173
174
|
) from exception
|
|
174
175
|
|
|
175
176
|
await self._perform_health_check()
|
|
@@ -180,7 +181,7 @@ class A2AAgent(AgentProtocol):
|
|
|
180
181
|
)
|
|
181
182
|
|
|
182
183
|
async def disconnect(self) -> None:
|
|
183
|
-
"""We don't need to explicitly disconnect the A2A client"""
|
|
184
|
+
"""We don't need to explicitly disconnect the A2A client."""
|
|
184
185
|
return
|
|
185
186
|
|
|
186
187
|
# ============================================================================
|
|
@@ -297,7 +298,7 @@ class A2AAgent(AgentProtocol):
|
|
|
297
298
|
def _handle_send_message_response(
|
|
298
299
|
self, agent_input: AgentInput, response: ClientEvent | Message
|
|
299
300
|
) -> Optional[AgentOutput]:
|
|
300
|
-
"""Handle possible response types from the A2A client
|
|
301
|
+
"""Handle possible response types from the A2A client.
|
|
301
302
|
|
|
302
303
|
In case of streaming, the response can be either exactly *one* Message,
|
|
303
304
|
or a *series* of tuples of (Task, Optional[TaskUpdateEvent]).
|
|
@@ -410,8 +411,8 @@ class A2AAgent(AgentProtocol):
|
|
|
410
411
|
agent_input: AgentInput,
|
|
411
412
|
task: Task,
|
|
412
413
|
) -> Optional[AgentOutput]:
|
|
413
|
-
"""If
|
|
414
|
-
|
|
414
|
+
"""If task status is terminal (e.g. completed, failed) return AgentOutput.
|
|
415
|
+
|
|
415
416
|
If the task is still in progress (i.e., submitted, working), return None,
|
|
416
417
|
so that the streaming or pooling agent can continue to wait for updates.
|
|
417
418
|
"""
|
|
@@ -655,6 +656,7 @@ class A2AAgent(AgentProtocol):
|
|
|
655
656
|
@staticmethod
|
|
656
657
|
def _generate_completed_response_message(task: Task) -> str:
|
|
657
658
|
"""Generate a response message for a completed task.
|
|
659
|
+
|
|
658
660
|
In case of completed tasks, the final message might be in
|
|
659
661
|
the task status message or in the artifacts (or both).
|
|
660
662
|
"""
|
|
@@ -728,19 +730,19 @@ class A2AAgent(AgentProtocol):
|
|
|
728
730
|
|
|
729
731
|
except FileNotFoundError as e:
|
|
730
732
|
raise AgentInitializationException(
|
|
731
|
-
f"Agent card file not found: {agent_card_path}"
|
|
733
|
+
f"Agent card file not found: {agent_card_path}",
|
|
732
734
|
) from e
|
|
733
735
|
except (IOError, PermissionError) as e:
|
|
734
736
|
raise AgentInitializationException(
|
|
735
|
-
f"Error reading agent card file {agent_card_path}: {e}"
|
|
737
|
+
f"Error reading agent card file {agent_card_path}: {e}",
|
|
736
738
|
) from e
|
|
737
739
|
except json.JSONDecodeError as e:
|
|
738
740
|
raise AgentInitializationException(
|
|
739
|
-
f"Invalid JSON in agent card file {agent_card_path}: {e}"
|
|
741
|
+
f"Invalid JSON in agent card file {agent_card_path}: {e}",
|
|
740
742
|
) from e
|
|
741
743
|
except ValidationError as e:
|
|
742
744
|
raise AgentInitializationException(
|
|
743
|
-
f"Failed to load agent card from {agent_card_path}: {e}"
|
|
745
|
+
f"Failed to load agent card from {agent_card_path}: {e}",
|
|
744
746
|
) from e
|
|
745
747
|
|
|
746
748
|
@staticmethod
|
|
@@ -798,7 +800,7 @@ class A2AAgent(AgentProtocol):
|
|
|
798
800
|
|
|
799
801
|
raise AgentInitializationException(
|
|
800
802
|
f"Failed to resolve agent card from {agent_card_path} after "
|
|
801
|
-
f"{max_retries} attempts."
|
|
803
|
+
f"{max_retries} attempts.",
|
|
802
804
|
)
|
|
803
805
|
|
|
804
806
|
# ============================================================================
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import abstractmethod
|
|
3
|
-
from datetime import datetime
|
|
3
|
+
from datetime import datetime, timedelta
|
|
4
4
|
from inspect import isawaitable
|
|
5
5
|
from typing import Any, Dict, List, Optional, Tuple
|
|
6
6
|
|
|
7
|
+
import anyio
|
|
7
8
|
import structlog
|
|
8
9
|
from jinja2 import Template
|
|
9
10
|
from mcp import ListToolsResult
|
|
@@ -75,6 +76,8 @@ class MCPBaseAgent(AgentProtocol):
|
|
|
75
76
|
|
|
76
77
|
MAX_ITERATIONS = 10
|
|
77
78
|
|
|
79
|
+
TOOL_CALL_DEFAULT_TIMEOUT = 10 # seconds
|
|
80
|
+
|
|
78
81
|
# ============================================================================
|
|
79
82
|
# Initialization & Setup
|
|
80
83
|
# ============================================================================
|
|
@@ -288,14 +291,16 @@ class MCPBaseAgent(AgentProtocol):
|
|
|
288
291
|
event_info="All connection attempts failed.",
|
|
289
292
|
)
|
|
290
293
|
raise AgentInitializationException(
|
|
291
|
-
f"
|
|
292
|
-
f"
|
|
294
|
+
f"Agent `{self._name}` failed to initialize. Failed to connect "
|
|
295
|
+
f"to MCP servers after {self._max_retries} attempts. {ce!s}"
|
|
293
296
|
) from ce
|
|
294
297
|
except (Exception, AuthenticationError) as e:
|
|
295
298
|
if isinstance(e, AuthenticationError):
|
|
296
|
-
event_info =
|
|
299
|
+
event_info = (
|
|
300
|
+
f"Authentication error during agent initialization. {e!s}"
|
|
301
|
+
)
|
|
297
302
|
else:
|
|
298
|
-
event_info = "Unexpected error during agent initialization."
|
|
303
|
+
event_info = f"Unexpected error during agent initialization. {e!s}"
|
|
299
304
|
structlogger.error(
|
|
300
305
|
"mcp_agent.connect.unexpected_exception",
|
|
301
306
|
event_info=event_info,
|
|
@@ -303,7 +308,7 @@ class MCPBaseAgent(AgentProtocol):
|
|
|
303
308
|
agent_name=self._name,
|
|
304
309
|
agent_id=str(make_agent_identifier(self._name, self.protocol_type)),
|
|
305
310
|
)
|
|
306
|
-
raise AgentInitializationException(
|
|
311
|
+
raise AgentInitializationException(event_info) from e
|
|
307
312
|
|
|
308
313
|
async def connect_to_server(self, server_config: AgentMCPServerConfig) -> None:
|
|
309
314
|
server_name = server_config.name
|
|
@@ -325,7 +330,7 @@ class MCPBaseAgent(AgentProtocol):
|
|
|
325
330
|
except Exception as e:
|
|
326
331
|
event_info = (
|
|
327
332
|
f"Agent `{self._name}` failed to connect to MCP server - "
|
|
328
|
-
f"`{server_name}` @ `{server_config.url}
|
|
333
|
+
f"`{server_name}` @ `{server_config.url}`"
|
|
329
334
|
)
|
|
330
335
|
structlogger.error(
|
|
331
336
|
"mcp_agent.connect.failed_to_connect",
|
|
@@ -335,7 +340,9 @@ class MCPBaseAgent(AgentProtocol):
|
|
|
335
340
|
agent_name=self._name,
|
|
336
341
|
agent_id=str(make_agent_identifier(self._name, self.protocol_type)),
|
|
337
342
|
)
|
|
338
|
-
|
|
343
|
+
|
|
344
|
+
# Wrap exceptions with extra info and raise the same type of exception.
|
|
345
|
+
raise type(e)(f"{event_info} : {e!s}") from e
|
|
339
346
|
|
|
340
347
|
async def connect_to_servers(self) -> None:
|
|
341
348
|
"""Connect to MCP servers."""
|
|
@@ -624,7 +631,11 @@ class MCPBaseAgent(AgentProtocol):
|
|
|
624
631
|
connection = self._server_connections[server_id]
|
|
625
632
|
try:
|
|
626
633
|
session = await connection.ensure_active_session()
|
|
627
|
-
result = await session.call_tool(
|
|
634
|
+
result = await session.call_tool(
|
|
635
|
+
tool_name,
|
|
636
|
+
arguments,
|
|
637
|
+
read_timeout_seconds=timedelta(seconds=self.TOOL_CALL_DEFAULT_TIMEOUT),
|
|
638
|
+
)
|
|
628
639
|
return AgentToolResult.from_mcp_tool_result(tool_name, result)
|
|
629
640
|
except Exception as e:
|
|
630
641
|
return AgentToolResult(
|
|
@@ -637,6 +648,21 @@ class MCPBaseAgent(AgentProtocol):
|
|
|
637
648
|
),
|
|
638
649
|
)
|
|
639
650
|
|
|
651
|
+
async def _run_custom_tool(
|
|
652
|
+
self, custom_tool: CustomToolSchema, arguments: Dict[str, Any]
|
|
653
|
+
) -> AgentToolResult:
|
|
654
|
+
"""Run a custom tool and return the result.
|
|
655
|
+
|
|
656
|
+
Args:
|
|
657
|
+
custom_tool: The custom tool schema containing the tool executor.
|
|
658
|
+
arguments: The arguments to pass to the tool executor.
|
|
659
|
+
|
|
660
|
+
Returns:
|
|
661
|
+
The result of the tool execution as an AgentToolResult.
|
|
662
|
+
"""
|
|
663
|
+
result = custom_tool.tool_executor(arguments)
|
|
664
|
+
return await result if isawaitable(result) else result
|
|
665
|
+
|
|
640
666
|
async def _execute_tool_call(
|
|
641
667
|
self, tool_name: str, arguments: Dict[str, Any]
|
|
642
668
|
) -> AgentToolResult:
|
|
@@ -655,8 +681,20 @@ class MCPBaseAgent(AgentProtocol):
|
|
|
655
681
|
try:
|
|
656
682
|
for custom_tool in self._custom_tools:
|
|
657
683
|
if custom_tool.tool_name == tool_name:
|
|
658
|
-
|
|
659
|
-
|
|
684
|
+
try:
|
|
685
|
+
with anyio.fail_after(self.TOOL_CALL_DEFAULT_TIMEOUT):
|
|
686
|
+
return await self._run_custom_tool(custom_tool, arguments)
|
|
687
|
+
|
|
688
|
+
except TimeoutError:
|
|
689
|
+
return AgentToolResult(
|
|
690
|
+
tool_name=tool_name,
|
|
691
|
+
result=None,
|
|
692
|
+
is_error=True,
|
|
693
|
+
error_message=(
|
|
694
|
+
f"Built-in tool `{tool_name}` timed out after "
|
|
695
|
+
f"{self.TOOL_CALL_DEFAULT_TIMEOUT} seconds."
|
|
696
|
+
),
|
|
697
|
+
)
|
|
660
698
|
except Exception as e:
|
|
661
699
|
return AgentToolResult(
|
|
662
700
|
tool_name=tool_name,
|
rasa/agents/validation.py
CHANGED
|
@@ -347,9 +347,11 @@ def _handle_pydantic_validation_error(
|
|
|
347
347
|
def _validate_endpoint_references(agent_config: AgentConfig) -> None:
|
|
348
348
|
"""Validate that LLM and MCP server references in agent config are valid."""
|
|
349
349
|
agent_name = agent_config.agent.name
|
|
350
|
-
|
|
351
|
-
# Get available endpoints
|
|
352
350
|
endpoints = Configuration.get_instance().endpoints
|
|
351
|
+
if not endpoints.config_file_path:
|
|
352
|
+
# If no endpoints were loaded (e.g., `data validate` without --endpoints), skip
|
|
353
|
+
# endpoint reference checks
|
|
354
|
+
return
|
|
353
355
|
|
|
354
356
|
# Validate LLM configuration references
|
|
355
357
|
if agent_config.configuration and agent_config.configuration.llm:
|
rasa/builder/config.py
CHANGED
|
@@ -13,6 +13,10 @@ OPENAI_VECTOR_STORE_ID = os.getenv(
|
|
|
13
13
|
)
|
|
14
14
|
OPENAI_MAX_VECTOR_RESULTS = int(os.getenv("OPENAI_MAX_VECTOR_RESULTS", "10"))
|
|
15
15
|
OPENAI_TIMEOUT = int(os.getenv("OPENAI_TIMEOUT", "30"))
|
|
16
|
+
# OpenAI Token Pricing Configuration (per 1,000 tokens)
|
|
17
|
+
COPILOT_INPUT_TOKEN_PRICE = float(os.getenv("COPILOT_INPUT_TOKEN_PRICE", "0.002"))
|
|
18
|
+
COPILOT_OUTPUT_TOKEN_PRICE = float(os.getenv("COPILOT_OUTPUT_TOKEN_PRICE", "0.0005"))
|
|
19
|
+
COPILOT_CACHED_TOKEN_PRICE = float(os.getenv("COPILOT_CACHED_TOKEN_PRICE", "0.002"))
|
|
16
20
|
|
|
17
21
|
# Server Configuration
|
|
18
22
|
BUILDER_SERVER_HOST = os.getenv("SERVER_HOST", "0.0.0.0")
|
rasa/builder/copilot/copilot.py
CHANGED
|
@@ -42,6 +42,7 @@ from rasa.builder.exceptions import (
|
|
|
42
42
|
DocumentRetrievalError,
|
|
43
43
|
)
|
|
44
44
|
from rasa.builder.shared.tracker_context import TrackerContext
|
|
45
|
+
from rasa.builder.telemetry.copilot_langfuse_telemetry import CopilotLangfuseTelemetry
|
|
45
46
|
from rasa.shared.constants import PACKAGE_NAME
|
|
46
47
|
|
|
47
48
|
structlogger = structlog.get_logger()
|
|
@@ -72,7 +73,11 @@ class Copilot:
|
|
|
72
73
|
)
|
|
73
74
|
|
|
74
75
|
# The final stream chunk includes usage statistics.
|
|
75
|
-
self.usage_statistics = UsageStatistics(
|
|
76
|
+
self.usage_statistics = UsageStatistics(
|
|
77
|
+
input_token_price=config.COPILOT_INPUT_TOKEN_PRICE,
|
|
78
|
+
output_token_price=config.COPILOT_OUTPUT_TOKEN_PRICE,
|
|
79
|
+
cached_token_price=config.COPILOT_CACHED_TOKEN_PRICE,
|
|
80
|
+
)
|
|
76
81
|
|
|
77
82
|
@asynccontextmanager
|
|
78
83
|
async def _get_client(self) -> AsyncGenerator[openai.AsyncOpenAI, None]:
|
|
@@ -94,6 +99,16 @@ class Copilot:
|
|
|
94
99
|
error=str(exc),
|
|
95
100
|
)
|
|
96
101
|
|
|
102
|
+
@property
|
|
103
|
+
def llm_config(self) -> Dict[str, Any]:
|
|
104
|
+
"""The LLM config used to generate the response."""
|
|
105
|
+
return {
|
|
106
|
+
"model": config.OPENAI_MODEL,
|
|
107
|
+
"temperature": config.OPENAI_TEMPERATURE,
|
|
108
|
+
"stream": True,
|
|
109
|
+
"stream_options": {"include_usage": True},
|
|
110
|
+
}
|
|
111
|
+
|
|
97
112
|
async def search_rasa_documentation(
|
|
98
113
|
self,
|
|
99
114
|
context: CopilotContext,
|
|
@@ -108,7 +123,9 @@ class Copilot:
|
|
|
108
123
|
"""
|
|
109
124
|
try:
|
|
110
125
|
query = self._create_documentation_search_query(context)
|
|
111
|
-
|
|
126
|
+
documents = await self._inkeep_document_retrieval.retrieve_documents(query)
|
|
127
|
+
# TODO: Log documentation retrieval to Langfuse
|
|
128
|
+
return documents
|
|
112
129
|
except DocumentRetrievalError as e:
|
|
113
130
|
structlogger.error(
|
|
114
131
|
"copilot.search_rasa_documentation.error",
|
|
@@ -145,11 +162,12 @@ class Copilot:
|
|
|
145
162
|
Exception: If an unexpected error occurs.
|
|
146
163
|
"""
|
|
147
164
|
relevant_documents = await self.search_rasa_documentation(context)
|
|
148
|
-
messages = await self._build_messages(context, relevant_documents)
|
|
149
165
|
tracker_event_attachments = self._extract_tracker_event_attachments(
|
|
150
166
|
context.copilot_chat_history[-1]
|
|
151
167
|
)
|
|
168
|
+
messages = await self._build_messages(context, relevant_documents)
|
|
152
169
|
|
|
170
|
+
# TODO: Delete this after Langfuse is implemented
|
|
153
171
|
support_evidence = CopilotGenerationContext(
|
|
154
172
|
relevant_documents=relevant_documents,
|
|
155
173
|
system_message=messages[0],
|
|
@@ -163,6 +181,7 @@ class Copilot:
|
|
|
163
181
|
support_evidence,
|
|
164
182
|
)
|
|
165
183
|
|
|
184
|
+
@CopilotLangfuseTelemetry.trace_copilot_streaming_generation
|
|
166
185
|
async def _stream_response(
|
|
167
186
|
self, messages: List[Dict[str, Any]]
|
|
168
187
|
) -> AsyncGenerator[str, None]:
|
|
@@ -172,13 +191,10 @@ class Copilot:
|
|
|
172
191
|
try:
|
|
173
192
|
async with self._get_client() as client:
|
|
174
193
|
stream = await client.chat.completions.create(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
temperature=config.OPENAI_TEMPERATURE,
|
|
178
|
-
stream=True,
|
|
179
|
-
stream_options={"include_usage": True},
|
|
194
|
+
messages=messages,
|
|
195
|
+
**self.llm_config,
|
|
180
196
|
)
|
|
181
|
-
async for chunk in stream:
|
|
197
|
+
async for chunk in stream: # type: ignore[attr-defined]
|
|
182
198
|
# The final chunk, which contains the usage statistics,
|
|
183
199
|
# arrives with an empty `choices` list.
|
|
184
200
|
if not chunk.choices:
|
|
@@ -189,6 +205,7 @@ class Copilot:
|
|
|
189
205
|
delta = chunk.choices[0].delta
|
|
190
206
|
if delta and delta.content:
|
|
191
207
|
yield delta.content
|
|
208
|
+
|
|
192
209
|
except openai.OpenAIError as e:
|
|
193
210
|
structlogger.exception("copilot.stream_response.api_error", error=str(e))
|
|
194
211
|
raise CopilotStreamError(
|
|
@@ -559,4 +576,6 @@ class Copilot:
|
|
|
559
576
|
"""Extract the tracker event attachments from the message."""
|
|
560
577
|
if not isinstance(message, UserChatMessage):
|
|
561
578
|
return []
|
|
579
|
+
# TODO: Log tracker event attachments to Langfuse only in the case of the
|
|
580
|
+
# User chat message.
|
|
562
581
|
return message.get_content_blocks_by_type(EventContent)
|
rasa/builder/copilot/models.py
CHANGED
|
@@ -3,6 +3,7 @@ from enum import Enum
|
|
|
3
3
|
from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union
|
|
4
4
|
|
|
5
5
|
import structlog
|
|
6
|
+
from openai.types.chat import ChatCompletion
|
|
6
7
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
7
8
|
from pydantic import (
|
|
8
9
|
BaseModel,
|
|
@@ -612,16 +613,171 @@ class TrainingErrorLog(CopilotOutput):
|
|
|
612
613
|
|
|
613
614
|
|
|
614
615
|
class UsageStatistics(BaseModel):
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
616
|
+
"""Usage statistics for a copilot generation."""
|
|
617
|
+
|
|
618
|
+
# Token usage statistics
|
|
619
|
+
prompt_tokens: Optional[int] = Field(
|
|
620
|
+
default=None,
|
|
621
|
+
description=(
|
|
622
|
+
"Total number of prompt tokens used to generate completion. "
|
|
623
|
+
"Should include cached prompt tokens."
|
|
624
|
+
),
|
|
625
|
+
)
|
|
626
|
+
completion_tokens: Optional[int] = Field(
|
|
627
|
+
default=None,
|
|
628
|
+
description="Number of generated tokens.",
|
|
629
|
+
)
|
|
630
|
+
total_tokens: Optional[int] = Field(
|
|
631
|
+
default=None,
|
|
632
|
+
description="Total number of tokens used (input + output).",
|
|
633
|
+
)
|
|
634
|
+
cached_prompt_tokens: Optional[int] = Field(
|
|
635
|
+
default=None,
|
|
636
|
+
description="Number of cached prompt tokens.",
|
|
637
|
+
)
|
|
638
|
+
model: Optional[str] = Field(
|
|
639
|
+
default=None,
|
|
640
|
+
description="The model used to generate the response.",
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
# Token prices
|
|
644
|
+
input_token_price: float = Field(
|
|
645
|
+
default=0.0,
|
|
646
|
+
description="Price per 1K input tokens in dollars.",
|
|
647
|
+
)
|
|
648
|
+
output_token_price: float = Field(
|
|
649
|
+
default=0.0,
|
|
650
|
+
description="Price per 1K output tokens in dollars.",
|
|
651
|
+
)
|
|
652
|
+
cached_token_price: float = Field(
|
|
653
|
+
default=0.0,
|
|
654
|
+
description="Price per 1K cached tokens in dollars.",
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
@property
|
|
658
|
+
def non_cached_prompt_tokens(self) -> Optional[int]:
|
|
659
|
+
"""Get the non-cached prompt tokens."""
|
|
660
|
+
if self.cached_prompt_tokens is not None and self.prompt_tokens is not None:
|
|
661
|
+
return self.prompt_tokens - self.cached_prompt_tokens
|
|
662
|
+
return self.prompt_tokens
|
|
663
|
+
|
|
664
|
+
@property
|
|
665
|
+
def non_cached_cost(self) -> Optional[float]:
|
|
666
|
+
"""Calculate the non-cached token cost based on configured pricing."""
|
|
667
|
+
if self.non_cached_prompt_tokens is None:
|
|
668
|
+
return None
|
|
669
|
+
if self.non_cached_prompt_tokens == 0:
|
|
670
|
+
return 0.0
|
|
671
|
+
|
|
672
|
+
return (self.non_cached_prompt_tokens / 1000.0) * self.input_token_price
|
|
673
|
+
|
|
674
|
+
@property
|
|
675
|
+
def cached_cost(self) -> Optional[float]:
|
|
676
|
+
"""Calculate the cached token cost based on configured pricing."""
|
|
677
|
+
if self.cached_prompt_tokens is None:
|
|
678
|
+
return None
|
|
679
|
+
if self.cached_prompt_tokens == 0:
|
|
680
|
+
return 0.0
|
|
681
|
+
|
|
682
|
+
return (self.cached_prompt_tokens / 1000.0) * self.cached_token_price
|
|
683
|
+
|
|
684
|
+
@property
|
|
685
|
+
def input_cost(self) -> Optional[float]:
|
|
686
|
+
"""Calculate the input token cost based on configured pricing.
|
|
687
|
+
|
|
688
|
+
The calculation takes into account the cached prompt tokens (if available) too.
|
|
689
|
+
"""
|
|
690
|
+
# If both non-cached and cached costs are None, there's no input cost
|
|
691
|
+
if self.non_cached_cost is None and self.cached_cost is None:
|
|
692
|
+
return None
|
|
693
|
+
|
|
694
|
+
# If only non-cached cost is available, return it
|
|
695
|
+
if self.non_cached_cost is not None and self.cached_cost is None:
|
|
696
|
+
return self.non_cached_cost
|
|
697
|
+
|
|
698
|
+
# If only cached cost is available, return it
|
|
699
|
+
if self.non_cached_cost is None and self.cached_cost is not None:
|
|
700
|
+
return self.cached_cost
|
|
701
|
+
|
|
702
|
+
# If both are available, return the sum
|
|
703
|
+
return self.non_cached_cost + self.cached_cost # type: ignore[operator]
|
|
704
|
+
|
|
705
|
+
@property
|
|
706
|
+
def output_cost(self) -> Optional[float]:
|
|
707
|
+
"""Calculate the output token cost based on configured pricing."""
|
|
708
|
+
if self.completion_tokens is None:
|
|
709
|
+
return None
|
|
710
|
+
if self.completion_tokens == 0:
|
|
711
|
+
return 0.0
|
|
712
|
+
|
|
713
|
+
return (self.completion_tokens / 1000.0) * self.output_token_price
|
|
714
|
+
|
|
715
|
+
@property
|
|
716
|
+
def total_cost(self) -> Optional[float]:
|
|
717
|
+
"""Calculate the total cost based on configured pricing.
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
Total cost in dollars, or None if insufficient data.
|
|
721
|
+
"""
|
|
722
|
+
if self.input_cost is None or self.output_cost is None:
|
|
723
|
+
return None
|
|
724
|
+
|
|
725
|
+
return self.input_cost + self.output_cost
|
|
726
|
+
|
|
727
|
+
def update_token_prices(
|
|
728
|
+
self,
|
|
729
|
+
input_token_price: float,
|
|
730
|
+
output_token_price: float,
|
|
731
|
+
cached_token_price: float,
|
|
732
|
+
) -> None:
|
|
733
|
+
"""Update token prices with provided values.
|
|
734
|
+
|
|
735
|
+
Args:
|
|
736
|
+
input_token_price: Price per 1K input tokens in dollars.
|
|
737
|
+
output_token_price: Price per 1K output tokens in dollars.
|
|
738
|
+
cached_token_price: Price per 1K cached tokens in dollars.
|
|
739
|
+
"""
|
|
740
|
+
self.input_token_price = input_token_price
|
|
741
|
+
self.output_token_price = output_token_price
|
|
742
|
+
self.cached_token_price = cached_token_price
|
|
743
|
+
|
|
744
|
+
@classmethod
|
|
745
|
+
def from_chat_completion_response(
|
|
746
|
+
cls,
|
|
747
|
+
response: ChatCompletion,
|
|
748
|
+
input_token_price: float = 0.0,
|
|
749
|
+
output_token_price: float = 0.0,
|
|
750
|
+
cached_token_price: float = 0.0,
|
|
751
|
+
) -> Optional["UsageStatistics"]:
|
|
752
|
+
"""Create a UsageStatistics object from a ChatCompletionChunk."""
|
|
753
|
+
if not (usage := getattr(response, "usage", None)):
|
|
754
|
+
return None
|
|
755
|
+
|
|
756
|
+
usage_statistics = cls(
|
|
757
|
+
input_token_price=input_token_price,
|
|
758
|
+
output_token_price=output_token_price,
|
|
759
|
+
cached_token_price=cached_token_price,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
usage_statistics.prompt_tokens = usage.prompt_tokens
|
|
763
|
+
usage_statistics.completion_tokens = usage.completion_tokens
|
|
764
|
+
usage_statistics.total_tokens = usage.total_tokens
|
|
765
|
+
usage_statistics.model = getattr(response, "model", None)
|
|
766
|
+
|
|
767
|
+
# Extract cached tokens if available
|
|
768
|
+
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
|
769
|
+
usage_statistics.cached_prompt_tokens = getattr(
|
|
770
|
+
usage.prompt_tokens_details, "cached_tokens", None
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
return usage_statistics
|
|
619
774
|
|
|
620
775
|
def reset(self) -> None:
|
|
621
776
|
"""Reset usage statistics to their default values."""
|
|
622
777
|
self.prompt_tokens = None
|
|
623
778
|
self.completion_tokens = None
|
|
624
779
|
self.total_tokens = None
|
|
780
|
+
self.cached_prompt_tokens = None
|
|
625
781
|
self.model = None
|
|
626
782
|
|
|
627
783
|
def update_from_stream_chunk(self, chunk: ChatCompletionChunk) -> None:
|
|
@@ -630,14 +786,25 @@ class UsageStatistics(BaseModel):
|
|
|
630
786
|
Args:
|
|
631
787
|
chunk: The OpenAI stream chunk containing usage statistics.
|
|
632
788
|
"""
|
|
789
|
+
# Reset the usage statistics to their default values
|
|
790
|
+
self.reset()
|
|
791
|
+
|
|
792
|
+
# If the chunk has no usage statistics, return
|
|
633
793
|
if not (usage := getattr(chunk, "usage", None)):
|
|
634
794
|
return
|
|
635
795
|
|
|
796
|
+
# Update the usage statistics with the values from the chunk
|
|
636
797
|
self.prompt_tokens = usage.prompt_tokens
|
|
637
798
|
self.completion_tokens = usage.completion_tokens
|
|
638
799
|
self.total_tokens = usage.total_tokens
|
|
639
800
|
self.model = getattr(chunk, "model", None)
|
|
640
801
|
|
|
802
|
+
# Extract cached tokens if available
|
|
803
|
+
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
|
804
|
+
self.cached_prompt_tokens = getattr(
|
|
805
|
+
usage.prompt_tokens_details, "cached_tokens", None
|
|
806
|
+
)
|
|
807
|
+
|
|
641
808
|
|
|
642
809
|
class SigningContext(BaseModel):
|
|
643
810
|
secret: Optional[str] = Field(None)
|
|
@@ -17,6 +17,7 @@ from rasa.builder.document_retrieval.constants import (
|
|
|
17
17
|
)
|
|
18
18
|
from rasa.builder.document_retrieval.models import Document
|
|
19
19
|
from rasa.builder.exceptions import DocumentRetrievalError
|
|
20
|
+
from rasa.builder.telemetry.copilot_langfuse_telemetry import CopilotLangfuseTelemetry
|
|
20
21
|
from rasa.shared.utils.io import read_json_file
|
|
21
22
|
|
|
22
23
|
structlogger = structlog.get_logger()
|
|
@@ -88,6 +89,7 @@ class InKeepDocumentRetrieval:
|
|
|
88
89
|
)
|
|
89
90
|
raise e
|
|
90
91
|
|
|
92
|
+
@CopilotLangfuseTelemetry.trace_document_retrieval_generation
|
|
91
93
|
async def _call_inkeep_rag_api(
|
|
92
94
|
self, query: str, temperature: float, timeout: float
|
|
93
95
|
) -> ChatCompletion:
|
rasa/builder/download.py
CHANGED
|
@@ -27,7 +27,7 @@ def _get_pyproject_toml_content(project_id: str) -> str:
|
|
|
27
27
|
version = "0.1.0"
|
|
28
28
|
description = "Add your description for your Rasa bot here"
|
|
29
29
|
readme = "README.md"
|
|
30
|
-
dependencies = ["rasa-pro>=3.
|
|
30
|
+
dependencies = ["rasa-pro>=3.14"]
|
|
31
31
|
requires-python = ">={sys.version_info.major}.{sys.version_info.minor}"
|
|
32
32
|
"""
|
|
33
33
|
)
|