sqlframe 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/__init__.py +0 -0
- sqlframe/_version.py +16 -0
- sqlframe/base/__init__.py +0 -0
- sqlframe/base/_typing.py +39 -0
- sqlframe/base/catalog.py +1163 -0
- sqlframe/base/column.py +388 -0
- sqlframe/base/dataframe.py +1519 -0
- sqlframe/base/decorators.py +51 -0
- sqlframe/base/exceptions.py +14 -0
- sqlframe/base/function_alternatives.py +1055 -0
- sqlframe/base/functions.py +1678 -0
- sqlframe/base/group.py +102 -0
- sqlframe/base/mixins/__init__.py +0 -0
- sqlframe/base/mixins/catalog_mixins.py +419 -0
- sqlframe/base/mixins/readwriter_mixins.py +118 -0
- sqlframe/base/normalize.py +84 -0
- sqlframe/base/operations.py +87 -0
- sqlframe/base/readerwriter.py +679 -0
- sqlframe/base/session.py +585 -0
- sqlframe/base/transforms.py +13 -0
- sqlframe/base/types.py +418 -0
- sqlframe/base/util.py +242 -0
- sqlframe/base/window.py +139 -0
- sqlframe/bigquery/__init__.py +23 -0
- sqlframe/bigquery/catalog.py +255 -0
- sqlframe/bigquery/column.py +1 -0
- sqlframe/bigquery/dataframe.py +54 -0
- sqlframe/bigquery/functions.py +378 -0
- sqlframe/bigquery/group.py +14 -0
- sqlframe/bigquery/readwriter.py +29 -0
- sqlframe/bigquery/session.py +89 -0
- sqlframe/bigquery/types.py +1 -0
- sqlframe/bigquery/window.py +1 -0
- sqlframe/duckdb/__init__.py +20 -0
- sqlframe/duckdb/catalog.py +108 -0
- sqlframe/duckdb/column.py +1 -0
- sqlframe/duckdb/dataframe.py +55 -0
- sqlframe/duckdb/functions.py +47 -0
- sqlframe/duckdb/group.py +14 -0
- sqlframe/duckdb/readwriter.py +111 -0
- sqlframe/duckdb/session.py +65 -0
- sqlframe/duckdb/types.py +1 -0
- sqlframe/duckdb/window.py +1 -0
- sqlframe/postgres/__init__.py +23 -0
- sqlframe/postgres/catalog.py +106 -0
- sqlframe/postgres/column.py +1 -0
- sqlframe/postgres/dataframe.py +54 -0
- sqlframe/postgres/functions.py +61 -0
- sqlframe/postgres/group.py +14 -0
- sqlframe/postgres/readwriter.py +29 -0
- sqlframe/postgres/session.py +68 -0
- sqlframe/postgres/types.py +1 -0
- sqlframe/postgres/window.py +1 -0
- sqlframe/redshift/__init__.py +23 -0
- sqlframe/redshift/catalog.py +127 -0
- sqlframe/redshift/column.py +1 -0
- sqlframe/redshift/dataframe.py +54 -0
- sqlframe/redshift/functions.py +18 -0
- sqlframe/redshift/group.py +14 -0
- sqlframe/redshift/readwriter.py +29 -0
- sqlframe/redshift/session.py +53 -0
- sqlframe/redshift/types.py +1 -0
- sqlframe/redshift/window.py +1 -0
- sqlframe/snowflake/__init__.py +26 -0
- sqlframe/snowflake/catalog.py +134 -0
- sqlframe/snowflake/column.py +1 -0
- sqlframe/snowflake/dataframe.py +54 -0
- sqlframe/snowflake/functions.py +18 -0
- sqlframe/snowflake/group.py +14 -0
- sqlframe/snowflake/readwriter.py +29 -0
- sqlframe/snowflake/session.py +53 -0
- sqlframe/snowflake/types.py +1 -0
- sqlframe/snowflake/window.py +1 -0
- sqlframe/spark/__init__.py +23 -0
- sqlframe/spark/catalog.py +1028 -0
- sqlframe/spark/column.py +1 -0
- sqlframe/spark/dataframe.py +54 -0
- sqlframe/spark/functions.py +22 -0
- sqlframe/spark/group.py +14 -0
- sqlframe/spark/readwriter.py +29 -0
- sqlframe/spark/session.py +90 -0
- sqlframe/spark/types.py +1 -0
- sqlframe/spark/window.py +1 -0
- sqlframe/standalone/__init__.py +26 -0
- sqlframe/standalone/catalog.py +13 -0
- sqlframe/standalone/column.py +1 -0
- sqlframe/standalone/dataframe.py +36 -0
- sqlframe/standalone/functions.py +1 -0
- sqlframe/standalone/group.py +14 -0
- sqlframe/standalone/readwriter.py +19 -0
- sqlframe/standalone/session.py +40 -0
- sqlframe/standalone/types.py +1 -0
- sqlframe/standalone/window.py +1 -0
- sqlframe-1.1.3.dist-info/LICENSE +21 -0
- sqlframe-1.1.3.dist-info/METADATA +172 -0
- sqlframe-1.1.3.dist-info/RECORD +98 -0
- sqlframe-1.1.3.dist-info/WHEEL +5 -0
- sqlframe-1.1.3.dist-info/top_level.txt +1 -0
sqlframe/base/types.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing as t
|
|
6
|
+
from decimal import Decimal
|
|
7
|
+
|
|
8
|
+
from sqlframe.base.exceptions import RowError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DataType:
|
|
12
|
+
def __repr__(self) -> str:
|
|
13
|
+
return self.__class__.__name__ + "()"
|
|
14
|
+
|
|
15
|
+
def __hash__(self) -> int:
|
|
16
|
+
return hash(str(self))
|
|
17
|
+
|
|
18
|
+
def __eq__(self, other: t.Any) -> bool:
|
|
19
|
+
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
|
|
20
|
+
|
|
21
|
+
def __ne__(self, other: t.Any) -> bool:
|
|
22
|
+
return not self.__eq__(other)
|
|
23
|
+
|
|
24
|
+
def __str__(self) -> str:
|
|
25
|
+
return self.typeName()
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def typeName(cls) -> str:
|
|
29
|
+
return cls.__name__[:-4].lower()
|
|
30
|
+
|
|
31
|
+
def simpleString(self) -> str:
|
|
32
|
+
return str(self)
|
|
33
|
+
|
|
34
|
+
def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]:
|
|
35
|
+
return str(self)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DataTypeWithLength(DataType):
|
|
39
|
+
def __init__(self, length: int):
|
|
40
|
+
self.length = length
|
|
41
|
+
|
|
42
|
+
def __repr__(self) -> str:
|
|
43
|
+
return f"{self.__class__.__name__}({self.length})"
|
|
44
|
+
|
|
45
|
+
def __str__(self) -> str:
|
|
46
|
+
return f"{self.typeName()}({self.length})"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class StringType(DataType):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CharType(DataTypeWithLength):
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class VarcharType(DataTypeWithLength):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class BinaryType(DataType):
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class BooleanType(DataType):
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class DateType(DataType):
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TimestampType(DataType):
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TimestampNTZType(DataType):
|
|
78
|
+
@classmethod
|
|
79
|
+
def typeName(cls) -> str:
|
|
80
|
+
return "timestamp_ntz"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class DecimalType(DataType):
|
|
84
|
+
def __init__(self, precision: int = 10, scale: int = 0):
|
|
85
|
+
self.precision = precision
|
|
86
|
+
self.scale = scale
|
|
87
|
+
|
|
88
|
+
def simpleString(self) -> str:
|
|
89
|
+
return f"decimal({self.precision}, {self.scale})"
|
|
90
|
+
|
|
91
|
+
def jsonValue(self) -> str:
|
|
92
|
+
return f"decimal({self.precision}, {self.scale})"
|
|
93
|
+
|
|
94
|
+
def __repr__(self) -> str:
|
|
95
|
+
return f"DecimalType({self.precision}, {self.scale})"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DoubleType(DataType):
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class FloatType(DataType):
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ByteType(DataType):
|
|
107
|
+
def __str__(self) -> str:
|
|
108
|
+
return "tinyint"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class IntegerType(DataType):
|
|
112
|
+
def __str__(self) -> str:
|
|
113
|
+
return "int"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class LongType(DataType):
|
|
117
|
+
def __str__(self) -> str:
|
|
118
|
+
return "bigint"
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ShortType(DataType):
|
|
122
|
+
def __str__(self) -> str:
|
|
123
|
+
return "smallint"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class ArrayType(DataType):
|
|
127
|
+
def __init__(self, elementType: DataType, containsNull: bool = True):
|
|
128
|
+
self.elementType = elementType
|
|
129
|
+
self.containsNull = containsNull
|
|
130
|
+
|
|
131
|
+
def __repr__(self) -> str:
|
|
132
|
+
return f"ArrayType({self.elementType, str(self.containsNull)}"
|
|
133
|
+
|
|
134
|
+
def simpleString(self) -> str:
|
|
135
|
+
return f"array<{self.elementType.simpleString()}>"
|
|
136
|
+
|
|
137
|
+
def jsonValue(self) -> t.Dict[str, t.Any]:
|
|
138
|
+
return {
|
|
139
|
+
"type": self.typeName(),
|
|
140
|
+
"elementType": self.elementType.jsonValue(),
|
|
141
|
+
"containsNull": self.containsNull,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class MapType(DataType):
|
|
146
|
+
def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True):
|
|
147
|
+
self.keyType = keyType
|
|
148
|
+
self.valueType = valueType
|
|
149
|
+
self.valueContainsNull = valueContainsNull
|
|
150
|
+
|
|
151
|
+
def __repr__(self) -> str:
|
|
152
|
+
return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})"
|
|
153
|
+
|
|
154
|
+
def simpleString(self) -> str:
|
|
155
|
+
return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>"
|
|
156
|
+
|
|
157
|
+
def jsonValue(self) -> t.Dict[str, t.Any]:
|
|
158
|
+
return {
|
|
159
|
+
"type": self.typeName(),
|
|
160
|
+
"keyType": self.keyType.jsonValue(),
|
|
161
|
+
"valueType": self.valueType.jsonValue(),
|
|
162
|
+
"valueContainsNull": self.valueContainsNull,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class StructField(DataType):
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
name: str,
|
|
170
|
+
dataType: DataType,
|
|
171
|
+
nullable: bool = True,
|
|
172
|
+
metadata: t.Optional[t.Dict[str, t.Any]] = None,
|
|
173
|
+
):
|
|
174
|
+
self.name = name
|
|
175
|
+
self.dataType = dataType
|
|
176
|
+
self.nullable = nullable
|
|
177
|
+
self.metadata = metadata or {}
|
|
178
|
+
|
|
179
|
+
def __repr__(self) -> str:
|
|
180
|
+
return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})"
|
|
181
|
+
|
|
182
|
+
def simpleString(self) -> str:
|
|
183
|
+
return f"{self.name}:{self.dataType.simpleString()}"
|
|
184
|
+
|
|
185
|
+
def jsonValue(self) -> t.Dict[str, t.Any]:
|
|
186
|
+
return {
|
|
187
|
+
"name": self.name,
|
|
188
|
+
"type": self.dataType.jsonValue(),
|
|
189
|
+
"nullable": self.nullable,
|
|
190
|
+
"metadata": self.metadata,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class StructType(DataType):
|
|
195
|
+
def __init__(self, fields: t.Optional[t.List[StructField]] = None):
|
|
196
|
+
if not fields:
|
|
197
|
+
self.fields = []
|
|
198
|
+
self.names = []
|
|
199
|
+
else:
|
|
200
|
+
self.fields = fields
|
|
201
|
+
self.names = [f.name for f in fields]
|
|
202
|
+
|
|
203
|
+
def __iter__(self) -> t.Iterator[StructField]:
|
|
204
|
+
return iter(self.fields)
|
|
205
|
+
|
|
206
|
+
def __len__(self) -> int:
|
|
207
|
+
return len(self.fields)
|
|
208
|
+
|
|
209
|
+
def __repr__(self) -> str:
|
|
210
|
+
return f"StructType({', '.join(str(field) for field in self)})"
|
|
211
|
+
|
|
212
|
+
def simpleString(self) -> str:
|
|
213
|
+
return f"struct<{', '.join(x.simpleString() for x in self)}>"
|
|
214
|
+
|
|
215
|
+
def jsonValue(self) -> t.Dict[str, t.Any]:
|
|
216
|
+
return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]}
|
|
217
|
+
|
|
218
|
+
def fieldNames(self) -> t.List[str]:
|
|
219
|
+
return list(self.names)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _create_row(
|
|
223
|
+
fields: t.Union[Row, t.List[str]], values: t.Union[t.Tuple[t.Any, ...], t.List[t.Any]]
|
|
224
|
+
) -> Row:
|
|
225
|
+
row = Row(*values)
|
|
226
|
+
row.__fields__ = fields
|
|
227
|
+
return row
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class Row(tuple):
|
|
231
|
+
"""
|
|
232
|
+
A row in :class:`DataFrame`.
|
|
233
|
+
The fields in it can be accessed:
|
|
234
|
+
|
|
235
|
+
* like attributes (``row.key``)
|
|
236
|
+
* like dictionary values (``row[key]``)
|
|
237
|
+
|
|
238
|
+
``key in row`` will search through row keys.
|
|
239
|
+
|
|
240
|
+
Row can be used to create a row object by using named arguments.
|
|
241
|
+
It is not allowed to omit a named argument to represent that the value is
|
|
242
|
+
None or missing. This should be explicitly set to None in this case.
|
|
243
|
+
|
|
244
|
+
.. versionchanged:: 3.0.0
|
|
245
|
+
Rows created from named arguments no longer have
|
|
246
|
+
field names sorted alphabetically and will be ordered in the position as
|
|
247
|
+
entered.
|
|
248
|
+
|
|
249
|
+
Examples
|
|
250
|
+
--------
|
|
251
|
+
>>> from pyspark.sql import Row
|
|
252
|
+
>>> row = Row(name="Alice", age=11)
|
|
253
|
+
>>> row
|
|
254
|
+
Row(name='Alice', age=11)
|
|
255
|
+
>>> row['name'], row['age']
|
|
256
|
+
('Alice', 11)
|
|
257
|
+
>>> row.name, row.age
|
|
258
|
+
('Alice', 11)
|
|
259
|
+
>>> 'name' in row
|
|
260
|
+
True
|
|
261
|
+
>>> 'wrong_key' in row
|
|
262
|
+
False
|
|
263
|
+
|
|
264
|
+
Row also can be used to create another Row like class, then it
|
|
265
|
+
could be used to create Row objects, such as
|
|
266
|
+
|
|
267
|
+
>>> Person = Row("name", "age")
|
|
268
|
+
>>> Person
|
|
269
|
+
<Row('name', 'age')>
|
|
270
|
+
>>> 'name' in Person
|
|
271
|
+
True
|
|
272
|
+
>>> 'wrong_key' in Person
|
|
273
|
+
False
|
|
274
|
+
>>> Person("Alice", 11)
|
|
275
|
+
Row(name='Alice', age=11)
|
|
276
|
+
|
|
277
|
+
This form can also be used to create rows as tuple values, i.e. with unnamed
|
|
278
|
+
fields.
|
|
279
|
+
|
|
280
|
+
>>> row1 = Row("Alice", 11)
|
|
281
|
+
>>> row2 = Row(name="Alice", age=11)
|
|
282
|
+
>>> row1 == row2
|
|
283
|
+
True
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
@t.overload
|
|
287
|
+
def __new__(cls, *args: str) -> Row: ...
|
|
288
|
+
|
|
289
|
+
@t.overload
|
|
290
|
+
def __new__(cls, **kwargs: t.Any) -> Row: ...
|
|
291
|
+
|
|
292
|
+
def __new__(cls, *args: t.Optional[str], **kwargs: t.Optional[t.Any]) -> Row:
|
|
293
|
+
if args and kwargs:
|
|
294
|
+
raise RowError("Cannot use both args and kwargs to create Row")
|
|
295
|
+
if kwargs:
|
|
296
|
+
# create row objects
|
|
297
|
+
# psycopg2 returns Decimal type for numeric while PySpark returns float so we convert to float
|
|
298
|
+
row = tuple.__new__(
|
|
299
|
+
cls, [float(x) if isinstance(x, Decimal) else x for x in kwargs.values()]
|
|
300
|
+
)
|
|
301
|
+
row.__fields__ = list(kwargs.keys())
|
|
302
|
+
return row
|
|
303
|
+
else:
|
|
304
|
+
# create row class or objects
|
|
305
|
+
return tuple.__new__(cls, args)
|
|
306
|
+
|
|
307
|
+
def asDict(self, recursive: bool = False) -> t.Dict[str, t.Any]:
|
|
308
|
+
"""
|
|
309
|
+
Return as a dict
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
recursive : bool, optional
|
|
314
|
+
turns the nested Rows to dict (default: False).
|
|
315
|
+
|
|
316
|
+
Notes
|
|
317
|
+
-----
|
|
318
|
+
If a row contains duplicate field names, e.g., the rows of a join
|
|
319
|
+
between two :class:`DataFrame` that both have the fields of same names,
|
|
320
|
+
one of the duplicate fields will be selected by ``asDict``. ``__getitem__``
|
|
321
|
+
will also return one of the duplicate fields, however returned value might
|
|
322
|
+
be different to ``asDict``.
|
|
323
|
+
|
|
324
|
+
Examples
|
|
325
|
+
--------
|
|
326
|
+
>>> from pyspark.sql import Row
|
|
327
|
+
>>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
|
|
328
|
+
True
|
|
329
|
+
>>> row = Row(key=1, value=Row(name='a', age=2))
|
|
330
|
+
>>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)}
|
|
331
|
+
True
|
|
332
|
+
>>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
|
|
333
|
+
True
|
|
334
|
+
"""
|
|
335
|
+
if not hasattr(self, "__fields__"):
|
|
336
|
+
raise RowError("Cannot convert a Row class into dict")
|
|
337
|
+
|
|
338
|
+
if recursive:
|
|
339
|
+
|
|
340
|
+
def conv(obj: t.Any) -> t.Any:
|
|
341
|
+
if isinstance(obj, Row):
|
|
342
|
+
return obj.asDict(True)
|
|
343
|
+
elif isinstance(obj, list):
|
|
344
|
+
return [conv(o) for o in obj]
|
|
345
|
+
elif isinstance(obj, dict):
|
|
346
|
+
return dict((k, conv(v)) for k, v in obj.items())
|
|
347
|
+
else:
|
|
348
|
+
return obj
|
|
349
|
+
|
|
350
|
+
return dict(zip(self.__fields__, (conv(o) for o in self)))
|
|
351
|
+
else:
|
|
352
|
+
return dict(zip(self.__fields__, self))
|
|
353
|
+
|
|
354
|
+
def __contains__(self, item: t.Any) -> bool:
|
|
355
|
+
if hasattr(self, "__fields__"):
|
|
356
|
+
return item in self.__fields__
|
|
357
|
+
else:
|
|
358
|
+
return super(Row, self).__contains__(item)
|
|
359
|
+
|
|
360
|
+
# let object acts like class
|
|
361
|
+
def __call__(self, *args: t.Any) -> Row:
|
|
362
|
+
"""create new Row object"""
|
|
363
|
+
if len(args) > len(self):
|
|
364
|
+
raise RowError(
|
|
365
|
+
"Cannot create Row with fields %s. Expected %d values but got %s instead"
|
|
366
|
+
% (self, len(self), args)
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
return _create_row(self, args)
|
|
370
|
+
|
|
371
|
+
def __getitem__(self, item: t.Any) -> t.Any:
|
|
372
|
+
if isinstance(item, (int, slice)):
|
|
373
|
+
return super(Row, self).__getitem__(item)
|
|
374
|
+
try:
|
|
375
|
+
# it will be slow when it has many fields,
|
|
376
|
+
# but this will not be used in normal cases
|
|
377
|
+
idx = self.__fields__.index(item)
|
|
378
|
+
return super(Row, self).__getitem__(idx)
|
|
379
|
+
except IndexError:
|
|
380
|
+
raise KeyError(item)
|
|
381
|
+
except ValueError:
|
|
382
|
+
raise RowError(item)
|
|
383
|
+
|
|
384
|
+
def __getattr__(self, item: str) -> t.Any:
|
|
385
|
+
if item.startswith("__"):
|
|
386
|
+
raise AttributeError(item)
|
|
387
|
+
try:
|
|
388
|
+
# it will be slow when it has many fields,
|
|
389
|
+
# but this will not be used in normal cases
|
|
390
|
+
idx = self.__fields__.index(item)
|
|
391
|
+
return self[idx]
|
|
392
|
+
except IndexError:
|
|
393
|
+
raise AttributeError(item)
|
|
394
|
+
except ValueError:
|
|
395
|
+
raise AttributeError(item)
|
|
396
|
+
|
|
397
|
+
def __setattr__(self, key: t.Any, value: t.Any) -> None:
|
|
398
|
+
if key != "__fields__":
|
|
399
|
+
raise RuntimeError("Row is read-only")
|
|
400
|
+
self.__dict__[key] = value
|
|
401
|
+
|
|
402
|
+
def __reduce__(
|
|
403
|
+
self,
|
|
404
|
+
) -> t.Union[str, t.Tuple[t.Any, ...]]:
|
|
405
|
+
"""Returns a tuple so Python knows how to pickle Row."""
|
|
406
|
+
if hasattr(self, "__fields__"):
|
|
407
|
+
return (_create_row, (self.__fields__, tuple(self)))
|
|
408
|
+
else:
|
|
409
|
+
return tuple.__reduce__(self)
|
|
410
|
+
|
|
411
|
+
def __repr__(self) -> str:
|
|
412
|
+
"""Printable representation of Row used in Python REPL."""
|
|
413
|
+
if hasattr(self, "__fields__"):
|
|
414
|
+
return "Row(%s)" % ", ".join(
|
|
415
|
+
"%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self))
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
return "<Row(%s)>" % ", ".join(repr(field) for field in self)
|
sqlframe/base/util.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import typing as t
|
|
5
|
+
import unicodedata
|
|
6
|
+
|
|
7
|
+
from sqlglot import expressions as exp
|
|
8
|
+
from sqlglot.dialects.dialect import Dialect, DialectType
|
|
9
|
+
from sqlglot.schema import ensure_column_mapping as sqlglot_ensure_column_mapping
|
|
10
|
+
|
|
11
|
+
if t.TYPE_CHECKING:
|
|
12
|
+
from pandas.core.frame import DataFrame as PandasDataFrame
|
|
13
|
+
from pyspark.sql.dataframe import SparkSession as PySparkSession
|
|
14
|
+
|
|
15
|
+
from sqlframe.base import types
|
|
16
|
+
from sqlframe.base._typing import OptionalPrimitiveType, SchemaInput
|
|
17
|
+
from sqlframe.base.session import _BaseSession
|
|
18
|
+
from sqlframe.base.types import StructType
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def decoded_str(value: t.Union[str, bytes]) -> str:
|
|
22
|
+
if isinstance(value, bytes):
|
|
23
|
+
return value.decode("utf-8")
|
|
24
|
+
return value
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def schema_(
|
|
28
|
+
db: exp.Identifier | str,
|
|
29
|
+
catalog: t.Optional[exp.Identifier | str] = None,
|
|
30
|
+
quoted: t.Optional[bool] = None,
|
|
31
|
+
) -> exp.Table:
|
|
32
|
+
"""Build a Schema.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
db: Database name.
|
|
36
|
+
catalog: Catalog name.
|
|
37
|
+
quoted: Whether to force quotes on the schema's identifiers.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The new Schema instance.
|
|
41
|
+
"""
|
|
42
|
+
return exp.Table(
|
|
43
|
+
this=None,
|
|
44
|
+
db=exp.to_identifier(db, quoted=quoted) if db else None,
|
|
45
|
+
catalog=exp.to_identifier(catalog, quoted=quoted) if catalog else None,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def to_schema(
|
|
50
|
+
sql_path: t.Union[str, exp.Table], dialect: t.Optional[DialectType] = None
|
|
51
|
+
) -> exp.Table:
|
|
52
|
+
if isinstance(sql_path, exp.Table) and sql_path.this is None:
|
|
53
|
+
return sql_path
|
|
54
|
+
table = exp.to_table(
|
|
55
|
+
sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path, dialect=dialect
|
|
56
|
+
)
|
|
57
|
+
table.set("catalog", table.args.get("db"))
|
|
58
|
+
table.set("db", table.args.get("this"))
|
|
59
|
+
table.set("this", None)
|
|
60
|
+
return table
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_column_mapping_from_schema_input(
|
|
64
|
+
schema: SchemaInput, dialect: DialectType = None
|
|
65
|
+
) -> t.Dict[str, t.Optional[exp.DataType]]:
|
|
66
|
+
from sqlframe.base import types
|
|
67
|
+
|
|
68
|
+
if isinstance(schema, dict):
|
|
69
|
+
value = schema
|
|
70
|
+
elif isinstance(schema, str):
|
|
71
|
+
col_name_type_strs = [x.strip() for x in schema.split(",")]
|
|
72
|
+
if len(col_name_type_strs) == 1 and len(col_name_type_strs[0].split(" ")) == 1:
|
|
73
|
+
value = {"value": col_name_type_strs[0].strip()}
|
|
74
|
+
else:
|
|
75
|
+
value = {
|
|
76
|
+
name_type_str.split(" ")[0].strip(): name_type_str.split(" ")[1].strip()
|
|
77
|
+
for name_type_str in col_name_type_strs
|
|
78
|
+
}
|
|
79
|
+
elif isinstance(schema, types.StructType):
|
|
80
|
+
value = {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
|
|
81
|
+
else:
|
|
82
|
+
value = {x.strip(): None for x in schema}
|
|
83
|
+
return {
|
|
84
|
+
exp.to_column(k).sql(dialect=dialect): exp.DataType.build(v, dialect=dialect)
|
|
85
|
+
if v is not None
|
|
86
|
+
else v
|
|
87
|
+
for k, v in value.items()
|
|
88
|
+
}
|
|
89
|
+
# return {x.strip(): None for x in schema} # type: ignore
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
|
|
93
|
+
if not expression.args.get("joins"):
|
|
94
|
+
return []
|
|
95
|
+
|
|
96
|
+
left_table = expression.args["from"].this
|
|
97
|
+
other_tables = [join.this for join in expression.args["joins"]]
|
|
98
|
+
return [left_table] + other_tables
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def to_csv(options: t.Dict[str, OptionalPrimitiveType], equality_char: str = "=") -> str:
|
|
102
|
+
return ", ".join(
|
|
103
|
+
[f"{k}{equality_char}{v}" for k, v in (options or {}).items() if v is not None]
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def ensure_column_mapping(schema: t.Union[str, StructType]) -> t.Dict:
|
|
108
|
+
if isinstance(schema, str):
|
|
109
|
+
col_name_type_strs = [x.strip() for x in schema.split(",")]
|
|
110
|
+
schema = { # type: ignore
|
|
111
|
+
name_type_str.split(" ")[0].strip(): name_type_str.split(" ")[1].strip()
|
|
112
|
+
for name_type_str in col_name_type_strs
|
|
113
|
+
}
|
|
114
|
+
# TODO: Make a protocol with a `simpleString` attribute as what it looks for instead of the actual
|
|
115
|
+
# `StructType` object.
|
|
116
|
+
elif hasattr(schema, "simpleString"):
|
|
117
|
+
return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
|
|
118
|
+
return sqlglot_ensure_column_mapping(schema) # type: ignore
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# SO: https://stackoverflow.com/questions/37513355/converting-pandas-dataframe-into-spark-dataframe-error
|
|
122
|
+
def get_equivalent_spark_type(pandas_type) -> types.DataType:
|
|
123
|
+
"""
|
|
124
|
+
This method will retrieve the corresponding spark type given a pandas
|
|
125
|
+
type.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
pandas_type (str): pandas data type
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
spark data type
|
|
132
|
+
"""
|
|
133
|
+
from sqlframe.base import types
|
|
134
|
+
|
|
135
|
+
type_map = {
|
|
136
|
+
"datetime64[ns]": types.TimestampType(),
|
|
137
|
+
"int64": types.LongType(),
|
|
138
|
+
"int32": types.IntegerType(),
|
|
139
|
+
"float64": types.DoubleType(),
|
|
140
|
+
"float32": types.FloatType(),
|
|
141
|
+
}
|
|
142
|
+
return type_map.get(str(pandas_type).lower(), types.StringType())
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def pandas_to_spark_schema(pandas_df: PandasDataFrame) -> types.StructType:
|
|
146
|
+
"""
|
|
147
|
+
This method will return a spark dataframe schema given a pandas dataframe.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
pandas_df (pandas.core.frame.DataFrame): pandas DataFrame
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
equivalent spark DataFrame schema
|
|
154
|
+
"""
|
|
155
|
+
from sqlframe.base import types
|
|
156
|
+
|
|
157
|
+
columns = list([x.replace("?column?", "unknown_column") for x in pandas_df.columns])
|
|
158
|
+
d_types = list(pandas_df.dtypes)
|
|
159
|
+
p_schema = types.StructType(
|
|
160
|
+
[
|
|
161
|
+
types.StructField(column, get_equivalent_spark_type(pandas_type))
|
|
162
|
+
for column, pandas_type in zip(columns, d_types)
|
|
163
|
+
]
|
|
164
|
+
)
|
|
165
|
+
return p_schema
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def dialect_to_string(dialect: Dialect) -> str:
|
|
169
|
+
mapping = {v: k for k, v in Dialect.classes.items()}
|
|
170
|
+
return mapping[type(dialect)]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def get_func_from_session(
|
|
174
|
+
name: str,
|
|
175
|
+
session: t.Optional[t.Union[_BaseSession, PySparkSession]] = None,
|
|
176
|
+
fallback: bool = True,
|
|
177
|
+
) -> t.Callable:
|
|
178
|
+
from sqlframe.base.session import _BaseSession
|
|
179
|
+
|
|
180
|
+
session = session if session else _BaseSession()
|
|
181
|
+
|
|
182
|
+
if isinstance(session, _BaseSession):
|
|
183
|
+
dialect_str = dialect_to_string(session.input_dialect)
|
|
184
|
+
import_path = f"sqlframe.{dialect_str}.functions"
|
|
185
|
+
else:
|
|
186
|
+
import_path = "pyspark.sql.functions"
|
|
187
|
+
try:
|
|
188
|
+
func = getattr(importlib.import_module(import_path), name)
|
|
189
|
+
except AttributeError as e:
|
|
190
|
+
if not fallback:
|
|
191
|
+
raise e
|
|
192
|
+
func = getattr(importlib.import_module("sqlframe.base.functions"), name)
|
|
193
|
+
if session.output_dialect in func.unsupported_engines: # type: ignore
|
|
194
|
+
raise NotImplementedError(
|
|
195
|
+
f"{name} is not supported by the engine: {session.output_dialect}" # type: ignore
|
|
196
|
+
)
|
|
197
|
+
return func
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def soundex(s):
|
|
201
|
+
if not s:
|
|
202
|
+
return ""
|
|
203
|
+
|
|
204
|
+
s = unicodedata.normalize("NFKD", s)
|
|
205
|
+
s = s.upper()
|
|
206
|
+
|
|
207
|
+
replacements = (
|
|
208
|
+
("BFPV", "1"),
|
|
209
|
+
("CGJKQSXZ", "2"),
|
|
210
|
+
("DT", "3"),
|
|
211
|
+
("L", "4"),
|
|
212
|
+
("MN", "5"),
|
|
213
|
+
("R", "6"),
|
|
214
|
+
)
|
|
215
|
+
result = [s[0]]
|
|
216
|
+
count = 1
|
|
217
|
+
|
|
218
|
+
# find would-be replacment for first character
|
|
219
|
+
for lset, sub in replacements:
|
|
220
|
+
if s[0] in lset:
|
|
221
|
+
last = sub
|
|
222
|
+
break
|
|
223
|
+
else:
|
|
224
|
+
last = None
|
|
225
|
+
|
|
226
|
+
for letter in s[1:]:
|
|
227
|
+
for lset, sub in replacements:
|
|
228
|
+
if letter in lset:
|
|
229
|
+
if sub != last:
|
|
230
|
+
result.append(sub)
|
|
231
|
+
count += 1
|
|
232
|
+
last = sub
|
|
233
|
+
break
|
|
234
|
+
else:
|
|
235
|
+
if letter != "H" and letter != "W":
|
|
236
|
+
# leave last alone if middle letter is H or W
|
|
237
|
+
last = None
|
|
238
|
+
if count == 4:
|
|
239
|
+
break
|
|
240
|
+
|
|
241
|
+
result += "0" * (4 - count)
|
|
242
|
+
return "".join(result)
|