sqlmesh 0.225.0__py3-none-any.whl → 0.227.2.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlmesh might be problematic. Click here for more details.

Files changed (43) hide show
  1. sqlmesh/__init__.py +10 -2
  2. sqlmesh/_version.py +2 -2
  3. sqlmesh/core/config/connection.py +10 -5
  4. sqlmesh/core/config/loader.py +1 -0
  5. sqlmesh/core/context.py +73 -8
  6. sqlmesh/core/engine_adapter/base.py +12 -7
  7. sqlmesh/core/engine_adapter/fabric.py +1 -2
  8. sqlmesh/core/engine_adapter/mssql.py +5 -2
  9. sqlmesh/core/engine_adapter/trino.py +1 -1
  10. sqlmesh/core/lineage.py +1 -0
  11. sqlmesh/core/linter/rules/builtin.py +15 -0
  12. sqlmesh/core/loader.py +4 -0
  13. sqlmesh/core/model/kind.py +2 -2
  14. sqlmesh/core/plan/definition.py +9 -7
  15. sqlmesh/core/renderer.py +7 -8
  16. sqlmesh/core/scheduler.py +45 -15
  17. sqlmesh/core/signal.py +35 -14
  18. sqlmesh/core/snapshot/definition.py +18 -12
  19. sqlmesh/core/snapshot/evaluator.py +31 -17
  20. sqlmesh/core/state_sync/db/snapshot.py +6 -1
  21. sqlmesh/core/table_diff.py +2 -2
  22. sqlmesh/core/test/definition.py +5 -3
  23. sqlmesh/core/test/discovery.py +4 -0
  24. sqlmesh/dbt/builtin.py +9 -11
  25. sqlmesh/dbt/column.py +17 -5
  26. sqlmesh/dbt/common.py +4 -2
  27. sqlmesh/dbt/context.py +2 -0
  28. sqlmesh/dbt/loader.py +15 -2
  29. sqlmesh/dbt/manifest.py +3 -1
  30. sqlmesh/dbt/model.py +13 -1
  31. sqlmesh/dbt/profile.py +3 -3
  32. sqlmesh/dbt/target.py +9 -4
  33. sqlmesh/utils/date.py +1 -1
  34. sqlmesh/utils/pydantic.py +6 -6
  35. sqlmesh/utils/windows.py +13 -3
  36. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/METADATA +2 -2
  37. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/RECORD +43 -43
  38. sqlmesh_dbt/cli.py +26 -1
  39. sqlmesh_dbt/operations.py +8 -2
  40. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/WHEEL +0 -0
  41. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/entry_points.txt +0 -0
  42. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/licenses/LICENSE +0 -0
  43. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/top_level.txt +0 -0
sqlmesh/__init__.py CHANGED
@@ -188,6 +188,7 @@ def configure_logging(
188
188
  write_to_file: bool = True,
189
189
  log_file_dir: t.Optional[t.Union[str, Path]] = None,
190
190
  ignore_warnings: bool = False,
191
+ log_level: t.Optional[t.Union[str, int]] = None,
191
192
  ) -> None:
192
193
  # Remove noisy grpc logs that are not useful for users
193
194
  os.environ["GRPC_VERBOSITY"] = os.environ.get("GRPC_VERBOSITY", "NONE")
@@ -195,8 +196,15 @@ def configure_logging(
195
196
  logger = logging.getLogger()
196
197
  debug = force_debug or debug_mode_enabled()
197
198
 
198
- # base logger needs to be the lowest level that we plan to log
199
- level = logging.DEBUG if debug else logging.INFO
199
+ if log_level is not None:
200
+ if isinstance(log_level, str):
201
+ level = logging._nameToLevel.get(log_level.upper()) or logging.INFO
202
+ else:
203
+ level = log_level
204
+ else:
205
+ # base logger needs to be the lowest level that we plan to log
206
+ level = logging.DEBUG if debug else logging.INFO
207
+
200
208
  logger.setLevel(level)
201
209
 
202
210
  if debug:
sqlmesh/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.225.0'
32
- __version_tuple__ = version_tuple = (0, 225, 0)
31
+ __version__ = version = '0.227.2.dev6'
32
+ __version_tuple__ = version_tuple = (0, 227, 2, 'dev6')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -58,6 +58,7 @@ FORBIDDEN_STATE_SYNC_ENGINES = {
58
58
  "clickhouse",
59
59
  }
60
60
  MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)")
