kailash 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,137 @@
1
+ """
2
+ Authentication Models for Kailash Middleware
3
+
4
+ Provides data models for JWT authentication without any circular dependencies.
5
+ These models can be imported anywhere in the codebase safely.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ from typing import List, Optional
11
+
12
+
13
+ @dataclass
14
+ class JWTConfig:
15
+ """Configuration for JWT authentication supporting both HS256 and RSA algorithms."""
16
+
17
+ # Signing configuration
18
+ algorithm: str = "HS256" # Default to HS256 for simplicity
19
+ secret_key: Optional[str] = None # For HS256
20
+ use_rsa: bool = False # Enable RSA mode
21
+ private_key: Optional[str] = None # For RSA (PEM format)
22
+ public_key: Optional[str] = None # For RSA (PEM format)
23
+
24
+ # Token expiration
25
+ access_token_expire_minutes: int = 15
26
+ refresh_token_expire_days: int = 7
27
+
28
+ # Security settings
29
+ issuer: str = "kailash-middleware"
30
+ audience: str = "kailash-api"
31
+
32
+ # Key management
33
+ auto_generate_keys: bool = True
34
+ key_rotation_days: int = 30 # Only applies to RSA mode
35
+
36
+ # Token settings
37
+ include_user_claims: bool = True
38
+ include_permissions: bool = True
39
+ max_refresh_count: int = 10
40
+
41
+ # Security features
42
+ enable_blacklist: bool = True
43
+ enable_token_cleanup: bool = True
44
+ cleanup_interval_minutes: int = 60
45
+
46
+
47
+ @dataclass
48
+ class TokenPayload:
49
+ """JWT token payload structure."""
50
+
51
+ # Standard claims
52
+ sub: str # Subject (user ID)
53
+ iss: str # Issuer
54
+ aud: str # Audience
55
+ exp: int # Expiration time
56
+ iat: int # Issued at
57
+ jti: str # JWT ID
58
+
59
+ # Custom claims
60
+ tenant_id: Optional[str] = None
61
+ session_id: Optional[str] = None
62
+ user_type: str = "user"
63
+ permissions: List[str] = None
64
+ roles: List[str] = None
65
+
66
+ # Token metadata
67
+ token_type: str = "access" # access, refresh
68
+ refresh_count: int = 0
69
+
70
+ def __post_init__(self):
71
+ if self.permissions is None:
72
+ self.permissions = []
73
+ if self.roles is None:
74
+ self.roles = []
75
+
76
+
77
+ @dataclass
78
+ class TokenPair:
79
+ """Access and refresh token pair."""
80
+
81
+ access_token: str
82
+ refresh_token: str
83
+ token_type: str = "Bearer"
84
+ expires_in: int = 0
85
+ expires_at: Optional[datetime] = None
86
+ scope: Optional[str] = None
87
+
88
+
89
+ @dataclass
90
+ class RefreshTokenData:
91
+ """Metadata for tracking refresh tokens."""
92
+
93
+ jti: str # Token ID
94
+ user_id: str
95
+ tenant_id: Optional[str] = None
96
+ session_id: Optional[str] = None
97
+ created_at: datetime = None
98
+ last_used: Optional[datetime] = None
99
+ refresh_count: int = 0
100
+ ip_address: Optional[str] = None
101
+ user_agent: Optional[str] = None
102
+
103
+ def __post_init__(self):
104
+ if self.created_at is None:
105
+ self.created_at = datetime.utcnow()
106
+
107
+
108
+ @dataclass
109
+ class UserClaims:
110
+ """User claims for JWT tokens."""
111
+
112
+ user_id: str
113
+ tenant_id: Optional[str] = None
114
+ email: Optional[str] = None
115
+ username: Optional[str] = None
116
+ roles: List[str] = None
117
+ permissions: List[str] = None
118
+ metadata: dict = None
119
+
120
+ def __post_init__(self):
121
+ if self.roles is None:
122
+ self.roles = []
123
+ if self.permissions is None:
124
+ self.permissions = []
125
+ if self.metadata is None:
126
+ self.metadata = {}
127
+
128
+
129
+ @dataclass
130
+ class AuthenticationResult:
131
+ """Result of authentication attempt."""
132
+
133
+ success: bool
134
+ token_pair: Optional[TokenPair] = None
135
+ user_claims: Optional[UserClaims] = None
136
+ error: Optional[str] = None
137
+ error_code: Optional[str] = None
@@ -0,0 +1,257 @@
1
+ """
2
+ Authentication Utilities for Kailash Middleware
3
+
4
+ Provides helper functions for authentication without circular dependencies.
5
+ """
6
+
7
+ import base64
8
+ import secrets
9
+ import string
10
+ from datetime import datetime, timedelta, timezone
11
+ from typing import Dict, Optional, Tuple
12
+
13
+
14
+ def generate_secret_key(length: int = 32) -> str:
15
+ """
16
+ Generate a secure random secret key for HS256.
17
+
18
+ Args:
19
+ length: Length of the secret key (default: 32)
20
+
21
+ Returns:
22
+ URL-safe base64 encoded secret key
23
+ """
24
+ return secrets.token_urlsafe(length)
25
+
26
+
27
+ def generate_key_pair() -> Tuple[str, str]:
28
+ """
29
+ Generate RSA key pair for RS256.
30
+
31
+ Returns:
32
+ Tuple of (private_key_pem, public_key_pem)
33
+ """
34
+ try:
35
+ from cryptography.hazmat.backends import default_backend
36
+ from cryptography.hazmat.primitives import serialization
37
+ from cryptography.hazmat.primitives.asymmetric import rsa
38
+
39
+ # Generate private key
40
+ private_key = rsa.generate_private_key(
41
+ public_exponent=65537, key_size=2048, backend=default_backend()
42
+ )
43
+
44
+ # Serialize private key
45
+ private_pem = private_key.private_bytes(
46
+ encoding=serialization.Encoding.PEM,
47
+ format=serialization.PrivateFormat.PKCS8,
48
+ encryption_algorithm=serialization.NoEncryption(),
49
+ ).decode("utf-8")
50
+
51
+ # Get public key
52
+ public_key = private_key.public_key()
53
+
54
+ # Serialize public key
55
+ public_pem = public_key.public_bytes(
56
+ encoding=serialization.Encoding.PEM,
57
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
58
+ ).decode("utf-8")
59
+
60
+ return private_pem, public_pem
61
+
62
+ except ImportError:
63
+ raise ImportError(
64
+ "RSA key generation requires 'cryptography' package. "
65
+ "Install with: pip install cryptography"
66
+ )
67
+
68
+
69
+ def calculate_token_expiry(
70
+ token_type: str = "access", access_minutes: int = 15, refresh_days: int = 7
71
+ ) -> datetime:
72
+ """
73
+ Calculate token expiration time.
74
+
75
+ Args:
76
+ token_type: Type of token ("access" or "refresh")
77
+ access_minutes: Minutes until access token expires
78
+ refresh_days: Days until refresh token expires
79
+
80
+ Returns:
81
+ Expiration datetime in UTC
82
+ """
83
+ now = datetime.now(timezone.utc)
84
+
85
+ if token_type == "access":
86
+ return now + timedelta(minutes=access_minutes)
87
+ else: # refresh
88
+ return now + timedelta(days=refresh_days)
89
+
90
+
91
+ def is_token_expired(exp_timestamp: int) -> bool:
92
+ """
93
+ Check if token has expired based on exp claim.
94
+
95
+ Args:
96
+ exp_timestamp: Expiration timestamp from token
97
+
98
+ Returns:
99
+ True if token has expired
100
+ """
101
+ now = datetime.now(timezone.utc)
102
+ exp_datetime = datetime.fromtimestamp(exp_timestamp, timezone.utc)
103
+ return now > exp_datetime
104
+
105
+
106
+ def generate_jti() -> str:
107
+ """
108
+ Generate unique JWT ID.
109
+
110
+ Returns:
111
+ Unique identifier for JWT
112
+ """
113
+ return secrets.token_urlsafe(16)
114
+
115
+
116
+ def encode_for_jwks(number: int) -> str:
117
+ """
118
+ Encode integer for JWKS format.
119
+
120
+ Args:
121
+ number: Integer to encode (e.g., RSA modulus or exponent)
122
+
123
+ Returns:
124
+ Base64url encoded string without padding
125
+ """
126
+ byte_length = (number.bit_length() + 7) // 8
127
+ number_bytes = number.to_bytes(byte_length, "big")
128
+ return base64.urlsafe_b64encode(number_bytes).decode("ascii").rstrip("=")
129
+
130
+
131
+ def validate_algorithm(algorithm: str) -> bool:
132
+ """
133
+ Validate JWT algorithm.
134
+
135
+ Args:
136
+ algorithm: Algorithm name
137
+
138
+ Returns:
139
+ True if algorithm is supported
140
+ """
141
+ supported = ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"]
142
+ return algorithm in supported
143
+
144
+
145
+ def parse_bearer_token(authorization_header: str) -> Optional[str]:
146
+ """
147
+ Extract token from Authorization header.
148
+
149
+ Args:
150
+ authorization_header: Value of Authorization header
151
+
152
+ Returns:
153
+ Token string or None if invalid format
154
+ """
155
+ if not authorization_header:
156
+ return None
157
+
158
+ parts = authorization_header.split()
159
+
160
+ if len(parts) != 2 or parts[0].lower() != "bearer":
161
+ return None
162
+
163
+ return parts[1]
164
+
165
+
166
+ def generate_random_password(
167
+ length: int = 16,
168
+ include_uppercase: bool = True,
169
+ include_lowercase: bool = True,
170
+ include_digits: bool = True,
171
+ include_symbols: bool = True,
172
+ ) -> str:
173
+ """
174
+ Generate a random password.
175
+
176
+ Args:
177
+ length: Password length
178
+ include_uppercase: Include uppercase letters
179
+ include_lowercase: Include lowercase letters
180
+ include_digits: Include digits
181
+ include_symbols: Include symbols
182
+
183
+ Returns:
184
+ Random password string
185
+ """
186
+ characters = ""
187
+
188
+ if include_uppercase:
189
+ characters += string.ascii_uppercase
190
+ if include_lowercase:
191
+ characters += string.ascii_lowercase
192
+ if include_digits:
193
+ characters += string.digits
194
+ if include_symbols:
195
+ characters += string.punctuation
196
+
197
+ if not characters:
198
+ characters = string.ascii_letters + string.digits
199
+
200
+ return "".join(secrets.choice(characters) for _ in range(length))
201
+
202
+
203
+ def hash_token_for_storage(token: str) -> str:
204
+ """
205
+ Hash token for secure storage (e.g., in blacklist).
206
+
207
+ Args:
208
+ token: JWT token to hash
209
+
210
+ Returns:
211
+ SHA256 hash of token
212
+ """
213
+ import hashlib
214
+
215
+ return hashlib.sha256(token.encode()).hexdigest()
216
+
217
+
218
+ def create_jwks_response(
219
+ public_key_pem: str, key_id: str, algorithm: str = "RS256"
220
+ ) -> Dict:
221
+ """
222
+ Create JWKS response for public key endpoint.
223
+
224
+ Args:
225
+ public_key_pem: Public key in PEM format
226
+ key_id: Key identifier
227
+ algorithm: Algorithm used
228
+
229
+ Returns:
230
+ JWKS formatted response
231
+ """
232
+ try:
233
+ from cryptography.hazmat.backends import default_backend
234
+ from cryptography.hazmat.primitives import serialization
235
+
236
+ # Load public key
237
+ public_key = serialization.load_pem_public_key(
238
+ public_key_pem.encode(), backend=default_backend()
239
+ )
240
+
241
+ # Get public numbers
242
+ public_numbers = public_key.public_numbers()
243
+
244
+ return {
245
+ "keys": [
246
+ {
247
+ "kty": "RSA",
248
+ "kid": key_id,
249
+ "use": "sig",
250
+ "alg": algorithm,
251
+ "n": encode_for_jwks(public_numbers.n),
252
+ "e": encode_for_jwks(public_numbers.e),
253
+ }
254
+ ]
255
+ }
256
+ except Exception:
257
+ return {"keys": []}
@@ -33,7 +33,6 @@ from ...nodes.security import CredentialManagerNode
33
33
  from ...nodes.transform import DataTransformer
34
34
  from ...workflow import Workflow
35
35
  from ...workflow.builder import WorkflowBuilder
36
- from ..auth import KailashJWTAuthManager
37
36
  from ..core.agent_ui import AgentUIMiddleware
38
37
  from ..core.schema import DynamicSchemaRegistry
39
38
  from .events import EventFilter, EventType
@@ -41,8 +40,8 @@ from .realtime import RealtimeMiddleware
41
40
 
42
41
  logger = logging.getLogger(__name__)
43
42
 
44
- # Use SDK Auth Manager instead of manual security
45
- auth_manager = KailashJWTAuthManager()
43
+ # Auth manager will be injected via dependency injection
44
+ # This avoids circular imports and allows for flexible auth implementations
46
45
 
47
46
 
48
47
  # Pydantic Models
@@ -135,8 +134,23 @@ class APIGateway:
135
134
  enable_docs: bool = True,
136
135
  max_sessions: int = 1000,
137
136
  enable_auth: bool = True,
137
+ auth_manager=None, # Dependency injection for auth
138
138
  database_url: str = None,
139
139
  ):
140
+ """
141
+ Initialize API Gateway with dependency injection support.
142
+
143
+ Args:
144
+ title: API title
145
+ description: API description
146
+ version: API version
147
+ cors_origins: Allowed CORS origins
148
+ enable_docs: Enable OpenAPI documentation
149
+ max_sessions: Maximum concurrent sessions
150
+ enable_auth: Enable authentication
151
+ auth_manager: Optional auth manager instance (creates default if None and auth enabled)
152
+ database_url: Optional database URL for persistence
153
+ """
140
154
  self.title = title
141
155
  self.version = version
142
156
  self.enable_docs = enable_docs
@@ -153,7 +167,21 @@ class APIGateway:
153
167
 
154
168
  # Initialize auth manager if enabled
155
169
  if enable_auth:
156
- self.auth_manager = KailashJWTAuthManager(secret_key="api-gateway-secret")
170
+ if auth_manager is None:
171
+ # Create default auth manager if none provided
172
+ # Import here to avoid circular dependency
173
+ from ..auth import JWTAuthManager
174
+
175
+ self.auth_manager = JWTAuthManager(
176
+ secret_key="api-gateway-secret",
177
+ algorithm="HS256",
178
+ issuer="kailash-gateway",
179
+ audience="kailash-api",
180
+ )
181
+ else:
182
+ self.auth_manager = auth_manager
183
+ else:
184
+ self.auth_manager = None
157
185
 
158
186
  # Create FastAPI app with lifespan management
159
187
  @asynccontextmanager
@@ -774,25 +802,39 @@ class APIGateway:
774
802
 
775
803
  # Convenience function for quick setup
776
804
  def create_gateway(
777
- agent_ui_middleware: AgentUIMiddleware = None, **kwargs
805
+ agent_ui_middleware: AgentUIMiddleware = None, auth_manager=None, **kwargs
778
806
  ) -> APIGateway:
779
807
  """
