asap-protocol 0.1.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,156 @@
1
+ """Bounded thread pool executor for DoS prevention.
2
+
3
+ This module provides a bounded executor that limits the number of concurrent
4
+ threads used for executing synchronous handlers. This prevents resource
5
+ exhaustion attacks by rejecting requests when the thread pool is full.
6
+
7
+ Example:
8
+ >>> from asap.transport.executors import BoundedExecutor
9
+ >>> executor = BoundedExecutor(max_threads=10)
10
+ >>> result = await loop.run_in_executor(executor, sync_handler, arg1, arg2)
11
+ """
12
+
13
+ import os
14
+ from concurrent.futures import Executor, Future, ThreadPoolExecutor
15
+ from threading import Semaphore
16
+ from typing import Callable, TypeVar
17
+
18
+ from asap.errors import ThreadPoolExhaustedError
19
+ from asap.observability import get_logger, get_metrics
20
+
21
+ # Module logger
22
+ logger = get_logger(__name__)
23
+
24
+ # Type variable for function return type
25
+ T = TypeVar("T")
26
+
27
+
28
+ class BoundedExecutor(Executor):
29
+ """Thread pool executor with bounded capacity for DoS prevention.
30
+
31
+ This executor wraps a ThreadPoolExecutor and uses a semaphore to limit
32
+ the number of concurrent tasks. When the limit is reached, submitting
33
+ a new task raises ThreadPoolExhaustedError instead of queuing indefinitely.
34
+
35
+ The executor prevents resource exhaustion by:
36
+ - Limiting concurrent thread usage
37
+ - Rejecting new tasks when capacity is reached (fail-fast)
38
+ - Recording metrics for monitoring
39
+
40
+ Attributes:
41
+ _executor: Underlying ThreadPoolExecutor
42
+ _semaphore: Semaphore controlling concurrent access
43
+ max_threads: Maximum number of concurrent threads
44
+
45
+ Example:
46
+ >>> executor = BoundedExecutor(max_threads=10)
47
+ >>> result = await loop.run_in_executor(executor, my_sync_function, arg1)
48
+ """
49
+
50
+ def __init__(self, max_threads: int | None = None) -> None:
51
+ """Initialize bounded executor.
52
+
53
+ Args:
54
+ max_threads: Maximum number of concurrent threads.
55
+ Defaults to min(32, os.cpu_count() + 4) if None.
56
+
57
+ Raises:
58
+ ValueError: If max_threads is less than 1
59
+ """
60
+ if max_threads is None:
61
+ # Default: min(32, cpu_count + 4) following asyncio convention
62
+ cpu_count = os.cpu_count() or 1
63
+ max_threads = min(32, cpu_count + 4)
64
+
65
+ if max_threads < 1:
66
+ raise ValueError(f"max_threads must be >= 1, got {max_threads}")
67
+
68
+ self.max_threads = max_threads
69
+ self._executor = ThreadPoolExecutor(max_workers=max_threads)
70
+ self._semaphore = Semaphore(max_threads)
71
+
72
+ logger.info(
73
+ "asap.executor.created",
74
+ max_threads=max_threads,
75
+ cpu_count=os.cpu_count(),
76
+ )
77
+
78
+ def submit(self, fn: Callable[..., T], /, *args: object, **kwargs: object) -> Future[T]:
79
+ """Submit a function to be executed in the thread pool.
80
+
81
+ This method acquires a semaphore permit before submitting to the
82
+ executor. If no permit is available (pool is full), it raises
83
+ ThreadPoolExhaustedError instead of blocking.
84
+
85
+ The returned Future will automatically release the semaphore permit
86
+ when the task completes (successfully or with an error).
87
+
88
+ Args:
89
+ fn: Function to execute
90
+ *args: Positional arguments for the function
91
+ **kwargs: Keyword arguments for the function
92
+
93
+ Returns:
94
+ Future representing the execution of the function
95
+
96
+ Raises:
97
+ ThreadPoolExhaustedError: If thread pool is exhausted
98
+
99
+ Note:
100
+ This method returns immediately with a Future. The function
101
+ execution happens asynchronously in the thread pool.
102
+ """
103
+ # Try to acquire semaphore (non-blocking check)
104
+ if not self._semaphore.acquire(blocking=False):
105
+ # Pool is exhausted - record metric and raise error
106
+ # We know the pool is full since acquire failed, so active_threads = max_threads
107
+ active_threads = self.max_threads
108
+ metrics = get_metrics()
109
+ metrics.increment_counter(
110
+ "asap_thread_pool_exhausted_total",
111
+ labels={},
112
+ value=1.0,
113
+ )
114
+
115
+ logger.warning(
116
+ "asap.executor.exhausted",
117
+ max_threads=self.max_threads,
118
+ active_threads=active_threads,
119
+ )
120
+
121
+ raise ThreadPoolExhaustedError(
122
+ max_threads=self.max_threads,
123
+ active_threads=active_threads,
124
+ )
125
+
126
+ # Submit to executor
127
+ future = self._executor.submit(fn, *args, **kwargs)
128
+
129
+ # Wrap future to release semaphore when done
130
+ def release_on_done(f: Future[T]) -> None:
131
+ """Release semaphore permit when future completes."""
132
+ # Future is already done when callback is called, just release semaphore
133
+ self._semaphore.release()
134
+
135
+ # Add callback to release semaphore when future completes
136
+ future.add_done_callback(release_on_done)
137
+
138
+ return future
139
+
140
+ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
141
+ """Shutdown the executor and release resources.
142
+
143
+ Args:
144
+ wait: If True, wait for all pending tasks to complete
145
+ cancel_futures: If True, cancel pending futures (Python 3.9+)
146
+ """
147
+ self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
148
+ logger.info("asap.executor.shutdown", max_threads=self.max_threads)
149
+
150
+ def __enter__(self) -> "BoundedExecutor":
151
+ """Context manager entry."""
152
+ return self
153
+
154
+ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
155
+ """Context manager exit - shutdown executor."""
156
+ self.shutdown(wait=True)
@@ -30,8 +30,9 @@ import asyncio
30
30
  import inspect
