sqlframe 1.9.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,16 +7,17 @@ import typing as t
7
7
 
8
8
  from sqlglot import exp, parse_one
9
9
 
10
- from sqlframe.base.catalog import Function, _BaseCatalog
10
+ from sqlframe.base.catalog import Column, Function, _BaseCatalog
11
+ from sqlframe.base.decorators import normalize
11
12
  from sqlframe.base.mixins.catalog_mixins import (
12
13
  GetCurrentCatalogFromFunctionMixin,
13
14
  GetCurrentDatabaseFromFunctionMixin,
14
15
  ListCatalogsFromInfoSchemaMixin,
15
- ListColumnsFromInfoSchemaMixin,
16
16
  ListDatabasesFromInfoSchemaMixin,
17
17
  ListTablesFromInfoSchemaMixin,
18
18
  SetCurrentDatabaseFromSearchPathMixin,
19
19
  )
20
+ from sqlframe.base.util import to_schema
20
21
 
21
22
  if t.TYPE_CHECKING:
22
23
  from sqlframe.postgres.session import PostgresSession # noqa
@@ -30,12 +31,131 @@ class PostgresCatalog(
30
31
  ListCatalogsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
31
32
  SetCurrentDatabaseFromSearchPathMixin["PostgresSession", "PostgresDataFrame"],
32
33
  ListTablesFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
33
- ListColumnsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
34
34
  _BaseCatalog["PostgresSession", "PostgresDataFrame"],
35
35
  ):
36
36
  CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.column("current_catalog")
37
37
  TEMP_SCHEMA_FILTER = exp.column("table_schema").like("pg_temp_%")
38
38
 
39
+ @normalize(["tableName", "dbName"])
40
+ def listColumns(
41
+ self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
42
+ ) -> t.List[Column]:
43
+ """Returns a t.List of columns for the given table/view in the specified database.
44
+
45
+ .. versionadded:: 2.0.0
46
+
47
+ Parameters
48
+ ----------
49
+ tableName : str
50
+ name of the table to t.List columns.
51
+
52
+ .. versionchanged:: 3.4.0
53
+ Allow ``tableName`` to be qualified with catalog name when ``dbName`` is None.
54
+
55
+ dbName : str, t.Optional
56
+ name of the database to find the table to t.List columns.
57
+
58
+ Returns
59
+ -------
60
+ t.List
61
+ A t.List of :class:`Column`.
62
+
63
+ Notes
64
+ -----
65
+ The order of arguments here is different from that of its JVM counterpart
66
+ because Python does not support method overloading.
67
+
68
+ If no database is specified, the current database and catalog
69
+ are used. This API includes all temporary views.
70
+
71
+ Examples
72
+ --------
73
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
74
+ >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
75
+ >>> spark.catalog.t.listColumns("tblA")
76
+ [Column(name='name', description=None, dataType='string', nullable=True, ...
77
+ >>> _ = spark.sql("DROP TABLE tblA")
78
+ """
79
+ if df := self.session.temp_views.get(tableName):
80
+ return [
81
+ Column(
82
+ name=x,
83
+ description=None,
84
+ dataType="",
85
+ nullable=True,
86
+ isPartition=False,
87
+ isBucket=False,
88
+ )
89
+ for x in df.columns
90
+ ]
91
+
92
+ table = exp.to_table(tableName, dialect=self.session.input_dialect)
93
+ schema = to_schema(dbName, dialect=self.session.input_dialect) if dbName else None
94
+ if not table.db:
95
+ if schema and schema.db:
96
+ table.set("db", schema.args["db"])
97
+ else:
98
+ table.set(
99
+ "db",
100
+ exp.parse_identifier(
101
+ self.currentDatabase(), dialect=self.session.input_dialect
102
+ ),
103
+ )
104
+ if not table.catalog:
105
+ if schema and schema.catalog:
106
+ table.set("catalog", schema.args["catalog"])
107
+ else:
108
+ table.set(
109
+ "catalog",
110
+ exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
111
+ )
112
+ source_table = self._get_info_schema_table("columns", database=table.db)
113
+ select = parse_one(
114
+ f"""
115
+ SELECT
116
+ att.attname AS column_name,
117
+ pg_catalog.format_type(att.atttypid, NULL) AS data_type,
118
+ col.is_nullable
119
+ FROM
120
+ pg_catalog.pg_attribute att
121
+ JOIN
122
+ pg_catalog.pg_class cls ON cls.oid = att.attrelid
123
+ JOIN
124
+ pg_catalog.pg_namespace nsp ON nsp.oid = cls.relnamespace
125
+ JOIN
126
+ information_schema.columns col ON col.table_schema = nsp.nspname AND col.table_name = cls.relname AND col.column_name = att.attname
127
+ WHERE
128
+ cls.relname = '{table.name}' AND -- replace with your table name
129
+ att.attnum > 0 AND
130
+ NOT att.attisdropped
131
+ ORDER BY
132
+ att.attnum;
133
+ """,
134
+ dialect="postgres",
135
+ )
136
+ if table.db:
137
+ schema_filter: exp.Expression = exp.column("table_schema").eq(table.db)
138
+ if include_temp and self.TEMP_SCHEMA_FILTER:
139
+ schema_filter = exp.Or(this=schema_filter, expression=self.TEMP_SCHEMA_FILTER)
140
+ select = select.where(schema_filter) # type: ignore
141
+ if table.catalog:
142
+ catalog_filter: exp.Expression = exp.column("table_catalog").eq(table.catalog)
143
+ if include_temp and self.TEMP_CATALOG_FILTER:
144
+ catalog_filter = exp.Or(this=catalog_filter, expression=self.TEMP_CATALOG_FILTER)
145
+ select = select.where(catalog_filter) # type: ignore
146
+ results = self.session._fetch_rows(select)
147
+ return [
148
+ Column(
149
+ name=x["column_name"],
150
+ description=None,
151
+ dataType=x["data_type"],
152
+ nullable=x["is_nullable"] == "YES",
153
+ isPartition=False,
154
+ isBucket=False,
155
+ )
156
+ for x in results
157
+ ]
158
+
39
159
  def listFunctions(
40
160
  self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
41
161
  ) -> t.List[Function]:
@@ -9,7 +9,10 @@ from sqlframe.base.dataframe import (
9
9
  _BaseDataFrameNaFunctions,
10
10
  _BaseDataFrameStatFunctions,
11
11
  )
12
- from sqlframe.base.mixins.dataframe_mixins import PrintSchemaFromTempObjectsMixin
12
+ from sqlframe.base.mixins.dataframe_mixins import (
13
+ NoCachePersistSupportMixin,
14
+ TypedColumnsFromTempViewMixin,
15
+ )
13
16
  from sqlframe.postgres.group import PostgresGroupedData
14
17
 
15
18
  if sys.version_info >= (3, 11):
@@ -34,7 +37,8 @@ class PostgresDataFrameStatFunctions(_BaseDataFrameStatFunctions["PostgresDataFr
34
37
 
35
38
 
36
39
  class PostgresDataFrame(
37
- PrintSchemaFromTempObjectsMixin,
40
+ NoCachePersistSupportMixin,
41
+ TypedColumnsFromTempViewMixin,
38
42
  _BaseDataFrame[
39
43
  "PostgresSession",
40
44
  "PostgresDataFrameWriter",
@@ -46,11 +50,3 @@ class PostgresDataFrame(
46
50
  _na = PostgresDataFrameNaFunctions
47
51
  _stat = PostgresDataFrameStatFunctions
48
52
  _group_data = PostgresGroupedData
49
-
50
- def cache(self) -> Self:
51
- logger.warning("Postgres does not support caching. Ignoring cache() call.")
52
- return self
53
-
54
- def persist(self) -> Self:
55
- logger.warning("Postgres does not support persist. Ignoring persist() call.")
56
- return self
@@ -16,6 +16,7 @@ globals().update(
16
16
 
17
17
 
18
18
  from sqlframe.base.function_alternatives import ( # noqa
19
+ any_value_ignore_nulls_not_supported as any_value,
19
20
  e_literal as e,
20
21
  expm1_from_exp as expm1,
21
22
  log1p_from_log as log1p,
@@ -40,6 +41,7 @@ from sqlframe.base.function_alternatives import ( # noqa
40
41
  date_add_by_multiplication as date_add,
41
42
  date_sub_by_multiplication as date_sub,
42
43
  date_diff_with_subtraction as date_diff,
44
+ date_diff_with_subtraction as datediff,
43
45
  add_months_by_multiplication as add_months,
44
46
  months_between_from_age_and_extract as months_between,
45
47
  from_unixtime_from_timestamp as from_unixtime,
@@ -58,4 +60,8 @@ from sqlframe.base.function_alternatives import ( # noqa
58
60
  get_json_object_using_arrow_op as get_json_object,
59
61
  array_min_from_subquery as array_min,
60
62
  array_max_from_subquery as array_max,
63
+ left_cast_len as left,
64
+ right_cast_len as right,
65
+ position_cast_start as position,
66
+ try_element_at_zero_based as try_element_at,
61
67
  )
@@ -1,4 +1,5 @@
1
1
  from sqlframe.base.function_alternatives import ( # noqa
2
+ any_value_ignore_nulls_not_supported as any_value,
2
3
  e_literal as e,
3
4
  expm1_from_exp as expm1,
4
5
  log1p_from_log as log1p,
@@ -23,6 +24,7 @@ from sqlframe.base.function_alternatives import ( # noqa
23
24
  date_add_by_multiplication as date_add,
24
25
  date_sub_by_multiplication as date_sub,
25
26
  date_diff_with_subtraction as date_diff,
27
+ date_diff_with_subtraction as datediff,
26
28
  add_months_by_multiplication as add_months,
27
29
  months_between_from_age_and_extract as months_between,
28
30
  from_unixtime_from_timestamp as from_unixtime,
@@ -41,6 +43,10 @@ from sqlframe.base.function_alternatives import ( # noqa
41
43
  get_json_object_using_arrow_op as get_json_object,
42
44
  array_min_from_subquery as array_min,
43
45
  array_max_from_subquery as array_max,
46
+ left_cast_len as left,
47
+ right_cast_len as right,
48
+ position_cast_start as position,
49
+ try_element_at_zero_based as try_element_at,
44
50
  )
45
51
  from sqlframe.base.functions import (
46
52
  abs as abs,
@@ -64,9 +70,13 @@ from sqlframe.base.functions import (
64
70
  bit_length as bit_length,
65
71
  bitwiseNOT as bitwiseNOT,
66
72
  bitwise_not as bitwise_not,
73
+ bool_and as bool_and,
74
+ bool_or as bool_or,
75
+ call_function as call_function,
67
76
  cbrt as cbrt,
68
77
  ceil as ceil,
69
78
  ceiling as ceiling,
79
+ char as char,
70
80
  coalesce as coalesce,
71
81
  col as col,
72
82
  collect_list as collect_list,
@@ -84,8 +94,10 @@ from sqlframe.base.functions import (
84
94
  cume_dist as cume_dist,
85
95
  current_date as current_date,
86
96
  current_timestamp as current_timestamp,
97
+ current_user as current_user,
87
98
  date_format as date_format,
88
99
  date_trunc as date_trunc,
100
+ dateadd as dateadd,
89
101
  degrees as degrees,
90
102
  dense_rank as dense_rank,
91
103
  desc as desc,
@@ -94,18 +106,22 @@ from sqlframe.base.functions import (
94
106
  exp as exp,
95
107
  explode as explode,
96
108
  expr as expr,
109
+ extract as extract,
97
110
  factorial as factorial,
98
111
  floor as floor,
99
112
  greatest as greatest,
113
+ ifnull as ifnull,
100
114
  initcap as initcap,
101
115
  input_file_name as input_file_name,
102
116
  instr as instr,
103
117
  lag as lag,
118
+ lcase as lcase,
104
119
  lead as lead,
105
120
  least as least,
106
121
  length as length,
107
122
  levenshtein as levenshtein,
108
123
  lit as lit,
124
+ ln as ln,
109
125
  locate as locate,
110
126
  log as log,
111
127
  log10 as log10,
@@ -117,19 +133,25 @@ from sqlframe.base.functions import (
117
133
  md5 as md5,
118
134
  mean as mean,
119
135
  min as min,
136
+ now as now,
120
137
  nth_value as nth_value,
121
138
  ntile as ntile,
122
139
  nullif as nullif,
140
+ nvl as nvl,
141
+ nvl2 as nvl2,
123
142
  octet_length as octet_length,
124
143
  overlay as overlay,
125
144
  percent_rank as percent_rank,
126
145
  percentile as percentile,
127
146
  pow as pow,
147
+ power as power,
128
148
  radians as radians,
129
149
  rank as rank,
150
+ regexp_like as regexp_like,
130
151
  regexp_replace as regexp_replace,
131
152
  repeat as repeat,
132
153
  reverse as reverse,
154
+ rlike as rlike,
133
155
  row_number as row_number,
134
156
  rpad as rpad,
135
157
  rtrim as rtrim,
@@ -137,12 +159,14 @@ from sqlframe.base.functions import (
137
159
  shiftRight as shiftRight,
138
160
  shiftleft as shiftleft,
139
161
  shiftright as shiftright,
162
+ sign as sign,
140
163
  signum as signum,
141
164
  sin as sin,
142
165
  sinh as sinh,
143
166
  size as size,
144
167
  soundex as soundex,
145
168
  sqrt as sqrt,
169
+ startswith as startswith,
146
170
  stddev as stddev,
147
171
  stddev_pop as stddev_pop,
148
172
  stddev_samp as stddev_samp,
@@ -156,11 +180,15 @@ from sqlframe.base.functions import (
156
180
  toDegrees as toDegrees,
157
181
  toRadians as toRadians,
158
182
  to_date as to_date,
183
+ to_number as to_number,
159
184
  to_timestamp as to_timestamp,
160
185
  translate as translate,
161
186
  trim as trim,
162
187
  trunc as trunc,
188
+ ucase as ucase,
189
+ unix_date as unix_date,
163
190
  upper as upper,
191
+ user as user,
164
192
  var_pop as var_pop,
165
193
  var_samp as var_samp,
166
194
  variance as variance,
@@ -9,13 +9,9 @@ from sqlframe.base.dataframe import (
9
9
  _BaseDataFrameNaFunctions,
10
10
  _BaseDataFrameStatFunctions,
11
11
  )
12
+ from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
12
13
  from sqlframe.redshift.group import RedshiftGroupedData
13
14
 
14
- if sys.version_info >= (3, 11):
15
- from typing import Self
16
- else:
17
- from typing_extensions import Self
18
-
19
15
  if t.TYPE_CHECKING:
20
16
  from sqlframe.redshift.readwriter import RedshiftDataFrameWriter
21
17
  from sqlframe.redshift.session import RedshiftSession
@@ -33,22 +29,15 @@ class RedshiftDataFrameStatFunctions(_BaseDataFrameStatFunctions["RedshiftDataFr
33
29
 
34
30
 
35
31
  class RedshiftDataFrame(
32
+ NoCachePersistSupportMixin,
36
33
  _BaseDataFrame[
37
34
  "RedshiftSession",
38
35
  "RedshiftDataFrameWriter",
39
36
  "RedshiftDataFrameNaFunctions",
40
37
  "RedshiftDataFrameStatFunctions",
41
38
  "RedshiftGroupedData",
42
- ]
39
+ ],
43
40
  ):
44
41
  _na = RedshiftDataFrameNaFunctions
45
42
  _stat = RedshiftDataFrameStatFunctions
46
43
  _group_data = RedshiftGroupedData
47
-
48
- def cache(self) -> Self:
49
- logger.warning("Redshift does not support caching. Ignoring cache() call.")
50
- return self
51
-
52
- def persist(self) -> Self:
53
- logger.warning("Redshift does not support persist. Ignoring persist() call.")
54
- return self
@@ -4,18 +4,15 @@ import logging
4
4
  import sys
5
5
  import typing as t
6
6
 
7
+ from sqlframe.base.catalog import Column as CatalogColumn
7
8
  from sqlframe.base.dataframe import (
8
9
  _BaseDataFrame,
9
10
  _BaseDataFrameNaFunctions,
10
11
  _BaseDataFrameStatFunctions,
11
12
  )
13
+ from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
12
14
  from sqlframe.snowflake.group import SnowflakeGroupedData
13
15
 
14
- if sys.version_info >= (3, 11):
15
- from typing import Self
16
- else:
17
- from typing_extensions import Self
18
-
19
16
  if t.TYPE_CHECKING:
20
17
  from sqlframe.snowflake.readwriter import SnowflakeDataFrameWriter
21
18
  from sqlframe.snowflake.session import SnowflakeSession
@@ -33,22 +30,35 @@ class SnowflakeDataFrameStatFunctions(_BaseDataFrameStatFunctions["SnowflakeData
33
30
 
34
31
 
35
32
  class SnowflakeDataFrame(
33
+ NoCachePersistSupportMixin,
36
34
  _BaseDataFrame[
37
35
  "SnowflakeSession",
38
36
  "SnowflakeDataFrameWriter",
39
37
  "SnowflakeDataFrameNaFunctions",
40
38
  "SnowflakeDataFrameStatFunctions",
41
39
  "SnowflakeGroupedData",
42
- ]
40
+ ],
43
41
  ):
44
42
  _na = SnowflakeDataFrameNaFunctions
45
43
  _stat = SnowflakeDataFrameStatFunctions
46
44
  _group_data = SnowflakeGroupedData
47
45
 
48
- def cache(self) -> Self:
49
- logger.warning("Snowflake does not support caching. Ignoring cache() call.")
50
- return self
51
-
52
- def persist(self) -> Self:
53
- logger.warning("Snowflake does not support persist. Ignoring persist() call.")
54
- return self
46
+ @property
47
+ def _typed_columns(self) -> t.List[CatalogColumn]:
48
+ df = self._convert_leaf_to_cte()
49
+ df = df.limit(0)
50
+ self.session._execute(df.expression)
51
+ query_id = self.session._cur.sfqid
52
+ columns = []
53
+ for row in self.session._fetch_rows(f"DESCRIBE RESULT '{query_id}'"):
54
+ columns.append(
55
+ CatalogColumn(
56
+ name=row.name,
57
+ dataType=row.type,
58
+ nullable=row["null?"] == "Y",
59
+ description=row.comment,
60
+ isPartition=False,
61
+ isBucket=False,
62
+ )
63
+ )
64
+ return columns
@@ -16,6 +16,7 @@ globals().update(
16
16
 
17
17
 
18
18
  from sqlframe.base.function_alternatives import ( # noqa
19
+ any_value_ignore_nulls_not_supported as any_value,
19
20
  e_literal as e,
20
21
  expm1_from_exp as expm1,
21
22
  log1p_from_log as log1p,
@@ -32,6 +33,7 @@ from sqlframe.base.function_alternatives import ( # noqa
32
33
  struct_with_eq as struct,
33
34
  make_date_date_from_parts as make_date,
34
35
  date_add_no_date_sub as date_add,
36
+ date_add_no_date_sub as dateadd,
35
37
  date_sub_by_date_add as date_sub,
36
38
  add_months_using_func as add_months,
37
39
  months_between_cast_as_date_cast_roundoff as months_between,
@@ -60,4 +62,5 @@ from sqlframe.base.function_alternatives import ( # noqa
60
62
  flatten_using_array_flatten as flatten,
61
63
  map_concat_using_map_cat as map_concat,
62
64
  sequence_from_array_generate_range as sequence,
65
+ to_number_using_to_double as to_number,
63
66
  )
@@ -1,4 +1,5 @@
1
1
  from sqlframe.base.function_alternatives import ( # noqa
2
+ any_value_ignore_nulls_not_supported as any_value,
2
3
  e_literal as e,
3
4
  expm1_from_exp as expm1,
4
5
  log1p_from_log as log1p,
@@ -15,6 +16,7 @@ from sqlframe.base.function_alternatives import ( # noqa
15
16
  struct_with_eq as struct,
16
17
  make_date_date_from_parts as make_date,
17
18
  date_add_no_date_sub as date_add,
19
+ date_add_no_date_sub as dateadd,
18
20
  date_sub_by_date_add as date_sub,
19
21
  add_months_using_func as add_months,
20
22
  months_between_cast_as_date_cast_roundoff as months_between,
@@ -43,6 +45,7 @@ from sqlframe.base.function_alternatives import ( # noqa
43
45
  flatten_using_array_flatten as flatten,
44
46
  map_concat_using_map_cat as map_concat,
45
47
  sequence_from_array_generate_range as sequence,
48
+ to_number_using_to_double as to_number,
46
49
  )
47
50
  from sqlframe.base.functions import (
48
51
  abs as abs,
@@ -69,9 +72,13 @@ from sqlframe.base.functions import (
69
72
  avg as avg,
70
73
  bit_length as bit_length,
71
74
  bitwiseNOT as bitwiseNOT,
75
+ bool_and as bool_and,
76
+ bool_or as bool_or,
77
+ call_function as call_function,
72
78
  cbrt as cbrt,
73
79
  ceil as ceil,
74
80
  ceiling as ceiling,
81
+ char as char,
75
82
  coalesce as coalesce,
76
83
  col as col,
77
84
  collect_list as collect_list,
@@ -85,14 +92,17 @@ from sqlframe.base.functions import (
85
92
  count as count,
86
93
  countDistinct as countDistinct,
87
94
  count_distinct as count_distinct,
95
+ count_if as count_if,
88
96
  covar_pop as covar_pop,
89
97
  covar_samp as covar_samp,
90
98
  cume_dist as cume_dist,
91
99
  current_date as current_date,
92
100
  current_timestamp as current_timestamp,
101
+ current_user as current_user,
93
102
  date_diff as date_diff,
94
103
  date_format as date_format,
95
104
  date_trunc as date_trunc,
105
+ datediff as datediff,
96
106
  dayofmonth as dayofmonth,
97
107
  dayofweek as dayofweek,
98
108
  dayofyear as dayofyear,
@@ -104,21 +114,26 @@ from sqlframe.base.functions import (
104
114
  exp as exp,
105
115
  explode as explode,
106
116
  expr as expr,
117
+ extract as extract,
107
118
  factorial as factorial,
108
119
  floor as floor,
109
120
  greatest as greatest,
110
121
  grouping_id as grouping_id,
111
122
  hash as hash,
112
123
  hour as hour,
124
+ ifnull as ifnull,
113
125
  initcap as initcap,
114
126
  input_file_name as input_file_name,
115
127
  instr as instr,
116
128
  kurtosis as kurtosis,
117
129
  lag as lag,
130
+ lcase as lcase,
118
131
  lead as lead,
119
132
  least as least,
133
+ left as left,
120
134
  length as length,
121
135
  lit as lit,
136
+ ln as ln,
122
137
  locate as locate,
123
138
  log as log,
124
139
  log10 as log10,
@@ -136,35 +151,44 @@ from sqlframe.base.functions import (
136
151
  minute as minute,
137
152
  month as month,
138
153
  next_day as next_day,
154
+ now as now,
139
155
  nth_value as nth_value,
140
156
  ntile as ntile,
141
157
  nullif as nullif,
158
+ nvl as nvl,
159
+ nvl2 as nvl2,
142
160
  octet_length as octet_length,
143
161
  percent_rank as percent_rank,
144
162
  percentile as percentile,
145
163
  posexplode as posexplode,
164
+ position as position,
146
165
  pow as pow,
166
+ power as power,
147
167
  quarter as quarter,
148
168
  radians as radians,
149
169
  rand as rand,
150
170
  rank as rank,
151
171
  regexp_replace as regexp_replace,
152
172
  repeat as repeat,
173
+ right as right,
153
174
  round as round,
154
175
  row_number as row_number,
155
176
  rpad as rpad,
156
177
  rtrim as rtrim,
157
178
  second as second,
179
+ sha as sha,
158
180
  sha1 as sha1,
159
181
  sha2 as sha2,
160
182
  shiftLeft as shiftLeft,
161
183
  shiftRight as shiftRight,
184
+ sign as sign,
162
185
  signum as signum,
163
186
  sin as sin,
164
187
  sinh as sinh,
165
188
  size as size,
166
189
  soundex as soundex,
167
190
  sqrt as sqrt,
191
+ startswith as startswith,
168
192
  stddev as stddev,
169
193
  stddev_pop as stddev_pop,
170
194
  stddev_samp as stddev_samp,
@@ -182,7 +206,10 @@ from sqlframe.base.functions import (
182
206
  translate as translate,
183
207
  trim as trim,
184
208
  trunc as trunc,
209
+ ucase as ucase,
210
+ unix_date as unix_date,
185
211
  upper as upper,
212
+ user as user,
186
213
  var_pop as var_pop,
187
214
  var_samp as var_samp,
188
215
  variance as variance,
@@ -1,26 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import sys
5
4
  import typing as t
6
5
 
6
+ from sqlglot import exp
7
+
8
+ from sqlframe.base.catalog import Column
7
9
  from sqlframe.base.dataframe import (
8
10
  _BaseDataFrame,
9
11
  _BaseDataFrameNaFunctions,
10
12
  _BaseDataFrameStatFunctions,
11
13
  )
14
+ from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
12
15
  from sqlframe.spark.group import SparkGroupedData
13
16
 
14
- if sys.version_info >= (3, 11):
15
- from typing import Self
16
- else:
17
- from typing_extensions import Self
18
-
19
17
  if t.TYPE_CHECKING:
20
18
  from sqlframe.spark.readwriter import SparkDataFrameWriter
21
19
  from sqlframe.spark.session import SparkSession
22
20
 
23
-
24
21
  logger = logging.getLogger(__name__)
25
22
 
26
23
 
@@ -33,22 +30,35 @@ class SparkDataFrameStatFunctions(_BaseDataFrameStatFunctions["SparkDataFrame"])
33
30
 
34
31
 
35
32
  class SparkDataFrame(
33
+ NoCachePersistSupportMixin,
36
34
  _BaseDataFrame[
37
35
  "SparkSession",
38
36
  "SparkDataFrameWriter",
39
37
  "SparkDataFrameNaFunctions",
40
38
  "SparkDataFrameStatFunctions",
41
39
  "SparkGroupedData",
42
- ]
40
+ ],
43
41
  ):
44
42
  _na = SparkDataFrameNaFunctions
45
43
  _stat = SparkDataFrameStatFunctions
46
44
  _group_data = SparkGroupedData
47
45
 
48
- def cache(self) -> Self:
49
- logger.warning("Spark does not support caching. Ignoring cache() call.")
50
- return self
51
-
52
- def persist(self) -> Self:
53
- logger.warning("Spark does not support persist. Ignoring persist() call.")
54
- return self
46
+ @property
47
+ def _typed_columns(self) -> t.List[Column]:
48
+ columns = []
49
+ for field in self.session.spark_session.sql(
50
+ self.session._to_sql(self.expression)
51
+ ).schema.fields:
52
+ columns.append(
53
+ Column(
54
+ name=field.name,
55
+ dataType=exp.DataType.build(field.dataType.simpleString(), dialect="spark").sql(
56
+ dialect="spark"
57
+ ),
58
+ nullable=field.nullable,
59
+ description=None,
60
+ isPartition=False,
61
+ isBucket=False,
62
+ )
63
+ )
64
+ return columns