cortex-suite-sdk 1.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cortex_suite_sdk-1.0.2/.gitignore +4 -0
- cortex_suite_sdk-1.0.2/PKG-INFO +77 -0
- cortex_suite_sdk-1.0.2/README.md +60 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/__init__.py +4 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/_generated_constants.py +26 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/_generated_errors.py +27 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/auth.py +117 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/client.py +331 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/constants.py +50 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/errors.py +49 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/liveness.py +111 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/py.typed +1 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/session.py +184 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/transport.py +129 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/types.py +55 -0
- cortex_suite_sdk-1.0.2/cortex_sdk/upload.py +50 -0
- cortex_suite_sdk-1.0.2/pyproject.toml +35 -0
- cortex_suite_sdk-1.0.2/tests/__init__.py +0 -0
- cortex_suite_sdk-1.0.2/tests/helpers.py +46 -0
- cortex_suite_sdk-1.0.2/tests/mock_server.py +617 -0
- cortex_suite_sdk-1.0.2/tests/scenarios.py +455 -0
- cortex_suite_sdk-1.0.2/tests/schema_validation.py +75 -0
- cortex_suite_sdk-1.0.2/tests/test_auth_config.py +215 -0
- cortex_suite_sdk-1.0.2/tests/test_auth_refresh.py +70 -0
- cortex_suite_sdk-1.0.2/tests/test_error_coverage.py +381 -0
- cortex_suite_sdk-1.0.2/tests/test_file_upload.py +80 -0
- cortex_suite_sdk-1.0.2/tests/test_liveness.py +97 -0
- cortex_suite_sdk-1.0.2/tests/test_normal_chat.py +53 -0
- cortex_suite_sdk-1.0.2/tests/test_parity_adapter.py +24 -0
- cortex_suite_sdk-1.0.2/tests/test_reconnect.py +84 -0
- cortex_suite_sdk-1.0.2/tests/test_schema_conformance.py +12 -0
- cortex_suite_sdk-1.0.2/tests/test_session_state.py +93 -0
- cortex_suite_sdk-1.0.2/tests/test_shared_artifacts.py +56 -0
- cortex_suite_sdk-1.0.2/tests/test_streaming_chat.py +97 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cortex-suite-sdk
|
|
3
|
+
Version: 1.0.2
|
|
4
|
+
Summary: Cortex SDK — transport client for Python
|
|
5
|
+
License: MIT
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Requires-Dist: httpx>=0.27
|
|
8
|
+
Requires-Dist: websockets>=12.0
|
|
9
|
+
Provides-Extra: dev
|
|
10
|
+
Requires-Dist: anyio[trio]; extra == 'dev'
|
|
11
|
+
Requires-Dist: build>=1.2; extra == 'dev'
|
|
12
|
+
Requires-Dist: jsonschema>=4.23; extra == 'dev'
|
|
13
|
+
Requires-Dist: mypy>=1.9; extra == 'dev'
|
|
14
|
+
Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
|
|
15
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
|
|
18
|
+
# Cortex SDK for Python
|
|
19
|
+
|
|
20
|
+
## Install
|
|
21
|
+
|
|
22
|
+
`pip install cortex-suite-sdk`
|
|
23
|
+
|
|
24
|
+
## Quick start
|
|
25
|
+
|
|
26
|
+
```python
|
|
27
|
+
import asyncio
|
|
28
|
+
|
|
29
|
+
from cortex_sdk import CortexClient
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def on_message(msg):
|
|
33
|
+
print(msg["type"], msg["payload"])
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def main():
|
|
37
|
+
client = CortexClient(
|
|
38
|
+
api_key="your-api-key",
|
|
39
|
+
# auth_url="https://auth.cortexsuite.app", # optional override
|
|
40
|
+
on_message=on_message,
|
|
41
|
+
)
|
|
42
|
+
await client.connect()
|
|
43
|
+
await client.send_message(content="Hello")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
asyncio.run(main())
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## API
|
|
50
|
+
|
|
51
|
+
See the [full API reference](../docs/04_api_reference.md).
|
|
52
|
+
`auth_url` is optional; if omitted, the SDK uses its default auth base URL.
|
|
53
|
+
|
|
54
|
+
## Error handling
|
|
55
|
+
|
|
56
|
+
```python
|
|
57
|
+
from cortex_sdk import CortexClient, CortexError
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def on_message(msg):
|
|
61
|
+
if msg["type"] == "system::error":
|
|
62
|
+
print("Runtime error:", msg["payload"]["code"])
|
|
63
|
+
if msg["payload"].get("fatal"):
|
|
64
|
+
pass # handle unrecoverable session
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
client = CortexClient(
|
|
68
|
+
api_key="your-api-key",
|
|
69
|
+
on_message=on_message,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
await client.connect()
|
|
74
|
+
except CortexError as err:
|
|
75
|
+
if err.code in ("auth_invalid", "auth_refresh_failed"):
|
|
76
|
+
pass # prompt re-authentication
|
|
77
|
+
```
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Cortex SDK for Python
|
|
2
|
+
|
|
3
|
+
## Install
|
|
4
|
+
|
|
5
|
+
`pip install cortex-suite-sdk`
|
|
6
|
+
|
|
7
|
+
## Quick start
|
|
8
|
+
|
|
9
|
+
```python
|
|
10
|
+
import asyncio
|
|
11
|
+
|
|
12
|
+
from cortex_sdk import CortexClient
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def on_message(msg):
|
|
16
|
+
print(msg["type"], msg["payload"])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
async def main():
|
|
20
|
+
client = CortexClient(
|
|
21
|
+
api_key="your-api-key",
|
|
22
|
+
# auth_url="https://auth.cortexsuite.app", # optional override
|
|
23
|
+
on_message=on_message,
|
|
24
|
+
)
|
|
25
|
+
await client.connect()
|
|
26
|
+
await client.send_message(content="Hello")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
asyncio.run(main())
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## API
|
|
33
|
+
|
|
34
|
+
See the [full API reference](../docs/04_api_reference.md).
|
|
35
|
+
`auth_url` is optional; if omitted, the SDK uses its default auth base URL.
|
|
36
|
+
|
|
37
|
+
## Error handling
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
from cortex_sdk import CortexClient, CortexError
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def on_message(msg):
|
|
44
|
+
if msg["type"] == "system::error":
|
|
45
|
+
print("Runtime error:", msg["payload"]["code"])
|
|
46
|
+
if msg["payload"].get("fatal"):
|
|
47
|
+
pass # handle unrecoverable session
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
client = CortexClient(
|
|
51
|
+
api_key="your-api-key",
|
|
52
|
+
on_message=on_message,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
await client.connect()
|
|
57
|
+
except CortexError as err:
|
|
58
|
+
if err.code in ("auth_invalid", "auth_refresh_failed"):
|
|
59
|
+
pass # prompt re-authentication
|
|
60
|
+
```
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# Generated from shared/constants.json. Do not edit manually.
|
|
4
|
+
|
|
5
|
+
DEFAULT_AUTH_URL: str = 'https://auth.cortexsuite.app'
|
|
6
|
+
AUTH_TOKEN_PATH: str = '/auth/token'
|
|
7
|
+
AUTH_REFRESH_PATH: str = '/auth/refresh'
|
|
8
|
+
WS_SUBPROTOCOL: str = 'cortex-sdk.v1'
|
|
9
|
+
WS_SUBPROTOCOL_JWT_PREFIX: str = 'cortex-sdk.jwt.'
|
|
10
|
+
SCHEMA_VERSION: str = '1.0'
|
|
11
|
+
|
|
12
|
+
DEFAULT_CONNECT_TIMEOUT_MS: int = 10000
|
|
13
|
+
DEFAULT_SEND_TIMEOUT_MS: int = 10000
|
|
14
|
+
DEFAULT_RESYNC_TIMEOUT_MS: int = 15000
|
|
15
|
+
DEFAULT_PING_INTERVAL_MS: int = 15000
|
|
16
|
+
DEFAULT_PONG_TIMEOUT_MS: int = 5000
|
|
17
|
+
DEFAULT_STALE_THRESHOLD_MS: int = 45000
|
|
18
|
+
TOKEN_REFRESH_BUFFER_MS: int = 60000
|
|
19
|
+
|
|
20
|
+
RECONNECT_BACKOFF_MS: tuple[int, ...] = (1000, 2000, 5000, 10000, 20000, 30000,)
|
|
21
|
+
|
|
22
|
+
# Deprecated compatibility aliases. Prefer DEFAULT_AUTH_URL/AUTH_TOKEN_PATH,
|
|
23
|
+
# DEFAULT_AUTH_URL/AUTH_REFRESH_PATH, and WS_SUBPROTOCOL directly.
|
|
24
|
+
CORTEX_AUTH_URL: str = DEFAULT_AUTH_URL + AUTH_TOKEN_PATH
|
|
25
|
+
CORTEX_REFRESH_URL: str = DEFAULT_AUTH_URL + AUTH_REFRESH_PATH
|
|
26
|
+
WS_SUBPROTOCOL_BASE: str = WS_SUBPROTOCOL
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
# Generated from sdk/shared/errors.json. Do not edit manually.
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class GeneratedErrorEntry:
|
|
9
|
+
code: str
|
|
10
|
+
retryable: bool
|
|
11
|
+
fatal: bool
|
|
12
|
+
|
|
13
|
+
GENERATED_ERROR_CATALOG: tuple[GeneratedErrorEntry, ...] = (
|
|
14
|
+
GeneratedErrorEntry('auth_invalid', retryable=False, fatal=True),
|
|
15
|
+
GeneratedErrorEntry('auth_expired', retryable=True, fatal=False),
|
|
16
|
+
GeneratedErrorEntry('auth_refresh_failed', retryable=False, fatal=True),
|
|
17
|
+
GeneratedErrorEntry('transport_connect_timeout', retryable=True, fatal=False),
|
|
18
|
+
GeneratedErrorEntry('transport_send_timeout', retryable=True, fatal=False),
|
|
19
|
+
GeneratedErrorEntry('transport_protocol_violation', retryable=False, fatal=True),
|
|
20
|
+
GeneratedErrorEntry('session_not_found', retryable=False, fatal=True),
|
|
21
|
+
GeneratedErrorEntry('session_terminal', retryable=False, fatal=True),
|
|
22
|
+
GeneratedErrorEntry('resync_timeout', retryable=True, fatal=False),
|
|
23
|
+
GeneratedErrorEntry('replay_unavailable', retryable=True, fatal=False),
|
|
24
|
+
GeneratedErrorEntry('upload_failed', retryable=True, fatal=False),
|
|
25
|
+
GeneratedErrorEntry('upload_too_large', retryable=False, fatal=False),
|
|
26
|
+
GeneratedErrorEntry('upload_type_rejected', retryable=False, fatal=False),
|
|
27
|
+
)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from .constants import (
|
|
10
|
+
DEFAULT_AUTH_URL,
|
|
11
|
+
AUTH_TOKEN_PATH,
|
|
12
|
+
AUTH_REFRESH_PATH,
|
|
13
|
+
TOKEN_REFRESH_BUFFER,
|
|
14
|
+
)
|
|
15
|
+
from .errors import make_error
|
|
16
|
+
from .types import AuthTokenResponse
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _parse_jwt_exp(token: str) -> float | None:
|
|
20
|
+
"""Extract exp (as Unix timestamp in seconds) from a JWT without verifying signature."""
|
|
21
|
+
try:
|
|
22
|
+
parts = token.split(".")
|
|
23
|
+
if len(parts) != 3:
|
|
24
|
+
return None
|
|
25
|
+
# base64url → base64 standard (add padding)
|
|
26
|
+
payload_b64 = parts[1]
|
|
27
|
+
# Add padding
|
|
28
|
+
padding = 4 - len(payload_b64) % 4
|
|
29
|
+
if padding != 4:
|
|
30
|
+
payload_b64 += "=" * padding
|
|
31
|
+
payload_json = base64.urlsafe_b64decode(payload_b64)
|
|
32
|
+
payload = json.loads(payload_json)
|
|
33
|
+
exp = payload.get("exp")
|
|
34
|
+
return float(exp) if isinstance(exp, (int, float)) else None
|
|
35
|
+
except Exception:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def is_token_expiring_soon(access_token: str) -> bool:
|
|
40
|
+
"""Return True if the token expires within TOKEN_REFRESH_BUFFER seconds."""
|
|
41
|
+
exp = _parse_jwt_exp(access_token)
|
|
42
|
+
if exp is None:
|
|
43
|
+
return False
|
|
44
|
+
return time.time() > exp - TOKEN_REFRESH_BUFFER
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
async def exchange_api_key(
|
|
48
|
+
api_key: str,
|
|
49
|
+
*,
|
|
50
|
+
auth_base_url: str = DEFAULT_AUTH_URL,
|
|
51
|
+
) -> AuthTokenResponse:
|
|
52
|
+
"""Exchange an API key for tokens and WS URL."""
|
|
53
|
+
async with httpx.AsyncClient() as client:
|
|
54
|
+
resp = await client.post(
|
|
55
|
+
_build_auth_endpoint(auth_base_url, AUTH_TOKEN_PATH),
|
|
56
|
+
headers={
|
|
57
|
+
"Content-Type": "application/json",
|
|
58
|
+
"Authorization": f"ApiKey {api_key}",
|
|
59
|
+
},
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if not resp.is_success:
|
|
63
|
+
try:
|
|
64
|
+
body = resp.json()
|
|
65
|
+
code = body.get("error", "auth_invalid")
|
|
66
|
+
message = body.get("message", "API key rejected")
|
|
67
|
+
except Exception:
|
|
68
|
+
code, message = "auth_invalid", "API key rejected"
|
|
69
|
+
raise make_error(str(code), str(message))
|
|
70
|
+
|
|
71
|
+
return resp.json() # type: ignore[no-any-return]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def refresh_access_token(
|
|
75
|
+
refresh_token: str,
|
|
76
|
+
*,
|
|
77
|
+
auth_base_url: str = DEFAULT_AUTH_URL,
|
|
78
|
+
) -> str:
|
|
79
|
+
"""Refresh the access token using the refresh token."""
|
|
80
|
+
async with httpx.AsyncClient() as client:
|
|
81
|
+
resp = await client.post(
|
|
82
|
+
_build_auth_endpoint(auth_base_url, AUTH_REFRESH_PATH),
|
|
83
|
+
headers={
|
|
84
|
+
"Content-Type": "application/json",
|
|
85
|
+
"Authorization": f"Bearer {refresh_token}",
|
|
86
|
+
},
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if not resp.is_success:
|
|
90
|
+
raise make_error("auth_refresh_failed", "Refresh token expired or invalid")
|
|
91
|
+
|
|
92
|
+
body = resp.json()
|
|
93
|
+
return str(body["access_token"])
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def normalize_auth_base_url(auth_url: str) -> str:
|
|
97
|
+
import warnings
|
|
98
|
+
normalized = auth_url.rstrip("/")
|
|
99
|
+
# Guard: consumer may have passed a full endpoint instead of the base URL.
|
|
100
|
+
# Strip known auth paths and warn so developers catch the misconfiguration early.
|
|
101
|
+
for known_path in [AUTH_TOKEN_PATH, AUTH_REFRESH_PATH]:
|
|
102
|
+
if normalized.endswith(known_path):
|
|
103
|
+
base = normalized[: -len(known_path)]
|
|
104
|
+
warnings.warn(
|
|
105
|
+
f"[CortexSDK] auth_url must be a base URL (origin only), not a full endpoint. "
|
|
106
|
+
f'Received "{auth_url}" — "{known_path}" has been stripped automatically. '
|
|
107
|
+
f'Pass the base URL only, e.g., "{base}".',
|
|
108
|
+
UserWarning,
|
|
109
|
+
stacklevel=3,
|
|
110
|
+
)
|
|
111
|
+
normalized = base
|
|
112
|
+
break
|
|
113
|
+
return normalized or DEFAULT_AUTH_URL
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _build_auth_endpoint(auth_base_url: str, path: str) -> str:
|
|
117
|
+
return f"{normalize_auth_base_url(auth_base_url)}{path}"
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import secrets
|
|
5
|
+
from typing import BinaryIO
|
|
6
|
+
from urllib.parse import urlsplit, urlunsplit
|
|
7
|
+
|
|
8
|
+
from .auth import (
|
|
9
|
+
exchange_api_key,
|
|
10
|
+
refresh_access_token,
|
|
11
|
+
is_token_expiring_soon,
|
|
12
|
+
normalize_auth_base_url,
|
|
13
|
+
)
|
|
14
|
+
from .constants import (
|
|
15
|
+
DEFAULT_AUTH_URL,
|
|
16
|
+
DEFAULT_CONNECT_TIMEOUT,
|
|
17
|
+
DEFAULT_SEND_TIMEOUT,
|
|
18
|
+
DEFAULT_RESYNC_TIMEOUT,
|
|
19
|
+
DEFAULT_PING_INTERVAL,
|
|
20
|
+
DEFAULT_PONG_TIMEOUT,
|
|
21
|
+
DEFAULT_STALE_THRESHOLD,
|
|
22
|
+
TOKEN_REFRESH_BUFFER,
|
|
23
|
+
RECONNECT_BACKOFF,
|
|
24
|
+
)
|
|
25
|
+
from .errors import CortexError, make_error
|
|
26
|
+
from .liveness import LivenessMonitor
|
|
27
|
+
from .session import SessionManager
|
|
28
|
+
from .transport import Transport
|
|
29
|
+
from .types import ChannelState, CortexMessage, MessageCallback, SessionState
|
|
30
|
+
from .upload import upload_file
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class CortexClient:
|
|
34
|
+
"""Cortex SDK client — one instance per session.
|
|
35
|
+
|
|
36
|
+
Public API:
|
|
37
|
+
connect() / disconnect()
|
|
38
|
+
send_message(content, attachments)
|
|
39
|
+
upload_attachment(file)
|
|
40
|
+
stop()
|
|
41
|
+
session_state / channel_state / session_id (read-only properties)
|
|
42
|
+
|
|
43
|
+
Time options are in **seconds** (not milliseconds).
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
api_key: str,
|
|
49
|
+
on_message: MessageCallback,
|
|
50
|
+
auth_url: str | None = None,
|
|
51
|
+
connect_timeout: float = DEFAULT_CONNECT_TIMEOUT,
|
|
52
|
+
send_timeout: float = DEFAULT_SEND_TIMEOUT,
|
|
53
|
+
resync_timeout: float = DEFAULT_RESYNC_TIMEOUT,
|
|
54
|
+
ping_interval: float = DEFAULT_PING_INTERVAL,
|
|
55
|
+
pong_timeout: float = DEFAULT_PONG_TIMEOUT,
|
|
56
|
+
stale_threshold: float = DEFAULT_STALE_THRESHOLD,
|
|
57
|
+
# Private test override — not part of the public API
|
|
58
|
+
_upload_url: str | None = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
self._api_key = api_key
|
|
61
|
+
self._on_message_cb = on_message
|
|
62
|
+
self._auth_url = normalize_auth_base_url(auth_url or DEFAULT_AUTH_URL)
|
|
63
|
+
self._connect_timeout = connect_timeout
|
|
64
|
+
self._send_timeout = send_timeout
|
|
65
|
+
self._resync_timeout = resync_timeout
|
|
66
|
+
self._ping_interval = ping_interval
|
|
67
|
+
self._pong_timeout = pong_timeout
|
|
68
|
+
self._stale_threshold = stale_threshold
|
|
69
|
+
|
|
70
|
+
self._upload_url = _upload_url # None → runtime-side upload URL heuristic
|
|
71
|
+
|
|
72
|
+
# Internal state
|
|
73
|
+
self._channel_state: ChannelState = "CLOSED"
|
|
74
|
+
self._access_token: str | None = None
|
|
75
|
+
self._refresh_token: str | None = None
|
|
76
|
+
self._ws_url: str | None = None
|
|
77
|
+
self._channel_id: str = f"ch_{secrets.token_hex(4)}"
|
|
78
|
+
self._reconnect_attempt: int = 0
|
|
79
|
+
self._disconnect_requested: bool = False
|
|
80
|
+
|
|
81
|
+
# Components
|
|
82
|
+
self._transport = Transport(connect_timeout, send_timeout)
|
|
83
|
+
self._session = SessionManager(
|
|
84
|
+
on_message=self._dispatch_message,
|
|
85
|
+
on_fatal_error=self._handle_fatal_error,
|
|
86
|
+
)
|
|
87
|
+
self._liveness: LivenessMonitor | None = None
|
|
88
|
+
self._reconnect_task: asyncio.Task[None] | None = None
|
|
89
|
+
self._token_refresh_task: asyncio.Task[None] | None = None
|
|
90
|
+
|
|
91
|
+
# Wire transport callbacks
|
|
92
|
+
self._transport.on_message = self._session.handle_incoming
|
|
93
|
+
self._transport.on_close = self._handle_close
|
|
94
|
+
self._transport.on_error = None # handled via on_close
|
|
95
|
+
|
|
96
|
+
# ── public properties ─────────────────────────────────────────────────────
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def session_state(self) -> SessionState:
|
|
100
|
+
return self._session.session_state
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def channel_state(self) -> ChannelState:
|
|
104
|
+
return self._channel_state
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def session_id(self) -> str | None:
|
|
108
|
+
return self._session.session_id
|
|
109
|
+
|
|
110
|
+
# ── public methods ────────────────────────────────────────────────────────
|
|
111
|
+
|
|
112
|
+
async def connect(self) -> None:
|
|
113
|
+
"""Full bootstrap: auth exchange → WS open → system::init → liveness start."""
|
|
114
|
+
self._disconnect_requested = False
|
|
115
|
+
self._reconnect_attempt = 0
|
|
116
|
+
|
|
117
|
+
auth = await exchange_api_key(self._api_key, auth_base_url=self._auth_url)
|
|
118
|
+
self._access_token = auth["access_token"]
|
|
119
|
+
self._refresh_token = auth["refresh_token"]
|
|
120
|
+
self._ws_url = auth["ws_url"]
|
|
121
|
+
if self._upload_url is None:
|
|
122
|
+
self._upload_url = _derive_upload_url_from_ws_url(self._ws_url)
|
|
123
|
+
|
|
124
|
+
await self._open_channel()
|
|
125
|
+
|
|
126
|
+
self._session.set_transport(self._transport, self._send_timeout)
|
|
127
|
+
await self._session.send_init(auth["runtime_bootstrap"])
|
|
128
|
+
|
|
129
|
+
self._start_liveness()
|
|
130
|
+
self._schedule_token_refresh()
|
|
131
|
+
|
|
132
|
+
async def disconnect(self) -> None:
|
|
133
|
+
"""Close the session and WebSocket connection cleanly."""
|
|
134
|
+
self._disconnect_requested = True
|
|
135
|
+
self._stop_liveness()
|
|
136
|
+
self._stop_token_refresh()
|
|
137
|
+
|
|
138
|
+
if self._reconnect_task and not self._reconnect_task.done():
|
|
139
|
+
self._reconnect_task.cancel()
|
|
140
|
+
try:
|
|
141
|
+
await self._reconnect_task
|
|
142
|
+
except (asyncio.CancelledError, Exception):
|
|
143
|
+
pass
|
|
144
|
+
self._reconnect_task = None
|
|
145
|
+
|
|
146
|
+
self._channel_state = "CLOSED"
|
|
147
|
+
await self._transport.aclose()
|
|
148
|
+
|
|
149
|
+
async def send_message(
|
|
150
|
+
self, content: str, attachments: list[str] | None = None
|
|
151
|
+
) -> None:
|
|
152
|
+
await self._session.send_chat_message(content, attachments)
|
|
153
|
+
|
|
154
|
+
async def upload_attachment(self, file: str | bytes | BinaryIO) -> str:
|
|
155
|
+
if not self._access_token:
|
|
156
|
+
raise make_error("auth_invalid", "Not connected")
|
|
157
|
+
url = self._upload_url or "/upload"
|
|
158
|
+
return await upload_file(file, self._access_token, upload_url=url)
|
|
159
|
+
|
|
160
|
+
async def stop(self) -> None:
|
|
161
|
+
await self._session.send_stop()
|
|
162
|
+
|
|
163
|
+
# ── internal helpers ──────────────────────────────────────────────────────
|
|
164
|
+
|
|
165
|
+
async def _open_channel(self) -> None:
|
|
166
|
+
if not self._ws_url or not self._access_token:
|
|
167
|
+
raise RuntimeError("Auth not completed before _open_channel")
|
|
168
|
+
self._channel_state = "CONNECTING"
|
|
169
|
+
await self._transport.open(self._ws_url, self._access_token)
|
|
170
|
+
self._channel_state = "OPEN"
|
|
171
|
+
self._reconnect_attempt = 0
|
|
172
|
+
|
|
173
|
+
def _start_liveness(self) -> None:
|
|
174
|
+
self._stop_liveness()
|
|
175
|
+
self._liveness = LivenessMonitor(
|
|
176
|
+
transport=self._transport,
|
|
177
|
+
ping_interval=self._ping_interval,
|
|
178
|
+
pong_timeout=self._pong_timeout,
|
|
179
|
+
stale_threshold=self._stale_threshold,
|
|
180
|
+
on_stale=self._handle_stale,
|
|
181
|
+
get_session_id=lambda: self._session.session_id,
|
|
182
|
+
get_channel_id=lambda: self._channel_id,
|
|
183
|
+
)
|
|
184
|
+
self._liveness.start()
|
|
185
|
+
|
|
186
|
+
def _stop_liveness(self) -> None:
|
|
187
|
+
if self._liveness:
|
|
188
|
+
self._liveness.stop()
|
|
189
|
+
self._liveness = None
|
|
190
|
+
|
|
191
|
+
def _schedule_token_refresh(self) -> None:
|
|
192
|
+
self._stop_token_refresh()
|
|
193
|
+
|
|
194
|
+
async def _refresh_loop() -> None:
|
|
195
|
+
try:
|
|
196
|
+
while not self._disconnect_requested:
|
|
197
|
+
await asyncio.sleep(TOKEN_REFRESH_BUFFER / 2)
|
|
198
|
+
if not self._access_token or not self._refresh_token:
|
|
199
|
+
continue
|
|
200
|
+
if is_token_expiring_soon(self._access_token):
|
|
201
|
+
try:
|
|
202
|
+
self._access_token = await refresh_access_token(
|
|
203
|
+
self._refresh_token,
|
|
204
|
+
auth_base_url=self._auth_url,
|
|
205
|
+
)
|
|
206
|
+
except Exception:
|
|
207
|
+
pass # surfaced at next reconnect
|
|
208
|
+
except asyncio.CancelledError:
|
|
209
|
+
pass
|
|
210
|
+
|
|
211
|
+
self._token_refresh_task = asyncio.create_task(_refresh_loop())
|
|
212
|
+
|
|
213
|
+
def _stop_token_refresh(self) -> None:
|
|
214
|
+
if self._token_refresh_task and not self._token_refresh_task.done():
|
|
215
|
+
self._token_refresh_task.cancel()
|
|
216
|
+
self._token_refresh_task = None
|
|
217
|
+
|
|
218
|
+
# ── callbacks from transport / session ────────────────────────────────────
|
|
219
|
+
|
|
220
|
+
def _dispatch_message(self, msg: CortexMessage) -> None:
|
|
221
|
+
# system::pong is internal — route to liveness, not to user callback
|
|
222
|
+
if msg.get("type") == "system::pong":
|
|
223
|
+
hb_id = msg["payload"].get("heartbeat_id")
|
|
224
|
+
if isinstance(hb_id, str) and self._liveness:
|
|
225
|
+
self._liveness.record_pong(hb_id)
|
|
226
|
+
return
|
|
227
|
+
self._on_message_cb(msg)
|
|
228
|
+
|
|
229
|
+
def _handle_fatal_error(self, err: CortexError) -> None:
|
|
230
|
+
self._channel_state = "AUTH_FAILED"
|
|
231
|
+
self._stop_liveness()
|
|
232
|
+
self._stop_token_refresh()
|
|
233
|
+
self._transport.close()
|
|
234
|
+
# Surface error as a system::error message
|
|
235
|
+
error_msg: CortexMessage = {
|
|
236
|
+
"type": "system::error",
|
|
237
|
+
"schema": "1.0",
|
|
238
|
+
"session_id": self._session.session_id or "",
|
|
239
|
+
"payload": {"code": err.code, "message": str(err)},
|
|
240
|
+
"ts": _now_iso(),
|
|
241
|
+
}
|
|
242
|
+
self._on_message_cb(error_msg)
|
|
243
|
+
|
|
244
|
+
def _handle_stale(self) -> None:
|
|
245
|
+
if self._channel_state in ("STALE", "RECONNECTING"):
|
|
246
|
+
return
|
|
247
|
+
self._channel_state = "STALE"
|
|
248
|
+
self._stop_liveness()
|
|
249
|
+
self._transport.close(1001, "stale")
|
|
250
|
+
# on_close will trigger reconnect
|
|
251
|
+
|
|
252
|
+
def _handle_close(self, code: int, reason: str) -> None:
|
|
253
|
+
if self._disconnect_requested:
|
|
254
|
+
return
|
|
255
|
+
if self._channel_state == "AUTH_FAILED":
|
|
256
|
+
return
|
|
257
|
+
if code == 4001:
|
|
258
|
+
self._channel_state = "AUTH_FAILED"
|
|
259
|
+
return
|
|
260
|
+
self._channel_state = "RECONNECTING"
|
|
261
|
+
self._reconnect_task = asyncio.ensure_future(self._reconnect_loop())
|
|
262
|
+
|
|
263
|
+
async def _reconnect_loop(self) -> None:
|
|
264
|
+
try:
|
|
265
|
+
while (
|
|
266
|
+
not self._disconnect_requested
|
|
267
|
+
and self._channel_state != "AUTH_FAILED"
|
|
268
|
+
):
|
|
269
|
+
idx = min(self._reconnect_attempt, len(RECONNECT_BACKOFF) - 1)
|
|
270
|
+
backoff = RECONNECT_BACKOFF[idx]
|
|
271
|
+
self._reconnect_attempt += 1
|
|
272
|
+
|
|
273
|
+
await asyncio.sleep(backoff)
|
|
274
|
+
if self._disconnect_requested:
|
|
275
|
+
break
|
|
276
|
+
|
|
277
|
+
# Proactively refresh token if needed
|
|
278
|
+
try:
|
|
279
|
+
await self._maybe_refresh_token()
|
|
280
|
+
except CortexError:
|
|
281
|
+
self._channel_state = "AUTH_FAILED"
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
await self._open_channel()
|
|
286
|
+
except Exception:
|
|
287
|
+
continue # retry after next backoff
|
|
288
|
+
|
|
289
|
+
# Re-attach session to the new transport
|
|
290
|
+
self._session.set_transport(self._transport, self._send_timeout)
|
|
291
|
+
|
|
292
|
+
# Resync with timeout
|
|
293
|
+
try:
|
|
294
|
+
await asyncio.wait_for(
|
|
295
|
+
self._session.send_resync(),
|
|
296
|
+
timeout=self._resync_timeout,
|
|
297
|
+
)
|
|
298
|
+
except Exception:
|
|
299
|
+
self._transport.close()
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
# Reconnect successful — restart liveness + token refresh
|
|
303
|
+
self._start_liveness()
|
|
304
|
+
self._schedule_token_refresh()
|
|
305
|
+
return
|
|
306
|
+
except asyncio.CancelledError:
|
|
307
|
+
pass
|
|
308
|
+
|
|
309
|
+
async def _maybe_refresh_token(self) -> None:
|
|
310
|
+
if not self._refresh_token:
|
|
311
|
+
raise make_error("auth_refresh_failed", "No refresh token")
|
|
312
|
+
if not self._access_token or is_token_expiring_soon(self._access_token):
|
|
313
|
+
self._access_token = await refresh_access_token(
|
|
314
|
+
self._refresh_token,
|
|
315
|
+
auth_base_url=self._auth_url,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _now_iso() -> str:
|
|
320
|
+
from datetime import datetime, timezone
|
|
321
|
+
return datetime.now(timezone.utc).isoformat()
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _derive_upload_url_from_ws_url(ws_url: str | None) -> str:
|
|
325
|
+
if not ws_url:
|
|
326
|
+
return "/upload"
|
|
327
|
+
|
|
328
|
+
parsed = urlsplit(ws_url)
|
|
329
|
+
scheme = parsed.scheme.lower()
|
|
330
|
+
http_scheme = "https" if scheme == "wss" else "http" if scheme == "ws" else scheme
|
|
331
|
+
return urlunsplit((http_scheme, parsed.netloc, "/upload", "", ""))
|