apache-airflow-providers-openlineage 1.3.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-openlineage might be problematic. Click here for more details.

@@ -0,0 +1,60 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import typing
21
+
22
+ from airflow.configuration import conf
23
+ from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter
24
+
25
+ if typing.TYPE_CHECKING:
26
+ from airflow.models import TaskInstance
27
+
28
+ _JOB_NAMESPACE = conf.get("openlineage", "namespace", fallback=os.getenv("OPENLINEAGE_NAMESPACE", "default"))
29
+
30
+
31
+ def lineage_run_id(task_instance: TaskInstance):
32
+ """
33
+ Macro function which returns the generated run id for a given task.
34
+
35
+ This can be used to forward the run id from a task to a child run so the job hierarchy is preserved.
36
+
37
+ .. seealso::
38
+ For more information on how to use this operator, take a look at the guide:
39
+ :ref:`howto/macros:openlineage`
40
+ """
41
+ return OpenLineageAdapter.build_task_instance_run_id(
42
+ task_instance.task.task_id, task_instance.execution_date, task_instance.try_number
43
+ )
44
+
45
+
46
+ def lineage_parent_id(run_id: str, task_instance: TaskInstance):
47
+ """
48
+ Macro function which returns the generated job and run id for a given task.
49
+
50
+ This can be used to forward the ids from a task to a child run so the job
51
+ hierarchy is preserved. Child run can create ParentRunFacet from those ids.
52
+
53
+ .. seealso::
54
+ For more information on how to use this macro, take a look at the guide:
55
+ :ref:`howto/macros:openlineage`
56
+ """
57
+ job_name = OpenLineageAdapter.build_task_instance_run_id(
58
+ task_instance.task.task_id, task_instance.execution_date, task_instance.try_number
59
+ )
60
+ return f"{_JOB_NAMESPACE}/{job_name}/{run_id}"
@@ -0,0 +1,51 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import os
20
+
21
+ from airflow.configuration import conf
22
+ from airflow.plugins_manager import AirflowPlugin
23
+ from airflow.providers.openlineage.plugins.listener import get_openlineage_listener
24
+ from airflow.providers.openlineage.plugins.macros import lineage_parent_id, lineage_run_id
25
+
26
+
27
+ def _is_disabled() -> bool:
28
+ return (
29
+ conf.getboolean("openlineage", "disabled", fallback=False)
30
+ or os.getenv("OPENLINEAGE_DISABLED", "false").lower() == "true"
31
+ or (
32
+ conf.get("openlineage", "transport", fallback="") == ""
33
+ and conf.get("openlineage", "config_path", fallback="") == ""
34
+ and os.getenv("OPENLINEAGE_URL", "") == ""
35
+ and os.getenv("OPENLINEAGE_CONFIG", "") == ""
36
+ )
37
+ )
38
+
39
+
40
+ class OpenLineageProviderPlugin(AirflowPlugin):
41
+ """
42
+ Listener that emits numerous Events.
43
+
44
+ OpenLineage Plugin provides listener that emits OL events on DAG start,
45
+ complete and failure and TaskInstances start, complete and failure.
46
+ """
47
+
48
+ name = "OpenLineageProviderPlugin"
49
+ if not _is_disabled():
50
+ macros = [lineage_run_id, lineage_parent_id]
51
+ listeners = [get_openlineage_listener()]
@@ -0,0 +1,347 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING, Callable
20
+
21
+ import sqlparse
22
+ from attrs import define
23
+ from openlineage.client.facet import (
24
+ BaseFacet,
25
+ ColumnLineageDatasetFacet,
26
+ ColumnLineageDatasetFacetFieldsAdditional,
27
+ ColumnLineageDatasetFacetFieldsAdditionalInputFields,
28
+ ExtractionError,
29
+ ExtractionErrorRunFacet,
30
+ SqlJobFacet,
31
+ )
32
+ from openlineage.common.sql import DbTableMeta, SqlMeta, parse
33
+
34
+ from airflow.providers.openlineage.extractors.base import OperatorLineage
35
+ from airflow.providers.openlineage.utils.sql import (
36
+ TablesHierarchy,
37
+ create_information_schema_query,
38
+ get_table_schemas,
39
+ )
40
+ from airflow.typing_compat import TypedDict
41
+
42
+ if TYPE_CHECKING:
43
+ from openlineage.client.run import Dataset
44
+ from sqlalchemy.engine import Engine
45
+
46
+ from airflow.hooks.base import BaseHook
47
+
48
+ DEFAULT_NAMESPACE = "default"
49
+ DEFAULT_INFORMATION_SCHEMA_COLUMNS = [
50
+ "table_schema",
51
+ "table_name",
52
+ "column_name",
53
+ "ordinal_position",
54
+ "udt_name",
55
+ ]
56
+ DEFAULT_INFORMATION_SCHEMA_TABLE_NAME = "information_schema.columns"
57
+
58
+
59
+ def default_normalize_name_method(name: str) -> str:
60
+ return name.lower()
61
+
62
+
63
+ class GetTableSchemasParams(TypedDict):
64
+ """get_table_schemas params."""
65
+
66
+ normalize_name: Callable[[str], str]
67
+ is_cross_db: bool
68
+ information_schema_columns: list[str]
69
+ information_schema_table: str
70
+ is_uppercase_names: bool
71
+ database: str | None
72
+
73
+
74
+ @define
75
+ class DatabaseInfo:
76
+ """
77
+ Contains database specific information needed to process SQL statement parse result.
78
+
79
+ :param scheme: Scheme part of URI in OpenLineage namespace.
80
+ :param authority: Authority part of URI in OpenLineage namespace.
81
+ For most cases it should return `{host}:{port}` part of Airflow connection.
82
+ See: https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
83
+ :param database: Takes precedence over parsed database name.
84
+ :param information_schema_columns: List of columns names from information schema table.
85
+ :param information_schema_table_name: Information schema table name.
86
+ :param is_information_schema_cross_db: Specifies if information schema contains
87
+ cross-database data.
88
+ :param is_uppercase_names: Specifies if database accepts only uppercase names (e.g. Snowflake).
89
+ :param normalize_name_method: Method to normalize database, schema and table names.
90
+ Defaults to `name.lower()`.
91
+ """
92
+
93
+ scheme: str
94
+ authority: str | None = None
95
+ database: str | None = None
96
+ information_schema_columns: list[str] = DEFAULT_INFORMATION_SCHEMA_COLUMNS
97
+ information_schema_table_name: str = DEFAULT_INFORMATION_SCHEMA_TABLE_NAME
98
+ is_information_schema_cross_db: bool = False
99
+ is_uppercase_names: bool = False
100
+ normalize_name_method: Callable[[str], str] = default_normalize_name_method
101
+
102
+
103
+ class SQLParser:
104
+ """Interface for openlineage-sql.
105
+
106
+ :param dialect: dialect specific to the database
107
+ :param default_schema: schema applied to each table with no schema parsed
108
+ """
109
+
110
+ def __init__(self, dialect: str | None = None, default_schema: str | None = None) -> None:
111
+ self.dialect = dialect
112
+ self.default_schema = default_schema
113
+
114
+ def parse(self, sql: list[str] | str) -> SqlMeta | None:
115
+ """Parse a single or a list of SQL statements."""
116
+ return parse(sql=sql, dialect=self.dialect)
117
+
118
+ def parse_table_schemas(
119
+ self,
120
+ hook: BaseHook,
121
+ inputs: list[DbTableMeta],
122
+ outputs: list[DbTableMeta],
123
+ database_info: DatabaseInfo,
124
+ namespace: str = DEFAULT_NAMESPACE,
125
+ database: str | None = None,
126
+ sqlalchemy_engine: Engine | None = None,
127
+ ) -> tuple[list[Dataset], ...]:
128
+ """Parse schemas for input and output tables."""
129
+ database_kwargs: GetTableSchemasParams = {
130
+ "normalize_name": database_info.normalize_name_method,
131
+ "is_cross_db": database_info.is_information_schema_cross_db,
132
+ "information_schema_columns": database_info.information_schema_columns,
133
+ "information_schema_table": database_info.information_schema_table_name,
134
+ "is_uppercase_names": database_info.is_uppercase_names,
135
+ "database": database or database_info.database,
136
+ }
137
+ return get_table_schemas(
138
+ hook,
139
+ namespace,
140
+ self.default_schema,
141
+ database or database_info.database,
142
+ self.create_information_schema_query(
143
+ tables=inputs, sqlalchemy_engine=sqlalchemy_engine, **database_kwargs
144
+ )
145
+ if inputs
146
+ else None,
147
+ self.create_information_schema_query(
148
+ tables=outputs, sqlalchemy_engine=sqlalchemy_engine, **database_kwargs
149
+ )
150
+ if outputs
151
+ else None,
152
+ )
153
+
154
+ def attach_column_lineage(
155
+ self, datasets: list[Dataset], database: str | None, parse_result: SqlMeta
156
+ ) -> None:
157
+ """
158
+ Attaches column lineage facet to the list of datasets.
159
+
160
+ Note that currently each dataset has the same column lineage information set.
161
+ This would be a matter of change after OpenLineage SQL Parser improvements.
162
+ """
163
+ if not len(parse_result.column_lineage):
164
+ return
165
+ for dataset in datasets:
166
+ dataset.facets["columnLineage"] = ColumnLineageDatasetFacet(
167
+ fields={
168
+ column_lineage.descendant.name: ColumnLineageDatasetFacetFieldsAdditional(
169
+ inputFields=[
170
+ ColumnLineageDatasetFacetFieldsAdditionalInputFields(
171
+ namespace=dataset.namespace,
172
+ name=".".join(
173
+ filter(
174
+ None,
175
+ (
176
+ column_meta.origin.database or database,
177
+ column_meta.origin.schema or self.default_schema,
178
+ column_meta.origin.name,
179
+ ),
180
+ )
181
+ )
182
+ if column_meta.origin
183
+ else "",
184
+ field=column_meta.name,
185
+ )
186
+ for column_meta in column_lineage.lineage
187
+ ],
188
+ transformationType="",
189
+ transformationDescription="",
190
+ )
191
+ for column_lineage in parse_result.column_lineage
192
+ }
193
+ )
194
+
195
+ def generate_openlineage_metadata_from_sql(
196
+ self,
197
+ sql: list[str] | str,
198
+ hook: BaseHook,
199
+ database_info: DatabaseInfo,
200
+ database: str | None = None,
201
+ sqlalchemy_engine: Engine | None = None,
202
+ ) -> OperatorLineage:
203
+ """Parses SQL statement(s) and generates OpenLineage metadata.
204
+
205
+ Generated OpenLineage metadata contains:
206
+
207
+ * input tables with schemas parsed
208
+ * output tables with schemas parsed
209
+ * run facets
210
+ * job facets.
211
+
212
+ :param sql: a SQL statement or list of SQL statement to be parsed
213
+ :param hook: Airflow Hook used to connect to the database
214
+ :param database_info: database specific information
215
+ :param database: when passed it takes precedence over parsed database name
216
+ :param sqlalchemy_engine: when passed, engine's dialect is used to compile SQL queries
217
+ """
218
+ job_facets: dict[str, BaseFacet] = {"sql": SqlJobFacet(query=self.normalize_sql(sql))}
219
+ parse_result = self.parse(self.split_sql_string(sql))
220
+ if not parse_result:
221
+ return OperatorLineage(job_facets=job_facets)
222
+
223
+ run_facets: dict[str, BaseFacet] = {}
224
+ if parse_result.errors:
225
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
226
+ totalTasks=len(sql) if isinstance(sql, list) else 1,
227
+ failedTasks=len(parse_result.errors),
228
+ errors=[
229
+ ExtractionError(
230
+ errorMessage=error.message,
231
+ stackTrace=None,
232
+ task=error.origin_statement,
233
+ taskNumber=error.index,
234
+ )
235
+ for error in parse_result.errors
236
+ ],
237
+ )
238
+
239
+ namespace = self.create_namespace(database_info=database_info)
240
+ inputs, outputs = self.parse_table_schemas(
241
+ hook=hook,
242
+ inputs=parse_result.in_tables,
243
+ outputs=parse_result.out_tables,
244
+ namespace=namespace,
245
+ database=database,
246
+ database_info=database_info,
247
+ sqlalchemy_engine=sqlalchemy_engine,
248
+ )
249
+
250
+ self.attach_column_lineage(outputs, database or database_info.database, parse_result)
251
+
252
+ return OperatorLineage(
253
+ inputs=inputs,
254
+ outputs=outputs,
255
+ run_facets=run_facets,
256
+ job_facets=job_facets,
257
+ )
258
+
259
+ @staticmethod
260
+ def create_namespace(database_info: DatabaseInfo) -> str:
261
+ return (
262
+ f"{database_info.scheme}://{database_info.authority}"
263
+ if database_info.authority
264
+ else database_info.scheme
265
+ )
266
+
267
+ @classmethod
268
+ def normalize_sql(cls, sql: list[str] | str) -> str:
269
+ """Makes sure to return a semicolon-separated SQL statements."""
270
+ return ";\n".join(stmt.rstrip(" ;\r\n") for stmt in cls.split_sql_string(sql))
271
+
272
+ @classmethod
273
+ def split_sql_string(cls, sql: list[str] | str) -> list[str]:
274
+ """
275
+ Split SQL string into list of statements.
276
+
277
+ Tries to use `DbApiHook.split_sql_string` if available.
278
+ Otherwise, uses the same logic.
279
+ """
280
+ try:
281
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
282
+
283
+ split_statement = DbApiHook.split_sql_string
284
+ except (ImportError, AttributeError):
285
+ # No common.sql Airflow provider available or version is too old.
286
+ def split_statement(sql: str) -> list[str]:
287
+ splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
288
+ return [s for s in splits if s]
289
+
290
+ if isinstance(sql, str):
291
+ return split_statement(sql)
292
+ return [obj for stmt in sql for obj in cls.split_sql_string(stmt) if obj != ""]
293
+
294
+ @classmethod
295
+ def create_information_schema_query(
296
+ cls,
297
+ tables: list[DbTableMeta],
298
+ normalize_name: Callable[[str], str],
299
+ is_cross_db: bool,
300
+ information_schema_columns,
301
+ information_schema_table,
302
+ is_uppercase_names,
303
+ database: str | None = None,
304
+ sqlalchemy_engine: Engine | None = None,
305
+ ) -> str:
306
+ """Creates SELECT statement to query information schema table."""
307
+ tables_hierarchy = cls._get_tables_hierarchy(
308
+ tables,
309
+ normalize_name=normalize_name,
310
+ database=database,
311
+ is_cross_db=is_cross_db,
312
+ )
313
+ return create_information_schema_query(
314
+ columns=information_schema_columns,
315
+ information_schema_table_name=information_schema_table,
316
+ tables_hierarchy=tables_hierarchy,
317
+ uppercase_names=is_uppercase_names,
318
+ sqlalchemy_engine=sqlalchemy_engine,
319
+ )
320
+
321
+ @staticmethod
322
+ def _get_tables_hierarchy(
323
+ tables: list[DbTableMeta],
324
+ normalize_name: Callable[[str], str],
325
+ database: str | None = None,
326
+ is_cross_db: bool = False,
327
+ ) -> TablesHierarchy:
328
+ """
329
+ Creates a hierarchy of database -> schema -> table name.
330
+
331
+ This helps to create simpler information schema query grouped by
332
+ database and schema.
333
+ :param tables: List of tables.
334
+ :param normalize_name: A method to normalize all names.
335
+ :param is_cross_db: If false, set top (database) level to None
336
+ when creating hierarchy.
337
+ """
338
+ hierarchy: TablesHierarchy = {}
339
+ for table in tables:
340
+ if is_cross_db:
341
+ db = table.database or database
342
+ else:
343
+ db = None
344
+ schemas = hierarchy.setdefault(normalize_name(db) if db else db, {})
345
+ tables = schemas.setdefault(normalize_name(table.schema) if table.schema else None, [])
346
+ tables.append(table.name)
347
+ return hierarchy
@@ -0,0 +1,16 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
@@ -0,0 +1,199 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from collections import defaultdict
20
+ from contextlib import closing
21
+ from enum import IntEnum
22
+ from typing import TYPE_CHECKING, Dict, List, Optional
23
+
24
+ from attrs import define
25
+ from openlineage.client.facet import SchemaDatasetFacet, SchemaField
26
+ from openlineage.client.run import Dataset
27
+ from sqlalchemy import Column, MetaData, Table, and_, union_all
28
+
29
+ if TYPE_CHECKING:
30
+ from sqlalchemy.engine import Engine
31
+ from sqlalchemy.sql import ClauseElement
32
+
33
+ from airflow.hooks.base import BaseHook
34
+
35
+
36
+ class ColumnIndex(IntEnum):
37
+ """Enumerates the indices of columns in information schema view."""
38
+
39
+ SCHEMA = 0
40
+ TABLE_NAME = 1
41
+ COLUMN_NAME = 2
42
+ ORDINAL_POSITION = 3
43
+ # Use 'udt_name' which is the underlying type of column
44
+ UDT_NAME = 4
45
+ # Database is optional as 5th column
46
+ DATABASE = 5
47
+
48
+
49
+ TablesHierarchy = Dict[Optional[str], Dict[Optional[str], List[str]]]
50
+
51
+
52
+ @define
53
+ class TableSchema:
54
+ """Temporary object used to construct OpenLineage Dataset."""
55
+
56
+ table: str
57
+ schema: str | None
58
+ database: str | None
59
+ fields: list[SchemaField]
60
+
61
+ def to_dataset(self, namespace: str, database: str | None = None, schema: str | None = None) -> Dataset:
62
+ # Prefix the table name with database and schema name using
63
+ # the format: {database_name}.{table_schema}.{table_name}.
64
+ name = ".".join(
65
+ part
66
+ for part in [self.database or database, self.schema or schema, self.table]
67
+ if part is not None
68
+ )
69
+ return Dataset(
70
+ namespace=namespace,
71
+ name=name,
72
+ facets={"schema": SchemaDatasetFacet(fields=self.fields)} if self.fields else {},
73
+ )
74
+
75
+
76
+ def get_table_schemas(
77
+ hook: BaseHook,
78
+ namespace: str,
79
+ schema: str | None,
80
+ database: str | None,
81
+ in_query: str | None,
82
+ out_query: str | None,
83
+ ) -> tuple[list[Dataset], list[Dataset]]:
84
+ """Query database for table schemas.
85
+
86
+ Uses provided hook. Responsibility to provide queries for this function is on particular extractors.
87
+ If query for input or output table isn't provided, the query is skipped.
88
+ """
89
+ # Do not query if we did not get both queries
90
+ if not in_query and not out_query:
91
+ return [], []
92
+
93
+ with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor:
94
+ if in_query:
95
+ cursor.execute(in_query)
96
+ in_datasets = [x.to_dataset(namespace, database, schema) for x in parse_query_result(cursor)]
97
+ else:
98
+ in_datasets = []
99
+ if out_query:
100
+ cursor.execute(out_query)
101
+ out_datasets = [x.to_dataset(namespace, database, schema) for x in parse_query_result(cursor)]
102
+ else:
103
+ out_datasets = []
104
+ return in_datasets, out_datasets
105
+
106
+
107
+ def parse_query_result(cursor) -> list[TableSchema]:
108
+ """Fetch results from DB-API 2.0 cursor and creates list of table schemas.
109
+
110
+ For each row it creates :class:`TableSchema`.
111
+ """
112
+ schemas: dict = {}
113
+ columns: dict = defaultdict(list)
114
+ for row in cursor.fetchall():
115
+ table_schema_name: str = row[ColumnIndex.SCHEMA]
116
+ table_name: str = row[ColumnIndex.TABLE_NAME]
117
+ table_column: SchemaField = SchemaField(
118
+ name=row[ColumnIndex.COLUMN_NAME],
119
+ type=row[ColumnIndex.UDT_NAME],
120
+ description=None,
121
+ )
122
+ ordinal_position = row[ColumnIndex.ORDINAL_POSITION]
123
+ try:
124
+ table_database = row[ColumnIndex.DATABASE]
125
+ except IndexError:
126
+ table_database = None
127
+
128
+ # Attempt to get table schema
129
+ table_key = ".".join(filter(None, [table_database, table_schema_name, table_name]))
130
+
131
+ schemas[table_key] = TableSchema(
132
+ table=table_name, schema=table_schema_name, database=table_database, fields=[]
133
+ )
134
+ columns[table_key].append((ordinal_position, table_column))
135
+
136
+ for schema in schemas.values():
137
+ table_key = ".".join(filter(None, [schema.database, schema.schema, schema.table]))
138
+ schema.fields = [x for _, x in sorted(columns[table_key])]
139
+
140
+ return list(schemas.values())
141
+
142
+
143
+ def create_information_schema_query(
144
+ columns: list[str],
145
+ information_schema_table_name: str,
146
+ tables_hierarchy: TablesHierarchy,
147
+ uppercase_names: bool = False,
148
+ sqlalchemy_engine: Engine | None = None,
149
+ ) -> str:
150
+ """Creates query for getting table schemas from information schema."""
151
+ metadata = MetaData(sqlalchemy_engine)
152
+ select_statements = []
153
+ for db, schema_mapping in tables_hierarchy.items():
154
+ # Information schema table name is expected to be "< information_schema schema >.<view/table name>"
155
+ # usually "information_schema.columns". In order to use table identifier correct for various table
156
+ # we need to pass first part of dot-separated identifier as `schema` argument to `sqlalchemy.Table`.
157
+ if db:
158
+ # Use database as first part of table identifier.
159
+ schema = db
160
+ table_name = information_schema_table_name
161
+ else:
162
+ # When no database passed, use schema as first part of table identifier.
163
+ schema, table_name = information_schema_table_name.split(".")
164
+ information_schema_table = Table(
165
+ table_name,
166
+ metadata,
167
+ *[Column(column) for column in columns],
168
+ schema=schema,
169
+ quote=False,
170
+ )
171
+ filter_clauses = create_filter_clauses(schema_mapping, information_schema_table, uppercase_names)
172
+ select_statements.append(information_schema_table.select().filter(*filter_clauses))
173
+ return str(
174
+ union_all(*select_statements).compile(sqlalchemy_engine, compile_kwargs={"literal_binds": True})
175
+ )
176
+
177
+
178
+ def create_filter_clauses(
179
+ schema_mapping: dict, information_schema_table: Table, uppercase_names: bool = False
180
+ ) -> ClauseElement:
181
+ """
182
+ Creates comprehensive filter clauses for all tables in one database.
183
+
184
+ :param schema_mapping: a dictionary of schema names and list of tables in each
185
+ :param information_schema_table: `sqlalchemy.Table` instance used to construct clauses
186
+ For most SQL dbs it contains `table_name` and `table_schema` columns,
187
+ therefore it is expected the table has them defined.
188
+ :param uppercase_names: if True use schema and table names uppercase
189
+ """
190
+ filter_clauses = []
191
+ for schema, tables in schema_mapping.items():
192
+ filter_clause = information_schema_table.c.table_name.in_(
193
+ name.upper() if uppercase_names else name for name in tables
194
+ )
195
+ if schema:
196
+ schema = schema.upper() if uppercase_names else schema
197
+ filter_clause = and_(information_schema_table.c.table_schema == schema, filter_clause)
198
+ filter_clauses.append(filter_clause)
199
+ return filter_clauses