fastapi-radar 0.1.7__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fastapi-radar might be problematic. Click here for more details.

fastapi_radar/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """FastAPI Radar - Debugging dashboard for FastAPI applications."""
2
2
 
3
3
  from .radar import Radar
4
+ from .background import track_background_task
4
5
 
5
- __version__ = "0.1.7"
6
- __all__ = ["Radar"]
6
+ __version__ = "0.3.0"
7
+ __all__ = ["Radar", "track_background_task"]
fastapi_radar/api.py CHANGED
@@ -1,14 +1,23 @@
1
1
  """API endpoints for FastAPI Radar dashboard."""
2
2
 
3
- from datetime import datetime, timedelta
3
+ from datetime import datetime, timedelta, timezone
4
4
  from typing import Any, Dict, List, Optional, Union
5
+ import uuid
5
6
 
6
7
  from fastapi import APIRouter, Depends, HTTPException, Query
7
8
  from pydantic import BaseModel
8
- from sqlalchemy import desc
9
+ from sqlalchemy import case, desc, func
9
10
  from sqlalchemy.orm import Session
10
-
11
- from .models import CapturedRequest, CapturedQuery, CapturedException, Trace, Span
11
+ import httpx
12
+
13
+ from .models import (
14
+ CapturedRequest,
15
+ CapturedQuery,
16
+ CapturedException,
17
+ Trace,
18
+ Span,
19
+ BackgroundTask,
20
+ )
12
21
  from .tracing import TracingManager
13
22
 
14
23
 
@@ -92,6 +101,19 @@ class TraceSummary(BaseModel):
92
101
  created_at: datetime
93
102
 
94
103
 
104
+ class BackgroundTaskSummary(BaseModel):
105
+ id: int
106
+ task_id: str
107
+ request_id: Optional[str]
108
+ name: str
109
+ status: str
110
+ start_time: Optional[datetime]
111
+ end_time: Optional[datetime]
112
+ duration_ms: Optional[float]
113
+ error: Optional[str]
114
+ created_at: datetime
115
+
116
+
95
117
  class WaterfallSpan(BaseModel):
96
118
  span_id: str
97
119
  parent_span_id: Optional[str]
@@ -135,10 +157,16 @@ def create_api_router(get_session_context) -> APIRouter:
135
157
  status_code: Optional[int] = None,
136
158
  method: Optional[str] = None,
137
159
  search: Optional[str] = None,
160
+ start_time: Optional[datetime] = None,
161
+ end_time: Optional[datetime] = None,
138
162
  session: Session = Depends(get_db),
139
163
  ):
140
164
  query = session.query(CapturedRequest)
141
165
 
166
+ if start_time:
167
+ query = query.filter(CapturedRequest.created_at >= start_time)
168
+ if end_time:
169
+ query = query.filter(CapturedRequest.created_at <= end_time)
142
170
  if status_code:
143
171
  if status_code in [200, 300, 400, 500]:
144
172
  # Filter by status code range
@@ -228,6 +256,124 @@ def create_api_router(get_session_context) -> APIRouter:
228
256
  ],
229
257
  )
230
258
 
