signalwire-agents 0.1.1__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- signalwire_agents/__init__.py +1 -1
- signalwire_agents/agent_server.py +1 -1
- signalwire_agents/core/__init__.py +29 -0
- signalwire_agents/core/agent_base.py +2541 -0
- signalwire_agents/core/function_result.py +123 -0
- signalwire_agents/core/pom_builder.py +204 -0
- signalwire_agents/core/security/__init__.py +9 -0
- signalwire_agents/core/security/session_manager.py +179 -0
- signalwire_agents/core/state/__init__.py +17 -0
- signalwire_agents/core/state/file_state_manager.py +219 -0
- signalwire_agents/core/state/state_manager.py +101 -0
- signalwire_agents/core/swaig_function.py +172 -0
- signalwire_agents/core/swml_builder.py +214 -0
- signalwire_agents/core/swml_handler.py +227 -0
- signalwire_agents/core/swml_renderer.py +368 -0
- signalwire_agents/core/swml_service.py +1057 -0
- signalwire_agents/prefabs/__init__.py +26 -0
- signalwire_agents/prefabs/concierge.py +267 -0
- signalwire_agents/prefabs/faq_bot.py +305 -0
- signalwire_agents/prefabs/info_gatherer.py +263 -0
- signalwire_agents/prefabs/receptionist.py +295 -0
- signalwire_agents/prefabs/survey.py +378 -0
- signalwire_agents/utils/__init__.py +9 -0
- signalwire_agents/utils/pom_utils.py +9 -0
- signalwire_agents/utils/schema_utils.py +357 -0
- signalwire_agents/utils/token_generators.py +9 -0
- signalwire_agents/utils/validators.py +9 -0
- {signalwire_agents-0.1.1.dist-info → signalwire_agents-0.1.5.dist-info}/METADATA +1 -1
- signalwire_agents-0.1.5.dist-info/RECORD +34 -0
- signalwire_agents-0.1.1.dist-info/RECORD +0 -9
- {signalwire_agents-0.1.1.data → signalwire_agents-0.1.5.data}/data/schema.json +0 -0
- {signalwire_agents-0.1.1.dist-info → signalwire_agents-0.1.5.dist-info}/WHEEL +0 -0
- {signalwire_agents-0.1.1.dist-info → signalwire_agents-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {signalwire_agents-0.1.1.dist-info → signalwire_agents-0.1.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2541 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 SignalWire
|
3
|
+
|
4
|
+
This file is part of the SignalWire AI Agents SDK.
|
5
|
+
|
6
|
+
Licensed under the MIT License.
|
7
|
+
See LICENSE file in the project root for full license information.
|
8
|
+
"""
|
9
|
+
|
10
|
+
"""
|
11
|
+
AgentBase - Core foundation class for all SignalWire AI Agents
|
12
|
+
"""
|
13
|
+
|
14
|
+
import functools
|
15
|
+
import inspect
|
16
|
+
import os
|
17
|
+
import sys
|
18
|
+
import uuid
|
19
|
+
import tempfile
|
20
|
+
import traceback
|
21
|
+
from typing import Dict, List, Any, Optional, Union, Callable, Tuple, Type, TypeVar
|
22
|
+
import base64
|
23
|
+
import secrets
|
24
|
+
from urllib.parse import urlparse
|
25
|
+
import json
|
26
|
+
from datetime import datetime
|
27
|
+
import re
|
28
|
+
|
29
|
+
try:
|
30
|
+
import fastapi
|
31
|
+
from fastapi import FastAPI, APIRouter, Depends, HTTPException, Query, Body, Request, Response
|
32
|
+
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
33
|
+
from pydantic import BaseModel
|
34
|
+
except ImportError:
|
35
|
+
raise ImportError(
|
36
|
+
"fastapi is required. Install it with: pip install fastapi"
|
37
|
+
)
|
38
|
+
|
39
|
+
try:
|
40
|
+
import uvicorn
|
41
|
+
except ImportError:
|
42
|
+
raise ImportError(
|
43
|
+
"uvicorn is required. Install it with: pip install uvicorn"
|
44
|
+
)
|
45
|
+
|
46
|
+
try:
|
47
|
+
import structlog
|
48
|
+
# Configure structlog only if not already configured
|
49
|
+
if not structlog.is_configured():
|
50
|
+
structlog.configure(
|
51
|
+
processors=[
|
52
|
+
structlog.stdlib.filter_by_level,
|
53
|
+
structlog.stdlib.add_logger_name,
|
54
|
+
structlog.stdlib.add_log_level,
|
55
|
+
structlog.stdlib.PositionalArgumentsFormatter(),
|
56
|
+
structlog.processors.TimeStamper(fmt="iso"),
|
57
|
+
structlog.processors.StackInfoRenderer(),
|
58
|
+
structlog.processors.format_exc_info,
|
59
|
+
structlog.processors.UnicodeDecoder(),
|
60
|
+
structlog.processors.JSONRenderer()
|
61
|
+
],
|
62
|
+
context_class=dict,
|
63
|
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
64
|
+
wrapper_class=structlog.stdlib.BoundLogger,
|
65
|
+
cache_logger_on_first_use=True,
|
66
|
+
)
|
67
|
+
except ImportError:
|
68
|
+
raise ImportError(
|
69
|
+
"structlog is required. Install it with: pip install structlog"
|
70
|
+
)
|
71
|
+
|
72
|
+
from signalwire_agents.core.pom_builder import PomBuilder
|
73
|
+
from signalwire_agents.core.swaig_function import SWAIGFunction
|
74
|
+
from signalwire_agents.core.function_result import SwaigFunctionResult
|
75
|
+
from signalwire_agents.core.swml_renderer import SwmlRenderer
|
76
|
+
from signalwire_agents.core.security.session_manager import SessionManager
|
77
|
+
from signalwire_agents.core.state import StateManager, FileStateManager
|
78
|
+
from signalwire_agents.core.swml_service import SWMLService
|
79
|
+
from signalwire_agents.core.swml_handler import AIVerbHandler
|
80
|
+
|
81
|
+
# Create a logger
|
82
|
+
logger = structlog.get_logger("agent_base")
|
83
|
+
|
84
|
+
class AgentBase(SWMLService):
|
85
|
+
"""
|
86
|
+
Base class for all SignalWire AI Agents.
|
87
|
+
|
88
|
+
This class extends SWMLService and provides enhanced functionality for building agents including:
|
89
|
+
- Prompt building and customization
|
90
|
+
- SWML rendering
|
91
|
+
- SWAIG function definition and execution
|
92
|
+
- Web service for serving SWML and handling webhooks
|
93
|
+
- Security and session management
|
94
|
+
|
95
|
+
Subclassing options:
|
96
|
+
1. Simple override of get_prompt() for raw text
|
97
|
+
2. Using prompt_* methods for structured prompts
|
98
|
+
3. Declarative PROMPT_SECTIONS class attribute
|
99
|
+
"""
|
100
|
+
|
101
|
+
# Subclasses can define this to declaratively set prompt sections
|
102
|
+
PROMPT_SECTIONS = None
|
103
|
+
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
name: str,
|
107
|
+
route: str = "/",
|
108
|
+
host: str = "0.0.0.0",
|
109
|
+
port: int = 3000,
|
110
|
+
basic_auth: Optional[Tuple[str, str]] = None,
|
111
|
+
use_pom: bool = True,
|
112
|
+
enable_state_tracking: bool = False,
|
113
|
+
token_expiry_secs: int = 600,
|
114
|
+
auto_answer: bool = True,
|
115
|
+
record_call: bool = False,
|
116
|
+
record_format: str = "mp4",
|
117
|
+
record_stereo: bool = True,
|
118
|
+
state_manager: Optional[StateManager] = None,
|
119
|
+
default_webhook_url: Optional[str] = None,
|
120
|
+
agent_id: Optional[str] = None,
|
121
|
+
native_functions: Optional[List[str]] = None,
|
122
|
+
schema_path: Optional[str] = None,
|
123
|
+
suppress_logs: bool = False
|
124
|
+
):
|
125
|
+
"""
|
126
|
+
Initialize a new agent
|
127
|
+
|
128
|
+
Args:
|
129
|
+
name: Agent name/identifier
|
130
|
+
route: HTTP route path for this agent
|
131
|
+
host: Host to bind the web server to
|
132
|
+
port: Port to bind the web server to
|
133
|
+
basic_auth: Optional (username, password) tuple for basic auth
|
134
|
+
use_pom: Whether to use POM for prompt building
|
135
|
+
enable_state_tracking: Whether to register startup_hook and hangup_hook SWAIG functions to track conversation state
|
136
|
+
token_expiry_secs: Seconds until tokens expire
|
137
|
+
auto_answer: Whether to automatically answer calls
|
138
|
+
record_call: Whether to record calls
|
139
|
+
record_format: Recording format
|
140
|
+
record_stereo: Whether to record in stereo
|
141
|
+
state_manager: Optional state manager for this agent
|
142
|
+
default_webhook_url: Optional default webhook URL for all SWAIG functions
|
143
|
+
agent_id: Optional unique ID for this agent, generated if not provided
|
144
|
+
native_functions: Optional list of native functions to include in the SWAIG object
|
145
|
+
schema_path: Optional path to the schema file
|
146
|
+
suppress_logs: Whether to suppress structured logs
|
147
|
+
"""
|
148
|
+
# Import SWMLService here to avoid circular imports
|
149
|
+
from signalwire_agents.core.swml_service import SWMLService
|
150
|
+
|
151
|
+
# If schema_path is not provided, we'll let SWMLService find it through its _find_schema_path method
|
152
|
+
# which will be called in its __init__
|
153
|
+
|
154
|
+
# Initialize the SWMLService base class
|
155
|
+
super().__init__(
|
156
|
+
name=name,
|
157
|
+
route=route,
|
158
|
+
host=host,
|
159
|
+
port=port,
|
160
|
+
basic_auth=basic_auth,
|
161
|
+
schema_path=schema_path
|
162
|
+
)
|
163
|
+
|
164
|
+
# Log the schema path if found and not suppressing logs
|
165
|
+
if self.schema_utils and self.schema_utils.schema_path and not suppress_logs:
|
166
|
+
print(f"Using schema.json at: {self.schema_utils.schema_path}")
|
167
|
+
|
168
|
+
# Setup logger for this instance
|
169
|
+
self.log = logger.bind(agent=name)
|
170
|
+
self.log.info("agent_initializing", route=route, host=host, port=port)
|
171
|
+
|
172
|
+
# Store agent-specific parameters
|
173
|
+
self._default_webhook_url = default_webhook_url
|
174
|
+
self._suppress_logs = suppress_logs
|
175
|
+
|
176
|
+
# Generate or use the provided agent ID
|
177
|
+
self.agent_id = agent_id or str(uuid.uuid4())
|
178
|
+
|
179
|
+
# Check for proxy URL base in environment
|
180
|
+
self._proxy_url_base = os.environ.get('SWML_PROXY_URL_BASE')
|
181
|
+
|
182
|
+
# Initialize prompt handling
|
183
|
+
self._use_pom = use_pom
|
184
|
+
self._raw_prompt = None
|
185
|
+
self._post_prompt = None
|
186
|
+
|
187
|
+
# Initialize POM if needed
|
188
|
+
if self._use_pom:
|
189
|
+
try:
|
190
|
+
from signalwire_pom.pom import PromptObjectModel
|
191
|
+
self.pom = PromptObjectModel()
|
192
|
+
except ImportError:
|
193
|
+
raise ImportError(
|
194
|
+
"signalwire-pom package is required for use_pom=True. "
|
195
|
+
"Install it with: pip install signalwire-pom"
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
self.pom = None
|
199
|
+
|
200
|
+
# Initialize tool registry (separate from SWMLService verb registry)
|
201
|
+
self._swaig_functions: Dict[str, SWAIGFunction] = {}
|
202
|
+
|
203
|
+
# Initialize session manager
|
204
|
+
self._session_manager = SessionManager(token_expiry_secs=token_expiry_secs)
|
205
|
+
self._enable_state_tracking = enable_state_tracking
|
206
|
+
|
207
|
+
# URL override variables
|
208
|
+
self._web_hook_url_override = None
|
209
|
+
self._post_prompt_url_override = None
|
210
|
+
|
211
|
+
# Register the tool decorator on this instance
|
212
|
+
self.tool = self._tool_decorator
|
213
|
+
|
214
|
+
# Call settings
|
215
|
+
self._auto_answer = auto_answer
|
216
|
+
self._record_call = record_call
|
217
|
+
self._record_format = record_format
|
218
|
+
self._record_stereo = record_stereo
|
219
|
+
|
220
|
+
# Process declarative PROMPT_SECTIONS if defined in subclass
|
221
|
+
self._process_prompt_sections()
|
222
|
+
|
223
|
+
# Initialize state manager
|
224
|
+
self._state_manager = state_manager or FileStateManager()
|
225
|
+
|
226
|
+
# Process class-decorated tools (using @AgentBase.tool)
|
227
|
+
self._register_class_decorated_tools()
|
228
|
+
|
229
|
+
# Add native_functions parameter
|
230
|
+
self.native_functions = native_functions or []
|
231
|
+
|
232
|
+
# Register state tracking tools if enabled
|
233
|
+
if enable_state_tracking:
|
234
|
+
self._register_state_tracking_tools()
|
235
|
+
|
236
|
+
# Initialize new configuration containers
|
237
|
+
self._hints = []
|
238
|
+
self._languages = []
|
239
|
+
self._pronounce = []
|
240
|
+
self._params = {}
|
241
|
+
self._global_data = {}
|
242
|
+
self._function_includes = []
|
243
|
+
|
244
|
+
def _process_prompt_sections(self):
|
245
|
+
"""
|
246
|
+
Process declarative PROMPT_SECTIONS attribute from a subclass
|
247
|
+
|
248
|
+
This auto-vivifies section methods and bootstraps the prompt
|
249
|
+
from class declaration, allowing for declarative agents.
|
250
|
+
"""
|
251
|
+
# Skip if no PROMPT_SECTIONS defined or not using POM
|
252
|
+
cls = self.__class__
|
253
|
+
if not hasattr(cls, 'PROMPT_SECTIONS') or cls.PROMPT_SECTIONS is None or not self._use_pom:
|
254
|
+
return
|
255
|
+
|
256
|
+
sections = cls.PROMPT_SECTIONS
|
257
|
+
|
258
|
+
# If sections is a dictionary mapping section names to content
|
259
|
+
if isinstance(sections, dict):
|
260
|
+
for title, content in sections.items():
|
261
|
+
# Handle different content types
|
262
|
+
if isinstance(content, str):
|
263
|
+
# Plain text - add as body
|
264
|
+
self.prompt_add_section(title, body=content)
|
265
|
+
elif isinstance(content, list) and content: # Only add if non-empty
|
266
|
+
# List of strings - add as bullets
|
267
|
+
self.prompt_add_section(title, bullets=content)
|
268
|
+
elif isinstance(content, dict):
|
269
|
+
# Dictionary with body/bullets/subsections
|
270
|
+
body = content.get('body', '')
|
271
|
+
bullets = content.get('bullets', [])
|
272
|
+
numbered = content.get('numbered', False)
|
273
|
+
numbered_bullets = content.get('numberedBullets', False)
|
274
|
+
|
275
|
+
# Only create section if it has content
|
276
|
+
if body or bullets or 'subsections' in content:
|
277
|
+
# Create the section
|
278
|
+
self.prompt_add_section(
|
279
|
+
title,
|
280
|
+
body=body,
|
281
|
+
bullets=bullets if bullets else None,
|
282
|
+
numbered=numbered,
|
283
|
+
numbered_bullets=numbered_bullets
|
284
|
+
)
|
285
|
+
|
286
|
+
# Process subsections if any
|
287
|
+
subsections = content.get('subsections', [])
|
288
|
+
for subsection in subsections:
|
289
|
+
if 'title' in subsection:
|
290
|
+
sub_title = subsection['title']
|
291
|
+
sub_body = subsection.get('body', '')
|
292
|
+
sub_bullets = subsection.get('bullets', [])
|
293
|
+
|
294
|
+
# Only add subsection if it has content
|
295
|
+
if sub_body or sub_bullets:
|
296
|
+
self.prompt_add_subsection(
|
297
|
+
title,
|
298
|
+
sub_title,
|
299
|
+
body=sub_body,
|
300
|
+
bullets=sub_bullets if sub_bullets else None
|
301
|
+
)
|
302
|
+
# If sections is a list of section objects, use the POM format directly
|
303
|
+
elif isinstance(sections, list):
|
304
|
+
if self.pom:
|
305
|
+
# Process each section using auto-vivifying methods
|
306
|
+
for section in sections:
|
307
|
+
if 'title' in section:
|
308
|
+
title = section['title']
|
309
|
+
body = section.get('body', '')
|
310
|
+
bullets = section.get('bullets', [])
|
311
|
+
numbered = section.get('numbered', False)
|
312
|
+
numbered_bullets = section.get('numberedBullets', False)
|
313
|
+
|
314
|
+
# Only create section if it has content
|
315
|
+
if body or bullets or 'subsections' in section:
|
316
|
+
self.prompt_add_section(
|
317
|
+
title,
|
318
|
+
body=body,
|
319
|
+
bullets=bullets if bullets else None,
|
320
|
+
numbered=numbered,
|
321
|
+
numbered_bullets=numbered_bullets
|
322
|
+
)
|
323
|
+
|
324
|
+
# Process subsections if any
|
325
|
+
subsections = section.get('subsections', [])
|
326
|
+
for subsection in subsections:
|
327
|
+
if 'title' in subsection:
|
328
|
+
sub_title = subsection['title']
|
329
|
+
sub_body = subsection.get('body', '')
|
330
|
+
sub_bullets = subsection.get('bullets', [])
|
331
|
+
|
332
|
+
# Only add subsection if it has content
|
333
|
+
if sub_body or sub_bullets:
|
334
|
+
self.prompt_add_subsection(
|
335
|
+
title,
|
336
|
+
sub_title,
|
337
|
+
body=sub_body,
|
338
|
+
bullets=sub_bullets if sub_bullets else None
|
339
|
+
)
|
340
|
+
|
341
|
+
# ----------------------------------------------------------------------
|
342
|
+
# Prompt Building Methods
|
343
|
+
# ----------------------------------------------------------------------
|
344
|
+
|
345
|
+
def set_prompt_text(self, text: str) -> 'AgentBase':
|
346
|
+
"""
|
347
|
+
Set the prompt as raw text instead of using POM
|
348
|
+
|
349
|
+
Args:
|
350
|
+
text: The raw prompt text
|
351
|
+
|
352
|
+
Returns:
|
353
|
+
Self for method chaining
|
354
|
+
"""
|
355
|
+
self._raw_prompt = text
|
356
|
+
return self
|
357
|
+
|
358
|
+
def set_prompt_pom(self, pom: List[Dict[str, Any]]) -> 'AgentBase':
|
359
|
+
"""
|
360
|
+
Set the prompt as a POM dictionary
|
361
|
+
|
362
|
+
Args:
|
363
|
+
pom: POM dictionary structure
|
364
|
+
|
365
|
+
Returns:
|
366
|
+
Self for method chaining
|
367
|
+
"""
|
368
|
+
if self._use_pom:
|
369
|
+
self.pom = pom
|
370
|
+
else:
|
371
|
+
raise ValueError("use_pom must be True to use set_prompt_pom")
|
372
|
+
return self
|
373
|
+
|
374
|
+
def prompt_add_section(
|
375
|
+
self,
|
376
|
+
title: str,
|
377
|
+
body: str = "",
|
378
|
+
bullets: Optional[List[str]] = None,
|
379
|
+
numbered: bool = False,
|
380
|
+
numbered_bullets: bool = False,
|
381
|
+
subsections: Optional[List[Dict[str, Any]]] = None
|
382
|
+
) -> 'AgentBase':
|
383
|
+
"""
|
384
|
+
Add a section to the prompt
|
385
|
+
|
386
|
+
Args:
|
387
|
+
title: Section title
|
388
|
+
body: Optional section body text
|
389
|
+
bullets: Optional list of bullet points
|
390
|
+
numbered: Whether this section should be numbered
|
391
|
+
numbered_bullets: Whether bullets should be numbered
|
392
|
+
subsections: Optional list of subsection objects
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
Self for method chaining
|
396
|
+
"""
|
397
|
+
if self._use_pom and self.pom:
|
398
|
+
# Create parameters for add_section based on what's supported
|
399
|
+
kwargs = {}
|
400
|
+
|
401
|
+
# Start with basic parameters
|
402
|
+
kwargs['title'] = title
|
403
|
+
kwargs['body'] = body
|
404
|
+
if bullets:
|
405
|
+
kwargs['bullets'] = bullets
|
406
|
+
|
407
|
+
# Add optional parameters if they look supported
|
408
|
+
if hasattr(self.pom, 'add_section'):
|
409
|
+
sig = inspect.signature(self.pom.add_section)
|
410
|
+
if 'numbered' in sig.parameters:
|
411
|
+
kwargs['numbered'] = numbered
|
412
|
+
if 'numberedBullets' in sig.parameters:
|
413
|
+
kwargs['numberedBullets'] = numbered_bullets
|
414
|
+
|
415
|
+
# Create the section
|
416
|
+
section = self.pom.add_section(**kwargs)
|
417
|
+
|
418
|
+
# Now add subsections if provided, by calling add_subsection on the section
|
419
|
+
if subsections:
|
420
|
+
for subsection in subsections:
|
421
|
+
if 'title' in subsection:
|
422
|
+
section.add_subsection(
|
423
|
+
title=subsection.get('title'),
|
424
|
+
body=subsection.get('body', ''),
|
425
|
+
bullets=subsection.get('bullets', [])
|
426
|
+
)
|
427
|
+
|
428
|
+
return self
|
429
|
+
|
430
|
+
def prompt_add_to_section(
|
431
|
+
self,
|
432
|
+
title: str,
|
433
|
+
body: Optional[str] = None,
|
434
|
+
bullet: Optional[str] = None,
|
435
|
+
bullets: Optional[List[str]] = None
|
436
|
+
) -> 'AgentBase':
|
437
|
+
"""
|
438
|
+
Add content to an existing section (creating it if needed)
|
439
|
+
|
440
|
+
Args:
|
441
|
+
title: Section title
|
442
|
+
body: Optional text to append to section body
|
443
|
+
bullet: Optional single bullet point to add
|
444
|
+
bullets: Optional list of bullet points to add
|
445
|
+
|
446
|
+
Returns:
|
447
|
+
Self for method chaining
|
448
|
+
"""
|
449
|
+
if self._use_pom and self.pom:
|
450
|
+
self.pom.add_to_section(
|
451
|
+
title=title,
|
452
|
+
body=body,
|
453
|
+
bullet=bullet,
|
454
|
+
bullets=bullets
|
455
|
+
)
|
456
|
+
return self
|
457
|
+
|
458
|
+
def prompt_add_subsection(
|
459
|
+
self,
|
460
|
+
parent_title: str,
|
461
|
+
title: str,
|
462
|
+
body: str = "",
|
463
|
+
bullets: Optional[List[str]] = None
|
464
|
+
) -> 'AgentBase':
|
465
|
+
"""
|
466
|
+
Add a subsection to an existing section (creating parent if needed)
|
467
|
+
|
468
|
+
Args:
|
469
|
+
parent_title: Parent section title
|
470
|
+
title: Subsection title
|
471
|
+
body: Optional subsection body text
|
472
|
+
bullets: Optional list of bullet points
|
473
|
+
|
474
|
+
Returns:
|
475
|
+
Self for method chaining
|
476
|
+
"""
|
477
|
+
if self._use_pom and self.pom:
|
478
|
+
# First find or create the parent section
|
479
|
+
parent_section = None
|
480
|
+
|
481
|
+
# Try to find the parent section by title
|
482
|
+
if hasattr(self.pom, 'sections'):
|
483
|
+
for section in self.pom.sections:
|
484
|
+
if hasattr(section, 'title') and section.title == parent_title:
|
485
|
+
parent_section = section
|
486
|
+
break
|
487
|
+
|
488
|
+
# If parent section not found, create it
|
489
|
+
if not parent_section:
|
490
|
+
parent_section = self.pom.add_section(title=parent_title)
|
491
|
+
|
492
|
+
# Now call add_subsection on the parent section object, not on POM
|
493
|
+
parent_section.add_subsection(
|
494
|
+
title=title,
|
495
|
+
body=body,
|
496
|
+
bullets=bullets or []
|
497
|
+
)
|
498
|
+
|
499
|
+
return self
|
500
|
+
|
501
|
+
# ----------------------------------------------------------------------
|
502
|
+
# Tool/Function Management
|
503
|
+
# ----------------------------------------------------------------------
|
504
|
+
|
505
|
+
def define_tool(
|
506
|
+
self,
|
507
|
+
name: str,
|
508
|
+
description: str,
|
509
|
+
parameters: Dict[str, Any],
|
510
|
+
handler: Callable,
|
511
|
+
secure: bool = True,
|
512
|
+
fillers: Optional[Dict[str, List[str]]] = None
|
513
|
+
) -> 'AgentBase':
|
514
|
+
"""
|
515
|
+
Define a SWAIG function that the AI can call
|
516
|
+
|
517
|
+
Args:
|
518
|
+
name: Function name (must be unique)
|
519
|
+
description: Function description for the AI
|
520
|
+
parameters: JSON Schema of parameters
|
521
|
+
handler: Function to call when invoked
|
522
|
+
secure: Whether to require token validation
|
523
|
+
fillers: Optional dict mapping language codes to arrays of filler phrases
|
524
|
+
|
525
|
+
Returns:
|
526
|
+
Self for method chaining
|
527
|
+
"""
|
528
|
+
if name in self._swaig_functions:
|
529
|
+
raise ValueError(f"Tool with name '{name}' already exists")
|
530
|
+
|
531
|
+
self._swaig_functions[name] = SWAIGFunction(
|
532
|
+
name=name,
|
533
|
+
description=description,
|
534
|
+
parameters=parameters,
|
535
|
+
handler=handler,
|
536
|
+
secure=secure,
|
537
|
+
fillers=fillers
|
538
|
+
)
|
539
|
+
return self
|
540
|
+
|
541
|
+
def _tool_decorator(self, name=None, **kwargs):
|
542
|
+
"""
|
543
|
+
Decorator for defining SWAIG tools in a class
|
544
|
+
|
545
|
+
Used as:
|
546
|
+
|
547
|
+
@agent.tool(name="example_function", parameters={...})
|
548
|
+
def example_function(self, param1):
|
549
|
+
# ...
|
550
|
+
"""
|
551
|
+
def decorator(func):
|
552
|
+
nonlocal name
|
553
|
+
if name is None:
|
554
|
+
name = func.__name__
|
555
|
+
|
556
|
+
parameters = kwargs.get("parameters", {})
|
557
|
+
description = kwargs.get("description", func.__doc__ or f"Function {name}")
|
558
|
+
secure = kwargs.get("secure", True)
|
559
|
+
fillers = kwargs.get("fillers", None)
|
560
|
+
|
561
|
+
self.define_tool(
|
562
|
+
name=name,
|
563
|
+
description=description,
|
564
|
+
parameters=parameters,
|
565
|
+
handler=func,
|
566
|
+
secure=secure,
|
567
|
+
fillers=fillers
|
568
|
+
)
|
569
|
+
return func
|
570
|
+
return decorator
|
571
|
+
|
572
|
+
@classmethod
|
573
|
+
def tool(cls, name=None, **kwargs):
|
574
|
+
"""
|
575
|
+
Class method decorator for defining SWAIG tools
|
576
|
+
|
577
|
+
Used as:
|
578
|
+
|
579
|
+
@AgentBase.tool(name="example_function", parameters={...})
|
580
|
+
def example_function(self, param1):
|
581
|
+
# ...
|
582
|
+
"""
|
583
|
+
def decorator(func):
|
584
|
+
setattr(func, "_is_tool", True)
|
585
|
+
setattr(func, "_tool_name", name or func.__name__)
|
586
|
+
setattr(func, "_tool_params", kwargs)
|
587
|
+
return func
|
588
|
+
return decorator
|
589
|
+
|
590
|
+
# ----------------------------------------------------------------------
|
591
|
+
# Override Points for Subclasses
|
592
|
+
# ----------------------------------------------------------------------
|
593
|
+
|
594
|
+
def get_name(self) -> str:
|
595
|
+
"""
|
596
|
+
Get the agent name
|
597
|
+
|
598
|
+
Returns:
|
599
|
+
Agent name/identifier
|
600
|
+
"""
|
601
|
+
return self.name
|
602
|
+
|
603
|
+
def get_prompt(self) -> Union[str, List[Dict[str, Any]]]:
|
604
|
+
"""
|
605
|
+
Get the prompt for the agent
|
606
|
+
|
607
|
+
Returns:
|
608
|
+
Either a string prompt or a POM object as list of dicts
|
609
|
+
"""
|
610
|
+
# If using POM, return the POM structure
|
611
|
+
if self._use_pom and self.pom:
|
612
|
+
try:
|
613
|
+
# Try different methods that might be available on the POM implementation
|
614
|
+
if hasattr(self.pom, 'render_dict'):
|
615
|
+
return self.pom.render_dict()
|
616
|
+
elif hasattr(self.pom, 'to_dict'):
|
617
|
+
return self.pom.to_dict()
|
618
|
+
elif hasattr(self.pom, 'to_list'):
|
619
|
+
return self.pom.to_list()
|
620
|
+
elif hasattr(self.pom, 'render'):
|
621
|
+
render_result = self.pom.render()
|
622
|
+
# If render returns a string, we need to convert it to JSON
|
623
|
+
if isinstance(render_result, str):
|
624
|
+
try:
|
625
|
+
import json
|
626
|
+
return json.loads(render_result)
|
627
|
+
except:
|
628
|
+
# If we can't parse as JSON, fall back to raw text
|
629
|
+
pass
|
630
|
+
return render_result
|
631
|
+
else:
|
632
|
+
# Last resort: attempt to convert the POM object directly to a list/dict
|
633
|
+
# This assumes the POM object has a reasonable __str__ or __repr__ method
|
634
|
+
pom_data = self.pom.__dict__
|
635
|
+
if '_sections' in pom_data and isinstance(pom_data['_sections'], list):
|
636
|
+
return pom_data['_sections']
|
637
|
+
# Fall through to default if nothing worked
|
638
|
+
except Exception as e:
|
639
|
+
print(f"Error rendering POM: {e}")
|
640
|
+
# Fall back to raw text if POM fails
|
641
|
+
|
642
|
+
# Return raw text (either explicitly set or default)
|
643
|
+
return self._raw_prompt or f"You are {self.name}, a helpful AI assistant."
|
644
|
+
|
645
|
+
def get_post_prompt(self) -> Optional[str]:
|
646
|
+
"""
|
647
|
+
Get the post-prompt for the agent
|
648
|
+
|
649
|
+
Returns:
|
650
|
+
Post-prompt text or None if not set
|
651
|
+
"""
|
652
|
+
return self._post_prompt
|
653
|
+
|
654
|
+
def define_tools(self) -> List[SWAIGFunction]:
|
655
|
+
"""
|
656
|
+
Define the tools this agent can use
|
657
|
+
|
658
|
+
Returns:
|
659
|
+
List of SWAIGFunction objects
|
660
|
+
|
661
|
+
This method can be overridden by subclasses.
|
662
|
+
"""
|
663
|
+
return list(self._swaig_functions.values())
|
664
|
+
|
665
|
+
def on_summary(self, summary: Optional[Dict[str, Any]], raw_data: Optional[Dict[str, Any]] = None) -> None:
|
666
|
+
"""
|
667
|
+
Called when a post-prompt summary is received
|
668
|
+
|
669
|
+
Args:
|
670
|
+
summary: The summary object or None if no summary was found
|
671
|
+
raw_data: The complete raw POST data from the request
|
672
|
+
"""
|
673
|
+
# Default implementation does nothing
|
674
|
+
pass
|
675
|
+
|
676
|
+
def on_function_call(self, name: str, args: Dict[str, Any], raw_data: Optional[Dict[str, Any]] = None) -> Any:
|
677
|
+
"""
|
678
|
+
Called when a SWAIG function is invoked
|
679
|
+
|
680
|
+
Args:
|
681
|
+
name: Function name
|
682
|
+
args: Function arguments
|
683
|
+
raw_data: Raw request data
|
684
|
+
|
685
|
+
Returns:
|
686
|
+
Function result
|
687
|
+
"""
|
688
|
+
# Check if the function is registered
|
689
|
+
if name not in self._swaig_functions:
|
690
|
+
# If the function is not found, return an error
|
691
|
+
return {"response": f"Function '{name}' not found"}
|
692
|
+
|
693
|
+
# Get the function
|
694
|
+
func = self._swaig_functions[name]
|
695
|
+
|
696
|
+
# Call the handler
|
697
|
+
try:
|
698
|
+
result = func.handler(args, raw_data)
|
699
|
+
if result is None:
|
700
|
+
# If the handler returns None, create a default response
|
701
|
+
result = SwaigFunctionResult("Function executed successfully")
|
702
|
+
return result
|
703
|
+
except Exception as e:
|
704
|
+
# If the handler raises an exception, return an error response
|
705
|
+
return {"response": f"Error executing function '{name}': {str(e)}"}
|
706
|
+
|
707
|
+
def validate_basic_auth(self, username: str, password: str) -> bool:
|
708
|
+
"""
|
709
|
+
Validate basic auth credentials
|
710
|
+
|
711
|
+
Args:
|
712
|
+
username: Username from request
|
713
|
+
password: Password from request
|
714
|
+
|
715
|
+
Returns:
|
716
|
+
True if valid, False otherwise
|
717
|
+
|
718
|
+
This method can be overridden by subclasses.
|
719
|
+
"""
|
720
|
+
return (username, password) == self._basic_auth
|
721
|
+
|
722
|
+
def _create_tool_token(self, tool_name: str, call_id: str) -> str:
|
723
|
+
"""
|
724
|
+
Create a secure token for a tool call
|
725
|
+
|
726
|
+
Args:
|
727
|
+
tool_name: Name of the tool
|
728
|
+
call_id: Call ID for this session
|
729
|
+
|
730
|
+
Returns:
|
731
|
+
Secure token string
|
732
|
+
"""
|
733
|
+
return self._session_manager.create_tool_token(tool_name, call_id)
|
734
|
+
|
735
|
+
def validate_tool_token(self, function_name: str, token: str, call_id: str) -> bool:
|
736
|
+
"""
|
737
|
+
Validate a tool token
|
738
|
+
|
739
|
+
Args:
|
740
|
+
function_name: Name of the function/tool
|
741
|
+
token: Token to validate
|
742
|
+
call_id: Call ID for the session
|
743
|
+
|
744
|
+
Returns:
|
745
|
+
True if token is valid, False otherwise
|
746
|
+
"""
|
747
|
+
# Skip validation for non-secure tools
|
748
|
+
if function_name not in self._swaig_functions:
|
749
|
+
return False
|
750
|
+
|
751
|
+
if not self._swaig_functions[function_name].secure:
|
752
|
+
return True
|
753
|
+
|
754
|
+
return self._session_manager.validate_tool_token(function_name, token, call_id)
|
755
|
+
|
756
|
+
# ----------------------------------------------------------------------
|
757
|
+
# Web Server and Routing
|
758
|
+
# ----------------------------------------------------------------------
|
759
|
+
|
760
|
+
def get_basic_auth_credentials(self, include_source: bool = False) -> Union[Tuple[str, str], Tuple[str, str, str]]:
|
761
|
+
"""
|
762
|
+
Get the basic auth credentials
|
763
|
+
|
764
|
+
Args:
|
765
|
+
include_source: Whether to include the source of the credentials
|
766
|
+
|
767
|
+
Returns:
|
768
|
+
If include_source is False:
|
769
|
+
(username, password) tuple
|
770
|
+
If include_source is True:
|
771
|
+
(username, password, source) tuple, where source is one of:
|
772
|
+
"provided", "environment", or "generated"
|
773
|
+
"""
|
774
|
+
username, password = self._basic_auth
|
775
|
+
|
776
|
+
if not include_source:
|
777
|
+
return (username, password)
|
778
|
+
|
779
|
+
# Determine source of credentials
|
780
|
+
env_user = os.environ.get('SWML_BASIC_AUTH_USER')
|
781
|
+
env_pass = os.environ.get('SWML_BASIC_AUTH_PASSWORD')
|
782
|
+
|
783
|
+
# More robust source detection
|
784
|
+
if env_user and env_pass and username == env_user and password == env_pass:
|
785
|
+
source = "environment"
|
786
|
+
elif username.startswith("user_") and len(password) > 20: # Format of generated credentials
|
787
|
+
source = "generated"
|
788
|
+
else:
|
789
|
+
source = "provided"
|
790
|
+
|
791
|
+
return (username, password, source)
|
792
|
+
|
793
|
+
def get_full_url(self, include_auth: bool = False) -> str:
|
794
|
+
"""
|
795
|
+
Get the full URL for this agent's endpoint
|
796
|
+
|
797
|
+
Args:
|
798
|
+
include_auth: Whether to include authentication credentials in the URL
|
799
|
+
|
800
|
+
Returns:
|
801
|
+
Full URL including host, port, and route (with auth if requested)
|
802
|
+
"""
|
803
|
+
# Start with the base URL (either proxy or local)
|
804
|
+
if self._proxy_url_base:
|
805
|
+
# Use the proxy URL base from environment, ensuring we don't duplicate the route
|
806
|
+
# Strip any trailing slashes from proxy base
|
807
|
+
proxy_base = self._proxy_url_base.rstrip('/')
|
808
|
+
# Make sure route starts with a slash for consistency
|
809
|
+
route = self.route if self.route.startswith('/') else f"/{self.route}"
|
810
|
+
base_url = f"{proxy_base}{route}"
|
811
|
+
else:
|
812
|
+
# Default local URL
|
813
|
+
if self.host in ("0.0.0.0", "127.0.0.1", "localhost"):
|
814
|
+
host = "localhost"
|
815
|
+
else:
|
816
|
+
host = self.host
|
817
|
+
|
818
|
+
base_url = f"http://{host}:{self.port}{self.route}"
|
819
|
+
|
820
|
+
# Add auth if requested
|
821
|
+
if include_auth:
|
822
|
+
username, password = self._basic_auth
|
823
|
+
url = urlparse(base_url)
|
824
|
+
return url._replace(netloc=f"{username}:{password}@{url.netloc}").geturl()
|
825
|
+
|
826
|
+
return base_url
|
827
|
+
|
828
|
+
def _build_webhook_url(self, endpoint: str, query_params: Optional[Dict[str, str]] = None) -> str:
|
829
|
+
"""
|
830
|
+
Helper method to build webhook URLs consistently
|
831
|
+
|
832
|
+
Args:
|
833
|
+
endpoint: The endpoint path (e.g., "swaig", "post_prompt")
|
834
|
+
query_params: Optional query parameters to append
|
835
|
+
|
836
|
+
Returns:
|
837
|
+
Fully constructed webhook URL
|
838
|
+
"""
|
839
|
+
# Base URL construction
|
840
|
+
if hasattr(self, '_proxy_url_base') and self._proxy_url_base:
|
841
|
+
# For proxy URLs
|
842
|
+
base = self._proxy_url_base.rstrip('/')
|
843
|
+
|
844
|
+
# Always add auth credentials
|
845
|
+
username, password = self._basic_auth
|
846
|
+
url = urlparse(base)
|
847
|
+
base = url._replace(netloc=f"{username}:{password}@{url.netloc}").geturl()
|
848
|
+
else:
|
849
|
+
# For local URLs
|
850
|
+
if self.host in ("0.0.0.0", "127.0.0.1", "localhost"):
|
851
|
+
host = "localhost"
|
852
|
+
else:
|
853
|
+
host = self.host
|
854
|
+
|
855
|
+
# Always include auth credentials
|
856
|
+
username, password = self._basic_auth
|
857
|
+
base = f"http://{username}:{password}@{host}:{self.port}"
|
858
|
+
|
859
|
+
# Ensure the endpoint has a trailing slash to prevent redirects
|
860
|
+
if endpoint in ["swaig", "post_prompt"]:
|
861
|
+
endpoint = f"{endpoint}/"
|
862
|
+
|
863
|
+
# Simple path - use the route directly with the endpoint
|
864
|
+
path = f"{self.route}/{endpoint}"
|
865
|
+
|
866
|
+
# Construct full URL
|
867
|
+
url = f"{base}{path}"
|
868
|
+
|
869
|
+
# Add query parameters if any (only if they have values)
|
870
|
+
# But NEVER add call_id parameter - it should be in the body, not the URL
|
871
|
+
if query_params:
|
872
|
+
# Remove any call_id from query params
|
873
|
+
filtered_params = {k: v for k, v in query_params.items() if k != "call_id" and v}
|
874
|
+
if filtered_params:
|
875
|
+
params = "&".join([f"{k}={v}" for k, v in filtered_params.items()])
|
876
|
+
url = f"{url}?{params}"
|
877
|
+
|
878
|
+
return url
|
879
|
+
|
880
|
+
def _render_swml(self, call_id: str = None, modifications: Optional[dict] = None) -> str:
|
881
|
+
"""
|
882
|
+
Render the complete SWML document using SWMLService methods
|
883
|
+
|
884
|
+
Args:
|
885
|
+
call_id: Optional call ID for session-specific tokens
|
886
|
+
modifications: Optional dict of modifications to apply to the SWML
|
887
|
+
|
888
|
+
Returns:
|
889
|
+
SWML document as a string
|
890
|
+
"""
|
891
|
+
# Reset the document to a clean state
|
892
|
+
self.reset_document()
|
893
|
+
|
894
|
+
# Get prompt
|
895
|
+
prompt = self.get_prompt()
|
896
|
+
prompt_is_pom = isinstance(prompt, list)
|
897
|
+
|
898
|
+
# Get post-prompt
|
899
|
+
post_prompt = self.get_post_prompt()
|
900
|
+
|
901
|
+
# Generate a call ID if needed
|
902
|
+
if self._enable_state_tracking and call_id is None:
|
903
|
+
call_id = self._session_manager.create_session()
|
904
|
+
|
905
|
+
# Empty query params - no need to include call_id in URLs
|
906
|
+
query_params = {}
|
907
|
+
|
908
|
+
# Get the default webhook URL with auth
|
909
|
+
default_webhook_url = self._build_webhook_url("swaig", query_params)
|
910
|
+
|
911
|
+
# Use override if set
|
912
|
+
if hasattr(self, '_web_hook_url_override') and self._web_hook_url_override:
|
913
|
+
default_webhook_url = self._web_hook_url_override
|
914
|
+
|
915
|
+
# Prepare SWAIG object (correct format)
|
916
|
+
swaig_obj = {}
|
917
|
+
|
918
|
+
# Add defaults if we have functions
|
919
|
+
if self._swaig_functions:
|
920
|
+
swaig_obj["defaults"] = {
|
921
|
+
"web_hook_url": default_webhook_url
|
922
|
+
}
|
923
|
+
|
924
|
+
# Add native_functions if any are defined
|
925
|
+
if self.native_functions:
|
926
|
+
swaig_obj["native_functions"] = self.native_functions
|
927
|
+
|
928
|
+
# Add includes if any are defined
|
929
|
+
if self._function_includes:
|
930
|
+
swaig_obj["includes"] = self._function_includes
|
931
|
+
|
932
|
+
# Create functions array
|
933
|
+
functions = []
|
934
|
+
|
935
|
+
# Add each function to the functions array
|
936
|
+
for name, func in self._swaig_functions.items():
|
937
|
+
# Get token for secure functions when we have a call_id
|
938
|
+
token = None
|
939
|
+
if func.secure and call_id:
|
940
|
+
token = self._create_tool_token(tool_name=name, call_id=call_id)
|
941
|
+
|
942
|
+
# Prepare function entry
|
943
|
+
function_entry = {
|
944
|
+
"function": name,
|
945
|
+
"description": func.description,
|
946
|
+
"parameters": {
|
947
|
+
"type": "object",
|
948
|
+
"properties": func.parameters
|
949
|
+
}
|
950
|
+
}
|
951
|
+
|
952
|
+
# Add fillers if present
|
953
|
+
if func.fillers:
|
954
|
+
function_entry["fillers"] = func.fillers
|
955
|
+
|
956
|
+
# Add token to URL if we have one
|
957
|
+
if token:
|
958
|
+
# Create token params without call_id
|
959
|
+
token_params = {"token": token}
|
960
|
+
function_entry["web_hook_url"] = self._build_webhook_url("swaig", token_params)
|
961
|
+
|
962
|
+
functions.append(function_entry)
|
963
|
+
|
964
|
+
# Add functions array to SWAIG object if we have any
|
965
|
+
if functions:
|
966
|
+
swaig_obj["functions"] = functions
|
967
|
+
|
968
|
+
# Add post-prompt URL if we have a post-prompt
|
969
|
+
post_prompt_url = None
|
970
|
+
if post_prompt:
|
971
|
+
post_prompt_url = self._build_webhook_url("post_prompt", {})
|
972
|
+
|
973
|
+
# Use override if set
|
974
|
+
if hasattr(self, '_post_prompt_url_override') and self._post_prompt_url_override:
|
975
|
+
post_prompt_url = self._post_prompt_url_override
|
976
|
+
|
977
|
+
# Add answer verb with auto-answer enabled
|
978
|
+
self.add_answer_verb()
|
979
|
+
|
980
|
+
# Use the AI verb handler to build and validate the AI verb config
|
981
|
+
ai_config = {}
|
982
|
+
|
983
|
+
# Get the AI verb handler
|
984
|
+
ai_handler = self.verb_registry.get_handler("ai")
|
985
|
+
if ai_handler:
|
986
|
+
try:
|
987
|
+
# Build AI config using the proper handler
|
988
|
+
ai_config = ai_handler.build_config(
|
989
|
+
prompt_text=None if prompt_is_pom else prompt,
|
990
|
+
prompt_pom=prompt if prompt_is_pom else None,
|
991
|
+
post_prompt=post_prompt,
|
992
|
+
post_prompt_url=post_prompt_url,
|
993
|
+
swaig=swaig_obj if swaig_obj else None
|
994
|
+
)
|
995
|
+
|
996
|
+
# Add new configuration parameters to the AI config
|
997
|
+
|
998
|
+
# Add hints if any
|
999
|
+
if self._hints:
|
1000
|
+
ai_config["hints"] = self._hints
|
1001
|
+
|
1002
|
+
# Add languages if any
|
1003
|
+
if self._languages:
|
1004
|
+
ai_config["languages"] = self._languages
|
1005
|
+
|
1006
|
+
# Add pronunciation rules if any
|
1007
|
+
if self._pronounce:
|
1008
|
+
ai_config["pronounce"] = self._pronounce
|
1009
|
+
|
1010
|
+
# Add params if any
|
1011
|
+
if self._params:
|
1012
|
+
ai_config["params"] = self._params
|
1013
|
+
|
1014
|
+
# Add global_data if any
|
1015
|
+
if self._global_data:
|
1016
|
+
ai_config["global_data"] = self._global_data
|
1017
|
+
|
1018
|
+
except ValueError as e:
|
1019
|
+
if not self._suppress_logs:
|
1020
|
+
print(f"Error building AI verb configuration: {str(e)}")
|
1021
|
+
else:
|
1022
|
+
# Fallback if no handler (shouldn't happen but just in case)
|
1023
|
+
ai_config = {
|
1024
|
+
"prompt": {
|
1025
|
+
"text" if not prompt_is_pom else "pom": prompt
|
1026
|
+
}
|
1027
|
+
}
|
1028
|
+
|
1029
|
+
if post_prompt:
|
1030
|
+
ai_config["post_prompt"] = {"text": post_prompt}
|
1031
|
+
if post_prompt_url:
|
1032
|
+
ai_config["post_prompt_url"] = post_prompt_url
|
1033
|
+
|
1034
|
+
if swaig_obj:
|
1035
|
+
ai_config["SWAIG"] = swaig_obj
|
1036
|
+
|
1037
|
+
# Add the new configurations if not already added by the handler
|
1038
|
+
if self._hints and "hints" not in ai_config:
|
1039
|
+
ai_config["hints"] = self._hints
|
1040
|
+
|
1041
|
+
if self._languages and "languages" not in ai_config:
|
1042
|
+
ai_config["languages"] = self._languages
|
1043
|
+
|
1044
|
+
if self._pronounce and "pronounce" not in ai_config:
|
1045
|
+
ai_config["pronounce"] = self._pronounce
|
1046
|
+
|
1047
|
+
if self._params and "params" not in ai_config:
|
1048
|
+
ai_config["params"] = self._params
|
1049
|
+
|
1050
|
+
if self._global_data and "global_data" not in ai_config:
|
1051
|
+
ai_config["global_data"] = self._global_data
|
1052
|
+
|
1053
|
+
# Add the AI verb to the document
|
1054
|
+
self.add_verb("ai", ai_config)
|
1055
|
+
|
1056
|
+
# Apply any modifications from the callback
|
1057
|
+
if modifications and isinstance(modifications, dict):
|
1058
|
+
# We need a way to apply modifications to the document
|
1059
|
+
# Get the current document
|
1060
|
+
document = self.get_document()
|
1061
|
+
|
1062
|
+
# Simple recursive update function
|
1063
|
+
def update_dict(target, source):
|
1064
|
+
for key, value in source.items():
|
1065
|
+
if isinstance(value, dict) and key in target and isinstance(target[key], dict):
|
1066
|
+
update_dict(target[key], value)
|
1067
|
+
else:
|
1068
|
+
target[key] = value
|
1069
|
+
|
1070
|
+
# Apply modifications to the document
|
1071
|
+
update_dict(document, modifications)
|
1072
|
+
|
1073
|
+
# Since we can't directly set the document in SWMLService,
|
1074
|
+
# we'll need to reset and rebuild if there are modifications
|
1075
|
+
self.reset_document()
|
1076
|
+
|
1077
|
+
# Add the modified document's sections
|
1078
|
+
for section_name, section_content in document["sections"].items():
|
1079
|
+
if section_name != "main": # Main section is created by default
|
1080
|
+
self.add_section(section_name)
|
1081
|
+
|
1082
|
+
# Add each verb to the section
|
1083
|
+
for verb_obj in section_content:
|
1084
|
+
for verb_name, verb_config in verb_obj.items():
|
1085
|
+
self.add_verb_to_section(section_name, verb_name, verb_config)
|
1086
|
+
|
1087
|
+
# Return the rendered document as a string
|
1088
|
+
return self.render_document()
|
1089
|
+
|
1090
|
+
def _check_basic_auth(self, request: Request) -> bool:
|
1091
|
+
"""
|
1092
|
+
Check basic auth from a request
|
1093
|
+
|
1094
|
+
Args:
|
1095
|
+
request: FastAPI request object
|
1096
|
+
|
1097
|
+
Returns:
|
1098
|
+
True if auth is valid, False otherwise
|
1099
|
+
"""
|
1100
|
+
auth_header = request.headers.get("Authorization")
|
1101
|
+
if not auth_header or not auth_header.startswith("Basic "):
|
1102
|
+
return False
|
1103
|
+
|
1104
|
+
try:
|
1105
|
+
# Decode the base64 credentials
|
1106
|
+
credentials = base64.b64decode(auth_header[6:]).decode("utf-8")
|
1107
|
+
username, password = credentials.split(":", 1)
|
1108
|
+
return self.validate_basic_auth(username, password)
|
1109
|
+
except Exception:
|
1110
|
+
return False
|
1111
|
+
|
1112
|
+
def as_router(self) -> APIRouter:
|
1113
|
+
"""
|
1114
|
+
Get a FastAPI router for this agent
|
1115
|
+
|
1116
|
+
Returns:
|
1117
|
+
FastAPI router
|
1118
|
+
"""
|
1119
|
+
# Get the base router from SWMLService
|
1120
|
+
router = super().as_router()
|
1121
|
+
|
1122
|
+
# Override the root endpoint to use our SWML rendering
|
1123
|
+
@router.get("/")
|
1124
|
+
@router.post("/")
|
1125
|
+
async def handle_root_no_slash(request: Request):
|
1126
|
+
return await self._handle_root_request(request)
|
1127
|
+
|
1128
|
+
# Root endpoint - with trailing slash
|
1129
|
+
@router.get("/")
|
1130
|
+
@router.post("/")
|
1131
|
+
async def handle_root_with_slash(request: Request):
|
1132
|
+
return await self._handle_root_request(request)
|
1133
|
+
|
1134
|
+
# Debug endpoint - without trailing slash
|
1135
|
+
@router.get("/debug")
|
1136
|
+
@router.post("/debug")
|
1137
|
+
async def handle_debug_no_slash(request: Request):
|
1138
|
+
return await self._handle_debug_request(request)
|
1139
|
+
|
1140
|
+
# Debug endpoint - with trailing slash
|
1141
|
+
@router.get("/debug/")
|
1142
|
+
@router.post("/debug/")
|
1143
|
+
async def handle_debug_with_slash(request: Request):
|
1144
|
+
return await self._handle_debug_request(request)
|
1145
|
+
|
1146
|
+
# SWAIG endpoint - without trailing slash
|
1147
|
+
@router.get("/swaig")
|
1148
|
+
@router.post("/swaig")
|
1149
|
+
async def handle_swaig_no_slash(request: Request):
|
1150
|
+
return await self._handle_swaig_request(request)
|
1151
|
+
|
1152
|
+
# SWAIG endpoint - with trailing slash
|
1153
|
+
@router.get("/swaig/")
|
1154
|
+
@router.post("/swaig/")
|
1155
|
+
async def handle_swaig_with_slash(request: Request):
|
1156
|
+
return await self._handle_swaig_request(request)
|
1157
|
+
|
1158
|
+
# Post-prompt endpoint - without trailing slash
|
1159
|
+
@router.get("/post_prompt")
|
1160
|
+
@router.post("/post_prompt")
|
1161
|
+
async def handle_post_prompt_no_slash(request: Request):
|
1162
|
+
return await self._handle_post_prompt_request(request)
|
1163
|
+
|
1164
|
+
# Post-prompt endpoint - with trailing slash
|
1165
|
+
@router.get("/post_prompt/")
|
1166
|
+
@router.post("/post_prompt/")
|
1167
|
+
async def handle_post_prompt_with_slash(request: Request):
|
1168
|
+
return await self._handle_post_prompt_request(request)
|
1169
|
+
|
1170
|
+
self._router = router
|
1171
|
+
return router
|
1172
|
+
|
1173
|
+
async def _handle_root_request(self, request: Request):
|
1174
|
+
"""Handle GET/POST requests to the root endpoint"""
|
1175
|
+
# Check if this is a callback path request
|
1176
|
+
callback_path = getattr(request.state, "callback_path", None)
|
1177
|
+
|
1178
|
+
req_log = self.log.bind(
|
1179
|
+
endpoint="root" if not callback_path else f"callback:{callback_path}",
|
1180
|
+
method=request.method,
|
1181
|
+
path=request.url.path
|
1182
|
+
)
|
1183
|
+
|
1184
|
+
req_log.debug("endpoint_called")
|
1185
|
+
|
1186
|
+
try:
|
1187
|
+
# Check auth
|
1188
|
+
if not self._check_basic_auth(request):
|
1189
|
+
req_log.warning("unauthorized_access_attempt")
|
1190
|
+
return Response(
|
1191
|
+
content=json.dumps({"error": "Unauthorized"}),
|
1192
|
+
status_code=401,
|
1193
|
+
headers={"WWW-Authenticate": "Basic"},
|
1194
|
+
media_type="application/json"
|
1195
|
+
)
|
1196
|
+
|
1197
|
+
# Try to parse request body for POST
|
1198
|
+
body = {}
|
1199
|
+
call_id = None
|
1200
|
+
|
1201
|
+
if request.method == "POST":
|
1202
|
+
# Check if body is empty first
|
1203
|
+
raw_body = await request.body()
|
1204
|
+
if raw_body:
|
1205
|
+
try:
|
1206
|
+
body = await request.json()
|
1207
|
+
req_log.debug("request_body_received", body_size=len(str(body)))
|
1208
|
+
if body:
|
1209
|
+
req_log.debug("request_body", body=json.dumps(body, indent=2))
|
1210
|
+
except Exception as e:
|
1211
|
+
req_log.warning("error_parsing_request_body", error=str(e), traceback=traceback.format_exc())
|
1212
|
+
req_log.debug("raw_request_body", body=raw_body.decode('utf-8', errors='replace'))
|
1213
|
+
# Continue processing with empty body
|
1214
|
+
body = {}
|
1215
|
+
else:
|
1216
|
+
req_log.debug("empty_request_body")
|
1217
|
+
|
1218
|
+
# Get call_id from body if present
|
1219
|
+
call_id = body.get("call_id")
|
1220
|
+
else:
|
1221
|
+
# Get call_id from query params for GET
|
1222
|
+
call_id = request.query_params.get("call_id")
|
1223
|
+
|
1224
|
+
# Add call_id to logger if any
|
1225
|
+
if call_id:
|
1226
|
+
req_log = req_log.bind(call_id=call_id)
|
1227
|
+
req_log.debug("call_id_identified")
|
1228
|
+
|
1229
|
+
# Check if this is a callback path and we need to apply routing
|
1230
|
+
if callback_path and hasattr(self, '_routing_callbacks') and callback_path in self._routing_callbacks:
|
1231
|
+
callback_fn = self._routing_callbacks[callback_path]
|
1232
|
+
|
1233
|
+
if request.method == "POST" and body:
|
1234
|
+
req_log.debug("processing_routing_callback", path=callback_path)
|
1235
|
+
# Call the routing callback
|
1236
|
+
try:
|
1237
|
+
route = callback_fn(request, body)
|
1238
|
+
if route is not None:
|
1239
|
+
req_log.info("routing_request", route=route)
|
1240
|
+
# Return a redirect to the new route
|
1241
|
+
return Response(
|
1242
|
+
status_code=307, # 307 Temporary Redirect preserves the method and body
|
1243
|
+
headers={"Location": route}
|
1244
|
+
)
|
1245
|
+
except Exception as e:
|
1246
|
+
req_log.error("error_in_routing_callback", error=str(e), traceback=traceback.format_exc())
|
1247
|
+
|
1248
|
+
# Allow subclasses to inspect/modify the request
|
1249
|
+
modifications = None
|
1250
|
+
if body:
|
1251
|
+
try:
|
1252
|
+
modifications = self.on_swml_request(body)
|
1253
|
+
if modifications:
|
1254
|
+
req_log.debug("request_modifications_applied")
|
1255
|
+
except Exception as e:
|
1256
|
+
req_log.error("error_in_request_modifier", error=str(e), traceback=traceback.format_exc())
|
1257
|
+
|
1258
|
+
# Render SWML
|
1259
|
+
swml = self._render_swml(call_id, modifications)
|
1260
|
+
req_log.debug("swml_rendered", swml_size=len(swml))
|
1261
|
+
|
1262
|
+
# Return as JSON
|
1263
|
+
req_log.info("request_successful")
|
1264
|
+
return Response(
|
1265
|
+
content=swml,
|
1266
|
+
media_type="application/json"
|
1267
|
+
)
|
1268
|
+
except Exception as e:
|
1269
|
+
req_log.error("request_failed", error=str(e), traceback=traceback.format_exc())
|
1270
|
+
return Response(
|
1271
|
+
content=json.dumps({"error": str(e), "traceback": traceback.format_exc()}),
|
1272
|
+
status_code=500,
|
1273
|
+
media_type="application/json"
|
1274
|
+
)
|
1275
|
+
|
1276
|
+
async def _handle_debug_request(self, request: Request):
|
1277
|
+
"""Handle GET/POST requests to the debug endpoint"""
|
1278
|
+
req_log = self.log.bind(
|
1279
|
+
endpoint="debug",
|
1280
|
+
method=request.method,
|
1281
|
+
path=request.url.path
|
1282
|
+
)
|
1283
|
+
|
1284
|
+
req_log.debug("endpoint_called")
|
1285
|
+
|
1286
|
+
try:
|
1287
|
+
# Check auth
|
1288
|
+
if not self._check_basic_auth(request):
|
1289
|
+
req_log.warning("unauthorized_access_attempt")
|
1290
|
+
return Response(
|
1291
|
+
content=json.dumps({"error": "Unauthorized"}),
|
1292
|
+
status_code=401,
|
1293
|
+
headers={"WWW-Authenticate": "Basic"},
|
1294
|
+
media_type="application/json"
|
1295
|
+
)
|
1296
|
+
|
1297
|
+
# Get call_id from either query params (GET) or body (POST)
|
1298
|
+
call_id = None
|
1299
|
+
body = {}
|
1300
|
+
|
1301
|
+
if request.method == "POST":
|
1302
|
+
try:
|
1303
|
+
body = await request.json()
|
1304
|
+
req_log.debug("request_body_received", body_size=len(str(body)))
|
1305
|
+
if body:
|
1306
|
+
req_log.debug("request_body", body=json.dumps(body, indent=2))
|
1307
|
+
call_id = body.get("call_id")
|
1308
|
+
except Exception as e:
|
1309
|
+
req_log.warning("error_parsing_request_body", error=str(e), traceback=traceback.format_exc())
|
1310
|
+
try:
|
1311
|
+
body_text = await request.body()
|
1312
|
+
req_log.debug("raw_request_body", body=body_text.decode('utf-8', errors='replace'))
|
1313
|
+
except:
|
1314
|
+
pass
|
1315
|
+
else:
|
1316
|
+
call_id = request.query_params.get("call_id")
|
1317
|
+
|
1318
|
+
# Add call_id to logger if any
|
1319
|
+
if call_id:
|
1320
|
+
req_log = req_log.bind(call_id=call_id)
|
1321
|
+
req_log.debug("call_id_identified")
|
1322
|
+
|
1323
|
+
# Allow subclasses to inspect/modify the request
|
1324
|
+
modifications = None
|
1325
|
+
if body:
|
1326
|
+
modifications = self.on_swml_request(body)
|
1327
|
+
if modifications:
|
1328
|
+
req_log.debug("request_modifications_applied")
|
1329
|
+
|
1330
|
+
# Render SWML
|
1331
|
+
swml = self._render_swml(call_id, modifications)
|
1332
|
+
req_log.debug("swml_rendered", swml_size=len(swml))
|
1333
|
+
|
1334
|
+
# Return as JSON
|
1335
|
+
req_log.info("request_successful")
|
1336
|
+
return Response(
|
1337
|
+
content=swml,
|
1338
|
+
media_type="application/json",
|
1339
|
+
headers={"X-Debug": "true"}
|
1340
|
+
)
|
1341
|
+
except Exception as e:
|
1342
|
+
req_log.error("request_failed", error=str(e), traceback=traceback.format_exc())
|
1343
|
+
return Response(
|
1344
|
+
content=json.dumps({"error": str(e), "traceback": traceback.format_exc()}),
|
1345
|
+
status_code=500,
|
1346
|
+
media_type="application/json"
|
1347
|
+
)
|
1348
|
+
|
1349
|
+
async def _handle_swaig_request(self, request: Request):
|
1350
|
+
"""Handle GET/POST requests to the SWAIG endpoint"""
|
1351
|
+
req_log = self.log.bind(
|
1352
|
+
endpoint="swaig",
|
1353
|
+
method=request.method,
|
1354
|
+
path=request.url.path
|
1355
|
+
)
|
1356
|
+
|
1357
|
+
req_log.debug("endpoint_called")
|
1358
|
+
|
1359
|
+
try:
|
1360
|
+
# Check auth
|
1361
|
+
if not self._check_basic_auth(request):
|
1362
|
+
req_log.warning("unauthorized_access_attempt")
|
1363
|
+
return Response(
|
1364
|
+
content=json.dumps({"error": "Unauthorized"}),
|
1365
|
+
status_code=401,
|
1366
|
+
headers={"WWW-Authenticate": "Basic"},
|
1367
|
+
media_type="application/json"
|
1368
|
+
)
|
1369
|
+
|
1370
|
+
# Handle differently based on method
|
1371
|
+
if request.method == "GET":
|
1372
|
+
# For GET requests, return the SWML document (same as root endpoint)
|
1373
|
+
call_id = request.query_params.get("call_id")
|
1374
|
+
swml = self._render_swml(call_id)
|
1375
|
+
req_log.debug("swml_rendered", swml_size=len(swml))
|
1376
|
+
return Response(
|
1377
|
+
content=swml,
|
1378
|
+
media_type="application/json"
|
1379
|
+
)
|
1380
|
+
|
1381
|
+
# For POST requests, process SWAIG function calls
|
1382
|
+
try:
|
1383
|
+
body = await request.json()
|
1384
|
+
req_log.debug("request_body_received", body_size=len(str(body)))
|
1385
|
+
if body:
|
1386
|
+
req_log.debug("request_body", body=json.dumps(body, indent=2))
|
1387
|
+
except Exception as e:
|
1388
|
+
req_log.error("error_parsing_request_body", error=str(e), traceback=traceback.format_exc())
|
1389
|
+
body = {}
|
1390
|
+
|
1391
|
+
# Extract function name
|
1392
|
+
function_name = body.get("function")
|
1393
|
+
if not function_name:
|
1394
|
+
req_log.warning("missing_function_name")
|
1395
|
+
return Response(
|
1396
|
+
content=json.dumps({"error": "Missing function name"}),
|
1397
|
+
status_code=400,
|
1398
|
+
media_type="application/json"
|
1399
|
+
)
|
1400
|
+
|
1401
|
+
# Add function info to logger
|
1402
|
+
req_log = req_log.bind(function=function_name)
|
1403
|
+
req_log.debug("function_call_received")
|
1404
|
+
|
1405
|
+
# Extract arguments
|
1406
|
+
args = {}
|
1407
|
+
if "argument" in body and isinstance(body["argument"], dict):
|
1408
|
+
if "parsed" in body["argument"] and isinstance(body["argument"]["parsed"], list) and body["argument"]["parsed"]:
|
1409
|
+
args = body["argument"]["parsed"][0]
|
1410
|
+
req_log.debug("parsed_arguments", args=json.dumps(args, indent=2))
|
1411
|
+
elif "raw" in body["argument"]:
|
1412
|
+
try:
|
1413
|
+
args = json.loads(body["argument"]["raw"])
|
1414
|
+
req_log.debug("raw_arguments_parsed", args=json.dumps(args, indent=2))
|
1415
|
+
except Exception as e:
|
1416
|
+
req_log.error("error_parsing_raw_arguments", error=str(e), raw=body["argument"]["raw"])
|
1417
|
+
|
1418
|
+
# Get call_id from body
|
1419
|
+
call_id = body.get("call_id")
|
1420
|
+
if call_id:
|
1421
|
+
req_log = req_log.bind(call_id=call_id)
|
1422
|
+
req_log.debug("call_id_identified")
|
1423
|
+
|
1424
|
+
# Call the function
|
1425
|
+
try:
|
1426
|
+
result = self.on_function_call(function_name, args, body)
|
1427
|
+
|
1428
|
+
# Convert result to dict if needed
|
1429
|
+
if isinstance(result, SwaigFunctionResult):
|
1430
|
+
result_dict = result.to_dict()
|
1431
|
+
elif isinstance(result, dict):
|
1432
|
+
result_dict = result
|
1433
|
+
else:
|
1434
|
+
result_dict = {"response": str(result)}
|
1435
|
+
|
1436
|
+
req_log.info("function_executed_successfully")
|
1437
|
+
req_log.debug("function_result", result=json.dumps(result_dict, indent=2))
|
1438
|
+
return result_dict
|
1439
|
+
except Exception as e:
|
1440
|
+
req_log.error("function_execution_error", error=str(e), traceback=traceback.format_exc())
|
1441
|
+
return {"error": str(e), "function": function_name}
|
1442
|
+
|
1443
|
+
except Exception as e:
|
1444
|
+
req_log.error("request_failed", error=str(e), traceback=traceback.format_exc())
|
1445
|
+
return Response(
|
1446
|
+
content=json.dumps({"error": str(e)}),
|
1447
|
+
status_code=500,
|
1448
|
+
media_type="application/json"
|
1449
|
+
)
|
1450
|
+
|
1451
|
+
async def _handle_post_prompt_request(self, request: Request):
|
1452
|
+
"""Handle GET/POST requests to the post_prompt endpoint"""
|
1453
|
+
req_log = self.log.bind(
|
1454
|
+
endpoint="post_prompt",
|
1455
|
+
method=request.method,
|
1456
|
+
path=request.url.path
|
1457
|
+
)
|
1458
|
+
|
1459
|
+
# Only log if not suppressed
|
1460
|
+
if not self._suppress_logs:
|
1461
|
+
req_log.debug("endpoint_called")
|
1462
|
+
|
1463
|
+
try:
|
1464
|
+
# Check auth
|
1465
|
+
if not self._check_basic_auth(request):
|
1466
|
+
req_log.warning("unauthorized_access_attempt")
|
1467
|
+
return Response(
|
1468
|
+
content=json.dumps({"error": "Unauthorized"}),
|
1469
|
+
status_code=401,
|
1470
|
+
headers={"WWW-Authenticate": "Basic"},
|
1471
|
+
media_type="application/json"
|
1472
|
+
)
|
1473
|
+
|
1474
|
+
# For GET requests, return the SWML document (same as root endpoint)
|
1475
|
+
if request.method == "GET":
|
1476
|
+
call_id = request.query_params.get("call_id")
|
1477
|
+
swml = self._render_swml(call_id)
|
1478
|
+
req_log.debug("swml_rendered", swml_size=len(swml))
|
1479
|
+
return Response(
|
1480
|
+
content=swml,
|
1481
|
+
media_type="application/json"
|
1482
|
+
)
|
1483
|
+
|
1484
|
+
# For POST requests, process the post-prompt data
|
1485
|
+
try:
|
1486
|
+
body = await request.json()
|
1487
|
+
|
1488
|
+
# Only log if not suppressed
|
1489
|
+
if not self._suppress_logs:
|
1490
|
+
req_log.debug("request_body_received", body_size=len(str(body)))
|
1491
|
+
# Log the raw body as properly formatted JSON (not Python dict representation)
|
1492
|
+
print("POST_PROMPT_BODY: " + json.dumps(body))
|
1493
|
+
except Exception as e:
|
1494
|
+
req_log.error("error_parsing_request_body", error=str(e), traceback=traceback.format_exc())
|
1495
|
+
body = {}
|
1496
|
+
|
1497
|
+
# Extract summary from the correct location in the request
|
1498
|
+
summary = self._find_summary_in_post_data(body, req_log)
|
1499
|
+
|
1500
|
+
# Save state if call_id is provided
|
1501
|
+
call_id = body.get("call_id")
|
1502
|
+
if call_id and summary:
|
1503
|
+
req_log = req_log.bind(call_id=call_id)
|
1504
|
+
|
1505
|
+
# Check if state manager has the right methods
|
1506
|
+
try:
|
1507
|
+
if hasattr(self._state_manager, 'get_state'):
|
1508
|
+
state = self._state_manager.get_state(call_id) or {}
|
1509
|
+
state["summary"] = summary
|
1510
|
+
if hasattr(self._state_manager, 'update_state'):
|
1511
|
+
self._state_manager.update_state(call_id, state)
|
1512
|
+
req_log.debug("state_updated_with_summary")
|
1513
|
+
except Exception as e:
|
1514
|
+
req_log.warning("state_update_failed", error=str(e))
|
1515
|
+
|
1516
|
+
# Call the summary handler with the summary and the full body
|
1517
|
+
try:
|
1518
|
+
if summary:
|
1519
|
+
self.on_summary(summary, body)
|
1520
|
+
req_log.debug("summary_handler_called_successfully")
|
1521
|
+
else:
|
1522
|
+
# If no summary found but still want to process the data
|
1523
|
+
self.on_summary(None, body)
|
1524
|
+
req_log.debug("summary_handler_called_with_null_summary")
|
1525
|
+
except Exception as e:
|
1526
|
+
req_log.error("error_in_summary_handler", error=str(e), traceback=traceback.format_exc())
|
1527
|
+
|
1528
|
+
# Return success
|
1529
|
+
req_log.info("request_successful")
|
1530
|
+
return {"success": True}
|
1531
|
+
except Exception as e:
|
1532
|
+
req_log.error("request_failed", error=str(e), traceback=traceback.format_exc())
|
1533
|
+
return Response(
|
1534
|
+
content=json.dumps({"error": str(e)}),
|
1535
|
+
status_code=500,
|
1536
|
+
media_type="application/json"
|
1537
|
+
)
|
1538
|
+
|
1539
|
+
def _find_summary_in_post_data(self, body, logger):
|
1540
|
+
"""
|
1541
|
+
Extensive search for the summary in the post data
|
1542
|
+
|
1543
|
+
Args:
|
1544
|
+
body: The POST request body
|
1545
|
+
logger: The logger instance to use
|
1546
|
+
|
1547
|
+
Returns:
|
1548
|
+
The summary if found, None otherwise
|
1549
|
+
"""
|
1550
|
+
summary = None
|
1551
|
+
|
1552
|
+
# Check all the locations where the summary might be found
|
1553
|
+
|
1554
|
+
# 1. First check post_prompt_data.parsed array (new standard location)
|
1555
|
+
post_prompt_data = body.get("post_prompt_data", {})
|
1556
|
+
if post_prompt_data:
|
1557
|
+
if not self._suppress_logs:
|
1558
|
+
logger.debug("checking_post_prompt_data", data_type=type(post_prompt_data).__name__)
|
1559
|
+
|
1560
|
+
# Check for parsed array first (this is the most common location)
|
1561
|
+
if isinstance(post_prompt_data, dict) and "parsed" in post_prompt_data:
|
1562
|
+
parsed = post_prompt_data.get("parsed")
|
1563
|
+
if isinstance(parsed, list) and len(parsed) > 0:
|
1564
|
+
# The summary is the first item in the parsed array
|
1565
|
+
summary = parsed[0]
|
1566
|
+
print("SUMMARY_FOUND: " + json.dumps(summary))
|
1567
|
+
return summary
|
1568
|
+
|
1569
|
+
# Check raw field - it might contain a JSON string
|
1570
|
+
if isinstance(post_prompt_data, dict) and "raw" in post_prompt_data:
|
1571
|
+
raw = post_prompt_data.get("raw")
|
1572
|
+
if isinstance(raw, str):
|
1573
|
+
try:
|
1574
|
+
# Try to parse the raw field as JSON
|
1575
|
+
parsed_raw = json.loads(raw)
|
1576
|
+
if not self._suppress_logs:
|
1577
|
+
print("SUMMARY_FOUND_RAW: " + json.dumps(parsed_raw))
|
1578
|
+
return parsed_raw
|
1579
|
+
except:
|
1580
|
+
pass
|
1581
|
+
|
1582
|
+
# Direct access to substituted field
|
1583
|
+
if isinstance(post_prompt_data, dict) and "substituted" in post_prompt_data:
|
1584
|
+
summary = post_prompt_data.get("substituted")
|
1585
|
+
if not self._suppress_logs:
|
1586
|
+
print("SUMMARY_FOUND_SUBSTITUTED: " + json.dumps(summary) if isinstance(summary, (dict, list)) else f"SUMMARY_FOUND_SUBSTITUTED: {summary}")
|
1587
|
+
return summary
|
1588
|
+
|
1589
|
+
# Check for nested data structure
|
1590
|
+
if isinstance(post_prompt_data, dict) and "data" in post_prompt_data:
|
1591
|
+
data = post_prompt_data.get("data")
|
1592
|
+
if isinstance(data, dict):
|
1593
|
+
if "substituted" in data:
|
1594
|
+
summary = data.get("substituted")
|
1595
|
+
if not self._suppress_logs:
|
1596
|
+
print("SUMMARY_FOUND_DATA_SUBSTITUTED: " + json.dumps(summary) if isinstance(summary, (dict, list)) else f"SUMMARY_FOUND_DATA_SUBSTITUTED: {summary}")
|
1597
|
+
return summary
|
1598
|
+
|
1599
|
+
# Try text field
|
1600
|
+
if "text" in data:
|
1601
|
+
summary = data.get("text")
|
1602
|
+
if not self._suppress_logs:
|
1603
|
+
print("SUMMARY_FOUND_DATA_TEXT: " + json.dumps(summary) if isinstance(summary, (dict, list)) else f"SUMMARY_FOUND_DATA_TEXT: {summary}")
|
1604
|
+
return summary
|
1605
|
+
|
1606
|
+
# 2. Check ai_response (legacy location)
|
1607
|
+
ai_response = body.get("ai_response", {})
|
1608
|
+
if ai_response and isinstance(ai_response, dict):
|
1609
|
+
if "summary" in ai_response:
|
1610
|
+
summary = ai_response.get("summary")
|
1611
|
+
if not self._suppress_logs:
|
1612
|
+
print("SUMMARY_FOUND_AI_RESPONSE: " + json.dumps(summary) if isinstance(summary, (dict, list)) else f"SUMMARY_FOUND_AI_RESPONSE: {summary}")
|
1613
|
+
return summary
|
1614
|
+
|
1615
|
+
# 3. Look for direct fields at the top level
|
1616
|
+
for field in ["substituted", "summary", "content", "text", "result", "output"]:
|
1617
|
+
if field in body:
|
1618
|
+
summary = body.get(field)
|
1619
|
+
if not self._suppress_logs:
|
1620
|
+
print(f"SUMMARY_FOUND_TOP_LEVEL_{field}: " + json.dumps(summary) if isinstance(summary, (dict, list)) else f"SUMMARY_FOUND_TOP_LEVEL_{field}: {summary}")
|
1621
|
+
return summary
|
1622
|
+
|
1623
|
+
# 4. Recursively search for summary-like fields up to 3 levels deep
|
1624
|
+
def recursive_search(data, path="", depth=0):
|
1625
|
+
if depth > 3 or not isinstance(data, dict): # Limit recursion depth
|
1626
|
+
return None
|
1627
|
+
|
1628
|
+
# Check if any key looks like it might contain a summary
|
1629
|
+
for key in data.keys():
|
1630
|
+
if key.lower() in ["summary", "substituted", "output", "result", "content", "text"]:
|
1631
|
+
value = data.get(key)
|
1632
|
+
curr_path = f"{path}.{key}" if path else key
|
1633
|
+
if not self._suppress_logs:
|
1634
|
+
logger.info(f"potential_summary_found_at_{curr_path}",
|
1635
|
+
value_type=type(value).__name__)
|
1636
|
+
if isinstance(value, (str, dict, list)):
|
1637
|
+
return value
|
1638
|
+
|
1639
|
+
# Recursively check nested dictionaries
|
1640
|
+
for key, value in data.items():
|
1641
|
+
if isinstance(value, dict):
|
1642
|
+
curr_path = f"{path}.{key}" if path else key
|
1643
|
+
result = recursive_search(value, curr_path, depth + 1)
|
1644
|
+
if result:
|
1645
|
+
return result
|
1646
|
+
|
1647
|
+
return None
|
1648
|
+
|
1649
|
+
# Perform recursive search
|
1650
|
+
recursive_result = recursive_search(body)
|
1651
|
+
if recursive_result:
|
1652
|
+
summary = recursive_result
|
1653
|
+
if not self._suppress_logs:
|
1654
|
+
print("SUMMARY_FOUND_RECURSIVE: " + json.dumps(summary) if isinstance(summary, (dict, list)) else f"SUMMARY_FOUND_RECURSIVE: {summary}")
|
1655
|
+
return summary
|
1656
|
+
|
1657
|
+
# No summary found
|
1658
|
+
if not self._suppress_logs:
|
1659
|
+
print("NO_SUMMARY_FOUND")
|
1660
|
+
return None
|
1661
|
+
|
1662
|
+
def _register_routes(self, app):
|
1663
|
+
"""Register all routes for the agent, with both slash variants and both HTTP methods"""
|
1664
|
+
|
1665
|
+
self.log.info("registering_routes", path=self.route)
|
1666
|
+
|
1667
|
+
# Root endpoint - without trailing slash
|
1668
|
+
@app.get(f"{self.route}")
|
1669
|
+
@app.post(f"{self.route}")
|
1670
|
+
async def handle_root_no_slash(request: Request):
|
1671
|
+
return await self._handle_root_request(request)
|
1672
|
+
|
1673
|
+
# Root endpoint - with trailing slash
|
1674
|
+
@app.get(f"{self.route}/")
|
1675
|
+
@app.post(f"{self.route}/")
|
1676
|
+
async def handle_root_with_slash(request: Request):
|
1677
|
+
return await self._handle_root_request(request)
|
1678
|
+
|
1679
|
+
# Debug endpoint - without trailing slash
|
1680
|
+
@app.get(f"{self.route}/debug")
|
1681
|
+
@app.post(f"{self.route}/debug")
|
1682
|
+
async def handle_debug_no_slash(request: Request):
|
1683
|
+
return await self._handle_debug_request(request)
|
1684
|
+
|
1685
|
+
# Debug endpoint - with trailing slash
|
1686
|
+
@app.get(f"{self.route}/debug/")
|
1687
|
+
@app.post(f"{self.route}/debug/")
|
1688
|
+
async def handle_debug_with_slash(request: Request):
|
1689
|
+
return await self._handle_debug_request(request)
|
1690
|
+
|
1691
|
+
# SWAIG endpoint - without trailing slash
|
1692
|
+
@app.get(f"{self.route}/swaig")
|
1693
|
+
@app.post(f"{self.route}/swaig")
|
1694
|
+
async def handle_swaig_no_slash(request: Request):
|
1695
|
+
return await self._handle_swaig_request(request)
|
1696
|
+
|
1697
|
+
# SWAIG endpoint - with trailing slash
|
1698
|
+
@app.get(f"{self.route}/swaig/")
|
1699
|
+
@app.post(f"{self.route}/swaig/")
|
1700
|
+
async def handle_swaig_with_slash(request: Request):
|
1701
|
+
return await self._handle_swaig_request(request)
|
1702
|
+
|
1703
|
+
# Post-prompt endpoint - without trailing slash
|
1704
|
+
@app.get(f"{self.route}/post_prompt")
|
1705
|
+
@app.post(f"{self.route}/post_prompt")
|
1706
|
+
async def handle_post_prompt_no_slash(request: Request):
|
1707
|
+
return await self._handle_post_prompt_request(request)
|
1708
|
+
|
1709
|
+
# Post-prompt endpoint - with trailing slash
|
1710
|
+
@app.get(f"{self.route}/post_prompt/")
|
1711
|
+
@app.post(f"{self.route}/post_prompt/")
|
1712
|
+
async def handle_post_prompt_with_slash(request: Request):
|
1713
|
+
return await self._handle_post_prompt_request(request)
|
1714
|
+
|
1715
|
+
# Register routes for all routing callbacks
|
1716
|
+
if hasattr(self, '_routing_callbacks') and self._routing_callbacks:
|
1717
|
+
for callback_path, callback_fn in self._routing_callbacks.items():
|
1718
|
+
# Skip the root path as it's already handled
|
1719
|
+
if callback_path == "/":
|
1720
|
+
continue
|
1721
|
+
|
1722
|
+
# Register the endpoint without trailing slash
|
1723
|
+
callback_route = callback_path
|
1724
|
+
self.log.info("registering_callback_route", path=callback_route)
|
1725
|
+
|
1726
|
+
@app.get(callback_route)
|
1727
|
+
@app.post(callback_route)
|
1728
|
+
async def handle_callback_no_slash(request: Request, path_param=callback_route):
|
1729
|
+
# Store the callback path in request state for _handle_root_request to use
|
1730
|
+
request.state.callback_path = path_param
|
1731
|
+
return await self._handle_root_request(request)
|
1732
|
+
|
1733
|
+
# Register the endpoint with trailing slash if it doesn't already have one
|
1734
|
+
if not callback_route.endswith('/'):
|
1735
|
+
slash_route = f"{callback_route}/"
|
1736
|
+
|
1737
|
+
@app.get(slash_route)
|
1738
|
+
@app.post(slash_route)
|
1739
|
+
async def handle_callback_with_slash(request: Request, path_param=callback_route):
|
1740
|
+
# Store the callback path in request state for _handle_root_request to use
|
1741
|
+
request.state.callback_path = path_param
|
1742
|
+
return await self._handle_root_request(request)
|
1743
|
+
|
1744
|
+
# Log all registered routes
|
1745
|
+
routes = [f"{route.methods} {route.path}" for route in app.routes]
|
1746
|
+
self.log.debug("routes_registered", routes=routes)
|
1747
|
+
|
1748
|
+
def _register_class_decorated_tools(self):
|
1749
|
+
"""
|
1750
|
+
Register all tools decorated with @AgentBase.tool
|
1751
|
+
"""
|
1752
|
+
for name in dir(self):
|
1753
|
+
attr = getattr(self, name)
|
1754
|
+
if callable(attr) and hasattr(attr, "_is_tool"):
|
1755
|
+
# Get tool parameters
|
1756
|
+
tool_name = getattr(attr, "_tool_name", name)
|
1757
|
+
tool_params = getattr(attr, "_tool_params", {})
|
1758
|
+
|
1759
|
+
# Extract parameters
|
1760
|
+
parameters = tool_params.get("parameters", {})
|
1761
|
+
description = tool_params.get("description", attr.__doc__ or f"Function {tool_name}")
|
1762
|
+
secure = tool_params.get("secure", True)
|
1763
|
+
fillers = tool_params.get("fillers", None)
|
1764
|
+
|
1765
|
+
# Create a wrapper that binds the method to this instance
|
1766
|
+
def make_wrapper(method):
|
1767
|
+
@functools.wraps(method)
|
1768
|
+
def wrapper(args, raw_data=None):
|
1769
|
+
return method(args, raw_data)
|
1770
|
+
return wrapper
|
1771
|
+
|
1772
|
+
# Register the tool
|
1773
|
+
self.define_tool(
|
1774
|
+
name=tool_name,
|
1775
|
+
description=description,
|
1776
|
+
parameters=parameters,
|
1777
|
+
handler=make_wrapper(attr),
|
1778
|
+
secure=secure,
|
1779
|
+
fillers=fillers
|
1780
|
+
)
|
1781
|
+
|
1782
|
+
# State Management Methods
|
1783
|
+
def get_state(self, call_id: str) -> Optional[Dict[str, Any]]:
|
1784
|
+
"""
|
1785
|
+
Get the state for a call
|
1786
|
+
|
1787
|
+
Args:
|
1788
|
+
call_id: Call ID to get state for
|
1789
|
+
|
1790
|
+
Returns:
|
1791
|
+
Call state or None if not found
|
1792
|
+
"""
|
1793
|
+
try:
|
1794
|
+
if hasattr(self._state_manager, 'get_state'):
|
1795
|
+
return self._state_manager.get_state(call_id)
|
1796
|
+
return None
|
1797
|
+
except Exception as e:
|
1798
|
+
logger.warning("get_state_failed", error=str(e))
|
1799
|
+
return None
|
1800
|
+
|
1801
|
+
def set_state(self, call_id: str, data: Dict[str, Any]) -> bool:
|
1802
|
+
"""
|
1803
|
+
Set the state for a call
|
1804
|
+
|
1805
|
+
Args:
|
1806
|
+
call_id: Call ID to set state for
|
1807
|
+
data: State data to set
|
1808
|
+
|
1809
|
+
Returns:
|
1810
|
+
True if state was set, False otherwise
|
1811
|
+
"""
|
1812
|
+
try:
|
1813
|
+
if hasattr(self._state_manager, 'set_state'):
|
1814
|
+
return self._state_manager.set_state(call_id, data)
|
1815
|
+
return False
|
1816
|
+
except Exception as e:
|
1817
|
+
logger.warning("set_state_failed", error=str(e))
|
1818
|
+
return False
|
1819
|
+
|
1820
|
+
def update_state(self, call_id: str, data: Dict[str, Any]) -> bool:
|
1821
|
+
"""
|
1822
|
+
Update the state for a call
|
1823
|
+
|
1824
|
+
Args:
|
1825
|
+
call_id: Call ID to update state for
|
1826
|
+
data: State data to update
|
1827
|
+
|
1828
|
+
Returns:
|
1829
|
+
True if state was updated, False otherwise
|
1830
|
+
"""
|
1831
|
+
try:
|
1832
|
+
if hasattr(self._state_manager, 'update_state'):
|
1833
|
+
return self._state_manager.update_state(call_id, data)
|
1834
|
+
return self.set_state(call_id, data)
|
1835
|
+
except Exception as e:
|
1836
|
+
logger.warning("update_state_failed", error=str(e))
|
1837
|
+
return False
|
1838
|
+
|
1839
|
+
def clear_state(self, call_id: str) -> bool:
|
1840
|
+
"""
|
1841
|
+
Clear the state for a call
|
1842
|
+
|
1843
|
+
Args:
|
1844
|
+
call_id: Call ID to clear state for
|
1845
|
+
|
1846
|
+
Returns:
|
1847
|
+
True if state was cleared, False otherwise
|
1848
|
+
"""
|
1849
|
+
try:
|
1850
|
+
if hasattr(self._state_manager, 'clear_state'):
|
1851
|
+
return self._state_manager.clear_state(call_id)
|
1852
|
+
return False
|
1853
|
+
except Exception as e:
|
1854
|
+
logger.warning("clear_state_failed", error=str(e))
|
1855
|
+
return False
|
1856
|
+
|
1857
|
+
def cleanup_expired_state(self) -> int:
|
1858
|
+
"""
|
1859
|
+
Clean up expired state
|
1860
|
+
|
1861
|
+
Returns:
|
1862
|
+
Number of expired state entries removed
|
1863
|
+
"""
|
1864
|
+
try:
|
1865
|
+
if hasattr(self._state_manager, 'cleanup_expired'):
|
1866
|
+
return self._state_manager.cleanup_expired()
|
1867
|
+
return 0
|
1868
|
+
except Exception as e:
|
1869
|
+
logger.warning("cleanup_expired_state_failed", error=str(e))
|
1870
|
+
return 0
|
1871
|
+
|
1872
|
+
def _register_state_tracking_tools(self):
|
1873
|
+
"""
|
1874
|
+
Register tools for tracking conversation state
|
1875
|
+
"""
|
1876
|
+
# Register startup hook
|
1877
|
+
self.define_tool(
|
1878
|
+
name="startup_hook",
|
1879
|
+
description="Called when the conversation starts",
|
1880
|
+
parameters={},
|
1881
|
+
handler=self._startup_hook_handler,
|
1882
|
+
secure=False
|
1883
|
+
)
|
1884
|
+
|
1885
|
+
# Register hangup hook
|
1886
|
+
self.define_tool(
|
1887
|
+
name="hangup_hook",
|
1888
|
+
description="Called when the conversation ends",
|
1889
|
+
parameters={},
|
1890
|
+
handler=self._hangup_hook_handler,
|
1891
|
+
secure=False
|
1892
|
+
)
|
1893
|
+
|
1894
|
+
def _startup_hook_handler(self, args, raw_data):
|
1895
|
+
"""
|
1896
|
+
Handler for the startup hook
|
1897
|
+
|
1898
|
+
Args:
|
1899
|
+
args: Function arguments
|
1900
|
+
raw_data: Raw request data
|
1901
|
+
|
1902
|
+
Returns:
|
1903
|
+
Function result
|
1904
|
+
"""
|
1905
|
+
# Extract call ID
|
1906
|
+
call_id = raw_data.get("call_id") if raw_data else None
|
1907
|
+
if not call_id:
|
1908
|
+
return SwaigFunctionResult("Error: Missing call_id")
|
1909
|
+
|
1910
|
+
# Activate the session
|
1911
|
+
self._session_manager.activate_session(call_id)
|
1912
|
+
|
1913
|
+
# Initialize state
|
1914
|
+
self.set_state(call_id, {
|
1915
|
+
"start_time": datetime.now().isoformat(),
|
1916
|
+
"events": []
|
1917
|
+
})
|
1918
|
+
|
1919
|
+
return SwaigFunctionResult("Call started and session activated")
|
1920
|
+
|
1921
|
+
def _hangup_hook_handler(self, args, raw_data):
|
1922
|
+
"""
|
1923
|
+
Handler for the hangup hook
|
1924
|
+
|
1925
|
+
Args:
|
1926
|
+
args: Function arguments
|
1927
|
+
raw_data: Raw request data
|
1928
|
+
|
1929
|
+
Returns:
|
1930
|
+
Function result
|
1931
|
+
"""
|
1932
|
+
# Extract call ID
|
1933
|
+
call_id = raw_data.get("call_id") if raw_data else None
|
1934
|
+
if not call_id:
|
1935
|
+
return SwaigFunctionResult("Error: Missing call_id")
|
1936
|
+
|
1937
|
+
# End the session
|
1938
|
+
self._session_manager.end_session(call_id)
|
1939
|
+
|
1940
|
+
# Update state
|
1941
|
+
state = self.get_state(call_id) or {}
|
1942
|
+
state["end_time"] = datetime.now().isoformat()
|
1943
|
+
self.update_state(call_id, state)
|
1944
|
+
|
1945
|
+
return SwaigFunctionResult("Call ended and session deactivated")
|
1946
|
+
|
1947
|
+
def set_post_prompt(self, text: str) -> 'AgentBase':
|
1948
|
+
"""
|
1949
|
+
Set the post-prompt for the agent
|
1950
|
+
|
1951
|
+
Args:
|
1952
|
+
text: Post-prompt text
|
1953
|
+
|
1954
|
+
Returns:
|
1955
|
+
Self for method chaining
|
1956
|
+
"""
|
1957
|
+
self._post_prompt = text
|
1958
|
+
return self
|
1959
|
+
|
1960
|
+
def set_auto_answer(self, enabled: bool) -> 'AgentBase':
|
1961
|
+
"""
|
1962
|
+
Set whether to automatically answer calls
|
1963
|
+
|
1964
|
+
Args:
|
1965
|
+
enabled: Whether to auto-answer
|
1966
|
+
|
1967
|
+
Returns:
|
1968
|
+
Self for method chaining
|
1969
|
+
"""
|
1970
|
+
self._auto_answer = enabled
|
1971
|
+
return self
|
1972
|
+
|
1973
|
+
def set_call_recording(self,
|
1974
|
+
enabled: bool,
|
1975
|
+
format: str = "mp4",
|
1976
|
+
stereo: bool = True) -> 'AgentBase':
|
1977
|
+
"""
|
1978
|
+
Set call recording parameters
|
1979
|
+
|
1980
|
+
Args:
|
1981
|
+
enabled: Whether to record calls
|
1982
|
+
format: Recording format
|
1983
|
+
stereo: Whether to record in stereo
|
1984
|
+
|
1985
|
+
Returns:
|
1986
|
+
Self for method chaining
|
1987
|
+
"""
|
1988
|
+
self._record_call = enabled
|
1989
|
+
self._record_format = format
|
1990
|
+
self._record_stereo = stereo
|
1991
|
+
return self
|
1992
|
+
|
1993
|
+
def add_native_function(self, function_name: str) -> 'AgentBase':
|
1994
|
+
"""
|
1995
|
+
Add a native function to the list of enabled native functions
|
1996
|
+
|
1997
|
+
Args:
|
1998
|
+
function_name: Name of native function to enable
|
1999
|
+
|
2000
|
+
Returns:
|
2001
|
+
Self for method chaining
|
2002
|
+
"""
|
2003
|
+
if function_name and isinstance(function_name, str):
|
2004
|
+
if not self.native_functions:
|
2005
|
+
self.native_functions = []
|
2006
|
+
if function_name not in self.native_functions:
|
2007
|
+
self.native_functions.append(function_name)
|
2008
|
+
return self
|
2009
|
+
|
2010
|
+
def remove_native_function(self, function_name: str) -> 'AgentBase':
|
2011
|
+
"""
|
2012
|
+
Remove a native function from the SWAIG object
|
2013
|
+
|
2014
|
+
Args:
|
2015
|
+
function_name: Name of the native function
|
2016
|
+
|
2017
|
+
Returns:
|
2018
|
+
Self for method chaining
|
2019
|
+
"""
|
2020
|
+
if function_name in self.native_functions:
|
2021
|
+
self.native_functions.remove(function_name)
|
2022
|
+
return self
|
2023
|
+
|
2024
|
+
def get_native_functions(self) -> List[str]:
|
2025
|
+
"""
|
2026
|
+
Get the list of native functions
|
2027
|
+
|
2028
|
+
Returns:
|
2029
|
+
List of native function names
|
2030
|
+
"""
|
2031
|
+
return self.native_functions.copy()
|
2032
|
+
|
2033
|
+
def has_section(self, title: str) -> bool:
|
2034
|
+
"""
|
2035
|
+
Check if a section exists in the prompt
|
2036
|
+
|
2037
|
+
Args:
|
2038
|
+
title: Section title
|
2039
|
+
|
2040
|
+
Returns:
|
2041
|
+
True if the section exists, False otherwise
|
2042
|
+
"""
|
2043
|
+
if not self._use_pom or not self.pom:
|
2044
|
+
return False
|
2045
|
+
|
2046
|
+
return self.pom.has_section(title)
|
2047
|
+
|
2048
|
+
def on_swml_request(self, request_data: Optional[dict] = None) -> Optional[dict]:
|
2049
|
+
"""
|
2050
|
+
Called when SWML is requested, with request data when available.
|
2051
|
+
|
2052
|
+
Subclasses can override this to inspect or modify SWML based on the request.
|
2053
|
+
|
2054
|
+
Args:
|
2055
|
+
request_data: Optional dictionary containing the parsed POST body
|
2056
|
+
|
2057
|
+
Returns:
|
2058
|
+
Optional dict to modify/augment the SWML document
|
2059
|
+
"""
|
2060
|
+
# Default implementation does nothing
|
2061
|
+
return None
|
2062
|
+
|
2063
|
+
def serve(self, host: Optional[str] = None, port: Optional[int] = None) -> None:
|
2064
|
+
"""
|
2065
|
+
Start a web server for this agent
|
2066
|
+
|
2067
|
+
Args:
|
2068
|
+
host: Optional host to override the default
|
2069
|
+
port: Optional port to override the default
|
2070
|
+
"""
|
2071
|
+
import uvicorn
|
2072
|
+
|
2073
|
+
# Create a FastAPI app with no automatic redirects
|
2074
|
+
app = FastAPI(redirect_slashes=False)
|
2075
|
+
|
2076
|
+
# Register all routes
|
2077
|
+
self._register_routes(app)
|
2078
|
+
|
2079
|
+
host = host or self.host
|
2080
|
+
port = port or self.port
|
2081
|
+
|
2082
|
+
# Print the auth credentials with source
|
2083
|
+
username, password, source = self.get_basic_auth_credentials(include_source=True)
|
2084
|
+
self.log.info("starting_server",
|
2085
|
+
url=f"http://{host}:{port}{self.route}",
|
2086
|
+
username=username,
|
2087
|
+
password="*" * len(password),
|
2088
|
+
auth_source=source)
|
2089
|
+
|
2090
|
+
print(f"Agent '{self.name}' is available at:")
|
2091
|
+
print(f"URL: http://{host}:{port}{self.route}")
|
2092
|
+
print(f"Basic Auth: {username}:{password} (source: {source})")
|
2093
|
+
|
2094
|
+
# Check if SIP usernames are registered and print that info
|
2095
|
+
if hasattr(self, '_sip_usernames') and self._sip_usernames:
|
2096
|
+
print(f"Registered SIP usernames: {', '.join(sorted(self._sip_usernames))}")
|
2097
|
+
|
2098
|
+
# Check if callback endpoints are registered and print them
|
2099
|
+
if hasattr(self, '_routing_callbacks') and self._routing_callbacks:
|
2100
|
+
for path in sorted(self._routing_callbacks.keys()):
|
2101
|
+
if hasattr(self, '_sip_usernames') and path == "/sip":
|
2102
|
+
print(f"SIP endpoint: http://{host}:{port}{path}")
|
2103
|
+
else:
|
2104
|
+
print(f"Callback endpoint: http://{host}:{port}{path}")
|
2105
|
+
|
2106
|
+
# Configure Uvicorn for production
|
2107
|
+
uvicorn_log_config = uvicorn.config.LOGGING_CONFIG
|
2108
|
+
uvicorn_log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
2109
|
+
uvicorn_log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
2110
|
+
|
2111
|
+
# Start the server
|
2112
|
+
try:
|
2113
|
+
# Run the server
|
2114
|
+
uvicorn.run(
|
2115
|
+
app,
|
2116
|
+
host=host,
|
2117
|
+
port=port,
|
2118
|
+
log_config=uvicorn_log_config
|
2119
|
+
)
|
2120
|
+
except KeyboardInterrupt:
|
2121
|
+
self.log.info("server_shutdown")
|
2122
|
+
print("\nStopping the agent.")
|
2123
|
+
|
2124
|
+
# ----------------------------------------------------------------------
|
2125
|
+
# AI Verb Configuration Methods
|
2126
|
+
# ----------------------------------------------------------------------
|
2127
|
+
|
2128
|
+
def add_hint(self, hint: str) -> 'AgentBase':
|
2129
|
+
"""
|
2130
|
+
Add a simple string hint to help the AI agent understand certain words better
|
2131
|
+
|
2132
|
+
Args:
|
2133
|
+
hint: The hint string to add
|
2134
|
+
|
2135
|
+
Returns:
|
2136
|
+
Self for method chaining
|
2137
|
+
"""
|
2138
|
+
if isinstance(hint, str) and hint:
|
2139
|
+
self._hints.append(hint)
|
2140
|
+
return self
|
2141
|
+
|
2142
|
+
def add_hints(self, hints: List[str]) -> 'AgentBase':
|
2143
|
+
"""
|
2144
|
+
Add multiple string hints
|
2145
|
+
|
2146
|
+
Args:
|
2147
|
+
hints: List of hint strings
|
2148
|
+
|
2149
|
+
Returns:
|
2150
|
+
Self for method chaining
|
2151
|
+
"""
|
2152
|
+
if hints and isinstance(hints, list):
|
2153
|
+
for hint in hints:
|
2154
|
+
if isinstance(hint, str) and hint:
|
2155
|
+
self._hints.append(hint)
|
2156
|
+
return self
|
2157
|
+
|
2158
|
+
def add_pattern_hint(self,
|
2159
|
+
hint: str,
|
2160
|
+
pattern: str,
|
2161
|
+
replace: str,
|
2162
|
+
ignore_case: bool = False) -> 'AgentBase':
|
2163
|
+
"""
|
2164
|
+
Add a complex hint with pattern matching
|
2165
|
+
|
2166
|
+
Args:
|
2167
|
+
hint: The hint to match
|
2168
|
+
pattern: Regular expression pattern
|
2169
|
+
replace: Text to replace the hint with
|
2170
|
+
ignore_case: Whether to ignore case when matching
|
2171
|
+
|
2172
|
+
Returns:
|
2173
|
+
Self for method chaining
|
2174
|
+
"""
|
2175
|
+
if hint and pattern and replace:
|
2176
|
+
self._hints.append({
|
2177
|
+
"hint": hint,
|
2178
|
+
"pattern": pattern,
|
2179
|
+
"replace": replace,
|
2180
|
+
"ignore_case": ignore_case
|
2181
|
+
})
|
2182
|
+
return self
|
2183
|
+
|
2184
|
+
def add_language(self,
|
2185
|
+
name: str,
|
2186
|
+
code: str,
|
2187
|
+
voice: str,
|
2188
|
+
speech_fillers: Optional[List[str]] = None,
|
2189
|
+
function_fillers: Optional[List[str]] = None,
|
2190
|
+
engine: Optional[str] = None,
|
2191
|
+
model: Optional[str] = None) -> 'AgentBase':
|
2192
|
+
"""
|
2193
|
+
Add a language configuration to support multilingual conversations
|
2194
|
+
|
2195
|
+
Args:
|
2196
|
+
name: Name of the language (e.g., "English", "French")
|
2197
|
+
code: Language code (e.g., "en-US", "fr-FR")
|
2198
|
+
voice: TTS voice to use. Can be a simple name (e.g., "en-US-Neural2-F")
|
2199
|
+
or a combined format "engine.voice:model" (e.g., "elevenlabs.josh:eleven_turbo_v2_5")
|
2200
|
+
speech_fillers: Optional list of filler phrases for natural speech
|
2201
|
+
function_fillers: Optional list of filler phrases during function calls
|
2202
|
+
engine: Optional explicit engine name (e.g., "elevenlabs", "rime")
|
2203
|
+
model: Optional explicit model name (e.g., "eleven_turbo_v2_5", "arcana")
|
2204
|
+
|
2205
|
+
Returns:
|
2206
|
+
Self for method chaining
|
2207
|
+
|
2208
|
+
Examples:
|
2209
|
+
# Simple voice name
|
2210
|
+
agent.add_language("English", "en-US", "en-US-Neural2-F")
|
2211
|
+
|
2212
|
+
# Explicit parameters
|
2213
|
+
agent.add_language("English", "en-US", "josh", engine="elevenlabs", model="eleven_turbo_v2_5")
|
2214
|
+
|
2215
|
+
# Combined format
|
2216
|
+
agent.add_language("English", "en-US", "elevenlabs.josh:eleven_turbo_v2_5")
|
2217
|
+
"""
|
2218
|
+
language = {
|
2219
|
+
"name": name,
|
2220
|
+
"code": code
|
2221
|
+
}
|
2222
|
+
|
2223
|
+
# Handle voice formatting (either explicit params or combined string)
|
2224
|
+
if engine or model:
|
2225
|
+
# Use explicit parameters if provided
|
2226
|
+
language["voice"] = voice
|
2227
|
+
if engine:
|
2228
|
+
language["engine"] = engine
|
2229
|
+
if model:
|
2230
|
+
language["model"] = model
|
2231
|
+
elif "." in voice and ":" in voice:
|
2232
|
+
# Parse combined string format: "engine.voice:model"
|
2233
|
+
try:
|
2234
|
+
engine_voice, model_part = voice.split(":", 1)
|
2235
|
+
engine_part, voice_part = engine_voice.split(".", 1)
|
2236
|
+
|
2237
|
+
language["voice"] = voice_part
|
2238
|
+
language["engine"] = engine_part
|
2239
|
+
language["model"] = model_part
|
2240
|
+
except ValueError:
|
2241
|
+
# If parsing fails, use the voice string as-is
|
2242
|
+
language["voice"] = voice
|
2243
|
+
else:
|
2244
|
+
# Simple voice string
|
2245
|
+
language["voice"] = voice
|
2246
|
+
|
2247
|
+
# Add fillers if provided
|
2248
|
+
if speech_fillers and function_fillers:
|
2249
|
+
language["speech_fillers"] = speech_fillers
|
2250
|
+
language["function_fillers"] = function_fillers
|
2251
|
+
elif speech_fillers or function_fillers:
|
2252
|
+
# If only one type of fillers is provided, use the deprecated "fillers" field
|
2253
|
+
fillers = speech_fillers or function_fillers
|
2254
|
+
language["fillers"] = fillers
|
2255
|
+
|
2256
|
+
self._languages.append(language)
|
2257
|
+
return self
|
2258
|
+
|
2259
|
+
def set_languages(self, languages: List[Dict[str, Any]]) -> 'AgentBase':
|
2260
|
+
"""
|
2261
|
+
Set all language configurations at once
|
2262
|
+
|
2263
|
+
Args:
|
2264
|
+
languages: List of language configuration dictionaries
|
2265
|
+
|
2266
|
+
Returns:
|
2267
|
+
Self for method chaining
|
2268
|
+
"""
|
2269
|
+
if languages and isinstance(languages, list):
|
2270
|
+
self._languages = languages
|
2271
|
+
return self
|
2272
|
+
|
2273
|
+
def add_pronunciation(self,
|
2274
|
+
replace: str,
|
2275
|
+
with_text: str,
|
2276
|
+
ignore_case: bool = False) -> 'AgentBase':
|
2277
|
+
"""
|
2278
|
+
Add a pronunciation rule to help the AI speak certain words correctly
|
2279
|
+
|
2280
|
+
Args:
|
2281
|
+
replace: The expression to replace
|
2282
|
+
with_text: The phonetic spelling to use instead
|
2283
|
+
ignore_case: Whether to ignore case when matching
|
2284
|
+
|
2285
|
+
Returns:
|
2286
|
+
Self for method chaining
|
2287
|
+
"""
|
2288
|
+
if replace and with_text:
|
2289
|
+
rule = {
|
2290
|
+
"replace": replace,
|
2291
|
+
"with": with_text
|
2292
|
+
}
|
2293
|
+
if ignore_case:
|
2294
|
+
rule["ignore_case"] = True
|
2295
|
+
|
2296
|
+
self._pronounce.append(rule)
|
2297
|
+
return self
|
2298
|
+
|
2299
|
+
def set_pronunciations(self, pronunciations: List[Dict[str, Any]]) -> 'AgentBase':
|
2300
|
+
"""
|
2301
|
+
Set all pronunciation rules at once
|
2302
|
+
|
2303
|
+
Args:
|
2304
|
+
pronunciations: List of pronunciation rule dictionaries
|
2305
|
+
|
2306
|
+
Returns:
|
2307
|
+
Self for method chaining
|
2308
|
+
"""
|
2309
|
+
if pronunciations and isinstance(pronunciations, list):
|
2310
|
+
self._pronounce = pronunciations
|
2311
|
+
return self
|
2312
|
+
|
2313
|
+
def set_param(self, key: str, value: Any) -> 'AgentBase':
|
2314
|
+
"""
|
2315
|
+
Set a single AI parameter
|
2316
|
+
|
2317
|
+
Args:
|
2318
|
+
key: Parameter name
|
2319
|
+
value: Parameter value
|
2320
|
+
|
2321
|
+
Returns:
|
2322
|
+
Self for method chaining
|
2323
|
+
"""
|
2324
|
+
if key:
|
2325
|
+
self._params[key] = value
|
2326
|
+
return self
|
2327
|
+
|
2328
|
+
def set_params(self, params: Dict[str, Any]) -> 'AgentBase':
|
2329
|
+
"""
|
2330
|
+
Set multiple AI parameters at once
|
2331
|
+
|
2332
|
+
Args:
|
2333
|
+
params: Dictionary of parameter name/value pairs
|
2334
|
+
|
2335
|
+
Returns:
|
2336
|
+
Self for method chaining
|
2337
|
+
"""
|
2338
|
+
if params and isinstance(params, dict):
|
2339
|
+
self._params.update(params)
|
2340
|
+
return self
|
2341
|
+
|
2342
|
+
def set_global_data(self, data: Dict[str, Any]) -> 'AgentBase':
|
2343
|
+
"""
|
2344
|
+
Set the global data available to the AI throughout the conversation
|
2345
|
+
|
2346
|
+
Args:
|
2347
|
+
data: Dictionary of global data
|
2348
|
+
|
2349
|
+
Returns:
|
2350
|
+
Self for method chaining
|
2351
|
+
"""
|
2352
|
+
if data and isinstance(data, dict):
|
2353
|
+
self._global_data = data
|
2354
|
+
return self
|
2355
|
+
|
2356
|
+
def update_global_data(self, data: Dict[str, Any]) -> 'AgentBase':
|
2357
|
+
"""
|
2358
|
+
Update the global data with new values
|
2359
|
+
|
2360
|
+
Args:
|
2361
|
+
data: Dictionary of global data to update
|
2362
|
+
|
2363
|
+
Returns:
|
2364
|
+
Self for method chaining
|
2365
|
+
"""
|
2366
|
+
if data and isinstance(data, dict):
|
2367
|
+
self._global_data.update(data)
|
2368
|
+
return self
|
2369
|
+
|
2370
|
+
def set_native_functions(self, function_names: List[str]) -> 'AgentBase':
|
2371
|
+
"""
|
2372
|
+
Set the list of native functions to enable
|
2373
|
+
|
2374
|
+
Args:
|
2375
|
+
function_names: List of native function names
|
2376
|
+
|
2377
|
+
Returns:
|
2378
|
+
Self for method chaining
|
2379
|
+
"""
|
2380
|
+
if function_names and isinstance(function_names, list):
|
2381
|
+
self.native_functions = [name for name in function_names if isinstance(name, str)]
|
2382
|
+
return self
|
2383
|
+
|
2384
|
+
def add_function_include(self, url: str, functions: List[str], meta_data: Optional[Dict[str, Any]] = None) -> 'AgentBase':
|
2385
|
+
"""
|
2386
|
+
Add a remote function include to the SWAIG configuration
|
2387
|
+
|
2388
|
+
Args:
|
2389
|
+
url: URL to fetch remote functions from
|
2390
|
+
functions: List of function names to include
|
2391
|
+
meta_data: Optional metadata to include with the function include
|
2392
|
+
|
2393
|
+
Returns:
|
2394
|
+
Self for method chaining
|
2395
|
+
"""
|
2396
|
+
if url and functions and isinstance(functions, list):
|
2397
|
+
include = {
|
2398
|
+
"url": url,
|
2399
|
+
"functions": functions
|
2400
|
+
}
|
2401
|
+
if meta_data and isinstance(meta_data, dict):
|
2402
|
+
include["meta_data"] = meta_data
|
2403
|
+
|
2404
|
+
self._function_includes.append(include)
|
2405
|
+
return self
|
2406
|
+
|
2407
|
+
def set_function_includes(self, includes: List[Dict[str, Any]]) -> 'AgentBase':
|
2408
|
+
"""
|
2409
|
+
Set the complete list of function includes
|
2410
|
+
|
2411
|
+
Args:
|
2412
|
+
includes: List of include objects, each with url and functions properties
|
2413
|
+
|
2414
|
+
Returns:
|
2415
|
+
Self for method chaining
|
2416
|
+
"""
|
2417
|
+
if includes and isinstance(includes, list):
|
2418
|
+
# Validate each include has required properties
|
2419
|
+
valid_includes = []
|
2420
|
+
for include in includes:
|
2421
|
+
if isinstance(include, dict) and "url" in include and "functions" in include:
|
2422
|
+
if isinstance(include["functions"], list):
|
2423
|
+
valid_includes.append(include)
|
2424
|
+
|
2425
|
+
self._function_includes = valid_includes
|
2426
|
+
return self
|
2427
|
+
|
2428
|
+
def enable_sip_routing(self, auto_map: bool = True, path: str = "/sip") -> 'AgentBase':
|
2429
|
+
"""
|
2430
|
+
Enable SIP-based routing for this agent
|
2431
|
+
|
2432
|
+
This allows the agent to automatically route SIP requests based on SIP usernames.
|
2433
|
+
When enabled, an endpoint at the specified path is automatically created
|
2434
|
+
that will handle SIP requests and deliver them to this agent.
|
2435
|
+
|
2436
|
+
Args:
|
2437
|
+
auto_map: Whether to automatically map common SIP usernames to this agent
|
2438
|
+
(based on the agent name and route path)
|
2439
|
+
path: The path to register the SIP routing endpoint (default: "/sip")
|
2440
|
+
|
2441
|
+
Returns:
|
2442
|
+
Self for method chaining
|
2443
|
+
"""
|
2444
|
+
# Create a routing callback that handles SIP usernames
|
2445
|
+
def sip_routing_callback(request: Request, body: Dict[str, Any]) -> Optional[str]:
|
2446
|
+
# Extract SIP username from the request body
|
2447
|
+
sip_username = self.extract_sip_username(body)
|
2448
|
+
|
2449
|
+
if sip_username:
|
2450
|
+
self.log.info("sip_username_extracted", username=sip_username)
|
2451
|
+
|
2452
|
+
# Check if this username is registered with this agent
|
2453
|
+
if hasattr(self, '_sip_usernames') and sip_username.lower() in self._sip_usernames:
|
2454
|
+
self.log.info("sip_username_matched", username=sip_username)
|
2455
|
+
# This route is already being handled by the agent, no need to redirect
|
2456
|
+
return None
|
2457
|
+
else:
|
2458
|
+
self.log.info("sip_username_not_matched", username=sip_username)
|
2459
|
+
# Not registered with this agent, let routing continue
|
2460
|
+
|
2461
|
+
return None
|
2462
|
+
|
2463
|
+
# Register the callback with the SWMLService, specifying the path
|
2464
|
+
self.register_routing_callback(sip_routing_callback, path=path)
|
2465
|
+
|
2466
|
+
# Auto-map common usernames if requested
|
2467
|
+
if auto_map:
|
2468
|
+
self.auto_map_sip_usernames()
|
2469
|
+
|
2470
|
+
return self
|
2471
|
+
|
2472
|
+
def register_sip_username(self, sip_username: str) -> 'AgentBase':
|
2473
|
+
"""
|
2474
|
+
Register a SIP username that should be routed to this agent
|
2475
|
+
|
2476
|
+
Args:
|
2477
|
+
sip_username: SIP username to register
|
2478
|
+
|
2479
|
+
Returns:
|
2480
|
+
Self for method chaining
|
2481
|
+
"""
|
2482
|
+
if not hasattr(self, '_sip_usernames'):
|
2483
|
+
self._sip_usernames = set()
|
2484
|
+
|
2485
|
+
self._sip_usernames.add(sip_username.lower())
|
2486
|
+
self.log.info("sip_username_registered", username=sip_username)
|
2487
|
+
|
2488
|
+
return self
|
2489
|
+
|
2490
|
+
def auto_map_sip_usernames(self) -> 'AgentBase':
|
2491
|
+
"""
|
2492
|
+
Automatically register common SIP usernames based on this agent's
|
2493
|
+
name and route
|
2494
|
+
|
2495
|
+
Returns:
|
2496
|
+
Self for method chaining
|
2497
|
+
"""
|
2498
|
+
# Register username based on agent name
|
2499
|
+
clean_name = re.sub(r'[^a-z0-9_]', '', self.name.lower())
|
2500
|
+
if clean_name:
|
2501
|
+
self.register_sip_username(clean_name)
|
2502
|
+
|
2503
|
+
# Register username based on route (without slashes)
|
2504
|
+
clean_route = re.sub(r'[^a-z0-9_]', '', self.route.lower())
|
2505
|
+
if clean_route and clean_route != clean_name:
|
2506
|
+
self.register_sip_username(clean_route)
|
2507
|
+
|
2508
|
+
# Register common variations if they make sense
|
2509
|
+
if len(clean_name) > 3:
|
2510
|
+
# Register without vowels
|
2511
|
+
no_vowels = re.sub(r'[aeiou]', '', clean_name)
|
2512
|
+
if no_vowels != clean_name and len(no_vowels) > 2:
|
2513
|
+
self.register_sip_username(no_vowels)
|
2514
|
+
|
2515
|
+
return self
|
2516
|
+
|
2517
|
+
def set_web_hook_url(self, url: str) -> 'AgentBase':
|
2518
|
+
"""
|
2519
|
+
Override the default web_hook_url with a supplied URL string
|
2520
|
+
|
2521
|
+
Args:
|
2522
|
+
url: The URL to use for SWAIG function webhooks
|
2523
|
+
|
2524
|
+
Returns:
|
2525
|
+
Self for method chaining
|
2526
|
+
"""
|
2527
|
+
self._web_hook_url_override = url
|
2528
|
+
return self
|
2529
|
+
|
2530
|
+
def set_post_prompt_url(self, url: str) -> 'AgentBase':
|
2531
|
+
"""
|
2532
|
+
Override the default post_prompt_url with a supplied URL string
|
2533
|
+
|
2534
|
+
Args:
|
2535
|
+
url: The URL to use for post-prompt summary delivery
|
2536
|
+
|
2537
|
+
Returns:
|
2538
|
+
Self for method chaining
|
2539
|
+
"""
|
2540
|
+
self._post_prompt_url_override = url
|
2541
|
+
return self
|