ygg 0.1.31__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ygg-0.1.31.dist-info → ygg-0.1.33.dist-info}/METADATA +1 -1
- ygg-0.1.33.dist-info/RECORD +60 -0
- yggdrasil/__init__.py +2 -0
- yggdrasil/databricks/__init__.py +2 -0
- yggdrasil/databricks/compute/__init__.py +2 -0
- yggdrasil/databricks/compute/cluster.py +244 -3
- yggdrasil/databricks/compute/execution_context.py +100 -11
- yggdrasil/databricks/compute/remote.py +24 -0
- yggdrasil/databricks/jobs/__init__.py +5 -0
- yggdrasil/databricks/jobs/config.py +29 -4
- yggdrasil/databricks/sql/__init__.py +2 -0
- yggdrasil/databricks/sql/engine.py +217 -36
- yggdrasil/databricks/sql/exceptions.py +1 -0
- yggdrasil/databricks/sql/statement_result.py +147 -0
- yggdrasil/databricks/sql/types.py +33 -1
- yggdrasil/databricks/workspaces/__init__.py +2 -1
- yggdrasil/databricks/workspaces/filesytem.py +183 -0
- yggdrasil/databricks/workspaces/io.py +387 -9
- yggdrasil/databricks/workspaces/path.py +297 -2
- yggdrasil/databricks/workspaces/path_kind.py +3 -0
- yggdrasil/databricks/workspaces/workspace.py +202 -5
- yggdrasil/dataclasses/__init__.py +2 -0
- yggdrasil/dataclasses/dataclass.py +42 -1
- yggdrasil/libs/__init__.py +2 -0
- yggdrasil/libs/databrickslib.py +9 -0
- yggdrasil/libs/extensions/__init__.py +2 -0
- yggdrasil/libs/extensions/polars_extensions.py +72 -0
- yggdrasil/libs/extensions/spark_extensions.py +116 -0
- yggdrasil/libs/pandaslib.py +7 -0
- yggdrasil/libs/polarslib.py +7 -0
- yggdrasil/libs/sparklib.py +41 -0
- yggdrasil/pyutils/__init__.py +4 -0
- yggdrasil/pyutils/callable_serde.py +106 -0
- yggdrasil/pyutils/exceptions.py +16 -0
- yggdrasil/pyutils/modules.py +44 -1
- yggdrasil/pyutils/parallel.py +29 -0
- yggdrasil/pyutils/python_env.py +301 -0
- yggdrasil/pyutils/retry.py +57 -0
- yggdrasil/requests/__init__.py +4 -0
- yggdrasil/requests/msal.py +124 -3
- yggdrasil/requests/session.py +18 -0
- yggdrasil/types/__init__.py +2 -0
- yggdrasil/types/cast/__init__.py +2 -1
- yggdrasil/types/cast/arrow_cast.py +123 -1
- yggdrasil/types/cast/cast_options.py +119 -1
- yggdrasil/types/cast/pandas_cast.py +29 -0
- yggdrasil/types/cast/polars_cast.py +47 -0
- yggdrasil/types/cast/polars_pandas_cast.py +29 -0
- yggdrasil/types/cast/registry.py +176 -0
- yggdrasil/types/cast/spark_cast.py +76 -0
- yggdrasil/types/cast/spark_pandas_cast.py +29 -0
- yggdrasil/types/cast/spark_polars_cast.py +28 -0
- yggdrasil/types/libs.py +2 -0
- yggdrasil/types/python_arrow.py +191 -0
- yggdrasil/types/python_defaults.py +73 -0
- yggdrasil/version.py +1 -0
- ygg-0.1.31.dist-info/RECORD +0 -59
- {ygg-0.1.31.dist-info → ygg-0.1.33.dist-info}/WHEEL +0 -0
- {ygg-0.1.31.dist-info → ygg-0.1.33.dist-info}/entry_points.txt +0 -0
- {ygg-0.1.31.dist-info → ygg-0.1.33.dist-info}/licenses/LICENSE +0 -0
- {ygg-0.1.31.dist-info → ygg-0.1.33.dist-info}/top_level.txt +0 -0
yggdrasil/requests/msal.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""MSAL-backed authentication helpers for requests sessions."""
|
|
2
|
+
|
|
1
3
|
# auth_session.py
|
|
2
4
|
import os
|
|
3
5
|
import time
|
|
@@ -27,6 +29,15 @@ __all__ = [
|
|
|
27
29
|
|
|
28
30
|
@dataclass
|
|
29
31
|
class MSALAuth:
|
|
32
|
+
"""Configuration and token cache for MSAL client credential flows.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
tenant_id: Azure tenant ID.
|
|
36
|
+
client_id: Azure application client ID.
|
|
37
|
+
client_secret: Azure application client secret.
|
|
38
|
+
authority: Optional authority URL override.
|
|
39
|
+
scopes: List of scopes to request.
|
|
40
|
+
"""
|
|
30
41
|
tenant_id: Optional[str] = None
|
|
31
42
|
client_id: Optional[str] = None
|
|
32
43
|
client_secret: Optional[str] = None
|
|
@@ -38,12 +49,34 @@ class MSALAuth:
|
|
|
38
49
|
_access_token: Optional[str] = None
|
|
39
50
|
|
|
40
51
|
def __setitem__(self, key, value):
|
|
52
|
+
"""Set an attribute via mapping-style assignment.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
key: Attribute name to set.
|
|
56
|
+
value: Value to assign.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
None.
|
|
60
|
+
"""
|
|
41
61
|
self.__setattr__(key, value)
|
|
42
62
|
|
|
43
63
|
def __getitem__(self, item):
|
|
64
|
+
"""Return attribute values via mapping-style access.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
item: Attribute name to fetch.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The attribute value.
|
|
71
|
+
"""
|
|
44
72
|
return getattr(self, item)
|
|
45
73
|
|
|
46
74
|
def __post_init__(self):
|
|
75
|
+
"""Populate defaults from environment variables and validate.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
None.
|
|
79
|
+
"""
|
|
47
80
|
self.tenant_id = self.tenant_id or os.environ.get("AZURE_TENANT_ID")
|
|
48
81
|
self.client_id = self.client_id or os.environ.get("AZURE_CLIENT_ID")
|
|
49
82
|
self.client_secret = self.client_secret or os.environ.get("AZURE_CLIENT_SECRET")
|
|
@@ -60,7 +93,11 @@ class MSALAuth:
|
|
|
60
93
|
self._validate_config()
|
|
61
94
|
|
|
62
95
|
def _validate_config(self):
|
|
63
|
-
"""Validate that all required configuration is present.
|
|
96
|
+
"""Validate that all required configuration is present.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
None.
|
|
100
|
+
"""
|
|
64
101
|
missing = []
|
|
65
102
|
|
|
66
103
|
if not self.client_id:
|
|
@@ -81,6 +118,15 @@ class MSALAuth:
|
|
|
81
118
|
env: Mapping = None,
|
|
82
119
|
prefix: Optional[str] = None
|
|
83
120
|
) -> "MSALAuth":
|
|
121
|
+
"""Return an MSALAuth built from environment variables if available.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
env: Mapping to read variables from; defaults to os.environ.
|
|
125
|
+
prefix: Optional prefix for variable names.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
A configured MSALAuth instance or None.
|
|
129
|
+
"""
|
|
84
130
|
if not env:
|
|
85
131
|
env = os.environ
|
|
86
132
|
prefix = prefix or "AZURE_"
|
|
@@ -105,6 +151,14 @@ class MSALAuth:
|
|
|
105
151
|
return None
|
|
106
152
|
|
|
107
153
|
def export_to(self, to: dict = os.environ):
|
|
154
|
+
"""Export the auth configuration to the provided mapping.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
to: Mapping to populate with auth configuration values.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
None.
|
|
161
|
+
"""
|
|
108
162
|
for key, value in (
|
|
109
163
|
("AZURE_CLIENT_ID", self.client_id),
|
|
110
164
|
("AZURE_CLIENT_SECRET", self.client_secret),
|
|
@@ -116,6 +170,11 @@ class MSALAuth:
|
|
|
116
170
|
|
|
117
171
|
@property
|
|
118
172
|
def auth_app(self) -> ConfidentialClientApplication:
|
|
173
|
+
"""Return or initialize the MSAL confidential client.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
MSAL confidential client instance.
|
|
177
|
+
"""
|
|
119
178
|
if not self._auth_app:
|
|
120
179
|
self._auth_app = ConfidentialClientApplication(
|
|
121
180
|
client_id=self.client_id,
|
|
@@ -127,19 +186,42 @@ class MSALAuth:
|
|
|
127
186
|
|
|
128
187
|
@property
|
|
129
188
|
def expires_in(self) -> float:
|
|
189
|
+
"""Return the number of seconds since the token expiry timestamp.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
Seconds elapsed since expiry (negative if not expired).
|
|
193
|
+
"""
|
|
130
194
|
return time.time() - self.expires_at
|
|
131
195
|
|
|
132
196
|
@property
|
|
133
197
|
def expires_at(self) -> float:
|
|
198
|
+
"""Ensure the token is fresh and return the expiry timestamp.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Token expiration time as a Unix timestamp.
|
|
202
|
+
"""
|
|
134
203
|
self.refresh()
|
|
135
204
|
|
|
136
205
|
return self._expires_at
|
|
137
206
|
|
|
138
207
|
@property
|
|
139
208
|
def expired(self) -> bool:
|
|
209
|
+
"""Return True when the token is missing or past its expiry time.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
True if expired or missing; False otherwise.
|
|
213
|
+
"""
|
|
140
214
|
return not self._expires_at or time.time() >= self._expires_at
|
|
141
215
|
|
|
142
216
|
def refresh(self, force: bool | None = None):
|
|
217
|
+
"""Acquire or refresh the token if needed.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
force: Force refresh even if not expired.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
The updated MSALAuth instance.
|
|
224
|
+
"""
|
|
143
225
|
if self.expired or force:
|
|
144
226
|
app = self.auth_app
|
|
145
227
|
result = app.acquire_token_for_client(scopes=self.scopes)
|
|
@@ -157,16 +239,32 @@ class MSALAuth:
|
|
|
157
239
|
|
|
158
240
|
@property
|
|
159
241
|
def access_token(self) -> str:
|
|
160
|
-
"""Return access token.
|
|
242
|
+
"""Return access token.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Access token string.
|
|
246
|
+
"""
|
|
161
247
|
self.refresh()
|
|
162
248
|
return self._access_token
|
|
163
249
|
|
|
164
250
|
@property
|
|
165
251
|
def authorization(self) -> str:
|
|
166
|
-
"""Return authorization token.
|
|
252
|
+
"""Return authorization token.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Authorization header value.
|
|
256
|
+
"""
|
|
167
257
|
return f"Bearer {self.access_token}"
|
|
168
258
|
|
|
169
259
|
def requests_session(self, **kwargs):
|
|
260
|
+
"""Build a requests session that injects the MSAL authorization header.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
**kwargs: Passed through to MSALSession.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Configured MSALSession.
|
|
267
|
+
"""
|
|
170
268
|
return MSALSession(
|
|
171
269
|
msal_auth=self,
|
|
172
270
|
**kwargs
|
|
@@ -174,6 +272,11 @@ class MSALAuth:
|
|
|
174
272
|
|
|
175
273
|
|
|
176
274
|
class MSALSession(YGGSession):
|
|
275
|
+
"""YGGSession subclass that injects MSAL authorization headers.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
YGGSession: Base retry-capable session.
|
|
279
|
+
"""
|
|
177
280
|
msal_auth: MSALAuth | None = None
|
|
178
281
|
|
|
179
282
|
def __init__(
|
|
@@ -182,11 +285,29 @@ class MSALSession(YGGSession):
|
|
|
182
285
|
*args,
|
|
183
286
|
**kwargs: dict
|
|
184
287
|
):
|
|
288
|
+
"""Initialize the session with optional MSAL auth configuration.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
msal_auth: MSALAuth configuration for token injection.
|
|
292
|
+
*args: Positional args for YGGSession.
|
|
293
|
+
**kwargs: Keyword args for YGGSession.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
None.
|
|
297
|
+
"""
|
|
185
298
|
super().__init__(*args, **kwargs)
|
|
186
299
|
self.msal_auth = msal_auth
|
|
187
300
|
|
|
188
301
|
|
|
189
302
|
def prepare_request(self, request):
|
|
303
|
+
"""Prepare the request with an Authorization header when needed.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
request: requests.PreparedRequest to mutate.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Prepared request.
|
|
310
|
+
"""
|
|
190
311
|
# called before sending; ensure header exists
|
|
191
312
|
if self.msal_auth is not None:
|
|
192
313
|
request.headers["Authorization"] = request.headers.get("Authorization", self.msal_auth.authorization)
|
yggdrasil/requests/session.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""HTTP session helpers with retry-enabled defaults."""
|
|
2
|
+
|
|
1
3
|
from typing import Optional, Dict
|
|
2
4
|
|
|
3
5
|
from requests import Session
|
|
@@ -10,6 +12,11 @@ __all__ = [
|
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class YGGSession(Session):
|
|
15
|
+
"""Requests session with preconfigured retry adapter support.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
Session: Base requests session type.
|
|
19
|
+
"""
|
|
13
20
|
def __init__(
|
|
14
21
|
self,
|
|
15
22
|
num_retry: int = 4,
|
|
@@ -17,6 +24,17 @@ class YGGSession(Session):
|
|
|
17
24
|
*args,
|
|
18
25
|
**kwargs
|
|
19
26
|
):
|
|
27
|
+
"""Initialize the session with retries and optional default headers.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
num_retry: Number of retries for connection and read errors.
|
|
31
|
+
headers: Optional default headers to merge into the session.
|
|
32
|
+
*args: Additional positional arguments passed to Session.
|
|
33
|
+
**kwargs: Additional keyword arguments passed to Session.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
None.
|
|
37
|
+
"""
|
|
20
38
|
super(YGGSession, self).__init__()
|
|
21
39
|
|
|
22
40
|
retry = Retry(
|
yggdrasil/types/__init__.py
CHANGED
yggdrasil/types/cast/__init__.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Casting utilities and converters across Arrow and engine types."""
|
|
2
|
+
|
|
1
3
|
from .registry import *
|
|
2
4
|
from .arrow_cast import *
|
|
3
5
|
from .polars_cast import *
|
|
@@ -6,4 +8,3 @@ from .spark_cast import *
|
|
|
6
8
|
from .spark_polars_cast import *
|
|
7
9
|
from .spark_pandas_cast import *
|
|
8
10
|
from .polars_pandas_cast import *
|
|
9
|
-
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
"""Arrow casting helpers for arrays, tables, and schemas."""
|
|
2
|
+
|
|
1
3
|
import dataclasses
|
|
2
4
|
import enum
|
|
3
5
|
import logging
|
|
4
6
|
from dataclasses import is_dataclass
|
|
7
|
+
from functools import partial
|
|
5
8
|
from typing import Optional, Union, List, Tuple, Any
|
|
6
9
|
|
|
7
10
|
import pyarrow as pa
|
|
@@ -452,6 +455,15 @@ def any_to_arrow_scalar(
|
|
|
452
455
|
scalar: Any,
|
|
453
456
|
options: Optional[CastOptions] = None,
|
|
454
457
|
) -> pa.Scalar:
|
|
458
|
+
"""Convert a Python value to an Arrow scalar.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
scalar: Input value.
|
|
462
|
+
options: Optional cast options.
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
Arrow scalar.
|
|
466
|
+
"""
|
|
455
467
|
if isinstance(scalar, pa.Scalar):
|
|
456
468
|
return cast_arrow_scalar(scalar, options)
|
|
457
469
|
|
|
@@ -492,6 +504,15 @@ def cast_arrow_scalar(
|
|
|
492
504
|
scalar: pa.Scalar,
|
|
493
505
|
options: Optional[CastOptions] = None,
|
|
494
506
|
) -> pa.Scalar:
|
|
507
|
+
"""Cast an Arrow scalar to the target Arrow field.
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
scalar: Arrow scalar to cast.
|
|
511
|
+
options: Optional cast options.
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
Casted Arrow scalar.
|
|
515
|
+
"""
|
|
495
516
|
options = CastOptions.check_arg(options)
|
|
496
517
|
target_field = options.target_field
|
|
497
518
|
|
|
@@ -741,6 +762,28 @@ def cast_arrow_tabular(
|
|
|
741
762
|
return data.__class__.from_arrays(all_arrays, schema=target_arrow_schema)
|
|
742
763
|
|
|
743
764
|
|
|
765
|
+
@register_converter(pds.Dataset, pds.Dataset)
|
|
766
|
+
def cast_arrow_dataset(data: pds.Dataset, options: Optional[CastOptions] = None) -> pds.Dataset:
|
|
767
|
+
"""Cast a dataset to the target schema in options.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
data: Arrow dataset to cast.
|
|
771
|
+
options: Optional cast options.
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
Casted dataset.
|
|
775
|
+
"""
|
|
776
|
+
if options is None:
|
|
777
|
+
return data
|
|
778
|
+
|
|
779
|
+
caster = partial(cast_arrow_tabular, options=options)
|
|
780
|
+
|
|
781
|
+
return pds.dataset(map(caster, data.to_batches(
|
|
782
|
+
batch_size=256 * 1024,
|
|
783
|
+
memory_pool=options.arrow_memory_pool
|
|
784
|
+
)))
|
|
785
|
+
|
|
786
|
+
|
|
744
787
|
@register_converter(pa.RecordBatchReader, pa.RecordBatchReader)
|
|
745
788
|
def cast_arrow_record_batch_reader(
|
|
746
789
|
data: pa.RecordBatchReader,
|
|
@@ -757,6 +800,11 @@ def cast_arrow_record_batch_reader(
|
|
|
757
800
|
return data
|
|
758
801
|
|
|
759
802
|
def casted_batches():
|
|
803
|
+
"""Yield casted batches from a RecordBatchReader.
|
|
804
|
+
|
|
805
|
+
Yields:
|
|
806
|
+
Casted RecordBatch instances.
|
|
807
|
+
"""
|
|
760
808
|
for batch in data:
|
|
761
809
|
yield cast_arrow_tabular(batch, options)
|
|
762
810
|
|
|
@@ -770,6 +818,15 @@ def any_to_arrow_array(
|
|
|
770
818
|
obj: Any,
|
|
771
819
|
options: Optional[CastOptions] = None,
|
|
772
820
|
) -> pa.Array:
|
|
821
|
+
"""Convert array-like input into an Arrow array.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
obj: Array-like input.
|
|
825
|
+
options: Optional cast options.
|
|
826
|
+
|
|
827
|
+
Returns:
|
|
828
|
+
Arrow array.
|
|
829
|
+
"""
|
|
773
830
|
options = CastOptions.check_arg(options)
|
|
774
831
|
arrow_array = None
|
|
775
832
|
|
|
@@ -846,6 +903,15 @@ def pylist_to_record_batch(
|
|
|
846
903
|
data: list,
|
|
847
904
|
options: Optional[CastOptions] = None,
|
|
848
905
|
) -> pa.RecordBatch:
|
|
906
|
+
"""Convert a list of rows into a RecordBatch.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
data: List of row objects.
|
|
910
|
+
options: Optional cast options.
|
|
911
|
+
|
|
912
|
+
Returns:
|
|
913
|
+
Arrow RecordBatch.
|
|
914
|
+
"""
|
|
849
915
|
options = CastOptions.check_arg(options)
|
|
850
916
|
|
|
851
917
|
array: Union[pa.Array, pa.StructArray] = any_to_arrow_array(data, options)
|
|
@@ -1100,10 +1166,39 @@ def record_batch_reader_to_record_batch(
|
|
|
1100
1166
|
def arrow_dataset_to_table(
|
|
1101
1167
|
data: pds.Dataset,
|
|
1102
1168
|
options: Optional[CastOptions] = None,
|
|
1103
|
-
) -> pa.
|
|
1169
|
+
) -> pa.Table:
|
|
1170
|
+
"""Convert a dataset to a Table and apply casting.
|
|
1171
|
+
|
|
1172
|
+
Args:
|
|
1173
|
+
data: Arrow dataset.
|
|
1174
|
+
options: Optional cast options.
|
|
1175
|
+
|
|
1176
|
+
Returns:
|
|
1177
|
+
Arrow table.
|
|
1178
|
+
"""
|
|
1104
1179
|
table = data.to_table()
|
|
1105
1180
|
return cast_arrow_tabular(table, options)
|
|
1106
1181
|
|
|
1182
|
+
|
|
1183
|
+
@register_converter(pa.Table, pds.Dataset)
|
|
1184
|
+
@register_converter(pa.RecordBatch, pds.Dataset)
|
|
1185
|
+
def arrow_tabular_to_dataset(
|
|
1186
|
+
data: Union[pa.Table, pa.RecordBatch],
|
|
1187
|
+
options: Optional[CastOptions] = None,
|
|
1188
|
+
) -> pa.Field:
|
|
1189
|
+
"""Convert Arrow tabular data to a dataset after casting.
|
|
1190
|
+
|
|
1191
|
+
Args:
|
|
1192
|
+
data: Table or RecordBatch.
|
|
1193
|
+
options: Optional cast options.
|
|
1194
|
+
|
|
1195
|
+
Returns:
|
|
1196
|
+
Arrow dataset.
|
|
1197
|
+
"""
|
|
1198
|
+
data = cast_arrow_tabular(data, options)
|
|
1199
|
+
return pds.dataset([data])
|
|
1200
|
+
|
|
1201
|
+
|
|
1107
1202
|
# ---------------------------------------------------------------------------
|
|
1108
1203
|
# Field / Schema converters
|
|
1109
1204
|
# ---------------------------------------------------------------------------
|
|
@@ -1154,6 +1249,15 @@ def arrow_schema_to_field(
|
|
|
1154
1249
|
data: pa.Schema,
|
|
1155
1250
|
options: Optional[CastOptions] = None,
|
|
1156
1251
|
) -> pa.Field:
|
|
1252
|
+
"""Wrap an Arrow schema as a struct field.
|
|
1253
|
+
|
|
1254
|
+
Args:
|
|
1255
|
+
data: Arrow schema.
|
|
1256
|
+
options: Optional cast options.
|
|
1257
|
+
|
|
1258
|
+
Returns:
|
|
1259
|
+
Arrow field.
|
|
1260
|
+
"""
|
|
1157
1261
|
dtype = pa.struct(list(data))
|
|
1158
1262
|
md = dict(data.metadata or {})
|
|
1159
1263
|
name = md.setdefault(b"name", b"root")
|
|
@@ -1166,6 +1270,15 @@ def arrow_field_to_schema(
|
|
|
1166
1270
|
data: pa.Field,
|
|
1167
1271
|
options: Optional[CastOptions] = None,
|
|
1168
1272
|
) -> pa.Schema:
|
|
1273
|
+
"""Return a schema view of an Arrow field.
|
|
1274
|
+
|
|
1275
|
+
Args:
|
|
1276
|
+
data: Arrow field.
|
|
1277
|
+
options: Optional cast options.
|
|
1278
|
+
|
|
1279
|
+
Returns:
|
|
1280
|
+
Arrow schema.
|
|
1281
|
+
"""
|
|
1169
1282
|
md = dict(data.metadata or {})
|
|
1170
1283
|
md[b"name"] = data.name.encode()
|
|
1171
1284
|
|
|
@@ -1181,4 +1294,13 @@ def arrow_tabular_to_field(
|
|
|
1181
1294
|
data: Union[pa.Table, pa.RecordBatch, pa.RecordBatchReader],
|
|
1182
1295
|
options: Optional[CastOptions] = None,
|
|
1183
1296
|
) -> pa.Field:
|
|
1297
|
+
"""Return a field representing the schema of tabular data.
|
|
1298
|
+
|
|
1299
|
+
Args:
|
|
1300
|
+
data: Arrow table/batch/reader.
|
|
1301
|
+
options: Optional cast options.
|
|
1302
|
+
|
|
1303
|
+
Returns:
|
|
1304
|
+
Arrow field.
|
|
1305
|
+
"""
|
|
1184
1306
|
return arrow_schema_to_field(data.schema, options)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Casting options for Arrow- and engine-aware conversions."""
|
|
2
|
+
|
|
1
3
|
import dataclasses
|
|
2
4
|
from typing import Optional, Union, List, Any
|
|
3
5
|
|
|
@@ -69,6 +71,22 @@ class CastOptions:
|
|
|
69
71
|
target_field: pa.Field | pa.Schema | pa.DataType | None = None,
|
|
70
72
|
**kwargs
|
|
71
73
|
):
|
|
74
|
+
"""Build a CastOptions instance with optional source/target fields.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
safe: Enable safe casting if True.
|
|
78
|
+
add_missing_columns: Add missing columns if True.
|
|
79
|
+
strict_match_names: Require exact field name matches if True.
|
|
80
|
+
allow_add_columns: Allow extra columns if True.
|
|
81
|
+
eager: Enable eager casting behavior if True.
|
|
82
|
+
datetime_patterns: Optional datetime parsing patterns.
|
|
83
|
+
source_field: Optional source Arrow field/schema/type.
|
|
84
|
+
target_field: Optional target Arrow field/schema/type.
|
|
85
|
+
**kwargs: Additional CastOptions fields.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
CastOptions instance.
|
|
89
|
+
"""
|
|
72
90
|
built = CastOptions(
|
|
73
91
|
safe=safe,
|
|
74
92
|
add_missing_columns=add_missing_columns,
|
|
@@ -169,6 +187,14 @@ class CastOptions:
|
|
|
169
187
|
return result
|
|
170
188
|
|
|
171
189
|
def check_source(self, obj: Any):
|
|
190
|
+
"""Set the source field if not already configured.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
obj: Source object to infer from.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Self.
|
|
197
|
+
"""
|
|
172
198
|
if self.source_field is not None or obj is None:
|
|
173
199
|
return self
|
|
174
200
|
|
|
@@ -177,6 +203,14 @@ class CastOptions:
|
|
|
177
203
|
return self
|
|
178
204
|
|
|
179
205
|
def need_arrow_type_cast(self, source_obj: Any):
|
|
206
|
+
"""Return True when Arrow type casting is required.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
source_obj: Source object to compare types against.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
True if Arrow type cast needed.
|
|
213
|
+
"""
|
|
180
214
|
if self.target_field is None:
|
|
181
215
|
return False
|
|
182
216
|
|
|
@@ -185,6 +219,14 @@ class CastOptions:
|
|
|
185
219
|
return self.source_field.type != self.target_field.type
|
|
186
220
|
|
|
187
221
|
def need_polars_type_cast(self, source_obj: Any):
|
|
222
|
+
"""Return True when Polars dtype casting is required.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
source_obj: Source object to compare types against.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
True if Polars type cast needed.
|
|
229
|
+
"""
|
|
188
230
|
if self.target_polars_field is None:
|
|
189
231
|
return False
|
|
190
232
|
|
|
@@ -193,6 +235,14 @@ class CastOptions:
|
|
|
193
235
|
return self.source_polars_field.dtype != self.target_polars_field.dtype
|
|
194
236
|
|
|
195
237
|
def need_spark_type_cast(self, source_obj: Any):
|
|
238
|
+
"""Return True when Spark datatype casting is required.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
source_obj: Source object to compare types against.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
True if Spark type cast needed.
|
|
245
|
+
"""
|
|
196
246
|
if self.target_spark_field is None:
|
|
197
247
|
return False
|
|
198
248
|
|
|
@@ -201,6 +251,14 @@ class CastOptions:
|
|
|
201
251
|
return self.source_spark_field.dataType != self.target_spark_field.dataType
|
|
202
252
|
|
|
203
253
|
def need_nullability_check(self, source_obj: Any):
|
|
254
|
+
"""Return True when nullability checks are required.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
source_obj: Source object to compare nullability against.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
True if nullability check needed.
|
|
261
|
+
"""
|
|
204
262
|
if self.target_field is None:
|
|
205
263
|
return False
|
|
206
264
|
|
|
@@ -213,6 +271,15 @@ class CastOptions:
|
|
|
213
271
|
arrow_field: pa.Field,
|
|
214
272
|
index: int
|
|
215
273
|
):
|
|
274
|
+
"""Return a child Arrow field by index for nested types.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
arrow_field: Parent Arrow field.
|
|
278
|
+
index: Child index.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Child Arrow field.
|
|
282
|
+
"""
|
|
216
283
|
source_type: Union[
|
|
217
284
|
pa.DataType, pa.ListType, pa.StructType, pa.MapType
|
|
218
285
|
] = arrow_field.type
|
|
@@ -235,6 +302,11 @@ class CastOptions:
|
|
|
235
302
|
|
|
236
303
|
@property
|
|
237
304
|
def source_field(self):
|
|
305
|
+
"""Return the configured source Arrow field.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Source Arrow field.
|
|
309
|
+
"""
|
|
238
310
|
return self.source_arrow_field
|
|
239
311
|
|
|
240
312
|
@source_field.setter
|
|
@@ -248,10 +320,23 @@ class CastOptions:
|
|
|
248
320
|
object.__setattr__(self, "source_arrow_field", value)
|
|
249
321
|
|
|
250
322
|
def source_child_arrow_field(self, index: int):
|
|
323
|
+
"""Return a child source Arrow field by index.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
index: Child index.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Child Arrow field.
|
|
330
|
+
"""
|
|
251
331
|
return self._child_arrow_field(self.source_arrow_field, index=index)
|
|
252
332
|
|
|
253
333
|
@property
|
|
254
334
|
def source_polars_field(self):
|
|
335
|
+
"""Return or compute the cached Polars field for the source.
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
Polars field or None.
|
|
339
|
+
"""
|
|
255
340
|
if self.source_arrow_field is not None and self._source_polars_field is None:
|
|
256
341
|
from ...types.cast.polars_cast import arrow_field_to_polars_field
|
|
257
342
|
|
|
@@ -260,6 +345,11 @@ class CastOptions:
|
|
|
260
345
|
|
|
261
346
|
@property
|
|
262
347
|
def source_spark_field(self):
|
|
348
|
+
"""Return or compute the cached Spark field for the source.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
Spark field or None.
|
|
352
|
+
"""
|
|
263
353
|
if self.source_arrow_field is not None and self._source_spark_field is None:
|
|
264
354
|
from ...types.cast.spark_cast import arrow_field_to_spark_field
|
|
265
355
|
|
|
@@ -275,6 +365,11 @@ class CastOptions:
|
|
|
275
365
|
|
|
276
366
|
@property
|
|
277
367
|
def target_field_name(self):
|
|
368
|
+
"""Return the effective target field name.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
Target field name or None.
|
|
372
|
+
"""
|
|
278
373
|
if self.target_field is None:
|
|
279
374
|
if self.source_field is not None:
|
|
280
375
|
return self.source_field.name
|
|
@@ -295,10 +390,23 @@ class CastOptions:
|
|
|
295
390
|
object.__setattr__(self, "target_arrow_field", value)
|
|
296
391
|
|
|
297
392
|
def target_child_arrow_field(self, index: int):
|
|
393
|
+
"""Return a child target Arrow field by index.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
index: Child index.
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
Child Arrow field.
|
|
400
|
+
"""
|
|
298
401
|
return self._child_arrow_field(self.target_arrow_field, index=index)
|
|
299
402
|
|
|
300
403
|
@property
|
|
301
404
|
def target_polars_field(self):
|
|
405
|
+
"""Return or compute the cached Polars field for the target.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
Polars field or None.
|
|
409
|
+
"""
|
|
302
410
|
if self.target_arrow_field is not None and self._target_polars_field is None:
|
|
303
411
|
from ...types.cast.polars_cast import arrow_field_to_polars_field
|
|
304
412
|
|
|
@@ -307,6 +415,11 @@ class CastOptions:
|
|
|
307
415
|
|
|
308
416
|
@property
|
|
309
417
|
def target_spark_field(self):
|
|
418
|
+
"""Return or compute the cached Spark field for the target.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
Spark field or None.
|
|
422
|
+
"""
|
|
310
423
|
if self.target_arrow_field is not None and self._target_spark_field is None:
|
|
311
424
|
from ...types.cast.spark_cast import arrow_field_to_spark_field
|
|
312
425
|
|
|
@@ -329,6 +442,11 @@ class CastOptions:
|
|
|
329
442
|
|
|
330
443
|
@property
|
|
331
444
|
def target_spark_schema(self) -> Optional["pyspark.sql.types.StructType"]:
|
|
445
|
+
"""Return a Spark schema view of the target Arrow schema.
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
Spark StructType schema or None.
|
|
449
|
+
"""
|
|
332
450
|
arrow_schema = self.target_arrow_schema
|
|
333
451
|
|
|
334
452
|
if arrow_schema is not None:
|
|
@@ -338,4 +456,4 @@ class CastOptions:
|
|
|
338
456
|
return arrow_schema
|
|
339
457
|
|
|
340
458
|
|
|
341
|
-
DEFAULT_INSTANCE = CastOptions()
|
|
459
|
+
DEFAULT_INSTANCE = CastOptions()
|