sqlmesh 0.225.0__py3-none-any.whl → 0.227.2.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlmesh might be problematic. Click here for more details.

Files changed (43) hide show
  1. sqlmesh/__init__.py +10 -2
  2. sqlmesh/_version.py +2 -2
  3. sqlmesh/core/config/connection.py +10 -5
  4. sqlmesh/core/config/loader.py +1 -0
  5. sqlmesh/core/context.py +73 -8
  6. sqlmesh/core/engine_adapter/base.py +12 -7
  7. sqlmesh/core/engine_adapter/fabric.py +1 -2
  8. sqlmesh/core/engine_adapter/mssql.py +5 -2
  9. sqlmesh/core/engine_adapter/trino.py +1 -1
  10. sqlmesh/core/lineage.py +1 -0
  11. sqlmesh/core/linter/rules/builtin.py +15 -0
  12. sqlmesh/core/loader.py +4 -0
  13. sqlmesh/core/model/kind.py +2 -2
  14. sqlmesh/core/plan/definition.py +9 -7
  15. sqlmesh/core/renderer.py +7 -8
  16. sqlmesh/core/scheduler.py +45 -15
  17. sqlmesh/core/signal.py +35 -14
  18. sqlmesh/core/snapshot/definition.py +18 -12
  19. sqlmesh/core/snapshot/evaluator.py +31 -17
  20. sqlmesh/core/state_sync/db/snapshot.py +6 -1
  21. sqlmesh/core/table_diff.py +2 -2
  22. sqlmesh/core/test/definition.py +5 -3
  23. sqlmesh/core/test/discovery.py +4 -0
  24. sqlmesh/dbt/builtin.py +9 -11
  25. sqlmesh/dbt/column.py +17 -5
  26. sqlmesh/dbt/common.py +4 -2
  27. sqlmesh/dbt/context.py +2 -0
  28. sqlmesh/dbt/loader.py +15 -2
  29. sqlmesh/dbt/manifest.py +3 -1
  30. sqlmesh/dbt/model.py +13 -1
  31. sqlmesh/dbt/profile.py +3 -3
  32. sqlmesh/dbt/target.py +9 -4
  33. sqlmesh/utils/date.py +1 -1
  34. sqlmesh/utils/pydantic.py +6 -6
  35. sqlmesh/utils/windows.py +13 -3
  36. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/METADATA +2 -2
  37. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/RECORD +43 -43
  38. sqlmesh_dbt/cli.py +26 -1
  39. sqlmesh_dbt/operations.py +8 -2
  40. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/WHEEL +0 -0
  41. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/entry_points.txt +0 -0
  42. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/licenses/LICENSE +0 -0
  43. {sqlmesh-0.225.0.dist-info → sqlmesh-0.227.2.dev6.dist-info}/top_level.txt +0 -0
sqlmesh/core/scheduler.py CHANGED
@@ -352,7 +352,7 @@ class Scheduler:
352
352
  )
353
353
  for snapshot, intervals in merged_intervals.items()
354
354
  }
355
- snapshot_batches = {}
355
+ snapshot_batches: t.Dict[Snapshot, Intervals] = {}
356
356
  all_unready_intervals: t.Dict[str, set[Interval]] = {}
357
357
  for snapshot_id in dag:
358
358
  if snapshot_id not in snapshot_intervals:
@@ -364,6 +364,14 @@ class Scheduler:
364
364
 
365
365
  adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway)
366
366
 
367
+ parent_intervals: Intervals = []
368
+ for parent_id in snapshot.parents:
369
+ parent_snapshot, _ = snapshot_intervals.get(parent_id, (None, []))
370
+ if not parent_snapshot or parent_snapshot.is_external:
371
+ continue
372
+
373
+ parent_intervals.extend(snapshot_batches[parent_snapshot])
374
+
367
375
  context = ExecutionContext(
368
376
  adapter,
369
377
  self.snapshots_by_name,
@@ -371,6 +379,7 @@ class Scheduler:
371
379
  default_dialect=adapter.dialect,
372
380
  default_catalog=self.default_catalog,
373
381
  is_restatement=is_restatement,
382
+ parent_intervals=parent_intervals,
374
383
  )
375
384
 
