kubiya-control-plane-api 0.1.0__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kubiya-control-plane-api might be problematic. Click here for more details.
- control_plane_api/README.md +266 -0
- control_plane_api/__init__.py +0 -0
- control_plane_api/__version__.py +1 -0
- control_plane_api/alembic/README +1 -0
- control_plane_api/alembic/env.py +98 -0
- control_plane_api/alembic/script.py.mako +28 -0
- control_plane_api/alembic/versions/1382bec74309_initial_migration_with_all_models.py +251 -0
- control_plane_api/alembic/versions/1f54bc2a37e3_add_analytics_tables.py +162 -0
- control_plane_api/alembic/versions/2e4cb136dc10_rename_toolset_ids_to_skill_ids_in_teams.py +30 -0
- control_plane_api/alembic/versions/31cd69a644ce_add_skill_templates_table.py +28 -0
- control_plane_api/alembic/versions/89e127caa47d_add_jobs_and_job_executions_tables.py +161 -0
- control_plane_api/alembic/versions/add_llm_models_table.py +51 -0
- control_plane_api/alembic/versions/b0e10697f212_add_runtime_column_to_teams_simple.py +42 -0
- control_plane_api/alembic/versions/ce43b24b63bf_add_execution_trigger_source_and_fix_.py +155 -0
- control_plane_api/alembic/versions/d4eaf16e3f8d_rename_toolsets_to_skills.py +84 -0
- control_plane_api/alembic/versions/efa2dc427da1_rename_metadata_to_custom_metadata.py +32 -0
- control_plane_api/alembic/versions/f973b431d1ce_add_workflow_executor_to_skill_types.py +44 -0
- control_plane_api/alembic.ini +148 -0
- control_plane_api/api/index.py +12 -0
- control_plane_api/app/__init__.py +11 -0
- control_plane_api/app/activities/__init__.py +20 -0
- control_plane_api/app/activities/agent_activities.py +379 -0
- control_plane_api/app/activities/team_activities.py +410 -0
- control_plane_api/app/activities/temporal_cloud_activities.py +577 -0
- control_plane_api/app/config/__init__.py +35 -0
- control_plane_api/app/config/api_config.py +354 -0
- control_plane_api/app/config/model_pricing.py +318 -0
- control_plane_api/app/config.py +95 -0
- control_plane_api/app/database.py +135 -0
- control_plane_api/app/exceptions.py +408 -0
- control_plane_api/app/lib/__init__.py +11 -0
- control_plane_api/app/lib/job_executor.py +312 -0
- control_plane_api/app/lib/kubiya_client.py +235 -0
- control_plane_api/app/lib/litellm_pricing.py +166 -0
- control_plane_api/app/lib/planning_tools/__init__.py +22 -0
- control_plane_api/app/lib/planning_tools/agents.py +155 -0
- control_plane_api/app/lib/planning_tools/base.py +189 -0
- control_plane_api/app/lib/planning_tools/environments.py +214 -0
- control_plane_api/app/lib/planning_tools/resources.py +240 -0
- control_plane_api/app/lib/planning_tools/teams.py +198 -0
- control_plane_api/app/lib/policy_enforcer_client.py +939 -0
- control_plane_api/app/lib/redis_client.py +436 -0
- control_plane_api/app/lib/supabase.py +71 -0
- control_plane_api/app/lib/temporal_client.py +138 -0
- control_plane_api/app/lib/validation/__init__.py +20 -0
- control_plane_api/app/lib/validation/runtime_validation.py +287 -0
- control_plane_api/app/main.py +128 -0
- control_plane_api/app/middleware/__init__.py +8 -0
- control_plane_api/app/middleware/auth.py +513 -0
- control_plane_api/app/middleware/exception_handler.py +267 -0
- control_plane_api/app/middleware/rate_limiting.py +384 -0
- control_plane_api/app/middleware/request_id.py +202 -0
- control_plane_api/app/models/__init__.py +27 -0
- control_plane_api/app/models/agent.py +79 -0
- control_plane_api/app/models/analytics.py +206 -0
- control_plane_api/app/models/associations.py +81 -0
- control_plane_api/app/models/environment.py +63 -0
- control_plane_api/app/models/execution.py +93 -0
- control_plane_api/app/models/job.py +179 -0
- control_plane_api/app/models/llm_model.py +75 -0
- control_plane_api/app/models/presence.py +49 -0
- control_plane_api/app/models/project.py +47 -0
- control_plane_api/app/models/session.py +38 -0
- control_plane_api/app/models/team.py +66 -0
- control_plane_api/app/models/workflow.py +55 -0
- control_plane_api/app/policies/README.md +121 -0
- control_plane_api/app/policies/approved_users.rego +62 -0
- control_plane_api/app/policies/business_hours.rego +51 -0
- control_plane_api/app/policies/rate_limiting.rego +100 -0
- control_plane_api/app/policies/tool_restrictions.rego +86 -0
- control_plane_api/app/routers/__init__.py +4 -0
- control_plane_api/app/routers/agents.py +364 -0
- control_plane_api/app/routers/agents_v2.py +1260 -0
- control_plane_api/app/routers/analytics.py +1014 -0
- control_plane_api/app/routers/context_manager.py +562 -0
- control_plane_api/app/routers/environment_context.py +270 -0
- control_plane_api/app/routers/environments.py +715 -0
- control_plane_api/app/routers/execution_environment.py +517 -0
- control_plane_api/app/routers/executions.py +1911 -0
- control_plane_api/app/routers/health.py +92 -0
- control_plane_api/app/routers/health_v2.py +326 -0
- control_plane_api/app/routers/integrations.py +274 -0
- control_plane_api/app/routers/jobs.py +1344 -0
- control_plane_api/app/routers/models.py +82 -0
- control_plane_api/app/routers/models_v2.py +361 -0
- control_plane_api/app/routers/policies.py +639 -0
- control_plane_api/app/routers/presence.py +234 -0
- control_plane_api/app/routers/projects.py +902 -0
- control_plane_api/app/routers/runners.py +379 -0
- control_plane_api/app/routers/runtimes.py +172 -0
- control_plane_api/app/routers/secrets.py +155 -0
- control_plane_api/app/routers/skills.py +1001 -0
- control_plane_api/app/routers/skills_definitions.py +140 -0
- control_plane_api/app/routers/task_planning.py +1256 -0
- control_plane_api/app/routers/task_queues.py +654 -0
- control_plane_api/app/routers/team_context.py +270 -0
- control_plane_api/app/routers/teams.py +1400 -0
- control_plane_api/app/routers/worker_queues.py +1545 -0
- control_plane_api/app/routers/workers.py +935 -0
- control_plane_api/app/routers/workflows.py +204 -0
- control_plane_api/app/runtimes/__init__.py +6 -0
- control_plane_api/app/runtimes/validation.py +344 -0
- control_plane_api/app/schemas/job_schemas.py +295 -0
- control_plane_api/app/services/__init__.py +1 -0
- control_plane_api/app/services/agno_service.py +619 -0
- control_plane_api/app/services/litellm_service.py +190 -0
- control_plane_api/app/services/policy_service.py +525 -0
- control_plane_api/app/services/temporal_cloud_provisioning.py +150 -0
- control_plane_api/app/skills/__init__.py +44 -0
- control_plane_api/app/skills/base.py +229 -0
- control_plane_api/app/skills/business_intelligence.py +189 -0
- control_plane_api/app/skills/data_visualization.py +154 -0
- control_plane_api/app/skills/docker.py +104 -0
- control_plane_api/app/skills/file_generation.py +94 -0
- control_plane_api/app/skills/file_system.py +110 -0
- control_plane_api/app/skills/python.py +92 -0
- control_plane_api/app/skills/registry.py +65 -0
- control_plane_api/app/skills/shell.py +102 -0
- control_plane_api/app/skills/workflow_executor.py +469 -0
- control_plane_api/app/utils/workflow_executor.py +354 -0
- control_plane_api/app/workflows/__init__.py +11 -0
- control_plane_api/app/workflows/agent_execution.py +507 -0
- control_plane_api/app/workflows/agent_execution_with_skills.py +222 -0
- control_plane_api/app/workflows/namespace_provisioning.py +326 -0
- control_plane_api/app/workflows/team_execution.py +399 -0
- control_plane_api/scripts/seed_models.py +239 -0
- control_plane_api/worker/__init__.py +0 -0
- control_plane_api/worker/activities/__init__.py +0 -0
- control_plane_api/worker/activities/agent_activities.py +1241 -0
- control_plane_api/worker/activities/approval_activities.py +234 -0
- control_plane_api/worker/activities/runtime_activities.py +388 -0
- control_plane_api/worker/activities/skill_activities.py +267 -0
- control_plane_api/worker/activities/team_activities.py +1217 -0
- control_plane_api/worker/config/__init__.py +31 -0
- control_plane_api/worker/config/worker_config.py +275 -0
- control_plane_api/worker/control_plane_client.py +529 -0
- control_plane_api/worker/examples/analytics_integration_example.py +362 -0
- control_plane_api/worker/models/__init__.py +1 -0
- control_plane_api/worker/models/inputs.py +89 -0
- control_plane_api/worker/runtimes/__init__.py +31 -0
- control_plane_api/worker/runtimes/base.py +789 -0
- control_plane_api/worker/runtimes/claude_code_runtime.py +1443 -0
- control_plane_api/worker/runtimes/default_runtime.py +617 -0
- control_plane_api/worker/runtimes/factory.py +173 -0
- control_plane_api/worker/runtimes/validation.py +93 -0
- control_plane_api/worker/services/__init__.py +1 -0
- control_plane_api/worker/services/agent_executor.py +422 -0
- control_plane_api/worker/services/agent_executor_v2.py +383 -0
- control_plane_api/worker/services/analytics_collector.py +457 -0
- control_plane_api/worker/services/analytics_service.py +464 -0
- control_plane_api/worker/services/approval_tools.py +310 -0
- control_plane_api/worker/services/approval_tools_agno.py +207 -0
- control_plane_api/worker/services/cancellation_manager.py +177 -0
- control_plane_api/worker/services/data_visualization.py +827 -0
- control_plane_api/worker/services/jira_tools.py +257 -0
- control_plane_api/worker/services/runtime_analytics.py +328 -0
- control_plane_api/worker/services/session_service.py +194 -0
- control_plane_api/worker/services/skill_factory.py +175 -0
- control_plane_api/worker/services/team_executor.py +574 -0
- control_plane_api/worker/services/team_executor_v2.py +465 -0
- control_plane_api/worker/services/workflow_executor_tools.py +1418 -0
- control_plane_api/worker/tests/__init__.py +1 -0
- control_plane_api/worker/tests/e2e/__init__.py +0 -0
- control_plane_api/worker/tests/e2e/test_execution_flow.py +571 -0
- control_plane_api/worker/tests/integration/__init__.py +0 -0
- control_plane_api/worker/tests/integration/test_control_plane_integration.py +308 -0
- control_plane_api/worker/tests/unit/__init__.py +0 -0
- control_plane_api/worker/tests/unit/test_control_plane_client.py +401 -0
- control_plane_api/worker/utils/__init__.py +1 -0
- control_plane_api/worker/utils/chunk_batcher.py +305 -0
- control_plane_api/worker/utils/retry_utils.py +60 -0
- control_plane_api/worker/utils/streaming_utils.py +373 -0
- control_plane_api/worker/worker.py +753 -0
- control_plane_api/worker/workflows/__init__.py +0 -0
- control_plane_api/worker/workflows/agent_execution.py +589 -0
- control_plane_api/worker/workflows/team_execution.py +429 -0
- kubiya_control_plane_api-0.3.4.dist-info/METADATA +229 -0
- kubiya_control_plane_api-0.3.4.dist-info/RECORD +182 -0
- kubiya_control_plane_api-0.3.4.dist-info/entry_points.txt +2 -0
- kubiya_control_plane_api-0.3.4.dist-info/top_level.txt +1 -0
- kubiya_control_plane_api-0.1.0.dist-info/METADATA +0 -66
- kubiya_control_plane_api-0.1.0.dist-info/RECORD +0 -5
- kubiya_control_plane_api-0.1.0.dist-info/top_level.txt +0 -1
- {kubiya_control_plane_api-0.1.0.dist-info/licenses → control_plane_api}/LICENSE +0 -0
- {kubiya_control_plane_api-0.1.0.dist-info → kubiya_control_plane_api-0.3.4.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Global exception handler middleware for Control Plane API.
|
|
3
|
+
|
|
4
|
+
Catches all exceptions and returns standardized error responses.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from fastapi import Request, status
|
|
8
|
+
from fastapi.responses import JSONResponse
|
|
9
|
+
from fastapi.exceptions import RequestValidationError, HTTPException
|
|
10
|
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
11
|
+
from control_plane_api.app.exceptions import ControlPlaneException
|
|
12
|
+
import structlog
|
|
13
|
+
import traceback
|
|
14
|
+
import uuid
|
|
15
|
+
from datetime import datetime, timezone
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
logger = structlog.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def control_plane_exception_handler(
|
|
22
|
+
request: Request,
|
|
23
|
+
exc: ControlPlaneException,
|
|
24
|
+
) -> JSONResponse:
|
|
25
|
+
"""
|
|
26
|
+
Handle ControlPlaneException instances.
|
|
27
|
+
|
|
28
|
+
These are our custom exceptions with structured error information.
|
|
29
|
+
"""
|
|
30
|
+
error_id = str(uuid.uuid4())
|
|
31
|
+
|
|
32
|
+
# Log the error with context
|
|
33
|
+
logger.error(
|
|
34
|
+
"control_plane_error",
|
|
35
|
+
error_id=error_id,
|
|
36
|
+
error_code=exc.error_code,
|
|
37
|
+
error_message=exc.message,
|
|
38
|
+
error_details=exc.details,
|
|
39
|
+
path=str(request.url),
|
|
40
|
+
method=request.method,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Build response
|
|
44
|
+
response = {
|
|
45
|
+
"error": {
|
|
46
|
+
"id": error_id,
|
|
47
|
+
"code": exc.error_code,
|
|
48
|
+
"message": exc.message,
|
|
49
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
# Add details if present
|
|
54
|
+
if exc.details:
|
|
55
|
+
response["error"]["details"] = exc.details
|
|
56
|
+
|
|
57
|
+
# Add request context in development
|
|
58
|
+
if request.app.debug:
|
|
59
|
+
response["error"]["request"] = {
|
|
60
|
+
"method": request.method,
|
|
61
|
+
"path": str(request.url.path),
|
|
62
|
+
"query": str(request.url.query) if request.url.query else None,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
return JSONResponse(
|
|
66
|
+
status_code=exc.status_code,
|
|
67
|
+
content=response,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
async def validation_exception_handler(
|
|
72
|
+
request: Request,
|
|
73
|
+
exc: RequestValidationError,
|
|
74
|
+
) -> JSONResponse:
|
|
75
|
+
"""
|
|
76
|
+
Handle Pydantic validation errors.
|
|
77
|
+
|
|
78
|
+
Converts Pydantic validation errors to our standard error format.
|
|
79
|
+
"""
|
|
80
|
+
error_id = str(uuid.uuid4())
|
|
81
|
+
|
|
82
|
+
# Extract validation errors
|
|
83
|
+
errors = []
|
|
84
|
+
for error in exc.errors():
|
|
85
|
+
field_path = ".".join(str(loc) for loc in error["loc"])
|
|
86
|
+
errors.append({
|
|
87
|
+
"field": field_path,
|
|
88
|
+
"message": error["msg"],
|
|
89
|
+
"type": error["type"],
|
|
90
|
+
})
|
|
91
|
+
|
|
92
|
+
# Log validation error
|
|
93
|
+
logger.warning(
|
|
94
|
+
"validation_error",
|
|
95
|
+
error_id=error_id,
|
|
96
|
+
validation_errors=errors,
|
|
97
|
+
path=str(request.url),
|
|
98
|
+
method=request.method,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Build response
|
|
102
|
+
response = {
|
|
103
|
+
"error": {
|
|
104
|
+
"id": error_id,
|
|
105
|
+
"code": "VALIDATION_ERROR",
|
|
106
|
+
"message": "Request validation failed",
|
|
107
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
108
|
+
"details": {
|
|
109
|
+
"validation_errors": errors,
|
|
110
|
+
},
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
return JSONResponse(
|
|
115
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
116
|
+
content=response,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def http_exception_handler(
|
|
121
|
+
request: Request,
|
|
122
|
+
exc: HTTPException,
|
|
123
|
+
) -> JSONResponse:
|
|
124
|
+
"""
|
|
125
|
+
Handle FastAPI/Starlette HTTP exceptions.
|
|
126
|
+
|
|
127
|
+
Converts standard HTTP exceptions to our error format.
|
|
128
|
+
"""
|
|
129
|
+
error_id = str(uuid.uuid4())
|
|
130
|
+
|
|
131
|
+
# Log the HTTP exception
|
|
132
|
+
logger.warning(
|
|
133
|
+
"http_exception",
|
|
134
|
+
error_id=error_id,
|
|
135
|
+
status_code=exc.status_code,
|
|
136
|
+
detail=exc.detail,
|
|
137
|
+
path=str(request.url),
|
|
138
|
+
method=request.method,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Map status codes to error codes
|
|
142
|
+
error_code_map = {
|
|
143
|
+
400: "BAD_REQUEST",
|
|
144
|
+
401: "AUTHENTICATION_ERROR",
|
|
145
|
+
403: "AUTHORIZATION_ERROR",
|
|
146
|
+
404: "RESOURCE_NOT_FOUND",
|
|
147
|
+
405: "METHOD_NOT_ALLOWED",
|
|
148
|
+
409: "RESOURCE_CONFLICT",
|
|
149
|
+
422: "VALIDATION_ERROR",
|
|
150
|
+
429: "RATE_LIMIT_EXCEEDED",
|
|
151
|
+
500: "INTERNAL_ERROR",
|
|
152
|
+
502: "BAD_GATEWAY",
|
|
153
|
+
503: "SERVICE_UNAVAILABLE",
|
|
154
|
+
504: "GATEWAY_TIMEOUT",
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
error_code = error_code_map.get(exc.status_code, "HTTP_ERROR")
|
|
158
|
+
|
|
159
|
+
# Build response
|
|
160
|
+
response = {
|
|
161
|
+
"error": {
|
|
162
|
+
"id": error_id,
|
|
163
|
+
"code": error_code,
|
|
164
|
+
"message": exc.detail or f"HTTP {exc.status_code} Error",
|
|
165
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
# Add headers if present
|
|
170
|
+
if hasattr(exc, "headers") and exc.headers:
|
|
171
|
+
response["error"]["headers"] = dict(exc.headers)
|
|
172
|
+
|
|
173
|
+
return JSONResponse(
|
|
174
|
+
status_code=exc.status_code,
|
|
175
|
+
content=response,
|
|
176
|
+
headers=getattr(exc, "headers", None),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
async def generic_exception_handler(
|
|
181
|
+
request: Request,
|
|
182
|
+
exc: Exception,
|
|
183
|
+
) -> JSONResponse:
|
|
184
|
+
"""
|
|
185
|
+
Handle unexpected exceptions.
|
|
186
|
+
|
|
187
|
+
This is the catch-all handler for any unhandled exceptions.
|
|
188
|
+
"""
|
|
189
|
+
error_id = str(uuid.uuid4())
|
|
190
|
+
|
|
191
|
+
# Log the full exception with traceback
|
|
192
|
+
logger.error(
|
|
193
|
+
"unhandled_exception",
|
|
194
|
+
error_id=error_id,
|
|
195
|
+
error_type=type(exc).__name__,
|
|
196
|
+
error_message=str(exc),
|
|
197
|
+
path=str(request.url),
|
|
198
|
+
method=request.method,
|
|
199
|
+
traceback=traceback.format_exc(),
|
|
200
|
+
exc_info=exc,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Build response
|
|
204
|
+
# In production, don't expose internal error details
|
|
205
|
+
if request.app.debug:
|
|
206
|
+
message = f"{type(exc).__name__}: {str(exc)}"
|
|
207
|
+
details = {
|
|
208
|
+
"exception_type": type(exc).__name__,
|
|
209
|
+
"traceback": traceback.format_exc().split("\n"),
|
|
210
|
+
}
|
|
211
|
+
else:
|
|
212
|
+
message = "An internal server error occurred"
|
|
213
|
+
details = None
|
|
214
|
+
|
|
215
|
+
response = {
|
|
216
|
+
"error": {
|
|
217
|
+
"id": error_id,
|
|
218
|
+
"code": "INTERNAL_ERROR",
|
|
219
|
+
"message": message,
|
|
220
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
221
|
+
}
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
if details:
|
|
225
|
+
response["error"]["details"] = details
|
|
226
|
+
|
|
227
|
+
# Alert on unhandled exceptions (in production, this might trigger PagerDuty)
|
|
228
|
+
if not request.app.debug:
|
|
229
|
+
logger.critical(
|
|
230
|
+
"unhandled_exception_alert",
|
|
231
|
+
error_id=error_id,
|
|
232
|
+
error_type=type(exc).__name__,
|
|
233
|
+
path=str(request.url),
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return JSONResponse(
|
|
237
|
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
238
|
+
content=response,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def setup_exception_handlers(app: Any) -> None:
|
|
243
|
+
"""
|
|
244
|
+
Register all exception handlers with the FastAPI app.
|
|
245
|
+
|
|
246
|
+
Call this in your main.py after creating the app.
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
from control_plane_api.app.middleware.exception_handler import setup_exception_handlers
|
|
250
|
+
|
|
251
|
+
app = FastAPI()
|
|
252
|
+
setup_exception_handlers(app)
|
|
253
|
+
"""
|
|
254
|
+
# Our custom exceptions
|
|
255
|
+
app.add_exception_handler(ControlPlaneException, control_plane_exception_handler)
|
|
256
|
+
|
|
257
|
+
# FastAPI/Pydantic validation errors
|
|
258
|
+
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
259
|
+
|
|
260
|
+
# HTTP exceptions
|
|
261
|
+
app.add_exception_handler(HTTPException, http_exception_handler)
|
|
262
|
+
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
|
|
263
|
+
|
|
264
|
+
# Catch-all for unhandled exceptions
|
|
265
|
+
app.add_exception_handler(Exception, generic_exception_handler)
|
|
266
|
+
|
|
267
|
+
logger.info("exception_handlers_registered")
|
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rate limiting middleware using token bucket algorithm.
|
|
3
|
+
|
|
4
|
+
Provides configurable rate limiting per client IP or user.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from fastapi import Request, Response, HTTPException, status
|
|
8
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
9
|
+
from starlette.types import ASGIApp
|
|
10
|
+
from typing import Dict, Optional, Tuple, Any
|
|
11
|
+
from datetime import datetime, timedelta
|
|
12
|
+
import time
|
|
13
|
+
import asyncio
|
|
14
|
+
import structlog
|
|
15
|
+
import hashlib
|
|
16
|
+
from control_plane_api.app.exceptions import RateLimitError
|
|
17
|
+
|
|
18
|
+
logger = structlog.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TokenBucket:
|
|
22
|
+
"""
|
|
23
|
+
Token bucket implementation for rate limiting.
|
|
24
|
+
|
|
25
|
+
Each bucket starts with a capacity of tokens.
|
|
26
|
+
Tokens are consumed when requests are made.
|
|
27
|
+
Tokens are refilled at a constant rate.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
capacity: int,
|
|
33
|
+
refill_rate: float,
|
|
34
|
+
refill_period: float = 60.0,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Initialize token bucket.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
capacity: Maximum number of tokens in bucket
|
|
41
|
+
refill_rate: Number of tokens to add per period
|
|
42
|
+
refill_period: Period in seconds for refilling tokens
|
|
43
|
+
"""
|
|
44
|
+
self.capacity = capacity
|
|
45
|
+
self.refill_rate = refill_rate
|
|
46
|
+
self.refill_period = refill_period
|
|
47
|
+
self.tokens = capacity
|
|
48
|
+
self.last_refill = time.time()
|
|
49
|
+
self.lock = asyncio.Lock()
|
|
50
|
+
|
|
51
|
+
async def consume(self, tokens: int = 1) -> Tuple[bool, Dict[str, Any]]:
|
|
52
|
+
"""
|
|
53
|
+
Try to consume tokens from the bucket.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
tokens: Number of tokens to consume
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Tuple of (success, info_dict)
|
|
60
|
+
"""
|
|
61
|
+
async with self.lock:
|
|
62
|
+
now = time.time()
|
|
63
|
+
|
|
64
|
+
# Refill tokens based on time elapsed
|
|
65
|
+
time_elapsed = now - self.last_refill
|
|
66
|
+
tokens_to_add = (time_elapsed / self.refill_period) * self.refill_rate
|
|
67
|
+
|
|
68
|
+
if tokens_to_add > 0:
|
|
69
|
+
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
|
70
|
+
self.last_refill = now
|
|
71
|
+
|
|
72
|
+
# Check if we have enough tokens
|
|
73
|
+
if self.tokens >= tokens:
|
|
74
|
+
self.tokens -= tokens
|
|
75
|
+
success = True
|
|
76
|
+
retry_after = None
|
|
77
|
+
else:
|
|
78
|
+
success = False
|
|
79
|
+
# Calculate when enough tokens will be available
|
|
80
|
+
tokens_needed = tokens - self.tokens
|
|
81
|
+
time_to_wait = (tokens_needed / self.refill_rate) * self.refill_period
|
|
82
|
+
retry_after = int(time_to_wait) + 1
|
|
83
|
+
|
|
84
|
+
info = {
|
|
85
|
+
"limit": self.capacity,
|
|
86
|
+
"remaining": int(self.tokens),
|
|
87
|
+
"reset": int(now + self.refill_period),
|
|
88
|
+
"retry_after": retry_after,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return success, info
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class RateLimiter:
|
|
95
|
+
"""
|
|
96
|
+
Rate limiter managing multiple token buckets for different clients.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
requests_per_minute: int = 60,
|
|
102
|
+
burst_size: Optional[int] = None,
|
|
103
|
+
cleanup_interval: int = 300,
|
|
104
|
+
):
|
|
105
|
+
"""
|
|
106
|
+
Initialize rate limiter.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
requests_per_minute: Number of requests allowed per minute
|
|
110
|
+
burst_size: Maximum burst size (defaults to requests_per_minute // 4)
|
|
111
|
+
cleanup_interval: Interval in seconds to clean up old buckets
|
|
112
|
+
"""
|
|
113
|
+
self.requests_per_minute = requests_per_minute
|
|
114
|
+
self.burst_size = burst_size or max(10, requests_per_minute // 4)
|
|
115
|
+
self.buckets: Dict[str, TokenBucket] = {}
|
|
116
|
+
self.last_cleanup = time.time()
|
|
117
|
+
self.cleanup_interval = cleanup_interval
|
|
118
|
+
self.lock = asyncio.Lock()
|
|
119
|
+
|
|
120
|
+
async def check_rate_limit(self, identifier: str) -> Tuple[bool, Dict[str, Any]]:
|
|
121
|
+
"""
|
|
122
|
+
Check if request is within rate limit.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
identifier: Client identifier (IP, user ID, etc.)
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Tuple of (allowed, headers_dict)
|
|
129
|
+
"""
|
|
130
|
+
# Clean up old buckets periodically
|
|
131
|
+
await self._cleanup_buckets()
|
|
132
|
+
|
|
133
|
+
# Get or create bucket for this identifier
|
|
134
|
+
bucket = await self._get_or_create_bucket(identifier)
|
|
135
|
+
|
|
136
|
+
# Try to consume a token
|
|
137
|
+
allowed, info = await bucket.consume()
|
|
138
|
+
|
|
139
|
+
# Build rate limit headers
|
|
140
|
+
headers = {
|
|
141
|
+
"X-RateLimit-Limit": str(info["limit"]),
|
|
142
|
+
"X-RateLimit-Remaining": str(info["remaining"]),
|
|
143
|
+
"X-RateLimit-Reset": str(info["reset"]),
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
if not allowed and info.get("retry_after"):
|
|
147
|
+
headers["Retry-After"] = str(info["retry_after"])
|
|
148
|
+
|
|
149
|
+
return allowed, headers
|
|
150
|
+
|
|
151
|
+
async def _get_or_create_bucket(self, identifier: str) -> TokenBucket:
|
|
152
|
+
"""Get existing bucket or create new one."""
|
|
153
|
+
async with self.lock:
|
|
154
|
+
if identifier not in self.buckets:
|
|
155
|
+
self.buckets[identifier] = TokenBucket(
|
|
156
|
+
capacity=self.burst_size,
|
|
157
|
+
refill_rate=self.requests_per_minute,
|
|
158
|
+
refill_period=60.0,
|
|
159
|
+
)
|
|
160
|
+
return self.buckets[identifier]
|
|
161
|
+
|
|
162
|
+
async def _cleanup_buckets(self):
|
|
163
|
+
"""Remove old unused buckets to prevent memory leak."""
|
|
164
|
+
now = time.time()
|
|
165
|
+
if now - self.last_cleanup < self.cleanup_interval:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
async with self.lock:
|
|
169
|
+
# Remove buckets that haven't been used recently
|
|
170
|
+
cutoff_time = now - self.cleanup_interval
|
|
171
|
+
to_remove = []
|
|
172
|
+
|
|
173
|
+
for identifier, bucket in self.buckets.items():
|
|
174
|
+
if bucket.last_refill < cutoff_time:
|
|
175
|
+
to_remove.append(identifier)
|
|
176
|
+
|
|
177
|
+
for identifier in to_remove:
|
|
178
|
+
del self.buckets[identifier]
|
|
179
|
+
|
|
180
|
+
if to_remove:
|
|
181
|
+
logger.info(
|
|
182
|
+
"rate_limiter_cleanup",
|
|
183
|
+
removed_count=len(to_remove),
|
|
184
|
+
remaining_count=len(self.buckets),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
self.last_cleanup = now
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
191
|
+
"""
|
|
192
|
+
Rate limiting middleware for FastAPI.
|
|
193
|
+
|
|
194
|
+
Limits requests per client based on IP address or authenticated user.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def __init__(
|
|
198
|
+
self,
|
|
199
|
+
app: ASGIApp,
|
|
200
|
+
requests_per_minute: int = 60,
|
|
201
|
+
burst_size: Optional[int] = None,
|
|
202
|
+
exclude_paths: Optional[list] = None,
|
|
203
|
+
identifier_callback: Optional[callable] = None,
|
|
204
|
+
):
|
|
205
|
+
"""
|
|
206
|
+
Initialize rate limit middleware.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
app: FastAPI application
|
|
210
|
+
requests_per_minute: Default rate limit
|
|
211
|
+
burst_size: Maximum burst size
|
|
212
|
+
exclude_paths: Paths to exclude from rate limiting
|
|
213
|
+
identifier_callback: Custom function to get client identifier
|
|
214
|
+
"""
|
|
215
|
+
super().__init__(app)
|
|
216
|
+
self.rate_limiter = RateLimiter(requests_per_minute, burst_size)
|
|
217
|
+
self.exclude_paths = exclude_paths or ["/health", "/metrics", "/docs", "/openapi.json"]
|
|
218
|
+
self.identifier_callback = identifier_callback
|
|
219
|
+
|
|
220
|
+
async def dispatch(self, request: Request, call_next):
|
|
221
|
+
"""Process request with rate limiting."""
|
|
222
|
+
|
|
223
|
+
# Check if path is excluded
|
|
224
|
+
if self._is_excluded(request.url.path):
|
|
225
|
+
return await call_next(request)
|
|
226
|
+
|
|
227
|
+
# Get client identifier
|
|
228
|
+
identifier = await self._get_identifier(request)
|
|
229
|
+
|
|
230
|
+
# Check rate limit
|
|
231
|
+
allowed, headers = await self.rate_limiter.check_rate_limit(identifier)
|
|
232
|
+
|
|
233
|
+
if not allowed:
|
|
234
|
+
# Log rate limit exceeded
|
|
235
|
+
logger.warning(
|
|
236
|
+
"rate_limit_exceeded",
|
|
237
|
+
identifier=self._hash_identifier(identifier),
|
|
238
|
+
path=request.url.path,
|
|
239
|
+
method=request.method,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Raise rate limit error
|
|
243
|
+
raise RateLimitError(
|
|
244
|
+
limit=self.rate_limiter.requests_per_minute,
|
|
245
|
+
window="minute",
|
|
246
|
+
retry_after=int(headers.get("Retry-After", 60)),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Process request
|
|
250
|
+
response = await call_next(request)
|
|
251
|
+
|
|
252
|
+
# Add rate limit headers to response
|
|
253
|
+
for header, value in headers.items():
|
|
254
|
+
response.headers[header] = value
|
|
255
|
+
|
|
256
|
+
return response
|
|
257
|
+
|
|
258
|
+
def _is_excluded(self, path: str) -> bool:
|
|
259
|
+
"""Check if path is excluded from rate limiting."""
|
|
260
|
+
for excluded in self.exclude_paths:
|
|
261
|
+
if path.startswith(excluded):
|
|
262
|
+
return True
|
|
263
|
+
return False
|
|
264
|
+
|
|
265
|
+
async def _get_identifier(self, request: Request) -> str:
|
|
266
|
+
"""
|
|
267
|
+
Get client identifier for rate limiting.
|
|
268
|
+
|
|
269
|
+
Priority:
|
|
270
|
+
1. Custom identifier callback
|
|
271
|
+
2. Authenticated user ID
|
|
272
|
+
3. Client IP address
|
|
273
|
+
"""
|
|
274
|
+
# Use custom identifier callback if provided
|
|
275
|
+
if self.identifier_callback:
|
|
276
|
+
identifier = await self.identifier_callback(request)
|
|
277
|
+
if identifier:
|
|
278
|
+
return f"custom:{identifier}"
|
|
279
|
+
|
|
280
|
+
# Check for authenticated user
|
|
281
|
+
if hasattr(request.state, "user") and request.state.user:
|
|
282
|
+
user_id = getattr(request.state.user, "id", None)
|
|
283
|
+
if user_id:
|
|
284
|
+
return f"user:{user_id}"
|
|
285
|
+
|
|
286
|
+
# Fall back to IP address
|
|
287
|
+
client_host = request.client.host if request.client else "unknown"
|
|
288
|
+
|
|
289
|
+
# Check for proxy headers
|
|
290
|
+
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
291
|
+
if forwarded_for:
|
|
292
|
+
# Take the first IP in the chain
|
|
293
|
+
client_host = forwarded_for.split(",")[0].strip()
|
|
294
|
+
|
|
295
|
+
real_ip = request.headers.get("X-Real-IP")
|
|
296
|
+
if real_ip:
|
|
297
|
+
client_host = real_ip
|
|
298
|
+
|
|
299
|
+
return f"ip:{client_host}"
|
|
300
|
+
|
|
301
|
+
def _hash_identifier(self, identifier: str) -> str:
|
|
302
|
+
"""Hash identifier for logging (privacy)."""
|
|
303
|
+
return hashlib.sha256(identifier.encode()).hexdigest()[:16]
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# Per-endpoint rate limiting decorator
|
|
307
|
+
class EndpointRateLimiter:
|
|
308
|
+
"""
|
|
309
|
+
Decorator for per-endpoint rate limiting.
|
|
310
|
+
|
|
311
|
+
Usage:
|
|
312
|
+
rate_limiter = EndpointRateLimiter()
|
|
313
|
+
|
|
314
|
+
@app.get("/expensive-operation")
|
|
315
|
+
@rate_limiter.limit(requests_per_minute=10)
|
|
316
|
+
async def expensive_operation():
|
|
317
|
+
...
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def __init__(self):
|
|
321
|
+
self.limiters: Dict[str, RateLimiter] = {}
|
|
322
|
+
|
|
323
|
+
def limit(
|
|
324
|
+
self,
|
|
325
|
+
requests_per_minute: int,
|
|
326
|
+
burst_size: Optional[int] = None,
|
|
327
|
+
identifier_callback: Optional[callable] = None,
|
|
328
|
+
):
|
|
329
|
+
"""
|
|
330
|
+
Create rate limit decorator for endpoint.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
requests_per_minute: Rate limit for this endpoint
|
|
334
|
+
burst_size: Burst size for this endpoint
|
|
335
|
+
identifier_callback: Custom identifier function
|
|
336
|
+
"""
|
|
337
|
+
def decorator(func):
|
|
338
|
+
# Create rate limiter for this endpoint
|
|
339
|
+
endpoint_id = f"{func.__module__}.{func.__name__}"
|
|
340
|
+
self.limiters[endpoint_id] = RateLimiter(
|
|
341
|
+
requests_per_minute=requests_per_minute,
|
|
342
|
+
burst_size=burst_size,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
async def wrapper(request: Request, *args, **kwargs):
|
|
346
|
+
# Get client identifier
|
|
347
|
+
if identifier_callback:
|
|
348
|
+
identifier = await identifier_callback(request)
|
|
349
|
+
else:
|
|
350
|
+
identifier = self._default_identifier(request)
|
|
351
|
+
|
|
352
|
+
# Check rate limit
|
|
353
|
+
limiter = self.limiters[endpoint_id]
|
|
354
|
+
allowed, headers = await limiter.check_rate_limit(identifier)
|
|
355
|
+
|
|
356
|
+
if not allowed:
|
|
357
|
+
raise RateLimitError(
|
|
358
|
+
limit=requests_per_minute,
|
|
359
|
+
window="minute",
|
|
360
|
+
retry_after=int(headers.get("Retry-After", 60)),
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Add headers to response (if we have access to it)
|
|
364
|
+
response = await func(request, *args, **kwargs)
|
|
365
|
+
if isinstance(response, Response):
|
|
366
|
+
for header, value in headers.items():
|
|
367
|
+
response.headers[header] = value
|
|
368
|
+
|
|
369
|
+
return response
|
|
370
|
+
|
|
371
|
+
return wrapper
|
|
372
|
+
return decorator
|
|
373
|
+
|
|
374
|
+
def _default_identifier(self, request: Request) -> str:
|
|
375
|
+
"""Default identifier extraction."""
|
|
376
|
+
if hasattr(request.state, "user") and request.state.user:
|
|
377
|
+
return f"user:{request.state.user.id}"
|
|
378
|
+
|
|
379
|
+
client_host = request.client.host if request.client else "unknown"
|
|
380
|
+
return f"ip:{client_host}"
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
# Global rate limiter instance for decorator usage
|
|
384
|
+
endpoint_rate_limiter = EndpointRateLimiter()
|