61
+ PASSWORD_REGEX = re.compile(r"(password=)(\S+)")
61
62
 
62
63
 
63
64
  def _get_engine_import_validator(
@@ -479,13 +480,13 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
479
480
  adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key)
480
481
  if adapter is not None:
481
482
  logger.info(
482
- f"Using existing DuckDB adapter due to overlapping data file: {self._mask_motherduck_token(key)}"
483
+ f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}"
483
484
  )
484
485
  return adapter
485
486
 
486
487
  if data_files:
487
488
  masked_files = {
488
- self._mask_motherduck_token(file if isinstance(file, str) else file.path)
489
+ self._mask_sensitive_data(file if isinstance(file, str) else file.path)
489
490
  for file in data_files
490
491
  }
491
492
  logger.info(f"Creating new DuckDB adapter for data files: {masked_files}")
@@ -507,10 +508,14 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
507
508
  return list(self.catalogs)[0]
508
509
  return None
509
510
 
510
- def _mask_motherduck_token(self, string: str) -> str:
511
- return MOTHERDUCK_TOKEN_REGEX.sub(
512
- lambda m: f"{m.group(1)}{m.group(2)}{'*' * len(m.group(3))}", string
511
+ def _mask_sensitive_data(self, string: str) -> str:
512
+ # Mask MotherDuck tokens with fixed number of asterisks
513
+ result = MOTHERDUCK_TOKEN_REGEX.sub(
514
+ lambda m: f"{m.group(1)}{m.group(2)}{'*' * 8 if m.group(3) else ''}", string
513
515
  )
516
+ # Mask PostgreSQL/MySQL passwords with fixed number of asterisks
517
+ result = PASSWORD_REGEX.sub(lambda m: f"{m.group(1)}{'*' * 8}", result)
518
+ return result
514
519
 
515
520
 
516
521
  class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig):