376
385
  intervals = self._check_ready_intervals(
@@ -538,6 +547,10 @@ class Scheduler:
538
547
  execution_time=execution_time,
539
548
  )
540
549
  else:
550
+ # If batch_index > 0, then the target table must exist since the first batch would have created it
551
+ target_table_exists = (
552
+ snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0
553
+ )
541
554
  audit_results = self.evaluate(
542
555
  snapshot=snapshot,
543
556
  environment_naming_info=environment_naming_info,
@@ -548,7 +561,7 @@ class Scheduler:
548
561
  batch_index=node.batch_index,
549
562
  allow_destructive_snapshots=allow_destructive_snapshots,
550
563
  allow_additive_snapshots=allow_additive_snapshots,
551
- target_table_exists=snapshot.snapshot_id not in snapshots_to_create,
564
+ target_table_exists=target_table_exists,
552
565
  selected_models=selected_models,
553
566
  )
554
567
 
@@ -646,6 +659,7 @@ class Scheduler:
646
659
  }
647
660
  snapshots_to_create = snapshots_to_create or set()
648
661
  original_snapshots_to_create = snapshots_to_create.copy()
662
+ upstream_dependencies_cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]] = {}
649
663
 
650
664
  snapshot_dag = snapshot_dag or snapshots_to_dag(batches)
651
665
  dag = DAG[SchedulingUnit]()
@@ -657,12 +671,15 @@ class Scheduler:
657
671
  snapshot = self.snapshots_by_name[snapshot_id.name]
658
672
  intervals = intervals_per_snapshot.get(snapshot.name, [])
659
673
 
660
- upstream_dependencies: t.List[SchedulingUnit] = []
674
+ upstream_dependencies: t.Set[SchedulingUnit] = set()
661
675
 
662
676
  for p_sid in snapshot.parents:
663
- upstream_dependencies.extend(
677
+ upstream_dependencies.update(
664
678
  self._find_upstream_dependencies(
665
- p_sid, intervals_per_snapshot, original_snapshots_to_create
679
+ p_sid,
680
+ intervals_per_snapshot,
681
+ original_snapshots_to_create,
682
+ upstream_dependencies_cache,
666
683
  )
667
684
  )
668
685
 
@@ -713,29 +730,42 @@ class Scheduler:
713
730
  parent_sid: SnapshotId,
714
731
  intervals_per_snapshot: t.Dict[str, Intervals],
715
732
  snapshots_to_create: t.Set[SnapshotId],
716
- ) -> t.List[SchedulingUnit]:
733
+ cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]],
734
+ ) -> t.Set[SchedulingUnit]:
717
735
  if parent_sid not in self.snapshots:
718
- return []
736
+ return set()
737
+ if parent_sid in cache:
738
+ return cache[parent_sid]
719
739
 
720
740
  p_intervals = intervals_per_snapshot.get(parent_sid.name, [])
721
741
 
742
+ parent_node: t.Optional[SchedulingUnit] = None
722
743
  if p_intervals:
723
744
  if len(p_intervals) > 1:
724
- return [DummyNode(snapshot_name=parent_sid.name)]
725
- interval = p_intervals[0]
726
- return [EvaluateNode(snapshot_name=parent_sid.name, interval=interval, batch_index=0)]
727
- if parent_sid in snapshots_to_create:
728
- return [CreateNode(snapshot_name=parent_sid.name)]
745
+ parent_node = DummyNode(snapshot_name=parent_sid.name)
746
+ else:
747
+ interval = p_intervals[0]
748
+ parent_node = EvaluateNode(
749
+ snapshot_name=parent_sid.name, interval=interval, batch_index=0
750
+ )
751
+ elif parent_sid in snapshots_to_create:
752
+ parent_node = CreateNode(snapshot_name=parent_sid.name)
753
+
754
+ if parent_node is not None:
755
+ cache[parent_sid] = {parent_node}
756
+ return {parent_node}
757
+
729
758
  # This snapshot has no intervals and doesn't need creation which means
730
759
  # that it can be a transitive dependency
731
- transitive_deps: t.List[SchedulingUnit] = []
760
+ transitive_deps: t.Set[SchedulingUnit] = set()
732
761
  parent_snapshot = self.snapshots[parent_sid]
733
762
  for grandparent_sid in parent_snapshot.parents:
