signalwire-agents 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- signalwire_agents/__init__.py +1 -1
- signalwire_agents/core/agent_base.py +1267 -1238
- signalwire_agents/core/security/session_manager.py +174 -86
- signalwire_agents/core/swml_service.py +195 -50
- signalwire_agents/prefabs/concierge.py +9 -2
- signalwire_agents/prefabs/faq_bot.py +3 -0
- signalwire_agents/prefabs/info_gatherer.py +3 -0
- signalwire_agents/prefabs/receptionist.py +3 -0
- signalwire_agents/prefabs/survey.py +9 -2
- {signalwire_agents-0.1.5.dist-info → signalwire_agents-0.1.7.dist-info}/METADATA +2 -1
- {signalwire_agents-0.1.5.dist-info → signalwire_agents-0.1.7.dist-info}/RECORD +15 -15
- {signalwire_agents-0.1.5.data → signalwire_agents-0.1.7.data}/data/schema.json +0 -0
- {signalwire_agents-0.1.5.dist-info → signalwire_agents-0.1.7.dist-info}/WHEEL +0 -0
- {signalwire_agents-0.1.5.dist-info → signalwire_agents-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {signalwire_agents-0.1.5.dist-info → signalwire_agents-0.1.7.dist-info}/top_level.txt +0 -0
@@ -14,73 +14,92 @@ Session manager for handling call sessions and security tokens
|
|
14
14
|
from typing import Dict, Any, Optional, Tuple
|
15
15
|
import secrets
|
16
16
|
import time
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
"""
|
22
|
-
Represents a single call session with associated tokens and state
|
23
|
-
"""
|
24
|
-
def __init__(self, call_id: str):
|
25
|
-
self.call_id = call_id
|
26
|
-
self.tokens: Dict[str, str] = {} # function_name -> token
|
27
|
-
self.state = "pending" # pending, active, expired
|
28
|
-
self.started_at = datetime.now()
|
29
|
-
self.metadata: Dict[str, Any] = {} # Custom state for the call
|
17
|
+
import hmac
|
18
|
+
import hashlib
|
19
|
+
import base64
|
20
|
+
from datetime import datetime, timedelta
|
30
21
|
|
31
22
|
|
32
23
|
class SessionManager:
|
33
24
|
"""
|
34
|
-
Manages
|
25
|
+
Manages security tokens for function calls
|
26
|
+
|
27
|
+
This implementation is completely stateless - it does not track call sessions
|
28
|
+
or store any information in memory. All validation is done using cryptographic
|
29
|
+
signatures with the tokens containing all necessary information.
|
35
30
|
"""
|
36
|
-
def __init__(self, token_expiry_secs: int =
|
31
|
+
def __init__(self, token_expiry_secs: int = 3600, secret_key: Optional[str] = None):
|
37
32
|
"""
|
38
33
|
Initialize the session manager
|
39
34
|
|
40
35
|
Args:
|
41
|
-
token_expiry_secs: Seconds until tokens expire (default:
|
36
|
+
token_expiry_secs: Seconds until tokens expire (default: 60 minutes)
|
37
|
+
secret_key: Secret key for signing tokens (generated if not provided)
|
42
38
|
"""
|
43
|
-
self._active_calls: Dict[str, CallSession] = {}
|
44
39
|
self.token_expiry_secs = token_expiry_secs
|
40
|
+
# Use provided secret key or generate a secure one
|
41
|
+
self.secret_key = secret_key or secrets.token_hex(32)
|
45
42
|
|
46
43
|
def create_session(self, call_id: Optional[str] = None) -> str:
|
47
44
|
"""
|
48
|
-
Create a new
|
45
|
+
Create a new session ID if one isn't provided
|
49
46
|
|
50
47
|
Args:
|
51
48
|
call_id: Optional call ID, generated if not provided
|
52
49
|
|
53
50
|
Returns:
|
54
|
-
The call_id for the
|
51
|
+
The call_id for the session
|
55
52
|
"""
|
56
53
|
# Generate call_id if not provided
|
57
54
|
if not call_id:
|
58
55
|
call_id = secrets.token_urlsafe(16)
|
59
56
|
|
60
|
-
# Create new session
|
61
|
-
self._active_calls[call_id] = CallSession(call_id)
|
62
57
|
return call_id
|
63
58
|
|
64
59
|
def generate_token(self, function_name: str, call_id: str) -> str:
|
65
60
|
"""
|
66
|
-
Generate a secure token for a function call
|
61
|
+
Generate a secure self-contained token for a function call
|
67
62
|
|
68
63
|
Args:
|
69
64
|
function_name: Name of the function to generate a token for
|
70
65
|
call_id: Call session ID
|
71
66
|
|
72
67
|
Returns:
|
73
|
-
A secure
|
74
|
-
|
75
|
-
Raises:
|
76
|
-
ValueError: If the call session does not exist
|
68
|
+
A secure token
|
77
69
|
"""
|
78
|
-
|
79
|
-
|
70
|
+
# Create token parts
|
71
|
+
expiry = int(time.time()) + self.token_expiry_secs
|
72
|
+
nonce = secrets.token_hex(4)
|
73
|
+
|
74
|
+
# Create the message to sign
|
75
|
+
message = f"{call_id}:{function_name}:{expiry}:{nonce}"
|
76
|
+
|
77
|
+
# Sign the message
|
78
|
+
signature = hmac.new(
|
79
|
+
self.secret_key.encode(),
|
80
|
+
message.encode(),
|
81
|
+
hashlib.sha256
|
82
|
+
).hexdigest()[:16] # Use first 16 chars of signature for shorter tokens
|
80
83
|
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
+
# Combine all parts into the token
|
85
|
+
token = f"{call_id}.{function_name}.{expiry}.{nonce}.{signature}"
|
86
|
+
|
87
|
+
# Base64 encode for URL safety
|
88
|
+
return base64.urlsafe_b64encode(token.encode()).decode()
|
89
|
+
|
90
|
+
# Alias for generate_token to maintain backward compatibility
|
91
|
+
def create_tool_token(self, function_name: str, call_id: str) -> str:
|
92
|
+
"""
|
93
|
+
Alias for generate_token to maintain backward compatibility
|
94
|
+
|
95
|
+
Args:
|
96
|
+
function_name: Name of the function to generate a token for
|
97
|
+
call_id: Call session ID
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
A secure token
|
101
|
+
"""
|
102
|
+
return self.generate_token(function_name, call_id)
|
84
103
|
|
85
104
|
def validate_token(self, call_id: str, function_name: str, token: str) -> bool:
|
86
105
|
"""
|
@@ -94,86 +113,155 @@ class SessionManager:
|
|
94
113
|
Returns:
|
95
114
|
True if valid, False otherwise
|
96
115
|
"""
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
# Check if token matches and is not expired
|
102
|
-
expected_token = session.tokens.get(function_name)
|
103
|
-
if not expected_token or expected_token != token:
|
104
|
-
return False
|
116
|
+
try:
|
117
|
+
# Decode the token
|
118
|
+
decoded_token = base64.urlsafe_b64decode(token.encode()).decode()
|
105
119
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
120
|
+
# Split the token parts
|
121
|
+
parts = decoded_token.split('.')
|
122
|
+
if len(parts) != 5:
|
123
|
+
return False
|
124
|
+
|
125
|
+
token_call_id, token_function, token_expiry, token_nonce, token_signature = parts
|
112
126
|
|
113
|
-
|
127
|
+
# Special case: if call_id is None or empty, use the call_id from the token
|
128
|
+
# This helps with scenarios where the call_id isn't provided in the request
|
129
|
+
if not call_id:
|
130
|
+
call_id = token_call_id
|
131
|
+
|
132
|
+
# Verify the function matches
|
133
|
+
if token_function != function_name:
|
134
|
+
return False
|
135
|
+
|
136
|
+
# Check if the token has expired
|
137
|
+
expiry = int(token_expiry)
|
138
|
+
if expiry < time.time():
|
139
|
+
return False
|
140
|
+
|
141
|
+
# Recreate the message and verify the signature
|
142
|
+
message = f"{token_call_id}:{token_function}:{token_expiry}:{token_nonce}"
|
143
|
+
expected_signature = hmac.new(
|
144
|
+
self.secret_key.encode(),
|
145
|
+
message.encode(),
|
146
|
+
hashlib.sha256
|
147
|
+
).hexdigest()[:16]
|
148
|
+
|
149
|
+
if token_signature != expected_signature:
|
150
|
+
return False
|
151
|
+
|
152
|
+
# Finally, verify the call_id matches unless we're in special case
|
153
|
+
# This check is done last to ensure the token is otherwise valid
|
154
|
+
if token_call_id != call_id:
|
155
|
+
return False
|
156
|
+
|
157
|
+
return True
|
158
|
+
except Exception:
|
159
|
+
# Any exception during validation means the token is invalid
|
160
|
+
return False
|
114
161
|
|
115
|
-
|
162
|
+
# Alias for validate_token to maintain backward compatibility
|
163
|
+
def validate_tool_token(self, function_name: str, token: str, call_id: str) -> bool:
|
116
164
|
"""
|
117
|
-
|
165
|
+
Alias for validate_token to maintain backward compatibility
|
118
166
|
|
119
167
|
Args:
|
168
|
+
function_name: Name of the function being called
|
169
|
+
token: Token to validate
|
120
170
|
call_id: Call session ID
|
121
171
|
|
122
172
|
Returns:
|
123
|
-
True if
|
173
|
+
True if valid, False otherwise
|
174
|
+
"""
|
175
|
+
# Reorder parameters to match validate_token signature (call_id first, then function_name)
|
176
|
+
return self.validate_token(call_id=call_id, function_name=function_name, token=token)
|
177
|
+
|
178
|
+
# Legacy methods that now don't track state but provide API compatibility
|
179
|
+
|
180
|
+
def activate_session(self, call_id: str) -> bool:
|
181
|
+
"""
|
182
|
+
Legacy method, does nothing but returns success
|
124
183
|
"""
|
125
|
-
session = self._active_calls.get(call_id)
|
126
|
-
if not session:
|
127
|
-
return False
|
128
|
-
|
129
|
-
session.state = "active"
|
130
184
|
return True
|
131
185
|
|
132
186
|
def end_session(self, call_id: str) -> bool:
|
133
187
|
"""
|
134
|
-
|
135
|
-
|
136
|
-
Args:
|
137
|
-
call_id: Call session ID
|
138
|
-
|
139
|
-
Returns:
|
140
|
-
True if successful, False otherwise
|
188
|
+
Legacy method, does nothing but returns success
|
141
189
|
"""
|
142
|
-
|
143
|
-
del self._active_calls[call_id]
|
144
|
-
return True
|
145
|
-
return False
|
190
|
+
return True
|
146
191
|
|
147
192
|
def get_session_metadata(self, call_id: str) -> Optional[Dict[str, Any]]:
|
148
193
|
"""
|
149
|
-
|
150
|
-
|
151
|
-
Args:
|
152
|
-
call_id: Call session ID
|
153
|
-
|
154
|
-
Returns:
|
155
|
-
Metadata dict or None if session not found
|
194
|
+
Legacy method, always returns empty metadata
|
156
195
|
"""
|
157
|
-
|
158
|
-
if not session:
|
159
|
-
return None
|
160
|
-
return session.metadata
|
196
|
+
return {}
|
161
197
|
|
162
198
|
def set_session_metadata(self, call_id: str, key: str, value: Any) -> bool:
|
163
199
|
"""
|
164
|
-
|
200
|
+
Legacy method, does nothing but returns success
|
201
|
+
"""
|
202
|
+
return True
|
203
|
+
|
204
|
+
def debug_token(self, token: str) -> Dict[str, Any]:
|
205
|
+
"""
|
206
|
+
Debug a token without validating it
|
207
|
+
|
208
|
+
This method decodes the token and extracts its components for debugging purposes
|
209
|
+
without performing validation.
|
165
210
|
|
166
211
|
Args:
|
167
|
-
|
168
|
-
key: Metadata key
|
169
|
-
value: Metadata value
|
212
|
+
token: The token to debug
|
170
213
|
|
171
214
|
Returns:
|
172
|
-
|
215
|
+
Dictionary with token components and analysis
|
173
216
|
"""
|
174
|
-
|
175
|
-
|
176
|
-
|
217
|
+
try:
|
218
|
+
# Decode the token
|
219
|
+
decoded_token = base64.urlsafe_b64decode(token.encode()).decode()
|
177
220
|
|
178
|
-
|
179
|
-
|
221
|
+
# Split the token parts
|
222
|
+
parts = decoded_token.split('.')
|
223
|
+
if len(parts) != 5:
|
224
|
+
return {
|
225
|
+
"valid_format": False,
|
226
|
+
"parts_count": len(parts),
|
227
|
+
"decoded": decoded_token
|
228
|
+
}
|
229
|
+
|
230
|
+
token_call_id, token_function, token_expiry, token_nonce, token_signature = parts
|
231
|
+
|
232
|
+
# Check expiration
|
233
|
+
current_time = int(time.time())
|
234
|
+
try:
|
235
|
+
expiry = int(token_expiry)
|
236
|
+
is_expired = expiry < current_time
|
237
|
+
expires_in = expiry - current_time if not is_expired else 0
|
238
|
+
expiry_date = datetime.fromtimestamp(expiry).isoformat()
|
239
|
+
except ValueError:
|
240
|
+
expiry = None
|
241
|
+
is_expired = None
|
242
|
+
expires_in = None
|
243
|
+
expiry_date = None
|
244
|
+
|
245
|
+
return {
|
246
|
+
"valid_format": True,
|
247
|
+
"components": {
|
248
|
+
"call_id": token_call_id,
|
249
|
+
"function": token_function,
|
250
|
+
"expiry": token_expiry,
|
251
|
+
"expiry_date": expiry_date,
|
252
|
+
"nonce": token_nonce,
|
253
|
+
"signature": token_signature
|
254
|
+
},
|
255
|
+
"status": {
|
256
|
+
"current_time": current_time,
|
257
|
+
"is_expired": is_expired,
|
258
|
+
"expires_in_seconds": expires_in
|
259
|
+
}
|
260
|
+
}
|
261
|
+
except Exception as e:
|
262
|
+
# Any exception during parsing
|
263
|
+
return {
|
264
|
+
"valid_format": False,
|
265
|
+
"error": str(e),
|
266
|
+
"token": token
|
267
|
+
}
|
@@ -126,6 +126,11 @@ class SWMLService:
|
|
126
126
|
self.ssl_enabled = False
|
127
127
|
self.domain = None
|
128
128
|
|
129
|
+
# Initialize proxy detection attributes
|
130
|
+
self._proxy_url_base = os.environ.get('SWML_PROXY_URL_BASE')
|
131
|
+
self._proxy_detection_done = False
|
132
|
+
self._proxy_debug = os.environ.get('SWML_PROXY_DEBUG', '').lower() in ('true', '1', 'yes')
|
133
|
+
|
129
134
|
# Initialize logger for this instance
|
130
135
|
self.log = logger.bind(service=name)
|
131
136
|
self.log.info("service_initializing", route=self.route, host=host, port=port)
|
@@ -560,58 +565,41 @@ class SWMLService:
|
|
560
565
|
|
561
566
|
def as_router(self) -> APIRouter:
|
562
567
|
"""
|
563
|
-
|
568
|
+
Create a FastAPI router for this service
|
564
569
|
|
565
570
|
Returns:
|
566
|
-
FastAPI router
|
571
|
+
APIRouter: FastAPI router
|
567
572
|
"""
|
568
|
-
router = APIRouter()
|
573
|
+
router = APIRouter(redirect_slashes=False)
|
569
574
|
|
570
|
-
# Root endpoint
|
571
|
-
@router.get("")
|
572
|
-
@router.post("")
|
573
|
-
async def handle_root_no_slash(request: Request, response: Response):
|
574
|
-
"""Handle GET/POST requests to the root endpoint"""
|
575
|
-
return await self._handle_request(request, response)
|
576
|
-
|
577
|
-
# Root endpoint - with trailing slash
|
575
|
+
# Root endpoint with and without trailing slash
|
578
576
|
@router.get("/")
|
579
577
|
@router.post("/")
|
580
|
-
async def
|
581
|
-
"""Handle
|
578
|
+
async def handle_root(request: Request, response: Response):
|
579
|
+
"""Handle requests to the root endpoint"""
|
582
580
|
return await self._handle_request(request, response)
|
583
581
|
|
584
|
-
#
|
582
|
+
# Register routing callbacks as needed
|
585
583
|
if hasattr(self, '_routing_callbacks') and self._routing_callbacks:
|
586
584
|
for callback_path, callback_fn in self._routing_callbacks.items():
|
587
|
-
# Skip the root path
|
585
|
+
# Skip the root path which is already handled
|
588
586
|
if callback_path == "/":
|
589
587
|
continue
|
590
588
|
|
591
|
-
# Register
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
589
|
+
# Register both versions: with and without trailing slash
|
590
|
+
path = callback_path.rstrip("/")
|
591
|
+
path_with_slash = f"{path}/"
|
592
|
+
|
593
|
+
@router.get(path)
|
594
|
+
@router.get(path_with_slash)
|
595
|
+
@router.post(path)
|
596
|
+
@router.post(path_with_slash)
|
597
|
+
async def handle_callback(request: Request, response: Response, cb_path=callback_path):
|
598
|
+
"""Handle requests to callback endpoints"""
|
599
|
+
# Store the callback path in the request state
|
597
600
|
request.state.callback_path = cb_path
|
598
601
|
return await self._handle_request(request, response)
|
599
|
-
|
600
|
-
# Register the endpoint with trailing slash if it doesn't already have one
|
601
|
-
if not callback_path.endswith('/'):
|
602
|
-
slash_path = f"{callback_path}/"
|
603
|
-
|
604
|
-
@router.get(slash_path)
|
605
|
-
@router.post(slash_path)
|
606
|
-
async def handle_callback_with_slash(request: Request, response: Response, cb_path=callback_path):
|
607
|
-
"""Handle GET/POST requests to a registered callback path with trailing slash"""
|
608
|
-
# Store the callback path in request state for _handle_request to use
|
609
|
-
request.state.callback_path = cb_path
|
610
|
-
return await self._handle_request(request, response)
|
611
|
-
|
612
|
-
self.log.info("callback_endpoint_registered", path=callback_path)
|
613
602
|
|
614
|
-
self._router = router
|
615
603
|
return router
|
616
604
|
|
617
605
|
def register_routing_callback(self, callback_fn: Callable[[Request, Dict[str, Any]], Optional[str]],
|
@@ -691,6 +679,11 @@ class SWMLService:
|
|
691
679
|
Returns:
|
692
680
|
Response with SWML document or error
|
693
681
|
"""
|
682
|
+
# Auto-detect proxy on first request if not explicitly configured
|
683
|
+
if not self._proxy_detection_done and not self._proxy_url_base:
|
684
|
+
self._detect_proxy_from_request(request)
|
685
|
+
self._proxy_detection_done = True
|
686
|
+
|
694
687
|
# Check auth
|
695
688
|
if not self._check_basic_auth(request):
|
696
689
|
response.headers["WWW-Authenticate"] = "Basic"
|
@@ -779,20 +772,22 @@ class SWMLService:
|
|
779
772
|
Start a web server for this service
|
780
773
|
|
781
774
|
Args:
|
782
|
-
host:
|
783
|
-
port:
|
775
|
+
host: Host to bind to (defaults to self.host)
|
776
|
+
port: Port to bind to (defaults to self.port)
|
784
777
|
ssl_cert: Path to SSL certificate file
|
785
|
-
ssl_key: Path to SSL
|
786
|
-
ssl_enabled: Whether to enable SSL
|
787
|
-
domain: Domain name for
|
778
|
+
ssl_key: Path to SSL key file
|
779
|
+
ssl_enabled: Whether to enable SSL
|
780
|
+
domain: Domain name for SSL certificate
|
788
781
|
"""
|
789
782
|
import uvicorn
|
790
783
|
|
791
|
-
#
|
792
|
-
self.ssl_enabled = ssl_enabled if ssl_enabled is not None else
|
793
|
-
|
794
|
-
|
795
|
-
|
784
|
+
# Store SSL configuration
|
785
|
+
self.ssl_enabled = ssl_enabled if ssl_enabled is not None else False
|
786
|
+
self.domain = domain
|
787
|
+
|
788
|
+
# Set SSL paths
|
789
|
+
ssl_cert_path = ssl_cert
|
790
|
+
ssl_key_path = ssl_key
|
796
791
|
|
797
792
|
# Validate SSL configuration if enabled
|
798
793
|
if self.ssl_enabled:
|
@@ -807,9 +802,64 @@ class SWMLService:
|
|
807
802
|
# We'll continue, but URLs might not be correctly generated
|
808
803
|
|
809
804
|
if self._app is None:
|
810
|
-
|
805
|
+
# Use redirect_slashes=False to be consistent with AgentBase
|
806
|
+
app = FastAPI(redirect_slashes=False)
|
811
807
|
router = self.as_router()
|
812
|
-
|
808
|
+
|
809
|
+
# Normalize the route to ensure it starts with a slash and doesn't end with one
|
810
|
+
# This avoids the FastAPI error about prefixes ending with slashes
|
811
|
+
normalized_route = "/" + self.route.strip("/")
|
812
|
+
|
813
|
+
# Include router with the normalized prefix
|
814
|
+
app.include_router(router, prefix=normalized_route)
|
815
|
+
|
816
|
+
# Add a catch-all route handler that will handle both /path and /path/ formats
|
817
|
+
# This provides the same behavior without using a trailing slash in the prefix
|
818
|
+
@app.get("/{full_path:path}")
|
819
|
+
@app.post("/{full_path:path}")
|
820
|
+
async def handle_all_routes(request: Request, response: Response, full_path: str):
|
821
|
+
# Get our route path without leading slash for comparison
|
822
|
+
route_path = normalized_route.lstrip("/")
|
823
|
+
route_with_slash = route_path + "/"
|
824
|
+
|
825
|
+
# Log the incoming path for debugging
|
826
|
+
print(f"Catch-all received: '{full_path}', route: '{route_path}'")
|
827
|
+
|
828
|
+
# Check for exact match to our route (without trailing slash)
|
829
|
+
if full_path == route_path:
|
830
|
+
# This is our exact route - handle it directly
|
831
|
+
return await self._handle_request(request, response)
|
832
|
+
|
833
|
+
# Check for our route with a trailing slash or subpaths
|
834
|
+
elif full_path == route_with_slash or full_path.startswith(route_with_slash):
|
835
|
+
# This is our route with a trailing slash
|
836
|
+
# Extract the path after our route prefix
|
837
|
+
sub_path = full_path[len(route_with_slash):]
|
838
|
+
|
839
|
+
# Forward to the appropriate handler in our router
|
840
|
+
if not sub_path:
|
841
|
+
# Root endpoint
|
842
|
+
return await self._handle_request(request, response)
|
843
|
+
|
844
|
+
# Check for routing callbacks if there are any
|
845
|
+
if hasattr(self, '_routing_callbacks'):
|
846
|
+
for callback_path, callback_fn in self._routing_callbacks.items():
|
847
|
+
cb_path_clean = callback_path.strip("/")
|
848
|
+
if sub_path == cb_path_clean or sub_path.startswith(cb_path_clean + "/"):
|
849
|
+
# Store the callback path in request state for handlers to use
|
850
|
+
request.state.callback_path = callback_path
|
851
|
+
return await self._handle_request(request, response)
|
852
|
+
|
853
|
+
# Not our route or not matching our patterns
|
854
|
+
print(f"No match for path: '{full_path}'")
|
855
|
+
return {"error": "Path not found"}
|
856
|
+
|
857
|
+
# Print all routes for debugging
|
858
|
+
print(f"All routes for {self.name}:")
|
859
|
+
for route in app.routes:
|
860
|
+
if hasattr(route, "path"):
|
861
|
+
print(f" {route.path}")
|
862
|
+
|
813
863
|
self._app = app
|
814
864
|
|
815
865
|
host = host or self.host
|
@@ -830,13 +880,14 @@ class SWMLService:
|
|
830
880
|
|
831
881
|
print(f"Service '{self.name}' is available at:")
|
832
882
|
print(f"URL: {protocol}://{display_host}{self.route}")
|
883
|
+
print(f"URL with trailing slash: {protocol}://{display_host}{self.route}/")
|
833
884
|
print(f"Basic Auth: {username}:{password}")
|
834
885
|
|
835
886
|
# Check if SIP routing is enabled and print additional info
|
836
887
|
if self._routing_callbacks:
|
837
888
|
print(f"Callback endpoints:")
|
838
889
|
for path in self._routing_callbacks:
|
839
|
-
print(f"{protocol}://{display_host}{path}")
|
890
|
+
print(f"{protocol}://{display_host}{self.route}{path}")
|
840
891
|
|
841
892
|
# Start uvicorn with or without SSL
|
842
893
|
if self.ssl_enabled and ssl_cert_path and ssl_key_path:
|
@@ -1054,4 +1105,98 @@ class SWMLService:
|
|
1054
1105
|
params = "&".join([f"{k}={v}" for k, v in filtered_params.items()])
|
1055
1106
|
url = f"{url}?{params}"
|
1056
1107
|
|
1057
|
-
return url
|
1108
|
+
return url
|
1109
|
+
|
1110
|
+
def _detect_proxy_from_request(self, request: Request) -> None:
|
1111
|
+
"""
|
1112
|
+
Detect if we're behind a proxy by examining request headers
|
1113
|
+
and auto-configure proxy_url_base if needed
|
1114
|
+
|
1115
|
+
Args:
|
1116
|
+
request: FastAPI Request object
|
1117
|
+
"""
|
1118
|
+
# First check for standard X-Forwarded headers (used by most proxies including ngrok)
|
1119
|
+
forwarded_host = request.headers.get("X-Forwarded-Host")
|
1120
|
+
forwarded_proto = request.headers.get("X-Forwarded-Proto", "http")
|
1121
|
+
|
1122
|
+
if forwarded_host:
|
1123
|
+
# Direct X-Forwarded-* headers - most common case
|
1124
|
+
self._proxy_url_base = f"{forwarded_proto}://{forwarded_host}"
|
1125
|
+
self.log.info("proxy_auto_detected", proxy_url_base=self._proxy_url_base,
|
1126
|
+
source="X-Forwarded headers")
|
1127
|
+
return
|
1128
|
+
|
1129
|
+
# If no standard headers, check other proxy detection methods
|
1130
|
+
|
1131
|
+
# Check for Forwarded header (RFC 7239)
|
1132
|
+
forwarded = request.headers.get("Forwarded")
|
1133
|
+
if forwarded:
|
1134
|
+
# Parse RFC 7239 Forwarded header
|
1135
|
+
try:
|
1136
|
+
# Extract host and proto from Forwarded: for=X;host=Y;proto=Z
|
1137
|
+
parts = [p.strip() for p in forwarded.split(';')]
|
1138
|
+
host_part = next((p for p in parts if p.startswith("host=")), None)
|
1139
|
+
proto_part = next((p for p in parts if p.startswith("proto=")), None)
|
1140
|
+
|
1141
|
+
if host_part:
|
1142
|
+
host = host_part.split('=', 1)[1].strip('"')
|
1143
|
+
proto = proto_part.split('=', 1)[1].strip('"') if proto_part else "http"
|
1144
|
+
self._proxy_url_base = f"{proto}://{host}"
|
1145
|
+
self.log.info("proxy_auto_detected", proxy_url_base=self._proxy_url_base,
|
1146
|
+
source="Forwarded header")
|
1147
|
+
return
|
1148
|
+
except Exception as e:
|
1149
|
+
self.log.warning("forwarded_header_parse_error", error=str(e))
|
1150
|
+
|
1151
|
+
# Try to detect from the URL itself for transparent proxies
|
1152
|
+
if str(request.url).startswith(("https://", "http://")) and not any(
|
1153
|
+
str(request.url).startswith(f"http://{h}") for h in ["localhost", "127.0.0.1", self.host, "0.0.0.0"]
|
1154
|
+
):
|
1155
|
+
# This is likely a transparent proxy - extract base URL
|
1156
|
+
parsed = urlparse(str(request.url))
|
1157
|
+
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
1158
|
+
self._proxy_url_base = base_url
|
1159
|
+
self.log.info("proxy_auto_detected", proxy_url_base=base_url,
|
1160
|
+
source="request URL (transparent proxy)")
|
1161
|
+
return
|
1162
|
+
|
1163
|
+
# Check for other common proxy setups
|
1164
|
+
original_host = request.headers.get("X-Original-Host") or request.headers.get("Host")
|
1165
|
+
if original_host:
|
1166
|
+
# Only use Host if it doesn't look like our local server
|
1167
|
+
local_hosts = [self.host, "localhost", "127.0.0.1", "0.0.0.0"]
|
1168
|
+
local_port = f":{self.port}"
|
1169
|
+
|
1170
|
+
# If host doesn't look like local server or doesn't contain our port
|
1171
|
+
if not any(h in original_host for h in local_hosts) and local_port not in original_host:
|
1172
|
+
proto = "https" if request.url.scheme == "https" else "http"
|
1173
|
+
self._proxy_url_base = f"{proto}://{original_host}"
|
1174
|
+
self.log.info("proxy_auto_detected", proxy_url_base=self._proxy_url_base,
|
1175
|
+
source="Host header")
|
1176
|
+
return
|
1177
|
+
|
1178
|
+
# If forward_for header exists, we're likely behind a proxy but couldn't determine the URL
|
1179
|
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
1180
|
+
if forwarded_for:
|
1181
|
+
self.log.warning("proxy_detected_but_url_unknown",
|
1182
|
+
client_ip=forwarded_for,
|
1183
|
+
message="Proxy detected via X-Forwarded-For header but could not determine public URL")
|
1184
|
+
|
1185
|
+
# No proxy detected, or unable to determine the public URL
|
1186
|
+
if self._proxy_debug:
|
1187
|
+
self.log.info("proxy_detection_failed",
|
1188
|
+
message="Could not auto-detect proxy. If you are behind a proxy, set SWML_PROXY_URL_BASE manually.")
|
1189
|
+
|
1190
|
+
def manual_set_proxy_url(self, proxy_url: str) -> None:
|
1191
|
+
"""
|
1192
|
+
Manually set the proxy URL base for webhook callbacks
|
1193
|
+
|
1194
|
+
This can be called at runtime to set or update the proxy URL
|
1195
|
+
|
1196
|
+
Args:
|
1197
|
+
proxy_url: The base URL to use for webhooks (e.g., https://example.ngrok.io)
|
1198
|
+
"""
|
1199
|
+
if proxy_url:
|
1200
|
+
self._proxy_url_base = proxy_url.rstrip('/')
|
1201
|
+
self.log.info("proxy_url_manually_set", proxy_url_base=self._proxy_url_base)
|
1202
|
+
self._proxy_detection_done = True
|