pyspiral 0.1.0__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyspiral-0.1.0.dist-info/METADATA +48 -0
- pyspiral-0.1.0.dist-info/RECORD +81 -0
- pyspiral-0.1.0.dist-info/WHEEL +4 -0
- pyspiral-0.1.0.dist-info/entry_points.txt +2 -0
- spiral/__init__.py +11 -0
- spiral/_lib.abi3.so +0 -0
- spiral/adbc.py +386 -0
- spiral/api/__init__.py +221 -0
- spiral/api/admin.py +29 -0
- spiral/api/filesystems.py +125 -0
- spiral/api/organizations.py +90 -0
- spiral/api/projects.py +160 -0
- spiral/api/tables.py +94 -0
- spiral/api/tokens.py +56 -0
- spiral/api/workloads.py +45 -0
- spiral/arrow.py +209 -0
- spiral/authn/__init__.py +0 -0
- spiral/authn/authn.py +89 -0
- spiral/authn/device.py +206 -0
- spiral/authn/github_.py +33 -0
- spiral/authn/modal_.py +18 -0
- spiral/catalog.py +78 -0
- spiral/cli/__init__.py +82 -0
- spiral/cli/__main__.py +4 -0
- spiral/cli/admin.py +21 -0
- spiral/cli/app.py +48 -0
- spiral/cli/console.py +95 -0
- spiral/cli/fs.py +47 -0
- spiral/cli/login.py +13 -0
- spiral/cli/org.py +90 -0
- spiral/cli/printer.py +45 -0
- spiral/cli/project.py +107 -0
- spiral/cli/state.py +3 -0
- spiral/cli/table.py +20 -0
- spiral/cli/token.py +27 -0
- spiral/cli/types.py +53 -0
- spiral/cli/workload.py +59 -0
- spiral/config.py +26 -0
- spiral/core/__init__.py +0 -0
- spiral/core/core/__init__.pyi +53 -0
- spiral/core/manifests/__init__.pyi +53 -0
- spiral/core/metastore/__init__.pyi +91 -0
- spiral/core/spec/__init__.pyi +257 -0
- spiral/dataset.py +239 -0
- spiral/debug.py +251 -0
- spiral/expressions/__init__.py +222 -0
- spiral/expressions/base.py +149 -0
- spiral/expressions/http.py +86 -0
- spiral/expressions/io.py +100 -0
- spiral/expressions/list_.py +68 -0
- spiral/expressions/refs.py +44 -0
- spiral/expressions/str_.py +39 -0
- spiral/expressions/struct.py +57 -0
- spiral/expressions/tiff.py +223 -0
- spiral/expressions/udf.py +46 -0
- spiral/grpc_.py +32 -0
- spiral/project.py +137 -0
- spiral/proto/_/__init__.py +0 -0
- spiral/proto/_/arrow/__init__.py +0 -0
- spiral/proto/_/arrow/flight/__init__.py +0 -0
- spiral/proto/_/arrow/flight/protocol/__init__.py +0 -0
- spiral/proto/_/arrow/flight/protocol/sql/__init__.py +1990 -0
- spiral/proto/_/scandal/__init__.py +223 -0
- spiral/proto/_/spfs/__init__.py +36 -0
- spiral/proto/_/spiral/__init__.py +0 -0
- spiral/proto/_/spiral/table/__init__.py +225 -0
- spiral/proto/_/spiraldb/__init__.py +0 -0
- spiral/proto/_/spiraldb/metastore/__init__.py +499 -0
- spiral/proto/__init__.py +0 -0
- spiral/proto/scandal/__init__.py +45 -0
- spiral/proto/spiral/__init__.py +0 -0
- spiral/proto/spiral/table/__init__.py +96 -0
- spiral/proto/substrait/__init__.py +3399 -0
- spiral/proto/substrait/extensions/__init__.py +115 -0
- spiral/proto/util.py +41 -0
- spiral/py.typed +0 -0
- spiral/scan_.py +168 -0
- spiral/settings.py +157 -0
- spiral/substrait_.py +275 -0
- spiral/table.py +157 -0
- spiral/types_.py +6 -0
spiral/api/tables.py
ADDED
@@ -0,0 +1,94 @@
|
|
1
|
+
from typing import Annotated
|
2
|
+
|
3
|
+
from pydantic import (
|
4
|
+
AfterValidator,
|
5
|
+
BaseModel,
|
6
|
+
StringConstraints,
|
7
|
+
)
|
8
|
+
|
9
|
+
from . import ArrowSchema, Paged, PagedRequest, PagedResponse, ProjectId, ServiceBase
|
10
|
+
|
11
|
+
|
12
|
+
def _validate_root_uri(uri: str) -> str:
|
13
|
+
if uri.endswith("/"):
|
14
|
+
raise ValueError("Root URI must not end with a slash.")
|
15
|
+
return uri
|
16
|
+
|
17
|
+
|
18
|
+
RootUri = Annotated[str, AfterValidator(_validate_root_uri)]
|
19
|
+
DatasetName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
|
20
|
+
TableName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
|
21
|
+
|
22
|
+
|
23
|
+
class TableMetadata(BaseModel):
|
24
|
+
key_schema: ArrowSchema
|
25
|
+
root_uri: RootUri
|
26
|
+
spfs_mount_id: str | None = None
|
27
|
+
|
28
|
+
# TODO(marko): Randomize this on creation of metadata.
|
29
|
+
# Column group salt is used to compute column group IDs.
|
30
|
+
# It's used to ensure that column group IDs are unique
|
31
|
+
# across different tables, even if paths are the same.
|
32
|
+
# It's never modified.
|
33
|
+
column_group_salt: int = 0
|
34
|
+
|
35
|
+
|
36
|
+
class Table(BaseModel):
|
37
|
+
id: str
|
38
|
+
project_id: ProjectId
|
39
|
+
dataset: DatasetName
|
40
|
+
table: TableName
|
41
|
+
metadata: TableMetadata
|
42
|
+
|
43
|
+
|
44
|
+
class CreateTable:
|
45
|
+
class Request(BaseModel):
|
46
|
+
project_id: ProjectId
|
47
|
+
dataset: DatasetName
|
48
|
+
table: TableName
|
49
|
+
key_schema: ArrowSchema
|
50
|
+
root_uri: RootUri | None = None
|
51
|
+
exist_ok: bool = False
|
52
|
+
|
53
|
+
class Response(BaseModel):
|
54
|
+
table: Table
|
55
|
+
|
56
|
+
|
57
|
+
class FindTable:
|
58
|
+
class Request(BaseModel):
|
59
|
+
project_id: ProjectId
|
60
|
+
dataset: DatasetName = None
|
61
|
+
table: TableName = None
|
62
|
+
|
63
|
+
class Response(BaseModel):
|
64
|
+
table: Table | None
|
65
|
+
|
66
|
+
|
67
|
+
class GetTable:
|
68
|
+
class Request(BaseModel):
|
69
|
+
id: str
|
70
|
+
|
71
|
+
class Response(BaseModel):
|
72
|
+
table: Table
|
73
|
+
|
74
|
+
|
75
|
+
class ListTables:
|
76
|
+
class Request(PagedRequest):
|
77
|
+
project_id: ProjectId
|
78
|
+
dataset: DatasetName | None = None
|
79
|
+
|
80
|
+
class Response(PagedResponse[Table]): ...
|
81
|
+
|
82
|
+
|
83
|
+
class TableService(ServiceBase):
|
84
|
+
def create(self, req: CreateTable.Request) -> CreateTable.Response:
|
85
|
+
return self.client.post("/table/create", req, CreateTable.Response)
|
86
|
+
|
87
|
+
def find(self, req: FindTable.Request) -> FindTable.Response:
|
88
|
+
return self.client.put("/table/find", req, FindTable.Response)
|
89
|
+
|
90
|
+
def get(self, req: GetTable.Request) -> GetTable.Response:
|
91
|
+
return self.client.put(f"/table/{req.id}", GetTable.Response)
|
92
|
+
|
93
|
+
def list(self, req: ListTables.Request) -> Paged[Table]:
|
94
|
+
return self.client.paged("/table/list", req, ListTables.Response)
|
spiral/api/tokens.py
ADDED
@@ -0,0 +1,56 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from . import Paged, PagedRequest, PagedResponse, ServiceBase
|
4
|
+
|
5
|
+
|
6
|
+
class Token(BaseModel):
|
7
|
+
id: str
|
8
|
+
project_id: str
|
9
|
+
on_behalf_of: str
|
10
|
+
|
11
|
+
|
12
|
+
class ExchangeToken:
|
13
|
+
class Request(BaseModel): ...
|
14
|
+
|
15
|
+
class Response(BaseModel):
|
16
|
+
token: str
|
17
|
+
|
18
|
+
|
19
|
+
class IssueToken:
|
20
|
+
class Request(BaseModel): ...
|
21
|
+
|
22
|
+
class Response(BaseModel):
|
23
|
+
token: Token
|
24
|
+
token_secret: str
|
25
|
+
|
26
|
+
|
27
|
+
class RevokeToken:
|
28
|
+
class Request(BaseModel):
|
29
|
+
token_id: str
|
30
|
+
|
31
|
+
class Response(BaseModel):
|
32
|
+
token: Token
|
33
|
+
|
34
|
+
|
35
|
+
class ListTokens:
|
36
|
+
class Request(PagedRequest):
|
37
|
+
project_id: str
|
38
|
+
on_behalf_of: str | None = None
|
39
|
+
|
40
|
+
class Response(PagedResponse[Token]): ...
|
41
|
+
|
42
|
+
|
43
|
+
class TokenService(ServiceBase):
|
44
|
+
def exchange(self) -> ExchangeToken.Response:
|
45
|
+
"""Exchange a basic / identity token to a short-lived Spiral token."""
|
46
|
+
return self.client.post("/token/exchange", ExchangeToken.Request(), ExchangeToken.Response)
|
47
|
+
|
48
|
+
def issue(self) -> IssueToken.Response:
|
49
|
+
"""Issue an API token on behalf of a principal."""
|
50
|
+
return self.client.post("/token/issue", IssueToken.Request(), IssueToken.Response)
|
51
|
+
|
52
|
+
def revoke(self, request: RevokeToken.Request) -> RevokeToken.Response:
|
53
|
+
return self.client.put("/token/revoke", request, RevokeToken.Response)
|
54
|
+
|
55
|
+
def list(self, request: ListTokens.Request) -> Paged[Token]:
|
56
|
+
return self.client.paged("/token/list", request, ListTokens.Response)
|
spiral/api/workloads.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
from pydantic import BaseModel, Field
|
2
|
+
|
3
|
+
from . import Paged, PagedRequest, PagedResponse, ProjectId, ServiceBase
|
4
|
+
|
5
|
+
|
6
|
+
class Workload(BaseModel):
|
7
|
+
id: str
|
8
|
+
project_id: ProjectId
|
9
|
+
name: str | None = None
|
10
|
+
|
11
|
+
|
12
|
+
class CreateWorkload:
|
13
|
+
class Request(BaseModel):
|
14
|
+
project_id: str
|
15
|
+
name: str | None = Field(default=None, description="Optional human-readable name for the workload")
|
16
|
+
|
17
|
+
class Response(BaseModel):
|
18
|
+
workload: Workload
|
19
|
+
|
20
|
+
|
21
|
+
class IssueToken:
|
22
|
+
class Request(BaseModel):
|
23
|
+
workload_id: str
|
24
|
+
|
25
|
+
class Response(BaseModel):
|
26
|
+
token_id: str
|
27
|
+
token_secret: str
|
28
|
+
|
29
|
+
|
30
|
+
class ListWorkloads:
|
31
|
+
class Request(PagedRequest):
|
32
|
+
project_id: str
|
33
|
+
|
34
|
+
class Response(PagedResponse[Workload]): ...
|
35
|
+
|
36
|
+
|
37
|
+
class WorkloadService(ServiceBase):
|
38
|
+
def create(self, request: CreateWorkload.Request) -> CreateWorkload.Response:
|
39
|
+
return self.client.post("/workload/create", request, CreateWorkload.Response)
|
40
|
+
|
41
|
+
def issue_token(self, request: IssueToken.Request) -> IssueToken.Response:
|
42
|
+
return self.client.post("/workload/issue-token", request, IssueToken.Response)
|
43
|
+
|
44
|
+
def list(self, request: ListWorkloads.Request) -> Paged[Workload]:
|
45
|
+
return self.client.paged("/workload/list", request, ListWorkloads.Response)
|
spiral/arrow.py
ADDED
@@ -0,0 +1,209 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from collections.abc import Callable, Iterable
|
3
|
+
from functools import reduce
|
4
|
+
from typing import TypeVar
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import pyarrow as pa
|
8
|
+
from pyarrow import compute as pc
|
9
|
+
|
10
|
+
T = TypeVar("T")
|
11
|
+
|
12
|
+
|
13
|
+
def arange(*args, **kwargs) -> pa.Array:
|
14
|
+
return pa.array(np.arange(*args, **kwargs), type=pa.int32())
|
15
|
+
|
16
|
+
|
17
|
+
def zip_tables(tables: Iterable[pa.Table]) -> pa.Table:
|
18
|
+
data = []
|
19
|
+
names = []
|
20
|
+
for table in tables:
|
21
|
+
data.extend(table.columns)
|
22
|
+
names.extend(table.column_names)
|
23
|
+
return pa.Table.from_arrays(data, names=names)
|
24
|
+
|
25
|
+
|
26
|
+
def merge_arrays(*arrays: pa.StructArray) -> pa.StructArray:
|
27
|
+
"""Recursively merge arrays into nested struct arrays."""
|
28
|
+
if len(arrays) == 1:
|
29
|
+
return arrays[0]
|
30
|
+
|
31
|
+
nstructs = sum(pa.types.is_struct(a.type) for a in arrays)
|
32
|
+
if nstructs == 0:
|
33
|
+
# Then we have conflicting arrays and we choose the last.
|
34
|
+
return arrays[-1]
|
35
|
+
|
36
|
+
if nstructs != len(arrays):
|
37
|
+
raise ValueError("Cannot merge structs with non-structs.")
|
38
|
+
|
39
|
+
data = defaultdict(list)
|
40
|
+
for array in arrays:
|
41
|
+
if isinstance(array, pa.ChunkedArray):
|
42
|
+
array = array.combine_chunks()
|
43
|
+
for field in array.type:
|
44
|
+
data[field.name].append(array.field(field.name))
|
45
|
+
|
46
|
+
return pa.StructArray.from_arrays([merge_arrays(*v) for v in data.values()], names=list(data.keys()))
|
47
|
+
|
48
|
+
|
49
|
+
def merge_scalars(*scalars: pa.StructScalar) -> pa.StructScalar:
|
50
|
+
"""Recursively merge scalars into nested struct scalars."""
|
51
|
+
if len(scalars) == 1:
|
52
|
+
return scalars[0]
|
53
|
+
|
54
|
+
nstructs = sum(pa.types.is_struct(a.type) for a in scalars)
|
55
|
+
if nstructs == 0:
|
56
|
+
# Then we have conflicting scalars and we choose the last.
|
57
|
+
return scalars[-1]
|
58
|
+
|
59
|
+
if nstructs != len(scalars):
|
60
|
+
raise ValueError("Cannot merge scalars with non-scalars.")
|
61
|
+
|
62
|
+
data = defaultdict(list)
|
63
|
+
for scalar in scalars:
|
64
|
+
for field in scalar.type:
|
65
|
+
data[field.name].append(scalar[field.name])
|
66
|
+
|
67
|
+
return pa.scalar({k: merge_scalars(*v) for k, v in data.items()})
|
68
|
+
|
69
|
+
|
70
|
+
def null_table(schema: pa.Schema, length: int = 0) -> pa.Table:
|
71
|
+
# We add an extra nulls column to ensure the length is correctly applied.
|
72
|
+
return pa.table(
|
73
|
+
[pa.nulls(length, type=field.type) for field in schema] + [pa.nulls(length)],
|
74
|
+
schema=pa.schema(list(schema) + [pa.field("__", type=pa.null())]),
|
75
|
+
).drop(["__"])
|
76
|
+
|
77
|
+
|
78
|
+
def coalesce_all(table: pa.Table) -> pa.Table:
|
79
|
+
"""Coalesce all columns that share the same name."""
|
80
|
+
columns: dict[str, list[pa.Array]] = defaultdict(list)
|
81
|
+
for i, col in enumerate(table.column_names):
|
82
|
+
columns[col].append(table[i])
|
83
|
+
|
84
|
+
data = []
|
85
|
+
names = []
|
86
|
+
for col, arrays in columns.items():
|
87
|
+
names.append(col)
|
88
|
+
if len(arrays) == 1:
|
89
|
+
data.append(arrays[0])
|
90
|
+
else:
|
91
|
+
data.append(pc.coalesce(*arrays))
|
92
|
+
|
93
|
+
return pa.Table.from_arrays(data, names=names)
|
94
|
+
|
95
|
+
|
96
|
+
def join(left: pa.Table, right: pa.Table, keys: list[str], join_type: str) -> pa.Table:
|
97
|
+
"""Arrow's builtin join doesn't support struct columns. So we join ourselves and zip them in."""
|
98
|
+
# TODO(ngates): if join_type == inner, we may have better luck performing two index_in operations since this
|
99
|
+
# also preserves sort order.
|
100
|
+
lhs = left.select(keys).add_column(0, "__lhs", arange(len(left)))
|
101
|
+
rhs = right.select(keys).add_column(0, "__rhs", arange(len(right)))
|
102
|
+
joined = lhs.join(rhs, keys=keys, join_type=join_type).sort_by([(k, "ascending") for k in keys])
|
103
|
+
return zip_tables(
|
104
|
+
[joined.select(keys), left.take(joined["__lhs"]).drop(keys), right.take(joined["__rhs"]).drop(keys)]
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
def nest_structs(array: pa.StructArray | pa.StructScalar | dict) -> dict:
|
109
|
+
"""Turn a struct-like value with dot-separated column names into a nested dictionary."""
|
110
|
+
data = {}
|
111
|
+
|
112
|
+
if isinstance(array, pa.StructArray | pa.StructScalar):
|
113
|
+
array = {f.name: field(array, f.name) for f in array.type}
|
114
|
+
|
115
|
+
for name in array.keys():
|
116
|
+
if "." not in name:
|
117
|
+
data[name] = array[name]
|
118
|
+
continue
|
119
|
+
|
120
|
+
parts = name.split(".")
|
121
|
+
child_data = data
|
122
|
+
for part in parts[:-1]:
|
123
|
+
if part not in child_data:
|
124
|
+
child_data[part] = {}
|
125
|
+
child_data = child_data[part]
|
126
|
+
child_data[parts[-1]] = array[name]
|
127
|
+
|
128
|
+
return data
|
129
|
+
|
130
|
+
|
131
|
+
def flatten_struct_table(table: pa.Table, separator=".") -> pa.Table:
|
132
|
+
"""Turn a nested struct table into a flat table with dot-separated names."""
|
133
|
+
data = []
|
134
|
+
names = []
|
135
|
+
|
136
|
+
def _unfold(array: pa.Array, prefix: str):
|
137
|
+
if pa.types.is_struct(array.type):
|
138
|
+
if isinstance(array, pa.ChunkedArray):
|
139
|
+
array = array.combine_chunks()
|
140
|
+
for f in array.type:
|
141
|
+
_unfold(field(array, f.name), f"{prefix}{separator}{f.name}")
|
142
|
+
else:
|
143
|
+
data.append(array)
|
144
|
+
names.append(prefix)
|
145
|
+
|
146
|
+
for col in table.column_names:
|
147
|
+
_unfold(table[col], col)
|
148
|
+
|
149
|
+
return pa.Table.from_arrays(data, names=names)
|
150
|
+
|
151
|
+
|
152
|
+
def dict_to_table(data) -> pa.Table:
|
153
|
+
return pa.Table.from_struct_array(dict_to_struct_array(data))
|
154
|
+
|
155
|
+
|
156
|
+
def dict_to_struct_array(data, propagate_nulls: bool = False) -> pa.StructArray:
|
157
|
+
"""Convert a nested dictionary of arrays to a table with nested structs."""
|
158
|
+
if isinstance(data, pa.ChunkedArray):
|
159
|
+
return data.combine_chunks()
|
160
|
+
if isinstance(data, pa.Array):
|
161
|
+
return data
|
162
|
+
arrays = [dict_to_struct_array(value) for value in data.values()]
|
163
|
+
return pa.StructArray.from_arrays(
|
164
|
+
arrays,
|
165
|
+
names=list(data.keys()),
|
166
|
+
mask=reduce(pc.and_, [pc.is_null(array) for array in arrays]) if propagate_nulls else None,
|
167
|
+
)
|
168
|
+
|
169
|
+
|
170
|
+
def struct_array_to_dict(array: pa.StructArray, array_fn: Callable[[pa.Array], T] = lambda a: a) -> dict | T:
|
171
|
+
"""Convert a struct array to a nested dictionary."""
|
172
|
+
if not pa.types.is_struct(array.type):
|
173
|
+
return array_fn(array)
|
174
|
+
if isinstance(array, pa.ChunkedArray):
|
175
|
+
array = array.combine_chunks()
|
176
|
+
return {field.name: struct_array_to_dict(array.field(i), array_fn=array_fn) for i, field in enumerate(array.type)}
|
177
|
+
|
178
|
+
|
179
|
+
def table_to_struct_array(table: pa.Table) -> pa.StructArray:
|
180
|
+
if not table.num_rows:
|
181
|
+
return pa.array([], type=pa.struct(table.schema))
|
182
|
+
array = table.to_struct_array()
|
183
|
+
if isinstance(array, pa.ChunkedArray):
|
184
|
+
array = array.combine_chunks()
|
185
|
+
return array
|
186
|
+
|
187
|
+
|
188
|
+
def table_from_struct_array(array: pa.StructArray | pa.ChunkedArray):
|
189
|
+
if len(array) == 0:
|
190
|
+
return null_table(pa.schema(array.type))
|
191
|
+
return pa.Table.from_struct_array(array)
|
192
|
+
|
193
|
+
|
194
|
+
def field(value: pa.StructArray | pa.StructScalar, name: str) -> pa.Array | pa.Scalar:
|
195
|
+
"""Get a field from a struct-like value."""
|
196
|
+
if isinstance(value, pa.StructScalar):
|
197
|
+
return value[name]
|
198
|
+
return value.field(name)
|
199
|
+
|
200
|
+
|
201
|
+
def concat_tables(tables: list[pa.Table]) -> pa.Table:
|
202
|
+
"""
|
203
|
+
Concatenate pyarrow.Table objects, filling "missing" data with appropriate null arrays
|
204
|
+
and casting arrays to the most common denominator type that fits all fields.
|
205
|
+
"""
|
206
|
+
if len(tables) == 1:
|
207
|
+
return tables[0]
|
208
|
+
else:
|
209
|
+
return pa.concat_tables(tables, promote_options="permissive")
|
spiral/authn/__init__.py
ADDED
File without changes
|
spiral/authn/authn.py
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
import base64
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
|
5
|
+
from spiral.api import Authn, SpiralAPI
|
6
|
+
|
7
|
+
ENV_TOKEN_ID = "SPIRAL_TOKEN_ID"
|
8
|
+
ENV_TOKEN_SECRET = "SPIRAL_TOKEN_SECRET"
|
9
|
+
|
10
|
+
log = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class FallbackAuthn(Authn):
|
14
|
+
"""Credential provider that tries multiple providers in order."""
|
15
|
+
|
16
|
+
def __init__(self, providers: list[Authn]):
|
17
|
+
self._providers = providers
|
18
|
+
|
19
|
+
def token(self) -> str | None:
|
20
|
+
for provider in self._providers:
|
21
|
+
token = provider.token()
|
22
|
+
if token is not None:
|
23
|
+
return token
|
24
|
+
return None
|
25
|
+
|
26
|
+
|
27
|
+
class TokenAuthn(Authn):
|
28
|
+
"""Credential provider that returns a fixed token."""
|
29
|
+
|
30
|
+
def __init__(self, token: str):
|
31
|
+
self._token = token
|
32
|
+
|
33
|
+
def token(self) -> str:
|
34
|
+
return self._token
|
35
|
+
|
36
|
+
|
37
|
+
class EnvironmentAuthn(Authn):
|
38
|
+
"""Credential provider that returns a basic token from the environment.
|
39
|
+
|
40
|
+
NOTE: Returns basic token. Must be exchanged.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def token(self) -> str | None:
|
44
|
+
if ENV_TOKEN_ID not in os.environ:
|
45
|
+
return None
|
46
|
+
if ENV_TOKEN_SECRET not in os.environ:
|
47
|
+
raise ValueError(f"{ENV_TOKEN_SECRET} is missing.")
|
48
|
+
|
49
|
+
token_id = os.environ[ENV_TOKEN_ID]
|
50
|
+
token_secret = os.environ[ENV_TOKEN_SECRET]
|
51
|
+
basic_token = base64.b64encode(f"{token_id}:{token_secret}".encode()).decode("utf-8")
|
52
|
+
|
53
|
+
return basic_token
|
54
|
+
|
55
|
+
|
56
|
+
class DeviceAuthProvider(Authn):
|
57
|
+
"""Auth provider that uses the device flow to authenticate a Spiral user."""
|
58
|
+
|
59
|
+
def __init__(self, device_auth):
|
60
|
+
# NOTE(ngates): device_auth: spiral.auth.device_code.DeviceAuth
|
61
|
+
# We don't type it to satisfy our import linter
|
62
|
+
self._device_auth = device_auth
|
63
|
+
|
64
|
+
def token(self) -> str | None:
|
65
|
+
# TODO(ngates): only run this if we're in a notebook, CLI, or otherwise on the user's machine.
|
66
|
+
return self._device_auth.authenticate().access_token
|
67
|
+
|
68
|
+
|
69
|
+
class TokenExchangeProvider(Authn):
|
70
|
+
"""Auth provider that exchanges a basic token for a Spiral token."""
|
71
|
+
|
72
|
+
def __init__(self, authn: Authn, base_url: str):
|
73
|
+
self._authn = authn
|
74
|
+
self._token_service = SpiralAPI(authn, base_url).token
|
75
|
+
|
76
|
+
self._sp_token = None
|
77
|
+
|
78
|
+
def token(self) -> str | None:
|
79
|
+
if self._sp_token is not None:
|
80
|
+
return self._sp_token
|
81
|
+
|
82
|
+
# Don't try to exchange if token is not discovered.
|
83
|
+
if self._authn.token() is None:
|
84
|
+
return None
|
85
|
+
|
86
|
+
log.debug("Exchanging token")
|
87
|
+
self._sp_token = self._token_service.exchange().token
|
88
|
+
|
89
|
+
return self._sp_token
|
spiral/authn/device.py
ADDED
@@ -0,0 +1,206 @@
|
|
1
|
+
import logging
|
2
|
+
import sys
|
3
|
+
import textwrap
|
4
|
+
import time
|
5
|
+
import webbrowser
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
import httpx
|
9
|
+
import jwt
|
10
|
+
from pydantic import BaseModel
|
11
|
+
|
12
|
+
log = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class TokensModel(BaseModel):
|
16
|
+
access_token: str
|
17
|
+
refresh_token: str
|
18
|
+
|
19
|
+
@property
|
20
|
+
def organization_id(self) -> str | None:
|
21
|
+
return self.unverified_access_token().get("org_id")
|
22
|
+
|
23
|
+
def unverified_access_token(self):
|
24
|
+
return jwt.decode(self.access_token, options={"verify_signature": False})
|
25
|
+
|
26
|
+
|
27
|
+
class AuthModel(BaseModel):
|
28
|
+
tokens: TokensModel | None = None
|
29
|
+
|
30
|
+
|
31
|
+
class DeviceAuth:
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
auth_file: Path,
|
35
|
+
domain: str,
|
36
|
+
client_id: str,
|
37
|
+
http: httpx.Client = None,
|
38
|
+
):
|
39
|
+
self._auth_file = auth_file
|
40
|
+
self._domain = domain
|
41
|
+
self._client_id = client_id
|
42
|
+
self._http = http or httpx.Client()
|
43
|
+
|
44
|
+
if self._auth_file.exists():
|
45
|
+
with self._auth_file.open("r") as f:
|
46
|
+
self._auth = AuthModel.model_validate_json(f.read())
|
47
|
+
else:
|
48
|
+
self._auth = AuthModel()
|
49
|
+
|
50
|
+
self._default_scope = ["email", "profile"]
|
51
|
+
|
52
|
+
def is_authenticated(self) -> bool:
|
53
|
+
"""Check if the user is authenticated."""
|
54
|
+
tokens = self._auth.tokens
|
55
|
+
if tokens is None:
|
56
|
+
return False
|
57
|
+
|
58
|
+
# Give ourselves a 30-second buffer before the token expires.
|
59
|
+
return tokens.unverified_access_token()["exp"] - 30 > time.time()
|
60
|
+
|
61
|
+
def authenticate(self, force: bool = False, refresh: bool = False, organization_id: str = None) -> TokensModel:
|
62
|
+
"""Blocking call to authenticate the user.
|
63
|
+
|
64
|
+
Triggers a device code flow and polls for the user to login.
|
65
|
+
"""
|
66
|
+
if force:
|
67
|
+
return self._device_code(organization_id)
|
68
|
+
|
69
|
+
if refresh:
|
70
|
+
if self._auth.tokens is None:
|
71
|
+
raise ValueError("No tokens to refresh.")
|
72
|
+
tokens = self._refresh(self._auth.tokens, organization_id)
|
73
|
+
if not tokens:
|
74
|
+
raise ValueError("Failed to refresh token.")
|
75
|
+
return tokens
|
76
|
+
|
77
|
+
# Check for mis-matched organization.
|
78
|
+
if organization_id is not None:
|
79
|
+
tokens = self._auth.tokens
|
80
|
+
if tokens is not None and tokens.unverified_access_token().get("org_id") != organization_id:
|
81
|
+
tokens = self._refresh(self._auth.tokens, organization_id)
|
82
|
+
if tokens is None:
|
83
|
+
return self._device_code(organization_id)
|
84
|
+
|
85
|
+
if self.is_authenticated():
|
86
|
+
return self._auth.tokens
|
87
|
+
|
88
|
+
# Try to refresh.
|
89
|
+
tokens = self._auth.tokens
|
90
|
+
if tokens is not None:
|
91
|
+
tokens = self._refresh(tokens)
|
92
|
+
if tokens is not None:
|
93
|
+
return tokens
|
94
|
+
|
95
|
+
# Otherwise, we kick off the device code flow.
|
96
|
+
return self._device_code(organization_id)
|
97
|
+
|
98
|
+
def logout(self):
|
99
|
+
self._remove_tokens()
|
100
|
+
|
101
|
+
def _device_code(self, organization_id: str | None):
|
102
|
+
scope = " ".join(self._default_scope)
|
103
|
+
res = self._http.post(
|
104
|
+
f"{self._domain}/auth/device/code",
|
105
|
+
data={
|
106
|
+
"client_id": self._client_id,
|
107
|
+
"scope": scope,
|
108
|
+
"organization_id": organization_id,
|
109
|
+
},
|
110
|
+
)
|
111
|
+
res = res.raise_for_status().json()
|
112
|
+
device_code = res["device_code"]
|
113
|
+
user_code = res["user_code"]
|
114
|
+
expires_at = res["expires_in"] + time.time()
|
115
|
+
interval = res["interval"]
|
116
|
+
verification_uri_complete = res["verification_uri_complete"]
|
117
|
+
|
118
|
+
# We need to detect if the user is running in a terminal, in Jupyter, etc.
|
119
|
+
# For now, we'll try to open the browser.
|
120
|
+
sys.stderr.write(
|
121
|
+
textwrap.dedent(
|
122
|
+
f"""
|
123
|
+
Please login here: {verification_uri_complete}
|
124
|
+
Your code is {user_code}.
|
125
|
+
"""
|
126
|
+
)
|
127
|
+
)
|
128
|
+
|
129
|
+
# Try to open the browser (this also works if the Jupiter notebook is running on the user's machine).
|
130
|
+
opened = webbrowser.open(verification_uri_complete)
|
131
|
+
|
132
|
+
# If we have a server-side Jupyter notebook, we can try to open with client-side JavaScript.
|
133
|
+
if not opened and _in_notebook():
|
134
|
+
from IPython.display import Javascript, display
|
135
|
+
|
136
|
+
display(Javascript(f'window.open("{verification_uri_complete}");'))
|
137
|
+
|
138
|
+
# In the meantime, we need to poll for the user to login.
|
139
|
+
while True:
|
140
|
+
if time.time() > expires_at:
|
141
|
+
raise TimeoutError("Login timed out.")
|
142
|
+
time.sleep(interval)
|
143
|
+
res = self._http.post(
|
144
|
+
f"{self._domain}/auth/token",
|
145
|
+
data={
|
146
|
+
"client_id": self._client_id,
|
147
|
+
"device_code": device_code,
|
148
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
149
|
+
},
|
150
|
+
)
|
151
|
+
if not res.is_success:
|
152
|
+
continue
|
153
|
+
|
154
|
+
tokens = TokensModel(
|
155
|
+
access_token=res.json()["access_token"],
|
156
|
+
refresh_token=res.json()["refresh_token"],
|
157
|
+
)
|
158
|
+
self._save_tokens(tokens)
|
159
|
+
return self._auth.tokens
|
160
|
+
|
161
|
+
def _refresh(self, tokens: TokensModel, organization_id: str = None) -> TokensModel | None:
|
162
|
+
"""Attempt to use the refresh token."""
|
163
|
+
log.debug("Refreshing token %s", self._client_id)
|
164
|
+
|
165
|
+
res = self._http.post(
|
166
|
+
f"{self._domain}/auth/refresh",
|
167
|
+
data={
|
168
|
+
"client_id": self._client_id,
|
169
|
+
"grant_type": "refresh_token",
|
170
|
+
"refresh_token": tokens.refresh_token,
|
171
|
+
"organization_id": organization_id,
|
172
|
+
},
|
173
|
+
)
|
174
|
+
if not res.is_success:
|
175
|
+
print("Failed to refresh token", res.status_code, res.text)
|
176
|
+
return None
|
177
|
+
|
178
|
+
tokens = TokensModel(
|
179
|
+
access_token=res.json()["access_token"],
|
180
|
+
refresh_token=res.json()["refresh_token"],
|
181
|
+
)
|
182
|
+
self._save_tokens(tokens)
|
183
|
+
return tokens
|
184
|
+
|
185
|
+
def _save_tokens(self, tokens: TokensModel):
|
186
|
+
self._auth = self._auth.model_copy(update={"tokens": tokens})
|
187
|
+
self._auth_file.parent.mkdir(parents=True, exist_ok=True)
|
188
|
+
with self._auth_file.open("w") as f:
|
189
|
+
f.write(self._auth.model_dump_json(exclude_defaults=True))
|
190
|
+
|
191
|
+
def _remove_tokens(self):
|
192
|
+
self._auth_file.unlink(missing_ok=True)
|
193
|
+
self._auth = self._auth.model_copy(update={"tokens": None})
|
194
|
+
|
195
|
+
|
196
|
+
def _in_notebook():
|
197
|
+
try:
|
198
|
+
from IPython import get_ipython
|
199
|
+
|
200
|
+
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
201
|
+
return False
|
202
|
+
except ImportError:
|
203
|
+
return False
|
204
|
+
except AttributeError:
|
205
|
+
return False
|
206
|
+
return True
|