pytrilogy 0.0.3.93__py3-none-any.whl → 0.0.3.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pytrilogy might be problematic. Click here for more details.

Files changed (39) hide show
  1. {pytrilogy-0.0.3.93.dist-info → pytrilogy-0.0.3.95.dist-info}/METADATA +170 -145
  2. {pytrilogy-0.0.3.93.dist-info → pytrilogy-0.0.3.95.dist-info}/RECORD +38 -34
  3. trilogy/__init__.py +1 -1
  4. trilogy/authoring/__init__.py +4 -0
  5. trilogy/core/enums.py +13 -0
  6. trilogy/core/env_processor.py +21 -10
  7. trilogy/core/environment_helpers.py +111 -0
  8. trilogy/core/exceptions.py +21 -1
  9. trilogy/core/functions.py +6 -1
  10. trilogy/core/graph_models.py +60 -67
  11. trilogy/core/internal.py +18 -0
  12. trilogy/core/models/author.py +16 -25
  13. trilogy/core/models/build.py +5 -4
  14. trilogy/core/models/core.py +3 -0
  15. trilogy/core/models/environment.py +28 -0
  16. trilogy/core/models/execute.py +7 -0
  17. trilogy/core/processing/node_generators/node_merge_node.py +30 -28
  18. trilogy/core/processing/node_generators/select_helpers/datasource_injection.py +25 -11
  19. trilogy/core/processing/node_generators/select_merge_node.py +68 -82
  20. trilogy/core/query_processor.py +2 -1
  21. trilogy/core/statements/author.py +18 -3
  22. trilogy/core/statements/common.py +0 -10
  23. trilogy/core/statements/execute.py +71 -16
  24. trilogy/core/validation/__init__.py +0 -0
  25. trilogy/core/validation/common.py +109 -0
  26. trilogy/core/validation/concept.py +122 -0
  27. trilogy/core/validation/datasource.py +192 -0
  28. trilogy/core/validation/environment.py +71 -0
  29. trilogy/dialect/base.py +40 -21
  30. trilogy/dialect/sql_server.py +3 -1
  31. trilogy/engine.py +25 -7
  32. trilogy/executor.py +145 -83
  33. trilogy/parsing/parse_engine.py +35 -4
  34. trilogy/parsing/trilogy.lark +11 -5
  35. trilogy/core/processing/node_generators/select_merge_node_v2.py +0 -792
  36. {pytrilogy-0.0.3.93.dist-info → pytrilogy-0.0.3.95.dist-info}/WHEEL +0 -0
  37. {pytrilogy-0.0.3.93.dist-info → pytrilogy-0.0.3.95.dist-info}/entry_points.txt +0 -0
  38. {pytrilogy-0.0.3.93.dist-info → pytrilogy-0.0.3.95.dist-info}/licenses/LICENSE.md +0 -0
  39. {pytrilogy-0.0.3.93.dist-info → pytrilogy-0.0.3.95.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+
4
+ from trilogy import Environment
5
+ from trilogy.authoring import ConceptRef
6
+ from trilogy.core.exceptions import ModelValidationError
7
+ from trilogy.core.models.build import (
8
+ BuildComparison,
9
+ BuildConcept,
10
+ BuildConditional,
11
+ BuildDatasource,
12
+ )
13
+ from trilogy.core.models.environment import EnvironmentConceptDict
14
+ from trilogy.core.models.execute import (
15
+ CTE,
16
+ QueryDatasource,
17
+ )
18
+ from trilogy.core.statements.execute import ProcessedQuery
19
+
20
+
21
+ class ExpectationType(Enum):
22
+ LOGICAL = "logical"
23
+ ROWCOUNT = "rowcount"
24
+ DATA_TYPE_LIST = "data_type_list"
25
+
26
+
27
+ @dataclass
28
+ class ValidationTest:
29
+ check_type: ExpectationType
30
+ query: str | None = None
31
+ expected: str | None = None
32
+ result: ModelValidationError | None = None
33
+ ran: bool = True
34
+
35
+
36
+ class ValidationType(Enum):
37
+ DATASOURCES = "datasources"
38
+ CONCEPTS = "concepts"
39
+
40
+
41
+ def easy_query(
42
+ concepts: list[BuildConcept],
43
+ datasource: BuildDatasource,
44
+ env: Environment,
45
+ condition: BuildConditional | BuildComparison | None = None,
46
+ limit: int = 100,
47
+ ):
48
+ """
49
+ Build basic datasource specific queries.
50
+ """
51
+ datasource_outputs = {c.address: c for c in datasource.concepts}
52
+ first_qds_concepts = datasource.concepts + concepts
53
+ root_qds = QueryDatasource(
54
+ input_concepts=first_qds_concepts,
55
+ output_concepts=concepts,
56
+ datasources=[datasource],
57
+ joins=[],
58
+ source_map={
59
+ concept.address: (
60
+ set([datasource]) if concept.address in datasource_outputs else set()
61
+ )
62
+ # include all base datasource conepts for convenience
63
+ for concept in first_qds_concepts
64
+ },
65
+ grain=datasource.grain,
66
+ )
67
+ cte = CTE(
68
+ name=f"datasource_{datasource.name}_base",
69
+ source=root_qds,
70
+ output_columns=concepts,
71
+ source_map={
72
+ concept.address: (
73
+ [datasource.safe_identifier]
74
+ if concept.address in datasource_outputs
75
+ else []
76
+ )
77
+ for concept in first_qds_concepts
78
+ },
79
+ grain=datasource.grain,
80
+ group_to_grain=True,
81
+ base_alias_override=datasource.safe_identifier,
82
+ )
83
+
84
+ filter_cte = CTE(
85
+ name=f"datasource_{datasource.name}_filter",
86
+ source=QueryDatasource(
87
+ datasources=[root_qds],
88
+ input_concepts=cte.output_columns,
89
+ output_concepts=cte.output_columns,
90
+ joins=[],
91
+ source_map={concept.address: (set([root_qds])) for concept in concepts},
92
+ grain=cte.grain,
93
+ ),
94
+ parent_ctes=[cte],
95
+ output_columns=cte.output_columns,
96
+ source_map={
97
+ concept.address: [cte.identifier] for concept in cte.output_columns
98
+ },
99
+ grain=cte.grain,
100
+ condition=condition,
101
+ limit=limit,
102
+ )
103
+
104
+ return ProcessedQuery(
105
+ output_columns=[ConceptRef(address=concept.address) for concept in concepts],
106
+ ctes=[cte, filter_cte],
107
+ base=cte,
108
+ local_concepts=EnvironmentConceptDict(**{}),
109
+ )
@@ -0,0 +1,122 @@
1
+ from trilogy import Executor
2
+ from trilogy.core.enums import Derivation, Purpose
3
+ from trilogy.core.exceptions import (
4
+ ConceptModelValidationError,
5
+ DatasourceModelValidationError,
6
+ )
7
+ from trilogy.core.models.build import (
8
+ BuildConcept,
9
+ )
10
+ from trilogy.core.models.build_environment import BuildEnvironment
11
+ from trilogy.core.validation.common import ExpectationType, ValidationTest, easy_query
12
+
13
+
14
+ def validate_property_concept(
15
+ concept: BuildConcept, generate_only: bool = False
16
+ ) -> list[ValidationTest]:
17
+ return []
18
+
19
+
20
+ def validate_key_concept(
21
+ concept: BuildConcept,
22
+ build_env: BuildEnvironment,
23
+ exec: Executor,
24
+ generate_only: bool = False,
25
+ ):
26
+ results: list[ValidationTest] = []
27
+ seen = {}
28
+ for datasource in build_env.datasources.values():
29
+ if concept.address in [c.address for c in datasource.concepts]:
30
+ assignment = [
31
+ x for x in datasource.columns if x.concept.address == concept.address
32
+ ][0]
33
+ type_query = easy_query(
34
+ concepts=[
35
+ # build_env.concepts[concept.address],
36
+ build_env.concepts[f"grain_check_{concept.safe_address}"],
37
+ ],
38
+ datasource=datasource,
39
+ env=exec.environment,
40
+ limit=1,
41
+ )
42
+ type_sql = exec.generate_sql(type_query)[-1]
43
+
44
+ rows = exec.execute_raw_sql(type_sql).fetchall()
45
+ if generate_only and assignment.is_complete:
46
+ results.append(
47
+ ValidationTest(
48
+ query=type_sql,
49
+ check_type=ExpectationType.ROWCOUNT,
50
+ expected=f"equal_max_{concept.safe_address}",
51
+ result=None,
52
+ ran=False,
53
+ )
54
+ )
55
+ continue
56
+ seen[datasource.name] = rows[0][0] if rows else None
57
+ if generate_only:
58
+ return results
59
+ max_seen = max([v for v in seen.values() if v is not None], default=0)
60
+ for datasource in build_env.datasources.values():
61
+ if concept.address in [c.address for c in datasource.concepts]:
62
+ assignment = [
63
+ x for x in datasource.columns if x.concept.address == concept.address
64
+ ][0]
65
+ err = None
66
+ if (seen[datasource.name] or 0) < max_seen and assignment.is_complete:
67
+ err = DatasourceModelValidationError(
68
+ f"Key concept {concept.address} is missing values in datasource {datasource.name} (max cardinality in data {max_seen}, datasource has {seen[datasource.name]} values) but is not marked as partial."
69
+ )
70
+ results.append(
71
+ ValidationTest(
72
+ query=None,
73
+ check_type=ExpectationType.ROWCOUNT,
74
+ expected=str(max_seen),
75
+ result=err,
76
+ ran=True,
77
+ )
78
+ )
79
+
80
+ return results
81
+
82
+
83
+ def validate_datasources(
84
+ concept: BuildConcept, build_env: BuildEnvironment
85
+ ) -> list[ValidationTest]:
86
+ if concept.lineage:
87
+ return []
88
+ for datasource in build_env.datasources.values():
89
+ if concept.address in [c.address for c in datasource.concepts]:
90
+ return []
91
+ if not concept.derivation == Derivation.ROOT:
92
+ return []
93
+ if concept.name.startswith("__") or (
94
+ concept.namespace and concept.namespace.startswith("__")
95
+ ):
96
+ return []
97
+ return [
98
+ ValidationTest(
99
+ query=None,
100
+ check_type=ExpectationType.LOGICAL,
101
+ expected=None,
102
+ result=ConceptModelValidationError(
103
+ f"Concept {concept.address} is a root concept but has no datasources bound"
104
+ ),
105
+ ran=True,
106
+ )
107
+ ]
108
+
109
+
110
+ def validate_concept(
111
+ concept: BuildConcept,
112
+ build_env: BuildEnvironment,
113
+ exec: Executor,
114
+ generate_only: bool = False,
115
+ ) -> list[ValidationTest]:
116
+ base: list[ValidationTest] = []
117
+ base += validate_datasources(concept, build_env)
118
+ if concept.purpose == Purpose.PROPERTY:
119
+ base += validate_property_concept(concept, generate_only)
120
+ elif concept.purpose == Purpose.KEY:
121
+ base += validate_key_concept(concept, build_env, exec, generate_only)
122
+ return base
@@ -0,0 +1,192 @@
1
+ from datetime import date, datetime
2
+ from decimal import Decimal
3
+ from typing import Any
4
+
5
+ from trilogy import Executor
6
+ from trilogy.authoring import (
7
+ ArrayType,
8
+ DataType,
9
+ MapType,
10
+ NumericType,
11
+ StructType,
12
+ TraitDataType,
13
+ )
14
+ from trilogy.core.enums import ComparisonOperator
15
+ from trilogy.core.exceptions import DatasourceModelValidationError
16
+ from trilogy.core.models.build import (
17
+ BuildComparison,
18
+ BuildDatasource,
19
+ )
20
+ from trilogy.core.models.build_environment import BuildEnvironment
21
+ from trilogy.core.validation.common import ExpectationType, ValidationTest, easy_query
22
+ from trilogy.utility import unique
23
+
24
+
25
+ def type_check(
26
+ input: Any,
27
+ expected_type: (
28
+ DataType | ArrayType | StructType | MapType | NumericType | TraitDataType
29
+ ),
30
+ nullable: bool = True,
31
+ ) -> bool:
32
+ if input is None and nullable:
33
+ return True
34
+ target_type = expected_type
35
+ while isinstance(target_type, TraitDataType):
36
+ return type_check(input, target_type.data_type, nullable)
37
+ if target_type == DataType.STRING:
38
+ return isinstance(input, str)
39
+ if target_type == DataType.INTEGER:
40
+ return isinstance(input, int)
41
+ if target_type == DataType.FLOAT or isinstance(target_type, NumericType):
42
+ return (
43
+ isinstance(input, float)
44
+ or isinstance(input, int)
45
+ or isinstance(input, Decimal)
46
+ )
47
+ if target_type == DataType.BOOL:
48
+ return isinstance(input, bool)
49
+ if target_type == DataType.DATE:
50
+ return isinstance(input, date)
51
+ if target_type == DataType.DATETIME:
52
+ return isinstance(input, datetime)
53
+ if target_type == DataType.ARRAY or isinstance(target_type, ArrayType):
54
+ return isinstance(input, list)
55
+ if target_type == DataType.MAP or isinstance(target_type, MapType):
56
+ return isinstance(input, dict)
57
+ if target_type == DataType.STRUCT or isinstance(target_type, StructType):
58
+ return isinstance(input, dict)
59
+ return False
60
+
61
+
62
+ def validate_datasource(
63
+ datasource: BuildDatasource,
64
+ build_env: BuildEnvironment,
65
+ exec: Executor,
66
+ generate_only: bool = False,
67
+ ) -> list[ValidationTest]:
68
+ results: list[ValidationTest] = []
69
+ # we might have merged concepts, where both wil lmap out to the same
70
+ unique_outputs = unique(
71
+ [build_env.concepts[col.concept.address] for col in datasource.columns],
72
+ "address",
73
+ )
74
+ type_query = easy_query(
75
+ concepts=unique_outputs,
76
+ datasource=datasource,
77
+ env=exec.environment,
78
+ limit=100,
79
+ )
80
+ type_sql = exec.generate_sql(type_query)[-1]
81
+ rows = []
82
+ if not generate_only:
83
+ try:
84
+ rows = exec.execute_raw_sql(type_sql).fetchall()
85
+ except Exception as e:
86
+ results.append(
87
+ ValidationTest(
88
+ query=type_sql,
89
+ check_type=ExpectationType.LOGICAL,
90
+ expected="valid_sql",
91
+ result=DatasourceModelValidationError(
92
+ f"Datasource {datasource.name} failed validation. Error executing type query {type_sql}: {e}"
93
+ ),
94
+ ran=True,
95
+ )
96
+ )
97
+ return results
98
+ else:
99
+ results.append(
100
+ ValidationTest(
101
+ query=type_sql,
102
+ check_type=ExpectationType.LOGICAL,
103
+ expected="datatype_match",
104
+ result=None,
105
+ ran=False,
106
+ )
107
+ )
108
+ return results
109
+ failures: list[
110
+ tuple[
111
+ str,
112
+ Any,
113
+ DataType | ArrayType | StructType | MapType | NumericType | TraitDataType,
114
+ bool,
115
+ ]
116
+ ] = []
117
+ cols_with_error = set()
118
+ for row in rows:
119
+ for col in datasource.columns:
120
+
121
+ actual_address = build_env.concepts[col.concept.address].safe_address
122
+ if actual_address in cols_with_error:
123
+ continue
124
+ rval = row[actual_address]
125
+ passed = type_check(rval, col.concept.datatype, col.is_nullable)
126
+ if not passed:
127
+ failures.append(
128
+ (
129
+ col.concept.address,
130
+ rval,
131
+ col.concept.datatype,
132
+ col.is_nullable,
133
+ )
134
+ )
135
+ cols_with_error.add(actual_address)
136
+
137
+ def format_failure(failure):
138
+ return f"Concept {failure[0]} value '{failure[1]}' does not conform to expected type {str(failure[2])} (nullable={failure[3]})"
139
+
140
+ if failures:
141
+ results.append(
142
+ ValidationTest(
143
+ query=None,
144
+ check_type=ExpectationType.LOGICAL,
145
+ expected="datatype_match",
146
+ ran=True,
147
+ result=DatasourceModelValidationError(
148
+ f"Datasource {datasource.name} failed validation. Found rows that do not conform to types: {[format_failure(failure) for failure in failures]}",
149
+ ),
150
+ )
151
+ )
152
+
153
+ query = easy_query(
154
+ concepts=[build_env.concepts[name] for name in datasource.grain.components]
155
+ + [build_env.concepts["grain_check"]],
156
+ datasource=datasource,
157
+ env=exec.environment,
158
+ condition=BuildComparison(
159
+ left=build_env.concepts["grain_check"],
160
+ right=1,
161
+ operator=ComparisonOperator.GT,
162
+ ),
163
+ )
164
+ if generate_only:
165
+ results.append(
166
+ ValidationTest(
167
+ query=exec.generate_sql(query)[-1],
168
+ check_type=ExpectationType.ROWCOUNT,
169
+ expected="0",
170
+ result=None,
171
+ ran=False,
172
+ )
173
+ )
174
+
175
+ else:
176
+ sql = exec.generate_sql(query)[-1]
177
+
178
+ rows = exec.execute_raw_sql(sql).fetchmany(10)
179
+ if rows:
180
+ results.append(
181
+ ValidationTest(
182
+ query=sql,
183
+ check_type=ExpectationType.ROWCOUNT,
184
+ expected="0",
185
+ result=DatasourceModelValidationError(
186
+ f"Datasource {datasource.name} failed validation. Found rows that do not conform to grain: {rows}"
187
+ ),
188
+ ran=True,
189
+ )
190
+ )
191
+
192
+ return results
@@ -0,0 +1,71 @@
1
+ from trilogy import Environment, Executor
2
+ from trilogy.authoring import DataType, Function
3
+ from trilogy.core.enums import FunctionType, Purpose, ValidationScope
4
+ from trilogy.core.exceptions import (
5
+ ModelValidationError,
6
+ )
7
+ from trilogy.core.validation.common import ValidationTest
8
+ from trilogy.core.validation.concept import validate_concept
9
+ from trilogy.core.validation.datasource import validate_datasource
10
+ from trilogy.parsing.common import function_to_concept
11
+
12
+
13
+ def validate_environment(
14
+ env: Environment,
15
+ exec: Executor,
16
+ scope: ValidationScope = ValidationScope.ALL,
17
+ targets: list[str] | None = None,
18
+ generate_only: bool = False,
19
+ ) -> list[ValidationTest]:
20
+ # avoid mutating the environment for validation
21
+ env = env.duplicate()
22
+ grain_check = function_to_concept(
23
+ parent=Function(
24
+ operator=FunctionType.SUM,
25
+ arguments=[1],
26
+ output_datatype=DataType.INTEGER,
27
+ output_purpose=Purpose.METRIC,
28
+ ),
29
+ name="grain_check",
30
+ environment=env,
31
+ )
32
+ env.add_concept(grain_check)
33
+ new_concepts = []
34
+ for concept in env.concepts.values():
35
+ concept_grain_check = function_to_concept(
36
+ parent=Function(
37
+ operator=FunctionType.COUNT_DISTINCT,
38
+ arguments=[concept.reference],
39
+ output_datatype=DataType.INTEGER,
40
+ output_purpose=Purpose.METRIC,
41
+ ),
42
+ name=f"grain_check_{concept.safe_address}",
43
+ environment=env,
44
+ )
45
+ new_concepts.append(concept_grain_check)
46
+ for concept in new_concepts:
47
+ env.add_concept(concept)
48
+ build_env = env.materialize_for_select()
49
+ results: list[ValidationTest] = []
50
+ if scope == ValidationScope.ALL or scope == ValidationScope.DATASOURCES:
51
+ for datasource in build_env.datasources.values():
52
+ if targets and datasource.name not in targets:
53
+ continue
54
+ results += validate_datasource(datasource, build_env, exec, generate_only)
55
+ if scope == ValidationScope.ALL or scope == ValidationScope.CONCEPTS:
56
+
57
+ for bconcept in build_env.concepts.values():
58
+ if targets and bconcept.address not in targets:
59
+ continue
60
+ results += validate_concept(bconcept, build_env, exec, generate_only)
61
+
62
+ # raise a nicely formatted union of all exceptions
63
+ exceptions: list[ModelValidationError] = [e.result for e in results if e.result]
64
+ if exceptions:
65
+ if not generate_only:
66
+ messages = "\n".join([str(e) for e in exceptions])
67
+ raise ModelValidationError(
68
+ f"Environment validation failed with the following errors:\n{messages}",
69
+ children=exceptions,
70
+ )
71
+ return results
trilogy/dialect/base.py CHANGED
@@ -72,14 +72,16 @@ from trilogy.core.statements.author import (
72
72
  RowsetDerivationStatement,
73
73
  SelectStatement,
74
74
  ShowStatement,
75
+ ValidateStatement,
75
76
  )
76
77
  from trilogy.core.statements.execute import (
77
- ProcessedCopyStatement,
78
+ PROCESSED_STATEMENT_TYPES,
78
79
  ProcessedQuery,
79
80
  ProcessedQueryPersist,
80
81
  ProcessedRawSQLStatement,
81
82
  ProcessedShowStatement,
82
83
  ProcessedStaticValueOutput,
84
+ ProcessedValidateStatement,
83
85
  )
84
86
  from trilogy.core.utility import safe_quote
85
87
  from trilogy.dialect.common import render_join, render_unnest
@@ -1025,21 +1027,11 @@ class BaseDialect:
1025
1027
  | RawSQLStatement
1026
1028
  | MergeStatementV2
1027
1029
  | CopyStatement
1030
+ | ValidateStatement
1028
1031
  ],
