sqlframe 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sqlframe/__init__.py +0 -0
  2. sqlframe/_version.py +16 -0
  3. sqlframe/base/__init__.py +0 -0
  4. sqlframe/base/_typing.py +39 -0
  5. sqlframe/base/catalog.py +1163 -0
  6. sqlframe/base/column.py +388 -0
  7. sqlframe/base/dataframe.py +1519 -0
  8. sqlframe/base/decorators.py +51 -0
  9. sqlframe/base/exceptions.py +14 -0
  10. sqlframe/base/function_alternatives.py +1055 -0
  11. sqlframe/base/functions.py +1678 -0
  12. sqlframe/base/group.py +102 -0
  13. sqlframe/base/mixins/__init__.py +0 -0
  14. sqlframe/base/mixins/catalog_mixins.py +419 -0
  15. sqlframe/base/mixins/readwriter_mixins.py +118 -0
  16. sqlframe/base/normalize.py +84 -0
  17. sqlframe/base/operations.py +87 -0
  18. sqlframe/base/readerwriter.py +679 -0
  19. sqlframe/base/session.py +585 -0
  20. sqlframe/base/transforms.py +13 -0
  21. sqlframe/base/types.py +418 -0
  22. sqlframe/base/util.py +242 -0
  23. sqlframe/base/window.py +139 -0
  24. sqlframe/bigquery/__init__.py +23 -0
  25. sqlframe/bigquery/catalog.py +255 -0
  26. sqlframe/bigquery/column.py +1 -0
  27. sqlframe/bigquery/dataframe.py +54 -0
  28. sqlframe/bigquery/functions.py +378 -0
  29. sqlframe/bigquery/group.py +14 -0
  30. sqlframe/bigquery/readwriter.py +29 -0
  31. sqlframe/bigquery/session.py +89 -0
  32. sqlframe/bigquery/types.py +1 -0
  33. sqlframe/bigquery/window.py +1 -0
  34. sqlframe/duckdb/__init__.py +20 -0
  35. sqlframe/duckdb/catalog.py +108 -0
  36. sqlframe/duckdb/column.py +1 -0
  37. sqlframe/duckdb/dataframe.py +55 -0
  38. sqlframe/duckdb/functions.py +47 -0
  39. sqlframe/duckdb/group.py +14 -0
  40. sqlframe/duckdb/readwriter.py +111 -0
  41. sqlframe/duckdb/session.py +65 -0
  42. sqlframe/duckdb/types.py +1 -0
  43. sqlframe/duckdb/window.py +1 -0
  44. sqlframe/postgres/__init__.py +23 -0
  45. sqlframe/postgres/catalog.py +106 -0
  46. sqlframe/postgres/column.py +1 -0
  47. sqlframe/postgres/dataframe.py +54 -0
  48. sqlframe/postgres/functions.py +61 -0
  49. sqlframe/postgres/group.py +14 -0
  50. sqlframe/postgres/readwriter.py +29 -0
  51. sqlframe/postgres/session.py +68 -0
  52. sqlframe/postgres/types.py +1 -0
  53. sqlframe/postgres/window.py +1 -0
  54. sqlframe/redshift/__init__.py +23 -0
  55. sqlframe/redshift/catalog.py +127 -0
  56. sqlframe/redshift/column.py +1 -0
  57. sqlframe/redshift/dataframe.py +54 -0
  58. sqlframe/redshift/functions.py +18 -0
  59. sqlframe/redshift/group.py +14 -0
  60. sqlframe/redshift/readwriter.py +29 -0
  61. sqlframe/redshift/session.py +53 -0
  62. sqlframe/redshift/types.py +1 -0
  63. sqlframe/redshift/window.py +1 -0
  64. sqlframe/snowflake/__init__.py +26 -0
  65. sqlframe/snowflake/catalog.py +134 -0
  66. sqlframe/snowflake/column.py +1 -0
  67. sqlframe/snowflake/dataframe.py +54 -0
  68. sqlframe/snowflake/functions.py +18 -0
  69. sqlframe/snowflake/group.py +14 -0
  70. sqlframe/snowflake/readwriter.py +29 -0
  71. sqlframe/snowflake/session.py +53 -0
  72. sqlframe/snowflake/types.py +1 -0
  73. sqlframe/snowflake/window.py +1 -0
  74. sqlframe/spark/__init__.py +23 -0
  75. sqlframe/spark/catalog.py +1028 -0
  76. sqlframe/spark/column.py +1 -0
  77. sqlframe/spark/dataframe.py +54 -0
  78. sqlframe/spark/functions.py +22 -0
  79. sqlframe/spark/group.py +14 -0
  80. sqlframe/spark/readwriter.py +29 -0
  81. sqlframe/spark/session.py +90 -0
  82. sqlframe/spark/types.py +1 -0
  83. sqlframe/spark/window.py +1 -0
  84. sqlframe/standalone/__init__.py +26 -0
  85. sqlframe/standalone/catalog.py +13 -0
  86. sqlframe/standalone/column.py +1 -0
  87. sqlframe/standalone/dataframe.py +36 -0
  88. sqlframe/standalone/functions.py +1 -0
  89. sqlframe/standalone/group.py +14 -0
  90. sqlframe/standalone/readwriter.py +19 -0
  91. sqlframe/standalone/session.py +40 -0
  92. sqlframe/standalone/types.py +1 -0
  93. sqlframe/standalone/window.py +1 -0
  94. sqlframe-1.1.3.dist-info/LICENSE +21 -0
  95. sqlframe-1.1.3.dist-info/METADATA +172 -0
  96. sqlframe-1.1.3.dist-info/RECORD +98 -0
  97. sqlframe-1.1.3.dist-info/WHEEL +5 -0
  98. sqlframe-1.1.3.dist-info/top_level.txt +1 -0
