dbworkload 0.9.2.dev1__tar.gz → 0.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/PKG-INFO +10 -2
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/cli/main.py +1 -1
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/cli/util.py +70 -10
- dbworkload-0.10.0/dbworkload/models/convert.py +482 -0
- dbworkload-0.10.0/dbworkload/models/prompts.py +418 -0
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/models/run.py +97 -200
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/utils/common.py +0 -9
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/pyproject.toml +14 -1
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/LICENSE +0 -0
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/README.md +0 -0
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/__init__.py +0 -0
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/cli/dep.py +0 -0
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/models/util.py +0 -0
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/templates/stub.j2 +0 -0
- {dbworkload-0.9.2.dev1 → dbworkload-0.10.0}/dbworkload/utils/simplefaker.py +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: dbworkload
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.10.0
|
|
4
4
|
Summary: Workload framework
|
|
5
5
|
License: GPLv3+
|
|
6
|
+
License-File: LICENSE
|
|
6
7
|
Author: Fabio Ghirardello
|
|
7
8
|
Requires-Python: >=3.11,<4.0
|
|
8
9
|
Classifier: License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)
|
|
@@ -12,8 +13,10 @@ Classifier: Programming Language :: Python :: 3
|
|
|
12
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
15
|
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
15
17
|
Provides-Extra: all
|
|
16
18
|
Provides-Extra: cassandra
|
|
19
|
+
Provides-Extra: convert
|
|
17
20
|
Provides-Extra: mariadb
|
|
18
21
|
Provides-Extra: mongo
|
|
19
22
|
Provides-Extra: mysql
|
|
@@ -24,9 +27,14 @@ Provides-Extra: spanner
|
|
|
24
27
|
Requires-Dist: cassandra-driver ; extra == "all" or extra == "cassandra"
|
|
25
28
|
Requires-Dist: google-cloud-spanner ; extra == "all" or extra == "spanner"
|
|
26
29
|
Requires-Dist: jinja2
|
|
30
|
+
Requires-Dist: langchain-core (>=1.0.5,<2.0.0) ; extra == "convert"
|
|
31
|
+
Requires-Dist: langchain-ollama (>=1.0.0,<2.0.0) ; extra == "convert"
|
|
32
|
+
Requires-Dist: langchain-openai (>=1.0.2,<2.0.0) ; extra == "convert"
|
|
33
|
+
Requires-Dist: langgraph (>=1.0.3,<2.0.0) ; extra == "convert"
|
|
27
34
|
Requires-Dist: mariadb ; extra == "all" or extra == "mariadb"
|
|
28
35
|
Requires-Dist: mysql-connector-python ; extra == "all" or extra == "mysql"
|
|
29
36
|
Requires-Dist: numpy
|
|
37
|
+
Requires-Dist: openai (>=2.8.0,<3.0.0) ; extra == "convert"
|
|
30
38
|
Requires-Dist: oracledb ; extra == "all" or extra == "oracle"
|
|
31
39
|
Requires-Dist: pandas
|
|
32
40
|
Requires-Dist: plotext
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
#!/usr/bin/python
|
|
2
2
|
|
|
3
|
+
import sys
|
|
3
4
|
from enum import Enum
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import Optional
|
|
6
7
|
|
|
7
8
|
import typer
|
|
8
9
|
|
|
9
|
-
import dbworkload.models.run
|
|
10
10
|
import dbworkload.models.util
|
|
11
|
-
import dbworkload.utils.common
|
|
12
11
|
from dbworkload.cli.dep import EPILOG, Param
|
|
13
12
|
|
|
13
|
+
# import cloud_instance.cli.util
|
|
14
|
+
from ..models.convert import ConvertTool
|
|
15
|
+
|
|
14
16
|
|
|
15
17
|
class Compression(str, Enum):
|
|
16
18
|
bz2 = "bz2"
|
|
@@ -19,14 +21,14 @@ class Compression(str, Enum):
|
|
|
19
21
|
zip = "zip"
|
|
20
22
|
|
|
21
23
|
|
|
22
|
-
|
|
24
|
+
util_app = typer.Typer(
|
|
23
25
|
epilog=EPILOG,
|
|
24
26
|
no_args_is_help=True,
|
|
25
27
|
help="Various utils.",
|
|
26
28
|
)
|
|
27
29
|
|
|
28
30
|
|
|
29
|
-
@
|
|
31
|
+
@util_app.command(
|
|
30
32
|
"csv",
|
|
31
33
|
epilog=EPILOG,
|
|
32
34
|
no_args_is_help=True,
|
|
@@ -99,7 +101,7 @@ def util_csv(
|
|
|
99
101
|
)
|
|
100
102
|
|
|
101
103
|
|
|
102
|
-
@
|
|
104
|
+
@util_app.command(
|
|
103
105
|
"yaml",
|
|
104
106
|
epilog=EPILOG,
|
|
105
107
|
no_args_is_help=True,
|
|
@@ -135,7 +137,7 @@ def util_yaml(
|
|
|
135
137
|
dbworkload.models.util.util_yaml(input=input, output=output)
|
|
136
138
|
|
|
137
139
|
|
|
138
|
-
@
|
|
140
|
+
@util_app.command(
|
|
139
141
|
"merge_sort",
|
|
140
142
|
epilog=EPILOG,
|
|
141
143
|
no_args_is_help=True,
|
|
@@ -178,7 +180,7 @@ def util_sort_merge(
|
|
|
178
180
|
dbworkload.models.util.util_merge_sort(input, output, csv_max_rows, compress)
|
|
179
181
|
|
|
180
182
|
|
|
181
|
-
@
|
|
183
|
+
@util_app.command(
|
|
182
184
|
"plot",
|
|
183
185
|
epilog=EPILOG,
|
|
184
186
|
no_args_is_help=True,
|
|
@@ -201,7 +203,7 @@ def util_plot(
|
|
|
201
203
|
dbworkload.models.util.util_plot(input)
|
|
202
204
|
|
|
203
205
|
|
|
204
|
-
@
|
|
206
|
+
@util_app.command(
|
|
205
207
|
"html",
|
|
206
208
|
epilog=EPILOG,
|
|
207
209
|
no_args_is_help=True,
|
|
@@ -224,7 +226,7 @@ def util_html(
|
|
|
224
226
|
dbworkload.models.util.util_html(input)
|
|
225
227
|
|
|
226
228
|
|
|
227
|
-
@
|
|
229
|
+
@util_app.command(
|
|
228
230
|
"merge_csvs",
|
|
229
231
|
epilog=EPILOG,
|
|
230
232
|
no_args_is_help=True,
|
|
@@ -247,7 +249,7 @@ def util_merge_csvs(
|
|
|
247
249
|
dbworkload.models.util.util_merge_csvs(input_dir)
|
|
248
250
|
|
|
249
251
|
|
|
250
|
-
@
|
|
252
|
+
@util_app.command(
|
|
251
253
|
"gen_stub",
|
|
252
254
|
epilog=EPILOG,
|
|
253
255
|
no_args_is_help=True,
|
|
@@ -268,3 +270,61 @@ def util_gen_stub(
|
|
|
268
270
|
),
|
|
269
271
|
):
|
|
270
272
|
dbworkload.models.util.util_gen_stub(input_file)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
@util_app.command(
|
|
276
|
+
name="convert",
|
|
277
|
+
help="Convert from PL to PL/pgSQL",
|
|
278
|
+
no_args_is_help=True,
|
|
279
|
+
)
|
|
280
|
+
def cli_convert(
|
|
281
|
+
base_dir: Optional[Path] = typer.Option(
|
|
282
|
+
".",
|
|
283
|
+
"--dir",
|
|
284
|
+
"-d",
|
|
285
|
+
help="Directory path",
|
|
286
|
+
exists=True,
|
|
287
|
+
file_okay=False,
|
|
288
|
+
dir_okay=True,
|
|
289
|
+
writable=False,
|
|
290
|
+
readable=True,
|
|
291
|
+
resolve_path=True,
|
|
292
|
+
),
|
|
293
|
+
uri: str = typer.Option(
|
|
294
|
+
None,
|
|
295
|
+
"--uri",
|
|
296
|
+
help="The connection URI to the database.",
|
|
297
|
+
),
|
|
298
|
+
root_file: Optional[str] = typer.Option(
|
|
299
|
+
None,
|
|
300
|
+
"--root-file",
|
|
301
|
+
"-r",
|
|
302
|
+
help="The root_file. Leave empty for processing all *.ddl files.",
|
|
303
|
+
),
|
|
304
|
+
generator_llm: Optional[str] = typer.Option(
|
|
305
|
+
"Ollama:llama3.2:3b",
|
|
306
|
+
"--generator-llm",
|
|
307
|
+
"-g",
|
|
308
|
+
help="The generator provider:model_name",
|
|
309
|
+
),
|
|
310
|
+
refiner_llm: Optional[str] = typer.Option(
|
|
311
|
+
"OpenAI:gpt-5",
|
|
312
|
+
"--refiner-llm",
|
|
313
|
+
"-n",
|
|
314
|
+
help="The refiner provider:model_name.",
|
|
315
|
+
),
|
|
316
|
+
):
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
ConvertTool(
|
|
320
|
+
base_dir,
|
|
321
|
+
uri,
|
|
322
|
+
root_file,
|
|
323
|
+
generator_llm,
|
|
324
|
+
refiner_llm,
|
|
325
|
+
# seed,
|
|
326
|
+
# seed_each_time,
|
|
327
|
+
).run()
|
|
328
|
+
except Exception as e:
|
|
329
|
+
print(e, file=sys.stderr)
|
|
330
|
+
typer.Exit(1)
|
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, TypedDict
|
|
5
|
+
|
|
6
|
+
import openai
|
|
7
|
+
import psycopg
|
|
8
|
+
import sqlparse
|
|
9
|
+
import yaml
|
|
10
|
+
from langchain_core.callbacks import get_usage_metadata_callback
|
|
11
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
12
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
13
|
+
from langchain_core.runnables import RunnablePassthrough
|
|
14
|
+
from langchain_ollama import ChatOllama
|
|
15
|
+
from langchain_openai import ChatOpenAI
|
|
16
|
+
from langgraph.graph import END, StateGraph
|
|
17
|
+
from psycopg.rows import dict_row
|
|
18
|
+
|
|
19
|
+
from .prompts import REFINER_PROMPT, SYSTEM_PROMPT
|
|
20
|
+
|
|
21
|
+
# setup global logger
|
|
22
|
+
logger = logging.getLogger("dbworkload")
|
|
23
|
+
|
|
24
|
+
openai.api_key = os.getenv("OPENAI_API_KEY") # or set it directly if needed
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_llm(provider: str, model: str) -> ChatOllama | ChatOpenAI | None:
|
|
28
|
+
|
|
29
|
+
if provider.lower() == "openai":
|
|
30
|
+
return ChatOpenAI(
|
|
31
|
+
model=model,
|
|
32
|
+
temperature=0.1,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if provider.lower() == "ollama":
|
|
36
|
+
return ChatOllama(
|
|
37
|
+
model=model,
|
|
38
|
+
temperature=0.1,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ConversionState(TypedDict):
|
|
43
|
+
"""The state shared across all nodes in the graph."""
|
|
44
|
+
|
|
45
|
+
oracle_code: str # Initial input code
|
|
46
|
+
converted_code: str # The current version of the code
|
|
47
|
+
validation_error: str # Error message from the validator
|
|
48
|
+
history: list # Log of attempts and errors
|
|
49
|
+
max_attempts: int # Stop condition
|
|
50
|
+
attempts: int
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ConvertTool:
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
base_dir: str,
|
|
57
|
+
uri: str,
|
|
58
|
+
root: str,
|
|
59
|
+
generator_llm: str,
|
|
60
|
+
refiner_llm: str,
|
|
61
|
+
):
|
|
62
|
+
|
|
63
|
+
self.base_dir = base_dir
|
|
64
|
+
self.uri = uri
|
|
65
|
+
self.root = root
|
|
66
|
+
|
|
67
|
+
self.expected_output: dict = {}
|
|
68
|
+
self.seed_statements: list[str] = []
|
|
69
|
+
self.test_statements: list[str] = []
|
|
70
|
+
self.generator_llm_provider, self.generator_llm_model = generator_llm.split(
|
|
71
|
+
":", maxsplit=1
|
|
72
|
+
)
|
|
73
|
+
self.refiner_llm_provider, self.refiner_llm_model = refiner_llm.split(
|
|
74
|
+
":", maxsplit=1
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Create the Prompt Template
|
|
78
|
+
self.prompt = ChatPromptTemplate.from_messages(
|
|
79
|
+
[("system", SYSTEM_PROMPT), ("user", "{oracle_code}")]
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Initialize your LLM (using a powerful model is key for code translation)
|
|
83
|
+
self.refiner_llm = get_llm(*refiner_llm.split(":", maxsplit=1))
|
|
84
|
+
|
|
85
|
+
self.generator_llm = get_llm(*generator_llm.split(":", maxsplit=1))
|
|
86
|
+
|
|
87
|
+
self.parser = StrOutputParser()
|
|
88
|
+
|
|
89
|
+
# Create the runnable chain
|
|
90
|
+
self.conversion_chain = (
|
|
91
|
+
{
|
|
92
|
+
"oracle_code": RunnablePassthrough(),
|
|
93
|
+
}
|
|
94
|
+
| self.prompt
|
|
95
|
+
| self.generator_llm
|
|
96
|
+
| self.parser
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self.refiner_chain = (
|
|
100
|
+
ChatPromptTemplate.from_messages(
|
|
101
|
+
[
|
|
102
|
+
("system", REFINER_PROMPT),
|
|
103
|
+
(
|
|
104
|
+
"user",
|
|
105
|
+
"Please provide the fixed, corrected CockroachDB PL/pgSQL code.",
|
|
106
|
+
),
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
| self.refiner_llm
|
|
110
|
+
| self.parser
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def refine(self, state: ConversionState) -> dict:
|
|
114
|
+
|
|
115
|
+
logger.info(f"⚙️ Refiner Node (Attempt #{state['attempts'] + 1})")
|
|
116
|
+
|
|
117
|
+
logger.info(f"📡 ➡️ Sending query to {self.refiner_llm_model}")
|
|
118
|
+
with get_usage_metadata_callback() as ctx:
|
|
119
|
+
refined_code = self.refiner_chain.invoke(
|
|
120
|
+
{
|
|
121
|
+
"validation_error": state["validation_error"],
|
|
122
|
+
"converted_code": state["converted_code"],
|
|
123
|
+
}
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
logger.info(
|
|
127
|
+
f"📡 ⬅️ Receiving from {self.refiner_llm_model}: {refined_code=}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
logger.info(f"💰 {self.refiner_llm_model} cost={ctx.usage_metadata}")
|
|
131
|
+
|
|
132
|
+
# 3. Update the state with the refined code for the next loop's generator
|
|
133
|
+
return {
|
|
134
|
+
"converted_code": refined_code, # Pass the refined code back to the validator for the next pass
|
|
135
|
+
"validation_error": "", # Clear the error for the next attempt
|
|
136
|
+
"history": state.get("history", [])
|
|
137
|
+
+ [
|
|
138
|
+
{
|
|
139
|
+
"attempt": state.get("attempts"),
|
|
140
|
+
"status": "Refined",
|
|
141
|
+
"refined_code": refined_code,
|
|
142
|
+
}
|
|
143
|
+
],
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
def generate_code(self, state: ConversionState) -> dict:
|
|
147
|
+
"""
|
|
148
|
+
Invokes the LangChain conversion pipeline to generate the first draft
|
|
149
|
+
or a refined draft of the CockroachDB PL/pgSQL code.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
logger.info(f"⚙️ Generate Node (Attempt #{state['attempts'] + 1})")
|
|
153
|
+
|
|
154
|
+
# 1. Get the Oracle code from the current state
|
|
155
|
+
oracle_code = state["oracle_code"]
|
|
156
|
+
|
|
157
|
+
# 2. Execute the LangChain Runnable (the pipeline)
|
|
158
|
+
# The 'invoke' command sends the code to the LLM via the API.
|
|
159
|
+
logger.info(f"📡 ➡️ Sending query to {self.generator_llm_model}")
|
|
160
|
+
|
|
161
|
+
with get_usage_metadata_callback() as ctx:
|
|
162
|
+
converted_code_output = self.conversion_chain.invoke(oracle_code)
|
|
163
|
+
|
|
164
|
+
logger.info(
|
|
165
|
+
f"📡 ⬅️ Receiving from {self.generator_llm_model}: {converted_code_output=}"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
logger.info(
|
|
169
|
+
f"💰 {self.generator_llm_model} cost={ctx.usage_metadata[self.generator_llm.model]}"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# 3. Update the state for the next node in the graph
|
|
173
|
+
# We also increment the attempt counter
|
|
174
|
+
|
|
175
|
+
return {
|
|
176
|
+
"converted_code": converted_code_output,
|
|
177
|
+
"history": state.get("history", [])
|
|
178
|
+
+ [
|
|
179
|
+
{
|
|
180
|
+
"attempt": state.get("attempts"),
|
|
181
|
+
"code": converted_code_output,
|
|
182
|
+
"status": "Generated",
|
|
183
|
+
}
|
|
184
|
+
],
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
def validate_code(self, state: ConversionState) -> dict:
|
|
188
|
+
"""
|
|
189
|
+
Checks the converted code for syntax or logical errors.
|
|
190
|
+
Returns the error message if a failure is found, or an empty string on success.
|
|
191
|
+
"""
|
|
192
|
+
logger.info(f"⚙️ Validator Node (Attempt #{state['attempts'] + 1})")
|
|
193
|
+
|
|
194
|
+
converted_code = state["converted_code"]
|
|
195
|
+
|
|
196
|
+
if self.seed_statements:
|
|
197
|
+
logger.info(f"🌱 Seeding CockroachDB prior to running tests")
|
|
198
|
+
|
|
199
|
+
self.execute_sql_stmts(self.seed_statements)
|
|
200
|
+
|
|
201
|
+
logger.info("Creating the CockroachDB SP in the test cluster")
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
with psycopg.connect(self.uri, autocommit=True) as conn:
|
|
205
|
+
with conn.cursor() as cur:
|
|
206
|
+
cur.execute(converted_code)
|
|
207
|
+
|
|
208
|
+
logger.info(f"🪳 🟢 {cur.statusmessage}")
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
|
|
212
|
+
error_message = str(e)
|
|
213
|
+
logger.error(f"🪳 🔴 {error_message}")
|
|
214
|
+
|
|
215
|
+
with open(f"{self.base_dir}/out/{self.root}.out", "a") as f:
|
|
216
|
+
f.write("Error creating Stored Procedure\n")
|
|
217
|
+
f.write(error_message)
|
|
218
|
+
f.write("\n\n")
|
|
219
|
+
|
|
220
|
+
new_attempts = state.get("attempts", 0) + 1
|
|
221
|
+
return {
|
|
222
|
+
"validation_error": error_message,
|
|
223
|
+
"attempts": new_attempts,
|
|
224
|
+
"history": state.get("history", [])
|
|
225
|
+
+ [
|
|
226
|
+
{
|
|
227
|
+
"attempt": state["attempts"],
|
|
228
|
+
"status": "Validated",
|
|
229
|
+
"error": error_message,
|
|
230
|
+
}
|
|
231
|
+
],
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
logger.info("Run the SQL Test statements against the CockroachDB cluster")
|
|
235
|
+
|
|
236
|
+
actual: dict = {}
|
|
237
|
+
error_message = ""
|
|
238
|
+
|
|
239
|
+
for idx, s in enumerate(self.test_statements):
|
|
240
|
+
try:
|
|
241
|
+
|
|
242
|
+
# Connect; row_factory yields dicts keyed by column names
|
|
243
|
+
with psycopg.connect(self.uri, autocommit=True) as conn:
|
|
244
|
+
|
|
245
|
+
with conn.cursor(row_factory=dict_row) as cur:
|
|
246
|
+
cur.execute(s)
|
|
247
|
+
|
|
248
|
+
# If cursor.description is present, we have a result set (e.g., SELECT/SHOW)
|
|
249
|
+
if cur.description is not None:
|
|
250
|
+
actual[idx] = [
|
|
251
|
+
{k: self.to_jsonable(v) for k, v in row.items()}
|
|
252
|
+
for row in cur
|
|
253
|
+
]
|
|
254
|
+
|
|
255
|
+
with open(f"{self.base_dir}/out/{self.root}.out", "a") as f:
|
|
256
|
+
if actual[idx] == self.expected_output[str(idx)]:
|
|
257
|
+
logger.info(f"{idx=} : 🟢 OK")
|
|
258
|
+
f.write(f"{idx=} : 🟢 OK ")
|
|
259
|
+
else:
|
|
260
|
+
logger.info(f"{idx=} : 🔴 FAIL")
|
|
261
|
+
f.write(f"{idx=} : 🔴 FAIL")
|
|
262
|
+
|
|
263
|
+
f.write("\n")
|
|
264
|
+
|
|
265
|
+
except Exception as e:
|
|
266
|
+
error_message = str(e)
|
|
267
|
+
|
|
268
|
+
with open(f"{self.base_dir}/out/{self.root}.out", "a") as f:
|
|
269
|
+
if "ERR" == self.expected_output[str(idx)]:
|
|
270
|
+
# The ERR in this case is expected, so it is a success
|
|
271
|
+
logger.info(f"{idx=} : 🟢 OK")
|
|
272
|
+
f.write(f"{idx=} : 🟢 OK ")
|
|
273
|
+
error_message = ""
|
|
274
|
+
else:
|
|
275
|
+
logger.info(f"{idx=} : 🔴 FAIL")
|
|
276
|
+
f.write(f"{idx=} : 🔴 FAIL")
|
|
277
|
+
|
|
278
|
+
f.write("<SQL>\n")
|
|
279
|
+
f.write(s)
|
|
280
|
+
f.write("\n</SQL>\n")
|
|
281
|
+
f.write(f"<Error message>\n")
|
|
282
|
+
f.write(error_message)
|
|
283
|
+
f.write("\n</Error message>\n")
|
|
284
|
+
f.write("\n")
|
|
285
|
+
|
|
286
|
+
with open(f"{self.base_dir}/out/{self.root}.json", "w") as f:
|
|
287
|
+
f.write(json.dumps(actual, indent=4))
|
|
288
|
+
|
|
289
|
+
new_attempts = state.get("attempts", 0) + 1
|
|
290
|
+
return {
|
|
291
|
+
"validation_error": error_message,
|
|
292
|
+
"attempts": new_attempts,
|
|
293
|
+
"history": state.get("history", [])
|
|
294
|
+
+ [
|
|
295
|
+
{
|
|
296
|
+
"attempt": state["attempts"],
|
|
297
|
+
"status": "Validated",
|
|
298
|
+
"error": error_message,
|
|
299
|
+
}
|
|
300
|
+
],
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
def langgraph_it(self, oracle_sp: str):
|
|
304
|
+
# Conceptual Graph Structure
|
|
305
|
+
|
|
306
|
+
graph_builder = StateGraph(ConversionState)
|
|
307
|
+
|
|
308
|
+
# 1. Start with the Generator
|
|
309
|
+
graph_builder.add_node("generator", self.generate_code)
|
|
310
|
+
|
|
311
|
+
# 2. After generation, always validate
|
|
312
|
+
graph_builder.add_node("validator", self.validate_code)
|
|
313
|
+
graph_builder.add_edge("generator", "validator")
|
|
314
|
+
|
|
315
|
+
# 3. Add the Refiner node
|
|
316
|
+
graph_builder.add_node("refiner", self.refine)
|
|
317
|
+
|
|
318
|
+
# 4. Define the conditional edge (The Loop)
|
|
319
|
+
def should_continue(state: ConversionState):
|
|
320
|
+
if state["validation_error"]:
|
|
321
|
+
if state["attempts"] >= 3:
|
|
322
|
+
# max attempts reached
|
|
323
|
+
logger.error(
|
|
324
|
+
f"❌ Failed conversion for {self.root}: max attempt reached."
|
|
325
|
+
)
|
|
326
|
+
return "end"
|
|
327
|
+
# Code failed validation
|
|
328
|
+
logger.warning(f"⚠️ Validation failed. Re-attempt conversion...")
|
|
329
|
+
return "refiner"
|
|
330
|
+
else:
|
|
331
|
+
# Success: end the graph
|
|
332
|
+
logger.info(f"✅ Successful conversion for {self.root}")
|
|
333
|
+
return "end"
|
|
334
|
+
|
|
335
|
+
graph_builder.add_conditional_edges(
|
|
336
|
+
"validator", # Start the condition check from the validator node
|
|
337
|
+
should_continue, # The function that makes the decision
|
|
338
|
+
{"refiner": "refiner", "end": END},
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# 5. Complete the loop
|
|
342
|
+
graph_builder.add_edge("refiner", "validator")
|
|
343
|
+
graph_builder.set_entry_point("generator")
|
|
344
|
+
|
|
345
|
+
# Compile and run the app
|
|
346
|
+
app = graph_builder.compile()
|
|
347
|
+
|
|
348
|
+
# Example Invocation:
|
|
349
|
+
final_output = app.invoke(
|
|
350
|
+
{"oracle_code": oracle_sp, "max_attempts": 3, "attempts": 0}
|
|
351
|
+
)
|
|
352
|
+
return final_output
|
|
353
|
+
|
|
354
|
+
def to_jsonable(self, obj: Any) -> Any:
|
|
355
|
+
"""Best-effort conversion for non-JSON types (Decimal, UUID, datetime, etc.)."""
|
|
356
|
+
# psycopg will usually deliver Python-native types; stringify unknowns
|
|
357
|
+
try:
|
|
358
|
+
json.dumps(obj)
|
|
359
|
+
return obj
|
|
360
|
+
except (TypeError, ValueError):
|
|
361
|
+
return str(obj)
|
|
362
|
+
|
|
363
|
+
def execute_sql_stmts(self, stmts: list[str]):
|
|
364
|
+
try:
|
|
365
|
+
with psycopg.connect(self.uri, autocommit=True) as conn:
|
|
366
|
+
with conn.cursor() as cur:
|
|
367
|
+
|
|
368
|
+
for s in stmts:
|
|
369
|
+
cur.execute(s)
|
|
370
|
+
logger.info(f"🪳 🟢 {cur.statusmessage}")
|
|
371
|
+
|
|
372
|
+
except Exception as e:
|
|
373
|
+
logger.error(f"🪳 🔴 {str(e)}")
|
|
374
|
+
|
|
375
|
+
def run(self) -> list[dict]:
|
|
376
|
+
|
|
377
|
+
roots = []
|
|
378
|
+
if self.root is None:
|
|
379
|
+
# process all the files
|
|
380
|
+
roots = [
|
|
381
|
+
os.path.splitext(f)[0]
|
|
382
|
+
for f in os.listdir(os.path.join(self.base_dir, "in/"))
|
|
383
|
+
if os.path.isfile(os.path.join(self.base_dir, "in", f))
|
|
384
|
+
and f.lower().endswith(".ddl")
|
|
385
|
+
]
|
|
386
|
+
else:
|
|
387
|
+
roots.append(self.root)
|
|
388
|
+
|
|
389
|
+
for root in roots:
|
|
390
|
+
self.convert(root)
|
|
391
|
+
|
|
392
|
+
def convert(self, root: str) -> list[dict]:
|
|
393
|
+
|
|
394
|
+
logger.info(f"🚀 Processing {root=}")
|
|
395
|
+
|
|
396
|
+
self.root = root
|
|
397
|
+
|
|
398
|
+
# create or override the out file
|
|
399
|
+
with open(f"{self.base_dir}/out/{root}.out", "w") as f:
|
|
400
|
+
pass
|
|
401
|
+
|
|
402
|
+
if not os.path.exists(f"{os.path.join(self.base_dir,'in/', root)}.json"):
|
|
403
|
+
logger.error(f"💾 ❌ Couldn't find expected output file {root}.json")
|
|
404
|
+
return
|
|
405
|
+
|
|
406
|
+
if not os.path.exists(f"{os.path.join(self.base_dir, 'in/', root)}.sql"):
|
|
407
|
+
logger.error(f"💾 ❌ Couldn't find test statements file {root}.sql")
|
|
408
|
+
return
|
|
409
|
+
|
|
410
|
+
with open(f"{self.base_dir}/in/{root}.ddl", "r") as f:
|
|
411
|
+
oracle_sp = f.read()
|
|
412
|
+
|
|
413
|
+
with open(f"{self.base_dir}/in/{root}.json", "r") as f:
|
|
414
|
+
self.expected_output = json.loads(f.read())
|
|
415
|
+
|
|
416
|
+
with open(f"{self.base_dir}/in/{root}.sql", "r") as f:
|
|
417
|
+
|
|
418
|
+
# separate seed files from seed statements from test statements
|
|
419
|
+
|
|
420
|
+
txt = f.read()
|
|
421
|
+
|
|
422
|
+
# separate seed file, if any
|
|
423
|
+
if "betwixt_file_end" in txt:
|
|
424
|
+
seed_files = [
|
|
425
|
+
x
|
|
426
|
+
for x in txt.split("betwixt_file_end")[0].split("\n")
|
|
427
|
+
if x.strip().endswith(".sql")
|
|
428
|
+
]
|
|
429
|
+
|
|
430
|
+
for s in seed_files:
|
|
431
|
+
if not os.path.isfile(os.path.join(self.base_dir, "in", s)):
|
|
432
|
+
logger.warning(f"💾 ⚠️ File <{s}> not found")
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
with open(os.path.join(self.base_dir, "in", s), "r") as f:
|
|
436
|
+
self.seed_statements += sqlparse.split(f.read())
|
|
437
|
+
|
|
438
|
+
# remove file lines
|
|
439
|
+
txt = txt.split("betwixt_file_end")[1]
|
|
440
|
+
|
|
441
|
+
# separate individual seeding SQL statements
|
|
442
|
+
if "betwixt_seed_end" in txt:
|
|
443
|
+
self.seed_statements += sqlparse.split(txt.split("betwixt_seed_end")[0])
|
|
444
|
+
|
|
445
|
+
# remove sql stmts lines
|
|
446
|
+
txt = txt.split("betwixt_seed_end")[1]
|
|
447
|
+
|
|
448
|
+
# the remaining SQL statements are TEST statements
|
|
449
|
+
self.test_statements = sqlparse.split(txt)
|
|
450
|
+
|
|
451
|
+
if len(self.expected_output.keys()) != len(self.test_statements):
|
|
452
|
+
logger.error(
|
|
453
|
+
f"❌ Expected output and Test statement count should match. "
|
|
454
|
+
f"Expected count={len(self.expected_output)}, "
|
|
455
|
+
f"Statement count={len(self.test_statements)}"
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
with open(f"{self.base_dir}/out/{root}.out", "a") as f:
|
|
459
|
+
f.write(
|
|
460
|
+
f"Expected output and Test statement count should match. "
|
|
461
|
+
f"Expected count={len(self.expected_output)}, ",
|
|
462
|
+
f"Statement count={len(self.test_statements)}",
|
|
463
|
+
)
|
|
464
|
+
f.write("\n\n")
|
|
465
|
+
|
|
466
|
+
return
|
|
467
|
+
|
|
468
|
+
answer = self.langgraph_it(oracle_sp)
|
|
469
|
+
|
|
470
|
+
logger.info(f"💾 Saved output to file out/{root}.out")
|
|
471
|
+
|
|
472
|
+
with open(f"{self.base_dir}/out/{root}.ai.yaml", "w") as f:
|
|
473
|
+
f.write(yaml.safe_dump(answer))
|
|
474
|
+
logger.info(f"💾 Saved AI answer to file out/{root}.ai.yaml")
|
|
475
|
+
|
|
476
|
+
with open(f"{self.base_dir}/out/{root}.ddl", "w") as f:
|
|
477
|
+
f.write(answer["converted_code"])
|
|
478
|
+
logger.info(f"💾 Saved converted code to file out/{root}.ddl")
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
# TODO llmlingua the prompt to save tokens
|
|
482
|
+
# TODO improve prompts syntax
|