pyspiral 0.1.0__cp310-abi3-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (81) hide show
  1. pyspiral-0.1.0.dist-info/METADATA +48 -0
  2. pyspiral-0.1.0.dist-info/RECORD +81 -0
  3. pyspiral-0.1.0.dist-info/WHEEL +4 -0
  4. pyspiral-0.1.0.dist-info/entry_points.txt +2 -0
  5. spiral/__init__.py +11 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +386 -0
  8. spiral/api/__init__.py +221 -0
  9. spiral/api/admin.py +29 -0
  10. spiral/api/filesystems.py +125 -0
  11. spiral/api/organizations.py +90 -0
  12. spiral/api/projects.py +160 -0
  13. spiral/api/tables.py +94 -0
  14. spiral/api/tokens.py +56 -0
  15. spiral/api/workloads.py +45 -0
  16. spiral/arrow.py +209 -0
  17. spiral/authn/__init__.py +0 -0
  18. spiral/authn/authn.py +89 -0
  19. spiral/authn/device.py +206 -0
  20. spiral/authn/github_.py +33 -0
  21. spiral/authn/modal_.py +18 -0
  22. spiral/catalog.py +78 -0
  23. spiral/cli/__init__.py +82 -0
  24. spiral/cli/__main__.py +4 -0
  25. spiral/cli/admin.py +21 -0
  26. spiral/cli/app.py +48 -0
  27. spiral/cli/console.py +95 -0
  28. spiral/cli/fs.py +47 -0
  29. spiral/cli/login.py +13 -0
  30. spiral/cli/org.py +90 -0
  31. spiral/cli/printer.py +45 -0
  32. spiral/cli/project.py +107 -0
  33. spiral/cli/state.py +3 -0
  34. spiral/cli/table.py +20 -0
  35. spiral/cli/token.py +27 -0
  36. spiral/cli/types.py +53 -0
  37. spiral/cli/workload.py +59 -0
  38. spiral/config.py +26 -0
  39. spiral/core/__init__.py +0 -0
  40. spiral/core/core/__init__.pyi +53 -0
  41. spiral/core/manifests/__init__.pyi +53 -0
  42. spiral/core/metastore/__init__.pyi +91 -0
  43. spiral/core/spec/__init__.pyi +257 -0
  44. spiral/dataset.py +239 -0
  45. spiral/debug.py +251 -0
  46. spiral/expressions/__init__.py +222 -0
  47. spiral/expressions/base.py +149 -0
  48. spiral/expressions/http.py +86 -0
  49. spiral/expressions/io.py +100 -0
  50. spiral/expressions/list_.py +68 -0
  51. spiral/expressions/refs.py +44 -0
  52. spiral/expressions/str_.py +39 -0
  53. spiral/expressions/struct.py +57 -0
  54. spiral/expressions/tiff.py +223 -0
  55. spiral/expressions/udf.py +46 -0
  56. spiral/grpc_.py +32 -0
  57. spiral/project.py +137 -0
  58. spiral/proto/_/__init__.py +0 -0
  59. spiral/proto/_/arrow/__init__.py +0 -0
  60. spiral/proto/_/arrow/flight/__init__.py +0 -0
  61. spiral/proto/_/arrow/flight/protocol/__init__.py +0 -0
  62. spiral/proto/_/arrow/flight/protocol/sql/__init__.py +1990 -0
  63. spiral/proto/_/scandal/__init__.py +223 -0
  64. spiral/proto/_/spfs/__init__.py +36 -0
  65. spiral/proto/_/spiral/__init__.py +0 -0
  66. spiral/proto/_/spiral/table/__init__.py +225 -0
  67. spiral/proto/_/spiraldb/__init__.py +0 -0
  68. spiral/proto/_/spiraldb/metastore/__init__.py +499 -0
  69. spiral/proto/__init__.py +0 -0
  70. spiral/proto/scandal/__init__.py +45 -0
  71. spiral/proto/spiral/__init__.py +0 -0
  72. spiral/proto/spiral/table/__init__.py +96 -0
  73. spiral/proto/substrait/__init__.py +3399 -0
  74. spiral/proto/substrait/extensions/__init__.py +115 -0
  75. spiral/proto/util.py +41 -0
  76. spiral/py.typed +0 -0
  77. spiral/scan_.py +168 -0
  78. spiral/settings.py +157 -0
  79. spiral/substrait_.py +275 -0
  80. spiral/table.py +157 -0
  81. spiral/types_.py +6 -0
