google-adk-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_adk_extras/__init__.py +3 -3
- google_adk_extras/adk_builder.py +15 -292
- google_adk_extras/artifacts/local_folder_artifact_service.py +0 -2
- google_adk_extras/artifacts/mongo_artifact_service.py +0 -1
- google_adk_extras/artifacts/s3_artifact_service.py +0 -1
- google_adk_extras/artifacts/sql_artifact_service.py +0 -1
- google_adk_extras/custom_agent_loader.py +1 -1
- google_adk_extras/enhanced_adk_web_server.py +0 -2
- google_adk_extras/enhanced_fastapi.py +97 -1
- google_adk_extras/memory/mongo_memory_service.py +0 -1
- google_adk_extras/memory/sql_memory_service.py +1 -1
- google_adk_extras/memory/yaml_file_memory_service.py +1 -3
- google_adk_extras/sessions/mongo_session_service.py +0 -1
- google_adk_extras/sessions/redis_session_service.py +1 -1
- google_adk_extras/sessions/yaml_file_session_service.py +0 -2
- google_adk_extras/streaming/__init__.py +12 -0
- google_adk_extras/streaming/streaming_controller.py +262 -0
- {google_adk_extras-0.2.5.dist-info → google_adk_extras-0.2.7.dist-info}/METADATA +12 -34
- google_adk_extras-0.2.7.dist-info/RECORD +32 -0
- google_adk_extras/credentials/__init__.py +0 -34
- google_adk_extras/credentials/github_oauth2_credential_service.py +0 -213
- google_adk_extras/credentials/google_oauth2_credential_service.py +0 -216
- google_adk_extras/credentials/http_basic_auth_credential_service.py +0 -388
- google_adk_extras/credentials/jwt_credential_service.py +0 -345
- google_adk_extras/credentials/microsoft_oauth2_credential_service.py +0 -250
- google_adk_extras/credentials/x_oauth2_credential_service.py +0 -240
- google_adk_extras-0.2.5.dist-info/RECORD +0 -37
- {google_adk_extras-0.2.5.dist-info → google_adk_extras-0.2.7.dist-info}/WHEEL +0 -0
- {google_adk_extras-0.2.5.dist-info → google_adk_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
- {google_adk_extras-0.2.5.dist-info → google_adk_extras-0.2.7.dist-info}/top_level.txt +0 -0
google_adk_extras/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
"""Production-ready services
|
1
|
+
"""Production-ready services and FastAPI wiring for Google ADK.
|
2
2
|
|
3
3
|
Public API surface:
|
4
4
|
- AdkBuilder
|
@@ -11,7 +11,7 @@ Service groups are exposed via subpackages:
|
|
11
11
|
- google_adk_extras.sessions
|
12
12
|
- google_adk_extras.artifacts
|
13
13
|
- google_adk_extras.memory
|
14
|
-
|
14
|
+
(credential services are provided by ADK; no custom extras here)
|
15
15
|
"""
|
16
16
|
|
17
17
|
from .adk_builder import AdkBuilder
|
@@ -28,4 +28,4 @@ __all__ = [
|
|
28
28
|
"CustomAgentLoader",
|
29
29
|
]
|
30
30
|
|
31
|
-
__version__ = "0.2.
|
31
|
+
__version__ = "0.2.7"
|
google_adk_extras/adk_builder.py
CHANGED
@@ -4,9 +4,7 @@ This module provides the AdkBuilder class that extends Google ADK's FastAPI inte
|
|
4
4
|
with support for custom credential services and enhanced configuration options.
|
5
5
|
"""
|
6
6
|
|
7
|
-
import os
|
8
7
|
import logging
|
9
|
-
from pathlib import Path
|
10
8
|
from typing import Any, Dict, List, Mapping, Optional, Union, Callable
|
11
9
|
from starlette.types import Lifespan
|
12
10
|
|
@@ -15,20 +13,16 @@ from google.adk.runners import Runner
|
|
15
13
|
from google.adk.agents.base_agent import BaseAgent
|
16
14
|
from google.adk.sessions.base_session_service import BaseSessionService
|
17
15
|
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
18
|
-
from google.adk.sessions.database_session_service import DatabaseSessionService
|
19
16
|
from google.adk.artifacts.base_artifact_service import BaseArtifactService
|
20
17
|
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
21
18
|
# GCS removed - vendor specific
|
22
19
|
from google.adk.memory.base_memory_service import BaseMemoryService
|
23
20
|
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
|
24
21
|
from google.adk.auth.credential_service.base_credential_service import BaseCredentialService
|
25
|
-
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
|
26
22
|
from google.adk.cli.utils.agent_loader import AgentLoader
|
27
23
|
from google.adk.cli.utils.base_agent_loader import BaseAgentLoader
|
28
|
-
from google.adk.cli.adk_web_server import AdkWebServer
|
29
24
|
|
30
25
|
from .custom_agent_loader import CustomAgentLoader
|
31
|
-
from .credentials.base_custom_credential_service import BaseCustomCredentialService
|
32
26
|
|
33
27
|
logger = logging.getLogger(__name__)
|
34
28
|
|
@@ -43,24 +37,19 @@ class AdkBuilder:
|
|
43
37
|
Example:
|
44
38
|
```python
|
45
39
|
from google_adk_extras import AdkBuilder
|
46
|
-
|
40
|
+
# (removed) custom credential service example
|
47
41
|
|
48
42
|
# Build FastAPI app with Google OAuth2 credentials
|
49
43
|
app = (AdkBuilder()
|
50
44
|
.with_agents_dir("./agents")
|
51
45
|
.with_session_service("sqlite:///sessions.db")
|
52
|
-
|
53
|
-
client_id="your-client-id",
|
54
|
-
client_secret="your-secret",
|
55
|
-
scopes=["calendar", "gmail.readonly"]
|
56
|
-
))
|
46
|
+
# credentials: rely on ADK defaults or pass an ADK BaseCredentialService explicitly
|
57
47
|
.with_web_ui()
|
58
48
|
.build_fastapi_app())
|
59
49
|
|
60
50
|
# Or build a Runner directly
|
61
51
|
runner = (AdkBuilder()
|
62
52
|
.with_agents_dir("./agents")
|
63
|
-
.with_credential_service_uri("oauth2-google://client-id:secret@scopes=calendar,gmail.readonly")
|
64
53
|
.build_runner("my_agent"))
|
65
54
|
```
|
66
55
|
"""
|
@@ -75,7 +64,7 @@ class AdkBuilder:
|
|
75
64
|
self._session_service_uri: Optional[str] = None
|
76
65
|
self._artifact_service_uri: Optional[str] = None
|
77
66
|
self._memory_service_uri: Optional[str] = None
|
78
|
-
|
67
|
+
# Note: custom credential-service URI parsing has been removed.
|
79
68
|
self._eval_storage_uri: Optional[str] = None
|
80
69
|
|
81
70
|
# Service instances (alternative to URIs)
|
@@ -190,25 +179,7 @@ class AdkBuilder:
|
|
190
179
|
self._memory_service_uri = uri
|
191
180
|
return self
|
192
181
|
|
193
|
-
|
194
|
-
"""Configure credential service using URI.
|
195
|
-
|
196
|
-
Supported URIs:
|
197
|
-
- "oauth2-google://client-id:secret@scopes=scope1,scope2"
|
198
|
-
- "oauth2-github://client-id:secret@scopes=user,repo"
|
199
|
-
- "oauth2-microsoft://tenant-id/client-id:secret@scopes=User.Read"
|
200
|
-
- "oauth2-x://client-id:secret@scopes=tweet.read,users.read"
|
201
|
-
- "jwt://secret@algorithm=HS256&issuer=my-app&audience=api.example.com&expiration_minutes=60"
|
202
|
-
- "basic-auth://username:password@realm=My API"
|
203
|
-
|
204
|
-
Args:
|
205
|
-
uri: Credential service URI.
|
206
|
-
|
207
|
-
Returns:
|
208
|
-
AdkBuilder: Self for method chaining.
|
209
|
-
"""
|
210
|
-
self._credential_service_uri = uri
|
211
|
-
return self
|
182
|
+
# Removed: URI-based credential service configuration. Use with_credential_service(instance) if needed.
|
212
183
|
|
213
184
|
def with_eval_storage(self, uri: str) -> "AdkBuilder":
|
214
185
|
"""Configure evaluation storage using URI.
|
@@ -606,15 +577,8 @@ class AdkBuilder:
|
|
606
577
|
return InMemoryMemoryService()
|
607
578
|
|
608
579
|
def _create_credential_service(self) -> Optional[BaseCredentialService]:
|
609
|
-
"""
|
610
|
-
|
611
|
-
return self._credential_service
|
612
|
-
|
613
|
-
if self._credential_service_uri:
|
614
|
-
return self._parse_credential_service_uri(self._credential_service_uri)
|
615
|
-
|
616
|
-
# No credential service configured; allow server to default
|
617
|
-
return None
|
580
|
+
"""Return explicitly provided ADK credential service instance (optional)."""
|
581
|
+
return self._credential_service
|
618
582
|
|
619
583
|
def _create_agent_loader(self) -> BaseAgentLoader:
|
620
584
|
"""Create agent loader from configuration.
|
@@ -669,237 +633,19 @@ class AdkBuilder:
|
|
669
633
|
"or with_agent_loader() to configure agents."
|
670
634
|
)
|
671
635
|
|
672
|
-
|
673
|
-
"""Parse credential service URI and create appropriate service.
|
674
|
-
|
675
|
-
Args:
|
676
|
-
uri: Credential service URI.
|
677
|
-
|
678
|
-
Returns:
|
679
|
-
BaseCredentialService: Configured credential service.
|
680
|
-
|
681
|
-
Raises:
|
682
|
-
ValueError: If URI format is invalid or unsupported.
|
683
|
-
"""
|
684
|
-
try:
|
685
|
-
if uri.startswith("oauth2-google://"):
|
686
|
-
return self._parse_google_oauth2_uri(uri)
|
687
|
-
elif uri.startswith("oauth2-github://"):
|
688
|
-
return self._parse_github_oauth2_uri(uri)
|
689
|
-
elif uri.startswith("oauth2-microsoft://"):
|
690
|
-
return self._parse_microsoft_oauth2_uri(uri)
|
691
|
-
elif uri.startswith("oauth2-x://"):
|
692
|
-
return self._parse_x_oauth2_uri(uri)
|
693
|
-
elif uri.startswith("jwt://"):
|
694
|
-
return self._parse_jwt_uri(uri)
|
695
|
-
elif uri.startswith("basic-auth://"):
|
696
|
-
return self._parse_basic_auth_uri(uri)
|
697
|
-
else:
|
698
|
-
raise ValueError(f"Unsupported credential service URI scheme: {uri}")
|
699
|
-
except Exception as e:
|
700
|
-
raise ValueError(f"Failed to parse credential service URI '{uri}': {e}")
|
636
|
+
# Removed: all URI-based credential parsing helpers for credentials.
|
701
637
|
|
702
|
-
|
703
|
-
"""Parse Google OAuth2 URI: oauth2-google://client-id:secret@scopes=scope1,scope2"""
|
704
|
-
from .credentials.google_oauth2_credential_service import (
|
705
|
-
GoogleOAuth2CredentialService,
|
706
|
-
)
|
707
|
-
# Remove scheme
|
708
|
-
uri_part = uri[len("oauth2-google://"):]
|
709
|
-
|
710
|
-
# Split at @
|
711
|
-
if "@" in uri_part:
|
712
|
-
credentials_part, params_part = uri_part.split("@", 1)
|
713
|
-
else:
|
714
|
-
credentials_part = uri_part
|
715
|
-
params_part = ""
|
716
|
-
|
717
|
-
# Parse credentials
|
718
|
-
if ":" in credentials_part:
|
719
|
-
client_id, client_secret = credentials_part.split(":", 1)
|
720
|
-
else:
|
721
|
-
raise ValueError("Google OAuth2 URI must include client_id:client_secret")
|
722
|
-
|
723
|
-
# Parse parameters
|
724
|
-
scopes = []
|
725
|
-
if params_part:
|
726
|
-
for param in params_part.split("&"):
|
727
|
-
if param.startswith("scopes="):
|
728
|
-
scopes = param[7:].split(",")
|
729
|
-
|
730
|
-
return GoogleOAuth2CredentialService(
|
731
|
-
client_id=client_id,
|
732
|
-
client_secret=client_secret,
|
733
|
-
scopes=scopes or ["openid", "email", "profile"]
|
734
|
-
)
|
638
|
+
# (removed) _parse_google_oauth2_uri
|
735
639
|
|
736
|
-
|
737
|
-
"""Parse GitHub OAuth2 URI: oauth2-github://client-id:secret@scopes=user,repo"""
|
738
|
-
from .credentials.github_oauth2_credential_service import (
|
739
|
-
GitHubOAuth2CredentialService,
|
740
|
-
)
|
741
|
-
uri_part = uri[len("oauth2-github://"):]
|
742
|
-
|
743
|
-
if "@" in uri_part:
|
744
|
-
credentials_part, params_part = uri_part.split("@", 1)
|
745
|
-
else:
|
746
|
-
credentials_part = uri_part
|
747
|
-
params_part = ""
|
748
|
-
|
749
|
-
if ":" in credentials_part:
|
750
|
-
client_id, client_secret = credentials_part.split(":", 1)
|
751
|
-
else:
|
752
|
-
raise ValueError("GitHub OAuth2 URI must include client_id:client_secret")
|
753
|
-
|
754
|
-
scopes = []
|
755
|
-
if params_part:
|
756
|
-
for param in params_part.split("&"):
|
757
|
-
if param.startswith("scopes="):
|
758
|
-
scopes = param[7:].split(",")
|
759
|
-
|
760
|
-
return GitHubOAuth2CredentialService(
|
761
|
-
client_id=client_id,
|
762
|
-
client_secret=client_secret,
|
763
|
-
scopes=scopes or ["user"]
|
764
|
-
)
|
640
|
+
# (removed) _parse_github_oauth2_uri
|
765
641
|
|
766
|
-
|
767
|
-
"""Parse Microsoft OAuth2 URI: oauth2-microsoft://tenant-id/client-id:secret@scopes=User.Read"""
|
768
|
-
from .credentials.microsoft_oauth2_credential_service import (
|
769
|
-
MicrosoftOAuth2CredentialService,
|
770
|
-
)
|
771
|
-
uri_part = uri[len("oauth2-microsoft://"):]
|
772
|
-
|
773
|
-
if "@" in uri_part:
|
774
|
-
credentials_part, params_part = uri_part.split("@", 1)
|
775
|
-
else:
|
776
|
-
credentials_part = uri_part
|
777
|
-
params_part = ""
|
778
|
-
|
779
|
-
# Parse tenant_id/client_id:secret
|
780
|
-
if "/" in credentials_part:
|
781
|
-
tenant_part, client_part = credentials_part.split("/", 1)
|
782
|
-
else:
|
783
|
-
raise ValueError("Microsoft OAuth2 URI must include tenant-id/client-id:secret")
|
784
|
-
|
785
|
-
if ":" in client_part:
|
786
|
-
client_id, client_secret = client_part.split(":", 1)
|
787
|
-
else:
|
788
|
-
raise ValueError("Microsoft OAuth2 URI must include client_id:client_secret")
|
789
|
-
|
790
|
-
scopes = []
|
791
|
-
if params_part:
|
792
|
-
for param in params_part.split("&"):
|
793
|
-
if param.startswith("scopes="):
|
794
|
-
scopes = param[7:].split(",")
|
795
|
-
|
796
|
-
return MicrosoftOAuth2CredentialService(
|
797
|
-
tenant_id=tenant_part,
|
798
|
-
client_id=client_id,
|
799
|
-
client_secret=client_secret,
|
800
|
-
scopes=scopes or ["User.Read"]
|
801
|
-
)
|
642
|
+
# (removed) _parse_microsoft_oauth2_uri
|
802
643
|
|
803
|
-
|
804
|
-
"""Parse X OAuth2 URI: oauth2-x://client-id:secret@scopes=tweet.read,users.read"""
|
805
|
-
from .credentials.x_oauth2_credential_service import XOAuth2CredentialService
|
806
|
-
uri_part = uri[len("oauth2-x://"):]
|
807
|
-
|
808
|
-
if "@" in uri_part:
|
809
|
-
credentials_part, params_part = uri_part.split("@", 1)
|
810
|
-
else:
|
811
|
-
credentials_part = uri_part
|
812
|
-
params_part = ""
|
813
|
-
|
814
|
-
if ":" in credentials_part:
|
815
|
-
client_id, client_secret = credentials_part.split(":", 1)
|
816
|
-
else:
|
817
|
-
raise ValueError("X OAuth2 URI must include client_id:client_secret")
|
818
|
-
|
819
|
-
scopes = []
|
820
|
-
if params_part:
|
821
|
-
for param in params_part.split("&"):
|
822
|
-
if param.startswith("scopes="):
|
823
|
-
scopes = param[7:].split(",")
|
824
|
-
|
825
|
-
return XOAuth2CredentialService(
|
826
|
-
client_id=client_id,
|
827
|
-
client_secret=client_secret,
|
828
|
-
scopes=scopes or ["tweet.read", "users.read", "offline.access"]
|
829
|
-
)
|
644
|
+
# (removed) _parse_x_oauth2_uri
|
830
645
|
|
831
|
-
|
832
|
-
"""Parse JWT URI: jwt://secret@algorithm=HS256&issuer=my-app&audience=api.example.com&expiration_minutes=60"""
|
833
|
-
from .credentials.jwt_credential_service import JWTCredentialService
|
834
|
-
uri_part = uri[len("jwt://"):]
|
835
|
-
|
836
|
-
if "@" in uri_part:
|
837
|
-
secret, params_part = uri_part.split("@", 1)
|
838
|
-
else:
|
839
|
-
secret = uri_part
|
840
|
-
params_part = ""
|
841
|
-
|
842
|
-
# Parse parameters
|
843
|
-
algorithm = "HS256"
|
844
|
-
issuer = None
|
845
|
-
audience = None
|
846
|
-
expiration_minutes = 60
|
847
|
-
custom_claims = {}
|
848
|
-
|
849
|
-
if params_part:
|
850
|
-
for param in params_part.split("&"):
|
851
|
-
if "=" in param:
|
852
|
-
key, value = param.split("=", 1)
|
853
|
-
if key == "algorithm":
|
854
|
-
algorithm = value
|
855
|
-
elif key == "issuer":
|
856
|
-
issuer = value
|
857
|
-
elif key == "audience":
|
858
|
-
audience = value
|
859
|
-
elif key == "expiration_minutes":
|
860
|
-
expiration_minutes = int(value)
|
861
|
-
else:
|
862
|
-
# Custom claim
|
863
|
-
custom_claims[key] = value
|
864
|
-
|
865
|
-
return JWTCredentialService(
|
866
|
-
secret=secret,
|
867
|
-
algorithm=algorithm,
|
868
|
-
issuer=issuer,
|
869
|
-
audience=audience,
|
870
|
-
expiration_minutes=expiration_minutes,
|
871
|
-
custom_claims=custom_claims
|
872
|
-
)
|
646
|
+
# (removed) _parse_jwt_uri
|
873
647
|
|
874
|
-
|
875
|
-
"""Parse Basic Auth URI: basic-auth://username:password@realm=My API"""
|
876
|
-
from .credentials.http_basic_auth_credential_service import (
|
877
|
-
HTTPBasicAuthCredentialService,
|
878
|
-
)
|
879
|
-
uri_part = uri[len("basic-auth://"):]
|
880
|
-
|
881
|
-
if "@" in uri_part:
|
882
|
-
credentials_part, params_part = uri_part.split("@", 1)
|
883
|
-
else:
|
884
|
-
credentials_part = uri_part
|
885
|
-
params_part = ""
|
886
|
-
|
887
|
-
if ":" in credentials_part:
|
888
|
-
username, password = credentials_part.split(":", 1)
|
889
|
-
else:
|
890
|
-
raise ValueError("Basic Auth URI must include username:password")
|
891
|
-
|
892
|
-
realm = None
|
893
|
-
if params_part:
|
894
|
-
for param in params_part.split("&"):
|
895
|
-
if param.startswith("realm="):
|
896
|
-
realm = param[6:]
|
897
|
-
|
898
|
-
return HTTPBasicAuthCredentialService(
|
899
|
-
username=username,
|
900
|
-
password=password,
|
901
|
-
realm=realm
|
902
|
-
)
|
648
|
+
# (removed) _parse_basic_auth_uri
|
903
649
|
|
904
650
|
# Build methods
|
905
651
|
def build_fastapi_app(self) -> FastAPI:
|
@@ -918,20 +664,7 @@ class AdkBuilder:
|
|
918
664
|
memory_service = self._create_memory_service()
|
919
665
|
credential_service = self._create_credential_service()
|
920
666
|
|
921
|
-
#
|
922
|
-
if isinstance(credential_service, BaseCustomCredentialService):
|
923
|
-
import asyncio
|
924
|
-
try:
|
925
|
-
# Try to initialize in current event loop
|
926
|
-
loop = asyncio.get_event_loop()
|
927
|
-
if loop.is_running():
|
928
|
-
# Create a task for initialization
|
929
|
-
asyncio.create_task(credential_service.initialize())
|
930
|
-
else:
|
931
|
-
loop.run_until_complete(credential_service.initialize())
|
932
|
-
except RuntimeError:
|
933
|
-
# No event loop, create one
|
934
|
-
asyncio.run(credential_service.initialize())
|
667
|
+
# No custom credential initialization; ADK services are passed through
|
935
668
|
|
936
669
|
# Use our enhanced FastAPI function that properly supports credential services
|
937
670
|
logger.info("Building FastAPI app with enhanced credential service support")
|
@@ -1005,17 +738,7 @@ class AdkBuilder:
|
|
1005
738
|
memory_service = self._create_memory_service()
|
1006
739
|
credential_service = self._create_credential_service()
|
1007
740
|
|
1008
|
-
#
|
1009
|
-
if isinstance(credential_service, BaseCustomCredentialService):
|
1010
|
-
import asyncio
|
1011
|
-
try:
|
1012
|
-
loop = asyncio.get_event_loop()
|
1013
|
-
if loop.is_running():
|
1014
|
-
asyncio.create_task(credential_service.initialize())
|
1015
|
-
else:
|
1016
|
-
loop.run_until_complete(credential_service.initialize())
|
1017
|
-
except RuntimeError:
|
1018
|
-
asyncio.run(credential_service.initialize())
|
741
|
+
# No custom credential initialization; ADK services are passed through
|
1019
742
|
|
1020
743
|
# Create Runner with all services
|
1021
744
|
app_name = self._app_name or (agent_or_agent_name if isinstance(agent_or_agent_name, str) else "default_app")
|
@@ -7,7 +7,7 @@ thread-safe registry management.
|
|
7
7
|
|
8
8
|
import logging
|
9
9
|
import threading
|
10
|
-
from typing import Dict, List
|
10
|
+
from typing import Dict, List
|
11
11
|
|
12
12
|
from google.adk.agents.base_agent import BaseAgent
|
13
13
|
from google.adk.cli.utils.base_agent_loader import BaseAgentLoader
|
@@ -5,13 +5,11 @@ AdkWebServer to use our EnhancedRunner with advanced features.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import os
|
8
|
-
from typing import Optional
|
9
8
|
|
10
9
|
from google.adk.cli.adk_web_server import AdkWebServer
|
11
10
|
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
|
12
11
|
from google.adk.cli.utils import cleanup
|
13
12
|
from google.adk.cli.utils import envs
|
14
|
-
from google.adk.runners import Runner
|
15
13
|
|
16
14
|
from .enhanced_runner import EnhancedRunner
|
17
15
|
|
@@ -5,6 +5,7 @@ that properly supports custom credential services.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import json
|
8
|
+
import asyncio
|
8
9
|
import logging
|
9
10
|
import os
|
10
11
|
from pathlib import Path
|
@@ -21,7 +22,6 @@ from watchdog.observers import Observer
|
|
21
22
|
|
22
23
|
from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
|
23
24
|
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
24
|
-
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
|
25
25
|
from google.adk.auth.credential_service.base_credential_service import BaseCredentialService
|
26
26
|
from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
|
27
27
|
from google.adk.evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
@@ -34,6 +34,7 @@ from google.adk.sessions.database_session_service import DatabaseSessionService
|
|
34
34
|
from google.adk.utils.feature_decorator import working_in_progress
|
35
35
|
from google.adk.cli.adk_web_server import AdkWebServer
|
36
36
|
from .enhanced_adk_web_server import EnhancedAdkWebServer
|
37
|
+
from .streaming import StreamingController, StreamingConfig
|
37
38
|
from google.adk.cli.utils import envs
|
38
39
|
from google.adk.cli.utils import evals
|
39
40
|
from google.adk.cli.utils.agent_change_handler import AgentChangeEventHandler
|
@@ -64,6 +65,9 @@ def get_enhanced_fast_api_app(
|
|
64
65
|
trace_to_cloud: bool = False,
|
65
66
|
reload_agents: bool = False,
|
66
67
|
lifespan: Optional[Lifespan[FastAPI]] = None,
|
68
|
+
# Streaming layer (optional)
|
69
|
+
enable_streaming: bool = False,
|
70
|
+
streaming_config: Optional[StreamingConfig] = None,
|
67
71
|
) -> FastAPI:
|
68
72
|
"""Enhanced version of Google ADK's get_fast_api_app with EnhancedRunner integration.
|
69
73
|
|
@@ -504,4 +508,96 @@ def get_enhanced_fast_api_app(
|
|
504
508
|
logger.error("Failed to setup programmatic A2A agent %s: %s", app_name, e)
|
505
509
|
|
506
510
|
logger.info("Enhanced FastAPI app created with credential service support")
|
511
|
+
|
512
|
+
# Optional streaming mounts (SSE + WebSocket)
|
513
|
+
if enable_streaming:
|
514
|
+
cfg = streaming_config or StreamingConfig(enable_streaming=True)
|
515
|
+
controller = StreamingController(
|
516
|
+
config=cfg,
|
517
|
+
get_runner_async=adk_web_server.get_runner_async,
|
518
|
+
session_service=session_service,
|
519
|
+
)
|
520
|
+
app.state.streaming_controller = controller
|
521
|
+
@app.on_event("startup")
|
522
|
+
async def _start_streaming(): # pragma: no cover - lifecycle glue
|
523
|
+
controller.start()
|
524
|
+
@app.on_event("shutdown")
|
525
|
+
async def _stop_streaming(): # pragma: no cover - lifecycle glue
|
526
|
+
await controller.stop()
|
527
|
+
|
528
|
+
from fastapi import APIRouter, WebSocket, Query
|
529
|
+
from fastapi.responses import StreamingResponse
|
530
|
+
from google.adk.cli.adk_web_server import RunAgentRequest
|
531
|
+
|
532
|
+
router = APIRouter()
|
533
|
+
base = cfg.streaming_path_base.rstrip("/")
|
534
|
+
|
535
|
+
@router.get(f"{base}/events/{{channel_id}}")
|
536
|
+
async def stream_events(channel_id: str, appName: str = Query(...), userId: str = Query(...), sessionId: Optional[str] = Query(None)):
|
537
|
+
ch = await app.state.streaming_controller.open_or_bind_channel(
|
538
|
+
channel_id=channel_id, app_name=appName, user_id=userId, session_id=sessionId
|
539
|
+
)
|
540
|
+
q = app.state.streaming_controller.subscribe(channel_id, kind="sse")
|
541
|
+
|
542
|
+
async def gen():
|
543
|
+
try:
|
544
|
+
# Announce channel binding with session id
|
545
|
+
yield "event: channel-bound\n"
|
546
|
+
yield f"data: {{\"appName\":\"{appName}\",\"userId\":\"{userId}\",\"sessionId\":\"{ch.session_id}\"}}\n\n"
|
547
|
+
while True:
|
548
|
+
payload = await q.get()
|
549
|
+
yield f"data: {payload}\n\n"
|
550
|
+
except asyncio.CancelledError:
|
551
|
+
pass
|
552
|
+
finally:
|
553
|
+
app.state.streaming_controller.unsubscribe(channel_id, q)
|
554
|
+
|
555
|
+
return StreamingResponse(gen(), media_type="text/event-stream")
|
556
|
+
|
557
|
+
@router.post(f"{base}/send/{{channel_id}}")
|
558
|
+
async def send_message(channel_id: str, req: RunAgentRequest):
|
559
|
+
# Validation: channel binding must match
|
560
|
+
await app.state.streaming_controller.enqueue(channel_id, req)
|
561
|
+
return PlainTextResponse("", status_code=204)
|
562
|
+
|
563
|
+
@router.websocket(f"{base}/ws/{{channel_id}}")
|
564
|
+
async def ws_endpoint(websocket: WebSocket, channel_id: str, appName: str, userId: str, sessionId: Optional[str] = None):
|
565
|
+
await websocket.accept()
|
566
|
+
try:
|
567
|
+
await app.state.streaming_controller.open_or_bind_channel(
|
568
|
+
channel_id=channel_id, app_name=appName, user_id=userId, session_id=sessionId
|
569
|
+
)
|
570
|
+
q = app.state.streaming_controller.subscribe(channel_id, kind="ws")
|
571
|
+
# Send channel binding info including session id
|
572
|
+
await websocket.send_text(json.dumps({"event": "channel-bound", "appName": appName, "userId": userId, "sessionId": app.state.streaming_controller._channels[channel_id].session_id}))
|
573
|
+
|
574
|
+
async def downlink():
|
575
|
+
try:
|
576
|
+
while True:
|
577
|
+
payload = await q.get()
|
578
|
+
await websocket.send_text(payload)
|
579
|
+
except asyncio.CancelledError:
|
580
|
+
pass
|
581
|
+
|
582
|
+
async def uplink():
|
583
|
+
try:
|
584
|
+
while True:
|
585
|
+
text = await websocket.receive_text()
|
586
|
+
# Strict type parity by default
|
587
|
+
req = RunAgentRequest.model_validate_json(text)
|
588
|
+
await app.state.streaming_controller.enqueue(channel_id, req)
|
589
|
+
except Exception:
|
590
|
+
return
|
591
|
+
|
592
|
+
down = asyncio.create_task(downlink())
|
593
|
+
up = asyncio.create_task(uplink())
|
594
|
+
await asyncio.wait({down, up}, return_when=asyncio.FIRST_COMPLETED)
|
595
|
+
finally:
|
596
|
+
try:
|
597
|
+
app.state.streaming_controller.unsubscribe(channel_id, q)
|
598
|
+
except Exception:
|
599
|
+
pass
|
600
|
+
|
601
|
+
app.include_router(router)
|
602
|
+
|
507
603
|
return app
|