sqlmesh 0.225.1.dev26__py3-none-any.whl → 0.228.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlmesh might be problematic. Click here for more details.

Files changed (34) hide show
  1. sqlmesh/_version.py +2 -2
  2. sqlmesh/core/config/connection.py +37 -1
  3. sqlmesh/core/context.py +60 -10
  4. sqlmesh/core/dialect.py +10 -2
  5. sqlmesh/core/engine_adapter/base.py +8 -1
  6. sqlmesh/core/engine_adapter/databricks.py +33 -16
  7. sqlmesh/core/engine_adapter/fabric.py +110 -2
  8. sqlmesh/core/engine_adapter/trino.py +44 -6
  9. sqlmesh/core/lineage.py +1 -0
  10. sqlmesh/core/linter/rules/builtin.py +15 -0
  11. sqlmesh/core/loader.py +17 -30
  12. sqlmesh/core/model/definition.py +9 -0
  13. sqlmesh/core/plan/definition.py +9 -7
  14. sqlmesh/core/renderer.py +7 -8
  15. sqlmesh/core/scheduler.py +45 -15
  16. sqlmesh/core/signal.py +35 -14
  17. sqlmesh/core/snapshot/definition.py +18 -12
  18. sqlmesh/core/snapshot/evaluator.py +24 -16
  19. sqlmesh/core/test/definition.py +5 -5
  20. sqlmesh/core/test/discovery.py +4 -0
  21. sqlmesh/dbt/common.py +4 -2
  22. sqlmesh/dbt/manifest.py +3 -1
  23. sqlmesh/integrations/github/cicd/command.py +11 -2
  24. sqlmesh/integrations/github/cicd/controller.py +6 -2
  25. sqlmesh/lsp/context.py +4 -2
  26. sqlmesh/magics.py +1 -1
  27. sqlmesh/utils/date.py +1 -1
  28. sqlmesh/utils/git.py +3 -1
  29. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/METADATA +3 -3
  30. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/RECORD +34 -34
  31. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/WHEEL +0 -0
  32. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/entry_points.txt +0 -0
  33. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/licenses/LICENSE +0 -0
  34. {sqlmesh-0.225.1.dev26.dist-info → sqlmesh-0.228.2.dist-info}/top_level.txt +0 -0
