clickzetta-semantic-model-generator 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: clickzetta-semantic-model-generator
3
- Version: 1.0.3
3
+ Version: 1.0.5
4
4
  Summary: Curate a Semantic Model for ClickZetta Lakehouse
5
5
  License: Apache Software License; BSD License
6
6
  Author: qililiang
@@ -1,13 +1,13 @@
1
1
  semantic_model_generator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- semantic_model_generator/clickzetta_utils/clickzetta_connector.py,sha256=rFBWdNQerLYinn6RoDV_J4k2G4LLofiFkLDa7j8hmng,32888
2
+ semantic_model_generator/clickzetta_utils/clickzetta_connector.py,sha256=z8WYF2Ft2_u4JJsbaaN64IW-bIaiV9Bkv6e1pes3PdU,33777
3
3
  semantic_model_generator/clickzetta_utils/env_vars.py,sha256=8cbL6R75c1-aVQ2i1TDr9SiHCUjTrgvXbIRz4MbcmbE,7664
4
- semantic_model_generator/clickzetta_utils/utils.py,sha256=D0SX2faBjwvhFJLt1Yk4mlZmyHmQt7LN93Jrc5YIU-A,3800
4
+ semantic_model_generator/clickzetta_utils/utils.py,sha256=UBfWy9qOTyut8tL02gOHHbh6Uz8RqRz5Mm2YdKWFN54,4950
5
5
  semantic_model_generator/data_processing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- semantic_model_generator/data_processing/cte_utils.py,sha256=jfTJIwc89-0nnelVw_5vpIVRout7V0YooUDfZzTzDr4,16086
6
+ semantic_model_generator/data_processing/cte_utils.py,sha256=-kQw_PfPPe3mf7shQf1XV5rqfqYdB9WK4A-EwAcKc_o,16928
7
7
  semantic_model_generator/data_processing/cte_utils_test.py,sha256=l6QkyyH22FexLKjvvbS9Je3YtdTrJE3a-BiknCy1g9s,2822
8
8
  semantic_model_generator/data_processing/data_types.py,sha256=1HsSCkdCWvcXiwN3o1-HVQi_ZVIR0lYevXG9CE1TvRc,1172
9
9
  semantic_model_generator/data_processing/proto_utils.py,sha256=UwqCfQYilTx68KcA4IYZN7PeM4Pz_pK1h0FrVJomzV8,2938
10
- semantic_model_generator/generate_model.py,sha256=ogNvx1HNOnC5KIZlGDwcWL7PLMHRs8zcZZbwricffDo,121843
10
+ semantic_model_generator/generate_model.py,sha256=vwISWJzYf4XS1TuLclpxKbberlsRKM99olrFlWaTCUw,125549
11
11
  semantic_model_generator/llm/__init__.py,sha256=rLQt2pzRmxtnBLKjxN_qZ2a_nvkFHtmguU5lyajCldw,1030
12
12
  semantic_model_generator/llm/dashscope_client.py,sha256=lHS36iqNZbFhwgidPpW1Bwwy4S2O7GeLyMSMdlSoBsY,6050
13
13
  semantic_model_generator/llm/enrichment.py,sha256=49e9Jg_jHfhUIEQ3JserEc5DV5sFWA12K76TY4UwnCg,41448
@@ -16,13 +16,13 @@ semantic_model_generator/output_models/.keep,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRk
16
16
  semantic_model_generator/protos/semantic_model.proto,sha256=WZiN4b8vR-ZX-Lj9Vsm6HjZNAyNvM1znIyut_YkPVSI,16473
17
17
  semantic_model_generator/protos/semantic_model_pb2.py,sha256=scbWkW-I-r3_hp_5SHoOWn02p52RJ9DJ0_-nRgr0LHc,25606
18
18
  semantic_model_generator/protos/semantic_model_pb2.pyi,sha256=iiBIZxtX9d6IuUO3aLcsJsHUeZqdi14vYNuUsSM8C0g,18267
19
- semantic_model_generator/relationships/__init__.py,sha256=HN6Opie25Oawt2fCDM_bZwRBVBEzqRsEXgDzYC7ytns,373
20
- semantic_model_generator/relationships/discovery.py,sha256=l_CixbfRvHBqxmLCmCq7bvQHRt3iUl0o5mui4R5LHXQ,5961
19
+ semantic_model_generator/relationships/__init__.py,sha256=I9-_QJdp36nEllzKTGXi2aWbRjiXrrexQXUfB6mi3Ww,477
20
+ semantic_model_generator/relationships/discovery.py,sha256=aw3LrthDZ6ng9P5eI3noxw-1E30csYqe2kyGn6CpLZA,13125
21
21
  semantic_model_generator/tests/clickzetta_connector_test.py,sha256=Fdx7jooNt1lslKB2Ub51wqOZ8OM0osgZiDDl3bV6riw,3086
22
- semantic_model_generator/tests/cte_utils_test.py,sha256=LdhWw_bHZDE1LyS2hBVy_VTNjLgodonesWaxw8jXpV4,17385
22
+ semantic_model_generator/tests/cte_utils_test.py,sha256=_9GAJiOPGSagdWmQsoAEOOhEgsBY0LFlr_xtwrlgf4A,17561
23
23
  semantic_model_generator/tests/generate_model_classification_test.py,sha256=Amq29cmeKd0S7iVikJ60RFm9gpWaQv1TijXofp3J-lI,2275
24
24
  semantic_model_generator/tests/llm_enrichment_test.py,sha256=1avLrPWp7J7o_K3PKbI_PIvduM5Id21MmoL0JTeDTfs,15738
25
- semantic_model_generator/tests/relationship_discovery_test.py,sha256=SOuXCwbmSUgvZoOS2s5oGK1w0LW283M1hg--QlLaDVA,3490
25
+ semantic_model_generator/tests/relationship_discovery_test.py,sha256=CBeQVfd9XT5haXpNs6tsccH79v8zDa6abnUYL8f2gSs,6829
26
26
  semantic_model_generator/tests/relationships_filters_test.py,sha256=bUm3r1UGaXca-hJOot7jMPz4It_TVsoddd-Xpk-76zM,10166
27
27
  semantic_model_generator/tests/samples/validate_yamls.py,sha256=262j-2i2oFZtTyK2susOrbxxE5eS-6IN-V0jFEOpt_w,156249
28
28
  semantic_model_generator/tests/utils_test.py,sha256=HWRXR45QYL1f6L8xsMppqLXzF9HAsrMwTMQIKpZrc_M,539
@@ -32,7 +32,7 @@ semantic_model_generator/validate/context_length.py,sha256=HL-GfaRXNcVji1-pAFGXG
32
32
  semantic_model_generator/validate/keywords.py,sha256=frZ5HjRXP69K6dYAU5_d86oSp40_3yoLUg1eQwU3oLM,7080
33
33
  semantic_model_generator/validate/schema.py,sha256=eL_wl5yscIeczwNBRUKhF_7QqWW2wSGimkgaOhMFsrA,5893
