sqlframe 1.9.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '1.9.0'
16
- __version_tuple__ = version_tuple = (1, 9, 0)
15
+ __version__ = version = '1.11.0'
16
+ __version_tuple__ = version_tuple = (1, 11, 0)
@@ -22,6 +22,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections
22
22
  from sqlglot.optimizer.qualify import qualify
23
23
  from sqlglot.optimizer.qualify_columns import quote_identifiers
24
24
 
25
+ from sqlframe.base.catalog import Column as CatalogColumn
25
26
  from sqlframe.base.decorators import normalize
26
27
  from sqlframe.base.operations import Operation, operation
27
28
  from sqlframe.base.transforms import replace_id_value
@@ -29,6 +30,7 @@ from sqlframe.base.util import (
29
30
  get_func_from_session,
30
31
  get_tables_from_expression_with_join,
31
32
  quote_preserving_alias_or_name,
33
+ sqlglot_to_spark,
32
34
  verify_openai_installed,
33
35
  )
34
36
 
@@ -231,6 +233,10 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
231
233
  def __copy__(self):
232
234
  return self.copy()
233
235
 
236
+ @property
237
+ def _typed_columns(self) -> t.List[CatalogColumn]:
238
+ raise NotImplementedError
239
+
234
240
  @property
235
241
  def write(self) -> WRITER:
236
242
  return self.session._writer(self)
@@ -293,7 +299,24 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
293
299
  StructType([StructField('age', LongType(), True),
294
300
  StructField('name', StringType(), True)])
