kubiya-control-plane-api 0.1.0__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kubiya-control-plane-api might be problematic. Click here for more details.

Files changed (185) hide show
  1. control_plane_api/README.md +266 -0
  2. control_plane_api/__init__.py +0 -0
  3. control_plane_api/__version__.py +1 -0
  4. control_plane_api/alembic/README +1 -0
  5. control_plane_api/alembic/env.py +98 -0
  6. control_plane_api/alembic/script.py.mako +28 -0
  7. control_plane_api/alembic/versions/1382bec74309_initial_migration_with_all_models.py +251 -0
  8. control_plane_api/alembic/versions/1f54bc2a37e3_add_analytics_tables.py +162 -0
  9. control_plane_api/alembic/versions/2e4cb136dc10_rename_toolset_ids_to_skill_ids_in_teams.py +30 -0
  10. control_plane_api/alembic/versions/31cd69a644ce_add_skill_templates_table.py +28 -0
  11. control_plane_api/alembic/versions/89e127caa47d_add_jobs_and_job_executions_tables.py +161 -0
  12. control_plane_api/alembic/versions/add_llm_models_table.py +51 -0
  13. control_plane_api/alembic/versions/b0e10697f212_add_runtime_column_to_teams_simple.py +42 -0
  14. control_plane_api/alembic/versions/ce43b24b63bf_add_execution_trigger_source_and_fix_.py +155 -0
  15. control_plane_api/alembic/versions/d4eaf16e3f8d_rename_toolsets_to_skills.py +84 -0
  16. control_plane_api/alembic/versions/efa2dc427da1_rename_metadata_to_custom_metadata.py +32 -0
  17. control_plane_api/alembic/versions/f973b431d1ce_add_workflow_executor_to_skill_types.py +44 -0
  18. control_plane_api/alembic.ini +148 -0
  19. control_plane_api/api/index.py +12 -0
  20. control_plane_api/app/__init__.py +11 -0
  21. control_plane_api/app/activities/__init__.py +20 -0
  22. control_plane_api/app/activities/agent_activities.py +379 -0
  23. control_plane_api/app/activities/team_activities.py +410 -0
  24. control_plane_api/app/activities/temporal_cloud_activities.py +577 -0
  25. control_plane_api/app/config/__init__.py +35 -0
  26. control_plane_api/app/config/api_config.py +354 -0
  27. control_plane_api/app/config/model_pricing.py +318 -0
  28. control_plane_api/app/config.py +95 -0
  29. control_plane_api/app/database.py +135 -0
  30. control_plane_api/app/exceptions.py +408 -0
  31. control_plane_api/app/lib/__init__.py +11 -0
  32. control_plane_api/app/lib/job_executor.py +312 -0
  33. control_plane_api/app/lib/kubiya_client.py +235 -0
  34. control_plane_api/app/lib/litellm_pricing.py +166 -0
  35. control_plane_api/app/lib/planning_tools/__init__.py +22 -0
  36. control_plane_api/app/lib/planning_tools/agents.py +155 -0
  37. control_plane_api/app/lib/planning_tools/base.py +189 -0
  38. control_plane_api/app/lib/planning_tools/environments.py +214 -0
  39. control_plane_api/app/lib/planning_tools/resources.py +240 -0
  40. control_plane_api/app/lib/planning_tools/teams.py +198 -0
  41. control_plane_api/app/lib/policy_enforcer_client.py +939 -0
  42. control_plane_api/app/lib/redis_client.py +436 -0
  43. control_plane_api/app/lib/supabase.py +71 -0
  44. control_plane_api/app/lib/temporal_client.py +138 -0
  45. control_plane_api/app/lib/validation/__init__.py +20 -0
  46. control_plane_api/app/lib/validation/runtime_validation.py +287 -0
  47. control_plane_api/app/main.py +128 -0
  48. control_plane_api/app/middleware/__init__.py +8 -0
  49. control_plane_api/app/middleware/auth.py +513 -0
  50. control_plane_api/app/middleware/exception_handler.py +267 -0
  51. control_plane_api/app/middleware/rate_limiting.py +384 -0
  52. control_plane_api/app/middleware/request_id.py +202 -0
  53. control_plane_api/app/models/__init__.py +27 -0
  54. control_plane_api/app/models/agent.py +79 -0
  55. control_plane_api/app/models/analytics.py +206 -0
  56. control_plane_api/app/models/associations.py +81 -0
  57. control_plane_api/app/models/environment.py +63 -0
  58. control_plane_api/app/models/execution.py +93 -0
  59. control_plane_api/app/models/job.py +179 -0
  60. control_plane_api/app/models/llm_model.py +75 -0
  61. control_plane_api/app/models/presence.py +49 -0
  62. control_plane_api/app/models/project.py +47 -0
  63. control_plane_api/app/models/session.py +38 -0
  64. control_plane_api/app/models/team.py +66 -0
  65. control_plane_api/app/models/workflow.py +55 -0
  66. control_plane_api/app/policies/README.md +121 -0
  67. control_plane_api/app/policies/approved_users.rego +62 -0
  68. control_plane_api/app/policies/business_hours.rego +51 -0
  69. control_plane_api/app/policies/rate_limiting.rego +100 -0
  70. control_plane_api/app/policies/tool_restrictions.rego +86 -0
  71. control_plane_api/app/routers/__init__.py +4 -0
  72. control_plane_api/app/routers/agents.py +364 -0
  73. control_plane_api/app/routers/agents_v2.py +1260 -0
  74. control_plane_api/app/routers/analytics.py +1014 -0
  75. control_plane_api/app/routers/context_manager.py +562 -0
  76. control_plane_api/app/routers/environment_context.py +270 -0
  77. control_plane_api/app/routers/environments.py +715 -0
  78. control_plane_api/app/routers/execution_environment.py +517 -0
  79. control_plane_api/app/routers/executions.py +1911 -0
  80. control_plane_api/app/routers/health.py +92 -0
  81. control_plane_api/app/routers/health_v2.py +326 -0
  82. control_plane_api/app/routers/integrations.py +274 -0
  83. control_plane_api/app/routers/jobs.py +1344 -0
  84. control_plane_api/app/routers/models.py +82 -0
  85. control_plane_api/app/routers/models_v2.py +361 -0
  86. control_plane_api/app/routers/policies.py +639 -0
  87. control_plane_api/app/routers/presence.py +234 -0
  88. control_plane_api/app/routers/projects.py +902 -0
  89. control_plane_api/app/routers/runners.py +379 -0
  90. control_plane_api/app/routers/runtimes.py +172 -0
  91. control_plane_api/app/routers/secrets.py +155 -0
  92. control_plane_api/app/routers/skills.py +1001 -0
  93. control_plane_api/app/routers/skills_definitions.py +140 -0
  94. control_plane_api/app/routers/task_planning.py +1256 -0
  95. control_plane_api/app/routers/task_queues.py +654 -0
  96. control_plane_api/app/routers/team_context.py +270 -0
  97. control_plane_api/app/routers/teams.py +1400 -0
  98. control_plane_api/app/routers/worker_queues.py +1545 -0
  99. control_plane_api/app/routers/workers.py +935 -0
  100. control_plane_api/app/routers/workflows.py +204 -0
  101. control_plane_api/app/runtimes/__init__.py +6 -0
  102. control_plane_api/app/runtimes/validation.py +344 -0
  103. control_plane_api/app/schemas/job_schemas.py +295 -0
  104. control_plane_api/app/services/__init__.py +1 -0
  105. control_plane_api/app/services/agno_service.py +619 -0
  106. control_plane_api/app/services/litellm_service.py +190 -0
  107. control_plane_api/app/services/policy_service.py +525 -0
  108. control_plane_api/app/services/temporal_cloud_provisioning.py +150 -0
  109. control_plane_api/app/skills/__init__.py +44 -0
  110. control_plane_api/app/skills/base.py +229 -0
  111. control_plane_api/app/skills/business_intelligence.py +189 -0
  112. control_plane_api/app/skills/data_visualization.py +154 -0
  113. control_plane_api/app/skills/docker.py +104 -0
  114. control_plane_api/app/skills/file_generation.py +94 -0
  115. control_plane_api/app/skills/file_system.py +110 -0
  116. control_plane_api/app/skills/python.py +92 -0
  117. control_plane_api/app/skills/registry.py +65 -0
  118. control_plane_api/app/skills/shell.py +102 -0
  119. control_plane_api/app/skills/workflow_executor.py +469 -0
  120. control_plane_api/app/utils/workflow_executor.py +354 -0
  121. control_plane_api/app/workflows/__init__.py +11 -0
  122. control_plane_api/app/workflows/agent_execution.py +507 -0
  123. control_plane_api/app/workflows/agent_execution_with_skills.py +222 -0
  124. control_plane_api/app/workflows/namespace_provisioning.py +326 -0
  125. control_plane_api/app/workflows/team_execution.py +399 -0
  126. control_plane_api/scripts/seed_models.py +239 -0
  127. control_plane_api/worker/__init__.py +0 -0
  128. control_plane_api/worker/activities/__init__.py +0 -0
  129. control_plane_api/worker/activities/agent_activities.py +1241 -0
  130. control_plane_api/worker/activities/approval_activities.py +234 -0
  131. control_plane_api/worker/activities/runtime_activities.py +388 -0
  132. control_plane_api/worker/activities/skill_activities.py +267 -0
  133. control_plane_api/worker/activities/team_activities.py +1217 -0
  134. control_plane_api/worker/config/__init__.py +31 -0
  135. control_plane_api/worker/config/worker_config.py +275 -0
  136. control_plane_api/worker/control_plane_client.py +529 -0
  137. control_plane_api/worker/examples/analytics_integration_example.py +362 -0
  138. control_plane_api/worker/models/__init__.py +1 -0
  139. control_plane_api/worker/models/inputs.py +89 -0
  140. control_plane_api/worker/runtimes/__init__.py +31 -0
  141. control_plane_api/worker/runtimes/base.py +789 -0
  142. control_plane_api/worker/runtimes/claude_code_runtime.py +1443 -0
  143. control_plane_api/worker/runtimes/default_runtime.py +617 -0
  144. control_plane_api/worker/runtimes/factory.py +173 -0
  145. control_plane_api/worker/runtimes/validation.py +93 -0
  146. control_plane_api/worker/services/__init__.py +1 -0
  147. control_plane_api/worker/services/agent_executor.py +422 -0
  148. control_plane_api/worker/services/agent_executor_v2.py +383 -0
  149. control_plane_api/worker/services/analytics_collector.py +457 -0
  150. control_plane_api/worker/services/analytics_service.py +464 -0
  151. control_plane_api/worker/services/approval_tools.py +310 -0
  152. control_plane_api/worker/services/approval_tools_agno.py +207 -0
  153. control_plane_api/worker/services/cancellation_manager.py +177 -0
  154. control_plane_api/worker/services/data_visualization.py +827 -0
  155. control_plane_api/worker/services/jira_tools.py +257 -0
  156. control_plane_api/worker/services/runtime_analytics.py +328 -0
  157. control_plane_api/worker/services/session_service.py +194 -0
  158. control_plane_api/worker/services/skill_factory.py +175 -0
  159. control_plane_api/worker/services/team_executor.py +574 -0
  160. control_plane_api/worker/services/team_executor_v2.py +465 -0
  161. control_plane_api/worker/services/workflow_executor_tools.py +1418 -0
  162. control_plane_api/worker/tests/__init__.py +1 -0
  163. control_plane_api/worker/tests/e2e/__init__.py +0 -0
  164. control_plane_api/worker/tests/e2e/test_execution_flow.py +571 -0
  165. control_plane_api/worker/tests/integration/__init__.py +0 -0
  166. control_plane_api/worker/tests/integration/test_control_plane_integration.py +308 -0
  167. control_plane_api/worker/tests/unit/__init__.py +0 -0
  168. control_plane_api/worker/tests/unit/test_control_plane_client.py +401 -0
  169. control_plane_api/worker/utils/__init__.py +1 -0
  170. control_plane_api/worker/utils/chunk_batcher.py +305 -0
  171. control_plane_api/worker/utils/retry_utils.py +60 -0
  172. control_plane_api/worker/utils/streaming_utils.py +373 -0
  173. control_plane_api/worker/worker.py +753 -0
  174. control_plane_api/worker/workflows/__init__.py +0 -0
  175. control_plane_api/worker/workflows/agent_execution.py +589 -0
  176. control_plane_api/worker/workflows/team_execution.py +429 -0
  177. kubiya_control_plane_api-0.3.4.dist-info/METADATA +229 -0
  178. kubiya_control_plane_api-0.3.4.dist-info/RECORD +182 -0
  179. kubiya_control_plane_api-0.3.4.dist-info/entry_points.txt +2 -0
  180. kubiya_control_plane_api-0.3.4.dist-info/top_level.txt +1 -0
  181. kubiya_control_plane_api-0.1.0.dist-info/METADATA +0 -66
  182. kubiya_control_plane_api-0.1.0.dist-info/RECORD +0 -5
  183. kubiya_control_plane_api-0.1.0.dist-info/top_level.txt +0 -1
  184. {kubiya_control_plane_api-0.1.0.dist-info/licenses → control_plane_api}/LICENSE +0 -0
  185. {kubiya_control_plane_api-0.1.0.dist-info → kubiya_control_plane_api-0.3.4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,267 @@
1
+ """
2
+ Global exception handler middleware for Control Plane API.
3
+
4
+ Catches all exceptions and returns standardized error responses.
5
+ """
6
+
7
+ from fastapi import Request, status
8
+ from fastapi.responses import JSONResponse
9
+ from fastapi.exceptions import RequestValidationError, HTTPException
10
+ from starlette.exceptions import HTTPException as StarletteHTTPException
11
+ from control_plane_api.app.exceptions import ControlPlaneException
12
+ import structlog
13
+ import traceback
14
+ import uuid
15
+ from datetime import datetime, timezone
16
+ from typing import Any
17
+
18
+ logger = structlog.get_logger()
19
+
20
+
21
+ async def control_plane_exception_handler(
22
+ request: Request,
23
+ exc: ControlPlaneException,
24
+ ) -> JSONResponse:
25
+ """
26
+ Handle ControlPlaneException instances.
27
+
28
+ These are our custom exceptions with structured error information.
29
+ """
30
+ error_id = str(uuid.uuid4())
31
+
32
+ # Log the error with context
33
+ logger.error(
34
+ "control_plane_error",
35
+ error_id=error_id,
36
+ error_code=exc.error_code,
37
+ error_message=exc.message,
38
+ error_details=exc.details,
39
+ path=str(request.url),
40
+ method=request.method,
41
+ )
42
+
43
+ # Build response
44
+ response = {
45
+ "error": {
46
+ "id": error_id,
47
+ "code": exc.error_code,
48
+ "message": exc.message,
49
+ "timestamp": datetime.now(timezone.utc).isoformat(),
50
+ }
51
+ }
52
+
53
+ # Add details if present
54
+ if exc.details:
55
+ response["error"]["details"] = exc.details
56
+
57
+ # Add request context in development
58
+ if request.app.debug:
59
+ response["error"]["request"] = {
60
+ "method": request.method,
61
+ "path": str(request.url.path),
62
+ "query": str(request.url.query) if request.url.query else None,
63
+ }
64
+
65
+ return JSONResponse(
66
+ status_code=exc.status_code,
67
+ content=response,
68
+ )
69
+
70
+
71
+ async def validation_exception_handler(
72
+ request: Request,
73
+ exc: RequestValidationError,
74
+ ) -> JSONResponse:
75
+ """
76
+ Handle Pydantic validation errors.
77
+
78
+ Converts Pydantic validation errors to our standard error format.
79
+ """
80
+ error_id = str(uuid.uuid4())
81
+
82
+ # Extract validation errors
83
+ errors = []
84
+ for error in exc.errors():
85
+ field_path = ".".join(str(loc) for loc in error["loc"])
86
+ errors.append({
87
+ "field": field_path,
88
+ "message": error["msg"],
89
+ "type": error["type"],
90
+ })
91
+
92
+ # Log validation error
93
+ logger.warning(
94
+ "validation_error",
95
+ error_id=error_id,
96
+ validation_errors=errors,
97
+ path=str(request.url),
98
+ method=request.method,
99
+ )
100
+
101
+ # Build response
102
+ response = {
103
+ "error": {
104
+ "id": error_id,
105
+ "code": "VALIDATION_ERROR",
106
+ "message": "Request validation failed",
107
+ "timestamp": datetime.now(timezone.utc).isoformat(),
108
+ "details": {
109
+ "validation_errors": errors,
110
+ },
111
+ }
112
+ }
113
+
114
+ return JSONResponse(
115
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
116
+ content=response,
117
+ )
118
+
119
+
120
+ async def http_exception_handler(
121
+ request: Request,
122
+ exc: HTTPException,
123
+ ) -> JSONResponse:
124
+ """
125
+ Handle FastAPI/Starlette HTTP exceptions.
126
+
127
+ Converts standard HTTP exceptions to our error format.
128
+ """
129
+ error_id = str(uuid.uuid4())
130
+
131
+ # Log the HTTP exception
132
+ logger.warning(
133
+ "http_exception",
134
+ error_id=error_id,
135
+ status_code=exc.status_code,
136
+ detail=exc.detail,
137
+ path=str(request.url),
138
+ method=request.method,
139
+ )
140
+
141
+ # Map status codes to error codes
142
+ error_code_map = {
143
+ 400: "BAD_REQUEST",
144
+ 401: "AUTHENTICATION_ERROR",
145
+ 403: "AUTHORIZATION_ERROR",
146
+ 404: "RESOURCE_NOT_FOUND",
147
+ 405: "METHOD_NOT_ALLOWED",
148
+ 409: "RESOURCE_CONFLICT",
149
+ 422: "VALIDATION_ERROR",
150
+ 429: "RATE_LIMIT_EXCEEDED",
151
+ 500: "INTERNAL_ERROR",
152
+ 502: "BAD_GATEWAY",
153
+ 503: "SERVICE_UNAVAILABLE",
154
+ 504: "GATEWAY_TIMEOUT",
155
+ }
156
+
157
+ error_code = error_code_map.get(exc.status_code, "HTTP_ERROR")
158
+
159
+ # Build response
160
+ response = {
161
+ "error": {
162
+ "id": error_id,
163
+ "code": error_code,
164
+ "message": exc.detail or f"HTTP {exc.status_code} Error",
165
+ "timestamp": datetime.now(timezone.utc).isoformat(),
166
+ }
167
+ }
168
+
169
+ # Add headers if present
170
+ if hasattr(exc, "headers") and exc.headers:
171
+ response["error"]["headers"] = dict(exc.headers)
172
+
173
+ return JSONResponse(
174
+ status_code=exc.status_code,
175
+ content=response,
176
+ headers=getattr(exc, "headers", None),
177
+ )
178
+
179
+
180
+ async def generic_exception_handler(
181
+ request: Request,
182
+ exc: Exception,
183
+ ) -> JSONResponse:
184
+ """
185
+ Handle unexpected exceptions.
186
+
187
+ This is the catch-all handler for any unhandled exceptions.
188
+ """
189
+ error_id = str(uuid.uuid4())
190
+
191
+ # Log the full exception with traceback
192
+ logger.error(
193
+ "unhandled_exception",
194
+ error_id=error_id,
195
+ error_type=type(exc).__name__,
196
+ error_message=str(exc),
197
+ path=str(request.url),
198
+ method=request.method,
199
+ traceback=traceback.format_exc(),
200
+ exc_info=exc,
201
+ )
202
+
203
+ # Build response
204
+ # In production, don't expose internal error details
205
+ if request.app.debug:
206
+ message = f"{type(exc).__name__}: {str(exc)}"
207
+ details = {
208
+ "exception_type": type(exc).__name__,
209
+ "traceback": traceback.format_exc().split("\n"),
210
+ }
211
+ else:
212
+ message = "An internal server error occurred"
213
+ details = None
214
+
215
+ response = {
216
+ "error": {
217
+ "id": error_id,
218
+ "code": "INTERNAL_ERROR",
219
+ "message": message,
220
+ "timestamp": datetime.now(timezone.utc).isoformat(),
221
+ }
222
+ }
223
+
224
+ if details:
225
+ response["error"]["details"] = details
226
+
227
+ # Alert on unhandled exceptions (in production, this might trigger PagerDuty)
228
+ if not request.app.debug:
229
+ logger.critical(
230
+ "unhandled_exception_alert",
231
+ error_id=error_id,
232
+ error_type=type(exc).__name__,
233
+ path=str(request.url),
234
+ )
235
+
236
+ return JSONResponse(
237
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
238
+ content=response,
239
+ )
240
+
241
+
242
+ def setup_exception_handlers(app: Any) -> None:
243
+ """
244
+ Register all exception handlers with the FastAPI app.
245
+
246
+ Call this in your main.py after creating the app.
247
+
248
+ Example:
249
+ from control_plane_api.app.middleware.exception_handler import setup_exception_handlers
250
+
251
+ app = FastAPI()
252
+ setup_exception_handlers(app)
253
+ """
254
+ # Our custom exceptions
255
+ app.add_exception_handler(ControlPlaneException, control_plane_exception_handler)
256
+
257
+ # FastAPI/Pydantic validation errors
258
+ app.add_exception_handler(RequestValidationError, validation_exception_handler)
259
+
260
+ # HTTP exceptions
261
+ app.add_exception_handler(HTTPException, http_exception_handler)
262
+ app.add_exception_handler(StarletteHTTPException, http_exception_handler)
263
+
264
+ # Catch-all for unhandled exceptions
265
+ app.add_exception_handler(Exception, generic_exception_handler)
266
+
267
+ logger.info("exception_handlers_registered")
@@ -0,0 +1,384 @@
1
+ """
2
+ Rate limiting middleware using token bucket algorithm.
3
+
4
+ Provides configurable rate limiting per client IP or user.
5
+ """
6
+
7
+ from fastapi import Request, Response, HTTPException, status
8
+ from starlette.middleware.base import BaseHTTPMiddleware
9
+ from starlette.types import ASGIApp
10
+ from typing import Dict, Optional, Tuple, Any
11
+ from datetime import datetime, timedelta
12
+ import time
13
+ import asyncio
14
+ import structlog
15
+ import hashlib
16
+ from control_plane_api.app.exceptions import RateLimitError
17
+
18
+ logger = structlog.get_logger()
19
+
20
+
21
+ class TokenBucket:
22
+ """
23
+ Token bucket implementation for rate limiting.
24
+
25
+ Each bucket starts with a capacity of tokens.
26
+ Tokens are consumed when requests are made.
27
+ Tokens are refilled at a constant rate.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ capacity: int,
33
+ refill_rate: float,
34
+ refill_period: float = 60.0,
35
+ ):
36
+ """
37
+ Initialize token bucket.
38
+
39
+ Args:
40
+ capacity: Maximum number of tokens in bucket
41
+ refill_rate: Number of tokens to add per period
42
+ refill_period: Period in seconds for refilling tokens
43
+ """
44
+ self.capacity = capacity
45
+ self.refill_rate = refill_rate
46
+ self.refill_period = refill_period
47
+ self.tokens = capacity
48
+ self.last_refill = time.time()
49
+ self.lock = asyncio.Lock()
50
+
51
+ async def consume(self, tokens: int = 1) -> Tuple[bool, Dict[str, Any]]:
52
+ """
53
+ Try to consume tokens from the bucket.
54
+
55
+ Args:
56
+ tokens: Number of tokens to consume
57
+
58
+ Returns:
59
+ Tuple of (success, info_dict)
60
+ """
61
+ async with self.lock:
62
+ now = time.time()
63
+
64
+ # Refill tokens based on time elapsed
65
+ time_elapsed = now - self.last_refill
66
+ tokens_to_add = (time_elapsed / self.refill_period) * self.refill_rate
67
+
68
+ if tokens_to_add > 0:
69
+ self.tokens = min(self.capacity, self.tokens + tokens_to_add)
70
+ self.last_refill = now
71
+
72
+ # Check if we have enough tokens
73
+ if self.tokens >= tokens:
74
+ self.tokens -= tokens
75
+ success = True
76
+ retry_after = None
77
+ else:
78
+ success = False
79
+ # Calculate when enough tokens will be available
80
+ tokens_needed = tokens - self.tokens
81
+ time_to_wait = (tokens_needed / self.refill_rate) * self.refill_period
82
+ retry_after = int(time_to_wait) + 1
83
+
84
+ info = {
85
+ "limit": self.capacity,
86
+ "remaining": int(self.tokens),
87
+ "reset": int(now + self.refill_period),
88
+ "retry_after": retry_after,
89
+ }
90
+
91
+ return success, info
92
+
93
+
94
+ class RateLimiter:
95
+ """
96
+ Rate limiter managing multiple token buckets for different clients.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ requests_per_minute: int = 60,
102
+ burst_size: Optional[int] = None,
103
+ cleanup_interval: int = 300,
104
+ ):
105
+ """
106
+ Initialize rate limiter.
107
+
108
+ Args:
109
+ requests_per_minute: Number of requests allowed per minute
110
+ burst_size: Maximum burst size (defaults to requests_per_minute // 4)
111
+ cleanup_interval: Interval in seconds to clean up old buckets
112
+ """
113
+ self.requests_per_minute = requests_per_minute
114
+ self.burst_size = burst_size or max(10, requests_per_minute // 4)
115
+ self.buckets: Dict[str, TokenBucket] = {}
116
+ self.last_cleanup = time.time()
117
+ self.cleanup_interval = cleanup_interval
118
+ self.lock = asyncio.Lock()
119
+
120
+ async def check_rate_limit(self, identifier: str) -> Tuple[bool, Dict[str, Any]]:
121
+ """
122
+ Check if request is within rate limit.
123
+
124
+ Args:
125
+ identifier: Client identifier (IP, user ID, etc.)
126
+
127
+ Returns:
128
+ Tuple of (allowed, headers_dict)
129
+ """
130
+ # Clean up old buckets periodically
131
+ await self._cleanup_buckets()
132
+
133
+ # Get or create bucket for this identifier
134
+ bucket = await self._get_or_create_bucket(identifier)
135
+
136
+ # Try to consume a token
137
+ allowed, info = await bucket.consume()
138
+
139
+ # Build rate limit headers
140
+ headers = {
141
+ "X-RateLimit-Limit": str(info["limit"]),
142
+ "X-RateLimit-Remaining": str(info["remaining"]),
143
+ "X-RateLimit-Reset": str(info["reset"]),
144
+ }
145
+
146
+ if not allowed and info.get("retry_after"):
147
+ headers["Retry-After"] = str(info["retry_after"])
148
+
149
+ return allowed, headers
150
+
151
+ async def _get_or_create_bucket(self, identifier: str) -> TokenBucket:
152
+ """Get existing bucket or create new one."""
153
+ async with self.lock:
154
+ if identifier not in self.buckets:
155
+ self.buckets[identifier] = TokenBucket(
156
+ capacity=self.burst_size,
157
+ refill_rate=self.requests_per_minute,
158
+ refill_period=60.0,
159
+ )
160
+ return self.buckets[identifier]
161
+
162
+ async def _cleanup_buckets(self):
163
+ """Remove old unused buckets to prevent memory leak."""
164
+ now = time.time()
165
+ if now - self.last_cleanup < self.cleanup_interval:
166
+ return
167
+
168
+ async with self.lock:
169
+ # Remove buckets that haven't been used recently
170
+ cutoff_time = now - self.cleanup_interval
171
+ to_remove = []
172
+
173
+ for identifier, bucket in self.buckets.items():
174
+ if bucket.last_refill < cutoff_time:
175
+ to_remove.append(identifier)
176
+
177
+ for identifier in to_remove:
178
+ del self.buckets[identifier]
179
+
180
+ if to_remove:
181
+ logger.info(
182
+ "rate_limiter_cleanup",
183
+ removed_count=len(to_remove),
184
+ remaining_count=len(self.buckets),
185
+ )
186
+
187
+ self.last_cleanup = now
188
+
189
+
190
+ class RateLimitMiddleware(BaseHTTPMiddleware):
191
+ """
192
+ Rate limiting middleware for FastAPI.
193
+
194
+ Limits requests per client based on IP address or authenticated user.
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ app: ASGIApp,
200
+ requests_per_minute: int = 60,
201
+ burst_size: Optional[int] = None,
202
+ exclude_paths: Optional[list] = None,
203
+ identifier_callback: Optional[callable] = None,
204
+ ):
205
+ """
206
+ Initialize rate limit middleware.
207
+
208
+ Args:
209
+ app: FastAPI application
210
+ requests_per_minute: Default rate limit
211
+ burst_size: Maximum burst size
212
+ exclude_paths: Paths to exclude from rate limiting
213
+ identifier_callback: Custom function to get client identifier
214
+ """
215
+ super().__init__(app)
216
+ self.rate_limiter = RateLimiter(requests_per_minute, burst_size)
217
+ self.exclude_paths = exclude_paths or ["/health", "/metrics", "/docs", "/openapi.json"]
218
+ self.identifier_callback = identifier_callback
219
+
220
+ async def dispatch(self, request: Request, call_next):
221
+ """Process request with rate limiting."""
222
+
223
+ # Check if path is excluded
224
+ if self._is_excluded(request.url.path):
225
+ return await call_next(request)
226
+
227
+ # Get client identifier
228
+ identifier = await self._get_identifier(request)
229
+
230
+ # Check rate limit
231
+ allowed, headers = await self.rate_limiter.check_rate_limit(identifier)
232
+
233
+ if not allowed:
234
+ # Log rate limit exceeded
235
+ logger.warning(
236
+ "rate_limit_exceeded",
237
+ identifier=self._hash_identifier(identifier),
238
+ path=request.url.path,
239
+ method=request.method,
240
+ )
241
+
242
+ # Raise rate limit error
243
+ raise RateLimitError(
244
+ limit=self.rate_limiter.requests_per_minute,
245
+ window="minute",
246
+ retry_after=int(headers.get("Retry-After", 60)),
247
+ )
248
+
249
+ # Process request
250
+ response = await call_next(request)
251
+
252
+ # Add rate limit headers to response
253
+ for header, value in headers.items():
254
+ response.headers[header] = value
255
+
256
+ return response
257
+
258
+ def _is_excluded(self, path: str) -> bool:
259
+ """Check if path is excluded from rate limiting."""
260
+ for excluded in self.exclude_paths:
261
+ if path.startswith(excluded):
262
+ return True
263
+ return False
264
+
265
+ async def _get_identifier(self, request: Request) -> str:
266
+ """
267
+ Get client identifier for rate limiting.
268
+
269
+ Priority:
270
+ 1. Custom identifier callback
271
+ 2. Authenticated user ID
272
+ 3. Client IP address
273
+ """
274
+ # Use custom identifier callback if provided
275
+ if self.identifier_callback:
276
+ identifier = await self.identifier_callback(request)
277
+ if identifier:
278
+ return f"custom:{identifier}"
279
+
280
+ # Check for authenticated user
281
+ if hasattr(request.state, "user") and request.state.user:
282
+ user_id = getattr(request.state.user, "id", None)
283
+ if user_id:
284
+ return f"user:{user_id}"
285
+
286
+ # Fall back to IP address
287
+ client_host = request.client.host if request.client else "unknown"
288
+
289
+ # Check for proxy headers
290
+ forwarded_for = request.headers.get("X-Forwarded-For")
291
+ if forwarded_for:
292
+ # Take the first IP in the chain
293
+ client_host = forwarded_for.split(",")[0].strip()
294
+
295
+ real_ip = request.headers.get("X-Real-IP")
296
+ if real_ip:
297
+ client_host = real_ip
298
+
299
+ return f"ip:{client_host}"
300
+
301
+ def _hash_identifier(self, identifier: str) -> str:
302
+ """Hash identifier for logging (privacy)."""
303
+ return hashlib.sha256(identifier.encode()).hexdigest()[:16]
304
+
305
+
306
+ # Per-endpoint rate limiting decorator
307
+ class EndpointRateLimiter:
308
+ """
309
+ Decorator for per-endpoint rate limiting.
310
+
311
+ Usage:
312
+ rate_limiter = EndpointRateLimiter()
313
+
314
+ @app.get("/expensive-operation")
315
+ @rate_limiter.limit(requests_per_minute=10)
316
+ async def expensive_operation():
317
+ ...
318
+ """
319
+
320
+ def __init__(self):
321
+ self.limiters: Dict[str, RateLimiter] = {}
322
+
323
+ def limit(
324
+ self,
325
+ requests_per_minute: int,
326
+ burst_size: Optional[int] = None,
327
+ identifier_callback: Optional[callable] = None,
328
+ ):
329
+ """
330
+ Create rate limit decorator for endpoint.
331
+
332
+ Args:
333
+ requests_per_minute: Rate limit for this endpoint
334
+ burst_size: Burst size for this endpoint
335
+ identifier_callback: Custom identifier function
336
+ """
337
+ def decorator(func):
338
+ # Create rate limiter for this endpoint
339
+ endpoint_id = f"{func.__module__}.{func.__name__}"
340
+ self.limiters[endpoint_id] = RateLimiter(
341
+ requests_per_minute=requests_per_minute,
342
+ burst_size=burst_size,
343
+ )
344
+
345
+ async def wrapper(request: Request, *args, **kwargs):
346
+ # Get client identifier
347
+ if identifier_callback:
348
+ identifier = await identifier_callback(request)
349
+ else:
350
+ identifier = self._default_identifier(request)
351
+
352
+ # Check rate limit
353
+ limiter = self.limiters[endpoint_id]
354
+ allowed, headers = await limiter.check_rate_limit(identifier)
355
+
356
+ if not allowed:
357
+ raise RateLimitError(
358
+ limit=requests_per_minute,
359
+ window="minute",
360
+ retry_after=int(headers.get("Retry-After", 60)),
361
+ )
362
+
363
+ # Add headers to response (if we have access to it)
364
+ response = await func(request, *args, **kwargs)
365
+ if isinstance(response, Response):
366
+ for header, value in headers.items():
367
+ response.headers[header] = value
368
+
369
+ return response
370
+
371
+ return wrapper
372
+ return decorator
373
+
374
+ def _default_identifier(self, request: Request) -> str:
375
+ """Default identifier extraction."""
376
+ if hasattr(request.state, "user") and request.state.user:
377
+ return f"user:{request.state.user.id}"
378
+
379
+ client_host = request.client.host if request.client else "unknown"
380
+ return f"ip:{client_host}"
381
+
382
+
383
+ # Global rate limiter instance for decorator usage
384
+ endpoint_rate_limiter = EndpointRateLimiter()