sqlmesh 0.225.1.dev26__py3-none-any.whl → 0.228.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlmesh might be problematic. Click here for more details.

Files changed (34) hide show
  1. sqlmesh/_version.py +2 -2
  2. sqlmesh/core/config/connection.py +37 -1
  3. sqlmesh/core/context.py +60 -10
  4. sqlmesh/core/dialect.py +10 -2
  5. sqlmesh/core/engine_adapter/base.py +8 -1
  6. sqlmesh/core/engine_adapter/databricks.py +33 -16
  7. sqlmesh/core/engine_adapter/fabric.py +110 -2
  8. sqlmesh/core/engine_adapter/trino.py +44 -6
  9. sqlmesh/core/lineage.py +1 -0
  10. sqlmesh/core/linter/rules/builtin.py +15 -0
  11. sqlmesh/core/loader.py +17 -30
  12. sqlmesh/core/model/definition.py +9 -0
  13. sqlmesh/core/plan/definition.py +9 -7
  14. sqlmesh/core/renderer.py +7 -8
  15. sqlmesh/core/scheduler.py +45 -15
  16. sqlmesh/core/signal.py +35 -14
  17. sqlmesh/core/snapshot/definition.py +18 -12
  18. sqlmesh/core/snapshot/evaluator.py +24 -16
  19. sqlmesh/core/test/definition.py +5 -5
  20. sqlmesh/core/test/discovery.py +4 -0
  21. sqlmesh/dbt/common.py +4 -2
  22. sqlmesh/dbt/manifest.py +3 -1
  23. sqlmesh/integrations/github/cicd/command.py +11 -2
  24. sqlmesh/integrations/github/cicd/controller.py +6 -2
  25. sqlmesh/lsp/context.py +4 -2
  26. sqlmesh/magics.py +1 -1
  27. sqlmesh/utils/date.py +1 -1
  28. sqlmesh/utils/git.py +3 -1
  29. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/METADATA +3 -3
  30. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/RECORD +34 -34
  31. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/WHEEL +0 -0
  32. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/entry_points.txt +0 -0
  33. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/licenses/LICENSE +0 -0
  34. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/top_level.txt +0 -0
sqlmesh/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.225.1.dev26'
32
- __version_tuple__ = version_tuple = (0, 225, 1, 'dev26')
31
+ __version__ = version = '0.228.2'
32
+ __version_tuple__ = version_tuple = (0, 228, 2)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -17,6 +17,7 @@ from pydantic_core import from_json
17
17
  from packaging import version
18
18
  from sqlglot import exp
19
19
  from sqlglot.helper import subclasses
20
+ from sqlglot.errors import ParseError
20
21
 
21
22
  from sqlmesh.core import engine_adapter
22
23
  from sqlmesh.core.config.base import BaseConfig
@@ -238,6 +239,7 @@ class DuckDBAttachOptions(BaseConfig):
238
239
  data_path: t.Optional[str] = None
239
240
  encrypted: bool = False
240
241
  data_inlining_row_limit: t.Optional[int] = None
242
+ metadata_schema: t.Optional[str] = None
241
243
 
242
244
  def to_sql(self, alias: str) -> str:
243
245
  options = []
@@ -259,6 +261,8 @@ class DuckDBAttachOptions(BaseConfig):
259
261
  options.append("ENCRYPTED")
260
262
  if self.data_inlining_row_limit is not None:
261
263
  options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}")
264
+ if self.metadata_schema is not None:
265
+ options.append(f"METADATA_SCHEMA '{self.metadata_schema}'")
262
266
 
263
267
  options_sql = f" ({', '.join(options)})" if options else ""
264
268
  alias_sql = ""
@@ -1887,6 +1891,7 @@ class TrinoConnectionConfig(ConnectionConfig):
1887
1891
 
1888
1892
  # SQLMesh options
1889
1893
  schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None
1894
+ timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None
1890
1895
  concurrent_tasks: int = 4