295
301
  """
296
- raise NotImplementedError
302
+ from sqlframe.base import types
303
+
304
+ try:
305
+ return types.StructType(
306
+ [
307
+ types.StructField(
308
+ c.name,
309
+ sqlglot_to_spark(
310
+ exp.DataType.build(c.dataType, dialect=self.session.output_dialect)
311
+ ),
312
+ )
313
+ for c in self._typed_columns
314
+ ]
315
+ )
316
+ except NotImplementedError as e:
317
+ raise NotImplementedError(
318
+ "This engine does not support schema inference likely since it does not have an active connection."
319
+ ) from e
297
320
 
298
321
  def _replace_cte_names_with_hashes(self, expression: exp.Select):
299
322
  replacement_mapping = {}
@@ -1537,6 +1560,36 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1537
1560
  table.add_row(list(row))
1538
1561
  print(table)
1539
1562
 
1563
+ def printSchema(self, level: t.Optional[int] = None) -> None:
1564
+ def print_schema(
1565
+ column_name: str, column_type: exp.DataType, nullable: bool, current_level: int
1566
+ ):
1567
+ if level and current_level >= level:
1568
+ return
1569
+ if current_level > 0:
1570
+ print(" | " * current_level, end="")
1571
+ print(
1572
+ f" |-- {column_name}: {column_type.sql(self.session.output_dialect).lower()} (nullable = {str(nullable).lower()})"
1573
+ )
1574
+ if column_type.this in (exp.DataType.Type.STRUCT, exp.DataType.Type.OBJECT):
1575
+ for column_def in column_type.expressions:
1576
+ print_schema(column_def.name, column_def.args["kind"], True, current_level + 1)
1577
+ if column_type.this == exp.DataType.Type.ARRAY:
1578
+ for data_type in column_type.expressions:
1579
+ print_schema("element", data_type, True, current_level + 1)
1580
+ if column_type.this == exp.DataType.Type.MAP:
1581
+ print_schema("key", column_type.expressions[0], True, current_level + 1)
1582
+ print_schema("value", column_type.expressions[1], True, current_level + 1)
1583
+
1584
+ print("root")
1585
+ for column in self._typed_columns:
1586
+ print_schema(
1587
+ column.name,
1588
+ exp.DataType.build(column.dataType, dialect=self.session.output_dialect),
1589
+ column.nullable,
1590
+ 0,
1591
+ )
1592
+
1540
1593
  def toPandas(self) -> pd.DataFrame:
1541
1594
  sql_kwargs = dict(
1542
1595
  pretty=False, optimize=False, dialect=self.session.output_dialect, as_list=True
@@ -12,3 +12,15 @@ class RowError(SQLFrameException):
12
12
 
13
13
  class TableSchemaError(SQLFrameException):
14
14
  pass
15
+
16
+
17
+ class PandasDiffError(SQLFrameException):
18
+ pass
19
+
20
+
21
+ class DataFrameDiffError(SQLFrameException):
22
+ pass
23
+
24
+
25
+ class SchemaDiffError(SQLFrameException):
26
+ pass
@@ -1424,3 +1424,99 @@ def bit_length_from_length(col: ColumnOrName) -> Column:
1424
1424
  col_func = get_func_from_session("col")
1425
1425
 
1426
1426
  return Column(expression.Length(this=col_func(col).expression)) * lit(8)
1427
+
1428
+
1429
+ def any_value_always_ignore_nulls(
1430
+ col: ColumnOrName, ignoreNulls: t.Optional[t.Union[bool, Column]] = None
1431
+ ) -> Column:
1432
+ from sqlframe.base.functions import any_value
1433
+
1434
+ if not ignoreNulls:
1435
+ logger.warning("Nulls are always ignored when using `ANY_VALUE` on this engine")
1436
+ return any_value(col)
1437
+
1438
+
1439
+ def any_value_ignore_nulls_not_supported(
1440
+ col: ColumnOrName, ignoreNulls: t.Optional[t.Union[bool, Column]] = None
1441
+ ) -> Column:
1442
+ from sqlframe.base.functions import any_value
1443
+
1444
+ if ignoreNulls:
1445
+ logger.warning("Ignoring nulls is not supported in this dialect")
1446
+ return any_value(col)
1447
+
1448
+
1449
+ def current_user_from_session_user() -> Column:
1450
+ return Column(expression.Anonymous(this="SESSION_USER"))
1451
+
1452
+
1453
+ def extract_convert_to_var(field: ColumnOrName, source: ColumnOrName) -> Column:
1454
+ from sqlframe.base.functions import extract
1455
+
1456
+ field = expression.Var(this=Column.ensure_col(field).alias_or_name) # type: ignore
1457
+ return extract(field, source) # type: ignore
1458
+
1459
+
1460
+ def left_cast_len(str: ColumnOrName, len: ColumnOrName) -> Column:
1461
+ from sqlframe.base.functions import left
1462
+
1463
+ len = Column.ensure_col(len).cast("integer")
1464
+ return left(str, len)
1465
+
1466
+
1467
+ def right_cast_len(str: ColumnOrName, len: ColumnOrName) -> Column:
1468
+ from sqlframe.base.functions import right
1469
+
1470
+ len = Column.ensure_col(len).cast("integer")
1471
+ return right(str, len)
1472
+
1473
+
1474
+ def position_cast_start(
1475
+ substr: ColumnOrName, str: ColumnOrName, start: t.Optional[ColumnOrName] = None
1476
+ ) -> Column:
1477
+ from sqlframe.base.functions import position
1478
+
1479
+ start = Column.ensure_col(start).cast("integer") if start else None
1480
+ return position(substr, str, start)
1481
+
1482
+
1483
+ def position_as_strpos(
1484
+ substr: ColumnOrName, str: ColumnOrName, start: t.Optional[ColumnOrName] = None
1485
+ ) -> Column:
1486
+ substr_func = get_func_from_session("substr")
1487
+ lit = get_func_from_session("lit")
1488
+
1489
+ if start:
1490
+ str = substr_func(str, start)
1491
+ column = Column.invoke_anonymous_function(str, "STRPOS", substr)
1492
+ if start:
1493
+ return column + start - lit(1)
1494
+ return column
1495
+
1496
+
1497
+ def to_number_using_to_double(col: ColumnOrName, format: ColumnOrName) -> Column:
1498
+ return Column.invoke_anonymous_function(col, "TO_DOUBLE", format)
1499
+
1500
+
1501
+ def try_element_at_zero_based(col: ColumnOrName, extraction: ColumnOrName) -> Column:
1502
+ from sqlframe.base.functions import try_element_at
1503
+
1504
+ lit = get_func_from_session("lit")
1505
+ index = Column.ensure_col(extraction)
1506
+ if isinstance(index.expression, expression.Literal) and index.expression.is_number:
1507
+ index = index - lit(1)
1508
+ return try_element_at(col, index)
1509
+
1510
+
1511
+ def to_unix_timestamp_include_default_format(
1512
+ timestamp: ColumnOrName,
1513
+ format: t.Optional[ColumnOrName] = None,
1514
+ ) -> Column:
1515
+ from sqlframe.base.functions import to_unix_timestamp
1516
+
1517
+ lit = get_func_from_session("lit")
1518
+
1519
+ if not format:
1520
+ format = lit("%Y-%m-%d %H:%M:%S")
1521
+
1522
+ return to_unix_timestamp(timestamp, format)