dataenginex 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. dataenginex/README.md +35 -0
  2. dataenginex/RELEASE_NOTES.md +38 -0
  3. dataenginex/__init__.py +16 -0
  4. dataenginex/api/__init__.py +11 -0
  5. dataenginex/api/auth.py +173 -0
  6. dataenginex/api/errors.py +70 -0
  7. dataenginex/api/health.py +133 -0
  8. dataenginex/api/pagination.py +94 -0
  9. dataenginex/api/rate_limit.py +122 -0
  10. dataenginex/api/routers/__init__.py +1 -0
  11. dataenginex/api/routers/v1.py +113 -0
  12. dataenginex/core/__init__.py +36 -0
  13. dataenginex/core/medallion_architecture.py +414 -0
  14. dataenginex/core/pipeline_config.py +111 -0
  15. dataenginex/core/schemas.py +304 -0
  16. dataenginex/core/validators.py +394 -0
  17. dataenginex/data/__init__.py +22 -0
  18. dataenginex/data/connectors.py +332 -0
  19. dataenginex/data/profiler.py +217 -0
  20. dataenginex/data/registry.py +148 -0
  21. dataenginex/lakehouse/__init__.py +22 -0
  22. dataenginex/lakehouse/catalog.py +145 -0
  23. dataenginex/lakehouse/partitioning.py +99 -0
  24. dataenginex/lakehouse/storage.py +177 -0
  25. dataenginex/middleware/__init__.py +19 -0
  26. dataenginex/middleware/logging_config.py +137 -0
  27. dataenginex/middleware/metrics.py +45 -0
  28. dataenginex/middleware/metrics_middleware.py +61 -0
  29. dataenginex/middleware/request_logging.py +77 -0
  30. dataenginex/middleware/tracing.py +87 -0
  31. dataenginex/ml/__init__.py +28 -0
  32. dataenginex/ml/drift.py +165 -0
  33. dataenginex/ml/registry.py +156 -0
  34. dataenginex/ml/serving.py +141 -0
  35. dataenginex/ml/training.py +205 -0
  36. dataenginex/warehouse/__init__.py +19 -0
  37. dataenginex/warehouse/lineage.py +164 -0
  38. dataenginex/warehouse/transforms.py +206 -0
  39. dataenginex-0.3.4.dist-info/METADATA +66 -0
  40. dataenginex-0.3.4.dist-info/RECORD +41 -0
  41. dataenginex-0.3.4.dist-info/WHEEL +4 -0