1891
1896
  register_comments: bool = True
1892
1897
  pre_ping: t.Literal[False] = False
@@ -1911,6 +1916,34 @@ class TrinoConnectionConfig(ConnectionConfig):
1911
1916
  )
1912
1917
  return compiled
1913
1918
 
1919
+ @field_validator("timestamp_mapping", mode="before")
1920
+ @classmethod
1921
+ def _validate_timestamp_mapping(
1922
+ cls, value: t.Optional[dict[str, str]]
1923
+ ) -> t.Optional[dict[exp.DataType, exp.DataType]]:
1924
+ if value is None:
1925
+ return value
1926
+
1927
+ result: dict[exp.DataType, exp.DataType] = {}
1928
+ for source_type, target_type in value.items():
1929
+ try:
1930
+ source_datatype = exp.DataType.build(source_type)
1931
+ except ParseError:
1932
+ raise ConfigError(
1933
+ f"Invalid SQL type string in timestamp_mapping: "
1934
+ f"'{source_type}' is not a valid SQL data type."
1935
+ )
1936
+ try:
1937
+ target_datatype = exp.DataType.build(target_type)
1938
+ except ParseError:
1939
+ raise ConfigError(
1940
+ f"Invalid SQL type string in timestamp_mapping: "
1941
+ f"'{target_type}' is not a valid SQL data type."
1942
+ )
1943
+ result[source_datatype] = target_datatype
1944
+
1945
+ return result
1946
+
1914
1947
  @model_validator(mode="after")
1915
1948
  def _root_validator(self) -> Self:
1916
1949
  port = self.port
@@ -2013,7 +2046,10 @@ class TrinoConnectionConfig(ConnectionConfig):
2013
2046
 
2014
2047
  @property
2015
2048
  def _extra_engine_config(self) -> t.Dict[str, t.Any]:
2016
- return {"schema_location_mapping": self.schema_location_mapping}
2049
+ return {
2050
+ "schema_location_mapping": self.schema_location_mapping,
2051
+ "timestamp_mapping": self.timestamp_mapping,
2052
+ }
2017
2053
 
2018
2054
 
2019
2055
  class ClickhouseConnectionConfig(ConnectionConfig):
sqlmesh/core/context.py CHANGED
@@ -115,6 +115,7 @@ from sqlmesh.core.test import (
115
115
  ModelTestMetadata,
116
116
  generate_test,
117
117
  run_tests,
118
+ filter_tests_by_patterns,
118
119
  )
119
120
  from sqlmesh.core.user import User
120
121
  from sqlmesh.utils import UniqueKeyDict, Verbosity
@@ -154,6 +155,8 @@ if t.TYPE_CHECKING:
154
155
  )
155
156
  from sqlmesh.core.snapshot import Node
156
157
 
158
+ from sqlmesh.core.snapshot.definition import Intervals
159
+
157
160
  ModelOrSnapshot = t.Union[str, Model, Snapshot]
158
161
  NodeOrSnapshot = t.Union[str, Model, StandaloneAudit, Snapshot]
159
162
 
@@ -276,6 +279,7 @@ class ExecutionContext(BaseContext):
276
279
  default_dialect: t.Optional[str] = None,
277
280
  default_catalog: t.Optional[str] = None,
278
281
  is_restatement: t.Optional[bool] = None,
282
+ parent_intervals: t.Optional[Intervals] = None,
279
283
  variables: t.Optional[t.Dict[str, t.Any]] = None,
280
284
  blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
281
285
  ):
@@ -287,6 +291,7 @@ class ExecutionContext(BaseContext):
287
291
  self._variables = variables or {}
288
292
  self._blueprint_variables = blueprint_variables or {}
289
293
  self._is_restatement = is_restatement
294
+ self._parent_intervals = parent_intervals
290
295
 
291
296
  @property
292
297
  def default_dialect(self) -> t.Optional[str]:
@@ -315,6 +320,10 @@ class ExecutionContext(BaseContext):
315
320
  def is_restatement(self) -> t.Optional[bool]:
316
321
  return self._is_restatement
317
322
 
323
+ @property
324
+ def parent_intervals(self) -> t.Optional[Intervals]:
325
+ return self._parent_intervals
326
+
318
327
  def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
319
328
  """Returns a variable value."""
320
329
  return self._variables.get(var_name.lower(), default)
@@ -390,6 +399,11 @@ class GenericContext(BaseContext, t.Generic[C]):
390
399
  self._standalone_audits: UniqueKeyDict[str, StandaloneAudit] = UniqueKeyDict(
391
400
  "standaloneaudits"
392
401
  )
402
+ self._model_test_metadata: t.List[ModelTestMetadata] = []
403
+ self._model_test_metadata_path_index: t.Dict[Path, t.List[ModelTestMetadata]] = {}
404
+ self._model_test_metadata_fully_qualified_name_index: t.Dict[str, ModelTestMetadata] = {}
405
+ self._models_with_tests: t.Set[str] = set()
406
+
393
407
  self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros")
394
408
  self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics")
395
409
  self._jinja_macros = JinjaMacroRegistry()
@@ -628,6 +642,10 @@ class GenericContext(BaseContext, t.Generic[C]):
628
642
  self._excluded_requirements.clear()
629
643
  self._linters.clear()
630
644
  self._environment_statements = []
645
+ self._model_test_metadata.clear()
646
+ self._model_test_metadata_path_index.clear()
647
+ self._model_test_metadata_fully_qualified_name_index.clear()
648
+ self._models_with_tests.clear()
631
649
 
632
650
  for loader, project in zip(self._loaders, loaded_projects):
633
651
  self._jinja_macros = self._jinja_macros.merge(project.jinja_macros)
@@ -640,6 +658,16 @@ class GenericContext(BaseContext, t.Generic[C]):
640
658
  self._excluded_requirements.update(project.excluded_requirements)
641
659
  self._environment_statements.extend(project.environment_statements)
642
660
 
661
+ self._model_test_metadata.extend(project.model_test_metadata)
662
+ for metadata in project.model_test_metadata:
663
+ if metadata.path not in self._model_test_metadata_path_index:
664
+ self._model_test_metadata_path_index[metadata.path] = []
665
+ self._model_test_metadata_path_index[metadata.path].append(metadata)
666
+ self._model_test_metadata_fully_qualified_name_index[
667
+ metadata.fully_qualified_test_name
668
+ ] = metadata
669
+ self._models_with_tests.add(metadata.model_name)
670
+
643
671
  config = loader.config
