trellis-datamodel 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. trellis_datamodel/__init__.py +8 -0
  2. trellis_datamodel/adapters/__init__.py +41 -0
  3. trellis_datamodel/adapters/base.py +147 -0
  4. trellis_datamodel/adapters/dbt_core.py +975 -0
  5. trellis_datamodel/cli.py +292 -0
  6. trellis_datamodel/config.py +239 -0
  7. trellis_datamodel/models/__init__.py +13 -0
  8. trellis_datamodel/models/schemas.py +28 -0
  9. trellis_datamodel/routes/__init__.py +11 -0
  10. trellis_datamodel/routes/data_model.py +221 -0
  11. trellis_datamodel/routes/manifest.py +110 -0
  12. trellis_datamodel/routes/schema.py +183 -0
  13. trellis_datamodel/server.py +101 -0
  14. trellis_datamodel/static/_app/env.js +1 -0
  15. trellis_datamodel/static/_app/immutable/assets/0.ByDwyx3a.css +1 -0
  16. trellis_datamodel/static/_app/immutable/assets/2.DLAp_5AW.css +1 -0
  17. trellis_datamodel/static/_app/immutable/assets/trellis_squared.CTOnsdDx.svg +127 -0
  18. trellis_datamodel/static/_app/immutable/chunks/8ZaN1sxc.js +1 -0
  19. trellis_datamodel/static/_app/immutable/chunks/BfBfOTnK.js +1 -0
  20. trellis_datamodel/static/_app/immutable/chunks/C3yhlRfZ.js +2 -0
  21. trellis_datamodel/static/_app/immutable/chunks/CK3bXPEX.js +1 -0
  22. trellis_datamodel/static/_app/immutable/chunks/CXDUumOQ.js +1 -0
  23. trellis_datamodel/static/_app/immutable/chunks/DDNfEvut.js +1 -0
  24. trellis_datamodel/static/_app/immutable/chunks/DUdVct7e.js +1 -0
  25. trellis_datamodel/static/_app/immutable/chunks/QRltG_J6.js +2 -0
  26. trellis_datamodel/static/_app/immutable/chunks/zXDdy2c_.js +1 -0
  27. trellis_datamodel/static/_app/immutable/entry/app.abCkWeAJ.js +2 -0
  28. trellis_datamodel/static/_app/immutable/entry/start.B7CjH6Z7.js +1 -0
  29. trellis_datamodel/static/_app/immutable/nodes/0.bFI_DI3G.js +1 -0
  30. trellis_datamodel/static/_app/immutable/nodes/1.J_r941Qf.js +1 -0
  31. trellis_datamodel/static/_app/immutable/nodes/2.WqbMkq6o.js +27 -0
  32. trellis_datamodel/static/_app/version.json +1 -0
  33. trellis_datamodel/static/index.html +40 -0
  34. trellis_datamodel/static/robots.txt +3 -0
  35. trellis_datamodel/static/trellis_squared.svg +127 -0
  36. trellis_datamodel/tests/__init__.py +2 -0
  37. trellis_datamodel/tests/conftest.py +132 -0
  38. trellis_datamodel/tests/test_cli.py +526 -0
  39. trellis_datamodel/tests/test_data_model.py +151 -0
  40. trellis_datamodel/tests/test_dbt_schema.py +892 -0
  41. trellis_datamodel/tests/test_manifest.py +72 -0
  42. trellis_datamodel/tests/test_server_static.py +44 -0
  43. trellis_datamodel/tests/test_yaml_handler.py +228 -0
  44. trellis_datamodel/utils/__init__.py +2 -0
  45. trellis_datamodel/utils/yaml_handler.py +365 -0
  46. trellis_datamodel-0.3.3.dist-info/METADATA +333 -0
  47. trellis_datamodel-0.3.3.dist-info/RECORD +52 -0
  48. trellis_datamodel-0.3.3.dist-info/WHEEL +5 -0
  49. trellis_datamodel-0.3.3.dist-info/entry_points.txt +2 -0
  50. trellis_datamodel-0.3.3.dist-info/licenses/LICENSE +661 -0
  51. trellis_datamodel-0.3.3.dist-info/licenses/NOTICE +6 -0
  52. trellis_datamodel-0.3.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,975 @@