734
- transitive_deps.extend(
763
+ transitive_deps.update(
735
764
  self._find_upstream_dependencies(
736
- grandparent_sid, intervals_per_snapshot, snapshots_to_create
765
+ grandparent_sid, intervals_per_snapshot, snapshots_to_create, cache
737
766
  )
738
767
  )
768
+ cache[parent_sid] = transitive_deps
739
769
  return transitive_deps
740
770
 
741
771
  def _run_or_audit(
sqlmesh/core/signal.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import typing as t
4
4
  from sqlmesh.utils import UniqueKeyDict, registry_decorator
5
+ from sqlmesh.utils.errors import MissingSourceError
5
6
 
6
7
  if t.TYPE_CHECKING:
7
8
  from sqlmesh.core.context import ExecutionContext
@@ -42,7 +43,16 @@ SignalRegistry = UniqueKeyDict[str, signal]
42
43
 
43
44
 
44
45
  @signal()
45
- def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool:
46
+ def freshness(
47
+ batch: DatetimeRanges,
48
+ snapshot: Snapshot,
49
+ context: ExecutionContext,
50
+ ) -> bool:
51
+ """
52
+ Implements model freshness as a signal, i.e it considers this model to be fresh if:
53
+ - Any upstream SQLMesh model has available intervals to compute i.e is fresh
54
+ - Any upstream external model has been altered since the last time the model was evaluated
55
+ """
46
56
  adapter = context.engine_adapter
47
57
  if context.is_restatement or not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS:
48
58
  return True
@@ -54,24 +64,35 @@ def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionConte
54
64
  if deployability_index.is_deployable(snapshot)
55
65
  else snapshot.dev_last_altered_ts
56
66
  )
67
+
57
68
  if not last_altered_ts:
58
69
  return True
59
70
 
60
71
  parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents}
61
- if len(parent_snapshots) != len(snapshot.node.depends_on) or not all(
62
- p.is_external for p in parent_snapshots
63
- ):
64
- # The mismatch can happen if e.g an external model is not registered in the project
72
+
73
+ upstream_parent_snapshots = {p for p in parent_snapshots if not p.is_external}
74
+ external_parents = snapshot.node.depends_on - {p.name for p in upstream_parent_snapshots}
75
+
76
+ if context.parent_intervals:
77
+ # At least one upstream sqlmesh model has intervals to compute (i.e is fresh),
78
+ # so the current model is considered fresh too
65
79
  return True
66
80
 
