svc-infra 0.1.595__py3-none-any.whl → 0.1.706__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of svc-infra might be problematic. Click here for more details.

Files changed (256) hide show
  1. svc_infra/__init__.py +58 -2
  2. svc_infra/apf_payments/models.py +133 -42
  3. svc_infra/apf_payments/provider/aiydan.py +121 -47
  4. svc_infra/apf_payments/provider/base.py +30 -9
  5. svc_infra/apf_payments/provider/stripe.py +156 -62
  6. svc_infra/apf_payments/schemas.py +18 -9
  7. svc_infra/apf_payments/service.py +98 -41
  8. svc_infra/apf_payments/settings.py +5 -1
  9. svc_infra/api/__init__.py +61 -0
  10. svc_infra/api/fastapi/__init__.py +15 -0
  11. svc_infra/api/fastapi/admin/__init__.py +3 -0
  12. svc_infra/api/fastapi/admin/add.py +245 -0
  13. svc_infra/api/fastapi/apf_payments/router.py +128 -70
  14. svc_infra/api/fastapi/apf_payments/setup.py +13 -6
  15. svc_infra/api/fastapi/auth/__init__.py +65 -0
  16. svc_infra/api/fastapi/auth/_cookies.py +6 -2
  17. svc_infra/api/fastapi/auth/add.py +17 -14
  18. svc_infra/api/fastapi/auth/gaurd.py +45 -16
  19. svc_infra/api/fastapi/auth/mfa/models.py +3 -1
  20. svc_infra/api/fastapi/auth/mfa/pre_auth.py +10 -6
  21. svc_infra/api/fastapi/auth/mfa/router.py +15 -8
  22. svc_infra/api/fastapi/auth/mfa/security.py +1 -2
  23. svc_infra/api/fastapi/auth/mfa/utils.py +2 -1
  24. svc_infra/api/fastapi/auth/mfa/verify.py +9 -2
  25. svc_infra/api/fastapi/auth/policy.py +0 -1
  26. svc_infra/api/fastapi/auth/providers.py +3 -1
  27. svc_infra/api/fastapi/auth/routers/apikey_router.py +6 -6
  28. svc_infra/api/fastapi/auth/routers/oauth_router.py +146 -52
  29. svc_infra/api/fastapi/auth/routers/session_router.py +6 -2
  30. svc_infra/api/fastapi/auth/security.py +31 -10
  31. svc_infra/api/fastapi/auth/sender.py +8 -1
  32. svc_infra/api/fastapi/auth/state.py +3 -1
  33. svc_infra/api/fastapi/auth/ws_security.py +275 -0
  34. svc_infra/api/fastapi/billing/router.py +73 -0
  35. svc_infra/api/fastapi/billing/setup.py +19 -0
  36. svc_infra/api/fastapi/cache/add.py +9 -5
  37. svc_infra/api/fastapi/db/__init__.py +5 -1
  38. svc_infra/api/fastapi/db/http.py +3 -1
  39. svc_infra/api/fastapi/db/nosql/__init__.py +39 -1
  40. svc_infra/api/fastapi/db/nosql/mongo/add.py +47 -32
  41. svc_infra/api/fastapi/db/nosql/mongo/crud_router.py +30 -11
  42. svc_infra/api/fastapi/db/sql/__init__.py +5 -1
  43. svc_infra/api/fastapi/db/sql/add.py +71 -26
  44. svc_infra/api/fastapi/db/sql/crud_router.py +210 -22
  45. svc_infra/api/fastapi/db/sql/health.py +3 -1
  46. svc_infra/api/fastapi/db/sql/session.py +18 -0
  47. svc_infra/api/fastapi/db/sql/users.py +18 -6
  48. svc_infra/api/fastapi/dependencies/ratelimit.py +78 -14
  49. svc_infra/api/fastapi/docs/add.py +173 -0
  50. svc_infra/api/fastapi/docs/landing.py +4 -2
  51. svc_infra/api/fastapi/docs/scoped.py +62 -15
  52. svc_infra/api/fastapi/dual/__init__.py +12 -2
  53. svc_infra/api/fastapi/dual/dualize.py +1 -1
  54. svc_infra/api/fastapi/dual/protected.py +126 -4
  55. svc_infra/api/fastapi/dual/public.py +25 -0
  56. svc_infra/api/fastapi/dual/router.py +40 -13
  57. svc_infra/api/fastapi/dx.py +33 -2
  58. svc_infra/api/fastapi/ease.py +10 -2
  59. svc_infra/api/fastapi/http/concurrency.py +2 -1
  60. svc_infra/api/fastapi/http/conditional.py +3 -1
  61. svc_infra/api/fastapi/middleware/debug.py +4 -1
  62. svc_infra/api/fastapi/middleware/errors/catchall.py +6 -2
  63. svc_infra/api/fastapi/middleware/errors/exceptions.py +1 -1
  64. svc_infra/api/fastapi/middleware/errors/handlers.py +54 -8
  65. svc_infra/api/fastapi/middleware/graceful_shutdown.py +104 -0
  66. svc_infra/api/fastapi/middleware/idempotency.py +197 -70
  67. svc_infra/api/fastapi/middleware/idempotency_store.py +187 -0
  68. svc_infra/api/fastapi/middleware/optimistic_lock.py +42 -0
  69. svc_infra/api/fastapi/middleware/ratelimit.py +125 -28
  70. svc_infra/api/fastapi/middleware/ratelimit_store.py +43 -10
  71. svc_infra/api/fastapi/middleware/request_id.py +27 -11
  72. svc_infra/api/fastapi/middleware/request_size_limit.py +3 -3
  73. svc_infra/api/fastapi/middleware/timeout.py +177 -0
  74. svc_infra/api/fastapi/openapi/apply.py +5 -3
  75. svc_infra/api/fastapi/openapi/conventions.py +9 -2
  76. svc_infra/api/fastapi/openapi/mutators.py +165 -20
  77. svc_infra/api/fastapi/openapi/pipeline.py +1 -1
  78. svc_infra/api/fastapi/openapi/security.py +3 -1
  79. svc_infra/api/fastapi/ops/add.py +75 -0
  80. svc_infra/api/fastapi/pagination.py +47 -20
  81. svc_infra/api/fastapi/routers/__init__.py +43 -15
  82. svc_infra/api/fastapi/routers/ping.py +1 -0
  83. svc_infra/api/fastapi/setup.py +188 -57
  84. svc_infra/api/fastapi/tenancy/add.py +19 -0
  85. svc_infra/api/fastapi/tenancy/context.py +112 -0
  86. svc_infra/api/fastapi/versioned.py +101 -0
  87. svc_infra/app/README.md +5 -5
  88. svc_infra/app/__init__.py +3 -1
  89. svc_infra/app/env.py +69 -1
  90. svc_infra/app/logging/add.py +9 -2
  91. svc_infra/app/logging/formats.py +12 -5
  92. svc_infra/billing/__init__.py +23 -0
  93. svc_infra/billing/async_service.py +147 -0
  94. svc_infra/billing/jobs.py +241 -0
  95. svc_infra/billing/models.py +177 -0
  96. svc_infra/billing/quotas.py +103 -0
  97. svc_infra/billing/schemas.py +36 -0
  98. svc_infra/billing/service.py +123 -0
  99. svc_infra/bundled_docs/README.md +5 -0
  100. svc_infra/bundled_docs/__init__.py +1 -0
  101. svc_infra/bundled_docs/getting-started.md +6 -0
  102. svc_infra/cache/__init__.py +9 -0
  103. svc_infra/cache/add.py +170 -0
  104. svc_infra/cache/backend.py +7 -6
  105. svc_infra/cache/decorators.py +81 -15
  106. svc_infra/cache/demo.py +2 -2
  107. svc_infra/cache/keys.py +24 -4
  108. svc_infra/cache/recache.py +26 -14
  109. svc_infra/cache/resources.py +14 -5
  110. svc_infra/cache/tags.py +19 -44
  111. svc_infra/cache/utils.py +3 -1
  112. svc_infra/cli/__init__.py +52 -8
  113. svc_infra/cli/__main__.py +4 -0
  114. svc_infra/cli/cmds/__init__.py +39 -2
  115. svc_infra/cli/cmds/db/nosql/mongo/mongo_cmds.py +7 -4
  116. svc_infra/cli/cmds/db/nosql/mongo/mongo_scaffold_cmds.py +7 -5
  117. svc_infra/cli/cmds/db/ops_cmds.py +270 -0
  118. svc_infra/cli/cmds/db/sql/alembic_cmds.py +103 -18
  119. svc_infra/cli/cmds/db/sql/sql_export_cmds.py +88 -0
  120. svc_infra/cli/cmds/db/sql/sql_scaffold_cmds.py +3 -3
  121. svc_infra/cli/cmds/docs/docs_cmds.py +142 -0
  122. svc_infra/cli/cmds/dx/__init__.py +12 -0
  123. svc_infra/cli/cmds/dx/dx_cmds.py +116 -0
  124. svc_infra/cli/cmds/health/__init__.py +179 -0
  125. svc_infra/cli/cmds/health/health_cmds.py +8 -0
  126. svc_infra/cli/cmds/help.py +4 -0
  127. svc_infra/cli/cmds/jobs/__init__.py +1 -0
  128. svc_infra/cli/cmds/jobs/jobs_cmds.py +47 -0
  129. svc_infra/cli/cmds/obs/obs_cmds.py +36 -15
  130. svc_infra/cli/cmds/sdk/__init__.py +0 -0
  131. svc_infra/cli/cmds/sdk/sdk_cmds.py +112 -0
  132. svc_infra/cli/foundation/runner.py +6 -2
  133. svc_infra/data/add.py +61 -0
  134. svc_infra/data/backup.py +58 -0
  135. svc_infra/data/erasure.py +45 -0
  136. svc_infra/data/fixtures.py +42 -0
  137. svc_infra/data/retention.py +61 -0
  138. svc_infra/db/__init__.py +15 -0
  139. svc_infra/db/crud_schema.py +9 -9
  140. svc_infra/db/inbox.py +67 -0
  141. svc_infra/db/nosql/__init__.py +3 -0
  142. svc_infra/db/nosql/core.py +30 -9
  143. svc_infra/db/nosql/indexes.py +3 -1
  144. svc_infra/db/nosql/management.py +1 -1
  145. svc_infra/db/nosql/mongo/README.md +13 -13
  146. svc_infra/db/nosql/mongo/client.py +19 -2
  147. svc_infra/db/nosql/mongo/settings.py +6 -2
  148. svc_infra/db/nosql/repository.py +35 -15
  149. svc_infra/db/nosql/resource.py +20 -3
  150. svc_infra/db/nosql/scaffold.py +9 -3
  151. svc_infra/db/nosql/service.py +3 -1
  152. svc_infra/db/nosql/types.py +6 -2
  153. svc_infra/db/ops.py +384 -0
  154. svc_infra/db/outbox.py +108 -0
  155. svc_infra/db/sql/apikey.py +37 -9
  156. svc_infra/db/sql/authref.py +9 -3
  157. svc_infra/db/sql/constants.py +12 -8
  158. svc_infra/db/sql/core.py +2 -2
  159. svc_infra/db/sql/management.py +11 -8
  160. svc_infra/db/sql/repository.py +99 -26
  161. svc_infra/db/sql/resource.py +5 -0
  162. svc_infra/db/sql/scaffold.py +6 -2
  163. svc_infra/db/sql/service.py +15 -5
  164. svc_infra/db/sql/templates/models_schemas/auth/models.py.tmpl +7 -56
  165. svc_infra/db/sql/templates/setup/env_async.py.tmpl +34 -12
  166. svc_infra/db/sql/templates/setup/env_sync.py.tmpl +29 -7
  167. svc_infra/db/sql/tenant.py +88 -0
  168. svc_infra/db/sql/uniq_hooks.py +9 -3
  169. svc_infra/db/sql/utils.py +138 -51
  170. svc_infra/db/sql/versioning.py +14 -0
  171. svc_infra/deploy/__init__.py +538 -0
  172. svc_infra/documents/__init__.py +100 -0
  173. svc_infra/documents/add.py +264 -0
  174. svc_infra/documents/ease.py +233 -0
  175. svc_infra/documents/models.py +114 -0
  176. svc_infra/documents/storage.py +264 -0
  177. svc_infra/dx/add.py +65 -0
  178. svc_infra/dx/changelog.py +74 -0
  179. svc_infra/dx/checks.py +68 -0
  180. svc_infra/exceptions.py +141 -0
  181. svc_infra/health/__init__.py +864 -0
  182. svc_infra/http/__init__.py +13 -0
  183. svc_infra/http/client.py +105 -0
  184. svc_infra/jobs/builtins/outbox_processor.py +40 -0
  185. svc_infra/jobs/builtins/webhook_delivery.py +95 -0
  186. svc_infra/jobs/easy.py +33 -0
  187. svc_infra/jobs/loader.py +50 -0
  188. svc_infra/jobs/queue.py +116 -0
  189. svc_infra/jobs/redis_queue.py +256 -0
  190. svc_infra/jobs/runner.py +79 -0
  191. svc_infra/jobs/scheduler.py +53 -0
  192. svc_infra/jobs/worker.py +40 -0
  193. svc_infra/loaders/__init__.py +186 -0
  194. svc_infra/loaders/base.py +142 -0
  195. svc_infra/loaders/github.py +311 -0
  196. svc_infra/loaders/models.py +147 -0
  197. svc_infra/loaders/url.py +235 -0
  198. svc_infra/logging/__init__.py +374 -0
  199. svc_infra/mcp/svc_infra_mcp.py +91 -33
  200. svc_infra/obs/README.md +2 -0
  201. svc_infra/obs/add.py +65 -9
  202. svc_infra/obs/cloud_dash.py +2 -1
  203. svc_infra/obs/grafana/dashboards/http-overview.json +45 -0
  204. svc_infra/obs/metrics/__init__.py +3 -4
  205. svc_infra/obs/metrics/asgi.py +13 -7
  206. svc_infra/obs/metrics/http.py +9 -5
  207. svc_infra/obs/metrics/sqlalchemy.py +13 -9
  208. svc_infra/obs/metrics.py +6 -5
  209. svc_infra/obs/settings.py +6 -2
  210. svc_infra/security/add.py +217 -0
  211. svc_infra/security/audit.py +92 -10
  212. svc_infra/security/audit_service.py +4 -3
  213. svc_infra/security/headers.py +15 -2
  214. svc_infra/security/hibp.py +14 -4
  215. svc_infra/security/jwt_rotation.py +74 -22
  216. svc_infra/security/lockout.py +11 -5
  217. svc_infra/security/models.py +54 -12
  218. svc_infra/security/oauth_models.py +73 -0
  219. svc_infra/security/org_invites.py +5 -3
  220. svc_infra/security/passwords.py +3 -1
  221. svc_infra/security/permissions.py +25 -2
  222. svc_infra/security/session.py +1 -1
  223. svc_infra/security/signed_cookies.py +21 -1
  224. svc_infra/storage/__init__.py +93 -0
  225. svc_infra/storage/add.py +253 -0
  226. svc_infra/storage/backends/__init__.py +11 -0
  227. svc_infra/storage/backends/local.py +339 -0
  228. svc_infra/storage/backends/memory.py +216 -0
  229. svc_infra/storage/backends/s3.py +353 -0
  230. svc_infra/storage/base.py +239 -0
  231. svc_infra/storage/easy.py +185 -0
  232. svc_infra/storage/settings.py +195 -0
  233. svc_infra/testing/__init__.py +685 -0
  234. svc_infra/utils.py +7 -3
  235. svc_infra/webhooks/__init__.py +69 -0
  236. svc_infra/webhooks/add.py +339 -0
  237. svc_infra/webhooks/encryption.py +115 -0
  238. svc_infra/webhooks/fastapi.py +39 -0
  239. svc_infra/webhooks/router.py +55 -0
  240. svc_infra/webhooks/service.py +70 -0
  241. svc_infra/webhooks/signing.py +34 -0
  242. svc_infra/websocket/__init__.py +79 -0
  243. svc_infra/websocket/add.py +140 -0
  244. svc_infra/websocket/client.py +282 -0
  245. svc_infra/websocket/config.py +69 -0
  246. svc_infra/websocket/easy.py +76 -0
  247. svc_infra/websocket/exceptions.py +61 -0
  248. svc_infra/websocket/manager.py +344 -0
  249. svc_infra/websocket/models.py +49 -0
  250. svc_infra-0.1.706.dist-info/LICENSE +21 -0
  251. svc_infra-0.1.706.dist-info/METADATA +356 -0
  252. svc_infra-0.1.706.dist-info/RECORD +357 -0
  253. svc_infra-0.1.595.dist-info/METADATA +0 -80
  254. svc_infra-0.1.595.dist-info/RECORD +0 -253
  255. {svc_infra-0.1.595.dist-info → svc_infra-0.1.706.dist-info}/WHEEL +0 -0
  256. {svc_infra-0.1.595.dist-info → svc_infra-0.1.706.dist-info}/entry_points.txt +0 -0
