pyspiral 0.1.0__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. pyspiral-0.1.0.dist-info/METADATA +48 -0
  2. pyspiral-0.1.0.dist-info/RECORD +81 -0
  3. pyspiral-0.1.0.dist-info/WHEEL +4 -0
  4. pyspiral-0.1.0.dist-info/entry_points.txt +2 -0
  5. spiral/__init__.py +11 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +386 -0
  8. spiral/api/__init__.py +221 -0
  9. spiral/api/admin.py +29 -0
  10. spiral/api/filesystems.py +125 -0
  11. spiral/api/organizations.py +90 -0
  12. spiral/api/projects.py +160 -0
  13. spiral/api/tables.py +94 -0
  14. spiral/api/tokens.py +56 -0
  15. spiral/api/workloads.py +45 -0
  16. spiral/arrow.py +209 -0
  17. spiral/authn/__init__.py +0 -0
  18. spiral/authn/authn.py +89 -0
  19. spiral/authn/device.py +206 -0
  20. spiral/authn/github_.py +33 -0
  21. spiral/authn/modal_.py +18 -0
  22. spiral/catalog.py +78 -0
  23. spiral/cli/__init__.py +82 -0
  24. spiral/cli/__main__.py +4 -0
  25. spiral/cli/admin.py +21 -0
  26. spiral/cli/app.py +48 -0
  27. spiral/cli/console.py +95 -0
  28. spiral/cli/fs.py +47 -0
  29. spiral/cli/login.py +13 -0
  30. spiral/cli/org.py +90 -0
  31. spiral/cli/printer.py +45 -0
  32. spiral/cli/project.py +107 -0
  33. spiral/cli/state.py +3 -0
  34. spiral/cli/table.py +20 -0
  35. spiral/cli/token.py +27 -0
  36. spiral/cli/types.py +53 -0
  37. spiral/cli/workload.py +59 -0
  38. spiral/config.py +26 -0
  39. spiral/core/__init__.py +0 -0
  40. spiral/core/core/__init__.pyi +53 -0
  41. spiral/core/manifests/__init__.pyi +53 -0
  42. spiral/core/metastore/__init__.pyi +91 -0
  43. spiral/core/spec/__init__.pyi +257 -0
  44. spiral/dataset.py +239 -0
  45. spiral/debug.py +251 -0
  46. spiral/expressions/__init__.py +222 -0
  47. spiral/expressions/base.py +149 -0
  48. spiral/expressions/http.py +86 -0
  49. spiral/expressions/io.py +100 -0
  50. spiral/expressions/list_.py +68 -0
  51. spiral/expressions/refs.py +44 -0
  52. spiral/expressions/str_.py +39 -0
  53. spiral/expressions/struct.py +57 -0
  54. spiral/expressions/tiff.py +223 -0
  55. spiral/expressions/udf.py +46 -0
  56. spiral/grpc_.py +32 -0
  57. spiral/project.py +137 -0
  58. spiral/proto/_/__init__.py +0 -0
  59. spiral/proto/_/arrow/__init__.py +0 -0
  60. spiral/proto/_/arrow/flight/__init__.py +0 -0
  61. spiral/proto/_/arrow/flight/protocol/__init__.py +0 -0
  62. spiral/proto/_/arrow/flight/protocol/sql/__init__.py +1990 -0
  63. spiral/proto/_/scandal/__init__.py +223 -0
  64. spiral/proto/_/spfs/__init__.py +36 -0
  65. spiral/proto/_/spiral/__init__.py +0 -0
  66. spiral/proto/_/spiral/table/__init__.py +225 -0
  67. spiral/proto/_/spiraldb/__init__.py +0 -0
  68. spiral/proto/_/spiraldb/metastore/__init__.py +499 -0
  69. spiral/proto/__init__.py +0 -0
  70. spiral/proto/scandal/__init__.py +45 -0
  71. spiral/proto/spiral/__init__.py +0 -0
  72. spiral/proto/spiral/table/__init__.py +96 -0
  73. spiral/proto/substrait/__init__.py +3399 -0
  74. spiral/proto/substrait/extensions/__init__.py +115 -0
  75. spiral/proto/util.py +41 -0
  76. spiral/py.typed +0 -0
  77. spiral/scan_.py +168 -0
  78. spiral/settings.py +157 -0
  79. spiral/substrait_.py +275 -0
  80. spiral/table.py +157 -0
  81. spiral/types_.py +6 -0