34
34
  semantic_model_generator/validate_model.py,sha256=Uq-V-GfPeF2Dy4l9uF5Guv104gDCDGh0Cxz1AJOu5dk,836
35
- clickzetta_semantic_model_generator-1.0.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
36
- clickzetta_semantic_model_generator-1.0.3.dist-info/METADATA,sha256=A1kBc4PO_LEbIjWM-24jHnnV6NynmowuX5Jy91tlWBk,7816
37
- clickzetta_semantic_model_generator-1.0.3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- clickzetta_semantic_model_generator-1.0.3.dist-info/RECORD,,
35
+ clickzetta_semantic_model_generator-1.0.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
36
+ clickzetta_semantic_model_generator-1.0.5.dist-info/METADATA,sha256=rxOjgbcKvTYIapoteFS2Lz9E1388cFCCpZPa4VjjcrE,7816
37
+ clickzetta_semantic_model_generator-1.0.5.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ clickzetta_semantic_model_generator-1.0.5.dist-info/RECORD,,
@@ -11,7 +11,12 @@ from clickzetta.zettapark.session import Session
11
11
  from loguru import logger
12
12
 
13
13
  from semantic_model_generator.clickzetta_utils import env_vars
14
- from semantic_model_generator.clickzetta_utils.utils import create_session
14
+ from semantic_model_generator.clickzetta_utils.utils import (
15
+ create_session,
16
+ join_quoted_identifiers,
17
+ normalize_identifier,
18
+ quote_identifier,
19
+ )
15
20
  from semantic_model_generator.data_processing.data_types import Column, Table
16
21
 
17
22
  ConnectionType = TypeVar("ConnectionType", bound=Session)
@@ -151,18 +156,8 @@ class ClickzettaConnectionProxy:
151
156
  self.session.close()
152
157
 
153
158
 
154
- def _quote_identifier(name: str) -> str:
155
- return f'"{name}"'
156
-
157
-
158
159
  def _qualify_table(workspace: str, schema_name: str, table_name: str) -> str:
159
- return ".".join(
160
- [
161
- _quote_identifier(workspace),
162
- _quote_identifier(schema_name),
163
- _quote_identifier(table_name),
164
- ]
165
- )
160
+ return join_quoted_identifiers(workspace, schema_name, table_name)
166
161
 
167
162
 
168
163
  def _value_is_true(value: Any) -> bool:
@@ -175,11 +170,9 @@ def _value_is_true(value: Any) -> bool:
175
170
 
176
171
 
177
172
  def _sanitize_identifier(value: Any, fallback: str = "") -> str:
178
- if value is None or value == "":
173
+ normalized = normalize_identifier(value)
174
+ if not normalized:
179
175
  return fallback
180
- normalized = str(value).strip()
181
- if normalized.startswith('"') and normalized.endswith('"') and len(normalized) >= 2:
182
- normalized = normalized[1:-1]
183
176
  return normalized
184
177
 
185
178
 
@@ -216,21 +209,19 @@ def _fetch_distinct_values(
216
209
  column_name: str,
217
210
  ndv: int,
218
211
  ) -> Optional[List[str]]:
219
- workspace_part = (
220
- _sanitize_identifier(workspace, workspace).upper() if workspace else ""
221
- )
212
+ workspace_part = _sanitize_identifier(workspace, workspace) if workspace else ""
222
213
  schema_part = (
223
- _sanitize_identifier(schema_name, schema_name).upper() if schema_name else ""
214
+ _sanitize_identifier(schema_name, schema_name) if schema_name else ""
224
215
  )
225
- table_part = _sanitize_identifier(table_name, table_name).upper()
226
- column_part = _sanitize_identifier(column_name, column_name).upper()
216
+ table_part = _sanitize_identifier(table_name, table_name)
217
+ column_part = _sanitize_identifier(column_name, column_name)
227
218
 
228
- qualified_parts = [
229
- part for part in (workspace_part, schema_part, table_part) if part
230
- ]
231
- qualified_table = ".".join(qualified_parts)
219
+ qualified_table = join_quoted_identifiers(
220
+ workspace_part, schema_part, table_part
221
+ )
222
+ column_expr = quote_identifier(column_part)
232
223
 
233
- query = f"SELECT DISTINCT {column_part} FROM {qualified_table} LIMIT {ndv}"
224
+ query = f"SELECT DISTINCT {column_expr} FROM {qualified_table} LIMIT {ndv}"
234
225
  try:
235
226
  df = session.sql(query).to_pandas()
236
227
  if df.empty:
@@ -489,15 +480,30 @@ def _fetch_columns_via_show(
489
480
  return pd.DataFrame()
490
481
 
491
482
  rows: List[pd.DataFrame] = []
492
- catalog = workspace.upper()
493
- schema = table_schema.upper() if table_schema else ""
483
+ category = _catalog_category(session, workspace)
484
+ is_shared_catalog = category in {"SHARED", "EXTERNAL"}
485
+ catalog = workspace if is_shared_catalog else workspace.upper()
486
+ schema = (
487
+ table_schema or ""
488
+ )
489
+ if schema and not is_shared_catalog:
490
+ schema = schema.upper()
494
491
 
495
492
  for table_name in table_names:
496
493
  qualified_parts = [
497
- part for part in (catalog, schema, table_name.upper()) if part
494
+ part
495
+ for part in (
496
+ catalog,
497
+ schema,
498
+ table_name.upper() if not is_shared_catalog else table_name,
499
+ )
500
+ if part
498
501
  ]
499
502
  qualified_table = ".".join(qualified_parts)
500
- query = f"SHOW COLUMNS IN {qualified_table}"
503
+ if is_shared_catalog:
504
+ query = f"SHOW COLUMNS IN SHARE {qualified_table}"
505
+ else:
506
+ query = f"SHOW COLUMNS IN {qualified_table}"
501
507
  try:
502
508
  df = session.sql(query).to_pandas()
503
509
  except Exception as exc:
@@ -655,14 +661,25 @@ def fetch_tables_views_in_schema(
655
661
  parts = schema_name.split(".", maxsplit=1)
656
662
  workspace = parts[0]
657
663
  schema = parts[1] if len(parts) > 1 else ""
658
- workspace_upper = workspace.upper()
659
- schema_upper = schema.upper()
664
+ category = _catalog_category(session, workspace)
665
+ is_shared_catalog = category in {"SHARED", "EXTERNAL"}
666
+
667
+ workspace_token = workspace if is_shared_catalog else workspace.upper()
668
+ schema_token = schema if is_shared_catalog else schema.upper()
660
669
 
661
670
  try:
662
- if workspace_upper and schema_upper:
663
- df = session.sql(
664
- f"SHOW TABLES IN {workspace_upper}.{schema_upper}"
665
- ).to_pandas()
671
+ if workspace_token and schema_token:
672
+ if is_shared_catalog:
673
+ scope = ".".join(
674
+ part for part in (workspace_token, schema_token) if part
675
+ )
676
+ df = session.sql(f"SHOW TABLES IN SHARE {scope}").to_pandas()
677
+ else:
678
+ scope = join_quoted_identifiers(
679
+ workspace_token,
680
+ schema_token,
681
+ )
682
+ df = session.sql(f"SHOW TABLES IN {scope}").to_pandas()
666
683
  else:
667
684
  df = session.sql("SHOW TABLES").to_pandas()
668
685
  except Exception as exc: # pragma: no cover
@@ -738,11 +755,15 @@ def fetch_stages_in_schema(connection: Any, schema_name: str) -> List[str]:
738
755
 
739
756
  queries: List[str] = []
740
757
  if schema:
741
- queries.append(f"SHOW VOLUMES IN {workspace}.{schema}")
742
- queries.append(f"SHOW STAGES IN SCHEMA {workspace}.{schema}")
758
+ scope = join_quoted_identifiers(workspace, schema)
759
+ if scope:
760
+ queries.append(f"SHOW VOLUMES IN {scope}")
761
+ queries.append(f"SHOW STAGES IN SCHEMA {scope}")
743
762
  else:
744
- queries.append(f"SHOW VOLUMES IN {workspace}")
745
- queries.append(f"SHOW STAGES IN DATABASE {workspace}")
763
+ workspace_identifier = quote_identifier(workspace)
764
+ if workspace_identifier:
765
+ queries.append(f"SHOW VOLUMES IN {workspace_identifier}")
766
+ queries.append(f"SHOW STAGES IN DATABASE {workspace_identifier}")
746
767
 
747
768
  stage_names: List[str] = ["volume:user://~/semantic_models/"]
748
769
  seen: set[str] = set(stage_names)
@@ -899,7 +920,7 @@ def create_table_in_schema(
899
920
  columns_schema: Dict[str, str],
900
921
  ) -> bool:
901
922
  fields = ", ".join(
902
- f"{_quote_identifier(name)} {dtype}" for name, dtype in columns_schema.items()
923
+ f"{quote_identifier(name)} {dtype}" for name, dtype in columns_schema.items()
903
924
  )
904
925
  query = f"CREATE TABLE IF NOT EXISTS {table_fqn} ({fields})"
905
926
  try:
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from contextlib import contextmanager
4
- from typing import Dict, Iterable
4
+ from typing import Any, Dict, Iterable
5
5
 
6
6
  from clickzetta.zettapark.session import Session
7
7
 
@@ -21,6 +21,47 @@ DEFAULT_HINTS: Dict[str, str] = {
21
21
  }
22
22
 
23
23
 
24
+ def normalize_identifier(value: Any) -> str:
25
+ """
26
+ Strips outer quotes/backticks and surrounding whitespace from an identifier.
27
+ Returns an empty string when the identifier is missing.
28
+ """
29
+
30
+ if value is None:
31
+ return ""
32
+ text = str(value).strip()
33
+ if len(text) >= 2 and text[0] == text[-1] and text[0] in {'"', '`'}:
34
+ return text[1:-1]
35
+ return text
36
+
37
+
38
+ def quote_identifier(value: Any) -> str:
39
+ """
40
+ Wraps an identifier in backticks, escaping embedded backticks as needed.
41
+ Returns an empty string if the identifier is missing.
42
+ """
43
+
44
+ normalized = normalize_identifier(value)
45
+ if not normalized:
46
+ return ""
47
+ escaped = normalized.replace("`", "``")
48
+ return f"`{escaped}`"
49
+
50
+
51
+ def join_quoted_identifiers(*parts: Any) -> str:
52
+ """
53
+ Joins identifier parts with '.' and ensures each segment is backtick-quoted.
54
+ Empty segments are skipped.
55
+ """
56
+
57
+ quoted_parts = [
58
+ quote_identifier(part)
59
+ for part in parts
60
+ if normalize_identifier(part)
61
+ ]
62
+ return ".".join(part for part in quoted_parts if part)
63
+
64
+
24
65
  def create_fqn_table(fqn_str: str) -> FQNParts:
25
66
  """
26
67
  Splits a fully qualified table name into its ClickZetta components.
@@ -72,7 +113,8 @@ def _apply_session_context(session: Session, *, schema: str, vcluster: str) -> N
72
113
  ("schema", schema),
73
114
  ("vcluster", vcluster),
74
115
  ):
75
- session.sql(f"USE {component.upper()} {value.upper()}")
116
+ identifier = quote_identifier(value)
117
+ session.sql(f"USE {component.upper()} {identifier}")
76
118
 
77
119
 
78
120
  def _iter_non_empty(*pairs: tuple[str, str]) -> Iterable[tuple[str, str]]:
@@ -11,12 +11,34 @@ from sqlglot import Dialect
11
11
  from semantic_model_generator.clickzetta_utils.clickzetta_connector import (
12
12
  OBJECT_DATATYPES,
13
13
  )
14
+ from semantic_model_generator.clickzetta_utils.utils import (
15
+ join_quoted_identifiers,
16
+ normalize_identifier,
17
+ )
14
18
  from semantic_model_generator.protos import semantic_model_pb2
15
19
 
16
20
  _SQLGLOT_CLICKZETTA_KEY = "".join(["snow", "flake"])
17
21
  ClickzettaDialect = Dialect.get_or_raise(_SQLGLOT_CLICKZETTA_KEY)
18
22
 
19
23
  _LOGICAL_TABLE_PREFIX = "__"
24
+ _SQLGLOT_QUOTE_CHAR = '"'
25
+
26
+
27
+ def _prepare_sql_for_parsing(sql: str) -> str:
28
+ """
29
+ Converts backtick-quoted identifiers to double quotes for SQLGlot parsing.
30
+ """
31
+
32
+ return sql.replace("`", _SQLGLOT_QUOTE_CHAR)
33
+
34
+
35
+ def _render_clickzetta_sql(expression: sqlglot.Expression, *, pretty: bool = False) -> str:
36
+ """
37
+ Renders a SQLGlot expression using ClickZetta dialect and rewrites identifiers with backticks.
38
+ """
39
+
40
+ rendered = expression.sql(dialect=ClickzettaDialect, pretty=pretty)
41
+ return rendered.replace(_SQLGLOT_QUOTE_CHAR, "`")
20
42
 
21
43
 
22
44
  def is_logical_table(table_name: str) -> bool:
@@ -33,12 +55,12 @@ def logical_table_name(table: semantic_model_pb2.Table) -> str:
33
55
 
34
56
  def fully_qualified_table_name(table: semantic_model_pb2.FullyQualifiedTable) -> str:
35
57
  """Returns fully qualified table name such as my_db.my_schema.my_table"""
36
- fqn = table.table
37
- if len(table.schema) > 0:
38
- fqn = f"{table.schema}.{fqn}"
39
- if len(table.database) > 0:
40
- fqn = f"{table.database}.{fqn}"
41
- return fqn # type: ignore[no-any-return]
58
+ parts = [
59
+ normalize_identifier(component)
60
+ for component in (table.database, table.schema, table.table)
61
+ if component
62
+ ]
63
+ return join_quoted_identifiers(*parts) # type: ignore[no-any-return]
42
64
 
43
65
 
44
66
  def is_aggregation_expr(col: semantic_model_pb2.Column) -> bool:
@@ -156,8 +178,8 @@ def _generate_cte_for(
156
178
  cte = f"WITH {logical_table_name(table)} AS (\n"
157
179
  cte += "SELECT \n"
158
180
  cte += ",\n".join(expr_columns) + "\n"
159
- cte += f"FROM {fully_qualified_table_name(table.base_table)}"
160
- cte += ")"
181
+ cte += f"FROM {fully_qualified_table_name(table.base_table)}\n"
182
+ cte += ")\n"
161
183
  return cte
162
184
 
163
185
 
@@ -261,13 +283,15 @@ def _convert_to_clickzetta_sql(sql: str) -> str:
261
283
  str: The SQL statement in ClickZetta syntax.
262
284
  """
263
285
  try:
264
- expression = sqlglot.parse_one(sql, dialect=ClickzettaDialect)
286
+ expression = sqlglot.parse_one(
287
+ _prepare_sql_for_parsing(sql), dialect=ClickzettaDialect
288
+ )
265
289
  except Exception as e:
266
290
  raise ValueError(
267
291
  f"Unable to parse sql statement.\n Provided sql: {sql}\n. Error: {e}"
268
292
  )
269
293
 
270
- return expression.sql(dialect=ClickzettaDialect)
294
+ return _render_clickzetta_sql(expression)
271
295
 
272
296
 
273
297
  def generate_select(
@@ -332,12 +356,16 @@ def expand_all_logical_tables_as_ctes(
332
356
  for cte in ctes:
333
357
  new_withs.append(
334
358
  sqlglot.parse_one(
335
- cte, read=ClickzettaDialect, into=sqlglot.expressions.With
359
+ _prepare_sql_for_parsing(cte),
360
+ read=ClickzettaDialect,
361
+ into=sqlglot.expressions.With,
336
362
  )
337
363
  )
338
364
 
339
365
  # Step 3: Prefix the CTEs to the original query.
340
- ast = sqlglot.parse_one(sql_query, read=ClickzettaDialect)
366
+ ast = sqlglot.parse_one(
367
+ _prepare_sql_for_parsing(sql_query), read=ClickzettaDialect
368
+ )
341
369
  with_ = ast.args.get("with")
342
370
  # If the query doesn't have a WITH clause, then generate one.
343
371
  if with_ is None:
@@ -349,7 +377,9 @@ def expand_all_logical_tables_as_ctes(
349
377
  else:
350
378
  new_ctes = [w.expressions[0] for w in new_withs]
351
379
  with_.set("expressions", new_ctes + with_.expressions)
352
- return ast.sql(dialect=ClickzettaDialect, pretty=True) # type: ignore [no-any-return]
380
+ return _render_clickzetta_sql(
381
+ ast, pretty=True
382
+ ) # type: ignore [no-any-return]
353
383
 
354
384
 
355
385
  def context_to_column_format(
@@ -1,6 +1,7 @@
1
1
  import math
2
2
  import os
3
3
  import re
4
+ import time
4
5
  from collections import defaultdict
5
6
  from datetime import datetime
6
7
  from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -17,7 +18,12 @@ from semantic_model_generator.clickzetta_utils.clickzetta_connector import (
17
18
  get_table_representation,
18
19
  get_valid_schemas_tables_columns_df,
19
20
  )
20
- from semantic_model_generator.clickzetta_utils.utils import create_fqn_table
21
+ from semantic_model_generator.clickzetta_utils.utils import (
22
+ create_fqn_table,
23
+ join_quoted_identifiers,
24
+ normalize_identifier,
25
+ quote_identifier,
26
+ )
21
27
  from semantic_model_generator.data_processing import data_types, proto_utils
22
28
  from semantic_model_generator.llm import (
23
29
  DashscopeClient,
@@ -41,6 +47,14 @@ _AUTOGEN_COMMENT_TOKEN = (
41
47
  )
42
48
  _DEFAULT_N_SAMPLE_VALUES_PER_COL = 10
43
49
  _AUTOGEN_COMMENT_WARNING = f"# NOTE: This file was auto-generated by the semantic model generator. Please fill out placeholders marked with {_FILL_OUT_TOKEN} (or remove if not relevant) and verify autogenerated comments.\n"
50
+ _GENERIC_IDENTIFIER_TOKENS = {
51
+ "ID",
52
+ "NAME",
53
+ "CODE",
54
+ "KEY",
55
+ "VALUE",
56
+ "NUMBER",
57
+ }
44
58
 
45
59
 
46
60
  def _singularize(token: str) -> str:
@@ -90,6 +104,14 @@ def _identifier_tokens(
90
104
  return tokens
91
105
 
92
106
 
107
+ def _is_generic_identifier(name: str) -> bool:
108
+ tokens = [token for token in _identifier_tokens(name) if token]
109
+ if not tokens:
110
+ return True
111
+ normalized_tokens = {token.upper() for token in tokens}
112
+ return normalized_tokens.issubset(_GENERIC_IDENTIFIER_TOKENS)
113
+
114
+
93
115
  def _sanitize_identifier_name(
94
116
  name: str, prefixes_to_drop: Optional[set[str]] = None
95
117
  ) -> str:
@@ -354,19 +376,17 @@ def _format_literal(value: str, base_type: str) -> str:
354
376
 
355
377
  def _format_sql_identifier(name: str) -> str:
356
378
  """
357
- Formats an identifier for SQL (without quoting) by stripping quotes and uppercasing.
379
+ Formats an identifier for SQL by wrapping it in backticks.
358
380
  """
359
- if not name:
360
- return ""
361
- return str(name).replace('"', "").replace("`", "").strip().upper()
381
+ return quote_identifier(name)
362
382
 
363
383
 
364
384
  def _qualified_table_name(fqn: data_types.FQNParts) -> str:
365
385
  """
366
- Builds a fully qualified table name without quoting.
386
+ Builds a fully qualified, backtick-quoted table name.
367
387
  """
368
- parts = [part for part in (fqn.database, fqn.schema_name, fqn.table) if part]
369
- return ".".join(_format_sql_identifier(part) for part in parts if part)
388
+ parts = [normalize_identifier(part) for part in (fqn.database, fqn.schema_name, fqn.table)]
389
+ return join_quoted_identifiers(*(part for part in parts if part))
370
390
 
371
391
 
372
392
  def _levenshtein_distance(s1: str, s2: str) -> int:
@@ -977,6 +997,19 @@ def _calculate_relationship_confidence(
977
997
 
978
998
  confidence_score += name_confidence
979
999
 
1000
+ generic_pair_count = sum(
1001
+ 1
1002
+ for left_col, right_col in column_pairs
1003
+ if _is_generic_identifier(left_col)
1004
+ and _is_generic_identifier(right_col)
1005
+ )
1006
+ if generic_pair_count:
1007
+ penalty = min(0.15 * generic_pair_count, 0.3)
1008
+ confidence_score = max(confidence_score - penalty, 0.0)
1009
+ reasoning_factors.append(
1010
+ f"Generic identifier names detected on both sides (-{penalty:.2f} confidence)"
1011
+ )
1012
+
980
1013
  # Check for foreign key naming patterns
981
1014
  fk_pattern_confidence = 0.0
982
1015
  for left_col, right_col in column_pairs:
@@ -2326,11 +2359,31 @@ def _infer_relationships(
2326
2359
  *,
2327
2360
  session: Optional[Session] = None,
2328
2361
  strict_join_inference: bool = False,
2362
+ status: Optional[Dict[str, bool]] = None,
2363
+ max_relationships: Optional[int] = None,
2364
+ min_confidence: float = 0.2,
2365
+ timeout_seconds: Optional[float] = None,
2329
2366
  ) -> List[semantic_model_pb2.Relationship]:
2367
+ status_dict = status if status is not None else {}
2368
+ if "limited_by_timeout" not in status_dict:
2369
+ status_dict["limited_by_timeout"] = False
2370
+ if "limited_by_max_relationships" not in status_dict:
2371
+ status_dict["limited_by_max_relationships"] = False
2372
+
2330
2373
  relationships: List[semantic_model_pb2.Relationship] = []
2331
2374
  if not raw_tables:
2332
2375
  return relationships
2333
2376
 
2377
+ start_time = time.perf_counter()
2378
+ min_confidence = max(0.0, min(min_confidence, 1.0))
2379
+ limit_reached = False
2380
+
2381
+ def _timed_out() -> bool:
2382
+ return (
2383
+ timeout_seconds is not None
2384
+ and (time.perf_counter() - start_time) >= timeout_seconds
2385
+ )
2386
+
2334
2387
  metadata = {}
2335
2388
  prefix_counter: Dict[str, int] = {}
2336
2389
  for _, raw_table in raw_tables:
@@ -2392,14 +2445,39 @@ def _infer_relationships(
2392
2445
  def _record_pair(
2393
2446
  left_table: str, right_table: str, left_col: str, right_col: str
2394
2447
  ) -> None:
2448
+ nonlocal limit_reached
2449
+ if limit_reached:
2450
+ return
2451
+ if _timed_out():
2452
+ status_dict["limited_by_timeout"] = True
2453
+ limit_reached = True
2454
+ return
2455
+
2395
2456
  key = (left_table, right_table)
2396
2457
  value = (left_col, right_col)
2397
- if value not in pairs.setdefault(key, []):
2398
- pairs[key].append(value)
2458
+ bucket = pairs.setdefault(key, [])
2459
+ if value not in bucket:
2460
+ bucket.append(value)
2461
+ if (
2462
+ max_relationships is not None
2463
+ and len(pairs) >= max_relationships
2464
+ ):
2465
+ status_dict["limited_by_max_relationships"] = True
2466
+ limit_reached = True
2399
2467
 
2400
2468
  table_names = list(metadata.keys())
2401
2469
  for i in range(len(table_names)):
2470
+ if limit_reached or status_dict["limited_by_timeout"]:
2471
+ break
2472
+ if _timed_out():
2473
+ status_dict["limited_by_timeout"] = True
2474
+ break
2402
2475
  for j in range(i + 1, len(table_names)):
2476
+ if limit_reached or status_dict["limited_by_timeout"]:
2477
+ break
2478
+ if _timed_out():
2479
+ status_dict["limited_by_timeout"] = True
2480
+ break
2403
2481
  table_a_name = table_names[i]
2404
2482
  table_b_name = table_names[j]
2405
2483
  table_a = metadata[table_a_name]
@@ -2575,6 +2653,15 @@ def _infer_relationships(
2575
2653
 
2576
2654
  # Build relationships with inferred cardinality
2577
2655
  for (left_table, right_table), column_pairs in pairs.items():
2656
+ if _timed_out():
2657
+ status_dict["limited_by_timeout"] = True
2658
+ break
2659
+ if (
2660
+ max_relationships is not None
2661
+ and len(relationships) >= max_relationships
2662
+ ):
2663
+ status_dict["limited_by_max_relationships"] = True
2664
+ break
2578
2665
  # Infer cardinality based on available metadata
2579
2666
  left_meta = metadata[left_table]
2580
2667
  right_meta = metadata[right_table]
@@ -2777,6 +2864,16 @@ def _infer_relationships(
2777
2864
  for factor in confidence_analysis["reasoning_factors"][:3]: # Top 3 factors
2778
2865
  logger.debug(f" + {factor}")
2779
2866
 
2867
+ if confidence_analysis["confidence_score"] < min_confidence:
2868
+ logger.debug(
2869
+ "Dropping relationship {} -> {} due to low confidence {:.2f} (threshold {:.2f})",
2870
+ left_table,
2871
+ right_table,
2872
+ confidence_analysis["confidence_score"],
2873
+ min_confidence,
2874
+ )
2875
+ continue
2876
+
2780
2877
  # Determine relationship type based on cardinality
2781
2878
  if left_card == "1" and right_card == "1":
2782
2879
  rel_type = semantic_model_pb2.RelationshipType.one_to_one
@@ -2804,16 +2901,27 @@ def _infer_relationships(
2804
2901
  relationships.append(relationship)
2805
2902
 
2806
2903
  # Phase 2: Detect many-to-many relationships through bridge table analysis
2807
- many_to_many_relationships = _detect_many_to_many_relationships(
2808
- raw_tables, metadata, relationships
2809
- )
2810
-
2811
- if many_to_many_relationships:
2812
- relationships.extend(many_to_many_relationships)
2813
- logger.info(
2814
- f"Detected {len(many_to_many_relationships)} many-to-many relationships via bridge tables"
2904
+ many_to_many_relationships: List[semantic_model_pb2.Relationship] = []
2905
+ if not status_dict["limited_by_timeout"] and (
2906
+ max_relationships is None or len(relationships) < max_relationships
2907
+ ):
2908
+ many_to_many_relationships = _detect_many_to_many_relationships(
2909
+ raw_tables, metadata, relationships
2815
2910
  )
2816
2911
 
2912
+ if many_to_many_relationships and max_relationships is not None:
2913
+ remaining = max_relationships - len(relationships)
2914
+ if remaining <= 0:
2915
+ many_to_many_relationships = []
2916
+ else:
2917
+ many_to_many_relationships = many_to_many_relationships[:remaining]
2918
+
2919
+ if many_to_many_relationships:
2920
+ relationships.extend(many_to_many_relationships)
2921
+ logger.info(
2922
+ f"Detected {len(many_to_many_relationships)} many-to-many relationships via bridge tables"
2923
+ )
2924
+
2817
2925
  logger.info(
2818
2926
  f"Inferred {len(relationships)} total relationships across {len(raw_tables)} tables"
2819
2927
  )
@@ -4,6 +4,7 @@ from .discovery import (
4
4
  RelationshipDiscoveryResult,
5
5
  RelationshipSummary,
6
6
  discover_relationships_from_schema,
7
+ discover_relationships_from_table_definitions,
7
8
  discover_relationships_from_tables,
8
9
  )
9
10
 
@@ -11,5 +12,6 @@ __all__ = [
11
12
  "RelationshipDiscoveryResult",
12
13
  "RelationshipSummary",
13
14
  "discover_relationships_from_schema",
15
+ "discover_relationships_from_table_definitions",
14
16
  "discover_relationships_from_tables",
15
17
  ]
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import time
4
4
  from dataclasses import dataclass
5
- from typing import Any, Iterable, List, Optional, Sequence, Tuple
5
+ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
6
6
 
7
7
  import pandas as pd
8
8
  from loguru import logger
@@ -13,7 +13,7 @@ from semantic_model_generator.clickzetta_utils.clickzetta_connector import (
13
13
  get_valid_schemas_tables_columns_df,
14
14
  )
15
15
  from semantic_model_generator.data_processing import data_types
16
- from semantic_model_generator.data_processing.data_types import FQNParts, Table
16
+ from semantic_model_generator.data_processing.data_types import Column, FQNParts, Table
17
17
  from semantic_model_generator.generate_model import (
18
18
  _DEFAULT_N_SAMPLE_VALUES_PER_COL,
19
19
  _infer_relationships,
@@ -34,6 +34,10 @@ class RelationshipSummary:
34
34
  total_columns: int
35
35
  total_relationships_found: int
36
36
  processing_time_ms: int
37
+ limited_by_timeout: bool = False
38
+ limited_by_max_relationships: bool = False
39
+ limited_by_table_cap: bool = False
40
+ notes: Optional[str] = None
37
41
 
38
42
 
39
43
  @dataclass
@@ -97,20 +101,139 @@ def _build_tables_from_dataframe(
97
101
  return tables
98
102
 
99
103
 
104
+ def _tables_payload_to_raw_tables(
105
+ tables: Sequence[Mapping[str, Any]],
106
+ *,
107
+ default_workspace: str = "OFFLINE",
108
+ default_schema: str = "PUBLIC",
109
+ ) -> List[Tuple[FQNParts, Table]]:
110
+ raw_tables: List[Tuple[FQNParts, Table]] = []
111
+ for table_index, table_entry in enumerate(tables):
112
+ if not isinstance(table_entry, Mapping):
113
+ raise TypeError("Each table definition must be a mapping of table metadata")
114
+
115
+ raw_table_identifier = str(
116
+ table_entry.get("table_name")
117
+ or table_entry.get("name")
118
+ or table_entry.get("table")
119
+ or ""
120
+ ).strip()
121
+ if not raw_table_identifier:
122
+ raise ValueError("Table definition missing 'table_name'")
123
+
124
+ identifier_workspace, identifier_schema, identifier_table = _split_table_identifier(
125
+ raw_table_identifier
126
+ )
127
+
128
+ workspace = str(
129
+ table_entry.get("workspace")
130
+ or table_entry.get("database")
131
+ or identifier_workspace
132
+ or default_workspace
133
+ ).strip() or default_workspace
134
+ schema = str(
135
+ table_entry.get("schema")
136
+ or table_entry.get("schema_name")
137
+ or identifier_schema
138
+ or default_schema
139
+ ).strip() or default_schema
140
+
141
+ table_name = identifier_table.strip()
142
+ if not table_name:
143
+ raise ValueError(f"Unable to parse table name from '{raw_table_identifier}'")
144
+
145
+ columns_payload = table_entry.get("columns")
146
+ if not isinstance(columns_payload, Sequence) or not columns_payload:
147
+ raise ValueError(
148
+ f"Table '{table_name}' must include a non-empty 'columns' list"
149
+ )
150
+
151
+ columns: List[Column] = []
152
+ for column_index, column_entry in enumerate(columns_payload):
153
+ if not isinstance(column_entry, Mapping):
154
+ raise TypeError(
155
+ f"Column definition for table '{table_name}' must be a mapping"
156
+ )
157
+
158
+ column_name = str(
159
+ column_entry.get("name")
160
+ or column_entry.get("column_name")
161
+ or column_entry.get("field")
162
+ or ""
163
+ ).strip()
164
+ if not column_name:
165
+ raise ValueError(
166
+ f"Column definition in table '{table_name}' missing 'name'"
167
+ )
168
+
169
+ column_type = str(
170
+ column_entry.get("type")
171
+ or column_entry.get("data_type")
172
+ or "STRING"
173
+ ).strip()
174
+
175
+ values = column_entry.get("sample_values") or column_entry.get("values")
176
+ if isinstance(values, Sequence) and not isinstance(values, (str, bytes)):
177
+ sample_values = [str(value) for value in values]
178
+ else:
179
+ sample_values = None
180
+
181
+ is_primary = bool(
182
+ column_entry.get("is_primary_key")
183
+ or column_entry.get("primary_key")
184
+ or column_entry.get("is_primary")
185
+ )
186
+
187
+ columns.append(
188
+ Column(
189
+ id_=column_index,
190
+ column_name=column_name,
191
+ column_type=column_type,
192
+ values=sample_values,
193
+ comment=column_entry.get("comment"),
194
+ is_primary_key=is_primary,
195
+ )
196
+ )
197
+
198
+ table_proto = Table(
199
+ id_=table_index,
200
+ name=table_name.upper(),
201
+ columns=columns,
202
+ comment=table_entry.get("comment"),
203
+ )
204
+ fqn = FQNParts(
205
+ database=workspace.upper(),
206
+ schema_name=schema.upper(),
207
+ table=table_name,
208
+ )
209
+ raw_tables.append((fqn, table_proto))
210
+
211
+ return raw_tables
212
+
213
+
100
214
  def _discover_relationships(
101
215
  raw_tables: List[Tuple[FQNParts, Table]],
102
216
  strict_join_inference: bool,
103
217
  session: Optional[Session],
104
- ) -> List[semantic_model_pb2.Relationship]:
218
+ *,
219
+ max_relationships: Optional[int] = None,
220
+ min_confidence: float = 0.5,
221
+ timeout_seconds: Optional[float] = None,
222
+ ) -> Tuple[List[semantic_model_pb2.Relationship], Dict[str, bool]]:
105
223
  if not raw_tables:
106
- return []
224
+ return [], {"limited_by_timeout": False, "limited_by_max_relationships": False}
107
225
 
226
+ status: Dict[str, bool] = {}
108
227
  relationships = _infer_relationships(
109
228
  raw_tables,
110
229
  session=session if strict_join_inference else None,
111
230
  strict_join_inference=strict_join_inference,
231
+ status=status,
232
+ max_relationships=max_relationships,
233
+ min_confidence=min_confidence,
234
+ timeout_seconds=timeout_seconds,
112
235
  )
113
- return relationships
236
+ return relationships, status
114
237
 
115
238
 
116
239
  def discover_relationships_from_tables(
@@ -118,33 +241,86 @@ def discover_relationships_from_tables(
118
241
  *,
119
242
  strict_join_inference: bool = False,
120
243
  session: Optional[Session] = None,
244
+ max_relationships: Optional[int] = None,
245
+ min_confidence: float = 0.5,
246
+ timeout_seconds: Optional[float] = 30.0,
247
+ max_tables: Optional[int] = None,
121
248
  ) -> RelationshipDiscoveryResult:
122
249
  """
123
250
  Run relationship inference using pre-constructed table metadata.
124
251
  """
125
252
  start = time.perf_counter()
126
- relationships = _discover_relationships(
127
- list(tables),
253
+ raw_tables = list(tables)
254
+ limited_by_table_cap = False
255
+ notes: List[str] = []
256
+
257
+ if max_tables is not None and len(raw_tables) > max_tables:
258
+ limited_by_table_cap = True
259
+ notes.append(
260
+ f"Input contained {len(raw_tables)} tables; analysis limited to first {max_tables}."
261
+ )
262
+ raw_tables = raw_tables[:max_tables]
263
+
264
+ relationships, status = _discover_relationships(
265
+ raw_tables,
128
266
  strict_join_inference=strict_join_inference,
129
267
  session=session,
268
+ max_relationships=max_relationships,
269
+ min_confidence=min_confidence,
270
+ timeout_seconds=timeout_seconds,
130
271
  )
131
272
  end = time.perf_counter()
132
273
 
133
- all_columns = sum(len(table.columns) for _, table in tables)
274
+ all_columns = sum(len(table.columns) for _, table in raw_tables)
134
275
  summary = RelationshipSummary(
135
- total_tables=len(tables),
276
+ total_tables=len(raw_tables),
136
277
  total_columns=all_columns,
137
278
  total_relationships_found=len(relationships),
138
279
  processing_time_ms=int((end - start) * 1000),
280
+ limited_by_timeout=status.get("limited_by_timeout", False),
281
+ limited_by_max_relationships=status.get("limited_by_max_relationships", False),
282
+ limited_by_table_cap=limited_by_table_cap,
283
+ notes=" ".join(notes) if notes else None,
139
284
  )
140
285
 
141
286
  return RelationshipDiscoveryResult(
142
287
  relationships=relationships,
143
- tables=[table for _, table in tables],
288
+ tables=[table for _, table in raw_tables],
144
289
  summary=summary,
145
290
  )
146
291
 
147
292
 
293
+ def discover_relationships_from_table_definitions(
294
+ table_definitions: Sequence[Mapping[str, Any]],
295
+ *,
296
+ default_workspace: str = "OFFLINE",
297
+ default_schema: str = "PUBLIC",
298
+ strict_join_inference: bool = False,
299
+ session: Optional[Session] = None,
300
+ max_relationships: Optional[int] = None,
301
+ min_confidence: float = 0.5,
302
+ timeout_seconds: Optional[float] = 15.0,
303
+ max_tables: Optional[int] = None,
304
+ ) -> RelationshipDiscoveryResult:
305
+ """Run relationship inference using raw table metadata dictionaries."""
306
+
307
+ raw_tables = _tables_payload_to_raw_tables(
308
+ table_definitions,
309
+ default_workspace=default_workspace,
310
+ default_schema=default_schema,
311
+ )
312
+
313
+ return discover_relationships_from_tables(
314
+ raw_tables,
315
+ strict_join_inference=strict_join_inference,
316
+ session=session,
317
+ max_relationships=max_relationships,
318
+ min_confidence=min_confidence,
319
+ timeout_seconds=timeout_seconds,
320
+ max_tables=max_tables,
321
+ )
322
+
323
+
148
324
  def discover_relationships_from_schema(
149
325
  session: Session,
150
326
  workspace: str,
@@ -154,6 +330,10 @@ def discover_relationships_from_schema(
154
330
  sample_values_per_column: int = _DEFAULT_N_SAMPLE_VALUES_PER_COL,
155
331
  strict_join_inference: bool = False,
156
332
  max_workers: int = DEFAULT_MAX_WORKERS,
333
+ max_relationships: Optional[int] = None,
334
+ min_confidence: float = 0.5,
335
+ timeout_seconds: Optional[float] = 30.0,
336
+ max_tables: Optional[int] = 60,
157
337
  ) -> RelationshipDiscoveryResult:
158
338
  """
159
339
  Discover table relationships for all tables in a ClickZetta schema.
@@ -199,4 +379,24 @@ def discover_relationships_from_schema(
199
379
  raw_tables,
200
380
  strict_join_inference=strict_join_inference,
201
381
  session=session,
382
+ max_relationships=max_relationships,
383
+ min_confidence=min_confidence,
384
+ timeout_seconds=timeout_seconds,
385
+ max_tables=max_tables,
202
386
  )
387
+ def _split_table_identifier(identifier: str) -> Tuple[Optional[str], Optional[str], str]:
388
+ """
389
+ Split a table identifier that may include workspace/schema prefixes.
390
+
391
+ Supported formats:
392
+ - workspace.schema.table
393
+ - schema.table
394
+ - table
395
+ """
396
+
397
+ parts = [part.strip() for part in identifier.split(".") if part.strip()]
398
+ if len(parts) == 3:
399
+ return parts[0], parts[1], parts[2]
400
+ if len(parts) == 2:
401
+ return None, parts[0], parts[1]
402
+ return None, None, parts[0]
@@ -5,6 +5,7 @@ import sqlglot
5
5
 
6
6
  from semantic_model_generator.data_processing.cte_utils import (
7
7
  ClickzettaDialect,
8
+ _prepare_sql_for_parsing,
8
9
  _enrich_column_in_expr_with_aggregation,
9
10
  _get_col_expr,
10
11
  _validate_col,
@@ -304,7 +305,7 @@ class SemanticModelTest(TestCase):
304
305
  col_format_tbl = get_test_table_col_format()
305
306
  got = generate_select(col_format_tbl, 100)
306
307
  want = [
307
- "WITH __t1 AS (SELECT d1_expr AS d1, d2_expr AS d2 FROM db.sc.t1) SELECT * FROM __t1 LIMIT 100"
308
+ "WITH __t1 AS (SELECT d1_expr AS d1, d2_expr AS d2 FROM `db`.`sc`.`t1`) SELECT * FROM __t1 LIMIT 100"
308
309
  ]
309
310
  assert got == want
310
311
 
@@ -312,8 +313,8 @@ class SemanticModelTest(TestCase):
312
313
  col_format_tbl = get_test_table_col_format_w_agg()
313
314
  got = generate_select(col_format_tbl, 100)
314
315
  want = [
315
- "WITH __t1 AS (SELECT SUM(d2) AS d2_total FROM db.sc.t1) SELECT * FROM __t1 LIMIT 100",
316
- "WITH __t1 AS (SELECT d1_expr AS d1, SUM(d3) OVER (PARTITION BY d1) AS d3 FROM db.sc.t1) SELECT * FROM __t1 LIMIT 100",
316
+ "WITH __t1 AS (SELECT SUM(d2) AS d2_total FROM `db`.`sc`.`t1`) SELECT * FROM __t1 LIMIT 100",
317
+ "WITH __t1 AS (SELECT d1_expr AS d1, SUM(d3) OVER (PARTITION BY d1) AS d3 FROM `db`.`sc`.`t1`) SELECT * FROM __t1 LIMIT 100",
317
318
  ]
318
319
  assert sorted(got) == sorted(want)
319
320
 
@@ -321,7 +322,7 @@ class SemanticModelTest(TestCase):
321
322
  col_format_tbl = get_test_table_col_format_w_agg_only()
322
323
  got = generate_select(col_format_tbl, 100)
323
324
  want = [
324
- "WITH __t1 AS (SELECT SUM(d2) AS d2_total FROM db.sc.t1) SELECT * FROM __t1 LIMIT 100"
325
+ "WITH __t1 AS (SELECT SUM(d2) AS d2_total FROM `db`.`sc`.`t1`) SELECT * FROM __t1 LIMIT 100"
325
326
  ]
326
327
  assert sorted(got) == sorted(want)
327
328
 
@@ -437,21 +438,21 @@ class SemanticModelTest(TestCase):
437
438
  want = """WITH __t1 AS (SELECT
438
439
  d1_expr AS d1,
439
440
  d2_expr AS d2
440
- FROM db.sc.t1
441
+ FROM `db`.`sc`.`t1`
441
442
  ), __t2 AS (
442
443
  SELECT
443
444
  td1_expr AS td1,
444
445
  m1_expr AS m1,
445
446
  m1_expr AS m2,
446
447
  m3_expr
447
- FROM db.sc.t2
448
+ FROM `db`.`sc`.`t2`
448
449
  )
449
450
  SELECT
450
451
  *
451
452
  FROM __t2"""
452
- assert sqlglot.parse_one(want, ClickzettaDialect) == sqlglot.parse_one(
453
- got, ClickzettaDialect
454
- )
453
+ assert sqlglot.parse_one(
454
+ _prepare_sql_for_parsing(want), ClickzettaDialect
455
+ ) == sqlglot.parse_one(_prepare_sql_for_parsing(got), ClickzettaDialect)
455
456
 
456
457
  def test_expand_all_logical_tables_as_ctes_with_column_renaming(self) -> None:
457
458
  ctx = semantic_model_pb2.SemanticModel(
@@ -465,12 +466,12 @@ FROM __t2"""
465
466
  clcks AS clicks,
466
467
  clcks,
467
468
  cst
468
- FROM db.sc.t1
469
+ FROM `db`.`sc`.`t1`
469
470
  )
470
471
  SELECT
471
472
  *
472
473
  FROM __t1
473
474
  """
474
- assert sqlglot.parse_one(want, ClickzettaDialect) == sqlglot.parse_one(
475
- got, ClickzettaDialect
476
- )
475
+ assert sqlglot.parse_one(
476
+ _prepare_sql_for_parsing(want), ClickzettaDialect
477
+ ) == sqlglot.parse_one(_prepare_sql_for_parsing(got), ClickzettaDialect)
@@ -6,6 +6,7 @@ import pandas as pd
6
6
 
7
7
  from semantic_model_generator.relationships.discovery import (
8
8
  discover_relationships_from_schema,
9
+ discover_relationships_from_table_definitions,
9
10
  )
10
11
 
11
12
 
@@ -112,3 +113,110 @@ def test_discover_relationships_from_schema_builds_relationships():
112
113
  right_tables = {rel.right_table for rel in result.relationships}
113
114
  assert "ORDERS" in left_tables
114
115
  assert "CUSTOMER" in right_tables
116
+
117
+
118
+ def test_discover_relationships_from_table_definitions_allows_manual_metadata() -> None:
119
+ payload = [
120
+ {
121
+ "table_name": "orders",
122
+ "columns": [
123
+ {"name": "order_id", "type": "NUMBER", "is_primary_key": True},
124
+ {"name": "customer_id", "type": "NUMBER"},
125
+ ],
126
+ },
127
+ {
128
+ "table_name": "customers",
129
+ "columns": [
130
+ {"name": "customer_id", "type": "NUMBER", "is_primary_key": True},
131
+ {"name": "name", "type": "STRING"},
132
+ ],
133
+ },
134
+ ]
135
+
136
+ result = discover_relationships_from_table_definitions(
137
+ payload,
138
+ default_workspace="demo",
139
+ default_schema="sales",
140
+ max_relationships=5,
141
+ timeout_seconds=5.0,
142
+ )
143
+
144
+ assert result.summary.total_tables == 2
145
+ assert result.summary.total_relationships_found >= 1
146
+ assert not result.summary.limited_by_timeout
147
+ assert any(
148
+ rel.left_table == "ORDERS" and rel.right_table == "CUSTOMERS"
149
+ for rel in result.relationships
150
+ )
151
+
152
+
153
+ def test_discover_relationships_from_table_definitions_filters_generic_ids() -> None:
154
+ payload = [
155
+ {
156
+ "table_name": "table_a",
157
+ "columns": [
158
+ {"name": "id", "type": "NUMBER", "is_primary_key": True},
159
+ {"name": "value", "type": "NUMBER"},
160
+ ],
161
+ },
162
+ {
163
+ "table_name": "table_b",
164
+ "columns": [
165
+ {"name": "id", "type": "NUMBER", "is_primary_key": True},
166
+ {"name": "value", "type": "NUMBER"},
167
+ ],
168
+ },
169
+ ]
170
+
171
+ result = discover_relationships_from_table_definitions(
172
+ payload,
173
+ min_confidence=0.6,
174
+ max_relationships=5,
175
+ )
176
+
177
+ assert result.summary.total_relationships_found == 0
178
+ assert not result.relationships
179
+
180
+
181
+ def test_table_definitions_support_fully_qualified_names() -> None:
182
+ payload = [
183
+ {
184
+ "table_name": "demo.sales.orders",
185
+ "columns": [
186
+ {"name": "order_id", "type": "NUMBER", "is_primary_key": True},
187
+ {"name": "customer_id", "type": "NUMBER"},
188
+ ],
189
+ },
190
+ {
191
+ "table_name": "sales.customers",
192
+ "workspace": "demo",
193
+ "columns": [
194
+ {"name": "customer_id", "type": "NUMBER", "is_primary_key": True},
195
+ {"name": "name", "type": "STRING"},
196
+ ],
197
+ },
198
+ {
199
+ "table_name": "products",
200
+ "workspace": "demo",
201
+ "schema": "sales",
202
+ "columns": [
203
+ {"name": "product_id", "type": "NUMBER", "is_primary_key": True},
204
+ {"name": "name", "type": "STRING"},
205
+ ],
206
+ },
207
+ ]
208
+
209
+ result = discover_relationships_from_table_definitions(
210
+ payload,
211
+ default_workspace="fallback",
212
+ default_schema="fallback_schema",
213
+ )
214
+
215
+ table_names = {table.name for table in result.tables}
216
+ assert table_names == {"ORDERS", "CUSTOMERS", "PRODUCTS"}
217
+
218
+ # Ensure relationships include the orders -> customers edge despite mixed identifiers
219
+ assert any(
220
+ rel.left_table == "ORDERS" and rel.right_table == "CUSTOMERS"
221
+ for rel in result.relationships
222
+ )