1
+ """
2
+ dbt-core adapter implementation.
3
+
4
+ Handles parsing dbt manifest.json/catalog.json and generating dbt schema YAML files.
5
+ """
6
+
7
+ import copy
8
+ import json
9
+ import os
10
+ import re
11
+ import yaml
12
+ from pathlib import Path
13
+ from typing import Any, Optional
14
+
15
+ from trellis_datamodel.utils.yaml_handler import YamlHandler
16
+ from .base import (
17
+ ColumnInfo,
18
+ ColumnSchema,
19
+ ModelInfo,
20
+ ModelSchema,
21
+ Relationship,
22
+ )
23
+
24
+
25
+ class DbtCoreAdapter:
26
+ """Adapter for dbt-core transformation framework."""
27
+
28
+ def __init__(
29
+ self,
30
+ manifest_path: str,
31
+ catalog_path: str,
32
+ project_path: str,
33
+ data_model_path: str,
34
+ model_paths: list[str],
35
+ ):
36
+ self.manifest_path = manifest_path
37
+ self.catalog_path = catalog_path
38
+ self.project_path = project_path
39
+ self.data_model_path = data_model_path
40
+ self.model_paths = model_paths
41
+ self.yaml_handler = YamlHandler()
42
+
43
+ def _load_catalog(self) -> Optional[dict]:
44
+ """Load catalog.json if it exists."""
45
+ if not os.path.exists(self.catalog_path):
46
+ return None
47
+ try:
48
+ with open(self.catalog_path, "r") as f:
49
+ return json.load(f)
50
+ except Exception as exc:
51
+ print(f"Warning: failed to read catalog at {self.catalog_path}: {exc}")
52
+ return None
53
+
54
+ def _load_manifest(self) -> dict:
55
+ """Load manifest.json."""
56
+ with open(self.manifest_path, "r") as f:
57
+ return json.load(f)
58
+
59
+ def _load_data_model(self) -> dict:
60
+ """Load data model YAML if it exists."""
61
+ if not self.data_model_path or not os.path.exists(self.data_model_path):
62
+ return {}
63
+ try:
64
+ with open(self.data_model_path, "r") as f:
65
+ return yaml.safe_load(f) or {}
66
+ except Exception as e:
67
+ print(f"Warning: Could not load data model: {e}")
68
+ return {}
69
+
70
+ def get_model_dirs(self) -> list[str]:
71
+ """
72
+ Resolve configured models directories, normalizing common dbt prefixes.
73
+
74
+ Users may configure entries like "3_core", "models/3_entity", or absolute
75
+ paths. We normalize these to real directories so downstream scans work.
76
+ """
77
+
78
+ def _normalize(subdir: str) -> str:
79
+ # Absolute path - return as-is
80
+ if os.path.isabs(subdir):
81
+ return os.path.abspath(subdir).rstrip(os.sep)
82
+
83
+ # Remove leading "./"
84
+ while subdir.startswith("./"):
85
+ subdir = subdir[2:]
86
+
87
+ # Strip an optional leading "models/" so we don't double-prepend
88
+ prefix = f"models{os.sep}"
89
+ if subdir.startswith(prefix):
90
+ subdir = subdir[len(prefix) :]
91
+
92
+ return os.path.abspath(
93
+ os.path.join(self.project_path, "models", subdir)
94
+ ).rstrip(os.sep)
95
+
96
+ if self.model_paths:
97
+ # Remove duplicates while preserving order
98
+ seen = set()
99
+ normalized = []
100
+ for path in self.model_paths:
101
+ norm = _normalize(path)
102
+ if norm not in seen:
103
+ seen.add(norm)
104
+ normalized.append(norm)
105
+ return normalized
106
+
107
+ return [
108
+ os.path.abspath(os.path.join(self.project_path, "models")).rstrip(os.sep)
109
+ ]
110
+
111
+ def _entity_to_model_name(self, entity: dict[str, Any]) -> str:
112
+ """
113
+ Resolve the dbt model name for an entity.
114
+
115
+ Prefers the bound dbt_model (strips project prefix), otherwise falls back to
116
+ the entity ID so unbound entities still persist somewhere.
117
+ """
118
+ dbt_model = entity.get("dbt_model")
119
+ if dbt_model:
120
+ # dbt unique_id for versioned models looks like model.<project>.<name>.v2
121
+ parts = dbt_model.split(".")
122
+ if len(parts) >= 2 and re.match(r"v\d+$", parts[-1]):
123
+ # Use the model name part (the element before the vN suffix)
124
+ return parts[-2]
125
+ return parts[-1]
126
+ return entity.get("id") or ""
127
+
128
+ def _build_model_keys(self, base: str, version: Optional[str] = None) -> list[str]:
129
+ """
130
+ Generate a set of lookup keys for a model, including versioned variants.
131
+ """
132
+ keys = [base]
133
+ if version is not None:
134
+ # Version may come through as int from YAML; normalize to string
135
+ version_str = str(version)
136
+ # Support common ref patterns:
137
+ # - ref('model', v=2) -> base.v2
138
+ # - alias names -> base_v2
139
+ # - fully qualified -> base.v2 and base_v2
140
+ version_num = version_str.lstrip("v")
141
+ keys.extend(
142
+ [
143
+ f"{base}.v{version_num}",
144
+ f"{base}_v{version_num}",
145
+ f"{base}v{version_num}",
146
+ ]
147
+ )
148
+ # Deduplicate while preserving order
149
+ seen = set()
150
+ ordered_keys = []
151
+ for k in keys:
152
+ if k not in seen:
153
+ seen.add(k)
154
+ ordered_keys.append(k)
155
+ return ordered_keys
156
+
157
+ def _get_model_to_entity_map(self) -> dict[str, str]:
158
+ """Build mapping from model names (with version aliases) to entity IDs."""
159
+ model_to_entity: dict[str, str] = {}
160
+ data_model = self._load_data_model()
161
+ entities = data_model.get("entities", [])
162
+ for entity in entities:
163
+ entity_id = entity.get("id")
164
+ dbt_model = entity.get("dbt_model")
165
+ if dbt_model:
166
+ parts = dbt_model.split(".")
167
+ version_part = None
168
+ if len(parts) >= 2 and re.match(r"v\d+$", parts[-1]):
169
+ version_part = parts[-1].lstrip("v")
170
+ base_name = parts[-2]
171
+ else:
172
+ base_name = parts[-1]
173
+
174
+ # Map the raw unique_id as well as base/version variants
175
+ model_to_entity[dbt_model] = entity_id
176
+ for key in self._build_model_keys(base_name, version_part):
177
+ model_to_entity[key] = entity_id
178
+ # Map additional models to the same entity
179
+ additional_models = entity.get("additional_models", [])
180
+ for add_model in additional_models:
181
+ parts = add_model.split(".")
182
+ version_part = None
183
+ if len(parts) >= 2 and re.match(r"v\d+$", parts[-1]):
184
+ version_part = parts[-1].lstrip("v")
185
+ base_name = parts[-2]
186
+ else:
187
+ base_name = parts[-1]
188
+
189
+ model_to_entity[add_model] = entity_id
190
+ for key in self._build_model_keys(base_name, version_part):
191
+ model_to_entity[key] = entity_id
192
+ if entity_id:
193
+ model_to_entity[entity_id] = entity_id
194
+ return model_to_entity
195
+
196
+ def _get_model_yml_path(
197
+ self, model_name: str, target_version: Optional[int] = None
198
+ ) -> Optional[str]:
199
+ """Get the yml file path for a model from the manifest."""
200
+ if not os.path.exists(self.manifest_path):
201
+ return None
202
+
203
+ manifest = self._load_manifest()
204
+ preferred_node: Optional[dict] = None
205
+ fallback_node: Optional[dict] = None
206
+ for key, node in manifest.get("nodes", {}).items():
207
+ if node.get("resource_type") == "model" and node.get("name") == model_name:
208
+ if target_version is not None and node.get("version") == target_version:
209
+ preferred_node = node
210
+ break
211
+ if fallback_node is None:
212
+ fallback_node = node
213
+
214
+ node = preferred_node or fallback_node
215
+ if not node:
216
+ return None
217
+
218
+ return self._derive_yml_path_from_node(node)
219
+
220
+ def _normalize_patch_path(self, patch_path: str) -> str:
221
+ """
222
+ Convert a manifest patch_path (which may include a scheme) into an
223
+ absolute filesystem path rooted at the dbt project.
224
+ """
225
+ if "://" in patch_path:
226
+ patch_path = patch_path.split("://", 1)[1]
227
+
228
+ patch_path = patch_path.lstrip("/")
229
+
230
+ if os.path.isabs(patch_path):
231
+ return patch_path
232
+
233
+ # Default: resolve relative to the dbt project directory
234
+ base = self.project_path or "."
235
+ return os.path.abspath(os.path.join(base, patch_path))
236
+
237
+ def _derive_yml_path_from_node(self, node: dict) -> Optional[str]:
238
+ """
239
+ Determine the YAML file path for a manifest node, preferring patch_path.
240
+ """
241
+ patch_path = node.get("patch_path")
242
+ if patch_path:
243
+ return self._normalize_patch_path(patch_path)
244
+
245
+ original_file_path = node.get("original_file_path", "")
246
+ if not original_file_path:
247
+ return None
248
+
249
+ sql_path = os.path.join(self.project_path, original_file_path)
250
+ base_path = os.path.splitext(sql_path)[0]
251
+ yml_path = f"{base_path}.yml"
252
+
253
+ # Common versioned layout: player_v2.sql + player.yml
254
+ if not os.path.exists(yml_path):
255
+ base_path = re.sub(r"_v\d+$", "", base_path)
256
+ yml_path = f"{base_path}.yml"
257
+
258
+ return yml_path
259
+
260
+ def _extract_version_from_string(self, value: str) -> Optional[int]:
261
+ """Extract integer version from strings like 'model.proj.player.v2'."""
262
+ if not value:
263
+ return None
264
+
265
+ match = re.search(r"\.v(\d+)$", value)
266
+ if match:
267
+ return int(match.group(1))
268
+
269
+ if value.startswith("v") and value[1:].isdigit():
270
+ return int(value[1:])
271
+
272
+ return None
273
+
274
+ def _resolve_model_version(
275
+ self, model_name: str, entity_id: str, data_model: dict[str, Any]
276
+ ) -> Optional[int]:
277
+ """
278
+ Resolve target version from data model (dbt_model) or manifest node.
279
+
280
+ Prefers explicit dbt_model binding with .vN suffix, falls back to
281
+ manifest node version.
282
+ """
283
+ for entity in data_model.get("entities", []):
284
+ if entity.get("id") != entity_id:
285
+ continue
286
+ dbt_model = entity.get("dbt_model") or ""
287
+ # Check last token after split to support fully-qualified model IDs
288
+ last_token = dbt_model.split(".")[-1] if dbt_model else ""
289
+ version = self._extract_version_from_string(last_token)
290
+ if version is not None:
291
+ return version
292
+
293
+ if os.path.exists(self.manifest_path):
294
+ manifest = self._load_manifest()
295
+ for node in manifest.get("nodes", {}).values():
296
+ if node.get("resource_type") != "model":
297
+ continue
298
+ if node.get("name") != model_name:
299
+ continue
300
+ node_version = node.get("version")
301
+ if node_version:
302
+ return int(node_version)
303
+
304
+ return None
305
+
306
+ def _find_manifest_model_nodes(
307
+ self, manifest: dict[str, Any], model_name: str
308
+ ) -> list[dict[str, Any]]:
309
+ """Collect manifest nodes matching a given model name."""
310
+ return [
311
+ node
312
+ for node in manifest.get("nodes", {}).values()
313
+ if node.get("resource_type") == "model" and node.get("name") == model_name
314
+ ]
315
+
316
+ def _select_model_node(
317
+ self, candidates: list[dict[str, Any]], target_version: Optional[int]
318
+ ) -> Optional[dict[str, Any]]:
319
+ """
320
+ Choose the best matching manifest node.
321
+
322
+ Prefers an exact version match when requested; otherwise picks the
323
+ highest numbered version when available, or the first candidate.
324
+ """
325
+ if target_version is not None:
326
+ for node in candidates:
327
+ node_version = node.get("version")
328
+ if node_version is not None and int(node_version) == target_version:
329
+ return node
330
+
331
+ if not candidates:
332
+ return None
333
+
334
+ versioned = [n for n in candidates if n.get("version") is not None]
335
+ if versioned:
336
+ return sorted(versioned, key=lambda n: n.get("version") or 0)[-1]
337
+
338
+ return candidates[0]
339
+
340
+ def get_models(self) -> list[ModelInfo]:
341
+ """Parse dbt manifest and catalog to return available models."""
342
+ if not os.path.exists(self.manifest_path):
343
+ raise FileNotFoundError(f"Manifest not found at {self.manifest_path}")
344
+
345
+ manifest = self._load_manifest()
346
+ catalog = self._load_catalog()
347
+ catalog_nodes = (catalog or {}).get("nodes", {})
348
+
349
+ models: list[ModelInfo] = []
350
+ for key, node in manifest.get("nodes", {}).items():
351
+ if node.get("resource_type") != "model":
352
+ continue
353
+
354
+ # Filter by path
355
+ original_path = node.get("original_file_path", "")
356
+ if self.model_paths:
357
+ match = any(pattern in original_path for pattern in self.model_paths)
358
+ if not match:
359
+ continue
360
+
361
+ # Get columns from catalog or manifest
362
+ columns: list[ColumnInfo] = []
363
+ unique_id = node.get("unique_id")
364
+ catalog_node = catalog_nodes.get(unique_id)
365
+
366
+ if catalog_node:
367
+ for col in catalog_node.get("columns", {}).values():
368
+ columns.append(
369
+ {
370
+ "name": col.get("name"),
371
+ "type": col.get("type") or col.get("data_type"),
372
+ }
373
+ )
374
+ else:
375
+ for col_name, col_data in node.get("columns", {}).items():
376
+ columns.append({"name": col_name, "type": col_data.get("type")})
377
+
378
+ # Extract materialization
379
+ config = node.get("config", {})
380
+ materialized = config.get("materialized", "view")
381
+
382
+ models.append(
383
+ {
384
+ "unique_id": unique_id,
385
+ "name": node.get("name"),
386
+ "version": node.get("version"),
387
+ "schema": node.get("schema"),
388
+ "table": node.get("alias", node.get("name")),
389
+ "columns": columns,
390
+ "description": node.get("description"),
391
+ "materialization": materialized,
392
+ "file_path": original_path,
393
+ "tags": node.get("tags", []),
394
+ }
395
+ )
396
+
397
+ models.sort(key=lambda x: x["name"])
398
+ return models
399
+
400
+ def get_model_schema(
401
+ self, model_name: str, version: Optional[int] = None
402
+ ) -> ModelSchema:
403
+ """Get the current schema definition for a specific model from its YAML file."""
404
+ if not os.path.exists(self.manifest_path):
405
+ raise FileNotFoundError(f"Manifest not found at {self.manifest_path}")
406
+
407
+ manifest = self._load_manifest()
408
+
409
+ candidate_nodes = self._find_manifest_model_nodes(manifest, model_name)
410
+ model_node = self._select_model_node(candidate_nodes, target_version=version)
411
+
412
+ if not model_node:
413
+ raise ValueError(f"Model '{model_name}' not found in manifest")
414
+
415
+ yml_path = self._derive_yml_path_from_node(model_node)
416
+ if not yml_path:
417
+ raise ValueError(
418
+ f"No patch_path or original_file_path found for model '{model_name}'"
419
+ )
420
+
421
+ data = self.yaml_handler.load_file(yml_path)
422
+ if not data:
423
+ return {
424
+ "model_name": model_name,
425
+ "description": "",
426
+ "columns": [],
427
+ "tags": [],
428
+ "file_path": yml_path,
429
+ }
430
+
431
+ model_entry = self.yaml_handler.find_model(data, model_name)
432
+ if not model_entry:
433
+ return {
434
+ "model_name": model_name,
435
+ "description": "",
436
+ "columns": [],
437
+ "tags": [],
438
+ "file_path": yml_path,
439
+ }
440
+
441
+ target_version = version if version is not None else model_node.get("version")
442
+ version_entry = None
443
+ versions = model_entry.get("versions") or []
444
+ if versions:
445
+ if target_version is not None:
446
+ for ver in versions:
447
+ if ver.get("v") == target_version:
448
+ version_entry = ver
449
+ break
450
+ if version_entry is None:
451
+ version_entry = sorted(versions, key=lambda ver: ver.get("v") or 0)[-1]
452
+
453
+ # Prefer versioned block if present
454
+ node_for_schema = version_entry or model_entry
455
+
456
+ columns = self.yaml_handler.get_columns(node_for_schema)
457
+ tags = self.yaml_handler.get_model_tags(node_for_schema)
458
+ return {
459
+ "model_name": model_name,
460
+ "description": node_for_schema.get("description", ""),
461
+ "columns": columns,
462
+ "tags": tags,
463
+ "file_path": yml_path,
464
+ }
465
+
466
+ def save_model_schema(
467
+ self,
468
+ model_name: str,
469
+ columns: list[ColumnSchema],
470
+ description: Optional[str] = None,
471
+ tags: Optional[list[str]] = None,
472
+ version: Optional[int] = None,
473
+ ) -> Path:
474
+ """Save/update the schema definition for a model."""
475
+ if not os.path.exists(self.manifest_path):
476
+ raise FileNotFoundError(f"Manifest not found at {self.manifest_path}")
477
+
478
+ manifest = self._load_manifest()
479
+
480
+ candidate_nodes = self._find_manifest_model_nodes(manifest, model_name)
481
+ model_node = self._select_model_node(candidate_nodes, target_version=version)
482
+
483
+ if not model_node:
484
+ raise ValueError(f"Model '{model_name}' not found in manifest")
485
+
486
+ yml_path = self._derive_yml_path_from_node(model_node)
487
+ if not yml_path:
488
+ raise ValueError(
489
+ f"No patch_path or original_file_path found for model '{model_name}'"
490
+ )
491
+
492
+ data = self.yaml_handler.load_file(yml_path)
493
+ if not data:
494
+ data = {"version": 2, "models": []}
495
+
496
+ model_entry = self.yaml_handler.ensure_model(data, model_name)
497
+
498
+ target_version = version if version is not None else model_node.get("version")
499
+ if target_version is not None:
500
+ # Versioned model: update version entry and keep latest_version in sync (non-decreasing)
501
+ self.yaml_handler.set_latest_version(model_entry, target_version)
502
+ version_entry = self.yaml_handler.ensure_model_version(
503
+ model_entry, target_version
504
+ )
505
+
506
+ if description is not None:
507
+ self.yaml_handler.update_model_description(version_entry, description)
508
+ # Keep top-level description aligned when present
509
+ self.yaml_handler.update_model_description(model_entry, description)
510
+
511
+ self.yaml_handler.update_columns_batch(version_entry, columns)
512
+
513
+ if tags is not None:
514
+ self.yaml_handler.update_version_tags(version_entry, tags)
515
+ else:
516
+ # Non-versioned model
517
+ if description is not None:
518
+ self.yaml_handler.update_model_description(model_entry, description)
519
+
520
+ self.yaml_handler.update_columns_batch(model_entry, columns)
521
+
522
+ if tags is not None:
523
+ self.yaml_handler.update_model_tags(model_entry, tags)
524
+
525
+ self.yaml_handler.save_file(yml_path, data)
526
+ return Path(yml_path)
527
+
528
+ def _parse_ref(self, ref_value: str) -> tuple[str, Optional[str]]:
529
+ """
530
+ Parse ref() targets, supporting optional version arguments.
531
+
532
+ Examples:
533
+ ref('player') -> ("player", None)
534
+ ref('player', v=1) -> ("player", "1")
535
+ ref("player", version=2) -> ("player", "2")
536
+ """
537
+ ref_pattern = (
538
+ r"ref\(\s*['\"]([^,'\"]+)['\"](?:\s*,\s*(?:v|version)\s*=\s*([0-9]+))?\s*\)"
539
+ )
540
+ match = re.fullmatch(ref_pattern, ref_value.strip())
541
+ if match:
542
+ return match.group(1), match.group(2)
543
+ return ref_value, None
544
+
545
+ def _resolve_entity_id(
546
+ self, model_to_entity: dict[str, str], base_name: str, version: Optional[str]
547
+ ) -> str:
548
+ """
549
+ Resolve an entity id using base model name plus optional version.
550
+ Falls back to the base name if no mapping is found.
551
+ """
552
+ for key in self._build_model_keys(base_name, version):
553
+ if key in model_to_entity:
554
+ return model_to_entity[key]
555
+ # Fallback: try the raw base name
556
+ return model_to_entity.get(base_name, base_name)
557
+
558
+ def infer_relationships(self, include_unbound: bool = False) -> list[Relationship]:
559
+ """Scan dbt yml files and infer entity relationships from relationship tests.
560
+
561
+ When include_unbound=True, returns ALL relationships found in dbt yml files
562
+ using raw model names. The frontend is responsible for mapping model names
563
+ to entity IDs based on current canvas state (which may not be saved yet).
564
+ """
565
+ model_dirs = self.get_model_dirs()
566
+ model_to_entity = self._get_model_to_entity_map()
567
+
568
+ # Only keep relationships where both ends map to entities that are bound to
569
+ # at least one dbt model (including additional_models). This prevents writing
570
+ # relationships for unbound entities in large projects.
571
+ # When include_unbound is True, skip filtering entirely - return all relationships
572
+ # using raw model names so the frontend can map them to current canvas state.
573
+ bound_entities: set[str] = set()
574
+ if not include_unbound:
575
+ data_model = self._load_data_model()
576
+ bound_entities = {
577
+ e.get("id")
578
+ for e in data_model.get("entities", [])
579
+ if e.get("id") and (e.get("dbt_model") or e.get("additional_models"))
580
+ }
581
+
582
+ relationships: list[Relationship] = []
583
+ yml_found = False
584
+
585
+ for models_dir in model_dirs:
586
+ if not os.path.exists(models_dir):
587
+ continue
588
+
589
+ for root, _, files in os.walk(models_dir):
590
+ for filename in files:
591
+ if not filename.endswith((".yml", ".yaml")):
592
+ continue
593
+ yml_found = True
594
+
595
+ filepath = os.path.join(root, filename)
596
+ try:
597
+ with open(filepath, "r") as f:
598
+ schema_data = yaml.safe_load(f) or {}
599
+
600
+ models_list = schema_data.get("models", [])
601
+ for model in models_list:
602
+ base_model_name = model.get("name")
603
+ if not base_model_name:
604
+ continue
605
+
606
+ # Versioned models may declare columns inside versions list
607
+ version_entries = model.get("versions", [])
608
+ versioned_columns = []
609
+ base_columns = model.get("columns", []) or []
610
+
611
+ if isinstance(version_entries, list) and version_entries:
612
+ # Carry forward columns when a version uses "include: all"
613
+ previous_columns = copy.deepcopy(base_columns)
614
+ for ver in version_entries:
615
+ raw_columns = ver.get("columns", []) or []
616
+ expanded_columns: list[dict] = []
617
+ for col in raw_columns:
618
+ if isinstance(col, dict) and col.get("include") == "all":
619
+ expanded_columns.extend(copy.deepcopy(previous_columns))
620
+ else:
621
+ expanded_columns.append(col)
622
+
623
+ # If no columns are explicitly provided, fall back to previous set
624
+ if not expanded_columns and previous_columns:
625
+ expanded_columns = copy.deepcopy(previous_columns)
626
+
627
+ versioned_columns.append(
628
+ (ver.get("v") or ver.get("version"), expanded_columns)
629
+ )
630
+ previous_columns = copy.deepcopy(expanded_columns)
631
+ else:
632
+ versioned_columns.append((None, base_columns))
633
+
634
+ for model_version, columns in versioned_columns:
635
+ # When include_unbound, use raw model name so frontend can remap
636
+ # Otherwise resolve to entity ID from saved data model
637
+ if include_unbound:
638
+ entity_id = base_model_name
639
+ else:
640
+ entity_id = self._resolve_entity_id(
641
+ model_to_entity, base_model_name, model_version
642
+ )
643
+
644
+ for column in columns or []:
645
+ test_blocks = []
646
+ for key in ("tests", "data_tests"):
647
+ value = column.get(key, [])
648
+ if isinstance(value, list):
649
+ test_blocks.extend(value)
650
+
651
+ for test in test_blocks:
652
+ if (
653
+ not isinstance(test, dict)
654
+ or "relationships" not in test
655
+ ):
656
+ continue
657
+
658
+ rel_test = test["relationships"]
659
+ args = rel_test.get("arguments", {}) or {}
660
+
661
+ # Support both the recommended arguments block and legacy top-level keys
662
+ to_ref = rel_test.get("to", "") or args.get(
663
+ "to", ""
664
+ )
665
+ target_field = rel_test.get(
666
+ "field", ""
667
+ ) or args.get("field", "")
668
+
669
+ # If either ref target or field is missing, skip and log for debugging
670
+ if not to_ref or not target_field:
671
+ continue
672
+
673
+ target_base, target_version = self._parse_ref(
674
+ to_ref
675
+ )
676
+
677
+ # When include_unbound, use raw model name
678
+ if include_unbound:
679
+ target_entity_id = target_base
680
+ else:
681
+ target_entity_id = self._resolve_entity_id(
682
+ model_to_entity, target_base, target_version
683
+ )
684
+
685
+ # Skip relationships where either side is not bound
686
+ if (
687
+ entity_id not in bound_entities
688
+ or target_entity_id not in bound_entities
689
+ ):
690
+ continue
691
+
692
+ relationships.append(
693
+ {
694
+ "source": target_entity_id,
695
+ "target": entity_id,
696
+ "label": "",
697
+ "type": "one_to_many",
698
+ "source_field": target_field,
699
+ "target_field": column.get("name"),
700
+ }
701
+ )
702
+ except Exception as e:
703
+ print(f"Warning: Could not parse {filepath}: {e}")
704
+ continue
705
+
706
+ if not yml_found:
707
+ raise FileNotFoundError(
708
+ f"No schema yml files found under configured dbt model paths: {model_dirs}"
709
+ )
710
+
711
+ # Remove duplicates
712
+ seen: set[tuple] = set()
713
+ unique_relationships: list[Relationship] = []
714
+ for rel in relationships:
715
+ key = (
716
+ rel["source"],
717
+ rel["target"],
718
+ rel.get("source_field"),
719
+ rel.get("target_field"),
720
+ )
721
+ if key not in seen:
722
+ seen.add(key)
723
+ unique_relationships.append(rel)
724
+
725
+ return unique_relationships
726
+
727
+ def sync_relationships(
728
+ self,
729
+ entities: list[dict[str, Any]],
730
+ relationships: list[dict[str, Any]],
731
+ ) -> list[Path]:
732
+ """Sync relationship definitions from data model to dbt yml files."""
733
+ # Build entity lookup
734
+ entity_map = {e["id"]: e for e in entities if e.get("id")}
735
+ # Map entity_id -> dbt model name (best-effort)
736
+ entity_model_name: dict[str, str] = {
737
+ eid: self._entity_to_model_name(ent) for eid, ent in entity_map.items()
738
+ }
739
+
740
+ # Group relationships by target entity (the one with the FK)
741
+ fk_by_entity: dict[str, list[dict]] = {}
742
+
743
+ for rel in relationships:
744
+ source_id = rel.get("source")
745
+ target_id = rel.get("target")
746
+ rel_type = rel.get("type", "one_to_many")
747
+ source_field = rel.get("source_field")
748
+ target_field = rel.get("target_field")
749
+
750
+ if not source_field or not target_field:
751
+ continue
752
+
753
+ fk_on_target = rel_type == "one_to_many"
754
+ fk_entity = target_id if fk_on_target else source_id
755
+ fk_field = target_field if fk_on_target else source_field
756
+ ref_entity = source_id if fk_on_target else target_id
757
+ ref_field = source_field if fk_on_target else target_field
758
+
759
+ fk_by_entity.setdefault(fk_entity, []).append(
760
+ {
761
+ "fk_field": fk_field,
762
+ "ref_entity": ref_entity,
763
+ "ref_field": ref_field,
764
+ }
765
+ )
766
+
767
+ models_dir = self.get_model_dirs()[0]
768
+ os.makedirs(models_dir, exist_ok=True)
769
+
770
+ updated_files: list[Path] = []
771
+
772
+ for entity in entities:
773
+ entity_id = entity.get("id")
774
+ if not entity_id:
775
+ continue
776
+
777
+ model_name = entity_model_name.get(entity_id, entity_id)
778
+
779
+ # For bound entities, use the correct path from manifest
780
+ # For unbound entities, fall back to models_dir/{entity_id}.yml
781
+ yml_path = None
782
+ if entity.get("dbt_model"):
783
+ yml_path = self._get_model_yml_path(model_name)
784
+ if not yml_path:
785
+ yml_path = os.path.join(models_dir, f"{entity_id}.yml")
786
+
787
+ data = self.yaml_handler.load_file(yml_path)
788
+ if not data:
789
+ data = {"version": 2, "models": []}
790
+
791
+ model_entry = self.yaml_handler.ensure_model(data, model_name)
792
+
793
+ if entity.get("description"):
794
+ self.yaml_handler.update_model_description(
795
+ model_entry, entity.get("description")
796
+ )
797
+
798
+ # Sync Tags
799
+ entity_tags = entity.get("tags")
800
+ if entity_tags is not None:
801
+ self.yaml_handler.update_model_tags(model_entry, entity_tags)
802
+
803
+ # Sync Drafted Fields
804
+ drafted_fields = entity.get("drafted_fields", [])
805
+ for field in drafted_fields:
806
+ f_name = field.get("name")
807
+ f_type = field.get("datatype")
808
+ f_desc = field.get("description")
809
+
810
+ if not f_name:
811
+ continue
812
+
813
+ col = self.yaml_handler.ensure_column(model_entry, f_name)
814
+ self.yaml_handler.update_column(
815
+ col, data_type=f_type, description=f_desc
816
+ )
817
+
818
+ # Sync Relationships (FKs)
819
+ fk_list = fk_by_entity.get(entity_id, [])
820
+ for fk_info in fk_list:
821
+ fk_field = fk_info["fk_field"]
822
+ ref_entity = fk_info["ref_entity"]
823
+ ref_field = fk_info["ref_field"]
824
+
825
+ # Resolve reference model name (dbt model) for relationship test
826
+ ref_model_name = entity_model_name.get(ref_entity, ref_entity)
827
+
828
+ col = self.yaml_handler.ensure_column(model_entry, fk_field)
829
+
830
+ if "data_type" not in col:
831
+ col["data_type"] = "text"
832
+
833
+ self.yaml_handler.add_relationship_test(col, ref_model_name, ref_field)
834
+
835
+ self.yaml_handler.save_file(yml_path, data)
836
+ updated_files.append(Path(yml_path))
837
+
838
+ return updated_files
839
+
840
+ def save_dbt_schema(
841
+ self,
842
+ entity_id: str,
843
+ model_name: str,
844
+ fields: list[dict[str, str]],
845
+ description: Optional[str] = None,
846
+ tags: Optional[list[str]] = None,
847
+ ) -> Path:
848
+ """
849
+ Generate and save a dbt schema YAML file for drafted fields.
850
+
851
+ This is used for creating new schema files from the data model editor.
852
+ """
853
+ data_model = self._load_data_model()
854
+
855
+ # Use description from request if available, otherwise fallback to data model
856
+ entity_description = description
857
+ if not entity_description:
858
+ entities = data_model.get("entities", [])
859
+ for entity in entities:
860
+ if entity.get("id") == entity_id:
861
+ entity_description = entity.get("description")
862
+ break
863
+
864
+ # Build a map of field names to relationships for this entity
865
+ relationships = data_model.get("relationships", [])
866
+ field_to_relationship: dict[str, dict] = {}
867
+ # Map entity -> dbt model name for refs
868
+ entity_model_name = {
869
+ e.get("id"): self._entity_to_model_name(e)
870
+ for e in data_model.get("entities", [])
871
+ if e.get("id")
872
+ }
873
+
874
+ for rel in relationships:
875
+ source_id = rel.get("source")
876
+ target_id = rel.get("target")
877
+ rel_type = rel.get("type", "one_to_many")
878
+ source_field = rel.get("source_field")
879
+ target_field = rel.get("target_field")
880
+
881
+ if not source_field or not target_field:
882
+ continue
883
+
884
+ fk_on_target = rel_type == "one_to_many"
885
+ fk_entity = target_id if fk_on_target else source_id
886
+ fk_field = target_field if fk_on_target else source_field
887
+ ref_entity = source_id if fk_on_target else target_id
888
+ ref_field = source_field if fk_on_target else target_field
889
+
890
+ if fk_entity == entity_id:
891
+ field_to_relationship[fk_field] = {
892
+ "target_entity": ref_entity,
893
+ "target_field": ref_field,
894
+ }
895
+
896
+ # Generate YAML content with relationship tests
897
+ columns = []
898
+ for field in fields:
899
+ column_dict: dict[str, Any] = {
900
+ "name": field["name"],
901
+ "data_type": field["datatype"],
902
+ }
903
+
904
+ if field.get("description"):
905
+ column_dict["description"] = field["description"]
906
+
907
+ field_name = field["name"]
908
+ if field_name in field_to_relationship:
909
+ rel_info = field_to_relationship[field_name]
910
+ ref_model = entity_model_name.get(
911
+ rel_info["target_entity"], rel_info["target_entity"]
912
+ )
913
+ column_dict["data_tests"] = [
914
+ {
915
+ "relationships": {
916
+ "arguments": {
917
+ "to": f"ref('{ref_model}')",
918
+ "field": rel_info["target_field"],
919
+ }
920
+ }
921
+ }
922
+ ]
923
+
924
+ columns.append(column_dict)
925
+
926
+ target_version = self._resolve_model_version(
927
+ model_name=model_name, entity_id=entity_id, data_model=data_model
928
+ )
929
+
930
+ # Prefer manifest-derived path; fall back to default models directory
931
+ yml_path = self._get_model_yml_path(model_name, target_version=target_version)
932
+ if not yml_path:
933
+ models_dir = self.get_model_dirs()[0]
934
+ os.makedirs(models_dir, exist_ok=True)
935
+ yml_path = os.path.join(models_dir, f"{entity_id}.yml")
936
+
937
+ data = self.yaml_handler.load_file(yml_path)
938
+ if not data:
939
+ data = {"version": 2, "models": []}
940
+
941
+ model_entry = self.yaml_handler.ensure_model(data, model_name)
942
+
943
+ if target_version is not None:
944
+ # Versioned model: update version block and latest_version
945
+ self.yaml_handler.set_latest_version(model_entry, target_version)
946
+ if entity_description:
947
+ self.yaml_handler.update_model_description(
948
+ model_entry, entity_description
949
+ )
950
+
951
+ version_entry = self.yaml_handler.ensure_model_version(
952
+ model_entry, target_version
953
+ )
954
+ self.yaml_handler.update_columns_batch(version_entry, columns)
955
+
956
+ if entity_description:
957
+ self.yaml_handler.update_model_description(
958
+ version_entry, entity_description
959
+ )
960
+ if tags is not None:
961
+ self.yaml_handler.update_version_tags(version_entry, tags)
962
+ else:
963
+ # Non-versioned model: update columns and metadata directly
964
+ if entity_description:
965
+ self.yaml_handler.update_model_description(
966
+ model_entry, entity_description
967
+ )
968
+
969
+ self.yaml_handler.update_columns_batch(model_entry, columns)
970
+
971
+ if tags is not None:
972
+ self.yaml_handler.update_model_tags(model_entry, tags)
973
+
974
+ self.yaml_handler.save_file(yml_path, data)
975
+ return Path(yml_path)