svc-infra 0.1.600__py3-none-any.whl → 0.1.640__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of svc-infra might be problematic. Click here for more details.

Files changed (118) hide show
  1. svc_infra/api/fastapi/admin/__init__.py +3 -0
  2. svc_infra/api/fastapi/admin/add.py +231 -0
  3. svc_infra/api/fastapi/billing/router.py +64 -0
  4. svc_infra/api/fastapi/billing/setup.py +19 -0
  5. svc_infra/api/fastapi/db/sql/add.py +32 -13
  6. svc_infra/api/fastapi/db/sql/crud_router.py +178 -16
  7. svc_infra/api/fastapi/db/sql/session.py +16 -0
  8. svc_infra/api/fastapi/dependencies/ratelimit.py +57 -7
  9. svc_infra/api/fastapi/docs/add.py +160 -0
  10. svc_infra/api/fastapi/docs/landing.py +1 -1
  11. svc_infra/api/fastapi/middleware/errors/handlers.py +45 -7
  12. svc_infra/api/fastapi/middleware/graceful_shutdown.py +87 -0
  13. svc_infra/api/fastapi/middleware/ratelimit.py +59 -1
  14. svc_infra/api/fastapi/middleware/ratelimit_store.py +12 -6
  15. svc_infra/api/fastapi/middleware/timeout.py +148 -0
  16. svc_infra/api/fastapi/openapi/mutators.py +114 -0
  17. svc_infra/api/fastapi/ops/add.py +73 -0
  18. svc_infra/api/fastapi/pagination.py +3 -1
  19. svc_infra/api/fastapi/routers/ping.py +1 -0
  20. svc_infra/api/fastapi/setup.py +11 -1
  21. svc_infra/api/fastapi/tenancy/add.py +19 -0
  22. svc_infra/api/fastapi/tenancy/context.py +112 -0
  23. svc_infra/app/README.md +5 -5
  24. svc_infra/billing/__init__.py +23 -0
  25. svc_infra/billing/async_service.py +147 -0
  26. svc_infra/billing/jobs.py +230 -0
  27. svc_infra/billing/models.py +131 -0
  28. svc_infra/billing/quotas.py +101 -0
  29. svc_infra/billing/schemas.py +33 -0
  30. svc_infra/billing/service.py +115 -0
  31. svc_infra/bundled_docs/README.md +5 -0
  32. svc_infra/bundled_docs/__init__.py +1 -0
  33. svc_infra/bundled_docs/getting-started.md +6 -0
  34. svc_infra/cache/__init__.py +4 -0
  35. svc_infra/cache/add.py +158 -0
  36. svc_infra/cache/backend.py +5 -2
  37. svc_infra/cache/decorators.py +19 -1
  38. svc_infra/cache/keys.py +24 -4
  39. svc_infra/cli/__init__.py +28 -8
  40. svc_infra/cli/cmds/__init__.py +8 -0
  41. svc_infra/cli/cmds/db/nosql/mongo/mongo_cmds.py +4 -3
  42. svc_infra/cli/cmds/db/nosql/mongo/mongo_scaffold_cmds.py +4 -4
  43. svc_infra/cli/cmds/db/sql/alembic_cmds.py +80 -11
  44. svc_infra/cli/cmds/db/sql/sql_export_cmds.py +80 -0
  45. svc_infra/cli/cmds/db/sql/sql_scaffold_cmds.py +3 -3
  46. svc_infra/cli/cmds/docs/docs_cmds.py +140 -0
  47. svc_infra/cli/cmds/dx/__init__.py +12 -0
  48. svc_infra/cli/cmds/dx/dx_cmds.py +99 -0
  49. svc_infra/cli/cmds/help.py +4 -0
  50. svc_infra/cli/cmds/obs/obs_cmds.py +4 -3
  51. svc_infra/cli/cmds/sdk/__init__.py +0 -0
  52. svc_infra/cli/cmds/sdk/sdk_cmds.py +102 -0
  53. svc_infra/data/add.py +61 -0
  54. svc_infra/data/backup.py +53 -0
  55. svc_infra/data/erasure.py +45 -0
  56. svc_infra/data/fixtures.py +40 -0
  57. svc_infra/data/retention.py +55 -0
  58. svc_infra/db/nosql/mongo/README.md +13 -13
  59. svc_infra/db/sql/repository.py +51 -11
  60. svc_infra/db/sql/resource.py +5 -0
  61. svc_infra/db/sql/templates/setup/env_async.py.tmpl +9 -1
  62. svc_infra/db/sql/templates/setup/env_sync.py.tmpl +9 -2
  63. svc_infra/db/sql/tenant.py +79 -0
  64. svc_infra/db/sql/utils.py +18 -4
  65. svc_infra/docs/acceptance-matrix.md +71 -0
  66. svc_infra/docs/acceptance.md +44 -0
  67. svc_infra/docs/admin.md +425 -0
  68. svc_infra/docs/adr/0002-background-jobs-and-scheduling.md +40 -0
  69. svc_infra/docs/adr/0003-webhooks-framework.md +24 -0
  70. svc_infra/docs/adr/0004-tenancy-model.md +42 -0
  71. svc_infra/docs/adr/0005-data-lifecycle.md +86 -0
  72. svc_infra/docs/adr/0006-ops-slos-and-metrics.md +47 -0
  73. svc_infra/docs/adr/0007-docs-and-sdks.md +83 -0
  74. svc_infra/docs/adr/0008-billing-primitives.md +143 -0
  75. svc_infra/docs/adr/0009-acceptance-harness.md +40 -0
  76. svc_infra/docs/adr/0010-timeouts-and-resource-limits.md +54 -0
  77. svc_infra/docs/adr/0011-admin-scope-and-impersonation.md +73 -0
  78. svc_infra/docs/api.md +59 -0
  79. svc_infra/docs/auth.md +11 -0
  80. svc_infra/docs/billing.md +190 -0
  81. svc_infra/docs/cache.md +76 -0
  82. svc_infra/docs/cli.md +74 -0
  83. svc_infra/docs/contributing.md +34 -0
  84. svc_infra/docs/data-lifecycle.md +52 -0
  85. svc_infra/docs/database.md +14 -0
  86. svc_infra/docs/docs-and-sdks.md +62 -0
  87. svc_infra/docs/environment.md +114 -0
  88. svc_infra/docs/getting-started.md +63 -0
  89. svc_infra/docs/idempotency.md +111 -0
  90. svc_infra/docs/jobs.md +67 -0
  91. svc_infra/docs/observability.md +16 -0
  92. svc_infra/docs/ops.md +37 -0
  93. svc_infra/docs/rate-limiting.md +125 -0
  94. svc_infra/docs/repo-review.md +48 -0
  95. svc_infra/docs/security.md +176 -0
  96. svc_infra/docs/tenancy.md +35 -0
  97. svc_infra/docs/timeouts-and-resource-limits.md +147 -0
  98. svc_infra/docs/webhooks.md +112 -0
  99. svc_infra/dx/add.py +63 -0
  100. svc_infra/dx/changelog.py +74 -0
  101. svc_infra/dx/checks.py +67 -0
  102. svc_infra/http/__init__.py +13 -0
  103. svc_infra/http/client.py +72 -0
  104. svc_infra/jobs/builtins/webhook_delivery.py +14 -2
  105. svc_infra/jobs/queue.py +9 -1
  106. svc_infra/jobs/runner.py +75 -0
  107. svc_infra/jobs/worker.py +17 -1
  108. svc_infra/mcp/svc_infra_mcp.py +85 -28
  109. svc_infra/obs/add.py +54 -7
  110. svc_infra/obs/grafana/dashboards/http-overview.json +45 -0
  111. svc_infra/security/headers.py +15 -2
  112. svc_infra/security/hibp.py +6 -2
  113. svc_infra/security/permissions.py +1 -0
  114. svc_infra/webhooks/service.py +10 -2
  115. {svc_infra-0.1.600.dist-info → svc_infra-0.1.640.dist-info}/METADATA +40 -14
  116. {svc_infra-0.1.600.dist-info → svc_infra-0.1.640.dist-info}/RECORD +118 -44
  117. {svc_infra-0.1.600.dist-info → svc_infra-0.1.640.dist-info}/WHEEL +0 -0
  118. {svc_infra-0.1.600.dist-info → svc_infra-0.1.640.dist-info}/entry_points.txt +0 -0
