kailash 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +1 -1
- kailash/access_control/__init__.py +1 -1
- kailash/core/actors/adaptive_pool_controller.py +630 -0
- kailash/core/actors/connection_actor.py +3 -3
- kailash/core/ml/__init__.py +1 -0
- kailash/core/ml/query_patterns.py +544 -0
- kailash/core/monitoring/__init__.py +19 -0
- kailash/core/monitoring/connection_metrics.py +488 -0
- kailash/core/optimization/__init__.py +1 -0
- kailash/core/resilience/__init__.py +17 -0
- kailash/core/resilience/circuit_breaker.py +382 -0
- kailash/gateway/api.py +7 -5
- kailash/gateway/enhanced_gateway.py +1 -1
- kailash/middleware/auth/access_control.py +11 -11
- kailash/middleware/communication/ai_chat.py +7 -7
- kailash/middleware/communication/api_gateway.py +5 -15
- kailash/middleware/gateway/checkpoint_manager.py +45 -8
- kailash/middleware/gateway/event_store.py +66 -26
- kailash/middleware/mcp/enhanced_server.py +2 -2
- kailash/nodes/admin/permission_check.py +110 -30
- kailash/nodes/admin/schema.sql +387 -0
- kailash/nodes/admin/tenant_isolation.py +249 -0
- kailash/nodes/admin/transaction_utils.py +244 -0
- kailash/nodes/admin/user_management.py +37 -9
- kailash/nodes/ai/ai_providers.py +55 -3
- kailash/nodes/ai/llm_agent.py +115 -13
- kailash/nodes/data/query_pipeline.py +641 -0
- kailash/nodes/data/query_router.py +895 -0
- kailash/nodes/data/sql.py +24 -0
- kailash/nodes/data/workflow_connection_pool.py +451 -23
- kailash/nodes/monitoring/__init__.py +3 -5
- kailash/nodes/monitoring/connection_dashboard.py +822 -0
- kailash/nodes/rag/__init__.py +1 -3
- kailash/resources/registry.py +6 -0
- kailash/runtime/async_local.py +7 -0
- kailash/utils/export.py +152 -0
- kailash/workflow/builder.py +42 -0
- kailash/workflow/graph.py +86 -17
- kailash/workflow/templates.py +4 -9
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/METADATA +14 -1
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/RECORD +45 -31
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/WHEEL +0 -0
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/entry_points.txt +0 -0
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.6.0.dist-info → kailash-0.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,382 @@
|
|
1
|
+
"""Circuit Breaker pattern implementation for connection management.
|
2
|
+
|
3
|
+
This module implements the Circuit Breaker pattern to prevent cascading failures
|
4
|
+
in connection pools and database operations. It provides automatic failure detection,
|
5
|
+
recovery testing, and graceful degradation.
|
6
|
+
|
7
|
+
The circuit breaker has three states:
|
8
|
+
- CLOSED: Normal operation, requests pass through
|
9
|
+
- OPEN: Failures detected, requests fail fast
|
10
|
+
- HALF_OPEN: Testing recovery, limited requests allowed
|
11
|
+
|
12
|
+
Example:
|
13
|
+
>>> breaker = ConnectionCircuitBreaker(
|
14
|
+
... failure_threshold=5,
|
15
|
+
... recovery_timeout=60,
|
16
|
+
... half_open_requests=3
|
17
|
+
... )
|
18
|
+
>>>
|
19
|
+
>>> # Wrap connection operations
|
20
|
+
>>> async with breaker.call() as protected:
|
21
|
+
... result = await connection.execute(query)
|
22
|
+
"""
|
23
|
+
|
24
|
+
import asyncio
|
25
|
+
import logging
|
26
|
+
import time
|
27
|
+
from collections import deque
|
28
|
+
from dataclasses import dataclass, field
|
29
|
+
from datetime import datetime, timedelta
|
30
|
+
from enum import Enum
|
31
|
+
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
T = TypeVar("T")
|
36
|
+
|
37
|
+
|
38
|
+
class CircuitState(Enum):
|
39
|
+
"""Circuit breaker states."""
|
40
|
+
|
41
|
+
CLOSED = "closed" # Normal operation
|
42
|
+
OPEN = "open" # Failing fast
|
43
|
+
HALF_OPEN = "half_open" # Testing recovery
|
44
|
+
|
45
|
+
|
46
|
+
class CircuitBreakerError(Exception):
|
47
|
+
"""Raised when circuit breaker is open."""
|
48
|
+
|
49
|
+
pass
|
50
|
+
|
51
|
+
|
52
|
+
@dataclass
|
53
|
+
class CircuitBreakerConfig:
|
54
|
+
"""Configuration for circuit breaker behavior."""
|
55
|
+
|
56
|
+
failure_threshold: int = 5 # Failures before opening
|
57
|
+
success_threshold: int = 3 # Successes to close from half-open
|
58
|
+
recovery_timeout: int = 60 # Seconds before trying half-open
|
59
|
+
half_open_requests: int = 3 # Requests allowed in half-open
|
60
|
+
error_rate_threshold: float = 0.5 # Error rate to trigger open
|
61
|
+
window_size: int = 100 # Rolling window for error rate
|
62
|
+
excluded_exceptions: List[type] = field(default_factory=list) # Don't count these
|
63
|
+
|
64
|
+
|
65
|
+
@dataclass
|
66
|
+
class CircuitBreakerMetrics:
|
67
|
+
"""Metrics tracking for circuit breaker."""
|
68
|
+
|
69
|
+
total_calls: int = 0
|
70
|
+
successful_calls: int = 0
|
71
|
+
failed_calls: int = 0
|
72
|
+
rejected_calls: int = 0
|
73
|
+
state_transitions: List[Dict[str, Any]] = field(default_factory=list)
|
74
|
+
last_failure_time: Optional[float] = None
|
75
|
+
consecutive_failures: int = 0
|
76
|
+
consecutive_successes: int = 0
|
77
|
+
|
78
|
+
def record_success(self):
|
79
|
+
"""Record successful call."""
|
80
|
+
self.total_calls += 1
|
81
|
+
self.successful_calls += 1
|
82
|
+
self.consecutive_successes += 1
|
83
|
+
self.consecutive_failures = 0
|
84
|
+
|
85
|
+
def record_failure(self):
|
86
|
+
"""Record failed call."""
|
87
|
+
self.total_calls += 1
|
88
|
+
self.failed_calls += 1
|
89
|
+
self.consecutive_failures += 1
|
90
|
+
self.consecutive_successes = 0
|
91
|
+
self.last_failure_time = time.time()
|
92
|
+
|
93
|
+
def record_rejection(self):
|
94
|
+
"""Record rejected call (circuit open)."""
|
95
|
+
self.rejected_calls += 1
|
96
|
+
|
97
|
+
def get_error_rate(self) -> float:
|
98
|
+
"""Calculate current error rate."""
|
99
|
+
if self.total_calls == 0:
|
100
|
+
return 0.0
|
101
|
+
return self.failed_calls / self.total_calls
|
102
|
+
|
103
|
+
|
104
|
+
class ConnectionCircuitBreaker(Generic[T]):
|
105
|
+
"""Circuit breaker for database connections and operations.
|
106
|
+
|
107
|
+
Monitors failures and prevents cascading failures by failing fast
|
108
|
+
when error threshold is reached. Automatically tests recovery
|
109
|
+
after timeout period.
|
110
|
+
"""
|
111
|
+
|
112
|
+
def __init__(self, config: Optional[CircuitBreakerConfig] = None):
|
113
|
+
"""Initialize circuit breaker with configuration."""
|
114
|
+
self.config = config or CircuitBreakerConfig()
|
115
|
+
self.state = CircuitState.CLOSED
|
116
|
+
self.metrics = CircuitBreakerMetrics()
|
117
|
+
self._lock = asyncio.Lock()
|
118
|
+
self._half_open_requests = 0
|
119
|
+
self._last_state_change = time.time()
|
120
|
+
self._rolling_window = deque(maxlen=self.config.window_size)
|
121
|
+
self._listeners: List[Callable] = []
|
122
|
+
|
123
|
+
async def call(self, func: Callable[..., T], *args, **kwargs) -> T:
|
124
|
+
"""Execute function with circuit breaker protection.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
func: Async function to protect
|
128
|
+
*args: Function arguments
|
129
|
+
**kwargs: Function keyword arguments
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
Function result
|
133
|
+
|
134
|
+
Raises:
|
135
|
+
CircuitBreakerError: If circuit is open
|
136
|
+
Exception: If function fails
|
137
|
+
"""
|
138
|
+
async with self._lock:
|
139
|
+
# Check if we should transition states
|
140
|
+
await self._check_state_transition()
|
141
|
+
|
142
|
+
if self.state == CircuitState.OPEN:
|
143
|
+
self.metrics.record_rejection()
|
144
|
+
raise CircuitBreakerError(
|
145
|
+
f"Circuit breaker is OPEN. "
|
146
|
+
f"Rejected after {self.metrics.consecutive_failures} failures. "
|
147
|
+
f"Will retry in {self._time_until_recovery():.1f}s"
|
148
|
+
)
|
149
|
+
|
150
|
+
if self.state == CircuitState.HALF_OPEN:
|
151
|
+
if self._half_open_requests >= self.config.half_open_requests:
|
152
|
+
self.metrics.record_rejection()
|
153
|
+
raise CircuitBreakerError(
|
154
|
+
"Circuit breaker is HALF_OPEN but request limit reached"
|
155
|
+
)
|
156
|
+
self._half_open_requests += 1
|
157
|
+
|
158
|
+
# Execute the function
|
159
|
+
start_time = time.time()
|
160
|
+
try:
|
161
|
+
result = await func(*args, **kwargs)
|
162
|
+
await self._record_success()
|
163
|
+
return result
|
164
|
+
except Exception as e:
|
165
|
+
# Check if this exception should be counted
|
166
|
+
if not any(
|
167
|
+
isinstance(e, exc_type) for exc_type in self.config.excluded_exceptions
|
168
|
+
):
|
169
|
+
await self._record_failure(e)
|
170
|
+
raise
|
171
|
+
|
172
|
+
async def _check_state_transition(self):
|
173
|
+
"""Check if state should transition based on metrics."""
|
174
|
+
current_time = time.time()
|
175
|
+
|
176
|
+
if self.state == CircuitState.CLOSED:
|
177
|
+
# Check if we should open
|
178
|
+
if self._should_open():
|
179
|
+
await self._transition_to(CircuitState.OPEN)
|
180
|
+
|
181
|
+
elif self.state == CircuitState.OPEN:
|
182
|
+
# Check if we should try recovery
|
183
|
+
time_since_open = current_time - self._last_state_change
|
184
|
+
if time_since_open >= self.config.recovery_timeout:
|
185
|
+
await self._transition_to(CircuitState.HALF_OPEN)
|
186
|
+
self._half_open_requests = 0
|
187
|
+
|
188
|
+
elif self.state == CircuitState.HALF_OPEN:
|
189
|
+
# This is handled after request execution
|
190
|
+
pass
|
191
|
+
|
192
|
+
def _should_open(self) -> bool:
|
193
|
+
"""Determine if circuit should open based on failures."""
|
194
|
+
# Check consecutive failures
|
195
|
+
if self.metrics.consecutive_failures >= self.config.failure_threshold:
|
196
|
+
return True
|
197
|
+
|
198
|
+
# Check error rate in rolling window
|
199
|
+
if len(self._rolling_window) >= self.config.window_size / 2:
|
200
|
+
error_count = sum(1 for success in self._rolling_window if not success)
|
201
|
+
error_rate = error_count / len(self._rolling_window)
|
202
|
+
if error_rate >= self.config.error_rate_threshold:
|
203
|
+
return True
|
204
|
+
|
205
|
+
return False
|
206
|
+
|
207
|
+
async def _record_success(self):
|
208
|
+
"""Record successful execution."""
|
209
|
+
async with self._lock:
|
210
|
+
self.metrics.record_success()
|
211
|
+
self._rolling_window.append(True)
|
212
|
+
|
213
|
+
if self.state == CircuitState.HALF_OPEN:
|
214
|
+
if self.metrics.consecutive_successes >= self.config.success_threshold:
|
215
|
+
await self._transition_to(CircuitState.CLOSED)
|
216
|
+
|
217
|
+
async def _record_failure(self, error: Exception):
|
218
|
+
"""Record failed execution."""
|
219
|
+
async with self._lock:
|
220
|
+
self.metrics.record_failure()
|
221
|
+
self._rolling_window.append(False)
|
222
|
+
|
223
|
+
if self.state == CircuitState.HALF_OPEN:
|
224
|
+
# Single failure in half-open goes back to open
|
225
|
+
await self._transition_to(CircuitState.OPEN)
|
226
|
+
elif self.state == CircuitState.CLOSED:
|
227
|
+
# Check if we should open the circuit
|
228
|
+
if self._should_open():
|
229
|
+
await self._transition_to(CircuitState.OPEN)
|
230
|
+
|
231
|
+
logger.warning(
|
232
|
+
f"Circuit breaker recorded failure: {type(error).__name__}: {error}"
|
233
|
+
)
|
234
|
+
|
235
|
+
async def _transition_to(self, new_state: CircuitState):
|
236
|
+
"""Transition to new state and notify listeners."""
|
237
|
+
old_state = self.state
|
238
|
+
self.state = new_state
|
239
|
+
self._last_state_change = time.time()
|
240
|
+
|
241
|
+
# Reset counters on state change
|
242
|
+
if new_state == CircuitState.CLOSED:
|
243
|
+
self.metrics.consecutive_failures = 0
|
244
|
+
elif new_state == CircuitState.OPEN:
|
245
|
+
self.metrics.consecutive_successes = 0
|
246
|
+
|
247
|
+
# Record transition
|
248
|
+
self.metrics.state_transitions.append(
|
249
|
+
{
|
250
|
+
"from": old_state.value,
|
251
|
+
"to": new_state.value,
|
252
|
+
"timestamp": datetime.now().isoformat(),
|
253
|
+
"reason": self._get_transition_reason(old_state, new_state),
|
254
|
+
}
|
255
|
+
)
|
256
|
+
|
257
|
+
logger.info(
|
258
|
+
f"Circuit breaker transitioned from {old_state.value} to {new_state.value}"
|
259
|
+
)
|
260
|
+
|
261
|
+
# Notify listeners
|
262
|
+
for listener in self._listeners:
|
263
|
+
try:
|
264
|
+
await listener(old_state, new_state, self.metrics)
|
265
|
+
except Exception as e:
|
266
|
+
logger.error(f"Error notifying circuit breaker listener: {e}")
|
267
|
+
|
268
|
+
def _get_transition_reason(
|
269
|
+
self, old_state: CircuitState, new_state: CircuitState
|
270
|
+
) -> str:
|
271
|
+
"""Get human-readable reason for state transition."""
|
272
|
+
if old_state == CircuitState.CLOSED and new_state == CircuitState.OPEN:
|
273
|
+
return f"Failure threshold reached ({self.metrics.consecutive_failures} failures)"
|
274
|
+
elif old_state == CircuitState.OPEN and new_state == CircuitState.HALF_OPEN:
|
275
|
+
return f"Recovery timeout elapsed ({self.config.recovery_timeout}s)"
|
276
|
+
elif old_state == CircuitState.HALF_OPEN and new_state == CircuitState.CLOSED:
|
277
|
+
return f"Success threshold reached ({self.metrics.consecutive_successes} successes)"
|
278
|
+
elif old_state == CircuitState.HALF_OPEN and new_state == CircuitState.OPEN:
|
279
|
+
return "Failure during recovery test"
|
280
|
+
return "Unknown reason"
|
281
|
+
|
282
|
+
def _time_until_recovery(self) -> float:
|
283
|
+
"""Calculate seconds until recovery attempt."""
|
284
|
+
if self.state != CircuitState.OPEN:
|
285
|
+
return 0.0
|
286
|
+
elapsed = time.time() - self._last_state_change
|
287
|
+
remaining = self.config.recovery_timeout - elapsed
|
288
|
+
return max(0.0, remaining)
|
289
|
+
|
290
|
+
async def force_open(self, reason: str = "Manual override"):
|
291
|
+
"""Manually open the circuit breaker."""
|
292
|
+
async with self._lock:
|
293
|
+
if self.state != CircuitState.OPEN:
|
294
|
+
logger.warning(f"Manually opening circuit breaker: {reason}")
|
295
|
+
await self._transition_to(CircuitState.OPEN)
|
296
|
+
|
297
|
+
async def force_close(self, reason: str = "Manual override"):
|
298
|
+
"""Manually close the circuit breaker."""
|
299
|
+
async with self._lock:
|
300
|
+
if self.state != CircuitState.CLOSED:
|
301
|
+
logger.warning(f"Manually closing circuit breaker: {reason}")
|
302
|
+
self.metrics.consecutive_failures = 0
|
303
|
+
self.metrics.consecutive_successes = 0
|
304
|
+
await self._transition_to(CircuitState.CLOSED)
|
305
|
+
|
306
|
+
async def reset(self):
|
307
|
+
"""Reset circuit breaker to initial state."""
|
308
|
+
async with self._lock:
|
309
|
+
self.state = CircuitState.CLOSED
|
310
|
+
self.metrics = CircuitBreakerMetrics()
|
311
|
+
self._rolling_window.clear()
|
312
|
+
self._half_open_requests = 0
|
313
|
+
self._last_state_change = time.time()
|
314
|
+
logger.info("Circuit breaker reset to initial state")
|
315
|
+
|
316
|
+
def add_listener(self, listener: Callable):
|
317
|
+
"""Add state change listener."""
|
318
|
+
self._listeners.append(listener)
|
319
|
+
|
320
|
+
def remove_listener(self, listener: Callable):
|
321
|
+
"""Remove state change listener."""
|
322
|
+
if listener in self._listeners:
|
323
|
+
self._listeners.remove(listener)
|
324
|
+
|
325
|
+
def get_status(self) -> Dict[str, Any]:
|
326
|
+
"""Get current circuit breaker status."""
|
327
|
+
return {
|
328
|
+
"state": self.state.value,
|
329
|
+
"metrics": {
|
330
|
+
"total_calls": self.metrics.total_calls,
|
331
|
+
"successful_calls": self.metrics.successful_calls,
|
332
|
+
"failed_calls": self.metrics.failed_calls,
|
333
|
+
"rejected_calls": self.metrics.rejected_calls,
|
334
|
+
"error_rate": self.metrics.get_error_rate(),
|
335
|
+
"consecutive_failures": self.metrics.consecutive_failures,
|
336
|
+
"consecutive_successes": self.metrics.consecutive_successes,
|
337
|
+
},
|
338
|
+
"config": {
|
339
|
+
"failure_threshold": self.config.failure_threshold,
|
340
|
+
"success_threshold": self.config.success_threshold,
|
341
|
+
"recovery_timeout": self.config.recovery_timeout,
|
342
|
+
"error_rate_threshold": self.config.error_rate_threshold,
|
343
|
+
},
|
344
|
+
"time_until_recovery": (
|
345
|
+
self._time_until_recovery() if self.state == CircuitState.OPEN else None
|
346
|
+
),
|
347
|
+
"state_transitions": self.metrics.state_transitions[
|
348
|
+
-5:
|
349
|
+
], # Last 5 transitions
|
350
|
+
}
|
351
|
+
|
352
|
+
|
353
|
+
class CircuitBreakerManager:
|
354
|
+
"""Manages multiple circuit breakers for different resources."""
|
355
|
+
|
356
|
+
def __init__(self):
|
357
|
+
"""Initialize circuit breaker manager."""
|
358
|
+
self._breakers: Dict[str, ConnectionCircuitBreaker] = {}
|
359
|
+
self._default_config = CircuitBreakerConfig()
|
360
|
+
|
361
|
+
def get_or_create(
|
362
|
+
self, name: str, config: Optional[CircuitBreakerConfig] = None
|
363
|
+
) -> ConnectionCircuitBreaker:
|
364
|
+
"""Get existing or create new circuit breaker."""
|
365
|
+
if name not in self._breakers:
|
366
|
+
self._breakers[name] = ConnectionCircuitBreaker(
|
367
|
+
config or self._default_config
|
368
|
+
)
|
369
|
+
return self._breakers[name]
|
370
|
+
|
371
|
+
def get_all_status(self) -> Dict[str, Dict[str, Any]]:
|
372
|
+
"""Get status of all circuit breakers."""
|
373
|
+
return {name: breaker.get_status() for name, breaker in self._breakers.items()}
|
374
|
+
|
375
|
+
async def reset_all(self):
|
376
|
+
"""Reset all circuit breakers."""
|
377
|
+
for breaker in self._breakers.values():
|
378
|
+
await breaker.reset()
|
379
|
+
|
380
|
+
def set_default_config(self, config: CircuitBreakerConfig):
|
381
|
+
"""Set default configuration for new breakers."""
|
382
|
+
self._default_config = config
|
kailash/gateway/api.py
CHANGED
@@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
11
11
|
|
12
12
|
from fastapi import APIRouter, BackgroundTasks, Depends, FastAPI, HTTPException
|
13
13
|
from fastapi.responses import JSONResponse
|
14
|
-
from pydantic import BaseModel, Field
|
14
|
+
from pydantic import BaseModel, ConfigDict, Field
|
15
15
|
|
16
16
|
from ..resources.registry import ResourceRegistry
|
17
17
|
from .enhanced_gateway import (
|
@@ -37,14 +37,15 @@ class ResourceReferenceModel(BaseModel):
|
|
37
37
|
None, description="Reference to credentials secret"
|
38
38
|
)
|
39
39
|
|
40
|
-
|
41
|
-
|
40
|
+
model_config = ConfigDict(
|
41
|
+
json_schema_extra={
|
42
42
|
"example": {
|
43
43
|
"type": "database",
|
44
44
|
"config": {"host": "localhost", "port": 5432, "database": "myapp"},
|
45
45
|
"credentials_ref": "db_credentials",
|
46
46
|
}
|
47
47
|
}
|
48
|
+
)
|
48
49
|
|
49
50
|
|
50
51
|
class WorkflowRequestModel(BaseModel):
|
@@ -59,8 +60,8 @@ class WorkflowRequestModel(BaseModel):
|
|
59
60
|
None, description="Additional context variables"
|
60
61
|
)
|
61
62
|
|
62
|
-
|
63
|
-
|
63
|
+
model_config = ConfigDict(
|
64
|
+
json_schema_extra={
|
64
65
|
"example": {
|
65
66
|
"inputs": {"user_id": 123, "action": "process"},
|
66
67
|
"resources": {
|
@@ -74,6 +75,7 @@ class WorkflowRequestModel(BaseModel):
|
|
74
75
|
"context": {"environment": "production", "trace_id": "abc123"},
|
75
76
|
}
|
76
77
|
}
|
78
|
+
)
|
77
79
|
|
78
80
|
|
79
81
|
class WorkflowResponseModel(BaseModel):
|
@@ -40,7 +40,7 @@ class WorkflowRequest:
|
|
40
40
|
inputs: Dict[str, Any] = field(default_factory=dict)
|
41
41
|
resources: Dict[str, Union[str, ResourceReference]] = field(default_factory=dict)
|
42
42
|
context: Dict[str, Any] = field(default_factory=dict)
|
43
|
-
timestamp: datetime = field(default_factory=datetime.
|
43
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
44
44
|
|
45
45
|
def to_dict(self) -> Dict[str, Any]:
|
46
46
|
"""Convert to JSON-serializable dict."""
|
@@ -60,11 +60,11 @@ class MiddlewareAccessControlManager:
|
|
60
60
|
self.enable_audit = enable_audit
|
61
61
|
|
62
62
|
# Kailash nodes for operations
|
63
|
-
self.user_mgmt_node = UserManagementNode(
|
64
|
-
self.role_mgmt_node = RoleManagementNode(
|
65
|
-
self.permission_check_node = PermissionCheckNode(
|
66
|
-
self.audit_node = AuditLogNode(
|
67
|
-
self.security_event_node = SecurityEventNode(
|
63
|
+
self.user_mgmt_node = UserManagementNode()
|
64
|
+
self.role_mgmt_node = RoleManagementNode()
|
65
|
+
self.permission_check_node = PermissionCheckNode()
|
66
|
+
self.audit_node = AuditLogNode() if enable_audit else None
|
67
|
+
self.security_event_node = SecurityEventNode()
|
68
68
|
|
69
69
|
async def check_session_access(
|
70
70
|
self, user_context: UserContext, session_id: str, action: str = "access"
|
@@ -72,7 +72,7 @@ class MiddlewareAccessControlManager:
|
|
72
72
|
"""Check if user can access a specific session."""
|
73
73
|
|
74
74
|
# Use Kailash permission check node
|
75
|
-
result = self.permission_check_node.
|
75
|
+
result = self.permission_check_node.execute(
|
76
76
|
{
|
77
77
|
"user_context": user_context,
|
78
78
|
"resource_type": "session",
|
@@ -114,7 +114,7 @@ class MiddlewareAccessControlManager:
|
|
114
114
|
|
115
115
|
# Audit logging using Kailash audit node
|
116
116
|
if self.enable_audit and self.audit_node:
|
117
|
-
self.audit_node.
|
117
|
+
self.audit_node.execute(
|
118
118
|
{
|
119
119
|
"event_type": "workflow_access_check",
|
120
120
|
"user_id": user_context.user_id,
|
@@ -192,7 +192,7 @@ class MiddlewareAccessControlManager:
|
|
192
192
|
) -> Dict[str, Any]:
|
193
193
|
"""Assign role to user using Kailash role management node."""
|
194
194
|
|
195
|
-
result = self.role_mgmt_node.
|
195
|
+
result = self.role_mgmt_node.execute(
|
196
196
|
{
|
197
197
|
"action": "assign_role",
|
198
198
|
"user_id": user_id,
|
@@ -239,7 +239,7 @@ class MiddlewareAccessControlManager:
|
|
239
239
|
|
240
240
|
# Audit the rule creation
|
241
241
|
if self.enable_audit and self.audit_node:
|
242
|
-
self.audit_node.
|
242
|
+
self.audit_node.execute(
|
243
243
|
{
|
244
244
|
"event_type": "permission_rule_created",
|
245
245
|
"rule_data": rule_data,
|
@@ -366,7 +366,7 @@ class MiddlewareAuthenticationMiddleware:
|
|
366
366
|
try:
|
367
367
|
# This would typically validate JWT token
|
368
368
|
# For now, simulating with credential manager
|
369
|
-
cred_result = self.credential_manager.
|
369
|
+
cred_result = self.credential_manager.execute(
|
370
370
|
{"action": "validate_token", "token": token}
|
371
371
|
)
|
372
372
|
|
@@ -388,7 +388,7 @@ class MiddlewareAuthenticationMiddleware:
|
|
388
388
|
|
389
389
|
except Exception as e:
|
390
390
|
# Log security event using Kailash security event node
|
391
|
-
self.access_manager.security_event_node.
|
391
|
+
self.access_manager.security_event_node.execute(
|
392
392
|
{
|
393
393
|
"event_type": "authentication_failure",
|
394
394
|
"error": str(e),
|
@@ -371,7 +371,7 @@ EXPLANATION:
|
|
371
371
|
|
372
372
|
try:
|
373
373
|
result = await asyncio.to_thread(
|
374
|
-
self.llm_node.
|
374
|
+
self.llm_node.execute, messages=[{"role": "user", "content": prompt}]
|
375
375
|
)
|
376
376
|
|
377
377
|
# Extract content from response
|
@@ -847,10 +847,10 @@ What would you like to work on? Just describe what you want to accomplish and I'
|
|
847
847
|
"""Store chat message with embedding in vector database."""
|
848
848
|
try:
|
849
849
|
# Generate embedding
|
850
|
-
embedding_result =
|
850
|
+
embedding_result = self.embedding_node.execute(text=content)
|
851
851
|
|
852
852
|
# Store in database (simplified for now)
|
853
|
-
|
853
|
+
self.vector_db.execute(
|
854
854
|
{
|
855
855
|
"query": "INSERT INTO chat_messages (id, session_id, user_id, content, role, timestamp) VALUES (?, ?, ?, ?, ?, ?)",
|
856
856
|
"parameters": [
|
@@ -875,10 +875,10 @@ What would you like to work on? Just describe what you want to accomplish and I'
|
|
875
875
|
"""Find similar past conversations using vector search."""
|
876
876
|
try:
|
877
877
|
# Generate query embedding
|
878
|
-
query_embedding =
|
878
|
+
query_embedding = self.embedding_node.execute(text=query)
|
879
879
|
|
880
880
|
# Search for similar messages (simplified for now)
|
881
|
-
search_result =
|
881
|
+
search_result = self.vector_db.execute(
|
882
882
|
{
|
883
883
|
"query": "SELECT * FROM chat_messages WHERE role = 'user' ORDER BY timestamp DESC LIMIT ?",
|
884
884
|
"parameters": [limit * 2],
|
@@ -930,7 +930,7 @@ What would you like to work on? Just describe what you want to accomplish and I'
|
|
930
930
|
|
931
931
|
try:
|
932
932
|
# Generate query embedding
|
933
|
-
query_embedding =
|
933
|
+
query_embedding = self.embedding_node.execute(text=query)
|
934
934
|
|
935
935
|
# Prepare filters
|
936
936
|
filters = {}
|
@@ -948,7 +948,7 @@ What would you like to work on? Just describe what you want to accomplish and I'
|
|
948
948
|
query_parts.append("ORDER BY timestamp DESC LIMIT ?")
|
949
949
|
params.append(limit)
|
950
950
|
|
951
|
-
search_result =
|
951
|
+
search_result = self.vector_db.execute(
|
952
952
|
{"query": " ".join(query_parts), "parameters": params}
|
953
953
|
)
|
954
954
|
|
@@ -225,10 +225,8 @@ class APIGateway:
|
|
225
225
|
# Data transformer for request/response formatting
|
226
226
|
self.data_transformer = DataTransformer(
|
227
227
|
name="gateway_transformer",
|
228
|
-
|
229
|
-
|
230
|
-
{"type": "add_field", "field": "timestamp", "value": "now()"},
|
231
|
-
],
|
228
|
+
# Transformations will be provided at runtime
|
229
|
+
transformations=[],
|
232
230
|
)
|
233
231
|
|
234
232
|
# Credential manager for gateway security
|
@@ -362,17 +360,9 @@ class APIGateway:
|
|
362
360
|
"active": session.active,
|
363
361
|
}
|
364
362
|
|
365
|
-
transformed =
|
366
|
-
|
367
|
-
|
368
|
-
"transformations": [
|
369
|
-
{
|
370
|
-
"type": "add_field",
|
371
|
-
"field": "api_version",
|
372
|
-
"value": self.version,
|
373
|
-
}
|
374
|
-
],
|
375
|
-
}
|
363
|
+
transformed = self.data_transformer.execute(
|
364
|
+
data=response_data,
|
365
|
+
transformations=[f"{{**data, 'api_version': '{self.version}'}}"],
|
376
366
|
)
|
377
367
|
|
378
368
|
return SessionResponse(**transformed["result"])
|
@@ -175,8 +175,32 @@ class CheckpointManager:
|
|
175
175
|
compression_enabled: bool = True,
|
176
176
|
compression_threshold_bytes: int = 1024, # 1KB
|
177
177
|
retention_hours: int = 24,
|
178
|
+
# Backward compatibility parameter
|
179
|
+
storage: Optional[DiskStorage] = None,
|
178
180
|
):
|
179
|
-
"""Initialize checkpoint manager.
|
181
|
+
"""Initialize checkpoint manager.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
memory_storage: Memory storage backend (optional)
|
185
|
+
disk_storage: Disk storage backend (optional)
|
186
|
+
cloud_storage: Cloud storage backend (optional)
|
187
|
+
compression_enabled: Enable compression for large checkpoints
|
188
|
+
compression_threshold_bytes: Minimum size for compression
|
189
|
+
retention_hours: Hours to retain checkpoints
|
190
|
+
storage: DEPRECATED - Use disk_storage instead
|
191
|
+
"""
|
192
|
+
# Handle backward compatibility
|
193
|
+
if storage is not None:
|
194
|
+
import warnings
|
195
|
+
|
196
|
+
warnings.warn(
|
197
|
+
"The 'storage' parameter is deprecated. Use 'disk_storage' instead.",
|
198
|
+
DeprecationWarning,
|
199
|
+
stacklevel=2,
|
200
|
+
)
|
201
|
+
if disk_storage is None:
|
202
|
+
disk_storage = storage
|
203
|
+
|
180
204
|
self.memory_storage = memory_storage or MemoryStorage()
|
181
205
|
self.disk_storage = disk_storage or DiskStorage()
|
182
206
|
self.cloud_storage = cloud_storage # Optional cloud backend
|
@@ -189,11 +213,23 @@ class CheckpointManager:
|
|
189
213
|
self.load_count = 0
|
190
214
|
self.compression_ratio_sum = 0.0
|
191
215
|
|
192
|
-
#
|
193
|
-
self._gc_task =
|
216
|
+
# Initialize garbage collection task (will be started when first used)
|
217
|
+
self._gc_task = None
|
218
|
+
self._gc_started = False
|
219
|
+
|
220
|
+
def _ensure_gc_started(self):
|
221
|
+
"""Ensure garbage collection task is started (lazy initialization)."""
|
222
|
+
if not self._gc_started:
|
223
|
+
try:
|
224
|
+
self._gc_task = asyncio.create_task(self._garbage_collection_loop())
|
225
|
+
self._gc_started = True
|
226
|
+
except RuntimeError:
|
227
|
+
# No event loop running, GC will be started later
|
228
|
+
pass
|
194
229
|
|
195
230
|
async def save_checkpoint(self, checkpoint: Checkpoint) -> None:
|
196
231
|
"""Save checkpoint to storage."""
|
232
|
+
self._ensure_gc_started()
|
197
233
|
start_time = time.time()
|
198
234
|
|
199
235
|
# Serialize checkpoint
|
@@ -391,8 +427,9 @@ class CheckpointManager:
|
|
391
427
|
|
392
428
|
async def close(self) -> None:
|
393
429
|
"""Close checkpoint manager and cleanup."""
|
394
|
-
self._gc_task
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
430
|
+
if self._gc_task is not None:
|
431
|
+
self._gc_task.cancel()
|
432
|
+
try:
|
433
|
+
await self._gc_task
|
434
|
+
except asyncio.CancelledError:
|
435
|
+
pass
|