pyspiral 0.1.0__cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pyspiral-0.1.0.dist-info/METADATA +48 -0
- pyspiral-0.1.0.dist-info/RECORD +81 -0
- pyspiral-0.1.0.dist-info/WHEEL +4 -0
- pyspiral-0.1.0.dist-info/entry_points.txt +2 -0
- spiral/__init__.py +11 -0
- spiral/_lib.abi3.so +0 -0
- spiral/adbc.py +386 -0
- spiral/api/__init__.py +221 -0
- spiral/api/admin.py +29 -0
- spiral/api/filesystems.py +125 -0
- spiral/api/organizations.py +90 -0
- spiral/api/projects.py +160 -0
- spiral/api/tables.py +94 -0
- spiral/api/tokens.py +56 -0
- spiral/api/workloads.py +45 -0
- spiral/arrow.py +209 -0
- spiral/authn/__init__.py +0 -0
- spiral/authn/authn.py +89 -0
- spiral/authn/device.py +206 -0
- spiral/authn/github_.py +33 -0
- spiral/authn/modal_.py +18 -0
- spiral/catalog.py +78 -0
- spiral/cli/__init__.py +82 -0
- spiral/cli/__main__.py +4 -0
- spiral/cli/admin.py +21 -0
- spiral/cli/app.py +48 -0
- spiral/cli/console.py +95 -0
- spiral/cli/fs.py +47 -0
- spiral/cli/login.py +13 -0
- spiral/cli/org.py +90 -0
- spiral/cli/printer.py +45 -0
- spiral/cli/project.py +107 -0
- spiral/cli/state.py +3 -0
- spiral/cli/table.py +20 -0
- spiral/cli/token.py +27 -0
- spiral/cli/types.py +53 -0
- spiral/cli/workload.py +59 -0
- spiral/config.py +26 -0
- spiral/core/__init__.py +0 -0
- spiral/core/core/__init__.pyi +53 -0
- spiral/core/manifests/__init__.pyi +53 -0
- spiral/core/metastore/__init__.pyi +91 -0
- spiral/core/spec/__init__.pyi +257 -0
- spiral/dataset.py +239 -0
- spiral/debug.py +251 -0
- spiral/expressions/__init__.py +222 -0
- spiral/expressions/base.py +149 -0
- spiral/expressions/http.py +86 -0
- spiral/expressions/io.py +100 -0
- spiral/expressions/list_.py +68 -0
- spiral/expressions/refs.py +44 -0
- spiral/expressions/str_.py +39 -0
- spiral/expressions/struct.py +57 -0
- spiral/expressions/tiff.py +223 -0
- spiral/expressions/udf.py +46 -0
- spiral/grpc_.py +32 -0
- spiral/project.py +137 -0
- spiral/proto/_/__init__.py +0 -0
- spiral/proto/_/arrow/__init__.py +0 -0
- spiral/proto/_/arrow/flight/__init__.py +0 -0
- spiral/proto/_/arrow/flight/protocol/__init__.py +0 -0
- spiral/proto/_/arrow/flight/protocol/sql/__init__.py +1990 -0
- spiral/proto/_/scandal/__init__.py +223 -0
- spiral/proto/_/spfs/__init__.py +36 -0
- spiral/proto/_/spiral/__init__.py +0 -0
- spiral/proto/_/spiral/table/__init__.py +225 -0
- spiral/proto/_/spiraldb/__init__.py +0 -0
- spiral/proto/_/spiraldb/metastore/__init__.py +499 -0
- spiral/proto/__init__.py +0 -0
- spiral/proto/scandal/__init__.py +45 -0
- spiral/proto/spiral/__init__.py +0 -0
- spiral/proto/spiral/table/__init__.py +96 -0
- spiral/proto/substrait/__init__.py +3399 -0
- spiral/proto/substrait/extensions/__init__.py +115 -0
- spiral/proto/util.py +41 -0
- spiral/py.typed +0 -0
- spiral/scan_.py +168 -0
- spiral/settings.py +157 -0
- spiral/substrait_.py +275 -0
- spiral/table.py +157 -0
- spiral/types_.py +6 -0
spiral/api/tables.py
ADDED
@@ -0,0 +1,94 @@
|
|
1
|
+
from typing import Annotated
|
2
|
+
|
3
|
+
from pydantic import (
|
4
|
+
AfterValidator,
|
5
|
+
BaseModel,
|
6
|
+
StringConstraints,
|
7
|
+
)
|
8
|
+
|
9
|
+
from . import ArrowSchema, Paged, PagedRequest, PagedResponse, ProjectId, ServiceBase
|
10
|
+
|
11
|
+
|
12
|
+
def _validate_root_uri(uri: str) -> str:
|
13
|
+
if uri.endswith("/"):
|
14
|
+
raise ValueError("Root URI must not end with a slash.")
|
15
|
+
return uri
|
16
|
+
|
17
|
+
|
18
|
+
RootUri = Annotated[str, AfterValidator(_validate_root_uri)]
|
19
|
+
DatasetName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
|
20
|
+
TableName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
|
21
|
+
|
22
|
+
|
23
|
+
class TableMetadata(BaseModel):
|
24
|
+
key_schema: ArrowSchema
|
25
|
+
root_uri: RootUri
|
26
|
+
spfs_mount_id: str | None = None
|
27
|
+
|
28
|
+
# TODO(marko): Randomize this on creation of metadata.
|
29
|
+
# Column group salt is used to compute column group IDs.
|
30
|
+
# It's used to ensure that column group IDs are unique
|
31
|
+
# across different tables, even if paths are the same.
|
32
|
+
# It's never modified.
|
33
|
+
column_group_salt: int = 0
|
34
|
+
|
35
|
+
|
36
|
+
class Table(BaseModel):
|
37
|
+
id: str
|
38
|
+
project_id: ProjectId
|
39
|
+
dataset: DatasetName
|
40
|
+
table: TableName
|
41
|
+
metadata: TableMetadata
|
42
|
+
|
43
|
+
|
44
|
+
class CreateTable:
|
45
|
+
class Request(BaseModel):
|
46
|
+
project_id: ProjectId
|
47
|
+
dataset: DatasetName
|
48
|
+
table: TableName
|
49
|
+
key_schema: ArrowSchema
|
50
|
+
root_uri: RootUri | None = None
|
51
|
+
exist_ok: bool = False
|
52
|
+
|
53
|
+
class Response(BaseModel):
|
54
|
+
table: Table
|
55
|
+
|
56
|
+
|
57
|
+
class FindTable:
|
58
|
+
class Request(BaseModel):
|
59
|
+
project_id: ProjectId
|
60
|
+
dataset: DatasetName = None
|
61
|
+
table: TableName = None
|
62
|
+
|
63
|
+
class Response(BaseModel):
|
64
|
+
table: Table | None
|
65
|
+
|
66
|
+
|
67
|
+
class GetTable:
|
68
|
+
class Request(BaseModel):
|
69
|
+
id: str
|
70
|
+
|
71
|
+
class Response(BaseModel):
|
72
|
+
table: Table
|
73
|
+
|
74
|
+
|
75
|
+
class ListTables:
|
76
|
+
class Request(PagedRequest):
|
77
|
+
project_id: ProjectId
|
78
|
+
dataset: DatasetName | None = None
|
79
|
+
|
80
|
+
class Response(PagedResponse[Table]): ...
|
81
|
+
|
82
|
+
|
83
|
+
class TableService(ServiceBase):
|
84
|
+
def create(self, req: CreateTable.Request) -> CreateTable.Response:
|
85
|
+
return self.client.post("/table/create", req, CreateTable.Response)
|
86
|
+
|
87
|
+
def find(self, req: FindTable.Request) -> FindTable.Response:
|
88
|
+
return self.client.put("/table/find", req, FindTable.Response)
|
89
|
+
|
90
|
+
def get(self, req: GetTable.Request) -> GetTable.Response:
|
91
|
+
return self.client.put(f"/table/{req.id}", GetTable.Response)
|
92
|
+
|
93
|
+
def list(self, req: ListTables.Request) -> Paged[Table]:
|
94
|
+
return self.client.paged("/table/list", req, ListTables.Response)
|
spiral/api/tokens.py
ADDED
@@ -0,0 +1,56 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from . import Paged, PagedRequest, PagedResponse, ServiceBase
|
4
|
+
|
5
|
+
|
6
|
+
class Token(BaseModel):
|
7
|
+
id: str
|
8
|
+
project_id: str
|
9
|
+
on_behalf_of: str
|
10
|
+
|
11
|
+
|
12
|
+
class ExchangeToken:
|
13
|
+
class Request(BaseModel): ...
|
14
|
+
|
15
|
+
class Response(BaseModel):
|
16
|
+
token: str
|
17
|
+
|
18
|
+
|
19
|
+
class IssueToken:
|
20
|
+
class Request(BaseModel): ...
|
21
|
+
|
22
|
+
class Response(BaseModel):
|
23
|
+
token: Token
|
24
|
+
token_secret: str
|
25
|
+
|
26
|
+
|
27
|
+
class RevokeToken:
|
28
|
+
class Request(BaseModel):
|
29
|
+
token_id: str
|
30
|
+
|
31
|
+
class Response(BaseModel):
|
32
|
+
token: Token
|
33
|
+
|
34
|
+
|
35
|
+
class ListTokens:
|
36
|
+
class Request(PagedRequest):
|
37
|
+
project_id: str
|
38
|
+
on_behalf_of: str | None = None
|
39
|
+
|
40
|
+
class Response(PagedResponse[Token]): ...
|
41
|
+
|
42
|
+
|
43
|
+
class TokenService(ServiceBase):
|
44
|
+
def exchange(self) -> ExchangeToken.Response:
|
45
|
+
"""Exchange a basic / identity token to a short-lived Spiral token."""
|
46
|
+
return self.client.post("/token/exchange", ExchangeToken.Request(), ExchangeToken.Response)
|
47
|
+
|
48
|
+
def issue(self) -> IssueToken.Response:
|
49
|
+
"""Issue an API token on behalf of a principal."""
|
50
|
+
return self.client.post("/token/issue", IssueToken.Request(), IssueToken.Response)
|
51
|
+
|
52
|
+
def revoke(self, request: RevokeToken.Request) -> RevokeToken.Response:
|
53
|
+
return self.client.put("/token/revoke", request, RevokeToken.Response)
|
54
|
+
|
55
|
+
def list(self, request: ListTokens.Request) -> Paged[Token]:
|
56
|
+
return self.client.paged("/token/list", request, ListTokens.Response)
|
spiral/api/workloads.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
from pydantic import BaseModel, Field
|
2
|
+
|
3
|
+
from . import Paged, PagedRequest, PagedResponse, ProjectId, ServiceBase
|
4
|
+
|
5
|
+
|
6
|
+
class Workload(BaseModel):
|
7
|
+
id: str
|
8
|
+
project_id: ProjectId
|
9
|
+
name: str | None = None
|
10
|
+
|
11
|
+
|
12
|
+
class CreateWorkload:
|
13
|
+
class Request(BaseModel):
|
14
|
+
project_id: str
|
15
|
+
name: str | None = Field(default=None, description="Optional human-readable name for the workload")
|
16
|
+
|
17
|
+
class Response(BaseModel):
|
18
|
+
workload: Workload
|
19
|
+
|
20
|
+
|
21
|
+
class IssueToken:
|
22
|
+
class Request(BaseModel):
|
23
|
+
workload_id: str
|
24
|
+
|
25
|
+
class Response(BaseModel):
|
26
|
+
token_id: str
|
27
|
+
token_secret: str
|
28
|
+
|
29
|
+
|
30
|
+
class ListWorkloads:
|
31
|
+
class Request(PagedRequest):
|
32
|
+
project_id: str
|
33
|
+
|
34
|
+
class Response(PagedResponse[Workload]): ...
|
35
|
+
|
36
|
+
|
37
|
+
class WorkloadService(ServiceBase):
|
38
|
+
def create(self, request: CreateWorkload.Request) -> CreateWorkload.Response:
|
39
|
+
return self.client.post("/workload/create", request, CreateWorkload.Response)
|
40
|
+
|
41
|
+
def issue_token(self, request: IssueToken.Request) -> IssueToken.Response:
|
42
|
+
return self.client.post("/workload/issue-token", request, IssueToken.Response)
|
43
|
+
|
44
|
+
def list(self, request: ListWorkloads.Request) -> Paged[Workload]:
|
45
|
+
return self.client.paged("/workload/list", request, ListWorkloads.Response)
|
spiral/arrow.py
ADDED
@@ -0,0 +1,209 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from collections.abc import Callable, Iterable
|
3
|
+
from functools import reduce
|
4
|
+
from typing import TypeVar
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import pyarrow as pa
|
8
|
+
from pyarrow import compute as pc
|
9
|
+
|
10
|
+
T = TypeVar("T")
|
11
|
+
|
12
|
+
|
13
|
+
def arange(*args, **kwargs) -> pa.Array:
|
14
|
+
return pa.array(np.arange(*args, **kwargs), type=pa.int32())
|
15
|
+
|
16
|
+
|
17
|
+
def zip_tables(tables: Iterable[pa.Table]) -> pa.Table:
|
18
|
+
data = []
|
19
|
+
names = []
|
20
|
+
for table in tables:
|
21
|
+
data.extend(table.columns)
|
22
|
+
names.extend(table.column_names)
|
23
|
+
return pa.Table.from_arrays(data, names=names)
|
24
|
+
|
25
|
+
|
26
|
+
def merge_arrays(*arrays: pa.StructArray) -> pa.StructArray:
|
27
|
+
"""Recursively merge arrays into nested struct arrays."""
|
28
|
+
if len(arrays) == 1:
|
29
|
+
return arrays[0]
|
30
|
+
|
31
|
+
nstructs = sum(pa.types.is_struct(a.type) for a in arrays)
|
32
|
+
if nstructs == 0:
|
33
|
+
# Then we have conflicting arrays and we choose the last.
|
34
|
+
return arrays[-1]
|
35
|
+
|
36
|
+
if nstructs != len(arrays):
|
37
|
+
raise ValueError("Cannot merge structs with non-structs.")
|
38
|
+
|
39
|
+
data = defaultdict(list)
|
40
|
+
for array in arrays:
|
41
|
+
if isinstance(array, pa.ChunkedArray):
|
42
|
+
array = array.combine_chunks()
|
43
|
+
for field in array.type:
|
44
|
+
data[field.name].append(array.field(field.name))
|
45
|
+
|
46
|
+
return pa.StructArray.from_arrays([merge_arrays(*v) for v in data.values()], names=list(data.keys()))
|
47
|
+
|
48
|
+
|
49
|
+
def merge_scalars(*scalars: pa.StructScalar) -> pa.StructScalar:
|
50
|
+
"""Recursively merge scalars into nested struct scalars."""
|
51
|
+
if len(scalars) == 1:
|
52
|
+
return scalars[0]
|
53
|
+
|
54
|
+
nstructs = sum(pa.types.is_struct(a.type) for a in scalars)
|
55
|
+
if nstructs == 0:
|
56
|
+
# Then we have conflicting scalars and we choose the last.
|
57
|
+
return scalars[-1]
|
58
|
+
|
59
|
+
if nstructs != len(scalars):
|
60
|
+
raise ValueError("Cannot merge scalars with non-scalars.")
|
61
|
+
|
62
|
+
data = defaultdict(list)
|
63
|
+
for scalar in scalars:
|
64
|
+
for field in scalar.type:
|
65
|
+
data[field.name].append(scalar[field.name])
|
66
|
+
|
67
|
+
return pa.scalar({k: merge_scalars(*v) for k, v in data.items()})
|
68
|
+
|
69
|
+
|
70
|
+
def null_table(schema: pa.Schema, length: int = 0) -> pa.Table:
|
71
|
+
# We add an extra nulls column to ensure the length is correctly applied.
|
72
|
+
return pa.table(
|
73
|
+
[pa.nulls(length, type=field.type) for field in schema] + [pa.nulls(length)],
|
74
|
+
schema=pa.schema(list(schema) + [pa.field("__", type=pa.null())]),
|
75
|
+
).drop(["__"])
|
76
|
+
|
77
|
+
|
78
|
+
def coalesce_all(table: pa.Table) -> pa.Table:
|
79
|
+
"""Coalesce all columns that share the same name."""
|
80
|
+
columns: dict[str, list[pa.Array]] = defaultdict(list)
|
81
|
+
for i, col in enumerate(table.column_names):
|
82
|
+
columns[col].append(table[i])
|
83
|
+
|
84
|
+
data = []
|
85
|
+
names = []
|
86
|
+
for col, arrays in columns.items():
|
87
|
+
names.append(col)
|
88
|
+
if len(arrays) == 1:
|
89
|
+
data.append(arrays[0])
|
90
|
+
else:
|
91
|
+
data.append(pc.coalesce(*arrays))
|
92
|
+
|
93
|
+
return pa.Table.from_arrays(data, names=names)
|
94
|
+
|
95
|
+
|
96
|
+
def join(left: pa.Table, right: pa.Table, keys: list[str], join_type: str) -> pa.Table:
|
97
|
+
"""Arrow's builtin join doesn't support struct columns. So we join ourselves and zip them in."""
|
98
|
+
# TODO(ngates): if join_type == inner, we may have better luck performing two index_in operations since this
|
99
|
+
# also preserves sort order.
|
100
|
+
lhs = left.select(keys).add_column(0, "__lhs", arange(len(left)))
|
101
|
+
rhs = right.select(keys).add_column(0, "__rhs", arange(len(right)))
|
102
|
+
joined = lhs.join(rhs, keys=keys, join_type=join_type).sort_by([(k, "ascending") for k in keys])
|
103
|
+
return zip_tables(
|
104
|
+
[joined.select(keys), left.take(joined["__lhs"]).drop(keys), right.take(joined["__rhs"]).drop(keys)]
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
def nest_structs(array: pa.StructArray | pa.StructScalar | dict) -> dict:
|
109
|
+
"""Turn a struct-like value with dot-separated column names into a nested dictionary."""
|
110
|
+
data = {}
|
111
|
+
|
112
|
+
if isinstance(array, pa.StructArray | pa.StructScalar):
|
113
|
+
array = {f.name: field(array, f.name) for f in array.type}
|
114
|
+
|
115
|
+
for name in array.keys():
|
116
|
+
if "." not in name:
|
117
|
+
data[name] = array[name]
|
118
|
+
continue
|
119
|
+
|
120
|
+
parts = name.split(".")
|
121
|
+
child_data = data
|
122
|
+
for part in parts[:-1]:
|
123
|
+
if part not in child_data:
|
124
|
+
child_data[part] = {}
|
125
|
+
child_data = child_data[part]
|
126
|
+
child_data[parts[-1]] = array[name]
|
127
|
+
|
128
|
+
return data
|
129
|
+
|
130
|
+
|
131
|
+
def flatten_struct_table(table: pa.Table, separator=".") -> pa.Table:
|
132
|
+
"""Turn a nested struct table into a flat table with dot-separated names."""
|
133
|
+
data = []
|
134
|
+
names = []
|
135
|
+
|
136
|
+
def _unfold(array: pa.Array, prefix: str):
|
137
|
+
if pa.types.is_struct(array.type):
|
138
|
+
if isinstance(array, pa.ChunkedArray):
|
139
|
+
array = array.combine_chunks()
|
140
|
+
for f in array.type:
|
141
|
+
_unfold(field(array, f.name), f"{prefix}{separator}{f.name}")
|
142
|
+
else:
|
143
|
+
data.append(array)
|
144
|
+
names.append(prefix)
|
145
|
+
|
146
|
+
for col in table.column_names:
|
147
|
+
_unfold(table[col], col)
|
148
|
+
|
149
|
+
return pa.Table.from_arrays(data, names=names)
|
150
|
+
|
151
|
+
|
152
|
+
def dict_to_table(data) -> pa.Table:
|
153
|
+
return pa.Table.from_struct_array(dict_to_struct_array(data))
|
154
|
+
|
155
|
+
|
156
|
+
def dict_to_struct_array(data, propagate_nulls: bool = False) -> pa.StructArray:
|
157
|
+
"""Convert a nested dictionary of arrays to a table with nested structs."""
|
158
|
+
if isinstance(data, pa.ChunkedArray):
|
159
|
+
return data.combine_chunks()
|
160
|
+
if isinstance(data, pa.Array):
|
161
|
+
return data
|
162
|
+
arrays = [dict_to_struct_array(value) for value in data.values()]
|
163
|
+
return pa.StructArray.from_arrays(
|
164
|
+
arrays,
|
165
|
+
names=list(data.keys()),
|
166
|
+
mask=reduce(pc.and_, [pc.is_null(array) for array in arrays]) if propagate_nulls else None,
|
167
|
+
)
|
168
|
+
|
169
|
+
|
170
|
+
def struct_array_to_dict(array: pa.StructArray, array_fn: Callable[[pa.Array], T] = lambda a: a) -> dict | T:
|
171
|
+
"""Convert a struct array to a nested dictionary."""
|
172
|
+
if not pa.types.is_struct(array.type):
|
173
|
+
return array_fn(array)
|
174
|
+
if isinstance(array, pa.ChunkedArray):
|
175
|
+
array = array.combine_chunks()
|
176
|
+
return {field.name: struct_array_to_dict(array.field(i), array_fn=array_fn) for i, field in enumerate(array.type)}
|
177
|
+
|
178
|
+
|
179
|
+
def table_to_struct_array(table: pa.Table) -> pa.StructArray:
|
180
|
+
if not table.num_rows:
|
181
|
+
return pa.array([], type=pa.struct(table.schema))
|
182
|
+
array = table.to_struct_array()
|
183
|
+
if isinstance(array, pa.ChunkedArray):
|
184
|
+
array = array.combine_chunks()
|
185
|
+
return array
|
186
|
+
|
187
|
+
|
188
|
+
def table_from_struct_array(array: pa.StructArray | pa.ChunkedArray):
|
189
|
+
if len(array) == 0:
|
190
|
+
return null_table(pa.schema(array.type))
|
191
|
+
return pa.Table.from_struct_array(array)
|
192
|
+
|
193
|
+
|
194
|
+
def field(value: pa.StructArray | pa.StructScalar, name: str) -> pa.Array | pa.Scalar:
|
195
|
+
"""Get a field from a struct-like value."""
|
196
|
+
if isinstance(value, pa.StructScalar):
|
197
|
+
return value[name]
|
198
|
+
return value.field(name)
|
199
|
+
|
200
|
+
|
201
|
+
def concat_tables(tables: list[pa.Table]) -> pa.Table:
|
202
|
+
"""
|
203
|
+
Concatenate pyarrow.Table objects, filling "missing" data with appropriate null arrays
|
204
|
+
and casting arrays to the most common denominator type that fits all fields.
|
205
|
+
"""
|
206
|
+
if len(tables) == 1:
|
207
|
+
return tables[0]
|
208
|
+
else:
|
209
|
+
return pa.concat_tables(tables, promote_options="permissive")
|
spiral/authn/__init__.py
ADDED
File without changes
|
spiral/authn/authn.py
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
import base64
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
|
5
|
+
from spiral.api import Authn, SpiralAPI
|
6
|
+
|
7
|
+
ENV_TOKEN_ID = "SPIRAL_TOKEN_ID"
|
8
|
+
ENV_TOKEN_SECRET = "SPIRAL_TOKEN_SECRET"
|
9
|
+
|
10
|
+
log = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class FallbackAuthn(Authn):
|
14
|
+
"""Credential provider that tries multiple providers in order."""
|
15
|
+
|
16
|
+
def __init__(self, providers: list[Authn]):
|
17
|
+
self._providers = providers
|
18
|
+
|
19
|
+
def token(self) -> str | None:
|
20
|
+
for provider in self._providers:
|
21
|
+
token = provider.token()
|
22
|
+
if token is not None:
|
23
|
+
return token
|
24
|
+
return None
|
25
|
+
|
26
|
+
|
27
|
+
class TokenAuthn(Authn):
|
28
|
+
"""Credential provider that returns a fixed token."""
|
29
|
+
|
30
|
+
def __init__(self, token: str):
|
31
|
+
self._token = token
|
32
|
+
|
33
|
+
def token(self) -> str:
|
34
|
+
return self._token
|
35
|
+
|
36
|
+
|
37
|
+
class EnvironmentAuthn(Authn):
|
38
|
+
"""Credential provider that returns a basic token from the environment.
|
39
|
+
|
40
|
+
NOTE: Returns basic token. Must be exchanged.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def token(self) -> str | None:
|
44
|
+
if ENV_TOKEN_ID not in os.environ:
|
45
|
+
return None
|
46
|
+
if ENV_TOKEN_SECRET not in os.environ:
|
47
|
+
raise ValueError(f"{ENV_TOKEN_SECRET} is missing.")
|
48
|
+
|
49
|
+
token_id = os.environ[ENV_TOKEN_ID]
|
50
|
+
token_secret = os.environ[ENV_TOKEN_SECRET]
|
51
|
+
basic_token = base64.b64encode(f"{token_id}:{token_secret}".encode()).decode("utf-8")
|
52
|
+
|
53
|
+
return basic_token
|
54
|
+
|
55
|
+
|
56
|
+
class DeviceAuthProvider(Authn):
|
57
|
+
"""Auth provider that uses the device flow to authenticate a Spiral user."""
|
58
|
+
|
59
|
+
def __init__(self, device_auth):
|
60
|
+
# NOTE(ngates): device_auth: spiral.auth.device_code.DeviceAuth
|
61
|
+
# We don't type it to satisfy our import linter
|
62
|
+
self._device_auth = device_auth
|
63
|
+
|
64
|
+
def token(self) -> str | None:
|
65
|
+
# TODO(ngates): only run this if we're in a notebook, CLI, or otherwise on the user's machine.
|
66
|
+
return self._device_auth.authenticate().access_token
|
67
|
+
|
68
|
+
|
69
|
+
class TokenExchangeProvider(Authn):
|
70
|
+
"""Auth provider that exchanges a basic token for a Spiral token."""
|
71
|
+
|
72
|
+
def __init__(self, authn: Authn, base_url: str):
|
73
|
+
self._authn = authn
|
74
|
+
self._token_service = SpiralAPI(authn, base_url).token
|
75
|
+
|
76
|
+
self._sp_token = None
|
77
|
+
|
78
|
+
def token(self) -> str | None:
|
79
|
+
if self._sp_token is not None:
|
80
|
+
return self._sp_token
|
81
|
+
|
82
|
+
# Don't try to exchange if token is not discovered.
|
83
|
+
if self._authn.token() is None:
|
84
|
+
return None
|
85
|
+
|
86
|
+
log.debug("Exchanging token")
|
87
|
+
self._sp_token = self._token_service.exchange().token
|
88
|
+
|
89
|
+
return self._sp_token
|
spiral/authn/device.py
ADDED
@@ -0,0 +1,206 @@
|
|
1
|
+
import logging
|
2
|
+
import sys
|
3
|
+
import textwrap
|
4
|
+
import time
|
5
|
+
import webbrowser
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
import httpx
|
9
|
+
import jwt
|
10
|
+
from pydantic import BaseModel
|
11
|
+
|
12
|
+
log = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class TokensModel(BaseModel):
|
16
|
+
access_token: str
|
17
|
+
refresh_token: str
|
18
|
+
|
19
|
+
@property
|
20
|
+
def organization_id(self) -> str | None:
|
21
|
+
return self.unverified_access_token().get("org_id")
|
22
|
+
|
23
|
+
def unverified_access_token(self):
|
24
|
+
return jwt.decode(self.access_token, options={"verify_signature": False})
|
25
|
+
|
26
|
+
|
27
|
+
class AuthModel(BaseModel):
|
28
|
+
tokens: TokensModel | None = None
|
29
|
+
|
30
|
+
|
31
|
+
class DeviceAuth:
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
auth_file: Path,
|
35
|
+
domain: str,
|
36
|
+
client_id: str,
|
37
|
+
http: httpx.Client = None,
|
38
|
+
):
|
39
|
+
self._auth_file = auth_file
|
40
|
+
self._domain = domain
|
41
|
+
self._client_id = client_id
|
42
|
+
self._http = http or httpx.Client()
|
43
|
+
|
44
|
+
if self._auth_file.exists():
|
45
|
+
with self._auth_file.open("r") as f:
|
46
|
+
self._auth = AuthModel.model_validate_json(f.read())
|
47
|
+
else:
|
48
|
+
self._auth = AuthModel()
|
49
|
+
|
50
|
+
self._default_scope = ["email", "profile"]
|
51
|
+
|
52
|
+
def is_authenticated(self) -> bool:
|
53
|
+
"""Check if the user is authenticated."""
|
54
|
+
tokens = self._auth.tokens
|
55
|
+
if tokens is None:
|
56
|
+
return False
|
57
|
+
|
58
|
+
# Give ourselves a 30-second buffer before the token expires.
|
59
|
+
return tokens.unverified_access_token()["exp"] - 30 > time.time()
|
60
|
+
|
61
|
+
def authenticate(self, force: bool = False, refresh: bool = False, organization_id: str = None) -> TokensModel:
|
62
|
+
"""Blocking call to authenticate the user.
|
63
|
+
|
64
|
+
Triggers a device code flow and polls for the user to login.
|
65
|
+
"""
|
66
|
+
if force:
|
67
|
+
return self._device_code(organization_id)
|
68
|
+
|
69
|
+
if refresh:
|
70
|
+
if self._auth.tokens is None:
|
71
|
+
raise ValueError("No tokens to refresh.")
|
72
|
+
tokens = self._refresh(self._auth.tokens, organization_id)
|
73
|
+
if not tokens:
|
74
|
+
raise ValueError("Failed to refresh token.")
|
75
|
+
return tokens
|
76
|
+
|
77
|
+
# Check for mis-matched organization.
|
78
|
+
if organization_id is not None:
|
79
|
+
tokens = self._auth.tokens
|
80
|
+
if tokens is not None and tokens.unverified_access_token().get("org_id") != organization_id:
|
81
|
+
tokens = self._refresh(self._auth.tokens, organization_id)
|
82
|
+
if tokens is None:
|
83
|
+
return self._device_code(organization_id)
|
84
|
+
|
85
|
+
if self.is_authenticated():
|
86
|
+
return self._auth.tokens
|
87
|
+
|
88
|
+
# Try to refresh.
|
89
|
+
tokens = self._auth.tokens
|
90
|
+
if tokens is not None:
|
91
|
+
tokens = self._refresh(tokens)
|
92
|
+
if tokens is not None:
|
93
|
+
return tokens
|
94
|
+
|
95
|
+
# Otherwise, we kick off the device code flow.
|
96
|
+
return self._device_code(organization_id)
|
97
|
+
|
98
|
+
def logout(self):
|
99
|
+
self._remove_tokens()
|
100
|
+
|
101
|
+
def _device_code(self, organization_id: str | None):
|
102
|
+
scope = " ".join(self._default_scope)
|
103
|
+
res = self._http.post(
|
104
|
+
f"{self._domain}/auth/device/code",
|
105
|
+
data={
|
106
|
+
"client_id": self._client_id,
|
107
|
+
"scope": scope,
|
108
|
+
"organization_id": organization_id,
|
109
|
+
},
|
110
|
+
)
|
111
|
+
res = res.raise_for_status().json()
|
112
|
+
device_code = res["device_code"]
|
113
|
+
user_code = res["user_code"]
|
114
|
+
expires_at = res["expires_in"] + time.time()
|
115
|
+
interval = res["interval"]
|
116
|
+
verification_uri_complete = res["verification_uri_complete"]
|
117
|
+
|
118
|
+
# We need to detect if the user is running in a terminal, in Jupyter, etc.
|
119
|
+
# For now, we'll try to open the browser.
|
120
|
+
sys.stderr.write(
|
121
|
+
textwrap.dedent(
|
122
|
+
f"""
|
123
|
+
Please login here: {verification_uri_complete}
|
124
|
+
Your code is {user_code}.
|
125
|
+
"""
|
126
|
+
)
|
127
|
+
)
|
128
|
+
|
129
|
+
# Try to open the browser (this also works if the Jupiter notebook is running on the user's machine).
|
130
|
+
opened = webbrowser.open(verification_uri_complete)
|
131
|
+
|
132
|
+
# If we have a server-side Jupyter notebook, we can try to open with client-side JavaScript.
|
133
|
+
if not opened and _in_notebook():
|
134
|
+
from IPython.display import Javascript, display
|
135
|
+
|
136
|
+
display(Javascript(f'window.open("{verification_uri_complete}");'))
|
137
|
+
|
138
|
+
# In the meantime, we need to poll for the user to login.
|
139
|
+
while True:
|
140
|
+
if time.time() > expires_at:
|
141
|
+
raise TimeoutError("Login timed out.")
|
142
|
+
time.sleep(interval)
|
143
|
+
res = self._http.post(
|
144
|
+
f"{self._domain}/auth/token",
|
145
|
+
data={
|
146
|
+
"client_id": self._client_id,
|
147
|
+
"device_code": device_code,
|
148
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
149
|
+
},
|
150
|
+
)
|
151
|
+
if not res.is_success:
|
152
|
+
continue
|
153
|
+
|
154
|
+
tokens = TokensModel(
|
155
|
+
access_token=res.json()["access_token"],
|
156
|
+
refresh_token=res.json()["refresh_token"],
|
157
|
+
)
|
158
|
+
self._save_tokens(tokens)
|
159
|
+
return self._auth.tokens
|
160
|
+
|
161
|
+
def _refresh(self, tokens: TokensModel, organization_id: str = None) -> TokensModel | None:
|
162
|
+
"""Attempt to use the refresh token."""
|
163
|
+
log.debug("Refreshing token %s", self._client_id)
|
164
|
+
|
165
|
+
res = self._http.post(
|
166
|
+
f"{self._domain}/auth/refresh",
|
167
|
+
data={
|
168
|
+
"client_id": self._client_id,
|
169
|
+
"grant_type": "refresh_token",
|
170
|
+
"refresh_token": tokens.refresh_token,
|
171
|
+
"organization_id": organization_id,
|
172
|
+
},
|
173
|
+
)
|
174
|
+
if not res.is_success:
|
175
|
+
print("Failed to refresh token", res.status_code, res.text)
|
176
|
+
return None
|
177
|
+
|
178
|
+
tokens = TokensModel(
|
179
|
+
access_token=res.json()["access_token"],
|
180
|
+
refresh_token=res.json()["refresh_token"],
|
181
|
+
)
|
182
|
+
self._save_tokens(tokens)
|
183
|
+
return tokens
|
184
|
+
|
185
|
+
def _save_tokens(self, tokens: TokensModel):
|
186
|
+
self._auth = self._auth.model_copy(update={"tokens": tokens})
|
187
|
+
self._auth_file.parent.mkdir(parents=True, exist_ok=True)
|
188
|
+
with self._auth_file.open("w") as f:
|
189
|
+
f.write(self._auth.model_dump_json(exclude_defaults=True))
|
190
|
+
|
191
|
+
def _remove_tokens(self):
|
192
|
+
self._auth_file.unlink(missing_ok=True)
|
193
|
+
self._auth = self._auth.model_copy(update={"tokens": None})
|
194
|
+
|
195
|
+
|
196
|
+
def _in_notebook():
|
197
|
+
try:
|
198
|
+
from IPython import get_ipython
|
199
|
+
|
200
|
+
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
201
|
+
return False
|
202
|
+
except ImportError:
|
203
|
+
return False
|
204
|
+
except AttributeError:
|
205
|
+
return False
|
206
|
+
return True
|