259
+ @router.get("/requests/{request_id}/curl")
260
+ async def get_request_as_curl(request_id: str, session: Session = Depends(get_db)):
261
+ request = (
262
+ session.query(CapturedRequest)
263
+ .filter(CapturedRequest.request_id == request_id)
264
+ .first()
265
+ )
266
+
267
+ if not request:
268
+ raise HTTPException(status_code=404, detail="Request not found")
269
+
270
+ # Build cURL command
271
+ parts = [f"curl -X {request.method}"]
272
+
273
+ # Add headers
274
+ if request.headers:
275
+ for key, value in request.headers.items():
276
+ if key.lower() not in ["host", "content-length"]:
277
+ parts.append(f"-H '{key}: {value}'")
278
+
279
+ # Add body
280
+ if request.body:
281
+ parts.append(f"-d '{request.body}'")
282
+
283
+ # Add URL (use full URL if available, otherwise construct from path)
284
+ url = request.url if request.url else request.path
285
+ parts.append(f"'{url}'")
286
+
287
+ return {"curl": " ".join(parts)}
288
+
289
+ @router.post("/requests/{request_id}/replay")
290
+ async def replay_request(
291
+ request_id: str,
292
+ body: Optional[Dict[str, Any]] = None,
293
+ session: Session = Depends(get_db),
294
+ ):
295
+ """Replay a captured request with optional body override.
296
+
297
+ WARNING: This endpoint replays HTTP requests. Use with caution in production.
298
+ Consider adding authentication and rate limiting.
299
+ """
300
+ request = (
301
+ session.query(CapturedRequest)
302
+ .filter(CapturedRequest.request_id == request_id)
303
+ .first()
304
+ )
305
+
306
+ if not request:
307
+ raise HTTPException(status_code=404, detail="Request not found")
308
+
309
+ # Security: Validate URL to prevent SSRF attacks
310
+ # Note: This is basic protection. For production, consider:
311
+ # 1. Whitelist allowed domains
312
+ # 2. Add authentication to this endpoint
313
+ # 3. Add rate limiting
314
+ # For dev/testing, allow localhost. For production, consider blocking.
315
+ # Example: Uncomment below to block all internal IPs:
316
+ # from urllib.parse import urlparse
317
+ # parsed = urlparse(request.url)
318
+ # if parsed.hostname in ["localhost", "127.0.0.1", "0.0.0.0", "::1", "::ffff:127.0.0.1"]:
319
+ # raise HTTPException(status_code=403, detail="Replay to localhost is disabled")
320
+
321
+ # Build replay request
322
+ headers = dict(request.headers) if request.headers else {}
323
+ # Remove hop-by-hop headers
324
+ headers.pop("host", None)
325
+ headers.pop("content-length", None)
326
+ headers.pop("connection", None)
327
+ headers.pop("keep-alive", None)
328
+ headers.pop("transfer-encoding", None)
329
+
330
+ request_body = body if body is not None else request.body
331
+
332
+ try:
333
+ async with httpx.AsyncClient(
334
+ timeout=30.0, follow_redirects=False
335
+ ) as client:
336
+ response = await client.request(
337
+ method=request.method,
338
+ url=request.url,
339
+ headers=headers,
340
+ content=(
341
+ request_body if isinstance(request_body, (str, bytes)) else None
342
+ ),
343
+ json=request_body if isinstance(request_body, dict) else None,
344
+ )
345
+
346
+ # Store the replayed request
347
+ replayed_request = CapturedRequest(
348
+ request_id=str(uuid.uuid4()),
349
+ method=request.method,
350
+ url=request.url,
351
+ path=request.path,
352
+ query_params=request.query_params,
353
+ headers=dict(response.request.headers),
354
+ body=request_body if isinstance(request_body, str) else None,
355
+ status_code=response.status_code,
356
+ response_body=response.text[:10000] if response.text else None,
357
+ response_headers=dict(response.headers),
358
+ duration_ms=response.elapsed.total_seconds() * 1000,
359
+ client_ip="replay",
360
+ )
361
+ session.add(replayed_request)
362
+ session.commit()
363
+ session.refresh(replayed_request)
364
+
365
+ return {
366
+ "status_code": response.status_code,
367
+ "headers": dict(response.headers),
368
+ "body": response.text,
369
+ "elapsed_ms": response.elapsed.total_seconds() * 1000,
370
+ "original_status": request.status_code,
371
+ "original_duration_ms": request.duration_ms,
372
+ "new_request_id": replayed_request.request_id,
373
+ }
374
+ except httpx.RequestError as e:
375
+ raise HTTPException(status_code=500, detail=f"Replay failed: {str(e)}")
376
+
231
377
  @router.get("/queries", response_model=List[QueryDetail])
232
378
  async def get_queries(
233
379
  limit: int = Query(100, ge=1, le=1000),
@@ -302,45 +448,43 @@ def create_api_router(get_session_context) -> APIRouter:
302
448
  slow_threshold: int = Query(100),
303
449
  session: Session = Depends(get_db),
304
450
  ):
305
- since = datetime.utcnow() - timedelta(hours=hours)
451
+ since = datetime.now(timezone.utc) - timedelta(hours=hours)
306
452
 
307
453
  requests = (
308
- session.query(CapturedRequest)
454
+ session.query(
455
+ func.count().label("total_requests"),
456
+ func.avg(CapturedRequest.duration_ms).label("avg_response_time"),
457
+ )
309
458
  .filter(CapturedRequest.created_at >= since)
310
- .all()
459
+ .one()
311
460
  )
312
461
 
