cradle-sdk 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cradle/sdk/client.py ADDED
@@ -0,0 +1,937 @@
1
+ import copy
2
+ import difflib
3
+ import time
4
+ from collections.abc import Iterator
5
+ from datetime import UTC, datetime
6
+ from importlib.metadata import version as pkg_version
7
+ from pathlib import Path
8
+ from typing import Any, BinaryIO, NoReturn, TypeVar
9
+ from urllib.parse import urljoin
10
+
11
+ import httpx
12
+ import pyarrow as pa
13
+ from pydantic import BaseModel, TypeAdapter
14
+ from typing_extensions import Generator, override
15
+
16
+ from cradle.sdk.auth.device import DeviceAuth
17
+ from cradle.sdk.exceptions import ClientError as ClientError # for re-export
18
+ from cradle.sdk.types.common import (
19
+ ContextProject,
20
+ ContextRound,
21
+ ContextWorkspace,
22
+ ListOptions,
23
+ )
24
+ from cradle.sdk.types.data import (
25
+ AddTableRequest,
26
+ ArtifactResponse,
27
+ BaseTableCreate,
28
+ BaseTableUpdate,
29
+ DataLoadCreate,
30
+ DataLoadResponse,
31
+ DataVersionResponse,
32
+ FileUploadResponse,
33
+ ListArtifactResponse,
34
+ ListDataLoadResponse,
35
+ ListDataVersionResponse,
36
+ ListTableResponse,
37
+ QueryDataRequest,
38
+ TableArchive,
39
+ TableRename,
40
+ TableResponse,
41
+ ViewTableCreate,
42
+ ViewTableUpdate,
43
+ )
44
+ from cradle.sdk.types.task import (
45
+ ListTaskResponse,
46
+ TaskCreate,
47
+ TaskResponse,
48
+ TaskState,
49
+ )
50
+ from cradle.sdk.types.workspace import (
51
+ ListProjectResponse,
52
+ ListRoundResponse,
53
+ ProjectCreate,
54
+ ProjectResponse,
55
+ RoundCreate,
56
+ RoundResponse,
57
+ WorkspaceResponse,
58
+ )
59
+
60
+ T = TypeVar("T", bound=BaseModel)
61
+
62
+
63
+ class HttpClient:
64
+ def __init__(
65
+ self,
66
+ prefix: str,
67
+ base_url: str,
68
+ auth: httpx.Auth | None = None,
69
+ timeout: float = 60.0,
70
+ user_agent: str | None = None,
71
+ workspace: str | None = None,
72
+ use_keyring: bool = True,
73
+ ):
74
+ base_url = base_url.rstrip("/")
75
+ user_agent = user_agent or f"cradle-sdk-python/{pkg_version('cradle-sdk')}"
76
+ headers = {
77
+ "Accept": "application/json",
78
+ "User-Agent": user_agent,
79
+ }
80
+ client = httpx.Client(auth=auth, headers=headers, timeout=timeout, follow_redirects=True)
81
+ if auth is None:
82
+ client.auth = DeviceAuth.from_strategy(
83
+ client=client, workspace=workspace, base_url=base_url, use_keyring=use_keyring
84
+ )
85
+
86
+ self.http_client = client
87
+ self.prefix = prefix.rstrip("/")
88
+ self.base_url = base_url.rstrip("/")
89
+ self.workspace = workspace
90
+
91
+ def with_prefix(self, prefix: str) -> "HttpClient":
92
+ c = copy.copy(self)
93
+ c.prefix = f"{self.prefix}/{prefix.strip('/')}".strip("/")
94
+ return c
95
+
96
+ def url(self, path: str) -> str:
97
+ return urljoin(f"{self.base_url}/", f"{self.prefix}{path}".lstrip("/"))
98
+
99
+ def request(
100
+ self,
101
+ method: str,
102
+ path: str,
103
+ json: dict[str, Any] | BaseModel | None = None,
104
+ params: dict[str, Any] | None = None,
105
+ **kwargs,
106
+ ) -> httpx.Response:
107
+ if params is not None:
108
+ # HTTPX will send "...&foo=&...", which FastAPI would interpret as an empty string rather than None.
109
+ kwargs["params"] = {k: v for k, v in params.items() if v is not None}
110
+ if isinstance(json, BaseModel):
111
+ json = json.model_dump(mode="json", by_alias=True)
112
+ if json is not None:
113
+ kwargs["json"] = json
114
+
115
+ return self.http_client.request(method, self.url(path), **kwargs)
116
+
117
+ def get(self, path: str, response_type: type[T] | Any | None, params: dict[str, Any] | None = None, **kwargs) -> T:
118
+ return self._handle_response(self.request("GET", path, params=params, **kwargs), response_type)
119
+
120
+ def post(self, path: str, response_type: type[T] | Any | None, params: dict[str, Any] | None = None, **kwargs) -> T:
121
+ return self._handle_response(self.request("POST", path, params=params, **kwargs), response_type)
122
+
123
+ def _handle_response(
124
+ self, response: httpx.Response, response_type: type[T] | Iterator[str] | Iterator[bytes] | None
125
+ ) -> Any:
126
+ if not response.is_success:
127
+ self.handle_error_response(response)
128
+ if response_type is None:
129
+ return None
130
+
131
+ ct = response.headers.get("content-type")
132
+ if ct is None:
133
+ raise ValueError("content type header missing")
134
+
135
+ ct = ct.split(";")[0].strip()
136
+ if ct == "application/json":
137
+ return TypeAdapter(response_type).validate_python(response.json())
138
+ if ct == "application/x-ndjson" and response_type == Iterator[str]:
139
+ return response.iter_lines()
140
+ if ct == "application/octet-stream" and response_type == Iterator[bytes]:
141
+ return response.iter_bytes()
142
+ raise ValueError(f"Unsupported content type {ct}")
143
+
144
+ def handle_error_response(self, response: httpx.Response) -> NoReturn:
145
+ auth = self.http_client.auth
146
+ if (
147
+ response.status_code == 403
148
+ and self.workspace is not None
149
+ and isinstance(auth, DeviceAuth)
150
+ and auth.authorized_workspace_id != self.workspace
151
+ ):
152
+ auth.suggest_logout(self.workspace)
153
+ try:
154
+ body = response.json()
155
+ error_msg = body.get("detail", response.text)
156
+ if (error_id := body.get("error_id")) is not None:
157
+ error_msg = f"{error_msg} (error ID: {error_id})"
158
+ errors = body.get("errors", [])
159
+ except Exception: # noqa: BLE001
160
+ error_msg = response.text
161
+ errors = []
162
+ raise ClientError(response.status_code, error_msg, errors=errors)
163
+
164
+
165
+ class DisabledAuth(httpx.Auth):
166
+ """Auth implementation to use when authentication should be disabled"""
167
+
168
+ @override
169
+ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
170
+ del self
171
+ yield request
172
+
173
+
174
+ API_URL = "https://api.cradle.bio"
175
+
176
+
177
+ class Client:
178
+ """The workspace API provides functionality for retrieving information about your workspace.
179
+
180
+ A workspace is the home for all data, projects and tasks for an organization. By
181
+ design no information can be shared across workspaces, they are the top-level
182
+ concept in the Cradle API.
183
+
184
+ Workspace creation and management is only available to Cradle administrators.
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ workspace: str,
190
+ auth: httpx.Auth | None = None,
191
+ base_url: str = API_URL,
192
+ timeout: int = 60,
193
+ user_agent: str | None = None,
194
+ use_keyring: bool = True,
195
+ ):
196
+ self._base_client = HttpClient(
197
+ prefix="/",
198
+ base_url=base_url,
199
+ auth=auth,
200
+ timeout=timeout,
201
+ user_agent=user_agent,
202
+ workspace=workspace,
203
+ use_keyring=use_keyring,
204
+ )
205
+ self._client = self._base_client.with_prefix(f"/v2/workspace/{workspace}/")
206
+ self._workspace = workspace
207
+
208
+ def workspace(self) -> WorkspaceResponse:
209
+ return self._base_client.get("/v2/workspace:get", WorkspaceResponse, dict(name=self._workspace))
210
+
211
+ @property
212
+ def project(self) -> "ProjectClient":
213
+ return ProjectClient(self._client)
214
+
215
+ @property
216
+ def round(self) -> "RoundClient":
217
+ return RoundClient(self._client)
218
+
219
+ @property
220
+ def data(self) -> "DataClient":
221
+ return DataClient(self._client)
222
+
223
+ @property
224
+ def task(self) -> "TaskClient":
225
+ return TaskClient(self._client)
226
+
227
+
228
+ class ProjectClient:
229
+ """The projects API provides functionality for managing and interacting with
230
+ projects in a workspace.
231
+
232
+ Projects organize tasks in a workspace and represent the work involved in
233
+ optimizing a specific protein or achieving a specific goal. Data is currently
234
+ isolated to a specific project, but in the future we will support cross project
235
+ data sharing within a workspace.
236
+ """
237
+
238
+ def __init__(self, client: HttpClient):
239
+ self._client = client.with_prefix("/project")
240
+
241
+ def get(self, project_id: int) -> ProjectResponse:
242
+ """Get a project by `ID`."""
243
+ return self._client.get(":get", ProjectResponse, dict(id=project_id))
244
+
245
+ def list(self) -> Generator[ProjectResponse, None, None]:
246
+ """List all projects in the workspace."""
247
+ opts = ListOptions()
248
+ while True:
249
+ resp = self._client.get(":list", ListProjectResponse, params=opts.model_dump())
250
+ yield from resp.items
251
+ if resp.cursor is None:
252
+ return
253
+ opts.cursor = resp.cursor
254
+
255
+ def get_by_name(self, name: str) -> ProjectResponse:
256
+ """Get a project by name. Raises ValueError if not found."""
257
+ for project in self.list():
258
+ if project.name == name:
259
+ return project
260
+ raise ValueError(f"Project with name '{name}' not found")
261
+
262
+ def create(self, project: ProjectCreate) -> ProjectResponse:
263
+ """Create a new project in the workspace."""
264
+ return self._client.post(":create", ProjectResponse, json=project)
265
+
266
+ def update(self, project_id: int, project: ProjectCreate) -> ProjectResponse:
267
+ """Update a project's information."""
268
+ return self._client.post(":update", ProjectResponse, dict(id=project_id), json=project)
269
+
270
+ def archive(self, project_id: int) -> ProjectResponse:
271
+ """Archive a project.
272
+
273
+ The data from the project will still be available.
274
+ It will not be possible to create new tasks, data loads, etc. in the project.
275
+ """
276
+ return self._client.post(":archive", ProjectResponse, dict(id=project_id))
277
+
278
+ def unarchive(self, project_id: int) -> ProjectResponse:
279
+ """Unarchive a previously archived project.
280
+
281
+ Returns an HTTP 422 `Unprocessable Entity` error if the project is not archived.
282
+ """
283
+ return self._client.post(":unarchive", ProjectResponse, dict(id=project_id))
284
+
285
+
286
+ class RoundClient:
287
+ """The rounds API provides functionality for managing and interacting with rounds within a project.
288
+
289
+ Rounds are a way to organize the experiments within a project. Commonly, a round represents all
290
+ in-vitro experiments related to a specific sub-goal or timeframe of a project.
291
+ While rounds are typically ordered sequentially, multiple rounds can exist in parallel (for example, when
292
+ testing multiple design methods in parallel).
293
+
294
+ Tasks, data loads, and reports can be (but don't have to be!) assigned to a round.
295
+ """
296
+
297
+ def __init__(self, client: HttpClient):
298
+ self._client = client.with_prefix("/round")
299
+
300
+ def get(self, round_id: int) -> RoundResponse:
301
+ """Get a round by ID."""
302
+ return self._client.get(":get", RoundResponse, dict(id=round_id))
303
+
304
+ def list(self, project_id: int | None = None) -> Generator[RoundResponse, None, None]:
305
+ """List all rounds in the project."""
306
+ params = {}
307
+ if project_id is not None:
308
+ params["project_id"] = project_id
309
+
310
+ opts = ListOptions()
311
+ while True:
312
+ resp = self._client.get(
313
+ ":list",
314
+ ListRoundResponse,
315
+ params={**opts.model_dump(), **params},
316
+ )
317
+ yield from resp.items
318
+ if resp.cursor is None:
319
+ return
320
+ opts.cursor = resp.cursor
321
+
322
+ def get_by_name(self, project_id: int, name: str) -> RoundResponse:
323
+ """Get a round by name within a specific project. Raises ValueError if not found."""
324
+ for round_ in self.list(project_id=project_id):
325
+ if round_.name == name:
326
+ return round_
327
+ raise ValueError(f"Round with name '{name}' not found in project {project_id}")
328
+
329
+ def create(self, round_: RoundCreate) -> RoundResponse:
330
+ """Create a new round in a project."""
331
+ return self._client.post(":create", RoundResponse, json=round_)
332
+
333
+ def update(self, round_id: int, round_: RoundCreate) -> RoundResponse:
334
+ """Update a round's name and description."""
335
+ return self._client.post(":update", RoundResponse, dict(id=round_id), json=round_)
336
+
337
+ def archive(self, round_id: int) -> RoundResponse:
338
+ """Archive a round.
339
+
340
+ The data from the round will still be available.
341
+ It will not be possible to create new tasks, data loads, etc. in the round.
342
+ """
343
+ return self._client.post(":archive", RoundResponse, dict(id=round_id))
344
+
345
+ def unarchive(self, round_id: int) -> RoundResponse:
346
+ """Unarchive a previously archived round.
347
+
348
+ Returns an HTTP 422 `Unprocessable Entity` error if the round is not archived.
349
+ """
350
+ return self._client.post(":unarchive", RoundResponse, dict(id=round_id))
351
+
352
+
353
+ class DataClient:
354
+ """The data API provides functionality for managing and interacting with data in a workspace.
355
+
356
+ Data comes in the form of *tables* and *artifacts*.
357
+ *Tables* are used to store row-based data, such as measurements, sequences, and other structured data.
358
+ *Artifacts* are used to store the results of tasks, such as machine learning models, predictions, and other outputs.
359
+
360
+ While tables can be created and uploaded by users, artifacts are only produced by tasks and cannot be uploaded directly.
361
+
362
+ Artifacts and table rows (via their data load source) have a context specified. This context refers to the
363
+ origin of the data and is used in other parts of the API to determine data visibility, e.g. to
364
+ isolate data between projects.
365
+
366
+ See the [Artifacts](#tag/Artifacts) API endpoints for more details on artifact data.
367
+
368
+ See the [Tables](#tag/Tables) API endpoints for more details on tabular data.
369
+ """
370
+
371
+ def __init__(self, client: HttpClient):
372
+ self._client = client.with_prefix("/data")
373
+
374
+ @property
375
+ def artifact(self) -> "ArtifactClient":
376
+ return ArtifactClient(self._client)
377
+
378
+ @property
379
+ def table(self) -> "TableClient":
380
+ return TableClient(self._client)
381
+
382
+ @property
383
+ def load(self) -> "DataLoadClient":
384
+ return DataLoadClient(self._client)
385
+
386
+ def list_versions(
387
+ self,
388
+ project_id: int | None = None,
389
+ table_id: int | None = None,
390
+ ) -> Generator[DataVersionResponse, None, None]:
391
+ params = {}
392
+ if project_id is not None:
393
+ params["project_id"] = project_id
394
+ if table_id is not None:
395
+ params["table_id"] = table_id
396
+
397
+ opts = ListOptions()
398
+ while True:
399
+ resp = self._client.get("/version:list", ListDataVersionResponse, params={**params, **opts.model_dump()})
400
+ yield from resp.items
401
+ if resp.cursor is None:
402
+ return
403
+ opts.cursor = resp.cursor
404
+
405
+
406
+ class ArtifactClient:
407
+ """Artifacts are produced by tasks and can be referenced as inputs to other tasks. They can generally
408
+ not be uploaded and downloaded directly. They have a well-defined type that determines for which
409
+ purpose they can be used.
410
+ """
411
+
412
+ def __init__(self, client: HttpClient):
413
+ self._client = client.with_prefix("/artifact")
414
+
415
+ def get(self, artifact_id: int) -> ArtifactResponse:
416
+ """Get information about an artifact."""
417
+ return self._client.get(":get", ArtifactResponse, dict(id=artifact_id))
418
+
419
+ def list(
420
+ self,
421
+ project_id: int | None = None,
422
+ round_id: int | None = None,
423
+ ) -> Generator[ArtifactResponse, None, None]:
424
+ """List artifacts in the workspace."""
425
+ opts = ListOptions()
426
+
427
+ params = {}
428
+ if project_id is not None:
429
+ params["project_id"] = project_id
430
+ if round_id is not None:
431
+ params["round_id"] = round_id
432
+
433
+ while True:
434
+ resp = self._client.get(
435
+ ":list",
436
+ ListArtifactResponse,
437
+ params={**opts.model_dump(), **params},
438
+ )
439
+ yield from resp.items
440
+ if resp.cursor is None:
441
+ return
442
+ opts.cursor = resp.cursor
443
+
444
+
445
+ class _IteratorReader:
446
+ def __init__(self, iterator: Iterator[bytes]):
447
+ self.iterator = iterator
448
+ self.buffer = b""
449
+ self._closed = False
450
+
451
+ @property
452
+ def closed(self) -> bool:
453
+ return self._closed
454
+
455
+ def close(self) -> None:
456
+ self._closed = True
457
+ self._buffer = b""
458
+
459
+ def read(self, size=-1):
460
+ try:
461
+ while size < 0 or len(self.buffer) < size:
462
+ self.buffer += next(self.iterator)
463
+ except StopIteration:
464
+ pass
465
+ if size < 0:
466
+ out, self.buffer = self.buffer, b""
467
+ else:
468
+ out, self.buffer = self.buffer[:size], self.buffer[size:]
469
+ return out
470
+
471
+ def __enter__(self):
472
+ return self
473
+
474
+ def __exit__(self, exc_type, exc_val, exc_tb):
475
+ self.close()
476
+
477
+
478
+ class TableClient:
479
+ """Tables store row-based data with a well-defined column schema. Such data may come from spreadsheets
480
+ or other tabular data stores.
481
+
482
+ Each table is created with a well-defined column schema, which can include primitive columns such
483
+ as numbers and strings, as well as nested struct columns and arrays.
484
+ A table's schema can be extended but existing columns cannot generally be modified or removed.
485
+
486
+ Table data as well as it's column schema itself are version controlled. Each change to the contained
487
+ rows is tracked as a data version.
488
+ Tables can be queried at any historic version unless some destructive change has been made, such as
489
+ hard deletion of historic data for compliance reasons.
490
+
491
+ Because tables are versioned, each row is immutable. This means that it is not possible to change individual
492
+ values in a row. Instead, the data load that added the row can be undone and the rows can be loaded
493
+ again with the changed data.
494
+
495
+ See the [Data Loads](#tag/Data-Loads) API endpoints on how to add and remove data to tables.
496
+ """
497
+
498
+ def __init__(self, client: HttpClient):
499
+ self._client = client.with_prefix("/table")
500
+
501
+ def query(
502
+ self,
503
+ query: str,
504
+ *,
505
+ project_id: int | None,
506
+ version_id: int | None = None,
507
+ ) -> Generator[pa.RecordBatch, None, None]:
508
+ """Query data with the given SQL query and retrieve the result as arrow record batches.
509
+
510
+ Args:
511
+ query: The SQL query the execute.
512
+ project_id: project_id to filter the source data by. Can be set to none to query data
513
+ access all projects the authenticated user has access to.
514
+ version_id: Data version ID at which to execute the query.
515
+
516
+ Example:
517
+ ```
518
+ result = client.query("SELECT * FROM source_table")
519
+ table = pyarrow.Table.from_batches(result)
520
+
521
+ # Write as CSV. Does only work with primitive columns.
522
+ pyarrow.csv.write_csv(t, "result.csv")
523
+
524
+ # Write as Parquet. Does work with structs and array columns as well.
525
+ pyarrow.parquet.write_table(t, "result.parquet")
526
+ ```
527
+ """
528
+ request = QueryDataRequest(
529
+ query=query,
530
+ project_id=project_id,
531
+ version_id=version_id,
532
+ )
533
+ with self._client.http_client.stream(
534
+ "POST", self._client.url(":query"), json=request.model_dump(by_alias=True)
535
+ ) as response:
536
+ if not response.is_success:
537
+ response.read() # Required in streaming request to the error handling can access content.
538
+ self._client.handle_error_response(response)
539
+
540
+ ct = response.headers.get("content-type")
541
+ if ct is None:
542
+ raise ValueError("content type header missing")
543
+ ct = ct.split(";")[0].strip()
544
+
545
+ if ct != "application/vnd.apache.arrow.stream":
546
+ raise ValueError(f"unexpected content type {ct}")
547
+ yield from pa.ipc.open_stream(_IteratorReader(response.iter_bytes()))
548
+
549
+ def get_by_id(self, table_id: int) -> TableResponse:
550
+ """Retrieve the table by its ID.
551
+ The ID is an opaque identifier for the table returned at table creation time and is distinct from the table reference.
552
+ """
553
+ return self._client.get(":getById", TableResponse, dict(id=table_id))
554
+
555
+ def get(self, reference: str) -> TableResponse:
556
+ """Get a table by its reference."""
557
+ return self._client.get(":get", TableResponse, dict(reference=reference))
558
+
559
+ def list(self) -> Generator[TableResponse, None, None]:
560
+ """List all current base tables and views in the workspace."""
561
+ opts = ListOptions()
562
+ while True:
563
+ resp = self._client.get(":list", ListTableResponse, params=opts.model_dump())
564
+ yield from resp.items
565
+ if resp.cursor is None:
566
+ return
567
+ opts.cursor = resp.cursor
568
+
569
+ def list_versions(self, table_id: int) -> Generator[TableResponse, None, None]:
570
+ """List all versions of a specific table."""
571
+ opts = ListOptions()
572
+ while True:
573
+ resp = self._client.get(
574
+ ":listVersions",
575
+ ListTableResponse,
576
+ params=dict(**opts.model_dump(), id=table_id),
577
+ )
578
+ yield from resp.items
579
+ if resp.cursor is None:
580
+ return
581
+ opts.cursor = resp.cursor
582
+
583
+ def list_archived(self, reference: str) -> Generator[TableResponse, None, None]:
584
+ """List all archived tables that match the given reference."""
585
+ opts = ListOptions()
586
+ while True:
587
+ resp = self._client.get(
588
+ ":listArchived",
589
+ ListTableResponse,
590
+ params=dict(**opts.model_dump(), reference=reference),
591
+ )
592
+ yield from resp.items
593
+ if resp.cursor is None:
594
+ return
595
+ opts.cursor = resp.cursor
596
+
597
+ def create(self, table: ViewTableCreate | BaseTableCreate) -> TableResponse:
598
+ """Create a new table or view."""
599
+ return self._client.post(":create", TableResponse, json=table)
600
+
601
+ def update(self, reference: str, table: ViewTableUpdate | BaseTableUpdate) -> TableResponse:
602
+ """Update a table or view.
603
+
604
+ Tables can only be updated in forward-compatible ways.
605
+ New columns can be added but not be removed.
606
+ Existing columns can be turned from non-nullable to nullable but not the other way around.
607
+ Existing columns cannot be renamed or changed in type.
608
+
609
+ Views can be updated arbitrarily.
610
+ """
611
+ return self._client.post(":update", TableResponse, params=dict(reference=reference), json=table)
612
+
613
+ def archive(self, reference: str, archive: TableArchive | None = None) -> TableResponse:
614
+ """Archive a table or view.
615
+
616
+ Archiving makes the table or view inaccessible for regular operations
617
+ but preserves it for historical purposes.
618
+ """
619
+ if not archive:
620
+ archive = TableArchive()
621
+ return self._client.post(":archive", TableResponse, params=dict(reference=reference), json=archive)
622
+
623
+ def unarchive(self, table_id: int) -> TableResponse:
624
+ """Unarchive a previously archived table or view.
625
+
626
+ This makes the table or view available for regular operations again.
627
+ If an active table with the same reference already exists, the unarchive operation
628
+ will fail with an HTTP 409 `Conflict error`.
629
+ """
630
+ return self._client.post(":unarchive", TableResponse, params=dict(id=table_id))
631
+
632
+ def rename(self, reference: str, rename: TableRename) -> TableResponse:
633
+ """Rename the table."""
634
+ return self._client.post(":rename", TableResponse, params=dict(reference=reference), json=rename)
635
+
636
+
637
+ class DataLoadClient:
638
+ """The data API provides functionality for managing and interacting with data in a workspace.
639
+
640
+ Data comes in the form of *tables* and *artifacts*.
641
+ *Tables* are used to store row-based data, such as measurements, sequences, and other structured data.
642
+ *Artifacts* are used to store the results of tasks, such as machine learning models, predictions, and other outputs.
643
+
644
+ While tables can be created and uploaded by users, artifacts are only produced by tasks and cannot be uploaded directly.
645
+
646
+ Artifacts and table rows (via their data load source) have a context specified. This context refers to the
647
+ origin of the data and is used in other parts of the API to determine data visibility, e.g. to
648
+ isolate data between projects.
649
+
650
+ See the [Artifacts](#tag/Artifacts) API endpoints for more details on artifact data.
651
+
652
+ See the [Tables](#tag/Tables) API endpoints for more details on tabular data.
653
+ """
654
+
655
+ def __init__(self, client: HttpClient):
656
+ self._client = client.with_prefix("/load")
657
+
658
+ def list(
659
+ self, project_id: int | None = None, round_id: int | None = None
660
+ ) -> Generator[DataLoadResponse, None, None]:
661
+ """List all data loads in the workspace."""
662
+ params = {}
663
+ if project_id is not None:
664
+ params["project_id"] = project_id
665
+ if round_id is not None:
666
+ params["round_id"] = round_id
667
+
668
+ opts = ListOptions()
669
+ while True:
670
+ resp = self._client.get(":list", ListDataLoadResponse, params={**params, **opts.model_dump()})
671
+ yield from resp.items
672
+ if resp.cursor is None:
673
+ return
674
+ opts.cursor = resp.cursor
675
+
676
+ def get(self, load_id: int) -> DataLoadResponse:
677
+ """Get a data load by ID."""
678
+ return self._client.get(":get", DataLoadResponse, dict(id=load_id))
679
+
680
+ def create(self, load: DataLoadCreate) -> DataLoadResponse:
681
+ """Create a new data load.
682
+ After creation, the data load will be empty. Files can subsequently be added using the `uploadFile` endpoint.
683
+ """
684
+ return self._client.post(":create", DataLoadResponse, json=load)
685
+
686
+ def upload_file(
687
+ self,
688
+ load_id: int,
689
+ file: Path,
690
+ *,
691
+ description: str | None = None,
692
+ filepath: str | Path | None = None,
693
+ table_reference: str | None = None,
694
+ source_file_id: int | None = None,
695
+ ) -> FileUploadResponse:
696
+ """Upload a file to the cloud storage bucket and add upload information to load."""
697
+ data = {}
698
+ if description is not None:
699
+ data["description"] = description
700
+ if filepath is not None:
701
+ data["filepath"] = str(filepath)
702
+ if table_reference is not None:
703
+ data["table_reference"] = table_reference
704
+ if source_file_id is not None:
705
+ data["source_file_id"] = source_file_id
706
+
707
+ with file.open("rb") as f:
708
+ return self._client.post(
709
+ ":uploadFile",
710
+ FileUploadResponse,
711
+ params=dict(id=load_id),
712
+ files=dict(file=(file.name, f, "application/octet-stream")),
713
+ data=data,
714
+ )
715
+
716
+ def download_file(self, file_id: str, buffer: BinaryIO):
717
+ """Download a previously uploaded file from a load."""
718
+ bytes_stream: Iterator[bytes] = self._client.get(":downloadFile", Iterator[bytes], dict(file_id=file_id))
719
+ buffer.writelines(bytes_stream)
720
+
721
+ def finalize(self, load_id: int) -> DataLoadResponse:
722
+ """Finalize the data load and start ingesting rows from the uploaded files.
723
+
724
+ The ingestion process will run in the background. The progress can be tracked by polling the data load
725
+ and waiting for the status to switch from `LOADING` to `COMPLETED` or `FAILED`.
726
+ """
727
+ return self._client.post(":finalize", DataLoadResponse, params=dict(id=load_id))
728
+
729
+ def delete(self, load_id: int) -> DataLoadResponse:
730
+ """Delete the data load if it has not been completed yet.
731
+
732
+ The deletion will happen asynchronously afterwards and is irreversible.
733
+ """
734
+ return self._client.post(":delete", DataLoadResponse, params=dict(id=load_id))
735
+
736
+ def undo(self, load_id: int) -> DataLoadResponse:
737
+ """Append an entry to the changelog undoing the load.
738
+
739
+ The rows from this data load will no longer show at subsequent data versions.
740
+ Previously undone loads can be redone to restore their data.
741
+ """
742
+ return self._client.post(":undo", DataLoadResponse, params=dict(id=load_id))
743
+
744
+ def redo(self, load_id: int) -> DataLoadResponse:
745
+ """Append an entry to the changelog redoing the load. The rows from this data load will be visible again at subsequent data versions."""
746
+ return self._client.post(":redo", DataLoadResponse, params=dict(id=load_id))
747
+
748
+ def add_table(self, load_id: int, request: AddTableRequest) -> DataLoadResponse:
749
+ """Register a new table in the data load configuration."""
750
+ return self._client.post(":addTable", DataLoadResponse, params=dict(id=load_id), json=request)
751
+
752
+
753
+ class TaskClient:
754
+ """Create endpoints for all supported task types.
755
+
756
+ When creating a task, a context must be specified. The context determines data visibility: data can either be visible
757
+ in the entire project, or only in a specific round within a project.
758
+ Artifact and table data also has a context and a task can only access data with a compatible context. A data context
759
+ is compatible when it a workspace context or, if it is a project or round context, belongs to the same project.
760
+
761
+ For example, suppose project `A` has rounds `abc` and `def` and project `B` has round `xyz`.
762
+ A task with context `project=A` or `round=abc` can only see data with context `project=A` or `round=abc` or `round=def`.
763
+ A task with context `project=B` can only see data with context `project=B` or `round=xyz`.
764
+
765
+ Any data a task creates will be saved with the same context as the task.
766
+
767
+ Tasks always use the latest version of the data in the workspace, unless a specific version is provided.
768
+ """
769
+
770
+ def __init__(self, client: HttpClient):
771
+ self._client = client.with_prefix("/task")
772
+
773
+ def list(self, project_id: int | None = None, round_id: int | None = None) -> Generator[TaskResponse, None, None]:
774
+ """List all tasks in a workspace, project or round.
775
+
776
+ To list all tasks in a workspace, leave `project_id` and `round_id` unset.
777
+ To list all tasks in a project, specify `project_id` and leave `round_id` unset.
778
+ To list all tasks in a round, specify `round_id` and leave `project_id` unset.
779
+
780
+ Only one of `project_id` or `round_id` must be specified at once.
781
+ """
782
+ params = {}
783
+ if project_id is not None:
784
+ params["project_id"] = project_id
785
+ if round_id is not None:
786
+ params["round_id"] = round_id
787
+
788
+ opts = ListOptions()
789
+ while True:
790
+ resp = self._client.get(":list", ListTaskResponse, params={**params, **opts.model_dump()})
791
+ yield from resp.items
792
+ if resp.cursor is None:
793
+ return
794
+ opts.cursor = resp.cursor
795
+
796
+ def get(self, task_id: int) -> TaskResponse:
797
+ """Obtain the current state and the result of a task.
798
+
799
+ This endpoint can be called periodically to monitor the status of a task and to detect
800
+ when it finished executing.
801
+
802
+ If the `state` field of the response is `COMPLETED`, then the `result` field contains the
803
+ task result (often containing IDs of artifacts and tables created by the task).
804
+
805
+ If the `state` field of the response is `FAILED`, then the `error` field contains
806
+ a description explaining why the task has failed.
807
+
808
+ The states `COMPLETED`, `FAILED` and `CANCELLED` are terminal states. When the task is in
809
+ any other state, it will eventually transition into one of these three.
810
+ """
811
+ return self._client.get(":get", TaskResponse, dict(id=task_id))
812
+
813
+ def get_by_name(
814
+ self,
815
+ name: str,
816
+ project_id: int | None = None,
817
+ round_id: int | None = None,
818
+ ) -> TaskResponse:
819
+ """Get a task by unique name within their workspace, project, or round context.
820
+
821
+ If a task with name "TaskA" exists in project with ID 123, the `project_id` must be specified.
822
+ It is not sufficient to leave it unset as tasks with the same name may exist at the workspace or
823
+ round context or in a different project.
824
+ """
825
+ tasks = self.list(project_id=project_id, round_id=round_id)
826
+ # When listing for workspace or project we still need to filter it down since they'll
827
+ # return tasks from their sub-contexts.
828
+ if round_id is not None:
829
+ scope_desc = f"round {round_id}"
830
+ tasks = [task for task in tasks if task.context == ContextRound(round_id=round_id)]
831
+ elif project_id is not None:
832
+ scope_desc = f"project {project_id}"
833
+ tasks = [task for task in tasks if task.context == ContextProject(project_id=project_id)]
834
+ else:
835
+ scope_desc = "workspace"
836
+ tasks = [task for task in tasks if task.context == ContextWorkspace()]
837
+
838
+ for task in tasks:
839
+ if task.name == name:
840
+ return task
841
+ raise ValueError(f"Task with name '{name}' not found in {scope_desc}")
842
+
843
+ def create(self, task: TaskCreate) -> TaskResponse:
844
+ return self._client.post(f"/{task.parameters.task_type}:create", TaskResponse, json=task)
845
+
846
+ def create_or_get(self, task: TaskCreate, skip_parameter_check: bool = False) -> TaskResponse:
847
+ """Create a task or get existing one with matching name and parameters.
848
+
849
+ The task must have a name set. If the parameter check is not skipped the existing task must have
850
+ the same parameters as `task`. If an explicit `data_version_id` is specified, it must also match.
851
+
852
+ Raises ValueError if parameters don't match or if name is None.
853
+ """
854
+ if task.name is None:
855
+ raise ValueError("Task must have a name")
856
+ try:
857
+ if isinstance(task.context, ContextRound):
858
+ existing_task = self.get_by_name(task.name, round_id=task.context.round_id)
859
+ elif isinstance(task.context, ContextProject):
860
+ existing_task = self.get_by_name(task.name, project_id=task.context.project_id)
861
+ else:
862
+ existing_task = self.get_by_name(task.name)
863
+ except ValueError:
864
+ existing_task = None
865
+
866
+ if existing_task is None:
867
+ return self.create(task)
868
+
869
+ if skip_parameter_check:
870
+ return existing_task
871
+
872
+ if existing_task.parameters != task.parameters:
873
+ existing_params = existing_task.parameters.model_dump_json(indent=2, by_alias=True)
874
+ new_params = task.parameters.model_dump_json(indent=2, by_alias=True)
875
+ diff = "\n".join(
876
+ difflib.unified_diff(
877
+ existing_params.splitlines(),
878
+ new_params.splitlines(),
879
+ fromfile="existing_parameters",
880
+ tofile="new_parameters",
881
+ lineterm="",
882
+ )
883
+ )
884
+ raise ValueError(f"Task with name '{task.name}' exists but has different parameters:\n{diff}")
885
+
886
+ if task.data_version_id is not None and existing_task.data_version_id != task.data_version_id:
887
+ raise ValueError(
888
+ f"Task with name '{task.name}' exists but has different data_version_id (existing: {existing_task.data_version_id}, new: {task.data_version_id})"
889
+ )
890
+
891
+ return existing_task
892
+
893
+ def cancel(self, task_id: int) -> TaskResponse:
894
+ """Cancel the task if possible.
895
+
896
+ Only tasks that have not yet terminated may be canceled. There is no guarantee that cancellation succeeds
897
+ before the task successfully completes.
898
+ """
899
+ return self._client.post(":cancel", TaskResponse, dict(id=task_id))
900
+
901
+ def recover(self, task_id: int) -> TaskResponse:
902
+ """Recover a task that has failed or been cancelled by putting it into state EXECUTING.
903
+
904
+ Note that this is best-effort and is not guaranteed to fix any issues in the underlying
905
+ backend execution, and may also cause unexpected side-effects such as duplicate data loads.
906
+ Main use-case is tasks that failed due to intermittent issues in the backend which have been resolved.
907
+ """
908
+ return self._client.post(":recover", TaskResponse, dict(id=task_id))
909
+
910
+ def archive(self, task_id: int) -> TaskResponse:
911
+ """Archive the task.
912
+
913
+ This will make the task disappear from regular task listings. It will not remove data that has been
914
+ produced by the task. If data removal is required, use the `data_load_id` for the task's results
915
+ and call the `/data/load:undo` endpoint separately.
916
+ """
917
+ return self._client.post(":archive", TaskResponse, dict(id=task_id))
918
+
919
+ def unarchive(self, task_id: int) -> TaskResponse:
920
+ """Unarchive the task."""
921
+ return self._client.post(":unarchive", TaskResponse, dict(id=task_id))
922
+
923
+ def wait(self, task_id: int, timeout: float = 60.0) -> TaskResponse:
924
+ start = time.monotonic()
925
+
926
+ while time.monotonic() - start < timeout:
927
+ task = self.get(task_id)
928
+ print(f"\r{task.state} @ {datetime.now(tz=UTC).strftime('%Y-%m-%d %H:%M:%S %Z')}", end="", flush=True)
929
+ if task.state == TaskState.FAILED:
930
+ raise ValueError(f"Task {task_id} failed: {task.errors}")
931
+ if task.state == TaskState.CANCELLED:
932
+ raise ValueError(f"Task {task_id} was cancelled")
933
+ if task.state == TaskState.COMPLETED:
934
+ break
935
+ time.sleep(3)
936
+ print()
937
+ return self.get(task_id)