31
31
  import time
32
32
  from collections.abc import Awaitable
33
+ from concurrent.futures import Executor
33
34
  from threading import RLock
34
- from typing import Protocol
35
+ from typing import Protocol, cast
35
36
 
36
37
  from asap.errors import ASAPError
37
38
  from asap.models.entities import Manifest
@@ -142,6 +143,7 @@ class HandlerRegistry:
142
143
  Attributes:
143
144
  _handlers: Internal mapping of payload_type to handler function
144
145
  _lock: Reentrant lock for thread-safe operations
146
+ _executor: Optional executor for running sync handlers (for DoS prevention)
145
147
 
146
148
  Example:
147
149
  >>> registry = HandlerRegistry()
@@ -151,10 +153,17 @@ class HandlerRegistry:
151
153
  >>> response = registry.dispatch(envelope, manifest)
152
154
  """
153
155
 
154
- def __init__(self) -> None:
155
- """Initialize empty handler registry with thread-safe lock."""
156
+ def __init__(self, executor: Executor | None = None) -> None:
157
+ """Initialize empty handler registry with thread-safe lock.
158
+
159
+ Args:
160
+ executor: Optional executor for running sync handlers.
161
+ If None, uses default asyncio executor (unbounded).
162
+ Should be a BoundedExecutor instance for DoS prevention.
163
+ """
156
164
  self._handlers: dict[str, Handler] = {}
157
165
  self._lock = RLock()
166
+ self._executor: Executor | None = executor
158
167
 
159
168
  def register(self, payload_type: str, handler: Handler) -> None:
160
169
  """Register a handler for a payload type.
@@ -333,13 +342,16 @@ class HandlerRegistry:
333
342
  # Sync handler - run in thread pool to avoid blocking event loop
