google-adk-extras 0.2.6__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. google_adk_extras/__init__.py +3 -3
  2. google_adk_extras/adk_builder.py +15 -292
  3. google_adk_extras/artifacts/local_folder_artifact_service.py +0 -2
  4. google_adk_extras/artifacts/mongo_artifact_service.py +0 -1
  5. google_adk_extras/artifacts/s3_artifact_service.py +0 -1
  6. google_adk_extras/artifacts/sql_artifact_service.py +0 -1
  7. google_adk_extras/auth/__init__.py +10 -0
  8. google_adk_extras/auth/attach.py +227 -0
  9. google_adk_extras/auth/config.py +45 -0
  10. google_adk_extras/auth/jwt_utils.py +36 -0
  11. google_adk_extras/auth/sql_store.py +183 -0
  12. google_adk_extras/custom_agent_loader.py +1 -1
  13. google_adk_extras/enhanced_adk_web_server.py +0 -2
  14. google_adk_extras/enhanced_fastapi.py +6 -1
  15. google_adk_extras/memory/mongo_memory_service.py +0 -1
  16. google_adk_extras/memory/sql_memory_service.py +1 -1
  17. google_adk_extras/memory/yaml_file_memory_service.py +1 -3
  18. google_adk_extras/sessions/mongo_session_service.py +0 -1
  19. google_adk_extras/sessions/redis_session_service.py +1 -1
  20. google_adk_extras/sessions/yaml_file_session_service.py +0 -2
  21. google_adk_extras/streaming/streaming_controller.py +2 -2
  22. {google_adk_extras-0.2.6.dist-info → google_adk_extras-0.3.0.dist-info}/METADATA +84 -34
  23. google_adk_extras-0.3.0.dist-info/RECORD +37 -0
  24. google_adk_extras/credentials/__init__.py +0 -34
  25. google_adk_extras/credentials/github_oauth2_credential_service.py +0 -213
  26. google_adk_extras/credentials/google_oauth2_credential_service.py +0 -216
  27. google_adk_extras/credentials/http_basic_auth_credential_service.py +0 -388
  28. google_adk_extras/credentials/jwt_credential_service.py +0 -345
  29. google_adk_extras/credentials/microsoft_oauth2_credential_service.py +0 -250
  30. google_adk_extras/credentials/x_oauth2_credential_service.py +0 -240
  31. google_adk_extras-0.2.6.dist-info/RECORD +0 -39
  32. {google_adk_extras-0.2.6.dist-info → google_adk_extras-0.3.0.dist-info}/WHEEL +0 -0
  33. {google_adk_extras-0.2.6.dist-info → google_adk_extras-0.3.0.dist-info}/licenses/LICENSE +0 -0
  34. {google_adk_extras-0.2.6.dist-info → google_adk_extras-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
6
+ import base64
7
+ from fastapi.security import APIKeyHeader
8
+ from starlette.middleware.base import BaseHTTPMiddleware
9
+
10
+ from .config import AuthConfig, JwtIssuerConfig, JwtValidatorConfig
11
+ from .jwt_utils import decode_jwt, encode_jwt, now_ts
12
+ from .sql_store import AuthStore
13
+
14
+
15
+ def attach_auth(app: FastAPI, cfg: Optional[AuthConfig]) -> None:
16
+ """Attach optional auth to the provided FastAPI app.
17
+
18
+ - Adds middleware that enforces auth on sensitive routes.
19
+ - Optionally registers token issuance endpoints if configured.
20
+ """
21
+ if not cfg or not cfg.enabled or cfg.allow_no_auth:
22
+ return
23
+
24
+ validator = cfg.jwt_validator
25
+ issuer_cfg = cfg.jwt_issuer
26
+ api_keys = set(cfg.api_keys or [])
27
+ basic_users = cfg.basic_users or {}
28
+ auth_store: Optional[AuthStore] = None
29
+ if issuer_cfg and issuer_cfg.database_url:
30
+ auth_store = AuthStore(issuer_cfg.database_url)
31
+
32
+ # Security helpers
33
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
34
+
35
+ async def _authenticate(request: Request) -> dict:
36
+ # API Key
37
+ api_key = request.query_params.get("api_key") or request.headers.get("x-api-key") or request.headers.get("X-API-Key")
38
+ if not api_key:
39
+ api_key = await api_key_header.__call__(request)
40
+ if api_key and api_key in api_keys:
41
+ return {"method": "api_key", "sub": "api_key_client"}
42
+ if api_key and auth_store and auth_store.verify_api_key(api_key):
43
+ return {"method": "api_key", "sub": "api_key_client"}
44
+
45
+ # Basic
46
+ authz = request.headers.get("authorization") or request.headers.get("Authorization")
47
+ if authz and authz.lower().startswith("basic "):
48
+ try:
49
+ b64 = authz.split(" ", 1)[1]
50
+ raw = base64.b64decode(b64).decode("utf-8")
51
+ username, _, password = raw.partition(":")
52
+ except Exception:
53
+ username, password = "", ""
54
+ # If SQL store present, try it first; else fall back to configured map
55
+ if auth_store:
56
+ uid = auth_store.authenticate_basic(username, password)
57
+ if uid:
58
+ return {"method": "basic", "sub": uid, "username": username}
59
+ stored = basic_users.get(username)
60
+ if stored and (stored == password):
61
+ return {"method": "basic", "sub": username, "username": username}
62
+
63
+ # Bearer JWT
64
+ if authz and authz.lower().startswith("bearer "):
65
+ token = authz.split(" ", 1)[1]
66
+ if validator and (validator.jwks_url or validator.hs256_secret):
67
+ try:
68
+ claims = decode_jwt(
69
+ token,
70
+ issuer=validator.issuer,
71
+ audience=validator.audience,
72
+ jwks_url=validator.jwks_url,
73
+ hs256_secret=validator.hs256_secret,
74
+ )
75
+ sub = str(claims.get("sub"))
76
+ if not sub:
77
+ raise HTTPException(status_code=401, detail="Invalid token: no subject")
78
+ return {"method": "jwt", "sub": sub, "claims": claims}
79
+ except Exception as e:
80
+ raise HTTPException(status_code=401, detail=f"Invalid token: {e}")
81
+
82
+ raise HTTPException(status_code=401, detail="Unauthorized")
83
+
84
+ def _path_requires_auth(path: str, method: str) -> bool:
85
+ method = method.upper()
86
+ # Always protect core run endpoints
87
+ if path == "/run" and method == "POST":
88
+ return True
89
+ if path == "/run_sse" and method == "POST":
90
+ return True
91
+ # Sessions and artifacts under /apps
92
+ if path.startswith("/apps/"):
93
+ # Allow metrics to be toggled
94
+ if path.endswith("/metrics-info") and method == "GET":
95
+ return cfg.protect_metrics
96
+ return True
97
+ # Debug and builder are privileged
98
+ if path.startswith("/debug/") or path.startswith("/builder/"):
99
+ return True
100
+ # API key management endpoints
101
+ if path.startswith("/auth/api-keys"):
102
+ return True
103
+ # Optionally protect list-apps
104
+ if path == "/list-apps" and method == "GET":
105
+ return cfg.protect_list_apps
106
+ return False
107
+
108
+ class _AuthMiddleware(BaseHTTPMiddleware):
109
+ async def dispatch(self, request: Request, call_next):
110
+ path = request.url.path
111
+ if not _path_requires_auth(path, request.method):
112
+ return await call_next(request)
113
+ # Authenticate
114
+ try:
115
+ request.state.identity = await _authenticate(request)
116
+ except HTTPException as e:
117
+ from fastapi.responses import JSONResponse
118
+ return JSONResponse({"detail": e.detail}, status_code=e.status_code)
119
+ # Optional: Enforce user ownership when path has /users/{user_id}/
120
+ try:
121
+ parts = path.strip("/").split("/")
122
+ if "users" in parts:
123
+ idx = parts.index("users")
124
+ claimed = parts[idx + 1]
125
+ sub = str(request.state.identity.get("sub"))
126
+ # Allow api_key method to bypass ownership
127
+ if request.state.identity.get("method") != "api_key" and sub != claimed:
128
+ from fastapi.responses import JSONResponse
129
+ return JSONResponse({"detail": "Forbidden: user mismatch"}, status_code=403)
130
+ except HTTPException:
131
+ raise
132
+ except Exception:
133
+ pass
134
+ return await call_next(request)
135
+
136
+ app.add_middleware(_AuthMiddleware)
137
+
138
+ # Token issuance endpoints (optional)
139
+ if issuer_cfg and issuer_cfg.enabled:
140
+ if issuer_cfg.algorithm == "HS256" and not issuer_cfg.hs256_secret:
141
+ raise RuntimeError("HS256 issuer requires hs256_secret")
142
+ router = APIRouter()
143
+
144
+ @router.post("/auth/register")
145
+ async def register(username: str, password: str):
146
+ if not auth_store:
147
+ raise HTTPException(status_code=400, detail="SQL store not configured")
148
+ uid = auth_store.create_user(username, password)
149
+ return {"user_id": uid}
150
+
151
+ @router.post("/auth/token")
152
+ async def token_grant(grant_type: str = "password", username: Optional[str] = None, password: Optional[str] = None,
153
+ user_id: Optional[str] = None, fingerprint: Optional[str] = None):
154
+ sub: Optional[str] = None
155
+ if grant_type == "password":
156
+ if not auth_store or not username or password is None:
157
+ raise HTTPException(status_code=400, detail="invalid_request")
158
+ uid = auth_store.authenticate_basic(username, password)
159
+ if not uid:
160
+ raise HTTPException(status_code=401, detail="invalid_grant")
161
+ sub = uid
162
+ elif grant_type == "client_credentials":
163
+ # For simplicity map to provided user_id
164
+ if not user_id:
165
+ raise HTTPException(status_code=400, detail="invalid_request")
166
+ sub = user_id
167
+ else:
168
+ raise HTTPException(status_code=400, detail="unsupported_grant_type")
169
+
170
+ now = now_ts()
171
+ access = {
172
+ "iss": issuer_cfg.issuer,
173
+ "aud": issuer_cfg.audience,
174
+ "sub": sub,
175
+ "iat": now,
176
+ "nbf": now,
177
+ "exp": now + issuer_cfg.access_ttl_seconds,
178
+ }
179
+ key = issuer_cfg.hs256_secret if issuer_cfg.algorithm == "HS256" else ""
180
+ access_token = encode_jwt(access, algorithm=issuer_cfg.algorithm, key=key)
181
+
182
+ refresh_token = None
183
+ if auth_store:
184
+ jti = auth_store.issue_refresh(sub, issuer_cfg.refresh_ttl_seconds, fingerprint=fingerprint)
185
+ refresh_token = jti
186
+ return {"access_token": access_token, "token_type": "bearer", "refresh_token": refresh_token}
187
+
188
+ @router.post("/auth/refresh")
189
+ async def refresh(user_id: str, refresh_token: str, fingerprint: Optional[str] = None):
190
+ if not auth_store:
191
+ raise HTTPException(status_code=400, detail="invalid_request")
192
+ if not auth_store.verify_refresh(refresh_token, user_id, fingerprint=fingerprint):
193
+ raise HTTPException(status_code=401, detail="invalid_grant")
194
+ now = now_ts()
195
+ access = {
196
+ "iss": issuer_cfg.issuer,
197
+ "aud": issuer_cfg.audience,
198
+ "sub": user_id,
199
+ "iat": now,
200
+ "nbf": now,
201
+ "exp": now + issuer_cfg.access_ttl_seconds,
202
+ }
203
+ key = issuer_cfg.hs256_secret if issuer_cfg.algorithm == "HS256" else ""
204
+ access_token = encode_jwt(access, algorithm=issuer_cfg.algorithm, key=key)
205
+ return {"access_token": access_token, "token_type": "bearer"}
206
+
207
+ app.include_router(router)
208
+
209
+ # API key management endpoints (require SQL store)
210
+ if auth_store:
211
+ api_router = APIRouter()
212
+
213
+ @api_router.post("/auth/api-keys")
214
+ async def create_api_key(user_id: Optional[str] = None, name: Optional[str] = None):
215
+ key_id, key_plain = auth_store.create_api_key(user_id=user_id, name=name)
216
+ return {"id": key_id, "api_key": key_plain}
217
+
218
+ @api_router.get("/auth/api-keys")
219
+ async def list_api_keys():
220
+ return auth_store.list_api_keys()
221
+
222
+ @api_router.delete("/auth/api-keys/{key_id}")
223
+ async def delete_api_key(key_id: str):
224
+ auth_store.revoke_api_key(key_id)
225
+ return {"ok": True}
226
+
227
+ app.include_router(api_router)
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Optional
5
+
6
+
7
+ @dataclass
8
+ class JwtValidatorConfig:
9
+ # Accept JWTs from external issuers (e.g., Google/Auth0/Okta) or our own issuer
10
+ jwks_url: Optional[str] = None
11
+ issuer: Optional[str] = None
12
+ audience: Optional[str] = None
13
+ # If you want to validate with an HS256 shared secret (tests/dev)
14
+ hs256_secret: Optional[str] = None
15
+
16
+
17
+ @dataclass
18
+ class JwtIssuerConfig:
19
+ # Configure our own issuer if we issue tokens
20
+ enabled: bool = False
21
+ issuer: str = "https://example-issuer"
22
+ audience: str = "adk-api"
23
+ algorithm: str = "HS256" # HS256 or RS256/ES256 later
24
+ hs256_secret: Optional[str] = None
25
+ access_ttl_seconds: int = 3600
26
+ refresh_ttl_seconds: int = 60 * 60 * 24 * 14
27
+ # SQL store for users/refresh tokens
28
+ database_url: Optional[str] = None # e.g. sqlite:///auth.db
29
+
30
+
31
+ @dataclass
32
+ class AuthConfig:
33
+ # Global toggle
34
+ enabled: bool = False
35
+ # Modes
36
+ allow_no_auth: bool = False # if True, bypass checks entirely
37
+ api_keys: List[str] = field(default_factory=list) # accepted API keys
38
+ basic_users: dict[str, str] = field(default_factory=dict) # username -> password (PBKDF2 hash or plaintext for tests)
39
+ jwt_validator: Optional[JwtValidatorConfig] = None
40
+ jwt_issuer: Optional[JwtIssuerConfig] = None
41
+ # Route policy toggles
42
+ protect_list_apps: bool = True
43
+ protect_metrics: bool = True
44
+ # Scopes are advisory; we currently validate presence of a token and subject. Extend as needed.
45
+
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import time
6
+ from typing import Any, Dict, Optional
7
+
8
+ import jwt
9
+ from jwt import PyJWKClient
10
+
11
+
12
+ def _b64url(data: bytes) -> str:
13
+ return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
14
+
15
+
16
+ def encode_jwt(payload: Dict[str, Any], *, algorithm: str, key: str, headers: Optional[Dict[str, Any]] = None) -> str:
17
+ return jwt.encode(payload, key, algorithm=algorithm, headers=headers)
18
+
19
+
20
+ def decode_jwt(token: str, *, issuer: Optional[str] = None, audience: Optional[str] = None,
21
+ jwks_url: Optional[str] = None, hs256_secret: Optional[str] = None) -> Dict[str, Any]:
22
+ options = {"verify_signature": True, "verify_exp": True, "verify_nbf": True}
23
+ if jwks_url:
24
+ jwk_client = PyJWKClient(jwks_url)
25
+ signing_key = jwk_client.get_signing_key_from_jwt(token).key
26
+ return jwt.decode(token, signing_key, algorithms=["RS256", "ES256"], audience=audience, issuer=issuer, options=options)
27
+ elif hs256_secret:
28
+ return jwt.decode(token, hs256_secret, algorithms=["HS256"], audience=audience, issuer=issuer, options=options)
29
+ else:
30
+ raise ValueError("No validation method configured (jwks_url or hs256_secret required)")
31
+
32
+
33
+ def now_ts() -> int:
34
+ return int(time.time())
35
+
36
+
@@ -0,0 +1,183 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import secrets
6
+ from datetime import datetime, timedelta, timezone
7
+ from typing import Optional
8
+
9
+ try:
10
+ from sqlalchemy import Column, String, DateTime, create_engine, Text, Boolean
11
+ from sqlalchemy.orm import declarative_base, sessionmaker
12
+ except ImportError as e:
13
+ raise ImportError(
14
+ "SQLAlchemy is required for the auth SQL store. Install with: pip install sqlalchemy"
15
+ ) from e
16
+
17
+
18
+ Base = declarative_base()
19
+
20
+
21
+ def _pbkdf2(password: str, salt: str) -> str:
22
+ dk = hashlib.pbkdf2_hmac("sha256", password.encode(), salt.encode(), 200_000)
23
+ return dk.hex()
24
+
25
+
26
+ def hash_password(password: str) -> str:
27
+ salt = secrets.token_hex(16)
28
+ return f"pbkdf2_sha256${salt}${_pbkdf2(password, salt)}"
29
+
30
+
31
+ def verify_password(password: str, stored: str) -> bool:
32
+ try:
33
+ algo, salt, digest = stored.split("$", 2)
34
+ if algo != "pbkdf2_sha256":
35
+ # fallback for plaintext in tests
36
+ return secrets.compare_digest(password, stored)
37
+ return secrets.compare_digest(_pbkdf2(password, salt), digest)
38
+ except Exception:
39
+ return secrets.compare_digest(password, stored)
40
+
41
+
42
+ class User(Base):
43
+ __tablename__ = "auth_users"
44
+ id = Column(String, primary_key=True)
45
+ username = Column(String, unique=True, index=True, nullable=False)
46
+ password_hash = Column(String, nullable=False)
47
+ roles = Column(String, default="") # comma-separated
48
+ created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
49
+ disabled = Column(Boolean, default=False)
50
+
51
+
52
+ class RefreshToken(Base):
53
+ __tablename__ = "auth_refresh_tokens"
54
+ jti = Column(String, primary_key=True)
55
+ user_id = Column(String, index=True, nullable=False)
56
+ expires_at = Column(DateTime(timezone=True), nullable=False)
57
+ revoked_at = Column(DateTime(timezone=True), nullable=True)
58
+ fingerprint = Column(String, nullable=True)
59
+
60
+
61
+ class ApiKey(Base):
62
+ __tablename__ = "auth_api_keys"
63
+ id = Column(String, primary_key=True)
64
+ user_id = Column(String, index=True, nullable=True)
65
+ key_hash = Column(String, nullable=False)
66
+ name = Column(String, nullable=True)
67
+ created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
68
+ revoked_at = Column(DateTime(timezone=True), nullable=True)
69
+
70
+
71
+ class AuthStore:
72
+ def __init__(self, database_url: str):
73
+ self.engine = create_engine(database_url)
74
+ Base.metadata.create_all(self.engine)
75
+ self.Session = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
76
+
77
+ def create_user(self, username: str, password: str, user_id: Optional[str] = None) -> str:
78
+ import uuid
79
+ uid = user_id or str(uuid.uuid4())
80
+ with self.Session() as s:
81
+ u = User(id=uid, username=username, password_hash=hash_password(password))
82
+ s.add(u)
83
+ s.commit()
84
+ return uid
85
+
86
+ def authenticate_basic(self, username: str, password: str) -> Optional[str]:
87
+ with self.Session() as s:
88
+ u: Optional[User] = s.query(User).filter_by(username=username).first()
89
+ if not u or u.disabled:
90
+ return None
91
+ if verify_password(password, u.password_hash):
92
+ return u.id
93
+ return None
94
+
95
+ def issue_refresh(self, user_id: str, ttl_seconds: int, fingerprint: Optional[str] = None) -> str:
96
+ import uuid
97
+ jti = str(uuid.uuid4())
98
+ with self.Session() as s:
99
+ rt = RefreshToken(
100
+ jti=jti,
101
+ user_id=user_id,
102
+ expires_at=datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds),
103
+ fingerprint=fingerprint,
104
+ )
105
+ s.add(rt)
106
+ s.commit()
107
+ return jti
108
+
109
+ def verify_refresh(self, jti: str, user_id: str, fingerprint: Optional[str] = None) -> bool:
110
+ with self.Session() as s:
111
+ rt: Optional[RefreshToken] = s.query(RefreshToken).filter_by(jti=jti, user_id=user_id).first()
112
+ if not rt or rt.revoked_at is not None:
113
+ return False
114
+ if rt.expires_at <= datetime.now(timezone.utc):
115
+ return False
116
+ if fingerprint and rt.fingerprint and rt.fingerprint != fingerprint:
117
+ return False
118
+ return True
119
+
120
+ def revoke_refresh(self, jti: str) -> None:
121
+ with self.Session() as s:
122
+ rt: Optional[RefreshToken] = s.query(RefreshToken).filter_by(jti=jti).first()
123
+ if not rt:
124
+ return
125
+ rt.revoked_at = datetime.now(timezone.utc)
126
+ s.add(rt)
127
+ s.commit()
128
+
129
+ # API Keys
130
+ def _hash_api_key(self, key: str) -> str:
131
+ # Reuse PBKDF2; different prefix
132
+ salt = secrets.token_hex(16)
133
+ return f"api_pbkdf2_sha256${salt}${_pbkdf2(key, salt)}"
134
+
135
+ def _verify_api_key(self, key: str, stored: str) -> bool:
136
+ try:
137
+ algo, salt, digest = stored.split("$", 3)
138
+ if algo != "api_pbkdf2_sha256":
139
+ return secrets.compare_digest(key, stored)
140
+ return secrets.compare_digest(_pbkdf2(key, salt), digest)
141
+ except Exception:
142
+ return False
143
+
144
+ def create_api_key(self, user_id: Optional[str] = None, name: Optional[str] = None) -> tuple[str, str]:
145
+ import uuid
146
+ key_plain = secrets.token_urlsafe(32)
147
+ key_id = str(uuid.uuid4())
148
+ with self.Session() as s:
149
+ rec = ApiKey(id=key_id, user_id=user_id, key_hash=self._hash_api_key(key_plain), name=name)
150
+ s.add(rec)
151
+ s.commit()
152
+ return key_id, key_plain
153
+
154
+ def list_api_keys(self):
155
+ with self.Session() as s:
156
+ rows = s.query(ApiKey).all()
157
+ return [
158
+ {
159
+ "id": r.id,
160
+ "user_id": r.user_id,
161
+ "name": r.name,
162
+ "created_at": r.created_at.isoformat() if r.created_at else None,
163
+ "revoked": r.revoked_at is not None,
164
+ }
165
+ for r in rows
166
+ ]
167
+
168
+ def revoke_api_key(self, key_id: str) -> None:
169
+ with self.Session() as s:
170
+ rec = s.query(ApiKey).filter_by(id=key_id).first()
171
+ if not rec:
172
+ return
173
+ rec.revoked_at = datetime.now(timezone.utc)
174
+ s.add(rec)
175
+ s.commit()
176
+
177
+ def verify_api_key(self, key: str) -> bool:
178
+ with self.Session() as s:
179
+ rows = s.query(ApiKey).filter(ApiKey.revoked_at.is_(None)).all()
180
+ for r in rows:
181
+ if self._verify_api_key(key, r.key_hash):
182
+ return True
183
+ return False
@@ -7,7 +7,7 @@ thread-safe registry management.
7
7
 