644
672
  self._linters[config.project] = Linter.from_rules(
645
673
  BUILTIN_RULES.union(project.user_rules), config.linter
@@ -1041,6 +1069,11 @@ class GenericContext(BaseContext, t.Generic[C]):
1041
1069
  """Returns all registered standalone audits in this context."""
1042
1070
  return MappingProxyType(self._standalone_audits)
1043
1071
 
1072
+ @property
1073
+ def models_with_tests(self) -> t.Set[str]:
1074
+ """Returns all models with tests in this context."""
1075
+ return self._models_with_tests
1076
+
1044
1077
  @property
1045
1078
  def snapshots(self) -> t.Dict[str, Snapshot]:
1046
1079
  """Generates and returns snapshots based on models registered in this context.
@@ -2212,7 +2245,7 @@ class GenericContext(BaseContext, t.Generic[C]):
2212
2245
 
2213
2246
  pd.set_option("display.max_columns", None)
2214
2247
 
2215
- test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
2248
+ test_meta = self.select_tests(tests=tests, patterns=match_patterns)
2216
2249
 
2217
2250
  result = run_tests(
2218
2251
  model_test_metadata=test_meta,
@@ -2271,6 +2304,7 @@ class GenericContext(BaseContext, t.Generic[C]):
2271
2304
  snapshot=snapshot,
2272
2305
  start=start,
2273
2306
  end=end,
2307
+ execution_time=execution_time,
2274
2308
  snapshots=self.snapshots,
2275
2309
  ):
2276
2310
  audit_id = f"{audit_result.audit.name}"
@@ -3184,18 +3218,34 @@ class GenericContext(BaseContext, t.Generic[C]):
3184
3218
 
3185
3219
  return all_violations
3186
3220
 
3187
- def load_model_tests(
3188
- self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
3221
+ def select_tests(
3222
+ self,
3223
+ tests: t.Optional[t.List[str]] = None,
3224
+ patterns: t.Optional[t.List[str]] = None,
3189
3225
  ) -> t.List[ModelTestMetadata]:
3190
- # If a set of specific test path(s) are provided, we can use a single loader
3191
- # since it's not required to walk every tests/ folder in each repo
3192
- loaders = [self._loaders[0]] if tests else self._loaders
3226
+ """Filter pre-loaded test metadata based on tests and patterns."""
3227
+
3228
+ test_meta = self._model_test_metadata
3229
+
3230
+ if tests:
3231
+ filtered_tests = []
3232
+ for test in tests:
3233
+ if "::" in test:
3234
+ if test in self._model_test_metadata_fully_qualified_name_index:
3235
+ filtered_tests.append(
3236
+ self._model_test_metadata_fully_qualified_name_index[test]
3237
+ )
3238
+ else:
3239
+ test_path = Path(test)
3240
+ if test_path in self._model_test_metadata_path_index:
3241
+ filtered_tests.extend(self._model_test_metadata_path_index[test_path])
3242
+
3243
+ test_meta = filtered_tests
3193
3244
 
3194
- model_tests = []
3195
- for loader in loaders:
3196
- model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns))
3245
+ if patterns:
3246
+ test_meta = filter_tests_by_patterns(test_meta, patterns)
3197
3247
 
3198
- return model_tests
3248
+ return test_meta
3199
3249
 
3200
3250
 
3201
3251
  class Context(GenericContext[Config]):
sqlmesh/core/dialect.py CHANGED
@@ -803,8 +803,15 @@ def text_diff(
803
803
  return "\n".join(unified_diff(a_sql, b_sql))
804
804
 
805
805
 
806
+ WS_OR_COMMENT = r"(?:\s|--[^\n]*\n|/\*.*?\*/)"
807
+ HEADER = r"\b(?:model|audit)\b(?=\s*\()"
808
+ KEY_BOUNDARY = r"(?:\(|,)" # key is preceded by either '(' or ','
809
+ DIALECT_VALUE = r"['\"]?(?P<dialect>[a-z][a-z0-9]*)['\"]?"
810
+ VALUE_BOUNDARY = r"(?=,|\))" # value is followed by comma or closing paren
811
+
806
812
  DIALECT_PATTERN = re.compile(
807
- r"(model|audit).*?\(.*?dialect\s+'?([a-z]*)", re.IGNORECASE | re.DOTALL
813
+ rf"{HEADER}.*?{KEY_BOUNDARY}{WS_OR_COMMENT}*dialect{WS_OR_COMMENT}+{DIALECT_VALUE}{WS_OR_COMMENT}*{VALUE_BOUNDARY}",
814
+ re.IGNORECASE | re.DOTALL,
808
815
  )
809
816
 
810
817
 
@@ -895,7 +902,8 @@ def parse(
895
902
  A list of the parsed expressions: [Model, *Statements, Query, *Statements]
896
903
  """
897
904
  match = match_dialect and DIALECT_PATTERN.search(sql[:MAX_MODEL_DEFINITION_SIZE])
898
- dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect)
905
+ dialect_str = match.group("dialect") if match else None
906
+ dialect = Dialect.get_or_raise(dialect_str or default_dialect)
899
907
 