334
343
  # Also handle async callable objects that return awaitables
335
344
  loop = asyncio.get_event_loop()
336
- result: object = await loop.run_in_executor(None, handler, envelope, manifest)
345
+ # Use bounded executor if provided, otherwise use default (unbounded)
346
+ executor = self._executor if self._executor is not None else None
347
+ result: object = await loop.run_in_executor(executor, handler, envelope, manifest)
337
348
  # Check if result is awaitable (handles async __call__ methods)
338
349
  if inspect.isawaitable(result):
339
350
  response = await result
340
351
  else:
341
352
  # Type narrowing: result is Envelope for sync handlers
342
- response = result # type: ignore[assignment]
353
+ # After checking it's not awaitable, we know it's Envelope
354
+ response = cast(Envelope, result)
343
355
 
344
356
  duration_ms = (time.perf_counter() - start_time) * 1000
345
357
  logger.debug(
@@ -1,10 +1,11 @@
1
- """Authentication middleware for ASAP protocol server.
1
+ """Authentication and rate limiting middleware for ASAP protocol server.
2
2
 
3
- This module provides authentication middleware that:
3
+ This module provides middleware that:
4
4
  - Validates Bearer tokens based on manifest configuration
5
5
  - Verifies sender identity matches authenticated agent
6
6
  - Supports custom token validation logic
7
7
  - Returns proper JSON-RPC error responses for auth failures
8
+ - Implements IP-based rate limiting to prevent DoS attacks
8
9
 
9
10
  Example:
10
11
  >>> from asap.transport.middleware import AuthenticationMiddleware, BearerTokenValidator
@@ -31,28 +32,237 @@ Example:
31
32
  >>> middleware = AuthenticationMiddleware(manifest, validator)
32
33
  """
33
34
 
34
- import hashlib
35
- from typing import Callable, Protocol
35
+ import uuid
36
+ from typing import Any, Awaitable, Callable, Protocol
37
+ from collections.abc import Sequence
36
38
 
37
39
  from fastapi import HTTPException, Request
40
+ from fastapi.responses import JSONResponse
38
41
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
42
+ from slowapi import Limiter
43
+ from slowapi.errors import RateLimitExceeded
44
+ from slowapi.util import get_remote_address
45
+ from starlette.middleware.base import BaseHTTPMiddleware
39
46
 
40
47
  from asap.models.entities import Manifest
41
48
  from asap.observability import get_logger
49
+ from asap.utils.sanitization import sanitize_token
42
50
 
43
51
  logger = get_logger(__name__)
44
52
 
45
53
  # Authentication header scheme
46
54
  AUTH_SCHEME_BEARER = "bearer"
47
55
 
56
+ # Rate limiting default configuration
57
+ DEFAULT_RATE_LIMIT = "100/minute"
58
+
59
+
60
+ def _get_sender_from_envelope(request: Request) -> str:
61
+ """Extract identifier from request for rate limiting.
62
+
63
+ This function implements IP-based rate limiting for the transport layer.
64
+ The rate limiter executes before the route handler parses the request body,
65
+ so the ASAP envelope is not yet available at rate limit check time.
66
+ Therefore, this function primarily returns the client IP address.
67
+
68
+ The function attempts to extract the sender from the envelope if already
69
+ parsed (for future compatibility), but in practice always falls back to
70
+ the client IP address. This IP-based approach is safer for DoS prevention
71
+ as it doesn't require parsing the request body before rate limiting.
72
+
73
+ Args:
74
+ request: FastAPI request object
75
+
76
+ Returns:
77
+ Client IP address (used as rate limiting key)
78
+
79
+ Example:
80
+ >>> sender = _get_sender_from_envelope(request)
81
+ >>> # Returns "192.168.1.1" (IP address, not sender URN)
82
+ """
83
+ # Try to extract sender from envelope if already parsed (early returns reduce complexity)
84
+ try:
85
+ # Check if envelope is stored in request state (after parsing)
86
+ if hasattr(request.state, "envelope") and request.state.envelope:
87
+ envelope = request.state.envelope
88
+ if hasattr(envelope, "sender") and isinstance(envelope.sender, str):
89
+ return envelope.sender
90
+
91
+ # Try to extract from JSON-RPC request if already parsed
92
+ if hasattr(request.state, "rpc_request"):
93
+ rpc_request = request.state.rpc_request
94
+ if (
95
+ hasattr(rpc_request, "params")
96
+ and isinstance(rpc_request.params, dict)
97
+ and "envelope" in rpc_request.params
98
+ ):
99
+ envelope_data = rpc_request.params.get("envelope")
100
+ if isinstance(envelope_data, dict) and "sender" in envelope_data:
101
+ sender = envelope_data["sender"]
102
+ if isinstance(sender, str):
103
+ return sender
104
+ except (AttributeError, KeyError, TypeError):
105
+ # Envelope not available, fall back to IP
106
+ pass
107
+
108
+ # Fallback to client IP address
109
+ remote_addr = get_remote_address(request)
110
+ # Type narrowing: get_remote_address returns str, but mypy may see it as Any
111
+ if isinstance(remote_addr, str):
112
+ return remote_addr
113
+ return str(remote_addr)
114
+
115
+
116
+ # Create rate limiter instance with IP-based key function
117
+ # Note: The key function attempts to extract sender but always falls back to IP
118
+ # because rate limiting executes before request body parsing
119
+ limiter = Limiter(
120
+ key_func=_get_sender_from_envelope,
121
+ default_limits=[DEFAULT_RATE_LIMIT],
122
+ storage_uri="memory://",
123
+ )
124
+
125
+
126
+ def create_test_limiter(limits: Sequence[str] | None = None) -> Limiter:
127
+ """Create a new limiter instance for testing isolation.
128
+
129
+ This allows tests to use isolated rate limiters to avoid interference
130
+ between test cases.
131
+
132
+ Args:
133
+ limits: Optional list of rate limit strings. Defaults to high limits for testing.
134
+
135
+ Returns:
136
+ New Limiter instance with isolated storage
137
+
138
+ Example:
139
+ >>> test_limiter = create_test_limiter(["100000/minute"])
140
+ >>> app.state.limiter = test_limiter
141
+ """
142
+ if limits is None:
143
+ limits = ["100000/minute"] # Very high limit for testing
144
+
145
+ # Use unique storage URI to ensure complete isolation between test instances
146
+ unique_storage_id = str(uuid.uuid4())
147
+ return Limiter(
148
+ key_func=_get_sender_from_envelope,
149
+ default_limits=list(limits),
150
+ storage_uri=f"memory://{unique_storage_id}", # Each instance gets its own memory storage
151
+ )
152
+
153
+
154
+ def create_limiter(limits: Sequence[str] | None = None) -> Limiter:
155
+ """Create a new limiter instance for production use.
156
+
157
+ Creates an isolated limiter instance with its own storage, allowing
158
+ multiple FastAPI app instances to have independent rate limiters.
159
+
160
+ Args:
161
+ limits: Optional list of rate limit strings (e.g., ["100/minute"]).
162
+ Defaults to DEFAULT_RATE_LIMIT if not provided.
163
+
164
+ Returns:
165
+ New Limiter instance with isolated storage
166
+
167
+ Example:
168
+ >>> limiter = create_limiter(["100/minute"])
169
+ >>> app.state.limiter = limiter
170
+ """
171
+ if limits is None:
172
+ limits = [DEFAULT_RATE_LIMIT]
173
+
174
+ # Use unique storage URI to ensure isolation between app instances
175
+ unique_storage_id = str(uuid.uuid4())
176
+ return Limiter(
177
+ key_func=_get_sender_from_envelope,
178
+ default_limits=list(limits),
179
+ storage_uri=f"memory://{unique_storage_id}",
180
+ )
181
+
182
+
183
+ def rate_limit_handler(request: Request, exc: Exception) -> JSONResponse:
184
+ """Handle rate limit exceeded exceptions with JSON-RPC formatted error.
185
+
186
+ Returns a JSON-RPC 2.0 compliant error response with HTTP 429 status
187
+ and Retry-After header indicating when the client can retry.
188
+
189
+ Args:
190
+ request: FastAPI request object
191
+ exc: RateLimitExceeded exception (typed as Exception for FastAPI compatibility)
192
+
193
+ Returns:
194
+ JSONResponse with JSON-RPC error format and 429 status code
195
+
196
+ Example:
197
+ >>> response = rate_limit_handler(request, exc)
198
+ >>> # Returns JSONResponse with status_code=429 and JSON-RPC error
199
+ """
200
+ # Type narrowing: FastAPI passes RateLimitExceeded but handler signature uses Exception
201
+ if not isinstance(exc, RateLimitExceeded):
202
+ # Fallback for unexpected exception types
203
+ logger.warning("asap.rate_limit.unexpected_exception", exc_type=type(exc).__name__)
204
+ return JSONResponse(
205
+ status_code=HTTP_TOO_MANY_REQUESTS,
206
+ content={
207
+ "jsonrpc": "2.0",
208
+ "id": getattr(request.state, "request_id", None),
209
+ "error": {
210
+ "code": HTTP_TOO_MANY_REQUESTS,
211
+ "message": ERROR_RATE_LIMIT_EXCEEDED,
212
+ },
213
+ },
214
+ )
215
+
216
+ # Calculate retry_after from exception or use default
217
+ retry_after = 60 # Default to 60 seconds
218
+ if hasattr(exc, "retry_after") and exc.retry_after is not None:
219
+ try:
220
+ retry_after = int(exc.retry_after)
221
+ except (ValueError, TypeError):
222
+ retry_after = 60
223
+
224
+ # Get limit information if available
225
+ limit_str = DEFAULT_RATE_LIMIT
226
+ if hasattr(exc, "limit") and exc.limit is not None:
227
+ limit_str = str(exc.limit)
228
+
229
+ logger.warning(
230
+ "asap.rate_limit.exceeded",
231
+ sender=_get_sender_from_envelope(request),
232
+ retry_after=retry_after,
233
+ limit=limit_str,
234
+ )
235
+
236
+ # Return JSON-RPC 2.0 formatted error response
237
+ return JSONResponse(
238
+ status_code=HTTP_TOO_MANY_REQUESTS,
239
+ content={
240
+ "jsonrpc": "2.0",
241
+ "id": getattr(request.state, "request_id", None),
242
+ "error": {
243
+ "code": HTTP_TOO_MANY_REQUESTS,
244
+ "message": ERROR_RATE_LIMIT_EXCEEDED,
245
+ "data": {
246
+ "retry_after": retry_after,
247
+ "limit": limit_str,
248
+ },
249
+ },
250
+ },
251
+ headers={"Retry-After": str(retry_after)},
252
+ )
253
+
254
+
48
255
  # HTTP status codes
49
256
  HTTP_UNAUTHORIZED = 401
50
257
  HTTP_FORBIDDEN = 403
258
+ HTTP_TOO_MANY_REQUESTS = 429
51
259
 
52
260
  # Error messages
53
261
  ERROR_AUTH_REQUIRED = "Authentication required"
54
- ERROR_INVALID_TOKEN = "Invalid authentication token"
262
+ # nosec B105: This is an error message constant, not a hardcoded password
263
+ ERROR_INVALID_TOKEN = "Invalid authentication token" # nosec B105
55
264
  ERROR_SENDER_MISMATCH = "Sender does not match authenticated identity"
265
+ ERROR_RATE_LIMIT_EXCEEDED = "Rate limit exceeded"
56
266
 
57
267
 
58
268
  class TokenValidator(Protocol):
@@ -289,12 +499,12 @@ class AuthenticationMiddleware:
289
499
  agent_id = self.validator(token)
290
500
 
291
501
  if agent_id is None:
292
- # Log token hash instead of prefix to avoid exposing token data
293
- token_hash = hashlib.sha256(token.encode()).hexdigest()[:16]
502
+ # Log sanitized token to avoid exposing full token data
503
+ token_prefix = sanitize_token(token)
294
504
  logger.warning(
295
505
  "asap.auth.invalid_token",
296
506
  manifest_id=self.manifest.id,
297
- token_hash=token_hash,
507
+ token_prefix=token_prefix,
298
508
  )
299
509
  raise HTTPException(
300
510
  status_code=HTTP_UNAUTHORIZED,
@@ -357,3 +567,89 @@ class AuthenticationMiddleware:
357
567
  "asap.auth.sender_verified",
358
568
  authenticated_agent=authenticated_agent_id,
359
569
  )
570
+
571
+
572
+ class SizeLimitMiddleware(BaseHTTPMiddleware):
573
+ """Middleware to validate request size before routing.
574
+
575
+ This middleware checks the Content-Length header and rejects requests
576
+ that exceed the maximum allowed size before any routing logic executes.
577
+ This provides early rejection and prevents unnecessary processing.
578
+
579
+ The middleware validates the Content-Length header only. Actual body
580
+ size validation during parsing (with streaming) is handled in the
581
+ route handler to prevent OOM attacks.
582
+
583
+ Attributes:
584
+ max_size: Maximum allowed request size in bytes
585
+
586
+ Example:
587
+ >>> from asap.transport.middleware import SizeLimitMiddleware
588
+ >>> app.add_middleware(SizeLimitMiddleware, max_size=10 * 1024 * 1024)
589
+ """
590
+
591
+ def __init__(self, app: Any, max_size: int) -> None:
592
+ """Initialize size limit middleware.
593
+
594
+ Args:
595
+ app: The ASGI application
596
+ max_size: Maximum allowed request size in bytes
597
+
598
+ Raises:
599
+ ValueError: If max_size is less than 1
600
+ """
601
+ if max_size < 1:
602
+ raise ValueError(f"max_size must be >= 1, got {max_size}")
603
+ super().__init__(app)
604
+ self.max_size = max_size
605
+
606
+ async def dispatch(
607
+ self, request: Request, call_next: Callable[[Request], Awaitable[Any]]
608
+ ) -> Any:
609
+ """Process request and validate size before routing.
610
+
611
+ Args:
612
+ request: FastAPI request object
613
+ call_next: Next middleware or route handler
614
+
615
+ Returns:
616
+ Response from next handler or error response if size exceeded
617
+ """
618
+ # Check Content-Length header if present
619
+ content_length = request.headers.get("content-length")
620
+ if content_length:
621
+ try:
622
+ size = int(content_length)
623
+ if size > self.max_size:
624
+ logger.warning(
625
+ "asap.request.size_exceeded",
626
+ content_length=size,
627
+ max_size=self.max_size,
628
+ )
629
+ # Return JSON response directly (middleware runs before route handlers)
630
+ return JSONResponse(
631
+ status_code=413,
632
+ content={
633
+ "detail": f"Request size ({size} bytes) exceeds maximum ({self.max_size} bytes)"
634
+ },
635
+ )
636
+ except ValueError:
637
+ # Invalid Content-Length header, let route handler validate actual body size
638
+ pass
639
+
640
+ # Continue to next middleware or route handler
641
+ return await call_next(request)
642
+
643
+
644
+ # Export rate limiting components
645
+ __all__ = [
646
+ "AuthenticationMiddleware",
647
+ "BearerTokenValidator",
648
+ "TokenValidator",
649
+ "SizeLimitMiddleware",
650
+ "limiter",
651
+ "rate_limit_handler",
652
+ "create_limiter",
653
+ "create_test_limiter",
654
+ "_get_sender_from_envelope",
655
+ ]