313
462
  queries = (
314
- session.query(CapturedQuery).filter(CapturedQuery.created_at >= since).all()
463
+ session.query(
464
+ func.count().label("total_queries"),
465
+ func.avg(CapturedQuery.duration_ms).label("avg_query_time"),
466
+ func.sum(
467
+ case((CapturedQuery.duration_ms >= slow_threshold, 1), else_=0)
468
+ ).label("slow_queries"),
469
+ )
470
+ .filter(CapturedQuery.created_at >= since)
471
+ .one()
315
472
  )
316
473
 
317
474
  exceptions = (
318
- session.query(CapturedException)
475
+ session.query(func.count().label("total_exceptions"))
319
476
  .filter(CapturedException.created_at >= since)
320
- .all()
477
+ .one()
321
478
  )
322
479
 
323
- total_requests = len(requests)
324
- avg_response_time = None
325
- if requests:
326
- valid_times = [r.duration_ms for r in requests if r.duration_ms is not None]
327
- if valid_times:
328
- avg_response_time = sum(valid_times) / len(valid_times)
329
-
330
- total_queries = len(queries)
331
- avg_query_time = None
332
- slow_queries = 0
333
- if queries:
334
- valid_times = [q.duration_ms for q in queries if q.duration_ms is not None]
335
- if valid_times:
336
- avg_query_time = sum(valid_times) / len(valid_times)
337
- slow_queries = len(
338
- [
339
- q
340
- for q in queries
341
- if q.duration_ms and q.duration_ms >= slow_threshold
342
- ]
343
- )
480
+ total_requests = requests.total_requests
481
+ avg_response_time = requests.avg_response_time
482
+
483
+ total_queries = queries.total_queries
484
+ avg_query_time = queries.avg_query_time
485
+ slow_queries = queries.slow_queries or 0
486
+
487
+ total_exceptions = exceptions.total_exceptions
344
488
 
345
489
  requests_per_minute = total_requests / (hours * 60)
346
490
 
@@ -349,7 +493,7 @@ def create_api_router(get_session_context) -> APIRouter:
349
493
  avg_response_time=round_float(avg_response_time),
350
494
  total_queries=total_queries,
351
495
  avg_query_time=round_float(avg_query_time),
352
- total_exceptions=len(exceptions),
496
+ total_exceptions=total_exceptions,
353
497
  slow_queries=slow_queries,
354
498
  requests_per_minute=round_float(requests_per_minute),
355
499
  )
@@ -359,7 +503,7 @@ def create_api_router(get_session_context) -> APIRouter:
359
503
  older_than_hours: Optional[int] = None, session: Session = Depends(get_db)
360
504
  ):
361
505
  if older_than_hours:
362
- cutoff = datetime.utcnow() - timedelta(hours=older_than_hours)
506
+ cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
363
507
  session.query(CapturedRequest).filter(
364
508
  CapturedRequest.created_at < cutoff
365
509
  ).delete()
@@ -382,7 +526,7 @@ def create_api_router(get_session_context) -> APIRouter:
382
526
  session: Session = Depends(get_db),
383
527
  ):
384
528
  """List traces."""
385
- since = datetime.utcnow() - timedelta(hours=hours)
529
+ since = datetime.now(timezone.utc) - timedelta(hours=hours)
386
530
  query = session.query(Trace).filter(Trace.created_at >= since)
387
531
 
388
532
  if status:
@@ -491,4 +635,43 @@ def create_api_router(get_session_context) -> APIRouter:
491
635
  "created_at": span.created_at.isoformat(),
492
636
  }
493
637
 
638
+ @router.get("/background-tasks", response_model=List[BackgroundTaskSummary])
639
+ async def get_background_tasks(
640
+ limit: int = Query(100, ge=1, le=1000),
641
+ offset: int = Query(0, ge=0),
642
+ status: Optional[str] = None,
643
+ request_id: Optional[str] = None,
644
+ session: Session = Depends(get_db),
645
+ ):
646
+ """Get background tasks with optional filters."""
647
+ query = session.query(BackgroundTask)
648
+
649
+ if status:
650
+ query = query.filter(BackgroundTask.status == status)
651
+ if request_id:
652
+ query = query.filter(BackgroundTask.request_id == request_id)
653
+
654
+ tasks = (
655
+ query.order_by(desc(BackgroundTask.created_at))
656
+ .offset(offset)
657
+ .limit(limit)
658
+ .all()
659
+ )
660
+
661
+ return [
662
+ BackgroundTaskSummary(
663
+ id=task.id,
664
+ task_id=task.task_id,
665
+ request_id=task.request_id,
666
+ name=task.name,
667
+ status=task.status,
668
+ start_time=task.start_time,
669
+ end_time=task.end_time,
670
+ duration_ms=round_float(task.duration_ms),
671
+ error=task.error,
672
+ created_at=task.created_at,
673
+ )
674
+ for task in tasks
675
+ ]
676
+
494
677
  return router
