mcp-hangar 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (160) hide show
  1. mcp_hangar/__init__.py +139 -0
  2. mcp_hangar/application/__init__.py +1 -0
  3. mcp_hangar/application/commands/__init__.py +67 -0
  4. mcp_hangar/application/commands/auth_commands.py +118 -0
  5. mcp_hangar/application/commands/auth_handlers.py +296 -0
  6. mcp_hangar/application/commands/commands.py +59 -0
  7. mcp_hangar/application/commands/handlers.py +189 -0
  8. mcp_hangar/application/discovery/__init__.py +21 -0
  9. mcp_hangar/application/discovery/discovery_metrics.py +283 -0
  10. mcp_hangar/application/discovery/discovery_orchestrator.py +497 -0
  11. mcp_hangar/application/discovery/lifecycle_manager.py +315 -0
  12. mcp_hangar/application/discovery/security_validator.py +414 -0
  13. mcp_hangar/application/event_handlers/__init__.py +50 -0
  14. mcp_hangar/application/event_handlers/alert_handler.py +191 -0
  15. mcp_hangar/application/event_handlers/audit_handler.py +203 -0
  16. mcp_hangar/application/event_handlers/knowledge_base_handler.py +120 -0
  17. mcp_hangar/application/event_handlers/logging_handler.py +69 -0
  18. mcp_hangar/application/event_handlers/metrics_handler.py +152 -0
  19. mcp_hangar/application/event_handlers/persistent_audit_store.py +217 -0
  20. mcp_hangar/application/event_handlers/security_handler.py +604 -0
  21. mcp_hangar/application/mcp/tooling.py +158 -0
  22. mcp_hangar/application/ports/__init__.py +9 -0
  23. mcp_hangar/application/ports/observability.py +237 -0
  24. mcp_hangar/application/queries/__init__.py +52 -0
  25. mcp_hangar/application/queries/auth_handlers.py +237 -0
  26. mcp_hangar/application/queries/auth_queries.py +118 -0
  27. mcp_hangar/application/queries/handlers.py +227 -0
  28. mcp_hangar/application/read_models/__init__.py +11 -0
  29. mcp_hangar/application/read_models/provider_views.py +139 -0
  30. mcp_hangar/application/sagas/__init__.py +11 -0
  31. mcp_hangar/application/sagas/group_rebalance_saga.py +137 -0
  32. mcp_hangar/application/sagas/provider_failover_saga.py +266 -0
  33. mcp_hangar/application/sagas/provider_recovery_saga.py +172 -0
  34. mcp_hangar/application/services/__init__.py +9 -0
  35. mcp_hangar/application/services/provider_service.py +208 -0
  36. mcp_hangar/application/services/traced_provider_service.py +211 -0
  37. mcp_hangar/bootstrap/runtime.py +328 -0
  38. mcp_hangar/context.py +178 -0
  39. mcp_hangar/domain/__init__.py +117 -0
  40. mcp_hangar/domain/contracts/__init__.py +57 -0
  41. mcp_hangar/domain/contracts/authentication.py +225 -0
  42. mcp_hangar/domain/contracts/authorization.py +229 -0
  43. mcp_hangar/domain/contracts/event_store.py +178 -0
  44. mcp_hangar/domain/contracts/metrics_publisher.py +59 -0
  45. mcp_hangar/domain/contracts/persistence.py +383 -0
  46. mcp_hangar/domain/contracts/provider_runtime.py +146 -0
  47. mcp_hangar/domain/discovery/__init__.py +20 -0
  48. mcp_hangar/domain/discovery/conflict_resolver.py +267 -0
  49. mcp_hangar/domain/discovery/discovered_provider.py +185 -0
  50. mcp_hangar/domain/discovery/discovery_service.py +412 -0
  51. mcp_hangar/domain/discovery/discovery_source.py +192 -0
  52. mcp_hangar/domain/events.py +433 -0
  53. mcp_hangar/domain/exceptions.py +525 -0
  54. mcp_hangar/domain/model/__init__.py +70 -0
  55. mcp_hangar/domain/model/aggregate.py +58 -0
  56. mcp_hangar/domain/model/circuit_breaker.py +152 -0
  57. mcp_hangar/domain/model/event_sourced_api_key.py +413 -0
  58. mcp_hangar/domain/model/event_sourced_provider.py +423 -0
  59. mcp_hangar/domain/model/event_sourced_role_assignment.py +268 -0
  60. mcp_hangar/domain/model/health_tracker.py +183 -0
  61. mcp_hangar/domain/model/load_balancer.py +185 -0
  62. mcp_hangar/domain/model/provider.py +810 -0
  63. mcp_hangar/domain/model/provider_group.py +656 -0
  64. mcp_hangar/domain/model/tool_catalog.py +105 -0
  65. mcp_hangar/domain/policies/__init__.py +19 -0
  66. mcp_hangar/domain/policies/provider_health.py +187 -0
  67. mcp_hangar/domain/repository.py +249 -0
  68. mcp_hangar/domain/security/__init__.py +85 -0
  69. mcp_hangar/domain/security/input_validator.py +710 -0
  70. mcp_hangar/domain/security/rate_limiter.py +387 -0
  71. mcp_hangar/domain/security/roles.py +237 -0
  72. mcp_hangar/domain/security/sanitizer.py +387 -0
  73. mcp_hangar/domain/security/secrets.py +501 -0
  74. mcp_hangar/domain/services/__init__.py +20 -0
  75. mcp_hangar/domain/services/audit_service.py +376 -0
  76. mcp_hangar/domain/services/image_builder.py +328 -0
  77. mcp_hangar/domain/services/provider_launcher.py +1046 -0
  78. mcp_hangar/domain/value_objects.py +1138 -0
  79. mcp_hangar/errors.py +818 -0
  80. mcp_hangar/fastmcp_server.py +1105 -0
  81. mcp_hangar/gc.py +134 -0
  82. mcp_hangar/infrastructure/__init__.py +79 -0
  83. mcp_hangar/infrastructure/async_executor.py +133 -0
  84. mcp_hangar/infrastructure/auth/__init__.py +37 -0
  85. mcp_hangar/infrastructure/auth/api_key_authenticator.py +388 -0
  86. mcp_hangar/infrastructure/auth/event_sourced_store.py +567 -0
  87. mcp_hangar/infrastructure/auth/jwt_authenticator.py +360 -0
  88. mcp_hangar/infrastructure/auth/middleware.py +340 -0
  89. mcp_hangar/infrastructure/auth/opa_authorizer.py +243 -0
  90. mcp_hangar/infrastructure/auth/postgres_store.py +659 -0
  91. mcp_hangar/infrastructure/auth/projections.py +366 -0
  92. mcp_hangar/infrastructure/auth/rate_limiter.py +311 -0
  93. mcp_hangar/infrastructure/auth/rbac_authorizer.py +323 -0
  94. mcp_hangar/infrastructure/auth/sqlite_store.py +624 -0
  95. mcp_hangar/infrastructure/command_bus.py +112 -0
  96. mcp_hangar/infrastructure/discovery/__init__.py +110 -0
  97. mcp_hangar/infrastructure/discovery/docker_source.py +289 -0
  98. mcp_hangar/infrastructure/discovery/entrypoint_source.py +249 -0
  99. mcp_hangar/infrastructure/discovery/filesystem_source.py +383 -0
  100. mcp_hangar/infrastructure/discovery/kubernetes_source.py +247 -0
  101. mcp_hangar/infrastructure/event_bus.py +260 -0
  102. mcp_hangar/infrastructure/event_sourced_repository.py +443 -0
  103. mcp_hangar/infrastructure/event_store.py +396 -0
  104. mcp_hangar/infrastructure/knowledge_base/__init__.py +259 -0
  105. mcp_hangar/infrastructure/knowledge_base/contracts.py +202 -0
  106. mcp_hangar/infrastructure/knowledge_base/memory.py +177 -0
  107. mcp_hangar/infrastructure/knowledge_base/postgres.py +545 -0
  108. mcp_hangar/infrastructure/knowledge_base/sqlite.py +513 -0
  109. mcp_hangar/infrastructure/metrics_publisher.py +36 -0
  110. mcp_hangar/infrastructure/observability/__init__.py +10 -0
  111. mcp_hangar/infrastructure/observability/langfuse_adapter.py +534 -0
  112. mcp_hangar/infrastructure/persistence/__init__.py +33 -0
  113. mcp_hangar/infrastructure/persistence/audit_repository.py +371 -0
  114. mcp_hangar/infrastructure/persistence/config_repository.py +398 -0
  115. mcp_hangar/infrastructure/persistence/database.py +333 -0
  116. mcp_hangar/infrastructure/persistence/database_common.py +330 -0
  117. mcp_hangar/infrastructure/persistence/event_serializer.py +280 -0
  118. mcp_hangar/infrastructure/persistence/event_upcaster.py +166 -0
  119. mcp_hangar/infrastructure/persistence/in_memory_event_store.py +150 -0
  120. mcp_hangar/infrastructure/persistence/recovery_service.py +312 -0
  121. mcp_hangar/infrastructure/persistence/sqlite_event_store.py +386 -0
  122. mcp_hangar/infrastructure/persistence/unit_of_work.py +409 -0
  123. mcp_hangar/infrastructure/persistence/upcasters/README.md +13 -0
  124. mcp_hangar/infrastructure/persistence/upcasters/__init__.py +7 -0
  125. mcp_hangar/infrastructure/query_bus.py +153 -0
  126. mcp_hangar/infrastructure/saga_manager.py +401 -0
  127. mcp_hangar/logging_config.py +209 -0
  128. mcp_hangar/metrics.py +1007 -0
  129. mcp_hangar/models.py +31 -0
  130. mcp_hangar/observability/__init__.py +54 -0
  131. mcp_hangar/observability/health.py +487 -0
  132. mcp_hangar/observability/metrics.py +319 -0
  133. mcp_hangar/observability/tracing.py +433 -0
  134. mcp_hangar/progress.py +542 -0
  135. mcp_hangar/retry.py +613 -0
  136. mcp_hangar/server/__init__.py +120 -0
  137. mcp_hangar/server/__main__.py +6 -0
  138. mcp_hangar/server/auth_bootstrap.py +340 -0
  139. mcp_hangar/server/auth_cli.py +335 -0
  140. mcp_hangar/server/auth_config.py +305 -0
  141. mcp_hangar/server/bootstrap.py +735 -0
  142. mcp_hangar/server/cli.py +161 -0
  143. mcp_hangar/server/config.py +224 -0
  144. mcp_hangar/server/context.py +215 -0
  145. mcp_hangar/server/http_auth_middleware.py +165 -0
  146. mcp_hangar/server/lifecycle.py +467 -0
  147. mcp_hangar/server/state.py +117 -0
  148. mcp_hangar/server/tools/__init__.py +16 -0
  149. mcp_hangar/server/tools/discovery.py +186 -0
  150. mcp_hangar/server/tools/groups.py +75 -0
  151. mcp_hangar/server/tools/health.py +301 -0
  152. mcp_hangar/server/tools/provider.py +939 -0
  153. mcp_hangar/server/tools/registry.py +320 -0
  154. mcp_hangar/server/validation.py +113 -0
  155. mcp_hangar/stdio_client.py +229 -0
  156. mcp_hangar-0.2.0.dist-info/METADATA +347 -0
  157. mcp_hangar-0.2.0.dist-info/RECORD +160 -0
  158. mcp_hangar-0.2.0.dist-info/WHEEL +4 -0
  159. mcp_hangar-0.2.0.dist-info/entry_points.txt +2 -0
  160. mcp_hangar-0.2.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,366 @@
