visionai-sdk-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- visionai_sdk_python/__init__.py +30 -0
- visionai_sdk_python/_base.py +203 -0
- visionai_sdk_python/_jwt_verifier.py +207 -0
- visionai_sdk_python/async_client.py +301 -0
- visionai_sdk_python/client.py +300 -0
- visionai_sdk_python/constants.py +49 -0
- visionai_sdk_python/endpoints.py +10 -0
- visionai_sdk_python/exceptions.py +40 -0
- visionai_sdk_python/models.py +39 -0
- visionai_sdk_python/py.typed +0 -0
- visionai_sdk_python-0.1.0.dist-info/METADATA +14 -0
- visionai_sdk_python-0.1.0.dist-info/RECORD +13 -0
- visionai_sdk_python-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""VisionAI SDK for Python.
|
|
2
|
+
|
|
3
|
+
A client library for interacting with VisionAI authentication and VLM services.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .async_client import AsyncClient
|
|
7
|
+
from .client import Client
|
|
8
|
+
from .exceptions import (
|
|
9
|
+
AuthenticationError,
|
|
10
|
+
ClientError,
|
|
11
|
+
JwksDiscoveryError,
|
|
12
|
+
NetworkError,
|
|
13
|
+
PermissionDeniedError,
|
|
14
|
+
ServerError,
|
|
15
|
+
VisionaiSDKError,
|
|
16
|
+
)
|
|
17
|
+
from .models import TokenResponse
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"AsyncClient",
|
|
21
|
+
"Client",
|
|
22
|
+
"TokenResponse",
|
|
23
|
+
"VisionaiSDKError",
|
|
24
|
+
"AuthenticationError",
|
|
25
|
+
"PermissionDeniedError",
|
|
26
|
+
"ClientError",
|
|
27
|
+
"ServerError",
|
|
28
|
+
"NetworkError",
|
|
29
|
+
"JwksDiscoveryError",
|
|
30
|
+
]
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import httpx
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .exceptions import (
|
|
6
|
+
APIError,
|
|
7
|
+
AuthenticationError,
|
|
8
|
+
ClientError,
|
|
9
|
+
PermissionDeniedError,
|
|
10
|
+
ServerError,
|
|
11
|
+
)
|
|
12
|
+
from ._jwt_verifier import JwtVerifier
|
|
13
|
+
from .constants import resolve_allowed_issuers
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _BaseClient:
|
|
17
|
+
"""Base class for VisionAI SDK clients.
|
|
18
|
+
|
|
19
|
+
Holds shared connection configuration and provides
|
|
20
|
+
common URL/header builder utilities.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
auth_url: str,
|
|
26
|
+
vlm_url: str,
|
|
27
|
+
allowed_issuers: list[str] | None = None,
|
|
28
|
+
verify_ssl: bool = True,
|
|
29
|
+
timeout: float = 10.0,
|
|
30
|
+
max_connections: int = 100,
|
|
31
|
+
max_keepalive_connections: int = 20,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Initialize the client with connection settings.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
auth_url: Base URL for the authentication service.
|
|
37
|
+
vlm_url: Base URL for the VLM inference service.
|
|
38
|
+
allowed_issuers: Optional list of allowed JWT issuers. If provided, tokens
|
|
39
|
+
whose ``iss`` claim is not in this list will be rejected. If omitted,
|
|
40
|
+
issuer validation is skipped.
|
|
41
|
+
verify_ssl: Whether to verify TLS certificates.
|
|
42
|
+
timeout: Default request timeout in seconds.
|
|
43
|
+
max_connections: Maximum number of concurrent connections in the pool.
|
|
44
|
+
max_keepalive_connections: Maximum number of idle keep-alive connections.
|
|
45
|
+
"""
|
|
46
|
+
if not auth_url.strip():
|
|
47
|
+
raise ValueError("auth_url must not be empty")
|
|
48
|
+
if not vlm_url.strip():
|
|
49
|
+
raise ValueError("vlm_url must not be empty")
|
|
50
|
+
self.auth_url = auth_url.strip()
|
|
51
|
+
self.vlm_url = vlm_url.strip()
|
|
52
|
+
self.verify_ssl = verify_ssl
|
|
53
|
+
self.timeout = timeout
|
|
54
|
+
self.max_connections = max_connections
|
|
55
|
+
self.max_keepalive_connections = max_keepalive_connections
|
|
56
|
+
resolved_issuers: list[str] = allowed_issuers if allowed_issuers is not None else resolve_allowed_issuers(self.auth_url)
|
|
57
|
+
self._jwt_verifier = JwtVerifier(
|
|
58
|
+
auth_url=self.auth_url,
|
|
59
|
+
allowed_issuers=resolved_issuers,
|
|
60
|
+
verify_ssl=verify_ssl,
|
|
61
|
+
timeout=timeout,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Token management state
|
|
65
|
+
self._access_token: str | None = None
|
|
66
|
+
self._token_expires_at: float | None = None # monotonic time
|
|
67
|
+
self._credentials: dict | None = None
|
|
68
|
+
self._credentials_type: Literal["login", "client"] | None = None
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def _build_url(base_url: str, path: str) -> str:
|
|
72
|
+
"""Join a base URL and a path, normalizing slashes."""
|
|
73
|
+
return f"{base_url.rstrip('/')}/{path.lstrip('/')}"
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _build_auth_header(access_token: str) -> dict[str, str]:
|
|
77
|
+
"""Build authorization header."""
|
|
78
|
+
if not access_token:
|
|
79
|
+
raise ValueError("access_token must not be empty")
|
|
80
|
+
return {"Authorization": f"Bearer {access_token}"}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _store_token(
|
|
84
|
+
self,
|
|
85
|
+
access_token: str,
|
|
86
|
+
expires_in: int,
|
|
87
|
+
credentials: dict | None = None,
|
|
88
|
+
credentials_type: Literal["login", "client"] | None = None,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Store token and credentials for auto-refresh.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
access_token: JWT access token.
|
|
94
|
+
expires_in: Token expiration time in seconds.
|
|
95
|
+
credentials: Optional credentials for auto-refresh.
|
|
96
|
+
credentials_type: Type of credentials ("login" or "client").
|
|
97
|
+
"""
|
|
98
|
+
self._access_token = access_token
|
|
99
|
+
self._token_expires_at = time.monotonic() + expires_in
|
|
100
|
+
self._credentials = credentials
|
|
101
|
+
self._credentials_type = credentials_type
|
|
102
|
+
|
|
103
|
+
def _is_token_expiring_soon(self, buffer_seconds: int = 30) -> bool:
|
|
104
|
+
"""Check if token is expiring soon.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
buffer_seconds: Time buffer in seconds before actual expiration.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
True if token is expiring within buffer_seconds, False otherwise.
|
|
111
|
+
"""
|
|
112
|
+
if self._token_expires_at is None:
|
|
113
|
+
return True
|
|
114
|
+
return time.monotonic() >= (self._token_expires_at - buffer_seconds)
|
|
115
|
+
|
|
116
|
+
def set_token(self, access_token: str, expires_in: int | None = None) -> None:
|
|
117
|
+
"""Set externally obtained token.
|
|
118
|
+
|
|
119
|
+
Use this when you have an access token obtained through other means
|
|
120
|
+
(e.g., frontend authorization code flow). Note that tokens set this way
|
|
121
|
+
cannot be auto-refreshed since no credentials are stored.
|
|
122
|
+
|
|
123
|
+
The token will be validated locally to ensure it has a valid signature
|
|
124
|
+
and has not expired. The actual expiration time from the token's ``exp``
|
|
125
|
+
claim will be used. If ``expires_in`` is provided, the minimum of the two
|
|
126
|
+
will be used to prevent extending the token's lifetime beyond its true expiration.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
access_token: JWT access token.
|
|
130
|
+
expires_in: Optional token expiration time in seconds. If provided,
|
|
131
|
+
the effective expiration will be min(expires_in, token's actual remaining time).
|
|
132
|
+
If None, the token's actual ``exp`` claim will be used.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
ValueError: If access_token is empty.
|
|
136
|
+
jwt.ExpiredSignatureError: If token has already expired.
|
|
137
|
+
jwt.InvalidSignatureError: If token signature is invalid.
|
|
138
|
+
jwt.DecodeError: If token is malformed.
|
|
139
|
+
jwt.MissingRequiredClaimError: If token is missing required claims.
|
|
140
|
+
jwt.InvalidIssuerError: If token issuer is not allowed.
|
|
141
|
+
|
|
142
|
+
Example:
|
|
143
|
+
>>> client = Client(auth_url="...", vlm_url="...")
|
|
144
|
+
>>> client.set_token("eyJhbG...")
|
|
145
|
+
>>> result = client.chat(payload) # Token will be used automatically
|
|
146
|
+
"""
|
|
147
|
+
if not access_token.strip():
|
|
148
|
+
raise ValueError("access_token must not be empty")
|
|
149
|
+
|
|
150
|
+
# Validate token and extract exp claim
|
|
151
|
+
claims = self._jwt_verifier.verify_sync(access_token)
|
|
152
|
+
jwt_exp = claims.get("exp")
|
|
153
|
+
if jwt_exp is None:
|
|
154
|
+
raise ValueError("Token missing 'exp' claim")
|
|
155
|
+
|
|
156
|
+
# Calculate remaining time from JWT exp
|
|
157
|
+
jwt_expires_in = int(jwt_exp - time.time())
|
|
158
|
+
if jwt_expires_in <= 0:
|
|
159
|
+
raise ValueError("Token has already expired")
|
|
160
|
+
|
|
161
|
+
# Use the minimum of provided expires_in and actual JWT expiration
|
|
162
|
+
if expires_in is not None:
|
|
163
|
+
effective_expires_in = min(expires_in, jwt_expires_in)
|
|
164
|
+
else:
|
|
165
|
+
effective_expires_in = jwt_expires_in
|
|
166
|
+
|
|
167
|
+
self._store_token(access_token, effective_expires_in, credentials=None, credentials_type=None)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def _handle_response(response: httpx.Response) -> httpx.Response:
|
|
171
|
+
"""Raise SDK-specific exceptions for non-2xx responses.
|
|
172
|
+
|
|
173
|
+
Raises:
|
|
174
|
+
AuthenticationError: If the server returns 401
|
|
175
|
+
PermissionDeniedError: If the server returns 403
|
|
176
|
+
ClientError: If the server returns any other 4xx
|
|
177
|
+
ServerError: If the server returns 5xx
|
|
178
|
+
APIError: If the server returns any other non-2xx status
|
|
179
|
+
"""
|
|
180
|
+
try:
|
|
181
|
+
response.raise_for_status()
|
|
182
|
+
except httpx.HTTPStatusError as e:
|
|
183
|
+
body = None
|
|
184
|
+
if e.response.content:
|
|
185
|
+
try:
|
|
186
|
+
body = e.response.json()
|
|
187
|
+
except Exception:
|
|
188
|
+
pass
|
|
189
|
+
if isinstance(body, dict):
|
|
190
|
+
detail: str = body.get("detail") or body.get("message") or str(e)
|
|
191
|
+
else:
|
|
192
|
+
detail = str(e)
|
|
193
|
+
status = e.response.status_code
|
|
194
|
+
if status == 401:
|
|
195
|
+
raise AuthenticationError(detail) from e
|
|
196
|
+
if status == 403:
|
|
197
|
+
raise PermissionDeniedError(detail) from e
|
|
198
|
+
if 400 <= status < 500:
|
|
199
|
+
raise ClientError(status, detail) from e
|
|
200
|
+
if 500 <= status < 600:
|
|
201
|
+
raise ServerError(status, detail) from e
|
|
202
|
+
raise APIError(status, detail)
|
|
203
|
+
return response
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import ssl
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
import jwt
|
|
7
|
+
from jwt import PyJWKClient
|
|
8
|
+
|
|
9
|
+
from .exceptions import JwksDiscoveryError
|
|
10
|
+
|
|
11
|
+
_OIDC_DISCOVERY_PATH = "/.well-known/openid-configuration"
|
|
12
|
+
_JWKS_URI_TTL: float = 3600.0 # seconds; refresh OIDC discovery cache after 1 hour
|
|
13
|
+
# Supported asymmetric signing algorithms:
|
|
14
|
+
# - RS256/384/512: RSA signature with SHA-256/384/512
|
|
15
|
+
# - ES256/384/512: ECDSA signature with SHA-256/384/512
|
|
16
|
+
_ALLOWED_ALGORITHMS: list[str] = ["RS256", "RS384", "RS512", "ES256", "ES384", "ES512"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _get_insecure_context() -> ssl.SSLContext:
|
|
20
|
+
ctx = ssl.create_default_context()
|
|
21
|
+
ctx.check_hostname = False
|
|
22
|
+
ctx.verify_mode = ssl.CERT_NONE
|
|
23
|
+
return ctx
|
|
24
|
+
|
|
25
|
+
class JwtVerifier:
|
|
26
|
+
"""Stateful JWT verifier that handles OIDC discovery and JWKS key fetching.
|
|
27
|
+
|
|
28
|
+
Maintains two internal caches:
|
|
29
|
+
- ``_jwks_uri_cache``: issuer → jwks_uri with TTL-based expiry.
|
|
30
|
+
- ``_jwks_clients``: jwks_uri → PyJWKClient (key-level caching delegated to PyJWKClient).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
auth_url: str,
|
|
36
|
+
allowed_issuers: list[str] | None = None,
|
|
37
|
+
verify_ssl: bool = True,
|
|
38
|
+
timeout: float = 10.0,
|
|
39
|
+
) -> None:
|
|
40
|
+
self._auth_url = auth_url
|
|
41
|
+
self._allowed_issuers: frozenset[str] = frozenset(allowed_issuers) if allowed_issuers else frozenset()
|
|
42
|
+
self._verify_ssl = verify_ssl
|
|
43
|
+
self._timeout = timeout
|
|
44
|
+
# issuer -> (jwks_uri, expire_at)
|
|
45
|
+
self._jwks_uri_cache: dict[str, tuple[str, float]] = {}
|
|
46
|
+
# jwks_uri -> PyJWKClient
|
|
47
|
+
self._jwks_clients: dict[str, PyJWKClient] = {}
|
|
48
|
+
|
|
49
|
+
# ------------------------------------------------------------------
|
|
50
|
+
# Cache helpers
|
|
51
|
+
# ------------------------------------------------------------------
|
|
52
|
+
|
|
53
|
+
def _get_cached_jwks_uri(self, issuer: str) -> str | None:
|
|
54
|
+
entry = self._jwks_uri_cache.get(issuer)
|
|
55
|
+
if entry is None:
|
|
56
|
+
return None
|
|
57
|
+
jwks_uri, expire_at = entry
|
|
58
|
+
if time.monotonic() >= expire_at:
|
|
59
|
+
del self._jwks_uri_cache[issuer]
|
|
60
|
+
return None
|
|
61
|
+
return jwks_uri
|
|
62
|
+
|
|
63
|
+
def _cache_jwks_uri(self, issuer: str, jwks_uri: str) -> None:
|
|
64
|
+
self._jwks_uri_cache[issuer] = (jwks_uri, time.monotonic() + _JWKS_URI_TTL)
|
|
65
|
+
|
|
66
|
+
def _validate_issuer(self, issuer: str) -> None:
|
|
67
|
+
"""Validate issuer is in the allowed issuers list.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
jwt.InvalidIssuerError: If allowed_issuers is set and issuer is not in the list.
|
|
71
|
+
"""
|
|
72
|
+
normalized = issuer.rstrip("/")
|
|
73
|
+
if self._allowed_issuers and normalized not in self._allowed_issuers:
|
|
74
|
+
raise jwt.InvalidIssuerError(
|
|
75
|
+
f"Token issuer '{normalized}' is not in the allowed issuers list"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _get_jwks_client(self, jwks_uri: str) -> PyJWKClient:
|
|
79
|
+
if jwks_uri not in self._jwks_clients:
|
|
80
|
+
ssl_context = None if self._verify_ssl else _get_insecure_context()
|
|
81
|
+
self._jwks_clients[jwks_uri] = PyJWKClient(
|
|
82
|
+
jwks_uri,
|
|
83
|
+
cache_keys=True,
|
|
84
|
+
timeout=self._timeout,
|
|
85
|
+
ssl_context=ssl_context,
|
|
86
|
+
)
|
|
87
|
+
return self._jwks_clients[jwks_uri]
|
|
88
|
+
|
|
89
|
+
# ------------------------------------------------------------------
|
|
90
|
+
# Pure JWT helpers (no I/O)
|
|
91
|
+
# ------------------------------------------------------------------
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def _get_issuer(access_token: str) -> str:
|
|
95
|
+
"""Return issuer from the token without verifying signature."""
|
|
96
|
+
claims = jwt.decode(
|
|
97
|
+
access_token,
|
|
98
|
+
options={"verify_signature": False, "verify_exp": False},
|
|
99
|
+
algorithms=_ALLOWED_ALGORITHMS,
|
|
100
|
+
)
|
|
101
|
+
issuer: str = claims.get("iss", "")
|
|
102
|
+
if not issuer:
|
|
103
|
+
raise jwt.MissingRequiredClaimError("iss")
|
|
104
|
+
return issuer
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def _decode_verified(access_token: str, signing_key: object) -> dict:
|
|
108
|
+
"""Verify signature + exp and return claims."""
|
|
109
|
+
return jwt.decode(
|
|
110
|
+
access_token,
|
|
111
|
+
signing_key, # type: ignore[arg-type]
|
|
112
|
+
algorithms=_ALLOWED_ALGORITHMS,
|
|
113
|
+
options={"verify_exp": True, "verify_aud": False},
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# ------------------------------------------------------------------
|
|
117
|
+
# OIDC discovery: sync / async
|
|
118
|
+
# ------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
def _fetch_jwks_uri_sync(self, issuer: str) -> str:
|
|
121
|
+
"""Resolve jwks_uri from the OIDC discovery document (blocking).
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
JwksDiscoveryError: If the discovery endpoint is unreachable, returns a
|
|
125
|
+
non-2xx status, or the response does not contain ``jwks_uri``.
|
|
126
|
+
"""
|
|
127
|
+
cached = self._get_cached_jwks_uri(issuer)
|
|
128
|
+
if cached is not None:
|
|
129
|
+
return cached
|
|
130
|
+
|
|
131
|
+
discovery_url = f"{issuer.rstrip('/')}{_OIDC_DISCOVERY_PATH}"
|
|
132
|
+
try:
|
|
133
|
+
with httpx.Client(verify=self._verify_ssl, timeout=self._timeout) as http:
|
|
134
|
+
resp = http.get(discovery_url)
|
|
135
|
+
resp.raise_for_status()
|
|
136
|
+
jwks_uri: str = resp.json()["jwks_uri"]
|
|
137
|
+
except (httpx.RequestError, httpx.HTTPStatusError, KeyError) as e:
|
|
138
|
+
raise JwksDiscoveryError(
|
|
139
|
+
f"Failed to fetch OIDC discovery document from '{discovery_url}': {e}"
|
|
140
|
+
) from e
|
|
141
|
+
|
|
142
|
+
self._cache_jwks_uri(issuer, jwks_uri)
|
|
143
|
+
return jwks_uri
|
|
144
|
+
|
|
145
|
+
async def _fetch_jwks_uri_async(self, issuer: str) -> str:
|
|
146
|
+
"""Resolve jwks_uri from the OIDC discovery document (non-blocking).
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
JwksDiscoveryError: If the discovery endpoint is unreachable, returns a
|
|
150
|
+
non-2xx status, or the response does not contain ``jwks_uri``.
|
|
151
|
+
"""
|
|
152
|
+
cached = self._get_cached_jwks_uri(issuer)
|
|
153
|
+
if cached is not None:
|
|
154
|
+
return cached
|
|
155
|
+
|
|
156
|
+
discovery_url = f"{issuer.rstrip('/')}{_OIDC_DISCOVERY_PATH}"
|
|
157
|
+
try:
|
|
158
|
+
async with httpx.AsyncClient(verify=self._verify_ssl, timeout=self._timeout) as http:
|
|
159
|
+
resp = await http.get(discovery_url)
|
|
160
|
+
resp.raise_for_status()
|
|
161
|
+
jwks_uri: str = resp.json()["jwks_uri"]
|
|
162
|
+
except (httpx.RequestError, httpx.HTTPStatusError, KeyError) as e:
|
|
163
|
+
raise JwksDiscoveryError(
|
|
164
|
+
f"Failed to fetch OIDC discovery document from '{discovery_url}': {e}"
|
|
165
|
+
) from e
|
|
166
|
+
|
|
167
|
+
self._cache_jwks_uri(issuer, jwks_uri)
|
|
168
|
+
return jwks_uri
|
|
169
|
+
|
|
170
|
+
# ------------------------------------------------------------------
|
|
171
|
+
# Public verify API
|
|
172
|
+
# ------------------------------------------------------------------
|
|
173
|
+
|
|
174
|
+
def verify_sync(self, access_token: str) -> dict:
|
|
175
|
+
"""Validate JWT signature and expiration (blocking).
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
jwt.ExpiredSignatureError: Token has expired.
|
|
179
|
+
jwt.InvalidSignatureError: Signature does not match.
|
|
180
|
+
jwt.DecodeError: Token is malformed.
|
|
181
|
+
jwt.MissingRequiredClaimError: Token is missing the ``iss`` claim.
|
|
182
|
+
jwt.InvalidIssuerError: Token issuer is not in the allowed issuers list.
|
|
183
|
+
"""
|
|
184
|
+
issuer = self._get_issuer(access_token)
|
|
185
|
+
self._validate_issuer(issuer)
|
|
186
|
+
jwks_uri = self._fetch_jwks_uri_sync(issuer)
|
|
187
|
+
signing_key = self._get_jwks_client(jwks_uri).get_signing_key_from_jwt(access_token)
|
|
188
|
+
return self._decode_verified(access_token, signing_key.key)
|
|
189
|
+
|
|
190
|
+
async def verify_async(self, access_token: str) -> dict:
|
|
191
|
+
"""Validate JWT signature and expiration (non-blocking).
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
jwt.ExpiredSignatureError: Token has expired.
|
|
195
|
+
jwt.InvalidSignatureError: Signature does not match.
|
|
196
|
+
jwt.DecodeError: Token is malformed.
|
|
197
|
+
jwt.MissingRequiredClaimError: Token is missing the ``iss`` claim.
|
|
198
|
+
jwt.InvalidIssuerError: Token issuer is not in the allowed issuers list.
|
|
199
|
+
"""
|
|
200
|
+
issuer = self._get_issuer(access_token)
|
|
201
|
+
self._validate_issuer(issuer)
|
|
202
|
+
jwks_uri = await self._fetch_jwks_uri_async(issuer)
|
|
203
|
+
# PyJWKClient.get_signing_key_from_jwt does blocking HTTP; offload to thread pool
|
|
204
|
+
signing_key = await asyncio.to_thread(
|
|
205
|
+
self._get_jwks_client(jwks_uri).get_signing_key_from_jwt, access_token
|
|
206
|
+
)
|
|
207
|
+
return self._decode_verified(access_token, signing_key.key)
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import httpx
|
|
3
|
+
import jwt
|
|
4
|
+
|
|
5
|
+
from ._base import _BaseClient
|
|
6
|
+
from .endpoints import AuthEndpoint, VLMEndpoint
|
|
7
|
+
from .models import TokenResponse
|
|
8
|
+
from .exceptions import AuthenticationError, JwksDiscoveryError, NetworkError, VisionaiSDKError
|
|
9
|
+
from .models import TokenResponse, NIMRequestModel, ResponseNormalModel, ResponseErrorModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AsyncClient(_BaseClient):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
auth_url: str,
|
|
19
|
+
vlm_url: str,
|
|
20
|
+
allowed_issuers: list[str] | None = None,
|
|
21
|
+
verify_ssl: bool = True,
|
|
22
|
+
timeout: float = 10.0,
|
|
23
|
+
max_connections: int = 100,
|
|
24
|
+
max_keepalive_connections: int = 20,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(
|
|
27
|
+
auth_url=auth_url,
|
|
28
|
+
vlm_url=vlm_url,
|
|
29
|
+
allowed_issuers=allowed_issuers,
|
|
30
|
+
verify_ssl=verify_ssl,
|
|
31
|
+
timeout=timeout,
|
|
32
|
+
max_connections=max_connections,
|
|
33
|
+
max_keepalive_connections=max_keepalive_connections,
|
|
34
|
+
)
|
|
35
|
+
self._client = httpx.AsyncClient(
|
|
36
|
+
verify=self.verify_ssl,
|
|
37
|
+
timeout=self.timeout,
|
|
38
|
+
limits=httpx.Limits(
|
|
39
|
+
max_connections=self.max_connections,
|
|
40
|
+
max_keepalive_connections=self.max_keepalive_connections,
|
|
41
|
+
),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
46
|
+
"""Execute an async HTTP request, mapping httpx exceptions to SDK exceptions."""
|
|
47
|
+
try:
|
|
48
|
+
response = await self._client.request(method, url, **kwargs)
|
|
49
|
+
except httpx.TimeoutException as e:
|
|
50
|
+
raise NetworkError("Request timed out") from e
|
|
51
|
+
except httpx.NetworkError as e:
|
|
52
|
+
raise NetworkError(f"Network error: {e}") from e
|
|
53
|
+
except httpx.RequestError as e:
|
|
54
|
+
raise VisionaiSDKError(f"Request failed: {e}") from e
|
|
55
|
+
return self._handle_response(response)
|
|
56
|
+
|
|
57
|
+
async def close(self) -> None:
|
|
58
|
+
"""Close the HTTP client and release connections."""
|
|
59
|
+
await self._client.aclose()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
async def __aenter__(self) -> "AsyncClient":
|
|
63
|
+
"""Async context manager entry."""
|
|
64
|
+
return self
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
68
|
+
"""Async context manager exit - close client."""
|
|
69
|
+
await self.close()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
async def get_access_token(self, client_id: str, client_secret: str) -> TokenResponse:
|
|
73
|
+
"""Get access token using client credentials flow.
|
|
74
|
+
|
|
75
|
+
The token is stored internally and will be used automatically for subsequent
|
|
76
|
+
API calls. If the token expires, it will be automatically refreshed.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
client_id: OAuth client ID
|
|
80
|
+
client_secret: OAuth client secret
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
TokenResponse with access_token, expires_in, and token_type
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: If client_id or client_secret is empty
|
|
87
|
+
NetworkError: If the request times out or a network error occurs
|
|
88
|
+
VisionaiSDKError: If the request fails for any other reason
|
|
89
|
+
"""
|
|
90
|
+
if not client_id.strip():
|
|
91
|
+
raise ValueError("client_id must not be empty")
|
|
92
|
+
if not client_secret.strip():
|
|
93
|
+
raise ValueError("client_secret must not be empty")
|
|
94
|
+
|
|
95
|
+
response = await self._request(
|
|
96
|
+
"POST",
|
|
97
|
+
self._build_url(self.auth_url, AuthEndpoint.CLIENT_TOKEN),
|
|
98
|
+
json={"client_id": client_id, "client_secret": client_secret},
|
|
99
|
+
)
|
|
100
|
+
token_response = TokenResponse(**response.json())
|
|
101
|
+
|
|
102
|
+
# Store token and credentials for auto-refresh
|
|
103
|
+
self._store_token(
|
|
104
|
+
access_token=token_response.access_token,
|
|
105
|
+
expires_in=token_response.expires_in,
|
|
106
|
+
credentials={"client_id": client_id, "client_secret": client_secret},
|
|
107
|
+
credentials_type="client",
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return token_response
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
async def login(self, email: str, password: str) -> TokenResponse:
|
|
114
|
+
"""Login with email and password to get JWT token.
|
|
115
|
+
|
|
116
|
+
The token is stored internally and will be used automatically for subsequent
|
|
117
|
+
API calls. If the token expires, it will be automatically refreshed.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
email: User email address
|
|
121
|
+
password: User password
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
TokenResponse with access_token, expires_in, and token_type
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
ValueError: If email or password is empty
|
|
128
|
+
NetworkError: If the request times out or a network error occurs
|
|
129
|
+
VisionaiSDKError: If the request fails for any other reason
|
|
130
|
+
"""
|
|
131
|
+
if not email.strip():
|
|
132
|
+
raise ValueError("email must not be empty")
|
|
133
|
+
if not password.strip():
|
|
134
|
+
raise ValueError("password must not be empty")
|
|
135
|
+
|
|
136
|
+
response = await self._request(
|
|
137
|
+
"POST",
|
|
138
|
+
self._build_url(self.auth_url, AuthEndpoint.LOGIN),
|
|
139
|
+
json={"email": email, "password": password},
|
|
140
|
+
)
|
|
141
|
+
token_response = TokenResponse(**response.json())
|
|
142
|
+
|
|
143
|
+
# Store token and credentials for auto-refresh
|
|
144
|
+
self._store_token(
|
|
145
|
+
access_token=token_response.access_token,
|
|
146
|
+
expires_in=token_response.expires_in,
|
|
147
|
+
credentials={"email": email, "password": password},
|
|
148
|
+
credentials_type="login",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
return token_response
|
|
152
|
+
|
|
153
|
+
async def _refresh_token(self) -> None:
|
|
154
|
+
"""Refresh token using stored credentials.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
AuthenticationError: If no credentials are stored or refresh fails.
|
|
158
|
+
"""
|
|
159
|
+
if self._credentials is None or self._credentials_type is None:
|
|
160
|
+
raise AuthenticationError("No credentials available for token refresh")
|
|
161
|
+
|
|
162
|
+
if self._credentials_type == "login":
|
|
163
|
+
await self.login(
|
|
164
|
+
email=self._credentials["email"],
|
|
165
|
+
password=self._credentials["password"],
|
|
166
|
+
)
|
|
167
|
+
elif self._credentials_type == "client":
|
|
168
|
+
await self.get_access_token(
|
|
169
|
+
client_id=self._credentials["client_id"],
|
|
170
|
+
client_secret=self._credentials["client_secret"],
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
async def _ensure_token(self) -> None:
|
|
174
|
+
"""Ensure a valid token is available, refreshing if necessary.
|
|
175
|
+
|
|
176
|
+
Raises:
|
|
177
|
+
AuthenticationError: If no token is available or token expired without credentials.
|
|
178
|
+
"""
|
|
179
|
+
if self._access_token is None:
|
|
180
|
+
raise AuthenticationError("Not authenticated. Call login() or get_access_token() first.")
|
|
181
|
+
|
|
182
|
+
if self._is_token_expiring_soon():
|
|
183
|
+
if self._credentials is None:
|
|
184
|
+
raise AuthenticationError("Token expired and no credentials available for refresh")
|
|
185
|
+
await self._refresh_token()
|
|
186
|
+
|
|
187
|
+
async def is_token_valid(self, access_token: str) -> bool:
|
|
188
|
+
"""Check whether a JWT access token is currently valid.
|
|
189
|
+
|
|
190
|
+
Validates the token's signature and expiration without raising exceptions.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
access_token: JWT access token to validate.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
``True`` if the token passes signature and expiration checks,
|
|
197
|
+
``False`` otherwise (invalid token or JWKS service unavailable).
|
|
198
|
+
|
|
199
|
+
Note:
|
|
200
|
+
Logs token validation failures. Unexpected errors will propagate
|
|
201
|
+
to allow fail-fast behavior for programming errors.
|
|
202
|
+
"""
|
|
203
|
+
try:
|
|
204
|
+
await self._jwt_verifier.verify_async(access_token)
|
|
205
|
+
return True
|
|
206
|
+
except jwt.InvalidTokenError as e:
|
|
207
|
+
# Expected: expired, malformed, invalid signature, missing claims
|
|
208
|
+
logger.warning(
|
|
209
|
+
"%s: Token validation failed",
|
|
210
|
+
type(e).__name__,
|
|
211
|
+
extra={"jwt_error_type": type(e).__name__, "jwt_error_message": str(e)}
|
|
212
|
+
)
|
|
213
|
+
return False
|
|
214
|
+
except jwt.PyJWKClientError as e:
|
|
215
|
+
# Expected: JWKS endpoint unavailable, network issues
|
|
216
|
+
logger.error(
|
|
217
|
+
"%s: JWKS client error during token validation",
|
|
218
|
+
type(e).__name__,
|
|
219
|
+
extra={"jwt_error_type": "PyJWKClientError", "jwt_error_message": str(e)}
|
|
220
|
+
)
|
|
221
|
+
return False
|
|
222
|
+
except JwksDiscoveryError as e:
|
|
223
|
+
# Expected: OIDC discovery endpoint unreachable or returned unexpected response
|
|
224
|
+
logger.error(
|
|
225
|
+
"%s: OIDC discovery failed during token validation",
|
|
226
|
+
type(e).__name__,
|
|
227
|
+
extra={"jwt_error_type": type(e).__name__, "jwt_error_message": str(e)}
|
|
228
|
+
)
|
|
229
|
+
return False
|
|
230
|
+
|
|
231
|
+
async def chat(
|
|
232
|
+
self,
|
|
233
|
+
payload: NIMRequestModel | dict,
|
|
234
|
+
) -> ResponseNormalModel | ResponseErrorModel:
|
|
235
|
+
"""Submit an inference request to the VLM service.
|
|
236
|
+
|
|
237
|
+
Uses the internally stored access token obtained from login() or get_access_token().
|
|
238
|
+
If the token is expiring soon, it will be automatically refreshed.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
payload: Inference parameters as a NIMRequestModel instance or a dict
|
|
242
|
+
whose keys match NIMRequestModel fields (validated via model_validate).
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
ResponseNormalModel if the request is accepted (status: pending/running/completed),
|
|
246
|
+
or ResponseErrorModel if the inference failed or timed out.
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
ValidationError: If payload is a dict that fails NIMRequestModel validation.
|
|
250
|
+
AuthenticationError: If not authenticated or token expired without refresh credentials.
|
|
251
|
+
NetworkError: If the request times out or a network error occurs.
|
|
252
|
+
VisionaiSDKError: If the request fails for any other reason.
|
|
253
|
+
"""
|
|
254
|
+
await self._ensure_token()
|
|
255
|
+
|
|
256
|
+
nim_request = (
|
|
257
|
+
NIMRequestModel.model_validate(payload)
|
|
258
|
+
if isinstance(payload, dict)
|
|
259
|
+
else payload
|
|
260
|
+
)
|
|
261
|
+
response = await self._request(
|
|
262
|
+
"POST",
|
|
263
|
+
self._build_url(self.vlm_url, VLMEndpoint.CHAT),
|
|
264
|
+
headers=self._build_auth_header(self._access_token),
|
|
265
|
+
json=nim_request.model_dump(mode="json"),
|
|
266
|
+
)
|
|
267
|
+
data = response.json()
|
|
268
|
+
if data.get("status") in ("failed", "timeout"):
|
|
269
|
+
return ResponseErrorModel(**data)
|
|
270
|
+
return ResponseNormalModel(**data)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
async def get_chat(self, result_id: str) -> ResponseNormalModel | ResponseErrorModel:
|
|
274
|
+
"""Poll the result of a previously submitted inference request.
|
|
275
|
+
|
|
276
|
+
Uses the internally stored access token obtained from login() or get_access_token().
|
|
277
|
+
If the token is expiring soon, it will be automatically refreshed.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
result_id: Chat result ID returned from a prior chat() call.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
ResponseNormalModel if the result is available (status: pending/running/completed),
|
|
284
|
+
or ResponseErrorModel if the inference failed or timed out.
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
AuthenticationError: If not authenticated or token expired without refresh credentials.
|
|
288
|
+
NetworkError: If the request times out or a network error occurs.
|
|
289
|
+
VisionaiSDKError: If the request fails for any other reason.
|
|
290
|
+
"""
|
|
291
|
+
await self._ensure_token()
|
|
292
|
+
|
|
293
|
+
response = await self._request(
|
|
294
|
+
"GET",
|
|
295
|
+
self._build_url(self.vlm_url, f"{VLMEndpoint.CHAT}/{result_id}"),
|
|
296
|
+
headers=self._build_auth_header(self._access_token),
|
|
297
|
+
)
|
|
298
|
+
data = response.json()
|
|
299
|
+
if data.get("status") in ("failed", "timeout"):
|
|
300
|
+
return ResponseErrorModel(**data)
|
|
301
|
+
return ResponseNormalModel(**data)
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import httpx
|
|
3
|
+
import jwt
|
|
4
|
+
|
|
5
|
+
from ._base import _BaseClient
|
|
6
|
+
from .endpoints import AuthEndpoint, VLMEndpoint
|
|
7
|
+
from .exceptions import AuthenticationError, JwksDiscoveryError, NetworkError, VisionaiSDKError
|
|
8
|
+
from .models import TokenResponse, NIMRequestModel, ResponseNormalModel, ResponseErrorModel
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Client(_BaseClient):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
auth_url: str,
|
|
17
|
+
vlm_url: str,
|
|
18
|
+
allowed_issuers: list[str] | None = None,
|
|
19
|
+
verify_ssl: bool = True,
|
|
20
|
+
timeout: float = 10.0,
|
|
21
|
+
max_connections: int = 100,
|
|
22
|
+
max_keepalive_connections: int = 20,
|
|
23
|
+
) -> None:
|
|
24
|
+
super().__init__(
|
|
25
|
+
auth_url=auth_url,
|
|
26
|
+
vlm_url=vlm_url,
|
|
27
|
+
allowed_issuers=allowed_issuers,
|
|
28
|
+
verify_ssl=verify_ssl,
|
|
29
|
+
timeout=timeout,
|
|
30
|
+
max_connections=max_connections,
|
|
31
|
+
max_keepalive_connections=max_keepalive_connections,
|
|
32
|
+
)
|
|
33
|
+
self._client = httpx.Client(
|
|
34
|
+
verify=self.verify_ssl,
|
|
35
|
+
timeout=self.timeout,
|
|
36
|
+
limits=httpx.Limits(
|
|
37
|
+
max_connections=self.max_connections,
|
|
38
|
+
max_keepalive_connections=self.max_keepalive_connections,
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
|
44
|
+
"""Execute an HTTP request, mapping httpx exceptions to SDK exceptions."""
|
|
45
|
+
try:
|
|
46
|
+
response = self._client.request(method, url, **kwargs)
|
|
47
|
+
except httpx.TimeoutException as e:
|
|
48
|
+
raise NetworkError("Request timed out") from e
|
|
49
|
+
except httpx.NetworkError as e:
|
|
50
|
+
raise NetworkError(f"Network error: {e}") from e
|
|
51
|
+
except httpx.RequestError as e:
|
|
52
|
+
raise VisionaiSDKError(f"Request failed: {e}") from e
|
|
53
|
+
return self._handle_response(response)
|
|
54
|
+
|
|
55
|
+
def close(self) -> None:
|
|
56
|
+
"""Close the HTTP client and release connections."""
|
|
57
|
+
self._client.close()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def __enter__(self) -> "Client":
|
|
61
|
+
"""Context manager entry."""
|
|
62
|
+
return self
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
66
|
+
"""Context manager exit - close client."""
|
|
67
|
+
self.close()
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_access_token(self, client_id: str, client_secret: str) -> TokenResponse:
|
|
71
|
+
"""Get access token using client credentials flow.
|
|
72
|
+
|
|
73
|
+
The token is stored internally and will be used automatically for subsequent
|
|
74
|
+
API calls. If the token expires, it will be automatically refreshed.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
client_id: OAuth client ID
|
|
78
|
+
client_secret: OAuth client secret
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
TokenResponse with access_token, expires_in, and token_type
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If client_id or client_secret is empty
|
|
85
|
+
NetworkError: If the request times out or a network error occurs
|
|
86
|
+
VisionaiSDKError: If the request fails for any other reason
|
|
87
|
+
"""
|
|
88
|
+
if not client_id.strip():
|
|
89
|
+
raise ValueError("client_id must not be empty")
|
|
90
|
+
if not client_secret.strip():
|
|
91
|
+
raise ValueError("client_secret must not be empty")
|
|
92
|
+
|
|
93
|
+
response = self._request(
|
|
94
|
+
"POST",
|
|
95
|
+
self._build_url(self.auth_url, AuthEndpoint.CLIENT_TOKEN),
|
|
96
|
+
json={"client_id": client_id, "client_secret": client_secret},
|
|
97
|
+
)
|
|
98
|
+
token_response = TokenResponse(**response.json())
|
|
99
|
+
|
|
100
|
+
# Store token and credentials for auto-refresh
|
|
101
|
+
self._store_token(
|
|
102
|
+
access_token=token_response.access_token,
|
|
103
|
+
expires_in=token_response.expires_in,
|
|
104
|
+
credentials={"client_id": client_id, "client_secret": client_secret},
|
|
105
|
+
credentials_type="client",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return token_response
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def login(self, email: str, password: str) -> TokenResponse:
|
|
112
|
+
"""Login with email and password to get JWT token.
|
|
113
|
+
|
|
114
|
+
The token is stored internally and will be used automatically for subsequent
|
|
115
|
+
API calls. If the token expires, it will be automatically refreshed.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
email: User email address
|
|
119
|
+
password: User password
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
TokenResponse with access_token, expires_in, and token_type
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
ValueError: If email or password is empty
|
|
126
|
+
NetworkError: If the request times out or a network error occurs
|
|
127
|
+
VisionaiSDKError: If the request fails for any other reason
|
|
128
|
+
"""
|
|
129
|
+
if not email.strip():
|
|
130
|
+
raise ValueError("email must not be empty")
|
|
131
|
+
if not password.strip():
|
|
132
|
+
raise ValueError("password must not be empty")
|
|
133
|
+
|
|
134
|
+
response = self._request(
|
|
135
|
+
"POST",
|
|
136
|
+
self._build_url(self.auth_url, AuthEndpoint.LOGIN),
|
|
137
|
+
json={"email": email, "password": password},
|
|
138
|
+
)
|
|
139
|
+
token_response = TokenResponse(**response.json())
|
|
140
|
+
|
|
141
|
+
# Store token and credentials for auto-refresh
|
|
142
|
+
self._store_token(
|
|
143
|
+
access_token=token_response.access_token,
|
|
144
|
+
expires_in=token_response.expires_in,
|
|
145
|
+
credentials={"email": email, "password": password},
|
|
146
|
+
credentials_type="login",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return token_response
|
|
150
|
+
|
|
151
|
+
def _refresh_token(self) -> None:
|
|
152
|
+
"""Refresh token using stored credentials.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
AuthenticationError: If no credentials are stored or refresh fails.
|
|
156
|
+
"""
|
|
157
|
+
if self._credentials is None or self._credentials_type is None:
|
|
158
|
+
raise AuthenticationError("No credentials available for token refresh")
|
|
159
|
+
|
|
160
|
+
if self._credentials_type == "login":
|
|
161
|
+
self.login(
|
|
162
|
+
email=self._credentials["email"],
|
|
163
|
+
password=self._credentials["password"],
|
|
164
|
+
)
|
|
165
|
+
elif self._credentials_type == "client":
|
|
166
|
+
self.get_access_token(
|
|
167
|
+
client_id=self._credentials["client_id"],
|
|
168
|
+
client_secret=self._credentials["client_secret"],
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def _ensure_token(self) -> None:
|
|
172
|
+
"""Ensure a valid token is available, refreshing if necessary.
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
AuthenticationError: If no token is available or token expired without credentials.
|
|
176
|
+
"""
|
|
177
|
+
if self._access_token is None:
|
|
178
|
+
raise AuthenticationError("Not authenticated. Call login() or get_access_token() first.")
|
|
179
|
+
|
|
180
|
+
if self._is_token_expiring_soon():
|
|
181
|
+
if self._credentials is None:
|
|
182
|
+
raise AuthenticationError("Token expired and no credentials available for refresh")
|
|
183
|
+
self._refresh_token()
|
|
184
|
+
|
|
185
|
+
def is_token_valid(self, access_token: str) -> bool:
|
|
186
|
+
"""Check whether a JWT access token is currently valid.
|
|
187
|
+
|
|
188
|
+
Validates the token's signature and expiration without raising exceptions.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
access_token: JWT access token to validate.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
``True`` if the token passes signature and expiration checks,
|
|
195
|
+
``False`` otherwise (invalid token or JWKS service unavailable).
|
|
196
|
+
|
|
197
|
+
Note:
|
|
198
|
+
Logs token validation failures. Unexpected errors will propagate
|
|
199
|
+
to allow fail-fast behavior for programming errors.
|
|
200
|
+
"""
|
|
201
|
+
try:
|
|
202
|
+
self._jwt_verifier.verify_sync(access_token)
|
|
203
|
+
return True
|
|
204
|
+
except jwt.InvalidTokenError as e:
|
|
205
|
+
# Expected: expired, malformed, invalid signature, missing claims
|
|
206
|
+
logger.warning(
|
|
207
|
+
"%s: Token validation failed",
|
|
208
|
+
type(e).__name__,
|
|
209
|
+
extra={"jwt_error_type": type(e).__name__, "jwt_error_message": str(e)}
|
|
210
|
+
)
|
|
211
|
+
return False
|
|
212
|
+
except jwt.PyJWKClientError as e:
|
|
213
|
+
# Expected: JWKS endpoint unavailable, network issues
|
|
214
|
+
logger.error(
|
|
215
|
+
"%s: JWKS client error during token validation",
|
|
216
|
+
type(e).__name__,
|
|
217
|
+
extra={"jwt_error_type": "PyJWKClientError", "jwt_error_message": str(e)}
|
|
218
|
+
)
|
|
219
|
+
return False
|
|
220
|
+
except JwksDiscoveryError as e:
|
|
221
|
+
# Expected: OIDC discovery endpoint unreachable or returned unexpected response
|
|
222
|
+
logger.error(
|
|
223
|
+
"%s: OIDC discovery failed during token validation",
|
|
224
|
+
type(e).__name__,
|
|
225
|
+
extra={"jwt_error_type": type(e).__name__, "jwt_error_message": str(e)}
|
|
226
|
+
)
|
|
227
|
+
return False
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def chat(
|
|
231
|
+
self,
|
|
232
|
+
payload: NIMRequestModel | dict,
|
|
233
|
+
) -> ResponseNormalModel | ResponseErrorModel:
|
|
234
|
+
"""Submit an inference request to the VLM service.
|
|
235
|
+
|
|
236
|
+
Uses the internally stored access token obtained from login() or get_access_token().
|
|
237
|
+
If the token is expiring soon, it will be automatically refreshed.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
payload: Inference parameters as a NIMRequestModel instance or a dict
|
|
241
|
+
whose keys match NIMRequestModel fields (validated via model_validate).
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
ResponseNormalModel if the request is accepted (status: pending/running/completed),
|
|
245
|
+
or ResponseErrorModel if the inference failed or timed out.
|
|
246
|
+
|
|
247
|
+
Raises:
|
|
248
|
+
ValidationError: If payload is a dict that fails NIMRequestModel validation.
|
|
249
|
+
AuthenticationError: If not authenticated or token expired without refresh credentials.
|
|
250
|
+
NetworkError: If the request times out or a network error occurs.
|
|
251
|
+
VisionaiSDKError: If the request fails for any other reason.
|
|
252
|
+
"""
|
|
253
|
+
self._ensure_token()
|
|
254
|
+
|
|
255
|
+
nim_request = (
|
|
256
|
+
NIMRequestModel.model_validate(payload)
|
|
257
|
+
if isinstance(payload, dict)
|
|
258
|
+
else payload
|
|
259
|
+
)
|
|
260
|
+
response = self._request(
|
|
261
|
+
"POST",
|
|
262
|
+
self._build_url(self.vlm_url, VLMEndpoint.CHAT),
|
|
263
|
+
headers=self._build_auth_header(self._access_token),
|
|
264
|
+
json=nim_request.model_dump(mode="json"),
|
|
265
|
+
)
|
|
266
|
+
data = response.json()
|
|
267
|
+
if data.get("status") in ("failed", "timeout"):
|
|
268
|
+
return ResponseErrorModel(**data)
|
|
269
|
+
return ResponseNormalModel(**data)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def get_chat(self, result_id: str) -> ResponseNormalModel | ResponseErrorModel:
|
|
273
|
+
"""Poll the result of a previously submitted inference request.
|
|
274
|
+
|
|
275
|
+
Uses the internally stored access token obtained from login() or get_access_token().
|
|
276
|
+
If the token is expiring soon, it will be automatically refreshed.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
result_id: Chat result ID returned from a prior chat() call.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
ResponseNormalModel if the result is available (status: pending/running/completed),
|
|
283
|
+
or ResponseErrorModel if the inference failed or timed out.
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
AuthenticationError: If not authenticated or token expired without refresh credentials.
|
|
287
|
+
NetworkError: If the request times out or a network error occurs.
|
|
288
|
+
VisionaiSDKError: If the request fails for any other reason.
|
|
289
|
+
"""
|
|
290
|
+
self._ensure_token()
|
|
291
|
+
|
|
292
|
+
response = self._request(
|
|
293
|
+
"GET",
|
|
294
|
+
self._build_url(self.vlm_url, f"{VLMEndpoint.CHAT}/{result_id}"),
|
|
295
|
+
headers=self._build_auth_header(self._access_token),
|
|
296
|
+
)
|
|
297
|
+
data = response.json()
|
|
298
|
+
if data.get("status") in ("failed", "timeout"):
|
|
299
|
+
return ResponseErrorModel(**data)
|
|
300
|
+
return ResponseNormalModel(**data)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Maps known VisionAI server base URLs to their trusted JWT issuers.
|
|
2
|
+
# Keys are normalized (no trailing slash).
|
|
3
|
+
# Values are exact issuer strings as they appear in the JWT `iss` claim.
|
|
4
|
+
#
|
|
5
|
+
# Long-term: replace this static table with dynamic discovery from the server's
|
|
6
|
+
# trusted-issuers endpoint (e.g. GET {auth_url}/api/v1/auth/trusted-issuers).
|
|
7
|
+
|
|
8
|
+
_KEYCLOAK_REALM_PATH = "/keycloak/realms/linker-platform"
|
|
9
|
+
|
|
10
|
+
# On-premise deployments backed by Keycloak (realm: linker-platform).
|
|
11
|
+
# Issuer = base_url + _KEYCLOAK_REALM_PATH for all entries.
|
|
12
|
+
_KEYCLOAK_BASE_URLS: list[str] = [
|
|
13
|
+
"https://offline.visionai.linkervision.com",
|
|
14
|
+
"https://lighthouse.visionai.linkervision.ai",
|
|
15
|
+
"https://lighthouse-production.visionai.linkervision.ai",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
# Cloud deployments backed by Auth0 (one tenant per environment).
|
|
19
|
+
_AUTH0_URL_TO_ISSUERS: dict[str, list[str]] = {
|
|
20
|
+
"https://visionai.linkervision.com": [
|
|
21
|
+
"https://data-engine-prod.us.auth0.com",
|
|
22
|
+
],
|
|
23
|
+
"https://staging.visionai.linkervision.com": [
|
|
24
|
+
"https://data-engine-staging.jp.auth0.com",
|
|
25
|
+
],
|
|
26
|
+
"https://dev2.visionai.linkervision.com": [
|
|
27
|
+
"https://data-engine-dev2.jp.auth0.com",
|
|
28
|
+
],
|
|
29
|
+
"https://dev.visionai.linkervision.com": [
|
|
30
|
+
"https://dev-045acunea5v1mm3l.us.auth0.com",
|
|
31
|
+
],
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
_AUTH_URL_TO_ISSUERS: dict[str, list[str]] = {
|
|
35
|
+
**{url: [f"{url}{_KEYCLOAK_REALM_PATH}"] for url in _KEYCLOAK_BASE_URLS},
|
|
36
|
+
**_AUTH0_URL_TO_ISSUERS,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def resolve_allowed_issuers(auth_url: str) -> list[str]:
|
|
41
|
+
"""Return the trusted issuers for the given auth_url.
|
|
42
|
+
|
|
43
|
+
If the URL is in the known table, return its configured issuers.
|
|
44
|
+
Otherwise, assume a Keycloak deployment and derive the issuer from auth_url.
|
|
45
|
+
"""
|
|
46
|
+
normalized = auth_url.rstrip("/")
|
|
47
|
+
if known := _AUTH_URL_TO_ISSUERS.get(normalized):
|
|
48
|
+
return known
|
|
49
|
+
return [f"{normalized}{_KEYCLOAK_REALM_PATH}"]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
class VisionaiSDKError(Exception):
|
|
2
|
+
"""Base exception for all SDK errors."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class APIError(VisionaiSDKError):
|
|
6
|
+
"""Base exception for all HTTP API errors (has status_code)."""
|
|
7
|
+
|
|
8
|
+
def __init__(self, status_code: int, message: str) -> None:
|
|
9
|
+
self.status_code = status_code
|
|
10
|
+
super().__init__(f"[{status_code}] {message}")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ClientError(APIError):
|
|
14
|
+
"""4xx - Client-side error."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AuthenticationError(ClientError):
|
|
18
|
+
"""401 - Invalid credentials/password."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, message: str) -> None:
|
|
21
|
+
super().__init__(401, message)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PermissionDeniedError(ClientError):
|
|
25
|
+
"""403 - Insufficient permissions."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, message: str) -> None:
|
|
28
|
+
super().__init__(403, message)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ServerError(APIError):
|
|
32
|
+
"""5xx - Server-side failure, consider retry."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class NetworkError(VisionaiSDKError):
|
|
36
|
+
"""Connection or timeout failure."""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class JwksDiscoveryError(VisionaiSDKError):
|
|
40
|
+
"""Failed to fetch or parse the OIDC discovery document or JWKS endpoint."""
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from pydantic import AnyHttpUrl, BaseModel, Field
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
class TokenResponse(BaseModel):
|
|
5
|
+
"""Response model for authentication token endpoints.
|
|
6
|
+
|
|
7
|
+
Used by both /api/users/jwt and /api/users/client-token endpoints.
|
|
8
|
+
"""
|
|
9
|
+
access_token: str = Field(..., description="JWT access token")
|
|
10
|
+
expires_in: int = Field(..., description="Token expiration time in seconds")
|
|
11
|
+
token_type: str = Field(..., description="Token type (e.g., 'Bearer')")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NIMRequestModel(BaseModel):
|
|
15
|
+
img: str | list[str]
|
|
16
|
+
prompt: str
|
|
17
|
+
temperature: float | None = 0.2
|
|
18
|
+
max_tokens: int | None = 500
|
|
19
|
+
top_p: float | None = 0.7
|
|
20
|
+
stream: bool = False
|
|
21
|
+
use_cache: bool = True
|
|
22
|
+
num_beams: int | None = 1
|
|
23
|
+
api_endpoint: AnyHttpUrl | None = None
|
|
24
|
+
hook: AnyHttpUrl | None = None
|
|
25
|
+
# Will only work within hook (will not validated though)
|
|
26
|
+
use_response_postprocess: bool = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ResponseNormalModel(BaseModel):
|
|
30
|
+
chat_id: str
|
|
31
|
+
status: Literal["pending", "running", "completed"]
|
|
32
|
+
message: str | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ResponseErrorModel(BaseModel):
|
|
36
|
+
chat_id: str
|
|
37
|
+
status: Literal["failed", "timeout"]
|
|
38
|
+
error: str
|
|
39
|
+
message: str | None = None
|
|
File without changes
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: visionai-sdk-python
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: VisionAI SDK for Python
|
|
5
|
+
Author: Tony Yang
|
|
6
|
+
Author-email: Tony Yang <tonyyang@linkervision.com>
|
|
7
|
+
Requires-Dist: httpx>=0.28.1
|
|
8
|
+
Requires-Dist: pydantic>=2.12.5
|
|
9
|
+
Requires-Dist: cryptography>=46.0.5
|
|
10
|
+
Requires-Dist: pyjwt[cryptography]>=2.8.0
|
|
11
|
+
Requires-Python: >=3.11
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
|
|
14
|
+
# README
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
visionai_sdk_python/__init__.py,sha256=FjhwCuJI-nI2sv3xxBp1mAXmtAefelZZj7fhFIQy9X0,624
|
|
2
|
+
visionai_sdk_python/_base.py,sha256=gr-4Sb9i7kI1UMZueQXu57_7PvlBbEJh8PlA0EoljNo,8040
|
|
3
|
+
visionai_sdk_python/_jwt_verifier.py,sha256=gFnm5mMoEAGdamIHYZ8wwPXP2itnfCm5Z4_8_SPzlwM,8365
|
|
4
|
+
visionai_sdk_python/async_client.py,sha256=CS_GRf6cxIG8vZTOuOPp5Hf86UdJMCXsHHmTrcYB8eA,11601
|
|
5
|
+
visionai_sdk_python/client.py,sha256=l8VnNXMINton247U7MsFYg-itRN1OJjXAhDsUD9c3jE,11414
|
|
6
|
+
visionai_sdk_python/constants.py,sha256=hDLdISOkrrEifLumVy0gPLO2NmUDxdfe2UIueizpbh8,1868
|
|
7
|
+
visionai_sdk_python/endpoints.py,sha256=CC-AoxXzg39sMAhqx1CTS3gRWKx_H7VOqBANRLVotoU,183
|
|
8
|
+
visionai_sdk_python/exceptions.py,sha256=r4bxvsQGNTVa7-ISX1JAC7VRJ-6qs3n30l3j3qxai-Y,1047
|
|
9
|
+
visionai_sdk_python/models.py,sha256=RJcN3dS-5tFR8_NhyWEc8ANMBmPDNEI7ZbPf5ZxcfcU,1198
|
|
10
|
+
visionai_sdk_python/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
visionai_sdk_python-0.1.0.dist-info/WHEEL,sha256=s_zqWxHFEH8b58BCtf46hFCqPaISurdB9R1XJ8za6XI,80
|
|
12
|
+
visionai_sdk_python-0.1.0.dist-info/METADATA,sha256=Hm46dTeS6xx5Z-ydjF24hegGFAqY4orip_K0awOpWwM,379
|
|
13
|
+
visionai_sdk_python-0.1.0.dist-info/RECORD,,
|