@@ -1,12 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import time
4
- from typing import Callable
4
+ from typing import Callable, Optional
5
5
 
6
6
  from fastapi import HTTPException
7
7
  from starlette.requests import Request
8
8
 
9
9
  from svc_infra.api.fastapi.middleware.ratelimit_store import InMemoryRateLimitStore, RateLimitStore
10
+
11
+ try:
12
+ from svc_infra.api.fastapi.tenancy.context import resolve_tenant_id as _resolve_tenant_id
13
+ except Exception: # pragma: no cover - minimal builds
14
+ _resolve_tenant_id = None # type: ignore
10
15
  from svc_infra.obs.metrics import emit_rate_limited
11
16
 
12
17
 
@@ -17,20 +22,44 @@ class RateLimiter:
17
22
  limit: int,
18
23
  window: int = 60,
19
24
  key_fn: Callable = lambda r: "global",
25
+ limit_resolver: Optional[Callable[[Request, Optional[str]], Optional[int]]] = None,
26
+ scope_by_tenant: bool = False,
20
27
  store: RateLimitStore | None = None,
21
28
  ):
22
29
  self.limit = limit
23
30
  self.window = window
24
31
  self.key_fn = key_fn
32
+ self._limit_resolver = limit_resolver
33
+ self.scope_by_tenant = scope_by_tenant
25
34
  self.store = store or InMemoryRateLimitStore(limit=limit)
