aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/text2sql.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
1
|
+
"""High-level API for text-to-SQL pipeline orchestration and mode management.
|
|
2
|
+
|
|
3
|
+
Provides the ``Text2SQL`` class as the primary entry point for all pipeline operations: engine configuration, API key management, schema loading, template persistence, and execution mode switching between interactive, simulator, and question-generation (qsim) modes.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from dotenv import load_dotenv
|
|
12
|
+
|
|
13
|
+
from .config import (
|
|
14
|
+
DatabricksRuntimeConfig,
|
|
15
|
+
EngineConfig,
|
|
16
|
+
PostgresRuntimeConfig,
|
|
17
|
+
QSimConfig,
|
|
18
|
+
)
|
|
19
|
+
from .contracts_base import (
|
|
20
|
+
QSimRange,
|
|
21
|
+
QSimSummary,
|
|
22
|
+
RejectedTemplateInfo,
|
|
23
|
+
SchemaGraph,
|
|
24
|
+
SimulatorSummary,
|
|
25
|
+
TemplateInfo,
|
|
26
|
+
)
|
|
27
|
+
from .dialect import get_dialect
|
|
28
|
+
from .main_execution import (
|
|
29
|
+
get_artifacts_dir,
|
|
30
|
+
get_qsim,
|
|
31
|
+
get_questions_only,
|
|
32
|
+
get_rejected_templates_list,
|
|
33
|
+
get_simulator_summary_from_dir,
|
|
34
|
+
get_templates_list,
|
|
35
|
+
interactive_run_once,
|
|
36
|
+
load_generated_questions,
|
|
37
|
+
qsim_run_once,
|
|
38
|
+
resolve_qsim_path,
|
|
39
|
+
simulator_run_once,
|
|
40
|
+
)
|
|
41
|
+
from .schema import load_or_create_schema_graph
|
|
42
|
+
from .templates import (
|
|
43
|
+
load_template_store,
|
|
44
|
+
store_to_rejected_templates,
|
|
45
|
+
store_to_templates,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _first_non_empty_env(names: list[str]) -> str | None:
|
|
50
|
+
"""Return the first non-empty environment value for the given names."""
|
|
51
|
+
for name in names:
|
|
52
|
+
value = os.environ.get(name)
|
|
53
|
+
if value:
|
|
54
|
+
return value
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _resolve_postgres_runtime_config(
|
|
59
|
+
host: str | None,
|
|
60
|
+
port: int | None,
|
|
61
|
+
user: str | None,
|
|
62
|
+
password: str | None,
|
|
63
|
+
database: str | None,
|
|
64
|
+
schema: str | None,
|
|
65
|
+
) -> tuple[str | None, int | None, str | None, str | None, str | None, str | None]:
|
|
66
|
+
"""Resolve PostgreSQL runtime configuration from arguments and environment."""
|
|
67
|
+
resolved_host = host or _first_non_empty_env(
|
|
68
|
+
["POSTGRESQL_HOST", "PGHOST", "DB_HOST"]
|
|
69
|
+
)
|
|
70
|
+
port_env = _first_non_empty_env(
|
|
71
|
+
["POSTGRESQL_PORT", "PGPORT", "DB_PORT"]
|
|
72
|
+
)
|
|
73
|
+
resolved_port = port
|
|
74
|
+
if resolved_port is None and port_env is not None:
|
|
75
|
+
try:
|
|
76
|
+
resolved_port = int(port_env)
|
|
77
|
+
except ValueError:
|
|
78
|
+
resolved_port = None
|
|
79
|
+
resolved_user = user or _first_non_empty_env(
|
|
80
|
+
["POSTGRESQL_USER", "PGUSER", "DB_USER"]
|
|
81
|
+
)
|
|
82
|
+
resolved_password = password or _first_non_empty_env(
|
|
83
|
+
["POSTGRESQL_PASSWORD", "PGPASSWORD", "DB_PASSWORD"]
|
|
84
|
+
)
|
|
85
|
+
resolved_database = database or _first_non_empty_env(
|
|
86
|
+
["POSTGRESQL_DB", "POSTGRESQL_DATABASE", "PGDATABASE", "DB_NAME", "DB_DATABASE"]
|
|
87
|
+
)
|
|
88
|
+
resolved_schema = schema or _first_non_empty_env(
|
|
89
|
+
["POSTGRESQL_SCHEMA", "DB_SCHEMA"]
|
|
90
|
+
)
|
|
91
|
+
return (
|
|
92
|
+
resolved_host,
|
|
93
|
+
resolved_port,
|
|
94
|
+
resolved_user,
|
|
95
|
+
resolved_password,
|
|
96
|
+
resolved_database,
|
|
97
|
+
resolved_schema,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _resolve_databricks_runtime_config(
|
|
102
|
+
catalog: str | None,
|
|
103
|
+
schema: str | None,
|
|
104
|
+
server_hostname: str | None,
|
|
105
|
+
http_path: str | None,
|
|
106
|
+
access_token: str | None,
|
|
107
|
+
) -> tuple[str | None, str | None, str | None, str | None, str | None]:
|
|
108
|
+
"""Resolve Databricks runtime configuration from arguments and environment."""
|
|
109
|
+
resolved_catalog = catalog or _first_non_empty_env(
|
|
110
|
+
["DATABRICKS_CATALOG", "DBR_CATALOG"]
|
|
111
|
+
)
|
|
112
|
+
resolved_schema = schema or _first_non_empty_env(
|
|
113
|
+
["DATABRICKS_SCHEMA", "DBR_SCHEMA"]
|
|
114
|
+
)
|
|
115
|
+
resolved_server = server_hostname or _first_non_empty_env(
|
|
116
|
+
[
|
|
117
|
+
"DATABRICKS_SERVER_HOSTNAME",
|
|
118
|
+
"DATABRICKS_HOST",
|
|
119
|
+
"DATABRICKS_SERVER",
|
|
120
|
+
"DBR_SERVER_HOSTNAME",
|
|
121
|
+
"DBR_HOST",
|
|
122
|
+
]
|
|
123
|
+
)
|
|
124
|
+
resolved_http_path = http_path or _first_non_empty_env(
|
|
125
|
+
[
|
|
126
|
+
"DATABRICKS_HTTP_PATH",
|
|
127
|
+
"DATABRICKS_WAREHOUSE_HTTP_PATH",
|
|
128
|
+
"DATABRICKS_SQL_HTTP_PATH",
|
|
129
|
+
"DBR_HTTP_PATH",
|
|
130
|
+
]
|
|
131
|
+
)
|
|
132
|
+
resolved_token = access_token or _first_non_empty_env(
|
|
133
|
+
[
|
|
134
|
+
"DATABRICKS_ACCESS_TOKEN",
|
|
135
|
+
"DATABRICKS_PAT",
|
|
136
|
+
"ACCESS_TOKEN",
|
|
137
|
+
"PERSONAL_ACCESS_TOKEN",
|
|
138
|
+
"DBR_ACCESS_TOKEN",
|
|
139
|
+
"DBR_PAT",
|
|
140
|
+
"DATABRICKS_TOKEN",
|
|
141
|
+
]
|
|
142
|
+
)
|
|
143
|
+
return (
|
|
144
|
+
resolved_catalog,
|
|
145
|
+
resolved_schema,
|
|
146
|
+
resolved_server,
|
|
147
|
+
resolved_http_path,
|
|
148
|
+
resolved_token,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class Text2SQL:
|
|
153
|
+
"""Primary entry point for the text-to-SQL pipeline.
|
|
154
|
+
|
|
155
|
+
Manages engine configuration, API key resolution, schema loading,
|
|
156
|
+
template persistence, and dispatches execution to the selected
|
|
157
|
+
operational mode (interactive, simulator, or qsim).
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
engine: str = "postgresql",
|
|
163
|
+
mode: str = "interactive",
|
|
164
|
+
env_file: str | None = None,
|
|
165
|
+
openai_api_key: str | None = None,
|
|
166
|
+
host: str | None = None,
|
|
167
|
+
port: int | None = None,
|
|
168
|
+
user: str | None = None,
|
|
169
|
+
password: str | None = None,
|
|
170
|
+
database: str | None = None,
|
|
171
|
+
schema: str | None = None,
|
|
172
|
+
catalog: str | None = None,
|
|
173
|
+
sql_file: str | None = None,
|
|
174
|
+
server_hostname: str | None = None,
|
|
175
|
+
http_path: str | None = None,
|
|
176
|
+
access_token: str | None = None,
|
|
177
|
+
) -> None:
|
|
178
|
+
"""Initialise the deterministic Text-to-SQL system.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
|
|
182
|
+
engine: ``"postgresql"`` or ``"databricks"``.
|
|
183
|
+
|
|
184
|
+
mode: ``"interactive"``, ``"simulator"``, or ``"qsim"``.
|
|
185
|
+
|
|
186
|
+
env_file: Path to ``.env`` containing ``OPENAI_API_KEY``.
|
|
187
|
+
|
|
188
|
+
openai_api_key: Direct API key (overrides *env_file* and environment variable).
|
|
189
|
+
|
|
190
|
+
host: PostgreSQL host.
|
|
191
|
+
|
|
192
|
+
port: PostgreSQL port.
|
|
193
|
+
|
|
194
|
+
user: PostgreSQL user.
|
|
195
|
+
|
|
196
|
+
password: PostgreSQL password.
|
|
197
|
+
|
|
198
|
+
database: PostgreSQL database name.
|
|
199
|
+
|
|
200
|
+
schema: PostgreSQL or Databricks schema name.
|
|
201
|
+
|
|
202
|
+
catalog: Databricks catalog name.
|
|
203
|
+
|
|
204
|
+
sql_file: Optional ``CREATE TABLE`` file for offline schema loading.
|
|
205
|
+
|
|
206
|
+
server_hostname: Databricks server hostname for ``databricks-sql-connector``.
|
|
207
|
+
|
|
208
|
+
http_path: Databricks HTTP path for ``databricks-sql-connector``.
|
|
209
|
+
|
|
210
|
+
access_token: Databricks personal access token for ``databricks-sql-connector``.
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
|
|
214
|
+
ValueError: If *engine* or *mode* is invalid.
|
|
215
|
+
"""
|
|
216
|
+
if engine not in ("postgresql", "databricks"):
|
|
217
|
+
raise ValueError(f"Invalid engine: {engine}. Must be 'postgresql' or 'databricks'.")
|
|
218
|
+
if mode not in ("interactive", "simulator", "qsim"):
|
|
219
|
+
raise ValueError(f"Invalid mode: {mode}. Must be 'interactive', 'simulator', or 'qsim'.")
|
|
220
|
+
|
|
221
|
+
self._mode = mode
|
|
222
|
+
self._api_key_set = False
|
|
223
|
+
|
|
224
|
+
if openai_api_key:
|
|
225
|
+
self.set_openai_api_key(openai_api_key)
|
|
226
|
+
elif env_file:
|
|
227
|
+
self.set_env(env_file)
|
|
228
|
+
elif os.environ.get("OPENAI_API_KEY"):
|
|
229
|
+
EngineConfig.API_TOKEN = os.environ.get("OPENAI_API_KEY")
|
|
230
|
+
self._api_key_set = True
|
|
231
|
+
|
|
232
|
+
EngineConfig.TYPE = engine
|
|
233
|
+
|
|
234
|
+
pg_host = host
|
|
235
|
+
pg_port = port
|
|
236
|
+
pg_user = user
|
|
237
|
+
pg_password = password
|
|
238
|
+
pg_database = database
|
|
239
|
+
pg_schema = schema
|
|
240
|
+
dbr_server_hostname = server_hostname
|
|
241
|
+
dbr_http_path = http_path
|
|
242
|
+
dbr_access_token = access_token
|
|
243
|
+
dbr_catalog = catalog
|
|
244
|
+
dbr_schema = schema
|
|
245
|
+
|
|
246
|
+
if engine == "postgresql":
|
|
247
|
+
(
|
|
248
|
+
pg_host,
|
|
249
|
+
pg_port,
|
|
250
|
+
pg_user,
|
|
251
|
+
pg_password,
|
|
252
|
+
pg_database,
|
|
253
|
+
pg_schema,
|
|
254
|
+
) = _resolve_postgres_runtime_config(
|
|
255
|
+
host,
|
|
256
|
+
port,
|
|
257
|
+
user,
|
|
258
|
+
password,
|
|
259
|
+
database,
|
|
260
|
+
schema,
|
|
261
|
+
)
|
|
262
|
+
elif engine == "databricks":
|
|
263
|
+
(
|
|
264
|
+
dbr_catalog,
|
|
265
|
+
dbr_schema,
|
|
266
|
+
dbr_server_hostname,
|
|
267
|
+
dbr_http_path,
|
|
268
|
+
dbr_access_token,
|
|
269
|
+
) = _resolve_databricks_runtime_config(
|
|
270
|
+
catalog,
|
|
271
|
+
schema,
|
|
272
|
+
server_hostname,
|
|
273
|
+
http_path,
|
|
274
|
+
access_token,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self.artifacts_dir = get_artifacts_dir(
|
|
278
|
+
engine,
|
|
279
|
+
pg_host if engine == "postgresql" else host,
|
|
280
|
+
pg_database if engine == "postgresql" else database,
|
|
281
|
+
pg_schema if engine == "postgresql" else dbr_schema,
|
|
282
|
+
dbr_catalog,
|
|
283
|
+
)
|
|
284
|
+
if not os.path.exists(self.artifacts_dir):
|
|
285
|
+
os.makedirs(self.artifacts_dir, exist_ok=True)
|
|
286
|
+
|
|
287
|
+
EngineConfig.SCHEMA_JSON_PATH = os.path.join(self.artifacts_dir, "schema_graph.json")
|
|
288
|
+
EngineConfig.TEMPLATE_JSON_PATH = os.path.join(self.artifacts_dir, "intent_templates.json")
|
|
289
|
+
QSimConfig.QUESTIONS_OUTPUT_PATH = os.path.join(self.artifacts_dir, QSimConfig.QUESTIONS_OUTPUT_PATH)
|
|
290
|
+
QSimConfig.SKELETONS_JSON_PATH = os.path.join(self.artifacts_dir, "qsim_skeletons.json")
|
|
291
|
+
|
|
292
|
+
if engine == "postgresql":
|
|
293
|
+
EngineConfig.RUNTIME = PostgresRuntimeConfig
|
|
294
|
+
if pg_host is not None:
|
|
295
|
+
PostgresRuntimeConfig.HOST = pg_host
|
|
296
|
+
if pg_port is not None:
|
|
297
|
+
PostgresRuntimeConfig.PORT = pg_port
|
|
298
|
+
if pg_user is not None:
|
|
299
|
+
PostgresRuntimeConfig.USER = pg_user
|
|
300
|
+
if pg_password is not None:
|
|
301
|
+
PostgresRuntimeConfig.PASSWORD = pg_password
|
|
302
|
+
if pg_database is not None:
|
|
303
|
+
PostgresRuntimeConfig.DATABASE = pg_database
|
|
304
|
+
if pg_schema is not None:
|
|
305
|
+
PostgresRuntimeConfig.SCHEMA = pg_schema
|
|
306
|
+
if sql_file is not None:
|
|
307
|
+
PostgresRuntimeConfig.SQL_FILE_PATH = sql_file
|
|
308
|
+
elif engine == "databricks":
|
|
309
|
+
EngineConfig.RUNTIME = DatabricksRuntimeConfig
|
|
310
|
+
if dbr_catalog is not None:
|
|
311
|
+
DatabricksRuntimeConfig.CATALOG = dbr_catalog
|
|
312
|
+
if dbr_schema is not None:
|
|
313
|
+
DatabricksRuntimeConfig.SCHEMA = dbr_schema
|
|
314
|
+
if sql_file is not None:
|
|
315
|
+
DatabricksRuntimeConfig.SQL_FILE_PATH = sql_file
|
|
316
|
+
if dbr_server_hostname is not None:
|
|
317
|
+
DatabricksRuntimeConfig.SERVER_HOSTNAME = dbr_server_hostname
|
|
318
|
+
if dbr_http_path is not None:
|
|
319
|
+
DatabricksRuntimeConfig.HTTP_PATH = dbr_http_path
|
|
320
|
+
if dbr_access_token is not None:
|
|
321
|
+
DatabricksRuntimeConfig.ACCESS_TOKEN = dbr_access_token
|
|
322
|
+
DatabricksRuntimeConfig.validate()
|
|
323
|
+
|
|
324
|
+
self.dialect = get_dialect(EngineConfig.TYPE, EngineConfig.RUNTIME)
|
|
325
|
+
|
|
326
|
+
if EngineConfig.TYPE == "postgresql":
|
|
327
|
+
from sqlalchemy import create_engine
|
|
328
|
+
|
|
329
|
+
self.engine = create_engine(EngineConfig.RUNTIME.db_url(), future=True)
|
|
330
|
+
elif EngineConfig.TYPE == "databricks":
|
|
331
|
+
self.engine = None
|
|
332
|
+
|
|
333
|
+
try:
|
|
334
|
+
self._schema = load_or_create_schema_graph(self.engine)
|
|
335
|
+
self._store = load_template_store(self._schema.schema_hash)
|
|
336
|
+
self._templates = store_to_templates(self._store)
|
|
337
|
+
self._rejected = store_to_rejected_templates(self._store)
|
|
338
|
+
self._schema_terms = set(self._schema.tables.keys())
|
|
339
|
+
for tinfo in self._schema.tables.values():
|
|
340
|
+
self._schema_terms.update(tinfo.columns)
|
|
341
|
+
for col in tinfo.columns:
|
|
342
|
+
self._schema_terms.add(col.lower())
|
|
343
|
+
self._schema_stats = self._schema.schema_stats or {}
|
|
344
|
+
except Exception:
|
|
345
|
+
raise
|
|
346
|
+
|
|
347
|
+
@property
|
|
348
|
+
def schema(self) -> SchemaGraph:
|
|
349
|
+
"""The loaded schema graph for the connected database."""
|
|
350
|
+
return self._schema
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def store(self) -> dict[str, Any]:
|
|
354
|
+
"""The in-memory template store keyed by intent fingerprint."""
|
|
355
|
+
return self._store
|
|
356
|
+
|
|
357
|
+
@store.setter
|
|
358
|
+
def store(self, value: dict[str, Any]) -> None:
|
|
359
|
+
"""Replace the template store."""
|
|
360
|
+
self._store = value
|
|
361
|
+
|
|
362
|
+
@property
|
|
363
|
+
def templates(self) -> dict[str, Any]:
|
|
364
|
+
"""Dict of accepted templates derived from the template
|
|
365
|
+
store."""
|
|
366
|
+
return self._templates
|
|
367
|
+
|
|
368
|
+
@templates.setter
|
|
369
|
+
def templates(self, value: dict[str, Any]) -> None:
|
|
370
|
+
"""Replace the accepted templates dict."""
|
|
371
|
+
self._templates = value
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def rejected(self) -> dict[str, Any]:
|
|
375
|
+
"""Dict of rejected templates derived from the template
|
|
376
|
+
store."""
|
|
377
|
+
return self._rejected
|
|
378
|
+
|
|
379
|
+
@rejected.setter
|
|
380
|
+
def rejected(self, value: dict[str, Any]) -> None:
|
|
381
|
+
"""Replace the rejected templates dict."""
|
|
382
|
+
self._rejected = value
|
|
383
|
+
|
|
384
|
+
@property
|
|
385
|
+
def schema_terms(self) -> set[str]:
|
|
386
|
+
"""Set of all table and column name tokens from the loaded
|
|
387
|
+
schema."""
|
|
388
|
+
return self._schema_terms
|
|
389
|
+
|
|
390
|
+
def _compute_num_intents_range(self) -> tuple[int, int]:
|
|
391
|
+
"""Compute an adaptive ``num_intents`` range based on schema complexity.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
|
|
395
|
+
A 2-tuple ``(min_intents, max_intents)`` derived from the table count.
|
|
396
|
+
"""
|
|
397
|
+
table_count = self._schema_stats["table_count"]
|
|
398
|
+
min_intents = max(5, table_count)
|
|
399
|
+
max_intents = min(200, table_count * 10)
|
|
400
|
+
return (min_intents, max_intents)
|
|
401
|
+
|
|
402
|
+
def _compute_num_questions_range(self) -> tuple[int, int]:
|
|
403
|
+
"""Compute an adaptive ``num_questions`` range based on schema variance potential.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
|
|
407
|
+
A 2-tuple ``(min_questions, max_questions)`` derived from the total number of filterable columns in the schema.
|
|
408
|
+
"""
|
|
409
|
+
min_intents, _ = self._compute_num_intents_range()
|
|
410
|
+
total_filterable = self._schema_stats["total_filterable"]
|
|
411
|
+
min_questions = max(min_intents, 10)
|
|
412
|
+
max_questions = min(2000, total_filterable * 20)
|
|
413
|
+
return (min_questions, max_questions)
|
|
414
|
+
|
|
415
|
+
def get_qsim_range(self) -> QSimRange:
|
|
416
|
+
"""Return valid ``num_intents`` and ``num_questions`` ranges for this schema."""
|
|
417
|
+
return QSimRange(
|
|
418
|
+
num_intents_range=self._compute_num_intents_range(),
|
|
419
|
+
num_questions_range=self._compute_num_questions_range(),
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
def set_mode(self, mode: str) -> None:
|
|
423
|
+
"""Switch the operational mode.
|
|
424
|
+
|
|
425
|
+
Args:
|
|
426
|
+
|
|
427
|
+
mode: ``"interactive"``, ``"simulator"``, or ``"qsim"``.
|
|
428
|
+
|
|
429
|
+
Raises:
|
|
430
|
+
|
|
431
|
+
ValueError: If *mode* is invalid.
|
|
432
|
+
"""
|
|
433
|
+
if mode not in ("interactive", "simulator", "qsim"):
|
|
434
|
+
raise ValueError(f"Invalid mode: {mode}. Must be 'interactive', 'simulator', or 'qsim'.")
|
|
435
|
+
self._mode = mode
|
|
436
|
+
|
|
437
|
+
def set_env(self, env_file: str) -> None:
|
|
438
|
+
"""Load ``OPENAI_API_KEY`` from a ``.env`` file.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
|
|
442
|
+
env_file: Path to the ``.env`` file.
|
|
443
|
+
|
|
444
|
+
Raises:
|
|
445
|
+
|
|
446
|
+
FileNotFoundError: If *env_file* does not exist.
|
|
447
|
+
|
|
448
|
+
ValueError: If ``OPENAI_API_KEY`` is missing from the file.
|
|
449
|
+
"""
|
|
450
|
+
if not os.path.exists(env_file):
|
|
451
|
+
raise FileNotFoundError(f"Environment file not found: {env_file}")
|
|
452
|
+
load_dotenv(env_file)
|
|
453
|
+
api_key = os.environ.get("OPENAI_API_KEY")
|
|
454
|
+
if not api_key:
|
|
455
|
+
raise ValueError("OPENAI_API_KEY not found in environment file")
|
|
456
|
+
EngineConfig.API_TOKEN = api_key
|
|
457
|
+
self._api_key_set = True
|
|
458
|
+
|
|
459
|
+
def set_openai_api_key(self, key: str) -> None:
|
|
460
|
+
"""Set the OpenAI API key directly.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
|
|
464
|
+
key: Non-empty API key string.
|
|
465
|
+
|
|
466
|
+
Raises:
|
|
467
|
+
|
|
468
|
+
ValueError: If *key* is empty.
|
|
469
|
+
"""
|
|
470
|
+
if not key or not key.strip():
|
|
471
|
+
raise ValueError("API key cannot be empty")
|
|
472
|
+
EngineConfig.API_TOKEN = key.strip()
|
|
473
|
+
self._api_key_set = True
|
|
474
|
+
|
|
475
|
+
def _ensure_api_key(self) -> None:
|
|
476
|
+
"""Raise ``RuntimeError`` if no API key has been configured.
|
|
477
|
+
|
|
478
|
+
Raises:
|
|
479
|
+
|
|
480
|
+
RuntimeError: If neither ``set_openai_api_key``, ``set_env``, nor the ``openai_api_key`` or ``env_file`` constructor parameters were used to provide a key.
|
|
481
|
+
"""
|
|
482
|
+
if not self._api_key_set and not EngineConfig.API_TOKEN:
|
|
483
|
+
raise RuntimeError(
|
|
484
|
+
"OpenAI API key not set. Use set_openai_api_key(), set_env(), or pass openai_api_key/env_file to __init__."
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def _validate_num_intents(self, value: int) -> None:
|
|
488
|
+
"""Validate *value* is within the schema-adaptive ``num_intents`` range.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
|
|
492
|
+
value: The requested number of intents.
|
|
493
|
+
|
|
494
|
+
Raises:
|
|
495
|
+
|
|
496
|
+
ValueError: If *value* is outside ``[min_intents, max_intents]``.
|
|
497
|
+
"""
|
|
498
|
+
min_intents, max_intents = self._compute_num_intents_range()
|
|
499
|
+
if not (min_intents <= value <= max_intents):
|
|
500
|
+
raise ValueError(
|
|
501
|
+
f"num_intents must be {min_intents}-{max_intents} for this schema ({self._schema_stats['table_count']} tables)"
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
def _validate_num_questions(self, value: int) -> None:
|
|
505
|
+
"""Validate *value* is within the schema-adaptive ``num_questions`` range.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
|
|
509
|
+
value: The requested number of questions.
|
|
510
|
+
|
|
511
|
+
Raises:
|
|
512
|
+
|
|
513
|
+
ValueError: If *value* is outside ``[min_questions, max_questions]``.
|
|
514
|
+
"""
|
|
515
|
+
min_questions, max_questions = self._compute_num_questions_range()
|
|
516
|
+
if not (min_questions <= value <= max_questions):
|
|
517
|
+
raise ValueError(
|
|
518
|
+
f"num_questions must be {min_questions}-{max_questions} for this schema ({self._schema_stats['total_filterable']} total filterable columns)"
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
def run_interactive(self, debug: bool | None = None, verbose: bool | None = None) -> None:
|
|
522
|
+
"""Run the interactive question-answering loop with template reuse.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
|
|
526
|
+
debug: Override instance debug flag for this run.
|
|
527
|
+
|
|
528
|
+
verbose: Override instance verbose flag for this run.
|
|
529
|
+
|
|
530
|
+
Raises:
|
|
531
|
+
|
|
532
|
+
RuntimeError: If mode is not ``"interactive"`` or no API key is set.
|
|
533
|
+
"""
|
|
534
|
+
if self._mode != "interactive":
|
|
535
|
+
raise RuntimeError(f"Current mode is '{self._mode}'. Call set_mode('interactive') first.")
|
|
536
|
+
self._ensure_api_key()
|
|
537
|
+
|
|
538
|
+
debug = self._debug if debug is None else debug
|
|
539
|
+
verbose = self._verbose if verbose is None else verbose
|
|
540
|
+
|
|
541
|
+
EngineConfig.DEBUG = debug
|
|
542
|
+
EngineConfig._propagate_debug()
|
|
543
|
+
|
|
544
|
+
prev_verbose_state = EngineConfig.VERBOSE
|
|
545
|
+
EngineConfig.VERBOSE = verbose
|
|
546
|
+
try:
|
|
547
|
+
interactive_run_once(
|
|
548
|
+
self._schema,
|
|
549
|
+
self._store,
|
|
550
|
+
self._templates,
|
|
551
|
+
self._rejected,
|
|
552
|
+
self._schema_terms,
|
|
553
|
+
)
|
|
554
|
+
finally:
|
|
555
|
+
EngineConfig.VERBOSE = prev_verbose_state
|
|
556
|
+
|
|
557
|
+
def run_simulator(
|
|
558
|
+
self,
|
|
559
|
+
seed_filepath: str,
|
|
560
|
+
interactive_gold: bool = True,
|
|
561
|
+
debug: bool | None = None,
|
|
562
|
+
verbose: bool | None = None,
|
|
563
|
+
) -> SimulatorSummary:
|
|
564
|
+
"""Generate synthetic templates from seed questions.
|
|
565
|
+
|
|
566
|
+
Args:
|
|
567
|
+
|
|
568
|
+
seed_filepath: Path to seed questions text file.
|
|
569
|
+
|
|
570
|
+
interactive_gold: Interactively confirm each gold intent when ``True``.
|
|
571
|
+
|
|
572
|
+
debug: Override instance debug flag for this run.
|
|
573
|
+
|
|
574
|
+
verbose: Override instance verbose flag for this run.
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
|
|
578
|
+
``SimulatorSummary`` describing the results.
|
|
579
|
+
|
|
580
|
+
Raises:
|
|
581
|
+
|
|
582
|
+
RuntimeError: If mode is not ``"simulator"`` or no API key is set.
|
|
583
|
+
"""
|
|
584
|
+
if self._mode != "simulator":
|
|
585
|
+
raise RuntimeError(f"Current mode is '{self._mode}'. Call set_mode('simulator') first.")
|
|
586
|
+
self._ensure_api_key()
|
|
587
|
+
|
|
588
|
+
debug = self._debug if debug is None else debug
|
|
589
|
+
verbose = self._verbose if verbose is None else verbose
|
|
590
|
+
|
|
591
|
+
EngineConfig.DEBUG = debug
|
|
592
|
+
EngineConfig._propagate_debug()
|
|
593
|
+
|
|
594
|
+
prev_verbose_state = EngineConfig.VERBOSE
|
|
595
|
+
EngineConfig.VERBOSE = verbose
|
|
596
|
+
try:
|
|
597
|
+
return simulator_run_once(
|
|
598
|
+
schema=self.schema,
|
|
599
|
+
dialect=self.dialect,
|
|
600
|
+
seed_filepath=seed_filepath,
|
|
601
|
+
output_dir=self.artifacts_dir,
|
|
602
|
+
store=self.store,
|
|
603
|
+
templates=self.templates,
|
|
604
|
+
interactive_gold=interactive_gold,
|
|
605
|
+
)
|
|
606
|
+
finally:
|
|
607
|
+
EngineConfig.VERBOSE = prev_verbose_state
|
|
608
|
+
|
|
609
|
+
def run_qsim(
|
|
610
|
+
self,
|
|
611
|
+
num_intents: int = 20,
|
|
612
|
+
num_questions: int = 100,
|
|
613
|
+
seed: int | None = None,
|
|
614
|
+
debug: bool | None = None,
|
|
615
|
+
verbose: bool | None = None,
|
|
616
|
+
) -> QSimSummary:
|
|
617
|
+
"""Generate synthetic NL questions from schema-derived intent skeletons.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
|
|
621
|
+
num_intents: Number of distinct intent types to generate.
|
|
622
|
+
|
|
623
|
+
num_questions: Total NL question variants to produce.
|
|
624
|
+
|
|
625
|
+
seed: Random seed for reproducible generation.
|
|
626
|
+
|
|
627
|
+
debug: Override instance debug flag for this run.
|
|
628
|
+
|
|
629
|
+
verbose: Override instance verbose flag for this run.
|
|
630
|
+
|
|
631
|
+
Returns:
|
|
632
|
+
|
|
633
|
+
``QSimSummary`` describing the generation results.
|
|
634
|
+
|
|
635
|
+
Raises:
|
|
636
|
+
|
|
637
|
+
RuntimeError: If mode is not ``"qsim"`` or no API key is set.
|
|
638
|
+
|
|
639
|
+
ValueError: If *num_intents* or *num_questions* are out of range.
|
|
640
|
+
"""
|
|
641
|
+
if self._mode != "qsim":
|
|
642
|
+
raise RuntimeError(f"Current mode is '{self._mode}'. Call set_mode('qsim') first.")
|
|
643
|
+
self._ensure_api_key()
|
|
644
|
+
|
|
645
|
+
self._validate_num_intents(num_intents)
|
|
646
|
+
self._validate_num_questions(num_questions)
|
|
647
|
+
|
|
648
|
+
debug = self._debug if debug is None else debug
|
|
649
|
+
verbose = self._verbose if verbose is None else verbose
|
|
650
|
+
|
|
651
|
+
EngineConfig.DEBUG = debug
|
|
652
|
+
EngineConfig._propagate_debug()
|
|
653
|
+
|
|
654
|
+
prev_verbose_state = EngineConfig.VERBOSE
|
|
655
|
+
EngineConfig.VERBOSE = verbose
|
|
656
|
+
|
|
657
|
+
try:
|
|
658
|
+
return qsim_run_once(
|
|
659
|
+
num_intents=num_intents,
|
|
660
|
+
num_questions=num_questions,
|
|
661
|
+
seed=seed,
|
|
662
|
+
artifacts_dir=self.artifacts_dir,
|
|
663
|
+
schema=self._schema,
|
|
664
|
+
)
|
|
665
|
+
finally:
|
|
666
|
+
EngineConfig.VERBOSE = prev_verbose_state
|
|
667
|
+
|
|
668
|
+
def get_qsim(self) -> list[QSimSummary]:
|
|
669
|
+
"""Return all QSim run summaries oldest-first."""
|
|
670
|
+
return get_qsim(self.artifacts_dir)
|
|
671
|
+
|
|
672
|
+
def get_questions_only(self, timestamp_or_result: str | QSimSummary, output_path: str | None = None) -> None:
|
|
673
|
+
"""Print and save NL questions from a QSim run.
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
|
|
677
|
+
timestamp_or_result: Timestamp string or ``QSimSummary`` identifying the run.
|
|
678
|
+
|
|
679
|
+
output_path: Output file path; defaults to ``qsim_<timestamp>_questions.txt``.
|
|
680
|
+
"""
|
|
681
|
+
if isinstance(timestamp_or_result, QSimSummary):
|
|
682
|
+
timestamp = timestamp_or_result.timestamp
|
|
683
|
+
else:
|
|
684
|
+
timestamp = timestamp_or_result
|
|
685
|
+
|
|
686
|
+
path = resolve_qsim_path(timestamp_or_result, self.artifacts_dir)
|
|
687
|
+
results = load_generated_questions(path)
|
|
688
|
+
|
|
689
|
+
if output_path is None:
|
|
690
|
+
output_path = f"qsim_{timestamp}_questions.txt"
|
|
691
|
+
|
|
692
|
+
get_questions_only(results, output_path)
|
|
693
|
+
|
|
694
|
+
def get_stats(self) -> dict[str, Any]:
|
|
695
|
+
"""Return template store statistics (counts, schema hash, table count)."""
|
|
696
|
+
return {
|
|
697
|
+
"templates_count": len(self.templates),
|
|
698
|
+
"rejected_count": len(self.rejected),
|
|
699
|
+
"schema_hash": self.schema.schema_hash,
|
|
700
|
+
"table_count": len(self.schema.tables),
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
def get_schema_created_at(self) -> str:
|
|
704
|
+
"""Return the ISO-format creation timestamp of the loaded schema."""
|
|
705
|
+
return self.schema.created_at
|
|
706
|
+
|
|
707
|
+
def get_templates(self) -> list[TemplateInfo]:
|
|
708
|
+
"""Return all accepted templates as ``TemplateInfo`` objects."""
|
|
709
|
+
return get_templates_list(self.templates)
|
|
710
|
+
|
|
711
|
+
def get_rejected_templates(self) -> list[RejectedTemplateInfo]:
|
|
712
|
+
"""Return all rejected templates as ``RejectedTemplateInfo`` objects."""
|
|
713
|
+
return get_rejected_templates_list(self.rejected)
|
|
714
|
+
|
|
715
|
+
def get_schema_tables(self) -> list[str]:
|
|
716
|
+
"""Return the list of table names in the loaded schema."""
|
|
717
|
+
return list(self.schema.tables.keys())
|
|
718
|
+
|
|
719
|
+
def get_simulator_summary(self, version: int) -> SimulatorSummary:
|
|
720
|
+
"""Return the ``SimulatorSummary`` for a specific run version.
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
|
|
724
|
+
version: Simulator run version number.
|
|
725
|
+
"""
|
|
726
|
+
return get_simulator_summary_from_dir(self.artifacts_dir, version)
|