sqlframe 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sqlframe/__init__.py +0 -0
  2. sqlframe/_version.py +16 -0
  3. sqlframe/base/__init__.py +0 -0
  4. sqlframe/base/_typing.py +39 -0
  5. sqlframe/base/catalog.py +1163 -0
  6. sqlframe/base/column.py +388 -0
  7. sqlframe/base/dataframe.py +1519 -0
  8. sqlframe/base/decorators.py +51 -0
  9. sqlframe/base/exceptions.py +14 -0
  10. sqlframe/base/function_alternatives.py +1055 -0
  11. sqlframe/base/functions.py +1678 -0
  12. sqlframe/base/group.py +102 -0
  13. sqlframe/base/mixins/__init__.py +0 -0
  14. sqlframe/base/mixins/catalog_mixins.py +419 -0
  15. sqlframe/base/mixins/readwriter_mixins.py +118 -0
  16. sqlframe/base/normalize.py +84 -0
  17. sqlframe/base/operations.py +87 -0
  18. sqlframe/base/readerwriter.py +679 -0
  19. sqlframe/base/session.py +585 -0
  20. sqlframe/base/transforms.py +13 -0
  21. sqlframe/base/types.py +418 -0
  22. sqlframe/base/util.py +242 -0
  23. sqlframe/base/window.py +139 -0
  24. sqlframe/bigquery/__init__.py +23 -0
  25. sqlframe/bigquery/catalog.py +255 -0
  26. sqlframe/bigquery/column.py +1 -0
  27. sqlframe/bigquery/dataframe.py +54 -0
  28. sqlframe/bigquery/functions.py +378 -0
  29. sqlframe/bigquery/group.py +14 -0
  30. sqlframe/bigquery/readwriter.py +29 -0
  31. sqlframe/bigquery/session.py +89 -0
  32. sqlframe/bigquery/types.py +1 -0
  33. sqlframe/bigquery/window.py +1 -0
  34. sqlframe/duckdb/__init__.py +20 -0
  35. sqlframe/duckdb/catalog.py +108 -0
  36. sqlframe/duckdb/column.py +1 -0
  37. sqlframe/duckdb/dataframe.py +55 -0
  38. sqlframe/duckdb/functions.py +47 -0
  39. sqlframe/duckdb/group.py +14 -0
  40. sqlframe/duckdb/readwriter.py +111 -0
  41. sqlframe/duckdb/session.py +65 -0
  42. sqlframe/duckdb/types.py +1 -0
  43. sqlframe/duckdb/window.py +1 -0
  44. sqlframe/postgres/__init__.py +23 -0
  45. sqlframe/postgres/catalog.py +106 -0
  46. sqlframe/postgres/column.py +1 -0
  47. sqlframe/postgres/dataframe.py +54 -0
  48. sqlframe/postgres/functions.py +61 -0
  49. sqlframe/postgres/group.py +14 -0
  50. sqlframe/postgres/readwriter.py +29 -0
  51. sqlframe/postgres/session.py +68 -0
  52. sqlframe/postgres/types.py +1 -0
  53. sqlframe/postgres/window.py +1 -0
  54. sqlframe/redshift/__init__.py +23 -0
  55. sqlframe/redshift/catalog.py +127 -0
  56. sqlframe/redshift/column.py +1 -0
  57. sqlframe/redshift/dataframe.py +54 -0
  58. sqlframe/redshift/functions.py +18 -0
  59. sqlframe/redshift/group.py +14 -0
  60. sqlframe/redshift/readwriter.py +29 -0
  61. sqlframe/redshift/session.py +53 -0
  62. sqlframe/redshift/types.py +1 -0
  63. sqlframe/redshift/window.py +1 -0
  64. sqlframe/snowflake/__init__.py +26 -0
  65. sqlframe/snowflake/catalog.py +134 -0
  66. sqlframe/snowflake/column.py +1 -0
  67. sqlframe/snowflake/dataframe.py +54 -0
  68. sqlframe/snowflake/functions.py +18 -0
  69. sqlframe/snowflake/group.py +14 -0
  70. sqlframe/snowflake/readwriter.py +29 -0
  71. sqlframe/snowflake/session.py +53 -0
  72. sqlframe/snowflake/types.py +1 -0
  73. sqlframe/snowflake/window.py +1 -0
  74. sqlframe/spark/__init__.py +23 -0
  75. sqlframe/spark/catalog.py +1028 -0
  76. sqlframe/spark/column.py +1 -0
  77. sqlframe/spark/dataframe.py +54 -0
  78. sqlframe/spark/functions.py +22 -0
  79. sqlframe/spark/group.py +14 -0
  80. sqlframe/spark/readwriter.py +29 -0
  81. sqlframe/spark/session.py +90 -0
  82. sqlframe/spark/types.py +1 -0
  83. sqlframe/spark/window.py +1 -0
  84. sqlframe/standalone/__init__.py +26 -0
  85. sqlframe/standalone/catalog.py +13 -0
  86. sqlframe/standalone/column.py +1 -0
  87. sqlframe/standalone/dataframe.py +36 -0
  88. sqlframe/standalone/functions.py +1 -0
  89. sqlframe/standalone/group.py +14 -0
  90. sqlframe/standalone/readwriter.py +19 -0
  91. sqlframe/standalone/session.py +40 -0
  92. sqlframe/standalone/types.py +1 -0
  93. sqlframe/standalone/window.py +1 -0
  94. sqlframe-1.1.3.dist-info/LICENSE +21 -0
  95. sqlframe-1.1.3.dist-info/METADATA +172 -0
  96. sqlframe-1.1.3.dist-info/RECORD +98 -0
  97. sqlframe-1.1.3.dist-info/WHEEL +5 -0
  98. sqlframe-1.1.3.dist-info/top_level.txt +1 -0