26
35
 
27
36
  async def __call__(self, request: Request):
37
+ # Try resolving tenant when asked
38
+ tenant_id = None
39
+ if self.scope_by_tenant or self._limit_resolver:
40
+ try:
41
+ if _resolve_tenant_id is not None:
42
+ tenant_id = await _resolve_tenant_id(request)
43
+ except Exception:
44
+ tenant_id = None
45
+
28
46
  key = self.key_fn(request)
29
- count, limit, reset = self.store.incr(str(key), self.window)
30
- if count > limit:
47
+ if self.scope_by_tenant and tenant_id:
48
+ key = f"{key}:tenant:{tenant_id}"
49
+
50
+ eff_limit = self.limit
51
+ if self._limit_resolver:
52
+ try:
53
+ v = self._limit_resolver(request, tenant_id)
54
+ eff_limit = int(v) if v is not None else self.limit
55
+ except Exception:
56
+ eff_limit = self.limit
57
+
58
+ count, store_limit, reset = self.store.incr(str(key), self.window)
59
+ if count > eff_limit:
31
60
  retry = max(0, reset - int(time.time()))
32
61
  try:
33
- emit_rate_limited(str(key), limit, retry)
62
+ emit_rate_limited(str(key), eff_limit, retry)
34
63
  except Exception:
35
64
  pass
