signalwire-agents 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- signalwire_agents/__init__.py +1 -1
- signalwire_agents/core/agent_base.py +1265 -1252
- signalwire_agents/core/security/session_manager.py +174 -86
- signalwire_agents/core/swml_service.py +90 -49
- signalwire_agents/prefabs/concierge.py +9 -2
- signalwire_agents/prefabs/faq_bot.py +3 -0
- signalwire_agents/prefabs/info_gatherer.py +3 -0
- signalwire_agents/prefabs/receptionist.py +3 -0
- signalwire_agents/prefabs/survey.py +9 -2
- {signalwire_agents-0.1.6.dist-info → signalwire_agents-0.1.7.dist-info}/METADATA +2 -1
- {signalwire_agents-0.1.6.dist-info → signalwire_agents-0.1.7.dist-info}/RECORD +15 -15
- {signalwire_agents-0.1.6.data → signalwire_agents-0.1.7.data}/data/schema.json +0 -0
- {signalwire_agents-0.1.6.dist-info → signalwire_agents-0.1.7.dist-info}/WHEEL +0 -0
- {signalwire_agents-0.1.6.dist-info → signalwire_agents-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {signalwire_agents-0.1.6.dist-info → signalwire_agents-0.1.7.dist-info}/top_level.txt +0 -0
@@ -14,73 +14,92 @@ Session manager for handling call sessions and security tokens
|
|
14
14
|
from typing import Dict, Any, Optional, Tuple
|
15
15
|
import secrets
|
16
16
|
import time
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
"""
|
22
|
-
Represents a single call session with associated tokens and state
|
23
|
-
"""
|
24
|
-
def __init__(self, call_id: str):
|
25
|
-
self.call_id = call_id
|
26
|
-
self.tokens: Dict[str, str] = {} # function_name -> token
|
27
|
-
self.state = "pending" # pending, active, expired
|
28
|
-
self.started_at = datetime.now()
|
29
|
-
self.metadata: Dict[str, Any] = {} # Custom state for the call
|
17
|
+
import hmac
|
18
|
+
import hashlib
|
19
|
+
import base64
|
20
|
+
from datetime import datetime, timedelta
|
30
21
|
|
31
22
|
|
32
23
|
class SessionManager:
|
33
24
|
"""
|
34
|
-
Manages
|
25
|
+
Manages security tokens for function calls
|
26
|
+
|
27
|
+
This implementation is completely stateless - it does not track call sessions
|
28
|
+
or store any information in memory. All validation is done using cryptographic
|
29
|
+
signatures with the tokens containing all necessary information.
|
35
30
|
"""
|
36
|
-
def __init__(self, token_expiry_secs: int =
|
31
|
+
def __init__(self, token_expiry_secs: int = 3600, secret_key: Optional[str] = None):
|
37
32
|
"""
|
38
33
|
Initialize the session manager
|
39
34
|
|
40
35
|
Args:
|
41
|
-
token_expiry_secs: Seconds until tokens expire (default:
|
36
|
+
token_expiry_secs: Seconds until tokens expire (default: 60 minutes)
|
37
|
+
secret_key: Secret key for signing tokens (generated if not provided)
|
42
38
|
"""
|
43
|
-
self._active_calls: Dict[str, CallSession] = {}
|
44
39
|
self.token_expiry_secs = token_expiry_secs
|
40
|
+
# Use provided secret key or generate a secure one
|
41
|
+
self.secret_key = secret_key or secrets.token_hex(32)
|
45
42
|
|
46
43
|
def create_session(self, call_id: Optional[str] = None) -> str:
|
47
44
|
"""
|
48
|
-
Create a new
|
45
|
+
Create a new session ID if one isn't provided
|
49
46
|
|
50
47
|
Args:
|
51
48
|
call_id: Optional call ID, generated if not provided
|
52
49
|
|
53
50
|
Returns:
|
54
|
-
The call_id for the
|
51
|
+
The call_id for the session
|
55
52
|
"""
|
56
53
|
# Generate call_id if not provided
|
57
54
|
if not call_id:
|
58
55
|
call_id = secrets.token_urlsafe(16)
|
59
56
|
|
60
|
-
# Create new session
|
61
|
-
self._active_calls[call_id] = CallSession(call_id)
|
62
57
|
return call_id
|
63
58
|
|
64
59
|
def generate_token(self, function_name: str, call_id: str) -> str:
|
65
60
|
"""
|
66
|
-
Generate a secure token for a function call
|
61
|
+
Generate a secure self-contained token for a function call
|
67
62
|
|
68
63
|
Args:
|
69
64
|
function_name: Name of the function to generate a token for
|
70
65
|
call_id: Call session ID
|
71
66
|
|
72
67
|
Returns:
|
73
|
-
A secure
|
74
|
-
|
75
|
-
Raises:
|
76
|
-
ValueError: If the call session does not exist
|
68
|
+
A secure token
|
77
69
|
"""
|
78
|
-
|
79
|
-
|
70
|
+
# Create token parts
|
71
|
+
expiry = int(time.time()) + self.token_expiry_secs
|
72
|
+
nonce = secrets.token_hex(4)
|
73
|
+
|
74
|
+
# Create the message to sign
|
75
|
+
message = f"{call_id}:{function_name}:{expiry}:{nonce}"
|
76
|
+
|
77
|
+
# Sign the message
|
78
|
+
signature = hmac.new(
|
79
|
+
self.secret_key.encode(),
|
80
|
+
message.encode(),
|
81
|
+
hashlib.sha256
|
82
|
+
).hexdigest()[:16] # Use first 16 chars of signature for shorter tokens
|
80
83
|
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
+
# Combine all parts into the token
|
85
|
+
token = f"{call_id}.{function_name}.{expiry}.{nonce}.{signature}"
|
86
|
+
|
87
|
+
# Base64 encode for URL safety
|
88
|
+
return base64.urlsafe_b64encode(token.encode()).decode()
|
89
|
+
|
90
|
+
# Alias for generate_token to maintain backward compatibility
|
91
|
+
def create_tool_token(self, function_name: str, call_id: str) -> str:
|
92
|
+
"""
|
93
|
+
Alias for generate_token to maintain backward compatibility
|
94
|
+
|
95
|
+
Args:
|
96
|
+
function_name: Name of the function to generate a token for
|
97
|
+
call_id: Call session ID
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
A secure token
|
101
|
+
"""
|
102
|
+
return self.generate_token(function_name, call_id)
|
84
103
|
|
85
104
|
def validate_token(self, call_id: str, function_name: str, token: str) -> bool:
|
86
105
|
"""
|
@@ -94,86 +113,155 @@ class SessionManager:
|
|
94
113
|
Returns:
|
95
114
|
True if valid, False otherwise
|
96
115
|
"""
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
# Check if token matches and is not expired
|
102
|
-
expected_token = session.tokens.get(function_name)
|
103
|
-
if not expected_token or expected_token != token:
|
104
|
-
return False
|
116
|
+
try:
|
117
|
+
# Decode the token
|
118
|
+
decoded_token = base64.urlsafe_b64decode(token.encode()).decode()
|
105
119
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
120
|
+
# Split the token parts
|
121
|
+
parts = decoded_token.split('.')
|
122
|
+
if len(parts) != 5:
|
123
|
+
return False
|
124
|
+
|
125
|
+
token_call_id, token_function, token_expiry, token_nonce, token_signature = parts
|
112
126
|
|
113
|
-
|
127
|
+
# Special case: if call_id is None or empty, use the call_id from the token
|
128
|
+
# This helps with scenarios where the call_id isn't provided in the request
|
129
|
+
if not call_id:
|
130
|
+
call_id = token_call_id
|
131
|
+
|
132
|
+
# Verify the function matches
|
133
|
+
if token_function != function_name:
|
134
|
+
return False
|
135
|
+
|
136
|
+
# Check if the token has expired
|
137
|
+
expiry = int(token_expiry)
|
138
|
+
if expiry < time.time():
|
139
|
+
return False
|
140
|
+
|
141
|
+
# Recreate the message and verify the signature
|
142
|
+
message = f"{token_call_id}:{token_function}:{token_expiry}:{token_nonce}"
|
143
|
+
expected_signature = hmac.new(
|
144
|
+
self.secret_key.encode(),
|
145
|
+
message.encode(),
|
146
|
+
hashlib.sha256
|
147
|
+
).hexdigest()[:16]
|
148
|
+
|
149
|
+
if token_signature != expected_signature:
|
150
|
+
return False
|
151
|
+
|
152
|
+
# Finally, verify the call_id matches unless we're in special case
|
153
|
+
# This check is done last to ensure the token is otherwise valid
|
154
|
+
if token_call_id != call_id:
|
155
|
+
return False
|
156
|
+
|
157
|
+
return True
|
158
|
+
except Exception:
|
159
|
+
# Any exception during validation means the token is invalid
|
160
|
+
return False
|
114
161
|
|
115
|
-
|
162
|
+
# Alias for validate_token to maintain backward compatibility
|
163
|
+
def validate_tool_token(self, function_name: str, token: str, call_id: str) -> bool:
|
116
164
|
"""
|
117
|
-
|
165
|
+
Alias for validate_token to maintain backward compatibility
|
118
166
|
|
119
167
|
Args:
|
168
|
+
function_name: Name of the function being called
|
169
|
+
token: Token to validate
|
120
170
|
call_id: Call session ID
|
121
171
|
|
122
172
|
Returns:
|
123
|
-
True if
|
173
|
+
True if valid, False otherwise
|
174
|
+
"""
|
175
|
+
# Reorder parameters to match validate_token signature (call_id first, then function_name)
|
176
|
+
return self.validate_token(call_id=call_id, function_name=function_name, token=token)
|
177
|
+
|
178
|
+
# Legacy methods that now don't track state but provide API compatibility
|
179
|
+
|
180
|
+
def activate_session(self, call_id: str) -> bool:
|
181
|
+
"""
|
182
|
+
Legacy method, does nothing but returns success
|
124
183
|
"""
|
125
|
-
session = self._active_calls.get(call_id)
|
126
|
-
if not session:
|
127
|
-
return False
|
128
|
-
|
129
|
-
session.state = "active"
|
130
184
|
return True
|
131
185
|
|
132
186
|
def end_session(self, call_id: str) -> bool:
|
133
187
|
"""
|
134
|
-
|
135
|
-
|
136
|
-
Args:
|
137
|
-
call_id: Call session ID
|
138
|
-
|
139
|
-
Returns:
|
140
|
-
True if successful, False otherwise
|
188
|
+
Legacy method, does nothing but returns success
|
141
189
|
"""
|
142
|
-
|
143
|
-
del self._active_calls[call_id]
|
144
|
-
return True
|
145
|
-
return False
|
190
|
+
return True
|
146
191
|
|
147
192
|
def get_session_metadata(self, call_id: str) -> Optional[Dict[str, Any]]:
|
148
193
|
"""
|
149
|
-
|
150
|
-
|
151
|
-
Args:
|
152
|
-
call_id: Call session ID
|
153
|
-
|
154
|
-
Returns:
|
155
|
-
Metadata dict or None if session not found
|
194
|
+
Legacy method, always returns empty metadata
|
156
195
|
"""
|
157
|
-
|
158
|
-
if not session:
|
159
|
-
return None
|
160
|
-
return session.metadata
|
196
|
+
return {}
|
161
197
|
|
162
198
|
def set_session_metadata(self, call_id: str, key: str, value: Any) -> bool:
|
163
199
|
"""
|
164
|
-
|
200
|
+
Legacy method, does nothing but returns success
|
201
|
+
"""
|
202
|
+
return True
|
203
|
+
|
204
|
+
def debug_token(self, token: str) -> Dict[str, Any]:
|
205
|
+
"""
|
206
|
+
Debug a token without validating it
|
207
|
+
|
208
|
+
This method decodes the token and extracts its components for debugging purposes
|
209
|
+
without performing validation.
|
165
210
|
|
166
211
|
Args:
|
167
|
-
|
168
|
-
key: Metadata key
|
169
|
-
value: Metadata value
|
212
|
+
token: The token to debug
|
170
213
|
|
171
214
|
Returns:
|
172
|
-
|
215
|
+
Dictionary with token components and analysis
|
173
216
|
"""
|
174
|
-
|
175
|
-
|
176
|
-
|
217
|
+
try:
|
218
|
+
# Decode the token
|
219
|
+
decoded_token = base64.urlsafe_b64decode(token.encode()).decode()
|
177
220
|
|
178
|
-
|
179
|
-
|
221
|
+
# Split the token parts
|
222
|
+
parts = decoded_token.split('.')
|
223
|
+
if len(parts) != 5:
|
224
|
+
return {
|
225
|
+
"valid_format": False,
|
226
|
+
"parts_count": len(parts),
|
227
|
+
"decoded": decoded_token
|
228
|
+
}
|
229
|
+
|
230
|
+
token_call_id, token_function, token_expiry, token_nonce, token_signature = parts
|
231
|
+
|
232
|
+
# Check expiration
|
233
|
+
current_time = int(time.time())
|
234
|
+
try:
|
235
|
+
expiry = int(token_expiry)
|
236
|
+
is_expired = expiry < current_time
|
237
|
+
expires_in = expiry - current_time if not is_expired else 0
|
238
|
+
expiry_date = datetime.fromtimestamp(expiry).isoformat()
|
239
|
+
except ValueError:
|
240
|
+
expiry = None
|
241
|
+
is_expired = None
|
242
|
+
expires_in = None
|
243
|
+
expiry_date = None
|
244
|
+
|
245
|
+
return {
|
246
|
+
"valid_format": True,
|
247
|
+
"components": {
|
248
|
+
"call_id": token_call_id,
|
249
|
+
"function": token_function,
|
250
|
+
"expiry": token_expiry,
|
251
|
+
"expiry_date": expiry_date,
|
252
|
+
"nonce": token_nonce,
|
253
|
+
"signature": token_signature
|
254
|
+
},
|
255
|
+
"status": {
|
256
|
+
"current_time": current_time,
|
257
|
+
"is_expired": is_expired,
|
258
|
+
"expires_in_seconds": expires_in
|
259
|
+
}
|
260
|
+
}
|
261
|
+
except Exception as e:
|
262
|
+
# Any exception during parsing
|
263
|
+
return {
|
264
|
+
"valid_format": False,
|
265
|
+
"error": str(e),
|
266
|
+
"token": token
|
267
|
+
}
|
@@ -565,58 +565,41 @@ class SWMLService:
|
|
565
565
|
|
566
566
|
def as_router(self) -> APIRouter:
|
567
567
|
"""
|
568
|
-
|
568
|
+
Create a FastAPI router for this service
|
569
569
|
|
570
570
|
Returns:
|
571
|
-
FastAPI router
|
571
|
+
APIRouter: FastAPI router
|
572
572
|
"""
|
573
|
-
router = APIRouter()
|
573
|
+
router = APIRouter(redirect_slashes=False)
|
574
574
|
|
575
|
-
# Root endpoint
|
576
|
-
@router.get("")
|
577
|
-
@router.post("")
|
578
|
-
async def handle_root_no_slash(request: Request, response: Response):
|
579
|
-
"""Handle GET/POST requests to the root endpoint"""
|
580
|
-
return await self._handle_request(request, response)
|
581
|
-
|
582
|
-
# Root endpoint - with trailing slash
|
575
|
+
# Root endpoint with and without trailing slash
|
583
576
|
@router.get("/")
|
584
577
|
@router.post("/")
|
585
|
-
async def
|
586
|
-
"""Handle
|
578
|
+
async def handle_root(request: Request, response: Response):
|
579
|
+
"""Handle requests to the root endpoint"""
|
587
580
|
return await self._handle_request(request, response)
|
588
581
|
|
589
|
-
#
|
582
|
+
# Register routing callbacks as needed
|
590
583
|
if hasattr(self, '_routing_callbacks') and self._routing_callbacks:
|
591
584
|
for callback_path, callback_fn in self._routing_callbacks.items():
|
592
|
-
# Skip the root path
|
585
|
+
# Skip the root path which is already handled
|
593
586
|
if callback_path == "/":
|
594
587
|
continue
|
595
588
|
|
596
|
-
# Register
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
589
|
+
# Register both versions: with and without trailing slash
|
590
|
+
path = callback_path.rstrip("/")
|
591
|
+
path_with_slash = f"{path}/"
|
592
|
+
|
593
|
+
@router.get(path)
|
594
|
+
@router.get(path_with_slash)
|
595
|
+
@router.post(path)
|
596
|
+
@router.post(path_with_slash)
|
597
|
+
async def handle_callback(request: Request, response: Response, cb_path=callback_path):
|
598
|
+
"""Handle requests to callback endpoints"""
|
599
|
+
# Store the callback path in the request state
|
602
600
|
request.state.callback_path = cb_path
|
603
601
|
return await self._handle_request(request, response)
|
604
|
-
|
605
|
-
# Register the endpoint with trailing slash if it doesn't already have one
|
606
|
-
if not callback_path.endswith('/'):
|
607
|
-
slash_path = f"{callback_path}/"
|
608
|
-
|
609
|
-
@router.get(slash_path)
|
610
|
-
@router.post(slash_path)
|
611
|
-
async def handle_callback_with_slash(request: Request, response: Response, cb_path=callback_path):
|
612
|
-
"""Handle GET/POST requests to a registered callback path with trailing slash"""
|
613
|
-
# Store the callback path in request state for _handle_request to use
|
614
|
-
request.state.callback_path = cb_path
|
615
|
-
return await self._handle_request(request, response)
|
616
|
-
|
617
|
-
self.log.info("callback_endpoint_registered", path=callback_path)
|
618
602
|
|
619
|
-
self._router = router
|
620
603
|
return router
|
621
604
|
|
622
605
|
def register_routing_callback(self, callback_fn: Callable[[Request, Dict[str, Any]], Optional[str]],
|
@@ -789,20 +772,22 @@ class SWMLService:
|
|
789
772
|
Start a web server for this service
|
790
773
|
|
791
774
|
Args:
|
792
|
-
host:
|
793
|
-
port:
|
775
|
+
host: Host to bind to (defaults to self.host)
|
776
|
+
port: Port to bind to (defaults to self.port)
|
794
777
|
ssl_cert: Path to SSL certificate file
|
795
|
-
ssl_key: Path to SSL
|
796
|
-
ssl_enabled: Whether to enable SSL
|
797
|
-
domain: Domain name for
|
778
|
+
ssl_key: Path to SSL key file
|
779
|
+
ssl_enabled: Whether to enable SSL
|
780
|
+
domain: Domain name for SSL certificate
|
798
781
|
"""
|
799
782
|
import uvicorn
|
800
783
|
|
801
|
-
#
|
802
|
-
self.ssl_enabled = ssl_enabled if ssl_enabled is not None else
|
803
|
-
|
804
|
-
|
805
|
-
|
784
|
+
# Store SSL configuration
|
785
|
+
self.ssl_enabled = ssl_enabled if ssl_enabled is not None else False
|
786
|
+
self.domain = domain
|
787
|
+
|
788
|
+
# Set SSL paths
|
789
|
+
ssl_cert_path = ssl_cert
|
790
|
+
ssl_key_path = ssl_key
|
806
791
|
|
807
792
|
# Validate SSL configuration if enabled
|
808
793
|
if self.ssl_enabled:
|
@@ -817,9 +802,64 @@ class SWMLService:
|
|
817
802
|
# We'll continue, but URLs might not be correctly generated
|
818
803
|
|
819
804
|
if self._app is None:
|
820
|
-
|
805
|
+
# Use redirect_slashes=False to be consistent with AgentBase
|
806
|
+
app = FastAPI(redirect_slashes=False)
|
821
807
|
router = self.as_router()
|
822
|
-
|
808
|
+
|
809
|
+
# Normalize the route to ensure it starts with a slash and doesn't end with one
|
810
|
+
# This avoids the FastAPI error about prefixes ending with slashes
|
811
|
+
normalized_route = "/" + self.route.strip("/")
|
812
|
+
|
813
|
+
# Include router with the normalized prefix
|
814
|
+
app.include_router(router, prefix=normalized_route)
|
815
|
+
|
816
|
+
# Add a catch-all route handler that will handle both /path and /path/ formats
|
817
|
+
# This provides the same behavior without using a trailing slash in the prefix
|
818
|
+
@app.get("/{full_path:path}")
|
819
|
+
@app.post("/{full_path:path}")
|
820
|
+
async def handle_all_routes(request: Request, response: Response, full_path: str):
|
821
|
+
# Get our route path without leading slash for comparison
|
822
|
+
route_path = normalized_route.lstrip("/")
|
823
|
+
route_with_slash = route_path + "/"
|
824
|
+
|
825
|
+
# Log the incoming path for debugging
|
826
|
+
print(f"Catch-all received: '{full_path}', route: '{route_path}'")
|
827
|
+
|
828
|
+
# Check for exact match to our route (without trailing slash)
|
829
|
+
if full_path == route_path:
|
830
|
+
# This is our exact route - handle it directly
|
831
|
+
return await self._handle_request(request, response)
|
832
|
+
|
833
|
+
# Check for our route with a trailing slash or subpaths
|
834
|
+
elif full_path == route_with_slash or full_path.startswith(route_with_slash):
|
835
|
+
# This is our route with a trailing slash
|
836
|
+
# Extract the path after our route prefix
|
837
|
+
sub_path = full_path[len(route_with_slash):]
|
838
|
+
|
839
|
+
# Forward to the appropriate handler in our router
|
840
|
+
if not sub_path:
|
841
|
+
# Root endpoint
|
842
|
+
return await self._handle_request(request, response)
|
843
|
+
|
844
|
+
# Check for routing callbacks if there are any
|
845
|
+
if hasattr(self, '_routing_callbacks'):
|
846
|
+
for callback_path, callback_fn in self._routing_callbacks.items():
|
847
|
+
cb_path_clean = callback_path.strip("/")
|
848
|
+
if sub_path == cb_path_clean or sub_path.startswith(cb_path_clean + "/"):
|
849
|
+
# Store the callback path in request state for handlers to use
|
850
|
+
request.state.callback_path = callback_path
|
851
|
+
return await self._handle_request(request, response)
|
852
|
+
|
853
|
+
# Not our route or not matching our patterns
|
854
|
+
print(f"No match for path: '{full_path}'")
|
855
|
+
return {"error": "Path not found"}
|
856
|
+
|
857
|
+
# Print all routes for debugging
|
858
|
+
print(f"All routes for {self.name}:")
|
859
|
+
for route in app.routes:
|
860
|
+
if hasattr(route, "path"):
|
861
|
+
print(f" {route.path}")
|
862
|
+
|
823
863
|
self._app = app
|
824
864
|
|
825
865
|
host = host or self.host
|
@@ -840,13 +880,14 @@ class SWMLService:
|
|
840
880
|
|
841
881
|
print(f"Service '{self.name}' is available at:")
|
842
882
|
print(f"URL: {protocol}://{display_host}{self.route}")
|
883
|
+
print(f"URL with trailing slash: {protocol}://{display_host}{self.route}/")
|
843
884
|
print(f"Basic Auth: {username}:{password}")
|
844
885
|
|
845
886
|
# Check if SIP routing is enabled and print additional info
|
846
887
|
if self._routing_callbacks:
|
847
888
|
print(f"Callback endpoints:")
|
848
889
|
for path in self._routing_callbacks:
|
849
|
-
print(f"{protocol}://{display_host}{path}")
|
890
|
+
print(f"{protocol}://{display_host}{self.route}{path}")
|
850
891
|
|
851
892
|
# Start uvicorn with or without SSL
|
852
893
|
if self.ssl_enabled and ssl_cert_path and ssl_key_path:
|
@@ -50,6 +50,9 @@ class ConciergeAgent(AgentBase):
|
|
50
50
|
hours_of_operation: Optional[Dict[str, str]] = None,
|
51
51
|
special_instructions: Optional[List[str]] = None,
|
52
52
|
welcome_message: Optional[str] = None,
|
53
|
+
name: str = "concierge",
|
54
|
+
route: str = "/concierge",
|
55
|
+
enable_state_tracking: bool = True,
|
53
56
|
**kwargs
|
54
57
|
):
|
55
58
|
"""
|
@@ -62,13 +65,17 @@ class ConciergeAgent(AgentBase):
|
|
62
65
|
hours_of_operation: Optional dictionary of operating hours
|
63
66
|
special_instructions: Optional list of special instructions
|
64
67
|
welcome_message: Optional custom welcome message
|
68
|
+
name: Agent name for the route
|
69
|
+
route: HTTP route for this agent
|
70
|
+
enable_state_tracking: Whether to enable state tracking (default: True)
|
65
71
|
**kwargs: Additional arguments for AgentBase
|
66
72
|
"""
|
67
73
|
# Initialize the base agent
|
68
74
|
super().__init__(
|
69
|
-
name=
|
70
|
-
route=
|
75
|
+
name=name,
|
76
|
+
route=route,
|
71
77
|
use_pom=True,
|
78
|
+
enable_state_tracking=enable_state_tracking,
|
72
79
|
**kwargs
|
73
80
|
)
|
74
81
|
|
@@ -51,6 +51,7 @@ class FAQBotAgent(AgentBase):
|
|
51
51
|
persona: Optional[str] = None,
|
52
52
|
name: str = "faq_bot",
|
53
53
|
route: str = "/faq",
|
54
|
+
enable_state_tracking: bool = True, # Enable state tracking by default
|
54
55
|
**kwargs
|
55
56
|
):
|
56
57
|
"""
|
@@ -65,6 +66,7 @@ class FAQBotAgent(AgentBase):
|
|
65
66
|
persona: Optional custom personality description
|
66
67
|
name: Agent name for the route
|
67
68
|
route: HTTP route for this agent
|
69
|
+
enable_state_tracking: Whether to enable state tracking (default: True)
|
68
70
|
**kwargs: Additional arguments for AgentBase
|
69
71
|
"""
|
70
72
|
# Initialize the base agent
|
@@ -72,6 +74,7 @@ class FAQBotAgent(AgentBase):
|
|
72
74
|
name=name,
|
73
75
|
route=route,
|
74
76
|
use_pom=True,
|
77
|
+
enable_state_tracking=enable_state_tracking, # Pass state tracking parameter to base
|
75
78
|
**kwargs
|
76
79
|
)
|
77
80
|
|
@@ -42,6 +42,7 @@ class InfoGathererAgent(AgentBase):
|
|
42
42
|
questions: List[Dict[str, str]],
|
43
43
|
name: str = "info_gatherer",
|
44
44
|
route: str = "/info_gatherer",
|
45
|
+
enable_state_tracking: bool = True, # Enable state tracking by default for InfoGatherer
|
45
46
|
**kwargs
|
46
47
|
):
|
47
48
|
"""
|
@@ -54,6 +55,7 @@ class InfoGathererAgent(AgentBase):
|
|
54
55
|
- confirm: (Optional) If set to True, the agent will confirm the answer before submitting
|
55
56
|
name: Agent name for the route
|
56
57
|
route: HTTP route for this agent
|
58
|
+
enable_state_tracking: Whether to enable state tracking (default: True)
|
57
59
|
**kwargs: Additional arguments for AgentBase
|
58
60
|
"""
|
59
61
|
# Initialize the base agent
|
@@ -61,6 +63,7 @@ class InfoGathererAgent(AgentBase):
|
|
61
63
|
name=name,
|
62
64
|
route=route,
|
63
65
|
use_pom=True,
|
66
|
+
enable_state_tracking=enable_state_tracking, # Pass state tracking parameter to base
|
64
67
|
**kwargs
|
65
68
|
)
|
66
69
|
|
@@ -41,6 +41,7 @@ class ReceptionistAgent(AgentBase):
|
|
41
41
|
route: str = "/receptionist",
|
42
42
|
greeting: str = "Thank you for calling. How can I help you today?",
|
43
43
|
voice: str = "elevenlabs.josh",
|
44
|
+
enable_state_tracking: bool = True, # Enable state tracking by default
|
44
45
|
**kwargs
|
45
46
|
):
|
46
47
|
"""
|
@@ -55,6 +56,7 @@ class ReceptionistAgent(AgentBase):
|
|
55
56
|
route: HTTP route for this agent
|
56
57
|
greeting: Initial greeting message
|
57
58
|
voice: Voice ID to use
|
59
|
+
enable_state_tracking: Whether to enable state tracking (default: True)
|
58
60
|
**kwargs: Additional arguments for AgentBase
|
59
61
|
"""
|
60
62
|
# Initialize the base agent
|
@@ -62,6 +64,7 @@ class ReceptionistAgent(AgentBase):
|
|
62
64
|
name=name,
|
63
65
|
route=route,
|
64
66
|
use_pom=True,
|
67
|
+
enable_state_tracking=enable_state_tracking, # Pass state tracking parameter to base
|
65
68
|
**kwargs
|
66
69
|
)
|
67
70
|
|
@@ -60,6 +60,9 @@ class SurveyAgent(AgentBase):
|
|
60
60
|
conclusion: Optional[str] = None,
|
61
61
|
brand_name: Optional[str] = None,
|
62
62
|
max_retries: int = 2,
|
63
|
+
name: str = "survey",
|
64
|
+
route: str = "/survey",
|
65
|
+
enable_state_tracking: bool = True, # Enable state tracking by default
|
63
66
|
**kwargs
|
64
67
|
):
|
65
68
|
"""
|
@@ -78,13 +81,17 @@ class SurveyAgent(AgentBase):
|
|
78
81
|
conclusion: Optional custom conclusion message
|
79
82
|
brand_name: Optional brand or company name
|
80
83
|
max_retries: Maximum number of times to retry invalid answers
|
84
|
+
name: Name for the agent (default: "survey")
|
85
|
+
route: HTTP route for the agent (default: "/survey")
|
86
|
+
enable_state_tracking: Whether to enable state tracking (default: True)
|
81
87
|
**kwargs: Additional arguments for AgentBase
|
82
88
|
"""
|
83
89
|
# Initialize the base agent
|
84
90
|
super().__init__(
|
85
|
-
name=
|
86
|
-
route=
|
91
|
+
name=name,
|
92
|
+
route=route,
|
87
93
|
use_pom=True,
|
94
|
+
enable_state_tracking=enable_state_tracking, # Pass state tracking parameter to base
|
88
95
|
**kwargs
|
89
96
|
)
|
90
97
|
|