67
- # Finding new data means that the upstream depedencies have been altered
68
- # since the last time the model was evaluated
69
- upstream_dep_has_new_data = any(
70
- upstream_last_altered_ts > last_altered_ts
71
- for upstream_last_altered_ts in adapter.get_table_last_modified_ts(
72
- [p.name for p in parent_snapshots]
81
+ if external_parents:
82
+ external_last_altered_timestamps = adapter.get_table_last_modified_ts(
83
+ list(external_parents)
84
+ )
85
+
86
+ if len(external_last_altered_timestamps) != len(external_parents):
87
+ raise MissingSourceError(
88
+ f"Expected {len(external_parents)} sources to be present, but got {len(external_last_altered_timestamps)}."
89
+ )
90
+
91
+ # Finding new data means that the upstream depedencies have been altered
92
+ # since the last time the model was evaluated
93
+ return any(
94
+ external_last_altered_ts > last_altered_ts
95
+ for external_last_altered_ts in external_last_altered_timestamps
73
96
  )
74
- )
75
97
 
76
- # Returning true is a no-op, returning False nullifies the batch so the model will not be evaluated.
77
- return upstream_dep_has_new_data
98
+ return False
@@ -2081,16 +2081,20 @@ def missing_intervals(
2081
2081
  continue
2082
2082
  snapshot_end_date = existing_interval_end
2083
2083
 
2084
+ snapshot_start_date = max(
2085
+ to_datetime(snapshot_start_date),
2086
+ to_datetime(start_date(snapshot, snapshots, cache, relative_to=snapshot_end_date)),
2087
+ )
2088
+ if snapshot_start_date > to_datetime(snapshot_end_date):
2089
+ continue
2090
+
2084
2091
  missing_interval_end_date = snapshot_end_date
2085
2092
  node_end_date = snapshot.node.end
2086
2093
  if node_end_date and (to_datetime(node_end_date) < to_datetime(snapshot_end_date)):
2087
2094
  missing_interval_end_date = node_end_date
2088
2095
 
2089
2096
  intervals = snapshot.missing_intervals(
2090
- max(
2091
- to_datetime(snapshot_start_date),
2092
- to_datetime(start_date(snapshot, snapshots, cache, relative_to=snapshot_end_date)),
2093
- ),
2097
+ snapshot_start_date,
2094
2098
  missing_interval_end_date,
2095
2099
  execution_time=execution_time,
2096
2100
  deployability_index=deployability_index,
@@ -2295,14 +2299,16 @@ def start_date(
2295
2299
  if not isinstance(snapshots, dict):
2296
2300
  snapshots = {snapshot.snapshot_id: snapshot for snapshot in snapshots}
2297
2301
 
2298
- earliest = snapshot.node.cron_prev(snapshot.node.cron_floor(relative_to or now()))
2299
-
2300
- for parent in snapshot.parents:
2301
- if parent in snapshots:
2302
- earliest = min(
2303
- earliest,
2304
- start_date(snapshots[parent], snapshots, cache=cache, relative_to=relative_to),
2305
- )
2302
+ parent_starts = [
2303
+ start_date(snapshots[parent], snapshots, cache=cache, relative_to=relative_to)
2304
+ for parent in snapshot.parents
2305
+ if parent in snapshots
2306
+ ]
2307
+ earliest = (
2308
+ min(parent_starts)
2309
+ if parent_starts
2310
+ else snapshot.node.cron_prev(snapshot.node.cron_floor(relative_to or now()))
2311
+ )
2306
2312
 
2307
2313
  cache[key] = earliest
2308
2314
  return earliest
@@ -1021,6 +1021,11 @@ class SnapshotEvaluator:
1021
1021
  ):
1022
1022
  import pandas as pd
1023
1023
 
1024
+ try:
1025
+ first_query_or_df = next(queries_or_dfs)
1026
+ except StopIteration:
1027
+ return
1028
+
1024
1029
  query_or_df = reduce(
1025
1030
  lambda a, b: (
1026
1031
  pd.concat([a, b], ignore_index=True) # type: ignore
@@ -1028,6 +1033,7 @@ class SnapshotEvaluator:
1028
1033
  else a.union_all(b) # type: ignore
1029
1034
  ), # type: ignore
1030
1035
  queries_or_dfs,
1036
+ first_query_or_df,
1031
1037
  )
1032
1038
  apply(query_or_df, index=0)
1033
1039
  else:
@@ -1593,14 +1599,14 @@ class SnapshotEvaluator:
1593
1599
  tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
1594
1600
  defaultdict(lambda: defaultdict(set))
1595
1601
  )
1596
- snapshots_by_table_name: t.Dict[str, Snapshot] = {}
1602
+ snapshots_by_table_name: t.Dict[exp.Table, t.Dict[str, Snapshot]] = defaultdict(dict)
1597
1603
  for snapshot in target_snapshots:
1598
1604
  if not snapshot.is_model or snapshot.is_symbolic:
1599
1605
  continue
1600
1606
  table = table_name_callable(snapshot)
1601
1607
  table_schema = d.schema_(table.db, catalog=table.catalog)
1602
1608
  tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)
1603
- snapshots_by_table_name[table.name] = snapshot
1609
+ snapshots_by_table_name[table_schema][table.name] = snapshot
1604
1610
 
1605
1611
  def _get_data_objects_in_schema(
1606
1612
  schema: exp.Table,
@@ -1613,23 +1619,25 @@ class SnapshotEvaluator:
1613
1619
  )
1614
1620
 
1615
1621
  with self.concurrent_context():
1616
- existing_objects: t.List[DataObject] = []
1622
+ snapshot_id_to_obj: t.Dict[SnapshotId, DataObject] = {}
1617
1623
  # A schema can be shared across multiple engines, so we need to group tables by both gateway and schema
1618
1624
  for gateway, tables_by_schema in tables_by_gateway_and_schema.items():
1619
- objs_for_gateway = [
1620
- obj
1621
- for objs in concurrent_apply_to_values(
1622
- list(tables_by_schema),
1623
- lambda s: _get_data_objects_in_schema(
1624
- schema=s, object_names=tables_by_schema.get(s), gateway=gateway
1625
- ),
1626
- self.ddl_concurrent_tasks,
1627
- )
1628
- for obj in objs
1629
- ]
1630
- existing_objects.extend(objs_for_gateway)
1625
+ schema_list = list(tables_by_schema.keys())
1626
+ results = concurrent_apply_to_values(
1627
+ schema_list,
1628
+ lambda s: _get_data_objects_in_schema(
1629
+ schema=s, object_names=tables_by_schema.get(s), gateway=gateway
1630
+ ),
1631
+ self.ddl_concurrent_tasks,
1632
+ )
1633
+
1634
+ for schema, objs in zip(schema_list, results):
1635
+ snapshots_by_name = snapshots_by_table_name.get(schema, {})
1636
+ for obj in objs:
1637
+ if obj.name in snapshots_by_name:
1638
+ snapshot_id_to_obj[snapshots_by_name[obj.name].snapshot_id] = obj
1631
1639
 
1632
- return {snapshots_by_table_name[obj.name].snapshot_id: obj for obj in existing_objects}
1640
+ return snapshot_id_to_obj
1633
1641
 
1634
1642
 
1635
1643
  def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy:
@@ -2185,7 +2193,13 @@ class MaterializableStrategy(PromotableStrategy, abc.ABC):
2185
2193
  if model.on_destructive_change.is_ignore or model.on_additive_change.is_ignore:
2186
2194
  # We need to identify the columns that are only in the source so we create an empty table with
2187
2195
  # the user query to determine that
2188
- with self.adapter.temp_table(model.ctas_query(**render_kwargs)) as temp_table:
2196
+ temp_table_name = exp.table_(
2197
+ "diff",
2198
+ db=model.physical_schema,
2199
+ )
2200
+ with self.adapter.temp_table(
2201
+ model.ctas_query(**render_kwargs), name=temp_table_name
2202
+ ) as temp_table:
2189
2203
  source_columns = list(self.adapter.columns(temp_table))
2190
2204
  else:
2191
2205
  source_columns = None
@@ -185,7 +185,12 @@ class SnapshotState:
185
185
  promoted_snapshot_ids = {
186
186
  snapshot.snapshot_id
187
187
  for environment in environments
188
- for snapshot in environment.snapshots
188
+ for snapshot in (
189
+ environment.snapshots
190
+ if environment.finalized_ts is not None
191
+ # If the environment is not finalized, check both the current snapshots and the previous finalized snapshots
192
+ else [*environment.snapshots, *(environment.previous_finalized_snapshots or [])]
193
+ )
189
194
  }
190
195
 
191
196
  if promoted_snapshot_ids:
@@ -367,8 +367,8 @@ class TableDiff:
367
367
  column_type = matched_columns[name]
368
368
  qualified_column = exp.column(name, table)
369
369
 
370
- if column_type.is_type(*exp.DataType.FLOAT_TYPES):
371
- return exp.func("ROUND", qualified_column, exp.Literal.number(self.decimals))
370
+ if column_type.is_type(*exp.DataType.REAL_TYPES):
371
+ return self.adapter._normalize_decimal_value(qualified_column, self.decimals)
372
372
  if column_type.is_type(*exp.DataType.NESTED_TYPES):
373
373
  return self.adapter._normalize_nested_value(qualified_column)
374
374
 
@@ -454,6 +454,9 @@ class ModelTest(unittest.TestCase):
454
454
  query = outputs.get("query")
455
455
  partial = outputs.pop("partial", None)
456
456
 
457
+ if ctes is None and query is None:
458
+ _raise_error("Incomplete test, outputs must contain 'query' or 'ctes'", self.path)
459
+
457
460
  def _normalize_rows(
458
461
  values: t.List[Row] | t.Dict,
459
462
  name: str,
@@ -804,7 +807,7 @@ class PythonModelTest(ModelTest):
804
807
  actual_df.reset_index(drop=True, inplace=True)
805
808
  expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
806
809
 
807
- self.assert_equal(expected, actual_df, sort=False, partial=partial)
810
+ self.assert_equal(expected, actual_df, sort=True, partial=partial)
808
811
 
809
812
  def _execute_model(self) -> pd.DataFrame:
810
813
  """Executes the python model and returns a DataFrame."""
@@ -922,8 +925,7 @@ def generate_test(
922
925
  cte_output = test._execute(cte_query)
923
926
  ctes[cte.alias] = (
924
927
  pandas_timestamp_to_pydatetime(
925
- cte_output.apply(lambda col: col.map(_normalize_df_value)),
926
- cte_query.named_selects,
928
+ df=cte_output.apply(lambda col: col.map(_normalize_df_value)),
927
929
  )
928
930
  .replace({np.nan: None})
929
931
  .to_dict(orient="records")
@@ -20,6 +20,10 @@ class ModelTestMetadata(PydanticModel):
20
20
  def fully_qualified_test_name(self) -> str:
21
21
  return f"{self.path}::{self.test_name}"
22
22
 
23
+ @property
24
+ def model_name(self) -> str:
25
+ return self.body.get("model", "")
26
+
23
27
  def __hash__(self) -> int:
24
28
  return self.fully_qualified_test_name.__hash__()
25
29
 
sqlmesh/dbt/builtin.py CHANGED
@@ -25,7 +25,7 @@ from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
25
25
  from sqlmesh.dbt.util import DBT_VERSION
26
26
  from sqlmesh.utils import AttributeDict, debug_mode_enabled, yaml
27
27
  from sqlmesh.utils.date import now
28
- from sqlmesh.utils.errors import ConfigError, MacroEvalError
28
+ from sqlmesh.utils.errors import ConfigError
29
29
  from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference, MacroReturnVal
30
30
 
31
31
  logger = logging.getLogger(__name__)
@@ -381,18 +381,16 @@ def do_zip(*args: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]
381
381
  return default
382
382
 
383
383
 
384
- def as_bool(value: str) -> bool:
385
- result = _try_literal_eval(value)
386
- if isinstance(result, bool):
387
- return result
388
- raise MacroEvalError(f"Failed to convert '{value}' into boolean.")
384
+ def as_bool(value: t.Any) -> t.Any:
385
+ # dbt's jinja TEXT_FILTERS just return the input value as is
386
+ # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559
387
+ return value
389
388
 
390
389
 
391
390
  def as_number(value: str) -> t.Any:
392
- result = _try_literal_eval(value)
393
- if isinstance(value, (int, float)) and not isinstance(result, bool):
394
- return result
395
- raise MacroEvalError(f"Failed to convert '{value}' into number.")
391
+ # dbt's jinja TEXT_FILTERS just return the input value as is
392
+ # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559
393
+ return value
396
394
 
397
395
 
398
396
  def _try_literal_eval(value: str) -> t.Any:
@@ -482,7 +480,7 @@ def create_builtin_globals(
482
480
  if variables is not None:
483
481
  builtin_globals["var"] = Var(variables)
484
482
 
485
- builtin_globals["config"] = Config(jinja_globals.pop("config", {}))
483
+ builtin_globals["config"] = Config(jinja_globals.pop("config", {"tags": []}))
486
484
 
487
485
  deployability_index = (
488
486
  jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable()
sqlmesh/dbt/column.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import typing as t
4
+ import logging
4
5
 
5
6
  from sqlglot import exp, parse_one
6
7
  from sqlglot.helper import ensure_list
@@ -9,6 +10,8 @@ from sqlmesh.dbt.common import GeneralConfig
9
10
  from sqlmesh.utils.conversions import ensure_bool
10
11
  from sqlmesh.utils.pydantic import field_validator
11
12
 
13
+ logger = logging.getLogger(__name__)
14
+
12
15
 
13
16
  def yaml_to_columns(
14
17
  yaml: t.Dict[str, ColumnConfig] | t.List[t.Dict[str, ColumnConfig]],
@@ -31,11 +34,20 @@ def column_types_to_sqlmesh(
31
34
  Returns:
32
35
  A dict of column name to exp.DataType
33
36
  """
34
- return {
35
- name: parse_one(column.data_type, into=exp.DataType, dialect=dialect or "")
36
- for name, column in columns.items()
37
- if column.enabled and column.data_type
38
- }
37
+ col_types_to_sqlmesh: t.Dict[str, exp.DataType] = {}
38
+ for name, column in columns.items():
39
+ if column.enabled and column.data_type:
40
+ column_def = parse_one(
41
+ f"{name} {column.data_type}", into=exp.ColumnDef, dialect=dialect or ""
42
+ )
43
+ if column_def.args.get("constraints"):
44
+ logger.warning(
45
+ f"Ignoring unsupported constraints for column '{name}' with definition '{column.data_type}'. Please refer to github.com/TobikoData/sqlmesh/issues/4717 for more information."
46
+ )
47
+ kind = column_def.kind
48
+ if kind:
49
+ col_types_to_sqlmesh[name] = kind
50
+ return col_types_to_sqlmesh
39
51
 
40
52
 
41
53
  def column_descriptions_to_sqlmesh(columns: t.Dict[str, ColumnConfig]) -> t.Dict[str, str]:
sqlmesh/dbt/common.py CHANGED
@@ -46,7 +46,9 @@ def load_yaml(source: str | Path) -> t.Dict:
46
46
  raise ConfigError(f"{source}: {ex}" if isinstance(source, Path) else f"{ex}")
47
47
 
48
48
 
49
- def parse_meta(v: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
49
+ def parse_meta(v: t.Optional[t.Dict[str, t.Any]]) -> t.Dict[str, t.Any]:
50
+ if v is None:
51
+ return {}
50
52
  for key, value in v.items():
51
53
  if isinstance(value, str):
52
54
  v[key] = try_str_to_bool(value)
@@ -115,7 +117,7 @@ class GeneralConfig(DbtConfig):
115
117
 
116
118
  @field_validator("meta", mode="before")
117
119
  @classmethod
118
- def _validate_meta(cls, v: t.Dict[str, t.Union[str, t.Any]]) -> t.Dict[str, t.Any]:
120
+ def _validate_meta(cls, v: t.Optional[t.Dict[str, t.Union[str, t.Any]]]) -> t.Dict[str, t.Any]:
119
121
  return parse_meta(v)
120
122
 
121
123
  _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
sqlmesh/dbt/context.py CHANGED
@@ -37,6 +37,8 @@ class DbtContext:
37
37
  """Context for DBT environment"""
38
38
 
39
39
  project_root: Path = Path()
40
+ profiles_dir: t.Optional[Path] = None
41
+ """Optional override to specify the directory where profiles.yml is located, if not at the :project_root"""
40
42
  target_name: t.Optional[str] = None
41
43
  profile_name: t.Optional[str] = None
42
44
  project_schema: t.Optional[str] = None
sqlmesh/dbt/loader.py CHANGED
@@ -53,10 +53,18 @@ def sqlmesh_config(
53
53
  threads: t.Optional[int] = None,
54
54
  register_comments: t.Optional[bool] = None,
55
55
  infer_state_schema_name: bool = False,
56
+ profiles_dir: t.Optional[Path] = None,
56
57
  **kwargs: t.Any,
57
58
  ) -> Config:
58
59
  project_root = project_root or Path()
59
- context = DbtContext(project_root=project_root, profile_name=dbt_profile_name)
60
+ context = DbtContext(
61
+ project_root=project_root, profiles_dir=profiles_dir, profile_name=dbt_profile_name
62
+ )
63
+
64
+ # note: Profile.load() is called twice with different DbtContext's:
65
+ # - once here with the above DbtContext (to determine connnection / gateway config which has to be set up before everything else)
66
+ # - again on the SQLMesh side via GenericContext.load() -> DbtLoader._load_projects() -> Project.load() which constructs a fresh DbtContext and ignores the above one
67
+ # it's important to ensure that the DbtContext created within the DbtLoader uses the same project root / profiles dir that we use here
60
68
  profile = Profile.load(context, target_name=dbt_target_name)
61
69
  model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig())
62
70
  if model_defaults.dialect is None:
@@ -98,6 +106,7 @@ def sqlmesh_config(
98
106
 
99
107
  return Config(
100
108
  loader=loader,
109
+ loader_kwargs=dict(profiles_dir=profiles_dir),
101
110
  model_defaults=model_defaults,
102
111
  variables=variables or {},
103
112
  dbt=RootDbtConfig(infer_state_schema_name=infer_state_schema_name),
@@ -116,9 +125,12 @@ def sqlmesh_config(
116
125
 
117
126
 
118
127
  class DbtLoader(Loader):
119
- def __init__(self, context: GenericContext, path: Path) -> None:
128
+ def __init__(
129
+ self, context: GenericContext, path: Path, profiles_dir: t.Optional[Path] = None
130
+ ) -> None:
120
131
  self._projects: t.List[Project] = []
121
132
  self._macros_max_mtime: t.Optional[float] = None
133
+ self._profiles_dir = profiles_dir
122
134
  super().__init__(context, path)
123
135
 
124
136
  def load(self) -> LoadedProject:
@@ -225,6 +237,7 @@ class DbtLoader(Loader):
225
237
  project = Project.load(
226
238
  DbtContext(
227
239
  project_root=self.config_path,
240
+ profiles_dir=self._profiles_dir,
228
241
  target_name=target_name,
229
242
  sqlmesh_config=self.config,
230
243
  ),
sqlmesh/dbt/manifest.py CHANGED
@@ -11,7 +11,7 @@ from collections import defaultdict
11
11
  from functools import cached_property
12
12
  from pathlib import Path
13
13
 
14
- from dbt import constants as dbt_constants, flags
14
+ from dbt import flags
15
15
 
16
16
  from sqlmesh.dbt.util import DBT_VERSION
17
17
  from sqlmesh.utils.conversions import make_serializable
@@ -19,6 +19,8 @@ from sqlmesh.utils.conversions import make_serializable
19
19
  # Override the file name to prevent dbt commands from invalidating the cache.
20
20
 
21
21
  if DBT_VERSION >= (1, 6, 0):
22
+ from dbt import constants as dbt_constants
23
+
22
24
  dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" # type: ignore
23
25
  else:
24
26
  from dbt.parser import manifest as dbt_manifest # type: ignore
sqlmesh/dbt/model.py CHANGED
@@ -567,6 +567,12 @@ class ModelConfig(BaseModelConfig):
567
567
  self.name,
568
568
  "views" if isinstance(kind, ViewKind) else "ephemeral models",
569
569
  )
570
+ elif context.target.dialect == "snowflake":
571
+ logger.warning(
572
+ "Ignoring partition_by config for model '%s' targeting %s. The partition_by config is not supported for Snowflake.",
573
+ self.name,
574
+ context.target.dialect,
575
+ )
570
576
  else:
571
577
  partitioned_by = []
572
578
  if isinstance(self.partition_by, list):
@@ -601,7 +607,13 @@ class ModelConfig(BaseModelConfig):
601
607
  clustered_by = []
602
608
  for c in self.cluster_by:
603
609
  try:
604
- clustered_by.append(d.parse_one(c, dialect=model_dialect))
610
+ cluster_expr = exp.maybe_parse(
611
+ c, into=exp.Cluster, prefix="CLUSTER BY", dialect=model_dialect
612
+ )
613
+ for expr in cluster_expr.expressions:
614
+ clustered_by.append(
615
+ expr.this if isinstance(expr, exp.Ordered) else expr
616
+ )
605
617
  except SqlglotError as e:
606
618
  raise ConfigError(
607
619
  f"Failed to parse model '{self.canonical_name(context)}' cluster_by field '{c}' in '{self.path}': {e}"
sqlmesh/dbt/profile.py CHANGED
@@ -60,7 +60,7 @@ class Profile:
60
60
  if not context.profile_name:
61
61
  raise ConfigError(f"{project_file.stem} must include project name.")
62
62
 
63
- profile_filepath = cls._find_profile(context.project_root)
63
+ profile_filepath = cls._find_profile(context.project_root, context.profiles_dir)
64
64
  if not profile_filepath:
65
65
  raise ConfigError(f"{cls.PROFILE_FILE} not found.")
66
66
 
@@ -68,8 +68,8 @@ class Profile:
68
68
  return Profile(profile_filepath, target_name, target)
69
69
 
70
70
  @classmethod
71
- def _find_profile(cls, project_root: Path) -> t.Optional[Path]:
72
- dir = os.environ.get("DBT_PROFILES_DIR", "")
71
+ def _find_profile(cls, project_root: Path, profiles_dir: t.Optional[Path]) -> t.Optional[Path]:
72
+ dir = os.environ.get("DBT_PROFILES_DIR", profiles_dir or "")
73
73
  path = Path(project_root, dir, cls.PROFILE_FILE)
74
74
  if path.exists():
75
75
  return path