altimate-code 0.5.2 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +12 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot.dataframe.sql import functions as F
6
+ from sqlglot.dataframe.sql.column import Column
7
+ from sqlglot.dataframe.sql.operations import Operation, operation
8
+
9
+ if t.TYPE_CHECKING:
10
+ from sqlglot.dataframe.sql.dataframe import DataFrame
11
+
12
+
13
+ class GroupedData:
14
+ def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation):
15
+ self._df = df.copy()
16
+ self.spark = df.spark
17
+ self.last_op = last_op
18
+ self.group_by_cols = group_by_cols
19
+
20
+ def _get_function_applied_columns(
21
+ self, func_name: str, cols: t.Tuple[str, ...]
22
+ ) -> t.List[Column]:
23
+ func_name = func_name.lower()
24
+ return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
25
+
26
+ @operation(Operation.SELECT)
27
+ def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame:
28
+ columns = (
29
+ [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
30
+ if isinstance(exprs[0], dict)
31
+ else exprs
32
+ )
33
+ cols = self._df._ensure_and_normalize_cols(columns)
34
+
35
+ expression = self._df.expression.group_by(
36
+ *[x.expression for x in self.group_by_cols]
37
+ ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
38
+ return self._df.copy(expression=expression)
39
+
40
+ def count(self) -> DataFrame:
41
+ return self.agg(F.count("*").alias("count"))
42
+
43
+ def mean(self, *cols: str) -> DataFrame:
44
+ return self.avg(*cols)
45
+
46
+ def avg(self, *cols: str) -> DataFrame:
47
+ return self.agg(*self._get_function_applied_columns("avg", cols))
48
+
49
+ def max(self, *cols: str) -> DataFrame:
50
+ return self.agg(*self._get_function_applied_columns("max", cols))
51
+
52
+ def min(self, *cols: str) -> DataFrame:
53
+ return self.agg(*self._get_function_applied_columns("min", cols))
54
+
55
+ def sum(self, *cols: str) -> DataFrame:
56
+ return self.agg(*self._get_function_applied_columns("sum", cols))
57
+
58
+ def pivot(self, *cols: str) -> DataFrame:
59
+ raise NotImplementedError("Sum distinct is not currently implemented")
@@ -0,0 +1,78 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import expressions as exp
6
+ from sqlglot.dataframe.sql.column import Column
7
+ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
8
+ from sqlglot.helper import ensure_list
9
+
10
+ NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
11
+
12
+ if t.TYPE_CHECKING:
13
+ from sqlglot.dataframe.sql.session import SparkSession
14
+
15
+
16
+ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
17
+ expr = ensure_list(expr)
18
+ expressions = _ensure_expressions(expr)
19
+ for expression in expressions:
20
+ identifiers = expression.find_all(exp.Identifier)
21
+ for identifier in identifiers:
22
+ identifier.transform(spark.dialect.normalize_identifier)
23
+ replace_alias_name_with_cte_name(spark, expression_context, identifier)
24
+ replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
25
+
26
+
27
+ def replace_alias_name_with_cte_name(
28
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
29
+ ):
30
+ if id.alias_or_name in spark.name_to_sequence_id_mapping:
31
+ for cte in reversed(expression_context.ctes):
32
+ if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
33
+ _set_alias_name(id, cte.alias_or_name)
34
+ break
35
+
36
+
37
+ def replace_branch_and_sequence_ids_with_cte_name(
38
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
39
+ ):
40
+ if id.alias_or_name in spark.known_ids:
41
+ # Check if we have a join and if both the tables in that join share a common branch id
42
+ # If so we need to have this reference the left table by default unless the id is a sequence
43
+ # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
44
+ # be common in practice
45
+ if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
46
+ join_table_aliases = [
47
+ x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
48
+ ]
49
+ ctes_in_join = [
50
+ cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
51
+ ]
52
+ if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
53
+ assert len(ctes_in_join) == 2
54
+ _set_alias_name(id, ctes_in_join[0].alias_or_name)
55
+ return
56
+
57
+ for cte in reversed(expression_context.ctes):
58
+ if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]):
59
+ _set_alias_name(id, cte.alias_or_name)
60
+ return
61
+
62
+
63
+ def _set_alias_name(id: exp.Identifier, name: str):
64
+ id.set("this", name)
65
+
66
+
67
+ def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
68
+ results = []
69
+ for value in values:
70
+ if isinstance(value, str):
71
+ results.append(Column.ensure_col(value).expression)
72
+ elif isinstance(value, Column):
73
+ results.append(value.expression)
74
+ elif isinstance(value, exp.Expression):
75
+ results.append(value)
76
+ else:
77
+ raise ValueError(f"Got an invalid type to normalize: {type(value)}")
78
+ return results
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import typing as t
5
+ from enum import IntEnum
6
+
7
+ if t.TYPE_CHECKING:
8
+ from sqlglot.dataframe.sql.dataframe import DataFrame
9
+ from sqlglot.dataframe.sql.group import GroupedData
10
+
11
+
12
+ class Operation(IntEnum):
13
+ INIT = -1
14
+ NO_OP = 0
15
+ FROM = 1
16
+ WHERE = 2
17
+ GROUP_BY = 3
18
+ HAVING = 4
19
+ SELECT = 5
20
+ ORDER_BY = 6
21
+ LIMIT = 7
22
+
23
+
24
+ def operation(op: Operation):
25
+ """
26
+ Decorator used around DataFrame methods to indicate what type of operation is being performed from the
27
+ ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
28
+ included with the previous operation.
29
+
30
+ Ex: After a user does a join we want to allow them to select which columns for the different
31
+ tables that they want to carry through to the following operation. If we put that join in
32
+ a CTE preemptively then the user would not have a chance to select which column they want
33
+ in cases where there is overlap in names.
34
+ """
35
+
36
+ def decorator(func: t.Callable):
37
+ @functools.wraps(func)
38
+ def wrapper(self: DataFrame, *args, **kwargs):
39
+ if self.last_op == Operation.INIT:
40
+ self = self._convert_leaf_to_cte()
41
+ self.last_op = Operation.NO_OP
42
+ last_op = self.last_op
43
+ new_op = op if op != Operation.NO_OP else last_op
44
+ if new_op < last_op or (last_op == new_op == Operation.SELECT):
45
+ self = self._convert_leaf_to_cte()
46
+ df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
47
+ df.last_op = new_op # type: ignore
48
+ return df
49
+
50
+ wrapper.__wrapped__ = func # type: ignore
51
+ return wrapper
52
+
53
+ return decorator
@@ -0,0 +1,108 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ import sqlglot as sqlglot
6
+ from sqlglot import expressions as exp
7
+ from sqlglot.helper import object_to_dict
8
+
9
+ if t.TYPE_CHECKING:
10
+ from sqlglot.dataframe.sql.dataframe import DataFrame
11
+ from sqlglot.dataframe.sql.session import SparkSession
12
+
13
+
14
+ class DataFrameReader:
15
+ def __init__(self, spark: SparkSession):
16
+ self.spark = spark
17
+
18
+ def table(self, tableName: str) -> DataFrame:
19
+ from sqlglot.dataframe.sql.dataframe import DataFrame
20
+ from sqlglot.dataframe.sql.session import SparkSession
21
+
22
+ sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
23
+
24
+ return DataFrame(
25
+ self.spark,
26
+ exp.Select()
27
+ .from_(
28
+ exp.to_table(tableName, dialect=SparkSession().dialect).transform(
29
+ SparkSession().dialect.normalize_identifier
30
+ )
31
+ )
32
+ .select(
33
+ *(
34
+ column
35
+ for column in sqlglot.schema.column_names(
36
+ tableName, dialect=SparkSession().dialect
37
+ )
38
+ )
39
+ ),
40
+ )
41
+
42
+
43
+ class DataFrameWriter:
44
+ def __init__(
45
+ self,
46
+ df: DataFrame,
47
+ spark: t.Optional[SparkSession] = None,
48
+ mode: t.Optional[str] = None,
49
+ by_name: bool = False,
50
+ ):
51
+ self._df = df
52
+ self._spark = spark or df.spark
53
+ self._mode = mode
54
+ self._by_name = by_name
55
+
56
+ def copy(self, **kwargs) -> DataFrameWriter:
57
+ return DataFrameWriter(
58
+ **{
59
+ k[1:] if k.startswith("_") else k: v
60
+ for k, v in object_to_dict(self, **kwargs).items()
61
+ }
62
+ )
63
+
64
+ def sql(self, **kwargs) -> t.List[str]:
65
+ return self._df.sql(**kwargs)
66
+
67
+ def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter:
68
+ return self.copy(_mode=saveMode)
69
+
70
+ @property
71
+ def byName(self):
72
+ return self.copy(by_name=True)
73
+
74
+ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
75
+ from sqlglot.dataframe.sql.session import SparkSession
76
+
77
+ output_expression_container = exp.Insert(
78
+ **{
79
+ "this": exp.to_table(tableName),
80
+ "overwrite": overwrite,
81
+ }
82
+ )
83
+ df = self._df.copy(output_expression_container=output_expression_container)
84
+ if self._by_name:
85
+ columns = sqlglot.schema.column_names(
86
+ tableName, only_visible=True, dialect=SparkSession().dialect
87
+ )
88
+ df = df._convert_leaf_to_cte().select(*columns)
89
+
90
+ return self.copy(_df=df)
91
+
92
+ def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None):
93
+ if format is not None:
94
+ raise NotImplementedError("Providing Format in the save as table is not supported")
95
+ exists, replace, mode = None, None, mode or str(self._mode)
96
+ if mode == "append":
97
+ return self.insertInto(name)
98
+ if mode == "ignore":
99
+ exists = True
100
+ if mode == "overwrite":
101
+ replace = True
102
+ output_expression_container = exp.Create(
103
+ this=exp.to_table(name),
104
+ kind="TABLE",
105
+ exists=exists,
106
+ replace=replace,
107
+ )
108
+ return self.copy(_df=self._df.copy(output_expression_container=output_expression_container))
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+ import uuid
5
+ from collections import defaultdict
6
+
7
+ import sqlglot as sqlglot
8
+ from sqlglot import Dialect, expressions as exp
9
+ from sqlglot.dataframe.sql import functions as F
10
+ from sqlglot.dataframe.sql.dataframe import DataFrame
11
+ from sqlglot.dataframe.sql.readwriter import DataFrameReader
12
+ from sqlglot.dataframe.sql.types import StructType
13
+ from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
14
+ from sqlglot.helper import classproperty
15
+
16
+ if t.TYPE_CHECKING:
17
+ from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
18
+
19
+
20
+ class SparkSession:
21
+ DEFAULT_DIALECT = "spark"
22
+ _instance = None
23
+
24
+ def __init__(self):
25
+ if not hasattr(self, "known_ids"):
26
+ self.known_ids = set()
27
+ self.known_branch_ids = set()
28
+ self.known_sequence_ids = set()
29
+ self.name_to_sequence_id_mapping = defaultdict(list)
30
+ self.incrementing_id = 1
31
+ self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
32
+
33
+ def __new__(cls, *args, **kwargs) -> SparkSession:
34
+ if cls._instance is None:
35
+ cls._instance = super().__new__(cls)
36
+ return cls._instance
37
+
38
+ @property
39
+ def read(self) -> DataFrameReader:
40
+ return DataFrameReader(self)
41
+
42
+ def table(self, tableName: str) -> DataFrame:
43
+ return self.read.table(tableName)
44
+
45
+ def createDataFrame(
46
+ self,
47
+ data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
48
+ schema: t.Optional[SchemaInput] = None,
49
+ samplingRatio: t.Optional[float] = None,
50
+ verifySchema: bool = False,
51
+ ) -> DataFrame:
52
+ from sqlglot.dataframe.sql.dataframe import DataFrame
53
+
54
+ if samplingRatio is not None or verifySchema:
55
+ raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
56
+ if schema is not None and (
57
+ not isinstance(schema, (StructType, str, list))
58
+ or (isinstance(schema, list) and not isinstance(schema[0], str))
59
+ ):
60
+ raise NotImplementedError("Only schema of either list or string of list supported")
61
+ if not data:
62
+ raise ValueError("Must provide data to create into a DataFrame")
63
+
64
+ column_mapping: t.Dict[str, t.Optional[str]]
65
+ if schema is not None:
66
+ column_mapping = get_column_mapping_from_schema_input(schema)
67
+ elif isinstance(data[0], dict):
68
+ column_mapping = {col_name.strip(): None for col_name in data[0]}
69
+ else:
70
+ column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}
71
+
72
+ data_expressions = [
73
+ exp.Tuple(
74
+ expressions=list(
75
+ map(
76
+ lambda x: F.lit(x).expression,
77
+ row if not isinstance(row, dict) else row.values(),
78
+ )
79
+ )
80
+ )
81
+ for row in data
82
+ ]
83
+
84
+ sel_columns = [
85
+ F.col(name).cast(data_type).alias(name).expression
86
+ if data_type is not None
87
+ else F.col(name).expression
88
+ for name, data_type in column_mapping.items()
89
+ ]
90
+
91
+ select_kwargs = {
92
+ "expressions": sel_columns,
93
+ "from": exp.From(
94
+ this=exp.Values(
95
+ expressions=data_expressions,
96
+ alias=exp.TableAlias(
97
+ this=exp.to_identifier(self._auto_incrementing_name),
98
+ columns=[exp.to_identifier(col_name) for col_name in column_mapping],
99
+ ),
100
+ ),
101
+ ),
102
+ }
103
+
104
+ sel_expression = exp.Select(**select_kwargs)
105
+ return DataFrame(self, sel_expression)
106
+
107
+ def sql(self, sqlQuery: str) -> DataFrame:
108
+ expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
109
+ if isinstance(expression, exp.Select):
110
+ df = DataFrame(self, expression)
111
+ df = df._convert_leaf_to_cte()
112
+ elif isinstance(expression, (exp.Create, exp.Insert)):
113
+ select_expression = expression.expression.copy()
114
+ if isinstance(expression, exp.Insert):
115
+ select_expression.set("with", expression.args.get("with"))
116
+ expression.set("with", None)
117
+ del expression.args["expression"]
118
+ df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
119
+ df = df._convert_leaf_to_cte()
120
+ else:
121
+ raise ValueError(
122
+ "Unknown expression type provided in the SQL. Please create an issue with the SQL."
123
+ )
124
+ return df
125
+
126
+ @property
127
+ def _auto_incrementing_name(self) -> str:
128
+ name = f"a{self.incrementing_id}"
129
+ self.incrementing_id += 1
130
+ return name
131
+
132
+ @property
133
+ def _random_branch_id(self) -> str:
134
+ id = self._random_id
135
+ self.known_branch_ids.add(id)
136
+ return id
137
+
138
+ @property
139
+ def _random_sequence_id(self):
140
+ id = self._random_id
141
+ self.known_sequence_ids.add(id)
142
+ return id
143
+
144
+ @property
145
+ def _random_id(self) -> str:
146
+ id = "r" + uuid.uuid4().hex
147
+ self.known_ids.add(id)
148
+ return id
149
+
150
+ @property
151
+ def _join_hint_names(self) -> t.Set[str]:
152
+ return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
153
+
154
+ def _add_alias_to_mapping(self, name: str, sequence_id: str):
155
+ self.name_to_sequence_id_mapping[name].append(sequence_id)
156
+
157
+ class Builder:
158
+ SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
159
+
160
+ def __init__(self):
161
+ self.dialect = "spark"
162
+
163
+ def __getattr__(self, item) -> SparkSession.Builder:
164
+ return self
165
+
166
+ def __call__(self, *args, **kwargs):
167
+ return self
168
+
169
+ def config(
170
+ self,
171
+ key: t.Optional[str] = None,
172
+ value: t.Optional[t.Any] = None,
173
+ *,
174
+ map: t.Optional[t.Dict[str, t.Any]] = None,
175
+ **kwargs: t.Any,
176
+ ) -> SparkSession.Builder:
177
+ if key == self.SQLFRAME_DIALECT_KEY:
178
+ self.dialect = value
179
+ elif map and self.SQLFRAME_DIALECT_KEY in map:
180
+ self.dialect = map[self.SQLFRAME_DIALECT_KEY]
181
+ return self
182
+
183
+ def getOrCreate(self) -> SparkSession:
184
+ spark = SparkSession()
185
+ spark.dialect = Dialect.get_or_raise(self.dialect)()
186
+ return spark
187
+
188
+ @classproperty
189
+ def builder(cls) -> Builder:
190
+ return cls.Builder()
@@ -0,0 +1,9 @@
1
+ import typing as t
2
+
3
+ from sqlglot import expressions as exp
4
+
5
+
6
+ def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]):
7
+ if isinstance(node, exp.Identifier) and node in replacement_mapping:
8
+ node = node.replace(replacement_mapping[node].copy())
9
+ return node