clickzetta-semantic-model-generator 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {clickzetta_semantic_model_generator-1.0.2.dist-info → clickzetta_semantic_model_generator-1.0.3.dist-info}/METADATA +5 -5
  2. {clickzetta_semantic_model_generator-1.0.2.dist-info → clickzetta_semantic_model_generator-1.0.3.dist-info}/RECORD +21 -21
  3. semantic_model_generator/clickzetta_utils/clickzetta_connector.py +91 -33
  4. semantic_model_generator/clickzetta_utils/env_vars.py +7 -2
  5. semantic_model_generator/data_processing/cte_utils.py +1 -1
  6. semantic_model_generator/generate_model.py +588 -224
  7. semantic_model_generator/llm/dashscope_client.py +4 -2
  8. semantic_model_generator/llm/enrichment.py +144 -57
  9. semantic_model_generator/llm/progress_tracker.py +16 -15
  10. semantic_model_generator/relationships/discovery.py +1 -6
  11. semantic_model_generator/tests/clickzetta_connector_test.py +3 -7
  12. semantic_model_generator/tests/cte_utils_test.py +1 -1
  13. semantic_model_generator/tests/generate_model_classification_test.py +12 -2
  14. semantic_model_generator/tests/llm_enrichment_test.py +152 -46
  15. semantic_model_generator/tests/relationship_discovery_test.py +6 -3
  16. semantic_model_generator/tests/relationships_filters_test.py +166 -30
  17. semantic_model_generator/tests/utils_test.py +1 -1
  18. semantic_model_generator/validate/keywords.py +453 -53
  19. semantic_model_generator/validate/schema.py +4 -2
  20. {clickzetta_semantic_model_generator-1.0.2.dist-info → clickzetta_semantic_model_generator-1.0.3.dist-info}/LICENSE +0 -0
  21. {clickzetta_semantic_model_generator-1.0.2.dist-info → clickzetta_semantic_model_generator-1.0.3.dist-info}/WHEEL +0 -0
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import json
4
3
  import os
5
4
  from dataclasses import dataclass
6
5
  from http import HTTPStatus
@@ -8,6 +7,7 @@ from typing import Any, Dict, List, Optional
8
7
  from urllib.parse import urlparse, urlunparse
9
8
 
10
9
  from loguru import logger
10
+
11
11
  try:
12
12
  import dashscope # type: ignore
13
13
  from dashscope import Generation # type: ignore
@@ -144,7 +144,9 @@ class DashscopeClient:
144
144
 
145
145
  output = getattr(response, "output", None)
146
146
  if not output or not hasattr(output, "choices"):
147
- raise DashscopeError(f"DashScope response missing output choices: {response}")
147
+ raise DashscopeError(
148
+ f"DashScope response missing output choices: {response}"
149
+ )
148
150
 
149
151
  choices = getattr(output, "choices")
150
152
  if not choices:
@@ -11,7 +11,7 @@ from semantic_model_generator.data_processing import data_types
11
11
  from semantic_model_generator.protos import semantic_model_pb2
12
12
 
13
13
  from .dashscope_client import DashscopeClient, DashscopeError
14
- from .progress_tracker import EnrichmentProgressTracker, EnrichmentStage, ProgressUpdate
14
+ from .progress_tracker import EnrichmentProgressTracker, EnrichmentStage
15
15
 
16
16
  if TYPE_CHECKING: # pragma: no cover
17
17
  from clickzetta.zettapark.session import Session