@@ -178,6 +178,7 @@ def load_config_from_paths(
178
178
 
179
179
  dbt_python_config = sqlmesh_config(
180
180
  project_root=dbt_project_file.parent,
181
+ profiles_dir=kwargs.pop("profiles_dir", None),
181
182
  dbt_profile_name=kwargs.pop("profile", None),
182
183
  dbt_target_name=kwargs.pop("target", None),
183
184
  variables=variables,
sqlmesh/core/context.py CHANGED
@@ -115,6 +115,7 @@ from sqlmesh.core.test import (
115
115
  ModelTestMetadata,
116
116
  generate_test,
117
117
  run_tests,
118
+ filter_tests_by_patterns,
118
119
  )
119
120
  from sqlmesh.core.user import User
120
121
  from sqlmesh.utils import UniqueKeyDict, Verbosity
@@ -139,20 +140,23 @@ from sqlmesh.utils.errors import (
139
140
  )
140
141
  from sqlmesh.utils.config import print_config
141
142
  from sqlmesh.utils.jinja import JinjaMacroRegistry
143
+ from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path
142
144
 
143
145
  if t.TYPE_CHECKING:
144
146
  import pandas as pd
145
147
  from typing_extensions import Literal
146
148
 
147
149
  from sqlmesh.core.engine_adapter._typing import (
148
- BigframeSession,
149
150
  DF,
151
+ BigframeSession,
150
152
  PySparkDataFrame,
151
153
  PySparkSession,
152
154
  SnowparkSession,
153
155
  )
154
156
  from sqlmesh.core.snapshot import Node
155
157
 
158
+ from sqlmesh.core.snapshot.definition import Intervals
159
+
156
160
  ModelOrSnapshot = t.Union[str, Model, Snapshot]
157
161
  NodeOrSnapshot = t.Union[str, Model, StandaloneAudit, Snapshot]
158
162
 
@@ -275,6 +279,7 @@ class ExecutionContext(BaseContext):
275
279
  default_dialect: t.Optional[str] = None,
276
280
  default_catalog: t.Optional[str] = None,
277
281
  is_restatement: t.Optional[bool] = None,
282
+ parent_intervals: t.Optional[Intervals] = None,
278
283
  variables: t.Optional[t.Dict[str, t.Any]] = None,
279
284
  blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
280
285
  ):
@@ -286,6 +291,7 @@ class ExecutionContext(BaseContext):
286
291
  self._variables = variables or {}
287
292
  self._blueprint_variables = blueprint_variables or {}
288
293
  self._is_restatement = is_restatement
294
+ self._parent_intervals = parent_intervals
289
295
 
290
296
  @property
291
297
  def default_dialect(self) -> t.Optional[str]:
@@ -314,6 +320,10 @@ class ExecutionContext(BaseContext):
314
320
  def is_restatement(self) -> t.Optional[bool]:
315
321
  return self._is_restatement
316
322
 
323
+ @property
324
+ def parent_intervals(self) -> t.Optional[Intervals]:
325
+ return self._parent_intervals
326
+
317
327
  def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
318
328
  """Returns a variable value."""
319
329
  return self._variables.get(var_name.lower(), default)
@@ -389,6 +399,10 @@ class GenericContext(BaseContext, t.Generic[C]):
389
399
  self._standalone_audits: UniqueKeyDict[str, StandaloneAudit] = UniqueKeyDict(
390
400
  "standaloneaudits"
391
401
  )
402
+ self._model_test_metadata: t.List[ModelTestMetadata] = []
403
+ self._model_test_metadata_path_index: t.Dict[Path, t.List[ModelTestMetadata]] = {}
404
+ self._model_test_metadata_fully_qualified_name_index: t.Dict[str, ModelTestMetadata] = {}
405
+ self._models_with_tests: t.Set[str] = set()
392
406
  self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros")
393
407
  self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics")
394
408
  self._jinja_macros = JinjaMacroRegistry()
@@ -627,6 +641,10 @@ class GenericContext(BaseContext, t.Generic[C]):
627
641
  self._excluded_requirements.clear()
628
642
  self._linters.clear()
629
643
  self._environment_statements = []
644
+ self._model_test_metadata.clear()
645
+ self._model_test_metadata_path_index.clear()
646
+ self._model_test_metadata_fully_qualified_name_index.clear()
647
+ self._models_with_tests.clear()
630
648
 
631
649
  for loader, project in zip(self._loaders, loaded_projects):
632
650
  self._jinja_macros = self._jinja_macros.merge(project.jinja_macros)
@@ -638,6 +656,15 @@ class GenericContext(BaseContext, t.Generic[C]):
638
656
  self._requirements.update(project.requirements)
639
657
  self._excluded_requirements.update(project.excluded_requirements)
640
658
  self._environment_statements.extend(project.environment_statements)
659
+ self._model_test_metadata.extend(project.model_test_metadata)
660
+ for metadata in project.model_test_metadata:
661
+ if metadata.path not in self._model_test_metadata_path_index:
662
+ self._model_test_metadata_path_index[metadata.path] = []
663
+ self._model_test_metadata_path_index[metadata.path].append(metadata)
664
+ self._model_test_metadata_fully_qualified_name_index[
665
+ metadata.fully_qualified_test_name
666
+ ] = metadata
667
+ self._models_with_tests.add(metadata.model_name)
641
668
 
642
669
  config = loader.config
643
670
  self._linters[config.project] = Linter.from_rules(
@@ -1040,6 +1067,11 @@ class GenericContext(BaseContext, t.Generic[C]):
1040
1067
  """Returns all registered standalone audits in this context."""
1041
1068
  return MappingProxyType(self._standalone_audits)
1042
1069
 
1070
+ @property
1071
+ def models_with_tests(self) -> t.Set[str]:
1072
+ """Returns all models with tests in this context."""
1073
+ return self._models_with_tests
1074
+
1043
1075
  @property
1044
1076
  def snapshots(self) -> t.Dict[str, Snapshot]:
1045
1077
  """Generates and returns snapshots based on models registered in this context.
@@ -2211,7 +2243,9 @@ class GenericContext(BaseContext, t.Generic[C]):
2211
2243
 
2212
2244
  pd.set_option("display.max_columns", None)
2213
2245
 
2214
- test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
2246
+ test_meta = self._select_tests(
2247
+ test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns
2248
+ )
2215
2249
 
2216
2250
  result = run_tests(
2217
2251
  model_test_metadata=test_meta,
@@ -2270,6 +2304,7 @@ class GenericContext(BaseContext, t.Generic[C]):
2270
2304
  snapshot=snapshot,
2271
2305
  start=start,
2272
2306
  end=end,
2307
+ execution_time=execution_time,
2273
2308
  snapshots=self.snapshots,
2274
2309
  ):
2275
2310
  audit_id = f"{audit_result.audit.name}"
@@ -2590,12 +2625,15 @@ class GenericContext(BaseContext, t.Generic[C]):
2590
2625
  )
2591
2626
 
2592
2627
  def clear_caches(self) -> None:
2593
- for path in self.configs:
2594
- cache_path = path / c.CACHE
2595
- if cache_path.exists():
2596
- rmtree(cache_path)
2597
- if self.cache_dir.exists():
2598
- rmtree(self.cache_dir)
2628
+ paths_to_remove = [path / c.CACHE for path in self.configs]
2629
+ paths_to_remove.append(self.cache_dir)
2630
+
2631
+ if IS_WINDOWS:
2632
+ paths_to_remove = [fix_windows_path(path) for path in paths_to_remove]
2633
+
2634
+ for path in paths_to_remove:
2635
+ if path.exists():
2636
+ rmtree(path)
2599
2637
 
2600
2638
  if isinstance(self._state_sync, CachingStateSync):
2601
2639
  self._state_sync.clear_cache()
@@ -2769,6 +2807,33 @@ class GenericContext(BaseContext, t.Generic[C]):
2769
2807
  raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
2770
2808
  return self.engine_adapter
2771
2809
 
2810
+ def _select_tests(
2811
+ self,
2812
+ test_meta: t.List[ModelTestMetadata],
2813
+ tests: t.Optional[t.List[str]] = None,
2814
+ patterns: t.Optional[t.List[str]] = None,
2815
+ ) -> t.List[ModelTestMetadata]:
2816
+ """Filter pre-loaded test metadata based on tests and patterns."""
2817
+
2818
+ if tests:
2819
+ filtered_tests = []
2820
+ for test in tests:
2821
+ if "::" in test:
2822
+ if test in self._model_test_metadata_fully_qualified_name_index:
2823
+ filtered_tests.append(
2824
+ self._model_test_metadata_fully_qualified_name_index[test]
2825
+ )
2826
+ else:
2827
+ test_path = Path(test)
2828
+ if test_path in self._model_test_metadata_path_index:
2829
+ filtered_tests.extend(self._model_test_metadata_path_index[test_path])
2830
+ test_meta = filtered_tests
2831
+
2832
+ if patterns:
2833
+ test_meta = filter_tests_by_patterns(test_meta, patterns)
2834
+
2835
+ return test_meta
2836
+
2772
2837
  def _snapshots(
2773
2838
  self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
2774
2839
  ) -> t.Dict[str, Snapshot]:
@@ -18,7 +18,7 @@ from functools import cached_property, partial
18
18
 
19
19
  from sqlglot import Dialect, exp
20
20
  from sqlglot.errors import ErrorLevel
21
- from sqlglot.helper import ensure_list
21
+ from sqlglot.helper import ensure_list, seq_get
22
22
  from sqlglot.optimizer.qualify_columns import quote_identifiers
23
23
 
24
24
  from sqlmesh.core.dialect import (
@@ -551,11 +551,13 @@ class EngineAdapter:
551
551
  target_table,
552
552
  source_queries,
553
553
  target_columns_to_types,
554
+ **kwargs,
554
555
  )
555
556
  return self._insert_overwrite_by_condition(
556
557
  target_table,
557
558
  source_queries,
558
559
  target_columns_to_types,
560
+ **kwargs,
559
561
  )
560
562
 
561
563
  def create_index(
@@ -1614,7 +1616,7 @@ class EngineAdapter:
1614
1616
  **kwargs: t.Any,
1615
1617
  ) -> None:
1616
1618
  return self._insert_overwrite_by_condition(
1617
- table_name, source_queries, target_columns_to_types, where
1619
+ table_name, source_queries, target_columns_to_types, where, **kwargs
1618
1620
  )
1619
1621
 
1620
1622
  def _values_to_sql(
@@ -1772,7 +1774,7 @@ class EngineAdapter:
1772
1774
  valid_from_col: exp.Column,
1773
1775
  valid_to_col: exp.Column,
1774
1776
  execution_time: t.Union[TimeLike, exp.Column],
1775
- check_columns: t.Union[exp.Star, t.Sequence[exp.Column]],
1777
+ check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]],
1776
1778
  invalidate_hard_deletes: bool = True,
1777
1779
  execution_time_as_valid_from: bool = False,
1778
1780
  target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1810,7 +1812,7 @@ class EngineAdapter:
1810
1812
  execution_time: t.Union[TimeLike, exp.Column],
1811
1813
  invalidate_hard_deletes: bool = True,
1812
1814
  updated_at_col: t.Optional[exp.Column] = None,
1813
- check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
1815
+ check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
1814
1816
  updated_at_as_valid_from: bool = False,
1815
1817
  execution_time_as_valid_from: bool = False,
1816
1818
  target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1885,8 +1887,10 @@ class EngineAdapter:
1885
1887
  # they are equal or not, the extra check is not a problem and we gain simplified logic here.
1886
1888
  # If we want to change this, then we just need to check the expressions in unique_key and pull out the
1887
1889
  # column names and then remove them from the unmanaged_columns
1888
- if check_columns and check_columns == exp.Star():
1889
- check_columns = [exp.column(col) for col in unmanaged_columns_to_types]
1890
+ if check_columns:
1891
+ # Handle both Star directly and [Star()] (which can happen during serialization/deserialization)
1892
+ if isinstance(seq_get(ensure_list(check_columns), 0), exp.Star):
1893
+ check_columns = [exp.column(col) for col in unmanaged_columns_to_types]
1890
1894
  execution_ts = (
1891
1895
  exp.cast(execution_time, time_data_type, dialect=self.dialect)
1892
1896
  if isinstance(execution_time, exp.Column)
@@ -1923,7 +1927,8 @@ class EngineAdapter:
1923
1927
  col_qualified.set("table", exp.to_identifier("joined"))
1924
1928
 
1925
1929
  t_col = col_qualified.copy()
1926
- t_col.this.set("this", f"t_{col.name}")
1930
+ for column in t_col.find_all(exp.Column):
1931
+ column.this.set("this", f"t_{column.name}")
1927
1932
 
1928
1933
  row_check_conditions.extend(
1929
1934
  [
@@ -7,7 +7,6 @@ import time
7
7
  from functools import cached_property
8
8
  from sqlglot import exp
9
9
  from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result
10
- from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin
11
10
  from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter
12
11
  from sqlmesh.core.engine_adapter.shared import (
13
12
  InsertOverwriteStrategy,
@@ -19,7 +18,7 @@ from sqlmesh.utils.connection_pool import ConnectionPool
19
18
  logger = logging.getLogger(__name__)
20
19
 
21
20
 
22
- class FabricEngineAdapter(LogicalMergeMixin, MSSQLEngineAdapter):
21
+ class FabricEngineAdapter(MSSQLEngineAdapter):
23
22
  """
24
23
  Adapter for Microsoft Fabric.
25
24
  """
@@ -56,6 +56,7 @@ class MSSQLEngineAdapter(
56
56
  COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
57
57
  COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
58
58
  SUPPORTS_REPLACE_TABLE = False
59
+ MAX_IDENTIFIER_LENGTH = 128
59
60
  SUPPORTS_QUERY_EXECUTION_TRACKING = True
60
61
  SCHEMA_DIFFER_KWARGS = {
61
62
  "parameterized_type_defaults": {
@@ -422,7 +423,9 @@ class MSSQLEngineAdapter(
422
423
  insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
423
424
  **kwargs: t.Any,
424
425
  ) -> None:
425
- if not where or where == exp.true():
426
+ # note that this is passed as table_properties here rather than physical_properties
427
+ use_merge_strategy = kwargs.get("table_properties", {}).get("mssql_merge_exists")
428
+ if (not where or where == exp.true()) and not use_merge_strategy:
426
429
  # this is a full table replacement, call the base strategy to do DELETE+INSERT
427
430
  # which will result in TRUNCATE+INSERT due to how we have overridden self.delete_from()
428
431
  return EngineAdapter._insert_overwrite_by_condition(
@@ -435,7 +438,7 @@ class MSSQLEngineAdapter(
435
438
  **kwargs,
436
439
  )
437
440
 
438
- # For actual conditional overwrites, use MERGE from InsertOverwriteWithMergeMixin
441
+ # For conditional overwrites or when mssql_merge_exists is set use MERGE
439
442
  return super()._insert_overwrite_by_condition(
440
443
  table_name=table_name,
441
444
  source_queries=source_queries,
@@ -302,7 +302,7 @@ class TrinoEngineAdapter(
302
302
  execution_time: t.Union[TimeLike, exp.Column],
303
303
  invalidate_hard_deletes: bool = True,
304
304
  updated_at_col: t.Optional[exp.Column] = None,
305
- check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
305
+ check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
306
306
  updated_at_as_valid_from: bool = False,
307
307
  execution_time_as_valid_from: bool = False,
308
308
  target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
sqlmesh/core/lineage.py CHANGED
@@ -66,6 +66,7 @@ def lineage(
66
66
  scope=scope,
67
67
  trim_selects=trim_selects,
68
68
  dialect=model.dialect,
69
+ copy=False,
69
70
  )
70
71
 
71
72
 
@@ -129,6 +129,21 @@ class NoMissingAudits(Rule):
129
129
  return self.violation()
130
130
 
131
131
 
132
+ class NoMissingUnitTest(Rule):
133
+ """All models must have a unit test found in the test/ directory yaml files"""
134
+
135
+ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
136
+ # External models cannot have unit tests
137
+ if isinstance(model, ExternalModel):
138
+ return None
139
+
140
+ if model.name not in self.context.models_with_tests:
141
+ return self.violation(
142
+ violation_msg=f"Model {model.name} is missing unit test(s). Please add in the tests/ directory."
143
+ )
144
+ return None
145
+
146
+
132
147
  class NoMissingExternalModels(Rule):
133
148
  """All external models must be registered in the external_models.yaml file"""
134
149
 
sqlmesh/core/loader.py CHANGED
@@ -64,6 +64,7 @@ class LoadedProject:
64
64
  excluded_requirements: t.Set[str]
65
65
  environment_statements: t.List[EnvironmentStatements]
66
66
  user_rules: RuleSet
67
+ model_test_metadata: t.List[ModelTestMetadata]
67
68
 
68
69
 
69
70
  class CacheBase(abc.ABC):
@@ -243,6 +244,8 @@ class Loader(abc.ABC):
243
244
 
244
245
  user_rules = self._load_linting_rules()
245
246
 
247
+ model_test_metadata = self.load_model_tests()
248
+
246
249
  project = LoadedProject(
247
250
  macros=macros,
248
251
  jinja_macros=jinja_macros,
@@ -254,6 +257,7 @@ class Loader(abc.ABC):
254
257
  excluded_requirements=excluded_requirements,
255
258
  environment_statements=environment_statements,
256
259
  user_rules=user_rules,
260
+ model_test_metadata=model_test_metadata,
257
261
  )
258
262
  return project
259
263
 
@@ -23,7 +23,7 @@ from sqlmesh.utils.pydantic import (
23
23
  PydanticModel,
24
24
  SQLGlotBool,
25
25
  SQLGlotColumn,
26
- SQLGlotListOfColumnsOrStar,
26
+ SQLGlotListOfFieldsOrStar,
27
27
  SQLGlotListOfFields,
28
28
  SQLGlotPositiveInt,
29
29
  SQLGlotString,
@@ -852,7 +852,7 @@ class SCDType2ByTimeKind(_SCDType2Kind):
852
852
 
853
853
  class SCDType2ByColumnKind(_SCDType2Kind):
854
854
  name: t.Literal[ModelKindName.SCD_TYPE_2_BY_COLUMN] = ModelKindName.SCD_TYPE_2_BY_COLUMN
855
- columns: SQLGlotListOfColumnsOrStar
855
+ columns: SQLGlotListOfFieldsOrStar
856
856
  execution_time_as_valid_from: SQLGlotBool = False
857
857
  updated_at_name: t.Optional[SQLGlotColumn] = None
858
858
 
@@ -63,7 +63,7 @@ class Plan(PydanticModel, frozen=True):
63
63
  restatements: t.Dict[SnapshotId, Interval]
64
64
  """
65
65
  All models being restated, which are typically the explicitly selected ones + their downstream dependencies.
66
-
66
+
67
67
  Note that dev previews are also considered restatements, so :selected_models_to_restate can be empty
68
68
  while :restatements is still populated with dev previews
69
69
  """
@@ -213,8 +213,8 @@ class Plan(PydanticModel, frozen=True):
213
213
 
214
214
  snapshots_by_name = self.context_diff.snapshots_by_name
215
215
  snapshots = [s.table_info for s in self.snapshots.values()]
216
- promoted_snapshot_ids = None
217
- if self.is_dev and not self.include_unmodified:
216
+ promotable_snapshot_ids = None
217
+ if self.is_dev:
218
218
  if self.selected_models_to_backfill is not None:
219
219
  # Only promote models that have been explicitly selected for backfill.
220
220
  promotable_snapshot_ids = {
@@ -225,12 +225,14 @@ class Plan(PydanticModel, frozen=True):
225
225
  if m in snapshots_by_name
226
226
  ],
227
227
  }
228
- else:
228
+ elif not self.include_unmodified:
229
229
  promotable_snapshot_ids = self.context_diff.promotable_snapshot_ids.copy()
230
230
 
231
- promoted_snapshot_ids = [
232
- s.snapshot_id for s in snapshots if s.snapshot_id in promotable_snapshot_ids
233
- ]
231
+ promoted_snapshot_ids = (
232
+ [s.snapshot_id for s in snapshots if s.snapshot_id in promotable_snapshot_ids]
233
+ if promotable_snapshot_ids is not None
234
+ else None
235
+ )
234
236
 
235
237
  previous_finalized_snapshots = (
236
238
  self.context_diff.environment_snapshots
sqlmesh/core/renderer.py CHANGED
@@ -196,7 +196,14 @@ class BaseExpressionRenderer:
196
196
  **kwargs,
197
197
  }
198
198
 
199
+ if this_model:
200
+ render_kwargs["this_model"] = this_model
201
+
202
+ macro_evaluator.locals.update(render_kwargs)
203
+
199
204
  variables = kwargs.pop("variables", {})
205
+ if variables:
206
+ macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
200
207
 
201
208
  expressions = [self._expression]
202
209
  if isinstance(self._expression, d.Jinja):
@@ -268,14 +275,6 @@ class BaseExpressionRenderer:
268
275
  f"Could not parse the rendered jinja at '{self._path}'.\n{ex}"
269
276
  ) from ex
270
277
 
271
- if this_model:
272
- render_kwargs["this_model"] = this_model
273
-
274
- macro_evaluator.locals.update(render_kwargs)
275
-
276
- if variables:
277
- macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
278
-
279
278
  for definition in self._macro_definitions:
280
279
  try:
281
280
  macro_evaluator.evaluate(definition)