spiral/api/tables.py ADDED
@@ -0,0 +1,94 @@
1
+ from typing import Annotated
2
+
3
+ from pydantic import (
4
+ AfterValidator,
5
+ BaseModel,
6
+ StringConstraints,
7
+ )
8
+
9
+ from . import ArrowSchema, Paged, PagedRequest, PagedResponse, ProjectId, ServiceBase
10
+
11
+
12
+ def _validate_root_uri(uri: str) -> str:
13
+ if uri.endswith("/"):
14
+ raise ValueError("Root URI must not end with a slash.")
15
+ return uri
16
+
17
+
18
+ RootUri = Annotated[str, AfterValidator(_validate_root_uri)]
19
+ DatasetName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
20
+ TableName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
21
+
22
+
23
+ class TableMetadata(BaseModel):
24
+ key_schema: ArrowSchema
25
+ root_uri: RootUri
26
+ spfs_mount_id: str | None = None
27
+
28
+ # TODO(marko): Randomize this on creation of metadata.
29
+ # Column group salt is used to compute column group IDs.
30
+ # It's used to ensure that column group IDs are unique
31
+ # across different tables, even if paths are the same.
32
+ # It's never modified.
33
+ column_group_salt: int = 0
34
+
35
+
36
+ class Table(BaseModel):
37
+ id: str
38
+ project_id: ProjectId
39
+ dataset: DatasetName
40
+ table: TableName
41
+ metadata: TableMetadata
42
+
43
+
44
+ class CreateTable:
45
+ class Request(BaseModel):
46
+ project_id: ProjectId
47
+ dataset: DatasetName
48
+ table: TableName
49
+ key_schema: ArrowSchema
50
+ root_uri: RootUri | None = None
51
+ exist_ok: bool = False
52
+
53
+ class Response(BaseModel):
54
+ table: Table
55
+
56
+
57
+ class FindTable:
58
+ class Request(BaseModel):
59
+ project_id: ProjectId
60
+ dataset: DatasetName = None
61
+ table: TableName = None
62
+
63
+ class Response(BaseModel):
64
+ table: Table | None
65
+
66
+
67
+ class GetTable:
68
+ class Request(BaseModel):
69
+ id: str
70
+
71
+ class Response(BaseModel):
72
+ table: Table
73
+
74
+
75
+ class ListTables:
76
+ class Request(PagedRequest):
77
+ project_id: ProjectId
78
+ dataset: DatasetName | None = None
79
+
80
+ class Response(PagedResponse[Table]): ...
81
+
82
+
83
+ class TableService(ServiceBase):
84
+ def create(self, req: CreateTable.Request) -> CreateTable.Response:
85
+ return self.client.post("/table/create", req, CreateTable.Response)
86
+
87
+ def find(self, req: FindTable.Request) -> FindTable.Response:
88
+ return self.client.put("/table/find", req, FindTable.Response)
89
+
90
+ def get(self, req: GetTable.Request) -> GetTable.Response:
91
+ return self.client.put(f"/table/{req.id}", GetTable.Response)
92
+
93
+ def list(self, req: ListTables.Request) -> Paged[Table]:
94
+ return self.client.paged("/table/list", req, ListTables.Response)
spiral/api/tokens.py ADDED
@@ -0,0 +1,56 @@
1
+ from pydantic import BaseModel
2
+
3
+ from . import Paged, PagedRequest, PagedResponse, ServiceBase
4
+
5
+
6
+ class Token(BaseModel):
7
+ id: str
8
+ project_id: str
9
+ on_behalf_of: str
10
+
11
+
12
+ class ExchangeToken:
13
+ class Request(BaseModel): ...
14
+
15
+ class Response(BaseModel):
16
+ token: str
17
+
18
+
19
+ class IssueToken:
20
+ class Request(BaseModel): ...
21
+
22
+ class Response(BaseModel):
23
+ token: Token
24
+ token_secret: str
25
+
26
+
27
+ class RevokeToken:
28
+ class Request(BaseModel):
29
+ token_id: str
30
+
31
+ class Response(BaseModel):
32
+ token: Token
33
+
34
+
35
+ class ListTokens:
36
+ class Request(PagedRequest):
37
+ project_id: str
38
+ on_behalf_of: str | None = None
39
+
40
+ class Response(PagedResponse[Token]): ...
41
+
42
+
43
+ class TokenService(ServiceBase):
44
+ def exchange(self) -> ExchangeToken.Response:
45
+ """Exchange a basic / identity token to a short-lived Spiral token."""
46
+ return self.client.post("/token/exchange", ExchangeToken.Request(), ExchangeToken.Response)
47
+
48
+ def issue(self) -> IssueToken.Response:
49
+ """Issue an API token on behalf of a principal."""
50
+ return self.client.post("/token/issue", IssueToken.Request(), IssueToken.Response)
51
+
52
+ def revoke(self, request: RevokeToken.Request) -> RevokeToken.Response:
53
+ return self.client.put("/token/revoke", request, RevokeToken.Response)
54
+
55
+ def list(self, request: ListTokens.Request) -> Paged[Token]:
56
+ return self.client.paged("/token/list", request, ListTokens.Response)
@@ -0,0 +1,45 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ from . import Paged, PagedRequest, PagedResponse, ProjectId, ServiceBase
4
+
5
+
6
+ class Workload(BaseModel):
7
+ id: str
8
+ project_id: ProjectId
9
+ name: str | None = None
10
+
11
+
12
+ class CreateWorkload:
13
+ class Request(BaseModel):
14
+ project_id: str
15
+ name: str | None = Field(default=None, description="Optional human-readable name for the workload")
16
+
17
+ class Response(BaseModel):
18
+ workload: Workload
19
+
20
+
21
+ class IssueToken:
22
+ class Request(BaseModel):
23
+ workload_id: str
24
+
25
+ class Response(BaseModel):
26
+ token_id: str
27
+ token_secret: str
28
+
29
+
30
+ class ListWorkloads:
31
+ class Request(PagedRequest):
32
+ project_id: str
33
+
34
+ class Response(PagedResponse[Workload]): ...
35
+
36
+
37
+ class WorkloadService(ServiceBase):
38
+ def create(self, request: CreateWorkload.Request) -> CreateWorkload.Response:
39
+ return self.client.post("/workload/create", request, CreateWorkload.Response)
40
+
41
+ def issue_token(self, request: IssueToken.Request) -> IssueToken.Response:
42
+ return self.client.post("/workload/issue-token", request, IssueToken.Response)
43
+
44
+ def list(self, request: ListWorkloads.Request) -> Paged[Workload]:
45
+ return self.client.paged("/workload/list", request, ListWorkloads.Response)
spiral/arrow.py ADDED
@@ -0,0 +1,209 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Callable, Iterable
3
+ from functools import reduce
4
+ from typing import TypeVar
5
+
6
+ import numpy as np
7
+ import pyarrow as pa
8
+ from pyarrow import compute as pc
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ def arange(*args, **kwargs) -> pa.Array:
14
+ return pa.array(np.arange(*args, **kwargs), type=pa.int32())
15
+
16
+
17
+ def zip_tables(tables: Iterable[pa.Table]) -> pa.Table:
18
+ data = []
19
+ names = []
20
+ for table in tables:
21
+ data.extend(table.columns)
22
+ names.extend(table.column_names)
23
+ return pa.Table.from_arrays(data, names=names)
24
+
25
+
26
+ def merge_arrays(*arrays: pa.StructArray) -> pa.StructArray:
27
+ """Recursively merge arrays into nested struct arrays."""
28
+ if len(arrays) == 1:
29
+ return arrays[0]
30
+
31
+ nstructs = sum(pa.types.is_struct(a.type) for a in arrays)
32
+ if nstructs == 0:
33
+ # Then we have conflicting arrays and we choose the last.
34
+ return arrays[-1]
35
+
36
+ if nstructs != len(arrays):
37
+ raise ValueError("Cannot merge structs with non-structs.")
38
+
39
+ data = defaultdict(list)
40
+ for array in arrays:
41
+ if isinstance(array, pa.ChunkedArray):
42
+ array = array.combine_chunks()
43
+ for field in array.type:
44
+ data[field.name].append(array.field(field.name))
45
+
46
+ return pa.StructArray.from_arrays([merge_arrays(*v) for v in data.values()], names=list(data.keys()))
47
+
48
+
49
+ def merge_scalars(*scalars: pa.StructScalar) -> pa.StructScalar:
50
+ """Recursively merge scalars into nested struct scalars."""
51
+ if len(scalars) == 1:
52
+ return scalars[0]
53
+
54
+ nstructs = sum(pa.types.is_struct(a.type) for a in scalars)
55
+ if nstructs == 0:
56
+ # Then we have conflicting scalars and we choose the last.
57
+ return scalars[-1]
58
+
59
+ if nstructs != len(scalars):
60
+ raise ValueError("Cannot merge scalars with non-scalars.")
61
+
62
+ data = defaultdict(list)
63
+ for scalar in scalars:
64
+ for field in scalar.type:
65
+ data[field.name].append(scalar[field.name])
66
+
67
+ return pa.scalar({k: merge_scalars(*v) for k, v in data.items()})
68
+
69
+
70
+ def null_table(schema: pa.Schema, length: int = 0) -> pa.Table:
71
+ # We add an extra nulls column to ensure the length is correctly applied.
72
+ return pa.table(
73
+ [pa.nulls(length, type=field.type) for field in schema] + [pa.nulls(length)],
74
+ schema=pa.schema(list(schema) + [pa.field("__", type=pa.null())]),
75
+ ).drop(["__"])
76
+
77
+
78
+ def coalesce_all(table: pa.Table) -> pa.Table:
79
+ """Coalesce all columns that share the same name."""
80
+ columns: dict[str, list[pa.Array]] = defaultdict(list)
81
+ for i, col in enumerate(table.column_names):
82
+ columns[col].append(table[i])
83
+
84
+ data = []
85
+ names = []
86
+ for col, arrays in columns.items():
87
+ names.append(col)
88
+ if len(arrays) == 1:
89
+ data.append(arrays[0])
90
+ else:
91
+ data.append(pc.coalesce(*arrays))
92
+
93
+ return pa.Table.from_arrays(data, names=names)
94
+
95
+
96
+ def join(left: pa.Table, right: pa.Table, keys: list[str], join_type: str) -> pa.Table:
97
+ """Arrow's builtin join doesn't support struct columns. So we join ourselves and zip them in."""
98
+ # TODO(ngates): if join_type == inner, we may have better luck performing two index_in operations since this
99
+ # also preserves sort order.
100
+ lhs = left.select(keys).add_column(0, "__lhs", arange(len(left)))
101
+ rhs = right.select(keys).add_column(0, "__rhs", arange(len(right)))
102
+ joined = lhs.join(rhs, keys=keys, join_type=join_type).sort_by([(k, "ascending") for k in keys])
103
+ return zip_tables(
104
+ [joined.select(keys), left.take(joined["__lhs"]).drop(keys), right.take(joined["__rhs"]).drop(keys)]
105
+ )
106
+
107
+
108
+ def nest_structs(array: pa.StructArray | pa.StructScalar | dict) -> dict:
109
+ """Turn a struct-like value with dot-separated column names into a nested dictionary."""
110
+ data = {}
111
+
112
+ if isinstance(array, pa.StructArray | pa.StructScalar):
113
+ array = {f.name: field(array, f.name) for f in array.type}
114
+
115
+ for name in array.keys():
116
+ if "." not in name:
117
+ data[name] = array[name]
118
+ continue
119
+
120
+ parts = name.split(".")
121
+ child_data = data
122
+ for part in parts[:-1]:
123
+ if part not in child_data:
124
+ child_data[part] = {}
125
+ child_data = child_data[part]
126
+ child_data[parts[-1]] = array[name]
127
+
128
+ return data
129
+
130
+
131
+ def flatten_struct_table(table: pa.Table, separator=".") -> pa.Table:
132
+ """Turn a nested struct table into a flat table with dot-separated names."""
133
+ data = []
134
+ names = []
135
+
136
+ def _unfold(array: pa.Array, prefix: str):
137
+ if pa.types.is_struct(array.type):
138
+ if isinstance(array, pa.ChunkedArray):
139
+ array = array.combine_chunks()
140
+ for f in array.type:
141
+ _unfold(field(array, f.name), f"{prefix}{separator}{f.name}")
142
+ else:
143
+ data.append(array)
144
+ names.append(prefix)
145
+
146
+ for col in table.column_names:
147
+ _unfold(table[col], col)
148
+
149
+ return pa.Table.from_arrays(data, names=names)
150
+
151
+
152
+ def dict_to_table(data) -> pa.Table:
153
+ return pa.Table.from_struct_array(dict_to_struct_array(data))
154
+
155
+
156
+ def dict_to_struct_array(data, propagate_nulls: bool = False) -> pa.StructArray:
157
+ """Convert a nested dictionary of arrays to a table with nested structs."""
158
+ if isinstance(data, pa.ChunkedArray):
159
+ return data.combine_chunks()
160
+ if isinstance(data, pa.Array):
161
+ return data
162
+ arrays = [dict_to_struct_array(value) for value in data.values()]
163
+ return pa.StructArray.from_arrays(
164
+ arrays,
165
+ names=list(data.keys()),
166
+ mask=reduce(pc.and_, [pc.is_null(array) for array in arrays]) if propagate_nulls else None,
167
+ )
168
+
169
+
170
+ def struct_array_to_dict(array: pa.StructArray, array_fn: Callable[[pa.Array], T] = lambda a: a) -> dict | T:
171
+ """Convert a struct array to a nested dictionary."""
172
+ if not pa.types.is_struct(array.type):
173
+ return array_fn(array)
174
+ if isinstance(array, pa.ChunkedArray):
175
+ array = array.combine_chunks()
176
+ return {field.name: struct_array_to_dict(array.field(i), array_fn=array_fn) for i, field in enumerate(array.type)}
177
+
178
+
179
+ def table_to_struct_array(table: pa.Table) -> pa.StructArray:
180
+ if not table.num_rows:
181
+ return pa.array([], type=pa.struct(table.schema))
182
+ array = table.to_struct_array()
183
+ if isinstance(array, pa.ChunkedArray):
184
+ array = array.combine_chunks()
185
+ return array
186
+
187
+
188
+ def table_from_struct_array(array: pa.StructArray | pa.ChunkedArray):
189
+ if len(array) == 0:
190
+ return null_table(pa.schema(array.type))
191
+ return pa.Table.from_struct_array(array)
192
+
193
+
194
+ def field(value: pa.StructArray | pa.StructScalar, name: str) -> pa.Array | pa.Scalar:
195
+ """Get a field from a struct-like value."""
196
+ if isinstance(value, pa.StructScalar):
197
+ return value[name]
198
+ return value.field(name)
199
+
200
+
201
+ def concat_tables(tables: list[pa.Table]) -> pa.Table:
202
+ """
203
+ Concatenate pyarrow.Table objects, filling "missing" data with appropriate null arrays
204
+ and casting arrays to the most common denominator type that fits all fields.
205
+ """
206
+ if len(tables) == 1:
207
+ return tables[0]
208
+ else:
209
+ return pa.concat_tables(tables, promote_options="permissive")
File without changes
spiral/authn/authn.py ADDED
@@ -0,0 +1,89 @@
1
+ import base64
2
+ import logging
3
+ import os
4
+
5
+ from spiral.api import Authn, SpiralAPI
6
+
7
+ ENV_TOKEN_ID = "SPIRAL_TOKEN_ID"
8
+ ENV_TOKEN_SECRET = "SPIRAL_TOKEN_SECRET"
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class FallbackAuthn(Authn):
14
+ """Credential provider that tries multiple providers in order."""
15
+
16
+ def __init__(self, providers: list[Authn]):
17
+ self._providers = providers
18
+
19
+ def token(self) -> str | None:
20
+ for provider in self._providers:
21
+ token = provider.token()
22
+ if token is not None:
23
+ return token
24
+ return None
25
+
26
+
27
+ class TokenAuthn(Authn):
28
+ """Credential provider that returns a fixed token."""
29
+
30
+ def __init__(self, token: str):
31
+ self._token = token
32
+
33
+ def token(self) -> str:
34
+ return self._token
35
+
36
+
37
+ class EnvironmentAuthn(Authn):
38
+ """Credential provider that returns a basic token from the environment.
39
+
40
+ NOTE: Returns basic token. Must be exchanged.
41
+ """
42
+
43
+ def token(self) -> str | None:
44
+ if ENV_TOKEN_ID not in os.environ:
45
+ return None
46
+ if ENV_TOKEN_SECRET not in os.environ:
47
+ raise ValueError(f"{ENV_TOKEN_SECRET} is missing.")
48
+
49
+ token_id = os.environ[ENV_TOKEN_ID]
50
+ token_secret = os.environ[ENV_TOKEN_SECRET]
51
+ basic_token = base64.b64encode(f"{token_id}:{token_secret}".encode()).decode("utf-8")
52
+
53
+ return basic_token
54
+
55
+
56
+ class DeviceAuthProvider(Authn):
57
+ """Auth provider that uses the device flow to authenticate a Spiral user."""
58
+
59
+ def __init__(self, device_auth):
60
+ # NOTE(ngates): device_auth: spiral.auth.device_code.DeviceAuth
61
+ # We don't type it to satisfy our import linter
62
+ self._device_auth = device_auth
63
+
64
+ def token(self) -> str | None:
65
+ # TODO(ngates): only run this if we're in a notebook, CLI, or otherwise on the user's machine.
66
+ return self._device_auth.authenticate().access_token
67
+
68
+
69
+ class TokenExchangeProvider(Authn):
70
+ """Auth provider that exchanges a basic token for a Spiral token."""
71
+
72
+ def __init__(self, authn: Authn, base_url: str):
73
+ self._authn = authn
74
+ self._token_service = SpiralAPI(authn, base_url).token
75
+
76
+ self._sp_token = None
77
+
78
+ def token(self) -> str | None:
79
+ if self._sp_token is not None:
80
+ return self._sp_token
81
+
82
+ # Don't try to exchange if token is not discovered.
83
+ if self._authn.token() is None:
84
+ return None
85
+
86
+ log.debug("Exchanging token")
87
+ self._sp_token = self._token_service.exchange().token
88
+
89
+ return self._sp_token
spiral/authn/device.py ADDED
@@ -0,0 +1,206 @@
1
+ import logging
2
+ import sys
3
+ import textwrap
4
+ import time
5
+ import webbrowser
6
+ from pathlib import Path
7
+
8
+ import httpx
9
+ import jwt
10
+ from pydantic import BaseModel
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ class TokensModel(BaseModel):
16
+ access_token: str
17
+ refresh_token: str
18
+
19
+ @property
20
+ def organization_id(self) -> str | None:
21
+ return self.unverified_access_token().get("org_id")
22
+
23
+ def unverified_access_token(self):
24
+ return jwt.decode(self.access_token, options={"verify_signature": False})
25
+
26
+
27
+ class AuthModel(BaseModel):
28
+ tokens: TokensModel | None = None
29
+
30
+
31
+ class DeviceAuth:
32
+ def __init__(
33
+ self,
34
+ auth_file: Path,
35
+ domain: str,
36
+ client_id: str,
37
+ http: httpx.Client = None,
38
+ ):
39
+ self._auth_file = auth_file
40
+ self._domain = domain
41
+ self._client_id = client_id
42
+ self._http = http or httpx.Client()
43
+
44
+ if self._auth_file.exists():
45
+ with self._auth_file.open("r") as f:
46
+ self._auth = AuthModel.model_validate_json(f.read())
47
+ else:
48
+ self._auth = AuthModel()
49
+
50
+ self._default_scope = ["email", "profile"]
51
+
52
+ def is_authenticated(self) -> bool:
53
+ """Check if the user is authenticated."""
54
+ tokens = self._auth.tokens
55
+ if tokens is None:
56
+ return False
57
+
58
+ # Give ourselves a 30-second buffer before the token expires.
59
+ return tokens.unverified_access_token()["exp"] - 30 > time.time()
60
+
61
+ def authenticate(self, force: bool = False, refresh: bool = False, organization_id: str = None) -> TokensModel:
62
+ """Blocking call to authenticate the user.
63
+
64
+ Triggers a device code flow and polls for the user to login.
65
+ """
66
+ if force:
67
+ return self._device_code(organization_id)
68
+
69
+ if refresh:
70
+ if self._auth.tokens is None:
71
+ raise ValueError("No tokens to refresh.")
72
+ tokens = self._refresh(self._auth.tokens, organization_id)
73
+ if not tokens:
74
+ raise ValueError("Failed to refresh token.")
75
+ return tokens
76
+
77
+ # Check for mis-matched organization.
78
+ if organization_id is not None:
79
+ tokens = self._auth.tokens
80
+ if tokens is not None and tokens.unverified_access_token().get("org_id") != organization_id:
81
+ tokens = self._refresh(self._auth.tokens, organization_id)
82
+ if tokens is None:
83
+ return self._device_code(organization_id)
84
+
85
+ if self.is_authenticated():
86
+ return self._auth.tokens
87
+
88
+ # Try to refresh.
89
+ tokens = self._auth.tokens
90
+ if tokens is not None:
91
+ tokens = self._refresh(tokens)
92
+ if tokens is not None:
93
+ return tokens
94
+
95
+ # Otherwise, we kick off the device code flow.
96
+ return self._device_code(organization_id)
97
+
98
+ def logout(self):
99
+ self._remove_tokens()
100
+
101
+ def _device_code(self, organization_id: str | None):
102
+ scope = " ".join(self._default_scope)
103
+ res = self._http.post(
104
+ f"{self._domain}/auth/device/code",
105
+ data={
106
+ "client_id": self._client_id,
107
+ "scope": scope,
108
+ "organization_id": organization_id,
109
+ },
110
+ )
111
+ res = res.raise_for_status().json()
112
+ device_code = res["device_code"]
113
+ user_code = res["user_code"]
114
+ expires_at = res["expires_in"] + time.time()
115
+ interval = res["interval"]
116
+ verification_uri_complete = res["verification_uri_complete"]
117
+
118
+ # We need to detect if the user is running in a terminal, in Jupyter, etc.
119
+ # For now, we'll try to open the browser.
120
+ sys.stderr.write(
121
+ textwrap.dedent(
122
+ f"""
123
+ Please login here: {verification_uri_complete}
124
+ Your code is {user_code}.
125
+ """
126
+ )
127
+ )
128
+
129
+ # Try to open the browser (this also works if the Jupiter notebook is running on the user's machine).
130
+ opened = webbrowser.open(verification_uri_complete)
131
+
132
+ # If we have a server-side Jupyter notebook, we can try to open with client-side JavaScript.
133
+ if not opened and _in_notebook():
134
+ from IPython.display import Javascript, display
135
+
136
+ display(Javascript(f'window.open("{verification_uri_complete}");'))
137
+
138
+ # In the meantime, we need to poll for the user to login.
139
+ while True:
140
+ if time.time() > expires_at:
141
+ raise TimeoutError("Login timed out.")
142
+ time.sleep(interval)
143
+ res = self._http.post(
144
+ f"{self._domain}/auth/token",
145
+ data={
146
+ "client_id": self._client_id,
147
+ "device_code": device_code,
148
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
149
+ },
150
+ )
151
+ if not res.is_success:
152
+ continue
153
+
154
+ tokens = TokensModel(
155
+ access_token=res.json()["access_token"],
156
+ refresh_token=res.json()["refresh_token"],
157
+ )
158
+ self._save_tokens(tokens)
159
+ return self._auth.tokens
160
+
161
+ def _refresh(self, tokens: TokensModel, organization_id: str = None) -> TokensModel | None:
162
+ """Attempt to use the refresh token."""
163
+ log.debug("Refreshing token %s", self._client_id)
164
+
165
+ res = self._http.post(
166
+ f"{self._domain}/auth/refresh",
167
+ data={
168
+ "client_id": self._client_id,
169
+ "grant_type": "refresh_token",
170
+ "refresh_token": tokens.refresh_token,
171
+ "organization_id": organization_id,
172
+ },
173
+ )
174
+ if not res.is_success:
175
+ print("Failed to refresh token", res.status_code, res.text)
176
+ return None
177
+
178
+ tokens = TokensModel(
179
+ access_token=res.json()["access_token"],
180
+ refresh_token=res.json()["refresh_token"],
181
+ )
182
+ self._save_tokens(tokens)
183
+ return tokens
184
+
185
+ def _save_tokens(self, tokens: TokensModel):
186
+ self._auth = self._auth.model_copy(update={"tokens": tokens})
187
+ self._auth_file.parent.mkdir(parents=True, exist_ok=True)
188
+ with self._auth_file.open("w") as f:
189
+ f.write(self._auth.model_dump_json(exclude_defaults=True))
190
+
191
+ def _remove_tokens(self):
192
+ self._auth_file.unlink(missing_ok=True)
193
+ self._auth = self._auth.model_copy(update={"tokens": None})
194
+
195
+
196
+ def _in_notebook():
197
+ try:
198
+ from IPython import get_ipython
199
+
200
+ if "IPKernelApp" not in get_ipython().config: # pragma: no cover
201
+ return False
202
+ except ImportError:
203
+ return False
204
+ except AttributeError:
205
+ return False
206
+ return True