@@ -1,61 +1,158 @@
1
+ import json
1
2
  import time
2
3
 
3
- from starlette.middleware.base import BaseHTTPMiddleware
4
- from starlette.responses import JSONResponse
4
+ from fastapi import Request
5
+ from starlette.types import ASGIApp, Receive, Scope, Send
5
6
 
6
7
  from svc_infra.obs.metrics import emit_rate_limited
7
8
 
8
9
  from .ratelimit_store import InMemoryRateLimitStore, RateLimitStore
9
10
 
11
+ try:
12
+ # Optional import: tenancy may not be enabled in all apps
13
+ from svc_infra.api.fastapi.tenancy.context import (
14
+ resolve_tenant_id as _resolve_tenant_id,
15
+ )
16
+ except Exception: # pragma: no cover - fallback for minimal builds
17
+ _resolve_tenant_id = None # type: ignore[assignment]
18
+
19
+
20
+ class SimpleRateLimitMiddleware:
21
+ """
22
+ Pure ASGI rate limiting middleware.
23
+
24
+ Applies per-key rate limits with configurable windows. Use skip_paths for
25
+ endpoints that should bypass rate limiting (e.g., health checks, webhooks).
26
+ """
10
27
 
11
- class SimpleRateLimitMiddleware(BaseHTTPMiddleware):
12
28
  def __init__(
13
29
  self,
14
- app,
30
+ app: ASGIApp,
15
31
  limit: int = 120,
16
32
  window: int = 60,
17
33
  key_fn=None,
34
+ *,
35
+ # When provided, dynamically computes a limit for the current request (e.g. per-tenant quotas)
36
+ # Signature: (request: Request, tenant_id: Optional[str]) -> int | None
37
+ limit_resolver=None,
38
+ # If True, automatically scopes the bucket key by tenant id when available
39
+ scope_by_tenant: bool = False,
40
+ # When True, allows unresolved tenant IDs to fall back to an "X-Tenant-Id" header value.
41
+ # Disabled by default to avoid trusting arbitrary client-provided headers which could
42
+ # otherwise be used to evade per-tenant limits when authentication fails.
43
+ allow_untrusted_tenant_header: bool = False,
18
44
  store: RateLimitStore | None = None,
45
+ skip_paths: list[str] | None = None,
19
46
  ):