@@ -65,24 +65,27 @@ def enrich_semantic_model(
65
65
 
66
66
  # Initialize progress tracking
67
67
  total_tables = len(model.tables)
68
- total_model_steps = 3 # model description, model metrics, verified queries
69
68
 
70
69
  if progress_tracker:
71
70
  progress_tracker.update_progress(
72
71
  EnrichmentStage.TABLE_ENRICHMENT,
73
72
  0,
74
73
  total_tables,
75
- message="Starting table enrichment"
74
+ message="Starting table enrichment",
76
75
  )
77
76
 
78
- raw_lookup: Dict[str, data_types.Table] = {tbl.name.upper(): tbl for _, tbl in raw_tables}
77
+ raw_lookup: Dict[str, data_types.Table] = {
78
+ tbl.name.upper(): tbl for _, tbl in raw_tables
79
+ }
79
80
  metric_notes: List[str] = []
80
81
 
81
82
  # Process each table with progress tracking
82
83
  for table_index, table in enumerate(model.tables):
83
84
  raw_table = raw_lookup.get(table.name.upper())
84
85
  if not raw_table:
85
- logger.debug("No raw metadata for table {}; skipping enrichment.", table.name)
86
+ logger.debug(
87
+ "No raw metadata for table {}; skipping enrichment.", table.name
88
+ )
86
89
  continue
87
90
 
88
91
  # Update progress for current table
@@ -92,11 +95,13 @@ def enrich_semantic_model(
92
95
  table_index + 1,
93
96
  total_tables,
94
97
  table_name=table.name,
95
- message=f"Enriching table {table.name}"
98
+ message=f"Enriching table {table.name}",
96
99
  )
97
100
 
98
101
  try:
99
- payload = _serialize_table_prompt(table, raw_table, model.description, placeholder, custom_prompt)
102
+ payload = _serialize_table_prompt(
103
+ table, raw_table, model.description, placeholder, custom_prompt
104
+ )
100
105
  response = client.chat_completion(payload["messages"])
101
106
  enrichment = _parse_llm_response(response.content)
102
107
  if enrichment:
@@ -108,10 +113,15 @@ def enrich_semantic_model(
108
113
  if (
109
114
  model_description
110
115
  and isinstance(model_description, str)
111
- and (model.description == placeholder or not model.description.strip())
116
+ and (
117
+ model.description == placeholder
118
+ or not model.description.strip()
119
+ )
112
120
  ):
113
121
  model.description = model_description.strip()
114
- except DashscopeError as exc: # pragma: no cover - network failures or remote errors
122
+ except (
123
+ DashscopeError
124
+ ) as exc: # pragma: no cover - network failures or remote errors
115
125
  logger.warning("DashScope enrichment failed for {}: {}", table.name, exc)
116
126
  except Exception as exc: # pragma: no cover - defensive guard
117
127
  logger.exception("Unexpected error enriching table {}: {}", table.name, exc)
@@ -121,7 +131,7 @@ def enrich_semantic_model(
121
131
  EnrichmentStage.MODEL_DESCRIPTION,
122
132
  0,
123
133
  1,
124
- message="Generating model description"
134
+ message="Generating model description",
125
135
  )
126
136
 
127
137
  if model.description == placeholder or not model.description.strip():
@@ -132,7 +142,7 @@ def enrich_semantic_model(
132
142
  EnrichmentStage.MODEL_DESCRIPTION,
133
143
  1,
134
144
  1,
135
- message="Model description generated"
145
+ message="Model description generated",
136
146
  )
137
147
 
138
148
  # Model metrics generation
@@ -141,7 +151,7 @@ def enrich_semantic_model(
141
151
  EnrichmentStage.MODEL_METRICS,
142
152
  0,
143
153
  1,
144
- message="Generating model-level metrics"
154
+ message="Generating model-level metrics",
145
155
  )
146
156
 
147
157
  overview = _build_model_overview(model, raw_lookup, raw_tables)
@@ -154,10 +164,7 @@ def enrich_semantic_model(
154
164
 
155
165
  if progress_tracker:
156
166
  progress_tracker.update_progress(
157
- EnrichmentStage.MODEL_METRICS,
158
- 1,
159
- 1,
160
- message="Model metrics generated"
167
+ EnrichmentStage.MODEL_METRICS, 1, 1, message="Model metrics generated"
161
168
  )
162
169
 
163
170
  # Verified queries generation
@@ -166,7 +173,7 @@ def enrich_semantic_model(
166
173
  EnrichmentStage.VERIFIED_QUERIES,
167
174
  0,
168
175
  1,
169
- message="Generating verified queries"
176
+ message="Generating verified queries",
170
177
  )
171
178
 
172
179
  try:
@@ -185,10 +192,7 @@ def enrich_semantic_model(
185
192
 
186
193
  if progress_tracker:
187
194
  progress_tracker.update_progress(
188
- EnrichmentStage.VERIFIED_QUERIES,
189
- 1,
190
- 1,
191
- message="Verified queries generated"
195
+ EnrichmentStage.VERIFIED_QUERIES, 1, 1, message="Verified queries generated"
192
196
  )
193
197
 
194
198
  if metric_notes:
@@ -248,7 +252,9 @@ def _serialize_table_prompt(
248
252
  "name": nf.name,
249
253
  "expr": nf.expr,
250
254
  "has_description": bool(nf.description.strip()),
251
- "has_synonyms": any(s.strip() and s != placeholder for s in nf.synonyms),
255
+ "has_synonyms": any(
256
+ s.strip() and s != placeholder for s in nf.synonyms
257
+ ),
252
258
  }
253
259
  for nf in table.filters
254
260
  ],
@@ -270,14 +276,14 @@ def _serialize_table_prompt(
270
276
  "{\n"
271
277
  ' "table_description": "Orders fact table that captures the status and finances of each order",\n'
272
278
  ' "columns": [\n'
273
- ' {\n'
279
+ " {\n"
274
280
  ' "name": "O_TOTALPRICE",\n'
275
281
  ' "description": "Total order value including tax",\n'
276
282
  ' "synonyms": ["Order amount", "Order total"]\n'
277
283
  " }\n"
278
284
  " ],\n"
279
285
  ' "business_metrics": [\n'
280
- ' {\n'
286
+ " {\n"
281
287
  ' "name": "Gross merchandise value",\n'
282
288
  ' "source_columns": ["O_TOTALPRICE"],\n'
283
289
  ' "description": "Used to measure GMV derived from the total order price."\n'
@@ -306,7 +312,9 @@ def _parse_llm_response(content: str) -> Optional[Dict[str, object]]:
306
312
  try:
307
313
  data = json.loads(json_text)
308
314
  except json.JSONDecodeError as exc:
309
- logger.warning("Unable to parse DashScope response as JSON: {} | raw={}", exc, content)
315
+ logger.warning(
316
+ "Unable to parse DashScope response as JSON: {} | raw={}", exc, content
317
+ )
310
318
  return None
311
319
  if not isinstance(data, dict):
312
320
  return None
@@ -368,7 +376,10 @@ def _apply_column_enrichment(
368
376
  continue
369
377
 
370
378
  description = entry.get("description")
371
- if isinstance(description, str) and getattr(target, "description", "") == placeholder:
379
+ if (
380
+ isinstance(description, str)
381
+ and getattr(target, "description", "") == placeholder
382
+ ):
372
383
  target.description = description.strip()
373
384
 
374
385
  synonyms = entry.get("synonyms")
@@ -376,7 +387,9 @@ def _apply_column_enrichment(
376
387
  _apply_synonyms(target, synonyms, placeholder)
377
388
 
378
389
 
379
- def _apply_synonyms(target: object, synonyms: Sequence[object], placeholder: str) -> None:
390
+ def _apply_synonyms(
391
+ target: object, synonyms: Sequence[object], placeholder: str
392
+ ) -> None:
380
393
  clean_synonyms: List[str] = []
381
394
  for item in synonyms:
382
395
  if isinstance(item, str):
@@ -386,7 +399,11 @@ def _apply_synonyms(target: object, synonyms: Sequence[object], placeholder: str
386
399
  if not clean_synonyms:
387
400
  return
388
401
 
389
- existing = [syn for syn in getattr(target, "synonyms", []) if syn.strip() and syn != placeholder]
402
+ existing = [
403
+ syn
404
+ for syn in getattr(target, "synonyms", [])
405
+ if syn.strip() and syn != placeholder
406
+ ]
390
407
  merged = _deduplicate(existing + clean_synonyms)
391
408
 
392
409
  if hasattr(target, "synonyms"):
@@ -452,7 +469,11 @@ def _apply_filter_enrichment(
452
469
  target.description = description.strip()
453
470
  synonyms = entry.get("synonyms")
454
471
  if isinstance(synonyms, list):
455
- clean_synonyms = [str(item).strip() for item in synonyms if isinstance(item, (str, int, float))]
472
+ clean_synonyms = [
473
+ str(item).strip()
474
+ for item in synonyms
475
+ if isinstance(item, (str, int, float))
476
+ ]
456
477
  if clean_synonyms:
457
478
  del target.synonyms[:]
458
479
  target.synonyms.extend(clean_synonyms)
@@ -483,7 +504,15 @@ _COUNT_KEYWORDS = (
483
504
  "headcount",
484
505
  )
485
506
  _DISTINCT_KEYWORDS = ("distinct", "unique", "deduplicated")
486
- _AVERAGE_KEYWORDS = ("average", "avg", "mean", "typical", "expected", "per order", "per customer")
507
+ _AVERAGE_KEYWORDS = (
508
+ "average",
509
+ "avg",
510
+ "mean",
511
+ "typical",
512
+ "expected",
513
+ "per order",
514
+ "per customer",
515
+ )
487
516
  _SUM_KEYWORDS = (
488
517
  "total",
489
518
  "sum",
@@ -610,7 +639,9 @@ def _apply_metric_enrichment(
610
639
  business_metrics: Sequence[object],
611
640
  placeholder: str,
612
641
  ) -> tuple[Optional[str], bool]:
613
- column_type_map = {col.column_name.upper(): col.column_type for col in raw_table.columns}
642
+ column_type_map = {
643
+ col.column_name.upper(): col.column_type for col in raw_table.columns
644
+ }
614
645
  existing_names: set[str] = {metric.name for metric in table.metrics}
615
646
  notes: List[Dict[str, object]] = []
616
647
  metrics_added = False
@@ -634,8 +665,12 @@ def _apply_metric_enrichment(
634
665
  continue
635
666
 
636
667
  metric_name = _sanitize_metric_name(name, existing_names)
637
- aggregation, use_product = _derive_metric_intent(entry, resolved_sources, column_type_map)
638
- expression = _build_metric_expression(resolved_sources, column_type_map, aggregation, use_product)
668
+ aggregation, use_product = _derive_metric_intent(
669
+ entry, resolved_sources, column_type_map
670
+ )
671
+ expression = _build_metric_expression(
672
+ resolved_sources, column_type_map, aggregation, use_product
673
+ )
639
674
 
640
675
  metric = table.metrics.add()
641
676
  metric.name = metric_name
@@ -643,7 +678,9 @@ def _apply_metric_enrichment(
643
678
 
644
679
  description = entry.get("description")
645
680
  metric.description = (
646
- description.strip() if isinstance(description, str) and description.strip() else placeholder
681
+ description.strip()
682
+ if isinstance(description, str) and description.strip()
683
+ else placeholder
647
684
  )
648
685
 
649
686
  synonyms = entry.get("synonyms")
@@ -661,8 +698,16 @@ def _apply_metric_enrichment(
661
698
  notes.append(
662
699
  {
663
700
  "name": name.strip(),
664
- "source_columns": raw_sources if isinstance(raw_sources, list) and raw_sources else resolved_sources,
665
- "description": description.strip() if isinstance(description, str) and description.strip() else "",
701
+ "source_columns": (
702
+ raw_sources
703
+ if isinstance(raw_sources, list) and raw_sources
704
+ else resolved_sources
705
+ ),
706
+ "description": (
707
+ description.strip()
708
+ if isinstance(description, str) and description.strip()
709
+ else ""
710
+ ),
666
711
  }
667
712
  )
668
713
  metrics_added = True
@@ -683,7 +728,9 @@ def _summarize_model_description(
683
728
  table_lines = []
684
729
  for table in model.tables:
685
730
  role = "fact" if table.facts or table.metrics else "dimension"
686
- desc = table.description.strip() if table.description.strip() else "No description"
731
+ desc = (
732
+ table.description.strip() if table.description.strip() else "No description"
733
+ )
687
734
  metrics = ", ".join(metric.name for metric in table.metrics) or "None"
688
735
  table_lines.append(f"- {table.name} ({role}): {desc}. Metrics: {metrics}")
689
736
 
@@ -692,7 +739,8 @@ def _summarize_model_description(
692
739
  parts = [f"{rel.left_table} -> {rel.right_table}"]
693
740
  if rel.relationship_columns:
694
741
  columns = ", ".join(
695
- f"{col.left_column}={col.right_column}" for col in rel.relationship_columns
742
+ f"{col.left_column}={col.right_column}"
743
+ for col in rel.relationship_columns
696
744
  )
697
745
  parts.append(f"on {columns}")
698
746
  relationship_lines.append(" ".join(parts))
@@ -754,8 +802,12 @@ def _build_model_overview(
754
802
  "name": table.name,
755
803
  "description": (table.description or "").strip(),
756
804
  "base_table": {
757
- "database": table.base_table.database if table.HasField("base_table") else "",
758
- "schema": table.base_table.schema if table.HasField("base_table") else "",
805
+ "database": (
806
+ table.base_table.database if table.HasField("base_table") else ""
807
+ ),
808
+ "schema": (
809
+ table.base_table.schema if table.HasField("base_table") else ""
810
+ ),
759
811
  "table": table.base_table.table if table.HasField("base_table") else "",
760
812
  },
761
813
  "dimensions": [
@@ -833,7 +885,9 @@ def _build_model_overview(
833
885
  "left_table": rel.left_table,
834
886
  "right_table": rel.right_table,
835
887
  "join_type": semantic_model_pb2.JoinType.Name(rel.join_type),
836
- "relationship_type": semantic_model_pb2.RelationshipType.Name(rel.relationship_type),
888
+ "relationship_type": semantic_model_pb2.RelationshipType.Name(
889
+ rel.relationship_type
890
+ ),
837
891
  "columns": [
838
892
  {"left_column": col.left_column, "right_column": col.right_column}
839
893
  for col in rel.relationship_columns
@@ -859,8 +913,10 @@ def _generate_model_metrics(
859
913
  metrics_accessible = False
860
914
 
861
915
  # Step 1: Check if metrics attribute exists
862
- if not hasattr(model, 'metrics'):
863
- logger.warning("Model object missing 'metrics' attribute, skipping model-level metrics generation")
916
+ if not hasattr(model, "metrics"):
917
+ logger.warning(
918
+ "Model object missing 'metrics' attribute, skipping model-level metrics generation"
919
+ )
864
920
  return
865
921
 
866
922
  # Step 2: Test basic read access
@@ -888,14 +944,22 @@ def _generate_model_metrics(
888
944
  logger.debug("Metrics field write access verified and cleaned up")
889
945
  else:
890
946
  # Metric add appeared to succeed but count didn't change - something is wrong
891
- logger.warning("Metrics field write access inconsistent (count: {} -> {}), skipping model-level metrics", current_count, new_count)
947
+ logger.warning(
948
+ "Metrics field write access inconsistent (count: {} -> {}), skipping model-level metrics",
949
+ current_count,
950
+ new_count,
951
+ )
892
952
  return
893
953
 
894
954
  except Exception as exc:
895
955
  logger.warning("Cannot write to model.metrics field: {}", str(exc))
896
956
  # Try to provide diagnostic information without causing more errors
897
957
  try:
898
- logger.debug("Model type: {}, metrics type: {}", type(model).__name__, type(getattr(model, 'metrics', None)))
958
+ logger.debug(
959
+ "Model type: {}, metrics type: {}",
960
+ type(model).__name__,
961
+ type(getattr(model, "metrics", None)),
962
+ )
899
963
  except Exception:
900
964
  pass
901
965
  return
@@ -920,12 +984,12 @@ def _generate_model_metrics(
920
984
  "Design up to three model-level business metrics (KPIs) using the semantic model summary below.\n"
921
985
  "Return JSON with the structure:\n"
922
986
  "{\n"
923
- " \"model_metrics\": [\n"
987
+ ' "model_metrics": [\n'
924
988
  " {\n"
925
- " \"name\": \"...\",\n"
926
- " \"expr\": \"SUM(FACT_SALES.total_amount)\",\n"
927
- " \"description\": \"...\",\n"
928
- " \"synonyms\": [\"...\"]\n"
989
+ ' "name": "...",\n'
990
+ ' "expr": "SUM(FACT_SALES.total_amount)",\n'
991
+ ' "description": "...",\n'
992
+ ' "synonyms": ["..."]\n'
929
993
  " }\n"
930
994
  " ]\n"
931
995
  "}\n"
@@ -953,7 +1017,10 @@ def _generate_model_metrics(
953
1017
 
954
1018
  entries = payload.get("model_metrics")
955
1019
  if not isinstance(entries, list):
956
- logger.debug("No model_metrics list found in LLM response: {}", payload.keys() if payload else "None")
1020
+ logger.debug(
1021
+ "No model_metrics list found in LLM response: {}",
1022
+ payload.keys() if payload else "None",
1023
+ )
957
1024
  return
958
1025
 
959
1026
  logger.debug("Found {} model metrics entries to process", len(entries))
@@ -981,8 +1048,14 @@ def _generate_model_metrics(
981
1048
  try:
982
1049
  metric = model.metrics.add()
983
1050
  except Exception as exc:
984
- logger.warning("Failed to add model-level metric '{}' despite pre-check: {}", name, str(exc))
985
- logger.info("Aborting model-level metrics generation due to unexpected field access failure")
1051
+ logger.warning(
1052
+ "Failed to add model-level metric '{}' despite pre-check: {}",
1053
+ name,
1054
+ str(exc),
1055
+ )
1056
+ logger.info(
1057
+ "Aborting model-level metrics generation due to unexpected field access failure"
1058
+ )
986
1059
  return
987
1060
 
988
1061
  metric.name = _sanitize_metric_name(name, existing_names)
@@ -996,7 +1069,11 @@ def _generate_model_metrics(
996
1069
 
997
1070
  synonyms = entry.get("synonyms")
998
1071
  if isinstance(synonyms, list):
999
- clean_synonyms = [str(item).strip() for item in synonyms if isinstance(item, (str, int, float)) and str(item).strip()]
1072
+ clean_synonyms = [
1073
+ str(item).strip()
1074
+ for item in synonyms
1075
+ if isinstance(item, (str, int, float)) and str(item).strip()
1076
+ ]
1000
1077
  if clean_synonyms:
1001
1078
  metric.synonyms.extend(clean_synonyms)
1002
1079
 
@@ -1033,7 +1110,9 @@ def _generate_verified_queries(
1033
1110
  max_items: int = 3,
1034
1111
  ) -> None:
1035
1112
  if session is None:
1036
- logger.debug("Skipping verified query generation because no ClickZetta session was provided.")
1113
+ logger.debug(
1114
+ "Skipping verified query generation because no ClickZetta session was provided."
1115
+ )
1037
1116
  return
1038
1117
 
1039
1118
  prompt_json = json.dumps(overview, ensure_ascii=False, indent=2)
@@ -1079,7 +1158,11 @@ def _generate_verified_queries(
1079
1158
  continue
1080
1159
  query_name = entry.get("name")
1081
1160
  if not isinstance(query_name, str) or not query_name.strip():
1082
- query_name = question if isinstance(question, str) and question.strip() else "Verified query"
1161
+ query_name = (
1162
+ question
1163
+ if isinstance(question, str) and question.strip()
1164
+ else "Verified query"
1165
+ )
1083
1166
 
1084
1167
  normalized_sql = _ensure_limit_clause(sql)
1085
1168
  if normalized_sql.strip().lower() in existing_sql:
@@ -1088,7 +1171,11 @@ def _generate_verified_queries(
1088
1171
  try:
1089
1172
  session.sql(normalized_sql).to_pandas()
1090
1173
  except Exception as exc: # pragma: no cover - ClickZetta query failed
1091
- logger.warning("Skipping verified query '{}' due to validation failure: {}", query_name, exc)
1174
+ logger.warning(
1175
+ "Skipping verified query '{}' due to validation failure: {}",
1176
+ query_name,
1177
+ exc,
1178
+ )
1092
1179
  continue
1093
1180
 
1094
1181
  verified_query = model.verified_queries.add()
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Progress tracking system for semantic model enrichment process.
3
3
  """
4
+
4
5
  from __future__ import annotations
5
6
 
6
7
  from dataclasses import dataclass
@@ -41,7 +42,9 @@ class EnrichmentProgressTracker:
41
42
  to UI components via callback functions.
42
43
  """
43
44
 
44
- def __init__(self, progress_callback: Optional[Callable[[ProgressUpdate], None]] = None):
45
+ def __init__(
46
+ self, progress_callback: Optional[Callable[[ProgressUpdate], None]] = None
47
+ ):
45
48
  """
46
49
  Initialize the progress tracker.
47
50
 
@@ -53,11 +56,11 @@ class EnrichmentProgressTracker:
53
56
 
54
57
  # Weight distribution across stages (should sum to 1.0)
55
58
  self.stage_weights = {
56
- EnrichmentStage.METADATA_FETCH: 0.05, # 5% - Quick metadata collection
57
- EnrichmentStage.TABLE_ENRICHMENT: 0.70, # 70% - Most time-consuming (multiple LLM calls)
58
- EnrichmentStage.MODEL_DESCRIPTION: 0.05, # 5% - Single LLM call
59
- EnrichmentStage.MODEL_METRICS: 0.10, # 10% - Single LLM call
60
- EnrichmentStage.VERIFIED_QUERIES: 0.10, # 10% - Single LLM call + validation
59
+ EnrichmentStage.METADATA_FETCH: 0.05, # 5% - Quick metadata collection
60
+ EnrichmentStage.TABLE_ENRICHMENT: 0.70, # 70% - Most time-consuming (multiple LLM calls)
61
+ EnrichmentStage.MODEL_DESCRIPTION: 0.05, # 5% - Single LLM call
62
+ EnrichmentStage.MODEL_METRICS: 0.10, # 10% - Single LLM call
63
+ EnrichmentStage.VERIFIED_QUERIES: 0.10, # 10% - Single LLM call + validation
61
64
  }
62
65
 
63
66
  # Track accumulated progress from completed stages
@@ -70,7 +73,7 @@ class EnrichmentProgressTracker:
70
73
  total: int,
71
74
  table_name: Optional[str] = None,
72
75
  message: str = "",
73
- details: Optional[Dict[str, Any]] = None
76
+ details: Optional[Dict[str, Any]] = None,
74
77
  ) -> None:
75
78
  """
76
79
  Update progress for the current enrichment stage.
@@ -99,7 +102,7 @@ class EnrichmentProgressTracker:
99
102
  table_name=table_name,
100
103
  message=message,
101
104
  percentage=percentage,
102
- details=details or {}
105
+ details=details or {},
103
106
  )
104
107
 
105
108
  # Send update via callback
@@ -116,10 +119,7 @@ class EnrichmentProgressTracker:
116
119
  self.completed_stage_progress += self.stage_weights[stage]
117
120
 
118
121
  def _calculate_overall_percentage(
119
- self,
120
- stage: EnrichmentStage,
121
- current: int,
122
- total: int
122
+ self, stage: EnrichmentStage, current: int, total: int
123
123
  ) -> float:
124
124
  """
125
125
  Calculate overall progress percentage across all stages.
@@ -153,7 +153,7 @@ class EnrichmentProgressTracker:
153
153
  stage=EnrichmentStage.COMPLETE,
154
154
  current=1,
155
155
  total=1,
156
- message="Enrichment complete"
156
+ message="Enrichment complete",
157
157
  )
158
158
 
159
159
 
@@ -164,6 +164,7 @@ def create_ui_progress_callback() -> Callable[[ProgressUpdate], None]:
164
164
  Returns:
165
165
  Callback function that formats progress updates for UI display.
166
166
  """
167
+
167
168
  def callback(update: ProgressUpdate) -> None:
168
169
  """Format and display progress update in UI."""
169
170
  # Build progress message
@@ -173,7 +174,7 @@ def create_ui_progress_callback() -> Callable[[ProgressUpdate], None]:
173
174
  EnrichmentStage.MODEL_DESCRIPTION: "Generating model description",
174
175
  EnrichmentStage.MODEL_METRICS: "Generating model metrics",
175
176
  EnrichmentStage.VERIFIED_QUERIES: "Generating verified queries",
176
- EnrichmentStage.COMPLETE: "Complete"
177
+ EnrichmentStage.COMPLETE: "Complete",
177
178
  }
178
179
 
179
180
  stage_label = stage_labels.get(update.stage, update.stage.value)
@@ -195,4 +196,4 @@ def create_ui_progress_callback() -> Callable[[ProgressUpdate], None]:
195
196
  # For now, we'll use the existing progress callback mechanism
196
197
  print(f"[{update.percentage:.1f}%] {full_message}")
197
198
 
198
- return callback
199
+ return callback
@@ -9,7 +9,6 @@ from loguru import logger
9
9
 
10
10
  from semantic_model_generator.clickzetta_utils.clickzetta_connector import (
11
11
  _TABLE_NAME_COL,
12
- _TABLE_SCHEMA_COL,
13
12
  get_table_representation,
14
13
  get_valid_schemas_tables_columns_df,
15
14
  )
@@ -68,11 +67,7 @@ def _build_tables_from_dataframe(
68
67
  )
69
68
 
70
69
  table_order = (
71
- columns_df[_TABLE_NAME_COL]
72
- .astype(str)
73
- .str.upper()
74
- .drop_duplicates()
75
- .tolist()
70
+ columns_df[_TABLE_NAME_COL].astype(str).str.upper().drop_duplicates().tolist()
76
71
  )
77
72
 
78
73
  tables: List[Tuple[FQNParts, Table]] = []
@@ -3,15 +3,13 @@ from unittest import mock
3
3
 
4
4
  import pandas as pd
5
5
 
6
- from semantic_model_generator.clickzetta_utils import env_vars
7
6
  from semantic_model_generator.clickzetta_utils import clickzetta_connector as connector
7
+ from semantic_model_generator.clickzetta_utils import env_vars
8
8
 
9
9
 
10
10
  def test_fetch_stages_includes_user_volume(monkeypatch):
11
11
  data = pd.DataFrame({"name": ["shared_stage"]})
12
- with mock.patch.object(
13
- connector, "_execute_query_to_pandas", return_value=data
14
- ):
12
+ with mock.patch.object(connector, "_execute_query_to_pandas", return_value=data):
15
13
  stages = connector.fetch_stages_in_schema(
16
14
  connection=mock.MagicMock(), schema_name="WORKSPACE.SCHEMA"
17
15
  )
@@ -29,9 +27,7 @@ def test_fetch_yaml_names_in_user_volume(monkeypatch):
29
27
  ]
30
28
  }
31
29
  )
32
- with mock.patch.object(
33
- connector, "_execute_query_to_pandas", return_value=data
34
- ):
30
+ with mock.patch.object(connector, "_execute_query_to_pandas", return_value=data):
35
31
  files = connector.fetch_yaml_names_in_stage(
36
32
  connection=mock.MagicMock(),
37
33
  stage="volume:user://~/semantic_models/",
@@ -4,10 +4,10 @@ import pytest
4
4
  import sqlglot
5
5
 
6
6
  from semantic_model_generator.data_processing.cte_utils import (
7
+ ClickzettaDialect,
7
8
  _enrich_column_in_expr_with_aggregation,
8
9
  _get_col_expr,
9
10
  _validate_col,
10
- ClickzettaDialect,
11
11
  context_to_column_format,
12
12
  expand_all_logical_tables_as_ctes,
13
13
  generate_select,
@@ -31,8 +31,18 @@ def test_string_date_promoted_to_time_dimension() -> None:
31
31
  id_=0,
32
32
  name="ORDERS",
33
33
  columns=[
34
- Column(id_=0, column_name="order_date", column_type="STRING", values=["2024-01-01", "2024-02-01"]),
35
- Column(id_=1, column_name="order_status", column_type="STRING", values=["OPEN", "CLOSED"]),
34
+ Column(
35
+ id_=0,
36
+ column_name="order_date",
37
+ column_type="STRING",
38
+ values=["2024-01-01", "2024-02-01"],
39
+ ),
40
+ Column(
41
+ id_=1,
42
+ column_name="order_status",
43
+ column_type="STRING",
44
+ values=["OPEN", "CLOSED"],
45
+ ),
36
46
  ],
37
47
  )
38
48