dataenginex/README.md ADDED
@@ -0,0 +1,35 @@
1
+ # dataenginex
2
+
3
+ `dataenginex` is the core DataEngineX framework package for building observable, production-ready data and API services.
4
+
5
+ It provides:
6
+ - FastAPI application primitives and API extensions
7
+ - Middleware for structured logging, metrics, and tracing
8
+ - Data quality and validation utilities
9
+ - Lakehouse and warehouse building blocks
10
+ - Reusable ML support modules for model-serving workflows
11
+
12
+ ## Install
13
+
14
+ ```bash
15
+ pip install dataenginex
16
+ ```
17
+
18
+ ## Package Scope
19
+
20
+ This package is the core library from the DEX monorepo.
21
+ `careerdex` and `weatherdex` are maintained in the same repository but are not part of this package release flow.
22
+
23
+ ## Quick Usage
24
+
25
+ ```python
26
+ from dataenginex import __version__
27
+
28
+ print(__version__)
29
+ ```
30
+
31
+ ## Source and Docs
32
+
33
+ - Repository: https://github.com/data-literate/DEX
34
+ - CI/CD guide: `docs/CI_CD.md`
35
+ - Release notes: `packages/dataenginex/src/dataenginex/RELEASE_NOTES.md`
@@ -0,0 +1,38 @@
1
+ # dataenginex Release Notes
2
+
3
+ This document tracks published package releases for `dataenginex` only.
4
+ Only include changes that modify files under `packages/dataenginex/src/dataenginex/**`.
5
+
6
+ ## v0.3.4 - 2026-02-20
7
+
8
+ - Released package version `0.3.4`.
9
+ - Tag: `v0.3.4`
10
+ - Release title: `Release v0.3.4`
11
+ - Changes in this release:
12
+ - Repo hygiene updates after monorepo/package-layout migration.
13
+ - Canonicalized package/app path references across docs and workflow guidance.
14
+ - Removed standalone `careerdex` and `weatherdex` package scaffolds from `packages/`.
15
+ - Updated CI/package validation and project metadata/docs to align with current structure.
16
+
17
+ ## v0.3.3 - 2026-02-16
18
+
19
+ - Released package version `0.3.3`.
20
+ - Tag: `v0.3.3`
21
+ - Release title: `Release v0.3.3`
22
+ - Changes in this release (`packages/dataenginex/src/dataenginex` only):
23
+ - Significant package expansion and refactor across API, core, data, lakehouse, middleware, ML, and warehouse modules.
24
+ - Added/expanded API modules including auth, pagination, rate limiting, and v1 router wiring.
25
+ - Consolidated middleware layout (logging, metrics, tracing) and moved logging config under middleware.
26
+ - Package diff from `v0.2.0` to `v0.3.3`: 39 files changed, 4213 insertions, 247 deletions.
27
+
28
+ ## v0.2.0 - 2026-02-12
29
+
30
+ - Released package version `0.2.0`.
31
+ - Tag: `v0.2.0`
32
+ - Release title: `Release v0.2.0 - Production Hardening`
33
+ - Changes in this release (`packages/dataenginex/src/dataenginex` only):
34
+ - Established core package structure and initial module organization.
35
+ - Added API readiness/health behavior and improved probe handling.
36
+ - Added structured request logging with request ID tracking.
37
+ - Added Prometheus metrics and OpenTelemetry tracing support in package middleware.
38
+ - Added validation and error-handling improvements in API/core paths.
@@ -0,0 +1,16 @@
1
+ """
2
+ DataEngineX (DEX) — Core framework for data engineering projects.
3
+
4
+ Submodules:
5
+ api – FastAPI application, health checks, error handling
6
+ core – Schemas, validators, medallion architecture, pipeline config
7
+ middleware – Logging, metrics, tracing, request middleware
8
+ """
9
+
10
+ from importlib.metadata import PackageNotFoundError, version
11
+
12
+ try:
13
+ __version__ = version("dataenginex")
14
+ except PackageNotFoundError:
15
+ __version__ = "0.3.4"
16
+
@@ -0,0 +1,11 @@
1
+ """Reusable API components (auth, health, errors) for product services."""
2
+
3
+ from .errors import APIHTTPException, ServiceUnavailableError # noqa: F401
4
+ from .health import HealthChecker, HealthStatus # noqa: F401
5
+
6
+ __all__ = [
7
+ "HealthChecker",
8
+ "HealthStatus",
9
+ "APIHTTPException",
10
+ "ServiceUnavailableError",
11
+ ]
@@ -0,0 +1,173 @@
1
+ """
2
+ JWT authentication middleware for DEX API.
3
+
4
+ Provides a ``JWTAuth`` dependency for FastAPI that validates bearer tokens
5
+ using HMAC-SHA256 (symmetric) or RSA (asymmetric via ``DEX_JWT_PUBLIC_KEY``).
6
+
7
+ Configuration is via environment variables:
8
+ DEX_JWT_SECRET — HMAC shared secret (required unless RSA key is set)
9
+ DEX_JWT_ALGORITHM — Algorithm (default HS256)
10
+ DEX_AUTH_ENABLED — "true" to enforce auth (default "false")
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import hmac
16
+ import json
17
+ import os
18
+ import time
19
+ from base64 import urlsafe_b64decode, urlsafe_b64encode
20
+ from dataclasses import dataclass
21
+ from hashlib import sha256
22
+ from typing import Any
23
+
24
+ from fastapi import Request
25
+ from fastapi.responses import JSONResponse
26
+ from loguru import logger
27
+ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
28
+ from starlette.responses import Response
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Token helpers (pure-Python HS256 — no external ``pyjwt`` needed)
32
+ # ---------------------------------------------------------------------------
33
+
34
+
35
+ def _b64url_decode(data: str) -> bytes:
36
+ padding = 4 - len(data) % 4
37
+ return urlsafe_b64decode(data + "=" * padding)
38
+
39
+
40
+ def _b64url_encode(data: bytes) -> str:
41
+ return urlsafe_b64encode(data).rstrip(b"=").decode()
42
+
43
+
44
+ def create_token(payload: dict[str, Any], secret: str, ttl: int = 3600) -> str:
45
+ """Create a HS256 JWT token.
46
+
47
+ Parameters
48
+ ----------
49
+ payload:
50
+ Claims dict (e.g. ``{"sub": "user123", "roles": ["admin"]}``).
51
+ secret:
52
+ HMAC shared secret.
53
+ ttl:
54
+ Time-to-live in seconds (default 1 hour).
55
+ """
56
+ header = {"alg": "HS256", "typ": "JWT"}
57
+ now = int(time.time())
58
+ payload = {**payload, "iat": now, "exp": now + ttl}
59
+
60
+ segments = [
61
+ _b64url_encode(json.dumps(header).encode()),
62
+ _b64url_encode(json.dumps(payload, default=str).encode()),
63
+ ]
64
+ signing_input = f"{segments[0]}.{segments[1]}"
65
+ signature = hmac.new(secret.encode(), signing_input.encode(), sha256).digest()
66
+ segments.append(_b64url_encode(signature))
67
+ return ".".join(segments)
68
+
69
+
70
+ def decode_token(token: str, secret: str) -> dict[str, Any]:
71
+ """Decode and verify a HS256 JWT token. Raises ``ValueError`` on failure."""
72
+ parts = token.split(".")
73
+ if len(parts) != 3:
74
+ raise ValueError("Malformed JWT")
75
+
76
+ signing_input = f"{parts[0]}.{parts[1]}"
77
+ expected_sig = hmac.new(secret.encode(), signing_input.encode(), sha256).digest()
78
+ actual_sig = _b64url_decode(parts[2])
79
+
80
+ if not hmac.compare_digest(expected_sig, actual_sig):
81
+ raise ValueError("Invalid JWT signature")
82
+
83
+ payload: dict[str, Any] = json.loads(_b64url_decode(parts[1]))
84
+ exp = payload.get("exp")
85
+ if exp is not None and int(exp) < int(time.time()):
86
+ raise ValueError("Token expired")
87
+
88
+ return payload
89
+
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # Dataclass carrying the authenticated user info
93
+ # ---------------------------------------------------------------------------
94
+
95
+
96
+ @dataclass
97
+ class AuthUser:
98
+ """Resolved identity from a valid JWT."""
99
+
100
+ sub: str
101
+ roles: list[str]
102
+ claims: dict[str, Any]
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # FastAPI middleware
107
+ # ---------------------------------------------------------------------------
108
+
109
+ # Paths that never require authentication
110
+ _PUBLIC_PATHS: set[str] = {
111
+ "/",
112
+ "/health",
113
+ "/ready",
114
+ "/startup",
115
+ "/metrics",
116
+ "/docs",
117
+ "/redoc",
118
+ "/openapi.json",
119
+ "/openapi.yaml",
120
+ }
121
+
122
+
123
+ class AuthMiddleware(BaseHTTPMiddleware):
124
+ """Starlette middleware that enforces JWT auth when enabled.
125
+
126
+ When ``DEX_AUTH_ENABLED`` is ``"true"`` (case-insensitive), every request
127
+ to a non-public path must carry a valid ``Authorization: Bearer <token>``
128
+ header. The decoded claims are stored on ``request.state.auth_user``.
129
+ """
130
+
131
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
132
+ enabled = os.getenv("DEX_AUTH_ENABLED", "false").lower() == "true"
133
+ if not enabled:
134
+ return await call_next(request)
135
+
136
+ # Skip public endpoints
137
+ if request.url.path in _PUBLIC_PATHS:
138
+ return await call_next(request)
139
+
140
+ secret = os.getenv("DEX_JWT_SECRET", "")
141
+ if not secret:
142
+ logger.error("DEX_AUTH_ENABLED=true but DEX_JWT_SECRET is not set")
143
+ return JSONResponse(
144
+ status_code=500,
145
+ content={"error": "auth_config_error", "message": "Auth secret not configured"},
146
+ )
147
+
148
+ auth_header = request.headers.get("Authorization", "")
149
+ if not auth_header.startswith("Bearer "):
150
+ return JSONResponse(
151
+ status_code=401,
152
+ content={"error": "unauthorized", "message": "Missing bearer token"},
153
+ )
154
+
155
+ token = auth_header[7:]
156
+ try:
157
+ claims = decode_token(token, secret)
158
+ except ValueError:
159
+ logger.exception("JWT validation failed")
160
+ return JSONResponse(
161
+ status_code=401,
162
+ content={
163
+ "error": "unauthorized",
164
+ "message": "Invalid or expired authentication token",
165
+ },
166
+ )
167
+
168
+ request.state.auth_user = AuthUser(
169
+ sub=claims.get("sub", "anonymous"),
170
+ roles=claims.get("roles", []),
171
+ claims=claims,
172
+ )
173
+ return await call_next(request)
@@ -0,0 +1,70 @@
1
+ """Custom exception types for API error handling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from fastapi import HTTPException, status
6
+
7
+ from dataenginex.core.schemas import ErrorDetail
8
+
9
+
10
+ class APIHTTPException(HTTPException):
11
+ """Base HTTP exception with error code and details."""
12
+
13
+ def __init__(
14
+ self,
15
+ status_code: int,
16
+ message: str,
17
+ code: str = "api_error",
18
+ details: list[ErrorDetail] | None = None,
19
+ ) -> None:
20
+ self.code = code
21
+ self.details = details
22
+ super().__init__(status_code=status_code, detail=message)
23
+
24
+
25
+ class BadRequestError(APIHTTPException):
26
+ """Raised for 400 validation or malformed requests."""
27
+
28
+ def __init__(
29
+ self,
30
+ message: str = "Bad request",
31
+ details: list[ErrorDetail] | None = None,
32
+ ) -> None:
33
+ super().__init__(
34
+ status_code=status.HTTP_400_BAD_REQUEST,
35
+ message=message,
36
+ code="bad_request",
37
+ details=details,
38
+ )
39
+
40
+
41
+ class NotFoundError(APIHTTPException):
42
+ """Raised for 404 not found errors."""
43
+
44
+ def __init__(
45
+ self,
46
+ message: str = "Resource not found",
47
+ details: list[ErrorDetail] | None = None,
48
+ ) -> None:
49
+ super().__init__(
50
+ status_code=status.HTTP_404_NOT_FOUND,
51
+ message=message,
52
+ code="not_found",
53
+ details=details,
54
+ )
55
+
56
+
57
+ class ServiceUnavailableError(APIHTTPException):
58
+ """Raised when a dependency is unavailable."""
59
+
60
+ def __init__(
61
+ self,
62
+ message: str = "Service unavailable",
63
+ details: list[ErrorDetail] | None = None,
64
+ ) -> None:
65
+ super().__init__(
66
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
67
+ message=message,
68
+ code="service_unavailable",
69
+ details=details,
70
+ )
@@ -0,0 +1,133 @@
1
+ """Health check utilities for DEX."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import os
7
+ import time
8
+ from dataclasses import dataclass
9
+ from enum import StrEnum
10
+
11
+ import httpx
12
+
13
+
14
+ class HealthStatus(StrEnum):
15
+ """Supported health statuses."""
16
+
17
+ HEALTHY = "healthy"
18
+ DEGRADED = "degraded"
19
+ UNHEALTHY = "unhealthy"
20
+ SKIPPED = "skipped"
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class ComponentHealth:
25
+ name: str
26
+ status: HealthStatus
27
+ message: str | None = None
28
+ duration_ms: float | None = None
29
+
30
+ def to_dict(self) -> dict[str, object | None]:
31
+ return {
32
+ "name": self.name,
33
+ "status": self.status.value,
34
+ "message": self.message,
35
+ "duration_ms": self.duration_ms,
36
+ }
37
+
38
+
39
+ class HealthChecker:
40
+ """Runs health checks for DEX dependencies."""
41
+
42
+ def __init__(self, timeout_seconds: float = 1.0) -> None:
43
+ self.timeout_seconds = timeout_seconds
44
+
45
+ async def check_all(self) -> list[ComponentHealth]:
46
+ return [
47
+ await self.check_database(),
48
+ await self.check_cache(),
49
+ await self.check_external_api(),
50
+ ]
51
+
52
+ def overall_status(self, components: list[ComponentHealth]) -> HealthStatus:
53
+ if any(c.status == HealthStatus.UNHEALTHY for c in components):
54
+ return HealthStatus.UNHEALTHY
55
+ if any(c.status == HealthStatus.DEGRADED for c in components):
56
+ return HealthStatus.DEGRADED
57
+ return HealthStatus.HEALTHY
58
+
59
+ async def check_database(self) -> ComponentHealth:
60
+ host = os.getenv("DEX_DB_HOST")
61
+ port = os.getenv("DEX_DB_PORT")
62
+ if not host or not port:
63
+ return ComponentHealth(
64
+ name="database",
65
+ status=HealthStatus.SKIPPED,
66
+ message="database not configured",
67
+ )
68
+
69
+ start = time.perf_counter()
70
+ ok, message = await self._tcp_check(host, int(port))
71
+ duration_ms = (time.perf_counter() - start) * 1000
72
+ return ComponentHealth(
73
+ name="database",
74
+ status=HealthStatus.HEALTHY if ok else HealthStatus.UNHEALTHY,
75
+ message=message,
76
+ duration_ms=round(duration_ms, 2),
77
+ )
78
+
79
+ async def check_cache(self) -> ComponentHealth:
80
+ host = os.getenv("DEX_CACHE_HOST")
81
+ port = os.getenv("DEX_CACHE_PORT")
82
+ if not host or not port:
83
+ return ComponentHealth(
84
+ name="cache",
85
+ status=HealthStatus.SKIPPED,
86
+ message="cache not configured",
87
+ )
88
+
89
+ start = time.perf_counter()
90
+ ok, message = await self._tcp_check(host, int(port))
91
+ duration_ms = (time.perf_counter() - start) * 1000
92
+ return ComponentHealth(
93
+ name="cache",
94
+ status=HealthStatus.HEALTHY if ok else HealthStatus.UNHEALTHY,
95
+ message=message,
96
+ duration_ms=round(duration_ms, 2),
97
+ )
98
+
99
+ async def check_external_api(self) -> ComponentHealth:
100
+ url = os.getenv("DEX_EXTERNAL_API_URL")
101
+ if not url:
102
+ return ComponentHealth(
103
+ name="external_api",
104
+ status=HealthStatus.SKIPPED,
105
+ message="external API not configured",
106
+ )
107
+
108
+ start = time.perf_counter()
109
+ timeout = httpx.Timeout(self.timeout_seconds)
110
+ try:
111
+ async with httpx.AsyncClient(timeout=timeout) as client:
112
+ response = await client.get(url)
113
+ ok = response.status_code < 500
114
+ message = f"status_code={response.status_code}"
115
+ except httpx.HTTPError as exc:
116
+ ok = False
117
+ message = f"error={exc.__class__.__name__}"
118
+ duration_ms = (time.perf_counter() - start) * 1000
119
+ return ComponentHealth(
120
+ name="external_api",
121
+ status=HealthStatus.HEALTHY if ok else HealthStatus.UNHEALTHY,
122
+ message=message,
123
+ duration_ms=round(duration_ms, 2),
124
+ )
125
+
126
+ async def _tcp_check(self, host: str, port: int) -> tuple[bool, str]:
127
+ try:
128
+ await asyncio.wait_for(
129
+ asyncio.open_connection(host, port), timeout=self.timeout_seconds
130
+ )
131
+ return True, "reachable"
132
+ except (TimeoutError, OSError) as exc:
133
+ return False, f"error={exc.__class__.__name__}"
@@ -0,0 +1,94 @@
1
+ """
2
+ Cursor-based pagination utilities for the DEX API.
3
+
4
+ Usage::
5
+
6
+ from dataenginex.api.pagination import PaginatedResponse, paginate
7
+
8
+ @app.get("/api/v1/items")
9
+ def list_items(cursor: str | None = None, limit: int = 20):
10
+ all_items = get_all_items()
11
+ return paginate(all_items, cursor=cursor, limit=limit)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import base64
17
+ import json
18
+ from typing import Any, TypeVar
19
+
20
+ from pydantic import BaseModel, Field
21
+
22
+ T = TypeVar("T")
23
+
24
+
25
+ class PaginationMeta(BaseModel):
26
+ """Pagination metadata returned alongside results."""
27
+
28
+ total: int = Field(description="Total number of items")
29
+ limit: int = Field(description="Page size")
30
+ has_next: bool = Field(description="Whether more items exist")
31
+ next_cursor: str | None = Field(None, description="Opaque cursor for next page")
32
+ has_previous: bool = Field(default=False, description="Whether previous items exist")
33
+
34
+
35
+ class PaginatedResponse(BaseModel):
36
+ """Generic paginated response wrapper."""
37
+
38
+ data: list[Any] = Field(default_factory=list)
39
+ pagination: PaginationMeta
40
+
41
+
42
+ def encode_cursor(offset: int) -> str:
43
+ """Encode an integer offset into an opaque base64 cursor."""
44
+ return base64.urlsafe_b64encode(json.dumps({"o": offset}).encode()).decode()
45
+
46
+
47
+ def decode_cursor(cursor: str) -> int:
48
+ """Decode a cursor back to an integer offset. Returns 0 on failure."""
49
+ try:
50
+ data = json.loads(base64.urlsafe_b64decode(cursor))
51
+ return int(data.get("o", 0))
52
+ except Exception:
53
+ return 0
54
+
55
+
56
+ def paginate(
57
+ items: list[Any],
58
+ *,
59
+ cursor: str | None = None,
60
+ limit: int = 20,
61
+ max_limit: int = 100,
62
+ ) -> PaginatedResponse:
63
+ """Slice *items* and return a ``PaginatedResponse``.
64
+
65
+ Parameters
66
+ ----------
67
+ items:
68
+ Full list of items to paginate.
69
+ cursor:
70
+ Opaque cursor from a previous response (or *None* for the first page).
71
+ limit:
72
+ Number of items per page.
73
+ max_limit:
74
+ Hard ceiling on *limit* to prevent abuse.
75
+ """
76
+ limit = min(max(1, limit), max_limit)
77
+ offset = decode_cursor(cursor) if cursor else 0
78
+ total = len(items)
79
+
80
+ page = items[offset : offset + limit]
81
+ has_next = (offset + limit) < total
82
+ next_cursor = encode_cursor(offset + limit) if has_next else None
83
+ has_previous = offset > 0
84
+
85
+ return PaginatedResponse(
86
+ data=page,
87
+ pagination=PaginationMeta(
88
+ total=total,
89
+ limit=limit,
90
+ has_next=has_next,
91
+ next_cursor=next_cursor,
92
+ has_previous=has_previous,
93
+ ),
94
+ )
@@ -0,0 +1,122 @@
1
+ """
2
+ Rate-limiting middleware for DEX API.
3
+
4
+ Implements a token-bucket algorithm per client IP. Configuration:
5
+
6
+ DEX_RATE_LIMIT_ENABLED — "true" to enable (default "false")
7
+ DEX_RATE_LIMIT_RPM — Requests per minute per IP (default 60)
8
+ DEX_RATE_LIMIT_BURST — Maximum burst size (default 10)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import time
15
+ from dataclasses import dataclass
16
+ from typing import Any
17
+
18
+ from fastapi import Request
19
+ from fastapi.responses import JSONResponse
20
+ from loguru import logger
21
+ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
22
+ from starlette.responses import Response
23
+
24
+
25
+ @dataclass
26
+ class _Bucket:
27
+ """Token bucket for a single client."""
28
+
29
+ tokens: float
30
+ last_refill: float
31
+ capacity: float
32
+ refill_rate: float # tokens per second
33
+
34
+ def consume(self) -> bool:
35
+ now = time.monotonic()
36
+ elapsed = now - self.last_refill
37
+ self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
38
+ self.last_refill = now
39
+ if self.tokens >= 1.0:
40
+ self.tokens -= 1.0
41
+ return True
42
+ return False
43
+
44
+
45
+ class RateLimiter:
46
+ """In-memory token-bucket rate limiter.
47
+
48
+ Parameters
49
+ ----------
50
+ requests_per_minute:
51
+ Sustained request rate per client.
52
+ burst:
53
+ Maximum instantaneous burst size.
54
+ """
55
+
56
+ def __init__(self, requests_per_minute: int = 60, burst: int = 10) -> None:
57
+ self.rpm = requests_per_minute
58
+ self.burst = burst
59
+ self._buckets: dict[str, _Bucket] = {}
60
+
61
+ def allow(self, client_id: str) -> bool:
62
+ bucket = self._buckets.get(client_id)
63
+ if bucket is None:
64
+ bucket = _Bucket(
65
+ tokens=float(self.burst),
66
+ last_refill=time.monotonic(),
67
+ capacity=float(self.burst),
68
+ refill_rate=self.rpm / 60.0,
69
+ )
70
+ self._buckets[client_id] = bucket
71
+ return bucket.consume()
72
+
73
+ def get_stats(self) -> dict[str, Any]:
74
+ return {
75
+ "rpm": self.rpm,
76
+ "burst": self.burst,
77
+ "active_clients": len(self._buckets),
78
+ }
79
+
80
+ def cleanup(self, max_age_seconds: float = 300.0) -> int:
81
+ """Evict buckets idle for more than *max_age_seconds*."""
82
+ now = time.monotonic()
83
+ stale = [k for k, b in self._buckets.items() if now - b.last_refill > max_age_seconds]
84
+ for k in stale:
85
+ del self._buckets[k]
86
+ return len(stale)
87
+
88
+
89
+ # Paths exempt from rate limiting
90
+ _EXEMPT_PATHS: set[str] = {"/health", "/ready", "/startup", "/metrics"}
91
+
92
+
93
+ class RateLimitMiddleware(BaseHTTPMiddleware):
94
+ """Starlette middleware applying per-IP rate limiting."""
95
+
96
+ def __init__(self, app: Any) -> None: # noqa: ANN401
97
+ super().__init__(app)
98
+ rpm = int(os.getenv("DEX_RATE_LIMIT_RPM", "60"))
99
+ burst = int(os.getenv("DEX_RATE_LIMIT_BURST", "10"))
100
+ self._limiter = RateLimiter(requests_per_minute=rpm, burst=burst)
101
+ self._enabled = os.getenv("DEX_RATE_LIMIT_ENABLED", "false").lower() == "true"
102
+
103
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
104
+ if not self._enabled:
105
+ return await call_next(request)
106
+
107
+ if request.url.path in _EXEMPT_PATHS:
108
+ return await call_next(request)
109
+
110
+ client_ip = request.client.host if request.client else "unknown"
111
+ if not self._limiter.allow(client_ip):
112
+ logger.warning("Rate limit exceeded for %s", client_ip)
113
+ return JSONResponse(
114
+ status_code=429,
115
+ content={
116
+ "error": "rate_limit_exceeded",
117
+ "message": "Too many requests — please slow down",
118
+ },
119
+ headers={"Retry-After": "60"},
120
+ )
121
+
122
+ return await call_next(request)