vega-framework 0.1.35__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vega/cli/commands/add.py +9 -10
- vega/cli/commands/generate.py +15 -15
- vega/cli/commands/init.py +9 -8
- vega/cli/commands/web.py +8 -7
- vega/cli/main.py +4 -4
- vega/cli/scaffolds/__init__.py +6 -2
- vega/cli/scaffolds/vega_web.py +109 -0
- vega/cli/templates/__init__.py +34 -8
- vega/cli/templates/components.py +29 -13
- vega/cli/templates/project/ARCHITECTURE.md.j2 +13 -13
- vega/cli/templates/project/README.md.j2 +5 -5
- vega/cli/templates/web/app.py.j2 +5 -5
- vega/cli/templates/web/health_route.py.j2 +2 -2
- vega/cli/templates/web/main.py.j2 +2 -3
- vega/cli/templates/web/middleware.py.j2 +3 -3
- vega/cli/templates/web/router.py.j2 +2 -2
- vega/cli/templates/web/routes_init.py.j2 +3 -3
- vega/cli/templates/web/routes_init_autodiscovery.py.j2 +2 -2
- vega/cli/templates/web/users_route.py.j2 +2 -2
- vega/discovery/routes.py +13 -13
- vega/web/__init__.py +100 -0
- vega/web/application.py +234 -0
- vega/web/builtin_middlewares.py +288 -0
- vega/web/exceptions.py +151 -0
- vega/web/middleware.py +185 -0
- vega/web/request.py +120 -0
- vega/web/response.py +220 -0
- vega/web/route_middleware.py +266 -0
- vega/web/router.py +350 -0
- vega/web/routing.py +347 -0
- {vega_framework-0.1.35.dist-info → vega_framework-0.2.1.dist-info}/METADATA +10 -9
- {vega_framework-0.1.35.dist-info → vega_framework-0.2.1.dist-info}/RECORD +35 -24
- {vega_framework-0.1.35.dist-info → vega_framework-0.2.1.dist-info}/WHEEL +0 -0
- {vega_framework-0.1.35.dist-info → vega_framework-0.2.1.dist-info}/entry_points.txt +0 -0
- {vega_framework-0.1.35.dist-info → vega_framework-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,288 @@
|
|
1
|
+
"""Built-in route middleware implementations for Vega Web Framework"""
|
2
|
+
|
3
|
+
import time
|
4
|
+
import logging
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
from .route_middleware import RouteMiddleware, MiddlewarePhase
|
8
|
+
from .request import Request
|
9
|
+
from .response import Response, JSONResponse
|
10
|
+
from .exceptions import HTTPException, status
|
11
|
+
|
12
|
+
|
13
|
+
class AuthMiddleware(RouteMiddleware):
|
14
|
+
"""
|
15
|
+
Authentication middleware - validates Authorization header.
|
16
|
+
|
17
|
+
Executes BEFORE the handler.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
header_name: Name of the header to check (default: "Authorization")
|
21
|
+
scheme: Expected auth scheme (default: "Bearer")
|
22
|
+
|
23
|
+
Example:
|
24
|
+
@router.get("/protected")
|
25
|
+
@middleware(AuthMiddleware())
|
26
|
+
async def protected_route():
|
27
|
+
return {"data": "secret"}
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, header_name: str = "Authorization", scheme: str = "Bearer"):
|
31
|
+
super().__init__(phase=MiddlewarePhase.BEFORE)
|
32
|
+
self.header_name = header_name
|
33
|
+
self.scheme = scheme
|
34
|
+
|
35
|
+
async def before(self, request: Request) -> Optional[Response]:
|
36
|
+
"""Check for valid authentication token"""
|
37
|
+
auth_header = request.headers.get(self.header_name.lower())
|
38
|
+
|
39
|
+
if not auth_header:
|
40
|
+
return JSONResponse(
|
41
|
+
content={"detail": "Missing authentication credentials"},
|
42
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
43
|
+
headers={"WWW-Authenticate": f"{self.scheme}"},
|
44
|
+
)
|
45
|
+
|
46
|
+
# Parse scheme and token
|
47
|
+
parts = auth_header.split()
|
48
|
+
if len(parts) != 2 or parts[0].lower() != self.scheme.lower():
|
49
|
+
return JSONResponse(
|
50
|
+
content={"detail": f"Invalid authentication scheme. Expected: {self.scheme}"},
|
51
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
52
|
+
headers={"WWW-Authenticate": f"{self.scheme}"},
|
53
|
+
)
|
54
|
+
|
55
|
+
token = parts[1]
|
56
|
+
|
57
|
+
# TODO: Validate token here (this is a simple example)
|
58
|
+
if token == "invalid":
|
59
|
+
return JSONResponse(
|
60
|
+
content={"detail": "Invalid or expired token"},
|
61
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
62
|
+
)
|
63
|
+
|
64
|
+
# Store user info in request state for handler to use
|
65
|
+
request.state.user_id = "user_from_token"
|
66
|
+
request.state.token = token
|
67
|
+
|
68
|
+
return None # Continue to handler
|
69
|
+
|
70
|
+
|
71
|
+
class TimingMiddleware(RouteMiddleware):
|
72
|
+
"""
|
73
|
+
Request timing middleware - measures execution time.
|
74
|
+
|
75
|
+
Executes BOTH before and after the handler.
|
76
|
+
|
77
|
+
Example:
|
78
|
+
@router.get("/slow-operation")
|
79
|
+
@middleware(TimingMiddleware())
|
80
|
+
async def slow_operation():
|
81
|
+
await asyncio.sleep(1)
|
82
|
+
return {"status": "done"}
|
83
|
+
"""
|
84
|
+
|
85
|
+
def __init__(self):
|
86
|
+
super().__init__(phase=MiddlewarePhase.BOTH)
|
87
|
+
self.logger = logging.getLogger("vega.web.timing")
|
88
|
+
|
89
|
+
async def before(self, request: Request) -> Optional[Response]:
|
90
|
+
"""Record start time"""
|
91
|
+
request.state.start_time = time.time()
|
92
|
+
return None
|
93
|
+
|
94
|
+
async def after(self, request: Request, response: Response) -> Response:
|
95
|
+
"""Calculate and log execution time"""
|
96
|
+
if hasattr(request.state, "start_time"):
|
97
|
+
duration = time.time() - request.state.start_time
|
98
|
+
self.logger.info(
|
99
|
+
f"{request.method} {request.url.path} completed in {duration:.3f}s"
|
100
|
+
)
|
101
|
+
|
102
|
+
# Add timing header to response
|
103
|
+
if hasattr(response, "headers"):
|
104
|
+
response.headers["X-Process-Time"] = f"{duration:.3f}"
|
105
|
+
|
106
|
+
return response
|
107
|
+
|
108
|
+
|
109
|
+
class CacheControlMiddleware(RouteMiddleware):
|
110
|
+
"""
|
111
|
+
Cache control middleware - adds cache headers to response.
|
112
|
+
|
113
|
+
Executes AFTER the handler.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
max_age: Cache max age in seconds
|
117
|
+
public: Whether cache is public (default: True)
|
118
|
+
|
119
|
+
Example:
|
120
|
+
@router.get("/static-data")
|
121
|
+
@middleware(CacheControlMiddleware(max_age=3600))
|
122
|
+
async def get_static_data():
|
123
|
+
return {"data": "rarely changes"}
|
124
|
+
"""
|
125
|
+
|
126
|
+
def __init__(self, max_age: int = 300, public: bool = True):
|
127
|
+
super().__init__(phase=MiddlewarePhase.AFTER)
|
128
|
+
self.max_age = max_age
|
129
|
+
self.public = public
|
130
|
+
|
131
|
+
async def after(self, request: Request, response: Response) -> Response:
|
132
|
+
"""Add cache control headers"""
|
133
|
+
cache_type = "public" if self.public else "private"
|
134
|
+
cache_value = f"{cache_type}, max-age={self.max_age}"
|
135
|
+
|
136
|
+
if hasattr(response, "headers"):
|
137
|
+
response.headers["Cache-Control"] = cache_value
|
138
|
+
|
139
|
+
return response
|
140
|
+
|
141
|
+
|
142
|
+
class CORSMiddleware(RouteMiddleware):
|
143
|
+
"""
|
144
|
+
CORS middleware for specific routes.
|
145
|
+
|
146
|
+
Executes AFTER the handler to add CORS headers.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
allow_origins: List of allowed origins or "*"
|
150
|
+
allow_methods: List of allowed methods
|
151
|
+
allow_headers: List of allowed headers
|
152
|
+
|
153
|
+
Example:
|
154
|
+
@router.get("/public-api/data")
|
155
|
+
@middleware(CORSMiddleware(allow_origins=["*"]))
|
156
|
+
async def public_data():
|
157
|
+
return {"data": "public"}
|
158
|
+
"""
|
159
|
+
|
160
|
+
def __init__(
|
161
|
+
self,
|
162
|
+
allow_origins: list = None,
|
163
|
+
allow_methods: list = None,
|
164
|
+
allow_headers: list = None,
|
165
|
+
):
|
166
|
+
super().__init__(phase=MiddlewarePhase.AFTER)
|
167
|
+
self.allow_origins = allow_origins or ["*"]
|
168
|
+
self.allow_methods = allow_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
169
|
+
self.allow_headers = allow_headers or ["*"]
|
170
|
+
|
171
|
+
async def after(self, request: Request, response: Response) -> Response:
|
172
|
+
"""Add CORS headers to response"""
|
173
|
+
if hasattr(response, "headers"):
|
174
|
+
origin = request.headers.get("origin", "*")
|
175
|
+
|
176
|
+
if "*" in self.allow_origins or origin in self.allow_origins:
|
177
|
+
response.headers["Access-Control-Allow-Origin"] = origin
|
178
|
+
response.headers["Access-Control-Allow-Methods"] = ", ".join(
|
179
|
+
self.allow_methods
|
180
|
+
)
|
181
|
+
response.headers["Access-Control-Allow-Headers"] = ", ".join(
|
182
|
+
self.allow_headers
|
183
|
+
)
|
184
|
+
|
185
|
+
return response
|
186
|
+
|
187
|
+
|
188
|
+
class RateLimitMiddleware(RouteMiddleware):
|
189
|
+
"""
|
190
|
+
Simple rate limiting middleware.
|
191
|
+
|
192
|
+
Executes BEFORE the handler.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
max_requests: Maximum requests allowed
|
196
|
+
window_seconds: Time window in seconds
|
197
|
+
|
198
|
+
Example:
|
199
|
+
@router.post("/expensive-operation")
|
200
|
+
@middleware(RateLimitMiddleware(max_requests=10, window_seconds=60))
|
201
|
+
async def expensive_op():
|
202
|
+
return {"status": "processing"}
|
203
|
+
"""
|
204
|
+
|
205
|
+
def __init__(self, max_requests: int = 100, window_seconds: int = 60):
|
206
|
+
super().__init__(phase=MiddlewarePhase.BEFORE)
|
207
|
+
self.max_requests = max_requests
|
208
|
+
self.window_seconds = window_seconds
|
209
|
+
self.requests = {} # IP -> [timestamps]
|
210
|
+
|
211
|
+
async def before(self, request: Request) -> Optional[Response]:
|
212
|
+
"""Check rate limit"""
|
213
|
+
from datetime import datetime, timedelta
|
214
|
+
|
215
|
+
client_ip = request.client.host if request.client else "unknown"
|
216
|
+
now = datetime.now()
|
217
|
+
|
218
|
+
# Clean old entries
|
219
|
+
if client_ip in self.requests:
|
220
|
+
cutoff = now - timedelta(seconds=self.window_seconds)
|
221
|
+
self.requests[client_ip] = [
|
222
|
+
ts for ts in self.requests[client_ip] if ts > cutoff
|
223
|
+
]
|
224
|
+
|
225
|
+
# Check limit
|
226
|
+
if client_ip in self.requests:
|
227
|
+
if len(self.requests[client_ip]) >= self.max_requests:
|
228
|
+
return JSONResponse(
|
229
|
+
content={
|
230
|
+
"detail": f"Rate limit exceeded. Max {self.max_requests} requests per {self.window_seconds}s"
|
231
|
+
},
|
232
|
+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
233
|
+
headers={
|
234
|
+
"Retry-After": str(self.window_seconds),
|
235
|
+
"X-RateLimit-Limit": str(self.max_requests),
|
236
|
+
"X-RateLimit-Remaining": "0",
|
237
|
+
},
|
238
|
+
)
|
239
|
+
|
240
|
+
# Record request
|
241
|
+
if client_ip not in self.requests:
|
242
|
+
self.requests[client_ip] = []
|
243
|
+
self.requests[client_ip].append(now)
|
244
|
+
|
245
|
+
return None
|
246
|
+
|
247
|
+
|
248
|
+
class LoggingMiddleware(RouteMiddleware):
|
249
|
+
"""
|
250
|
+
Request/Response logging middleware.
|
251
|
+
|
252
|
+
Executes BOTH before and after.
|
253
|
+
|
254
|
+
Example:
|
255
|
+
@router.post("/important-action")
|
256
|
+
@middleware(LoggingMiddleware())
|
257
|
+
async def important_action():
|
258
|
+
return {"status": "done"}
|
259
|
+
"""
|
260
|
+
|
261
|
+
def __init__(self, logger_name: str = "vega.web.routes"):
|
262
|
+
super().__init__(phase=MiddlewarePhase.BOTH)
|
263
|
+
self.logger = logging.getLogger(logger_name)
|
264
|
+
|
265
|
+
async def before(self, request: Request) -> Optional[Response]:
|
266
|
+
"""Log incoming request"""
|
267
|
+
self.logger.info(
|
268
|
+
f"→ {request.method} {request.url.path} "
|
269
|
+
f"from {request.client.host if request.client else 'unknown'}"
|
270
|
+
)
|
271
|
+
return None
|
272
|
+
|
273
|
+
async def after(self, request: Request, response: Response) -> Response:
|
274
|
+
"""Log response"""
|
275
|
+
self.logger.info(
|
276
|
+
f"← {request.method} {request.url.path} [{response.status_code}]"
|
277
|
+
)
|
278
|
+
return response
|
279
|
+
|
280
|
+
|
281
|
+
__all__ = [
|
282
|
+
"AuthMiddleware",
|
283
|
+
"TimingMiddleware",
|
284
|
+
"CacheControlMiddleware",
|
285
|
+
"CORSMiddleware",
|
286
|
+
"RateLimitMiddleware",
|
287
|
+
"LoggingMiddleware",
|
288
|
+
]
|
vega/web/exceptions.py
ADDED
@@ -0,0 +1,151 @@
|
|
1
|
+
"""HTTP exceptions and status codes for Vega Web Framework"""
|
2
|
+
|
3
|
+
from typing import Any, Dict, Optional
|
4
|
+
|
5
|
+
|
6
|
+
class HTTPException(Exception):
|
7
|
+
"""
|
8
|
+
HTTP exception that can be raised to return an HTTP error response.
|
9
|
+
|
10
|
+
Compatible with FastAPI's HTTPException API for easy migration.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
status_code: HTTP status code
|
14
|
+
detail: Error message or detail object
|
15
|
+
headers: Optional HTTP headers to include in the response
|
16
|
+
|
17
|
+
Example:
|
18
|
+
raise HTTPException(status_code=404, detail="User not found")
|
19
|
+
raise HTTPException(
|
20
|
+
status_code=401,
|
21
|
+
detail="Not authenticated",
|
22
|
+
headers={"WWW-Authenticate": "Bearer"}
|
23
|
+
)
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
status_code: int,
|
29
|
+
detail: Any = None,
|
30
|
+
headers: Optional[Dict[str, str]] = None
|
31
|
+
) -> None:
|
32
|
+
self.status_code = status_code
|
33
|
+
self.detail = detail
|
34
|
+
self.headers = headers
|
35
|
+
super().__init__(detail)
|
36
|
+
|
37
|
+
def __repr__(self) -> str:
|
38
|
+
return f"{self.__class__.__name__}(status_code={self.status_code}, detail={self.detail})"
|
39
|
+
|
40
|
+
|
41
|
+
class ValidationError(HTTPException):
|
42
|
+
"""Raised when request validation fails (422)"""
|
43
|
+
|
44
|
+
def __init__(self, detail: Any = "Validation Error", headers: Optional[Dict[str, str]] = None):
|
45
|
+
super().__init__(status_code=422, detail=detail, headers=headers)
|
46
|
+
|
47
|
+
|
48
|
+
class NotFoundError(HTTPException):
|
49
|
+
"""Raised when resource is not found (404)"""
|
50
|
+
|
51
|
+
def __init__(self, detail: Any = "Not Found", headers: Optional[Dict[str, str]] = None):
|
52
|
+
super().__init__(status_code=404, detail=detail, headers=headers)
|
53
|
+
|
54
|
+
|
55
|
+
class UnauthorizedError(HTTPException):
|
56
|
+
"""Raised when authentication is required (401)"""
|
57
|
+
|
58
|
+
def __init__(self, detail: Any = "Unauthorized", headers: Optional[Dict[str, str]] = None):
|
59
|
+
super().__init__(status_code=401, detail=detail, headers=headers)
|
60
|
+
|
61
|
+
|
62
|
+
class ForbiddenError(HTTPException):
|
63
|
+
"""Raised when access is forbidden (403)"""
|
64
|
+
|
65
|
+
def __init__(self, detail: Any = "Forbidden", headers: Optional[Dict[str, str]] = None):
|
66
|
+
super().__init__(status_code=403, detail=detail, headers=headers)
|
67
|
+
|
68
|
+
|
69
|
+
class BadRequestError(HTTPException):
|
70
|
+
"""Raised for bad requests (400)"""
|
71
|
+
|
72
|
+
def __init__(self, detail: Any = "Bad Request", headers: Optional[Dict[str, str]] = None):
|
73
|
+
super().__init__(status_code=400, detail=detail, headers=headers)
|
74
|
+
|
75
|
+
|
76
|
+
# HTTP Status codes - compatible with FastAPI's status module
|
77
|
+
class status:
|
78
|
+
"""HTTP status codes (compatible with fastapi.status)"""
|
79
|
+
|
80
|
+
# 1xx Informational
|
81
|
+
HTTP_100_CONTINUE = 100
|
82
|
+
HTTP_101_SWITCHING_PROTOCOLS = 101
|
83
|
+
HTTP_102_PROCESSING = 102
|
84
|
+
HTTP_103_EARLY_HINTS = 103
|
85
|
+
|
86
|
+
# 2xx Success
|
87
|
+
HTTP_200_OK = 200
|
88
|
+
HTTP_201_CREATED = 201
|
89
|
+
HTTP_202_ACCEPTED = 202
|
90
|
+
HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203
|
91
|
+
HTTP_204_NO_CONTENT = 204
|
92
|
+
HTTP_205_RESET_CONTENT = 205
|
93
|
+
HTTP_206_PARTIAL_CONTENT = 206
|
94
|
+
HTTP_207_MULTI_STATUS = 207
|
95
|
+
HTTP_208_ALREADY_REPORTED = 208
|
96
|
+
HTTP_226_IM_USED = 226
|
97
|
+
|
98
|
+
# 3xx Redirection
|
99
|
+
HTTP_300_MULTIPLE_CHOICES = 300
|
100
|
+
HTTP_301_MOVED_PERMANENTLY = 301
|
101
|
+
HTTP_302_FOUND = 302
|
102
|
+
HTTP_303_SEE_OTHER = 303
|
103
|
+
HTTP_304_NOT_MODIFIED = 304
|
104
|
+
HTTP_305_USE_PROXY = 305
|
105
|
+
HTTP_306_RESERVED = 306
|
106
|
+
HTTP_307_TEMPORARY_REDIRECT = 307
|
107
|
+
HTTP_308_PERMANENT_REDIRECT = 308
|
108
|
+
|
109
|
+
# 4xx Client Error
|
110
|
+
HTTP_400_BAD_REQUEST = 400
|
111
|
+
HTTP_401_UNAUTHORIZED = 401
|
112
|
+
HTTP_402_PAYMENT_REQUIRED = 402
|
113
|
+
HTTP_403_FORBIDDEN = 403
|
114
|
+
HTTP_404_NOT_FOUND = 404
|
115
|
+
HTTP_405_METHOD_NOT_ALLOWED = 405
|
116
|
+
HTTP_406_NOT_ACCEPTABLE = 406
|
117
|
+
HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407
|
118
|
+
HTTP_408_REQUEST_TIMEOUT = 408
|
119
|
+
HTTP_409_CONFLICT = 409
|
120
|
+
HTTP_410_GONE = 410
|
121
|
+
HTTP_411_LENGTH_REQUIRED = 411
|
122
|
+
HTTP_412_PRECONDITION_FAILED = 412
|
123
|
+
HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413
|
124
|
+
HTTP_414_REQUEST_URI_TOO_LONG = 414
|
125
|
+
HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415
|
126
|
+
HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416
|
127
|
+
HTTP_417_EXPECTATION_FAILED = 417
|
128
|
+
HTTP_418_IM_A_TEAPOT = 418
|
129
|
+
HTTP_421_MISDIRECTED_REQUEST = 421
|
130
|
+
HTTP_422_UNPROCESSABLE_ENTITY = 422
|
131
|
+
HTTP_423_LOCKED = 423
|
132
|
+
HTTP_424_FAILED_DEPENDENCY = 424
|
133
|
+
HTTP_425_TOO_EARLY = 425
|
134
|
+
HTTP_426_UPGRADE_REQUIRED = 426
|
135
|
+
HTTP_428_PRECONDITION_REQUIRED = 428
|
136
|
+
HTTP_429_TOO_MANY_REQUESTS = 429
|
137
|
+
HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431
|
138
|
+
HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451
|
139
|
+
|
140
|
+
# 5xx Server Error
|
141
|
+
HTTP_500_INTERNAL_SERVER_ERROR = 500
|
142
|
+
HTTP_501_NOT_IMPLEMENTED = 501
|
143
|
+
HTTP_502_BAD_GATEWAY = 502
|
144
|
+
HTTP_503_SERVICE_UNAVAILABLE = 503
|
145
|
+
HTTP_504_GATEWAY_TIMEOUT = 504
|
146
|
+
HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505
|
147
|
+
HTTP_506_VARIANT_ALSO_NEGOTIATES = 506
|
148
|
+
HTTP_507_INSUFFICIENT_STORAGE = 507
|
149
|
+
HTTP_508_LOOP_DETECTED = 508
|
150
|
+
HTTP_510_NOT_EXTENDED = 510
|
151
|
+
HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511
|
vega/web/middleware.py
ADDED
@@ -0,0 +1,185 @@
|
|
1
|
+
"""Middleware utilities for Vega Web Framework"""
|
2
|
+
|
3
|
+
from typing import Callable
|
4
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
5
|
+
from starlette.requests import Request
|
6
|
+
from starlette.responses import Response
|
7
|
+
|
8
|
+
|
9
|
+
class VegaMiddleware(BaseHTTPMiddleware):
|
10
|
+
"""
|
11
|
+
Base middleware class for Vega applications.
|
12
|
+
|
13
|
+
Extend this class to create custom middleware.
|
14
|
+
|
15
|
+
Example:
|
16
|
+
class LoggingMiddleware(VegaMiddleware):
|
17
|
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
18
|
+
print(f"Request: {request.method} {request.url.path}")
|
19
|
+
response = await call_next(request)
|
20
|
+
print(f"Response: {response.status_code}")
|
21
|
+
return response
|
22
|
+
|
23
|
+
app.add_middleware(LoggingMiddleware)
|
24
|
+
"""
|
25
|
+
|
26
|
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
27
|
+
"""
|
28
|
+
Process the request.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
request: Incoming request
|
32
|
+
call_next: Function to call next middleware or endpoint
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
Response object
|
36
|
+
"""
|
37
|
+
return await call_next(request)
|
38
|
+
|
39
|
+
|
40
|
+
class CORSMiddleware:
|
41
|
+
"""
|
42
|
+
CORS (Cross-Origin Resource Sharing) middleware.
|
43
|
+
|
44
|
+
This is a re-export of Starlette's CORSMiddleware for convenience.
|
45
|
+
|
46
|
+
Example:
|
47
|
+
from vega.web import VegaApp
|
48
|
+
from vega.web.middleware import CORSMiddleware
|
49
|
+
|
50
|
+
app = VegaApp()
|
51
|
+
app.add_middleware(
|
52
|
+
CORSMiddleware,
|
53
|
+
allow_origins=["*"],
|
54
|
+
allow_credentials=True,
|
55
|
+
allow_methods=["*"],
|
56
|
+
allow_headers=["*"],
|
57
|
+
)
|
58
|
+
"""
|
59
|
+
|
60
|
+
# This will be imported from Starlette
|
61
|
+
from starlette.middleware.cors import CORSMiddleware as _CORSMiddleware
|
62
|
+
|
63
|
+
def __new__(cls, *args, **kwargs):
|
64
|
+
return cls._CORSMiddleware(*args, **kwargs)
|
65
|
+
|
66
|
+
|
67
|
+
class TrustedHostMiddleware:
|
68
|
+
"""
|
69
|
+
Middleware to validate the Host header.
|
70
|
+
|
71
|
+
This is a re-export of Starlette's TrustedHostMiddleware.
|
72
|
+
|
73
|
+
Example:
|
74
|
+
app.add_middleware(
|
75
|
+
TrustedHostMiddleware,
|
76
|
+
allowed_hosts=["example.com", "*.example.com"]
|
77
|
+
)
|
78
|
+
"""
|
79
|
+
|
80
|
+
from starlette.middleware.trustedhost import TrustedHostMiddleware as _TrustedHostMiddleware
|
81
|
+
|
82
|
+
def __new__(cls, *args, **kwargs):
|
83
|
+
return cls._TrustedHostMiddleware(*args, **kwargs)
|
84
|
+
|
85
|
+
|
86
|
+
class GZipMiddleware:
|
87
|
+
"""
|
88
|
+
Middleware to compress responses using GZip.
|
89
|
+
|
90
|
+
This is a re-export of Starlette's GZipMiddleware.
|
91
|
+
|
92
|
+
Example:
|
93
|
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
94
|
+
"""
|
95
|
+
|
96
|
+
from starlette.middleware.gzip import GZipMiddleware as _GZipMiddleware
|
97
|
+
|
98
|
+
def __new__(cls, *args, **kwargs):
|
99
|
+
return cls._GZipMiddleware(*args, **kwargs)
|
100
|
+
|
101
|
+
|
102
|
+
class RateLimitMiddleware(VegaMiddleware):
|
103
|
+
"""
|
104
|
+
Simple rate limiting middleware (example implementation).
|
105
|
+
|
106
|
+
Args:
|
107
|
+
requests_per_minute: Maximum requests per minute per IP
|
108
|
+
|
109
|
+
Example:
|
110
|
+
app.add_middleware(RateLimitMiddleware, requests_per_minute=60)
|
111
|
+
"""
|
112
|
+
|
113
|
+
def __init__(self, app, requests_per_minute: int = 60):
|
114
|
+
super().__init__(app)
|
115
|
+
self.requests_per_minute = requests_per_minute
|
116
|
+
self.requests: dict = {} # IP -> [timestamps]
|
117
|
+
|
118
|
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
119
|
+
"""Check rate limit before processing request"""
|
120
|
+
from datetime import datetime, timedelta
|
121
|
+
|
122
|
+
client_ip = request.client.host if request.client else "unknown"
|
123
|
+
now = datetime.now()
|
124
|
+
|
125
|
+
# Clean old entries
|
126
|
+
if client_ip in self.requests:
|
127
|
+
cutoff = now - timedelta(minutes=1)
|
128
|
+
self.requests[client_ip] = [
|
129
|
+
ts for ts in self.requests[client_ip] if ts > cutoff
|
130
|
+
]
|
131
|
+
|
132
|
+
# Check rate limit
|
133
|
+
if client_ip in self.requests:
|
134
|
+
if len(self.requests[client_ip]) >= self.requests_per_minute:
|
135
|
+
from ..response import JSONResponse
|
136
|
+
return JSONResponse(
|
137
|
+
content={"detail": "Rate limit exceeded"},
|
138
|
+
status_code=429,
|
139
|
+
)
|
140
|
+
|
141
|
+
# Record request
|
142
|
+
if client_ip not in self.requests:
|
143
|
+
self.requests[client_ip] = []
|
144
|
+
self.requests[client_ip].append(now)
|
145
|
+
|
146
|
+
return await call_next(request)
|
147
|
+
|
148
|
+
|
149
|
+
class RequestLoggingMiddleware(VegaMiddleware):
|
150
|
+
"""
|
151
|
+
Middleware to log all requests.
|
152
|
+
|
153
|
+
Example:
|
154
|
+
app.add_middleware(RequestLoggingMiddleware)
|
155
|
+
"""
|
156
|
+
|
157
|
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
158
|
+
"""Log request and response"""
|
159
|
+
import logging
|
160
|
+
import time
|
161
|
+
|
162
|
+
logger = logging.getLogger("vega.web")
|
163
|
+
|
164
|
+
start_time = time.time()
|
165
|
+
logger.info(f"→ {request.method} {request.url.path}")
|
166
|
+
|
167
|
+
response = await call_next(request)
|
168
|
+
|
169
|
+
process_time = time.time() - start_time
|
170
|
+
logger.info(
|
171
|
+
f"← {request.method} {request.url.path} "
|
172
|
+
f"[{response.status_code}] {process_time:.3f}s"
|
173
|
+
)
|
174
|
+
|
175
|
+
return response
|
176
|
+
|
177
|
+
|
178
|
+
__all__ = [
|
179
|
+
"VegaMiddleware",
|
180
|
+
"CORSMiddleware",
|
181
|
+
"TrustedHostMiddleware",
|
182
|
+
"GZipMiddleware",
|
183
|
+
"RateLimitMiddleware",
|
184
|
+
"RequestLoggingMiddleware",
|
185
|
+
]
|