36
65
  raise HTTPException(
@@ -46,17 +75,38 @@ def rate_limiter(
46
75
  limit: int,
47
76
  window: int = 60,
48
77
  key_fn: Callable = lambda r: "global",
78
+ limit_resolver: Optional[Callable[[Request, Optional[str]], Optional[int]]] = None,
79
+ scope_by_tenant: bool = False,
49
80
  store: RateLimitStore | None = None,
50
81
  ):
51
82
  store_ = store or InMemoryRateLimitStore(limit=limit)
52
83
 
53
84
  async def dep(request: Request):
85
+ tenant_id = None
86
+ if scope_by_tenant or limit_resolver:
87
+ try:
88
+ if _resolve_tenant_id is not None:
89
+ tenant_id = await _resolve_tenant_id(request)
90
+ except Exception:
91
+ tenant_id = None
92
+
54
93
  key = key_fn(request)
55
- count, lim, reset = store_.incr(str(key), window)
56
- if count > lim:
94
+ if scope_by_tenant and tenant_id:
95
+ key = f"{key}:tenant:{tenant_id}"
96
+
97
+ eff_limit = limit
98
+ if limit_resolver:
99
+ try:
100
+ v = limit_resolver(request, tenant_id)
101
+ eff_limit = int(v) if v is not None else limit
102
+ except Exception:
103
+ eff_limit = limit
104
+
105
+ count, _store_limit, reset = store_.incr(str(key), window)
106
+ if count > eff_limit:
57
107
  retry = max(0, reset - int(time.time()))
58
108
  try:
59
- emit_rate_limited(str(key), lim, retry)
109
+ emit_rate_limited(str(key), eff_limit, retry)
60
110
  except Exception:
61
111
  pass
62
112
  raise HTTPException(
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from fastapi import FastAPI, Request
8
+ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
9
+ from fastapi.responses import HTMLResponse, JSONResponse
10
+
11
+ from .landing import CardSpec, DocTargets, render_index_html
12
+ from .scoped import DOC_SCOPES
13
+
14
+
15
+ def add_docs(
16
+ app: FastAPI,
17
+ *,
18
+ redoc_url: str = "/redoc",
19
+ swagger_url: str = "/docs",
20
+ openapi_url: str = "/openapi.json",
21
+ export_openapi_to: Optional[str] = None,
22
+ # Landing page options
23
+ landing_url: str = "/",
24
+ include_landing: bool = True,
25
+ ) -> None:
26
+ """Enable docs endpoints and optionally export OpenAPI schema to disk on startup.
27
+
28
+ We mount docs and OpenAPI routes explicitly so this works even when configured post-init.
29
+ """
30
+
31
+ # OpenAPI JSON route
32
+ async def openapi_handler() -> JSONResponse: # noqa: ANN201
33
+ return JSONResponse(app.openapi())
34
+
35
+ app.add_api_route(openapi_url, openapi_handler, methods=["GET"], include_in_schema=False)
36
+
37
+ # Swagger UI route
38
+ async def swagger_ui(request: Request) -> HTMLResponse: # noqa: ANN201
39
+ resp = get_swagger_ui_html(openapi_url=openapi_url, title="API Docs")
40
+ theme = request.query_params.get("theme")
41
+ if theme == "dark":
42
+ return _with_dark_mode(resp)
43
+ return resp
44
+
45
+ app.add_api_route(swagger_url, swagger_ui, methods=["GET"], include_in_schema=False)
46
+
47
+ # Redoc route
48
+ async def redoc_ui(request: Request) -> HTMLResponse: # noqa: ANN201
49
+ resp = get_redoc_html(openapi_url=openapi_url, title="API ReDoc")
50
+ theme = request.query_params.get("theme")
51
+ if theme == "dark":
52
+ return _with_dark_mode(resp)
53
+ return resp
54
+
55
+ app.add_api_route(redoc_url, redoc_ui, methods=["GET"], include_in_schema=False)
56
+
57
+ # Optional export to disk on startup
58
+ if export_openapi_to:
59
+ export_path = Path(export_openapi_to)
60
+
61
+ async def _export_docs() -> None:
62
+ # Startup export
63
+ spec = app.openapi()
64
+ export_path.parent.mkdir(parents=True, exist_ok=True)
65
+ export_path.write_text(json.dumps(spec, indent=2))
66
+
67
+ app.add_event_handler("startup", _export_docs)
68
+
69
+ # Optional landing page with the same look/feel as setup_service_api
70
+ if include_landing:
71
+ # Avoid path collision; if landing_url is already taken for GET, fallback to "/_docs"
72
+ existing_paths = {
73
+ (getattr(r, "path", None) or getattr(r, "path_format", None))
74
+ for r in getattr(app, "routes", [])
75
+ if getattr(r, "methods", None) and "GET" in r.methods
76
+ }
77
+ landing_path = landing_url or "/"
78
+ if landing_path in existing_paths:
79
+ landing_path = "/_docs"
80
+
81
+ async def _landing() -> HTMLResponse: # noqa: ANN201
82
+ cards: list[CardSpec] = []
83
+ # Root docs card using the provided paths
84
+ cards.append(
85
+ CardSpec(
86
+ tag="",
87
+ docs=DocTargets(swagger=swagger_url, redoc=redoc_url, openapi_json=openapi_url),
88
+ )
89
+ )
90
+ # Scoped docs (if any were registered via add_prefixed_docs)
91
+ for scope, swagger, redoc, openapi_json, _title in DOC_SCOPES:
92
+ cards.append(
93
+ CardSpec(
94
+ tag=scope.strip("/"),
95
+ docs=DocTargets(swagger=swagger, redoc=redoc, openapi_json=openapi_json),
96
+ )
97
+ )
98
+ html = render_index_html(
99
+ service_name=app.title or "API", release=app.version or "", cards=cards
100
+ )
101
+ return HTMLResponse(html)
102
+
103
+ app.add_api_route(landing_path, _landing, methods=["GET"], include_in_schema=False)
104
+
105
+
106
+ def _with_dark_mode(resp: HTMLResponse) -> HTMLResponse:
107
+ """Return a copy of the HTMLResponse with a minimal dark-theme CSS injected.
108
+
109
+ We avoid depending on custom Swagger/ReDoc builds; this works by inlining a small CSS
110
+ block and toggling a `.dark` class on the body element.
111
+ """
112
+ try:
113
+ body = resp.body.decode("utf-8", errors="ignore")
114
+ except Exception: # pragma: no cover - very unlikely
115
+ return resp
116
+
117
+ css = _DARK_CSS
118
+ if "</head>" in body:
119
+ body = body.replace("</head>", f"<style>\n{css}\n</style></head>", 1)
120
+ # add class to body to allow stronger selectors
121
+ body = body.replace("<body>", '<body class="dark">', 1)
122
+ return HTMLResponse(content=body, status_code=resp.status_code, headers=dict(resp.headers))
123
+
124
+
125
+ _DARK_CSS = """
126
+ /* Minimal dark mode override for Swagger/ReDoc */
127
+ @media (prefers-color-scheme: dark) { :root { color-scheme: dark; } }
128
+ html.dark, body.dark { background: #0b0e14; color: #e0e6f1; }
129
+ #swagger, .redoc-wrap { background: transparent; }
130
+ a { color: #62aef7; }
131
+ """
132
+
133
+
134
+ def add_sdk_generation_stub(
135
+ app: FastAPI,
136
+ *,
137
+ on_generate: Optional[callable] = None,
138
+ openapi_path: str = "/openapi.json",
139
+ ) -> None:
140
+ """Hook to add an SDK generation stub.
141
+
142
+ Provide `on_generate()` to run generation (e.g., openapi-generator). This is a stub only; we
143
+ don't ship a hard dependency. If `on_generate` is provided, we expose `/_docs/generate-sdk`.
144
+ """
145
+ from svc_infra.api.fastapi.dual.public import public_router
146
+
147
+ if not on_generate:
148
+ return
149
+
150
+ router = public_router(prefix="/_docs", include_in_schema=False)
151
+
152
+ @router.post("/generate-sdk")
153
+ async def _generate() -> dict: # noqa: ANN201
154
+ on_generate()
155
+ return {"status": "ok"}
156
+
157
+ app.include_router(router)
158
+
159
+
160
+ __all__ = ["add_docs", "add_sdk_generation_stub"]
@@ -115,7 +115,7 @@ def render_index_html(*, service_name: str, release: str, cards: Iterable[CardSp
115
115
  <section class="grid">
116
116
  {grid}
117
117
  </section>
118
- <footer>Tip: each card exposes Swagger, ReDoc, and a pretty JSON view.</footer>
118
+ <footer>Tip: each card exposes Swagger, ReDoc, and a JSON view.</footer>
119
119
  </div>
120
120
  </body>
121
121
  </html>
@@ -2,6 +2,7 @@ import logging
2
2
  import traceback
3
3
  from typing import Any, Dict, Optional
4
4
 
5
+ import httpx
5
6
  from fastapi import Request
6
7
  from fastapi.exceptions import HTTPException, RequestValidationError
7
8
  from fastapi.responses import JSONResponse, Response
@@ -46,6 +47,7 @@ def problem_response(
46
47
  code: str | None = None,
47
48
  errors: list[dict] | None = None,
48
49
  trace_id: str | None = None,
50
+ headers: dict[str, str] | None = None,
49
51
  ) -> Response:
50
52
  body: Dict[str, Any] = {
51
53
  "type": type_uri,
@@ -62,10 +64,24 @@ def problem_response(
62
64
  body["errors"] = errors
63
65
  if trace_id:
64
66
  body["trace_id"] = trace_id
65
- return JSONResponse(status_code=status, content=body, media_type=PROBLEM_MT)
67
+ return JSONResponse(status_code=status, content=body, media_type=PROBLEM_MT, headers=headers)
66
68
 
67
69
 
68
70
  def register_error_handlers(app):
71
+ @app.exception_handler(httpx.TimeoutException)
72
+ async def handle_httpx_timeout(request: Request, exc: httpx.TimeoutException):
73
+ trace_id = _trace_id_from_request(request)
74
+ # Map outbound HTTP client timeouts to 504 Gateway Timeout
75
+ # Keep details generic in prod
76
+ return problem_response(
77
+ status=504,
78
+ title="Gateway Timeout",
79
+ detail=("Upstream request timed out." if IS_PROD else (str(exc) or "httpx timeout")),
80
+ code="GATEWAY_TIMEOUT",
81
+ instance=str(request.url),
82
+ trace_id=trace_id,
83
+ )
84
+
69
85
  @app.exception_handler(FastApiException)
70
86
  async def handle_app_exception(request: Request, exc: FastApiException):
71
87
  trace_id = _trace_id_from_request(request)
@@ -104,14 +120,25 @@ def register_error_handlers(app):
104
120
  @app.exception_handler(HTTPException)
105
121
  async def handle_http_exception(request: Request, exc: HTTPException):
106
122
  trace_id = _trace_id_from_request(request)
107
- title = {401: "Unauthorized", 403: "Forbidden", 404: "Not Found"}.get(
108
- exc.status_code, "Error"
109
- )
123
+ title = {
124
+ 401: "Unauthorized",
125
+ 403: "Forbidden",
126
+ 404: "Not Found",
127
+ 429: "Too Many Requests",
128
+ }.get(exc.status_code, "Error")
110
129
  detail = (
111
130
  exc.detail
112
131
  if not IS_PROD or exc.status_code < 500
113
132
  else "Something went wrong. Please contact support."
114
133
  )
134
+ # Preserve headers set on the exception (e.g., Retry-After for rate limits)
135
+ hdrs: dict[str, str] | None = None
136
+ try:
137
+ if getattr(exc, "headers", None):
138
+ # FastAPI/Starlette exceptions store headers as a dict[str, str]
139
+ hdrs = dict(getattr(exc, "headers")) # type: ignore[arg-type]
140
+ except Exception:
141
+ hdrs = None
115
142
  return problem_response(
116
143
  status=exc.status_code,
117
144
  title=title,
@@ -119,19 +146,29 @@ def register_error_handlers(app):
119
146
  code=title.replace(" ", "_").upper(),
120
147
  instance=str(request.url),
121
148
  trace_id=trace_id,
149
+ headers=hdrs,
122
150
  )
123
151
 
124
152
  @app.exception_handler(StarletteHTTPException)
125
153
  async def handle_starlette_http_exception(request: Request, exc: StarletteHTTPException):
126
154
  trace_id = _trace_id_from_request(request)
127
- title = {401: "Unauthorized", 403: "Forbidden", 404: "Not Found"}.get(
128
- exc.status_code, "Error"
129
- )
155
+ title = {
156
+ 401: "Unauthorized",
157
+ 403: "Forbidden",
158
+ 404: "Not Found",
159
+ 429: "Too Many Requests",
160
+ }.get(exc.status_code, "Error")
130
161
  detail = (
131
162
  exc.detail
132
163
  if not IS_PROD or exc.status_code < 500
133
164
  else "Something went wrong. Please contact support."
134
165
  )
166
+ hdrs: dict[str, str] | None = None
167
+ try:
168
+ if getattr(exc, "headers", None):
169
+ hdrs = dict(getattr(exc, "headers")) # type: ignore[arg-type]
170
+ except Exception:
171
+ hdrs = None
135
172
  return problem_response(
136
173
  status=exc.status_code,
137
174
  title=title,
@@ -139,6 +176,7 @@ def register_error_handlers(app):
139
176
  code=title.replace(" ", "_").upper(),
140
177
  instance=str(request.url),
141
178
  trace_id=trace_id,
179
+ headers=hdrs,
142
180
  )
143
181
 
144
182
  @app.exception_handler(IntegrityError)
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ from contextlib import asynccontextmanager
7
+ from typing import Optional
8
+
9
+ from fastapi import FastAPI
10
+ from starlette.types import ASGIApp, Receive, Scope, Send
11
+
12
+ from svc_infra.app.env import pick
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _get_grace_period_seconds() -> float:
18
+ default = pick(prod=20.0, nonprod=5.0)
19
+ raw = os.getenv("SHUTDOWN_GRACE_PERIOD_SECONDS")
20
+ if raw is None or raw == "":
21
+ return float(default)
22
+ try:
23
+ return float(raw)
24
+ except ValueError:
25
+ return float(default)
26
+
27
+
28
+ class InflightTrackerMiddleware:
29
+ """Tracks number of in-flight requests to support graceful shutdown drains."""
30
+
31
+ def __init__(self, app: ASGIApp):
32
+ self.app = app
33
+
34
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
35
+ if scope.get("type") != "http":
36
+ await self.app(scope, receive, send)
37
+ return
38
+ state = scope.get("app").state # type: ignore[attr-defined]
39
+ state._inflight_requests = getattr(state, "_inflight_requests", 0) + 1
40
+ try:
41
+ await self.app(scope, receive, send)
42
+ finally:
43
+ state._inflight_requests = max(0, getattr(state, "_inflight_requests", 1) - 1)
44
+
45
+
46
+ async def _wait_for_drain(app: FastAPI, grace: float) -> None:
47
+ interval = 0.1
48
+ waited = 0.0
49
+ while waited < grace:
50
+ inflight = int(getattr(app.state, "_inflight_requests", 0))
51
+ if inflight <= 0:
52
+ return
53
+ await asyncio.sleep(interval)
54
+ waited += interval
55
+ inflight = int(getattr(app.state, "_inflight_requests", 0))
56
+ if inflight > 0:
57
+ logger.warning(
58
+ "Graceful shutdown timeout: %s in-flight request(s) after %.2fs", inflight, waited
59
+ )
60
+
61
+
62
+ def install_graceful_shutdown(app: FastAPI, *, grace_seconds: Optional[float] = None) -> None:
63
+ """Install inflight tracking and lifespan hooks to wait for requests to drain.
64
+
65
+ - Adds InflightTrackerMiddleware
66
+ - Registers a lifespan handler that initializes state and waits up to grace_seconds on shutdown
67
+ """
68
+ app.add_middleware(InflightTrackerMiddleware)
69
+
70
+ g = float(grace_seconds) if grace_seconds is not None else _get_grace_period_seconds()
71
+
72
+ # Preserve any existing lifespan and wrap it so our drain runs on shutdown.
73
+ previous_lifespan = getattr(app.router, "lifespan_context", None)
74
+
75
+ @asynccontextmanager
76
+ async def _lifespan(a: FastAPI): # noqa: ANN202
77
+ # Startup: initialize inflight counter
78
+ a.state._inflight_requests = 0
79
+ if previous_lifespan is not None:
80
+ async with previous_lifespan(a):
81
+ yield
82
+ else:
83
+ yield
84
+ # Shutdown: wait for in-flight requests to drain (up to grace period)
85
+ await _wait_for_drain(a, g)
86
+
87
+ app.router.lifespan_context = _lifespan
@@ -7,6 +7,12 @@ from svc_infra.obs.metrics import emit_rate_limited
7
7
 
8
8
  from .ratelimit_store import InMemoryRateLimitStore, RateLimitStore
9
9
 
10
+ try:
11
+ # Optional import: tenancy may not be enabled in all apps
12
+ from svc_infra.api.fastapi.tenancy.context import resolve_tenant_id as _resolve_tenant_id
13
+ except Exception: # pragma: no cover - fallback for minimal builds
14
+ _resolve_tenant_id = None # type: ignore
15
+
10
16
 
11
17
  class SimpleRateLimitMiddleware(BaseHTTPMiddleware):
12
18
  def __init__(
@@ -15,18 +21,70 @@ class SimpleRateLimitMiddleware(BaseHTTPMiddleware):
15
21
  limit: int = 120,
16
22
  window: int = 60,
17
23
  key_fn=None,
24
+ *,
25
+ # When provided, dynamically computes a limit for the current request (e.g. per-tenant quotas)
26
+ # Signature: (request: Request, tenant_id: Optional[str]) -> int | None
27
+ limit_resolver=None,
28
+ # If True, automatically scopes the bucket key by tenant id when available
29
+ scope_by_tenant: bool = False,
30
+ # When True, allows unresolved tenant IDs to fall back to an "X-Tenant-Id" header value.
31
+ # Disabled by default to avoid trusting arbitrary client-provided headers which could
32
+ # otherwise be used to evade per-tenant limits when authentication fails.
33
+ allow_untrusted_tenant_header: bool = False,
18
34
  store: RateLimitStore | None = None,
19
35
  ):
20
36
  super().__init__(app)
21
37
  self.limit, self.window = limit, window
22
38
  self.key_fn = key_fn or (lambda r: r.headers.get("X-API-Key") or r.client.host)
39
+ self._limit_resolver = limit_resolver
40
+ self.scope_by_tenant = scope_by_tenant
41
+ self._allow_untrusted_tenant_header = allow_untrusted_tenant_header
23
42
  self.store = store or InMemoryRateLimitStore(limit=limit)
24
43
 
25
44
  async def dispatch(self, request, call_next):
45
+ # Resolve tenant when possible
46
+ tenant_id = None
47
+ if self.scope_by_tenant or self._limit_resolver:
48
+ try:
49
+ if _resolve_tenant_id is not None:
50
+ tenant_id = await _resolve_tenant_id(request)
51
+ except Exception:
52
+ tenant_id = None
53
+ # Fallback header behavior:
54
+ # - If tenancy context is unavailable (minimal builds), accept header by default so
55
+ # unit/integration tests can exercise per-tenant scoping without full auth state.
56
+ # - If tenancy is available, only trust the header when explicitly allowed.
57
+ if not tenant_id:
58
+ if _resolve_tenant_id is None:
59
+ tenant_id = request.headers.get("X-Tenant-Id") or request.headers.get(
60
+ "X-Tenant-ID"
61
+ )
62
+ elif self._allow_untrusted_tenant_header:
63
+ tenant_id = request.headers.get("X-Tenant-Id") or request.headers.get(
64
+ "X-Tenant-ID"
65
+ )
66
+
26
67
  key = self.key_fn(request)
68
+ if self.scope_by_tenant and tenant_id:
69
+ key = f"{key}:tenant:{tenant_id}"
70
+
71
+ # Allow dynamic limit overrides
72
+ eff_limit = self.limit
73
+ if self._limit_resolver:
74
+ try:
75
+ v = self._limit_resolver(request, tenant_id)
76
+ eff_limit = int(v) if v is not None else self.limit
77
+ except Exception:
78
+ eff_limit = self.limit
79
+
27
80
  now = int(time.time())
28
81
  # Increment counter in store
29
- count, limit, reset = self.store.incr(str(key), self.window)
82
+ # Update store limit if it differs; stores capture configured limit internally
83
+ # For in-memory store, we can temporarily adjust per-request by swapping a new store instance
84
+ # but to keep API simple, we reuse store and clamp by eff_limit below.
85
+ count, store_limit, reset = self.store.incr(str(key), self.window)
86
+ # Enforce the effective limit selected for this request
87
+ limit = eff_limit
30
88
  remaining = max(0, limit - count)
31
89
 
32
90
  if remaining < 0: # defensive clamp
@@ -16,14 +16,20 @@ class RateLimitStore(Protocol):
16
16
  class InMemoryRateLimitStore:
17
17
  def __init__(self, limit: int = 120):
18
18
  self.limit = limit
19
- self._buckets: dict[tuple[str, int], int] = {}
19
+ # Track per-key rolling windows: key -> (count, window_start_epoch)
20
+ self._state: dict[str, tuple[int, float]] = {}
20
21
 
21
22
  def incr(self, key: str, window: int) -> Tuple[int, int, int]:
22
- now = int(time.time())
23
- win = now - (now % window)
24
- count = self._buckets.get((key, win), 0) + 1
25
- self._buckets[(key, win)] = count
26
- reset = win + window
23
+ now = time.time()
24
+ count, window_start = self._state.get(key, (0, now))
25
+ # If outside the rolling window, reset
26
+ if now >= window_start + window:
27
+ count = 1
28
+ window_start = now
29
+ else:
30
+ count += 1
31
+ self._state[key] = (count, window_start)
32
+ reset = int(window_start + window)
27
33
  return count, self.limit, reset
28
34
 
29
35