900
908
  tokens = dialect.tokenize(sql)
901
909
  chunks: t.List[t.Tuple[t.List[Token], ChunkType]] = [([], ChunkType.SQL)]
@@ -811,6 +811,7 @@ class EngineAdapter:
811
811
  column_descriptions: t.Optional[t.Dict[str, str]] = None,
812
812
  expressions: t.Optional[t.List[exp.PrimaryKey]] = None,
813
813
  is_view: bool = False,
814
+ materialized: bool = False,
814
815
  ) -> exp.Schema:
815
816
  """
816
817
  Build a schema expression for a table, columns, column comments, and additional schema properties.
@@ -823,6 +824,7 @@ class EngineAdapter:
823
824
  target_columns_to_types=target_columns_to_types,
824
825
  column_descriptions=column_descriptions,
825
826
  is_view=is_view,
827
+ materialized=materialized,
826
828
  )
827
829
  + expressions,
828
830
  )
@@ -832,6 +834,7 @@ class EngineAdapter:
832
834
  target_columns_to_types: t.Dict[str, exp.DataType],
833
835
  column_descriptions: t.Optional[t.Dict[str, str]] = None,
834
836
  is_view: bool = False,
837
+ materialized: bool = False,
835
838
  ) -> t.List[exp.ColumnDef]:
836
839
  engine_supports_schema_comments = (
837
840
  self.COMMENT_CREATION_VIEW.supports_schema_def
@@ -1260,7 +1263,11 @@ class EngineAdapter:
1260
1263
  schema: t.Union[exp.Table, exp.Schema] = exp.to_table(view_name)
1261
1264
  if target_columns_to_types:
1262
1265
  schema = self._build_schema_exp(
1263
- exp.to_table(view_name), target_columns_to_types, column_descriptions, is_view=True
1266
+ exp.to_table(view_name),
1267
+ target_columns_to_types,
1268
+ column_descriptions,
1269
+ is_view=True,
1270
+ materialized=materialized,
1264
1271
  )
1265
1272
 
1266
1273
  properties = create_kwargs.pop("properties", None)
@@ -78,21 +78,21 @@ class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
78
78
  def _use_spark_session(self) -> bool:
79
79
  if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))):
80
80
  return True
81
- return (
82
- self.can_access_databricks_connect(
83
- bool(self._extra_config.get("disable_databricks_connect"))
84
- )
85
- and (
86
- {
87
- "databricks_connect_server_hostname",
88
- "databricks_connect_access_token",
89
- }.issubset(self._extra_config)
90
- )
91
- and (
92
- "databricks_connect_cluster_id" in self._extra_config
93
- or "databricks_connect_use_serverless" in self._extra_config
94
- )
95
- )
81
+
82
+ if self.can_access_databricks_connect(
83
+ bool(self._extra_config.get("disable_databricks_connect"))
84
+ ):
85
+ if self._extra_config.get("databricks_connect_use_serverless"):
86
+ return True
87
+
88
+ if {
89
+ "databricks_connect_cluster_id",
90
+ "databricks_connect_server_hostname",
91
+ "databricks_connect_access_token",
92
+ }.issubset(self._extra_config):
93
+ return True
94
+
95
+ return False
96
96
 
97
97
  @property
98
98
  def is_spark_session_connection(self) -> bool:
@@ -108,7 +108,7 @@ class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
108
108
 
109
109
  connect_kwargs = dict(
110
110
  host=self._extra_config["databricks_connect_server_hostname"],
111
- token=self._extra_config["databricks_connect_access_token"],
111
+ token=self._extra_config.get("databricks_connect_access_token"),
112
112
  )
113
113
  if "databricks_connect_use_serverless" in self._extra_config:
114
114
  connect_kwargs["serverless"] = True
@@ -394,3 +394,20 @@ class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
394
394
  expressions.append(clustered_by_exp)
395
395
  properties = exp.Properties(expressions=expressions)
396
396
  return properties
397
+
398
+ def _build_column_defs(
399
+ self,
400
+ target_columns_to_types: t.Dict[str, exp.DataType],
401
+ column_descriptions: t.Optional[t.Dict[str, str]] = None,
402
+ is_view: bool = False,
403
+ materialized: bool = False,
404
+ ) -> t.List[exp.ColumnDef]:
405
+ # Databricks requires column types to be specified when adding column comments
406
+ # in CREATE MATERIALIZED VIEW statements. Override is_view to False to force
407
+ # column types to be included when comments are present.
408
+ if is_view and materialized and column_descriptions:
409
+ is_view = False
410
+
411
+ return super()._build_column_defs(
412
+ target_columns_to_types, column_descriptions, is_view, materialized
413
+ )
@@ -7,19 +7,20 @@ import time
7
7
  from functools import cached_property
8
8
  from sqlglot import exp
9
9
  from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result
10
- from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin
11
10
  from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter
12
11
  from sqlmesh.core.engine_adapter.shared import (
13
12
  InsertOverwriteStrategy,
14
13
  )
15
14
  from sqlmesh.utils.errors import SQLMeshError
16
15
  from sqlmesh.utils.connection_pool import ConnectionPool
16
+ from sqlmesh.core.schema_diff import TableAlterOperation
17
+ from sqlmesh.utils import random_id
17
18
 
18
19
 
19
20
  logger = logging.getLogger(__name__)
20
21
 
21
22
 
22
- class FabricEngineAdapter(LogicalMergeMixin, MSSQLEngineAdapter):
23
+ class FabricEngineAdapter(MSSQLEngineAdapter):
23
24
  """
