pyspiral 0.6.8__cp312-abi3-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. pyspiral-0.6.8.dist-info/METADATA +51 -0
  2. pyspiral-0.6.8.dist-info/RECORD +102 -0
  3. pyspiral-0.6.8.dist-info/WHEEL +4 -0
  4. pyspiral-0.6.8.dist-info/entry_points.txt +2 -0
  5. spiral/__init__.py +35 -0
  6. spiral/_lib.abi3.so +0 -0
  7. spiral/adbc.py +411 -0
  8. spiral/api/__init__.py +78 -0
  9. spiral/api/admin.py +15 -0
  10. spiral/api/client.py +164 -0
  11. spiral/api/filesystems.py +134 -0
  12. spiral/api/key_space_indexes.py +23 -0
  13. spiral/api/organizations.py +77 -0
  14. spiral/api/projects.py +219 -0
  15. spiral/api/telemetry.py +19 -0
  16. spiral/api/text_indexes.py +56 -0
  17. spiral/api/types.py +22 -0
  18. spiral/api/workers.py +40 -0
  19. spiral/api/workloads.py +52 -0
  20. spiral/arrow_.py +216 -0
  21. spiral/cli/__init__.py +88 -0
  22. spiral/cli/__main__.py +4 -0
  23. spiral/cli/admin.py +14 -0
  24. spiral/cli/app.py +104 -0
  25. spiral/cli/console.py +95 -0
  26. spiral/cli/fs.py +76 -0
  27. spiral/cli/iceberg.py +97 -0
  28. spiral/cli/key_spaces.py +89 -0
  29. spiral/cli/login.py +24 -0
  30. spiral/cli/orgs.py +89 -0
  31. spiral/cli/printer.py +53 -0
  32. spiral/cli/projects.py +147 -0
  33. spiral/cli/state.py +5 -0
  34. spiral/cli/tables.py +174 -0
  35. spiral/cli/telemetry.py +17 -0
  36. spiral/cli/text.py +115 -0
  37. spiral/cli/types.py +50 -0
  38. spiral/cli/workloads.py +58 -0
  39. spiral/client.py +178 -0
  40. spiral/core/__init__.pyi +0 -0
  41. spiral/core/_tools/__init__.pyi +5 -0
  42. spiral/core/authn/__init__.pyi +27 -0
  43. spiral/core/client/__init__.pyi +237 -0
  44. spiral/core/table/__init__.pyi +101 -0
  45. spiral/core/table/manifests/__init__.pyi +35 -0
  46. spiral/core/table/metastore/__init__.pyi +58 -0
  47. spiral/core/table/spec/__init__.pyi +213 -0
  48. spiral/dataloader.py +285 -0
  49. spiral/dataset.py +255 -0
  50. spiral/datetime_.py +27 -0
  51. spiral/debug/__init__.py +0 -0
  52. spiral/debug/manifests.py +87 -0
  53. spiral/debug/metrics.py +56 -0
  54. spiral/debug/scan.py +266 -0
  55. spiral/expressions/__init__.py +276 -0
  56. spiral/expressions/base.py +157 -0
  57. spiral/expressions/http.py +86 -0
  58. spiral/expressions/io.py +100 -0
  59. spiral/expressions/list_.py +68 -0
  60. spiral/expressions/mp4.py +62 -0
  61. spiral/expressions/png.py +18 -0
  62. spiral/expressions/qoi.py +18 -0
  63. spiral/expressions/refs.py +58 -0
  64. spiral/expressions/str_.py +39 -0
  65. spiral/expressions/struct.py +59 -0
  66. spiral/expressions/text.py +62 -0
  67. spiral/expressions/tiff.py +223 -0
  68. spiral/expressions/udf.py +46 -0
  69. spiral/grpc_.py +32 -0
  70. spiral/iceberg.py +31 -0
  71. spiral/iterable_dataset.py +106 -0
  72. spiral/key_space_index.py +44 -0
  73. spiral/project.py +199 -0
  74. spiral/protogen/_/__init__.py +0 -0
  75. spiral/protogen/_/arrow/__init__.py +0 -0
  76. spiral/protogen/_/arrow/flight/__init__.py +0 -0
  77. spiral/protogen/_/arrow/flight/protocol/__init__.py +0 -0
  78. spiral/protogen/_/arrow/flight/protocol/sql/__init__.py +2548 -0
  79. spiral/protogen/_/google/__init__.py +0 -0
  80. spiral/protogen/_/google/protobuf/__init__.py +2310 -0
  81. spiral/protogen/_/message_pool.py +3 -0
  82. spiral/protogen/_/py.typed +0 -0
  83. spiral/protogen/_/scandal/__init__.py +190 -0
  84. spiral/protogen/_/spfs/__init__.py +72 -0
  85. spiral/protogen/_/spql/__init__.py +61 -0
  86. spiral/protogen/_/substrait/__init__.py +6196 -0
  87. spiral/protogen/_/substrait/extensions/__init__.py +169 -0
  88. spiral/protogen/__init__.py +0 -0
  89. spiral/protogen/util.py +41 -0
  90. spiral/py.typed +0 -0
  91. spiral/scan.py +285 -0
  92. spiral/server.py +17 -0
  93. spiral/settings.py +114 -0
  94. spiral/snapshot.py +56 -0
  95. spiral/streaming_/__init__.py +3 -0
  96. spiral/streaming_/reader.py +133 -0
  97. spiral/streaming_/stream.py +157 -0
  98. spiral/substrait_.py +274 -0
  99. spiral/table.py +293 -0
  100. spiral/text_index.py +17 -0
  101. spiral/transaction.py +58 -0
  102. spiral/types_.py +6 -0
