svc-infra 0.1.562__py3-none-any.whl → 0.1.654__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. svc_infra/apf_payments/README.md +732 -0
  2. svc_infra/apf_payments/models.py +142 -4
  3. svc_infra/apf_payments/provider/__init__.py +4 -0
  4. svc_infra/apf_payments/provider/aiydan.py +797 -0
  5. svc_infra/apf_payments/provider/base.py +178 -12
  6. svc_infra/apf_payments/provider/stripe.py +757 -48
  7. svc_infra/apf_payments/schemas.py +163 -1
  8. svc_infra/apf_payments/service.py +582 -42
  9. svc_infra/apf_payments/settings.py +22 -2
  10. svc_infra/api/fastapi/admin/__init__.py +3 -0
  11. svc_infra/api/fastapi/admin/add.py +231 -0
  12. svc_infra/api/fastapi/apf_payments/router.py +792 -73
  13. svc_infra/api/fastapi/apf_payments/setup.py +13 -4
  14. svc_infra/api/fastapi/auth/add.py +10 -4
  15. svc_infra/api/fastapi/auth/gaurd.py +67 -5
  16. svc_infra/api/fastapi/auth/routers/oauth_router.py +74 -34
  17. svc_infra/api/fastapi/auth/routers/session_router.py +63 -0
  18. svc_infra/api/fastapi/auth/settings.py +2 -0
  19. svc_infra/api/fastapi/billing/router.py +64 -0
  20. svc_infra/api/fastapi/billing/setup.py +19 -0
  21. svc_infra/api/fastapi/cache/add.py +9 -5
  22. svc_infra/api/fastapi/db/nosql/mongo/add.py +33 -27
  23. svc_infra/api/fastapi/db/sql/add.py +40 -18
  24. svc_infra/api/fastapi/db/sql/crud_router.py +176 -14
  25. svc_infra/api/fastapi/db/sql/session.py +16 -0
  26. svc_infra/api/fastapi/db/sql/users.py +13 -1
  27. svc_infra/api/fastapi/dependencies/ratelimit.py +116 -0
  28. svc_infra/api/fastapi/docs/add.py +160 -0
  29. svc_infra/api/fastapi/docs/landing.py +1 -1
  30. svc_infra/api/fastapi/docs/scoped.py +41 -6
  31. svc_infra/api/fastapi/middleware/errors/handlers.py +45 -7
  32. svc_infra/api/fastapi/middleware/graceful_shutdown.py +87 -0
  33. svc_infra/api/fastapi/middleware/idempotency.py +82 -42
  34. svc_infra/api/fastapi/middleware/idempotency_store.py +187 -0
  35. svc_infra/api/fastapi/middleware/optimistic_lock.py +37 -0
  36. svc_infra/api/fastapi/middleware/ratelimit.py +84 -11
  37. svc_infra/api/fastapi/middleware/ratelimit_store.py +84 -0
  38. svc_infra/api/fastapi/middleware/request_size_limit.py +36 -0
  39. svc_infra/api/fastapi/middleware/timeout.py +148 -0
  40. svc_infra/api/fastapi/openapi/mutators.py +244 -38
  41. svc_infra/api/fastapi/ops/add.py +73 -0
  42. svc_infra/api/fastapi/pagination.py +133 -32
  43. svc_infra/api/fastapi/routers/ping.py +1 -0
  44. svc_infra/api/fastapi/setup.py +23 -14
  45. svc_infra/api/fastapi/tenancy/add.py +19 -0
  46. svc_infra/api/fastapi/tenancy/context.py +112 -0
  47. svc_infra/api/fastapi/versioned.py +101 -0
  48. svc_infra/app/README.md +5 -5
  49. svc_infra/billing/__init__.py +23 -0
  50. svc_infra/billing/async_service.py +147 -0
  51. svc_infra/billing/jobs.py +230 -0
  52. svc_infra/billing/models.py +131 -0
  53. svc_infra/billing/quotas.py +101 -0
  54. svc_infra/billing/schemas.py +33 -0
  55. svc_infra/billing/service.py +115 -0
  56. svc_infra/bundled_docs/README.md +5 -0
  57. svc_infra/bundled_docs/__init__.py +1 -0
  58. svc_infra/bundled_docs/getting-started.md +6 -0
  59. svc_infra/cache/__init__.py +4 -0
  60. svc_infra/cache/add.py +158 -0
  61. svc_infra/cache/backend.py +5 -2
  62. svc_infra/cache/decorators.py +19 -1
  63. svc_infra/cache/keys.py +24 -4
  64. svc_infra/cli/__init__.py +32 -8
  65. svc_infra/cli/__main__.py +4 -0
  66. svc_infra/cli/cmds/__init__.py +10 -0
  67. svc_infra/cli/cmds/db/nosql/mongo/mongo_cmds.py +4 -3
  68. svc_infra/cli/cmds/db/nosql/mongo/mongo_scaffold_cmds.py +4 -4
  69. svc_infra/cli/cmds/db/sql/alembic_cmds.py +80 -11
  70. svc_infra/cli/cmds/db/sql/sql_export_cmds.py +80 -0
  71. svc_infra/cli/cmds/db/sql/sql_scaffold_cmds.py +3 -3
  72. svc_infra/cli/cmds/docs/docs_cmds.py +140 -0
  73. svc_infra/cli/cmds/dx/__init__.py +12 -0
  74. svc_infra/cli/cmds/dx/dx_cmds.py +99 -0
  75. svc_infra/cli/cmds/help.py +4 -0
  76. svc_infra/cli/cmds/jobs/__init__.py +1 -0
  77. svc_infra/cli/cmds/jobs/jobs_cmds.py +43 -0
  78. svc_infra/cli/cmds/obs/obs_cmds.py +4 -3
  79. svc_infra/cli/cmds/sdk/__init__.py +0 -0
  80. svc_infra/cli/cmds/sdk/sdk_cmds.py +102 -0
  81. svc_infra/data/add.py +61 -0
  82. svc_infra/data/backup.py +53 -0
  83. svc_infra/data/erasure.py +45 -0
  84. svc_infra/data/fixtures.py +40 -0
  85. svc_infra/data/retention.py +55 -0
  86. svc_infra/db/inbox.py +67 -0
  87. svc_infra/db/nosql/mongo/README.md +13 -13
  88. svc_infra/db/outbox.py +104 -0
  89. svc_infra/db/sql/repository.py +52 -12
  90. svc_infra/db/sql/resource.py +5 -0
  91. svc_infra/db/sql/templates/models_schemas/auth/schemas.py.tmpl +1 -1
  92. svc_infra/db/sql/templates/setup/env_async.py.tmpl +13 -8
  93. svc_infra/db/sql/templates/setup/env_sync.py.tmpl +9 -5
  94. svc_infra/db/sql/tenant.py +79 -0
  95. svc_infra/db/sql/utils.py +18 -4
  96. svc_infra/db/sql/versioning.py +14 -0
  97. svc_infra/docs/acceptance-matrix.md +71 -0
  98. svc_infra/docs/acceptance.md +44 -0
  99. svc_infra/docs/admin.md +425 -0
  100. svc_infra/docs/adr/0002-background-jobs-and-scheduling.md +40 -0
  101. svc_infra/docs/adr/0003-webhooks-framework.md +24 -0
  102. svc_infra/docs/adr/0004-tenancy-model.md +42 -0
  103. svc_infra/docs/adr/0005-data-lifecycle.md +86 -0
  104. svc_infra/docs/adr/0006-ops-slos-and-metrics.md +47 -0
  105. svc_infra/docs/adr/0007-docs-and-sdks.md +83 -0
  106. svc_infra/docs/adr/0008-billing-primitives.md +143 -0
  107. svc_infra/docs/adr/0009-acceptance-harness.md +40 -0
  108. svc_infra/docs/adr/0010-timeouts-and-resource-limits.md +54 -0
  109. svc_infra/docs/adr/0011-admin-scope-and-impersonation.md +73 -0
  110. svc_infra/docs/api.md +59 -0
  111. svc_infra/docs/auth.md +11 -0
  112. svc_infra/docs/billing.md +190 -0
  113. svc_infra/docs/cache.md +76 -0
  114. svc_infra/docs/cli.md +74 -0
  115. svc_infra/docs/contributing.md +34 -0
  116. svc_infra/docs/data-lifecycle.md +52 -0
  117. svc_infra/docs/database.md +14 -0
  118. svc_infra/docs/docs-and-sdks.md +62 -0
  119. svc_infra/docs/environment.md +114 -0
  120. svc_infra/docs/getting-started.md +63 -0
  121. svc_infra/docs/idempotency.md +111 -0
  122. svc_infra/docs/jobs.md +67 -0
  123. svc_infra/docs/observability.md +16 -0
  124. svc_infra/docs/ops.md +37 -0
  125. svc_infra/docs/rate-limiting.md +125 -0
  126. svc_infra/docs/repo-review.md +48 -0
  127. svc_infra/docs/security.md +176 -0
  128. svc_infra/docs/tenancy.md +35 -0
  129. svc_infra/docs/timeouts-and-resource-limits.md +147 -0
  130. svc_infra/docs/versioned-integrations.md +146 -0
  131. svc_infra/docs/webhooks.md +112 -0
  132. svc_infra/dx/add.py +63 -0
  133. svc_infra/dx/changelog.py +74 -0
  134. svc_infra/dx/checks.py +67 -0
  135. svc_infra/http/__init__.py +13 -0
  136. svc_infra/http/client.py +72 -0
  137. svc_infra/jobs/builtins/outbox_processor.py +38 -0
  138. svc_infra/jobs/builtins/webhook_delivery.py +90 -0
  139. svc_infra/jobs/easy.py +32 -0
  140. svc_infra/jobs/loader.py +45 -0
  141. svc_infra/jobs/queue.py +81 -0
  142. svc_infra/jobs/redis_queue.py +191 -0
  143. svc_infra/jobs/runner.py +75 -0
  144. svc_infra/jobs/scheduler.py +41 -0
  145. svc_infra/jobs/worker.py +40 -0
  146. svc_infra/mcp/svc_infra_mcp.py +85 -28
  147. svc_infra/obs/README.md +2 -0
  148. svc_infra/obs/add.py +54 -7
  149. svc_infra/obs/grafana/dashboards/http-overview.json +45 -0
  150. svc_infra/obs/metrics/__init__.py +53 -0
  151. svc_infra/obs/metrics.py +52 -0
  152. svc_infra/security/add.py +201 -0
  153. svc_infra/security/audit.py +130 -0
  154. svc_infra/security/audit_service.py +73 -0
  155. svc_infra/security/headers.py +52 -0
  156. svc_infra/security/hibp.py +95 -0
  157. svc_infra/security/jwt_rotation.py +53 -0
  158. svc_infra/security/lockout.py +96 -0
  159. svc_infra/security/models.py +255 -0
  160. svc_infra/security/org_invites.py +128 -0
  161. svc_infra/security/passwords.py +77 -0
  162. svc_infra/security/permissions.py +149 -0
  163. svc_infra/security/session.py +98 -0
  164. svc_infra/security/signed_cookies.py +80 -0
  165. svc_infra/webhooks/__init__.py +16 -0
  166. svc_infra/webhooks/add.py +322 -0
  167. svc_infra/webhooks/fastapi.py +37 -0
  168. svc_infra/webhooks/router.py +55 -0
  169. svc_infra/webhooks/service.py +67 -0
  170. svc_infra/webhooks/signing.py +30 -0
  171. svc_infra-0.1.654.dist-info/METADATA +154 -0
  172. {svc_infra-0.1.562.dist-info → svc_infra-0.1.654.dist-info}/RECORD +174 -56
  173. svc_infra-0.1.562.dist-info/METADATA +0 -79
  174. {svc_infra-0.1.562.dist-info → svc_infra-0.1.654.dist-info}/WHEEL +0 -0
  175. {svc_infra-0.1.562.dist-info → svc_infra-0.1.654.dist-info}/entry_points.txt +0 -0
