acryl-datahub 1.0.0.1rc2__py3-none-any.whl → 1.0.0.1rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of acryl-datahub might be problematic. Click here for more details.
- {acryl_datahub-1.0.0.1rc2.dist-info → acryl_datahub-1.0.0.1rc4.dist-info}/METADATA +2569 -2569
- {acryl_datahub-1.0.0.1rc2.dist-info → acryl_datahub-1.0.0.1rc4.dist-info}/RECORD +37 -35
- datahub/_version.py +1 -1
- datahub/emitter/rest_emitter.py +2 -2
- datahub/ingestion/graph/client.py +6 -11
- datahub/ingestion/graph/filters.py +22 -2
- datahub/ingestion/source/common/subtypes.py +1 -1
- datahub/ingestion/source/gc/soft_deleted_entity_cleanup.py +101 -104
- datahub/ingestion/source/ge_data_profiler.py +11 -1
- datahub/ingestion/source/mlflow.py +19 -1
- datahub/ingestion/source/redshift/lineage_v2.py +7 -0
- datahub/ingestion/source/redshift/query.py +1 -1
- datahub/ingestion/source/snowflake/constants.py +1 -0
- datahub/ingestion/source/snowflake/snowflake_config.py +14 -1
- datahub/ingestion/source/snowflake/snowflake_query.py +17 -0
- datahub/ingestion/source/snowflake/snowflake_report.py +3 -0
- datahub/ingestion/source/snowflake/snowflake_schema.py +29 -0
- datahub/ingestion/source/snowflake/snowflake_schema_gen.py +112 -42
- datahub/ingestion/source/snowflake/snowflake_utils.py +25 -1
- datahub/ingestion/source/sql/mssql/job_models.py +15 -1
- datahub/ingestion/source/sql/mssql/source.py +8 -4
- datahub/ingestion/source/sql/stored_procedures/__init__.py +0 -0
- datahub/ingestion/source/sql/stored_procedures/base.py +242 -0
- datahub/ingestion/source/sql/{mssql/stored_procedure_lineage.py → stored_procedures/lineage.py} +1 -29
- datahub/ingestion/source/superset.py +153 -13
- datahub/ingestion/source/vertexai/vertexai.py +1 -1
- datahub/metadata/schema.avsc +2 -0
- datahub/metadata/schemas/Deprecation.avsc +2 -0
- datahub/metadata/schemas/MetadataChangeEvent.avsc +2 -0
- datahub/sdk/__init__.py +1 -0
- datahub/sdk/main_client.py +2 -1
- datahub/sdk/search_filters.py +18 -23
- datahub/sql_parsing/split_statements.py +17 -3
- {acryl_datahub-1.0.0.1rc2.dist-info → acryl_datahub-1.0.0.1rc4.dist-info}/WHEEL +0 -0
- {acryl_datahub-1.0.0.1rc2.dist-info → acryl_datahub-1.0.0.1rc4.dist-info}/entry_points.txt +0 -0
- {acryl_datahub-1.0.0.1rc2.dist-info → acryl_datahub-1.0.0.1rc4.dist-info}/licenses/LICENSE +0 -0
- {acryl_datahub-1.0.0.1rc2.dist-info → acryl_datahub-1.0.0.1rc4.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ import logging
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from datetime import datetime
|
|
5
5
|
from functools import lru_cache
|
|
6
|
-
from typing import Any, Dict, Iterable, List, Optional
|
|
6
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import dateutil.parser as dp
|
|
9
9
|
import requests
|
|
@@ -11,6 +11,7 @@ from pydantic import BaseModel
|
|
|
11
11
|
from pydantic.class_validators import root_validator, validator
|
|
12
12
|
from pydantic.fields import Field
|
|
13
13
|
|
|
14
|
+
import datahub.emitter.mce_builder as builder
|
|
14
15
|
from datahub.configuration.common import AllowDenyPattern
|
|
15
16
|
from datahub.configuration.source_common import (
|
|
16
17
|
EnvConfigMixin,
|
|
@@ -26,6 +27,7 @@ from datahub.emitter.mce_builder import (
|
|
|
26
27
|
make_schema_field_urn,
|
|
27
28
|
make_user_urn,
|
|
28
29
|
)
|
|
30
|
+
from datahub.emitter.mcp import MetadataChangeProposalWrapper
|
|
29
31
|
from datahub.emitter.mcp_builder import add_domain_to_entity_wu
|
|
30
32
|
from datahub.ingestion.api.common import PipelineContext
|
|
31
33
|
from datahub.ingestion.api.decorators import (
|
|
@@ -50,6 +52,8 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
|
|
|
50
52
|
)
|
|
51
53
|
from datahub.metadata.com.linkedin.pegasus2avro.common import (
|
|
52
54
|
ChangeAuditStamps,
|
|
55
|
+
InputField,
|
|
56
|
+
InputFields,
|
|
53
57
|
Status,
|
|
54
58
|
TimeStamp,
|
|
55
59
|
)
|
|
@@ -60,11 +64,16 @@ from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import (
|
|
|
60
64
|
)
|
|
61
65
|
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
|
|
62
66
|
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
|
|
67
|
+
BooleanTypeClass,
|
|
68
|
+
DateTypeClass,
|
|
63
69
|
MySqlDDL,
|
|
64
70
|
NullType,
|
|
71
|
+
NullTypeClass,
|
|
72
|
+
NumberTypeClass,
|
|
65
73
|
SchemaField,
|
|
66
74
|
SchemaFieldDataType,
|
|
67
75
|
SchemaMetadata,
|
|
76
|
+
StringTypeClass,
|
|
68
77
|
)
|
|
69
78
|
from datahub.metadata.schema_classes import (
|
|
70
79
|
AuditStampClass,
|
|
@@ -113,9 +122,17 @@ chart_type_from_viz_type = {
|
|
|
113
122
|
"box_plot": ChartTypeClass.BAR,
|
|
114
123
|
}
|
|
115
124
|
|
|
116
|
-
|
|
117
125
|
platform_without_databases = ["druid"]
|
|
118
126
|
|
|
127
|
+
FIELD_TYPE_MAPPING = {
|
|
128
|
+
"INT": NumberTypeClass,
|
|
129
|
+
"STRING": StringTypeClass,
|
|
130
|
+
"FLOAT": NumberTypeClass,
|
|
131
|
+
"DATETIME": DateTypeClass,
|
|
132
|
+
"BOOLEAN": BooleanTypeClass,
|
|
133
|
+
"SQL": StringTypeClass,
|
|
134
|
+
}
|
|
135
|
+
|
|
119
136
|
|
|
120
137
|
@dataclass
|
|
121
138
|
class SupersetSourceReport(StaleEntityRemovalSourceReport):
|
|
@@ -512,7 +529,119 @@ class SupersetSource(StatefulIngestionSourceBase):
|
|
|
512
529
|
entity_urn=dashboard_snapshot.urn,
|
|
513
530
|
)
|
|
514
531
|
|
|
515
|
-
def
|
|
532
|
+
def build_input_fields(
|
|
533
|
+
self,
|
|
534
|
+
chart_columns: List[Tuple[str, str, str]],
|
|
535
|
+
datasource_urn: Union[str, None],
|
|
536
|
+
) -> List[InputField]:
|
|
537
|
+
input_fields: List[InputField] = []
|
|
538
|
+
|
|
539
|
+
for column in chart_columns:
|
|
540
|
+
col_name, col_type, description = column
|
|
541
|
+
if not col_type or not datasource_urn:
|
|
542
|
+
continue
|
|
543
|
+
|
|
544
|
+
type_class = FIELD_TYPE_MAPPING.get(
|
|
545
|
+
col_type.upper(), NullTypeClass
|
|
546
|
+
) # gets the type mapping
|
|
547
|
+
|
|
548
|
+
input_fields.append(
|
|
549
|
+
InputField(
|
|
550
|
+
schemaFieldUrn=builder.make_schema_field_urn(
|
|
551
|
+
parent_urn=str(datasource_urn),
|
|
552
|
+
field_path=col_name,
|
|
553
|
+
),
|
|
554
|
+
schemaField=SchemaField(
|
|
555
|
+
fieldPath=col_name,
|
|
556
|
+
type=SchemaFieldDataType(type=type_class()), # type: ignore
|
|
557
|
+
description=(description if description != "null" else ""),
|
|
558
|
+
nativeDataType=col_type,
|
|
559
|
+
globalTags=None,
|
|
560
|
+
nullable=True,
|
|
561
|
+
),
|
|
562
|
+
)
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
return input_fields
|
|
566
|
+
|
|
567
|
+
def construct_chart_cll(
|
|
568
|
+
self,
|
|
569
|
+
chart_data: dict,
|
|
570
|
+
datasource_urn: Union[str, None],
|
|
571
|
+
datasource_id: Union[Any, int],
|
|
572
|
+
) -> List[InputField]:
|
|
573
|
+
column_data: List[Union[str, dict]] = chart_data.get("form_data", {}).get(
|
|
574
|
+
"all_columns", []
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
# the second field represents whether its a SQL expression,
|
|
578
|
+
# false being just regular column and true being SQL col
|
|
579
|
+
chart_column_data: List[Tuple[str, bool]] = [
|
|
580
|
+
(column, False)
|
|
581
|
+
if isinstance(column, str)
|
|
582
|
+
else (column.get("label", ""), True)
|
|
583
|
+
for column in column_data
|
|
584
|
+
]
|
|
585
|
+
|
|
586
|
+
dataset_columns: List[Tuple[str, str, str]] = []
|
|
587
|
+
|
|
588
|
+
# parses the superset dataset's column info, to build type and description info
|
|
589
|
+
if datasource_id:
|
|
590
|
+
dataset_info = self.get_dataset_info(datasource_id).get("result", {})
|
|
591
|
+
dataset_column_info = dataset_info.get("columns", [])
|
|
592
|
+
|
|
593
|
+
for column in dataset_column_info:
|
|
594
|
+
col_name = column.get("column_name", "")
|
|
595
|
+
col_type = column.get("type", "")
|
|
596
|
+
col_description = column.get("description", "")
|
|
597
|
+
|
|
598
|
+
# if missing column name or column type, cannot construct the column,
|
|
599
|
+
# so we skip this column, missing description is fine
|
|
600
|
+
if col_name == "" or col_type == "":
|
|
601
|
+
logger.info(f"could not construct column lineage for {column}")
|
|
602
|
+
continue
|
|
603
|
+
|
|
604
|
+
dataset_columns.append((col_name, col_type, col_description))
|
|
605
|
+
else:
|
|
606
|
+
# if no datasource id, cannot build cll, just return
|
|
607
|
+
logger.warning(
|
|
608
|
+
"no datasource id was found, cannot build column level lineage"
|
|
609
|
+
)
|
|
610
|
+
return []
|
|
611
|
+
|
|
612
|
+
chart_columns: List[Tuple[str, str, str]] = []
|
|
613
|
+
for chart_col in chart_column_data:
|
|
614
|
+
chart_col_name, is_sql = chart_col
|
|
615
|
+
if is_sql:
|
|
616
|
+
chart_columns.append(
|
|
617
|
+
(
|
|
618
|
+
chart_col_name,
|
|
619
|
+
"SQL",
|
|
620
|
+
"",
|
|
621
|
+
)
|
|
622
|
+
)
|
|
623
|
+
continue
|
|
624
|
+
|
|
625
|
+
# find matching upstream column
|
|
626
|
+
for dataset_col in dataset_columns:
|
|
627
|
+
dataset_col_name, dataset_col_type, dataset_col_description = (
|
|
628
|
+
dataset_col
|
|
629
|
+
)
|
|
630
|
+
if dataset_col_name == chart_col_name:
|
|
631
|
+
chart_columns.append(
|
|
632
|
+
(chart_col_name, dataset_col_type, dataset_col_description)
|
|
633
|
+
) # column name, column type, description
|
|
634
|
+
break
|
|
635
|
+
|
|
636
|
+
# if no matching upstream column was found
|
|
637
|
+
if len(chart_columns) == 0 or chart_columns[-1][0] != chart_col_name:
|
|
638
|
+
chart_columns.append((chart_col_name, "", ""))
|
|
639
|
+
|
|
640
|
+
return self.build_input_fields(chart_columns, datasource_urn)
|
|
641
|
+
|
|
642
|
+
def construct_chart_from_chart_data(
|
|
643
|
+
self, chart_data: dict
|
|
644
|
+
) -> Iterable[MetadataWorkUnit]:
|
|
516
645
|
chart_urn = make_chart_urn(
|
|
517
646
|
platform=self.platform,
|
|
518
647
|
name=str(chart_data["id"]),
|
|
@@ -600,6 +729,18 @@ class SupersetSource(StatefulIngestionSourceBase):
|
|
|
600
729
|
)
|
|
601
730
|
chart_snapshot.aspects.append(chart_info)
|
|
602
731
|
|
|
732
|
+
input_fields = self.construct_chart_cll(
|
|
733
|
+
chart_data, datasource_urn, datasource_id
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
if input_fields:
|
|
737
|
+
yield MetadataChangeProposalWrapper(
|
|
738
|
+
entityUrn=chart_urn,
|
|
739
|
+
aspect=InputFields(
|
|
740
|
+
fields=sorted(input_fields, key=lambda x: x.schemaFieldUrn)
|
|
741
|
+
),
|
|
742
|
+
).as_workunit()
|
|
743
|
+
|
|
603
744
|
chart_owners_list = self.build_owner_urn(chart_data)
|
|
604
745
|
owners_info = OwnershipClass(
|
|
605
746
|
owners=[
|
|
@@ -612,7 +753,14 @@ class SupersetSource(StatefulIngestionSourceBase):
|
|
|
612
753
|
lastModified=last_modified,
|
|
613
754
|
)
|
|
614
755
|
chart_snapshot.aspects.append(owners_info)
|
|
615
|
-
|
|
756
|
+
yield MetadataWorkUnit(
|
|
757
|
+
id=chart_urn, mce=MetadataChangeEvent(proposedSnapshot=chart_snapshot)
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
yield from self._get_domain_wu(
|
|
761
|
+
title=chart_data.get("slice_name", ""),
|
|
762
|
+
entity_urn=chart_urn,
|
|
763
|
+
)
|
|
616
764
|
|
|
617
765
|
def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
|
|
618
766
|
for chart_data in self.paginate_entity_api_results("chart/", PAGE_SIZE):
|
|
@@ -642,20 +790,12 @@ class SupersetSource(StatefulIngestionSourceBase):
|
|
|
642
790
|
f"Chart '{chart_name}' (id: {chart_id}) uses dataset '{dataset_name}' which is filtered by dataset_pattern"
|
|
643
791
|
)
|
|
644
792
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
mce = MetadataChangeEvent(proposedSnapshot=chart_snapshot)
|
|
793
|
+
yield from self.construct_chart_from_chart_data(chart_data)
|
|
648
794
|
except Exception as e:
|
|
649
795
|
self.report.warning(
|
|
650
796
|
f"Failed to construct chart snapshot. Chart name: {chart_name}. Error: \n{e}"
|
|
651
797
|
)
|
|
652
798
|
continue
|
|
653
|
-
# Emit the chart
|
|
654
|
-
yield MetadataWorkUnit(id=chart_snapshot.urn, mce=mce)
|
|
655
|
-
yield from self._get_domain_wu(
|
|
656
|
-
title=chart_data.get("slice_name", ""),
|
|
657
|
-
entity_urn=chart_snapshot.urn,
|
|
658
|
-
)
|
|
659
799
|
|
|
660
800
|
def gen_schema_fields(self, column_data: List[Dict[str, str]]) -> List[SchemaField]:
|
|
661
801
|
schema_fields: List[SchemaField] = []
|
|
@@ -107,7 +107,6 @@ class ContainerKeyWithId(ContainerKey):
|
|
|
107
107
|
SourceCapability.DESCRIPTIONS,
|
|
108
108
|
"Extract descriptions for Vertex AI Registered Models and Model Versions",
|
|
109
109
|
)
|
|
110
|
-
@capability(SourceCapability.TAGS, "Extract tags for Vertex AI Registered Model Stages")
|
|
111
110
|
class VertexAISource(Source):
|
|
112
111
|
platform: str = "vertexai"
|
|
113
112
|
|
|
@@ -602,6 +601,7 @@ class VertexAISource(Source):
|
|
|
602
601
|
else None
|
|
603
602
|
),
|
|
604
603
|
customProperties=None,
|
|
604
|
+
externalUrl=self._make_model_external_url(model),
|
|
605
605
|
),
|
|
606
606
|
SubTypesClass(typeNames=[MLAssetSubTypes.VERTEX_MODEL_GROUP]),
|
|
607
607
|
ContainerClass(container=self._get_project_container().as_urn()),
|
datahub/metadata/schema.avsc
CHANGED
datahub/sdk/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@ import types
|
|
|
3
3
|
import datahub.metadata.schema_classes as models
|
|
4
4
|
from datahub.errors import SdkUsageError
|
|
5
5
|
from datahub.ingestion.graph.config import DatahubClientConfig
|
|
6
|
+
from datahub.ingestion.graph.filters import FilterOperator
|
|
6
7
|
from datahub.metadata.urns import (
|
|
7
8
|
ChartUrn,
|
|
8
9
|
ContainerUrn,
|
datahub/sdk/main_client.py
CHANGED
datahub/sdk/search_filters.py
CHANGED
|
@@ -14,7 +14,7 @@ import pydantic
|
|
|
14
14
|
from datahub.configuration.common import ConfigModel
|
|
15
15
|
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
|
|
16
16
|
from datahub.ingestion.graph.client import entity_type_to_graphql
|
|
17
|
-
from datahub.ingestion.graph.filters import SearchFilterRule
|
|
17
|
+
from datahub.ingestion.graph.filters import FilterOperator, SearchFilterRule
|
|
18
18
|
from datahub.metadata.schema_classes import EntityTypeName
|
|
19
19
|
from datahub.metadata.urns import DataPlatformUrn, DomainUrn
|
|
20
20
|
|
|
@@ -63,25 +63,19 @@ class _EntityTypeFilter(_BaseFilter):
|
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
class _EntitySubtypeFilter(_BaseFilter):
|
|
66
|
-
entity_type: str
|
|
67
66
|
entity_subtype: str = pydantic.Field(
|
|
68
67
|
description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.",
|
|
69
68
|
)
|
|
70
69
|
|
|
70
|
+
def _build_rule(self) -> SearchFilterRule:
|
|
71
|
+
return SearchFilterRule(
|
|
72
|
+
field="typeNames",
|
|
73
|
+
condition="EQUAL",
|
|
74
|
+
values=[self.entity_subtype],
|
|
75
|
+
)
|
|
76
|
+
|
|
71
77
|
def compile(self) -> _OrFilters:
|
|
72
|
-
|
|
73
|
-
SearchFilterRule(
|
|
74
|
-
field="_entityType",
|
|
75
|
-
condition="EQUAL",
|
|
76
|
-
values=[_flexible_entity_type_to_graphql(self.entity_type)],
|
|
77
|
-
),
|
|
78
|
-
SearchFilterRule(
|
|
79
|
-
field="typeNames",
|
|
80
|
-
condition="EQUAL",
|
|
81
|
-
values=[self.entity_subtype],
|
|
82
|
-
),
|
|
83
|
-
]
|
|
84
|
-
return [{"and": rules}]
|
|
78
|
+
return [{"and": [self._build_rule()]}]
|
|
85
79
|
|
|
86
80
|
|
|
87
81
|
class _PlatformFilter(_BaseFilter):
|
|
@@ -160,7 +154,7 @@ class _CustomCondition(_BaseFilter):
|
|
|
160
154
|
"""Represents a single field condition."""
|
|
161
155
|
|
|
162
156
|
field: str
|
|
163
|
-
condition:
|
|
157
|
+
condition: FilterOperator
|
|
164
158
|
values: List[str]
|
|
165
159
|
|
|
166
160
|
def compile(self) -> _OrFilters:
|
|
@@ -329,14 +323,15 @@ class FilterDsl:
|
|
|
329
323
|
)
|
|
330
324
|
|
|
331
325
|
@staticmethod
|
|
332
|
-
def entity_subtype(
|
|
326
|
+
def entity_subtype(
|
|
327
|
+
entity_subtype: Union[str, Sequence[str]],
|
|
328
|
+
) -> _EntitySubtypeFilter:
|
|
333
329
|
return _EntitySubtypeFilter(
|
|
334
|
-
|
|
335
|
-
entity_subtype=subtype,
|
|
330
|
+
entity_subtype=entity_subtype,
|
|
336
331
|
)
|
|
337
332
|
|
|
338
333
|
@staticmethod
|
|
339
|
-
def platform(platform: Union[str,
|
|
334
|
+
def platform(platform: Union[str, Sequence[str]], /) -> _PlatformFilter:
|
|
340
335
|
return _PlatformFilter(
|
|
341
336
|
platform=[platform] if isinstance(platform, str) else platform
|
|
342
337
|
)
|
|
@@ -344,11 +339,11 @@ class FilterDsl:
|
|
|
344
339
|
# TODO: Add a platform_instance filter
|
|
345
340
|
|
|
346
341
|
@staticmethod
|
|
347
|
-
def domain(domain: Union[str,
|
|
342
|
+
def domain(domain: Union[str, Sequence[str]], /) -> _DomainFilter:
|
|
348
343
|
return _DomainFilter(domain=[domain] if isinstance(domain, str) else domain)
|
|
349
344
|
|
|
350
345
|
@staticmethod
|
|
351
|
-
def env(env: Union[str,
|
|
346
|
+
def env(env: Union[str, Sequence[str]], /) -> _EnvFilter:
|
|
352
347
|
return _EnvFilter(env=[env] if isinstance(env, str) else env)
|
|
353
348
|
|
|
354
349
|
@staticmethod
|
|
@@ -365,7 +360,7 @@ class FilterDsl:
|
|
|
365
360
|
|
|
366
361
|
@staticmethod
|
|
367
362
|
def custom_filter(
|
|
368
|
-
field: str, condition:
|
|
363
|
+
field: str, condition: FilterOperator, values: Sequence[str]
|
|
369
364
|
) -> _CustomCondition:
|
|
370
365
|
return _CustomCondition(
|
|
371
366
|
field=field,
|
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import re
|
|
2
3
|
from enum import Enum
|
|
3
4
|
from typing import Iterator, List, Tuple
|
|
4
5
|
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
5
7
|
SELECT_KEYWORD = "SELECT"
|
|
6
8
|
CASE_KEYWORD = "CASE"
|
|
7
9
|
END_KEYWORD = "END"
|
|
@@ -120,7 +122,9 @@ class _StatementSplitter:
|
|
|
120
122
|
# Reset current_statement-specific state.
|
|
121
123
|
self.does_select_mean_new_statement = False
|
|
122
124
|
if self.current_case_statements != 0:
|
|
123
|
-
|
|
125
|
+
logger.warning(
|
|
126
|
+
f"Unexpected END keyword. Current case statements: {self.current_case_statements}"
|
|
127
|
+
)
|
|
124
128
|
self.current_case_statements = 0
|
|
125
129
|
|
|
126
130
|
def process(self) -> Iterator[str]:
|
|
@@ -233,8 +237,10 @@ class _StatementSplitter:
|
|
|
233
237
|
),
|
|
234
238
|
)
|
|
235
239
|
if (
|
|
236
|
-
is_force_new_statement_keyword
|
|
237
|
-
|
|
240
|
+
is_force_new_statement_keyword
|
|
241
|
+
and not self._has_preceding_cte(most_recent_real_char)
|
|
242
|
+
and not self._is_part_of_merge_query()
|
|
243
|
+
):
|
|
238
244
|
# Force termination of current statement
|
|
239
245
|
yield from self._yield_if_complete()
|
|
240
246
|
|
|
@@ -247,6 +253,14 @@ class _StatementSplitter:
|
|
|
247
253
|
else:
|
|
248
254
|
self.current_statement.append(c)
|
|
249
255
|
|
|
256
|
+
def _has_preceding_cte(self, most_recent_real_char: str) -> bool:
|
|
257
|
+
# usually we'd have a close paren that closes a CTE
|
|
258
|
+
return most_recent_real_char == ")"
|
|
259
|
+
|
|
260
|
+
def _is_part_of_merge_query(self) -> bool:
|
|
261
|
+
# In merge statement we'd have `when matched then` or `when not matched then"
|
|
262
|
+
return "".join(self.current_statement).strip().lower().endswith("then")
|
|
263
|
+
|
|
250
264
|
|
|
251
265
|
def split_statements(sql: str) -> Iterator[str]:
|
|
252
266
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|