20
- super().__init__(app)
47
+ self.app = app
21
48
  self.limit, self.window = limit, window
22
- self.key_fn = key_fn or (lambda r: r.headers.get("X-API-Key") or r.client.host)
49
+ self.key_fn = key_fn
50
+ self._limit_resolver = limit_resolver
51
+ self.scope_by_tenant = scope_by_tenant
52
+ self._allow_untrusted_tenant_header = allow_untrusted_tenant_header
23
53
  self.store = store or InMemoryRateLimitStore(limit=limit)
54
+ self.skip_paths = skip_paths or []
55
+
56
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
57
+ if scope.get("type") != "http":
58
+ await self.app(scope, receive, send)
59
+ return
60
+
61
+ path = scope.get("path", "")
62
+
63
+ # Skip specified paths
64
+ if any(skip in path for skip in self.skip_paths):
65
+ await self.app(scope, receive, send)
66
+ return
67
+
68
+ # Create a Request object for key extraction and tenant resolution
69
+ request = Request(scope, receive)
70
+
71
+ # Default key function
72
+ key_fn = self.key_fn or (
73
+ lambda r: r.headers.get("X-API-Key")
74
+ or (r.client.host if r.client else "unknown")
75
+ )
76
+
77
+ # Resolve tenant when possible
78
+ tenant_id = None
79
+ if self.scope_by_tenant or self._limit_resolver:
80
+ try:
81
+ if _resolve_tenant_id is not None:
82
+ tenant_id = await _resolve_tenant_id(request)
83
+ except Exception:
84
+ tenant_id = None
85
+ # Fallback header behavior - ONLY if explicitly allowed
86
+ # Never trust untrusted headers by default to prevent rate limit evasion
87
+ if not tenant_id and self._allow_untrusted_tenant_header:
88
+ tenant_id = request.headers.get("X-Tenant-Id") or request.headers.get(
89
+ "X-Tenant-ID"
90
+ )
91
+
92
+ key = key_fn(request)
93
+ if self.scope_by_tenant and tenant_id:
94
+ key = f"{key}:tenant:{tenant_id}"
95
+
96
+ # Allow dynamic limit overrides
97
+ eff_limit = self.limit
98
+ if self._limit_resolver:
99
+ try:
100
+ v = self._limit_resolver(request, tenant_id)
101
+ eff_limit = int(v) if v is not None else self.limit
102
+ except Exception:
103
+ eff_limit = self.limit
24
104
 
