strand-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
strand/__init__.py ADDED
@@ -0,0 +1,62 @@
1
+ """Strand Platform Python SDK.
2
+
3
+ Quickstart — one-call pipeline:
4
+
5
+ >>> from strand import Client
6
+ >>> client = Client() # reads STRAND_API_KEY
7
+ >>> result = client.predict(
8
+ ... "slide.svs",
9
+ ... markers=["CD3", "CD8"],
10
+ ... output_dir="./outputs/",
11
+ ... )
12
+ >>> print(f"used {result.credits_used} credits")
13
+
14
+ Lower-level primitives stay available for fine-grained control:
15
+
16
+ >>> upload = client.uploads.upload_file("slide.svs")
17
+ >>> job = client.predict.submit(upload.id, markers=["CD3", "CD8"])
18
+ >>> job.wait()
19
+ >>> adata = job.download_results()
20
+
21
+ See `https://app.strandai.com/docs/api` for the underlying REST API reference.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from ._client import Client
27
+ from ._errors import (
28
+ AuthError,
29
+ BadRequestError,
30
+ InsufficientCreditsError,
31
+ JobFailedError,
32
+ JobTimeoutError,
33
+ NotFoundError,
34
+ RateLimitError,
35
+ StrandError,
36
+ UploadError,
37
+ )
38
+ from ._jobs import Job, JobEvent
39
+ from ._models import Estimate, JobStatus, PredictResult, Upload
40
+ from ._results import JobResults
41
+
42
+ __all__ = [
43
+ "AuthError",
44
+ "BadRequestError",
45
+ "Client",
46
+ "Estimate",
47
+ "InsufficientCreditsError",
48
+ "Job",
49
+ "JobEvent",
50
+ "JobFailedError",
51
+ "JobResults",
52
+ "JobStatus",
53
+ "JobTimeoutError",
54
+ "NotFoundError",
55
+ "PredictResult",
56
+ "RateLimitError",
57
+ "StrandError",
58
+ "Upload",
59
+ "UploadError",
60
+ ]
61
+
62
+ __version__ = "0.1.0"
strand/_client.py ADDED
@@ -0,0 +1,95 @@
1
+ """Top-level `Client` — entry point for the SDK."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import httpx
8
+
9
+ from ._http import DEFAULT_TIMEOUT, HttpSession
10
+ from ._predict import Predict
11
+ from ._uploads import Uploads
12
+
13
+ if TYPE_CHECKING:
14
+ from ._jobs import Job
15
+
16
+
17
+ class _JobsNamespace:
18
+ """`client.jobs` namespace — fetch / look up jobs by id."""
19
+
20
+ def __init__(self, client: Client) -> None:
21
+ self._client = client
22
+
23
+ def get(self, job_id: str) -> Job:
24
+ """Return a `Job` handle and pre-populate its cached status."""
25
+ from ._jobs import Job
26
+
27
+ job = Job(id=job_id, reserved_credits=None, client=self._client)
28
+ job.refresh()
29
+ return job
30
+
31
+
32
+ class Client:
33
+ """Strand Platform API client.
34
+
35
+ Args:
36
+ api_key: API key (`sk-strand-...`). Falls back to `STRAND_API_KEY` env var.
37
+ base_url: API base URL. Defaults to `STRAND_BASE_URL` env var, else
38
+ `https://app.strandai.com`. Should not include the `/api/v1` suffix.
39
+ timeout: Per-request timeout in seconds (or an `httpx.Timeout`).
40
+ http_client: Pre-built `httpx.Client` for advanced use (e.g., custom
41
+ transport, retries, ASGI mounting in tests). The SDK will NOT
42
+ override the client's `Authorization` header — if you pass one,
43
+ wire auth headers yourself.
44
+
45
+ Example — one-call pipeline:
46
+
47
+ >>> client = Client(api_key="sk-strand-...")
48
+ >>> result = client.predict(
49
+ ... "slide.svs",
50
+ ... markers=["CD3", "CD8"],
51
+ ... output_dir="./outputs/",
52
+ ... )
53
+ >>> print(f"used {result.credits_used} credits")
54
+
55
+ Lower-level primitives (`client.predict` is also a namespace):
56
+
57
+ >>> upload = client.uploads.upload_file("slide.svs")
58
+ >>> estimate = client.predict.estimate(upload.id, markers=["CD3"])
59
+ >>> job = client.predict.submit(upload.id, markers=["CD3"])
60
+ >>> job.wait()
61
+ >>> adata = job.download_results()
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ *,
67
+ api_key: str | None = None,
68
+ base_url: str | None = None,
69
+ timeout: float | httpx.Timeout | None = DEFAULT_TIMEOUT,
70
+ http_client: httpx.Client | None = None,
71
+ ) -> None:
72
+ self._http = HttpSession(
73
+ api_key=api_key, base_url=base_url, timeout=timeout, client=http_client
74
+ )
75
+ self.uploads = Uploads(self._http)
76
+ self.predict = Predict(self._http, self)
77
+ self.jobs = _JobsNamespace(self)
78
+
79
+ @property
80
+ def base_url(self) -> str:
81
+ return self._http.base_url
82
+
83
+ @property
84
+ def api_root(self) -> str:
85
+ return self._http.api_root
86
+
87
+ def close(self) -> None:
88
+ """Close the underlying httpx client."""
89
+ self._http.close()
90
+
91
+ def __enter__(self) -> Client:
92
+ return self
93
+
94
+ def __exit__(self, *_: object) -> None:
95
+ self.close()
strand/_errors.py ADDED
@@ -0,0 +1,92 @@
1
+ """Typed exceptions for the Strand SDK.
2
+
3
+ All HTTP-level failures raised by the public surface inherit from `StrandError`.
4
+ Network-level failures (`httpx.HTTPError` and friends) pass through unchanged so
5
+ callers can apply their own retry logic — we only wrap responses that the
6
+ platform itself returned with a documented error shape.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any
12
+
13
+
14
+ class StrandError(Exception):
15
+ """Base class for SDK errors raised against documented API responses."""
16
+
17
+ def __init__(
18
+ self,
19
+ message: str,
20
+ *,
21
+ status_code: int | None = None,
22
+ error_code: str | None = None,
23
+ body: dict[str, Any] | None = None,
24
+ ) -> None:
25
+ super().__init__(message)
26
+ self.message = message
27
+ self.status_code = status_code
28
+ self.error_code = error_code
29
+ self.body = body or {}
30
+
31
+
32
+ class AuthError(StrandError):
33
+ """401 — missing / invalid / expired API key."""
34
+
35
+
36
+ class BadRequestError(StrandError):
37
+ """400 — request body or arguments rejected by the server."""
38
+
39
+
40
+ class NotFoundError(StrandError):
41
+ """404 — referenced resource (upload, job, file) does not exist or isn't accessible."""
42
+
43
+
44
+ class InsufficientCreditsError(StrandError):
45
+ """402 — org has insufficient credits to reserve for this job.
46
+
47
+ Attributes:
48
+ required: Credits required to run the job, as returned by the server.
49
+ balance: Best-effort cached org balance from the most recent estimate, if available.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ message: str,
55
+ *,
56
+ required: int | None = None,
57
+ balance: int | None = None,
58
+ body: dict[str, Any] | None = None,
59
+ ) -> None:
60
+ super().__init__(message, status_code=402, error_code="insufficient_credits", body=body)
61
+ self.required = required
62
+ self.balance = balance
63
+
64
+
65
+ class RateLimitError(StrandError):
66
+ """429 — per-org concurrent job cap exceeded. `retry_after` is in seconds."""
67
+
68
+ def __init__(
69
+ self,
70
+ message: str,
71
+ *,
72
+ retry_after: int | None = None,
73
+ body: dict[str, Any] | None = None,
74
+ ) -> None:
75
+ super().__init__(message, status_code=429, error_code="rate_limited", body=body)
76
+ self.retry_after = retry_after
77
+
78
+
79
+ class JobFailedError(StrandError):
80
+ """Raised by `Job.wait()` when the job terminates with `status == "failed"`."""
81
+
82
+ def __init__(self, message: str, *, job_id: str) -> None:
83
+ super().__init__(message, error_code="job_failed")
84
+ self.job_id = job_id
85
+
86
+
87
+ class JobTimeoutError(StrandError):
88
+ """Raised by `Job.wait(timeout=...)` when the wait deadline elapses before terminal status."""
89
+
90
+
91
+ class UploadError(StrandError):
92
+ """Raised when the resumable upload session aborts or returns an unexpected response."""
strand/_http.py ADDED
@@ -0,0 +1,213 @@
1
+ """Thin HTTP helper around the generated `AuthenticatedClient`.
2
+
3
+ Centralizes:
4
+ - Constructing an `httpx.Client` with our auth header, base URL, and timeouts.
5
+ - Mapping documented error response bodies onto our typed exceptions.
6
+
7
+ We deliberately do not lean on the generated `client.AuthenticatedClient.with_*`
8
+ helpers — `httpx.Client` is plenty, and we want a single source of truth for
9
+ error mapping.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ from collections.abc import Iterator
16
+ from typing import Any, cast
17
+
18
+ import httpx
19
+
20
+ from ._errors import (
21
+ AuthError,
22
+ BadRequestError,
23
+ InsufficientCreditsError,
24
+ NotFoundError,
25
+ RateLimitError,
26
+ StrandError,
27
+ )
28
+
29
+ DEFAULT_BASE_URL = "https://app.strandai.com"
30
+ DEFAULT_TIMEOUT = 60.0
31
+ USER_AGENT = "strand-sdk-python/0.1.0"
32
+
33
+
34
+ class HttpSession:
35
+ """Wraps `httpx.Client` with Strand auth + error mapping."""
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ api_key: str | None,
41
+ base_url: str | None,
42
+ timeout: float | httpx.Timeout | None,
43
+ client: httpx.Client | None = None,
44
+ ) -> None:
45
+ resolved_key = api_key or os.environ.get("STRAND_API_KEY")
46
+ if not resolved_key:
47
+ raise AuthError(
48
+ "No API key provided. Pass api_key=... or set STRAND_API_KEY.",
49
+ status_code=401,
50
+ error_code="missing_api_key",
51
+ )
52
+
53
+ resolved_base = (
54
+ base_url or os.environ.get("STRAND_BASE_URL") or DEFAULT_BASE_URL
55
+ ).rstrip("/")
56
+
57
+ self.api_key = resolved_key
58
+ self.base_url = resolved_base
59
+ self.api_root = f"{resolved_base}/api/v1"
60
+ self._owned_client = client is None
61
+ self._client = client or httpx.Client(
62
+ base_url=self.api_root,
63
+ timeout=timeout if timeout is not None else DEFAULT_TIMEOUT,
64
+ headers={
65
+ "Authorization": f"Bearer {resolved_key}",
66
+ "User-Agent": USER_AGENT,
67
+ "Accept": "application/json",
68
+ },
69
+ follow_redirects=False,
70
+ )
71
+
72
+ # ---------- lifecycle ----------
73
+
74
+ def close(self) -> None:
75
+ if self._owned_client:
76
+ self._client.close()
77
+
78
+ def __enter__(self) -> HttpSession:
79
+ return self
80
+
81
+ def __exit__(self, *_: object) -> None:
82
+ self.close()
83
+
84
+ @property
85
+ def client(self) -> httpx.Client:
86
+ return self._client
87
+
88
+ # ---------- request helpers ----------
89
+
90
+ def request_json(
91
+ self,
92
+ method: str,
93
+ path: str,
94
+ *,
95
+ json: Any = None,
96
+ params: dict[str, Any] | None = None,
97
+ expected: tuple[int, ...] = (200,),
98
+ ) -> dict[str, Any]:
99
+ resp = self._client.request(method, path, json=json, params=params)
100
+ self._raise_for_error(resp)
101
+ if resp.status_code not in expected:
102
+ raise StrandError(
103
+ f"Unexpected status {resp.status_code} for {method} {path}",
104
+ status_code=resp.status_code,
105
+ body=_safe_body(resp),
106
+ )
107
+ return cast(dict[str, Any], resp.json())
108
+
109
+ def request_bytes(
110
+ self,
111
+ method: str,
112
+ path: str,
113
+ *,
114
+ params: dict[str, Any] | None = None,
115
+ ) -> bytes:
116
+ resp = self._client.request(method, path, params=params)
117
+ self._raise_for_error(resp)
118
+ return resp.content
119
+
120
+ def stream_lines(
121
+ self,
122
+ method: str,
123
+ path: str,
124
+ *,
125
+ params: dict[str, Any] | None = None,
126
+ timeout: float | httpx.Timeout | None = None,
127
+ ) -> Iterator[bytes]:
128
+ """Yield raw lines (including trailing newlines stripped) from a streaming response.
129
+
130
+ Used for the SSE endpoint — callers wrap with httpx-sse for event parsing.
131
+ """
132
+ ctx_timeout = (
133
+ timeout if timeout is not None else httpx.Timeout(connect=10.0, read=None, write=10.0, pool=10.0)
134
+ )
135
+ with self._client.stream(method, path, params=params, timeout=ctx_timeout) as resp:
136
+ self._raise_for_error(resp)
137
+ for line in resp.iter_lines():
138
+ yield line.encode("utf-8") if isinstance(line, str) else line
139
+
140
+ def stream_response(
141
+ self,
142
+ method: str,
143
+ path: str,
144
+ *,
145
+ params: dict[str, Any] | None = None,
146
+ timeout: float | httpx.Timeout | None = None,
147
+ ) -> httpx.Response:
148
+ """Open a streaming response and return it; caller must call `.close()`.
149
+
150
+ Used by the SSE wait loop with httpx-sse.
151
+ """
152
+ ctx_timeout = (
153
+ timeout if timeout is not None else httpx.Timeout(connect=10.0, read=None, write=10.0, pool=10.0)
154
+ )
155
+ req = self._client.build_request(method, path, params=params, timeout=ctx_timeout)
156
+ return self._client.send(req, stream=True)
157
+
158
+ # ---------- error mapping ----------
159
+
160
+ def _raise_for_error(self, resp: httpx.Response) -> None:
161
+ if resp.status_code < 400:
162
+ return
163
+
164
+ body = _safe_body(resp)
165
+ message = _extract_message(body, resp)
166
+ error_code = body.get("error") if isinstance(body, dict) else None
167
+
168
+ if resp.status_code == 400:
169
+ raise BadRequestError(message, status_code=400, error_code=error_code, body=body)
170
+ if resp.status_code == 401:
171
+ raise AuthError(message, status_code=401, error_code=error_code, body=body)
172
+ if resp.status_code == 402:
173
+ required = body.get("required") if isinstance(body, dict) else None
174
+ raise InsufficientCreditsError(
175
+ message,
176
+ required=int(required) if isinstance(required, int) else None,
177
+ body=body,
178
+ )
179
+ if resp.status_code == 404:
180
+ raise NotFoundError(message, status_code=404, error_code=error_code, body=body)
181
+ if resp.status_code == 429:
182
+ retry_after_raw = resp.headers.get("Retry-After")
183
+ try:
184
+ retry_after = int(retry_after_raw) if retry_after_raw is not None else None
185
+ except ValueError:
186
+ retry_after = None
187
+ raise RateLimitError(message, retry_after=retry_after, body=body)
188
+ # 409 and other 4xx with a documented error shape — surface as generic StrandError.
189
+ raise StrandError(
190
+ message,
191
+ status_code=resp.status_code,
192
+ error_code=error_code,
193
+ body=body,
194
+ )
195
+
196
+
197
+ def _safe_body(resp: httpx.Response) -> dict[str, Any]:
198
+ try:
199
+ data = resp.json()
200
+ except ValueError:
201
+ return {}
202
+ return data if isinstance(data, dict) else {}
203
+
204
+
205
+ def _extract_message(body: dict[str, Any], resp: httpx.Response) -> str:
206
+ if isinstance(body, dict):
207
+ msg = body.get("message")
208
+ if isinstance(msg, str) and msg:
209
+ return msg
210
+ err = body.get("error")
211
+ if isinstance(err, str) and err:
212
+ return err
213
+ return f"HTTP {resp.status_code}"
strand/_jobs.py ADDED
@@ -0,0 +1,205 @@
1
+ """Job handle: status polling, SSE waits, results download."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import time
7
+ from collections.abc import Iterator
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import httpx
12
+ from httpx_sse import EventSource, SSEError
13
+
14
+ from ._errors import JobFailedError, JobTimeoutError, StrandError
15
+ from ._models import JobStatus
16
+ from ._results import JobResults
17
+
18
+ if TYPE_CHECKING:
19
+ from ._client import Client
20
+
21
+ TERMINAL_STATUSES = frozenset({"completed", "failed"})
22
+
23
+
24
+ @dataclass
25
+ class JobEvent:
26
+ """A single status snapshot pushed over SSE."""
27
+
28
+ id: str | None
29
+ status: str | None
30
+ progress: float | None
31
+ result_gcs_path: str | None
32
+ raw: dict[str, Any]
33
+
34
+ @classmethod
35
+ def _from_payload(cls, payload: dict[str, Any]) -> JobEvent:
36
+ return cls(
37
+ id=payload.get("id"),
38
+ status=payload.get("status"),
39
+ progress=(
40
+ float(payload["progress"])
41
+ if isinstance(payload.get("progress"), (int, float))
42
+ else None
43
+ ),
44
+ result_gcs_path=payload.get("resultGcsPath"),
45
+ raw=payload,
46
+ )
47
+
48
+ @property
49
+ def is_terminal(self) -> bool:
50
+ return self.status in TERMINAL_STATUSES
51
+
52
+
53
+ class Job:
54
+ """Handle for a submitted prediction job.
55
+
56
+ Created by `Client.predict.submit(...)` and `Client.jobs.get(...)`.
57
+ """
58
+
59
+ def __init__(self, *, id: str, reserved_credits: int | None, client: Client) -> None:
60
+ self.id = id
61
+ self.reserved_credits = reserved_credits
62
+ self._client = client
63
+ self._http = client._http
64
+ self._cached_status: JobStatus | None = None
65
+
66
+ # ---------- public surface ----------
67
+
68
+ def __repr__(self) -> str:
69
+ s = self._cached_status.status if self._cached_status else "unknown"
70
+ return f"Job(id={self.id!r}, status={s!r})"
71
+
72
+ def refresh(self) -> JobStatus:
73
+ """Fetch the latest status snapshot and cache it on the job."""
74
+ raw = self._http.request_json("GET", f"/jobs/{self.id}")
75
+ status = JobStatus._from_dict(raw)
76
+ self._cached_status = status
77
+ return status
78
+
79
+ @property
80
+ def status(self) -> JobStatus:
81
+ """Most recently fetched status. Calls `refresh()` if none cached."""
82
+ if self._cached_status is None:
83
+ return self.refresh()
84
+ return self._cached_status
85
+
86
+ def stream_events(self) -> Iterator[JobEvent]:
87
+ """Yield `JobEvent`s as the server emits them.
88
+
89
+ The generator closes when the job reaches a terminal status. The platform
90
+ emits `: keep-alive` heartbeats; httpx-sse filters those out.
91
+ """
92
+ resp = self._http.stream_response("GET", f"/jobs/{self.id}/stream")
93
+ try:
94
+ try:
95
+ event_source = EventSource(resp)
96
+ for sse in event_source.iter_sse():
97
+ if sse.event and sse.event != "message":
98
+ continue
99
+ data = sse.data
100
+ if not data:
101
+ continue
102
+ try:
103
+ payload = json.loads(data)
104
+ except json.JSONDecodeError:
105
+ continue
106
+ if not isinstance(payload, dict):
107
+ continue
108
+ if "error" in payload and "status" not in payload:
109
+ raise StrandError(
110
+ f"Server reported error in stream: {payload['error']}",
111
+ body=payload,
112
+ )
113
+ event = JobEvent._from_payload(payload)
114
+ yield event
115
+ if event.is_terminal:
116
+ return
117
+ except SSEError as exc: # malformed stream → fall back to polling.
118
+ raise StrandError(f"Malformed SSE stream: {exc}") from exc
119
+ finally:
120
+ resp.close()
121
+
122
+ def wait(
123
+ self,
124
+ *,
125
+ timeout: float | None = None,
126
+ poll_interval: float = 2.0,
127
+ use_stream: bool = True,
128
+ ) -> JobStatus:
129
+ """Block until the job reaches a terminal status.
130
+
131
+ Args:
132
+ timeout: Max seconds to wait. `None` waits forever.
133
+ poll_interval: Used by the polling fallback if `use_stream=False`
134
+ or if the stream connection drops mid-job.
135
+ use_stream: When `True` (default), prefer SSE; fall back to polling
136
+ if the stream errors out.
137
+
138
+ Returns:
139
+ The terminal `JobStatus`.
140
+
141
+ Raises:
142
+ JobFailedError: status terminates as `"failed"`.
143
+ JobTimeoutError: `timeout` elapses before terminal status.
144
+ """
145
+ deadline = time.monotonic() + timeout if timeout is not None else None
146
+
147
+ def _check_deadline() -> None:
148
+ if deadline is not None and time.monotonic() > deadline:
149
+ raise JobTimeoutError(
150
+ f"Job {self.id} did not reach terminal status within {timeout}s",
151
+ )
152
+
153
+ if use_stream:
154
+ try:
155
+ for event in self.stream_events():
156
+ _check_deadline()
157
+ if event.is_terminal:
158
+ break
159
+ except (httpx.HTTPError, StrandError):
160
+ # Drop down to polling — possibly a transient disconnect.
161
+ pass
162
+
163
+ while True:
164
+ _check_deadline()
165
+ status = self.refresh()
166
+ if status.is_terminal:
167
+ if status.status == "failed":
168
+ raise JobFailedError(
169
+ status.error_message or f"Job {self.id} failed",
170
+ job_id=self.id,
171
+ )
172
+ return status
173
+ sleep_for = poll_interval
174
+ if deadline is not None:
175
+ sleep_for = min(sleep_for, max(0.0, deadline - time.monotonic()))
176
+ if sleep_for <= 0:
177
+ _check_deadline()
178
+ time.sleep(sleep_for)
179
+
180
+ def results(self) -> JobResults:
181
+ """Return a `JobResults` handle (lazy — does not fetch zarr bytes)."""
182
+ raw = self._http.request_json("GET", f"/jobs/{self.id}/results")
183
+ return JobResults(
184
+ job_id=self.id,
185
+ result_url=str(raw["resultUrl"]),
186
+ result_base_path=str(raw.get("resultBasePath", "")),
187
+ expires_at=str(raw["expiresAt"]),
188
+ client=self._client,
189
+ )
190
+
191
+ def download_results(
192
+ self,
193
+ path: str | None = None,
194
+ ) -> Any:
195
+ """Download all result zarr files.
196
+
197
+ - If `path` is `None` (default), parse the zarr in-memory and return an
198
+ `AnnData` object (requires the `anndata` extra).
199
+ - If `path` is given, write the zarr store to that directory and return
200
+ the `Path`. No `anndata` dependency required.
201
+ """
202
+ results = self.results()
203
+ if path is None:
204
+ return results.to_anndata()
205
+ return results.download_to(path)