spiral/api/tables.py ADDED
@@ -0,0 +1,94 @@
1
+ from typing import Annotated
2
+
3
+ from pydantic import (
4
+ AfterValidator,
5
+ BaseModel,
6
+ StringConstraints,
7
+ )
8
+
9
+ from . import ArrowSchema, Paged, PagedRequest, PagedResponse, ProjectId, ServiceBase
10
+
11
+
12
+ def _validate_root_uri(uri: str) -> str:
13
+ if uri.endswith("/"):
14
+ raise ValueError("Root URI must not end with a slash.")
15
+ return uri
16
+
17
+
18
+ RootUri = Annotated[str, AfterValidator(_validate_root_uri)]
19
+ DatasetName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
20
+ TableName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
21
+
22
+
23
+ class TableMetadata(BaseModel):
24
+ key_schema: ArrowSchema
25
+ root_uri: RootUri
26
+ spfs_mount_id: str | None = None
27
+
28
+ # TODO(marko): Randomize this on creation of metadata.
29
+ # Column group salt is used to compute column group IDs.
30
+ # It's used to ensure that column group IDs are unique
31
+ # across different tables, even if paths are the same.
32
+ # It's never modified.
33
+ column_group_salt: int = 0
34
+
35
+
36
+ class Table(BaseModel):
37
+ id: str
38
+ project_id: ProjectId
39
+ dataset: DatasetName
40
+ table: TableName
41
+ metadata: TableMetadata
42
+
43
+
44
+ class CreateTable:
45
+ class Request(BaseModel):
46
+ project_id: ProjectId
47
+ dataset: DatasetName
48
+ table: TableName
49
+ key_schema: ArrowSchema
50
+ root_uri: RootUri | None = None
51
+ exist_ok: bool = False
52
+
53
+ class Response(BaseModel):
54
+ table: Table
55
+
56
+
57
+ class FindTable:
58
+ class Request(BaseModel):
59
+ project_id: ProjectId
60
+ dataset: DatasetName = None
61
+ table: TableName = None
62
+
63
+ class Response(BaseModel):
64
+ table: Table | None
65
+
66
+
67
+ class GetTable:
68
+ class Request(BaseModel):
69
+ id: str
70
+
71
+ class Response(BaseModel):
72
+ table: Table
73
+
74
+
75
+ class ListTables:
76
+ class Request(PagedRequest):
77
+ project_id: ProjectId
78
+ dataset: DatasetName | None = None
79
+
80
+ class Response(PagedResponse[Table]): ...
81
+
82
+
83
+ class TableService(ServiceBase):
84
+ def create(self, req: CreateTable.Request) -> CreateTable.Response:
85
+ return self.client.post("/table/create", req, CreateTable.Response)
86
+
87
+ def find(self, req: FindTable.Request) -> FindTable.Response:
88
+ return self.client.put("/table/find", req, FindTable.Response)
89
+
90
+ def get(self, req: GetTable.Request) -> GetTable.Response:
91
+ return self.client.put(f"/table/{req.id}", GetTable.Response)
92
+
93
+ def list(self, req: ListTables.Request) -> Paged[Table]:
94
+ return self.client.paged("/table/list", req, ListTables.Response)
spiral/api/tokens.py ADDED
@@ -0,0 +1,56 @@
1
+ from pydantic import BaseModel
2
+
3
+ from . import Paged, PagedRequest, PagedResponse, ServiceBase
4
+
5
+
6
+ class Token(BaseModel):
7
+ id: str
8
+ project_id: str
9
+ on_behalf_of: str
10
+
11
+
12
+ class ExchangeToken:
13
+ class Request(BaseModel): ...
14
+
15
+ class Response(BaseModel):
16
+ token: str
17
+
18
+
19
+ class IssueToken:
20
+ class Request(BaseModel): ...
21
+
22
+ class Response(BaseModel):
23
+ token: Token
24
+ token_secret: str
25
+
26
+
27
+ class RevokeToken:
28
+ class Request(BaseModel):
29
+ token_id: str
30
+
31
+ class Response(BaseModel):
32
+ token: Token
33
+
34
+
35
+ class ListTokens:
36
+ class Request(PagedRequest):
37
+ project_id: str
38
+ on_behalf_of: str | None = None
39
+
40
+ class Response(PagedResponse[Token]): ...
41
+
42
+
43
+ class TokenService(ServiceBase):
44
+ def exchange(self) -> ExchangeToken.Response:
45
+ """Exchange a basic / identity token to a short-lived Spiral token."""
46
+ return self.client.post("/token/exchange", ExchangeToken.Request(), ExchangeToken.Response)
47
+
48
+ def issue(self) -> IssueToken.Response:
49
+ """Issue an API token on behalf of a principal."""
50
+ return self.client.post("/token/issue", IssueToken.Request(), IssueToken.Response)
51
+
52
+ def revoke(self, request: RevokeToken.Request) -> RevokeToken.Response:
53
+ return self.client.put("/token/revoke", request, RevokeToken.Response)
54
+
55
+ def list(self, request: ListTokens.Request) -> Paged[Token]:
56
+ return self.client.paged("/token/list", request, ListTokens.Response)
@@ -0,0 +1,45 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ from . import Paged, PagedRequest, PagedResponse, ProjectId, ServiceBase
4
+
5
+
6
+ class Workload(BaseModel):
7
+ id: str
8
+ project_id: ProjectId
9
+ name: str | None = None
10
+
11
+
12
+ class CreateWorkload:
13
+ class Request(BaseModel):
14
+ project_id: str
15
+ name: str | None = Field(default=None, description="Optional human-readable name for the workload")
16
+
17
+ class Response(BaseModel):
18
+ workload: Workload
19
+
20
+
21
+ class IssueToken:
22
+ class Request(BaseModel):
23
+ workload_id: str
24
+
25
+ class Response(BaseModel):
26
+ token_id: str
27
+ token_secret: str
28
+
29
+
30
+ class ListWorkloads:
31
+ class Request(PagedRequest):
32
+ project_id: str
33
+
34
+ class Response(PagedResponse[Workload]): ...
35
+
36
+
37
+ class WorkloadService(ServiceBase):
38
+ def create(self, request: CreateWorkload.Request) -> CreateWorkload.Response:
39
+ return self.client.post("/workload/create", request, CreateWorkload.Response)
40
+
41
+ def issue_token(self, request: IssueToken.Request) -> IssueToken.Response:
42
+ return self.client.post("/workload/issue-token", request, IssueToken.Response)
43
+
44
+ def list(self, request: ListWorkloads.Request) -> Paged[Workload]:
45
+ return self.client.paged("/workload/list", request, ListWorkloads.Response)
spiral/arrow.py ADDED
@@ -0,0 +1,209 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Callable, Iterable
3
+ from functools import reduce
4
+ from typing import TypeVar
5
+
6
+ import numpy as np
7
+ import pyarrow as pa
8
+ from pyarrow import compute as pc
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ def arange(*args, **kwargs) -> pa.Array:
14
+ return pa.array(np.arange(*args, **kwargs), type=pa.int32())
15
+
16
+
17
+ def zip_tables(tables: Iterable[pa.Table]) -> pa.Table:
18
+ data = []
19
+ names = []
20
+ for table in tables:
21
+ data.extend(table.columns)
22
+ names.extend(table.column_names)
23
+ return pa.Table.from_arrays(data, names=names)
24
+
25
+
26
+ def merge_arrays(*arrays: pa.StructArray) -> pa.StructArray:
27
+ """Recursively merge arrays into nested struct arrays."""
28
+ if len(arrays) == 1:
29
+ return arrays[0]
30
+
31
+ nstructs = sum(pa.types.is_struct(a.type) for a in arrays)
32
+ if nstructs == 0:
33
+ # Then we have conflicting arrays and we choose the last.
34
+ return arrays[-1]
35
+
36
+ if nstructs != len(arrays):
37
+ raise ValueError("Cannot merge structs with non-structs.")
38
+
39
+ data = defaultdict(list)
40
+ for array in arrays:
41
+ if isinstance(array, pa.ChunkedArray):
42
+ array = array.combine_chunks()
43
+ for field in array.type:
44
+ data[field.name].append(array.field(field.name))
45
+
46
+ return pa.StructArray.from_arrays([merge_arrays(*v) for v in data.values()], names=list(data.keys()))
47
+
48
+
49
+ def merge_scalars(*scalars: pa.StructScalar) -> pa.StructScalar:
50
+ """Recursively merge scalars into nested struct scalars."""
51
+ if len(scalars) == 1:
52
+ return scalars[0]
53
+
54
+ nstructs = sum(pa.types.is_struct(a.type) for a in scalars)
55
+ if nstructs == 0:
56
+ # Then we have conflicting scalars and we choose the last.
57
+ return scalars[-1]
58
+
59
+ if nstructs != len(scalars):
60
+ raise ValueError("Cannot merge scalars with non-scalars.")
61
+
62
+ data = defaultdict(list)
63
+ for scalar in scalars:
64
+ for field in scalar.type:
65
+ data[field.name].append(scalar[field.name])
66
+
67
+ return pa.scalar({k: merge_scalars(*v) for k, v in data.items()})
68
+
69
+
70
+ def null_table(schema: pa.Schema, length: int = 0) -> pa.Table:
71
+ # We add an extra nulls column to ensure the length is correctly applied.
72
+ return pa.table(
73
+ [pa.nulls(length, type=field.type) for field in schema] + [pa.nulls(length)],
74
+ schema=pa.schema(list(schema) + [pa.field("__", type=pa.null())]),
75
+ ).drop(["__"])
76
+
77
+
78
+ def coalesce_all(table: pa.Table) -> pa.Table:
79
+ """Coalesce all columns that share the same name."""
80
+ columns: dict[str, list[pa.Array]] = defaultdict(list)
81
+ for i, col in enumerate(table.column_names):
82
+ columns[col].append(table[i])
83
+
84
+ data = []
85
+ names = []
86
+ for col, arrays in columns.items():
87
+ names.append(col)
88
+ if len(arrays) == 1:
89
+ data.append(arrays[0])
90
+ else:
91
+ data.append(pc.coalesce(*arrays))
92
+
93
+ return pa.Table.from_arrays(data, names=names)
94
+
95
+
96
+ def join(left: pa.Table, right: pa.Table, keys: list[str], join_type: str) -> pa.Table:
97
+ """Arrow's builtin join doesn't support struct columns. So we join ourselves and zip them in."""
98
+ # TODO(ngates): if join_type == inner, we may have better luck performing two index_in operations since this
99
+ # also preserves sort order.
100
+ lhs = left.select(keys).add_column(0, "__lhs", arange(len(left)))
101
+ rhs = right.select(keys).add_column(0, "__rhs", arange(len(right)))
102
+ joined = lhs.join(rhs, keys=keys, join_type=join_type).sort_by([(k, "ascending") for k in keys])
103
+ return zip_tables(
104
+ [joined.select(keys), left.take(joined["__lhs"]).drop(keys), right.take(joined["__rhs"]).drop(keys)]
105
+ )
106
+
107
+
108
+ def nest_structs(array: pa.StructArray | pa.StructScalar | dict) -> dict:
109
+ """Turn a struct-like value with dot-separated column names into a nested dictionary."""
110
+ data = {}
111
+
112
+ if isinstance(array, pa.StructArray | pa.StructScalar):
113
+ array = {f.name: field(array, f.name) for f in array.type}
114
+
115
+ for name in array.keys():
116
+ if "." not in name:
117
+ data[name] = array[name]
118
+ continue
119
+
120
+ parts = name.split(".")
121
+ child_data = data
122
+ for part in parts[:-1]:
123
+ if part not in child_data:
124
+ child_data[part] = {}
125
+ child_data = child_data[part]
126
+ child_data[parts[-1]] = array[name]
127
+
128
+ return data
129
+
130
+
131
+ def flatten_struct_table(table: pa.Table, separator=".") -> pa.Table:
132
+ """Turn a nested struct table into a flat table with dot-separated names."""
133
+ data = []
134
+ names = []
135
+
136
+ def _unfold(array: pa.Array, prefix: str):
137
+ if pa.types.is_struct(array.type):
138
+ if isinstance(array, pa.ChunkedArray):
139
+ array = array.combine_chunks()
140
+ for f in array.type:
141
+ _unfold(field(array, f.name), f"{prefix}{separator}{f.name}")
142
+ else:
143
+ data.append(array)
144
+ names.append(prefix)
145
+
146
+ for col in table.column_names:
147
+ _unfold(table[col], col)
148
+
149
+ return pa.Table.from_arrays(data, names=names)
150
+
151
+
152
+ def dict_to_table(data) -> pa.Table:
153
+ return pa.Table.from_struct_array(dict_to_struct_array(data))
154
+
155
+
156
+ def dict_to_struct_array(data, propagate_nulls: bool = False) -> pa.StructArray:
157
+ """Convert a nested dictionary of arrays to a table with nested structs."""
158
+ if isinstance(data, pa.ChunkedArray):
159
+ return data.combine_chunks()
160
+ if isinstance(data, pa.Array):
161
+ return data
162
+ arrays = [dict_to_struct_array(value) for value in data.values()]
163
+ return pa.StructArray.from_arrays(
164
+ arrays,
165
+ names=list(data.keys()),
166
+ mask=reduce(pc.and_, [pc.is_null(array) for array in arrays]) if propagate_nulls else None,
167
+ )
168
+
169
+
170
+ def struct_array_to_dict(array: pa.StructArray, array_fn: Callable[[pa.Array], T] = lambda a: a) -> dict | T:
171
+ """Convert a struct array to a nested dictionary."""
172
+ if not pa.types.is_struct(array.type):
173
+ return array_fn(array)
174
+ if isinstance(array, pa.ChunkedArray):
175
+ array = array.combine_chunks()
176
+ return {field.name: struct_array_to_dict(array.field(i), array_fn=array_fn) for i, field in enumerate(array.type)}
177
+
178
+
179
+ def table_to_struct_array(table: pa.Table) -> pa.StructArray:
180
+ if not table.num_rows:
181
+ return pa.array([], type=pa.struct(table.schema))
182
+ array = table.to_struct_array()
183
+ if isinstance(array, pa.ChunkedArray):
184
+ array = array.combine_chunks()
185
+ return array
186
+
187
+
188
+ def table_from_struct_array(array: pa.StructArray | pa.ChunkedArray):
189
+ if len(array) == 0:
190
+ return null_table(pa.schema(array.type))
191
+ return pa.Table.from_struct_array(array)
192
+
193
+
194
+ def field(value: pa.StructArray | pa.StructScalar, name: str) -> pa.Array | pa.Scalar:
195
+ """Get a field from a struct-like value."""
196
+ if isinstance(value, pa.StructScalar):
197
+ return value[name]
198
+ return value.field(name)
199
+
200
+
201
+ def concat_tables(tables: list[pa.Table]) -> pa.Table:
202
+ """
203
+ Concatenate pyarrow.Table objects, filling "missing" data with appropriate null arrays
204
+ and casting arrays to the most common denominator type that fits all fields.
205
+ """
206
+ if len(tables) == 1:
207
+ return tables[0]
208
+ else:
209
+ return pa.concat_tables(tables, promote_options="permissive")
File without changes
spiral/authn/authn.py ADDED
@@ -0,0 +1,89 @@
1
+ import base64
2
+ import logging
3
+ import os
4
+
5
+ from spiral.api import Authn, SpiralAPI
6
+
7
+ ENV_TOKEN_ID = "SPIRAL_TOKEN_ID"
8
+ ENV_TOKEN_SECRET = "SPIRAL_TOKEN_SECRET"
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class FallbackAuthn(Authn):
14
+ """Credential provider that tries multiple providers in order."""
15
+
16
+ def __init__(self, providers: list[Authn]):
17
+ self._providers = providers
18
+
19
+ def token(self) -> str | None:
20
+ for provider in self._providers:
21
+ token = provider.token()
22
+ if token is not None:
23
+ return token
24
+ return None
25
+
26
+
27
+ class TokenAuthn(Authn):
28
+ """Credential provider that returns a fixed token."""
29
+
30
+ def __init__(self, token: str):
31
+ self._token = token
32
+
33
+ def token(self) -> str:
34
+ return self._token
35
+
36
+
37
+ class EnvironmentAuthn(Authn):
38
+ """Credential provider that returns a basic token from the environment.
39
+
40
+ NOTE: Returns basic token. Must be exchanged.
41
+ """
42
+
43
+ def token(self) -> str | None:
44
+ if ENV_TOKEN_ID not in os.environ:
45
+ return None
46
+ if ENV_TOKEN_SECRET not in os.environ:
47
+ raise ValueError(f"{ENV_TOKEN_SECRET} is missing.")
48
+
49
+ token_id = os.environ[ENV_TOKEN_ID]
50
+ token_secret = os.environ[ENV_TOKEN_SECRET]
51
+ basic_token = base64.b64encode(f"{token_id}:{token_secret}".encode()).decode("utf-8")
52
+
53
+ return basic_token
54
+
55
+
56
+ class DeviceAuthProvider(Authn):
57
+ """Auth provider that uses the device flow to authenticate a Spiral user."""
58
+
59
+ def __init__(self, device_auth):
60
+ # NOTE(ngates): device_auth: spiral.auth.device_code.DeviceAuth
61
+ # We don't type it to satisfy our import linter
62
+ self._device_auth = device_auth
63
+
64
+ def token(self) -> str | None:
65
+ # TODO(ngates): only run this if we're in a notebook, CLI, or otherwise on the user's machine.
66
+ return self._device_auth.authenticate().access_token
67
+
68
+
69
+ class TokenExchangeProvider(Authn):
70
+ """Auth provider that exchanges a basic token for a Spiral token."""
71
+
72
+ def __init__(self, authn: Authn, base_url: str):
73
+ self._authn = authn
74
+ self._token_service = SpiralAPI(authn, base_url).token
75
+
76
+ self._sp_token = None
77
+
78
+ def token(self) -> str | None:
79
+ if self._sp_token is not None:
80
+ return self._sp_token
81
+
82
+ # Don't try to exchange if token is not discovered.
83
+ if self._authn.token() is None:
84
+ return None
85
+
86
+ log.debug("Exchanging token")
87
+ self._sp_token = self._token_service.exchange().token
88
+
89
+ return self._sp_token
spiral/authn/device.py ADDED
@@ -0,0 +1,206 @@
1
+ import logging
2
+ import sys
3
+ import textwrap
4
+ import time
5
+ import webbrowser
6
+ from pathlib import Path
7
+
8
+ import httpx
9
+ import jwt
10
+ from pydantic import BaseModel
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ class TokensModel(BaseModel):
16
+ access_token: str
17
+ refresh_token: str
18
+
19
+ @property
20
+ def organization_id(self) -> str | None:
21
+ return self.unverified_access_token().get("org_id")
22
+
23
+ def unverified_access_token(self):
24
+ return jwt.decode(self.access_token, options={"verify_signature": False})
25
+
26
+
27
+ class AuthModel(BaseModel):
28
+ tokens: TokensModel | None = None
29
+
30
+
31
+ class DeviceAuth:
32
+ def __init__(
33
+ self,
34
+ auth_file: Path,
35
+ domain: str,
36
+ client_id: str,
37
+ http: httpx.Client = None,
38
+ ):
39
+ self._auth_file = auth_file
40
+ self._domain = domain
41
+ self._client_id = client_id
42
+ self._http = http or httpx.Client()
43
+
44
+ if self._auth_file.exists():
45
+ with self._auth_file.open("r") as f:
46
+ self._auth = AuthModel.model_validate_json(f.read())
47
+ else:
48
+ self._auth = AuthModel()
49
+
50
+ self._default_scope = ["email", "profile"]
51
+
52
+ def is_authenticated(self) -> bool:
53
+ """Check if the user is authenticated."""
54
+ tokens = self._auth.tokens
55
+ if tokens is None:
56
+ return False
57
+
58
+ # Give ourselves a 30-second buffer before the token expires.
59
+ return tokens.unverified_access_token()["exp"] - 30 > time.time()
60
+
61
+ def authenticate(self, force: bool = False, refresh: bool = False, organization_id: str = None) -> TokensModel:
62
+ """Blocking call to authenticate the user.
63
+
64
+ Triggers a device code flow and polls for the user to login.
65
+ """
66
+ if force:
67
+ return self._device_code(organization_id)
68
+
69
+ if refresh:
70
+ if self._auth.tokens is None:
71
+ raise ValueError("No tokens to refresh.")
72
+ tokens = self._refresh(self._auth.tokens, organization_id)
73
+ if not tokens:
74
+ raise ValueError("Failed to refresh token.")
75
+ return tokens
76
+
77
+ # Check for mis-matched organization.
78
+ if organization_id is not None:
79
+ tokens = self._auth.tokens
80
+ if tokens is not None and tokens.unverified_access_token().get("org_id") != organization_id:
81
+ tokens = self._refresh(self._auth.tokens, organization_id)
82
+ if tokens is None:
83
+ return self._device_code(organization_id)
84
+
85
+ if self.is_authenticated():
86
+ return self._auth.tokens
87
+
88
+ # Try to refresh.
89
+ tokens = self._auth.tokens
90
+ if tokens is not None:
91
+ tokens = self._refresh(tokens)
92
+ if tokens is not None:
93
+ return tokens
94
+
95
+ # Otherwise, we kick off the device code flow.
96
+ return self._device_code(organization_id)
97
+
98
+ def logout(self):
99
+ self._remove_tokens()
100
+
101
+ def _device_code(self, organization_id: str | None):
102
+ scope = " ".join(self._default_scope)
103
+ res = self._http.post(
104
+ f"{self._domain}/auth/device/code",
105
+ data={
106
+ "client_id": self._client_id,
107
+ "scope": scope,
108
+ "organization_id": organization_id,
109
+ },
110
+ )
111
+ res = res.raise_for_status().json()
112
+ device_code = res["device_code"]
113
+ user_code = res["user_code"]
114
+ expires_at = res["expires_in"] + time.time()
115
+ interval = res["interval"]
116
+ verification_uri_complete = res["verification_uri_complete"]
117
+
118
+ # We need to detect if the user is running in a terminal, in Jupyter, etc.
119
+ # For now, we'll try to open the browser.
120
+ sys.stderr.write(
121
+ textwrap.dedent(
122
+ f"""
123
+ Please login here: {verification_uri_complete}
124
+ Your code is {user_code}.
125
+ """
126
+ )
127
+ )
128
+
129
+ # Try to open the browser (this also works if the Jupiter notebook is running on the user's machine).
130
+ opened = webbrowser.open(verification_uri_complete)
131
+
132
+ # If we have a server-side Jupyter notebook, we can try to open with client-side JavaScript.
133
+ if not opened and _in_notebook():
134
+ from IPython.display import Javascript, display
135
+
136
+ display(Javascript(f'window.open("{verification_uri_complete}");'))
137
+
138
+ # In the meantime, we need to poll for the user to login.
139
+ while True:
140
+ if time.time() > expires_at:
141
+ raise TimeoutError("Login timed out.")
142
+ time.sleep(interval)
143
+ res = self._http.post(
144
+ f"{self._domain}/auth/token",
145
+ data={
146
+ "client_id": self._client_id,
147
+ "device_code": device_code,
148
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
149
+ },
150
+ )
151
+ if not res.is_success:
152
+ continue
153
+
154
+ tokens = TokensModel(
155
+ access_token=res.json()["access_token"],
156
+ refresh_token=res.json()["refresh_token"],
157
+ )
158
+ self._save_tokens(tokens)
159
+ return self._auth.tokens
160
+
161
+ def _refresh(self, tokens: TokensModel, organization_id: str = None) -> TokensModel | None:
162
+ """Attempt to use the refresh token."""
163
+ log.debug("Refreshing token %s", self._client_id)
164
+
165
+ res = self._http.post(
166
+ f"{self._domain}/auth/refresh",
167
+ data={
168
+ "client_id": self._client_id,
169
+ "grant_type": "refresh_token",
170
+ "refresh_token": tokens.refresh_token,
171
+ "organization_id": organization_id,
172
+ },
173
+ )
174
+ if not res.is_success:
175
+ print("Failed to refresh token", res.status_code, res.text)
176
+ return None
177
+
178
+ tokens = TokensModel(
179
+ access_token=res.json()["access_token"],
180
+ refresh_token=res.json()["refresh_token"],
181
+ )
182
+ self._save_tokens(tokens)
183
+ return tokens
184
+
185
+ def _save_tokens(self, tokens: TokensModel):
186
+ self._auth = self._auth.model_copy(update={"tokens": tokens})
187
+ self._auth_file.parent.mkdir(parents=True, exist_ok=True)
188
+ with self._auth_file.open("w") as f:
189
+ f.write(self._auth.model_dump_json(exclude_defaults=True))
190
+
191
+ def _remove_tokens(self):
192
+ self._auth_file.unlink(missing_ok=True)
193
+ self._auth = self._auth.model_copy(update={"tokens": None})
194
+
195
+
196
+ def _in_notebook():
197
+ try:
198
+ from IPython import get_ipython
199
+
200
+ if "IPKernelApp" not in get_ipython().config: # pragma: no cover
201
+ return False
202
+ except ImportError:
203
+ return False
204
+ except AttributeError:
205
+ return False
206
+ return True