@@ -0,0 +1,120 @@
1
+ """Background task monitoring for FastAPI Radar."""
2
+
3
+ import inspect
4
+ import time
5
+ import uuid
6
+ from datetime import datetime, timezone
7
+ from functools import wraps
8
+ from typing import Callable, Any
9
+
10
+ from .models import BackgroundTask
11
+
12
+
13
+ def track_background_task(get_session: Callable):
14
+ """Decorator to track background tasks.
15
+
16
+ Can optionally accept request_id as kwarg:
17
+ background_tasks.add_task(my_task, arg1, request_id="abc123")
18
+ """
19
+
20
+ def decorator(func: Callable) -> Callable:
21
+ @wraps(func)
22
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
23
+ task_id = str(uuid.uuid4())
24
+ # Extract request_id from kwargs if provided
25
+ req_id = kwargs.pop("_radar_request_id", None)
26
+ # Clean task name (just function name, not full module path)
27
+ task_name = func.__name__
28
+
29
+ # Create task record
30
+ with get_session() as session:
31
+ task = BackgroundTask(
32
+ task_id=task_id,
33
+ request_id=req_id,
34
+ name=task_name,
35
+ status="running",
36
+ start_time=datetime.now(timezone.utc),
37
+ )
38
+ session.add(task)
39
+ session.commit()
40
+
41
+ start_time = time.time()
42
+ error = None
43
+
44
+ try:
45
+ result = await func(*args, **kwargs)
46
+ status = "completed"
47
+ return result
48
+ except Exception as e:
49
+ status = "failed"
50
+ error = str(e)
51
+ raise
52
+ finally:
53
+ duration_ms = (time.time() - start_time) * 1000
54
+
55
+ with get_session() as session:
56
+ task = (
57
+ session.query(BackgroundTask)
58
+ .filter(BackgroundTask.task_id == task_id)
59
+ .first()
60
+ )
61
+ if task:
62
+ task.status = status
63
+ task.end_time = datetime.now(timezone.utc)
64
+ task.duration_ms = duration_ms
65
+ task.error = error
66
+ session.commit()
67
+
68
+ @wraps(func)
69
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
70
+ task_id = str(uuid.uuid4())
71
+ # Extract request_id from kwargs if provided
72
+ req_id = kwargs.pop("_radar_request_id", None)
73
+ # Clean task name (just function name, not full module path)
74
+ task_name = func.__name__
75
+
76
+ # Create task record
77
+ with get_session() as session:
78
+ task = BackgroundTask(
79
+ task_id=task_id,
80
+ request_id=req_id,
81
+ name=task_name,
82
+ status="running",
83
+ start_time=datetime.now(timezone.utc),
84
+ )
85
+ session.add(task)
86
+ session.commit()
87
+
88
+ start_time = time.time()
89
+ error = None
90
+
91
+ try:
92
+ result = func(*args, **kwargs)
93
+ status = "completed"
94
+ return result
95
+ except Exception as e:
96
+ status = "failed"
97
+ error = str(e)
98
+ raise
99
+ finally:
100
+ duration_ms = (time.time() - start_time) * 1000
101
+
102
+ with get_session() as session:
103
+ task = (
104
+ session.query(BackgroundTask)
105
+ .filter(BackgroundTask.task_id == task_id)
106
+ .first()
107
+ )
108
+ if task:
109
+ task.status = status
110
+ task.end_time = datetime.now(timezone.utc)
111
+ task.duration_ms = duration_ms
112
+ task.error = error
113
+ session.commit()
114
+
115
+ # Return appropriate wrapper based on function type
116
+ if inspect.iscoroutinefunction(func):
117
+ return async_wrapper
118
+ return sync_wrapper
119
+
120
+ return decorator
fastapi_radar/capture.py CHANGED
@@ -1,13 +1,18 @@
1
1
  """SQLAlchemy query capture for FastAPI Radar."""
2
2
 
3
3
  import time
4
- from typing import Any, Callable, Dict, List, Union
5
-
4
+ from typing import Any, Callable, Dict, List, Optional, Union
6
5
  from sqlalchemy import event
