google-adk-extras 0.2.6__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_adk_extras/__init__.py +3 -3
- google_adk_extras/adk_builder.py +15 -292
- google_adk_extras/artifacts/local_folder_artifact_service.py +0 -2
- google_adk_extras/artifacts/mongo_artifact_service.py +0 -1
- google_adk_extras/artifacts/s3_artifact_service.py +0 -1
- google_adk_extras/artifacts/sql_artifact_service.py +0 -1
- google_adk_extras/auth/__init__.py +10 -0
- google_adk_extras/auth/attach.py +227 -0
- google_adk_extras/auth/config.py +45 -0
- google_adk_extras/auth/jwt_utils.py +36 -0
- google_adk_extras/auth/sql_store.py +183 -0
- google_adk_extras/custom_agent_loader.py +1 -1
- google_adk_extras/enhanced_adk_web_server.py +0 -2
- google_adk_extras/enhanced_fastapi.py +6 -1
- google_adk_extras/memory/mongo_memory_service.py +0 -1
- google_adk_extras/memory/sql_memory_service.py +1 -1
- google_adk_extras/memory/yaml_file_memory_service.py +1 -3
- google_adk_extras/sessions/mongo_session_service.py +0 -1
- google_adk_extras/sessions/redis_session_service.py +1 -1
- google_adk_extras/sessions/yaml_file_session_service.py +0 -2
- google_adk_extras/streaming/streaming_controller.py +2 -2
- {google_adk_extras-0.2.6.dist-info → google_adk_extras-0.3.0.dist-info}/METADATA +84 -34
- google_adk_extras-0.3.0.dist-info/RECORD +37 -0
- google_adk_extras/credentials/__init__.py +0 -34
- google_adk_extras/credentials/github_oauth2_credential_service.py +0 -213
- google_adk_extras/credentials/google_oauth2_credential_service.py +0 -216
- google_adk_extras/credentials/http_basic_auth_credential_service.py +0 -388
- google_adk_extras/credentials/jwt_credential_service.py +0 -345
- google_adk_extras/credentials/microsoft_oauth2_credential_service.py +0 -250
- google_adk_extras/credentials/x_oauth2_credential_service.py +0 -240
- google_adk_extras-0.2.6.dist-info/RECORD +0 -39
- {google_adk_extras-0.2.6.dist-info → google_adk_extras-0.3.0.dist-info}/WHEEL +0 -0
- {google_adk_extras-0.2.6.dist-info → google_adk_extras-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {google_adk_extras-0.2.6.dist-info → google_adk_extras-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
6
|
+
import base64
|
7
|
+
from fastapi.security import APIKeyHeader
|
8
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
9
|
+
|
10
|
+
from .config import AuthConfig, JwtIssuerConfig, JwtValidatorConfig
|
11
|
+
from .jwt_utils import decode_jwt, encode_jwt, now_ts
|
12
|
+
from .sql_store import AuthStore
|
13
|
+
|
14
|
+
|
15
|
+
def attach_auth(app: FastAPI, cfg: Optional[AuthConfig]) -> None:
|
16
|
+
"""Attach optional auth to the provided FastAPI app.
|
17
|
+
|
18
|
+
- Adds middleware that enforces auth on sensitive routes.
|
19
|
+
- Optionally registers token issuance endpoints if configured.
|
20
|
+
"""
|
21
|
+
if not cfg or not cfg.enabled or cfg.allow_no_auth:
|
22
|
+
return
|
23
|
+
|
24
|
+
validator = cfg.jwt_validator
|
25
|
+
issuer_cfg = cfg.jwt_issuer
|
26
|
+
api_keys = set(cfg.api_keys or [])
|
27
|
+
basic_users = cfg.basic_users or {}
|
28
|
+
auth_store: Optional[AuthStore] = None
|
29
|
+
if issuer_cfg and issuer_cfg.database_url:
|
30
|
+
auth_store = AuthStore(issuer_cfg.database_url)
|
31
|
+
|
32
|
+
# Security helpers
|
33
|
+
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
34
|
+
|
35
|
+
async def _authenticate(request: Request) -> dict:
|
36
|
+
# API Key
|
37
|
+
api_key = request.query_params.get("api_key") or request.headers.get("x-api-key") or request.headers.get("X-API-Key")
|
38
|
+
if not api_key:
|
39
|
+
api_key = await api_key_header.__call__(request)
|
40
|
+
if api_key and api_key in api_keys:
|
41
|
+
return {"method": "api_key", "sub": "api_key_client"}
|
42
|
+
if api_key and auth_store and auth_store.verify_api_key(api_key):
|
43
|
+
return {"method": "api_key", "sub": "api_key_client"}
|
44
|
+
|
45
|
+
# Basic
|
46
|
+
authz = request.headers.get("authorization") or request.headers.get("Authorization")
|
47
|
+
if authz and authz.lower().startswith("basic "):
|
48
|
+
try:
|
49
|
+
b64 = authz.split(" ", 1)[1]
|
50
|
+
raw = base64.b64decode(b64).decode("utf-8")
|
51
|
+
username, _, password = raw.partition(":")
|
52
|
+
except Exception:
|
53
|
+
username, password = "", ""
|
54
|
+
# If SQL store present, try it first; else fall back to configured map
|
55
|
+
if auth_store:
|
56
|
+
uid = auth_store.authenticate_basic(username, password)
|
57
|
+
if uid:
|
58
|
+
return {"method": "basic", "sub": uid, "username": username}
|
59
|
+
stored = basic_users.get(username)
|
60
|
+
if stored and (stored == password):
|
61
|
+
return {"method": "basic", "sub": username, "username": username}
|
62
|
+
|
63
|
+
# Bearer JWT
|
64
|
+
if authz and authz.lower().startswith("bearer "):
|
65
|
+
token = authz.split(" ", 1)[1]
|
66
|
+
if validator and (validator.jwks_url or validator.hs256_secret):
|
67
|
+
try:
|
68
|
+
claims = decode_jwt(
|
69
|
+
token,
|
70
|
+
issuer=validator.issuer,
|
71
|
+
audience=validator.audience,
|
72
|
+
jwks_url=validator.jwks_url,
|
73
|
+
hs256_secret=validator.hs256_secret,
|
74
|
+
)
|
75
|
+
sub = str(claims.get("sub"))
|
76
|
+
if not sub:
|
77
|
+
raise HTTPException(status_code=401, detail="Invalid token: no subject")
|
78
|
+
return {"method": "jwt", "sub": sub, "claims": claims}
|
79
|
+
except Exception as e:
|
80
|
+
raise HTTPException(status_code=401, detail=f"Invalid token: {e}")
|
81
|
+
|
82
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
83
|
+
|
84
|
+
def _path_requires_auth(path: str, method: str) -> bool:
|
85
|
+
method = method.upper()
|
86
|
+
# Always protect core run endpoints
|
87
|
+
if path == "/run" and method == "POST":
|
88
|
+
return True
|
89
|
+
if path == "/run_sse" and method == "POST":
|
90
|
+
return True
|
91
|
+
# Sessions and artifacts under /apps
|
92
|
+
if path.startswith("/apps/"):
|
93
|
+
# Allow metrics to be toggled
|
94
|
+
if path.endswith("/metrics-info") and method == "GET":
|
95
|
+
return cfg.protect_metrics
|
96
|
+
return True
|
97
|
+
# Debug and builder are privileged
|
98
|
+
if path.startswith("/debug/") or path.startswith("/builder/"):
|
99
|
+
return True
|
100
|
+
# API key management endpoints
|
101
|
+
if path.startswith("/auth/api-keys"):
|
102
|
+
return True
|
103
|
+
# Optionally protect list-apps
|
104
|
+
if path == "/list-apps" and method == "GET":
|
105
|
+
return cfg.protect_list_apps
|
106
|
+
return False
|
107
|
+
|
108
|
+
class _AuthMiddleware(BaseHTTPMiddleware):
|
109
|
+
async def dispatch(self, request: Request, call_next):
|
110
|
+
path = request.url.path
|
111
|
+
if not _path_requires_auth(path, request.method):
|
112
|
+
return await call_next(request)
|
113
|
+
# Authenticate
|
114
|
+
try:
|
115
|
+
request.state.identity = await _authenticate(request)
|
116
|
+
except HTTPException as e:
|
117
|
+
from fastapi.responses import JSONResponse
|
118
|
+
return JSONResponse({"detail": e.detail}, status_code=e.status_code)
|
119
|
+
# Optional: Enforce user ownership when path has /users/{user_id}/
|
120
|
+
try:
|
121
|
+
parts = path.strip("/").split("/")
|
122
|
+
if "users" in parts:
|
123
|
+
idx = parts.index("users")
|
124
|
+
claimed = parts[idx + 1]
|
125
|
+
sub = str(request.state.identity.get("sub"))
|
126
|
+
# Allow api_key method to bypass ownership
|
127
|
+
if request.state.identity.get("method") != "api_key" and sub != claimed:
|
128
|
+
from fastapi.responses import JSONResponse
|
129
|
+
return JSONResponse({"detail": "Forbidden: user mismatch"}, status_code=403)
|
130
|
+
except HTTPException:
|
131
|
+
raise
|
132
|
+
except Exception:
|
133
|
+
pass
|
134
|
+
return await call_next(request)
|
135
|
+
|
136
|
+
app.add_middleware(_AuthMiddleware)
|
137
|
+
|
138
|
+
# Token issuance endpoints (optional)
|
139
|
+
if issuer_cfg and issuer_cfg.enabled:
|
140
|
+
if issuer_cfg.algorithm == "HS256" and not issuer_cfg.hs256_secret:
|
141
|
+
raise RuntimeError("HS256 issuer requires hs256_secret")
|
142
|
+
router = APIRouter()
|
143
|
+
|
144
|
+
@router.post("/auth/register")
|
145
|
+
async def register(username: str, password: str):
|
146
|
+
if not auth_store:
|
147
|
+
raise HTTPException(status_code=400, detail="SQL store not configured")
|
148
|
+
uid = auth_store.create_user(username, password)
|
149
|
+
return {"user_id": uid}
|
150
|
+
|
151
|
+
@router.post("/auth/token")
|
152
|
+
async def token_grant(grant_type: str = "password", username: Optional[str] = None, password: Optional[str] = None,
|
153
|
+
user_id: Optional[str] = None, fingerprint: Optional[str] = None):
|
154
|
+
sub: Optional[str] = None
|
155
|
+
if grant_type == "password":
|
156
|
+
if not auth_store or not username or password is None:
|
157
|
+
raise HTTPException(status_code=400, detail="invalid_request")
|
158
|
+
uid = auth_store.authenticate_basic(username, password)
|
159
|
+
if not uid:
|
160
|
+
raise HTTPException(status_code=401, detail="invalid_grant")
|
161
|
+
sub = uid
|
162
|
+
elif grant_type == "client_credentials":
|
163
|
+
# For simplicity map to provided user_id
|
164
|
+
if not user_id:
|
165
|
+
raise HTTPException(status_code=400, detail="invalid_request")
|
166
|
+
sub = user_id
|
167
|
+
else:
|
168
|
+
raise HTTPException(status_code=400, detail="unsupported_grant_type")
|
169
|
+
|
170
|
+
now = now_ts()
|
171
|
+
access = {
|
172
|
+
"iss": issuer_cfg.issuer,
|
173
|
+
"aud": issuer_cfg.audience,
|
174
|
+
"sub": sub,
|
175
|
+
"iat": now,
|
176
|
+
"nbf": now,
|
177
|
+
"exp": now + issuer_cfg.access_ttl_seconds,
|
178
|
+
}
|
179
|
+
key = issuer_cfg.hs256_secret if issuer_cfg.algorithm == "HS256" else ""
|
180
|
+
access_token = encode_jwt(access, algorithm=issuer_cfg.algorithm, key=key)
|
181
|
+
|
182
|
+
refresh_token = None
|
183
|
+
if auth_store:
|
184
|
+
jti = auth_store.issue_refresh(sub, issuer_cfg.refresh_ttl_seconds, fingerprint=fingerprint)
|
185
|
+
refresh_token = jti
|
186
|
+
return {"access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token}
|
187
|
+
|
188
|
+
@router.post("/auth/refresh")
|
189
|
+
async def refresh(user_id: str, refresh_token: str, fingerprint: Optional[str] = None):
|
190
|
+
if not auth_store:
|
191
|
+
raise HTTPException(status_code=400, detail="invalid_request")
|
192
|
+
if not auth_store.verify_refresh(refresh_token, user_id, fingerprint=fingerprint):
|
193
|
+
raise HTTPException(status_code=401, detail="invalid_grant")
|
194
|
+
now = now_ts()
|
195
|
+
access = {
|
196
|
+
"iss": issuer_cfg.issuer,
|
197
|
+
"aud": issuer_cfg.audience,
|
198
|
+
"sub": user_id,
|
199
|
+
"iat": now,
|
200
|
+
"nbf": now,
|
201
|
+
"exp": now + issuer_cfg.access_ttl_seconds,
|
202
|
+
}
|
203
|
+
key = issuer_cfg.hs256_secret if issuer_cfg.algorithm == "HS256" else ""
|
204
|
+
access_token = encode_jwt(access, algorithm=issuer_cfg.algorithm, key=key)
|
205
|
+
return {"access_token": access_token, "token_type": "bearer"}
|
206
|
+
|
207
|
+
app.include_router(router)
|
208
|
+
|
209
|
+
# API key management endpoints (require SQL store)
|
210
|
+
if auth_store:
|
211
|
+
api_router = APIRouter()
|
212
|
+
|
213
|
+
@api_router.post("/auth/api-keys")
|
214
|
+
async def create_api_key(user_id: Optional[str] = None, name: Optional[str] = None):
|
215
|
+
key_id, key_plain = auth_store.create_api_key(user_id=user_id, name=name)
|
216
|
+
return {"id": key_id, "api_key": key_plain}
|
217
|
+
|
218
|
+
@api_router.get("/auth/api-keys")
|
219
|
+
async def list_api_keys():
|
220
|
+
return auth_store.list_api_keys()
|
221
|
+
|
222
|
+
@api_router.delete("/auth/api-keys/{key_id}")
|
223
|
+
async def delete_api_key(key_id: str):
|
224
|
+
auth_store.revoke_api_key(key_id)
|
225
|
+
return {"ok": True}
|
226
|
+
|
227
|
+
app.include_router(api_router)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class JwtValidatorConfig:
|
9
|
+
# Accept JWTs from external issuers (e.g., Google/Auth0/Okta) or our own issuer
|
10
|
+
jwks_url: Optional[str] = None
|
11
|
+
issuer: Optional[str] = None
|
12
|
+
audience: Optional[str] = None
|
13
|
+
# If you want to validate with an HS256 shared secret (tests/dev)
|
14
|
+
hs256_secret: Optional[str] = None
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class JwtIssuerConfig:
|
19
|
+
# Configure our own issuer if we issue tokens
|
20
|
+
enabled: bool = False
|
21
|
+
issuer: str = "https://example-issuer"
|
22
|
+
audience: str = "adk-api"
|
23
|
+
algorithm: str = "HS256" # HS256 or RS256/ES256 later
|
24
|
+
hs256_secret: Optional[str] = None
|
25
|
+
access_ttl_seconds: int = 3600
|
26
|
+
refresh_ttl_seconds: int = 60 * 60 * 24 * 14
|
27
|
+
# SQL store for users/refresh tokens
|
28
|
+
database_url: Optional[str] = None # e.g. sqlite:///auth.db
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class AuthConfig:
|
33
|
+
# Global toggle
|
34
|
+
enabled: bool = False
|
35
|
+
# Modes
|
36
|
+
allow_no_auth: bool = False # if True, bypass checks entirely
|
37
|
+
api_keys: List[str] = field(default_factory=list) # accepted API keys
|
38
|
+
basic_users: dict[str, str] = field(default_factory=dict) # username -> password (PBKDF2 hash or plaintext for tests)
|
39
|
+
jwt_validator: Optional[JwtValidatorConfig] = None
|
40
|
+
jwt_issuer: Optional[JwtIssuerConfig] = None
|
41
|
+
# Route policy toggles
|
42
|
+
protect_list_apps: bool = True
|
43
|
+
protect_metrics: bool = True
|
44
|
+
# Scopes are advisory; we currently validate presence of a token and subject. Extend as needed.
|
45
|
+
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import base64
|
4
|
+
import json
|
5
|
+
import time
|
6
|
+
from typing import Any, Dict, Optional
|
7
|
+
|
8
|
+
import jwt
|
9
|
+
from jwt import PyJWKClient
|
10
|
+
|
11
|
+
|
12
|
+
def _b64url(data: bytes) -> str:
|
13
|
+
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
14
|
+
|
15
|
+
|
16
|
+
def encode_jwt(payload: Dict[str, Any], *, algorithm: str, key: str, headers: Optional[Dict[str, Any]] = None) -> str:
|
17
|
+
return jwt.encode(payload, key, algorithm=algorithm, headers=headers)
|
18
|
+
|
19
|
+
|
20
|
+
def decode_jwt(token: str, *, issuer: Optional[str] = None, audience: Optional[str] = None,
|
21
|
+
jwks_url: Optional[str] = None, hs256_secret: Optional[str] = None) -> Dict[str, Any]:
|
22
|
+
options = {"verify_signature": True, "verify_exp": True, "verify_nbf": True}
|
23
|
+
if jwks_url:
|
24
|
+
jwk_client = PyJWKClient(jwks_url)
|
25
|
+
signing_key = jwk_client.get_signing_key_from_jwt(token).key
|
26
|
+
return jwt.decode(token, signing_key, algorithms=["RS256", "ES256"], audience=audience, issuer=issuer, options=options)
|
27
|
+
elif hs256_secret:
|
28
|
+
return jwt.decode(token, hs256_secret, algorithms=["HS256"], audience=audience, issuer=issuer, options=options)
|
29
|
+
else:
|
30
|
+
raise ValueError("No validation method configured (jwks_url or hs256_secret required)")
|
31
|
+
|
32
|
+
|
33
|
+
def now_ts() -> int:
|
34
|
+
return int(time.time())
|
35
|
+
|
36
|
+
|
@@ -0,0 +1,183 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import hashlib
|
4
|
+
import os
|
5
|
+
import secrets
|
6
|
+
from datetime import datetime, timedelta, timezone
|
7
|
+
from typing import Optional
|
8
|
+
|
9
|
+
try:
|
10
|
+
from sqlalchemy import Column, String, DateTime, create_engine, Text, Boolean
|
11
|
+
from sqlalchemy.orm import declarative_base, sessionmaker
|
12
|
+
except ImportError as e:
|
13
|
+
raise ImportError(
|
14
|
+
"SQLAlchemy is required for the auth SQL store. Install with: pip install sqlalchemy"
|
15
|
+
) from e
|
16
|
+
|
17
|
+
|
18
|
+
Base = declarative_base()
|
19
|
+
|
20
|
+
|
21
|
+
def _pbkdf2(password: str, salt: str) -> str:
|
22
|
+
dk = hashlib.pbkdf2_hmac("sha256", password.encode(), salt.encode(), 200_000)
|
23
|
+
return dk.hex()
|
24
|
+
|
25
|
+
|
26
|
+
def hash_password(password: str) -> str:
|
27
|
+
salt = secrets.token_hex(16)
|
28
|
+
return f"pbkdf2_sha256${salt}${_pbkdf2(password, salt)}"
|
29
|
+
|
30
|
+
|
31
|
+
def verify_password(password: str, stored: str) -> bool:
|
32
|
+
try:
|
33
|
+
algo, salt, digest = stored.split("$", 2)
|
34
|
+
if algo != "pbkdf2_sha256":
|
35
|
+
# fallback for plaintext in tests
|
36
|
+
return secrets.compare_digest(password, stored)
|
37
|
+
return secrets.compare_digest(_pbkdf2(password, salt), digest)
|
38
|
+
except Exception:
|
39
|
+
return secrets.compare_digest(password, stored)
|
40
|
+
|
41
|
+
|
42
|
+
class User(Base):
|
43
|
+
__tablename__ = "auth_users"
|
44
|
+
id = Column(String, primary_key=True)
|
45
|
+
username = Column(String, unique=True, index=True, nullable=False)
|
46
|
+
password_hash = Column(String, nullable=False)
|
47
|
+
roles = Column(String, default="") # comma-separated
|
48
|
+
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
49
|
+
disabled = Column(Boolean, default=False)
|
50
|
+
|
51
|
+
|
52
|
+
class RefreshToken(Base):
|
53
|
+
__tablename__ = "auth_refresh_tokens"
|
54
|
+
jti = Column(String, primary_key=True)
|
55
|
+
user_id = Column(String, index=True, nullable=False)
|
56
|
+
expires_at = Column(DateTime(timezone=True), nullable=False)
|
57
|
+
revoked_at = Column(DateTime(timezone=True), nullable=True)
|
58
|
+
fingerprint = Column(String, nullable=True)
|
59
|
+
|
60
|
+
|
61
|
+
class ApiKey(Base):
|
62
|
+
__tablename__ = "auth_api_keys"
|
63
|
+
id = Column(String, primary_key=True)
|
64
|
+
user_id = Column(String, index=True, nullable=True)
|
65
|
+
key_hash = Column(String, nullable=False)
|
66
|
+
name = Column(String, nullable=True)
|
67
|
+
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
68
|
+
revoked_at = Column(DateTime(timezone=True), nullable=True)
|
69
|
+
|
70
|
+
|
71
|
+
class AuthStore:
|
72
|
+
def __init__(self, database_url: str):
|
73
|
+
self.engine = create_engine(database_url)
|
74
|
+
Base.metadata.create_all(self.engine)
|
75
|
+
self.Session = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
|
76
|
+
|
77
|
+
def create_user(self, username: str, password: str, user_id: Optional[str] = None) -> str:
|
78
|
+
import uuid
|
79
|
+
uid = user_id or str(uuid.uuid4())
|
80
|
+
with self.Session() as s:
|
81
|
+
u = User(id=uid, username=username, password_hash=hash_password(password))
|
82
|
+
s.add(u)
|
83
|
+
s.commit()
|
84
|
+
return uid
|
85
|
+
|
86
|
+
def authenticate_basic(self, username: str, password: str) -> Optional[str]:
|
87
|
+
with self.Session() as s:
|
88
|
+
u: Optional[User] = s.query(User).filter_by(username=username).first()
|
89
|
+
if not u or u.disabled:
|
90
|
+
return None
|
91
|
+
if verify_password(password, u.password_hash):
|
92
|
+
return u.id
|
93
|
+
return None
|
94
|
+
|
95
|
+
def issue_refresh(self, user_id: str, ttl_seconds: int, fingerprint: Optional[str] = None) -> str:
|
96
|
+
import uuid
|
97
|
+
jti = str(uuid.uuid4())
|
98
|
+
with self.Session() as s:
|
99
|
+
rt = RefreshToken(
|
100
|
+
jti=jti,
|
101
|
+
user_id=user_id,
|
102
|
+
expires_at=datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds),
|
103
|
+
fingerprint=fingerprint,
|
104
|
+
)
|
105
|
+
s.add(rt)
|
106
|
+
s.commit()
|
107
|
+
return jti
|
108
|
+
|
109
|
+
def verify_refresh(self, jti: str, user_id: str, fingerprint: Optional[str] = None) -> bool:
|
110
|
+
with self.Session() as s:
|
111
|
+
rt: Optional[RefreshToken] = s.query(RefreshToken).filter_by(jti=jti, user_id=user_id).first()
|
112
|
+
if not rt or rt.revoked_at is not None:
|
113
|
+
return False
|
114
|
+
if rt.expires_at <= datetime.now(timezone.utc):
|
115
|
+
return False
|
116
|
+
if fingerprint and rt.fingerprint and rt.fingerprint != fingerprint:
|
117
|
+
return False
|
118
|
+
return True
|
119
|
+
|
120
|
+
def revoke_refresh(self, jti: str) -> None:
|
121
|
+
with self.Session() as s:
|
122
|
+
rt: Optional[RefreshToken] = s.query(RefreshToken).filter_by(jti=jti).first()
|
123
|
+
if not rt:
|
124
|
+
return
|
125
|
+
rt.revoked_at = datetime.now(timezone.utc)
|
126
|
+
s.add(rt)
|
127
|
+
s.commit()
|
128
|
+
|
129
|
+
# API Keys
|
130
|
+
def _hash_api_key(self, key: str) -> str:
|
131
|
+
# Reuse PBKDF2; different prefix
|
132
|
+
salt = secrets.token_hex(16)
|
133
|
+
return f"api_pbkdf2_sha256${salt}${_pbkdf2(key, salt)}"
|
134
|
+
|
135
|
+
def _verify_api_key(self, key: str, stored: str) -> bool:
|
136
|
+
try:
|
137
|
+
algo, salt, digest = stored.split("$", 3)
|
138
|
+
if algo != "api_pbkdf2_sha256":
|
139
|
+
return secrets.compare_digest(key, stored)
|
140
|
+
return secrets.compare_digest(_pbkdf2(key, salt), digest)
|
141
|
+
except Exception:
|
142
|
+
return False
|
143
|
+
|
144
|
+
def create_api_key(self, user_id: Optional[str] = None, name: Optional[str] = None) -> tuple[str, str]:
|
145
|
+
import uuid
|
146
|
+
key_plain = secrets.token_urlsafe(32)
|
147
|
+
key_id = str(uuid.uuid4())
|
148
|
+
with self.Session() as s:
|
149
|
+
rec = ApiKey(id=key_id, user_id=user_id, key_hash=self._hash_api_key(key_plain), name=name)
|
150
|
+
s.add(rec)
|
151
|
+
s.commit()
|
152
|
+
return key_id, key_plain
|
153
|
+
|
154
|
+
def list_api_keys(self):
|
155
|
+
with self.Session() as s:
|
156
|
+
rows = s.query(ApiKey).all()
|
157
|
+
return [
|
158
|
+
{
|
159
|
+
"id": r.id,
|
160
|
+
"user_id": r.user_id,
|
161
|
+
"name": r.name,
|
162
|
+
"created_at": r.created_at.isoformat() if r.created_at else None,
|
163
|
+
"revoked": r.revoked_at is not None,
|
164
|
+
}
|
165
|
+
for r in rows
|
166
|
+
]
|
167
|
+
|
168
|
+
def revoke_api_key(self, key_id: str) -> None:
|
169
|
+
with self.Session() as s:
|
170
|
+
rec = s.query(ApiKey).filter_by(id=key_id).first()
|
171
|
+
if not rec:
|
172
|
+
return
|
173
|
+
rec.revoked_at = datetime.now(timezone.utc)
|
174
|
+
s.add(rec)
|
175
|
+
s.commit()
|
176
|
+
|
177
|
+
def verify_api_key(self, key: str) -> bool:
|
178
|
+
with self.Session() as s:
|
179
|
+
rows = s.query(ApiKey).filter(ApiKey.revoked_at.is_(None)).all()
|
180
|
+
for r in rows:
|
181
|
+
if self._verify_api_key(key, r.key_hash):
|
182
|
+
return True
|
183
|
+
return False
|
@@ -7,7 +7,7 @@ thread-safe registry management.
|
|
7
7
|
|
8
8
|
import logging
|
9
9
|
import threading
|
10
|
-
from typing import Dict, List
|
10
|
+
from typing import Dict, List
|
11
11
|
|
12
12
|
from google.adk.agents.base_agent import BaseAgent
|
13
13
|
from google.adk.cli.utils.base_agent_loader import BaseAgentLoader
|
@@ -5,13 +5,11 @@ AdkWebServer to use our EnhancedRunner with advanced features.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import os
|
8
|
-
from typing import Optional
|
9
8
|
|
10
9
|
from google.adk.cli.adk_web_server import AdkWebServer
|
11
10
|
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
|
12
11
|
from google.adk.cli.utils import cleanup
|
13
12
|
from google.adk.cli.utils import envs
|
14
|
-
from google.adk.runners import Runner
|
15
13
|
|
16
14
|
from .enhanced_runner import EnhancedRunner
|
17
15
|
|
@@ -22,7 +22,6 @@ from watchdog.observers import Observer
|
|
22
22
|
|
23
23
|
from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
|
24
24
|
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
25
|
-
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
|
26
25
|
from google.adk.auth.credential_service.base_credential_service import BaseCredentialService
|
27
26
|
from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
|
28
27
|
from google.adk.evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
@@ -35,6 +34,7 @@ from google.adk.sessions.database_session_service import DatabaseSessionService
|
|
35
34
|
from google.adk.utils.feature_decorator import working_in_progress
|
36
35
|
from google.adk.cli.adk_web_server import AdkWebServer
|
37
36
|
from .enhanced_adk_web_server import EnhancedAdkWebServer
|
37
|
+
from .auth import attach_auth, AuthConfig, JwtIssuerConfig, JwtValidatorConfig
|
38
38
|
from .streaming import StreamingController, StreamingConfig
|
39
39
|
from google.adk.cli.utils import envs
|
40
40
|
from google.adk.cli.utils import evals
|
@@ -69,6 +69,8 @@ def get_enhanced_fast_api_app(
|
|
69
69
|
# Streaming layer (optional)
|
70
70
|
enable_streaming: bool = False,
|
71
71
|
streaming_config: Optional[StreamingConfig] = None,
|
72
|
+
# Auth layer (optional)
|
73
|
+
auth_config: Optional[AuthConfig] = None,
|
72
74
|
) -> FastAPI:
|
73
75
|
"""Enhanced version of Google ADK's get_fast_api_app with EnhancedRunner integration.
|
74
76
|
|
@@ -601,4 +603,7 @@ def get_enhanced_fast_api_app(
|
|
601
603
|
|
602
604
|
app.include_router(router)
|
603
605
|
|
606
|
+
# Attach optional auth layer last so all routes are covered
|
607
|
+
attach_auth(app, auth_config)
|
608
|
+
|
604
609
|
return app
|
@@ -1,9 +1,9 @@
|
|
1
1
|
import asyncio
|
2
2
|
import time
|
3
3
|
from dataclasses import dataclass, field
|
4
|
-
from typing import Any, Dict, Optional,
|
4
|
+
from typing import Any, Dict, Optional, Callable, Awaitable
|
5
5
|
|
6
|
-
from fastapi import
|
6
|
+
from fastapi import HTTPException
|
7
7
|
from pydantic import BaseModel
|
8
8
|
|
9
9
|
from google.adk.events.event import Event
|