sqlframe 1.14.0__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. sqlframe/_version.py +2 -2
  2. sqlframe/base/_typing.py +1 -0
  3. sqlframe/base/catalog.py +36 -13
  4. sqlframe/base/column.py +11 -9
  5. sqlframe/base/dataframe.py +72 -79
  6. sqlframe/base/decorators.py +0 -38
  7. sqlframe/base/function_alternatives.py +36 -25
  8. sqlframe/base/functions.py +88 -28
  9. sqlframe/base/mixins/catalog_mixins.py +156 -45
  10. sqlframe/base/mixins/dataframe_mixins.py +1 -1
  11. sqlframe/base/readerwriter.py +12 -14
  12. sqlframe/base/session.py +157 -84
  13. sqlframe/base/udf.py +36 -0
  14. sqlframe/base/util.py +71 -20
  15. sqlframe/bigquery/catalog.py +79 -28
  16. sqlframe/bigquery/functions.py +5 -8
  17. sqlframe/bigquery/session.py +4 -2
  18. sqlframe/bigquery/udf.py +11 -0
  19. sqlframe/duckdb/catalog.py +30 -13
  20. sqlframe/duckdb/dataframe.py +5 -0
  21. sqlframe/duckdb/functions.py +2 -0
  22. sqlframe/duckdb/readwriter.py +7 -6
  23. sqlframe/duckdb/session.py +8 -2
  24. sqlframe/duckdb/udf.py +19 -0
  25. sqlframe/postgres/catalog.py +30 -18
  26. sqlframe/postgres/functions.py +2 -0
  27. sqlframe/postgres/session.py +16 -5
  28. sqlframe/postgres/udf.py +11 -0
  29. sqlframe/redshift/catalog.py +28 -13
  30. sqlframe/redshift/session.py +4 -2
  31. sqlframe/redshift/udf.py +11 -0
  32. sqlframe/snowflake/catalog.py +64 -24
  33. sqlframe/snowflake/dataframe.py +9 -5
  34. sqlframe/snowflake/functions.py +1 -0
  35. sqlframe/snowflake/session.py +5 -2
  36. sqlframe/snowflake/udf.py +11 -0
  37. sqlframe/spark/catalog.py +180 -10
  38. sqlframe/spark/session.py +46 -14
  39. sqlframe/spark/udf.py +34 -0
  40. sqlframe/standalone/session.py +3 -0
  41. sqlframe/standalone/udf.py +11 -0
  42. {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/METADATA +14 -10
  43. {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/RECORD +46 -38
  44. {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/LICENSE +0 -0
  45. {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/WHEEL +0 -0
  46. {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/top_level.txt +0 -0
sqlframe/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '1.14.0'
16
- __version_tuple__ = version_tuple = (1, 14, 0)
15
+ __version__ = version = '2.1.0'
16
+ __version_tuple__ = version_tuple = (2, 1, 0)
sqlframe/base/_typing.py CHANGED
@@ -24,6 +24,7 @@ OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
24
24
  StorageLevel = str
25
25
  PathOrPaths = t.Union[str, t.List[str]]
26
26
  OptionalPrimitiveType = t.Optional[PrimitiveType]
27
+ DataTypeOrString = t.Union[DataType, str]
27
28
 
28
29
 
29
30
  class UserDefinedFunctionLike(t.Protocol):
sqlframe/base/catalog.py CHANGED
@@ -3,12 +3,12 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import typing as t
6
+ from collections import defaultdict
6
7
 
7
8
  from sqlglot import MappingSchema, exp
8
9
 
9
- from sqlframe.base.decorators import normalize
10
10
  from sqlframe.base.exceptions import TableSchemaError
11
- from sqlframe.base.util import ensure_column_mapping, to_schema
11
+ from sqlframe.base.util import ensure_column_mapping, normalize_string, to_schema
12
12
 
13
13
  if t.TYPE_CHECKING:
14
14
  from sqlglot.schema import ColumnMapping
@@ -33,6 +33,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
33
33
  """Create a new Catalog that wraps the underlying JVM object."""
34
34
  self.session = sparkSession
35
35
  self._schema = schema or MappingSchema()
36
+ self._quoted_columns: t.Dict[exp.Table, t.List[str]] = defaultdict(list)
36
37
 
37
38
  @property
38
39
  def spark(self) -> SESSION:
@@ -52,7 +53,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
52
53
  def get_columns_from_schema(self, table: exp.Table | str) -> t.Dict[str, exp.DataType]:
53
54
  table = self.ensure_table(table)
54
55
  return {
55
- exp.column(name, quoted=True).sql(
56
+ exp.column(name, quoted=name in self._quoted_columns[table]).sql(
56
57
  dialect=self.session.input_dialect
57
58
  ): exp.DataType.build(dtype, dialect=self.session.input_dialect)
58
59
  for name, dtype in self._schema.find(table, raise_on_missing=True).items() # type: ignore
@@ -64,9 +65,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
64
65
  if not columns:
65
66
  return {}
66
67
  return {
67
- exp.column(c.name, quoted=True).sql(
68
- dialect=self.session.input_dialect
69
- ): exp.DataType.build(c.dataType, dialect=self.session.input_dialect)
68
+ c.name: exp.DataType.build(c.dataType, dialect=self.session.output_dialect)
70
69
  for c in columns
71
70
  }
72
71
 
@@ -79,16 +78,30 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
79
78
  return
80
79
  if not column_mapping:
81
80
  try:
82
- column_mapping = self.get_columns(table)
81
+ column_mapping = {
82
+ normalize_string(
83
+ k, from_dialect="output", to_dialect="input", is_column=True
84
+ ): normalize_string(
85
+ v.sql(dialect=self.session.output_dialect),
86
+ from_dialect="output",
87
+ to_dialect="input",
88
+ is_datatype=True,
89
+ )
90
+ for k, v in self.get_columns(table).items()
91
+ }
83
92
  except NotImplementedError:
84
93
  # TODO: Add doc link
85
94
  raise TableSchemaError(
86
95
  "This session does not have access to a catalog that can lookup column information. See docs for explicitly defining columns or using a session that can automatically determine this."
87
96
  )
88
97
  column_mapping = ensure_column_mapping(column_mapping) # type: ignore
98
+ for column_name in column_mapping:
99
+ column = exp.to_column(column_name, dialect=self.session.input_dialect)
100
+ if column.this.quoted:
101
+ self._quoted_columns[table].append(column.this.name)
102
+
89
103
  self._schema.add_table(table, column_mapping, dialect=self.session.input_dialect)
90
104
 
91
- @normalize(["dbName"])
92
105
  def getDatabase(self, dbName: str) -> Database:
93
106
  """Get the database with the specified name.
94
107
  This throws an :class:`AnalysisException` when the database cannot be found.
@@ -115,6 +128,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
115
128
  >>> spark.catalog.getDatabase("spark_catalog.default")
116
129
  Database(name='default', catalog='spark_catalog', description='default database', ...
117
130
  """
131
+ dbName = normalize_string(dbName, from_dialect="input", is_schema=True)
118
132
  schema = to_schema(dbName, dialect=self.session.input_dialect)
119
133
  database_name = schema.db
120
134
  databases = self.listDatabases(pattern=database_name)
@@ -122,12 +136,16 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
122
136
  raise ValueError(f"Database '{dbName}' not found")
123
137
  if len(databases) > 1:
124
138
  if schema.catalog is not None:
125
- filtered_databases = [db for db in databases if db.catalog == schema.catalog]
139
+ filtered_databases = [
140
+ db
141
+ for db in databases
142
+ if normalize_string(db.catalog, from_dialect="output", to_dialect="input") # type: ignore
143
+ == schema.catalog
144
+ ]
126
145
  if filtered_databases:
127
146
  return filtered_databases[0]
128
147
  return databases[0]
129
148
 
130
- @normalize(["dbName"])
131
149
  def databaseExists(self, dbName: str) -> bool:
132
150
  """Check if the database with the specified name exists.
133
151
 
@@ -168,7 +186,6 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
168
186
  except ValueError:
169
187
  return False
170
188
 
171
- @normalize(["tableName"])
172
189
  def getTable(self, tableName: str) -> Table:
173
190
  """Get the table or view with the specified name. This table can be a temporary view or a
174
191
  table/view. This throws an :class:`AnalysisException` when no Table can be found.
@@ -210,13 +227,18 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
210
227
  ...
211
228
  AnalysisException: ...
212
229
  """
230
+ tableName = normalize_string(tableName, from_dialect="input", is_table=True)
213
231
  table = exp.to_table(tableName, dialect=self.session.input_dialect)
214
232
  schema = table.copy()
215
233
  schema.set("this", None)
216
234
  tables = self.listTables(
217
235
  schema.sql(dialect=self.session.input_dialect) if schema.db else None
218
236
  )
219
- matching_tables = [t for t in tables if t.name == table.name]
237
+ matching_tables = [
238
+ t
239
+ for t in tables
240
+ if normalize_string(t.name, from_dialect="output", to_dialect="input") == table.name
241
+ ]
220
242
  if not matching_tables:
221
243
  raise ValueError(f"Table '{tableName}' not found")
222
244
  return matching_tables[0]
@@ -315,7 +337,6 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
315
337
  raise ValueError(f"Function '{functionName}' not found")
316
338
  return matching_functions[0]
317
339
 
318
- @normalize(["tableName", "dbName"])
319
340
  def tableExists(self, tableName: str, dbName: t.Optional[str] = None) -> bool:
320
341
  """Check if the table or view with the specified name exists.
321
342
  This can either be a temporary view or a table/view.
@@ -389,6 +410,8 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
389
410
  >>> spark.catalog.tableExists("view1")
390
411
  False
391
412
  """
413
+ tableName = normalize_string(tableName, from_dialect="input", is_table=True)
414
+ dbName = normalize_string(dbName, from_dialect="input", is_schema=True) if dbName else None
392
415
  table = exp.to_table(tableName, dialect=self.session.input_dialect)
393
416
  schema_arg = to_schema(dbName, dialect=self.session.input_dialect) if dbName else None
394
417
  if not table.db:
sqlframe/base/column.py CHANGED
@@ -7,11 +7,11 @@ import math
7
7
  import typing as t
8
8
 
9
9
  import sqlglot
10
+ from sqlglot import Dialect
10
11
  from sqlglot import expressions as exp
11
12
  from sqlglot.helper import flatten, is_iterable
12
13
  from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
13
14
 
14
- from sqlframe.base.decorators import normalize
15
15
  from sqlframe.base.exceptions import UnsupportedOperationError
16
16
  from sqlframe.base.types import DataType
17
17
  from sqlframe.base.util import get_func_from_session, quote_preserving_alias_or_name
@@ -211,9 +211,8 @@ class Column:
211
211
  def binary_op(
212
212
  self, klass: t.Callable, other: ColumnOrLiteral, paren: bool = False, **kwargs
213
213
  ) -> Column:
214
- op = klass(
215
- this=self.column_expression, expression=Column(other).column_expression, **kwargs
216
- )
214
+ other = self._lit(other) if isinstance(other, str) else Column(other)
215
+ op = klass(this=self.column_expression, expression=other.column_expression, **kwargs)
217
216
  if paren:
218
217
  return Column(exp.Paren(this=op))
219
218
  return Column(op)
@@ -221,9 +220,8 @@ class Column:
221
220
  def inverse_binary_op(
222
221
  self, klass: t.Callable, other: ColumnOrLiteral, paren: bool = False, **kwargs
223
222
  ) -> Column:
224
- op = klass(
225
- this=Column(other).column_expression, expression=self.column_expression, **kwargs
226
- )
223
+ other = self._lit(other) if isinstance(other, str) else Column(other)
224
+ op = klass(this=other.column_expression, expression=self.column_expression, **kwargs)
227
225
  if paren:
228
226
  return Column(exp.Paren(this=op))
229
227
  return Column(op)
@@ -340,13 +338,17 @@ class Column:
340
338
  new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
341
339
  return Column(new_expression)
342
340
 
343
- def cast(self, dataType: t.Union[str, DataType]) -> Column:
341
+ def cast(
342
+ self, dataType: t.Union[str, DataType], dialect: t.Optional[t.Union[str, Dialect]] = None
343
+ ) -> Column:
344
344
  from sqlframe.base.session import _BaseSession
345
345
 
346
346
  if isinstance(dataType, DataType):
347
347
  dataType = dataType.simpleString()
348
348
  return Column(
349
- exp.cast(self.column_expression, dataType, dialect=_BaseSession().input_dialect)
349
+ exp.cast(
350
+ self.column_expression, dataType, dialect=dialect or _BaseSession().input_dialect
351
+ )
350
352
  )
351
353
 
352
354
  def startswith(self, value: t.Union[str, Column]) -> Column:
@@ -23,12 +23,12 @@ from sqlglot.optimizer.qualify import qualify
23
23
  from sqlglot.optimizer.qualify_columns import quote_identifiers
24
24
 
25
25
  from sqlframe.base.catalog import Column as CatalogColumn
26
- from sqlframe.base.decorators import normalize
27
26
  from sqlframe.base.operations import Operation, operation
28
27
  from sqlframe.base.transforms import replace_id_value
29
28
  from sqlframe.base.util import (
30
29
  get_func_from_session,
31
30
  get_tables_from_expression_with_join,
31
+ normalize_string,
32
32
  quote_preserving_alias_or_name,
33
33
  sqlglot_to_spark,
34
34
  verify_openai_installed,
@@ -41,6 +41,7 @@ else:
41
41
 
42
42
  if t.TYPE_CHECKING:
43
43
  import pandas as pd
44
+ from pyarrow import Table as ArrowTable
44
45
  from sqlglot.dialects.dialect import DialectType
45
46
 
46
47
  from sqlframe.base._typing import (
@@ -535,39 +536,11 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
535
536
  )
536
537
  return [col]
537
538
 
538
- @t.overload
539
- def sql(
540
- self,
541
- dialect: DialectType = ...,
542
- optimize: bool = ...,
543
- pretty: bool = ...,
544
- *,
545
- as_list: t.Literal[False],
546
- **kwargs: t.Any,
547
- ) -> str: ...
548
-
549
- @t.overload
550
- def sql(
551
- self,
552
- dialect: DialectType = ...,
553
- optimize: bool = ...,
554
- pretty: bool = ...,
555
- *,
556
- as_list: t.Literal[True],
557
- **kwargs: t.Any,
558
- ) -> t.List[str]: ...
559
-
560
- def sql(
539
+ def _get_expressions(
561
540
  self,
562
- dialect: DialectType = None,
563
541
  optimize: bool = True,
564
- pretty: bool = True,
565
542
  openai_config: t.Optional[t.Union[t.Dict[str, t.Any], OpenAIConfig]] = None,
566
- as_list: bool = False,
567
- **kwargs,
568
- ) -> t.Union[str, t.List[str]]:
569
- dialect = Dialect.get_or_raise(dialect or self.session.output_dialect)
570
-
543
+ ) -> t.List[exp.Expression]:
571
544
  df = self._resolve_pending_hints()
572
545
  select_expressions = df._get_select_expressions()
573
546
  output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
@@ -583,12 +556,16 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
583
556
  replace_id_value, replacement_mapping
584
557
  ).assert_is(exp.Select)
585
558
  if optimize:
586
- quote_identifiers(select_expression, dialect=dialect)
587
559
  select_expression = t.cast(
588
- exp.Select, self.session._optimize(select_expression, dialect=dialect)
560
+ exp.Select,
561
+ self.session._optimize(select_expression),
589
562
  )
590
563
  elif openai_config:
591
- qualify(select_expression, dialect=dialect, schema=self.session.catalog._schema)
564
+ qualify(
565
+ select_expression,
566
+ dialect=self.session.input_dialect,
567
+ schema=self.session.catalog._schema,
568
+ )
592
569
  pushdown_projections(select_expression, schema=self.session.catalog._schema)
593
570
 
594
571
  select_expression = df._replace_cte_names_with_hashes(select_expression)
@@ -606,7 +583,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
606
583
  cache_table_name,
607
584
  {
608
585
  quote_preserving_alias_or_name(expression): expression.type.sql(
609
- dialect=dialect
586
+ dialect=self.session.input_dialect
610
587
  )
611
588
  if expression.type
612
589
  else "UNKNOWN"
@@ -642,10 +619,43 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
642
619
  raise ValueError(f"Invalid expression type: {expression_type}")
643
620
 
644
621
  output_expressions.append(expression)
622
+ return output_expressions # type: ignore
623
+
624
+ @t.overload
625
+ def sql(
626
+ self,
627
+ dialect: DialectType = ...,
628
+ optimize: bool = ...,
629
+ pretty: bool = ...,
630
+ *,
631
+ as_list: t.Literal[False],
632
+ **kwargs: t.Any,
633
+ ) -> str: ...
634
+
635
+ @t.overload
636
+ def sql(
637
+ self,
638
+ dialect: DialectType = ...,
639
+ optimize: bool = ...,
640
+ pretty: bool = ...,
641
+ *,
642
+ as_list: t.Literal[True],
643
+ **kwargs: t.Any,
644
+ ) -> t.List[str]: ...
645
645
 
646
+ def sql(
647
+ self,
648
+ dialect: DialectType = None,
649
+ optimize: bool = True,
650
+ pretty: bool = True,
651
+ openai_config: t.Optional[t.Union[t.Dict[str, t.Any], OpenAIConfig]] = None,
652
+ as_list: bool = False,
653
+ **kwargs,
654
+ ) -> t.Union[str, t.List[str]]:
655
+ dialect = Dialect.get_or_raise(dialect) if dialect else self.session.output_dialect
646
656
  results = []
647
- for expression in output_expressions:
648
- sql = expression.sql(dialect=dialect, pretty=pretty, **kwargs)
657
+ for expression in self._get_expressions(optimize=optimize, openai_config=openai_config):
658
+ sql = self.session._to_sql(expression, dialect=dialect, pretty=pretty, **kwargs)
649
659
  if openai_config:
650
660
  assert isinstance(openai_config, OpenAIConfig)
651
661
  verify_openai_installed()
@@ -828,7 +838,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
828
838
  from sqlframe.base.functions import coalesce
829
839
 
830
840
  if on is None:
831
- logger.warning("Got no value for on. This appears change the join to a cross join.")
841
+ logger.warning("Got no value for on. This appears to change the join to a cross join.")
832
842
  how = "cross"
833
843
  other_df = other_df._convert_leaf_to_cte()
834
844
  # We will determine actual "join on" expression later so we don't provide it at first
@@ -871,13 +881,12 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
871
881
  )
872
882
  join_column_pairs.append((left_column, right_column))
873
883
  num_matching_ctes += 1
874
- if num_matching_ctes > 1:
875
- raise ValueError(
876
- f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
877
- )
878
- elif num_matching_ctes == 0:
884
+ # We only want to match one table to the column and that should be matched left -> right
885
+ # so we break after the first match
886
+ break
887
+ if num_matching_ctes == 0:
879
888
  raise ValueError(
880
- f"Column {join_column.alias_or_name} does not exist in any of the tables."
889
+ f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
881
890
  )
882
891
  join_clause = functools.reduce(
883
892
  lambda x, y: x & y,
@@ -1154,7 +1163,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1154
1163
  if len(sql_queries) > 1:
1155
1164
  raise ValueError("Cannot explain a DataFrame with multiple queries")
1156
1165
  sql_query = "EXPLAIN " + sql_queries[0]
1157
- self.session._execute(sql_query, quote_identifiers=False)
1166
+ self.session._execute(sql_query)
1158
1167
 
1159
1168
  @operation(Operation.FROM)
1160
1169
  def fillna(
@@ -1373,23 +1382,20 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1373
1382
  self._ensure_and_normalize_col(k).alias_or_name: self._ensure_and_normalize_col(v)
1374
1383
  for k, v in colsMap[0].items()
1375
1384
  }
1376
- existing_col_names = [
1377
- x.alias_or_name for x in self._get_outer_select_columns(self.expression)
1378
- ]
1379
- updated_expression = self.expression.copy()
1380
- select_columns = []
1385
+ existing_cols = self._get_outer_select_columns(self.expression)
1386
+ existing_col_names = [x.alias_or_name for x in existing_cols]
1387
+ select_columns = existing_cols
1381
1388
  for column_name, col_value in col_map.items():
1382
1389
  existing_col_index = (
1383
1390
  existing_col_names.index(column_name) if column_name in existing_col_names else None
1384
1391
  )
1385
- if existing_col_index:
1386
- updated_expression.expressions[existing_col_index] = col_value.alias(
1392
+ if existing_col_index is not None:
1393
+ select_columns[existing_col_index] = col_value.alias( # type: ignore
1387
1394
  column_name
1388
1395
  ).expression
1389
1396
  else:
1390
1397
  select_columns.append(col_value.alias(column_name))
1391
- df = self.copy(expression=updated_expression)
1392
- return df.select.__wrapped__(df, *select_columns, append=True) # type: ignore
1398
+ return self.select.__wrapped__(self, *select_columns) # type: ignore
1393
1399
 
1394
1400
  @operation(Operation.SELECT)
1395
1401
  def drop(self, *cols: t.Union[str, Column]) -> Self:
@@ -1539,10 +1545,10 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1539
1545
  return self._group_data(self, grouping_columns, self.last_op)
1540
1546
 
1541
1547
  def collect(self) -> t.List[Row]:
1542
- result = []
1543
- for sql in self.sql(pretty=False, optimize=False, as_list=True):
1544
- result = self.session._fetch_rows(sql)
1545
- return result
1548
+ return self._collect()
1549
+
1550
+ def _collect(self, **kwargs) -> t.List[Row]:
1551
+ return self.session._collect(self._get_expressions(optimize=False), **kwargs)
1546
1552
 
1547
1553
  @t.overload
1548
1554
  def head(self) -> t.Optional[Row]: ...
@@ -1569,11 +1575,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1569
1575
  logger.warning("Truncate is ignored so full results will be displayed")
1570
1576
  # Make sure that the limit we add doesn't affect the results
1571
1577
  df = self._convert_leaf_to_cte()
1572
- sql = df.limit(n).sql(
1573
- pretty=False, optimize=False, dialect=self.session.output_dialect, as_list=True
1574
- )
1575
- for sql in ensure_list(sql):
1576
- result = self.session._fetch_rows(sql)
1578
+ result = df.limit(n).collect()
1577
1579
  table = PrettyTable()
1578
1580
  if row := seq_get(result, 0):
1579
1581
  table.field_names = row._unique_field_names
@@ -1612,18 +1614,10 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1612
1614
  )
1613
1615
 
1614
1616
  def toPandas(self) -> pd.DataFrame:
1615
- sql_kwargs = dict(
1616
- pretty=False, optimize=False, dialect=self.session.output_dialect, as_list=True
1617
- )
1618
- sqls = [None] + self.sql(**sql_kwargs) # type: ignore
1619
- for sql in self.sql(**sql_kwargs)[:-1]: # type: ignore
1620
- if sql:
1621
- self.session._execute(sql)
1622
- assert sqls[-1] is not None
1623
- return self.session._fetchdf(sqls[-1])
1624
-
1625
- @normalize("name")
1617
+ return self.session._fetchdf(self._get_expressions(optimize=False))
1618
+
1626
1619
  def createOrReplaceTempView(self, name: str) -> None:
1620
+ name = normalize_string(name, from_dialect="input")
1627
1621
  self.session.temp_views[name] = self.copy()._convert_leaf_to_cte()
1628
1622
 
1629
1623
  def count(self) -> int:
@@ -1632,11 +1626,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1632
1626
 
1633
1627
  df = self._convert_leaf_to_cte()
1634
1628
  df = self.copy(expression=df.expression.select("count(*)", append=False))
1635
- for sql in df.sql(
1636
- dialect=self.session.output_dialect, pretty=False, optimize=False, as_list=True
1637
- ):
1638
- result = self.session._fetch_rows(sql)
1639
- return result[0][0]
1629
+ return df.collect()[0][0]
1640
1630
 
1641
1631
  def createGlobalTempView(self, name: str) -> None:
1642
1632
  raise NotImplementedError("Global temp views are not yet supported")
@@ -1816,3 +1806,6 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1816
1806
  col_func = get_func_from_session("col")
1817
1807
 
1818
1808
  return self.select(covar_samp(col_func(col1), col_func(col2))).collect()[0][0]
1809
+
1810
+ def toArrow(self) -> ArrowTable:
1811
+ raise NotImplementedError("Arrow conversion is not supported by this engine")
@@ -1,50 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- import functools
4
3
  import typing as t
5
4
 
6
- from sqlglot import parse_one
7
5
  from sqlglot.helper import ensure_list
8
- from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
9
-
10
- if t.TYPE_CHECKING:
11
- from sqlframe.base.catalog import _BaseCatalog
12
6
 
13
7
  CALLING_CLASS = t.TypeVar("CALLING_CLASS")
14
8
 
15
9
 
16
- def normalize(normalize_kwargs: t.Union[str, t.List[str]]) -> t.Callable[[t.Callable], t.Callable]:
17
- """
18
- Decorator used to normalize identifiers in the kwargs of a method.
19
- """
20
-
21
- def decorator(func: t.Callable) -> t.Callable:
22
- @functools.wraps(func)
23
- def wrapper(self: CALLING_CLASS, *args, **kwargs) -> CALLING_CLASS:
24
- from sqlframe.base.session import _BaseSession
25
-
26
- input_dialect = _BaseSession().input_dialect
27
- kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
28
- for kwarg in ensure_list(normalize_kwargs):
29
- if kwarg in kwargs:
30
- value = kwargs.get(kwarg)
31
- if value:
32
- expression = (
33
- parse_one(value, dialect=input_dialect)
34
- if isinstance(value, str)
35
- else value
36
- )
37
- kwargs[kwarg] = normalize_identifiers(expression, input_dialect).sql(
38
- dialect=input_dialect
39
- )
40
- return func(self, **kwargs)
41
-
42
- wrapper.__wrapped__ = func # type: ignore
43
- return wrapper
44
-
45
- return decorator
46
-
47
-
48
10
  def func_metadata(unsupported_engines: t.Optional[t.Union[str, t.List[str]]] = None) -> t.Callable:
49
11
  def _metadata(func: t.Callable) -> t.Callable:
50
12
  func.unsupported_engines = ensure_list(unsupported_engines) if unsupported_engines else [] # type: ignore