24
25
  Adapter for Microsoft Fabric.
25
26
  """
@@ -154,6 +155,113 @@ class FabricEngineAdapter(LogicalMergeMixin, MSSQLEngineAdapter):
154
155
  f"Unable to switch catalog to {catalog_name}, catalog ended up as {catalog_after_switch}"
155
156
  )
156
157
 
158
+ def alter_table(
159
+ self, alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]]
160
+ ) -> None:
161
+ """
162
+ Applies alter expressions to a table. Fabric has limited support for ALTER TABLE,
163
+ so this method implements a workaround for column type changes.
164
+ This method is self-contained and sets its own catalog context.
165
+ """
166
+ if not alter_expressions:
167
+ return
168
+
169
+ # Get the target table from the first expression to determine the correct catalog.
170
+ first_op = alter_expressions[0]
171
+ expression = first_op.expression if isinstance(first_op, TableAlterOperation) else first_op
172
+ if not isinstance(expression, exp.Alter) or not expression.this.catalog:
173
+ # Fallback for unexpected scenarios
174
+ logger.warning(
175
+ "Could not determine catalog from alter expression, executing with current context."
176
+ )
177
+ super().alter_table(alter_expressions)
178
+ return
179
+
180
+ target_catalog = expression.this.catalog
181
+ self.set_current_catalog(target_catalog)
182
+
183
+ with self.transaction():
184
+ for op in alter_expressions:
185
+ expression = op.expression if isinstance(op, TableAlterOperation) else op
186
+
187
+ if not isinstance(expression, exp.Alter):
188
+ self.execute(expression)
189
+ continue
190
+
191
+ for action in expression.actions:
192
+ table_name = expression.this
193
+
194
+ table_name_without_catalog = table_name.copy()
195
+ table_name_without_catalog.set("catalog", None)
196
+
197
+ is_type_change = isinstance(action, exp.AlterColumn) and action.args.get(
198
+ "dtype"
199
+ )
200
+
201
+ if is_type_change:
202
+ column_to_alter = action.this
203
+ new_type = action.args["dtype"]
204
+ temp_column_name_str = f"{column_to_alter.name}__{random_id(short=True)}"
205
+ temp_column_name = exp.to_identifier(temp_column_name_str)
206
+
207
+ logger.info(
208
+ "Applying workaround for column '%s' on table '%s' to change type to '%s'.",
209
+ column_to_alter.sql(),
210
+ table_name.sql(),
211
+ new_type.sql(),
212
+ )
213
+
214
+ # Step 1: Add a temporary column.
215
+ add_column_expr = exp.Alter(
216
+ this=table_name_without_catalog.copy(),
217
+ kind="TABLE",
218
+ actions=[
219
+ exp.ColumnDef(this=temp_column_name.copy(), kind=new_type.copy())
220
+ ],
221
+ )
222
+ add_sql = self._to_sql(add_column_expr)
223
+ self.execute(add_sql)
224
+
225
+ # Step 2: Copy and cast data.
226
+ update_sql = self._to_sql(
227
+ exp.Update(
228
+ this=table_name_without_catalog.copy(),
229
+ expressions=[
230
+ exp.EQ(
231
+ this=temp_column_name.copy(),
232
+ expression=exp.Cast(
233
+ this=column_to_alter.copy(), to=new_type.copy()
234
+ ),
235
+ )
236
+ ],
237
+ )
238
+ )
239
+ self.execute(update_sql)
240
+
241
+ # Step 3: Drop the original column.
242
+ drop_sql = self._to_sql(
243
+ exp.Alter(
244
+ this=table_name_without_catalog.copy(),
245
+ kind="TABLE",
246
+ actions=[exp.Drop(this=column_to_alter.copy(), kind="COLUMN")],
247
+ )
248
+ )
249
+ self.execute(drop_sql)
250
+
251
+ # Step 4: Rename the temporary column.
252
+ old_name_qualified = f"{table_name_without_catalog.sql(dialect=self.dialect)}.{temp_column_name.sql(dialect=self.dialect)}"
253
+ new_name_unquoted = column_to_alter.sql(
254
+ dialect=self.dialect, identify=False
255
+ )
256
+ rename_sql = f"EXEC sp_rename '{old_name_qualified}', '{new_name_unquoted}', 'COLUMN'"
257
+ self.execute(rename_sql)
258
+ else:
259
+ # For other alterations, execute directly.
260
+ direct_alter_expr = exp.Alter(
261
+ this=table_name_without_catalog.copy(), kind="TABLE", actions=[action]
262
+ )
263
+ self.execute(direct_alter_expr)
264
+
157
265
 
158
266
  class FabricHttpClient:
159
267
  def __init__(self, tenant_id: str, workspace_id: str, client_id: str, client_secret: str):
@@ -74,6 +74,32 @@ class TrinoEngineAdapter(
74
74
  def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
75
75
  return self._extra_config.get("schema_location_mapping")
76
76
 
77
+ @property
78
+ def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]:
79
+ return self._extra_config.get("timestamp_mapping")
80
+
81
+ def _apply_timestamp_mapping(
82
+ self, columns_to_types: t.Dict[str, exp.DataType]
83
+ ) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]:
84
+ """Apply custom timestamp mapping to column types.
85
+
86
+ Returns:
87
+ A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names
88
+ contains the names of columns that were found in the mapping.
89
+ """
90
+ if not self.timestamp_mapping:
91
+ return columns_to_types, set()
92
+
93
+ result = {}
94
+ mapped_columns: t.Set[str] = set()
95
+ for column, column_type in columns_to_types.items():
96
+ if column_type in self.timestamp_mapping:
97
+ result[column] = self.timestamp_mapping[column_type]
98
+ mapped_columns.add(column)
99
+ else:
100
+ result[column] = column_type
101
+ return result, mapped_columns
102
+
77
103
  @property
78
104
  def catalog_support(self) -> CatalogSupport:
79
105
  return CatalogSupport.FULL_SUPPORT
@@ -117,7 +143,7 @@ class TrinoEngineAdapter(
117
143
  try:
118
144
  yield
119
145
  finally:
120
- self.execute(f"RESET SESSION AUTHORIZATION")
146
+ self.execute("RESET SESSION AUTHORIZATION")
121
147
 
122
148
  def replace_query(
123
149
  self,
@@ -284,9 +310,13 @@ class TrinoEngineAdapter(
284
310
  column_descriptions: t.Optional[t.Dict[str, str]] = None,
285
311
  expressions: t.Optional[t.List[exp.PrimaryKey]] = None,
286
312
  is_view: bool = False,
313
+ materialized: bool = False,
287
314
  ) -> exp.Schema:
315
+ target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
316
+ target_columns_to_types
317
+ )
288
318
  if "delta_lake" in self.get_catalog_type_from_table(table):
289
- target_columns_to_types = self._to_delta_ts(target_columns_to_types)
319
+ target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
290
320
 
291
321
  return super()._build_schema_exp(
292
322
  table, target_columns_to_types, column_descriptions, expressions, is_view
@@ -312,10 +342,15 @@ class TrinoEngineAdapter(
312
342
  source_columns: t.Optional[t.List[str]] = None,
313
343
  **kwargs: t.Any,
314
344
  ) -> None:
345
+ mapped_columns: t.Set[str] = set()
346
+ if target_columns_to_types:
347
+ target_columns_to_types, mapped_columns = self._apply_timestamp_mapping(
348
+ target_columns_to_types
349
+ )
315
350
  if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table(
316
351
  target_table
317
352
  ):
318
- target_columns_to_types = self._to_delta_ts(target_columns_to_types)
353
+ target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns)
319
354
 
320
355
  return super()._scd_type_2(
321
356
  target_table,
@@ -345,18 +380,21 @@ class TrinoEngineAdapter(
345
380
  # - `timestamp(3) with time zone` for timezone-aware
346
381
  # https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping
347
382
  def _to_delta_ts(
348
- self, columns_to_types: t.Dict[str, exp.DataType]
383
+ self,
384
+ columns_to_types: t.Dict[str, exp.DataType],
385
+ skip_columns: t.Optional[t.Set[str]] = None,
349
386
  ) -> t.Dict[str, exp.DataType]:
350
387
  ts6 = exp.DataType.build("timestamp(6)")
351
388
  ts3_tz = exp.DataType.build("timestamp(3) with time zone")
389
+ skip = skip_columns or set()
352
390
 
353
391
  delta_columns_to_types = {
354
- k: ts6 if v.is_type(exp.DataType.Type.TIMESTAMP) else v
392
+ k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v
355
393
  for k, v in columns_to_types.items()
356
394
  }
357
395
 
358
396
  delta_columns_to_types = {
359
- k: ts3_tz if v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
397
+ k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
360
398
  for k, v in delta_columns_to_types.items()
361
399
  }
362
400
 
sqlmesh/core/lineage.py CHANGED
@@ -66,6 +66,7 @@ def lineage(
66
66
  scope=scope,
67
67
  trim_selects=trim_selects,
68
68
  dialect=model.dialect,
69
+ copy=False,
69
70
  )
70
71
 
71
72
 
@@ -129,6 +129,21 @@ class NoMissingAudits(Rule):
129
129
  return self.violation()
130
130
 
131
131
 
132
+ class NoMissingUnitTest(Rule):
133
+ """All models must have a unit test found in the tests/ directory yaml files"""
134
+
135
+ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
136
+ # External models cannot have unit tests
137
+ if isinstance(model, ExternalModel):
138
+ return None
139
+
140
+ if model.name not in self.context.models_with_tests:
141
+ return self.violation(
142
+ violation_msg=f"Model {model.name} is missing unit test(s). Please add in the tests/ directory."
143
+ )
144
+ return None
145
+
146
+
132
147
  class NoMissingExternalModels(Rule):
133
148
  """All external models must be registered in the external_models.yaml file"""
134
149