25
- async def dispatch(self, request, call_next):
26
- key = self.key_fn(request)
27
105
  now = int(time.time())
28
- # Increment counter in store
29
- count, limit, reset = self.store.incr(str(key), self.window)
106
+ count, store_limit, reset = self.store.incr(str(key), self.window)
107
+ limit = eff_limit
30
108
  remaining = max(0, limit - count)
31
109
 
32
- if remaining < 0: # defensive clamp
33
- remaining = 0
34
-
35
110
  if count > limit:
111
+ # Rate limited - return 429
36
112
  retry = max(0, reset - now)
37
113
  try:
38
114
  emit_rate_limited(str(key), limit, retry)
39
115
  except Exception:
40
116
  pass
41
- return JSONResponse(
42
- status_code=429,
43
- content={
117
+
118
+ body = json.dumps(
119
+ {
44
120
  "title": "Too Many Requests",
45
121
  "status": 429,
46
122
  "detail": "Rate limit exceeded.",
47
123
  "code": "RATE_LIMITED",
48
- },
49
- headers={
50
- "X-RateLimit-Limit": str(limit),
51
- "X-RateLimit-Remaining": "0",
52
- "X-RateLimit-Reset": str(reset),
53
- "Retry-After": str(retry),
54
- },
124
+ }
125
+ ).encode("utf-8")
126
+
127
+ await send(
128
+ {
129
+ "type": "http.response.start",
130
+ "status": 429,
131
+ "headers": [
132
+ (b"content-type", b"application/json"),
133
+ (b"x-ratelimit-limit", str(limit).encode()),
134
+ (b"x-ratelimit-remaining", b"0"),
135
+ (b"x-ratelimit-reset", str(reset).encode()),
136
+ (b"retry-after", str(retry).encode()),
137
+ ],
138
+ }
55
139
  )
140
+ await send({"type": "http.response.body", "body": body, "more_body": False})
141
+ return
142
+
143
+ # Not rate limited - add headers to response
144
+ async def send_with_headers(message):
145
+ if message["type"] == "http.response.start":
146
+ headers = list(message.get("headers", []))
147
+ # Add rate limit headers if not already present
148
+ header_names = {h[0].lower() for h in headers}
149
+ if b"x-ratelimit-limit" not in header_names:
150
+ headers.append((b"x-ratelimit-limit", str(limit).encode()))
151
+ if b"x-ratelimit-remaining" not in header_names:
152
+ headers.append((b"x-ratelimit-remaining", str(remaining).encode()))
153
+ if b"x-ratelimit-reset" not in header_names:
154
+ headers.append((b"x-ratelimit-reset", str(reset).encode()))
155
+ message = {**message, "headers": headers}
156
+ await send(message)
56
157
 
57
- resp = await call_next(request)
58
- resp.headers.setdefault("X-RateLimit-Limit", str(limit))
59
- resp.headers.setdefault("X-RateLimit-Remaining", str(remaining))
60
- resp.headers.setdefault("X-RateLimit-Reset", str(reset))
61
- return resp
158
+ await self.app(scope, receive, send_with_headers)
@@ -1,7 +1,31 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
4
+ import os
3
5
  import time
4
- from typing import Optional, Protocol, Tuple
6
+ import warnings
7
+ from typing import Callable, Optional, Protocol, Tuple
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ _INMEMORY_WARNED = False
12
+
13
+
14
+ def _check_inmemory_production_warning(class_name: str) -> None:
15
+ """Warn if in-memory store is used in production."""
16
+ global _INMEMORY_WARNED
17
+ if _INMEMORY_WARNED:
18
+ return
19
+ env = os.getenv("ENV", "development").lower()
20
+ if env in ("production", "staging", "prod"):
21
+ _INMEMORY_WARNED = True
22
+ msg = (
23
+ f"{class_name} is being used in {env} environment. "
24
+ "This is NOT suitable for production - data will be lost on restart. "
25
+ "Use RedisRateLimitStore instead."
26
+ )
27
+ warnings.warn(msg, RuntimeWarning, stacklevel=3)
28
+ logger.critical(msg)
5
29
 
6
30
 
7
31
  class RateLimitStore(Protocol):
@@ -15,15 +39,22 @@ class RateLimitStore(Protocol):
15
39
 
16
40
  class InMemoryRateLimitStore:
17
41
  def __init__(self, limit: int = 120):
42
+ _check_inmemory_production_warning("InMemoryRateLimitStore")
18
43
  self.limit = limit
19
- self._buckets: dict[tuple[str, int], int] = {}
44
+ # Track per-key rolling windows: key -> (count, window_start_epoch)
45
+ self._state: dict[str, tuple[int, float]] = {}
20
46
 
21
47
  def incr(self, key: str, window: int) -> Tuple[int, int, int]:
22
- now = int(time.time())
23
- win = now - (now % window)
24
- count = self._buckets.get((key, win), 0) + 1
25
- self._buckets[(key, win)] = count
26
- reset = win + window
48
+ now = time.time()
49
+ count, window_start = self._state.get(key, (0, now))
50
+ # If outside the rolling window, reset
51
+ if now >= window_start + window:
52
+ count = 1
53
+ window_start = now
54
+ else:
55
+ count += 1
56
+ self._state[key] = (count, window_start)
57
+ reset = int(window_start + window)
27
58
  return count, self.limit, reset
28
59
 
29
60
 
@@ -43,14 +74,14 @@ class RedisRateLimitStore:
43
74
  *,
44
75
  limit: int = 120,
45
76
  prefix: str = "ratelimit",
46
- clock: Optional[callable] = None,
77
+ clock: Optional[Callable[[], float]] = None,
47
78
  ):