8
8
  import logging
9
9
  import threading
10
- from typing import Dict, List, Optional
10
+ from typing import Dict, List
11
11
 
12
12
  from google.adk.agents.base_agent import BaseAgent
13
13
  from google.adk.cli.utils.base_agent_loader import BaseAgentLoader
@@ -5,13 +5,11 @@ AdkWebServer to use our EnhancedRunner with advanced features.
5
5
  """
6
6
 
7
7
  import os
8
- from typing import Optional
9
8
 
10
9
  from google.adk.cli.adk_web_server import AdkWebServer
11
10
  from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
12
11
  from google.adk.cli.utils import cleanup
13
12
  from google.adk.cli.utils import envs
14
- from google.adk.runners import Runner
15
13
 
16
14
  from .enhanced_runner import EnhancedRunner
17
15
 
@@ -22,7 +22,6 @@ from watchdog.observers import Observer
22
22
 
23
23
  from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
24
24
  from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
25
- from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
26
25
  from google.adk.auth.credential_service.base_credential_service import BaseCredentialService
27
26
  from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
28
27
  from google.adk.evaluation.local_eval_sets_manager import LocalEvalSetsManager
@@ -35,6 +34,7 @@ from google.adk.sessions.database_session_service import DatabaseSessionService
35
34
  from google.adk.utils.feature_decorator import working_in_progress
36
35
  from google.adk.cli.adk_web_server import AdkWebServer
37
36
  from .enhanced_adk_web_server import EnhancedAdkWebServer
37
+ from .auth import attach_auth, AuthConfig, JwtIssuerConfig, JwtValidatorConfig
38
38
  from .streaming import StreamingController, StreamingConfig
39
39
  from google.adk.cli.utils import envs
40
40
  from google.adk.cli.utils import evals
@@ -69,6 +69,8 @@ def get_enhanced_fast_api_app(
69
69
  # Streaming layer (optional)
70
70
  enable_streaming: bool = False,
71
71
  streaming_config: Optional[StreamingConfig] = None,
72
+ # Auth layer (optional)
73
+ auth_config: Optional[AuthConfig] = None,
72
74
  ) -> FastAPI:
73
75
  """Enhanced version of Google ADK's get_fast_api_app with EnhancedRunner integration.
