omnidapter 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnidapter/__init__.py +118 -0
- omnidapter/auth/__init__.py +13 -0
- omnidapter/auth/kinds.py +5 -0
- omnidapter/auth/models.py +63 -0
- omnidapter/auth/oauth.py +215 -0
- omnidapter/auth/refresh.py +94 -0
- omnidapter/core/__init__.py +0 -0
- omnidapter/core/connection.py +106 -0
- omnidapter/core/errors.py +175 -0
- omnidapter/core/logging.py +41 -0
- omnidapter/core/metadata.py +69 -0
- omnidapter/core/omnidapter.py +160 -0
- omnidapter/core/registry.py +188 -0
- omnidapter/providers/__init__.py +0 -0
- omnidapter/providers/_base.py +128 -0
- omnidapter/providers/_oauth.py +236 -0
- omnidapter/providers/apple/__init__.py +0 -0
- omnidapter/providers/apple/calendar.py +47 -0
- omnidapter/providers/apple/metadata.py +52 -0
- omnidapter/providers/apple/provider.py +39 -0
- omnidapter/providers/caldav/__init__.py +0 -0
- omnidapter/providers/caldav/auth.py +15 -0
- omnidapter/providers/caldav/calendar.py +442 -0
- omnidapter/providers/caldav/mappers.py +234 -0
- omnidapter/providers/caldav/metadata.py +66 -0
- omnidapter/providers/caldav/provider.py +39 -0
- omnidapter/providers/caldav/server_hints.py +65 -0
- omnidapter/providers/google/__init__.py +0 -0
- omnidapter/providers/google/calendar.py +308 -0
- omnidapter/providers/google/mappers.py +361 -0
- omnidapter/providers/google/metadata.py +66 -0
- omnidapter/providers/google/oauth.py +27 -0
- omnidapter/providers/google/provider.py +47 -0
- omnidapter/providers/microsoft/__init__.py +0 -0
- omnidapter/providers/microsoft/calendar.py +298 -0
- omnidapter/providers/microsoft/mappers.py +365 -0
- omnidapter/providers/microsoft/metadata.py +67 -0
- omnidapter/providers/microsoft/oauth.py +20 -0
- omnidapter/providers/microsoft/provider.py +51 -0
- omnidapter/providers/zoho/__init__.py +0 -0
- omnidapter/providers/zoho/calendar.py +315 -0
- omnidapter/providers/zoho/mappers.py +165 -0
- omnidapter/providers/zoho/metadata.py +54 -0
- omnidapter/providers/zoho/oauth.py +21 -0
- omnidapter/providers/zoho/provider.py +47 -0
- omnidapter/services/__init__.py +0 -0
- omnidapter/services/calendar/__init__.py +0 -0
- omnidapter/services/calendar/capabilities.py +31 -0
- omnidapter/services/calendar/interface.py +121 -0
- omnidapter/services/calendar/models.py +179 -0
- omnidapter/services/calendar/requests.py +106 -0
- omnidapter/stores/__init__.py +0 -0
- omnidapter/stores/credentials.py +91 -0
- omnidapter/stores/memory.py +63 -0
- omnidapter/stores/oauth_state.py +40 -0
- omnidapter/testing/__init__.py +17 -0
- omnidapter/testing/contracts/__init__.py +0 -0
- omnidapter/testing/contracts/calendar.py +104 -0
- omnidapter/testing/fakes/__init__.py +0 -0
- omnidapter/testing/fakes/stores.py +21 -0
- omnidapter/transport/__init__.py +0 -0
- omnidapter/transport/client.py +365 -0
- omnidapter/transport/correlation.py +12 -0
- omnidapter/transport/hooks.py +69 -0
- omnidapter/transport/retry.py +53 -0
- omnidapter-0.3.2.dist-info/METADATA +9 -0
- omnidapter-0.3.2.dist-info/RECORD +68 -0
- omnidapter-0.3.2.dist-info/WHEEL +4 -0
omnidapter/__init__.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Omnidapter — provider-agnostic async integration library.
|
|
3
|
+
|
|
4
|
+
Quick start:
|
|
5
|
+
from omnidapter import Omnidapter
|
|
6
|
+
from omnidapter.transport.retry import RetryPolicy
|
|
7
|
+
|
|
8
|
+
omni = Omnidapter(
|
|
9
|
+
credential_store=my_store,
|
|
10
|
+
oauth_state_store=my_state_store,
|
|
11
|
+
)
|
|
12
|
+
conn = await omni.connection("conn_123")
|
|
13
|
+
calendar = conn.calendar()
|
|
14
|
+
|
|
15
|
+
async for event in calendar.list_events(calendar_id="primary"):
|
|
16
|
+
print(event.summary)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from omnidapter.auth.models import (
|
|
20
|
+
ApiKeyCredentials,
|
|
21
|
+
BaseCredentials,
|
|
22
|
+
BasicCredentials,
|
|
23
|
+
OAuth2Credentials,
|
|
24
|
+
)
|
|
25
|
+
from omnidapter.core.errors import (
|
|
26
|
+
AuthError,
|
|
27
|
+
ConnectionNotFoundError,
|
|
28
|
+
InvalidCredentialFormatError,
|
|
29
|
+
OAuthStateError,
|
|
30
|
+
OmnidapterError,
|
|
31
|
+
ProviderAPIError,
|
|
32
|
+
ProviderNotConfiguredError,
|
|
33
|
+
RateLimitError,
|
|
34
|
+
ScopeInsufficientError,
|
|
35
|
+
TokenRefreshError,
|
|
36
|
+
TransportError,
|
|
37
|
+
UnsupportedCapabilityError,
|
|
38
|
+
)
|
|
39
|
+
from omnidapter.core.metadata import AuthKind, ProviderMetadata, ServiceKind
|
|
40
|
+
from omnidapter.core.omnidapter import Omnidapter
|
|
41
|
+
from omnidapter.core.registry import ProviderRegistry
|
|
42
|
+
from omnidapter.services.calendar.capabilities import CalendarCapability
|
|
43
|
+
from omnidapter.services.calendar.models import (
|
|
44
|
+
Attendee,
|
|
45
|
+
AvailabilityResponse,
|
|
46
|
+
Calendar,
|
|
47
|
+
CalendarEvent,
|
|
48
|
+
ConferenceData,
|
|
49
|
+
EventStatus,
|
|
50
|
+
Organizer,
|
|
51
|
+
Recurrence,
|
|
52
|
+
)
|
|
53
|
+
from omnidapter.services.calendar.requests import (
|
|
54
|
+
CreateCalendarRequest,
|
|
55
|
+
CreateEventRequest,
|
|
56
|
+
GetAvailabilityRequest,
|
|
57
|
+
UpdateCalendarRequest,
|
|
58
|
+
UpdateEventRequest,
|
|
59
|
+
)
|
|
60
|
+
from omnidapter.stores.credentials import CredentialStore, StoredCredential
|
|
61
|
+
from omnidapter.stores.memory import InMemoryCredentialStore, InMemoryOAuthStateStore
|
|
62
|
+
from omnidapter.stores.oauth_state import OAuthStateStore
|
|
63
|
+
from omnidapter.transport.retry import RetryPolicy
|
|
64
|
+
|
|
65
|
+
__version__ = "0.1.0"
|
|
66
|
+
|
|
67
|
+
__all__ = [
|
|
68
|
+
"Omnidapter",
|
|
69
|
+
# Errors
|
|
70
|
+
"OmnidapterError",
|
|
71
|
+
"AuthError",
|
|
72
|
+
"OAuthStateError",
|
|
73
|
+
"ProviderNotConfiguredError",
|
|
74
|
+
"TokenRefreshError",
|
|
75
|
+
"UnsupportedCapabilityError",
|
|
76
|
+
"ConnectionNotFoundError",
|
|
77
|
+
"InvalidCredentialFormatError",
|
|
78
|
+
"ScopeInsufficientError",
|
|
79
|
+
"TransportError",
|
|
80
|
+
"ProviderAPIError",
|
|
81
|
+
"RateLimitError",
|
|
82
|
+
# Auth
|
|
83
|
+
"BaseCredentials",
|
|
84
|
+
"OAuth2Credentials",
|
|
85
|
+
"ApiKeyCredentials",
|
|
86
|
+
"BasicCredentials",
|
|
87
|
+
# Stores
|
|
88
|
+
"StoredCredential",
|
|
89
|
+
"CredentialStore",
|
|
90
|
+
"OAuthStateStore",
|
|
91
|
+
"InMemoryCredentialStore",
|
|
92
|
+
"InMemoryOAuthStateStore",
|
|
93
|
+
# Transport
|
|
94
|
+
"RetryPolicy",
|
|
95
|
+
# Registry
|
|
96
|
+
"ProviderRegistry",
|
|
97
|
+
# Metadata
|
|
98
|
+
"AuthKind",
|
|
99
|
+
"ServiceKind",
|
|
100
|
+
"ProviderMetadata",
|
|
101
|
+
# Calendar
|
|
102
|
+
"CalendarCapability",
|
|
103
|
+
"CalendarEvent",
|
|
104
|
+
"EventStatus",
|
|
105
|
+
"Calendar",
|
|
106
|
+
"AvailabilityResponse",
|
|
107
|
+
"Attendee",
|
|
108
|
+
"Organizer",
|
|
109
|
+
"Recurrence",
|
|
110
|
+
"ConferenceData",
|
|
111
|
+
"CreateCalendarRequest",
|
|
112
|
+
"CreateEventRequest",
|
|
113
|
+
"UpdateCalendarRequest",
|
|
114
|
+
"UpdateEventRequest",
|
|
115
|
+
"GetAvailabilityRequest",
|
|
116
|
+
# Version
|
|
117
|
+
"__version__",
|
|
118
|
+
]
|
omnidapter/auth/kinds.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Typed credential payload models.
|
|
3
|
+
|
|
4
|
+
These are the auth-kind-specific payloads stored inside a ``StoredCredential``.
|
|
5
|
+
Every concrete credential type must inherit from :class:`BaseCredentials` so that
|
|
6
|
+
calling code can use ``isinstance(creds, BaseCredentials)`` to distinguish a
|
|
7
|
+
credential envelope's payload from arbitrary dicts.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from datetime import datetime, timedelta, timezone
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from pydantic import BaseModel, Field
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseCredentials(BaseModel):
|
|
19
|
+
"""Base class for all credential payload types.
|
|
20
|
+
|
|
21
|
+
``OAuth2Credentials``, ``ApiKeyCredentials``, and ``BasicCredentials`` all
|
|
22
|
+
inherit from this class. It is the common base for ``isinstance`` checks
|
|
23
|
+
and an extension point for shared behaviour in future auth kinds.
|
|
24
|
+
|
|
25
|
+
Consuming apps that implement custom credential types for non-standard
|
|
26
|
+
auth schemes should also inherit from ``BaseCredentials`` so that
|
|
27
|
+
``StoredCredential.credentials`` remains uniformly typed.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OAuth2Credentials(BaseCredentials):
|
|
32
|
+
"""OAuth2 token payload."""
|
|
33
|
+
|
|
34
|
+
access_token: str
|
|
35
|
+
refresh_token: str | None = None
|
|
36
|
+
token_type: str = "Bearer"
|
|
37
|
+
expires_at: datetime | None = None # UTC
|
|
38
|
+
id_token: str | None = None
|
|
39
|
+
raw: dict[str, Any] = Field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
def is_expired(self, buffer_seconds: float = 60.0) -> bool:
|
|
42
|
+
"""Return True if the access token is expired (or within *buffer_seconds* of expiry)."""
|
|
43
|
+
if self.expires_at is None:
|
|
44
|
+
return False
|
|
45
|
+
return datetime.now(tz=timezone.utc) >= self.expires_at - timedelta(seconds=buffer_seconds)
|
|
46
|
+
|
|
47
|
+
def is_refreshable(self) -> bool:
|
|
48
|
+
"""Return True if a refresh_token is available."""
|
|
49
|
+
return self.refresh_token is not None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ApiKeyCredentials(BaseCredentials):
|
|
53
|
+
"""API key auth payload."""
|
|
54
|
+
|
|
55
|
+
api_key: str
|
|
56
|
+
header_name: str = "X-API-Key"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class BasicCredentials(BaseCredentials):
|
|
60
|
+
"""HTTP Basic auth payload."""
|
|
61
|
+
|
|
62
|
+
username: str
|
|
63
|
+
password: str
|
omnidapter/auth/oauth.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OAuth 2.0 flow helpers — begin and complete flows.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import hashlib
|
|
8
|
+
import secrets
|
|
9
|
+
import urllib.parse
|
|
10
|
+
from datetime import datetime, timedelta, timezone
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from omnidapter.core.logging import auth_logger
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import httpx
|
|
19
|
+
|
|
20
|
+
from omnidapter.core.registry import ProviderRegistry
|
|
21
|
+
from omnidapter.stores.credentials import CredentialStore
|
|
22
|
+
from omnidapter.stores.oauth_state import OAuthStateStore
|
|
23
|
+
from omnidapter.transport.hooks import TransportHooks
|
|
24
|
+
from omnidapter.transport.retry import RetryPolicy
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OAuthBeginResult(BaseModel):
|
|
28
|
+
"""Result of beginning an OAuth flow."""
|
|
29
|
+
|
|
30
|
+
authorization_url: str
|
|
31
|
+
state: str
|
|
32
|
+
connection_id: str
|
|
33
|
+
provider: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OAuthPendingState(BaseModel):
|
|
37
|
+
"""Payload stored in the OAuthStateStore during a pending OAuth flow."""
|
|
38
|
+
|
|
39
|
+
connection_id: str
|
|
40
|
+
provider: str
|
|
41
|
+
redirect_uri: str
|
|
42
|
+
code_verifier: str | None = None
|
|
43
|
+
expires_at: datetime # UTC
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _generate_pkce_pair() -> tuple[str, str]:
|
|
47
|
+
"""Generate a PKCE code_verifier and code_challenge pair (S256 method)."""
|
|
48
|
+
import base64
|
|
49
|
+
|
|
50
|
+
verifier = secrets.token_urlsafe(64)
|
|
51
|
+
challenge = (
|
|
52
|
+
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).rstrip(b"=").decode()
|
|
53
|
+
)
|
|
54
|
+
return verifier, challenge
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class OAuthHelper:
|
|
58
|
+
"""Manages OAuth begin/complete flows with automatic credential persistence."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
registry: ProviderRegistry,
|
|
63
|
+
credential_store: CredentialStore,
|
|
64
|
+
oauth_state_store: OAuthStateStore,
|
|
65
|
+
retry_policy: RetryPolicy | None = None,
|
|
66
|
+
hooks: TransportHooks | None = None,
|
|
67
|
+
http_client: httpx.AsyncClient | None = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
self._registry = registry
|
|
70
|
+
self._credential_store = credential_store
|
|
71
|
+
self._oauth_state_store = oauth_state_store
|
|
72
|
+
self._retry_policy = retry_policy
|
|
73
|
+
self._hooks = hooks
|
|
74
|
+
self._http_client = http_client
|
|
75
|
+
|
|
76
|
+
def _configure_provider_transport(self, provider_impl: Any) -> None:
|
|
77
|
+
configure = getattr(provider_impl, "configure_oauth_transport", None)
|
|
78
|
+
if callable(configure):
|
|
79
|
+
configure(
|
|
80
|
+
retry_policy=self._retry_policy,
|
|
81
|
+
hooks=self._hooks,
|
|
82
|
+
http_client=self._http_client,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def begin(
|
|
86
|
+
self,
|
|
87
|
+
provider: str,
|
|
88
|
+
connection_id: str,
|
|
89
|
+
redirect_uri: str,
|
|
90
|
+
scopes: list[str] | None = None,
|
|
91
|
+
extra_params: dict[str, str] | None = None,
|
|
92
|
+
) -> OAuthBeginResult:
|
|
93
|
+
"""Begin an OAuth flow.
|
|
94
|
+
|
|
95
|
+
Generates authorization URL, persists temporary state, returns the redirect URL.
|
|
96
|
+
"""
|
|
97
|
+
provider_impl = self._registry.get(provider)
|
|
98
|
+
oauth_config = provider_impl.get_oauth_config()
|
|
99
|
+
if oauth_config is None:
|
|
100
|
+
raise ValueError(f"Provider {provider!r} does not support OAuth2")
|
|
101
|
+
|
|
102
|
+
state_id = secrets.token_urlsafe(32)
|
|
103
|
+
code_verifier: str | None = None
|
|
104
|
+
|
|
105
|
+
params: dict[str, str] = {
|
|
106
|
+
"response_type": "code",
|
|
107
|
+
"client_id": oauth_config.client_id,
|
|
108
|
+
"redirect_uri": redirect_uri,
|
|
109
|
+
"state": state_id,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Scopes
|
|
113
|
+
effective_scopes = scopes or oauth_config.default_scopes
|
|
114
|
+
if effective_scopes:
|
|
115
|
+
params["scope"] = oauth_config.scope_separator.join(effective_scopes)
|
|
116
|
+
|
|
117
|
+
# Provider-defined extra params (e.g. access_type=offline for Google)
|
|
118
|
+
if oauth_config.extra_auth_params:
|
|
119
|
+
params.update(oauth_config.extra_auth_params)
|
|
120
|
+
|
|
121
|
+
# PKCE
|
|
122
|
+
if oauth_config.supports_pkce:
|
|
123
|
+
code_verifier, code_challenge = _generate_pkce_pair()
|
|
124
|
+
params["code_challenge"] = code_challenge
|
|
125
|
+
params["code_challenge_method"] = "S256"
|
|
126
|
+
|
|
127
|
+
# Caller overrides last so they can always override provider defaults
|
|
128
|
+
if extra_params:
|
|
129
|
+
params.update(extra_params)
|
|
130
|
+
|
|
131
|
+
authorization_url = (
|
|
132
|
+
oauth_config.authorization_endpoint + "?" + urllib.parse.urlencode(params)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
pending = OAuthPendingState(
|
|
136
|
+
connection_id=connection_id,
|
|
137
|
+
provider=provider,
|
|
138
|
+
redirect_uri=redirect_uri,
|
|
139
|
+
code_verifier=code_verifier,
|
|
140
|
+
expires_at=datetime.now(tz=timezone.utc) + timedelta(minutes=15),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
await self._oauth_state_store.save_state(
|
|
144
|
+
state_id=state_id,
|
|
145
|
+
payload=pending.model_dump(mode="json"),
|
|
146
|
+
expires_at=pending.expires_at,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
auth_logger.info(
|
|
150
|
+
"OAuth begin: provider=%r connection_id=%r",
|
|
151
|
+
provider,
|
|
152
|
+
connection_id,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return OAuthBeginResult(
|
|
156
|
+
authorization_url=authorization_url,
|
|
157
|
+
state=state_id,
|
|
158
|
+
connection_id=connection_id,
|
|
159
|
+
provider=provider,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def complete(
|
|
163
|
+
self,
|
|
164
|
+
provider: str,
|
|
165
|
+
connection_id: str,
|
|
166
|
+
code: str,
|
|
167
|
+
state: str,
|
|
168
|
+
redirect_uri: str,
|
|
169
|
+
) -> Any:
|
|
170
|
+
"""Complete an OAuth flow.
|
|
171
|
+
|
|
172
|
+
Validates state, exchanges code for tokens, persists credentials.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
The StoredCredential (for inspection only — already persisted).
|
|
176
|
+
"""
|
|
177
|
+
from omnidapter.core.errors import OAuthStateError
|
|
178
|
+
|
|
179
|
+
# Load and validate state
|
|
180
|
+
state_payload = await self._oauth_state_store.load_state(state)
|
|
181
|
+
if state_payload is None:
|
|
182
|
+
raise OAuthStateError("OAuth state not found or expired")
|
|
183
|
+
|
|
184
|
+
pending = OAuthPendingState.model_validate(state_payload)
|
|
185
|
+
|
|
186
|
+
if pending.connection_id != connection_id:
|
|
187
|
+
raise OAuthStateError("OAuth state connection_id mismatch")
|
|
188
|
+
if pending.provider != provider:
|
|
189
|
+
raise OAuthStateError("OAuth state provider mismatch")
|
|
190
|
+
if pending.redirect_uri != redirect_uri:
|
|
191
|
+
raise OAuthStateError("OAuth state redirect_uri mismatch")
|
|
192
|
+
|
|
193
|
+
# Exchange code for tokens
|
|
194
|
+
provider_impl = self._registry.get(provider)
|
|
195
|
+
self._configure_provider_transport(provider_impl)
|
|
196
|
+
stored_credential = await provider_impl.exchange_code_for_tokens(
|
|
197
|
+
connection_id=connection_id,
|
|
198
|
+
code=code,
|
|
199
|
+
redirect_uri=redirect_uri,
|
|
200
|
+
code_verifier=pending.code_verifier,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Persist credentials
|
|
204
|
+
await self._credential_store.save_credentials(connection_id, stored_credential)
|
|
205
|
+
|
|
206
|
+
# Clean up state
|
|
207
|
+
await self._oauth_state_store.delete_state(state)
|
|
208
|
+
|
|
209
|
+
auth_logger.info(
|
|
210
|
+
"OAuth complete: provider=%r connection_id=%r",
|
|
211
|
+
provider,
|
|
212
|
+
connection_id,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return stored_credential
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Automatic token refresh logic.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from omnidapter.core.logging import auth_logger
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from omnidapter.core.registry import ProviderRegistry
|
|
15
|
+
from omnidapter.stores.credentials import CredentialStore, StoredCredential
|
|
16
|
+
from omnidapter.transport.hooks import TransportHooks
|
|
17
|
+
from omnidapter.transport.retry import RetryPolicy
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TokenRefreshManager:
|
|
21
|
+
"""Manages automatic token refresh for OAuth2 credentials."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
registry: ProviderRegistry,
|
|
26
|
+
credential_store: CredentialStore,
|
|
27
|
+
retry_policy: RetryPolicy | None = None,
|
|
28
|
+
hooks: TransportHooks | None = None,
|
|
29
|
+
http_client: httpx.AsyncClient | None = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
self._registry = registry
|
|
32
|
+
self._credential_store = credential_store
|
|
33
|
+
self._retry_policy = retry_policy
|
|
34
|
+
self._hooks = hooks
|
|
35
|
+
self._http_client = http_client
|
|
36
|
+
|
|
37
|
+
def _configure_provider_transport(self, provider_impl: object) -> None:
|
|
38
|
+
configure = getattr(provider_impl, "configure_oauth_transport", None)
|
|
39
|
+
if callable(configure):
|
|
40
|
+
configure(
|
|
41
|
+
retry_policy=self._retry_policy,
|
|
42
|
+
hooks=self._hooks,
|
|
43
|
+
http_client=self._http_client,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
async def ensure_fresh(self, connection_id: str) -> StoredCredential:
|
|
47
|
+
"""Ensure credentials are fresh, refreshing if necessary.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Fresh StoredCredential.
|
|
51
|
+
"""
|
|
52
|
+
from omnidapter.auth.models import OAuth2Credentials
|
|
53
|
+
from omnidapter.core.metadata import AuthKind
|
|
54
|
+
|
|
55
|
+
stored = await self._credential_store.get_credentials(connection_id)
|
|
56
|
+
if stored is None:
|
|
57
|
+
from omnidapter.core.errors import ConnectionNotFoundError
|
|
58
|
+
|
|
59
|
+
raise ConnectionNotFoundError(connection_id)
|
|
60
|
+
|
|
61
|
+
# Only refresh OAuth2 credentials
|
|
62
|
+
if stored.auth_kind != AuthKind.OAUTH2:
|
|
63
|
+
return stored
|
|
64
|
+
|
|
65
|
+
creds = stored.credentials
|
|
66
|
+
if not isinstance(creds, OAuth2Credentials):
|
|
67
|
+
return stored
|
|
68
|
+
|
|
69
|
+
if not creds.is_expired():
|
|
70
|
+
return stored
|
|
71
|
+
|
|
72
|
+
if not creds.is_refreshable():
|
|
73
|
+
auth_logger.warning(
|
|
74
|
+
"Token expired but no refresh_token available: connection_id=%r",
|
|
75
|
+
connection_id,
|
|
76
|
+
)
|
|
77
|
+
return stored
|
|
78
|
+
|
|
79
|
+
auth_logger.info("Refreshing token: connection_id=%r", connection_id)
|
|
80
|
+
|
|
81
|
+
provider_impl = self._registry.get(stored.provider_key)
|
|
82
|
+
self._configure_provider_transport(provider_impl)
|
|
83
|
+
updated = await provider_impl.refresh_token(stored)
|
|
84
|
+
|
|
85
|
+
# Persist updated credentials
|
|
86
|
+
await self._credential_store.save_credentials(connection_id, updated)
|
|
87
|
+
|
|
88
|
+
auth_logger.info(
|
|
89
|
+
"Token refreshed successfully: connection_id=%r provider=%r",
|
|
90
|
+
connection_id,
|
|
91
|
+
stored.provider_key,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return updated
|
|
File without changes
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Connection represents authorization to a provider account.
|
|
3
|
+
|
|
4
|
+
Services are accessed from a connection.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections.abc import Awaitable, Callable
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
from omnidapter.core.metadata import ServiceKind
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import httpx
|
|
16
|
+
|
|
17
|
+
from omnidapter.core.registry import ProviderRegistry
|
|
18
|
+
from omnidapter.services.calendar.interface import CalendarService
|
|
19
|
+
from omnidapter.stores.credentials import StoredCredential
|
|
20
|
+
from omnidapter.transport.hooks import TransportHooks
|
|
21
|
+
from omnidapter.transport.retry import RetryPolicy
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Connection:
|
|
25
|
+
"""Represents an authorized connection to a provider account.
|
|
26
|
+
|
|
27
|
+
Services are accessed through a connection:
|
|
28
|
+
|
|
29
|
+
conn = await omni.connection("conn_123")
|
|
30
|
+
calendar = conn.calendar()
|
|
31
|
+
await calendar.list_calendars()
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
connection_id: str,
|
|
37
|
+
stored_credential: StoredCredential,
|
|
38
|
+
registry: ProviderRegistry,
|
|
39
|
+
retry_policy: RetryPolicy | None = None,
|
|
40
|
+
hooks: TransportHooks | None = None,
|
|
41
|
+
credential_resolver: Callable[[str], Awaitable[StoredCredential]] | None = None,
|
|
42
|
+
http_client: httpx.AsyncClient | None = None,
|
|
43
|
+
) -> None:
|
|
44
|
+
self._connection_id = connection_id
|
|
45
|
+
self._stored = stored_credential
|
|
46
|
+
self._registry = registry
|
|
47
|
+
self._retry_policy = retry_policy
|
|
48
|
+
self._hooks = hooks
|
|
49
|
+
self._credential_resolver = credential_resolver
|
|
50
|
+
self._http_client = http_client
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def connection_id(self) -> str:
|
|
54
|
+
return self._connection_id
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def provider_key(self) -> str:
|
|
58
|
+
return self._stored.provider_key
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def stored_credential(self) -> StoredCredential:
|
|
62
|
+
return self._stored
|
|
63
|
+
|
|
64
|
+
def supports(self, service: ServiceKind) -> bool:
|
|
65
|
+
"""Return True if the provider for this connection supports the given service."""
|
|
66
|
+
provider = self._registry.get(self._stored.provider_key)
|
|
67
|
+
return service in provider.metadata.services
|
|
68
|
+
|
|
69
|
+
def _configure_service_runtime(self, service: CalendarService) -> CalendarService:
|
|
70
|
+
service_runtime: Any = service
|
|
71
|
+
|
|
72
|
+
if self._credential_resolver is not None:
|
|
73
|
+
service_runtime._credential_resolver = self._credential_resolver
|
|
74
|
+
|
|
75
|
+
if self._http_client is not None:
|
|
76
|
+
transport = getattr(service_runtime, "_http", None)
|
|
77
|
+
set_shared_client = getattr(transport, "set_shared_client", None)
|
|
78
|
+
if callable(set_shared_client):
|
|
79
|
+
set_shared_client(self._http_client)
|
|
80
|
+
|
|
81
|
+
return service
|
|
82
|
+
|
|
83
|
+
def calendar(self) -> CalendarService:
|
|
84
|
+
"""Return the calendar service for this connection.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
UnsupportedCapabilityError: If the provider does not support calendars.
|
|
88
|
+
Use ``conn.supports(ServiceKind.CALENDAR)`` to check first.
|
|
89
|
+
"""
|
|
90
|
+
if not self.supports(ServiceKind.CALENDAR):
|
|
91
|
+
from omnidapter.core.errors import UnsupportedCapabilityError
|
|
92
|
+
|
|
93
|
+
raise UnsupportedCapabilityError(
|
|
94
|
+
f"Provider {self._stored.provider_key!r} does not support calendars. "
|
|
95
|
+
"Check conn.supports(ServiceKind.CALENDAR) before calling conn.calendar().",
|
|
96
|
+
provider_key=self._stored.provider_key,
|
|
97
|
+
capability=ServiceKind.CALENDAR,
|
|
98
|
+
)
|
|
99
|
+
provider = self._registry.get(self._stored.provider_key)
|
|
100
|
+
service = provider.get_calendar_service(
|
|
101
|
+
connection_id=self._connection_id,
|
|
102
|
+
stored_credential=self._stored,
|
|
103
|
+
retry_policy=self._retry_policy,
|
|
104
|
+
hooks=self._hooks,
|
|
105
|
+
)
|
|
106
|
+
return self._configure_service_runtime(service)
|