48
79
  self.redis = redis_client
49
80
  self.limit = limit
50
81
  self.prefix = prefix
51
82
  self._clock = clock or time.time
52
83
 
53
- def _window_key(self, key: str, window: int) -> tuple[str, int, str]:
84
+ def _window_key(self, key: str, window: int) -> tuple[str, int, int]:
54
85
  now = int(self._clock())
55
86
  win = now - (now % window)
56
87
  redis_key = f"{self.prefix}:{key}:{win}"
@@ -63,7 +94,9 @@ class RedisRateLimitStore:
63
94
  pipe.incr(rkey)
64
95
  pipe.ttl(rkey)
65
96
  count, ttl = pipe.execute()
66
- if ttl == -1: # key exists without expire or just created; set expire to end of window
97
+ if (
98
+ ttl == -1
99
+ ): # key exists without expire or just created; set expire to end of window
67
100
  expire_sec = (win + window) - now
68
101
  if expire_sec <= 0:
69
102
  expire_sec = window
@@ -1,23 +1,39 @@
1
1
  import contextvars
2
2
  from uuid import uuid4
3
3
 
4
- from starlette.middleware.base import BaseHTTPMiddleware
5
- from starlette.types import ASGIApp
4
+ from starlette.datastructures import Headers, MutableHeaders
5
+ from starlette.types import ASGIApp, Message, Receive, Scope, Send
6
6
 