780
- Create a configured API gateway instance.
808
+ Create a configured API gateway instance with dependency injection.
781
809
 
782
810
  Args:
783
811
  agent_ui_middleware: Optional existing AgentUIMiddleware instance
812
+ auth_manager: Optional auth manager instance (e.g., JWTAuthManager)
784
813
  **kwargs: Additional arguments for APIGateway initialization
785
814
 
786
815
  Returns:
787
816
  Configured APIGateway instance
788
817
 
789
818
  Example:
819
+ >>> from kailash.middleware.auth import JWTAuthManager
820
+ >>>
821
+ >>> # Create with custom auth
822
+ >>> auth = JWTAuthManager(use_rsa=True)
790
823
  >>> gateway = create_gateway(
791
824
  ... title="My App Gateway",
792
- ... cors_origins=["http://localhost:3000"]
825
+ ... cors_origins=["http://localhost:3000"],
826
+ ... auth_manager=auth
793
827
  ... )
828
+ >>>
829
+ >>> # Or use default auth
830
+ >>> gateway = create_gateway(title="My App")
831
+ >>>
794
832
  >>> gateway.run(port=8000)
795
833
  """
834
+ # Pass auth_manager to APIGateway
835
+ if auth_manager is not None:
836
+ kwargs["auth_manager"] = auth_manager
837
+
796
838
  gateway = APIGateway(**kwargs)
797
839
 
798
840
  if agent_ui_middleware:
@@ -540,7 +540,114 @@ class AgentUIMiddleware:
540
540
  inputs: Dict[str, Any] = None,
541
541
  config_overrides: Dict[str, Any] = None,
542
542
  ) -> str:
543
- """Execute a workflow asynchronously."""
543
+ """Execute a workflow asynchronously.
544
+
545
+ .. deprecated:: 0.5.0
546
+ Use :meth:`execute` instead. This method will be removed in version 1.0.0.
547
+ """
548
+ import warnings
549
+
550
+ warnings.warn(
551
+ "execute_workflow() is deprecated and will be removed in version 1.0.0. "
552
+ "Use execute() instead for consistency with runtime API.",
553
+ DeprecationWarning,
554
+ stacklevel=2,
555
+ )
556
+ session = await self.get_session(session_id)
557
+ if not session:
558
+ raise ValueError(f"Session {session_id} not found")
559
+
560
+ # Get workflow
561
+ workflow = None
562
+ if workflow_id in session.workflows:
563
+ workflow = session.workflows[workflow_id]
564
+ elif workflow_id in self.shared_workflows:
565
+ workflow = self.shared_workflows[workflow_id]
566
+ else:
567
+ raise ValueError(f"Workflow {workflow_id} not found")
568
+
569
+ # Start execution
570
+ execution_id = session.start_execution(workflow_id, inputs)
571
+
572
+ # Track execution
573
+ self.active_executions[execution_id] = {
574
+ "session_id": session_id,
575
+ "workflow_id": workflow_id,
576
+ "workflow": workflow,
577
+ "inputs": inputs or {},
578
+ "config_overrides": config_overrides or {},
579
+ "start_time": time.time(),
580
+ }
581
+
582
+ # Persist execution if enabled
583
+ if self.enable_persistence:
584
+ try:
585
+ await self.execution_repo.create(
586
+ {
587
+ "id": execution_id,
588
+ "workflow_id": workflow_id,
589
+ "session_id": session_id,
590
+ "user_id": session.user_id,
591
+ "inputs": inputs,
592
+ "metadata": config_overrides,
593
+ }
594
+ )
595
+ except Exception as e:
596
+ logger.error(f"Failed to persist execution: {e}")
597
+
598
+ # Log execution start
599
+ logger.info(
600
+ f"Workflow execution started: {execution_id} for workflow {workflow_id}"
601
+ )
602
+
603
+ # Emit started event
604
+ await self.event_stream.emit_workflow_started(
605
+ workflow_id=workflow_id,
606
+ workflow_name=workflow.name,
607
+ execution_id=execution_id,
608
+ user_id=session.user_id,
609
+ session_id=session_id,
610
+ )
611
+
612
+ # Execute in background
613
+ asyncio.create_task(self._execute_workflow_async(execution_id))
614
+
615
+ self.workflows_executed += 1
616
+ return execution_id
617
+
618
+ async def execute(
619
+ self,
620
+ session_id: str,
621
+ workflow_id: str,
622
+ inputs: Dict[str, Any] = None,
623
+ config_overrides: Dict[str, Any] = None,
624
+ ) -> str:
625
+ """
626
+ Execute a workflow asynchronously.
627
+
628
+ This is the preferred method for workflow execution, providing consistency
629
+ with the runtime API.
630
+
631
+ Args:
632
+ session_id: Session identifier
633
+ workflow_id: Workflow identifier
634
+ inputs: Optional input parameters for the workflow
635
+ config_overrides: Optional configuration overrides
636
+
637
+ Returns:
638
+ str: Execution ID for tracking
639
+
640
+ Raises:
641
+ ValueError: If session or workflow not found
642
+ RuntimeError: If execution fails
643
+
644
+ Example:
645
+ >>> execution_id = await middleware.execute(
646
+ ... session_id="sess_123",
647
+ ... workflow_id="data_pipeline",
648
+ ... inputs={"data": "input.csv"}
649
+ ... )
650
+ """
544
651
  session = await self.get_session(session_id)
545
652
  if not session:
546
653
  raise ValueError(f"Session {session_id} not found")
@@ -56,7 +56,9 @@ class BaseRepository:
56
56
  """Execute database query using SDK node."""
57
57
  try:
58
58
  if self.use_async:
59
- result = await self.db_node.execute(query=query, params=params or {})
59
+ result = await self.db_node.execute_async(
60
+ query=query, params=params or {}
61
+ )
60
62
  else:
61
63
  result = self.db_node.execute(query=query, params=params or {})
62
64
 
@@ -248,8 +248,8 @@ else:
248
248
 
249
249
  self.tool_register_workflow.add_node(validator)
250
250
  self.tool_register_workflow.add_node(register_handler)
251
- self.tool_register_workflow.connect(
252
- validator, register_handler, mapping={"result": "validation_result"}
251
+ self.tool_register_workflow.add_connection(
252
+ "validate_tool", "result", "register_tool", "validation_result"
253
253
  )
254
254
 
255
255
  # Tool Execution Workflow