sqlmesh/core/loader.py CHANGED
@@ -35,7 +35,7 @@ from sqlmesh.core.model import (
35
35
  from sqlmesh.core.model import model as model_registry
36
36
  from sqlmesh.core.model.common import make_python_env
37
37
  from sqlmesh.core.signal import signal
38
- from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
38
+ from sqlmesh.core.test import ModelTestMetadata
39
39
  from sqlmesh.utils import UniqueKeyDict, sys_path
40
40
  from sqlmesh.utils.errors import ConfigError
41
41
  from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
@@ -64,6 +64,7 @@ class LoadedProject:
64
64
  excluded_requirements: t.Set[str]
65
65
  environment_statements: t.List[EnvironmentStatements]
66
66
  user_rules: RuleSet
67
+ model_test_metadata: t.List[ModelTestMetadata]
67
68
 
68
69
 
69
70
  class CacheBase(abc.ABC):
@@ -243,6 +244,8 @@ class Loader(abc.ABC):
243
244
 
244
245
  user_rules = self._load_linting_rules()
245
246
 
247
+ model_test_metadata = self.load_model_tests()
248
+
246
249
  project = LoadedProject(
247
250
  macros=macros,
248
251
  jinja_macros=jinja_macros,
@@ -254,6 +257,7 @@ class Loader(abc.ABC):
254
257
  excluded_requirements=excluded_requirements,
255
258
  environment_statements=environment_statements,
256
259
  user_rules=user_rules,
260
+ model_test_metadata=model_test_metadata,
257
261
  )
258
262
  return project
259
263
 
@@ -423,9 +427,7 @@ class Loader(abc.ABC):
423
427
  """Loads user linting rules"""
424
428
  return RuleSet()
425
429
 
426
- def load_model_tests(
427
- self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
428
- ) -> t.List[ModelTestMetadata]:
430
+ def load_model_tests(self) -> t.List[ModelTestMetadata]:
429
431
  """Loads YAML-based model tests"""
430
432
  return []
431
433
 
@@ -864,38 +866,23 @@ class SqlMeshLoader(Loader):
864
866
 
865
867
  return model_test_metadata
866
868
 
867
- def load_model_tests(
868
- self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
869
- ) -> t.List[ModelTestMetadata]:
869
+ def load_model_tests(self) -> t.List[ModelTestMetadata]:
870
870
  """Loads YAML-based model tests"""
871
871
  test_meta_list: t.List[ModelTestMetadata] = []
872
872
 
873
- if tests:
874
- for test in tests:
875
- filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
876
-
877
- test_meta = self._load_model_test_file(Path(filename))
878
- if test_name:
879
- test_meta_list.append(test_meta[test_name])
880
- else:
881
- test_meta_list.extend(test_meta.values())
882
- else:
883
- search_path = Path(self.config_path) / c.TESTS
873
+ search_path = Path(self.config_path) / c.TESTS
884
874
 
885
- for yaml_file in itertools.chain(
886
- search_path.glob("**/test*.yaml"),
887
- search_path.glob("**/test*.yml"),
875
+ for yaml_file in itertools.chain(
876
+ search_path.glob("**/test*.yaml"),
877
+ search_path.glob("**/test*.yml"),
878
+ ):
879
+ if any(
880
+ yaml_file.match(ignore_pattern)
881
+ for ignore_pattern in self.config.ignore_patterns or []
888
882
  ):
889
- if any(
890
- yaml_file.match(ignore_pattern)
891
- for ignore_pattern in self.config.ignore_patterns or []
892
- ):
893
- continue
894
-
895
- test_meta_list.extend(self._load_model_test_file(yaml_file).values())
883
+ continue
896
884
 
897
- if patterns:
898
- test_meta_list = filter_tests_by_patterns(test_meta_list, patterns)
885
+ test_meta_list.extend(self._load_model_test_file(yaml_file).values())
899
886
 
900
887
  return test_meta_list
901
888
 
@@ -34,6 +34,7 @@ from sqlmesh.core.model.common import (
34
34
  )
35
35
  from sqlmesh.core.model.meta import ModelMeta
36
36
  from sqlmesh.core.model.kind import (
37
+ ExternalKind,
37
38
  ModelKindName,
38
39
  SeedKind,
39
40
  ModelKind,
@@ -1035,6 +1036,13 @@ class _Model(ModelMeta, frozen=True):
1035
1036
  # Will raise if the custom materialization points to an invalid class
1036
1037
  get_custom_materialization_type_or_raise(self.kind.materialization)
1037
1038
 
1039
+ # Embedded model kind shouldn't have audits
1040
+ if self.kind.name == ModelKindName.EMBEDDED and self.audits:
1041
+ raise_config_error(
1042
+ "Audits are not supported for embedded models",
1043
+ self._path,
1044
+ )
1045
+
1038
1046
  def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
1039
1047
  """Determines whether this model is a breaking change in relation to the `previous` model.
1040
1048
 
@@ -1962,6 +1970,7 @@ class PythonModel(_Model):
1962
1970
  class ExternalModel(_Model):
1963
1971
  """The model definition which represents an external source/table."""
1964
1972
 
1973
+ kind: ModelKind = ExternalKind()
1965
1974
  source_type: t.Literal["external"] = "external"
1966
1975
 
1967
1976
  def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
@@ -63,7 +63,7 @@ class Plan(PydanticModel, frozen=True):
63
63
  restatements: t.Dict[SnapshotId, Interval]
64
64
  """
65
65
  All models being restated, which are typically the explicitly selected ones + their downstream dependencies.
66
-
66
+
67
67
  Note that dev previews are also considered restatements, so :selected_models_to_restate can be empty
68
68
  while :restatements is still populated with dev previews
69
69
  """
@@ -213,8 +213,8 @@ class Plan(PydanticModel, frozen=True):
213
213
 
214
214
  snapshots_by_name = self.context_diff.snapshots_by_name
215
215
  snapshots = [s.table_info for s in self.snapshots.values()]
216
- promoted_snapshot_ids = None
217
- if self.is_dev and not self.include_unmodified:
216
+ promotable_snapshot_ids = None
217
+ if self.is_dev:
218
218
  if self.selected_models_to_backfill is not None:
219
219
  # Only promote models that have been explicitly selected for backfill.
220
220
  promotable_snapshot_ids = {
@@ -225,12 +225,14 @@ class Plan(PydanticModel, frozen=True):
225
225
  if m in snapshots_by_name
226
226
  ],
227
227
  }
228
- else:
228
+ elif not self.include_unmodified:
229
229
  promotable_snapshot_ids = self.context_diff.promotable_snapshot_ids.copy()
230
230
 
231
- promoted_snapshot_ids = [
232
- s.snapshot_id for s in snapshots if s.snapshot_id in promotable_snapshot_ids
233
- ]
231
+ promoted_snapshot_ids = (
232
+ [s.snapshot_id for s in snapshots if s.snapshot_id in promotable_snapshot_ids]
233
+ if promotable_snapshot_ids is not None
234
+ else None
235
+ )
234
236
 
235
237
  previous_finalized_snapshots = (
236
238
  self.context_diff.environment_snapshots
sqlmesh/core/renderer.py CHANGED
@@ -196,7 +196,14 @@ class BaseExpressionRenderer:
196
196
  **kwargs,
197
197
  }
198
198
 
199
+ if this_model:
200
+ render_kwargs["this_model"] = this_model
201
+
202
+ macro_evaluator.locals.update(render_kwargs)
203
+
199
204
  variables = kwargs.pop("variables", {})
205
+ if variables:
206
+ macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
200
207
 
201
208
  expressions = [self._expression]
202
209
  if isinstance(self._expression, d.Jinja):
@@ -268,14 +275,6 @@ class BaseExpressionRenderer:
268
275
  f"Could not parse the rendered jinja at '{self._path}'.\n{ex}"
269
276
  ) from ex
270
277
 
271
- if this_model:
272
- render_kwargs["this_model"] = this_model
273
-
274
- macro_evaluator.locals.update(render_kwargs)
275
-
276
- if variables:
277
- macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
278
-
279
278
  for definition in self._macro_definitions:
280
279
  try:
281
280
  macro_evaluator.evaluate(definition)
sqlmesh/core/scheduler.py CHANGED
@@ -352,7 +352,7 @@ class Scheduler:
352
352
  )
353
353
  for snapshot, intervals in merged_intervals.items()
354
354
  }
355
- snapshot_batches = {}
355
+ snapshot_batches: t.Dict[Snapshot, Intervals] = {}
356
356
  all_unready_intervals: t.Dict[str, set[Interval]] = {}
357
357
  for snapshot_id in dag:
358
358
  if snapshot_id not in snapshot_intervals:
@@ -364,6 +364,14 @@ class Scheduler:
364
364
 
365
365
  adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway)
366
366
 
367
+ parent_intervals: Intervals = []
368
+ for parent_id in snapshot.parents:
369
+ parent_snapshot, _ = snapshot_intervals.get(parent_id, (None, []))
370
+ if not parent_snapshot or parent_snapshot.is_external:
371
+ continue
372
+
373
+ parent_intervals.extend(snapshot_batches[parent_snapshot])
374
+
367
375
  context = ExecutionContext(
368
376
  adapter,
369
377
  self.snapshots_by_name,
@@ -371,6 +379,7 @@ class Scheduler:
371
379
  default_dialect=adapter.dialect,
372
380
  default_catalog=self.default_catalog,
373
381
  is_restatement=is_restatement,
382
+ parent_intervals=parent_intervals,
374
383
  )
375
384
 
376
385
  intervals = self._check_ready_intervals(
@@ -538,6 +547,10 @@ class Scheduler:
538
547
  execution_time=execution_time,
539
548
  )
540
549
  else:
550
+ # If batch_index > 0, then the target table must exist since the first batch would have created it
551
+ target_table_exists = (
552
+ snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0
553
+ )
541
554
  audit_results = self.evaluate(
542
555
  snapshot=snapshot,
543
556
  environment_naming_info=environment_naming_info,
@@ -548,7 +561,7 @@ class Scheduler:
548
561
  batch_index=node.batch_index,
549
562
  allow_destructive_snapshots=allow_destructive_snapshots,
550
563
  allow_additive_snapshots=allow_additive_snapshots,
551
- target_table_exists=snapshot.snapshot_id not in snapshots_to_create,
564
+ target_table_exists=target_table_exists,
552
565
  selected_models=selected_models,
553
566
  )
554
567
 
@@ -646,6 +659,7 @@ class Scheduler:
646
659
  }
647
660
  snapshots_to_create = snapshots_to_create or set()
648
661
  original_snapshots_to_create = snapshots_to_create.copy()
662
+ upstream_dependencies_cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]] = {}
649
663
 
650
664
  snapshot_dag = snapshot_dag or snapshots_to_dag(batches)
651
665
  dag = DAG[SchedulingUnit]()
@@ -657,12 +671,15 @@ class Scheduler:
657
671
  snapshot = self.snapshots_by_name[snapshot_id.name]
658
672
  intervals = intervals_per_snapshot.get(snapshot.name, [])
659
673
 
660
- upstream_dependencies: t.List[SchedulingUnit] = []
674
+ upstream_dependencies: t.Set[SchedulingUnit] = set()
661
675
 
662
676
  for p_sid in snapshot.parents:
663
- upstream_dependencies.extend(
677
+ upstream_dependencies.update(
664
678
  self._find_upstream_dependencies(
665
- p_sid, intervals_per_snapshot, original_snapshots_to_create
679
+ p_sid,
680
+ intervals_per_snapshot,
681
+ original_snapshots_to_create,
682
+ upstream_dependencies_cache,
666
683
  )
667
684
  )
668
685
 
@@ -713,29 +730,42 @@ class Scheduler:
713
730
  parent_sid: SnapshotId,
714
731
  intervals_per_snapshot: t.Dict[str, Intervals],
715
732
  snapshots_to_create: t.Set[SnapshotId],
716
- ) -> t.List[SchedulingUnit]:
733
+ cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]],
734
+ ) -> t.Set[SchedulingUnit]:
717
735
  if parent_sid not in self.snapshots:
718
- return []
736
+ return set()
737
+ if parent_sid in cache:
738
+ return cache[parent_sid]
719
739
 
720
740
  p_intervals = intervals_per_snapshot.get(parent_sid.name, [])
721
741
 
742
+ parent_node: t.Optional[SchedulingUnit] = None
722
743
  if p_intervals:
723
744
  if len(p_intervals) > 1:
724
- return [DummyNode(snapshot_name=parent_sid.name)]
725
- interval = p_intervals[0]
726
- return [EvaluateNode(snapshot_name=parent_sid.name, interval=interval, batch_index=0)]
727
- if parent_sid in snapshots_to_create:
728
- return [CreateNode(snapshot_name=parent_sid.name)]
745
+ parent_node = DummyNode(snapshot_name=parent_sid.name)
746
+ else:
747
+ interval = p_intervals[0]
748
+ parent_node = EvaluateNode(
749
+ snapshot_name=parent_sid.name, interval=interval, batch_index=0
750
+ )
751
+ elif parent_sid in snapshots_to_create:
752
+ parent_node = CreateNode(snapshot_name=parent_sid.name)
753
+
754
+ if parent_node is not None:
755
+ cache[parent_sid] = {parent_node}
756
+ return {parent_node}
757
+
729
758
  # This snapshot has no intervals and doesn't need creation which means
730
759
  # that it can be a transitive dependency
731
- transitive_deps: t.List[SchedulingUnit] = []
760
+ transitive_deps: t.Set[SchedulingUnit] = set()
732
761
  parent_snapshot = self.snapshots[parent_sid]
733
762
  for grandparent_sid in parent_snapshot.parents:
734
- transitive_deps.extend(
763
+ transitive_deps.update(
735
764
  self._find_upstream_dependencies(
736
- grandparent_sid, intervals_per_snapshot, snapshots_to_create
765
+ grandparent_sid, intervals_per_snapshot, snapshots_to_create, cache
737
766
  )
738
767
  )
768
+ cache[parent_sid] = transitive_deps
739
769
  return transitive_deps
740
770
 
741
771
  def _run_or_audit(
sqlmesh/core/signal.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import typing as t
4
4
  from sqlmesh.utils import UniqueKeyDict, registry_decorator
5
+ from sqlmesh.utils.errors import MissingSourceError
5
6
 
6
7
  if t.TYPE_CHECKING:
7
8
  from sqlmesh.core.context import ExecutionContext
@@ -42,7 +43,16 @@ SignalRegistry = UniqueKeyDict[str, signal]
42
43
 
43
44
 
44
45
  @signal()
45
- def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool:
46
+ def freshness(
47
+ batch: DatetimeRanges,
48
+ snapshot: Snapshot,
49
+ context: ExecutionContext,
50
+ ) -> bool:
51
+ """
52
+ Implements model freshness as a signal, i.e it considers this model to be fresh if:
53
+ - Any upstream SQLMesh model has available intervals to compute i.e is fresh
54
+ - Any upstream external model has been altered since the last time the model was evaluated
55
+ """
46
56
  adapter = context.engine_adapter
47
57
  if context.is_restatement or not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS:
48
58
  return True
@@ -54,24 +64,35 @@ def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionConte
54
64
  if deployability_index.is_deployable(snapshot)
55
65
  else snapshot.dev_last_altered_ts
56
66
  )
67
+
57
68
  if not last_altered_ts:
58
69
  return True
59
70
 
60
71
  parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents}
61
- if len(parent_snapshots) != len(snapshot.node.depends_on) or not all(
62
- p.is_external for p in parent_snapshots
63
- ):
64
- # The mismatch can happen if e.g an external model is not registered in the project
72
+
73
+ upstream_parent_snapshots = {p for p in parent_snapshots if not p.is_external}
74
+ external_parents = snapshot.node.depends_on - {p.name for p in upstream_parent_snapshots}
75
+
76
+ if context.parent_intervals:
77
+ # At least one upstream sqlmesh model has intervals to compute (i.e is fresh),
78
+ # so the current model is considered fresh too
65
79
  return True
66
80
 
67
- # Finding new data means that the upstream depedencies have been altered
68
- # since the last time the model was evaluated
69
- upstream_dep_has_new_data = any(
70
- upstream_last_altered_ts > last_altered_ts
71
- for upstream_last_altered_ts in adapter.get_table_last_modified_ts(
72
- [p.name for p in parent_snapshots]
81
+ if external_parents:
82
+ external_last_altered_timestamps = adapter.get_table_last_modified_ts(
83
+ list(external_parents)
84
+ )
85
+
86
+ if len(external_last_altered_timestamps) != len(external_parents):
87
+ raise MissingSourceError(
88
+ f"Expected {len(external_parents)} sources to be present, but got {len(external_last_altered_timestamps)}."
89
+ )
90
+
91
+ # Finding new data means that the upstream depedencies have been altered
92
+ # since the last time the model was evaluated
93
+ return any(
94
+ external_last_altered_ts > last_altered_ts
95
+ for external_last_altered_ts in external_last_altered_timestamps
73
96
  )
74
- )
75
97
 
76
- # Returning true is a no-op, returning False nullifies the batch so the model will not be evaluated.
77
- return upstream_dep_has_new_data
98
+ return False
@@ -2081,16 +2081,20 @@ def missing_intervals(
2081
2081
  continue
2082
2082
  snapshot_end_date = existing_interval_end
2083
2083
 
2084
+ snapshot_start_date = max(
2085
+ to_datetime(snapshot_start_date),
2086
+ to_datetime(start_date(snapshot, snapshots, cache, relative_to=snapshot_end_date)),
2087
+ )
2088
+ if snapshot_start_date > to_datetime(snapshot_end_date):
2089
+ continue
2090
+
2084
2091
  missing_interval_end_date = snapshot_end_date
2085
2092
  node_end_date = snapshot.node.end
2086
2093
  if node_end_date and (to_datetime(node_end_date) < to_datetime(snapshot_end_date)):
2087
2094
  missing_interval_end_date = node_end_date
2088
2095
 
2089
2096
  intervals = snapshot.missing_intervals(
2090
- max(
2091
- to_datetime(snapshot_start_date),
2092
- to_datetime(start_date(snapshot, snapshots, cache, relative_to=snapshot_end_date)),
2093
- ),
2097
+ snapshot_start_date,
2094
2098
  missing_interval_end_date,
2095
2099
  execution_time=execution_time,
2096
2100
  deployability_index=deployability_index,
@@ -2295,14 +2299,16 @@ def start_date(
2295
2299
  if not isinstance(snapshots, dict):
2296
2300
  snapshots = {snapshot.snapshot_id: snapshot for snapshot in snapshots}
2297
2301
 
2298
- earliest = snapshot.node.cron_prev(snapshot.node.cron_floor(relative_to or now()))
2299
-
2300
- for parent in snapshot.parents:
2301
- if parent in snapshots:
2302
- earliest = min(
2303
- earliest,
2304
- start_date(snapshots[parent], snapshots, cache=cache, relative_to=relative_to),
2305
- )
2302
+ parent_starts = [
2303
+ start_date(snapshots[parent], snapshots, cache=cache, relative_to=relative_to)
2304
+ for parent in snapshot.parents
2305
+ if parent in snapshots
2306
+ ]
2307
+ earliest = (
2308
+ min(parent_starts)
2309
+ if parent_starts
2310
+ else snapshot.node.cron_prev(snapshot.node.cron_floor(relative_to or now()))
2311
+ )
2306
2312
 
2307
2313
  cache[key] = earliest
2308
2314
  return earliest
@@ -1021,6 +1021,11 @@ class SnapshotEvaluator:
1021
1021
  ):
1022
1022
  import pandas as pd
1023
1023
 
1024
+ try:
1025
+ first_query_or_df = next(queries_or_dfs)
1026
+ except StopIteration:
1027
+ return
1028
+
1024
1029
  query_or_df = reduce(
1025
1030
  lambda a, b: (
1026
1031
  pd.concat([a, b], ignore_index=True) # type: ignore
@@ -1028,6 +1033,7 @@ class SnapshotEvaluator:
1028
1033
  else a.union_all(b) # type: ignore
1029
1034
  ), # type: ignore
1030
1035
  queries_or_dfs,
1036
+ first_query_or_df,
1031
1037
  )
1032
1038
  apply(query_or_df, index=0)
1033
1039
  else:
@@ -1593,14 +1599,14 @@ class SnapshotEvaluator:
1593
1599
  tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
1594
1600
  defaultdict(lambda: defaultdict(set))
1595
1601
  )
1596
- snapshots_by_table_name: t.Dict[str, Snapshot] = {}
1602
+ snapshots_by_table_name: t.Dict[exp.Table, t.Dict[str, Snapshot]] = defaultdict(dict)
1597
1603
  for snapshot in target_snapshots:
1598
1604
  if not snapshot.is_model or snapshot.is_symbolic:
1599
1605
  continue
1600
1606
  table = table_name_callable(snapshot)
1601
1607
  table_schema = d.schema_(table.db, catalog=table.catalog)
1602
1608
  tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)
1603
- snapshots_by_table_name[table.name] = snapshot
1609
+ snapshots_by_table_name[table_schema][table.name] = snapshot
1604
1610
 
1605
1611
  def _get_data_objects_in_schema(
1606
1612
  schema: exp.Table,
@@ -1613,23 +1619,25 @@ class SnapshotEvaluator:
1613
1619
  )
1614
1620
 
1615
1621
  with self.concurrent_context():
1616
- existing_objects: t.List[DataObject] = []
1622
+ snapshot_id_to_obj: t.Dict[SnapshotId, DataObject] = {}
1617
1623
  # A schema can be shared across multiple engines, so we need to group tables by both gateway and schema
1618
1624
  for gateway, tables_by_schema in tables_by_gateway_and_schema.items():
1619
- objs_for_gateway = [
1620
- obj
1621
- for objs in concurrent_apply_to_values(
1622
- list(tables_by_schema),
1623
- lambda s: _get_data_objects_in_schema(
1624
- schema=s, object_names=tables_by_schema.get(s), gateway=gateway
1625
- ),
1626
- self.ddl_concurrent_tasks,
1627
- )
1628
- for obj in objs
1629
- ]
1630
- existing_objects.extend(objs_for_gateway)
1625
+ schema_list = list(tables_by_schema.keys())
1626
+ results = concurrent_apply_to_values(
1627
+ schema_list,
1628
+ lambda s: _get_data_objects_in_schema(
1629
+ schema=s, object_names=tables_by_schema.get(s), gateway=gateway
1630
+ ),
1631
+ self.ddl_concurrent_tasks,
1632
+ )
1633
+
1634
+ for schema, objs in zip(schema_list, results):
1635
+ snapshots_by_name = snapshots_by_table_name.get(schema, {})
1636
+ for obj in objs:
1637
+ if obj.name in snapshots_by_name:
1638
+ snapshot_id_to_obj[snapshots_by_name[obj.name].snapshot_id] = obj
1631
1639
 
1632
- return {snapshots_by_table_name[obj.name].snapshot_id: obj for obj in existing_objects}
1640
+ return snapshot_id_to_obj
1633
1641
 
1634
1642
 
1635
1643
  def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy:
@@ -355,11 +355,12 @@ class ModelTest(unittest.TestCase):
355
355
  for df in _split_df_by_column_pairs(diff)
356
356
  )
357
357
  else:
358
- from pandas import MultiIndex
358
+ from pandas import DataFrame, MultiIndex
359
359
 
360
360
  levels = t.cast(MultiIndex, diff.columns).levels[0]
361
361
  for col in levels:
362
- col_diff = diff[col]
362
+ # diff[col] returns a DataFrame when columns is a MultiIndex
363
+ col_diff = t.cast(DataFrame, diff[col])
363
364
  if not col_diff.empty:
364
365
  table = df_to_table(
365
366
  f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",
@@ -807,7 +808,7 @@ class PythonModelTest(ModelTest):
807
808
  actual_df.reset_index(drop=True, inplace=True)
808
809
  expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
809
810
 
810
- self.assert_equal(expected, actual_df, sort=False, partial=partial)
811
+ self.assert_equal(expected, actual_df, sort=True, partial=partial)
811
812
 
812
813
  def _execute_model(self) -> pd.DataFrame:
813
814
  """Executes the python model and returns a DataFrame."""
@@ -925,8 +926,7 @@ def generate_test(
925
926
  cte_output = test._execute(cte_query)
926
927
  ctes[cte.alias] = (
927
928
  pandas_timestamp_to_pydatetime(
928
- cte_output.apply(lambda col: col.map(_normalize_df_value)),
929
- cte_query.named_selects,
929
+ df=cte_output.apply(lambda col: col.map(_normalize_df_value)),
930
930
  )
931
931
  .replace({np.nan: None})
932
932
  .to_dict(orient="records")
@@ -20,6 +20,10 @@ class ModelTestMetadata(PydanticModel):
20
20
  def fully_qualified_test_name(self) -> str:
21
21
  return f"{self.path}::{self.test_name}"
22
22
 
23
+ @property
24
+ def model_name(self) -> str:
25
+ return self.body.get("model", "")
26
+
23
27
  def __hash__(self) -> int:
24
28
  return self.fully_qualified_test_name.__hash__()
25
29
 
sqlmesh/dbt/common.py CHANGED
@@ -46,7 +46,9 @@ def load_yaml(source: str | Path) -> t.Dict:
46
46
  raise ConfigError(f"{source}: {ex}" if isinstance(source, Path) else f"{ex}")
47
47
 
48
48
 
49
- def parse_meta(v: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
49
+ def parse_meta(v: t.Optional[t.Dict[str, t.Any]]) -> t.Dict[str, t.Any]:
50
+ if v is None:
51
+ return {}
50
52
  for key, value in v.items():
51
53
  if isinstance(value, str):
52
54
  v[key] = try_str_to_bool(value)
@@ -115,7 +117,7 @@ class GeneralConfig(DbtConfig):
115
117
 
116
118
  @field_validator("meta", mode="before")
117
119
  @classmethod
118
- def _validate_meta(cls, v: t.Dict[str, t.Union[str, t.Any]]) -> t.Dict[str, t.Any]:
120
+ def _validate_meta(cls, v: t.Optional[t.Dict[str, t.Union[str, t.Any]]]) -> t.Dict[str, t.Any]:
119
121
  return parse_meta(v)
120
122
 
121
123
  _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
sqlmesh/dbt/manifest.py CHANGED
@@ -11,7 +11,7 @@ from collections import defaultdict
11
11
  from functools import cached_property
12
12
  from pathlib import Path
13
13
 
14
- from dbt import constants as dbt_constants, flags
14
+ from dbt import flags
15
15
 
16
16
  from sqlmesh.dbt.util import DBT_VERSION
17
17
  from sqlmesh.utils.conversions import make_serializable
@@ -19,6 +19,8 @@ from sqlmesh.utils.conversions import make_serializable
19
19
  # Override the file name to prevent dbt commands from invalidating the cache.
20
20
 
21
21
  if DBT_VERSION >= (1, 6, 0):
22
+ from dbt import constants as dbt_constants
23
+
22
24
  dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" # type: ignore
23
25
  else:
24
26
  from dbt.parser import manifest as dbt_manifest # type: ignore
@@ -25,12 +25,21 @@ logger = logging.getLogger(__name__)
25
25
  envvar="GITHUB_TOKEN",
26
26
  help="The Github Token to be used. Pass in `${{ secrets.GITHUB_TOKEN }}` if you want to use the one created by Github actions",
27
27
  )
28
+ @click.option(
29
+ "--full-logs",
30
+ is_flag=True,
31
+ help="Whether to print all logs in the Github Actions output or only in their relevant GA check",
32
+ )
28
33
  @click.pass_context
29
- def github(ctx: click.Context, token: str) -> None:
34
+ def github(ctx: click.Context, token: str, full_logs: bool = False) -> None:
30
35
  """Github Action CI/CD Bot. See https://sqlmesh.readthedocs.io/en/stable/integrations/github/ for details"""
31
36
  # set a larger width because if none is specified, it auto-detects 80 characters when running in GitHub Actions
32
37
  # which can result in surprise newlines when outputting dates to backfill
33
- set_console(MarkdownConsole(width=1000, warning_capture_only=True, error_capture_only=True))
38
+ set_console(
39
+ MarkdownConsole(
40
+ width=1000, warning_capture_only=not full_logs, error_capture_only=not full_logs
41
+ )
42
+ )
34
43
  ctx.obj["github"] = GithubController(
35
44
  paths=ctx.obj["paths"],
36
45
  token=token,
@@ -448,10 +448,9 @@ class GithubController:
448
448
  c.PROD,
449
449
  # this is required to highlight any data gaps between this PR environment and prod (since PR environments may only contain a subset of data)
450
450
  no_gaps=False,
451
- # this works because the snapshots were already categorized when applying self.pr_plan so there are no uncategorized local snapshots to trigger a plan error
452
- no_auto_categorization=True,
453
451
  skip_tests=True,
454
452
  skip_linter=True,
453
+ categorizer_config=self.bot_config.auto_categorize_changes,
455
454
  run=self.bot_config.run_on_deploy_to_prod,
456
455
  forward_only=self.forward_only_plan,
457
456
  )
@@ -773,6 +772,11 @@ class GithubController:
773
772
  "PR is already merged and this event was triggered prior to the merge."
774
773
  )
775
774
  merge_status = self._get_merge_state_status()
775
+ if merge_status.is_blocked:
776
+ raise CICDBotError(
777
+ "Branch protection or ruleset requirement is likely not satisfied, e.g. missing CODEOWNERS approval. "
778
+ "Please check PR and resolve any issues."
779
+ )
776
780
  if merge_status.is_dirty:
777
781
  raise CICDBotError(
778
782
  "Merge commit cannot be cleanly created. Likely from a merge conflict. "