1
+ """Auth Projections - Read models built from events.
2
+
3
+ Projections listen to domain events and build optimized read models
4
+ for queries. They are part of CQRS read side.
5
+ """
6
+
7
+ from dataclasses import dataclass, replace
8
+ from datetime import datetime, timezone
9
+ import threading
10
+ from typing import Any
11
+
12
+ from ...domain.contracts.event_store import IEventStore
13
+ from ...domain.events import ApiKeyCreated, ApiKeyRevoked, DomainEvent, RoleAssigned, RoleRevoked
14
+ from ...logging_config import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class ApiKeyReadModel:
21
+ """Read model for API key queries."""
22
+
23
+ key_id: str
24
+ key_hash: str
25
+ principal_id: str
26
+ name: str
27
+ tenant_id: str | None
28
+ groups: list[str]
29
+ created_at: datetime
30
+ created_by: str
31
+ expires_at: datetime | None
32
+ revoked: bool
33
+ revoked_at: datetime | None = None
34
+ revoked_by: str | None = None
35
+ revocation_reason: str | None = None
36
+
37
+
38
+ @dataclass
39
+ class RoleAssignmentReadModel:
40
+ """Read model for role assignment queries."""
41
+
42
+ principal_id: str
43
+ role_name: str
44
+ scope: str
45
+ assigned_at: datetime
46
+ assigned_by: str
47
+
48
+
49
+ class AuthProjection:
50
+ """Projection that builds auth read models from events.
51
+
52
+ Processes events from the event store and maintains in-memory
53
+ read models optimized for queries.
54
+
55
+ Features:
56
+ - Indexes by key_hash, key_id, and principal_id
57
+ - Role assignments indexed by principal_id
58
+ - Thread-safe access
59
+ - Supports catchup from event store
60
+ """
61
+
62
+ def __init__(self, event_store: IEventStore | None = None):
63
+ """Initialize the projection.
64
+
65
+ Args:
66
+ event_store: Optional event store for catchup.
67
+ """
68
+ self._event_store = event_store
69
+ self._lock = threading.RLock()
70
+
71
+ # API Key indexes
72
+ self._keys_by_hash: dict[str, ApiKeyReadModel] = {}
73
+ self._keys_by_id: dict[str, ApiKeyReadModel] = {}
74
+ self._keys_by_principal: dict[str, list[str]] = {} # principal -> key_ids
75
+
76
+ # Role assignment indexes
77
+ self._roles_by_principal: dict[str, list[RoleAssignmentReadModel]] = {}
78
+
79
+ # Track position for incremental updates
80
+ self._last_position = 0
81
+
82
+ def catchup(self) -> int:
83
+ """Catch up with events from event store.
84
+
85
+ Reads all events from last known position and applies them.
86
+
87
+ Returns:
88
+ Number of events processed.
89
+ """
90
+ if not self._event_store:
91
+ return 0
92
+
93
+ count = 0
94
+ for position, stream_id, event in self._event_store.read_all(from_position=self._last_position):
95
+ self.apply(event)
96
+ self._last_position = position
97
+ count += 1
98
+
99
+ logger.info(
100
+ "auth_projection_catchup_complete",
101
+ events_processed=count,
102
+ last_position=self._last_position,
103
+ )
104
+
105
+ return count
106
+
107
+ def apply(self, event: DomainEvent) -> None:
108
+ """Apply a domain event to update read models.
109
+
110
+ Args:
111
+ event: Event to apply.
112
+ """
113
+ if isinstance(event, ApiKeyCreated):
114
+ self._apply_key_created(event)
115
+ elif isinstance(event, ApiKeyRevoked):
116
+ self._apply_key_revoked(event)
117
+ elif isinstance(event, RoleAssigned):
118
+ self._apply_role_assigned(event)
119
+ elif isinstance(event, RoleRevoked):
120
+ self._apply_role_revoked(event)
121
+
122
+ def _apply_key_created(self, event: ApiKeyCreated) -> None:
123
+ """Apply ApiKeyCreated event."""
124
+ with self._lock:
125
+ # We don't have key_hash in the event - that's stored in stream_id
126
+ # For now, create read model without hash (can be updated later)
127
+ model = ApiKeyReadModel(
128
+ key_id=event.key_id,
129
+ key_hash="", # Will be set when we know it
130
+ principal_id=event.principal_id,
131
+ name=event.key_name,
132
+ tenant_id=None, # Not in event
133
+ groups=[], # Not in event
134
+ created_at=datetime.fromtimestamp(event.occurred_at, tz=timezone.utc),
135
+ created_by=event.created_by,
136
+ expires_at=datetime.fromtimestamp(event.expires_at, tz=timezone.utc) if event.expires_at else None,
137
+ revoked=False,
138
+ )
139
+
140
+ self._keys_by_id[event.key_id] = model
141
+
142
+ if event.principal_id not in self._keys_by_principal:
143
+ self._keys_by_principal[event.principal_id] = []
144
+ self._keys_by_principal[event.principal_id].append(event.key_id)
145
+
146
+ def _apply_key_revoked(self, event: ApiKeyRevoked) -> None:
147
+ """Apply ApiKeyRevoked event."""
148
+ with self._lock:
149
+ if event.key_id in self._keys_by_id:
150
+ model = self._keys_by_id[event.key_id]
151
+ # Create updated model using immutable pattern
152
+ self._keys_by_id[event.key_id] = replace(
153
+ model,
154
+ revoked=True,
155
+ revoked_at=datetime.fromtimestamp(event.occurred_at, tz=timezone.utc),
156
+ revoked_by=event.revoked_by,
157
+ revocation_reason=event.reason,
158
+ )
159
+
160
+ def _apply_role_assigned(self, event: RoleAssigned) -> None:
161
+ """Apply RoleAssigned event."""
162
+ with self._lock:
163
+ model = RoleAssignmentReadModel(
164
+ principal_id=event.principal_id,
165
+ role_name=event.role_name,
166
+ scope=event.scope,
167
+ assigned_at=datetime.fromtimestamp(event.occurred_at, tz=timezone.utc),
168
+ assigned_by=event.assigned_by,
169
+ )
170
+
171
+ if event.principal_id not in self._roles_by_principal:
172
+ self._roles_by_principal[event.principal_id] = []
173
+
174
+ # Check if already assigned (idempotency)
175
+ existing = next(
176
+ (
177
+ r
178
+ for r in self._roles_by_principal[event.principal_id]
179
+ if r.role_name == event.role_name and r.scope == event.scope
180
+ ),
181
+ None,
182
+ )
183
+ if not existing:
184
+ self._roles_by_principal[event.principal_id].append(model)
185
+
186
+ def _apply_role_revoked(self, event: RoleRevoked) -> None:
187
+ """Apply RoleRevoked event."""
188
+ with self._lock:
189
+ if event.principal_id in self._roles_by_principal:
190
+ self._roles_by_principal[event.principal_id] = [
191
+ r
192
+ for r in self._roles_by_principal[event.principal_id]
193
+ if not (r.role_name == event.role_name and r.scope == event.scope)
194
+ ]
195
+
196
+ # =========================================================================
197
+ # Queries
198
+ # =========================================================================
199
+
200
+ def get_key_by_id(self, key_id: str) -> ApiKeyReadModel | None:
201
+ """Get API key by ID."""
202
+ with self._lock:
203
+ return self._keys_by_id.get(key_id)
204
+
205
+ def get_keys_for_principal(self, principal_id: str) -> list[ApiKeyReadModel]:
206
+ """Get all API keys for a principal."""
207
+ with self._lock:
208
+ key_ids = self._keys_by_principal.get(principal_id, [])
209
+ return [self._keys_by_id[kid] for kid in key_ids if kid in self._keys_by_id]
210
+
211
+ def get_active_key_count(self, principal_id: str) -> int:
212
+ """Get count of active (non-revoked) keys for a principal."""
213
+ keys = self.get_keys_for_principal(principal_id)
214
+ return sum(1 for k in keys if not k.revoked)
215
+
216
+ def get_roles_for_principal(self, principal_id: str) -> list[RoleAssignmentReadModel]:
217
+ """Get all role assignments for a principal."""
218
+ with self._lock:
219
+ return list(self._roles_by_principal.get(principal_id, []))
220
+
221
+ def has_role(self, principal_id: str, role_name: str, scope: str = "*") -> bool:
222
+ """Check if principal has a specific role."""
223
+ roles = self.get_roles_for_principal(principal_id)
224
+ for role in roles:
225
+ if role.role_name == role_name:
226
+ if scope == "*" or role.scope == scope or role.scope == "global":
227
+ return True
228
+ return False
229
+
230
+ # =========================================================================
231
+ # Statistics
232
+ # =========================================================================
233
+
234
+ def get_stats(self) -> dict[str, Any]:
235
+ """Get projection statistics."""
236
+ with self._lock:
237
+ total_keys = len(self._keys_by_id)
238
+ revoked_keys = sum(1 for k in self._keys_by_id.values() if k.revoked)
239
+
240
+ total_assignments = sum(len(roles) for roles in self._roles_by_principal.values())
241
+
242
+ return {
243
+ "total_api_keys": total_keys,
244
+ "active_api_keys": total_keys - revoked_keys,
245
+ "revoked_api_keys": revoked_keys,
246
+ "total_principals_with_keys": len(self._keys_by_principal),
247
+ "total_role_assignments": total_assignments,
248
+ "total_principals_with_roles": len(self._roles_by_principal),
249
+ "last_event_position": self._last_position,
250
+ }
251
+
252
+
253
+ class AuthAuditLog:
254
+ """Audit log projection for auth events.
255
+
256
+ Maintains a time-ordered log of all auth events for audit purposes.
257
+ """
258
+
259
+ def __init__(self, max_entries: int = 10000):
260
+ """Initialize audit log.
261
+
262
+ Args:
263
+ max_entries: Maximum entries to keep in memory.
264
+ """
265
+ self._max_entries = max_entries
266
+ self._entries: list[dict[str, Any]] = []
267
+ self._lock = threading.RLock()
268
+
269
+ def apply(self, event: DomainEvent) -> None:
270
+ """Apply event to audit log."""
271
+ entry = self._event_to_entry(event)
272
+ if entry:
273
+ with self._lock:
274
+ self._entries.append(entry)
275
+
276
+ # Trim if over limit
277
+ if len(self._entries) > self._max_entries:
278
+ self._entries = self._entries[-self._max_entries :]
279
+
280
+ def _event_to_entry(self, event: DomainEvent) -> dict[str, Any] | None:
281
+ """Convert event to audit entry."""
282
+ if isinstance(event, ApiKeyCreated):
283
+ return {
284
+ "timestamp": event.occurred_at,
285
+ "event_type": "api_key_created",
286
+ "principal_id": event.principal_id,
287
+ "details": {
288
+ "key_id": event.key_id,
289
+ "key_name": event.key_name,
290
+ "created_by": event.created_by,
291
+ "expires_at": event.expires_at,
292
+ },
293
+ }
294
+
295
+ elif isinstance(event, ApiKeyRevoked):
296
+ return {
297
+ "timestamp": event.occurred_at,
298
+ "event_type": "api_key_revoked",
299
+ "principal_id": event.principal_id,
300
+ "details": {
301
+ "key_id": event.key_id,
302
+ "revoked_by": event.revoked_by,
303
+ "reason": event.reason,
304
+ },
305
+ }
306
+
307
+ elif isinstance(event, RoleAssigned):
308
+ return {
309
+ "timestamp": event.occurred_at,
310
+ "event_type": "role_assigned",
311
+ "principal_id": event.principal_id,
312
+ "details": {
313
+ "role_name": event.role_name,
314
+ "scope": event.scope,
315
+ "assigned_by": event.assigned_by,
316
+ },
317
+ }
318
+
319
+ elif isinstance(event, RoleRevoked):
320
+ return {
321
+ "timestamp": event.occurred_at,
322
+ "event_type": "role_revoked",
323
+ "principal_id": event.principal_id,
324
+ "details": {
325
+ "role_name": event.role_name,
326
+ "scope": event.scope,
327
+ "revoked_by": event.revoked_by,
328
+ },
329
+ }
330
+
331
+ return None
332
+
333
+ def query(
334
+ self,
335
+ principal_id: str | None = None,
336
+ event_type: str | None = None,
337
+ since: float | None = None,
338
+ limit: int = 100,
339
+ ) -> list[dict[str, Any]]:
340
+ """Query audit log entries.
341
+
342
+ Args:
343
+ principal_id: Filter by principal.
344
+ event_type: Filter by event type.
345
+ since: Filter entries after this timestamp.
346
+ limit: Maximum entries to return.
347
+
348
+ Returns:
349
+ List of matching audit entries.
350
+ """
351
+ with self._lock:
352
+ result = []
353
+ for entry in reversed(self._entries):
354
+ # Apply filters
355
+ if principal_id and entry.get("principal_id") != principal_id:
356
+ continue
357
+ if event_type and entry.get("event_type") != event_type:
358
+ continue
359
+ if since and entry.get("timestamp", 0) <= since:
360
+ continue
361
+
362
+ result.append(entry)
363
+ if len(result) >= limit:
364
+ break
365
+
366
+ return result
@@ -0,0 +1,311 @@
1
+ """Rate limiting for authentication attempts.
2
+
3
+ Provides protection against brute-force attacks by limiting
4
+ the number of failed authentication attempts per IP address.
5
+
6
+ Uses a token bucket algorithm with per-IP tracking.
7
+ """
8
+
9
+ from dataclasses import dataclass, field
10
+ import threading
11
+ import time
12
+ from typing import NamedTuple
13
+
14
+ import structlog
15
+
16
+ logger = structlog.get_logger(__name__)
17
+
18
+
19
+ class RateLimitResult(NamedTuple):
20
+ """Result of a rate limit check."""
21
+
22
+ allowed: bool
23
+ remaining: int
24
+ retry_after: float | None # Seconds until next attempt allowed
25
+ reason: str
26
+
27
+
28
+ @dataclass
29
+ class AuthRateLimitConfig:
30
+ """Configuration for authentication rate limiting.
31
+
32
+ Attributes:
33
+ enabled: Whether rate limiting is enabled.
34
+ max_attempts: Maximum failed attempts per window.
35
+ window_seconds: Time window for counting attempts.
36
+ lockout_seconds: How long to lock out after exceeding limit.
37
+ cleanup_interval: How often to clean up old entries.
38
+ """
39
+
40
+ enabled: bool = True
41
+ max_attempts: int = 10
42
+ window_seconds: int = 60
43
+ lockout_seconds: int = 300
44
+ cleanup_interval: int = 300 # 5 minutes
45
+
46
+
47
+ @dataclass
48
+ class _AttemptTracker:
49
+ """Tracks authentication attempts for a single IP."""
50
+
51
+ attempts: list[float] = field(default_factory=list)
52
+ locked_until: float | None = None
53
+
54
+
55
+ class AuthRateLimiter:
56
+ """Rate limiter for authentication attempts.
57
+
58
+ Tracks failed authentication attempts per IP address and
59
+ blocks IPs that exceed the configured threshold.
60
+
61
+ Thread-safe implementation using RLock.
62
+
63
+ Usage:
64
+ limiter = AuthRateLimiter(config)
65
+
66
+ # Before authentication
67
+ result = limiter.check_rate_limit(client_ip)
68
+ if not result.allowed:
69
+ raise RateLimitExceeded(result.retry_after)
70
+
71
+ # After failed authentication
72
+ limiter.record_failure(client_ip)
73
+
74
+ # After successful authentication
75
+ limiter.record_success(client_ip)
76
+ """
77
+
78
+ def __init__(self, config: AuthRateLimitConfig | None = None):
79
+ """Initialize rate limiter.
80
+
81
+ Args:
82
+ config: Rate limit configuration. Uses defaults if None.
83
+ """
84
+ self._config = config or AuthRateLimitConfig()
85
+ self._trackers: dict[str, _AttemptTracker] = {}
86
+ self._lock = threading.RLock()
87
+ self._last_cleanup = time.time()
88
+
89
+ @property
90
+ def enabled(self) -> bool:
91
+ """Check if rate limiting is enabled."""
92
+ return self._config.enabled
93
+
94
+ def check_rate_limit(self, ip: str) -> RateLimitResult:
95
+ """Check if an IP is allowed to attempt authentication.
96
+
97
+ Args:
98
+ ip: Client IP address.
99
+
100
+ Returns:
101
+ RateLimitResult indicating if attempt is allowed.
102
+ """
103
+ if not self._config.enabled:
104
+ return RateLimitResult(
105
+ allowed=True,
106
+ remaining=self._config.max_attempts,
107
+ retry_after=None,
108
+ reason="rate_limiting_disabled",
109
+ )
110
+
111
+ now = time.time()
112
+
113
+ with self._lock:
114
+ self._maybe_cleanup(now)
115
+
116
+ tracker = self._trackers.get(ip)
117
+ if tracker is None:
118
+ return RateLimitResult(
119
+ allowed=True,
120
+ remaining=self._config.max_attempts,
121
+ retry_after=None,
122
+ reason="no_previous_attempts",
123
+ )
124
+
125
+ # Check if locked out
126
+ if tracker.locked_until is not None:
127
+ if now < tracker.locked_until:
128
+ retry_after = tracker.locked_until - now
129
+ logger.warning(
130
+ "auth_rate_limit_locked",
131
+ ip=ip,
132
+ retry_after=retry_after,
133
+ )
134
+ return RateLimitResult(
135
+ allowed=False,
136
+ remaining=0,
137
+ retry_after=retry_after,
138
+ reason="locked_out",
139
+ )
140
+ else:
141
+ # Lockout expired, reset tracker
142
+ tracker.locked_until = None
143
+ tracker.attempts.clear()
144
+
145
+ # Count attempts in current window
146
+ window_start = now - self._config.window_seconds
147
+ recent_attempts = [t for t in tracker.attempts if t > window_start]
148
+ tracker.attempts = recent_attempts # Prune old attempts
149
+
150
+ remaining = self._config.max_attempts - len(recent_attempts)
151
+
152
+ if remaining <= 0:
153
+ # Lock out the IP
154
+ tracker.locked_until = now + self._config.lockout_seconds
155
+ logger.warning(
156
+ "auth_rate_limit_exceeded",
157
+ ip=ip,
158
+ attempts=len(recent_attempts),
159
+ lockout_seconds=self._config.lockout_seconds,
160
+ )
161
+ return RateLimitResult(
162
+ allowed=False,
163
+ remaining=0,
164
+ retry_after=self._config.lockout_seconds,
165
+ reason="rate_limit_exceeded",
166
+ )
167
+
168
+ return RateLimitResult(
169
+ allowed=True,
170
+ remaining=remaining,
171
+ retry_after=None,
172
+ reason="within_limit",
173
+ )
174
+
175
+ def record_failure(self, ip: str) -> None:
176
+ """Record a failed authentication attempt.
177
+
178
+ Args:
179
+ ip: Client IP address.
180
+ """
181
+ if not self._config.enabled:
182
+ return
183
+
184
+ now = time.time()
185
+
186
+ with self._lock:
187
+ if ip not in self._trackers:
188
+ self._trackers[ip] = _AttemptTracker()
189
+
190
+ tracker = self._trackers[ip]
191
+ tracker.attempts.append(now)
192
+
193
+ logger.debug(
194
+ "auth_failure_recorded",
195
+ ip=ip,
196
+ total_attempts=len(tracker.attempts),
197
+ )
198
+
199
+ def record_success(self, ip: str) -> None:
200
+ """Record a successful authentication (clears failure count).
201
+
202
+ Args:
203
+ ip: Client IP address.
204
+ """
205
+ if not self._config.enabled:
206
+ return
207
+
208
+ with self._lock:
209
+ if ip in self._trackers:
210
+ del self._trackers[ip]
211
+ logger.debug("auth_success_cleared_tracker", ip=ip)
212
+
213
+ def get_status(self, ip: str) -> dict:
214
+ """Get rate limit status for an IP.
215
+
216
+ Args:
217
+ ip: Client IP address.
218
+
219
+ Returns:
220
+ Dict with rate limit status information.
221
+ """
222
+ with self._lock:
223
+ tracker = self._trackers.get(ip)
224
+ if tracker is None:
225
+ return {
226
+ "ip": ip,
227
+ "attempts": 0,
228
+ "remaining": self._config.max_attempts,
229
+ "locked": False,
230
+ "locked_until": None,
231
+ }
232
+
233
+ now = time.time()
234
+ window_start = now - self._config.window_seconds
235
+ recent = len([t for t in tracker.attempts if t > window_start])
236
+
237
+ return {
238
+ "ip": ip,
239
+ "attempts": recent,
240
+ "remaining": max(0, self._config.max_attempts - recent),
241
+ "locked": tracker.locked_until is not None and now < tracker.locked_until,
242
+ "locked_until": tracker.locked_until,
243
+ }
244
+
245
+ def clear(self, ip: str | None = None) -> None:
246
+ """Clear rate limit data.
247
+
248
+ Args:
249
+ ip: Specific IP to clear, or None to clear all.
250
+ """
251
+ with self._lock:
252
+ if ip is None:
253
+ self._trackers.clear()
254
+ logger.info("auth_rate_limit_cleared_all")
255
+ elif ip in self._trackers:
256
+ del self._trackers[ip]
257
+ logger.info("auth_rate_limit_cleared", ip=ip)
258
+
259
+ def _maybe_cleanup(self, now: float) -> None:
260
+ """Clean up old entries periodically.
261
+
262
+ Called under lock.
263
+ """
264
+ if now - self._last_cleanup < self._config.cleanup_interval:
265
+ return
266
+
267
+ self._last_cleanup = now
268
+ window_start = now - self._config.window_seconds
269
+
270
+ # Remove trackers with no recent activity and not locked
271
+ to_remove = []
272
+ for ip, tracker in self._trackers.items():
273
+ # Keep if locked
274
+ if tracker.locked_until is not None and now < tracker.locked_until:
275
+ continue
276
+ # Keep if has recent attempts
277
+ if any(t > window_start for t in tracker.attempts):
278
+ continue
279
+ to_remove.append(ip)
280
+
281
+ for ip in to_remove:
282
+ del self._trackers[ip]
283
+
284
+ if to_remove:
285
+ logger.debug("auth_rate_limit_cleanup", removed_count=len(to_remove))
286
+
287
+
288
+ # Global instance for use across the application
289
+ _default_limiter: AuthRateLimiter | None = None
290
+
291
+
292
+ def get_auth_rate_limiter() -> AuthRateLimiter:
293
+ """Get the global auth rate limiter instance.
294
+
295
+ Returns:
296
+ AuthRateLimiter instance.
297
+ """
298
+ global _default_limiter
299
+ if _default_limiter is None:
300
+ _default_limiter = AuthRateLimiter()
301
+ return _default_limiter
302
+
303
+
304
+ def set_auth_rate_limiter(limiter: AuthRateLimiter) -> None:
305
+ """Set the global auth rate limiter instance.
306
+
307
+ Args:
308
+ limiter: AuthRateLimiter to use globally.
309
+ """
310
+ global _default_limiter
311
+ _default_limiter = limiter