spiral/api/projects.py ADDED
@@ -0,0 +1,219 @@
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from .client import Paged, PagedResponse, ServiceBase
6
+ from .types import OrgId, ProjectId, RoleId
7
+
8
+
9
+ class Project(BaseModel):
10
+ id: ProjectId
11
+ org_id: OrgId
12
+ name: str | None = None
13
+
14
+
15
+ class CreateProjectRequest(BaseModel):
16
+ id_prefix: str | None = None
17
+ name: str | None = None
18
+
19
+
20
+ class CreateProjectResponse(BaseModel):
21
+ project: Project
22
+
23
+
24
+ class Grant(BaseModel):
25
+ id: str
26
+ project_id: ProjectId
27
+ role_id: RoleId
28
+ principal: str
29
+ conditions: dict | None = None
30
+
31
+
32
+ class OrgRolePrincipalConditions(BaseModel):
33
+ type: Literal["org_role"] = "org_role"
34
+ org_id: OrgId
35
+ role: str
36
+
37
+
38
+ class OrgUserPrincipalConditions(BaseModel):
39
+ type: Literal["org_user"] = "org_user"
40
+ org_id: OrgId
41
+ user_id: str
42
+
43
+
44
+ class WorkloadPrincipalConditions(BaseModel):
45
+ type: Literal["workload"] = "workload"
46
+ workload_id: str
47
+
48
+
49
+ class GitHubConditions(BaseModel):
50
+ environment: str | None = None
51
+ ref: str | None = None
52
+ ref_type: str | None = None
53
+ sha: str | None = None
54
+ repository: str | None = None
55
+ repository_owner: str | None = None
56
+ repository_visibility: str | None = None
57
+ repository_id: str | None = None
58
+ repository_owner_id: str | None = None
59
+ run_id: str | None = None
60
+ run_number: str | None = None
61
+ run_attempt: str | None = None
62
+ runner_environment: str | None = None
63
+ actor_id: str | None = None
64
+ actor: str | None = None
65
+ workflow: str | None = None
66
+ head_ref: str | None = None
67
+ base_ref: str | None = None
68
+ job_workflow_ref: str | None = None
69
+ event_name: str | None = None
70
+
71
+
72
+ class GitHubPrincipalConditions(BaseModel):
73
+ type: Literal["github"] = "github"
74
+ org: str
75
+ repo: str
76
+ conditions: GitHubConditions | None = None
77
+
78
+
79
+ class ModalConditions(BaseModel):
80
+ app_id: str | None = None
81
+ app_name: str | None = None
82
+ function_id: str | None = None
83
+ function_name: str | None = None
84
+ container_id: str | None = None
85
+
86
+
87
+ class ModalPrincipalConditions(BaseModel):
88
+ type: Literal["modal"] = "modal"
89
+
90
+ # A Modal App is a group of functions and classes that are deployed together.
91
+ # See https://modal.com/docs/guide/apps. Nick and Marko discussed having an app_name
92
+ # here as well and decided to leave it out for now with the assumption that people
93
+ # will want to authorize the whole Modal environment to access Spiral (their data).
94
+ workspace_id: str
95
+ # Environments are sub-divisions of workspaces. Name is unique within a workspace.
96
+ # See https://modal.com/docs/guide/environments
97
+ environment_name: str
98
+
99
+ conditions: ModalConditions | None = None
100
+
101
+
102
+ class GcpServiceAccountPrincipalConditions(BaseModel):
103
+ type: Literal["gcp"] = "gcp"
104
+ service_account: str
105
+ unique_id: str
106
+
107
+
108
+ class AwsAssumedRolePrincipalConditions(BaseModel):
109
+ type: Literal["aws"] = "aws"
110
+ account_id: str
111
+ role_name: str
112
+
113
+
114
+ PrincipalConditions = Annotated[
115
+ OrgRolePrincipalConditions
116
+ | OrgUserPrincipalConditions
117
+ | WorkloadPrincipalConditions
118
+ | GitHubPrincipalConditions
119
+ | ModalPrincipalConditions
120
+ | GcpServiceAccountPrincipalConditions
121
+ | AwsAssumedRolePrincipalConditions,
122
+ Field(discriminator="type"),
123
+ ]
124
+
125
+
126
+ class GrantRoleRequest(BaseModel):
127
+ role_id: RoleId
128
+ principal: PrincipalConditions
129
+
130
+
131
+ class GrantRoleResponse(BaseModel):
132
+ grant: Grant
133
+
134
+
135
+ class TableResource(BaseModel):
136
+ id: str
137
+ project_id: ProjectId
138
+ dataset: str
139
+ table: str
140
+
141
+
142
+ class TextIndexResource(BaseModel):
143
+ id: str
144
+ project_id: ProjectId
145
+ name: str
146
+
147
+
148
+ class KeySpaceIndexResource(BaseModel):
149
+ id: str
150
+ project_id: ProjectId
151
+ name: str
152
+
153
+
154
+ class ProjectService(ServiceBase):
155
+ """Service for project operations."""
156
+
157
+ def create(self, request: CreateProjectRequest) -> CreateProjectResponse:
158
+ """Create a new project."""
159
+ return self.client.post("/v1/projects", request, CreateProjectResponse)
160
+
161
+ def list(self) -> Paged[Project]:
162
+ """List projects."""
163
+ return self.client.paged("/v1/projects", PagedResponse[Project])
164
+
165
+ def list_tables(
166
+ self, project_id: ProjectId, dataset: str | None = None, table: str | None = None
167
+ ) -> Paged[TableResource]:
168
+ """List tables in a project."""
169
+ params = {}
170
+ if dataset:
171
+ params["dataset"] = dataset
172
+ if table:
173
+ params["table"] = table
174
+ return self.client.paged(f"/v1/projects/{project_id}/tables", PagedResponse[TableResource], params=params)
175
+
176
+ def list_text_indexes(self, project_id: ProjectId, name: str | None = None) -> Paged[TextIndexResource]:
177
+ """List text indexes in a project."""
178
+ params = {}
179
+ if name:
180
+ params["name"] = name
181
+ return self.client.paged(
182
+ f"/v1/projects/{project_id}/text-indexes", PagedResponse[TextIndexResource], params=params
183
+ )
184
+
185
+ def list_key_space_indexes(self, project_id: ProjectId, name: str | None = None) -> Paged[KeySpaceIndexResource]:
186
+ """List key space indexes in a project."""
187
+ params = {}
188
+ if name:
189
+ params["name"] = name
190
+ return self.client.paged(
191
+ f"/v1/projects/{project_id}/key-space-indexes", PagedResponse[KeySpaceIndexResource], params=params
192
+ )
193
+
194
+ def get(self, project_id: ProjectId) -> Project:
195
+ """Get a project."""
196
+ return self.client.get(f"/v1/projects/{project_id}", Project)
197
+
198
+ def grant_role(self, project_id: ProjectId, request: GrantRoleRequest) -> GrantRoleResponse:
199
+ """Grant a role to a principal."""
200
+ return self.client.post(f"/v1/projects/{project_id}/grants", request, GrantRoleResponse)
201
+
202
+ def list_grants(
203
+ self,
204
+ project_id: ProjectId,
205
+ principal: str | None = None,
206
+ ) -> Paged[Grant]:
207
+ """List active project grants."""
208
+ params = {}
209
+ if principal:
210
+ params["principal"] = principal
211
+ return self.client.paged(f"/v1/projects/{project_id}/grants", PagedResponse[Grant], params=params)
212
+
213
+ def get_grant(self, grant_id: str) -> Grant:
214
+ """Get a grant."""
215
+ return self.client.get(f"/v1/grants/{grant_id}", Grant)
216
+
217
+ def revoke_grant(self, grant_id: str):
218
+ """Revoke a grant."""
219
+ return self.client.delete(f"/v1/grants/{grant_id}", type[None])
@@ -0,0 +1,19 @@
1
+ from pydantic import BaseModel
2
+
3
+ from .client import ServiceBase
4
+
5
+
6
+ class IssueExportTokenRequest(BaseModel):
7
+ pass
8
+
9
+
10
+ class IssueExportTokenResponse(BaseModel):
11
+ token: str
12
+
13
+
14
+ class TelemetryService(ServiceBase):
15
+ """Service for telemetry operations."""
16
+
17
+ def issue_export_token(self) -> IssueExportTokenResponse:
18
+ """Issue telemetry export token."""
19
+ return self.client.put("/v1/telemetry/token", IssueExportTokenRequest(), IssueExportTokenResponse)
@@ -0,0 +1,56 @@
1
+ from pydantic import BaseModel
2
+
3
+ from .client import Paged, PagedResponse, ServiceBase
4
+ from .types import IndexId, ProjectId, WorkerId
5
+ from .workers import CPU, GcpRegion, Memory, ResourceClass
6
+
7
+
8
+ class TextSearchWorker(BaseModel):
9
+ worker_id: WorkerId
10
+ project_id: ProjectId
11
+ index_id: IndexId
12
+ url: str | None
13
+
14
+
15
+ class CreateWorkerRequest(BaseModel):
16
+ cpu: CPU
17
+ memory: Memory
18
+ region: GcpRegion
19
+
20
+
21
+ class CreateWorkerResponse(BaseModel):
22
+ worker_id: WorkerId
23
+
24
+
25
+ class SyncIndexRequest(BaseModel):
26
+ """Request to sync a text index."""
27
+
28
+ resources: ResourceClass
29
+
30
+
31
+ class SyncIndexResponse(BaseModel):
32
+ worker_id: WorkerId
33
+
34
+
35
+ class TextIndexesService(ServiceBase):
36
+ """Service for text index operations."""
37
+
38
+ def create_worker(self, index_id: IndexId, request: CreateWorkerRequest) -> CreateWorkerResponse:
39
+ """Create a new search worker."""
40
+ return self.client.post(f"/v1/text-indexes/{index_id}/workers", request, CreateWorkerResponse)
41
+
42
+ def list_workers(self, index_id: IndexId) -> Paged[WorkerId]:
43
+ """List text index workers for the given index."""
44
+ return self.client.paged(f"/v1/text-indexes/{index_id}/workers", PagedResponse[WorkerId])
45
+
46
+ def get_worker(self, worker_id: WorkerId) -> TextSearchWorker:
47
+ """Get a text index worker."""
48
+ return self.client.get(f"/v1/text-index-workers/{worker_id}", TextSearchWorker)
49
+
50
+ def shutdown_worker(self, worker_id: WorkerId) -> None:
51
+ """Shutdown a text index worker."""
52
+ return self.client.delete(f"/v1/text-index-workers/{worker_id}", type[None])
53
+
54
+ def sync_index(self, index_id: IndexId, request: SyncIndexRequest) -> SyncIndexResponse:
55
+ """Start a job to sync an index."""
56
+ return self.client.post(f"/v1/text-indexes/{index_id}/sync", request, SyncIndexResponse)
spiral/api/types.py ADDED
@@ -0,0 +1,22 @@
1
+ from typing import Annotated
2
+
3
+ from pydantic import AfterValidator, StringConstraints
4
+
5
+
6
+ def _validate_root_uri(uri: str) -> str:
7
+ if uri.endswith("/"):
8
+ raise ValueError("Root URI must not end with a slash.")
9
+ return uri
10
+
11
+
12
+ UserId = str
13
+ OrgId = str
14
+ ProjectId = str
15
+ RoleId = str
16
+ IndexId = str
17
+ WorkerId = str
18
+
19
+ RootUri = Annotated[str, AfterValidator(_validate_root_uri)]
20
+ DatasetName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]+$")]
21
+ TableName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]*$")]
22
+ IndexName = Annotated[str, StringConstraints(max_length=128, pattern=r"^[a-zA-Z_][a-zA-Z0-9_-]*$")]
spiral/api/workers.py ADDED
@@ -0,0 +1,40 @@
1
+ from enum import Enum, IntEnum
2
+
3
+
4
+ class CPU(IntEnum):
5
+ ONE = 1
6
+ TWO = 2
7
+ FOUR = 4
8
+ EIGHT = 8
9
+
10
+ def __str__(self):
11
+ return str(self.value)
12
+
13
+
14
+ class Memory(str, Enum):
15
+ MB_512 = "512Mi"
16
+ GB_1 = "1Gi"
17
+ GB_2 = "2Gi"
18
+ GB_4 = "4Gi"
19
+ GB_8 = "8Gi"
20
+
21
+ def __str__(self):
22
+ return self.value
23
+
24
+
25
+ class GcpRegion(str, Enum):
26
+ US_EAST4 = "us-east4"
27
+ EUROPE_WEST4 = "europe-west4"
28
+
29
+ def __str__(self):
30
+ return self.value
31
+
32
+
33
+ class ResourceClass(str, Enum):
34
+ """Resource class for text index sync."""
35
+
36
+ SMALL = "small"
37
+ LARGE = "large"
38
+
39
+ def __str__(self):
40
+ return self.value
@@ -0,0 +1,52 @@
1
+ from pydantic import BaseModel
2
+
3
+ from .client import Paged, PagedResponse, ServiceBase
4
+ from .types import ProjectId
5
+
6
+
7
+ class Workload(BaseModel):
8
+ id: str
9
+ project_id: ProjectId
10
+ name: str | None = None
11
+
12
+
13
+ class CreateWorkloadRequest(BaseModel):
14
+ name: str | None = None
15
+
16
+
17
+ class CreateWorkloadResponse(BaseModel):
18
+ workload: Workload
19
+
20
+
21
+ class IssueWorkloadCredentialsResponse(BaseModel):
22
+ client_id: str
23
+ client_secret: str
24
+ revoked_client_id: str | None = None
25
+
26
+
27
+ class WorkloadService(ServiceBase):
28
+ """Service for workload operations."""
29
+
30
+ def create(self, project_id: ProjectId, request: CreateWorkloadRequest) -> CreateWorkloadResponse:
31
+ """Create a new workload."""
32
+ return self.client.post(f"/v1/projects/{project_id}/workloads", request, CreateWorkloadResponse)
33
+
34
+ def list(self, project_id: ProjectId) -> Paged[Workload]:
35
+ """List active project workloads."""
36
+ return self.client.paged(f"/projects/{project_id}/workloads", PagedResponse[Workload])
37
+
38
+ def get(self, workload_id: str) -> Workload:
39
+ """Get a workload."""
40
+ return self.client.get(f"/v1/workloads/{workload_id}", Workload)
41
+
42
+ def deactivate(self, workload_id: str) -> None:
43
+ """De-activate a workload."""
44
+ return self.client.delete(f"/v1/workloads/{workload_id}", None)
45
+
46
+ def issue_credentials(self, workload_id: str) -> IssueWorkloadCredentialsResponse:
47
+ """Issue workload credentials."""
48
+ return self.client.post(f"/v1/workloads/{workload_id}/credentials", None, IssueWorkloadCredentialsResponse)
49
+
50
+ def revoke_credentials(self, client_id: str) -> None:
51
+ """Revoke workload credentials."""
52
+ return self.client.delete(f"/v1/credentials/{client_id}", None)
spiral/arrow_.py ADDED
@@ -0,0 +1,216 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Callable, Iterable
3
+ from functools import reduce
4
+
5
+ import numpy as np
6
+ import pyarrow as pa
7
+ from pyarrow import compute as pc
8
+
9
+
10
+ def arange(*args, **kwargs) -> pa.Array:
11
+ return pa.array(np.arange(*args, **kwargs), type=pa.int32())
12
+
13
+
14
+ def zip_tables(tables: Iterable[pa.Table]) -> pa.Table:
15
+ data = []
16
+ names = []
17
+ for table in tables:
18
+ data.extend(table.columns)
19
+ names.extend(table.column_names)
20
+ return pa.Table.from_arrays(data, names=names)
21
+
22
+
23
+ def merge_arrays(*arrays: pa.StructArray) -> pa.StructArray:
24
+ """Recursively merge arrays into nested struct arrays."""
25
+ if len(arrays) == 1:
26
+ return arrays[0]
27
+
28
+ nstructs = sum(pa.types.is_struct(a.type) for a in arrays)
29
+ if nstructs == 0:
30
+ # Then we have conflicting arrays and we choose the last.
31
+ return arrays[-1]
32
+
33
+ if nstructs != len(arrays):
34
+ raise ValueError("Cannot merge structs with non-structs.")
35
+
36
+ data = defaultdict(list)
37
+ for array in arrays:
38
+ if isinstance(array, pa.ChunkedArray):
39
+ array = array.combine_chunks()
40
+ for field in array.type:
41
+ data[field.name].append(array.field(field.name))
42
+
43
+ return pa.StructArray.from_arrays([merge_arrays(*v) for v in data.values()], names=list(data.keys()))
44
+
45
+
46
+ def merge_scalars(*scalars: pa.StructScalar) -> pa.StructScalar:
47
+ """Recursively merge scalars into nested struct scalars."""
48
+ if len(scalars) == 1:
49
+ return scalars[0]
50
+
51
+ nstructs = sum(pa.types.is_struct(a.type) for a in scalars)
52
+ if nstructs == 0:
53
+ # Then we have conflicting scalars and we choose the last.
54
+ return scalars[-1]
55
+
56
+ if nstructs != len(scalars):
57
+ raise ValueError("Cannot merge scalars with non-scalars.")
58
+
59
+ data = defaultdict(list)
60
+ for scalar in scalars:
61
+ for field in scalar.type:
62
+ data[field.name].append(scalar[field.name])
63
+
64
+ return pa.scalar({k: merge_scalars(*v) for k, v in data.items()})
65
+
66
+
67
+ def null_table(schema: pa.Schema, length: int = 0) -> pa.Table:
68
+ # We add an extra nulls column to ensure the length is correctly applied.
69
+ return pa.table(
70
+ [pa.nulls(length, type=field.type) for field in schema] + [pa.nulls(length)],
71
+ schema=pa.schema(list(schema) + [pa.field("__", type=pa.null())]),
72
+ ).drop(["__"])
73
+
74
+
75
+ def coalesce_all(table: pa.Table) -> pa.Table:
76
+ """Coalesce all columns that share the same name."""
77
+ columns: dict[str, list[pa.Array]] = defaultdict(list)
78
+ for i, col in enumerate(table.column_names):
79
+ columns[col].append(table[i])
80
+
81
+ data = []
82
+ names = []
83
+ for col, arrays in columns.items():
84
+ names.append(col)
85
+ if len(arrays) == 1:
86
+ data.append(arrays[0])
87
+ else:
88
+ data.append(pc.coalesce(*arrays))
89
+
90
+ return pa.Table.from_arrays(data, names=names)
91
+
92
+
93
+ def join(left: pa.Table, right: pa.Table, keys: list[str], join_type: str) -> pa.Table:
94
+ """Arrow's builtin join doesn't support struct columns. So we join ourselves and zip them in."""
95
+ # TODO(ngates): if join_type == inner, we may have better luck performing two index_in operations since this
96
+ # also preserves sort order.
97
+ lhs = left.select(keys).add_column(0, "__lhs", arange(len(left)))
98
+ rhs = right.select(keys).add_column(0, "__rhs", arange(len(right)))
99
+ joined = lhs.join(rhs, keys=keys, join_type=join_type).sort_by([(k, "ascending") for k in keys])
100
+ return zip_tables(
101
+ [joined.select(keys), left.take(joined["__lhs"]).drop(keys), right.take(joined["__rhs"]).drop(keys)]
102
+ )
103
+
104
+
105
+ def nest_structs(array: pa.StructArray | pa.StructScalar | dict) -> dict:
106
+ """Turn a struct-like value with dot-separated column names into a nested dictionary."""
107
+ data = {}
108
+
109
+ if isinstance(array, pa.StructArray | pa.StructScalar):
110
+ array = {f.name: field(array, f.name) for f in array.type}
111
+
112
+ for name in array.keys():
113
+ if "." not in name:
114
+ data[name] = array[name]
115
+ continue
116
+
117
+ parts = name.split(".")
118
+ child_data = data
119
+ for part in parts[:-1]:
120
+ if part not in child_data:
121
+ child_data[part] = {}
122
+ child_data = child_data[part]
123
+ child_data[parts[-1]] = array[name]
124
+
125
+ return data
126
+
127
+
128
+ def flatten_struct_table(table: pa.Table, separator=".") -> pa.Table:
129
+ """Turn a nested struct table into a flat table with dot-separated names."""
130
+ data = []
131
+ names = []
132
+
133
+ def _unfold(array: pa.Array, prefix: str):
134
+ if pa.types.is_struct(array.type):
135
+ if isinstance(array, pa.ChunkedArray):
136
+ array = array.combine_chunks()
137
+ for f in array.type:
138
+ _unfold(field(array, f.name), f"{prefix}{separator}{f.name}")
139
+ else:
140
+ data.append(array)
141
+ names.append(prefix)
142
+
143
+ for col in table.column_names:
144
+ _unfold(table[col], col)
145
+
146
+ return pa.Table.from_arrays(data, names=names)
147
+
148
+
149
+ def struct_array(fields: list[tuple[str, bool, pa.Array]], /, mask: list[bool] | None = None) -> pa.StructArray:
150
+ return pa.StructArray.from_arrays(
151
+ arrays=[x[2] for x in fields],
152
+ fields=[pa.field(x[0], type=x[2].type, nullable=x[1]) for x in fields],
153
+ mask=pa.array(mask) if mask else mask,
154
+ )
155
+
156
+
157
+ def table(fields: list[tuple[str, bool, pa.Array]], /) -> pa.Table:
158
+ return pa.Table.from_struct_array(struct_array(fields))
159
+
160
+
161
+ def dict_to_table(data) -> pa.Table:
162
+ return pa.Table.from_struct_array(dict_to_struct_array(data))
163
+
164
+
165
+ def dict_to_struct_array(data: dict | pa.StructArray, propagate_nulls: bool = False) -> pa.StructArray:
166
+ """Convert a nested dictionary of arrays to a table with nested structs."""
167
+ if isinstance(data, pa.Array):
168
+ return data
169
+ arrays = [dict_to_struct_array(value) for value in data.values()]
170
+ return pa.StructArray.from_arrays(
171
+ arrays,
172
+ names=list(data.keys()),
173
+ mask=reduce(pc.and_, [pc.is_null(array) for array in arrays]) if propagate_nulls else None,
174
+ )
175
+
176
+
177
+ def struct_array_to_dict[T](array: pa.StructArray, array_fn: Callable[[pa.Array], T] = lambda a: a) -> dict | T:
178
+ """Convert a struct array to a nested dictionary."""
179
+ if not pa.types.is_struct(array.type):
180
+ return array_fn(array)
181
+ if isinstance(array, pa.ChunkedArray):
182
+ array = array.combine_chunks()
183
+ return {field.name: struct_array_to_dict(array.field(i), array_fn=array_fn) for i, field in enumerate(array.type)}
184
+
185
+
186
+ def table_to_struct_array(table: pa.Table) -> pa.StructArray:
187
+ if not table.num_rows:
188
+ return pa.array([], type=pa.struct(table.schema))
189
+ array = table.to_struct_array()
190
+ if isinstance(array, pa.ChunkedArray):
191
+ array = array.combine_chunks()
192
+ return array
193
+
194
+
195
+ def table_from_struct_array(array: pa.StructArray | pa.ChunkedArray):
196
+ if len(array) == 0:
197
+ return null_table(pa.schema(array.type))
198
+ return pa.Table.from_struct_array(array)
199
+
200
+
201
+ def field(value: pa.StructArray | pa.StructScalar, name: str) -> pa.Array | pa.Scalar:
202
+ """Get a field from a struct-like value."""
203
+ if isinstance(value, pa.StructScalar):
204
+ return value[name]
205
+ return value.field(name)
206
+
207
+
208
+ def concat_tables(tables: list[pa.Table]) -> pa.Table:
209
+ """
210
+ Concatenate pyarrow.Table objects, filling "missing" data with appropriate null arrays
211
+ and casting arrays to the most common denominator type that fits all fields.
212
+ """
213
+ if len(tables) == 1:
214
+ return tables[0]
215
+ else:
216
+ return pa.concat_tables(tables, promote_options="permissive")