7
6
  from sqlalchemy.engine import Engine
8
7
 
8
+ try: # SQLAlchemy async support is optional
9
+ from sqlalchemy.ext.asyncio import AsyncEngine
10
+ except Exception: # pragma: no cover - module might not exist in older SQLAlchemy
11
+ AsyncEngine = None # type: ignore[assignment]
9
12
  from .middleware import request_context
10
13
  from .models import CapturedQuery
14
+
15
+
11
16
  from .utils import format_sql
12
17
  from .tracing import get_current_trace_context
13
18
 
@@ -23,14 +28,20 @@ class QueryCapture:
23
28
  self.capture_bindings = capture_bindings
24
29
  self.slow_query_threshold = slow_query_threshold
25
30
  self._query_start_times = {}
31
+ self._registered_engines: Dict[int, Engine] = {}
26
32
 
27
33
  def register(self, engine: Engine) -> None:
28
- event.listen(engine, "before_cursor_execute", self._before_cursor_execute)
29
- event.listen(engine, "after_cursor_execute", self._after_cursor_execute)
34
+ sync_engine = self._resolve_engine(engine)
35
+ event.listen(sync_engine, "before_cursor_execute", self._before_cursor_execute)
36
+ event.listen(sync_engine, "after_cursor_execute", self._after_cursor_execute)
37
+ self._registered_engines[id(engine)] = sync_engine
30
38
 
31
39
  def unregister(self, engine: Engine) -> None:
32
- event.remove(engine, "before_cursor_execute", self._before_cursor_execute)
33
- event.remove(engine, "after_cursor_execute", self._after_cursor_execute)
40
+ sync_engine = self._registered_engines.pop(id(engine), None)
41
+ if not sync_engine:
42
+ sync_engine = self._resolve_engine(engine)
43
+ event.remove(sync_engine, "before_cursor_execute", self._before_cursor_execute)
44
+ event.remove(sync_engine, "after_cursor_execute", self._after_cursor_execute)
34
45
 
35
46
  def _before_cursor_execute(
36
47
  self,
@@ -44,14 +55,14 @@ class QueryCapture:
44
55
  request_id = request_context.get()
45
56
  if not request_id:
46
57
  return
47
-
48
58
  context_id = id(context)
49
59
  self._query_start_times[context_id] = time.time()
50
-
60
+ setattr(context, "_radar_request_id", request_id)
51
61
  trace_ctx = get_current_trace_context()
52
62
  if trace_ctx:
53
63
  formatted_sql = format_sql(statement)
54
64
  operation_type = self._get_operation_type(statement)
65
+ db_tags = self._get_db_tags(conn)
55
66
  span_id = trace_ctx.create_span(
56
67
  operation_name=f"DB {operation_type}",
57
68
  span_kind="client",
@@ -59,6 +70,7 @@ class QueryCapture:
59
70
  "db.statement": formatted_sql[:500], # limit SQL length
60
71
  "db.operation_type": operation_type,
61
72
  "component": "database",
73
+ **db_tags,
62
74
  },
63
75
  )
64
76
  setattr(context, "_radar_span_id", span_id)
@@ -74,8 +86,9 @@ class QueryCapture:
74
86
  ) -> None:
75
87
  request_id = request_context.get()
76
88
  if not request_id:
77
- return
78
-
89
+ request_id = getattr(context, "_radar_request_id", None)
90
+ if not request_id:
91
+ return
79
92
  start_time = self._query_start_times.pop(id(context), None)
80
93
  if start_time is None:
81
94
  return
@@ -160,3 +173,18 @@ class QueryCapture:
160
173
  return {k: str(v) for k, v in list(parameters.items())[:100]}
161
174
 
162
175
  return [str(parameters)]
176
+
177
+ def _resolve_engine(self, engine: Engine) -> Engine:
178
+ if AsyncEngine is not None and isinstance(engine, AsyncEngine):
179
+ return engine.sync_engine
180
+ return engine
181
+
182
+ def _get_db_tags(self, conn: Any) -> Dict[str, Optional[str]]:
183
+ tags: Dict[str, Optional[str]] = {}
184
+ engine = getattr(conn, "engine", None)
185
+ if engine and getattr(engine, "dialect", None):
186
+ tags["db.system"] = engine.dialect.name
187
+ url = getattr(engine, "url", None)
188
+ if url is not None:
189
+ tags["db.name"] = getattr(url, "database", None)
190
+ return tags