signalwire-agents 0.1.28__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- signalwire_agents/__init__.py +1 -1
- signalwire_agents/core/auth_handler.py +233 -0
- signalwire_agents/core/config_loader.py +259 -0
- signalwire_agents/core/contexts.py +75 -0
- signalwire_agents/core/security_config.py +333 -0
- signalwire_agents/core/swml_service.py +19 -25
- signalwire_agents/search/search_service.py +200 -11
- {signalwire_agents-0.1.28.dist-info → signalwire_agents-0.1.29.dist-info}/METADATA +1 -1
- {signalwire_agents-0.1.28.dist-info → signalwire_agents-0.1.29.dist-info}/RECORD +13 -10
- {signalwire_agents-0.1.28.dist-info → signalwire_agents-0.1.29.dist-info}/WHEEL +0 -0
- {signalwire_agents-0.1.28.dist-info → signalwire_agents-0.1.29.dist-info}/entry_points.txt +0 -0
- {signalwire_agents-0.1.28.dist-info → signalwire_agents-0.1.29.dist-info}/licenses/LICENSE +0 -0
- {signalwire_agents-0.1.28.dist-info → signalwire_agents-0.1.29.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,333 @@
|
|
1
|
+
"""
|
2
|
+
Copyright (c) 2025 SignalWire
|
3
|
+
|
4
|
+
This file is part of the SignalWire AI Agents SDK.
|
5
|
+
|
6
|
+
Licensed under the MIT License.
|
7
|
+
See LICENSE file in the project root for full license information.
|
8
|
+
"""
|
9
|
+
|
10
|
+
import os
|
11
|
+
import secrets
|
12
|
+
from typing import Dict, Any, Optional, Tuple, List, Union
|
13
|
+
from signalwire_agents.core.logging_config import get_logger
|
14
|
+
from signalwire_agents.core.config_loader import ConfigLoader
|
15
|
+
|
16
|
+
logger = get_logger("security_config")
|
17
|
+
|
18
|
+
|
19
|
+
class SecurityConfig:
|
20
|
+
"""
|
21
|
+
Unified security configuration for SignalWire services.
|
22
|
+
|
23
|
+
This class provides centralized security settings that can be used by
|
24
|
+
both SWML and Search services, ensuring consistent security behavior.
|
25
|
+
"""
|
26
|
+
|
27
|
+
# Security environment variable names
|
28
|
+
SSL_ENABLED = 'SWML_SSL_ENABLED'
|
29
|
+
SSL_CERT_PATH = 'SWML_SSL_CERT_PATH'
|
30
|
+
SSL_KEY_PATH = 'SWML_SSL_KEY_PATH'
|
31
|
+
SSL_DOMAIN = 'SWML_DOMAIN'
|
32
|
+
SSL_VERIFY_MODE = 'SWML_SSL_VERIFY_MODE'
|
33
|
+
|
34
|
+
# Additional security settings
|
35
|
+
ALLOWED_HOSTS = 'SWML_ALLOWED_HOSTS'
|
36
|
+
CORS_ORIGINS = 'SWML_CORS_ORIGINS'
|
37
|
+
MAX_REQUEST_SIZE = 'SWML_MAX_REQUEST_SIZE'
|
38
|
+
RATE_LIMIT = 'SWML_RATE_LIMIT'
|
39
|
+
REQUEST_TIMEOUT = 'SWML_REQUEST_TIMEOUT'
|
40
|
+
USE_HSTS = 'SWML_USE_HSTS'
|
41
|
+
HSTS_MAX_AGE = 'SWML_HSTS_MAX_AGE'
|
42
|
+
|
43
|
+
# Authentication
|
44
|
+
BASIC_AUTH_USER = 'SWML_BASIC_AUTH_USER'
|
45
|
+
BASIC_AUTH_PASSWORD = 'SWML_BASIC_AUTH_PASSWORD'
|
46
|
+
|
47
|
+
# Defaults (secure by default)
|
48
|
+
DEFAULTS = {
|
49
|
+
SSL_ENABLED: False, # Off by default, but secure when enabled
|
50
|
+
SSL_VERIFY_MODE: 'CERT_REQUIRED',
|
51
|
+
ALLOWED_HOSTS: '*', # Accept all hosts by default for backward compatibility
|
52
|
+
CORS_ORIGINS: '*', # Accept all origins by default for backward compatibility
|
53
|
+
MAX_REQUEST_SIZE: 10 * 1024 * 1024, # 10MB
|
54
|
+
RATE_LIMIT: 60, # Requests per minute
|
55
|
+
REQUEST_TIMEOUT: 30, # Seconds
|
56
|
+
USE_HSTS: True, # Enable HSTS when HTTPS is on
|
57
|
+
HSTS_MAX_AGE: 31536000, # 1 year
|
58
|
+
}
|
59
|
+
|
60
|
+
def __init__(self, config_file: Optional[str] = None, service_name: Optional[str] = None):
|
61
|
+
"""
|
62
|
+
Initialize security configuration.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
config_file: Optional path to config file
|
66
|
+
service_name: Optional service name for finding service-specific config
|
67
|
+
"""
|
68
|
+
# First, set defaults
|
69
|
+
self._set_defaults()
|
70
|
+
|
71
|
+
# Then load from environment variables (backward compatibility)
|
72
|
+
self.load_from_env()
|
73
|
+
|
74
|
+
# Finally, apply config file if available (highest priority)
|
75
|
+
self._load_config_file(config_file, service_name)
|
76
|
+
|
77
|
+
def _set_defaults(self):
|
78
|
+
"""Set default values for all configuration"""
|
79
|
+
# SSL configuration
|
80
|
+
self.ssl_enabled = self.DEFAULTS[self.SSL_ENABLED]
|
81
|
+
self.ssl_cert_path = None
|
82
|
+
self.ssl_key_path = None
|
83
|
+
self.domain = None
|
84
|
+
self.ssl_verify_mode = self.DEFAULTS[self.SSL_VERIFY_MODE]
|
85
|
+
|
86
|
+
# Additional settings
|
87
|
+
self.allowed_hosts = self._parse_list(self.DEFAULTS[self.ALLOWED_HOSTS])
|
88
|
+
self.cors_origins = self._parse_list(self.DEFAULTS[self.CORS_ORIGINS])
|
89
|
+
self.max_request_size = self.DEFAULTS[self.MAX_REQUEST_SIZE]
|
90
|
+
self.rate_limit = self.DEFAULTS[self.RATE_LIMIT]
|
91
|
+
self.request_timeout = self.DEFAULTS[self.REQUEST_TIMEOUT]
|
92
|
+
self.use_hsts = self.DEFAULTS[self.USE_HSTS]
|
93
|
+
self.hsts_max_age = self.DEFAULTS[self.HSTS_MAX_AGE]
|
94
|
+
|
95
|
+
# Authentication
|
96
|
+
self.basic_auth_user = None
|
97
|
+
self.basic_auth_password = None
|
98
|
+
|
99
|
+
def _load_config_file(self, config_file: Optional[str], service_name: Optional[str]):
|
100
|
+
"""Load configuration from config file if available"""
|
101
|
+
# Find config file
|
102
|
+
if not config_file:
|
103
|
+
config_file = ConfigLoader.find_config_file(service_name)
|
104
|
+
|
105
|
+
if not config_file:
|
106
|
+
return
|
107
|
+
|
108
|
+
# Load config
|
109
|
+
config_loader = ConfigLoader([config_file])
|
110
|
+
if not config_loader.has_config():
|
111
|
+
return
|
112
|
+
|
113
|
+
logger.info("loading_config_from_file", file=config_file)
|
114
|
+
|
115
|
+
# Get security section
|
116
|
+
security_config = config_loader.get_section('security')
|
117
|
+
if not security_config:
|
118
|
+
return
|
119
|
+
|
120
|
+
# Apply security settings (config file takes precedence)
|
121
|
+
if 'ssl_enabled' in security_config:
|
122
|
+
self.ssl_enabled = security_config['ssl_enabled']
|
123
|
+
|
124
|
+
if 'ssl_cert_path' in security_config:
|
125
|
+
self.ssl_cert_path = security_config['ssl_cert_path']
|
126
|
+
|
127
|
+
if 'ssl_key_path' in security_config:
|
128
|
+
self.ssl_key_path = security_config['ssl_key_path']
|
129
|
+
|
130
|
+
if 'domain' in security_config:
|
131
|
+
self.domain = security_config['domain']
|
132
|
+
|
133
|
+
if 'ssl_verify_mode' in security_config:
|
134
|
+
self.ssl_verify_mode = security_config['ssl_verify_mode']
|
135
|
+
|
136
|
+
# Additional settings
|
137
|
+
if 'allowed_hosts' in security_config:
|
138
|
+
self.allowed_hosts = self._parse_list(security_config['allowed_hosts'])
|
139
|
+
|
140
|
+
if 'cors_origins' in security_config:
|
141
|
+
self.cors_origins = self._parse_list(security_config['cors_origins'])
|
142
|
+
|
143
|
+
if 'max_request_size' in security_config:
|
144
|
+
self.max_request_size = int(security_config['max_request_size'])
|
145
|
+
|
146
|
+
if 'rate_limit' in security_config:
|
147
|
+
self.rate_limit = int(security_config['rate_limit'])
|
148
|
+
|
149
|
+
if 'request_timeout' in security_config:
|
150
|
+
self.request_timeout = int(security_config['request_timeout'])
|
151
|
+
|
152
|
+
if 'use_hsts' in security_config:
|
153
|
+
self.use_hsts = security_config['use_hsts']
|
154
|
+
|
155
|
+
if 'hsts_max_age' in security_config:
|
156
|
+
self.hsts_max_age = int(security_config['hsts_max_age'])
|
157
|
+
|
158
|
+
# Authentication from config
|
159
|
+
auth_config = security_config.get('auth', {})
|
160
|
+
if isinstance(auth_config, dict):
|
161
|
+
basic_auth = auth_config.get('basic', {})
|
162
|
+
if isinstance(basic_auth, dict):
|
163
|
+
if 'user' in basic_auth:
|
164
|
+
self.basic_auth_user = basic_auth['user']
|
165
|
+
if 'password' in basic_auth:
|
166
|
+
self.basic_auth_password = basic_auth['password']
|
167
|
+
|
168
|
+
def load_from_env(self):
|
169
|
+
"""Load configuration from environment variables"""
|
170
|
+
# SSL configuration
|
171
|
+
ssl_enabled_env = os.environ.get(self.SSL_ENABLED, '').lower()
|
172
|
+
self.ssl_enabled = ssl_enabled_env in ('true', '1', 'yes')
|
173
|
+
self.ssl_cert_path = os.environ.get(self.SSL_CERT_PATH)
|
174
|
+
self.ssl_key_path = os.environ.get(self.SSL_KEY_PATH)
|
175
|
+
self.domain = os.environ.get(self.SSL_DOMAIN)
|
176
|
+
self.ssl_verify_mode = os.environ.get(self.SSL_VERIFY_MODE, self.DEFAULTS[self.SSL_VERIFY_MODE])
|
177
|
+
|
178
|
+
# Additional security settings
|
179
|
+
self.allowed_hosts = self._parse_list(os.environ.get(self.ALLOWED_HOSTS, self.DEFAULTS[self.ALLOWED_HOSTS]))
|
180
|
+
self.cors_origins = self._parse_list(os.environ.get(self.CORS_ORIGINS, self.DEFAULTS[self.CORS_ORIGINS]))
|
181
|
+
self.max_request_size = int(os.environ.get(self.MAX_REQUEST_SIZE, self.DEFAULTS[self.MAX_REQUEST_SIZE]))
|
182
|
+
self.rate_limit = int(os.environ.get(self.RATE_LIMIT, self.DEFAULTS[self.RATE_LIMIT]))
|
183
|
+
self.request_timeout = int(os.environ.get(self.REQUEST_TIMEOUT, self.DEFAULTS[self.REQUEST_TIMEOUT]))
|
184
|
+
|
185
|
+
# HSTS settings
|
186
|
+
use_hsts_env = os.environ.get(self.USE_HSTS, '').lower()
|
187
|
+
self.use_hsts = use_hsts_env != 'false' if use_hsts_env else self.DEFAULTS[self.USE_HSTS]
|
188
|
+
self.hsts_max_age = int(os.environ.get(self.HSTS_MAX_AGE, self.DEFAULTS[self.HSTS_MAX_AGE]))
|
189
|
+
|
190
|
+
# Authentication
|
191
|
+
self.basic_auth_user = os.environ.get(self.BASIC_AUTH_USER)
|
192
|
+
self.basic_auth_password = os.environ.get(self.BASIC_AUTH_PASSWORD)
|
193
|
+
|
194
|
+
def _parse_list(self, value: Union[str, list]) -> list:
|
195
|
+
"""Parse comma-separated list from environment variable or list from config"""
|
196
|
+
if isinstance(value, list):
|
197
|
+
return value
|
198
|
+
if value == '*':
|
199
|
+
return ['*']
|
200
|
+
return [item.strip() for item in value.split(',') if item.strip()]
|
201
|
+
|
202
|
+
def validate_ssl_config(self) -> Tuple[bool, Optional[str]]:
|
203
|
+
"""
|
204
|
+
Validate SSL configuration.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
Tuple of (is_valid, error_message)
|
208
|
+
"""
|
209
|
+
if not self.ssl_enabled:
|
210
|
+
return True, None
|
211
|
+
|
212
|
+
if not self.ssl_cert_path:
|
213
|
+
return False, "SSL enabled but SWML_SSL_CERT_PATH not set"
|
214
|
+
|
215
|
+
if not self.ssl_key_path:
|
216
|
+
return False, "SSL enabled but SWML_SSL_KEY_PATH not set"
|
217
|
+
|
218
|
+
if not os.path.exists(self.ssl_cert_path):
|
219
|
+
return False, f"SSL certificate file not found: {self.ssl_cert_path}"
|
220
|
+
|
221
|
+
if not os.path.exists(self.ssl_key_path):
|
222
|
+
return False, f"SSL key file not found: {self.ssl_key_path}"
|
223
|
+
|
224
|
+
return True, None
|
225
|
+
|
226
|
+
def get_ssl_context_kwargs(self) -> Dict[str, Any]:
|
227
|
+
"""
|
228
|
+
Get SSL context kwargs for uvicorn.
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
Dictionary of SSL parameters for uvicorn
|
232
|
+
"""
|
233
|
+
if not self.ssl_enabled:
|
234
|
+
return {}
|
235
|
+
|
236
|
+
is_valid, error = self.validate_ssl_config()
|
237
|
+
if not is_valid:
|
238
|
+
logger.error("ssl_validation_failed", error=error)
|
239
|
+
return {}
|
240
|
+
|
241
|
+
return {
|
242
|
+
'ssl_certfile': self.ssl_cert_path,
|
243
|
+
'ssl_keyfile': self.ssl_key_path,
|
244
|
+
# Additional SSL options can be added here
|
245
|
+
}
|
246
|
+
|
247
|
+
def get_basic_auth(self) -> Tuple[str, str]:
|
248
|
+
"""
|
249
|
+
Get basic auth credentials, generating if not set.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
Tuple of (username, password)
|
253
|
+
"""
|
254
|
+
username = self.basic_auth_user or "signalwire"
|
255
|
+
password = self.basic_auth_password or secrets.token_urlsafe(32)
|
256
|
+
|
257
|
+
return username, password
|
258
|
+
|
259
|
+
def get_security_headers(self, is_https: bool = False) -> Dict[str, str]:
|
260
|
+
"""
|
261
|
+
Get security headers to add to responses.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
is_https: Whether the connection is over HTTPS
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
Dictionary of security headers
|
268
|
+
"""
|
269
|
+
headers = {
|
270
|
+
'X-Content-Type-Options': 'nosniff',
|
271
|
+
'X-Frame-Options': 'DENY',
|
272
|
+
'X-XSS-Protection': '1; mode=block',
|
273
|
+
'Referrer-Policy': 'strict-origin-when-cross-origin',
|
274
|
+
}
|
275
|
+
|
276
|
+
# Add HSTS header if HTTPS and enabled
|
277
|
+
if is_https and self.use_hsts:
|
278
|
+
headers['Strict-Transport-Security'] = f'max-age={self.hsts_max_age}; includeSubDomains'
|
279
|
+
|
280
|
+
return headers
|
281
|
+
|
282
|
+
def should_allow_host(self, host: str) -> bool:
|
283
|
+
"""
|
284
|
+
Check if a host is allowed.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
host: The host to check
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
True if the host is allowed
|
291
|
+
"""
|
292
|
+
if '*' in self.allowed_hosts:
|
293
|
+
return True
|
294
|
+
|
295
|
+
return host in self.allowed_hosts
|
296
|
+
|
297
|
+
def get_cors_config(self) -> Dict[str, Any]:
|
298
|
+
"""
|
299
|
+
Get CORS configuration for FastAPI.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
Dictionary of CORS settings
|
303
|
+
"""
|
304
|
+
return {
|
305
|
+
'allow_origins': self.cors_origins,
|
306
|
+
'allow_credentials': True,
|
307
|
+
'allow_methods': ['*'],
|
308
|
+
'allow_headers': ['*'],
|
309
|
+
}
|
310
|
+
|
311
|
+
def get_url_scheme(self) -> str:
|
312
|
+
"""Get the URL scheme based on SSL configuration"""
|
313
|
+
return 'https' if self.ssl_enabled else 'http'
|
314
|
+
|
315
|
+
def log_config(self, service_name: str):
|
316
|
+
"""Log the current security configuration"""
|
317
|
+
logger.info(
|
318
|
+
"security_config_loaded",
|
319
|
+
service=service_name,
|
320
|
+
ssl_enabled=self.ssl_enabled,
|
321
|
+
domain=self.domain,
|
322
|
+
allowed_hosts=self.allowed_hosts,
|
323
|
+
cors_origins=self.cors_origins,
|
324
|
+
max_request_size=self.max_request_size,
|
325
|
+
rate_limit=self.rate_limit,
|
326
|
+
use_hsts=self.use_hsts,
|
327
|
+
has_basic_auth=bool(self.basic_auth_user and self.basic_auth_password)
|
328
|
+
)
|
329
|
+
|
330
|
+
|
331
|
+
# Global instance for easy access (backward compatibility)
|
332
|
+
# Services can create their own instances with specific config files
|
333
|
+
security_config = SecurityConfig()
|
@@ -42,6 +42,7 @@ except ImportError:
|
|
42
42
|
|
43
43
|
from signalwire_agents.utils.schema_utils import SchemaUtils
|
44
44
|
from signalwire_agents.core.swml_handler import VerbHandlerRegistry, SWMLVerbHandler
|
45
|
+
from signalwire_agents.core.security_config import SecurityConfig
|
45
46
|
|
46
47
|
|
47
48
|
class SWMLService:
|
@@ -65,7 +66,8 @@ class SWMLService:
|
|
65
66
|
host: str = "0.0.0.0",
|
66
67
|
port: int = 3000,
|
67
68
|
basic_auth: Optional[Tuple[str, str]] = None,
|
68
|
-
schema_path: Optional[str] = None
|
69
|
+
schema_path: Optional[str] = None,
|
70
|
+
config_file: Optional[str] = None
|
69
71
|
):
|
70
72
|
"""
|
71
73
|
Initialize a new SWML service
|
@@ -77,22 +79,26 @@ class SWMLService:
|
|
77
79
|
port: Port to bind the web server to
|
78
80
|
basic_auth: Optional (username, password) tuple for basic auth
|
79
81
|
schema_path: Optional path to the schema file
|
82
|
+
config_file: Optional path to configuration file
|
80
83
|
"""
|
81
84
|
self.name = name
|
82
85
|
self.route = route.rstrip("/") # Ensure no trailing slash
|
83
86
|
self.host = host
|
84
87
|
self.port = port
|
85
88
|
|
86
|
-
# Initialize SSL configuration from environment variables
|
87
|
-
ssl_enabled_env = os.environ.get('SWML_SSL_ENABLED', '').lower()
|
88
|
-
self.ssl_enabled = ssl_enabled_env in ('true', '1', 'yes')
|
89
|
-
self.domain = os.environ.get('SWML_DOMAIN')
|
90
|
-
self.ssl_cert_path = os.environ.get('SWML_SSL_CERT_PATH')
|
91
|
-
self.ssl_key_path = os.environ.get('SWML_SSL_KEY_PATH')
|
92
|
-
|
93
89
|
# Initialize logger for this instance FIRST before using it
|
94
90
|
self.log = logger.bind(service=name)
|
95
91
|
|
92
|
+
# Load unified security configuration with optional config file
|
93
|
+
self.security = SecurityConfig(config_file=config_file, service_name=name)
|
94
|
+
self.security.log_config("SWMLService")
|
95
|
+
|
96
|
+
# For backward compatibility, expose SSL settings as instance attributes
|
97
|
+
self.ssl_enabled = self.security.ssl_enabled
|
98
|
+
self.domain = self.security.domain
|
99
|
+
self.ssl_cert_path = self.security.ssl_cert_path
|
100
|
+
self.ssl_key_path = self.security.ssl_key_path
|
101
|
+
|
96
102
|
# Initialize proxy detection attributes
|
97
103
|
self._proxy_url_base = os.environ.get('SWML_PROXY_URL_BASE')
|
98
104
|
self._proxy_url_base_from_env = bool(self._proxy_url_base) # Track if it came from environment
|
@@ -108,18 +114,8 @@ class SWMLService:
|
|
108
114
|
# Use provided credentials
|
109
115
|
self._basic_auth = basic_auth
|
110
116
|
else:
|
111
|
-
#
|
112
|
-
|
113
|
-
env_pass = os.environ.get('SWML_BASIC_AUTH_PASSWORD')
|
114
|
-
|
115
|
-
if env_user and env_pass:
|
116
|
-
# Use environment variables
|
117
|
-
self._basic_auth = (env_user, env_pass)
|
118
|
-
else:
|
119
|
-
# Generate random credentials as fallback
|
120
|
-
username = f"user_{secrets.token_hex(4)}"
|
121
|
-
password = secrets.token_urlsafe(16)
|
122
|
-
self._basic_auth = (username, password)
|
117
|
+
# Use unified security config for auth credentials
|
118
|
+
self._basic_auth = self.security.get_basic_auth()
|
123
119
|
|
124
120
|
# Find the schema file if not provided
|
125
121
|
if schema_path is None:
|
@@ -768,11 +764,9 @@ class SWMLService:
|
|
768
764
|
|
769
765
|
# Validate SSL configuration if enabled
|
770
766
|
if self.ssl_enabled:
|
771
|
-
|
772
|
-
|
773
|
-
self.
|
774
|
-
elif not ssl_key_path or not os.path.exists(ssl_key_path):
|
775
|
-
self.log.warning("ssl_key_not_found", path=ssl_key_path)
|
767
|
+
is_valid, error = self.security.validate_ssl_config()
|
768
|
+
if not is_valid:
|
769
|
+
self.log.warning("ssl_config_invalid", error=error)
|
776
770
|
self.ssl_enabled = False
|
777
771
|
elif not self.domain:
|
778
772
|
self.log.warning("ssl_domain_not_specified")
|
@@ -8,15 +8,23 @@ See LICENSE file in the project root for full license information.
|
|
8
8
|
"""
|
9
9
|
|
10
10
|
import logging
|
11
|
-
from typing import Dict, Any, List, Optional
|
11
|
+
from typing import Dict, Any, List, Optional, Tuple
|
12
12
|
|
13
13
|
try:
|
14
|
-
from fastapi import FastAPI, HTTPException
|
14
|
+
from fastapi import FastAPI, HTTPException, Request, Response, Depends
|
15
|
+
from fastapi.middleware.cors import CORSMiddleware
|
16
|
+
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
15
17
|
from pydantic import BaseModel
|
16
18
|
except ImportError:
|
17
19
|
FastAPI = None
|
18
20
|
HTTPException = None
|
19
21
|
BaseModel = None
|
22
|
+
Request = None
|
23
|
+
Response = None
|
24
|
+
Depends = None
|
25
|
+
CORSMiddleware = None
|
26
|
+
HTTPBasic = None
|
27
|
+
HTTPBasicCredentials = None
|
20
28
|
|
21
29
|
try:
|
22
30
|
from sentence_transformers import SentenceTransformer
|
@@ -25,8 +33,11 @@ except ImportError:
|
|
25
33
|
|
26
34
|
from .query_processor import preprocess_query
|
27
35
|
from .search_engine import SearchEngine
|
36
|
+
from signalwire_agents.core.security_config import SecurityConfig
|
37
|
+
from signalwire_agents.core.config_loader import ConfigLoader
|
38
|
+
from signalwire_agents.core.logging_config import get_logger
|
28
39
|
|
29
|
-
logger =
|
40
|
+
logger = get_logger("search_service")
|
30
41
|
|
31
42
|
# Pydantic models for API
|
32
43
|
if BaseModel:
|
@@ -73,14 +84,30 @@ else:
|
|
73
84
|
class SearchService:
|
74
85
|
"""Local search service with HTTP API"""
|
75
86
|
|
76
|
-
def __init__(self, port: int = 8001, indexes: Dict[str, str] = None
|
87
|
+
def __init__(self, port: int = 8001, indexes: Dict[str, str] = None,
|
88
|
+
basic_auth: Optional[Tuple[str, str]] = None,
|
89
|
+
config_file: Optional[str] = None):
|
90
|
+
# Load configuration first
|
91
|
+
self._load_config(config_file)
|
92
|
+
|
93
|
+
# Override with constructor params if provided
|
77
94
|
self.port = port
|
78
|
-
|
95
|
+
if indexes is not None:
|
96
|
+
self.indexes = indexes
|
97
|
+
|
79
98
|
self.search_engines = {}
|
80
99
|
self.model = None
|
81
100
|
|
101
|
+
# Load security configuration with optional config file
|
102
|
+
self.security = SecurityConfig(config_file=config_file, service_name="search")
|
103
|
+
self.security.log_config("SearchService")
|
104
|
+
|
105
|
+
# Set up authentication
|
106
|
+
self._basic_auth = basic_auth or self.security.get_basic_auth()
|
107
|
+
|
82
108
|
if FastAPI:
|
83
109
|
self.app = FastAPI(title="SignalWire Local Search Service")
|
110
|
+
self._setup_security()
|
84
111
|
self._setup_routes()
|
85
112
|
else:
|
86
113
|
self.app = None
|
@@ -88,22 +115,131 @@ class SearchService:
|
|
88
115
|
|
89
116
|
self._load_resources()
|
90
117
|
|
118
|
+
def _load_config(self, config_file: Optional[str]):
|
119
|
+
"""Load configuration from file if available"""
|
120
|
+
# Initialize defaults
|
121
|
+
self.indexes = {}
|
122
|
+
|
123
|
+
# Find config file
|
124
|
+
if not config_file:
|
125
|
+
config_file = ConfigLoader.find_config_file("search")
|
126
|
+
|
127
|
+
if not config_file:
|
128
|
+
return
|
129
|
+
|
130
|
+
# Load config
|
131
|
+
config_loader = ConfigLoader([config_file])
|
132
|
+
if not config_loader.has_config():
|
133
|
+
return
|
134
|
+
|
135
|
+
logger.info("loading_config_from_file", file=config_file)
|
136
|
+
|
137
|
+
# Get service section
|
138
|
+
service_config = config_loader.get_section('service')
|
139
|
+
if service_config:
|
140
|
+
if 'port' in service_config:
|
141
|
+
self.port = int(service_config['port'])
|
142
|
+
|
143
|
+
if 'indexes' in service_config and isinstance(service_config['indexes'], dict):
|
144
|
+
self.indexes = service_config['indexes']
|
145
|
+
|
146
|
+
def _setup_security(self):
|
147
|
+
"""Setup security middleware and authentication"""
|
148
|
+
if not self.app:
|
149
|
+
return
|
150
|
+
|
151
|
+
# Add CORS middleware if FastAPI has it
|
152
|
+
if CORSMiddleware:
|
153
|
+
self.app.add_middleware(
|
154
|
+
CORSMiddleware,
|
155
|
+
**self.security.get_cors_config()
|
156
|
+
)
|
157
|
+
|
158
|
+
# Add security headers middleware
|
159
|
+
@self.app.middleware("http")
|
160
|
+
async def add_security_headers(request: Request, call_next):
|
161
|
+
response = await call_next(request)
|
162
|
+
|
163
|
+
# Add security headers
|
164
|
+
is_https = request.url.scheme == "https"
|
165
|
+
headers = self.security.get_security_headers(is_https)
|
166
|
+
for header, value in headers.items():
|
167
|
+
response.headers[header] = value
|
168
|
+
|
169
|
+
return response
|
170
|
+
|
171
|
+
# Add host validation middleware
|
172
|
+
@self.app.middleware("http")
|
173
|
+
async def validate_host(request: Request, call_next):
|
174
|
+
host = request.headers.get("host", "").split(":")[0]
|
175
|
+
if host and not self.security.should_allow_host(host):
|
176
|
+
return Response(content="Invalid host", status_code=400)
|
177
|
+
|
178
|
+
return await call_next(request)
|
179
|
+
|
180
|
+
def _get_current_username(self, credentials: HTTPBasicCredentials = None) -> str:
|
181
|
+
"""Validate basic auth credentials"""
|
182
|
+
if not credentials:
|
183
|
+
return None
|
184
|
+
|
185
|
+
correct_username, correct_password = self._basic_auth
|
186
|
+
|
187
|
+
# Compare credentials
|
188
|
+
import secrets
|
189
|
+
username_correct = secrets.compare_digest(credentials.username, correct_username)
|
190
|
+
password_correct = secrets.compare_digest(credentials.password, correct_password)
|
191
|
+
|
192
|
+
if not (username_correct and password_correct):
|
193
|
+
raise HTTPException(
|
194
|
+
status_code=401,
|
195
|
+
detail="Invalid authentication credentials",
|
196
|
+
headers={"WWW-Authenticate": "Basic"},
|
197
|
+
)
|
198
|
+
|
199
|
+
return credentials.username
|
200
|
+
|
91
201
|
def _setup_routes(self):
|
92
202
|
"""Setup FastAPI routes"""
|
93
203
|
if not self.app:
|
94
204
|
return
|
205
|
+
|
206
|
+
# Create security dependency if HTTPBasic is available
|
207
|
+
security = HTTPBasic() if HTTPBasic else None
|
208
|
+
|
209
|
+
# Create dependency for authenticated routes
|
210
|
+
def get_authenticated():
|
211
|
+
if security:
|
212
|
+
return security
|
213
|
+
return None
|
95
214
|
|
96
215
|
@self.app.post("/search", response_model=SearchResponse)
|
97
|
-
async def search(
|
216
|
+
async def search(
|
217
|
+
request: SearchRequest,
|
218
|
+
credentials: HTTPBasicCredentials = None if not security else Depends(security)
|
219
|
+
):
|
220
|
+
if security:
|
221
|
+
self._get_current_username(credentials)
|
98
222
|
return await self._handle_search(request)
|
99
223
|
|
100
224
|
@self.app.get("/health")
|
101
225
|
async def health():
|
102
|
-
return {
|
226
|
+
return {
|
227
|
+
"status": "healthy",
|
228
|
+
"indexes": list(self.indexes.keys()),
|
229
|
+
"ssl_enabled": self.security.ssl_enabled,
|
230
|
+
"auth_required": bool(security)
|
231
|
+
}
|
103
232
|
|
104
233
|
@self.app.post("/reload_index")
|
105
|
-
async def reload_index(
|
234
|
+
async def reload_index(
|
235
|
+
index_name: str,
|
236
|
+
index_path: str,
|
237
|
+
credentials: HTTPBasicCredentials = None if not security else Depends(security)
|
238
|
+
):
|
106
239
|
"""Reload or add new index"""
|
240
|
+
if security:
|
241
|
+
self._get_current_username(credentials)
|
242
|
+
|
107
243
|
self.indexes[index_name] = index_path
|
108
244
|
self.search_engines[index_name] = SearchEngine(index_path, self.model)
|
109
245
|
return {"status": "reloaded", "index": index_name}
|
@@ -235,14 +371,67 @@ class SearchService:
|
|
235
371
|
'query_analysis': response.query_analysis
|
236
372
|
}
|
237
373
|
|
238
|
-
def start(self
|
239
|
-
|
374
|
+
def start(self, host: str = "0.0.0.0", port: Optional[int] = None,
|
375
|
+
ssl_cert: Optional[str] = None, ssl_key: Optional[str] = None):
|
376
|
+
"""
|
377
|
+
Start the service with optional HTTPS support.
|
378
|
+
|
379
|
+
Args:
|
380
|
+
host: Host to bind to (default: "0.0.0.0")
|
381
|
+
port: Port to bind to (default: self.port)
|
382
|
+
ssl_cert: Path to SSL certificate file (overrides environment)
|
383
|
+
ssl_key: Path to SSL key file (overrides environment)
|
384
|
+
"""
|
240
385
|
if not self.app:
|
241
386
|
raise RuntimeError("FastAPI not available. Cannot start HTTP service.")
|
242
387
|
|
388
|
+
port = port or self.port
|
389
|
+
|
390
|
+
# Get SSL configuration
|
391
|
+
ssl_kwargs = {}
|
392
|
+
if ssl_cert and ssl_key:
|
393
|
+
# Use provided SSL files
|
394
|
+
ssl_kwargs = {
|
395
|
+
'ssl_certfile': ssl_cert,
|
396
|
+
'ssl_keyfile': ssl_key
|
397
|
+
}
|
398
|
+
else:
|
399
|
+
# Use security config SSL settings
|
400
|
+
ssl_kwargs = self.security.get_ssl_context_kwargs()
|
401
|
+
|
402
|
+
# Build startup URL
|
403
|
+
scheme = "https" if ssl_kwargs else "http"
|
404
|
+
startup_url = f"{scheme}://{host}:{port}"
|
405
|
+
|
406
|
+
# Get auth credentials
|
407
|
+
username, password = self._basic_auth
|
408
|
+
|
409
|
+
# Log startup information
|
410
|
+
logger.info(
|
411
|
+
"starting_search_service",
|
412
|
+
url=startup_url,
|
413
|
+
ssl_enabled=bool(ssl_kwargs),
|
414
|
+
indexes=list(self.indexes.keys()),
|
415
|
+
username=username
|
416
|
+
)
|
417
|
+
|
418
|
+
# Print user-friendly startup message
|
419
|
+
print(f"\nSignalWire Search Service starting...")
|
420
|
+
print(f"URL: {startup_url}")
|
421
|
+
print(f"Indexes: {', '.join(self.indexes.keys()) if self.indexes else 'None'}")
|
422
|
+
print(f"Basic Auth: {username}:{password}")
|
423
|
+
if ssl_kwargs:
|
424
|
+
print(f"SSL: Enabled")
|
425
|
+
print("")
|
426
|
+
|
243
427
|
try:
|
244
428
|
import uvicorn
|
245
|
-
uvicorn.run(
|
429
|
+
uvicorn.run(
|
430
|
+
self.app,
|
431
|
+
host=host,
|
432
|
+
port=port,
|
433
|
+
**ssl_kwargs
|
434
|
+
)
|
246
435
|
except ImportError:
|
247
436
|
raise RuntimeError("uvicorn not available. Cannot start HTTP service.")
|
248
437
|
|