google-adk-extras 0.1.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_adk_extras/__init__.py +31 -1
- google_adk_extras/adk_builder.py +1030 -0
- google_adk_extras/artifacts/__init__.py +25 -12
- google_adk_extras/artifacts/base_custom_artifact_service.py +148 -11
- google_adk_extras/artifacts/local_folder_artifact_service.py +133 -13
- google_adk_extras/artifacts/s3_artifact_service.py +135 -19
- google_adk_extras/artifacts/sql_artifact_service.py +109 -10
- google_adk_extras/credentials/__init__.py +34 -0
- google_adk_extras/credentials/base_custom_credential_service.py +113 -0
- google_adk_extras/credentials/github_oauth2_credential_service.py +213 -0
- google_adk_extras/credentials/google_oauth2_credential_service.py +216 -0
- google_adk_extras/credentials/http_basic_auth_credential_service.py +388 -0
- google_adk_extras/credentials/jwt_credential_service.py +345 -0
- google_adk_extras/credentials/microsoft_oauth2_credential_service.py +250 -0
- google_adk_extras/credentials/x_oauth2_credential_service.py +240 -0
- google_adk_extras/custom_agent_loader.py +156 -0
- google_adk_extras/enhanced_adk_web_server.py +137 -0
- google_adk_extras/enhanced_fastapi.py +470 -0
- google_adk_extras/enhanced_runner.py +38 -0
- google_adk_extras/memory/__init__.py +30 -13
- google_adk_extras/memory/base_custom_memory_service.py +37 -5
- google_adk_extras/memory/sql_memory_service.py +105 -19
- google_adk_extras/memory/yaml_file_memory_service.py +115 -22
- google_adk_extras/sessions/__init__.py +29 -13
- google_adk_extras/sessions/base_custom_session_service.py +133 -11
- google_adk_extras/sessions/sql_session_service.py +127 -16
- google_adk_extras/sessions/yaml_file_session_service.py +122 -14
- google_adk_extras-0.2.3.dist-info/METADATA +302 -0
- google_adk_extras-0.2.3.dist-info/RECORD +37 -0
- google_adk_extras/py.typed +0 -0
- google_adk_extras-0.1.1.dist-info/METADATA +0 -175
- google_adk_extras-0.1.1.dist-info/RECORD +0 -25
- {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/WHEEL +0 -0
- {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {google_adk_extras-0.1.1.dist-info → google_adk_extras-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,345 @@
|
|
1
|
+
"""JWT credential service implementation."""
|
2
|
+
|
3
|
+
from typing import Optional, Dict, Any
|
4
|
+
import logging
|
5
|
+
import jwt
|
6
|
+
from datetime import datetime, timedelta, timezone
|
7
|
+
|
8
|
+
from google.adk.auth.credential_service.session_state_credential_service import SessionStateCredentialService
|
9
|
+
from google.adk.auth.credential_service.base_credential_service import CallbackContext
|
10
|
+
from google.adk.auth import AuthConfig, AuthCredential, AuthCredentialTypes
|
11
|
+
from google.adk.auth.auth_credential import HttpAuth, HttpCredentials
|
12
|
+
from fastapi.openapi.models import HTTPBearer
|
13
|
+
|
14
|
+
from .base_custom_credential_service import BaseCustomCredentialService
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class JWTCredentialService(BaseCustomCredentialService):
|
20
|
+
"""JWT credential service for handling JSON Web Token authentication.
|
21
|
+
|
22
|
+
This service generates and manages JWT tokens for API authentication.
|
23
|
+
It supports both short-lived and long-lived tokens with automatic refresh.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
secret: The secret key used to sign JWT tokens.
|
27
|
+
algorithm: The algorithm used for JWT signing. Default is 'HS256'.
|
28
|
+
issuer: The issuer of the JWT token. Optional.
|
29
|
+
audience: The intended audience of the JWT token. Optional.
|
30
|
+
expiration_minutes: Token expiration time in minutes. Default is 60 minutes.
|
31
|
+
custom_claims: Additional custom claims to include in the JWT payload.
|
32
|
+
use_session_state: If True, stores credentials in session state. If False,
|
33
|
+
uses in-memory storage. Default is True for persistence.
|
34
|
+
|
35
|
+
Example:
|
36
|
+
```python
|
37
|
+
credential_service = JWTCredentialService(
|
38
|
+
secret="your-jwt-secret",
|
39
|
+
algorithm="HS256",
|
40
|
+
issuer="my-app",
|
41
|
+
audience="api.example.com",
|
42
|
+
expiration_minutes=120,
|
43
|
+
custom_claims={"role": "admin", "permissions": ["read", "write"]}
|
44
|
+
)
|
45
|
+
await credential_service.initialize()
|
46
|
+
|
47
|
+
# Use with Runner
|
48
|
+
runner = Runner(
|
49
|
+
agent=agent,
|
50
|
+
session_service=session_service,
|
51
|
+
credential_service=credential_service,
|
52
|
+
app_name="my_app"
|
53
|
+
)
|
54
|
+
```
|
55
|
+
"""
|
56
|
+
|
57
|
+
SUPPORTED_ALGORITHMS = {
|
58
|
+
'HS256', 'HS384', 'HS512', # HMAC with SHA
|
59
|
+
'RS256', 'RS384', 'RS512', # RSA with SHA
|
60
|
+
'ES256', 'ES384', 'ES512' # ECDSA with SHA
|
61
|
+
}
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
secret: str,
|
66
|
+
algorithm: str = 'HS256',
|
67
|
+
issuer: Optional[str] = None,
|
68
|
+
audience: Optional[str] = None,
|
69
|
+
expiration_minutes: int = 60,
|
70
|
+
custom_claims: Optional[Dict[str, Any]] = None,
|
71
|
+
use_session_state: bool = True
|
72
|
+
):
|
73
|
+
"""Initialize the JWT credential service.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
secret: JWT signing secret.
|
77
|
+
algorithm: JWT signing algorithm.
|
78
|
+
issuer: JWT issuer.
|
79
|
+
audience: JWT audience.
|
80
|
+
expiration_minutes: Token expiration in minutes.
|
81
|
+
custom_claims: Additional claims to include in JWT.
|
82
|
+
use_session_state: Whether to use session state for credential storage.
|
83
|
+
"""
|
84
|
+
super().__init__()
|
85
|
+
self.secret = secret
|
86
|
+
self.algorithm = algorithm
|
87
|
+
self.issuer = issuer
|
88
|
+
self.audience = audience
|
89
|
+
self.expiration_minutes = expiration_minutes
|
90
|
+
self.custom_claims = custom_claims or {}
|
91
|
+
self.use_session_state = use_session_state
|
92
|
+
|
93
|
+
# Underlying credential service for storage
|
94
|
+
if use_session_state:
|
95
|
+
self._storage_service = SessionStateCredentialService()
|
96
|
+
else:
|
97
|
+
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
|
98
|
+
self._storage_service = InMemoryCredentialService()
|
99
|
+
|
100
|
+
async def _initialize_impl(self) -> None:
|
101
|
+
"""Initialize the JWT credential service.
|
102
|
+
|
103
|
+
Validates the configuration parameters.
|
104
|
+
|
105
|
+
Raises:
|
106
|
+
ValueError: If configuration is invalid.
|
107
|
+
"""
|
108
|
+
if not self.secret:
|
109
|
+
raise ValueError("JWT secret is required")
|
110
|
+
|
111
|
+
if self.algorithm not in self.SUPPORTED_ALGORITHMS:
|
112
|
+
raise ValueError(f"Unsupported JWT algorithm: {self.algorithm}. Supported: {self.SUPPORTED_ALGORITHMS}")
|
113
|
+
|
114
|
+
if self.expiration_minutes <= 0:
|
115
|
+
raise ValueError("JWT expiration_minutes must be positive")
|
116
|
+
|
117
|
+
# Test JWT creation to validate secret and algorithm
|
118
|
+
try:
|
119
|
+
test_payload = {"test": "validation"}
|
120
|
+
jwt.encode(test_payload, self.secret, algorithm=self.algorithm)
|
121
|
+
logger.info(f"Initialized JWT credential service with algorithm {self.algorithm}")
|
122
|
+
except Exception as e:
|
123
|
+
raise ValueError(f"Invalid JWT configuration: {e}")
|
124
|
+
|
125
|
+
def generate_jwt_token(self, user_id: str, additional_claims: Optional[Dict[str, Any]] = None) -> str:
|
126
|
+
"""Generate a JWT token for the specified user.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
user_id: The user ID to include in the JWT token.
|
130
|
+
additional_claims: Additional claims to include in this specific token.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
str: The generated JWT token.
|
134
|
+
|
135
|
+
Raises:
|
136
|
+
RuntimeError: If the service is not initialized.
|
137
|
+
"""
|
138
|
+
self._check_initialized()
|
139
|
+
|
140
|
+
now = datetime.now(timezone.utc)
|
141
|
+
exp = now + timedelta(minutes=self.expiration_minutes)
|
142
|
+
|
143
|
+
payload = {
|
144
|
+
"sub": user_id, # Subject (user ID)
|
145
|
+
"iat": now, # Issued at
|
146
|
+
"exp": exp, # Expiration
|
147
|
+
}
|
148
|
+
|
149
|
+
# Add optional standard claims
|
150
|
+
if self.issuer:
|
151
|
+
payload["iss"] = self.issuer
|
152
|
+
if self.audience:
|
153
|
+
payload["aud"] = self.audience
|
154
|
+
|
155
|
+
# Add custom claims
|
156
|
+
payload.update(self.custom_claims)
|
157
|
+
if additional_claims:
|
158
|
+
payload.update(additional_claims)
|
159
|
+
|
160
|
+
token = jwt.encode(payload, self.secret, algorithm=self.algorithm)
|
161
|
+
logger.debug(f"Generated JWT token for user {user_id} expiring at {exp}")
|
162
|
+
|
163
|
+
return token
|
164
|
+
|
165
|
+
def verify_jwt_token(self, token: str) -> Dict[str, Any]:
|
166
|
+
"""Verify and decode a JWT token.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
token: The JWT token to verify.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
Dict[str, Any]: The decoded token payload.
|
173
|
+
|
174
|
+
Raises:
|
175
|
+
jwt.InvalidTokenError: If the token is invalid or expired.
|
176
|
+
RuntimeError: If the service is not initialized.
|
177
|
+
"""
|
178
|
+
self._check_initialized()
|
179
|
+
|
180
|
+
options = {
|
181
|
+
"verify_signature": True,
|
182
|
+
"verify_exp": True,
|
183
|
+
"verify_iat": True,
|
184
|
+
}
|
185
|
+
|
186
|
+
# Set audience verification if configured
|
187
|
+
audience = self.audience if self.audience else None
|
188
|
+
issuer = self.issuer if self.issuer else None
|
189
|
+
|
190
|
+
payload = jwt.decode(
|
191
|
+
token,
|
192
|
+
self.secret,
|
193
|
+
algorithms=[self.algorithm],
|
194
|
+
audience=audience,
|
195
|
+
issuer=issuer,
|
196
|
+
options=options
|
197
|
+
)
|
198
|
+
|
199
|
+
return payload
|
200
|
+
|
201
|
+
def is_token_expired(self, token: str) -> bool:
|
202
|
+
"""Check if a JWT token is expired without raising an exception.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
token: The JWT token to check.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
bool: True if the token is expired, False otherwise.
|
209
|
+
"""
|
210
|
+
try:
|
211
|
+
self.verify_jwt_token(token)
|
212
|
+
return False
|
213
|
+
except jwt.ExpiredSignatureError:
|
214
|
+
return True
|
215
|
+
except jwt.InvalidTokenError:
|
216
|
+
# Other validation errors also count as "expired" for refresh purposes
|
217
|
+
return True
|
218
|
+
|
219
|
+
def create_auth_config(self, user_id: str, additional_claims: Optional[Dict[str, Any]] = None) -> AuthConfig:
|
220
|
+
"""Create an AuthConfig with a generated JWT token.
|
221
|
+
|
222
|
+
Args:
|
223
|
+
user_id: The user ID for the JWT token.
|
224
|
+
additional_claims: Additional claims for this specific token.
|
225
|
+
|
226
|
+
Returns:
|
227
|
+
AuthConfig: Configured auth config with JWT Bearer token.
|
228
|
+
"""
|
229
|
+
self._check_initialized()
|
230
|
+
|
231
|
+
# Generate JWT token
|
232
|
+
token = self.generate_jwt_token(user_id, additional_claims)
|
233
|
+
|
234
|
+
# Create HTTP Bearer auth scheme
|
235
|
+
auth_scheme = HTTPBearer()
|
236
|
+
|
237
|
+
# Create HTTP Bearer credential
|
238
|
+
auth_credential = AuthCredential(
|
239
|
+
auth_type=AuthCredentialTypes.HTTP,
|
240
|
+
http=HttpAuth(
|
241
|
+
scheme="bearer",
|
242
|
+
credentials=HttpCredentials(token=token)
|
243
|
+
)
|
244
|
+
)
|
245
|
+
|
246
|
+
return AuthConfig(
|
247
|
+
auth_scheme=auth_scheme,
|
248
|
+
raw_auth_credential=auth_credential
|
249
|
+
)
|
250
|
+
|
251
|
+
async def load_credential(
|
252
|
+
self,
|
253
|
+
auth_config: AuthConfig,
|
254
|
+
callback_context: CallbackContext,
|
255
|
+
) -> Optional[AuthCredential]:
|
256
|
+
"""Load JWT credential from storage and refresh if expired.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
auth_config: The auth config containing credential key information.
|
260
|
+
callback_context: The current callback context.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
Optional[AuthCredential]: The stored credential or refreshed credential.
|
264
|
+
"""
|
265
|
+
self._check_initialized()
|
266
|
+
|
267
|
+
# Load existing credential
|
268
|
+
credential = await self._storage_service.load_credential(auth_config, callback_context)
|
269
|
+
|
270
|
+
if not credential or not credential.http or not credential.http.credentials:
|
271
|
+
return None
|
272
|
+
|
273
|
+
# Check if token needs refresh
|
274
|
+
token = credential.http.credentials.token
|
275
|
+
if not token or self.is_token_expired(token):
|
276
|
+
logger.info(f"JWT token expired for user {callback_context._invocation_context.user_id}, generating new token")
|
277
|
+
|
278
|
+
# Generate new token
|
279
|
+
user_id = callback_context._invocation_context.user_id
|
280
|
+
new_token = self.generate_jwt_token(user_id)
|
281
|
+
|
282
|
+
# Update credential with new token
|
283
|
+
credential.http.credentials.token = new_token
|
284
|
+
|
285
|
+
# Save refreshed credential
|
286
|
+
updated_auth_config = AuthConfig(
|
287
|
+
auth_scheme=auth_config.auth_scheme,
|
288
|
+
raw_auth_credential=credential,
|
289
|
+
exchanged_auth_credential=credential
|
290
|
+
)
|
291
|
+
await self._storage_service.save_credential(updated_auth_config, callback_context)
|
292
|
+
|
293
|
+
return credential
|
294
|
+
|
295
|
+
async def save_credential(
|
296
|
+
self,
|
297
|
+
auth_config: AuthConfig,
|
298
|
+
callback_context: CallbackContext,
|
299
|
+
) -> None:
|
300
|
+
"""Save JWT credential to storage.
|
301
|
+
|
302
|
+
Args:
|
303
|
+
auth_config: The auth config containing the credential to save.
|
304
|
+
callback_context: The current callback context.
|
305
|
+
"""
|
306
|
+
self._check_initialized()
|
307
|
+
await self._storage_service.save_credential(auth_config, callback_context)
|
308
|
+
|
309
|
+
logger.info(f"Saved JWT credential for user {callback_context._invocation_context.user_id}")
|
310
|
+
|
311
|
+
def get_token_info(self, token: str) -> Dict[str, Any]:
|
312
|
+
"""Get information about a JWT token without full verification.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
token: The JWT token to inspect.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
Dict[str, Any]: Token information including claims and expiration.
|
319
|
+
"""
|
320
|
+
try:
|
321
|
+
# Decode without verification to get token info
|
322
|
+
payload = jwt.decode(token, options={"verify_signature": False})
|
323
|
+
|
324
|
+
# Check expiration without requiring initialization
|
325
|
+
expired = True
|
326
|
+
if self._initialized:
|
327
|
+
expired = self.is_token_expired(token)
|
328
|
+
elif "exp" in payload:
|
329
|
+
# Simple expiration check without initialization
|
330
|
+
exp_timestamp = payload["exp"]
|
331
|
+
expired = datetime.now(timezone.utc).timestamp() > exp_timestamp
|
332
|
+
|
333
|
+
info = {
|
334
|
+
"payload": payload,
|
335
|
+
"expired": expired,
|
336
|
+
}
|
337
|
+
|
338
|
+
if "exp" in payload:
|
339
|
+
exp_timestamp = payload["exp"]
|
340
|
+
exp_datetime = datetime.fromtimestamp(exp_timestamp, timezone.utc)
|
341
|
+
info["expires_at"] = exp_datetime.isoformat()
|
342
|
+
|
343
|
+
return info
|
344
|
+
except Exception as e:
|
345
|
+
return {"error": str(e), "expired": True}
|
@@ -0,0 +1,250 @@
|
|
1
|
+
"""Microsoft OAuth2 credential service implementation."""
|
2
|
+
|
3
|
+
from typing import Optional, List
|
4
|
+
import logging
|
5
|
+
|
6
|
+
from google.adk.auth.credential_service.session_state_credential_service import SessionStateCredentialService
|
7
|
+
from google.adk.auth.credential_service.base_credential_service import CallbackContext
|
8
|
+
from google.adk.auth import AuthConfig, AuthCredential, AuthCredentialTypes
|
9
|
+
from google.adk.auth.auth_credential import OAuth2Auth
|
10
|
+
from fastapi.openapi.models import OAuth2
|
11
|
+
from fastapi.openapi.models import OAuthFlowAuthorizationCode, OAuthFlows
|
12
|
+
|
13
|
+
from .base_custom_credential_service import BaseCustomCredentialService
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class MicrosoftOAuth2CredentialService(BaseCustomCredentialService):
|
19
|
+
"""Microsoft OAuth2 credential service for handling Microsoft authentication flows.
|
20
|
+
|
21
|
+
This service provides pre-configured OAuth2 flows for Microsoft Graph API including
|
22
|
+
Outlook, Teams, OneDrive, and other Microsoft 365 services.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
tenant_id: The Azure AD tenant ID. Use "common" for multi-tenant applications.
|
26
|
+
client_id: The Microsoft OAuth2 client ID from Azure AD App Registration.
|
27
|
+
client_secret: The Microsoft OAuth2 client secret from Azure AD App Registration.
|
28
|
+
scopes: List of OAuth2 scopes to request. Common scopes include:
|
29
|
+
- "User.Read" - Read user profile
|
30
|
+
- "Mail.Read" - Read user's mail
|
31
|
+
- "Mail.ReadWrite" - Read and write user's mail
|
32
|
+
- "Calendars.Read" - Read user's calendars
|
33
|
+
- "Calendars.ReadWrite" - Read and write user's calendars
|
34
|
+
- "Files.Read" - Read user's files
|
35
|
+
- "Files.ReadWrite" - Read and write user's files
|
36
|
+
use_session_state: If True, stores credentials in session state. If False,
|
37
|
+
uses in-memory storage. Default is True for persistence.
|
38
|
+
|
39
|
+
Example:
|
40
|
+
```python
|
41
|
+
credential_service = MicrosoftOAuth2CredentialService(
|
42
|
+
tenant_id="common", # or specific tenant ID
|
43
|
+
client_id="your-azure-client-id",
|
44
|
+
client_secret="your-azure-client-secret",
|
45
|
+
scopes=["User.Read", "Mail.Read", "Calendars.ReadWrite"]
|
46
|
+
)
|
47
|
+
await credential_service.initialize()
|
48
|
+
|
49
|
+
# Use with Runner
|
50
|
+
runner = Runner(
|
51
|
+
agent=agent,
|
52
|
+
session_service=session_service,
|
53
|
+
credential_service=credential_service,
|
54
|
+
app_name="my_app"
|
55
|
+
)
|
56
|
+
```
|
57
|
+
"""
|
58
|
+
|
59
|
+
# Microsoft OAuth2 endpoints (v2.0)
|
60
|
+
def _get_auth_url(self, tenant_id: str) -> str:
|
61
|
+
"""Get the Microsoft OAuth2 authorization URL for the tenant."""
|
62
|
+
return f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/authorize"
|
63
|
+
|
64
|
+
def _get_token_url(self, tenant_id: str) -> str:
|
65
|
+
"""Get the Microsoft OAuth2 token URL for the tenant."""
|
66
|
+
return f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"
|
67
|
+
|
68
|
+
# Common Microsoft Graph API scopes
|
69
|
+
COMMON_SCOPES = {
|
70
|
+
# User and profile scopes
|
71
|
+
"User.Read": "Read user profile",
|
72
|
+
"User.ReadWrite": "Read and write user profile",
|
73
|
+
"User.ReadBasic.All": "Read basic profiles of all users",
|
74
|
+
"User.Read.All": "Read all users' full profiles",
|
75
|
+
"User.ReadWrite.All": "Read and write all users' full profiles",
|
76
|
+
|
77
|
+
# Mail scopes
|
78
|
+
"Mail.Read": "Read user mail",
|
79
|
+
"Mail.ReadWrite": "Read and write user mail",
|
80
|
+
"Mail.Send": "Send mail as user",
|
81
|
+
"Mail.Read.Shared": "Read user and shared mail",
|
82
|
+
"Mail.ReadWrite.Shared": "Read and write user and shared mail",
|
83
|
+
|
84
|
+
# Calendar scopes
|
85
|
+
"Calendars.Read": "Read user calendars",
|
86
|
+
"Calendars.ReadWrite": "Read and write user calendars",
|
87
|
+
"Calendars.Read.Shared": "Read user and shared calendars",
|
88
|
+
"Calendars.ReadWrite.Shared": "Read and write user and shared calendars",
|
89
|
+
|
90
|
+
# Files and OneDrive scopes
|
91
|
+
"Files.Read": "Read user files",
|
92
|
+
"Files.ReadWrite": "Read and write user files",
|
93
|
+
"Files.Read.All": "Read all files that user can access",
|
94
|
+
"Files.ReadWrite.All": "Read and write all files that user can access",
|
95
|
+
"Sites.Read.All": "Read items in all site collections",
|
96
|
+
"Sites.ReadWrite.All": "Read and write items in all site collections",
|
97
|
+
|
98
|
+
# Groups and Teams scopes
|
99
|
+
"Group.Read.All": "Read all groups",
|
100
|
+
"Group.ReadWrite.All": "Read and write all groups",
|
101
|
+
"Team.ReadBasic.All": "Read names and descriptions of teams",
|
102
|
+
"TeamSettings.Read.All": "Read all teams' settings",
|
103
|
+
"TeamSettings.ReadWrite.All": "Read and write all teams' settings",
|
104
|
+
|
105
|
+
# Directory scopes
|
106
|
+
"Directory.Read.All": "Read directory data",
|
107
|
+
"Directory.ReadWrite.All": "Read and write directory data",
|
108
|
+
|
109
|
+
# Application scopes
|
110
|
+
"Application.Read.All": "Read all applications",
|
111
|
+
"Application.ReadWrite.All": "Read and write all applications",
|
112
|
+
|
113
|
+
# OpenID Connect scopes
|
114
|
+
"openid": "OpenID Connect sign-in",
|
115
|
+
"email": "View user's email address",
|
116
|
+
"profile": "View user's basic profile",
|
117
|
+
"offline_access": "Maintain access to data you have given access to"
|
118
|
+
}
|
119
|
+
|
120
|
+
def __init__(
|
121
|
+
self,
|
122
|
+
tenant_id: str,
|
123
|
+
client_id: str,
|
124
|
+
client_secret: str,
|
125
|
+
scopes: Optional[List[str]] = None,
|
126
|
+
use_session_state: bool = True
|
127
|
+
):
|
128
|
+
"""Initialize the Microsoft OAuth2 credential service.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
tenant_id: Azure AD tenant ID or "common" for multi-tenant.
|
132
|
+
client_id: Microsoft OAuth2 client ID.
|
133
|
+
client_secret: Microsoft OAuth2 client secret.
|
134
|
+
scopes: List of OAuth2 scopes to request.
|
135
|
+
use_session_state: Whether to use session state for credential storage.
|
136
|
+
"""
|
137
|
+
super().__init__()
|
138
|
+
self.tenant_id = tenant_id
|
139
|
+
self.client_id = client_id
|
140
|
+
self.client_secret = client_secret
|
141
|
+
self.scopes = scopes or ["User.Read", "Mail.Read"]
|
142
|
+
self.use_session_state = use_session_state
|
143
|
+
|
144
|
+
# Underlying credential service for storage
|
145
|
+
if use_session_state:
|
146
|
+
self._storage_service = SessionStateCredentialService()
|
147
|
+
else:
|
148
|
+
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
|
149
|
+
self._storage_service = InMemoryCredentialService()
|
150
|
+
|
151
|
+
async def _initialize_impl(self) -> None:
|
152
|
+
"""Initialize the Microsoft OAuth2 credential service.
|
153
|
+
|
154
|
+
Validates the client credentials and sets up the OAuth2 auth scheme.
|
155
|
+
|
156
|
+
Raises:
|
157
|
+
ValueError: If required parameters are missing.
|
158
|
+
"""
|
159
|
+
if not self.tenant_id:
|
160
|
+
raise ValueError("Microsoft OAuth2 tenant_id is required")
|
161
|
+
if not self.client_id:
|
162
|
+
raise ValueError("Microsoft OAuth2 client_id is required")
|
163
|
+
if not self.client_secret:
|
164
|
+
raise ValueError("Microsoft OAuth2 client_secret is required")
|
165
|
+
if not self.scopes:
|
166
|
+
raise ValueError("At least one OAuth2 scope is required")
|
167
|
+
|
168
|
+
# Validate scopes against known Microsoft scopes
|
169
|
+
unknown_scopes = set(self.scopes) - set(self.COMMON_SCOPES.keys())
|
170
|
+
if unknown_scopes:
|
171
|
+
logger.warning(f"Unknown Microsoft OAuth2 scopes: {unknown_scopes}")
|
172
|
+
|
173
|
+
logger.info(f"Initialized Microsoft OAuth2 credential service for tenant {self.tenant_id} with scopes: {self.scopes}")
|
174
|
+
|
175
|
+
def create_auth_config(self) -> AuthConfig:
|
176
|
+
"""Create an AuthConfig for Microsoft OAuth2 authentication.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
AuthConfig: Configured auth config for Microsoft OAuth2 flow.
|
180
|
+
"""
|
181
|
+
self._check_initialized()
|
182
|
+
|
183
|
+
# Create OAuth2 auth scheme
|
184
|
+
auth_scheme = OAuth2(
|
185
|
+
flows=OAuthFlows(
|
186
|
+
authorizationCode=OAuthFlowAuthorizationCode(
|
187
|
+
authorizationUrl=self._get_auth_url(self.tenant_id),
|
188
|
+
tokenUrl=self._get_token_url(self.tenant_id),
|
189
|
+
scopes={
|
190
|
+
scope: self.COMMON_SCOPES.get(scope, f"Microsoft Graph scope: {scope}")
|
191
|
+
for scope in self.scopes
|
192
|
+
}
|
193
|
+
)
|
194
|
+
)
|
195
|
+
)
|
196
|
+
|
197
|
+
# Create OAuth2 credential
|
198
|
+
auth_credential = AuthCredential(
|
199
|
+
auth_type=AuthCredentialTypes.OAUTH2,
|
200
|
+
oauth2=OAuth2Auth(
|
201
|
+
client_id=self.client_id,
|
202
|
+
client_secret=self.client_secret
|
203
|
+
)
|
204
|
+
)
|
205
|
+
|
206
|
+
return AuthConfig(
|
207
|
+
auth_scheme=auth_scheme,
|
208
|
+
raw_auth_credential=auth_credential
|
209
|
+
)
|
210
|
+
|
211
|
+
async def load_credential(
|
212
|
+
self,
|
213
|
+
auth_config: AuthConfig,
|
214
|
+
callback_context: CallbackContext,
|
215
|
+
) -> Optional[AuthCredential]:
|
216
|
+
"""Load Microsoft OAuth2 credential from storage.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
auth_config: The auth config containing credential key information.
|
220
|
+
callback_context: The current callback context.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
Optional[AuthCredential]: The stored credential or None if not found.
|
224
|
+
"""
|
225
|
+
self._check_initialized()
|
226
|
+
return await self._storage_service.load_credential(auth_config, callback_context)
|
227
|
+
|
228
|
+
async def save_credential(
|
229
|
+
self,
|
230
|
+
auth_config: AuthConfig,
|
231
|
+
callback_context: CallbackContext,
|
232
|
+
) -> None:
|
233
|
+
"""Save Microsoft OAuth2 credential to storage.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
auth_config: The auth config containing the credential to save.
|
237
|
+
callback_context: The current callback context.
|
238
|
+
"""
|
239
|
+
self._check_initialized()
|
240
|
+
await self._storage_service.save_credential(auth_config, callback_context)
|
241
|
+
|
242
|
+
logger.info(f"Saved Microsoft OAuth2 credential for user {callback_context._invocation_context.user_id} in tenant {self.tenant_id}")
|
243
|
+
|
244
|
+
def get_supported_scopes(self) -> dict:
|
245
|
+
"""Get dictionary of supported Microsoft OAuth2 scopes and their descriptions.
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
dict: Mapping of scope names to descriptions.
|
249
|
+
"""
|
250
|
+
return self.COMMON_SCOPES.copy()
|