fastapi-radar 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi_radar/__init__.py +3 -2
- fastapi_radar/api.py +217 -34
- fastapi_radar/background.py +120 -0
- fastapi_radar/capture.py +38 -10
- fastapi_radar/dashboard/dist/assets/index-8Om0PGu6.js +326 -0
- fastapi_radar/dashboard/dist/assets/index-D51YrvFG.css +1 -0
- fastapi_radar/dashboard/dist/assets/index-p3czTzXB.js +361 -0
- fastapi_radar/dashboard/dist/index.html +1 -1
- fastapi_radar/dashboard/node_modules/flatted/python/flatted.py +149 -0
- fastapi_radar/middleware.py +26 -9
- fastapi_radar/models.py +41 -8
- fastapi_radar/radar.py +70 -25
- fastapi_radar/tracing.py +6 -6
- fastapi_radar/utils.py +24 -0
- {fastapi_radar-0.1.8.dist-info → fastapi_radar-0.3.1.dist-info}/METADATA +23 -2
- fastapi_radar-0.3.1.dist-info/RECORD +19 -0
- {fastapi_radar-0.1.8.dist-info → fastapi_radar-0.3.1.dist-info}/top_level.txt +0 -1
- fastapi_radar/dashboard/dist/assets/index-By5DXl8Z.js +0 -318
- fastapi_radar/dashboard/dist/assets/index-XlGcZj49.css +0 -1
- fastapi_radar-0.1.8.dist-info/RECORD +0 -18
- tests/__init__.py +0 -1
- tests/test_radar.py +0 -75
- {fastapi_radar-0.1.8.dist-info → fastapi_radar-0.3.1.dist-info}/WHEEL +0 -0
- {fastapi_radar-0.1.8.dist-info → fastapi_radar-0.3.1.dist-info}/licenses/LICENSE +0 -0
fastapi_radar/__init__.py
CHANGED
fastapi_radar/api.py
CHANGED
|
@@ -1,14 +1,23 @@
|
|
|
1
1
|
"""API endpoints for FastAPI Radar dashboard."""
|
|
2
2
|
|
|
3
|
-
from datetime import datetime, timedelta
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
4
4
|
from typing import Any, Dict, List, Optional, Union
|
|
5
|
+
import uuid
|
|
5
6
|
|
|
6
7
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
7
8
|
from pydantic import BaseModel
|
|
8
|
-
from sqlalchemy import desc
|
|
9
|
+
from sqlalchemy import case, desc, func
|
|
9
10
|
from sqlalchemy.orm import Session
|
|
10
|
-
|
|
11
|
-
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from .models import (
|
|
14
|
+
CapturedRequest,
|
|
15
|
+
CapturedQuery,
|
|
16
|
+
CapturedException,
|
|
17
|
+
Trace,
|
|
18
|
+
Span,
|
|
19
|
+
BackgroundTask,
|
|
20
|
+
)
|
|
12
21
|
from .tracing import TracingManager
|
|
13
22
|
|
|
14
23
|
|
|
@@ -92,6 +101,19 @@ class TraceSummary(BaseModel):
|
|
|
92
101
|
created_at: datetime
|
|
93
102
|
|
|
94
103
|
|
|
104
|
+
class BackgroundTaskSummary(BaseModel):
|
|
105
|
+
id: int
|
|
106
|
+
task_id: str
|
|
107
|
+
request_id: Optional[str]
|
|
108
|
+
name: str
|
|
109
|
+
status: str
|
|
110
|
+
start_time: Optional[datetime]
|
|
111
|
+
end_time: Optional[datetime]
|
|
112
|
+
duration_ms: Optional[float]
|
|
113
|
+
error: Optional[str]
|
|
114
|
+
created_at: datetime
|
|
115
|
+
|
|
116
|
+
|
|
95
117
|
class WaterfallSpan(BaseModel):
|
|
96
118
|
span_id: str
|
|
97
119
|
parent_span_id: Optional[str]
|
|
@@ -135,10 +157,16 @@ def create_api_router(get_session_context) -> APIRouter:
|
|
|
135
157
|
status_code: Optional[int] = None,
|
|
136
158
|
method: Optional[str] = None,
|
|
137
159
|
search: Optional[str] = None,
|
|
160
|
+
start_time: Optional[datetime] = None,
|
|
161
|
+
end_time: Optional[datetime] = None,
|
|
138
162
|
session: Session = Depends(get_db),
|
|
139
163
|
):
|
|
140
164
|
query = session.query(CapturedRequest)
|
|
141
165
|
|
|
166
|
+
if start_time:
|
|
167
|
+
query = query.filter(CapturedRequest.created_at >= start_time)
|
|
168
|
+
if end_time:
|
|
169
|
+
query = query.filter(CapturedRequest.created_at <= end_time)
|
|
142
170
|
if status_code:
|
|
143
171
|
if status_code in [200, 300, 400, 500]:
|
|
144
172
|
# Filter by status code range
|
|
@@ -228,6 +256,124 @@ def create_api_router(get_session_context) -> APIRouter:
|
|
|
228
256
|
],
|
|
229
257
|
)
|
|
230
258
|
|
|
259
|
+
@router.get("/requests/{request_id}/curl")
|
|
260
|
+
async def get_request_as_curl(request_id: str, session: Session = Depends(get_db)):
|
|
261
|
+
request = (
|
|
262
|
+
session.query(CapturedRequest)
|
|
263
|
+
.filter(CapturedRequest.request_id == request_id)
|
|
264
|
+
.first()
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if not request:
|
|
268
|
+
raise HTTPException(status_code=404, detail="Request not found")
|
|
269
|
+
|
|
270
|
+
# Build cURL command
|
|
271
|
+
parts = [f"curl -X {request.method}"]
|
|
272
|
+
|
|
273
|
+
# Add headers
|
|
274
|
+
if request.headers:
|
|
275
|
+
for key, value in request.headers.items():
|
|
276
|
+
if key.lower() not in ["host", "content-length"]:
|
|
277
|
+
parts.append(f"-H '{key}: {value}'")
|
|
278
|
+
|
|
279
|
+
# Add body
|
|
280
|
+
if request.body:
|
|
281
|
+
parts.append(f"-d '{request.body}'")
|
|
282
|
+
|
|
283
|
+
# Add URL (use full URL if available, otherwise construct from path)
|
|
284
|
+
url = request.url if request.url else request.path
|
|
285
|
+
parts.append(f"'{url}'")
|
|
286
|
+
|
|
287
|
+
return {"curl": " ".join(parts)}
|
|
288
|
+
|
|
289
|
+
@router.post("/requests/{request_id}/replay")
|
|
290
|
+
async def replay_request(
|
|
291
|
+
request_id: str,
|
|
292
|
+
body: Optional[Dict[str, Any]] = None,
|
|
293
|
+
session: Session = Depends(get_db),
|
|
294
|
+
):
|
|
295
|
+
"""Replay a captured request with optional body override.
|
|
296
|
+
|
|
297
|
+
WARNING: This endpoint replays HTTP requests. Use with caution in production.
|
|
298
|
+
Consider adding authentication and rate limiting.
|
|
299
|
+
"""
|
|
300
|
+
request = (
|
|
301
|
+
session.query(CapturedRequest)
|
|
302
|
+
.filter(CapturedRequest.request_id == request_id)
|
|
303
|
+
.first()
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
if not request:
|
|
307
|
+
raise HTTPException(status_code=404, detail="Request not found")
|
|
308
|
+
|
|
309
|
+
# Security: Validate URL to prevent SSRF attacks
|
|
310
|
+
# Note: This is basic protection. For production, consider:
|
|
311
|
+
# 1. Whitelist allowed domains
|
|
312
|
+
# 2. Add authentication to this endpoint
|
|
313
|
+
# 3. Add rate limiting
|
|
314
|
+
# For dev/testing, allow localhost. For production, consider blocking.
|
|
315
|
+
# Example: Uncomment below to block all internal IPs:
|
|
316
|
+
# from urllib.parse import urlparse
|
|
317
|
+
# parsed = urlparse(request.url)
|
|
318
|
+
# if parsed.hostname in ["localhost", "127.0.0.1", "0.0.0.0", "::1", "::ffff:127.0.0.1"]:
|
|
319
|
+
# raise HTTPException(status_code=403, detail="Replay to localhost is disabled")
|
|
320
|
+
|
|
321
|
+
# Build replay request
|
|
322
|
+
headers = dict(request.headers) if request.headers else {}
|
|
323
|
+
# Remove hop-by-hop headers
|
|
324
|
+
headers.pop("host", None)
|
|
325
|
+
headers.pop("content-length", None)
|
|
326
|
+
headers.pop("connection", None)
|
|
327
|
+
headers.pop("keep-alive", None)
|
|
328
|
+
headers.pop("transfer-encoding", None)
|
|
329
|
+
|
|
330
|
+
request_body = body if body is not None else request.body
|
|
331
|
+
|
|
332
|
+
try:
|
|
333
|
+
async with httpx.AsyncClient(
|
|
334
|
+
timeout=30.0, follow_redirects=False
|
|
335
|
+
) as client:
|
|
336
|
+
response = await client.request(
|
|
337
|
+
method=request.method,
|
|
338
|
+
url=request.url,
|
|
339
|
+
headers=headers,
|
|
340
|
+
content=(
|
|
341
|
+
request_body if isinstance(request_body, (str, bytes)) else None
|
|
342
|
+
),
|
|
343
|
+
json=request_body if isinstance(request_body, dict) else None,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Store the replayed request
|
|
347
|
+
replayed_request = CapturedRequest(
|
|
348
|
+
request_id=str(uuid.uuid4()),
|
|
349
|
+
method=request.method,
|
|
350
|
+
url=request.url,
|
|
351
|
+
path=request.path,
|
|
352
|
+
query_params=request.query_params,
|
|
353
|
+
headers=dict(response.request.headers),
|
|
354
|
+
body=request_body if isinstance(request_body, str) else None,
|
|
355
|
+
status_code=response.status_code,
|
|
356
|
+
response_body=response.text[:10000] if response.text else None,
|
|
357
|
+
response_headers=dict(response.headers),
|
|
358
|
+
duration_ms=response.elapsed.total_seconds() * 1000,
|
|
359
|
+
client_ip="replay",
|
|
360
|
+
)
|
|
361
|
+
session.add(replayed_request)
|
|
362
|
+
session.commit()
|
|
363
|
+
session.refresh(replayed_request)
|
|
364
|
+
|
|
365
|
+
return {
|
|
366
|
+
"status_code": response.status_code,
|
|
367
|
+
"headers": dict(response.headers),
|
|
368
|
+
"body": response.text,
|
|
369
|
+
"elapsed_ms": response.elapsed.total_seconds() * 1000,
|
|
370
|
+
"original_status": request.status_code,
|
|
371
|
+
"original_duration_ms": request.duration_ms,
|
|
372
|
+
"new_request_id": replayed_request.request_id,
|
|
373
|
+
}
|
|
374
|
+
except httpx.RequestError as e:
|
|
375
|
+
raise HTTPException(status_code=500, detail=f"Replay failed: {str(e)}")
|
|
376
|
+
|
|
231
377
|
@router.get("/queries", response_model=List[QueryDetail])
|
|
232
378
|
async def get_queries(
|
|
233
379
|
limit: int = Query(100, ge=1, le=1000),
|
|
@@ -302,45 +448,43 @@ def create_api_router(get_session_context) -> APIRouter:
|
|
|
302
448
|
slow_threshold: int = Query(100),
|
|
303
449
|
session: Session = Depends(get_db),
|
|
304
450
|
):
|
|
305
|
-
since = datetime.
|
|
451
|
+
since = datetime.now(timezone.utc) - timedelta(hours=hours)
|
|
306
452
|
|
|
307
453
|
requests = (
|
|
308
|
-
session.query(
|
|
454
|
+
session.query(
|
|
455
|
+
func.count().label("total_requests"),
|
|
456
|
+
func.avg(CapturedRequest.duration_ms).label("avg_response_time"),
|
|
457
|
+
)
|
|
309
458
|
.filter(CapturedRequest.created_at >= since)
|
|
310
|
-
.
|
|
459
|
+
.one()
|
|
311
460
|
)
|
|
312
461
|
|
|
313
462
|
queries = (
|
|
314
|
-
session.query(
|
|
463
|
+
session.query(
|
|
464
|
+
func.count().label("total_queries"),
|
|
465
|
+
func.avg(CapturedQuery.duration_ms).label("avg_query_time"),
|
|
466
|
+
func.sum(
|
|
467
|
+
case((CapturedQuery.duration_ms >= slow_threshold, 1), else_=0)
|
|
468
|
+
).label("slow_queries"),
|
|
469
|
+
)
|
|
470
|
+
.filter(CapturedQuery.created_at >= since)
|
|
471
|
+
.one()
|
|
315
472
|
)
|
|
316
473
|
|
|
317
474
|
exceptions = (
|
|
318
|
-
session.query(
|
|
475
|
+
session.query(func.count().label("total_exceptions"))
|
|
319
476
|
.filter(CapturedException.created_at >= since)
|
|
320
|
-
.
|
|
477
|
+
.one()
|
|
321
478
|
)
|
|
322
479
|
|
|
323
|
-
total_requests =
|
|
324
|
-
avg_response_time =
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
avg_query_time = None
|
|
332
|
-
slow_queries = 0
|
|
333
|
-
if queries:
|
|
334
|
-
valid_times = [q.duration_ms for q in queries if q.duration_ms is not None]
|
|
335
|
-
if valid_times:
|
|
336
|
-
avg_query_time = sum(valid_times) / len(valid_times)
|
|
337
|
-
slow_queries = len(
|
|
338
|
-
[
|
|
339
|
-
q
|
|
340
|
-
for q in queries
|
|
341
|
-
if q.duration_ms and q.duration_ms >= slow_threshold
|
|
342
|
-
]
|
|
343
|
-
)
|
|
480
|
+
total_requests = requests.total_requests
|
|
481
|
+
avg_response_time = requests.avg_response_time
|
|
482
|
+
|
|
483
|
+
total_queries = queries.total_queries
|
|
484
|
+
avg_query_time = queries.avg_query_time
|
|
485
|
+
slow_queries = queries.slow_queries or 0
|
|
486
|
+
|
|
487
|
+
total_exceptions = exceptions.total_exceptions
|
|
344
488
|
|
|
345
489
|
requests_per_minute = total_requests / (hours * 60)
|
|
346
490
|
|
|
@@ -349,7 +493,7 @@ def create_api_router(get_session_context) -> APIRouter:
|
|
|
349
493
|
avg_response_time=round_float(avg_response_time),
|
|
350
494
|
total_queries=total_queries,
|
|
351
495
|
avg_query_time=round_float(avg_query_time),
|
|
352
|
-
total_exceptions=
|
|
496
|
+
total_exceptions=total_exceptions,
|
|
353
497
|
slow_queries=slow_queries,
|
|
354
498
|
requests_per_minute=round_float(requests_per_minute),
|
|
355
499
|
)
|
|
@@ -359,7 +503,7 @@ def create_api_router(get_session_context) -> APIRouter:
|
|
|
359
503
|
older_than_hours: Optional[int] = None, session: Session = Depends(get_db)
|
|
360
504
|
):
|
|
361
505
|
if older_than_hours:
|
|
362
|
-
cutoff = datetime.
|
|
506
|
+
cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
|
|
363
507
|
session.query(CapturedRequest).filter(
|
|
364
508
|
CapturedRequest.created_at < cutoff
|
|
365
509
|
).delete()
|
|
@@ -382,7 +526,7 @@ def create_api_router(get_session_context) -> APIRouter:
|
|
|
382
526
|
session: Session = Depends(get_db),
|
|
383
527
|
):
|
|
384
528
|
"""List traces."""
|
|
385
|
-
since = datetime.
|
|
529
|
+
since = datetime.now(timezone.utc) - timedelta(hours=hours)
|
|
386
530
|
query = session.query(Trace).filter(Trace.created_at >= since)
|
|
387
531
|
|
|
388
532
|
if status:
|
|
@@ -491,4 +635,43 @@ def create_api_router(get_session_context) -> APIRouter:
|
|
|
491
635
|
"created_at": span.created_at.isoformat(),
|
|
492
636
|
}
|
|
493
637
|
|
|
638
|
+
@router.get("/background-tasks", response_model=List[BackgroundTaskSummary])
|
|
639
|
+
async def get_background_tasks(
|
|
640
|
+
limit: int = Query(100, ge=1, le=1000),
|
|
641
|
+
offset: int = Query(0, ge=0),
|
|
642
|
+
status: Optional[str] = None,
|
|
643
|
+
request_id: Optional[str] = None,
|
|
644
|
+
session: Session = Depends(get_db),
|
|
645
|
+
):
|
|
646
|
+
"""Get background tasks with optional filters."""
|
|
647
|
+
query = session.query(BackgroundTask)
|
|
648
|
+
|
|
649
|
+
if status:
|
|
650
|
+
query = query.filter(BackgroundTask.status == status)
|
|
651
|
+
if request_id:
|
|
652
|
+
query = query.filter(BackgroundTask.request_id == request_id)
|
|
653
|
+
|
|
654
|
+
tasks = (
|
|
655
|
+
query.order_by(desc(BackgroundTask.created_at))
|
|
656
|
+
.offset(offset)
|
|
657
|
+
.limit(limit)
|
|
658
|
+
.all()
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
return [
|
|
662
|
+
BackgroundTaskSummary(
|
|
663
|
+
id=task.id,
|
|
664
|
+
task_id=task.task_id,
|
|
665
|
+
request_id=task.request_id,
|
|
666
|
+
name=task.name,
|
|
667
|
+
status=task.status,
|
|
668
|
+
start_time=task.start_time,
|
|
669
|
+
end_time=task.end_time,
|
|
670
|
+
duration_ms=round_float(task.duration_ms),
|
|
671
|
+
error=task.error,
|
|
672
|
+
created_at=task.created_at,
|
|
673
|
+
)
|
|
674
|
+
for task in tasks
|
|
675
|
+
]
|
|
676
|
+
|
|
494
677
|
return router
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Background task monitoring for FastAPI Radar."""
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from typing import Callable, Any
|
|
9
|
+
|
|
10
|
+
from .models import BackgroundTask
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def track_background_task(get_session: Callable):
|
|
14
|
+
"""Decorator to track background tasks.
|
|
15
|
+
|
|
16
|
+
Can optionally accept request_id as kwarg:
|
|
17
|
+
background_tasks.add_task(my_task, arg1, request_id="abc123")
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def decorator(func: Callable) -> Callable:
|
|
21
|
+
@wraps(func)
|
|
22
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
23
|
+
task_id = str(uuid.uuid4())
|
|
24
|
+
# Extract request_id from kwargs if provided
|
|
25
|
+
req_id = kwargs.pop("_radar_request_id", None)
|
|
26
|
+
# Clean task name (just function name, not full module path)
|
|
27
|
+
task_name = func.__name__
|
|
28
|
+
|
|
29
|
+
# Create task record
|
|
30
|
+
with get_session() as session:
|
|
31
|
+
task = BackgroundTask(
|
|
32
|
+
task_id=task_id,
|
|
33
|
+
request_id=req_id,
|
|
34
|
+
name=task_name,
|
|
35
|
+
status="running",
|
|
36
|
+
start_time=datetime.now(timezone.utc),
|
|
37
|
+
)
|
|
38
|
+
session.add(task)
|
|
39
|
+
session.commit()
|
|
40
|
+
|
|
41
|
+
start_time = time.time()
|
|
42
|
+
error = None
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
result = await func(*args, **kwargs)
|
|
46
|
+
status = "completed"
|
|
47
|
+
return result
|
|
48
|
+
except Exception as e:
|
|
49
|
+
status = "failed"
|
|
50
|
+
error = str(e)
|
|
51
|
+
raise
|
|
52
|
+
finally:
|
|
53
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
54
|
+
|
|
55
|
+
with get_session() as session:
|
|
56
|
+
task = (
|
|
57
|
+
session.query(BackgroundTask)
|
|
58
|
+
.filter(BackgroundTask.task_id == task_id)
|
|
59
|
+
.first()
|
|
60
|
+
)
|
|
61
|
+
if task:
|
|
62
|
+
task.status = status
|
|
63
|
+
task.end_time = datetime.now(timezone.utc)
|
|
64
|
+
task.duration_ms = duration_ms
|
|
65
|
+
task.error = error
|
|
66
|
+
session.commit()
|
|
67
|
+
|
|
68
|
+
@wraps(func)
|
|
69
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
70
|
+
task_id = str(uuid.uuid4())
|
|
71
|
+
# Extract request_id from kwargs if provided
|
|
72
|
+
req_id = kwargs.pop("_radar_request_id", None)
|
|
73
|
+
# Clean task name (just function name, not full module path)
|
|
74
|
+
task_name = func.__name__
|
|
75
|
+
|
|
76
|
+
# Create task record
|
|
77
|
+
with get_session() as session:
|
|
78
|
+
task = BackgroundTask(
|
|
79
|
+
task_id=task_id,
|
|
80
|
+
request_id=req_id,
|
|
81
|
+
name=task_name,
|
|
82
|
+
status="running",
|
|
83
|
+
start_time=datetime.now(timezone.utc),
|
|
84
|
+
)
|
|
85
|
+
session.add(task)
|
|
86
|
+
session.commit()
|
|
87
|
+
|
|
88
|
+
start_time = time.time()
|
|
89
|
+
error = None
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
result = func(*args, **kwargs)
|
|
93
|
+
status = "completed"
|
|
94
|
+
return result
|
|
95
|
+
except Exception as e:
|
|
96
|
+
status = "failed"
|
|
97
|
+
error = str(e)
|
|
98
|
+
raise
|
|
99
|
+
finally:
|
|
100
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
101
|
+
|
|
102
|
+
with get_session() as session:
|
|
103
|
+
task = (
|
|
104
|
+
session.query(BackgroundTask)
|
|
105
|
+
.filter(BackgroundTask.task_id == task_id)
|
|
106
|
+
.first()
|
|
107
|
+
)
|
|
108
|
+
if task:
|
|
109
|
+
task.status = status
|
|
110
|
+
task.end_time = datetime.now(timezone.utc)
|
|
111
|
+
task.duration_ms = duration_ms
|
|
112
|
+
task.error = error
|
|
113
|
+
session.commit()
|
|
114
|
+
|
|
115
|
+
# Return appropriate wrapper based on function type
|
|
116
|
+
if inspect.iscoroutinefunction(func):
|
|
117
|
+
return async_wrapper
|
|
118
|
+
return sync_wrapper
|
|
119
|
+
|
|
120
|
+
return decorator
|
fastapi_radar/capture.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
|
1
1
|
"""SQLAlchemy query capture for FastAPI Radar."""
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
|
-
from typing import Any, Callable, Dict, List, Union
|
|
5
|
-
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
6
5
|
from sqlalchemy import event
|
|
7
6
|
from sqlalchemy.engine import Engine
|
|
8
7
|
|
|
8
|
+
try: # SQLAlchemy async support is optional
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
10
|
+
except Exception: # pragma: no cover - module might not exist in older SQLAlchemy
|
|
11
|
+
AsyncEngine = None # type: ignore[assignment]
|
|
9
12
|
from .middleware import request_context
|
|
10
13
|
from .models import CapturedQuery
|
|
14
|
+
|
|
15
|
+
|
|
11
16
|
from .utils import format_sql
|
|
12
17
|
from .tracing import get_current_trace_context
|
|
13
18
|
|
|
@@ -23,14 +28,20 @@ class QueryCapture:
|
|
|
23
28
|
self.capture_bindings = capture_bindings
|
|
24
29
|
self.slow_query_threshold = slow_query_threshold
|
|
25
30
|
self._query_start_times = {}
|
|
31
|
+
self._registered_engines: Dict[int, Engine] = {}
|
|
26
32
|
|
|
27
33
|
def register(self, engine: Engine) -> None:
|
|
28
|
-
|
|
29
|
-
event.listen(
|
|
34
|
+
sync_engine = self._resolve_engine(engine)
|
|
35
|
+
event.listen(sync_engine, "before_cursor_execute", self._before_cursor_execute)
|
|
36
|
+
event.listen(sync_engine, "after_cursor_execute", self._after_cursor_execute)
|
|
37
|
+
self._registered_engines[id(engine)] = sync_engine
|
|
30
38
|
|
|
31
39
|
def unregister(self, engine: Engine) -> None:
|
|
32
|
-
|
|
33
|
-
|
|
40
|
+
sync_engine = self._registered_engines.pop(id(engine), None)
|
|
41
|
+
if not sync_engine:
|
|
42
|
+
sync_engine = self._resolve_engine(engine)
|
|
43
|
+
event.remove(sync_engine, "before_cursor_execute", self._before_cursor_execute)
|
|
44
|
+
event.remove(sync_engine, "after_cursor_execute", self._after_cursor_execute)
|
|
34
45
|
|
|
35
46
|
def _before_cursor_execute(
|
|
36
47
|
self,
|
|
@@ -44,14 +55,14 @@ class QueryCapture:
|
|
|
44
55
|
request_id = request_context.get()
|
|
45
56
|
if not request_id:
|
|
46
57
|
return
|
|
47
|
-
|
|
48
58
|
context_id = id(context)
|
|
49
59
|
self._query_start_times[context_id] = time.time()
|
|
50
|
-
|
|
60
|
+
setattr(context, "_radar_request_id", request_id)
|
|
51
61
|
trace_ctx = get_current_trace_context()
|
|
52
62
|
if trace_ctx:
|
|
53
63
|
formatted_sql = format_sql(statement)
|
|
54
64
|
operation_type = self._get_operation_type(statement)
|
|
65
|
+
db_tags = self._get_db_tags(conn)
|
|
55
66
|
span_id = trace_ctx.create_span(
|
|
56
67
|
operation_name=f"DB {operation_type}",
|
|
57
68
|
span_kind="client",
|
|
@@ -59,6 +70,7 @@ class QueryCapture:
|
|
|
59
70
|
"db.statement": formatted_sql[:500], # limit SQL length
|
|
60
71
|
"db.operation_type": operation_type,
|
|
61
72
|
"component": "database",
|
|
73
|
+
**db_tags,
|
|
62
74
|
},
|
|
63
75
|
)
|
|
64
76
|
setattr(context, "_radar_span_id", span_id)
|
|
@@ -74,8 +86,9 @@ class QueryCapture:
|
|
|
74
86
|
) -> None:
|
|
75
87
|
request_id = request_context.get()
|
|
76
88
|
if not request_id:
|
|
77
|
-
|
|
78
|
-
|
|
89
|
+
request_id = getattr(context, "_radar_request_id", None)
|
|
90
|
+
if not request_id:
|
|
91
|
+
return
|
|
79
92
|
start_time = self._query_start_times.pop(id(context), None)
|
|
80
93
|
if start_time is None:
|
|
81
94
|
return
|
|
@@ -160,3 +173,18 @@ class QueryCapture:
|
|
|
160
173
|
return {k: str(v) for k, v in list(parameters.items())[:100]}
|
|
161
174
|
|
|
162
175
|
return [str(parameters)]
|
|
176
|
+
|
|
177
|
+
def _resolve_engine(self, engine: Engine) -> Engine:
|
|
178
|
+
if AsyncEngine is not None and isinstance(engine, AsyncEngine):
|
|
179
|
+
return engine.sync_engine
|
|
180
|
+
return engine
|
|
181
|
+
|
|
182
|
+
def _get_db_tags(self, conn: Any) -> Dict[str, Optional[str]]:
|
|
183
|
+
tags: Dict[str, Optional[str]] = {}
|
|
184
|
+
engine = getattr(conn, "engine", None)
|
|
185
|
+
if engine and getattr(engine, "dialect", None):
|
|
186
|
+
tags["db.system"] = engine.dialect.name
|
|
187
|
+
url = getattr(engine, "url", None)
|
|
188
|
+
if url is not None:
|
|
189
|
+
tags["db.name"] = getattr(url, "database", None)
|
|
190
|
+
return tags
|