clickzetta-semantic-model-generator 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {clickzetta_semantic_model_generator-1.0.2.dist-info → clickzetta_semantic_model_generator-1.0.4.dist-info}/METADATA +5 -5
  2. clickzetta_semantic_model_generator-1.0.4.dist-info/RECORD +38 -0
  3. semantic_model_generator/clickzetta_utils/clickzetta_connector.py +100 -48
  4. semantic_model_generator/clickzetta_utils/env_vars.py +7 -2
  5. semantic_model_generator/clickzetta_utils/utils.py +44 -2
  6. semantic_model_generator/data_processing/cte_utils.py +44 -14
  7. semantic_model_generator/generate_model.py +711 -239
  8. semantic_model_generator/llm/dashscope_client.py +4 -2
  9. semantic_model_generator/llm/enrichment.py +144 -57
  10. semantic_model_generator/llm/progress_tracker.py +16 -15
  11. semantic_model_generator/relationships/__init__.py +2 -0
  12. semantic_model_generator/relationships/discovery.py +181 -16
  13. semantic_model_generator/tests/clickzetta_connector_test.py +3 -7
  14. semantic_model_generator/tests/cte_utils_test.py +15 -14
  15. semantic_model_generator/tests/generate_model_classification_test.py +12 -2
  16. semantic_model_generator/tests/llm_enrichment_test.py +152 -46
  17. semantic_model_generator/tests/relationship_discovery_test.py +70 -3
  18. semantic_model_generator/tests/relationships_filters_test.py +166 -30
  19. semantic_model_generator/tests/utils_test.py +1 -1
  20. semantic_model_generator/validate/keywords.py +453 -53
  21. semantic_model_generator/validate/schema.py +4 -2
  22. clickzetta_semantic_model_generator-1.0.2.dist-info/RECORD +0 -38
  23. {clickzetta_semantic_model_generator-1.0.2.dist-info → clickzetta_semantic_model_generator-1.0.4.dist-info}/LICENSE +0 -0
  24. {clickzetta_semantic_model_generator-1.0.2.dist-info → clickzetta_semantic_model_generator-1.0.4.dist-info}/WHEEL +0 -0
@@ -2,19 +2,18 @@ from __future__ import annotations
2
2
 
3
3
  import time
4
4
  from dataclasses import dataclass
5
- from typing import Any, Iterable, List, Optional, Sequence, Tuple
5
+ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
6
6
 
7
7
  import pandas as pd
8
8
  from loguru import logger
9
9
 
10
10
  from semantic_model_generator.clickzetta_utils.clickzetta_connector import (
11
11
  _TABLE_NAME_COL,
12
- _TABLE_SCHEMA_COL,
13
12
  get_table_representation,
14
13
  get_valid_schemas_tables_columns_df,
15
14
  )
16
15
  from semantic_model_generator.data_processing import data_types