sqlframe/base/types.py ADDED
@@ -0,0 +1,418 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing as t
6
+ from decimal import Decimal
7
+
8
+ from sqlframe.base.exceptions import RowError
9
+
10
+
11
+ class DataType:
12
+ def __repr__(self) -> str:
13
+ return self.__class__.__name__ + "()"
14
+
15
+ def __hash__(self) -> int:
16
+ return hash(str(self))
17
+
18
+ def __eq__(self, other: t.Any) -> bool:
19
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
20
+
21
+ def __ne__(self, other: t.Any) -> bool:
22
+ return not self.__eq__(other)
23
+
24
+ def __str__(self) -> str:
25
+ return self.typeName()
26
+
27
+ @classmethod
28
+ def typeName(cls) -> str:
29
+ return cls.__name__[:-4].lower()
30
+
31
+ def simpleString(self) -> str:
32
+ return str(self)
33
+
34
+ def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]:
35
+ return str(self)
36
+
37
+
38
+ class DataTypeWithLength(DataType):
39
+ def __init__(self, length: int):
40
+ self.length = length
41
+
42
+ def __repr__(self) -> str:
43
+ return f"{self.__class__.__name__}({self.length})"
44
+
45
+ def __str__(self) -> str:
46
+ return f"{self.typeName()}({self.length})"
47
+
48
+
49
+ class StringType(DataType):
50
+ pass
51
+
52
+
53
+ class CharType(DataTypeWithLength):
54
+ pass
55
+
56
+
57
+ class VarcharType(DataTypeWithLength):
58
+ pass
59
+
60
+
61
+ class BinaryType(DataType):
62
+ pass
63
+
64
+
65
+ class BooleanType(DataType):
66
+ pass
67
+
68
+
69
+ class DateType(DataType):
70
+ pass
71
+
72
+
73
+ class TimestampType(DataType):
74
+ pass
75
+
76
+
77
+ class TimestampNTZType(DataType):
78
+ @classmethod
79
+ def typeName(cls) -> str:
80
+ return "timestamp_ntz"
81
+
82
+
83
+ class DecimalType(DataType):
84
+ def __init__(self, precision: int = 10, scale: int = 0):
85
+ self.precision = precision
86
+ self.scale = scale
87
+
88
+ def simpleString(self) -> str:
89
+ return f"decimal({self.precision}, {self.scale})"
90
+
91
+ def jsonValue(self) -> str:
92
+ return f"decimal({self.precision}, {self.scale})"
93
+
94
+ def __repr__(self) -> str:
95
+ return f"DecimalType({self.precision}, {self.scale})"
96
+
97
+
98
+ class DoubleType(DataType):
99
+ pass
100
+
101
+
102
+ class FloatType(DataType):
103
+ pass
104
+
105
+
106
+ class ByteType(DataType):
107
+ def __str__(self) -> str:
108
+ return "tinyint"
109
+
110
+
111
+ class IntegerType(DataType):
112
+ def __str__(self) -> str:
113
+ return "int"
114
+
115
+
116
+ class LongType(DataType):
117
+ def __str__(self) -> str:
118
+ return "bigint"
119
+
120
+
121
+ class ShortType(DataType):
122
+ def __str__(self) -> str:
123
+ return "smallint"
124
+
125
+
126
+ class ArrayType(DataType):
127
+ def __init__(self, elementType: DataType, containsNull: bool = True):
128
+ self.elementType = elementType
129
+ self.containsNull = containsNull
130
+
131
+ def __repr__(self) -> str:
132
+ return f"ArrayType({self.elementType, str(self.containsNull)}"
133
+
134
+ def simpleString(self) -> str:
135
+ return f"array<{self.elementType.simpleString()}>"
136
+
137
+ def jsonValue(self) -> t.Dict[str, t.Any]:
138
+ return {
139
+ "type": self.typeName(),
140
+ "elementType": self.elementType.jsonValue(),
141
+ "containsNull": self.containsNull,
142
+ }
143
+
144
+
145
+ class MapType(DataType):
146
+ def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True):
147
+ self.keyType = keyType
148
+ self.valueType = valueType
149
+ self.valueContainsNull = valueContainsNull
150
+
151
+ def __repr__(self) -> str:
152
+ return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})"
153
+
154
+ def simpleString(self) -> str:
155
+ return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>"
156
+
157
+ def jsonValue(self) -> t.Dict[str, t.Any]:
158
+ return {
159
+ "type": self.typeName(),
160
+ "keyType": self.keyType.jsonValue(),
161
+ "valueType": self.valueType.jsonValue(),
162
+ "valueContainsNull": self.valueContainsNull,
163
+ }
164
+
165
+
166
+ class StructField(DataType):
167
+ def __init__(
168
+ self,
169
+ name: str,
170
+ dataType: DataType,
171
+ nullable: bool = True,
172
+ metadata: t.Optional[t.Dict[str, t.Any]] = None,
173
+ ):
174
+ self.name = name
175
+ self.dataType = dataType
176
+ self.nullable = nullable
177
+ self.metadata = metadata or {}
178
+
179
+ def __repr__(self) -> str:
180
+ return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})"
181
+
182
+ def simpleString(self) -> str:
183
+ return f"{self.name}:{self.dataType.simpleString()}"
184
+
185
+ def jsonValue(self) -> t.Dict[str, t.Any]:
186
+ return {
187
+ "name": self.name,
188
+ "type": self.dataType.jsonValue(),
189
+ "nullable": self.nullable,
190
+ "metadata": self.metadata,
191
+ }
192
+
193
+
194
+ class StructType(DataType):
195
+ def __init__(self, fields: t.Optional[t.List[StructField]] = None):
196
+ if not fields:
197
+ self.fields = []
198
+ self.names = []
199
+ else:
200
+ self.fields = fields
201
+ self.names = [f.name for f in fields]
202
+
203
+ def __iter__(self) -> t.Iterator[StructField]:
204
+ return iter(self.fields)
205
+
206
+ def __len__(self) -> int:
207
+ return len(self.fields)
208
+
209
+ def __repr__(self) -> str:
210
+ return f"StructType({', '.join(str(field) for field in self)})"
211
+
212
+ def simpleString(self) -> str:
213
+ return f"struct<{', '.join(x.simpleString() for x in self)}>"
214
+
215
+ def jsonValue(self) -> t.Dict[str, t.Any]:
216
+ return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]}
217
+
218
+ def fieldNames(self) -> t.List[str]:
219
+ return list(self.names)
220
+
221
+
222
+ def _create_row(
223
+ fields: t.Union[Row, t.List[str]], values: t.Union[t.Tuple[t.Any, ...], t.List[t.Any]]
224
+ ) -> Row:
225
+ row = Row(*values)
226
+ row.__fields__ = fields
227
+ return row
228
+
229
+
230
+ class Row(tuple):
231
+ """
232
+ A row in :class:`DataFrame`.
233
+ The fields in it can be accessed:
234
+
235
+ * like attributes (``row.key``)
236
+ * like dictionary values (``row[key]``)
237
+
238
+ ``key in row`` will search through row keys.
239
+
240
+ Row can be used to create a row object by using named arguments.
241
+ It is not allowed to omit a named argument to represent that the value is
242
+ None or missing. This should be explicitly set to None in this case.
243
+
244
+ .. versionchanged:: 3.0.0
245
+ Rows created from named arguments no longer have
246
+ field names sorted alphabetically and will be ordered in the position as
247
+ entered.
248
+
249
+ Examples
250
+ --------
251
+ >>> from pyspark.sql import Row
252
+ >>> row = Row(name="Alice", age=11)
253
+ >>> row
254
+ Row(name='Alice', age=11)
255
+ >>> row['name'], row['age']
256
+ ('Alice', 11)
257
+ >>> row.name, row.age
258
+ ('Alice', 11)
259
+ >>> 'name' in row
260
+ True
261
+ >>> 'wrong_key' in row
262
+ False
263
+
264
+ Row also can be used to create another Row like class, then it
265
+ could be used to create Row objects, such as
266
+
267
+ >>> Person = Row("name", "age")
268
+ >>> Person
269
+ <Row('name', 'age')>
270
+ >>> 'name' in Person
271
+ True
272
+ >>> 'wrong_key' in Person
273
+ False
274
+ >>> Person("Alice", 11)
275
+ Row(name='Alice', age=11)
276
+
277
+ This form can also be used to create rows as tuple values, i.e. with unnamed
278
+ fields.
279
+
280
+ >>> row1 = Row("Alice", 11)
281
+ >>> row2 = Row(name="Alice", age=11)
282
+ >>> row1 == row2
283
+ True
284
+ """
285
+
286
+ @t.overload
287
+ def __new__(cls, *args: str) -> Row: ...
288
+
289
+ @t.overload
290
+ def __new__(cls, **kwargs: t.Any) -> Row: ...
291
+
292
+ def __new__(cls, *args: t.Optional[str], **kwargs: t.Optional[t.Any]) -> Row:
293
+ if args and kwargs:
294
+ raise RowError("Cannot use both args and kwargs to create Row")
295
+ if kwargs:
296
+ # create row objects
297
+ # psycopg2 returns Decimal type for numeric while PySpark returns float so we convert to float
298
+ row = tuple.__new__(
299
+ cls, [float(x) if isinstance(x, Decimal) else x for x in kwargs.values()]
300
+ )
301
+ row.__fields__ = list(kwargs.keys())
302
+ return row
303
+ else:
304
+ # create row class or objects
305
+ return tuple.__new__(cls, args)
306
+
307
+ def asDict(self, recursive: bool = False) -> t.Dict[str, t.Any]:
308
+ """
309
+ Return as a dict
310
+
311
+ Parameters
312
+ ----------
313
+ recursive : bool, optional
314
+ turns the nested Rows to dict (default: False).
315
+
316
+ Notes
317
+ -----
318
+ If a row contains duplicate field names, e.g., the rows of a join
319
+ between two :class:`DataFrame` that both have the fields of same names,
320
+ one of the duplicate fields will be selected by ``asDict``. ``__getitem__``
321
+ will also return one of the duplicate fields, however returned value might
322
+ be different to ``asDict``.
323
+
324
+ Examples
325
+ --------
326
+ >>> from pyspark.sql import Row
327
+ >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
328
+ True
329
+ >>> row = Row(key=1, value=Row(name='a', age=2))
330
+ >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)}
331
+ True
332
+ >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
333
+ True
334
+ """
335
+ if not hasattr(self, "__fields__"):
336
+ raise RowError("Cannot convert a Row class into dict")
337
+
338
+ if recursive:
339
+
340
+ def conv(obj: t.Any) -> t.Any:
341
+ if isinstance(obj, Row):
342
+ return obj.asDict(True)
343
+ elif isinstance(obj, list):
344
+ return [conv(o) for o in obj]
345
+ elif isinstance(obj, dict):
346
+ return dict((k, conv(v)) for k, v in obj.items())
347
+ else:
348
+ return obj
349
+
350
+ return dict(zip(self.__fields__, (conv(o) for o in self)))
351
+ else:
352
+ return dict(zip(self.__fields__, self))
353
+
354
+ def __contains__(self, item: t.Any) -> bool:
355
+ if hasattr(self, "__fields__"):
356
+ return item in self.__fields__
357
+ else:
358
+ return super(Row, self).__contains__(item)
359
+
360
+ # let object acts like class
361
+ def __call__(self, *args: t.Any) -> Row:
362
+ """create new Row object"""
363
+ if len(args) > len(self):
364
+ raise RowError(
365
+ "Cannot create Row with fields %s. Expected %d values but got %s instead"
366
+ % (self, len(self), args)
367
+ )
368
+
369
+ return _create_row(self, args)
370
+
371
+ def __getitem__(self, item: t.Any) -> t.Any:
372
+ if isinstance(item, (int, slice)):
373
+ return super(Row, self).__getitem__(item)
374
+ try:
375
+ # it will be slow when it has many fields,
376
+ # but this will not be used in normal cases
377
+ idx = self.__fields__.index(item)
378
+ return super(Row, self).__getitem__(idx)
379
+ except IndexError:
380
+ raise KeyError(item)
381
+ except ValueError:
382
+ raise RowError(item)
383
+
384
+ def __getattr__(self, item: str) -> t.Any:
385
+ if item.startswith("__"):
386
+ raise AttributeError(item)
387
+ try:
388
+ # it will be slow when it has many fields,
389
+ # but this will not be used in normal cases
390
+ idx = self.__fields__.index(item)
391
+ return self[idx]
392
+ except IndexError:
393
+ raise AttributeError(item)
394
+ except ValueError:
395
+ raise AttributeError(item)
396
+
397
+ def __setattr__(self, key: t.Any, value: t.Any) -> None:
398
+ if key != "__fields__":
399
+ raise RuntimeError("Row is read-only")
400
+ self.__dict__[key] = value
401
+
402
+ def __reduce__(
403
+ self,
404
+ ) -> t.Union[str, t.Tuple[t.Any, ...]]:
405
+ """Returns a tuple so Python knows how to pickle Row."""
406
+ if hasattr(self, "__fields__"):
407
+ return (_create_row, (self.__fields__, tuple(self)))
408
+ else:
409
+ return tuple.__reduce__(self)
410
+
411
+ def __repr__(self) -> str:
412
+ """Printable representation of Row used in Python REPL."""
413
+ if hasattr(self, "__fields__"):
414
+ return "Row(%s)" % ", ".join(
415
+ "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self))
416
+ )
417
+ else:
418
+ return "<Row(%s)>" % ", ".join(repr(field) for field in self)
sqlframe/base/util.py ADDED
@@ -0,0 +1,242 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import typing as t
5
+ import unicodedata
6
+
7
+ from sqlglot import expressions as exp
8
+ from sqlglot.dialects.dialect import Dialect, DialectType
9
+ from sqlglot.schema import ensure_column_mapping as sqlglot_ensure_column_mapping
10
+
11
+ if t.TYPE_CHECKING:
12
+ from pandas.core.frame import DataFrame as PandasDataFrame
13
+ from pyspark.sql.dataframe import SparkSession as PySparkSession
14
+
15
+ from sqlframe.base import types
16
+ from sqlframe.base._typing import OptionalPrimitiveType, SchemaInput
17
+ from sqlframe.base.session import _BaseSession
18
+ from sqlframe.base.types import StructType
19
+
20
+
21
+ def decoded_str(value: t.Union[str, bytes]) -> str:
22
+ if isinstance(value, bytes):
23
+ return value.decode("utf-8")
24
+ return value
25
+
26
+
27
+ def schema_(
28
+ db: exp.Identifier | str,
29
+ catalog: t.Optional[exp.Identifier | str] = None,
30
+ quoted: t.Optional[bool] = None,
31
+ ) -> exp.Table:
32
+ """Build a Schema.
33
+
34
+ Args:
35
+ db: Database name.
36
+ catalog: Catalog name.
37
+ quoted: Whether to force quotes on the schema's identifiers.
38
+
39
+ Returns:
40
+ The new Schema instance.
41
+ """
42
+ return exp.Table(
43
+ this=None,
44
+ db=exp.to_identifier(db, quoted=quoted) if db else None,
45
+ catalog=exp.to_identifier(catalog, quoted=quoted) if catalog else None,
46
+ )
47
+
48
+
49
+ def to_schema(
50
+ sql_path: t.Union[str, exp.Table], dialect: t.Optional[DialectType] = None
51
+ ) -> exp.Table:
52
+ if isinstance(sql_path, exp.Table) and sql_path.this is None:
53
+ return sql_path
54
+ table = exp.to_table(
55
+ sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path, dialect=dialect
56
+ )
57
+ table.set("catalog", table.args.get("db"))
58
+ table.set("db", table.args.get("this"))
59
+ table.set("this", None)
60
+ return table
61
+
62
+
63
+ def get_column_mapping_from_schema_input(
64
+ schema: SchemaInput, dialect: DialectType = None
65
+ ) -> t.Dict[str, t.Optional[exp.DataType]]:
66
+ from sqlframe.base import types
67
+
68
+ if isinstance(schema, dict):
69
+ value = schema
70
+ elif isinstance(schema, str):
71
+ col_name_type_strs = [x.strip() for x in schema.split(",")]
72
+ if len(col_name_type_strs) == 1 and len(col_name_type_strs[0].split(" ")) == 1:
73
+ value = {"value": col_name_type_strs[0].strip()}
74
+ else:
75
+ value = {
76
+ name_type_str.split(" ")[0].strip(): name_type_str.split(" ")[1].strip()
77
+ for name_type_str in col_name_type_strs
78
+ }
79
+ elif isinstance(schema, types.StructType):
80
+ value = {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
81
+ else:
82
+ value = {x.strip(): None for x in schema}
83
+ return {
84
+ exp.to_column(k).sql(dialect=dialect): exp.DataType.build(v, dialect=dialect)
85
+ if v is not None
86
+ else v
87
+ for k, v in value.items()
88
+ }
89
+ # return {x.strip(): None for x in schema} # type: ignore
90
+
91
+
92
+ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]:
93
+ if not expression.args.get("joins"):
94
+ return []
95
+
96
+ left_table = expression.args["from"].this
97
+ other_tables = [join.this for join in expression.args["joins"]]
98
+ return [left_table] + other_tables
99
+
100
+
101
+ def to_csv(options: t.Dict[str, OptionalPrimitiveType], equality_char: str = "=") -> str:
102
+ return ", ".join(
103
+ [f"{k}{equality_char}{v}" for k, v in (options or {}).items() if v is not None]
104
+ )
105
+
106
+
107
+ def ensure_column_mapping(schema: t.Union[str, StructType]) -> t.Dict:
108
+ if isinstance(schema, str):
109
+ col_name_type_strs = [x.strip() for x in schema.split(",")]
110
+ schema = { # type: ignore
111
+ name_type_str.split(" ")[0].strip(): name_type_str.split(" ")[1].strip()
112
+ for name_type_str in col_name_type_strs
113
+ }
114
+ # TODO: Make a protocol with a `simpleString` attribute as what it looks for instead of the actual
115
+ # `StructType` object.
116
+ elif hasattr(schema, "simpleString"):
117
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema}
118
+ return sqlglot_ensure_column_mapping(schema) # type: ignore
119
+
120
+
121
+ # SO: https://stackoverflow.com/questions/37513355/converting-pandas-dataframe-into-spark-dataframe-error
122
+ def get_equivalent_spark_type(pandas_type) -> types.DataType:
123
+ """
124
+ This method will retrieve the corresponding spark type given a pandas
125
+ type.
126
+
127
+ Args:
128
+ pandas_type (str): pandas data type
129
+
130
+ Returns:
131
+ spark data type
132
+ """
133
+ from sqlframe.base import types
134
+
135
+ type_map = {
136
+ "datetime64[ns]": types.TimestampType(),
137
+ "int64": types.LongType(),
138
+ "int32": types.IntegerType(),
139
+ "float64": types.DoubleType(),
140
+ "float32": types.FloatType(),
141
+ }
142
+ return type_map.get(str(pandas_type).lower(), types.StringType())
143
+
144
+
145
+ def pandas_to_spark_schema(pandas_df: PandasDataFrame) -> types.StructType:
146
+ """
147
+ This method will return a spark dataframe schema given a pandas dataframe.
148
+
149
+ Args:
150
+ pandas_df (pandas.core.frame.DataFrame): pandas DataFrame
151
+
152
+ Returns:
153
+ equivalent spark DataFrame schema
154
+ """
155
+ from sqlframe.base import types
156
+
157
+ columns = list([x.replace("?column?", "unknown_column") for x in pandas_df.columns])
158
+ d_types = list(pandas_df.dtypes)
159
+ p_schema = types.StructType(
160
+ [
161
+ types.StructField(column, get_equivalent_spark_type(pandas_type))
162
+ for column, pandas_type in zip(columns, d_types)
163
+ ]
164
+ )
165
+ return p_schema
166
+
167
+
168
+ def dialect_to_string(dialect: Dialect) -> str:
169
+ mapping = {v: k for k, v in Dialect.classes.items()}
170
+ return mapping[type(dialect)]
171
+
172
+
173
+ def get_func_from_session(
174
+ name: str,
175
+ session: t.Optional[t.Union[_BaseSession, PySparkSession]] = None,
176
+ fallback: bool = True,
177
+ ) -> t.Callable:
178
+ from sqlframe.base.session import _BaseSession
179
+
180
+ session = session if session else _BaseSession()
181
+
182
+ if isinstance(session, _BaseSession):
183
+ dialect_str = dialect_to_string(session.input_dialect)
184
+ import_path = f"sqlframe.{dialect_str}.functions"
185
+ else:
186
+ import_path = "pyspark.sql.functions"
187
+ try:
188
+ func = getattr(importlib.import_module(import_path), name)
189
+ except AttributeError as e:
190
+ if not fallback:
191
+ raise e
192
+ func = getattr(importlib.import_module("sqlframe.base.functions"), name)
193
+ if session.output_dialect in func.unsupported_engines: # type: ignore
194
+ raise NotImplementedError(
195
+ f"{name} is not supported by the engine: {session.output_dialect}" # type: ignore
196
+ )
197
+ return func
198
+
199
+
200
+ def soundex(s):
201
+ if not s:
202
+ return ""
203
+
204
+ s = unicodedata.normalize("NFKD", s)
205
+ s = s.upper()
206
+
207
+ replacements = (
208
+ ("BFPV", "1"),
209
+ ("CGJKQSXZ", "2"),
210
+ ("DT", "3"),
211
+ ("L", "4"),
212
+ ("MN", "5"),
213
+ ("R", "6"),
214
+ )
215
+ result = [s[0]]
216
+ count = 1
217
+
218
+ # find would-be replacment for first character
219
+ for lset, sub in replacements:
220
+ if s[0] in lset:
221
+ last = sub
222
+ break
223
+ else:
224
+ last = None
225
+
226
+ for letter in s[1:]:
227
+ for lset, sub in replacements:
228
+ if letter in lset:
229
+ if sub != last:
230
+ result.append(sub)
231
+ count += 1
232
+ last = sub
233
+ break
234
+ else:
235
+ if letter != "H" and letter != "W":
236
+ # leave last alone if middle letter is H or W
237
+ last = None
238
+ if count == 4:
239
+ break
240
+
241
+ result += "0" * (4 - count)
242
+ return "".join(result)