sqlframe/base/group.py ADDED
@@ -0,0 +1,102 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing as t
6
+
7
+ from sqlframe.base.operations import Operation, group_operation, operation
8
+
9
+ if t.TYPE_CHECKING:
10
+ from sqlframe.base.column import Column
11
+ from sqlframe.base.session import DF
12
+ else:
13
+ DF = t.TypeVar("DF")
14
+
15
+
16
+ # https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-groupby.html
17
+ # https://stackoverflow.com/questions/37975227/what-is-the-difference-between-cube-rollup-and-groupby-operators
18
+ class _BaseGroupedData(t.Generic[DF]):
19
+ def __init__(
20
+ self,
21
+ df: DF,
22
+ group_by_cols: t.Union[t.List[Column], t.List[t.List[Column]]],
23
+ last_op: Operation,
24
+ ):
25
+ self._df = df.copy()
26
+ self.session = df.session
27
+ self.last_op = last_op
28
+ self.group_by_cols = group_by_cols
29
+
30
+ def _get_function_applied_columns(
31
+ self, func_name: str, cols: t.Tuple[str, ...]
32
+ ) -> t.List[Column]:
33
+ from sqlframe.base import functions as F
34
+
35
+ func_name = func_name.lower()
36
+ return [
37
+ getattr(F, func_name)(name).alias(
38
+ self.session._sanitize_column_name(f"{func_name}({name})")
39
+ )
40
+ for name in cols
41
+ ]
42
+
43
+ @group_operation(Operation.SELECT)
44
+ def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DF:
45
+ from sqlframe.base.column import Column
46
+
47
+ columns = (
48
+ [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
49
+ if isinstance(exprs[0], dict)
50
+ else exprs
51
+ )
52
+ cols = self._df._ensure_and_normalize_cols(columns)
53
+
54
+ if not self.group_by_cols or not isinstance(self.group_by_cols[0], (list, tuple, set)):
55
+ expression = self._df.expression.group_by(
56
+ *[x.expression for x in self.group_by_cols] # type: ignore
57
+ ).select(*[x.expression for x in self.group_by_cols + cols], append=False) # type: ignore
58
+ group_by_cols = self.group_by_cols
59
+ else:
60
+ from sqlglot import exp
61
+
62
+ expression = self._df.expression
63
+ all_grouping_sets = []
64
+ group_by_cols = []
65
+ for grouping_set in self.group_by_cols:
66
+ all_grouping_sets.append(
67
+ exp.Tuple(expressions=[x.expression for x in grouping_set]) # type: ignore
68
+ )
69
+ group_by_cols.extend(grouping_set) # type: ignore
70
+ group_by_cols = list(dict.fromkeys(group_by_cols))
71
+ group_by = exp.Group(grouping_sets=all_grouping_sets)
72
+ expression.set("group", group_by)
73
+ for col in cols:
74
+ # Spark supports having an empty grouping_id which means all of the columns but other dialects
75
+ # like duckdb don't support this so we expand the grouping_id to include all of the columns
76
+ if col.column_expression.this == "GROUPING_ID":
77
+ col.column_expression.set("expressions", [x.expression for x in group_by_cols]) # type: ignore
78
+ expression = expression.select(*[x.expression for x in group_by_cols + cols], append=False) # type: ignore
79
+ return self._df.copy(expression=expression)
80
+
81
+ def count(self) -> DF:
82
+ from sqlframe.base import functions as F
83
+
84
+ return self.agg(F.count("*").alias("count"))
85
+
86
+ def mean(self, *cols: str) -> DF:
87
+ return self.avg(*cols)
88
+
89
+ def avg(self, *cols: str) -> DF:
90
+ return self.agg(*self._get_function_applied_columns("avg", cols))
91
+
92
+ def max(self, *cols: str) -> DF:
93
+ return self.agg(*self._get_function_applied_columns("max", cols))
94
+
95
+ def min(self, *cols: str) -> DF:
96
+ return self.agg(*self._get_function_applied_columns("min", cols))
97
+
98
+ def sum(self, *cols: str) -> DF:
99
+ return self.agg(*self._get_function_applied_columns("sum", cols))
100
+
101
+ def pivot(self, *cols: str) -> DF:
102
+ raise NotImplementedError("Sum distinct is not currently implemented")
File without changes
@@ -0,0 +1,419 @@
1
+ import fnmatch
2
+ import typing as t
3
+
4
+ from sqlglot import exp
5
+
6
+ from sqlframe.base.catalog import (
7
+ DF,
8
+ SESSION,
9
+ CatalogMetadata,
10
+ Column,
11
+ Database,
12
+ Table,
13
+ _BaseCatalog,
14
+ )
15
+ from sqlframe.base.decorators import normalize
16
+ from sqlframe.base.util import decoded_str, schema_, to_schema
17
+
18
+
19
+ class _BaseInfoSchemaMixin(_BaseCatalog, t.Generic[SESSION, DF]):
20
+ QUALIFY_INFO_SCHEMA_WITH_DATABASE = False
21
+ UPPERCASE_INFO_SCHEMA = False
22
+
23
+ def _get_info_schema_table(
24
+ self,
25
+ table_name: str,
26
+ database: t.Optional[str] = None,
27
+ qualify_override: t.Optional[bool] = None,
28
+ ) -> exp.Table:
29
+ table = f"information_schema.{table_name}"
30
+ if self.UPPERCASE_INFO_SCHEMA:
31
+ table = table.upper()
32
+ qualify = (
33
+ qualify_override
34
+ if qualify_override is not None
35
+ else self.QUALIFY_INFO_SCHEMA_WITH_DATABASE
36
+ )
37
+ if qualify:
38
+ db = database or self.currentDatabase()
39
+ if not db:
40
+ raise ValueError("Table name must be qualified with a database.")
41
+ table = f"{db}.{table}"
42
+ return exp.to_table(table)
43
+
44
+
45
+ class GetCurrentCatalogFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF]):
46
+ CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog")
47
+
48
+ def currentCatalog(self) -> str:
49
+ """Returns the current default catalog in this session.
50
+
51
+ .. versionadded:: 3.4.0
52
+
53
+ Examples
54
+ --------
55
+ >>> spark.catalog.currentCatalog()
56
+ 'spark_catalog'
57
+ """
58
+ return self.session._fetch_rows(
59
+ exp.select(self.CURRENT_CATALOG_EXPRESSION), quote_identifiers=False
60
+ )[0][0]
61
+
62
+
63
+ class GetCurrentDatabaseFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF]):
64
+ CURRENT_DATABASE_EXPRESSION: exp.Expression = exp.func("current_schema")
65
+
66
+ def currentDatabase(self) -> str:
67
+ """Returns the current default schema in this session.
68
+
69
+ .. versionadded:: 3.4.0
70
+
71
+ Examples
72
+ --------
73
+ >>> spark.catalog.currentDatabase()
74
+ 'default'
75
+ """
76
+ return self.session._fetch_rows(exp.select(self.CURRENT_DATABASE_EXPRESSION))[0][0]
77
+
78
+
79
+ class SetCurrentCatalogFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF]):
80
+ def setCurrentCatalog(self, catalogName: str) -> None:
81
+ """Sets the current default catalog in this session.
82
+
83
+ .. versionadded:: 3.4.0
84
+
85
+ Parameters
86
+ ----------
87
+ catalogName : str
88
+ name of the catalog to set
89
+
90
+ Examples
91
+ --------
92
+ >>> spark.catalog.setCurrentCatalog("spark_catalog")
93
+ """
94
+ self.session._execute(
95
+ exp.Use(this=exp.parse_identifier(catalogName, dialect=self.session.input_dialect))
96
+ )
97
+
98
+
99
+ class ListDatabasesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
100
+ def listDatabases(self, pattern: t.Optional[str] = None) -> t.List[Database]:
101
+ """
102
+ Returns a t.List of databases available across all sessions.
103
+
104
+ .. versionadded:: 2.0.0
105
+
106
+ Parameters
107
+ ----------
108
+ pattern : str
109
+ The pattern that the database name needs to match.
110
+
111
+ .. versionchanged: 3.5.0
112
+ Adds ``pattern`` argument.
113
+
114
+ Returns
115
+ -------
116
+ t.List
117
+ A t.List of :class:`Database`.
118
+
119
+ Examples
120
+ --------
121
+ >>> spark.catalog.t.listDatabases()
122
+ [Database(name='default', catalog='spark_catalog', description='default database', ...
123
+
124
+ >>> spark.catalog.t.listDatabases("def*")
125
+ [Database(name='default', catalog='spark_catalog', description='default database', ...
126
+
127
+ >>> spark.catalog.t.listDatabases("def2*")
128
+ []
129
+ """
130
+ table = self._get_info_schema_table("schemata", qualify_override=False)
131
+ results = self.session._fetch_rows(
132
+ exp.Select().select("schema_name", "catalog_name").from_(table)
133
+ )
134
+ databases = [
135
+ Database(name=x[0], catalog=x[1], description=None, locationUri="") for x in results
136
+ ]
137
+ if pattern:
138
+ databases = [db for db in databases if fnmatch.fnmatch(db.name, pattern)]
139
+ return databases
140
+
141
+
142
+ class ListCatalogsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
143
+ def listCatalogs(self, pattern: t.Optional[str] = None) -> t.List[CatalogMetadata]:
144
+ """
145
+ Returns a t.List of databases available across all sessions.
146
+
147
+ .. versionadded:: 2.0.0
148
+
149
+ Parameters
150
+ ----------
151
+ pattern : str
152
+ The pattern that the database name needs to match.
153
+
154
+ .. versionchanged: 3.5.0
155
+ Adds ``pattern`` argument.
156
+
157
+ Returns
158
+ -------
159
+ t.List
160
+ A t.List of :class:`Database`.
161
+
162
+ Examples
163
+ --------
164
+ >>> spark.catalog.t.listDatabases()
165
+ [Database(name='default', catalog='spark_catalog', description='default database', ...
166
+
167
+ >>> spark.catalog.t.listDatabases("def*")
168
+ [Database(name='default', catalog='spark_catalog', description='default database', ...
169
+
170
+ >>> spark.catalog.t.listDatabases("def2*")
171
+ []
172
+ """
173
+ table = self._get_info_schema_table("schemata")
174
+ results = self.session._fetch_rows(
175
+ exp.Select().select("catalog_name").from_(table).distinct()
176
+ )
177
+ catalogs = [CatalogMetadata(name=x[0], description=None) for x in results]
178
+ if pattern:
179
+ catalogs = [catalog for catalog in catalogs if fnmatch.fnmatch(catalog.name, pattern)]
180
+ return catalogs
181
+
182
+
183
+ class SetCurrentDatabaseFromSearchPathMixin(_BaseCatalog, t.Generic[SESSION, DF]):
184
+ def setCurrentDatabase(self, dbName: str) -> None:
185
+ """
186
+ Sets the current default database in this session.
187
+
188
+ .. versionadded:: 2.0.0
189
+
190
+ Examples
191
+ --------
192
+ >>> spark.catalog.setCurrentDatabase("default")
193
+ """
194
+ self.session._execute(f'SET search_path TO "{dbName}"')
195
+
196
+
197
+ class SetCurrentDatabaseFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF]):
198
+ def setCurrentDatabase(self, dbName: str) -> None:
199
+ """
200
+ Sets the current default database in this session.
201
+
202
+ .. versionadded:: 2.0.0
203
+
204
+ Examples
205
+ --------
206
+ >>> spark.catalog.setCurrentDatabase("default")
207
+ """
208
+ schema = to_schema(dbName, dialect=self.session.input_dialect)
209
+ if not schema.catalog:
210
+ schema.set(
211
+ "catalog",
212
+ exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
213
+ )
214
+ self.session._execute(exp.Use(this=schema))
215
+
216
+
217
+ class ListTablesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
218
+ @normalize(["dbName"])
219
+ def listTables(
220
+ self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
221
+ ) -> t.List[Table]:
222
+ """Returns a t.List of tables/views in the specified database.
223
+
224
+ .. versionadded:: 2.0.0
225
+
226
+ Parameters
227
+ ----------
228
+ dbName : str
229
+ name of the database to t.List the tables.
230
+
231
+ .. versionchanged:: 3.4.0
232
+ Allow ``dbName`` to be qualified with catalog name.
233
+
234
+ pattern : str
235
+ The pattern that the database name needs to match.
236
+
237
+ .. versionchanged: 3.5.0
238
+ Adds ``pattern`` argument.
239
+
240
+ Returns
241
+ -------
242
+ t.List
243
+ A t.List of :class:`Table`.
244
+
245
+ Notes
246
+ -----
247
+ If no database is specified, the current database and catalog
248
+ are used. This API includes all temporary views.
249
+
250
+ Examples
251
+ --------
252
+ >>> spark.range(1).createTempView("test_view")
253
+ >>> spark.catalog.t.listTables()
254
+ [Table(name='test_view', catalog=None, namespace=[], description=None, ...
255
+
256
+ >>> spark.catalog.t.listTables(pattern="test*")
257
+ [Table(name='test_view', catalog=None, namespace=[], description=None, ...
258
+
259
+ >>> spark.catalog.t.listTables(pattern="table*")
260
+ []
261
+
262
+ >>> _ = spark.catalog.dropTempView("test_view")
263
+ >>> spark.catalog.t.listTables()
264
+ []
265
+ """
266
+ if dbName is None and pattern is None:
267
+ schema = schema_(
268
+ db=exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
269
+ catalog=exp.parse_identifier(
270
+ self.currentCatalog(), dialect=self.session.input_dialect
271
+ ),
272
+ )
273
+ elif dbName:
274
+ schema = to_schema(dbName, dialect=self.session.input_dialect)
275
+ else:
276
+ schema = None
277
+ table = self._get_info_schema_table("tables", database=schema.db if schema else None)
278
+ select = exp.select(
279
+ 'table_name AS "table_name"',
280
+ 'table_schema AS "table_schema"',
281
+ 'table_catalog AS "table_catalog"',
282
+ 'table_type AS "table_type"',
283
+ ).from_(table)
284
+ if schema and schema.db:
285
+ select = select.where(exp.column("table_schema").eq(schema.db))
286
+ if schema and schema.catalog:
287
+ select = select.where(exp.column("table_catalog").eq(schema.catalog))
288
+ results = self.session._fetch_rows(select)
289
+ tables = [
290
+ Table(
291
+ name=x["table_name"],
292
+ catalog=x["table_catalog"],
293
+ namespace=[x["table_schema"]],
294
+ description=None,
295
+ tableType="VIEW" if x["table_type"] == "VIEW" else "MANAGED",
296
+ isTemporary=False,
297
+ )
298
+ for x in results
299
+ ]
300
+ for table in self.session.temp_views.keys():
301
+ tables.append(
302
+ Table(
303
+ name=table, # type: ignore
304
+ catalog=None,
305
+ namespace=[],
306
+ description=None,
307
+ tableType="VIEW",
308
+ isTemporary=True,
309
+ )
310
+ )
311
+ if pattern:
312
+ tables = [x for x in tables if fnmatch.fnmatch(x.name, pattern)]
313
+ return tables
314
+
315
+
316
+ class ListColumnsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
317
+ @normalize(["tableName", "dbName"])
318
+ def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]:
319
+ """Returns a t.List of columns for the given table/view in the specified database.
320
+
321
+ .. versionadded:: 2.0.0
322
+
323
+ Parameters
324
+ ----------
325
+ tableName : str
326
+ name of the table to t.List columns.
327
+
328
+ .. versionchanged:: 3.4.0
329
+ Allow ``tableName`` to be qualified with catalog name when ``dbName`` is None.
330
+
331
+ dbName : str, t.Optional
332
+ name of the database to find the table to t.List columns.
333
+
334
+ Returns
335
+ -------
336
+ t.List
337
+ A t.List of :class:`Column`.
338
+
339
+ Notes
340
+ -----
341
+ The order of arguments here is different from that of its JVM counterpart
342
+ because Python does not support method overloading.
343
+
344
+ If no database is specified, the current database and catalog
345
+ are used. This API includes all temporary views.
346
+
347
+ Examples
348
+ --------
349
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
350
+ >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
351
+ >>> spark.catalog.t.listColumns("tblA")
352
+ [Column(name='name', description=None, dataType='string', nullable=True, ...
353
+ >>> _ = spark.sql("DROP TABLE tblA")
354
+ """
355
+ if df := self.session.temp_views.get(tableName):
356
+ return [
357
+ Column(
358
+ name=x,
359
+ description=None,
360
+ dataType="",
361
+ nullable=True,
362
+ isPartition=False,
363
+ isBucket=False,
364
+ )
365
+ for x in df.columns
366
+ ]
367
+
368
+ table = exp.to_table(tableName, dialect=self.session.input_dialect)
369
+ schema = to_schema(dbName, dialect=self.session.input_dialect) if dbName else None
370
+ if not table.db:
371
+ if schema and schema.db:
372
+ table.set("db", schema.args["db"])
373
+ else:
374
+ table.set(
375
+ "db",
376
+ exp.parse_identifier(
377
+ self.currentDatabase(), dialect=self.session.input_dialect
378
+ ),
379
+ )
380
+ if not table.catalog:
381
+ if schema and schema.catalog:
382
+ table.set("catalog", schema.args["catalog"])
383
+ else:
384
+ table.set(
385
+ "catalog",
386
+ exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
387
+ )
388
+ # if self.QUALIFY_INFO_SCHEMA_WITH_DATABASE:
389
+ # if not table.db:
390
+ # raise ValueError("dbName must be specified when listing columns from INFORMATION_SCHEMA")
391
+ # source_table = f"{table.db}.INFORMATION_SCHEMA.COLUMNS"
392
+ # else:
393
+ # source_table = "INFORMATION_SCHEMA.COLUMNS"
394
+ source_table = self._get_info_schema_table("columns", database=table.db)
395
+ select = (
396
+ exp.select(
397
+ 'column_name AS "column_name"',
398
+ 'data_type AS "data_type"',
399
+ 'is_nullable AS "is_nullable"',
400
+ )
401
+ .from_(source_table)
402
+ .where(exp.column("table_name").eq(table.name))
403
+ )
404
+ if table.db:
405
+ select = select.where(exp.column("table_schema").eq(table.db))
406
+ if table.catalog:
407
+ select = select.where(exp.column("table_catalog").eq(table.catalog))
408
+ results = self.session._fetch_rows(select)
409
+ return [
410
+ Column(
411
+ name=x["column_name"],
412
+ description=None,
413
+ dataType=x["data_type"],
414
+ nullable=x["is_nullable"] == "YES",
415
+ isPartition=False,
416
+ isBucket=False,
417
+ )
418
+ for x in results
419
+ ]
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+ import typing as t
5
+
6
+ import pandas as pd
7
+
8
+ from sqlframe.base.exceptions import UnsupportedOperationError
9
+ from sqlframe.base.readerwriter import (
10
+ DF,
11
+ SESSION,
12
+ _BaseDataFrameReader,
13
+ _BaseDataFrameWriter,
14
+ _infer_format,
15
+ )
16
+ from sqlframe.base.util import pandas_to_spark_schema
17
+
18
+ if t.TYPE_CHECKING:
19
+ from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
20
+ from sqlframe.base.types import StructType
21
+
22
+
23
+ class PandasLoaderMixin(_BaseDataFrameReader, t.Generic[SESSION, DF]):
24
+ def load(
25
+ self,
26
+ path: t.Optional[PathOrPaths] = None,
27
+ format: t.Optional[str] = None,
28
+ schema: t.Optional[t.Union[StructType, str]] = None,
29
+ **options: OptionalPrimitiveType,
30
+ ) -> DF:
31
+ """Loads data from a data source and returns it as a :class:`DataFrame`.
32
+
33
+ .. versionadded:: 1.4.0
34
+
35
+ .. versionchanged:: 3.4.0
36
+ Supports Spark Connect.
37
+
38
+ Parameters
39
+ ----------
40
+ path : str or list, t.Optional
41
+ t.Optional string or a list of string for file-system backed data sources.
42
+ format : str, t.Optional
43
+ t.Optional string for format of the data source. Default to 'parquet'.
44
+ schema : :class:`pyspark.sql.types.StructType` or str, t.Optional
45
+ t.Optional :class:`pyspark.sql.types.StructType` for the input schema
46
+ or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
47
+ **options : dict
48
+ all other string options
49
+
50
+ Examples
51
+ --------
52
+ Load a CSV file with format, schema and options specified.
53
+
54
+ >>> import tempfile
55
+ >>> with tempfile.TemporaryDirectory() as d:
56
+ ... # Write a DataFrame into a CSV file with a header
57
+ ... df = spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}])
58
+ ... df.write.option("header", True).mode("overwrite").format("csv").save(d)
59
+ ...
60
+ ... # Read the CSV file as a DataFrame with 'nullValue' option set to 'Hyukjin Kwon',
61
+ ... # and 'header' option set to `True`.
62
+ ... df = spark.read.load(
63
+ ... d, schema=df.schema, format="csv", nullValue="Hyukjin Kwon", header=True)
64
+ ... df.printSchema()
65
+ ... df.show()
66
+ root
67
+ |-- age: long (nullable = true)
68
+ |-- name: string (nullable = true)
69
+ +---+----+
70
+ |age|name|
71
+ +---+----+
72
+ |100|NULL|
73
+ +---+----+
74
+ """
75
+ assert path is not None, "path is required"
76
+ assert isinstance(path, str), "path must be a string"
77
+ format = format or _infer_format(path)
78
+ kwargs = {k: v for k, v in options.items() if v is not None}
79
+ if format == "json":
80
+ df = pd.read_json(path, lines=True, **kwargs) # type: ignore
81
+ elif format == "parquet":
82
+ df = pd.read_parquet(path, **kwargs) # type: ignore
83
+ elif format == "csv":
84
+ df = pd.read_csv(path, **kwargs) # type: ignore
85
+ else:
86
+ raise UnsupportedOperationError(f"Unsupported format: {format}")
87
+ schema = schema or pandas_to_spark_schema(df)
88
+ self.session._last_loaded_file = path
89
+ return self._session.createDataFrame(list(df.itertuples(index=False)), schema)
90
+
91
+
92
+ class PandasWriterMixin(_BaseDataFrameWriter, t.Generic[SESSION, DF]):
93
+ def _write(self, path: str, mode: t.Optional[str], format: str, **options): # type: ignore
94
+ mode, skip = self._validate_mode(path, mode)
95
+ if skip:
96
+ return
97
+ pandas_df = self._df.toPandas()
98
+ mode = self._mode_to_pandas_mode(mode)
99
+ kwargs = {k: v for k, v in options.items() if v is not None}
100
+ kwargs["index"] = False
101
+ if format == "csv":
102
+ kwargs["mode"] = mode
103
+ if mode == "a" and pathlib.Path(path).exists():
104
+ kwargs["header"] = False
105
+ pandas_df.to_csv(path, **kwargs)
106
+ elif format == "parquet":
107
+ if mode == "a":
108
+ raise NotImplementedError("Append mode is not supported for parquet.")
109
+ pandas_df.to_parquet(path, **kwargs)
110
+ elif format == "json":
111
+ # Pandas versions are inconsistent on how to handle True/False index so we just remove it
112
+ # since in all versions it will not result in an index column in the output.
113
+ del kwargs["index"]
114
+ kwargs["mode"] = mode
115
+ kwargs["orient"] = "records"
116
+ pandas_df.to_json(path, lines=True, **kwargs)
117
+ else:
118
+ raise NotImplementedError(f"Unsupported format: {format}")