@@ -115,7 +115,7 @@ def render_index_html(*, service_name: str, release: str, cards: Iterable[CardSp
115
115
  <section class="grid">
116
116
  {grid}
117
117
  </section>
118
- <footer>Tip: each card exposes Swagger, ReDoc, and a pretty JSON view.</footer>
118
+ <footer>Tip: each card exposes Swagger, ReDoc, and a JSON view.</footer>
119
119
  </div>
120
120
  </body>
121
121
  </html>
@@ -65,11 +65,18 @@ def _close_over_component_refs(
65
65
 
66
66
 
67
67
  def _prune_to_paths(
68
- full_schema: Dict, keep_paths: Dict[str, dict], title_suffix: Optional[str]
68
+ full_schema: Dict,
69
+ keep_paths: Dict[str, dict],
70
+ title_suffix: Optional[str],
71
+ server_prefix: Optional[str] = None,
69
72
  ) -> Dict:
70
73
  schema = copy.deepcopy(full_schema)
71
74
  schema["paths"] = keep_paths
72
75
 
76
+ # Set server URL for scoped docs
77
+ if server_prefix is not None:
78
+ schema["servers"] = [{"url": server_prefix}]
79
+
73
80
  used_tags: Set[str] = set()
74
81
  direct_refs: Set[Tuple[str, str]] = set()
75
82
  used_security_schemes: Set[str] = set()
@@ -124,7 +131,26 @@ def _build_filtered_schema(
124
131
  keep_paths = {
125
132
  p: v for p, v in paths.items() if _path_included(p, include_prefixes, exclude_prefixes)
126
133
  }
127
- return _prune_to_paths(full_schema, keep_paths, title_suffix)
134
+
135
+ # Determine the server prefix for scoped docs
136
+ server_prefix = None
137
+ if include_prefixes and len(include_prefixes) == 1:
138
+ # Single include prefix = scoped docs
139
+ server_prefix = include_prefixes[0].rstrip("/") or "/"
140
+
141
+ # Strip prefix from paths to make them relative to the server
142
+ stripped_paths = {}
143
+ for path, spec in keep_paths.items():
144
+ if path.startswith(server_prefix) and path != server_prefix:
145
+ # Remove prefix, keeping the leading slash
146
+ relative_path = path[len(server_prefix) :]
147
+ stripped_paths[relative_path] = spec
148
+ else:
149
+ # Path equals prefix or doesn't start with it
150
+ stripped_paths[path] = spec
151
+ keep_paths = stripped_paths
152
+
153
+ return _prune_to_paths(full_schema, keep_paths, title_suffix, server_prefix=server_prefix)
128
154
 
129
155
 
130
156
  def _ensure_original_openapi_saved(app: FastAPI) -> None:
@@ -175,11 +201,23 @@ def add_prefixed_docs(
175
201
  auto_exclude_from_root: bool = True,
176
202
  visible_envs: Optional[Iterable[Environment | str]] = (LOCAL_ENV, DEV_ENV),
177
203
  ) -> None:
204
+ scope = prefix.rstrip("/") or "/"
205
+
206
+ # Always exclude from root if requested, regardless of environment
207
+ if auto_exclude_from_root:
208
+ _ensure_original_openapi_saved(app)
209
+ # Add to exclusion list for root docs
210
+ if not hasattr(app.state, "_scoped_root_exclusions"):
211
+ app.state._scoped_root_exclusions = []
212
+ if scope not in app.state._scoped_root_exclusions:
213
+ app.state._scoped_root_exclusions.append(scope)
214
+ _install_root_filter(app, app.state._scoped_root_exclusions)
215
+
216
+ # Only create scoped docs in allowed environments
178
217
  allow = _normalize_envs(visible_envs)
179
218
  if allow is not None and CURRENT_ENVIRONMENT not in allow:
180
219
  return
181
220
 
182
- scope = prefix.rstrip("/") or "/"
183
221
  openapi_path = f"{scope}/openapi.json"
184
222
  swagger_path = f"{scope}/docs"
185
223
  redoc_path = f"{scope}/redoc"
@@ -211,9 +249,6 @@ def add_prefixed_docs(
211
249
 
212
250
  DOC_SCOPES.append((scope, swagger_path, redoc_path, openapi_path, title))
213
251
 
214
- if auto_exclude_from_root:
215
- _ensure_root_excludes_registered_scopes(app)
216
-
217
252
 
218
253
  def replace_root_openapi_with_exclusions(app: FastAPI, *, exclude_prefixes: List[str]) -> None:
219
254
  _install_root_filter(app, exclude_prefixes)
@@ -2,6 +2,7 @@ import logging
2
2
  import traceback
3
3
  from typing import Any, Dict, Optional
4
4
 
5
+ import httpx
5
6
  from fastapi import Request
6
7
  from fastapi.exceptions import HTTPException, RequestValidationError
7
8
  from fastapi.responses import JSONResponse, Response
@@ -46,6 +47,7 @@ def problem_response(
46
47
  code: str | None = None,
47
48
  errors: list[dict] | None = None,
48
49
  trace_id: str | None = None,
50
+ headers: dict[str, str] | None = None,
49
51
  ) -> Response:
50
52
  body: Dict[str, Any] = {
51
53
  "type": type_uri,
@@ -62,10 +64,24 @@ def problem_response(
62
64
  body["errors"] = errors
63
65
  if trace_id:
64
66
  body["trace_id"] = trace_id
65
- return JSONResponse(status_code=status, content=body, media_type=PROBLEM_MT)
67
+ return JSONResponse(status_code=status, content=body, media_type=PROBLEM_MT, headers=headers)
66
68
 
67
69
 
68
70
  def register_error_handlers(app):
71
+ @app.exception_handler(httpx.TimeoutException)
72
+ async def handle_httpx_timeout(request: Request, exc: httpx.TimeoutException):
73
+ trace_id = _trace_id_from_request(request)
74
+ # Map outbound HTTP client timeouts to 504 Gateway Timeout
75
+ # Keep details generic in prod
76
+ return problem_response(
77
+ status=504,
78
+ title="Gateway Timeout",
79
+ detail=("Upstream request timed out." if IS_PROD else (str(exc) or "httpx timeout")),
80
+ code="GATEWAY_TIMEOUT",
81
+ instance=str(request.url),
82
+ trace_id=trace_id,
83
+ )
84
+
69
85
  @app.exception_handler(FastApiException)
70
86
  async def handle_app_exception(request: Request, exc: FastApiException):
71
87
  trace_id = _trace_id_from_request(request)
@@ -104,14 +120,25 @@ def register_error_handlers(app):
104
120
  @app.exception_handler(HTTPException)
105
121
  async def handle_http_exception(request: Request, exc: HTTPException):
106
122
  trace_id = _trace_id_from_request(request)
107
- title = {401: "Unauthorized", 403: "Forbidden", 404: "Not Found"}.get(
108
- exc.status_code, "Error"
109
- )
123
+ title = {
124
+ 401: "Unauthorized",
125
+ 403: "Forbidden",
126
+ 404: "Not Found",
127
+ 429: "Too Many Requests",
128
+ }.get(exc.status_code, "Error")
110
129
  detail = (
111
130
  exc.detail
112
131
  if not IS_PROD or exc.status_code < 500
113
132
  else "Something went wrong. Please contact support."
114
133
  )
134
+ # Preserve headers set on the exception (e.g., Retry-After for rate limits)
135
+ hdrs: dict[str, str] | None = None
136
+ try:
137
+ if getattr(exc, "headers", None):
138
+ # FastAPI/Starlette exceptions store headers as a dict[str, str]
139
+ hdrs = dict(getattr(exc, "headers")) # type: ignore[arg-type]
140
+ except Exception:
141
+ hdrs = None
115
142
  return problem_response(
116
143
  status=exc.status_code,
117
144
  title=title,
@@ -119,19 +146,29 @@ def register_error_handlers(app):
119
146
  code=title.replace(" ", "_").upper(),
120
147
  instance=str(request.url),
121
148
  trace_id=trace_id,
149
+ headers=hdrs,
122
150
  )
123
151
 
124
152
  @app.exception_handler(StarletteHTTPException)
125
153
  async def handle_starlette_http_exception(request: Request, exc: StarletteHTTPException):
126
154
  trace_id = _trace_id_from_request(request)
127
- title = {401: "Unauthorized", 403: "Forbidden", 404: "Not Found"}.get(
128
- exc.status_code, "Error"
129
- )
155
+ title = {
156
+ 401: "Unauthorized",
157
+ 403: "Forbidden",
158
+ 404: "Not Found",
159
+ 429: "Too Many Requests",
160
+ }.get(exc.status_code, "Error")
130
161
  detail = (
131
162
  exc.detail
132
163
  if not IS_PROD or exc.status_code < 500
133
164
  else "Something went wrong. Please contact support."
134
165
  )
166
+ hdrs: dict[str, str] | None = None
167
+ try:
168
+ if getattr(exc, "headers", None):
169
+ hdrs = dict(getattr(exc, "headers")) # type: ignore[arg-type]
170
+ except Exception:
171
+ hdrs = None
135
172
  return problem_response(
136
173
  status=exc.status_code,
137
174
  title=title,
@@ -139,6 +176,7 @@ def register_error_handlers(app):
139
176
  code=title.replace(" ", "_").upper(),
140
177
  instance=str(request.url),
141
178
  trace_id=trace_id,
179
+ headers=hdrs,
142
180
  )
143
181
 
144
182
  @app.exception_handler(IntegrityError)
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ from contextlib import asynccontextmanager
7
+ from typing import Optional
8
+
9
+ from fastapi import FastAPI
10
+ from starlette.types import ASGIApp, Receive, Scope, Send
11
+
12
+ from svc_infra.app.env import pick
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _get_grace_period_seconds() -> float:
18
+ default = pick(prod=20.0, nonprod=5.0)
19
+ raw = os.getenv("SHUTDOWN_GRACE_PERIOD_SECONDS")
20
+ if raw is None or raw == "":
21
+ return float(default)
22
+ try:
23
+ return float(raw)
24
+ except ValueError:
25
+ return float(default)
26
+
27
+
28
+ class InflightTrackerMiddleware:
29
+ """Tracks number of in-flight requests to support graceful shutdown drains."""
30
+
31
+ def __init__(self, app: ASGIApp):
32
+ self.app = app
33
+
34
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
35
+ if scope.get("type") != "http":
36
+ await self.app(scope, receive, send)
37
+ return
38
+ state = scope.get("app").state # type: ignore[attr-defined]
39
+ state._inflight_requests = getattr(state, "_inflight_requests", 0) + 1
40
+ try:
41
+ await self.app(scope, receive, send)
42
+ finally:
43
+ state._inflight_requests = max(0, getattr(state, "_inflight_requests", 1) - 1)
44
+
45
+
46
+ async def _wait_for_drain(app: FastAPI, grace: float) -> None:
47
+ interval = 0.1
48
+ waited = 0.0
49
+ while waited < grace:
50
+ inflight = int(getattr(app.state, "_inflight_requests", 0))
51
+ if inflight <= 0:
52
+ return
53
+ await asyncio.sleep(interval)
54
+ waited += interval
55
+ inflight = int(getattr(app.state, "_inflight_requests", 0))
56
+ if inflight > 0:
57
+ logger.warning(
58
+ "Graceful shutdown timeout: %s in-flight request(s) after %.2fs", inflight, waited
59
+ )
60
+
61
+
62
+ def install_graceful_shutdown(app: FastAPI, *, grace_seconds: Optional[float] = None) -> None:
63
+ """Install inflight tracking and lifespan hooks to wait for requests to drain.
64
+
65
+ - Adds InflightTrackerMiddleware
66
+ - Registers a lifespan handler that initializes state and waits up to grace_seconds on shutdown
67
+ """
68
+ app.add_middleware(InflightTrackerMiddleware)
69
+
70
+ g = float(grace_seconds) if grace_seconds is not None else _get_grace_period_seconds()
71
+
72
+ # Preserve any existing lifespan and wrap it so our drain runs on shutdown.
73
+ previous_lifespan = getattr(app.router, "lifespan_context", None)
74
+
75
+ @asynccontextmanager
76
+ async def _lifespan(a: FastAPI): # noqa: ANN202
77
+ # Startup: initialize inflight counter
78
+ a.state._inflight_requests = 0
79
+ if previous_lifespan is not None:
80
+ async with previous_lifespan(a):
81
+ yield
82
+ else:
83
+ yield
84
+ # Shutdown: wait for in-flight requests to drain (up to grace period)
85
+ await _wait_for_drain(a, g)
86
+
87
+ app.router.lifespan_context = _lifespan
@@ -1,36 +1,32 @@
1
+ import base64
1
2
  import hashlib
2
3
  import time
4
+ from typing import Annotated, Dict, Optional
3
5
 
6
+ from fastapi import Header, HTTPException, Request
4
7
  from starlette.middleware.base import BaseHTTPMiddleware
5
- from starlette.responses import Response
8
+ from starlette.responses import JSONResponse, Response
9
+
10
+ from .idempotency_store import IdempotencyStore, InMemoryIdempotencyStore
6
11
 
7
12
 
8
13
  class IdempotencyMiddleware(BaseHTTPMiddleware):
9
- def __init__(self, app, ttl_seconds: int = 24 * 3600, store=None):
14
+ def __init__(
15
+ self,
16
+ app,
17
+ ttl_seconds: int = 24 * 3600,
18
+ store: Optional[IdempotencyStore] = None,
19
+ header_name: str = "Idempotency-Key",
20
+ ):
10
21
  super().__init__(app)
11
22
  self.ttl = ttl_seconds
12
- self.store = store or {} # replace with Redis
23
+ self.store: IdempotencyStore = store or InMemoryIdempotencyStore()
24
+ self.header_name = header_name
13
25
 
14
26
  def _cache_key(self, request, idkey: str):
15
- body = getattr(request, "_body", None)
16
- if body is None:
17
- body = b""
18
-
19
- async def _read():
20
- data = await request.body()
21
- request._body = data # stash for downstream
22
- return data
23
-
24
- # read once
25
- # note: starlette Request is awaitable; we read in dispatch below
26
-
27
+ # The cache key must NOT include the body to allow conflict detection for mismatched payloads.
27
28
  sig = hashlib.sha256(
28
- (
29
- request.method + "|" + request.url.path + "|" + idkey + "|" + (request._body or b"")
30
- ).encode()
31
- if isinstance(request._body, str)
32
- else (request.method + "|" + request.url.path + "|" + idkey).encode()
33
- + (request._body or b"")
29
+ (request.method + "|" + request.url.path + "|" + idkey).encode()
34
30
  ).hexdigest()
35
31
  return f"idmp:{sig}"
36
32
 
@@ -39,33 +35,69 @@ class IdempotencyMiddleware(BaseHTTPMiddleware):
39
35
  # read & buffer body once
40
36
  body = await request.body()
41
37
  request._body = body
42
- idkey = request.headers.get("Idempotency-Key")
38
+ idkey = request.headers.get(self.header_name)
43
39
  if idkey:
44
40
  k = self._cache_key(request, idkey)
45
- entry = self.store.get(k)
46
41
  now = time.time()
47
- if entry and entry["exp"] > now:
48
- cached = entry["resp"]
49
- return Response(
50
- content=cached["body"],
51
- status_code=cached["status"],
52
- headers=cached["headers"],
53
- media_type=cached.get("media_type"),
54
- )
42
+ # build request hash to detect mismatched replays
43
+ req_hash = hashlib.sha256(body or b"").hexdigest()
44
+
45
+ existing = self.store.get(k)
46
+ if existing and existing.exp > now:
47
+ # If payload mismatches any existing claim, return conflict
48
+ if existing.req_hash and existing.req_hash != req_hash:
49
+ return JSONResponse(
50
+ status_code=409,
51
+ content={
52
+ "type": "about:blank",
53
+ "title": "Conflict",
54
+ "detail": "Idempotency-Key re-used with different request payload.",
55
+ },
56
+ )
57
+ # If response cached and payload matches, replay it
58
+ if existing.status is not None and existing.body_b64 is not None:
59
+ return Response(
60
+ content=base64.b64decode(existing.body_b64),
61
+ status_code=existing.status,
62
+ headers=existing.headers or {},
63
+ media_type=existing.media_type,
64
+ )
65
+
66
+ # Claim the key if not present
67
+ exp = now + self.ttl
68
+ created = self.store.set_initial(k, req_hash, exp)
69
+ if not created:
70
+ # Someone else claimed; re-check for conflict or replay
71
+ existing = self.store.get(k)
72
+ if existing and existing.req_hash and existing.req_hash != req_hash:
73
+ return JSONResponse(
74
+ status_code=409,
75
+ content={
76
+ "type": "about:blank",
77
+ "title": "Conflict",
78
+ "detail": "Idempotency-Key re-used with different request payload.",
79
+ },
80
+ )
81
+ if existing and existing.status is not None and existing.body_b64 is not None:
82
+ return Response(
83
+ content=base64.b64decode(existing.body_b64),
84
+ status_code=existing.status,
85
+ headers=existing.headers or {},
86
+ media_type=existing.media_type,
87
+ )
88
+
89
+ # Proceed to handler
55
90
  resp = await call_next(request)
56
- # cache only 2xx/201 responses
57
91
  if 200 <= resp.status_code < 300:
58
92
  body_bytes = b"".join([section async for section in resp.body_iterator])
59
- headers = dict(resp.headers)
60
- self.store[k] = {
61
- "resp": {
62
- "status": resp.status_code,
63
- "body": body_bytes,
64
- "headers": headers,
65
- "media_type": resp.media_type,
66
- },
67
- "exp": now + self.ttl,
68
- }
93
+ headers: Dict[str, str] = dict(resp.headers)
94
+ self.store.set_response(
95
+ k,
96
+ status=resp.status_code,
97
+ body=body_bytes,
98
+ headers=headers,
99
+ media_type=resp.media_type,
100
+ )
69
101
  return Response(
70
102
  content=body_bytes,
71
103
  status_code=resp.status_code,
@@ -74,3 +106,11 @@ class IdempotencyMiddleware(BaseHTTPMiddleware):
74
106
  )
75
107
  return resp
76
108
  return await call_next(request)
109
+
110
+
111
+ async def require_idempotency_key(
112
+ idempotency_key: Annotated[str, Header(alias="Idempotency-Key")],
113
+ request: Request,
114
+ ) -> None:
115
+ if not idempotency_key.strip():
116
+ raise HTTPException(status_code=400, detail="Idempotency-Key must not be empty.")
@@ -0,0 +1,187 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import Dict, Optional, Protocol
8
+
9
+
10
+ @dataclass
11
+ class IdempotencyEntry:
12
+ req_hash: str
13
+ exp: float
14
+ # Optional response fields when available
15
+ status: Optional[int] = None
16
+ body_b64: Optional[str] = None
17
+ headers: Optional[Dict[str, str]] = None
18
+ media_type: Optional[str] = None
19
+
20
+
21
+ class IdempotencyStore(Protocol):
22
+ def get(self, key: str) -> Optional[IdempotencyEntry]:
23
+ pass
24
+
25
+ def set_initial(self, key: str, req_hash: str, exp: float) -> bool:
26
+ """Atomically create an entry if absent. Returns True if created, False if already exists."""
27
+ pass
28
+
29
+ def set_response(
30
+ self,
31
+ key: str,
32
+ *,
33
+ status: int,
34
+ body: bytes,
35
+ headers: Dict[str, str],
36
+ media_type: Optional[str],
37
+ ) -> None:
38
+ pass
39
+
40
+ def delete(self, key: str) -> None:
41
+ pass
42
+
43
+
44
+ class InMemoryIdempotencyStore:
45
+ def __init__(self):
46
+ self._store: dict[str, IdempotencyEntry] = {}
47
+
48
+ def get(self, key: str) -> Optional[IdempotencyEntry]:
49
+ entry = self._store.get(key)
50
+ if not entry:
51
+ return None
52
+ # expire lazily
53
+ if entry.exp <= time.time():
54
+ self._store.pop(key, None)
55
+ return None
56
+ return entry
57
+
58
+ def set_initial(self, key: str, req_hash: str, exp: float) -> bool:
59
+ now = time.time()
60
+ existing = self._store.get(key)
61
+ if existing and existing.exp > now:
62
+ return False
63
+ self._store[key] = IdempotencyEntry(req_hash=req_hash, exp=exp)
64
+ return True
65
+
66
+ def set_response(
67
+ self,
68
+ key: str,
69
+ *,
70
+ status: int,
71
+ body: bytes,
72
+ headers: Dict[str, str],
73
+ media_type: Optional[str],
74
+ ) -> None:
75
+ entry = self._store.get(key)
76
+ if not entry:
77
+ # Create if missing to ensure replay works until exp
78
+ entry = IdempotencyEntry(req_hash="", exp=time.time() + 60)
79
+ self._store[key] = entry
80
+ entry.status = status
81
+ entry.body_b64 = base64.b64encode(body).decode()
82
+ entry.headers = dict(headers)
83
+ entry.media_type = media_type
84
+
85
+ def delete(self, key: str) -> None:
86
+ self._store.pop(key, None)
87
+
88
+
89
+ class RedisIdempotencyStore:
90
+ """A simple Redis-backed store.
91
+
92
+ Notes:
93
+ - Uses GET/SET with JSON payload; initial claim uses SETNX semantics.
94
+ - Not fully atomic for response update; sufficient for basic dedupe.
95
+ - For strict guarantees, replace with a Lua script (future improvement).
96
+ """
97
+
98
+ def __init__(self, redis_client, *, prefix: str = "idmp"):
99
+ self.r = redis_client
100
+ self.prefix = prefix
101
+
102
+ def _k(self, key: str) -> str:
103
+ return f"{self.prefix}:{key}"
104
+
105
+ def get(self, key: str) -> Optional[IdempotencyEntry]:
106
+ raw = self.r.get(self._k(key))
107
+ if not raw:
108
+ return None
109
+ try:
110
+ data = json.loads(raw)
111
+ except Exception:
112
+ return None
113
+ entry = IdempotencyEntry(
114
+ req_hash=data.get("req_hash", ""),
115
+ exp=float(data.get("exp", 0)),
116
+ status=data.get("status"),
117
+ body_b64=data.get("body_b64"),
118
+ headers=data.get("headers"),
119
+ media_type=data.get("media_type"),
120
+ )
121
+ if entry.exp <= time.time():
122
+ try:
123
+ self.r.delete(self._k(key))
124
+ except Exception:
125
+ pass
126
+ return None
127
+ return entry
128
+
129
+ def set_initial(self, key: str, req_hash: str, exp: float) -> bool:
130
+ payload = json.dumps({"req_hash": req_hash, "exp": exp})
131
+ # Attempt NX set
132
+ ok = self.r.set(self._k(key), payload, nx=True)
133
+ # If set, also set TTL (expire at exp)
134
+ if ok:
135
+ ttl = max(1, int(exp - time.time()))
136
+ try:
137
+ self.r.expire(self._k(key), ttl)
138
+ except Exception:
139
+ pass
140
+ return True
141
+ # If exists but expired, overwrite
142
+ entry = self.get(key)
143
+ if not entry:
144
+ self.r.set(self._k(key), payload)
145
+ ttl = max(1, int(exp - time.time()))
146
+ try:
147
+ self.r.expire(self._k(key), ttl)
148
+ except Exception:
149
+ pass
150
+ return True
151
+ return False
152
+
153
+ def set_response(
154
+ self,
155
+ key: str,
156
+ *,
157
+ status: int,
158
+ body: bytes,
159
+ headers: Dict[str, str],
160
+ media_type: Optional[str],
161
+ ) -> None:
162
+ entry = self.get(key)
163
+ if not entry:
164
+ # default short ttl if missing; caller should have set initial
165
+ entry = IdempotencyEntry(req_hash="", exp=time.time() + 60)
166
+ entry.status = status
167
+ entry.body_b64 = base64.b64encode(body).decode()
168
+ entry.headers = dict(headers)
169
+ entry.media_type = media_type
170
+ ttl = max(1, int(entry.exp - time.time()))
171
+ payload = json.dumps(
172
+ {
173
+ "req_hash": entry.req_hash,
174
+ "exp": entry.exp,
175
+ "status": entry.status,
176
+ "body_b64": entry.body_b64,
177
+ "headers": entry.headers,
178
+ "media_type": entry.media_type,
179
+ }
180
+ )
181
+ self.r.set(self._k(key), payload, ex=ttl)
182
+
183
+ def delete(self, key: str) -> None:
184
+ try:
185
+ self.r.delete(self._k(key))
186
+ except Exception:
187
+ pass
@@ -0,0 +1,37 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Annotated, Any, Callable, Optional
4
+
5
+ from fastapi import Header, HTTPException
6
+
7
+
8
+ async def require_if_match(
9
+ version: Annotated[Optional[str], Header(alias="If-Match")] = None
10
+ ) -> str:
11
+ """Require If-Match header for optimistic locking on mutating operations.
12
+
13
+ Returns the header value. Raises 428 if missing.
14
+ """
15
+ if not version:
16
+ raise HTTPException(
17
+ status_code=428, detail="Missing If-Match header for optimistic locking."
18
+ )
19
+ return version
20
+
21
+
22
+ def check_version_or_409(get_current_version: Callable[[], Any], provided: str) -> None:
23
+ """Compare provided version with current version; raise 409 on mismatch.
24
+
25
+ - get_current_version: callable returning the resource's current version (int/str)
26
+ - provided: header value; attempts to coerce to int if current is int
27
+ """
28
+ current = get_current_version()
29
+ if isinstance(current, int):
30
+ try:
31
+ p = int(provided)
32
+ except Exception:
33
+ raise HTTPException(status_code=400, detail="Invalid If-Match value; expected integer.")
34
+ else:
35
+ p = provided
36
+ if p != current:
37
+ raise HTTPException(status_code=409, detail="Version mismatch (optimistic locking).")