17
- from semantic_model_generator.data_processing.data_types import FQNParts, Table
16
+ from semantic_model_generator.data_processing.data_types import Column, FQNParts, Table
18
17
  from semantic_model_generator.generate_model import (
19
18
  _DEFAULT_N_SAMPLE_VALUES_PER_COL,
20
19
  _infer_relationships,
@@ -35,6 +34,10 @@ class RelationshipSummary:
35
34
  total_columns: int
36
35
  total_relationships_found: int
37
36
  processing_time_ms: int
37
+ limited_by_timeout: bool = False
38
+ limited_by_max_relationships: bool = False
39
+ limited_by_table_cap: bool = False
40
+ notes: Optional[str] = None
38
41
 
39
42
 
40
43
  @dataclass
@@ -68,11 +71,7 @@ def _build_tables_from_dataframe(
68
71
  )
69
72
 
70
73
  table_order = (
71
- columns_df[_TABLE_NAME_COL]
72
- .astype(str)
73
- .str.upper()
74
- .drop_duplicates()
75
- .tolist()
74
+ columns_df[_TABLE_NAME_COL].astype(str).str.upper().drop_duplicates().tolist()
76
75
  )
77
76
 
78
77
  tables: List[Tuple[FQNParts, Table]] = []
@@ -102,20 +101,125 @@ def _build_tables_from_dataframe(
102
101
  return tables
103
102
 
104
103
 
104
+ def _tables_payload_to_raw_tables(
105
+ tables: Sequence[Mapping[str, Any]],
106
+ *,
107
+ default_workspace: str = "OFFLINE",
108
+ default_schema: str = "PUBLIC",
109
+ ) -> List[Tuple[FQNParts, Table]]:
110
+ raw_tables: List[Tuple[FQNParts, Table]] = []
111
+ for table_index, table_entry in enumerate(tables):
112
+ if not isinstance(table_entry, Mapping):
113
+ raise TypeError("Each table definition must be a mapping of table metadata")
114
+
115
+ table_name = str(
116
+ table_entry.get("table_name")
117
+ or table_entry.get("name")
118
+ or table_entry.get("table")
119
+ or ""
120
+ ).strip()
121
+ if not table_name:
122
+ raise ValueError("Table definition missing 'table_name'")
123
+
124
+ workspace = str(table_entry.get("workspace") or default_workspace).strip() or default_workspace
125
+ schema = str(
126
+ table_entry.get("schema")
127
+ or table_entry.get("schema_name")
128
+ or default_schema
129
+ ).strip() or default_schema
130
+
131
+ columns_payload = table_entry.get("columns")
132
+ if not isinstance(columns_payload, Sequence) or not columns_payload:
133
+ raise ValueError(
134
+ f"Table '{table_name}' must include a non-empty 'columns' list"
135
+ )
136
+
137
+ columns: List[Column] = []
138
+ for column_index, column_entry in enumerate(columns_payload):
139
+ if not isinstance(column_entry, Mapping):
140
+ raise TypeError(
141
+ f"Column definition for table '{table_name}' must be a mapping"
142
+ )
143
+
144
+ column_name = str(
145
+ column_entry.get("name")
146
+ or column_entry.get("column_name")
147
+ or column_entry.get("field")
148
+ or ""
149
+ ).strip()
150
+ if not column_name:
151
+ raise ValueError(
152
+ f"Column definition in table '{table_name}' missing 'name'"
153
+ )
154
+
155
+ column_type = str(
156
+ column_entry.get("type")
157
+ or column_entry.get("data_type")
158
+ or "STRING"
159
+ ).strip()
160
+
161
+ values = column_entry.get("sample_values") or column_entry.get("values")
162
+ if isinstance(values, Sequence) and not isinstance(values, (str, bytes)):
163
+ sample_values = [str(value) for value in values]
164
+ else:
165
+ sample_values = None
166
+
167
+ is_primary = bool(
168
+ column_entry.get("is_primary_key")
169
+ or column_entry.get("primary_key")
170
+ or column_entry.get("is_primary")
171
+ )
172
+
173
+ columns.append(
174
+ Column(
175
+ id_=column_index,
176
+ column_name=column_name,
177
+ column_type=column_type,
178
+ values=sample_values,
179
+ comment=column_entry.get("comment"),
180
+ is_primary_key=is_primary,
181
+ )
182
+ )
183
+
184
+ table_proto = Table(
185
+ id_=table_index,
186
+ name=table_name.upper(),
187
+ columns=columns,
188
+ comment=table_entry.get("comment"),
189
+ )
190
+ fqn = FQNParts(
191
+ database=workspace.upper(),
192
+ schema_name=schema.upper(),
193
+ table=table_name,
194
+ )
195
+ raw_tables.append((fqn, table_proto))
196
+
197
+ return raw_tables
198
+
199
+
105
200
  def _discover_relationships(
106
201
  raw_tables: List[Tuple[FQNParts, Table]],
107
202
  strict_join_inference: bool,
108
203
  session: Optional[Session],
109
- ) -> List[semantic_model_pb2.Relationship]:
204
+ *,
205
+ max_relationships: Optional[int] = None,
206
+ min_confidence: float = 0.5,
207
+ timeout_seconds: Optional[float] = None,
208
+ ) -> Tuple[List[semantic_model_pb2.Relationship], Dict[str, bool]]:
110
209
  if not raw_tables:
111
- return []
210
+ return [], {"limited_by_timeout": False, "limited_by_max_relationships": False}
112
211
 
212
+ status: Dict[str, bool] = {}
113
213
  relationships = _infer_relationships(
114
214
  raw_tables,
115
215
  session=session if strict_join_inference else None,
116
216
  strict_join_inference=strict_join_inference,
217
+ status=status,
218
+ max_relationships=max_relationships,
219
+ min_confidence=min_confidence,
220
+ timeout_seconds=timeout_seconds,
117
221
  )
118
- return relationships
222
+ return relationships, status
119
223
 
120
224
 
121
225
  def discover_relationships_from_tables(
@@ -123,33 +227,86 @@ def discover_relationships_from_tables(
123
227
  *,
124
228
  strict_join_inference: bool = False,
125
229
  session: Optional[Session] = None,
230
+ max_relationships: Optional[int] = None,
231
+ min_confidence: float = 0.5,
232
+ timeout_seconds: Optional[float] = 30.0,
233
+ max_tables: Optional[int] = None,
126
234
  ) -> RelationshipDiscoveryResult:
127
235
  """
128
236
  Run relationship inference using pre-constructed table metadata.
129
237
  """
130
238
  start = time.perf_counter()
131
- relationships = _discover_relationships(
132
- list(tables),
239
+ raw_tables = list(tables)
240
+ limited_by_table_cap = False
241
+ notes: List[str] = []
242
+
243
+ if max_tables is not None and len(raw_tables) > max_tables:
244
+ limited_by_table_cap = True
245
+ notes.append(
246
+ f"Input contained {len(raw_tables)} tables; analysis limited to first {max_tables}."
247
+ )
248
+ raw_tables = raw_tables[:max_tables]
249
+
250
+ relationships, status = _discover_relationships(
251
+ raw_tables,
133
252
  strict_join_inference=strict_join_inference,
134
253
  session=session,
254
+ max_relationships=max_relationships,
255
+ min_confidence=min_confidence,
256
+ timeout_seconds=timeout_seconds,
135
257
  )
136
258
  end = time.perf_counter()
137
259
 
138
- all_columns = sum(len(table.columns) for _, table in tables)
260
+ all_columns = sum(len(table.columns) for _, table in raw_tables)
139
261
  summary = RelationshipSummary(
140
- total_tables=len(tables),
262
+ total_tables=len(raw_tables),
141
263
  total_columns=all_columns,
142
264
  total_relationships_found=len(relationships),
143
265
  processing_time_ms=int((end - start) * 1000),
266
+ limited_by_timeout=status.get("limited_by_timeout", False),
267
+ limited_by_max_relationships=status.get("limited_by_max_relationships", False),
268
+ limited_by_table_cap=limited_by_table_cap,
269
+ notes=" ".join(notes) if notes else None,
144
270
  )
145
271
 
146
272
  return RelationshipDiscoveryResult(
147
273
  relationships=relationships,
148
- tables=[table for _, table in tables],
274
+ tables=[table for _, table in raw_tables],
149
275
  summary=summary,
150
276
  )
151
277
 
152
278
 
279
+ def discover_relationships_from_table_definitions(
280
+ table_definitions: Sequence[Mapping[str, Any]],
281
+ *,
282
+ default_workspace: str = "OFFLINE",
283
+ default_schema: str = "PUBLIC",
284
+ strict_join_inference: bool = False,
285
+ session: Optional[Session] = None,
286
+ max_relationships: Optional[int] = None,
287
+ min_confidence: float = 0.5,
288
+ timeout_seconds: Optional[float] = 15.0,
289
+ max_tables: Optional[int] = None,
290
+ ) -> RelationshipDiscoveryResult:
291
+ """Run relationship inference using raw table metadata dictionaries."""
292
+
293
+ raw_tables = _tables_payload_to_raw_tables(
294
+ table_definitions,
295
+ default_workspace=default_workspace,
296
+ default_schema=default_schema,
297
+ )
298
+
299
+ return discover_relationships_from_tables(
300
+ raw_tables,
301
+ strict_join_inference=strict_join_inference,
302
+ session=session,
303
+ max_relationships=max_relationships,
304
+ min_confidence=min_confidence,
305
+ timeout_seconds=timeout_seconds,
306
+ max_tables=max_tables,
307
+ )
308
+
309
+
153
310
  def discover_relationships_from_schema(
154
311
  session: Session,
155
312
  workspace: str,
@@ -159,6 +316,10 @@ def discover_relationships_from_schema(
159
316
  sample_values_per_column: int = _DEFAULT_N_SAMPLE_VALUES_PER_COL,
160
317
  strict_join_inference: bool = False,
161
318
  max_workers: int = DEFAULT_MAX_WORKERS,
319
+ max_relationships: Optional[int] = None,
320
+ min_confidence: float = 0.5,
321
+ timeout_seconds: Optional[float] = 30.0,
322
+ max_tables: Optional[int] = 60,
162
323
  ) -> RelationshipDiscoveryResult:
163
324
  """
164
325
  Discover table relationships for all tables in a ClickZetta schema.
@@ -204,4 +365,8 @@ def discover_relationships_from_schema(
204
365
  raw_tables,
205
366
  strict_join_inference=strict_join_inference,
206
367
  session=session,
368
+ max_relationships=max_relationships,
369
+ min_confidence=min_confidence,
370
+ timeout_seconds=timeout_seconds,
371
+ max_tables=max_tables,
207
372
  )
@@ -3,15 +3,13 @@ from unittest import mock
3
3
 
4
4
  import pandas as pd
5
5
 
6
- from semantic_model_generator.clickzetta_utils import env_vars
7
6
  from semantic_model_generator.clickzetta_utils import clickzetta_connector as connector
7
+ from semantic_model_generator.clickzetta_utils import env_vars
8
8
 
9
9
 
10
10
  def test_fetch_stages_includes_user_volume(monkeypatch):
11
11
  data = pd.DataFrame({"name": ["shared_stage"]})
12
- with mock.patch.object(
13
- connector, "_execute_query_to_pandas", return_value=data
14
- ):
12
+ with mock.patch.object(connector, "_execute_query_to_pandas", return_value=data):
15
13
  stages = connector.fetch_stages_in_schema(
16
14
  connection=mock.MagicMock(), schema_name="WORKSPACE.SCHEMA"
17
15
  )
@@ -29,9 +27,7 @@ def test_fetch_yaml_names_in_user_volume(monkeypatch):
29
27
  ]
30
28
  }
31
29
  )
32
- with mock.patch.object(
33
- connector, "_execute_query_to_pandas", return_value=data
34
- ):
30
+ with mock.patch.object(connector, "_execute_query_to_pandas", return_value=data):
35
31
  files = connector.fetch_yaml_names_in_stage(
36
32
  connection=mock.MagicMock(),
37
33
  stage="volume:user://~/semantic_models/",
@@ -4,10 +4,11 @@ import pytest
4
4
  import sqlglot
5
5
 
6
6
  from semantic_model_generator.data_processing.cte_utils import (
7
+ ClickzettaDialect,
8
+ _prepare_sql_for_parsing,
7
9
  _enrich_column_in_expr_with_aggregation,
8
10
  _get_col_expr,
9
11
  _validate_col,
10
- ClickzettaDialect,
11
12
  context_to_column_format,
12
13
  expand_all_logical_tables_as_ctes,
13
14
  generate_select,
@@ -304,7 +305,7 @@ class SemanticModelTest(TestCase):
304
305
  col_format_tbl = get_test_table_col_format()
305
306
  got = generate_select(col_format_tbl, 100)
306
307
  want = [
307
- "WITH __t1 AS (SELECT d1_expr AS d1, d2_expr AS d2 FROM db.sc.t1) SELECT * FROM __t1 LIMIT 100"
308
+ "WITH __t1 AS (SELECT d1_expr AS d1, d2_expr AS d2 FROM `db`.`sc`.`t1`) SELECT * FROM __t1 LIMIT 100"
308
309
  ]
309
310
  assert got == want
310
311
 
@@ -312,8 +313,8 @@ class SemanticModelTest(TestCase):
312
313
  col_format_tbl = get_test_table_col_format_w_agg()
313
314
  got = generate_select(col_format_tbl, 100)
314
315
  want = [
315
- "WITH __t1 AS (SELECT SUM(d2) AS d2_total FROM db.sc.t1) SELECT * FROM __t1 LIMIT 100",
316
- "WITH __t1 AS (SELECT d1_expr AS d1, SUM(d3) OVER (PARTITION BY d1) AS d3 FROM db.sc.t1) SELECT * FROM __t1 LIMIT 100",
316
+ "WITH __t1 AS (SELECT SUM(d2) AS d2_total FROM `db`.`sc`.`t1`) SELECT * FROM __t1 LIMIT 100",
317
+ "WITH __t1 AS (SELECT d1_expr AS d1, SUM(d3) OVER (PARTITION BY d1) AS d3 FROM `db`.`sc`.`t1`) SELECT * FROM __t1 LIMIT 100",
317
318
  ]
318
319
  assert sorted(got) == sorted(want)
319
320
 
@@ -321,7 +322,7 @@ class SemanticModelTest(TestCase):
321
322
  col_format_tbl = get_test_table_col_format_w_agg_only()
322
323
  got = generate_select(col_format_tbl, 100)
323
324
  want = [
324
- "WITH __t1 AS (SELECT SUM(d2) AS d2_total FROM db.sc.t1) SELECT * FROM __t1 LIMIT 100"
325
+ "WITH __t1 AS (SELECT SUM(d2) AS d2_total FROM `db`.`sc`.`t1`) SELECT * FROM __t1 LIMIT 100"
325
326
  ]
326
327
  assert sorted(got) == sorted(want)
327
328
 
@@ -437,21 +438,21 @@ class SemanticModelTest(TestCase):
437
438
  want = """WITH __t1 AS (SELECT
438
439
  d1_expr AS d1,
439
440
  d2_expr AS d2
440
- FROM db.sc.t1
441
+ FROM `db`.`sc`.`t1`
441
442
  ), __t2 AS (
442
443
  SELECT
443
444
  td1_expr AS td1,
444
445
  m1_expr AS m1,
445
446
  m1_expr AS m2,
446
447
  m3_expr
447
- FROM db.sc.t2
448
+ FROM `db`.`sc`.`t2`
448
449
  )
449
450
  SELECT
450
451
  *
451
452
  FROM __t2"""
452
- assert sqlglot.parse_one(want, ClickzettaDialect) == sqlglot.parse_one(
453
- got, ClickzettaDialect
454
- )
453
+ assert sqlglot.parse_one(
454
+ _prepare_sql_for_parsing(want), ClickzettaDialect
455
+ ) == sqlglot.parse_one(_prepare_sql_for_parsing(got), ClickzettaDialect)
455
456
 
456
457
  def test_expand_all_logical_tables_as_ctes_with_column_renaming(self) -> None:
457
458
  ctx = semantic_model_pb2.SemanticModel(
@@ -465,12 +466,12 @@ FROM __t2"""
465
466
  clcks AS clicks,
466
467
  clcks,
467
468
  cst
468
- FROM db.sc.t1
469
+ FROM `db`.`sc`.`t1`
469
470
  )
470
471
  SELECT
471
472
  *
472
473
  FROM __t1
473
474
  """
474
- assert sqlglot.parse_one(want, ClickzettaDialect) == sqlglot.parse_one(
475
- got, ClickzettaDialect
476
- )
475
+ assert sqlglot.parse_one(
476
+ _prepare_sql_for_parsing(want), ClickzettaDialect
477
+ ) == sqlglot.parse_one(_prepare_sql_for_parsing(got), ClickzettaDialect)
@@ -31,8 +31,18 @@ def test_string_date_promoted_to_time_dimension() -> None:
31
31
  id_=0,
32
32
  name="ORDERS",
33
33
  columns=[
34
- Column(id_=0, column_name="order_date", column_type="STRING", values=["2024-01-01", "2024-02-01"]),
35
- Column(id_=1, column_name="order_status", column_type="STRING", values=["OPEN", "CLOSED"]),
34
+ Column(
35
+ id_=0,
36
+ column_name="order_date",
37
+ column_type="STRING",
38
+ values=["2024-01-01", "2024-02-01"],
39
+ ),
40
+ Column(
41
+ id_=1,
42
+ column_name="order_status",
43
+ column_type="STRING",
44
+ values=["OPEN", "CLOSED"],
45
+ ),
36
46
  ],
37
47
  )
38
48