testgenie-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/__init__.py +0 -0
- testgen/analyzer/__init__.py +0 -0
- testgen/analyzer/ast_analyzer.py +149 -0
- testgen/analyzer/contracts/__init__.py +0 -0
- testgen/analyzer/contracts/contract.py +13 -0
- testgen/analyzer/contracts/no_exception_contract.py +16 -0
- testgen/analyzer/contracts/nonnull_contract.py +15 -0
- testgen/analyzer/fuzz_analyzer.py +106 -0
- testgen/analyzer/random_feedback_analyzer.py +291 -0
- testgen/analyzer/reinforcement_analyzer.py +75 -0
- testgen/analyzer/test_case_analyzer.py +46 -0
- testgen/analyzer/test_case_analyzer_context.py +58 -0
- testgen/controller/__init__.py +0 -0
- testgen/controller/cli_controller.py +194 -0
- testgen/controller/docker_controller.py +169 -0
- testgen/docker/Dockerfile +22 -0
- testgen/docker/poetry.lock +361 -0
- testgen/docker/pyproject.toml +22 -0
- testgen/generator/__init__.py +0 -0
- testgen/generator/code_generator.py +66 -0
- testgen/generator/doctest_generator.py +208 -0
- testgen/generator/generator.py +55 -0
- testgen/generator/pytest_generator.py +77 -0
- testgen/generator/test_generator.py +26 -0
- testgen/generator/unit_test_generator.py +84 -0
- testgen/inspector/__init__.py +0 -0
- testgen/inspector/inspector.py +61 -0
- testgen/main.py +13 -0
- testgen/models/__init__.py +0 -0
- testgen/models/analysis_context.py +56 -0
- testgen/models/function_metadata.py +61 -0
- testgen/models/generator_context.py +63 -0
- testgen/models/test_case.py +8 -0
- testgen/presentation/__init__.py +0 -0
- testgen/presentation/cli_view.py +12 -0
- testgen/q_table/global_q_table.json +1 -0
- testgen/reinforcement/__init__.py +0 -0
- testgen/reinforcement/abstract_state.py +7 -0
- testgen/reinforcement/agent.py +153 -0
- testgen/reinforcement/environment.py +215 -0
- testgen/reinforcement/statement_coverage_state.py +33 -0
- testgen/service/__init__.py +0 -0
- testgen/service/analysis_service.py +260 -0
- testgen/service/cfg_service.py +55 -0
- testgen/service/generator_service.py +169 -0
- testgen/service/service.py +389 -0
- testgen/sqlite/__init__.py +0 -0
- testgen/sqlite/db.py +84 -0
- testgen/sqlite/db_service.py +219 -0
- testgen/tree/__init__.py +0 -0
- testgen/tree/node.py +7 -0
- testgen/tree/tree_utils.py +79 -0
- testgen/util/__init__.py +0 -0
- testgen/util/coverage_utils.py +168 -0
- testgen/util/coverage_visualizer.py +154 -0
- testgen/util/file_utils.py +110 -0
- testgen/util/randomizer.py +122 -0
- testgen/util/utils.py +143 -0
- testgen/util/z3_utils/__init__.py +0 -0
- testgen/util/z3_utils/ast_to_z3.py +99 -0
- testgen/util/z3_utils/branch_condition.py +72 -0
- testgen/util/z3_utils/constraint_extractor.py +36 -0
- testgen/util/z3_utils/variable_finder.py +10 -0
- testgen/util/z3_utils/z3_test_case.py +94 -0
- testgenie_py-0.1.0.dist-info/METADATA +24 -0
- testgenie_py-0.1.0.dist-info/RECORD +68 -0
- testgenie_py-0.1.0.dist-info/WHEEL +4 -0
- testgenie_py-0.1.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,389 @@
|
|
1
|
+
import inspect
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
import sqlite3
|
6
|
+
import time
|
7
|
+
import subprocess
|
8
|
+
from types import ModuleType
|
9
|
+
from typing import List
|
10
|
+
import testgen.util.file_utils as file_utils
|
11
|
+
|
12
|
+
from testgen.models.test_case import TestCase
|
13
|
+
from testgen.service.analysis_service import AnalysisService
|
14
|
+
from testgen.service.generator_service import GeneratorService
|
15
|
+
from testgen.sqlite.db_service import DBService
|
16
|
+
from testgen.models.analysis_context import AnalysisContext
|
17
|
+
from testgen.util.coverage_visualizer import CoverageVisualizer
|
18
|
+
|
19
|
+
# Constants for test strategies
|
20
|
+
AST_STRAT = 1
|
21
|
+
FUZZ_STRAT = 2
|
22
|
+
RANDOM_STRAT = 3
|
23
|
+
REINFORCE_STRAT = 4
|
24
|
+
|
25
|
+
# Constants for test formats
|
26
|
+
UNITTEST_FORMAT = 1
|
27
|
+
PYTEST_FORMAT = 2
|
28
|
+
DOCTEST_FORMAT = 3
|
29
|
+
|
30
|
+
class Service:
|
31
|
+
def __init__(self):
|
32
|
+
self.test_strategy: int = 0
|
33
|
+
self.test_format: int = 0
|
34
|
+
self.file_path = None
|
35
|
+
self.generated_file_path = None
|
36
|
+
self.class_name = None
|
37
|
+
self.test_cases = []
|
38
|
+
self.reinforcement_mode = "train"
|
39
|
+
|
40
|
+
# Initialize specialized services
|
41
|
+
self.analysis_service = AnalysisService()
|
42
|
+
self.generator_service = GeneratorService(None, None, None)
|
43
|
+
self.db_service = DBService()
|
44
|
+
|
45
|
+
def select_all_from_db(self) -> None:
|
46
|
+
rows = self.db_service.get_test_suites()
|
47
|
+
for row in rows:
|
48
|
+
print(repr(dict(row)))
|
49
|
+
|
50
|
+
def generate_tests(self, output_path=None):
|
51
|
+
"""Generate tests for a class or module."""
|
52
|
+
module = file_utils.load_module(self.file_path)
|
53
|
+
class_name = self.analysis_service.get_class_name(module)
|
54
|
+
|
55
|
+
if self.test_strategy == AST_STRAT:
|
56
|
+
self.generated_file_path = self.generate_function_code()
|
57
|
+
Service.wait_for_file(self.generated_file_path)
|
58
|
+
self.analysis_service.set_file_path(self.generated_file_path)
|
59
|
+
module = file_utils.load_module(self.generated_file_path)
|
60
|
+
class_name = self.analysis_service.get_class_name(module)
|
61
|
+
self.analysis_service.set_test_strategy(self.test_strategy, module.__name__, self.class_name)
|
62
|
+
|
63
|
+
else:
|
64
|
+
self.analysis_service.set_test_strategy(self.test_strategy, module.__name__, self.class_name)
|
65
|
+
|
66
|
+
test_cases: List[TestCase] = []
|
67
|
+
if self.test_strategy == REINFORCE_STRAT:
|
68
|
+
test_cases = self.analysis_service.do_reinforcement_learning(self.file_path, self.reinforcement_mode)
|
69
|
+
else:
|
70
|
+
test_cases = self.analysis_service.generate_test_cases()
|
71
|
+
|
72
|
+
self.test_cases = test_cases
|
73
|
+
|
74
|
+
file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
|
75
|
+
self.db_service.save_test_generation_data(
|
76
|
+
file_path_to_use,
|
77
|
+
test_cases,
|
78
|
+
self.test_strategy,
|
79
|
+
class_name
|
80
|
+
)
|
81
|
+
|
82
|
+
if os.environ.get("RUNNING_IN_DOCKER") is not None:
|
83
|
+
print(f"Serializing test cases {test_cases}")
|
84
|
+
self.serialize_test_cases(test_cases)
|
85
|
+
return None # Exit early in analysis-only mode
|
86
|
+
|
87
|
+
test_file = self.generate_test_file(test_cases, output_path, module, class_name)
|
88
|
+
|
89
|
+
# Ensure the test file is ready
|
90
|
+
Service.wait_for_file(test_file)
|
91
|
+
return test_file
|
92
|
+
|
93
|
+
def generate_test_file(self, test_cases: List[TestCase], output_path: str | None = None,
|
94
|
+
module: ModuleType | None = None, class_name: str | None = None) -> str:
|
95
|
+
if module is None:
|
96
|
+
module = file_utils.load_module(self.file_path)
|
97
|
+
class_name = self.analysis_service.get_class_name(module)
|
98
|
+
|
99
|
+
if self.test_strategy == AST_STRAT:
|
100
|
+
self.generator_service = GeneratorService(self.generated_file_path, output_path, self.test_format)
|
101
|
+
self.generator_service.set_test_format(self.test_format)
|
102
|
+
module = file_utils.load_module(self.generated_file_path)
|
103
|
+
else:
|
104
|
+
# Create the correct instance of the generator service
|
105
|
+
self.generator_service = GeneratorService(self.file_path, output_path, self.test_format)
|
106
|
+
self.generator_service.set_test_format(self.test_format)
|
107
|
+
|
108
|
+
test_file = self.generator_service.generate_test_file(
|
109
|
+
module,
|
110
|
+
class_name,
|
111
|
+
test_cases,
|
112
|
+
output_path
|
113
|
+
)
|
114
|
+
|
115
|
+
# Ensure the test file is ready
|
116
|
+
Service.wait_for_file(test_file)
|
117
|
+
return test_file
|
118
|
+
|
119
|
+
def generate_function_code(self):
|
120
|
+
"""Generate function code for a given class or module."""
|
121
|
+
module = file_utils.load_module(self.file_path)
|
122
|
+
class_name = self.analysis_service.get_class_name(module)
|
123
|
+
functions = self.inspect_class(class_name)
|
124
|
+
return self.generator_service.generate_function_code(self.file_path, class_name, functions)
|
125
|
+
|
126
|
+
def run_coverage(self, test_file):
|
127
|
+
"""Run coverage analysis on the generated tests."""
|
128
|
+
Service.wait_for_file(test_file)
|
129
|
+
file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
|
130
|
+
coverage_output = ""
|
131
|
+
|
132
|
+
try:
|
133
|
+
if self.test_format == UNITTEST_FORMAT:
|
134
|
+
subprocess.run(["python", "-m", "coverage", "run", "--source=.", "-m", "unittest", test_file], check=True)
|
135
|
+
result = subprocess.run(
|
136
|
+
["python", "-m", "coverage", "report", file_path_to_use],
|
137
|
+
check=True,
|
138
|
+
capture_output=True,
|
139
|
+
text=True
|
140
|
+
)
|
141
|
+
coverage_output = result.stdout
|
142
|
+
elif self.test_format == PYTEST_FORMAT:
|
143
|
+
subprocess.run(["python", "-m", "coverage", "run", "--source=.", "-m", "pytest", test_file], check=True)
|
144
|
+
result = subprocess.run(
|
145
|
+
["python", "-m", "coverage", "report", file_path_to_use],
|
146
|
+
check=True,
|
147
|
+
capture_output=True,
|
148
|
+
text=True
|
149
|
+
)
|
150
|
+
coverage_output = result.stdout
|
151
|
+
elif self.test_format == DOCTEST_FORMAT:
|
152
|
+
result = subprocess.run(
|
153
|
+
["python", "-m", "coverage", "run", "--source=.", "-m", "doctest", "-v", test_file],
|
154
|
+
check=True,
|
155
|
+
capture_output=True,
|
156
|
+
text=True
|
157
|
+
)
|
158
|
+
coverage_output = result.stdout
|
159
|
+
else:
|
160
|
+
raise ValueError("Unsupported test format for coverage analysis.")
|
161
|
+
|
162
|
+
self._save_coverage_data(coverage_output, file_path_to_use)
|
163
|
+
|
164
|
+
except subprocess.CalledProcessError as e:
|
165
|
+
raise RuntimeError(f"Error running coverage subprocess: {e}")
|
166
|
+
|
167
|
+
def _save_coverage_data(self, coverage_output, file_path):
|
168
|
+
"""Parse coverage output and save to database."""
|
169
|
+
try:
|
170
|
+
lines = coverage_output.strip().split('\n')
|
171
|
+
for line in lines:
|
172
|
+
if file_path in line:
|
173
|
+
parts = line.split()
|
174
|
+
if len(parts) >= 4:
|
175
|
+
file_name = os.path.basename(file_path)
|
176
|
+
try:
|
177
|
+
total_lines = int(parts[-3])
|
178
|
+
missed_lines = int(parts[-2])
|
179
|
+
executed_lines = total_lines - missed_lines
|
180
|
+
coverage_str = parts[-1].strip('%')
|
181
|
+
branch_coverage = float(coverage_str) / 100
|
182
|
+
|
183
|
+
source_file_id = self._get_source_file_id(file_path)
|
184
|
+
|
185
|
+
self.db_service.insert_coverage_data(
|
186
|
+
file_name,
|
187
|
+
executed_lines,
|
188
|
+
missed_lines,
|
189
|
+
branch_coverage,
|
190
|
+
source_file_id
|
191
|
+
)
|
192
|
+
break
|
193
|
+
except (ValueError, IndexError) as e:
|
194
|
+
print(f"Error parsing coverage data: {e}")
|
195
|
+
except Exception as e:
|
196
|
+
print(f"Error saving coverage data: {e}")
|
197
|
+
|
198
|
+
def _get_source_file_id(self, file_path):
|
199
|
+
"""Helper to get source file ID from DB."""
|
200
|
+
conn = sqlite3.connect(self.db_service.db_name)
|
201
|
+
conn.row_factory = sqlite3.Row
|
202
|
+
cursor = conn.cursor()
|
203
|
+
|
204
|
+
cursor.execute("SELECT id FROM SourceFile WHERE path = ?", (file_path,))
|
205
|
+
row = cursor.fetchone()
|
206
|
+
|
207
|
+
conn.close()
|
208
|
+
|
209
|
+
if row:
|
210
|
+
return row[0]
|
211
|
+
else:
|
212
|
+
with open(file_path, 'r') as f:
|
213
|
+
lines_of_code = len(f.readlines())
|
214
|
+
return self.db_service.insert_source_file(file_path, lines_of_code)
|
215
|
+
|
216
|
+
def serialize_test_cases(self, test_cases):
|
217
|
+
"""Serialize input arguments ot JSON-compatible format"""
|
218
|
+
print("##TEST_CASES_BEGIN##")
|
219
|
+
serialized = []
|
220
|
+
for tc in test_cases:
|
221
|
+
case_data = {
|
222
|
+
"func_name": tc.func_name,
|
223
|
+
"inputs": self.serialize_value(tc.inputs),
|
224
|
+
"expected": self.serialize_value(tc.expected)
|
225
|
+
}
|
226
|
+
serialized.append(case_data)
|
227
|
+
print(json.dumps(serialized))
|
228
|
+
print("##TEST_CASES_END##")
|
229
|
+
|
230
|
+
def serialize_value(self, value):
|
231
|
+
"""Serialize a single value to a JSON-compatible format"""
|
232
|
+
if value is None or isinstance(value, (bool, int, float, str)):
|
233
|
+
return value
|
234
|
+
elif isinstance(value, (list, tuple)):
|
235
|
+
return [self.serialize_value(v) for v in value]
|
236
|
+
else:
|
237
|
+
return str(value)
|
238
|
+
|
239
|
+
@staticmethod
|
240
|
+
def parse_test_cases_from_logs(logs_output):
|
241
|
+
"""Extract and parse test cases from container logs"""
|
242
|
+
pattern = r"##TEST_CASES_BEGIN##\n(.*?)\n##TEST_CASES_END##"
|
243
|
+
match = re.search(pattern, logs_output, re.DOTALL)
|
244
|
+
|
245
|
+
if not match:
|
246
|
+
raise ValueError("Could not find test cases in the container logs")
|
247
|
+
|
248
|
+
test_cases_json = match.group(1).strip()
|
249
|
+
test_cases_data = json.loads(test_cases_json)
|
250
|
+
|
251
|
+
test_cases: List[TestCase] = []
|
252
|
+
for tc_data in test_cases_data:
|
253
|
+
test_case = TestCase(
|
254
|
+
func_name=tc_data["func_name"],
|
255
|
+
inputs=tc_data["inputs"],
|
256
|
+
expected=tc_data["expected"]
|
257
|
+
)
|
258
|
+
test_cases.append(test_case)
|
259
|
+
|
260
|
+
return test_cases
|
261
|
+
|
262
|
+
def set_file_path(self, path: str):
|
263
|
+
"""Set the file path for analysis and validate it."""
|
264
|
+
if os.path.isfile(path) and path.endswith(".py"):
|
265
|
+
self.file_path = path
|
266
|
+
self.analysis_service.set_file_path(path)
|
267
|
+
else:
|
268
|
+
raise ValueError("Invalid file path! Please provide a valid Python file path.")
|
269
|
+
|
270
|
+
def set_class_name(self, class_name: str):
|
271
|
+
"""Set the class name to analyze."""
|
272
|
+
self.class_name = class_name
|
273
|
+
|
274
|
+
def set_test_generator_format(self, test_format: int):
|
275
|
+
"""Set the test generator format."""
|
276
|
+
self.test_format = test_format
|
277
|
+
self.generator_service.set_test_format(test_format)
|
278
|
+
|
279
|
+
def set_test_analysis_strategy(self, strategy: int):
|
280
|
+
"""Set the test analysis strategy."""
|
281
|
+
self.test_strategy = strategy
|
282
|
+
module = file_utils.load_module(self.file_path)
|
283
|
+
self.analysis_service.set_test_strategy(strategy, module.__name__, self.class_name)
|
284
|
+
|
285
|
+
def get_analysis_context(self, filepath: str) -> AnalysisContext:
|
286
|
+
"""Create an analysis context for the given file."""
|
287
|
+
return self.analysis_service.create_analysis_context(filepath)
|
288
|
+
|
289
|
+
@staticmethod
|
290
|
+
def wait_for_file(file_path, retries=5, delay=1):
|
291
|
+
"""Wait for the generated file to appear."""
|
292
|
+
while retries > 0 and not os.path.exists(file_path):
|
293
|
+
time.sleep(delay)
|
294
|
+
retries -= 1
|
295
|
+
if not os.path.exists(file_path):
|
296
|
+
raise FileNotFoundError(f"File '{file_path}' not found after waiting.")
|
297
|
+
|
298
|
+
def get_full_import_path(self) -> str:
|
299
|
+
"""Get the full import path for the current file."""
|
300
|
+
package_root = self.find_package_root()
|
301
|
+
if not package_root:
|
302
|
+
raise ImportError(f"Could not determine the package root for {self.file_path}.")
|
303
|
+
|
304
|
+
module_path = os.path.abspath(self.file_path)
|
305
|
+
rel_path = os.path.relpath(module_path, package_root)
|
306
|
+
package_path = rel_path.replace(os.sep, ".")
|
307
|
+
|
308
|
+
if package_path.endswith(".py"):
|
309
|
+
package_path = package_path[:-3]
|
310
|
+
|
311
|
+
package_name = os.path.basename(package_root)
|
312
|
+
if not package_path.startswith(package_name + "."):
|
313
|
+
package_path = package_name + "." + package_path
|
314
|
+
|
315
|
+
return package_path
|
316
|
+
|
317
|
+
def find_package_root(self):
|
318
|
+
"""Find the package root directory."""
|
319
|
+
current_dir = os.path.abspath(os.path.dirname(self.file_path))
|
320
|
+
last_valid = None
|
321
|
+
|
322
|
+
while current_dir:
|
323
|
+
if "__init__.py" in os.listdir(current_dir):
|
324
|
+
last_valid = current_dir
|
325
|
+
else:
|
326
|
+
break
|
327
|
+
|
328
|
+
parent_dir = os.path.dirname(current_dir)
|
329
|
+
if parent_dir == current_dir:
|
330
|
+
break
|
331
|
+
current_dir = parent_dir
|
332
|
+
|
333
|
+
return last_valid
|
334
|
+
|
335
|
+
def inspect_class(self, class_name):
|
336
|
+
"""Inspect a class or module and return its functions."""
|
337
|
+
module = file_utils.load_module(self.file_path)
|
338
|
+
|
339
|
+
# Handle module-level functions when class_name is None
|
340
|
+
if not class_name:
|
341
|
+
# Get module-level functions
|
342
|
+
functions = inspect.getmembers(module, inspect.isfunction)
|
343
|
+
return functions
|
344
|
+
|
345
|
+
# Handle class functions
|
346
|
+
cls = getattr(module, class_name, None)
|
347
|
+
if cls is None:
|
348
|
+
raise ValueError(f"Class '{class_name}' not found in module '{module.__name__}'.")
|
349
|
+
|
350
|
+
functions = inspect.getmembers(cls, inspect.isfunction)
|
351
|
+
return functions
|
352
|
+
|
353
|
+
@staticmethod
|
354
|
+
def resolve_module_path(module_name):
|
355
|
+
"""Resolve a module name to its file path by checking multiple locations."""
|
356
|
+
direct_path = f"/controller/{module_name}.py"
|
357
|
+
if os.path.exists(direct_path):
|
358
|
+
print(f"Found module at {direct_path}")
|
359
|
+
return direct_path
|
360
|
+
|
361
|
+
testgen_path = f"/controller/testgen/{module_name}.py"
|
362
|
+
if os.path.exists(testgen_path):
|
363
|
+
print(f"Found module at {testgen_path}")
|
364
|
+
return testgen_path
|
365
|
+
|
366
|
+
if '.' in module_name:
|
367
|
+
parts = module_name.split('.')
|
368
|
+
potential_path = os.path.join('/controller', *parts) + '.py'
|
369
|
+
if os.path.exists(potential_path):
|
370
|
+
print(f"Found module at {potential_path}")
|
371
|
+
return potential_path
|
372
|
+
|
373
|
+
print(f"Could not find module: {module_name}")
|
374
|
+
return None
|
375
|
+
|
376
|
+
def visualize_test_coverage(self):
|
377
|
+
from testgen.service.cfg_service import CFGService
|
378
|
+
cfg_service = CFGService()
|
379
|
+
cfg_service.initialize_visualizer(self)
|
380
|
+
|
381
|
+
return cfg_service.visualize_test_coverage(
|
382
|
+
file_path=self.file_path,
|
383
|
+
test_cases=self.test_cases,
|
384
|
+
)
|
385
|
+
|
386
|
+
def set_reinforcement_mode(self, mode: str):
|
387
|
+
self.reinforcement_mode = mode
|
388
|
+
if hasattr(self ,'analysis_service'):
|
389
|
+
self.analysis_service.set_reinforcement_mode(mode)
|
File without changes
|
testgen/sqlite/db.py
ADDED
@@ -0,0 +1,84 @@
|
|
1
|
+
import sqlite3
|
2
|
+
|
3
|
+
def create_database(db_name="testgen.db"):
|
4
|
+
conn = sqlite3.connect(db_name)
|
5
|
+
cursor = conn.cursor()
|
6
|
+
|
7
|
+
# Enable foreign key constraints
|
8
|
+
cursor.execute("PRAGMA foreign_keys = ON;")
|
9
|
+
|
10
|
+
# Create tables
|
11
|
+
cursor.execute("""
|
12
|
+
CREATE TABLE IF NOT EXISTS TestSuite (
|
13
|
+
id INTEGER PRIMARY KEY,
|
14
|
+
name TEXT,
|
15
|
+
creation_date TIMESTAMP
|
16
|
+
);
|
17
|
+
""")
|
18
|
+
|
19
|
+
cursor.execute("""
|
20
|
+
CREATE TABLE IF NOT EXISTS SourceFile (
|
21
|
+
id INTEGER PRIMARY KEY,
|
22
|
+
path TEXT,
|
23
|
+
lines_of_code INTEGER,
|
24
|
+
last_modified TIMESTAMP
|
25
|
+
);
|
26
|
+
""")
|
27
|
+
|
28
|
+
cursor.execute("""
|
29
|
+
CREATE TABLE IF NOT EXISTS Function (
|
30
|
+
id INTEGER PRIMARY KEY,
|
31
|
+
name TEXT,
|
32
|
+
start_line INTEGER,
|
33
|
+
end_line INTEGER,
|
34
|
+
num_lines INTEGER,
|
35
|
+
source_file_id INTEGER,
|
36
|
+
FOREIGN KEY (source_file_id) REFERENCES SourceFile(id)
|
37
|
+
);
|
38
|
+
""")
|
39
|
+
|
40
|
+
cursor.execute("""
|
41
|
+
CREATE TABLE IF NOT EXISTS TestCase (
|
42
|
+
id INTEGER PRIMARY KEY,
|
43
|
+
name TEXT,
|
44
|
+
expected_output TEXT, -- storing JSON as TEXT
|
45
|
+
input TEXT, -- storing JSON as TEXT
|
46
|
+
test_function TEXT,
|
47
|
+
last_run_time TIMESTAMP,
|
48
|
+
test_method_type INTEGER,
|
49
|
+
test_suite_id INTEGER,
|
50
|
+
function_id INTEGER,
|
51
|
+
FOREIGN KEY (test_suite_id) REFERENCES TestSuite(id),
|
52
|
+
FOREIGN KEY (function_id) REFERENCES Function(id)
|
53
|
+
);
|
54
|
+
""")
|
55
|
+
|
56
|
+
cursor.execute("""
|
57
|
+
CREATE TABLE IF NOT EXISTS TestResult (
|
58
|
+
id INTEGER PRIMARY KEY,
|
59
|
+
test_case_id INTEGER,
|
60
|
+
status BOOLEAN,
|
61
|
+
error TEXT,
|
62
|
+
execution_time TIMESTAMP,
|
63
|
+
FOREIGN KEY (test_case_id) REFERENCES TestCase(id)
|
64
|
+
);
|
65
|
+
""")
|
66
|
+
|
67
|
+
cursor.execute("""
|
68
|
+
CREATE TABLE IF NOT EXISTS CoverageData (
|
69
|
+
id INTEGER PRIMARY KEY,
|
70
|
+
file_name TEXT,
|
71
|
+
executed_lines INTEGER,
|
72
|
+
missed_lines INTEGER,
|
73
|
+
branch_coverage REAL,
|
74
|
+
source_file_id INTEGER,
|
75
|
+
FOREIGN KEY (source_file_id) REFERENCES SourceFile(id)
|
76
|
+
);
|
77
|
+
""")
|
78
|
+
|
79
|
+
conn.commit()
|
80
|
+
conn.close()
|
81
|
+
print(f"Database '{db_name}' created successfully with all tables.")
|
82
|
+
|
83
|
+
if __name__ == "__main__":
|
84
|
+
create_database()
|
@@ -0,0 +1,219 @@
|
|
1
|
+
import os
|
2
|
+
import sqlite3
|
3
|
+
import json
|
4
|
+
import time
|
5
|
+
import ast
|
6
|
+
from datetime import datetime
|
7
|
+
from typing import List, Tuple
|
8
|
+
|
9
|
+
from testgen.models.test_case import TestCase
|
10
|
+
from testgen.sqlite.db import create_database
|
11
|
+
|
12
|
+
class DBService:
|
13
|
+
def __init__(self, db_name="testgen.db"):
|
14
|
+
"""Initialize database service with connection to specified database."""
|
15
|
+
self.db_name = db_name
|
16
|
+
self.conn = None
|
17
|
+
self.cursor = None
|
18
|
+
self._connect()
|
19
|
+
|
20
|
+
def _connect(self):
|
21
|
+
"""Establish connection to the database."""
|
22
|
+
if not os.path.exists(self.db_name):
|
23
|
+
create_database(self.db_name)
|
24
|
+
|
25
|
+
self.conn = sqlite3.connect(self.db_name)
|
26
|
+
self.conn.row_factory = sqlite3.Row
|
27
|
+
self.cursor = self.conn.cursor()
|
28
|
+
# Enable foreign keys
|
29
|
+
self.cursor.execute("PRAGMA foreign_keys = ON;")
|
30
|
+
|
31
|
+
def close(self):
|
32
|
+
"""Close the database connection."""
|
33
|
+
if self.conn:
|
34
|
+
self.conn.close()
|
35
|
+
self.conn = None
|
36
|
+
self.cursor = None
|
37
|
+
|
38
|
+
def insert_test_suite(self, name: str) -> int:
|
39
|
+
"""Insert a test suite and return its ID."""
|
40
|
+
self.cursor.execute(
|
41
|
+
"INSERT INTO TestSuite (name, creation_date) VALUES (?, ?)",
|
42
|
+
(name, datetime.now())
|
43
|
+
)
|
44
|
+
self.conn.commit()
|
45
|
+
return self.cursor.lastrowid
|
46
|
+
|
47
|
+
def insert_source_file(self, path: str, lines_of_code: int) -> int:
|
48
|
+
"""Insert a source file and return its ID."""
|
49
|
+
# Check if file already exists
|
50
|
+
self.cursor.execute("SELECT id FROM SourceFile WHERE path = ?", (path,))
|
51
|
+
existing = self.cursor.fetchone()
|
52
|
+
if existing:
|
53
|
+
return existing[0]
|
54
|
+
|
55
|
+
self.cursor.execute(
|
56
|
+
"INSERT INTO SourceFile (path, lines_of_code, last_modified) VALUES (?, ?, ?)",
|
57
|
+
(path, lines_of_code, datetime.now())
|
58
|
+
)
|
59
|
+
self.conn.commit()
|
60
|
+
return self.cursor.lastrowid
|
61
|
+
|
62
|
+
def insert_function(self, name: str, start_line: int, end_line: int, source_file_id: int) -> int:
|
63
|
+
"""Insert a function and return its ID."""
|
64
|
+
num_lines = end_line - start_line + 1
|
65
|
+
|
66
|
+
# Check if function already exists for this source file
|
67
|
+
self.cursor.execute(
|
68
|
+
"SELECT id FROM Function WHERE name = ? AND source_file_id = ?",
|
69
|
+
(name, source_file_id)
|
70
|
+
)
|
71
|
+
existing = self.cursor.fetchone()
|
72
|
+
if existing:
|
73
|
+
return existing[0]
|
74
|
+
|
75
|
+
self.cursor.execute(
|
76
|
+
"INSERT INTO Function (name, start_line, end_line, num_lines, source_file_id) VALUES (?, ?, ?, ?, ?)",
|
77
|
+
(name, start_line, end_line, num_lines, source_file_id)
|
78
|
+
)
|
79
|
+
self.conn.commit()
|
80
|
+
return self.cursor.lastrowid
|
81
|
+
|
82
|
+
def insert_test_case(self, test_case: TestCase, test_suite_id: int, function_id: int, test_method_type: int) -> int:
|
83
|
+
"""Insert a test case and return its ID."""
|
84
|
+
# Convert inputs and expected output to JSON strings
|
85
|
+
inputs_json = json.dumps(test_case.inputs)
|
86
|
+
expected_json = json.dumps(test_case.expected)
|
87
|
+
|
88
|
+
self.cursor.execute(
|
89
|
+
"INSERT INTO TestCase (name, expected_output, input, test_function, last_run_time, test_method_type, test_suite_id, function_id) "
|
90
|
+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
91
|
+
(
|
92
|
+
f"test_{test_case.func_name}",
|
93
|
+
expected_json,
|
94
|
+
inputs_json,
|
95
|
+
test_case.func_name,
|
96
|
+
datetime.now(),
|
97
|
+
test_method_type,
|
98
|
+
test_suite_id,
|
99
|
+
function_id
|
100
|
+
)
|
101
|
+
)
|
102
|
+
self.conn.commit()
|
103
|
+
return self.cursor.lastrowid
|
104
|
+
|
105
|
+
def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
|
106
|
+
"""Insert a test result and return its ID."""
|
107
|
+
self.cursor.execute(
|
108
|
+
"INSERT INTO TestResult (test_case_id, status, error, execution_time) VALUES (?, ?, ?, ?)",
|
109
|
+
(test_case_id, status, error, datetime.now())
|
110
|
+
)
|
111
|
+
self.conn.commit()
|
112
|
+
return self.cursor.lastrowid
|
113
|
+
|
114
|
+
def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
|
115
|
+
branch_coverage: float, source_file_id: int) -> int:
|
116
|
+
"""Insert coverage data and return its ID."""
|
117
|
+
self.cursor.execute(
|
118
|
+
"INSERT INTO CoverageData (file_name, executed_lines, missed_lines, branch_coverage, source_file_id) "
|
119
|
+
"VALUES (?, ?, ?, ?, ?)",
|
120
|
+
(file_name, executed_lines, missed_lines, branch_coverage, source_file_id)
|
121
|
+
)
|
122
|
+
self.conn.commit()
|
123
|
+
return self.cursor.lastrowid
|
124
|
+
|
125
|
+
def save_test_generation_data(self, file_path: str, test_cases: List[TestCase],
|
126
|
+
test_method_type: int, class_name: str = None) -> Tuple[int, List[int]]:
|
127
|
+
"""
|
128
|
+
Save all data related to a test generation run.
|
129
|
+
Returns the test suite ID and a list of test case IDs.
|
130
|
+
"""
|
131
|
+
# Count lines in the source file
|
132
|
+
with open(file_path, 'r') as f:
|
133
|
+
lines_of_code = len(f.readlines())
|
134
|
+
|
135
|
+
# Create test suite
|
136
|
+
strategy_names = {1: "AST", 2: "Fuzz", 3: "Random", 4: "Reinforcement"}
|
137
|
+
suite_name = f"{strategy_names.get(test_method_type, 'Unknown')}_Suite_{int(time.time())}"
|
138
|
+
test_suite_id = self.insert_test_suite(suite_name)
|
139
|
+
|
140
|
+
# Insert source file
|
141
|
+
source_file_id = self.insert_source_file(file_path, lines_of_code)
|
142
|
+
|
143
|
+
# Process functions and test cases
|
144
|
+
test_case_ids = []
|
145
|
+
function_ids = {} # Cache function IDs to avoid redundant queries
|
146
|
+
|
147
|
+
for test_case in test_cases:
|
148
|
+
# Extract function name from test case
|
149
|
+
func_name = test_case.func_name
|
150
|
+
|
151
|
+
# If function not already processed
|
152
|
+
if func_name not in function_ids:
|
153
|
+
# Get function line numbers
|
154
|
+
start_line, end_line = self._get_function_line_numbers(file_path, func_name)
|
155
|
+
function_id = self.insert_function(func_name, start_line, end_line, source_file_id)
|
156
|
+
function_ids[func_name] = function_id
|
157
|
+
|
158
|
+
# Insert test case
|
159
|
+
test_case_id = self.insert_test_case(
|
160
|
+
test_case,
|
161
|
+
test_suite_id,
|
162
|
+
function_ids[func_name],
|
163
|
+
test_method_type
|
164
|
+
)
|
165
|
+
test_case_ids.append(test_case_id)
|
166
|
+
|
167
|
+
return test_suite_id, test_case_ids
|
168
|
+
|
169
|
+
def _get_function_line_numbers(self, file_path: str, function_name: str) -> Tuple[int, int]:
|
170
|
+
"""
|
171
|
+
Extract the start and end line numbers for a function in a file.
|
172
|
+
Returns a tuple of (start_line, end_line).
|
173
|
+
"""
|
174
|
+
try:
|
175
|
+
# Load the file and parse it
|
176
|
+
with open(file_path, 'r') as f:
|
177
|
+
file_content = f.read()
|
178
|
+
|
179
|
+
tree = ast.parse(file_content)
|
180
|
+
|
181
|
+
# Find the function definition
|
182
|
+
for node in ast.walk(tree):
|
183
|
+
if isinstance(node, ast.FunctionDef) and node.name == function_name:
|
184
|
+
end_line = node.end_lineno if hasattr(node, 'end_lineno') else node.lineno + 5 # Estimate if end_lineno not available
|
185
|
+
return node.lineno, end_line
|
186
|
+
|
187
|
+
# Also look for class methods
|
188
|
+
for node in ast.walk(tree):
|
189
|
+
if isinstance(node, ast.ClassDef):
|
190
|
+
for class_node in node.body:
|
191
|
+
if isinstance(class_node, ast.FunctionDef) and class_node.name == function_name:
|
192
|
+
end_line = class_node.end_lineno if hasattr(class_node, 'end_lineno') else class_node.lineno + 5
|
193
|
+
return class_node.lineno, end_line
|
194
|
+
except Exception as e:
|
195
|
+
print(f"Error getting function line numbers: {e}")
|
196
|
+
|
197
|
+
# If we reach here, the function wasn't found or there was an error
|
198
|
+
return 0, 0
|
199
|
+
|
200
|
+
def get_test_suites(self):
|
201
|
+
"""Get all test suites from the database."""
|
202
|
+
self.cursor.execute("SELECT * FROM TestSuite ORDER BY creation_date DESC")
|
203
|
+
return self.cursor.fetchall()
|
204
|
+
|
205
|
+
def get_test_cases_by_function(self, function_name):
|
206
|
+
"""Get all test cases for a specific function."""
|
207
|
+
self.cursor.execute(
|
208
|
+
"SELECT tc.* FROM TestCase tc JOIN Function f ON tc.function_id = f.id WHERE f.name = ?",
|
209
|
+
(function_name,)
|
210
|
+
)
|
211
|
+
return self.cursor.fetchall()
|
212
|
+
|
213
|
+
def get_coverage_by_file(self, file_path):
|
214
|
+
"""Get coverage data for a specific file."""
|
215
|
+
self.cursor.execute(
|
216
|
+
"SELECT cd.* FROM CoverageData cd JOIN SourceFile sf ON cd.source_file_id = sf.id WHERE sf.path = ?",
|
217
|
+
(file_path,)
|
218
|
+
)
|
219
|
+
return self.cursor.fetchall()
|
testgen/tree/__init__.py
ADDED
File without changes
|