strand-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- strand/__init__.py +62 -0
- strand/_client.py +95 -0
- strand/_errors.py +92 -0
- strand/_http.py +213 -0
- strand/_jobs.py +205 -0
- strand/_models.py +147 -0
- strand/_predict.py +144 -0
- strand/_results.py +348 -0
- strand/_uploads.py +156 -0
- strand/py.typed +0 -0
- strand_sdk-0.1.0.dist-info/METADATA +167 -0
- strand_sdk-0.1.0.dist-info/RECORD +14 -0
- strand_sdk-0.1.0.dist-info/WHEEL +4 -0
- strand_sdk-0.1.0.dist-info/licenses/LICENSE +17 -0
strand/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Strand Platform Python SDK.
|
|
2
|
+
|
|
3
|
+
Quickstart — one-call pipeline:
|
|
4
|
+
|
|
5
|
+
>>> from strand import Client
|
|
6
|
+
>>> client = Client() # reads STRAND_API_KEY
|
|
7
|
+
>>> result = client.predict(
|
|
8
|
+
... "slide.svs",
|
|
9
|
+
... markers=["CD3", "CD8"],
|
|
10
|
+
... output_dir="./outputs/",
|
|
11
|
+
... )
|
|
12
|
+
>>> print(f"used {result.credits_used} credits")
|
|
13
|
+
|
|
14
|
+
Lower-level primitives stay available for fine-grained control:
|
|
15
|
+
|
|
16
|
+
>>> upload = client.uploads.upload_file("slide.svs")
|
|
17
|
+
>>> job = client.predict.submit(upload.id, markers=["CD3", "CD8"])
|
|
18
|
+
>>> job.wait()
|
|
19
|
+
>>> adata = job.download_results()
|
|
20
|
+
|
|
21
|
+
See `https://app.strandai.com/docs/api` for the underlying REST API reference.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from ._client import Client
|
|
27
|
+
from ._errors import (
|
|
28
|
+
AuthError,
|
|
29
|
+
BadRequestError,
|
|
30
|
+
InsufficientCreditsError,
|
|
31
|
+
JobFailedError,
|
|
32
|
+
JobTimeoutError,
|
|
33
|
+
NotFoundError,
|
|
34
|
+
RateLimitError,
|
|
35
|
+
StrandError,
|
|
36
|
+
UploadError,
|
|
37
|
+
)
|
|
38
|
+
from ._jobs import Job, JobEvent
|
|
39
|
+
from ._models import Estimate, JobStatus, PredictResult, Upload
|
|
40
|
+
from ._results import JobResults
|
|
41
|
+
|
|
42
|
+
__all__ = [
|
|
43
|
+
"AuthError",
|
|
44
|
+
"BadRequestError",
|
|
45
|
+
"Client",
|
|
46
|
+
"Estimate",
|
|
47
|
+
"InsufficientCreditsError",
|
|
48
|
+
"Job",
|
|
49
|
+
"JobEvent",
|
|
50
|
+
"JobFailedError",
|
|
51
|
+
"JobResults",
|
|
52
|
+
"JobStatus",
|
|
53
|
+
"JobTimeoutError",
|
|
54
|
+
"NotFoundError",
|
|
55
|
+
"PredictResult",
|
|
56
|
+
"RateLimitError",
|
|
57
|
+
"StrandError",
|
|
58
|
+
"Upload",
|
|
59
|
+
"UploadError",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
__version__ = "0.1.0"
|
strand/_client.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Top-level `Client` — entry point for the SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from ._http import DEFAULT_TIMEOUT, HttpSession
|
|
10
|
+
from ._predict import Predict
|
|
11
|
+
from ._uploads import Uploads
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ._jobs import Job
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class _JobsNamespace:
|
|
18
|
+
"""`client.jobs` namespace — fetch / look up jobs by id."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, client: Client) -> None:
|
|
21
|
+
self._client = client
|
|
22
|
+
|
|
23
|
+
def get(self, job_id: str) -> Job:
|
|
24
|
+
"""Return a `Job` handle and pre-populate its cached status."""
|
|
25
|
+
from ._jobs import Job
|
|
26
|
+
|
|
27
|
+
job = Job(id=job_id, reserved_credits=None, client=self._client)
|
|
28
|
+
job.refresh()
|
|
29
|
+
return job
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Client:
|
|
33
|
+
"""Strand Platform API client.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
api_key: API key (`sk-strand-...`). Falls back to `STRAND_API_KEY` env var.
|
|
37
|
+
base_url: API base URL. Defaults to `STRAND_BASE_URL` env var, else
|
|
38
|
+
`https://app.strandai.com`. Should not include the `/api/v1` suffix.
|
|
39
|
+
timeout: Per-request timeout in seconds (or an `httpx.Timeout`).
|
|
40
|
+
http_client: Pre-built `httpx.Client` for advanced use (e.g., custom
|
|
41
|
+
transport, retries, ASGI mounting in tests). The SDK will NOT
|
|
42
|
+
override the client's `Authorization` header — if you pass one,
|
|
43
|
+
wire auth headers yourself.
|
|
44
|
+
|
|
45
|
+
Example — one-call pipeline:
|
|
46
|
+
|
|
47
|
+
>>> client = Client(api_key="sk-strand-...")
|
|
48
|
+
>>> result = client.predict(
|
|
49
|
+
... "slide.svs",
|
|
50
|
+
... markers=["CD3", "CD8"],
|
|
51
|
+
... output_dir="./outputs/",
|
|
52
|
+
... )
|
|
53
|
+
>>> print(f"used {result.credits_used} credits")
|
|
54
|
+
|
|
55
|
+
Lower-level primitives (`client.predict` is also a namespace):
|
|
56
|
+
|
|
57
|
+
>>> upload = client.uploads.upload_file("slide.svs")
|
|
58
|
+
>>> estimate = client.predict.estimate(upload.id, markers=["CD3"])
|
|
59
|
+
>>> job = client.predict.submit(upload.id, markers=["CD3"])
|
|
60
|
+
>>> job.wait()
|
|
61
|
+
>>> adata = job.download_results()
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
*,
|
|
67
|
+
api_key: str | None = None,
|
|
68
|
+
base_url: str | None = None,
|
|
69
|
+
timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT,
|
|
70
|
+
http_client: httpx.Client | None = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
self._http = HttpSession(
|
|
73
|
+
api_key=api_key, base_url=base_url, timeout=timeout, client=http_client
|
|
74
|
+
)
|
|
75
|
+
self.uploads = Uploads(self._http)
|
|
76
|
+
self.predict = Predict(self._http, self)
|
|
77
|
+
self.jobs = _JobsNamespace(self)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def base_url(self) -> str:
|
|
81
|
+
return self._http.base_url
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def api_root(self) -> str:
|
|
85
|
+
return self._http.api_root
|
|
86
|
+
|
|
87
|
+
def close(self) -> None:
|
|
88
|
+
"""Close the underlying httpx client."""
|
|
89
|
+
self._http.close()
|
|
90
|
+
|
|
91
|
+
def __enter__(self) -> Client:
|
|
92
|
+
return self
|
|
93
|
+
|
|
94
|
+
def __exit__(self, *_: object) -> None:
|
|
95
|
+
self.close()
|
strand/_errors.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Typed exceptions for the Strand SDK.
|
|
2
|
+
|
|
3
|
+
All HTTP-level failures raised by the public surface inherit from `StrandError`.
|
|
4
|
+
Network-level failures (`httpx.HTTPError` and friends) pass through unchanged so
|
|
5
|
+
callers can apply their own retry logic — we only wrap responses that the
|
|
6
|
+
platform itself returned with a documented error shape.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class StrandError(Exception):
|
|
15
|
+
"""Base class for SDK errors raised against documented API responses."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
message: str,
|
|
20
|
+
*,
|
|
21
|
+
status_code: int | None = None,
|
|
22
|
+
error_code: str | None = None,
|
|
23
|
+
body: dict[str, Any] | None = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
super().__init__(message)
|
|
26
|
+
self.message = message
|
|
27
|
+
self.status_code = status_code
|
|
28
|
+
self.error_code = error_code
|
|
29
|
+
self.body = body or {}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AuthError(StrandError):
|
|
33
|
+
"""401 — missing / invalid / expired API key."""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BadRequestError(StrandError):
|
|
37
|
+
"""400 — request body or arguments rejected by the server."""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class NotFoundError(StrandError):
|
|
41
|
+
"""404 — referenced resource (upload, job, file) does not exist or isn't accessible."""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class InsufficientCreditsError(StrandError):
|
|
45
|
+
"""402 — org has insufficient credits to reserve for this job.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
required: Credits required to run the job, as returned by the server.
|
|
49
|
+
balance: Best-effort cached org balance from the most recent estimate, if available.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
message: str,
|
|
55
|
+
*,
|
|
56
|
+
required: int | None = None,
|
|
57
|
+
balance: int | None = None,
|
|
58
|
+
body: dict[str, Any] | None = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
super().__init__(message, status_code=402, error_code="insufficient_credits", body=body)
|
|
61
|
+
self.required = required
|
|
62
|
+
self.balance = balance
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class RateLimitError(StrandError):
|
|
66
|
+
"""429 — per-org concurrent job cap exceeded. `retry_after` is in seconds."""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
message: str,
|
|
71
|
+
*,
|
|
72
|
+
retry_after: int | None = None,
|
|
73
|
+
body: dict[str, Any] | None = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
super().__init__(message, status_code=429, error_code="rate_limited", body=body)
|
|
76
|
+
self.retry_after = retry_after
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class JobFailedError(StrandError):
|
|
80
|
+
"""Raised by `Job.wait()` when the job terminates with `status == "failed"`."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, message: str, *, job_id: str) -> None:
|
|
83
|
+
super().__init__(message, error_code="job_failed")
|
|
84
|
+
self.job_id = job_id
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class JobTimeoutError(StrandError):
|
|
88
|
+
"""Raised by `Job.wait(timeout=...)` when the wait deadline elapses before terminal status."""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class UploadError(StrandError):
|
|
92
|
+
"""Raised when the resumable upload session aborts or returns an unexpected response."""
|
strand/_http.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""Thin HTTP helper around the generated `AuthenticatedClient`.
|
|
2
|
+
|
|
3
|
+
Centralizes:
|
|
4
|
+
- Constructing an `httpx.Client` with our auth header, base URL, and timeouts.
|
|
5
|
+
- Mapping documented error response bodies onto our typed exceptions.
|
|
6
|
+
|
|
7
|
+
We deliberately do not lean on the generated `client.AuthenticatedClient.with_*`
|
|
8
|
+
helpers — `httpx.Client` is plenty, and we want a single source of truth for
|
|
9
|
+
error mapping.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
from collections.abc import Iterator
|
|
16
|
+
from typing import Any, cast
|
|
17
|
+
|
|
18
|
+
import httpx
|
|
19
|
+
|
|
20
|
+
from ._errors import (
|
|
21
|
+
AuthError,
|
|
22
|
+
BadRequestError,
|
|
23
|
+
InsufficientCreditsError,
|
|
24
|
+
NotFoundError,
|
|
25
|
+
RateLimitError,
|
|
26
|
+
StrandError,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
DEFAULT_BASE_URL = "https://app.strandai.com"
|
|
30
|
+
DEFAULT_TIMEOUT = 60.0
|
|
31
|
+
USER_AGENT = "strand-sdk-python/0.1.0"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class HttpSession:
|
|
35
|
+
"""Wraps `httpx.Client` with Strand auth + error mapping."""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
*,
|
|
40
|
+
api_key: str | None,
|
|
41
|
+
base_url: str | None,
|
|
42
|
+
timeout: float | httpx.Timeout | None,
|
|
43
|
+
client: httpx.Client | None = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
resolved_key = api_key or os.environ.get("STRAND_API_KEY")
|
|
46
|
+
if not resolved_key:
|
|
47
|
+
raise AuthError(
|
|
48
|
+
"No API key provided. Pass api_key=... or set STRAND_API_KEY.",
|
|
49
|
+
status_code=401,
|
|
50
|
+
error_code="missing_api_key",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
resolved_base = (
|
|
54
|
+
base_url or os.environ.get("STRAND_BASE_URL") or DEFAULT_BASE_URL
|
|
55
|
+
).rstrip("/")
|
|
56
|
+
|
|
57
|
+
self.api_key = resolved_key
|
|
58
|
+
self.base_url = resolved_base
|
|
59
|
+
self.api_root = f"{resolved_base}/api/v1"
|
|
60
|
+
self._owned_client = client is None
|
|
61
|
+
self._client = client or httpx.Client(
|
|
62
|
+
base_url=self.api_root,
|
|
63
|
+
timeout=timeout if timeout is not None else DEFAULT_TIMEOUT,
|
|
64
|
+
headers={
|
|
65
|
+
"Authorization": f"Bearer {resolved_key}",
|
|
66
|
+
"User-Agent": USER_AGENT,
|
|
67
|
+
"Accept": "application/json",
|
|
68
|
+
},
|
|
69
|
+
follow_redirects=False,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# ---------- lifecycle ----------
|
|
73
|
+
|
|
74
|
+
def close(self) -> None:
|
|
75
|
+
if self._owned_client:
|
|
76
|
+
self._client.close()
|
|
77
|
+
|
|
78
|
+
def __enter__(self) -> HttpSession:
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
def __exit__(self, *_: object) -> None:
|
|
82
|
+
self.close()
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def client(self) -> httpx.Client:
|
|
86
|
+
return self._client
|
|
87
|
+
|
|
88
|
+
# ---------- request helpers ----------
|
|
89
|
+
|
|
90
|
+
def request_json(
|
|
91
|
+
self,
|
|
92
|
+
method: str,
|
|
93
|
+
path: str,
|
|
94
|
+
*,
|
|
95
|
+
json: Any = None,
|
|
96
|
+
params: dict[str, Any] | None = None,
|
|
97
|
+
expected: tuple[int, ...] = (200,),
|
|
98
|
+
) -> dict[str, Any]:
|
|
99
|
+
resp = self._client.request(method, path, json=json, params=params)
|
|
100
|
+
self._raise_for_error(resp)
|
|
101
|
+
if resp.status_code not in expected:
|
|
102
|
+
raise StrandError(
|
|
103
|
+
f"Unexpected status {resp.status_code} for {method} {path}",
|
|
104
|
+
status_code=resp.status_code,
|
|
105
|
+
body=_safe_body(resp),
|
|
106
|
+
)
|
|
107
|
+
return cast(dict[str, Any], resp.json())
|
|
108
|
+
|
|
109
|
+
def request_bytes(
|
|
110
|
+
self,
|
|
111
|
+
method: str,
|
|
112
|
+
path: str,
|
|
113
|
+
*,
|
|
114
|
+
params: dict[str, Any] | None = None,
|
|
115
|
+
) -> bytes:
|
|
116
|
+
resp = self._client.request(method, path, params=params)
|
|
117
|
+
self._raise_for_error(resp)
|
|
118
|
+
return resp.content
|
|
119
|
+
|
|
120
|
+
def stream_lines(
|
|
121
|
+
self,
|
|
122
|
+
method: str,
|
|
123
|
+
path: str,
|
|
124
|
+
*,
|
|
125
|
+
params: dict[str, Any] | None = None,
|
|
126
|
+
timeout: float | httpx.Timeout | None = None,
|
|
127
|
+
) -> Iterator[bytes]:
|
|
128
|
+
"""Yield raw lines (including trailing newlines stripped) from a streaming response.
|
|
129
|
+
|
|
130
|
+
Used for the SSE endpoint — callers wrap with httpx-sse for event parsing.
|
|
131
|
+
"""
|
|
132
|
+
ctx_timeout = (
|
|
133
|
+
timeout if timeout is not None else httpx.Timeout(connect=10.0, read=None, write=10.0, pool=10.0)
|
|
134
|
+
)
|
|
135
|
+
with self._client.stream(method, path, params=params, timeout=ctx_timeout) as resp:
|
|
136
|
+
self._raise_for_error(resp)
|
|
137
|
+
for line in resp.iter_lines():
|
|
138
|
+
yield line.encode("utf-8") if isinstance(line, str) else line
|
|
139
|
+
|
|
140
|
+
def stream_response(
|
|
141
|
+
self,
|
|
142
|
+
method: str,
|
|
143
|
+
path: str,
|
|
144
|
+
*,
|
|
145
|
+
params: dict[str, Any] | None = None,
|
|
146
|
+
timeout: float | httpx.Timeout | None = None,
|
|
147
|
+
) -> httpx.Response:
|
|
148
|
+
"""Open a streaming response and return it; caller must call `.close()`.
|
|
149
|
+
|
|
150
|
+
Used by the SSE wait loop with httpx-sse.
|
|
151
|
+
"""
|
|
152
|
+
ctx_timeout = (
|
|
153
|
+
timeout if timeout is not None else httpx.Timeout(connect=10.0, read=None, write=10.0, pool=10.0)
|
|
154
|
+
)
|
|
155
|
+
req = self._client.build_request(method, path, params=params, timeout=ctx_timeout)
|
|
156
|
+
return self._client.send(req, stream=True)
|
|
157
|
+
|
|
158
|
+
# ---------- error mapping ----------
|
|
159
|
+
|
|
160
|
+
def _raise_for_error(self, resp: httpx.Response) -> None:
|
|
161
|
+
if resp.status_code < 400:
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
body = _safe_body(resp)
|
|
165
|
+
message = _extract_message(body, resp)
|
|
166
|
+
error_code = body.get("error") if isinstance(body, dict) else None
|
|
167
|
+
|
|
168
|
+
if resp.status_code == 400:
|
|
169
|
+
raise BadRequestError(message, status_code=400, error_code=error_code, body=body)
|
|
170
|
+
if resp.status_code == 401:
|
|
171
|
+
raise AuthError(message, status_code=401, error_code=error_code, body=body)
|
|
172
|
+
if resp.status_code == 402:
|
|
173
|
+
required = body.get("required") if isinstance(body, dict) else None
|
|
174
|
+
raise InsufficientCreditsError(
|
|
175
|
+
message,
|
|
176
|
+
required=int(required) if isinstance(required, int) else None,
|
|
177
|
+
body=body,
|
|
178
|
+
)
|
|
179
|
+
if resp.status_code == 404:
|
|
180
|
+
raise NotFoundError(message, status_code=404, error_code=error_code, body=body)
|
|
181
|
+
if resp.status_code == 429:
|
|
182
|
+
retry_after_raw = resp.headers.get("Retry-After")
|
|
183
|
+
try:
|
|
184
|
+
retry_after = int(retry_after_raw) if retry_after_raw is not None else None
|
|
185
|
+
except ValueError:
|
|
186
|
+
retry_after = None
|
|
187
|
+
raise RateLimitError(message, retry_after=retry_after, body=body)
|
|
188
|
+
# 409 and other 4xx with a documented error shape — surface as generic StrandError.
|
|
189
|
+
raise StrandError(
|
|
190
|
+
message,
|
|
191
|
+
status_code=resp.status_code,
|
|
192
|
+
error_code=error_code,
|
|
193
|
+
body=body,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _safe_body(resp: httpx.Response) -> dict[str, Any]:
|
|
198
|
+
try:
|
|
199
|
+
data = resp.json()
|
|
200
|
+
except ValueError:
|
|
201
|
+
return {}
|
|
202
|
+
return data if isinstance(data, dict) else {}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _extract_message(body: dict[str, Any], resp: httpx.Response) -> str:
|
|
206
|
+
if isinstance(body, dict):
|
|
207
|
+
msg = body.get("message")
|
|
208
|
+
if isinstance(msg, str) and msg:
|
|
209
|
+
return msg
|
|
210
|
+
err = body.get("error")
|
|
211
|
+
if isinstance(err, str) and err:
|
|
212
|
+
return err
|
|
213
|
+
return f"HTTP {resp.status_code}"
|
strand/_jobs.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""Job handle: status polling, SSE waits, results download."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
from collections.abc import Iterator
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
from httpx_sse import EventSource, SSEError
|
|
13
|
+
|
|
14
|
+
from ._errors import JobFailedError, JobTimeoutError, StrandError
|
|
15
|
+
from ._models import JobStatus
|
|
16
|
+
from ._results import JobResults
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from ._client import Client
|
|
20
|
+
|
|
21
|
+
TERMINAL_STATUSES = frozenset({"completed", "failed"})
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class JobEvent:
|
|
26
|
+
"""A single status snapshot pushed over SSE."""
|
|
27
|
+
|
|
28
|
+
id: str | None
|
|
29
|
+
status: str | None
|
|
30
|
+
progress: float | None
|
|
31
|
+
result_gcs_path: str | None
|
|
32
|
+
raw: dict[str, Any]
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def _from_payload(cls, payload: dict[str, Any]) -> JobEvent:
|
|
36
|
+
return cls(
|
|
37
|
+
id=payload.get("id"),
|
|
38
|
+
status=payload.get("status"),
|
|
39
|
+
progress=(
|
|
40
|
+
float(payload["progress"])
|
|
41
|
+
if isinstance(payload.get("progress"), (int, float))
|
|
42
|
+
else None
|
|
43
|
+
),
|
|
44
|
+
result_gcs_path=payload.get("resultGcsPath"),
|
|
45
|
+
raw=payload,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def is_terminal(self) -> bool:
|
|
50
|
+
return self.status in TERMINAL_STATUSES
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Job:
|
|
54
|
+
"""Handle for a submitted prediction job.
|
|
55
|
+
|
|
56
|
+
Created by `Client.predict.submit(...)` and `Client.jobs.get(...)`.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, *, id: str, reserved_credits: int | None, client: Client) -> None:
|
|
60
|
+
self.id = id
|
|
61
|
+
self.reserved_credits = reserved_credits
|
|
62
|
+
self._client = client
|
|
63
|
+
self._http = client._http
|
|
64
|
+
self._cached_status: JobStatus | None = None
|
|
65
|
+
|
|
66
|
+
# ---------- public surface ----------
|
|
67
|
+
|
|
68
|
+
def __repr__(self) -> str:
|
|
69
|
+
s = self._cached_status.status if self._cached_status else "unknown"
|
|
70
|
+
return f"Job(id={self.id!r}, status={s!r})"
|
|
71
|
+
|
|
72
|
+
def refresh(self) -> JobStatus:
|
|
73
|
+
"""Fetch the latest status snapshot and cache it on the job."""
|
|
74
|
+
raw = self._http.request_json("GET", f"/jobs/{self.id}")
|
|
75
|
+
status = JobStatus._from_dict(raw)
|
|
76
|
+
self._cached_status = status
|
|
77
|
+
return status
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def status(self) -> JobStatus:
|
|
81
|
+
"""Most recently fetched status. Calls `refresh()` if none cached."""
|
|
82
|
+
if self._cached_status is None:
|
|
83
|
+
return self.refresh()
|
|
84
|
+
return self._cached_status
|
|
85
|
+
|
|
86
|
+
def stream_events(self) -> Iterator[JobEvent]:
|
|
87
|
+
"""Yield `JobEvent`s as the server emits them.
|
|
88
|
+
|
|
89
|
+
The generator closes when the job reaches a terminal status. The platform
|
|
90
|
+
emits `: keep-alive` heartbeats; httpx-sse filters those out.
|
|
91
|
+
"""
|
|
92
|
+
resp = self._http.stream_response("GET", f"/jobs/{self.id}/stream")
|
|
93
|
+
try:
|
|
94
|
+
try:
|
|
95
|
+
event_source = EventSource(resp)
|
|
96
|
+
for sse in event_source.iter_sse():
|
|
97
|
+
if sse.event and sse.event != "message":
|
|
98
|
+
continue
|
|
99
|
+
data = sse.data
|
|
100
|
+
if not data:
|
|
101
|
+
continue
|
|
102
|
+
try:
|
|
103
|
+
payload = json.loads(data)
|
|
104
|
+
except json.JSONDecodeError:
|
|
105
|
+
continue
|
|
106
|
+
if not isinstance(payload, dict):
|
|
107
|
+
continue
|
|
108
|
+
if "error" in payload and "status" not in payload:
|
|
109
|
+
raise StrandError(
|
|
110
|
+
f"Server reported error in stream: {payload['error']}",
|
|
111
|
+
body=payload,
|
|
112
|
+
)
|
|
113
|
+
event = JobEvent._from_payload(payload)
|
|
114
|
+
yield event
|
|
115
|
+
if event.is_terminal:
|
|
116
|
+
return
|
|
117
|
+
except SSEError as exc: # malformed stream → fall back to polling.
|
|
118
|
+
raise StrandError(f"Malformed SSE stream: {exc}") from exc
|
|
119
|
+
finally:
|
|
120
|
+
resp.close()
|
|
121
|
+
|
|
122
|
+
def wait(
|
|
123
|
+
self,
|
|
124
|
+
*,
|
|
125
|
+
timeout: float | None = None,
|
|
126
|
+
poll_interval: float = 2.0,
|
|
127
|
+
use_stream: bool = True,
|
|
128
|
+
) -> JobStatus:
|
|
129
|
+
"""Block until the job reaches a terminal status.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
timeout: Max seconds to wait. `None` waits forever.
|
|
133
|
+
poll_interval: Used by the polling fallback if `use_stream=False`
|
|
134
|
+
or if the stream connection drops mid-job.
|
|
135
|
+
use_stream: When `True` (default), prefer SSE; fall back to polling
|
|
136
|
+
if the stream errors out.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
The terminal `JobStatus`.
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
JobFailedError: status terminates as `"failed"`.
|
|
143
|
+
JobTimeoutError: `timeout` elapses before terminal status.
|
|
144
|
+
"""
|
|
145
|
+
deadline = time.monotonic() + timeout if timeout is not None else None
|
|
146
|
+
|
|
147
|
+
def _check_deadline() -> None:
|
|
148
|
+
if deadline is not None and time.monotonic() > deadline:
|
|
149
|
+
raise JobTimeoutError(
|
|
150
|
+
f"Job {self.id} did not reach terminal status within {timeout}s",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if use_stream:
|
|
154
|
+
try:
|
|
155
|
+
for event in self.stream_events():
|
|
156
|
+
_check_deadline()
|
|
157
|
+
if event.is_terminal:
|
|
158
|
+
break
|
|
159
|
+
except (httpx.HTTPError, StrandError):
|
|
160
|
+
# Drop down to polling — possibly a transient disconnect.
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
while True:
|
|
164
|
+
_check_deadline()
|
|
165
|
+
status = self.refresh()
|
|
166
|
+
if status.is_terminal:
|
|
167
|
+
if status.status == "failed":
|
|
168
|
+
raise JobFailedError(
|
|
169
|
+
status.error_message or f"Job {self.id} failed",
|
|
170
|
+
job_id=self.id,
|
|
171
|
+
)
|
|
172
|
+
return status
|
|
173
|
+
sleep_for = poll_interval
|
|
174
|
+
if deadline is not None:
|
|
175
|
+
sleep_for = min(sleep_for, max(0.0, deadline - time.monotonic()))
|
|
176
|
+
if sleep_for <= 0:
|
|
177
|
+
_check_deadline()
|
|
178
|
+
time.sleep(sleep_for)
|
|
179
|
+
|
|
180
|
+
def results(self) -> JobResults:
|
|
181
|
+
"""Return a `JobResults` handle (lazy — does not fetch zarr bytes)."""
|
|
182
|
+
raw = self._http.request_json("GET", f"/jobs/{self.id}/results")
|
|
183
|
+
return JobResults(
|
|
184
|
+
job_id=self.id,
|
|
185
|
+
result_url=str(raw["resultUrl"]),
|
|
186
|
+
result_base_path=str(raw.get("resultBasePath", "")),
|
|
187
|
+
expires_at=str(raw["expiresAt"]),
|
|
188
|
+
client=self._client,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def download_results(
|
|
192
|
+
self,
|
|
193
|
+
path: str | None = None,
|
|
194
|
+
) -> Any:
|
|
195
|
+
"""Download all result zarr files.
|
|
196
|
+
|
|
197
|
+
- If `path` is `None` (default), parse the zarr in-memory and return an
|
|
198
|
+
`AnnData` object (requires the `anndata` extra).
|
|
199
|
+
- If `path` is given, write the zarr store to that directory and return
|
|
200
|
+
the `Path`. No `anndata` dependency required.
|
|
201
|
+
"""
|
|
202
|
+
results = self.results()
|
|
203
|
+
if path is None:
|
|
204
|
+
return results.to_anndata()
|
|
205
|
+
return results.download_to(path)
|