7
- request_id_ctx: contextvars.ContextVar[str] = contextvars.ContextVar("request_id", default="")
7
+ request_id_ctx: contextvars.ContextVar[str] = contextvars.ContextVar(
8
+ "request_id", default=""
9
+ )
8
10
 
9
11
 
10
- class RequestIdMiddleware(BaseHTTPMiddleware):
12
+ class RequestIdMiddleware:
13
+ """Pure ASGI middleware that adds request IDs. Compatible with streaming responses."""
14
+
11
15
  def __init__(self, app: ASGIApp, header_name: str = "X-Request-Id"):
12
- super().__init__(app)
13
- self.header_name = header_name
16
+ self.app = app
17
+ self.header_name = header_name.lower()
18
+
19
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
20
+ if scope["type"] != "http":
21
+ await self.app(scope, receive, send)
22
+ return
14
23
 
15
- async def dispatch(self, request, call_next):
16
- rid = request.headers.get(self.header_name) or uuid4().hex
24
+ # Extract or generate request ID
25
+ headers = Headers(scope=scope)
26
+ rid = headers.get(self.header_name) or uuid4().hex
17
27
  token = request_id_ctx.set(rid)
28
+
29
+ async def send_with_request_id(message: Message) -> None:
30
+ if message["type"] == "http.response.start":
31
+ # Add request ID to response headers
32
+ response_headers = MutableHeaders(scope=message)
33
+ response_headers.append(self.header_name, rid)
34
+ await send(message)
35
+
18
36
  try:
19
- resp = await call_next(request)
20
- resp.headers[self.header_name] = rid
21
- return resp
37
+ await self.app(scope, receive, send_with_request_id)
22
38
  finally:
23
39
  request_id_ctx.reset(token)
@@ -19,9 +19,9 @@ class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
19
19
  size = None
20
20
  if size is not None and size > self.max_bytes:
21
21
  try:
22
- emit_suspect_payload(
23
- getattr(request, "url", None).path if hasattr(request, "url") else None, size
24
- )
22
+ url = getattr(request, "url", None)
23
+ path = url.path if url is not None else None
24
+ emit_suspect_payload(path, size)
25
25
  except Exception:
26
26
  pass
27
27
  return JSONResponse(
@@ -0,0 +1,177 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import os
5
+ from typing import Any
6
+
7
+ from fastapi import Request
8
+ from starlette.types import ASGIApp, Receive, Scope, Send
9
+
10
+ from svc_infra.api.fastapi.middleware.errors.handlers import problem_response
11
+ from svc_infra.app.env import pick
12
+
13
+
14
+ def _env_int(name: str, default: int) -> int:
15
+ v = os.getenv(name)
16
+ if v is None:
17
+ return default
18
+ try:
19
+ return int(v)
20
+ except Exception:
21
+ return default
22
+
23
+
24
+ REQUEST_BODY_TIMEOUT_SECONDS: int = pick(
25
+ prod=_env_int("REQUEST_BODY_TIMEOUT_SECONDS", 15),
26
+ nonprod=_env_int("REQUEST_BODY_TIMEOUT_SECONDS", 30),
27
+ )
28
+ REQUEST_TIMEOUT_SECONDS: int = pick(
29
+ prod=_env_int("REQUEST_TIMEOUT_SECONDS", 30),
30
+ nonprod=_env_int("REQUEST_TIMEOUT_SECONDS", 15),
31
+ )
32
+
33
+
34
+ class HandlerTimeoutMiddleware:
35
+ """
36
+ Caps total handler execution time. If exceeded, returns 504 Problem+JSON.
37
+
38
+ Use skip_paths for endpoints that may run longer than the timeout
39
+ (e.g., streaming responses, long-polling, file uploads).
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ app: ASGIApp,
45
+ timeout_seconds: int | None = None,
46
+ skip_paths: list[str] | None = None,
47
+ ) -> None:
48
+ self.app = app
49
+ self.timeout_seconds = (
50
+ timeout_seconds if timeout_seconds is not None else REQUEST_TIMEOUT_SECONDS
51
+ )
52
+ self.skip_paths = skip_paths or []
53
+
54
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
55
+ if scope.get("type") != "http":
56
+ await self.app(scope, receive, send)
57
+ return
58
+
59
+ path = scope.get("path", "")
60
+
61
+ # Skip specified paths (e.g., long-running endpoints)
62
+ if any(skip in path for skip in self.skip_paths):
63
+ await self.app(scope, receive, send)
64
+ return
65
+
66
+ # Track if response has started (headers sent)
67
+ response_started = False
68
+
69
+ async def send_wrapper(message: dict) -> None:
70
+ nonlocal response_started
71
+ if message.get("type") == "http.response.start":
72
+ response_started = True
73
+ await send(message)
74
+
75
+ try:
76
+ await asyncio.wait_for(
77
+ self.app(scope, receive, send_wrapper), # type: ignore[arg-type] # ASGI send signature
78
+ timeout=self.timeout_seconds,
79
+ )
80
+ except asyncio.TimeoutError:
81
+ # Only send 504 if response hasn't started yet
82
+ if not response_started:
83
+ response = problem_response(
84
+ status=504,
85
+ title="Gateway Timeout",
86
+ detail=f"Handler did not complete within {self.timeout_seconds}s",
87
+ )
88
+ await response(scope, receive, send)
89
+ # If response already started, we can't change it - just let it fail
90
+
91
+
92
+ class BodyReadTimeoutMiddleware:
93
+ """
94
+ Enforces a timeout while reading the request body to mitigate slowloris.
95
+ If body read does not make progress within the timeout, returns 408 Problem+JSON.
96
+ """
97
+
98
+ def __init__(self, app: ASGIApp, timeout_seconds: int | None = None) -> None:
99
+ self.app = app
100
+ self.timeout_seconds = (
101
+ timeout_seconds
102
+ if timeout_seconds is not None
103
+ else REQUEST_BODY_TIMEOUT_SECONDS
104
+ )
105
+
106
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
107
+ if scope.get("type") != "http":
108
+ await self.app(scope, receive, send)
109
+ return
110
+
111
+ # Strategy: greedily drain the incoming request body here while enforcing
112
+ # per-receive timeout, then replay it to the downstream app from a buffer.
113
+ # This ensures we can detect slowloris-style uploads even if the app only
114
+ # reads the body later (after the server has finished buffering).
115
+ buffered = bytearray()
116
+
117
+ try:
118
+ while True:
119
+ message = await asyncio.wait_for(
120
+ receive(), timeout=self.timeout_seconds
121
+ )
122
+
123
+ mtype = message.get("type")
124
+ if mtype == "http.request":
125
+ chunk = message.get("body", b"") or b""
126
+ if chunk:
127
+ buffered.extend(chunk)
128
+ # Stop when server indicates no more body
129
+ if not message.get("more_body", False):
130
+ break
131
+ # else: continue reading remaining chunks with timeout
132
+ continue
133
+
134
+ if mtype == "http.disconnect": # client disconnected mid-upload
135
+ # Treat as end of body for the purposes of replay; downstream
136
+ # will see an empty body. No timeout response needed here.
137
+ break
138
+ # Ignore other message types and continue
139
+ except asyncio.TimeoutError:
140
+ # Timed out while waiting for the next body chunk → return 408
141
+ request = Request(scope, receive=receive)
142
+ trace_id = None
143
+ for h in ("x-request-id", "x-correlation-id", "x-trace-id"):
144
+ v = request.headers.get(h)
145
+ if v:
146
+ trace_id = v
147
+ break
148
+ resp = problem_response(
149
+ status=408,
150
+ title="Request Timeout",
151
+ detail="Timed out while reading request body.",
152
+ code="REQUEST_TIMEOUT",
153
+ instance=str(request.url),
154
+ trace_id=trace_id,
155
+ )
156
+ await resp(scope, receive, send)
157
+ return
158
+
159
+ # Replay the drained body to the app as a single http.request message.
160
+ # IMPORTANT: After replaying the body, we must forward the original receive()
161
+ # so that Starlette's listen_for_disconnect can properly detect client disconnects.
162
+ # This is required for streaming responses on ASGI spec < 2.4.
163
+ body_sent = False
164
+
165
+ async def _replay_receive() -> dict[str, Any]:
166
+ nonlocal body_sent
167
+ if not body_sent:
168
+ body_sent = True
169
+ return {
170
+ "type": "http.request",
171
+ "body": bytes(buffered),
172
+ "more_body": False,
173
+ }
174
+ # After body is sent, forward to original receive for disconnect detection
175
+ return dict(await receive())
176
+
177
+ await self.app(scope, _replay_receive, send)
@@ -5,7 +5,9 @@ from typing import Any, Callable
5
5
  from fastapi import APIRouter
6
6
 
7
7
 
8
- def apply_default_security(router: APIRouter, *, default_security: list[dict] | None) -> None:
8
+ def apply_default_security(
9
+ router: APIRouter, *, default_security: list[dict] | None
10
+ ) -> None:
9
11
  if default_security is None:
10
12
  return
11
13
  original_add = router.add_api_route
@@ -17,7 +19,7 @@ def apply_default_security(router: APIRouter, *, default_security: list[dict] |
17
19
  kwargs["openapi_extra"] = ox
18
20
  return original_add(path, endpoint, **kwargs)
19
21
 
20
- router.add_api_route = _wrapped_add_api_route # type: ignore[attr-defined]
22
+ setattr(router, "add_api_route", _wrapped_add_api_route)
21
23
 
22
24
 
23
25
  def apply_default_responses(router: APIRouter, defaults: dict[int, dict]) -> None:
@@ -38,4 +40,4 @@ def apply_default_responses(router: APIRouter, defaults: dict[int, dict]) -> Non
38
40
  kwargs["responses"] = responses
39
41
  return original_add(path, endpoint, **kwargs)
40
42
 
41
- router.add_api_route = _wrapped_add_api_route # type: ignore[attr-defined]
43
+ setattr(router, "add_api_route", _wrapped_add_api_route)
@@ -16,7 +16,11 @@ PROBLEM_SCHEMA: Dict[str, Any] = {
16
16
  "description": "URI identifying the error type",
17
17
  },
18
18
  "title": {"type": "string", "description": "Short, human-readable summary"},
19
- "status": {"type": "integer", "format": "int32", "description": "HTTP status code"},
19
+ "status": {
20
+ "type": "integer",
21
+ "format": "int32",
22
+ "description": "HTTP status code",
23
+ },
20
24
  "detail": {"type": "string", "description": "Human-readable explanation"},
21
25
  "instance": {
22
26
  "type": "string",
@@ -36,7 +40,10 @@ PROBLEM_SCHEMA: Dict[str, Any] = {
36
40
  },
37
41
  },
38
42
  },
39
- "trace_id": {"type": "string", "description": "Correlation/trace id (if available)"},
43
+ "trace_id": {
44
+ "type": "string",
45
+ "description": "Correlation/trace id (if available)",
46
+ },
40
47
  },
41
48
  "required": ["title", "status"],
42
49
  }