sqlframe 3.31.4__py3-none-any.whl → 3.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/catalog.py +12 -1
- sqlframe/base/dataframe.py +3 -0
- sqlframe/base/functions.py +7 -7
- sqlframe/base/util.py +87 -0
- {sqlframe-3.31.4.dist-info → sqlframe-3.32.1.dist-info}/METADATA +1 -1
- {sqlframe-3.31.4.dist-info → sqlframe-3.32.1.dist-info}/RECORD +10 -10
- {sqlframe-3.31.4.dist-info → sqlframe-3.32.1.dist-info}/LICENSE +0 -0
- {sqlframe-3.31.4.dist-info → sqlframe-3.32.1.dist-info}/WHEEL +0 -0
- {sqlframe-3.31.4.dist-info → sqlframe-3.32.1.dist-info}/top_level.txt +0 -0
sqlframe/_version.py
CHANGED
sqlframe/base/catalog.py
CHANGED
@@ -6,9 +6,16 @@ import typing as t
|
|
6
6
|
from collections import defaultdict
|
7
7
|
|
8
8
|
from sqlglot import MappingSchema, exp
|
9
|
+
from sqlglot.helper import seq_get
|
9
10
|
|
11
|
+
from sqlframe.base import types
|
10
12
|
from sqlframe.base.exceptions import TableSchemaError
|
11
|
-
from sqlframe.base.util import
|
13
|
+
from sqlframe.base.util import (
|
14
|
+
ensure_column_mapping,
|
15
|
+
normalize_string,
|
16
|
+
spark_to_sqlglot,
|
17
|
+
to_schema,
|
18
|
+
)
|
12
19
|
|
13
20
|
if t.TYPE_CHECKING:
|
14
21
|
from sqlglot.schema import ColumnMapping
|
@@ -99,6 +106,10 @@ class _BaseCatalog(t.Generic[SESSION, DF, TABLE]):
|
|
99
106
|
"This session does not have access to a catalog that can lookup column information. See docs for explicitly defining columns or using a session that can automatically determine this."
|
100
107
|
)
|
101
108
|
column_mapping = ensure_column_mapping(column_mapping) # type: ignore
|
109
|
+
if isinstance(column_mapping, dict) and isinstance(
|
110
|
+
seq_get(list(column_mapping.values()), 0), types.DataType
|
111
|
+
):
|
112
|
+
column_mapping = {k: spark_to_sqlglot(v) for k, v in column_mapping.items()}
|
102
113
|
for column_name in column_mapping:
|
103
114
|
column = exp.to_column(column_name, dialect=self.session.input_dialect)
|
104
115
|
if column.this.quoted:
|
sqlframe/base/dataframe.py
CHANGED
@@ -260,6 +260,9 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
260
260
|
def __copy__(self):
|
261
261
|
return self.copy()
|
262
262
|
|
263
|
+
def _display_(self) -> str:
|
264
|
+
return self.__repr__()
|
265
|
+
|
263
266
|
@property
|
264
267
|
def _typed_columns(self) -> t.List[CatalogColumn]:
|
265
268
|
raise NotImplementedError
|
sqlframe/base/functions.py
CHANGED
@@ -494,21 +494,21 @@ def skewness(col: ColumnOrName) -> Column:
|
|
494
494
|
func_name = "SKEW"
|
495
495
|
|
496
496
|
if session._is_duckdb or session._is_snowflake:
|
497
|
+
col = Column.ensure_col(col)
|
497
498
|
when_func = get_func_from_session("when")
|
498
499
|
count_func = get_func_from_session("count")
|
499
|
-
|
500
|
+
count_col = count_func(col)
|
500
501
|
lit_func = get_func_from_session("lit")
|
501
502
|
sqrt_func = get_func_from_session("sqrt")
|
502
|
-
col = Column.ensure_col(col)
|
503
503
|
full_calc = (
|
504
504
|
Column.invoke_anonymous_function(col, func_name)
|
505
|
-
* (
|
506
|
-
/ (sqrt_func(
|
505
|
+
* (count_col - lit_func(2))
|
506
|
+
/ (sqrt_func(count_col * (count_col - lit_func(1))))
|
507
507
|
)
|
508
508
|
return (
|
509
|
-
when_func(
|
510
|
-
.when(
|
511
|
-
.when(
|
509
|
+
when_func(count_col == lit_func(0), lit_func(None))
|
510
|
+
.when(count_col == lit_func(1), lit_func(None))
|
511
|
+
.when(count_col == lit_func(2), lit_func(0.0))
|
512
512
|
.otherwise(full_calc)
|
513
513
|
)
|
514
514
|
|
sqlframe/base/util.py
CHANGED
@@ -347,6 +347,93 @@ def sqlglot_to_spark(sqlglot_dtype: exp.DataType) -> types.DataType:
|
|
347
347
|
raise NotImplementedError(f"Unsupported data type: {sqlglot_dtype}")
|
348
348
|
|
349
349
|
|
350
|
+
def spark_to_sqlglot(spark_dtype: types.DataType) -> exp.DataType:
|
351
|
+
"""
|
352
|
+
Convert a Spark data type to a SQLGlot data type.
|
353
|
+
|
354
|
+
This function is the opposite of sqlglot_to_spark.
|
355
|
+
|
356
|
+
Args:
|
357
|
+
spark_dtype: A Spark data type
|
358
|
+
|
359
|
+
Returns:
|
360
|
+
The equivalent SQLGlot data type
|
361
|
+
"""
|
362
|
+
from sqlframe.base import types
|
363
|
+
|
364
|
+
# Handle primitive types
|
365
|
+
if isinstance(spark_dtype, types.StringType):
|
366
|
+
return exp.DataType(this=exp.DataType.Type.TEXT)
|
367
|
+
elif isinstance(spark_dtype, types.VarcharType):
|
368
|
+
return exp.DataType(
|
369
|
+
this=exp.DataType.Type.VARCHAR,
|
370
|
+
expressions=[exp.DataTypeParam(this=exp.Literal.number(spark_dtype.length))],
|
371
|
+
)
|
372
|
+
elif isinstance(spark_dtype, types.CharType):
|
373
|
+
return exp.DataType(
|
374
|
+
this=exp.DataType.Type.CHAR,
|
375
|
+
expressions=[exp.DataTypeParam(this=exp.Literal.number(spark_dtype.length))],
|
376
|
+
)
|
377
|
+
elif isinstance(spark_dtype, types.BinaryType):
|
378
|
+
return exp.DataType(this=exp.DataType.Type.BINARY)
|
379
|
+
elif isinstance(spark_dtype, types.BooleanType):
|
380
|
+
return exp.DataType(this=exp.DataType.Type.BOOLEAN)
|
381
|
+
elif isinstance(spark_dtype, types.IntegerType):
|
382
|
+
return exp.DataType(this=exp.DataType.Type.INT)
|
383
|
+
elif isinstance(spark_dtype, types.LongType):
|
384
|
+
return exp.DataType(this=exp.DataType.Type.BIGINT)
|
385
|
+
elif isinstance(spark_dtype, types.ShortType):
|
386
|
+
return exp.DataType(this=exp.DataType.Type.SMALLINT)
|
387
|
+
elif isinstance(spark_dtype, types.ByteType):
|
388
|
+
return exp.DataType(this=exp.DataType.Type.TINYINT)
|
389
|
+
elif isinstance(spark_dtype, types.FloatType):
|
390
|
+
return exp.DataType(this=exp.DataType.Type.FLOAT)
|
391
|
+
elif isinstance(spark_dtype, types.DoubleType):
|
392
|
+
return exp.DataType(this=exp.DataType.Type.DOUBLE)
|
393
|
+
elif isinstance(spark_dtype, types.DecimalType):
|
394
|
+
if spark_dtype.precision is not None and spark_dtype.scale is not None:
|
395
|
+
return exp.DataType(
|
396
|
+
this=exp.DataType.Type.DECIMAL,
|
397
|
+
expressions=[
|
398
|
+
exp.DataTypeParam(this=exp.Literal.number(spark_dtype.precision)),
|
399
|
+
exp.DataTypeParam(this=exp.Literal.number(spark_dtype.scale)),
|
400
|
+
],
|
401
|
+
)
|
402
|
+
return exp.DataType(this=exp.DataType.Type.DECIMAL)
|
403
|
+
elif isinstance(spark_dtype, types.TimestampType):
|
404
|
+
return exp.DataType(this=exp.DataType.Type.TIMESTAMP)
|
405
|
+
elif isinstance(spark_dtype, types.TimestampNTZType):
|
406
|
+
return exp.DataType(this=exp.DataType.Type.TIMESTAMPNTZ)
|
407
|
+
elif isinstance(spark_dtype, types.DateType):
|
408
|
+
return exp.DataType(this=exp.DataType.Type.DATE)
|
409
|
+
|
410
|
+
# Handle complex types
|
411
|
+
elif isinstance(spark_dtype, types.ArrayType):
|
412
|
+
return exp.DataType(
|
413
|
+
this=exp.DataType.Type.ARRAY, expressions=[spark_to_sqlglot(spark_dtype.elementType)]
|
414
|
+
)
|
415
|
+
elif isinstance(spark_dtype, types.MapType):
|
416
|
+
return exp.DataType(
|
417
|
+
this=exp.DataType.Type.MAP,
|
418
|
+
expressions=[
|
419
|
+
spark_to_sqlglot(spark_dtype.keyType),
|
420
|
+
spark_to_sqlglot(spark_dtype.valueType),
|
421
|
+
],
|
422
|
+
)
|
423
|
+
elif isinstance(spark_dtype, types.StructType):
|
424
|
+
return exp.DataType(
|
425
|
+
this=exp.DataType.Type.STRUCT,
|
426
|
+
expressions=[
|
427
|
+
exp.ColumnDef(
|
428
|
+
this=exp.to_identifier(field.name), kind=spark_to_sqlglot(field.dataType)
|
429
|
+
)
|
430
|
+
for field in spark_dtype
|
431
|
+
],
|
432
|
+
)
|
433
|
+
|
434
|
+
raise NotImplementedError(f"Unsupported data type: {spark_dtype}")
|
435
|
+
|
436
|
+
|
350
437
|
def normalize_string(
|
351
438
|
value: t.Union[str, exp.Expression],
|
352
439
|
from_dialect: DialectType = None,
|
@@ -1,15 +1,15 @@
|
|
1
1
|
sqlframe/__init__.py,sha256=SB80yLTITBXHI2GCDS6n6bN5ObHqgPjfpRPAUwxaots,3403
|
2
|
-
sqlframe/_version.py,sha256=
|
2
|
+
sqlframe/_version.py,sha256=Pcai4MtywlDCWdUgy7BkNFFkadRh0SVOsjFMYSHbJLU,513
|
3
3
|
sqlframe/py.typed,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
|
4
4
|
sqlframe/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
5
|
sqlframe/base/_typing.py,sha256=b2clI5HI1zEZKB_3Msx3FeAJQyft44ubUifJwQRVXyQ,1298
|
6
|
-
sqlframe/base/catalog.py,sha256
|
6
|
+
sqlframe/base/catalog.py,sha256=-YulM2BMK8MoWbXi05AsJIPxd4AuiZDBCZuk4HoeMlE,38900
|
7
7
|
sqlframe/base/column.py,sha256=sp3fJstA49FslE2CcgvVFHyi7Jxsxk8qHTd-Z0cAEWc,19932
|
8
|
-
sqlframe/base/dataframe.py,sha256=
|
8
|
+
sqlframe/base/dataframe.py,sha256=6L8xTdwwQCkUzpJ6K3QlCcz5zqk2QQmGzteI-1EJ23A,84374
|
9
9
|
sqlframe/base/decorators.py,sha256=IhE5xNQDkwJHacCvulq5WpUKyKmXm7dL2A3o5WuKGP4,2131
|
10
10
|
sqlframe/base/exceptions.py,sha256=9Uwvqn2eAkDpqm4BrRgbL61qM-GMCbJEMAW8otxO46s,370
|
11
11
|
sqlframe/base/function_alternatives.py,sha256=Bs1bwl25fN3Yy9rb4GnUWBGunQ1C_yelkb2yV9DSZIY,53918
|
12
|
-
sqlframe/base/functions.py,sha256=
|
12
|
+
sqlframe/base/functions.py,sha256=i93fc9t7HooXMo8p35VLHd3FeYazVZztVIWqGBmsMYA,227188
|
13
13
|
sqlframe/base/group.py,sha256=OY4w1WRsCqLgW-Pi7DjF63zbbxSLISCF3qjAbzI2CQ4,4283
|
14
14
|
sqlframe/base/normalize.py,sha256=nXAJ5CwxVf4DV0GsH-q1w0p8gmjSMlv96k_ez1eVul8,3880
|
15
15
|
sqlframe/base/operations.py,sha256=g-YNcbvNKTOBbYm23GKfB3fmydlR7ZZDAuZUtXIHtzw,4438
|
@@ -19,7 +19,7 @@ sqlframe/base/table.py,sha256=rCeh1W5SWbtEVfkLAUiexzrZwNgmZeptLEmLcM1ABkE,6961
|
|
19
19
|
sqlframe/base/transforms.py,sha256=y0j3SGDz3XCmNGrvassk1S-owllUWfkHyMgZlY6SFO4,467
|
20
20
|
sqlframe/base/types.py,sha256=iBNk9bpFtb2NBIogYS8i7OlQZMRvpR6XxqzBebsjQDU,12280
|
21
21
|
sqlframe/base/udf.py,sha256=O6hMhBUy9NVv-mhJRtfFhXTIa_-Z8Y_FkmmuOHu0l90,1117
|
22
|
-
sqlframe/base/util.py,sha256=
|
22
|
+
sqlframe/base/util.py,sha256=gv_kRc3LxCuQy3t4dHFldV7elB8RU5PMqIN5-xSkWSo,19107
|
23
23
|
sqlframe/base/window.py,sha256=7NaKDTlhun-95LEghukBCjFBwq0RHrPaajWQNCsLxok,4818
|
24
24
|
sqlframe/base/mixins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
sqlframe/base/mixins/catalog_mixins.py,sha256=9fZGWToz9xMJSzUl1vsVtj6TH3TysP3fBCKJLnGUQzE,23353
|
@@ -130,8 +130,8 @@ sqlframe/standalone/udf.py,sha256=azmgtUjHNIPs0WMVNId05SHwiYn41MKVBhKXsQJ5dmY,27
|
|
130
130
|
sqlframe/standalone/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
|
131
131
|
sqlframe/testing/__init__.py,sha256=VVCosQhitU74A3NnE52O4mNtGZONapuEXcc20QmSlnQ,132
|
132
132
|
sqlframe/testing/utils.py,sha256=PFsGZpwNUE_4-g_f43_vstTqsK0AQ2lBneb5Eb6NkFo,13008
|
133
|
-
sqlframe-3.
|
134
|
-
sqlframe-3.
|
135
|
-
sqlframe-3.
|
136
|
-
sqlframe-3.
|
137
|
-
sqlframe-3.
|
133
|
+
sqlframe-3.32.1.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
|
134
|
+
sqlframe-3.32.1.dist-info/METADATA,sha256=fRuHZ0SdYY3p5ns78Az8S5qfa6irmVZe2S14ixRYaBE,8987
|
135
|
+
sqlframe-3.32.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
136
|
+
sqlframe-3.32.1.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
|
137
|
+
sqlframe-3.32.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|