74
76
 
@@ -601,4 +603,7 @@ def get_enhanced_fast_api_app(
601
603
 
602
604
  app.include_router(router)
603
605
 
606
+ # Attach optional auth layer last so all routes are covered
607
+ attach_auth(app, auth_config)
608
+
604
609
  return app
@@ -1,6 +1,5 @@
1
1
  """MongoDB-based memory service implementation using PyMongo."""
2
2
 
3
- import json
4
3
  import logging
5
4
  from typing import Optional, List
6
5
  import re
@@ -2,7 +2,7 @@
2
2
 
3
3
  import json
4
4
  import logging
5
- from typing import Optional, List
5
+ from typing import Optional
6
6
  import re
7
7
  from datetime import datetime, timezone
8
8
 
@@ -1,9 +1,7 @@
1
1
  """YAML file-based memory service implementation."""
2
2
 
3
- import os
4
- import json
5
3
  import logging
6
- from typing import Optional, List
4
+ from typing import List
7
5
  from pathlib import Path
8
6
  import re
9
7
  from datetime import datetime
@@ -1,6 +1,5 @@
1
1
  """MongoDB-based session service implementation."""
2
2
 
3
- import json
4
3
  import time
5
4
  import uuid
6
5
  from typing import Any, Optional, Dict
@@ -3,7 +3,7 @@
3
3
  import json
4
4
  import time
5
5
  import uuid
6
- from typing import Any, Optional, Dict
6
+ from typing import Any, Optional
7
7
 
8
8
  try:
9
9
  import redis
@@ -1,9 +1,7 @@
1
1
  """YAML file-based session service implementation."""
2
2
 
3
- import json
4
3
  import time
5
4
  import uuid
6
- import os
7
5
  from typing import Any, Optional
8
6
  from pathlib import Path
9
7
 
@@ -1,9 +1,9 @@
1
1
  import asyncio
2
2
  import time
3
3
  from dataclasses import dataclass, field
4
- from typing import Any, Dict, Optional, Set, Callable, Awaitable
4
+ from typing import Any, Dict, Optional, Callable, Awaitable
5
5
 
6
- from fastapi import WebSocket, HTTPException
6
+ from fastapi import HTTPException
7
7
  from pydantic import BaseModel
8
8
 
9
9
  from google.adk.events.event import Event