1029
1032
  hooks: Optional[List[BaseHook]] = None,
1030
- ) -> List[
1031
- ProcessedQuery
1032
- | ProcessedQueryPersist
1033
- | ProcessedShowStatement
1034
- | ProcessedRawSQLStatement
1035
- ]:
1036
- output: List[
1037
- ProcessedQuery
1038
- | ProcessedQueryPersist
1039
- | ProcessedShowStatement
1040
- | ProcessedRawSQLStatement
1041
- | ProcessedCopyStatement
1042
- ] = []
1033
+ ) -> List[PROCESSED_STATEMENT_TYPES]:
1034
+ output: List[PROCESSED_STATEMENT_TYPES] = []
1043
1035
  for statement in statements:
1044
1036
  if isinstance(statement, PersistStatement):
1045
1037
  if hooks:
@@ -1089,10 +1081,39 @@ class BaseDialect:
1089
1081
  output.append(
1090
1082
  self.create_show_output(environment, statement.content)
1091
1083
  )
1084
+ elif isinstance(statement.content, ValidateStatement):
1085
+ output.append(
1086
+ ProcessedShowStatement(
1087
+ output_columns=[
1088
+ environment.concepts[
1089
+ DEFAULT_CONCEPTS["label"].address
1090
+ ].reference,
1091
+ environment.concepts[
1092
+ DEFAULT_CONCEPTS["query_text"].address
1093
+ ].reference,
1094
+ environment.concepts[
1095
+ DEFAULT_CONCEPTS["expected"].address
1096
+ ].reference,
1097
+ ],
1098
+ output_values=[
1099
+ ProcessedValidateStatement(
1100
+ scope=statement.content.scope,
1101
+ targets=statement.content.targets,
1102
+ )
1103
+ ],
1104
+ )
1105
+ )
1092
1106
  else:
1093
1107
  raise NotImplementedError(type(statement.content))
1094
1108
  elif isinstance(statement, RawSQLStatement):
1095
1109
  output.append(ProcessedRawSQLStatement(text=statement.text))
1110
+ elif isinstance(statement, ValidateStatement):
1111
+ output.append(
1112
+ ProcessedValidateStatement(
1113
+ scope=statement.scope,
1114
+ targets=statement.targets,
1115
+ )
1116
+ )
1096
1117
  elif isinstance(
1097
1118
  statement,
1098
1119
  (
@@ -1111,18 +1132,16 @@ class BaseDialect:
1111
1132
 
1112
1133
  def compile_statement(
1113
1134
  self,
1114
- query: (
1115
- ProcessedQuery
1116
- | ProcessedQueryPersist
1117
- | ProcessedShowStatement
1118
- | ProcessedRawSQLStatement
1119
- ),
1135
+ query: PROCESSED_STATEMENT_TYPES,
1120
1136
  ) -> str:
1121
1137
  if isinstance(query, ProcessedShowStatement):
1122
1138
  return ";\n".join([str(x) for x in query.output_values])
1123
1139
  elif isinstance(query, ProcessedRawSQLStatement):
1124
1140
  return query.text
1125
1141
 
1142
+ elif isinstance(query, ProcessedValidateStatement):
1143
+ return "select 1;"
1144
+
1126
1145
  recursive = any(isinstance(x, RecursiveCTE) for x in query.ctes)
1127
1146
 
1128
1147
  compiled_ctes = self.generate_ctes(query)
@@ -1139,7 +1158,7 @@ class BaseDialect:
1139
1158
  if CONFIG.strict_mode and INVALID_REFERENCE_STRING(1) in final:
1140
1159
  raise ValueError(
1141
1160
  f"Invalid reference string found in query: {final}, this should never"
1142
- " occur. Please create a GitHub issue to report this."
1161
+ " occur. Please create an issue to report this."
1143
1162
  )
1144
1163
  logger.info(f"{LOGGER_PREFIX} Compiled query: {final}")
1145
1164
  return final
@@ -8,6 +8,7 @@ from trilogy.core.statements.execute import (
8
8
  ProcessedQueryPersist,
9
9
  ProcessedRawSQLStatement,
10
10
  ProcessedShowStatement,
11
+ ProcessedValidateStatement,
11
12
  )
12
13
  from trilogy.dialect.base import BaseDialect
13
14
  from trilogy.utility import string_to_hash
@@ -90,10 +91,11 @@ class SqlServerDialect(BaseDialect):
90
91
  | ProcessedQueryPersist
91
92
  | ProcessedShowStatement
92
93
  | ProcessedRawSQLStatement
94
+ | ProcessedValidateStatement
93
95
  ),
94
96
  ) -> str:
95
97
  base = super().compile_statement(query)
96
- if isinstance(base, (ProcessedQuery, ProcessedQueryPersist)):
98
+ if isinstance(query, (ProcessedQuery, ProcessedQueryPersist)):
97
99
  for cte in query.ctes:
98
100
  if len(cte.name) > MAX_IDENTIFIER_LENGTH:
99
101
  new_name = f"rhash_{string_to_hash(cte.name)}"
trilogy/engine.py CHANGED
@@ -1,21 +1,27 @@
1
- from typing import Any, Protocol
1
+ from typing import Any, Generator, List, Optional, Protocol
2
2
 
3
3
  from sqlalchemy.engine import Connection, CursorResult, Engine
4
4
 
5
5
  from trilogy.core.models.environment import Environment
6
6
 
7
7
 
8
- class EngineResult(Protocol):
9
- pass
8
+ class ResultProtocol(Protocol):
10
9
 
11
- def fetchall(self) -> list[tuple]:
12
- pass
10
+ def fetchall(self) -> List[Any]: ...
11
+
12
+ def keys(self) -> List[str]: ...
13
+
14
+ def fetchone(self) -> Optional[Any]: ...
15
+
16
+ def fetchmany(self, size: int) -> List[Any]: ...
17
+
18
+ def __iter__(self) -> Generator[Any, None, None]: ...
13
19
 
14
20
 
15
21
  class EngineConnection(Protocol):
16
22
  pass
17
23
 
18
- def execute(self, statement: str, parameters: Any | None = None) -> EngineResult:
24
+ def execute(self, statement: str, parameters: Any | None = None) -> ResultProtocol:
19
25
  pass
20
26
 
21
27
  def commit(self):
@@ -39,13 +45,25 @@ class ExecutionEngine(Protocol):
39
45
 
40
46
 
41
47
  ### Begin default SQLAlchemy implementation
42
- class SqlAlchemyResult(EngineResult):
48
+ class SqlAlchemyResult:
43
49
  def __init__(self, result: CursorResult):
44
50
  self.result = result
45
51
 
46
52
  def fetchall(self):
47
53
  return self.result.fetchall()
48
54
 
55
+ def keys(self):
56
+ return self.result.keys()
57
+
58
+ def fetchone(self):
59
+ return self.result.fetchone()
60
+
61
+ def fetchmany(self, size: int):
62
+ return self.result.fetchmany(size)
63
+
64
+ def __iter__(self):
65
+ return iter(self.result)
66
+
49
67
 
50
68
  class SqlAlchemyConnection(EngineConnection):
51
69
  def __init__(self, connection: Connection):