pyspiral 0.6.6__cp312-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyspiral-0.6.6.dist-info/METADATA +51 -0
- pyspiral-0.6.6.dist-info/RECORD +102 -0
- pyspiral-0.6.6.dist-info/WHEEL +4 -0
- pyspiral-0.6.6.dist-info/entry_points.txt +2 -0
- spiral/__init__.py +35 -0
- spiral/_lib.abi3.so +0 -0
- spiral/adbc.py +411 -0
- spiral/api/__init__.py +78 -0
- spiral/api/admin.py +15 -0
- spiral/api/client.py +164 -0
- spiral/api/filesystems.py +134 -0
- spiral/api/key_space_indexes.py +23 -0
- spiral/api/organizations.py +77 -0
- spiral/api/projects.py +219 -0
- spiral/api/telemetry.py +19 -0
- spiral/api/text_indexes.py +56 -0
- spiral/api/types.py +22 -0
- spiral/api/workers.py +40 -0
- spiral/api/workloads.py +52 -0
- spiral/arrow_.py +216 -0
- spiral/cli/__init__.py +88 -0
- spiral/cli/__main__.py +4 -0
- spiral/cli/admin.py +14 -0
- spiral/cli/app.py +104 -0
- spiral/cli/console.py +95 -0
- spiral/cli/fs.py +76 -0
- spiral/cli/iceberg.py +97 -0
- spiral/cli/key_spaces.py +89 -0
- spiral/cli/login.py +24 -0
- spiral/cli/orgs.py +89 -0
- spiral/cli/printer.py +53 -0
- spiral/cli/projects.py +147 -0
- spiral/cli/state.py +5 -0
- spiral/cli/tables.py +174 -0
- spiral/cli/telemetry.py +17 -0
- spiral/cli/text.py +115 -0
- spiral/cli/types.py +50 -0
- spiral/cli/workloads.py +58 -0
- spiral/client.py +178 -0
- spiral/core/__init__.pyi +0 -0
- spiral/core/_tools/__init__.pyi +5 -0
- spiral/core/authn/__init__.pyi +27 -0
- spiral/core/client/__init__.pyi +237 -0
- spiral/core/table/__init__.pyi +101 -0
- spiral/core/table/manifests/__init__.pyi +35 -0
- spiral/core/table/metastore/__init__.pyi +58 -0
- spiral/core/table/spec/__init__.pyi +213 -0
- spiral/dataloader.py +285 -0
- spiral/dataset.py +255 -0
- spiral/datetime_.py +27 -0
- spiral/debug/__init__.py +0 -0
- spiral/debug/manifests.py +87 -0
- spiral/debug/metrics.py +56 -0
- spiral/debug/scan.py +266 -0
- spiral/expressions/__init__.py +276 -0
- spiral/expressions/base.py +157 -0
- spiral/expressions/http.py +86 -0
- spiral/expressions/io.py +100 -0
- spiral/expressions/list_.py +68 -0
- spiral/expressions/mp4.py +62 -0
- spiral/expressions/png.py +18 -0
- spiral/expressions/qoi.py +18 -0
- spiral/expressions/refs.py +58 -0
- spiral/expressions/str_.py +39 -0
- spiral/expressions/struct.py +59 -0
- spiral/expressions/text.py +62 -0
- spiral/expressions/tiff.py +223 -0
- spiral/expressions/udf.py +46 -0
- spiral/grpc_.py +32 -0
- spiral/iceberg.py +31 -0
- spiral/iterable_dataset.py +106 -0
- spiral/key_space_index.py +44 -0
- spiral/project.py +199 -0
- spiral/protogen/_/__init__.py +0 -0
- spiral/protogen/_/arrow/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
- spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
- spiral/protogen/_/google/__init__.py +0 -0
- spiral/protogen/_/google/protobuf/__init__.py +2310 -0
- spiral/protogen/_/message_pool.py +3 -0
- spiral/protogen/_/py.typed +0 -0
- spiral/protogen/_/scandal/__init__.py +190 -0
- spiral/protogen/_/spfs/__init__.py +72 -0
- spiral/protogen/_/spql/__init__.py +61 -0
- spiral/protogen/_/substrait/__init__.py +6196 -0
- spiral/protogen/_/substrait/extensions/__init__.py +169 -0
- spiral/protogen/__init__.py +0 -0
- spiral/protogen/util.py +41 -0
- spiral/py.typed +0 -0
- spiral/scan.py +285 -0
- spiral/server.py +17 -0
- spiral/settings.py +114 -0
- spiral/snapshot.py +56 -0
- spiral/streaming_/__init__.py +3 -0
- spiral/streaming_/reader.py +133 -0
- spiral/streaming_/stream.py +157 -0
- spiral/substrait_.py +274 -0
- spiral/table.py +293 -0
- spiral/text_index.py +17 -0
- spiral/transaction.py +58 -0
- spiral/types_.py +6 -0
spiral/api/projects.py
ADDED
@@ -0,0 +1,219 @@
|
|
1
|
+
from typing import Annotated, Literal
|
2
|
+
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
from .client import Paged, PagedResponse, ServiceBase
|
6
|
+
from .types import OrgId, ProjectId, RoleId
|
7
|
+
|
8
|
+
|
9
|
+
class Project(BaseModel):
|
10
|
+
id: ProjectId
|
11
|
+
org_id: OrgId
|
12
|
+
name: str | None = None
|
13
|
+
|
14
|
+
|
15
|
+
class CreateProjectRequest(BaseModel):
|
16
|
+
id_prefix: str | None = None
|
17
|
+
name: str | None = None
|
18
|
+
|
19
|
+
|
20
|
+
class CreateProjectResponse(BaseModel):
|
21
|
+
project: Project
|
22
|
+
|
23
|
+
|
24
|
+
class Grant(BaseModel):
|
25
|
+
id: str
|
26
|
+
project_id: ProjectId
|
27
|
+
role_id: RoleId
|
28
|
+
principal: str
|
29
|
+
conditions: dict | None = None
|
30
|
+
|
31
|
+
|
32
|
+
class OrgRolePrincipalConditions(BaseModel):
|
33
|
+
type: Literal["org_role"] = "org_role"
|
34
|
+
org_id: OrgId
|
35
|
+
role: str
|
36
|
+
|
37
|
+
|
38
|
+
class OrgUserPrincipalConditions(BaseModel):
|
39
|
+
type: Literal["org_user"] = "org_user"
|
40
|
+
org_id: OrgId
|
41
|
+
user_id: str
|
42
|
+
|
43
|
+
|
44
|
+
class WorkloadPrincipalConditions(BaseModel):
|
45
|
+
type: Literal["workload"] = "workload"
|
46
|
+
workload_id: str
|
47
|
+
|
48
|
+
|
49
|
+
class GitHubConditions(BaseModel):
|
50
|
+
environment: str | None = None
|
51
|
+
ref: str | None = None
|
52
|
+
ref_type: str | None = None
|
53
|
+
sha: str | None = None
|
54
|
+
repository: str | None = None
|
55
|
+
repository_owner: str | None = None
|
56
|
+
repository_visibility: str | None = None
|
57
|
+
repository_id: str | None = None
|
58
|
+
repository_owner_id: str | None = None
|
59
|
+
run_id: str | None = None
|
60
|
+
run_number: str | None = None
|
61
|
+
run_attempt: str | None = None
|
62
|
+
runner_environment: str | None = None
|
63
|
+
actor_id: str | None = None
|
64
|
+
actor: str | None = None
|
65
|
+
workflow: str | None = None
|
66
|
+
head_ref: str | None = None
|
67
|
+
base_ref: str | None = None
|
68
|
+
job_workflow_ref: str | None = None
|
69
|
+
event_name: str | None = None
|
70
|
+
|
71
|
+
|
72
|
+
class GitHubPrincipalConditions(BaseModel):
|
73
|
+
type: Literal["github"] = "github"
|
74
|
+
org: str
|
75
|
+
repo: str
|
76
|
+
conditions: GitHubConditions | None = None
|
77
|
+
|
78
|
+
|
79
|
+
class ModalConditions(BaseModel):
|
80
|
+
app_id: str | None = None
|
81
|
+
app_name: str | None = None
|
82
|
+
function_id: str | None = None
|
83
|
+
function_name: str | None = None
|
84
|
+
container_id: str | None = None
|
85
|
+
|
86
|
+
|
87
|
+
class ModalPrincipalConditions(BaseModel):
|
88
|
+
type: Literal["modal"] = "modal"
|
89
|
+
|
90
|
+
# A Modal App is a group of functions and classes that are deployed together.
|
91
|
+
# See https://modal.com/docs/guide/apps. Nick and Marko discussed having an app_name
|
92
|
+
# here as well and decided to leave it out for now with the assumption that people
|
93
|
+
# will want to authorize the whole Modal environment to access Spiral (their data).
|
94
|
+
workspace_id: str
|
95
|
+
# Environments are sub-divisions of workspaces. Name is unique within a workspace.
|
96
|
+
# See https://modal.com/docs/guide/environments
|
97
|
+
environment_name: str
|
98
|
+
|
99
|
+
conditions: ModalConditions | None = None
|
100
|
+
|
101
|
+
|
102
|
+
class GcpServiceAccountPrincipalConditions(BaseModel):
|
103
|
+
type: Literal["gcp"] = "gcp"
|
104
|
+
service_account: str
|
105
|
+
unique_id: str
|
106
|
+
|
107
|
+
|
108
|
+
class AwsAssumedRolePrincipalConditions(BaseModel):
|
109
|
+
type: Literal["aws"] = "aws"
|
110
|
+
account_id: str
|
111
|
+
role_name: str
|
112
|
+
|
113
|
+
|
114
|
+
PrincipalConditions = Annotated[
|
115
|
+
OrgRolePrincipalConditions
|
116
|
+
| OrgUserPrincipalConditions
|
117
|
+
| WorkloadPrincipalConditions
|
118
|
+
| GitHubPrincipalConditions
|
119
|
+
| ModalPrincipalConditions
|
120
|
+
| GcpServiceAccountPrincipalConditions
|
121
|
+
| AwsAssumedRolePrincipalConditions,
|
122
|
+
Field(discriminator="type"),
|
123
|
+
]
|
124
|
+
|
125
|
+
|
126
|
+
class GrantRoleRequest(BaseModel):
|
127
|
+
role_id: RoleId
|
128
|
+
principal: PrincipalConditions
|
129
|
+
|
130
|
+
|
131
|
+
class GrantRoleResponse(BaseModel):
|
132
|
+
grant: Grant
|
133
|
+
|
134
|
+
|
135
|
+
class TableResource(BaseModel):
|
136
|
+
id: str
|
137
|
+
project_id: ProjectId
|
138
|
+
dataset: str
|
139
|
+
table: str
|
140
|
+
|
141
|
+
|
142
|
+
class TextIndexResource(BaseModel):
|
143
|
+
id: str
|
144
|
+
project_id: ProjectId
|
145
|
+
name: str
|
146
|
+
|
147
|
+
|
148
|
+
class KeySpaceIndexResource(BaseModel):
|
149
|
+
id: str
|
150
|
+
project_id: ProjectId
|
151
|
+
name: str
|
152
|
+
|
153
|
+
|
154
|
+
class ProjectService(ServiceBase):
|
155
|
+
"""Service for project operations."""
|
156
|
+
|
157
|
+
def create(self, request: CreateProjectRequest) -> CreateProjectResponse:
|
158
|
+
"""Create a new project."""
|
159
|
+
return self.client.post("/v1/projects", request, CreateProjectResponse)
|
160
|
+
|
161
|
+
def list(self) -> Paged[Project]:
|
162
|
+
"""List projects."""
|
163
|
+
return self.client.paged("/v1/projects", PagedResponse[Project])
|
164
|
+
|
165
|
+
def list_tables(
|
166
|
+
self, project_id: ProjectId, dataset: str | None = None, table: str | None = None
|
167
|
+
) -> Paged[TableResource]:
|
168
|
+
"""List tables in a project."""
|
169
|
+
params = {}
|
170
|
+
if dataset:
|
171
|
+
params["dataset"] = dataset
|
172
|
+
if table:
|
173
|
+
params["table"] = table
|
174
|
+
return self.client.paged(f"/v1/projects/{project_id}/tables", PagedResponse[TableResource], params=params)
|
175
|
+
|
176
|
+
def list_text_indexes(self, project_id: ProjectId, name: str | None = None) -> Paged[TextIndexResource]:
|
177
|
+
"""List text indexes in a project."""
|
178
|
+
params = {}
|
179
|
+
if name:
|
180
|
+
params["name"] = name
|
181
|
+
return self.client.paged(
|
182
|
+
f"/v1/projects/{project_id}/text-indexes", PagedResponse[TextIndexResource], params=params
|
183
|
+
)
|
184
|
+
|
185
|
+
def list_key_space_indexes(self, project_id: ProjectId, name: str | None = None) -> Paged[KeySpaceIndexResource]:
|
186
|
+
"""List key space indexes in a project."""
|
187
|
+
params = {}
|
188
|
+
if name:
|
189
|
+
params["name"] = name
|
190
|
+
return self.client.paged(
|
191
|
+
f"/v1/projects/{project_id}/key-space-indexes", PagedResponse[KeySpaceIndexResource], params=params
|
192
|
+
)
|
193
|
+
|
194
|
+
def get(self, project_id: ProjectId) -> Project:
|
195
|
+
"""Get a project."""
|
196
|
+
return self.client.get(f"/v1/projects/{project_id}", Project)
|
197
|
+
|
198
|
+
def grant_role(self, project_id: ProjectId, request: GrantRoleRequest) -> GrantRoleResponse:
|
199
|
+
"""Grant a role to a principal."""
|
200
|
+
return self.client.post(f"/v1/projects/{project_id}/grants", request, GrantRoleResponse)
|
201
|
+
|
202
|
+
def list_grants(
|
203
|
+
self,
|
204
|
+
project_id: ProjectId,
|
205
|
+
principal: str | None = None,
|
206
|
+
) -> Paged[Grant]:
|
207
|
+
"""List active project grants."""
|
208
|
+
params = {}
|
209
|
+
if principal:
|
210
|
+
params["principal"] = principal
|
211
|
+
return self.client.paged(f"/v1/projects/{project_id}/grants", PagedResponse[Grant], params=params)
|
212
|
+
|
213
|
+
def get_grant(self, grant_id: str) -> Grant:
|
214
|
+
"""Get a grant."""
|
215
|
+
return self.client.get(f"/v1/grants/{grant_id}", Grant)
|
216
|
+
|
217
|
+
def revoke_grant(self, grant_id: str):
|
218
|
+
"""Revoke a grant."""
|
219
|
+
return self.client.delete(f"/v1/grants/{grant_id}", type[None])
|
spiral/api/telemetry.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from .client import ServiceBase
|
4
|
+
|
5
|
+
|
6
|
+
class IssueExportTokenRequest(BaseModel):
|
7
|
+
pass
|
8
|
+
|
9
|
+
|
10
|
+
class IssueExportTokenResponse(BaseModel):
|
11
|
+
token: str
|
12
|
+
|
13
|
+
|
14
|
+
class TelemetryService(ServiceBase):
|
15
|
+
"""Service for telemetry operations."""
|
16
|
+
|
17
|
+
def issue_export_token(self) -> IssueExportTokenResponse:
|
18
|
+
"""Issue telemetry export token."""
|
19
|
+
return self.client.put("/v1/telemetry/token", IssueExportTokenRequest(), IssueExportTokenResponse)
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from .client import Paged, PagedResponse, ServiceBase
|
4
|
+
from .types import IndexId, ProjectId, WorkerId
|
5
|
+
from .workers import CPU, GcpRegion, Memory, ResourceClass
|
6
|
+
|
7
|
+
|
8
|
+
class TextSearchWorker(BaseModel):
|
9
|
+
worker_id: WorkerId
|
10
|
+
project_id: ProjectId
|
11
|
+
index_id: IndexId
|
12
|
+
url: str | None
|
13
|
+
|
14
|
+
|
15
|
+
class CreateWorkerRequest(BaseModel):
|
16
|
+
cpu: CPU
|
17
|
+
memory: Memory
|
18
|
+
region: GcpRegion
|
19
|
+
|
20
|
+
|
21
|
+
class CreateWorkerResponse(BaseModel):
|
22
|
+
worker_id: WorkerId
|
23
|
+
|
24
|
+
|
25
|
+
class SyncIndexRequest(BaseModel):
|
26
|
+
"""Request to sync a text index."""
|
27
|
+
|
28
|
+
resources: ResourceClass
|
29
|
+
|
30
|
+
|
31
|
+
class SyncIndexResponse(BaseModel):
|
32
|
+
worker_id: WorkerId
|
33
|
+
|
34
|
+
|
35
|
+
class TextIndexesService(ServiceBase):
|
36
|
+
"""Service for text index operations."""
|
37
|
+
|
38
|
+
def create_worker(self, index_id: IndexId, request: CreateWorkerRequest) -> CreateWorkerResponse:
|
39
|
+
"""Create a new search worker."""
|
40
|
+
return self.client.post(f"/v1/text-indexes/{index_id}/workers", request, CreateWorkerResponse)
|
41
|
+
|
42
|
+
def list_workers(self, index_id: IndexId) -> Paged[WorkerId]:
|
43
|
+
"""List text index workers for the given index."""
|
44
|
+
return self.client.paged(f"/v1/text-indexes/{index_id}/workers", PagedResponse[WorkerId])
|
45
|
+
|
46
|
+
def get_worker(self, worker_id: WorkerId) -> TextSearchWorker:
|
47
|
+
"""Get a text index worker."""
|
48
|
+
return self.client.get(f"/v1/text-index-workers/{worker_id}", TextSearchWorker)
|
49
|
+
|
50
|
+
def shutdown_worker(self, worker_id: WorkerId) -> None:
|
51
|
+
"""Shutdown a text index worker."""
|
52
|
+
return self.client.delete(f"/v1/text-index-workers/{worker_id}", type[None])
|
53
|
+
|
54
|
+
def sync_index(self, index_id: IndexId, request: SyncIndexRequest) -> SyncIndexResponse:
|
55
|
+
"""Start a job to sync an index."""
|
56
|
+
return self.client.post(f"/v1/text-indexes/{index_id}/sync", request, SyncIndexResponse)
|
spiral/api/types.py
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
from typing import Annotated
|
2
|
+
|
3
|
+
from pydantic import AfterValidator, StringConstraints
|
4
|
+
|
5
|
+
|
6
|
+
def _validate_root_uri(uri: str) -> str:
|
7
|
+
if uri.endswith("/"):
|
8
|
+
raise ValueError("Root URI must not end with a slash.")
|
9
|
+
return uri
|
10
|
+
|
11
|
+
|
12
|
+
UserId = str
|
13
|
+
OrgId = str
|
14
|
+
ProjectId = str
|
15
|
+
RoleId = str
|
16
|
+
IndexId = str
|
17
|
+
WorkerId = str
|
18
|
+
|
19
|
+
RootUri = Annotated[str, AfterValidator(_validate_root_uri)]
|
20
|
+
DatasetName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
|
21
|
+
TableName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]*$")]
|
22
|
+
IndexName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]*$")]
|
spiral/api/workers.py
ADDED
@@ -0,0 +1,40 @@
|
|
1
|
+
from enum import Enum, IntEnum
|
2
|
+
|
3
|
+
|
4
|
+
class CPU(IntEnum):
|
5
|
+
ONE = 1
|
6
|
+
TWO = 2
|
7
|
+
FOUR = 4
|
8
|
+
EIGHT = 8
|
9
|
+
|
10
|
+
def __str__(self):
|
11
|
+
return str(self.value)
|
12
|
+
|
13
|
+
|
14
|
+
class Memory(str, Enum):
|
15
|
+
MB_512 = "512Mi"
|
16
|
+
GB_1 = "1Gi"
|
17
|
+
GB_2 = "2Gi"
|
18
|
+
GB_4 = "4Gi"
|
19
|
+
GB_8 = "8Gi"
|
20
|
+
|
21
|
+
def __str__(self):
|
22
|
+
return self.value
|
23
|
+
|
24
|
+
|
25
|
+
class GcpRegion(str, Enum):
|
26
|
+
US_EAST4 = "us-east4"
|
27
|
+
EUROPE_WEST4 = "europe-west4"
|
28
|
+
|
29
|
+
def __str__(self):
|
30
|
+
return self.value
|
31
|
+
|
32
|
+
|
33
|
+
class ResourceClass(str, Enum):
|
34
|
+
"""Resource class for text index sync."""
|
35
|
+
|
36
|
+
SMALL = "small"
|
37
|
+
LARGE = "large"
|
38
|
+
|
39
|
+
def __str__(self):
|
40
|
+
return self.value
|
spiral/api/workloads.py
ADDED
@@ -0,0 +1,52 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from .client import Paged, PagedResponse, ServiceBase
|
4
|
+
from .types import ProjectId
|
5
|
+
|
6
|
+
|
7
|
+
class Workload(BaseModel):
|
8
|
+
id: str
|
9
|
+
project_id: ProjectId
|
10
|
+
name: str | None = None
|
11
|
+
|
12
|
+
|
13
|
+
class CreateWorkloadRequest(BaseModel):
|
14
|
+
name: str | None = None
|
15
|
+
|
16
|
+
|
17
|
+
class CreateWorkloadResponse(BaseModel):
|
18
|
+
workload: Workload
|
19
|
+
|
20
|
+
|
21
|
+
class IssueWorkloadCredentialsResponse(BaseModel):
|
22
|
+
client_id: str
|
23
|
+
client_secret: str
|
24
|
+
revoked_client_id: str | None = None
|
25
|
+
|
26
|
+
|
27
|
+
class WorkloadService(ServiceBase):
|
28
|
+
"""Service for workload operations."""
|
29
|
+
|
30
|
+
def create(self, project_id: ProjectId, request: CreateWorkloadRequest) -> CreateWorkloadResponse:
|
31
|
+
"""Create a new workload."""
|
32
|
+
return self.client.post(f"/v1/projects/{project_id}/workloads", request, CreateWorkloadResponse)
|
33
|
+
|
34
|
+
def list(self, project_id: ProjectId) -> Paged[Workload]:
|
35
|
+
"""List active project workloads."""
|
36
|
+
return self.client.paged(f"/projects/{project_id}/workloads", PagedResponse[Workload])
|
37
|
+
|
38
|
+
def get(self, workload_id: str) -> Workload:
|
39
|
+
"""Get a workload."""
|
40
|
+
return self.client.get(f"/v1/workloads/{workload_id}", Workload)
|
41
|
+
|
42
|
+
def deactivate(self, workload_id: str) -> None:
|
43
|
+
"""De-activate a workload."""
|
44
|
+
return self.client.delete(f"/v1/workloads/{workload_id}", None)
|
45
|
+
|
46
|
+
def issue_credentials(self, workload_id: str) -> IssueWorkloadCredentialsResponse:
|
47
|
+
"""Issue workload credentials."""
|
48
|
+
return self.client.post(f"/v1/workloads/{workload_id}/credentials", None, IssueWorkloadCredentialsResponse)
|
49
|
+
|
50
|
+
def revoke_credentials(self, client_id: str) -> None:
|
51
|
+
"""Revoke workload credentials."""
|
52
|
+
return self.client.delete(f"/v1/credentials/{client_id}", None)
|
spiral/arrow_.py
ADDED
@@ -0,0 +1,216 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from collections.abc import Callable, Iterable
|
3
|
+
from functools import reduce
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pyarrow as pa
|
7
|
+
from pyarrow import compute as pc
|
8
|
+
|
9
|
+
|
10
|
+
def arange(*args, **kwargs) -> pa.Array:
|
11
|
+
return pa.array(np.arange(*args, **kwargs), type=pa.int32())
|
12
|
+
|
13
|
+
|
14
|
+
def zip_tables(tables: Iterable[pa.Table]) -> pa.Table:
|
15
|
+
data = []
|
16
|
+
names = []
|
17
|
+
for table in tables:
|
18
|
+
data.extend(table.columns)
|
19
|
+
names.extend(table.column_names)
|
20
|
+
return pa.Table.from_arrays(data, names=names)
|
21
|
+
|
22
|
+
|
23
|
+
def merge_arrays(*arrays: pa.StructArray) -> pa.StructArray:
|
24
|
+
"""Recursively merge arrays into nested struct arrays."""
|
25
|
+
if len(arrays) == 1:
|
26
|
+
return arrays[0]
|
27
|
+
|
28
|
+
nstructs = sum(pa.types.is_struct(a.type) for a in arrays)
|
29
|
+
if nstructs == 0:
|
30
|
+
# Then we have conflicting arrays and we choose the last.
|
31
|
+
return arrays[-1]
|
32
|
+
|
33
|
+
if nstructs != len(arrays):
|
34
|
+
raise ValueError("Cannot merge structs with non-structs.")
|
35
|
+
|
36
|
+
data = defaultdict(list)
|
37
|
+
for array in arrays:
|
38
|
+
if isinstance(array, pa.ChunkedArray):
|
39
|
+
array = array.combine_chunks()
|
40
|
+
for field in array.type:
|
41
|
+
data[field.name].append(array.field(field.name))
|
42
|
+
|
43
|
+
return pa.StructArray.from_arrays([merge_arrays(*v) for v in data.values()], names=list(data.keys()))
|
44
|
+
|
45
|
+
|
46
|
+
def merge_scalars(*scalars: pa.StructScalar) -> pa.StructScalar:
|
47
|
+
"""Recursively merge scalars into nested struct scalars."""
|
48
|
+
if len(scalars) == 1:
|
49
|
+
return scalars[0]
|
50
|
+
|
51
|
+
nstructs = sum(pa.types.is_struct(a.type) for a in scalars)
|
52
|
+
if nstructs == 0:
|
53
|
+
# Then we have conflicting scalars and we choose the last.
|
54
|
+
return scalars[-1]
|
55
|
+
|
56
|
+
if nstructs != len(scalars):
|
57
|
+
raise ValueError("Cannot merge scalars with non-scalars.")
|
58
|
+
|
59
|
+
data = defaultdict(list)
|
60
|
+
for scalar in scalars:
|
61
|
+
for field in scalar.type:
|
62
|
+
data[field.name].append(scalar[field.name])
|
63
|
+
|
64
|
+
return pa.scalar({k: merge_scalars(*v) for k, v in data.items()})
|
65
|
+
|
66
|
+
|
67
|
+
def null_table(schema: pa.Schema, length: int = 0) -> pa.Table:
|
68
|
+
# We add an extra nulls column to ensure the length is correctly applied.
|
69
|
+
return pa.table(
|
70
|
+
[pa.nulls(length, type=field.type) for field in schema] + [pa.nulls(length)],
|
71
|
+
schema=pa.schema(list(schema) + [pa.field("__", type=pa.null())]),
|
72
|
+
).drop(["__"])
|
73
|
+
|
74
|
+
|
75
|
+
def coalesce_all(table: pa.Table) -> pa.Table:
|
76
|
+
"""Coalesce all columns that share the same name."""
|
77
|
+
columns: dict[str, list[pa.Array]] = defaultdict(list)
|
78
|
+
for i, col in enumerate(table.column_names):
|
79
|
+
columns[col].append(table[i])
|
80
|
+
|
81
|
+
data = []
|
82
|
+
names = []
|
83
|
+
for col, arrays in columns.items():
|
84
|
+
names.append(col)
|
85
|
+
if len(arrays) == 1:
|
86
|
+
data.append(arrays[0])
|
87
|
+
else:
|
88
|
+
data.append(pc.coalesce(*arrays))
|
89
|
+
|
90
|
+
return pa.Table.from_arrays(data, names=names)
|
91
|
+
|
92
|
+
|
93
|
+
def join(left: pa.Table, right: pa.Table, keys: list[str], join_type: str) -> pa.Table:
|
94
|
+
"""Arrow's builtin join doesn't support struct columns. So we join ourselves and zip them in."""
|
95
|
+
# TODO(ngates): if join_type == inner, we may have better luck performing two index_in operations since this
|
96
|
+
# also preserves sort order.
|
97
|
+
lhs = left.select(keys).add_column(0, "__lhs", arange(len(left)))
|
98
|
+
rhs = right.select(keys).add_column(0, "__rhs", arange(len(right)))
|
99
|
+
joined = lhs.join(rhs, keys=keys, join_type=join_type).sort_by([(k, "ascending") for k in keys])
|
100
|
+
return zip_tables(
|
101
|
+
[joined.select(keys), left.take(joined["__lhs"]).drop(keys), right.take(joined["__rhs"]).drop(keys)]
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
def nest_structs(array: pa.StructArray | pa.StructScalar | dict) -> dict:
|
106
|
+
"""Turn a struct-like value with dot-separated column names into a nested dictionary."""
|
107
|
+
data = {}
|
108
|
+
|
109
|
+
if isinstance(array, pa.StructArray | pa.StructScalar):
|
110
|
+
array = {f.name: field(array, f.name) for f in array.type}
|
111
|
+
|
112
|
+
for name in array.keys():
|
113
|
+
if "." not in name:
|
114
|
+
data[name] = array[name]
|
115
|
+
continue
|
116
|
+
|
117
|
+
parts = name.split(".")
|
118
|
+
child_data = data
|
119
|
+
for part in parts[:-1]:
|
120
|
+
if part not in child_data:
|
121
|
+
child_data[part] = {}
|
122
|
+
child_data = child_data[part]
|
123
|
+
child_data[parts[-1]] = array[name]
|
124
|
+
|
125
|
+
return data
|
126
|
+
|
127
|
+
|
128
|
+
def flatten_struct_table(table: pa.Table, separator=".") -> pa.Table:
|
129
|
+
"""Turn a nested struct table into a flat table with dot-separated names."""
|
130
|
+
data = []
|
131
|
+
names = []
|
132
|
+
|
133
|
+
def _unfold(array: pa.Array, prefix: str):
|
134
|
+
if pa.types.is_struct(array.type):
|
135
|
+
if isinstance(array, pa.ChunkedArray):
|
136
|
+
array = array.combine_chunks()
|
137
|
+
for f in array.type:
|
138
|
+
_unfold(field(array, f.name), f"{prefix}{separator}{f.name}")
|
139
|
+
else:
|
140
|
+
data.append(array)
|
141
|
+
names.append(prefix)
|
142
|
+
|
143
|
+
for col in table.column_names:
|
144
|
+
_unfold(table[col], col)
|
145
|
+
|
146
|
+
return pa.Table.from_arrays(data, names=names)
|
147
|
+
|
148
|
+
|
149
|
+
def struct_array(fields: list[tuple[str, bool, pa.Array]], /, mask: list[bool] | None = None) -> pa.StructArray:
|
150
|
+
return pa.StructArray.from_arrays(
|
151
|
+
arrays=[x[2] for x in fields],
|
152
|
+
fields=[pa.field(x[0], type=x[2].type, nullable=x[1]) for x in fields],
|
153
|
+
mask=pa.array(mask) if mask else mask,
|
154
|
+
)
|
155
|
+
|
156
|
+
|
157
|
+
def table(fields: list[tuple[str, bool, pa.Array]], /) -> pa.Table:
|
158
|
+
return pa.Table.from_struct_array(struct_array(fields))
|
159
|
+
|
160
|
+
|
161
|
+
def dict_to_table(data) -> pa.Table:
|
162
|
+
return pa.Table.from_struct_array(dict_to_struct_array(data))
|
163
|
+
|
164
|
+
|
165
|
+
def dict_to_struct_array(data: dict | pa.StructArray, propagate_nulls: bool = False) -> pa.StructArray:
|
166
|
+
"""Convert a nested dictionary of arrays to a table with nested structs."""
|
167
|
+
if isinstance(data, pa.Array):
|
168
|
+
return data
|
169
|
+
arrays = [dict_to_struct_array(value) for value in data.values()]
|
170
|
+
return pa.StructArray.from_arrays(
|
171
|
+
arrays,
|
172
|
+
names=list(data.keys()),
|
173
|
+
mask=reduce(pc.and_, [pc.is_null(array) for array in arrays]) if propagate_nulls else None,
|
174
|
+
)
|
175
|
+
|
176
|
+
|
177
|
+
def struct_array_to_dict[T](array: pa.StructArray, array_fn: Callable[[pa.Array], T] = lambda a: a) -> dict | T:
|
178
|
+
"""Convert a struct array to a nested dictionary."""
|
179
|
+
if not pa.types.is_struct(array.type):
|
180
|
+
return array_fn(array)
|
181
|
+
if isinstance(array, pa.ChunkedArray):
|
182
|
+
array = array.combine_chunks()
|
183
|
+
return {field.name: struct_array_to_dict(array.field(i), array_fn=array_fn) for i, field in enumerate(array.type)}
|
184
|
+
|
185
|
+
|
186
|
+
def table_to_struct_array(table: pa.Table) -> pa.StructArray:
|
187
|
+
if not table.num_rows:
|
188
|
+
return pa.array([], type=pa.struct(table.schema))
|
189
|
+
array = table.to_struct_array()
|
190
|
+
if isinstance(array, pa.ChunkedArray):
|
191
|
+
array = array.combine_chunks()
|
192
|
+
return array
|
193
|
+
|
194
|
+
|
195
|
+
def table_from_struct_array(array: pa.StructArray | pa.ChunkedArray):
|
196
|
+
if len(array) == 0:
|
197
|
+
return null_table(pa.schema(array.type))
|
198
|
+
return pa.Table.from_struct_array(array)
|
199
|
+
|
200
|
+
|
201
|
+
def field(value: pa.StructArray | pa.StructScalar, name: str) -> pa.Array | pa.Scalar:
|
202
|
+
"""Get a field from a struct-like value."""
|
203
|
+
if isinstance(value, pa.StructScalar):
|
204
|
+
return value[name]
|
205
|
+
return value.field(name)
|
206
|
+
|
207
|
+
|
208
|
+
def concat_tables(tables: list[pa.Table]) -> pa.Table:
|
209
|
+
"""
|
210
|
+
Concatenate pyarrow.Table objects, filling "missing" data with appropriate null arrays
|
211
|
+
and casting arrays to the most common denominator type that fits all fields.
|
212
|
+
"""
|
213
|
+
if len(tables) == 1:
|
214
|
+
return tables[0]
|
215
|
+
else:
|
216
|
+
return pa.concat_tables(tables, promote_options="permissive")
|