testgenie-py 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. testgen/__init__.py +0 -0
  2. testgen/analyzer/__init__.py +0 -0
  3. testgen/analyzer/ast_analyzer.py +149 -0
  4. testgen/analyzer/contracts/__init__.py +0 -0
  5. testgen/analyzer/contracts/contract.py +13 -0
  6. testgen/analyzer/contracts/no_exception_contract.py +16 -0
  7. testgen/analyzer/contracts/nonnull_contract.py +15 -0
  8. testgen/analyzer/fuzz_analyzer.py +106 -0
  9. testgen/analyzer/random_feedback_analyzer.py +291 -0
  10. testgen/analyzer/reinforcement_analyzer.py +75 -0
  11. testgen/analyzer/test_case_analyzer.py +46 -0
  12. testgen/analyzer/test_case_analyzer_context.py +58 -0
  13. testgen/controller/__init__.py +0 -0
  14. testgen/controller/cli_controller.py +194 -0
  15. testgen/controller/docker_controller.py +169 -0
  16. testgen/docker/Dockerfile +22 -0
  17. testgen/docker/poetry.lock +361 -0
  18. testgen/docker/pyproject.toml +22 -0
  19. testgen/generator/__init__.py +0 -0
  20. testgen/generator/code_generator.py +66 -0
  21. testgen/generator/doctest_generator.py +208 -0
  22. testgen/generator/generator.py +55 -0
  23. testgen/generator/pytest_generator.py +77 -0
  24. testgen/generator/test_generator.py +26 -0
  25. testgen/generator/unit_test_generator.py +84 -0
  26. testgen/inspector/__init__.py +0 -0
  27. testgen/inspector/inspector.py +61 -0
  28. testgen/main.py +13 -0
  29. testgen/models/__init__.py +0 -0
  30. testgen/models/analysis_context.py +56 -0
  31. testgen/models/function_metadata.py +61 -0
  32. testgen/models/generator_context.py +63 -0
  33. testgen/models/test_case.py +8 -0
  34. testgen/presentation/__init__.py +0 -0
  35. testgen/presentation/cli_view.py +12 -0
  36. testgen/q_table/global_q_table.json +1 -0
  37. testgen/reinforcement/__init__.py +0 -0
  38. testgen/reinforcement/abstract_state.py +7 -0
  39. testgen/reinforcement/agent.py +153 -0
  40. testgen/reinforcement/environment.py +215 -0
  41. testgen/reinforcement/statement_coverage_state.py +33 -0
  42. testgen/service/__init__.py +0 -0
  43. testgen/service/analysis_service.py +260 -0
  44. testgen/service/cfg_service.py +55 -0
  45. testgen/service/generator_service.py +169 -0
  46. testgen/service/service.py +389 -0
  47. testgen/sqlite/__init__.py +0 -0
  48. testgen/sqlite/db.py +84 -0
  49. testgen/sqlite/db_service.py +219 -0
  50. testgen/tree/__init__.py +0 -0
  51. testgen/tree/node.py +7 -0
  52. testgen/tree/tree_utils.py +79 -0
  53. testgen/util/__init__.py +0 -0
  54. testgen/util/coverage_utils.py +168 -0
  55. testgen/util/coverage_visualizer.py +154 -0
  56. testgen/util/file_utils.py +110 -0
  57. testgen/util/randomizer.py +122 -0
  58. testgen/util/utils.py +143 -0
  59. testgen/util/z3_utils/__init__.py +0 -0
  60. testgen/util/z3_utils/ast_to_z3.py +99 -0
  61. testgen/util/z3_utils/branch_condition.py +72 -0
  62. testgen/util/z3_utils/constraint_extractor.py +36 -0
  63. testgen/util/z3_utils/variable_finder.py +10 -0
  64. testgen/util/z3_utils/z3_test_case.py +94 -0
  65. testgenie_py-0.1.0.dist-info/METADATA +24 -0
  66. testgenie_py-0.1.0.dist-info/RECORD +68 -0
  67. testgenie_py-0.1.0.dist-info/WHEEL +4 -0
  68. testgenie_py-0.1.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,389 @@
1
+ import inspect
2
+ import json
3
+ import os
4
+ import re
5
+ import sqlite3
6
+ import time
7
+ import subprocess
8
+ from types import ModuleType
9
+ from typing import List
10
+ import testgen.util.file_utils as file_utils
11
+
12
+ from testgen.models.test_case import TestCase
13
+ from testgen.service.analysis_service import AnalysisService
14
+ from testgen.service.generator_service import GeneratorService
15
+ from testgen.sqlite.db_service import DBService
16
+ from testgen.models.analysis_context import AnalysisContext
17
+ from testgen.util.coverage_visualizer import CoverageVisualizer
18
+
19
+ # Constants for test strategies
20
+ AST_STRAT = 1
21
+ FUZZ_STRAT = 2
22
+ RANDOM_STRAT = 3
23
+ REINFORCE_STRAT = 4
24
+
25
+ # Constants for test formats
26
+ UNITTEST_FORMAT = 1
27
+ PYTEST_FORMAT = 2
28
+ DOCTEST_FORMAT = 3
29
+
30
+ class Service:
31
+ def __init__(self):
32
+ self.test_strategy: int = 0
33
+ self.test_format: int = 0
34
+ self.file_path = None
35
+ self.generated_file_path = None
36
+ self.class_name = None
37
+ self.test_cases = []
38
+ self.reinforcement_mode = "train"
39
+
40
+ # Initialize specialized services
41
+ self.analysis_service = AnalysisService()
42
+ self.generator_service = GeneratorService(None, None, None)
43
+ self.db_service = DBService()
44
+
45
+ def select_all_from_db(self) -> None:
46
+ rows = self.db_service.get_test_suites()
47
+ for row in rows:
48
+ print(repr(dict(row)))
49
+
50
+ def generate_tests(self, output_path=None):
51
+ """Generate tests for a class or module."""
52
+ module = file_utils.load_module(self.file_path)
53
+ class_name = self.analysis_service.get_class_name(module)
54
+
55
+ if self.test_strategy == AST_STRAT:
56
+ self.generated_file_path = self.generate_function_code()
57
+ Service.wait_for_file(self.generated_file_path)
58
+ self.analysis_service.set_file_path(self.generated_file_path)
59
+ module = file_utils.load_module(self.generated_file_path)
60
+ class_name = self.analysis_service.get_class_name(module)
61
+ self.analysis_service.set_test_strategy(self.test_strategy, module.__name__, self.class_name)
62
+
63
+ else:
64
+ self.analysis_service.set_test_strategy(self.test_strategy, module.__name__, self.class_name)
65
+
66
+ test_cases: List[TestCase] = []
67
+ if self.test_strategy == REINFORCE_STRAT:
68
+ test_cases = self.analysis_service.do_reinforcement_learning(self.file_path, self.reinforcement_mode)
69
+ else:
70
+ test_cases = self.analysis_service.generate_test_cases()
71
+
72
+ self.test_cases = test_cases
73
+
74
+ file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
75
+ self.db_service.save_test_generation_data(
76
+ file_path_to_use,
77
+ test_cases,
78
+ self.test_strategy,
79
+ class_name
80
+ )
81
+
82
+ if os.environ.get("RUNNING_IN_DOCKER") is not None:
83
+ print(f"Serializing test cases {test_cases}")
84
+ self.serialize_test_cases(test_cases)
85
+ return None # Exit early in analysis-only mode
86
+
87
+ test_file = self.generate_test_file(test_cases, output_path, module, class_name)
88
+
89
+ # Ensure the test file is ready
90
+ Service.wait_for_file(test_file)
91
+ return test_file
92
+
93
+ def generate_test_file(self, test_cases: List[TestCase], output_path: str | None = None,
94
+ module: ModuleType | None = None, class_name: str | None = None) -> str:
95
+ if module is None:
96
+ module = file_utils.load_module(self.file_path)
97
+ class_name = self.analysis_service.get_class_name(module)
98
+
99
+ if self.test_strategy == AST_STRAT:
100
+ self.generator_service = GeneratorService(self.generated_file_path, output_path, self.test_format)
101
+ self.generator_service.set_test_format(self.test_format)
102
+ module = file_utils.load_module(self.generated_file_path)
103
+ else:
104
+ # Create the correct instance of the generator service
105
+ self.generator_service = GeneratorService(self.file_path, output_path, self.test_format)
106
+ self.generator_service.set_test_format(self.test_format)
107
+
108
+ test_file = self.generator_service.generate_test_file(
109
+ module,
110
+ class_name,
111
+ test_cases,
112
+ output_path
113
+ )
114
+
115
+ # Ensure the test file is ready
116
+ Service.wait_for_file(test_file)
117
+ return test_file
118
+
119
+ def generate_function_code(self):
120
+ """Generate function code for a given class or module."""
121
+ module = file_utils.load_module(self.file_path)
122
+ class_name = self.analysis_service.get_class_name(module)
123
+ functions = self.inspect_class(class_name)
124
+ return self.generator_service.generate_function_code(self.file_path, class_name, functions)
125
+
126
+ def run_coverage(self, test_file):
127
+ """Run coverage analysis on the generated tests."""
128
+ Service.wait_for_file(test_file)
129
+ file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
130
+ coverage_output = ""
131
+
132
+ try:
133
+ if self.test_format == UNITTEST_FORMAT:
134
+ subprocess.run(["python", "-m", "coverage", "run", "--source=.", "-m", "unittest", test_file], check=True)
135
+ result = subprocess.run(
136
+ ["python", "-m", "coverage", "report", file_path_to_use],
137
+ check=True,
138
+ capture_output=True,
139
+ text=True
140
+ )
141
+ coverage_output = result.stdout
142
+ elif self.test_format == PYTEST_FORMAT:
143
+ subprocess.run(["python", "-m", "coverage", "run", "--source=.", "-m", "pytest", test_file], check=True)
144
+ result = subprocess.run(
145
+ ["python", "-m", "coverage", "report", file_path_to_use],
146
+ check=True,
147
+ capture_output=True,
148
+ text=True
149
+ )
150
+ coverage_output = result.stdout
151
+ elif self.test_format == DOCTEST_FORMAT:
152
+ result = subprocess.run(
153
+ ["python", "-m", "coverage", "run", "--source=.", "-m", "doctest", "-v", test_file],
154
+ check=True,
155
+ capture_output=True,
156
+ text=True
157
+ )
158
+ coverage_output = result.stdout
159
+ else:
160
+ raise ValueError("Unsupported test format for coverage analysis.")
161
+
162
+ self._save_coverage_data(coverage_output, file_path_to_use)
163
+
164
+ except subprocess.CalledProcessError as e:
165
+ raise RuntimeError(f"Error running coverage subprocess: {e}")
166
+
167
+ def _save_coverage_data(self, coverage_output, file_path):
168
+ """Parse coverage output and save to database."""
169
+ try:
170
+ lines = coverage_output.strip().split('\n')
171
+ for line in lines:
172
+ if file_path in line:
173
+ parts = line.split()
174
+ if len(parts) >= 4:
175
+ file_name = os.path.basename(file_path)
176
+ try:
177
+ total_lines = int(parts[-3])
178
+ missed_lines = int(parts[-2])
179
+ executed_lines = total_lines - missed_lines
180
+ coverage_str = parts[-1].strip('%')
181
+ branch_coverage = float(coverage_str) / 100
182
+
183
+ source_file_id = self._get_source_file_id(file_path)
184
+
185
+ self.db_service.insert_coverage_data(
186
+ file_name,
187
+ executed_lines,
188
+ missed_lines,
189
+ branch_coverage,
190
+ source_file_id
191
+ )
192
+ break
193
+ except (ValueError, IndexError) as e:
194
+ print(f"Error parsing coverage data: {e}")
195
+ except Exception as e:
196
+ print(f"Error saving coverage data: {e}")
197
+
198
+ def _get_source_file_id(self, file_path):
199
+ """Helper to get source file ID from DB."""
200
+ conn = sqlite3.connect(self.db_service.db_name)
201
+ conn.row_factory = sqlite3.Row
202
+ cursor = conn.cursor()
203
+
204
+ cursor.execute("SELECT id FROM SourceFile WHERE path = ?", (file_path,))
205
+ row = cursor.fetchone()
206
+
207
+ conn.close()
208
+
209
+ if row:
210
+ return row[0]
211
+ else:
212
+ with open(file_path, 'r') as f:
213
+ lines_of_code = len(f.readlines())
214
+ return self.db_service.insert_source_file(file_path, lines_of_code)
215
+
216
+ def serialize_test_cases(self, test_cases):
217
+ """Serialize input arguments ot JSON-compatible format"""
218
+ print("##TEST_CASES_BEGIN##")
219
+ serialized = []
220
+ for tc in test_cases:
221
+ case_data = {
222
+ "func_name": tc.func_name,
223
+ "inputs": self.serialize_value(tc.inputs),
224
+ "expected": self.serialize_value(tc.expected)
225
+ }
226
+ serialized.append(case_data)
227
+ print(json.dumps(serialized))
228
+ print("##TEST_CASES_END##")
229
+
230
+ def serialize_value(self, value):
231
+ """Serialize a single value to a JSON-compatible format"""
232
+ if value is None or isinstance(value, (bool, int, float, str)):
233
+ return value
234
+ elif isinstance(value, (list, tuple)):
235
+ return [self.serialize_value(v) for v in value]
236
+ else:
237
+ return str(value)
238
+
239
+ @staticmethod
240
+ def parse_test_cases_from_logs(logs_output):
241
+ """Extract and parse test cases from container logs"""
242
+ pattern = r"##TEST_CASES_BEGIN##\n(.*?)\n##TEST_CASES_END##"
243
+ match = re.search(pattern, logs_output, re.DOTALL)
244
+
245
+ if not match:
246
+ raise ValueError("Could not find test cases in the container logs")
247
+
248
+ test_cases_json = match.group(1).strip()
249
+ test_cases_data = json.loads(test_cases_json)
250
+
251
+ test_cases: List[TestCase] = []
252
+ for tc_data in test_cases_data:
253
+ test_case = TestCase(
254
+ func_name=tc_data["func_name"],
255
+ inputs=tc_data["inputs"],
256
+ expected=tc_data["expected"]
257
+ )
258
+ test_cases.append(test_case)
259
+
260
+ return test_cases
261
+
262
+ def set_file_path(self, path: str):
263
+ """Set the file path for analysis and validate it."""
264
+ if os.path.isfile(path) and path.endswith(".py"):
265
+ self.file_path = path
266
+ self.analysis_service.set_file_path(path)
267
+ else:
268
+ raise ValueError("Invalid file path! Please provide a valid Python file path.")
269
+
270
+ def set_class_name(self, class_name: str):
271
+ """Set the class name to analyze."""
272
+ self.class_name = class_name
273
+
274
+ def set_test_generator_format(self, test_format: int):
275
+ """Set the test generator format."""
276
+ self.test_format = test_format
277
+ self.generator_service.set_test_format(test_format)
278
+
279
+ def set_test_analysis_strategy(self, strategy: int):
280
+ """Set the test analysis strategy."""
281
+ self.test_strategy = strategy
282
+ module = file_utils.load_module(self.file_path)
283
+ self.analysis_service.set_test_strategy(strategy, module.__name__, self.class_name)
284
+
285
+ def get_analysis_context(self, filepath: str) -> AnalysisContext:
286
+ """Create an analysis context for the given file."""
287
+ return self.analysis_service.create_analysis_context(filepath)
288
+
289
+ @staticmethod
290
+ def wait_for_file(file_path, retries=5, delay=1):
291
+ """Wait for the generated file to appear."""
292
+ while retries > 0 and not os.path.exists(file_path):
293
+ time.sleep(delay)
294
+ retries -= 1
295
+ if not os.path.exists(file_path):
296
+ raise FileNotFoundError(f"File '{file_path}' not found after waiting.")
297
+
298
+ def get_full_import_path(self) -> str:
299
+ """Get the full import path for the current file."""
300
+ package_root = self.find_package_root()
301
+ if not package_root:
302
+ raise ImportError(f"Could not determine the package root for {self.file_path}.")
303
+
304
+ module_path = os.path.abspath(self.file_path)
305
+ rel_path = os.path.relpath(module_path, package_root)
306
+ package_path = rel_path.replace(os.sep, ".")
307
+
308
+ if package_path.endswith(".py"):
309
+ package_path = package_path[:-3]
310
+
311
+ package_name = os.path.basename(package_root)
312
+ if not package_path.startswith(package_name + "."):
313
+ package_path = package_name + "." + package_path
314
+
315
+ return package_path
316
+
317
+ def find_package_root(self):
318
+ """Find the package root directory."""
319
+ current_dir = os.path.abspath(os.path.dirname(self.file_path))
320
+ last_valid = None
321
+
322
+ while current_dir:
323
+ if "__init__.py" in os.listdir(current_dir):
324
+ last_valid = current_dir
325
+ else:
326
+ break
327
+
328
+ parent_dir = os.path.dirname(current_dir)
329
+ if parent_dir == current_dir:
330
+ break
331
+ current_dir = parent_dir
332
+
333
+ return last_valid
334
+
335
+ def inspect_class(self, class_name):
336
+ """Inspect a class or module and return its functions."""
337
+ module = file_utils.load_module(self.file_path)
338
+
339
+ # Handle module-level functions when class_name is None
340
+ if not class_name:
341
+ # Get module-level functions
342
+ functions = inspect.getmembers(module, inspect.isfunction)
343
+ return functions
344
+
345
+ # Handle class functions
346
+ cls = getattr(module, class_name, None)
347
+ if cls is None:
348
+ raise ValueError(f"Class '{class_name}' not found in module '{module.__name__}'.")
349
+
350
+ functions = inspect.getmembers(cls, inspect.isfunction)
351
+ return functions
352
+
353
+ @staticmethod
354
+ def resolve_module_path(module_name):
355
+ """Resolve a module name to its file path by checking multiple locations."""
356
+ direct_path = f"/controller/{module_name}.py"
357
+ if os.path.exists(direct_path):
358
+ print(f"Found module at {direct_path}")
359
+ return direct_path
360
+
361
+ testgen_path = f"/controller/testgen/{module_name}.py"
362
+ if os.path.exists(testgen_path):
363
+ print(f"Found module at {testgen_path}")
364
+ return testgen_path
365
+
366
+ if '.' in module_name:
367
+ parts = module_name.split('.')
368
+ potential_path = os.path.join('/controller', *parts) + '.py'
369
+ if os.path.exists(potential_path):
370
+ print(f"Found module at {potential_path}")
371
+ return potential_path
372
+
373
+ print(f"Could not find module: {module_name}")
374
+ return None
375
+
376
+ def visualize_test_coverage(self):
377
+ from testgen.service.cfg_service import CFGService
378
+ cfg_service = CFGService()
379
+ cfg_service.initialize_visualizer(self)
380
+
381
+ return cfg_service.visualize_test_coverage(
382
+ file_path=self.file_path,
383
+ test_cases=self.test_cases,
384
+ )
385
+
386
+ def set_reinforcement_mode(self, mode: str):
387
+ self.reinforcement_mode = mode
388
+ if hasattr(self ,'analysis_service'):
389
+ self.analysis_service.set_reinforcement_mode(mode)
File without changes
testgen/sqlite/db.py ADDED
@@ -0,0 +1,84 @@
1
+ import sqlite3
2
+
3
+ def create_database(db_name="testgen.db"):
4
+ conn = sqlite3.connect(db_name)
5
+ cursor = conn.cursor()
6
+
7
+ # Enable foreign key constraints
8
+ cursor.execute("PRAGMA foreign_keys = ON;")
9
+
10
+ # Create tables
11
+ cursor.execute("""
12
+ CREATE TABLE IF NOT EXISTS TestSuite (
13
+ id INTEGER PRIMARY KEY,
14
+ name TEXT,
15
+ creation_date TIMESTAMP
16
+ );
17
+ """)
18
+
19
+ cursor.execute("""
20
+ CREATE TABLE IF NOT EXISTS SourceFile (
21
+ id INTEGER PRIMARY KEY,
22
+ path TEXT,
23
+ lines_of_code INTEGER,
24
+ last_modified TIMESTAMP
25
+ );
26
+ """)
27
+
28
+ cursor.execute("""
29
+ CREATE TABLE IF NOT EXISTS Function (
30
+ id INTEGER PRIMARY KEY,
31
+ name TEXT,
32
+ start_line INTEGER,
33
+ end_line INTEGER,
34
+ num_lines INTEGER,
35
+ source_file_id INTEGER,
36
+ FOREIGN KEY (source_file_id) REFERENCES SourceFile(id)
37
+ );
38
+ """)
39
+
40
+ cursor.execute("""
41
+ CREATE TABLE IF NOT EXISTS TestCase (
42
+ id INTEGER PRIMARY KEY,
43
+ name TEXT,
44
+ expected_output TEXT, -- storing JSON as TEXT
45
+ input TEXT, -- storing JSON as TEXT
46
+ test_function TEXT,
47
+ last_run_time TIMESTAMP,
48
+ test_method_type INTEGER,
49
+ test_suite_id INTEGER,
50
+ function_id INTEGER,
51
+ FOREIGN KEY (test_suite_id) REFERENCES TestSuite(id),
52
+ FOREIGN KEY (function_id) REFERENCES Function(id)
53
+ );
54
+ """)
55
+
56
+ cursor.execute("""
57
+ CREATE TABLE IF NOT EXISTS TestResult (
58
+ id INTEGER PRIMARY KEY,
59
+ test_case_id INTEGER,
60
+ status BOOLEAN,
61
+ error TEXT,
62
+ execution_time TIMESTAMP,
63
+ FOREIGN KEY (test_case_id) REFERENCES TestCase(id)
64
+ );
65
+ """)
66
+
67
+ cursor.execute("""
68
+ CREATE TABLE IF NOT EXISTS CoverageData (
69
+ id INTEGER PRIMARY KEY,
70
+ file_name TEXT,
71
+ executed_lines INTEGER,
72
+ missed_lines INTEGER,
73
+ branch_coverage REAL,
74
+ source_file_id INTEGER,
75
+ FOREIGN KEY (source_file_id) REFERENCES SourceFile(id)
76
+ );
77
+ """)
78
+
79
+ conn.commit()
80
+ conn.close()
81
+ print(f"Database '{db_name}' created successfully with all tables.")
82
+
83
+ if __name__ == "__main__":
84
+ create_database()
@@ -0,0 +1,219 @@
1
+ import os
2
+ import sqlite3
3
+ import json
4
+ import time
5
+ import ast
6
+ from datetime import datetime
7
+ from typing import List, Tuple
8
+
9
+ from testgen.models.test_case import TestCase
10
+ from testgen.sqlite.db import create_database
11
+
12
+ class DBService:
13
+ def __init__(self, db_name="testgen.db"):
14
+ """Initialize database service with connection to specified database."""
15
+ self.db_name = db_name
16
+ self.conn = None
17
+ self.cursor = None
18
+ self._connect()
19
+
20
+ def _connect(self):
21
+ """Establish connection to the database."""
22
+ if not os.path.exists(self.db_name):
23
+ create_database(self.db_name)
24
+
25
+ self.conn = sqlite3.connect(self.db_name)
26
+ self.conn.row_factory = sqlite3.Row
27
+ self.cursor = self.conn.cursor()
28
+ # Enable foreign keys
29
+ self.cursor.execute("PRAGMA foreign_keys = ON;")
30
+
31
+ def close(self):
32
+ """Close the database connection."""
33
+ if self.conn:
34
+ self.conn.close()
35
+ self.conn = None
36
+ self.cursor = None
37
+
38
+ def insert_test_suite(self, name: str) -> int:
39
+ """Insert a test suite and return its ID."""
40
+ self.cursor.execute(
41
+ "INSERT INTO TestSuite (name, creation_date) VALUES (?, ?)",
42
+ (name, datetime.now())
43
+ )
44
+ self.conn.commit()
45
+ return self.cursor.lastrowid
46
+
47
+ def insert_source_file(self, path: str, lines_of_code: int) -> int:
48
+ """Insert a source file and return its ID."""
49
+ # Check if file already exists
50
+ self.cursor.execute("SELECT id FROM SourceFile WHERE path = ?", (path,))
51
+ existing = self.cursor.fetchone()
52
+ if existing:
53
+ return existing[0]
54
+
55
+ self.cursor.execute(
56
+ "INSERT INTO SourceFile (path, lines_of_code, last_modified) VALUES (?, ?, ?)",
57
+ (path, lines_of_code, datetime.now())
58
+ )
59
+ self.conn.commit()
60
+ return self.cursor.lastrowid
61
+
62
+ def insert_function(self, name: str, start_line: int, end_line: int, source_file_id: int) -> int:
63
+ """Insert a function and return its ID."""
64
+ num_lines = end_line - start_line + 1
65
+
66
+ # Check if function already exists for this source file
67
+ self.cursor.execute(
68
+ "SELECT id FROM Function WHERE name = ? AND source_file_id = ?",
69
+ (name, source_file_id)
70
+ )
71
+ existing = self.cursor.fetchone()
72
+ if existing:
73
+ return existing[0]
74
+
75
+ self.cursor.execute(
76
+ "INSERT INTO Function (name, start_line, end_line, num_lines, source_file_id) VALUES (?, ?, ?, ?, ?)",
77
+ (name, start_line, end_line, num_lines, source_file_id)
78
+ )
79
+ self.conn.commit()
80
+ return self.cursor.lastrowid
81
+
82
+ def insert_test_case(self, test_case: TestCase, test_suite_id: int, function_id: int, test_method_type: int) -> int:
83
+ """Insert a test case and return its ID."""
84
+ # Convert inputs and expected output to JSON strings
85
+ inputs_json = json.dumps(test_case.inputs)
86
+ expected_json = json.dumps(test_case.expected)
87
+
88
+ self.cursor.execute(
89
+ "INSERT INTO TestCase (name, expected_output, input, test_function, last_run_time, test_method_type, test_suite_id, function_id) "
90
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
91
+ (
92
+ f"test_{test_case.func_name}",
93
+ expected_json,
94
+ inputs_json,
95
+ test_case.func_name,
96
+ datetime.now(),
97
+ test_method_type,
98
+ test_suite_id,
99
+ function_id
100
+ )
101
+ )
102
+ self.conn.commit()
103
+ return self.cursor.lastrowid
104
+
105
+ def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
106
+ """Insert a test result and return its ID."""
107
+ self.cursor.execute(
108
+ "INSERT INTO TestResult (test_case_id, status, error, execution_time) VALUES (?, ?, ?, ?)",
109
+ (test_case_id, status, error, datetime.now())
110
+ )
111
+ self.conn.commit()
112
+ return self.cursor.lastrowid
113
+
114
+ def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
115
+ branch_coverage: float, source_file_id: int) -> int:
116
+ """Insert coverage data and return its ID."""
117
+ self.cursor.execute(
118
+ "INSERT INTO CoverageData (file_name, executed_lines, missed_lines, branch_coverage, source_file_id) "
119
+ "VALUES (?, ?, ?, ?, ?)",
120
+ (file_name, executed_lines, missed_lines, branch_coverage, source_file_id)
121
+ )
122
+ self.conn.commit()
123
+ return self.cursor.lastrowid
124
+
125
+ def save_test_generation_data(self, file_path: str, test_cases: List[TestCase],
126
+ test_method_type: int, class_name: str = None) -> Tuple[int, List[int]]:
127
+ """
128
+ Save all data related to a test generation run.
129
+ Returns the test suite ID and a list of test case IDs.
130
+ """
131
+ # Count lines in the source file
132
+ with open(file_path, 'r') as f:
133
+ lines_of_code = len(f.readlines())
134
+
135
+ # Create test suite
136
+ strategy_names = {1: "AST", 2: "Fuzz", 3: "Random", 4: "Reinforcement"}
137
+ suite_name = f"{strategy_names.get(test_method_type, 'Unknown')}_Suite_{int(time.time())}"
138
+ test_suite_id = self.insert_test_suite(suite_name)
139
+
140
+ # Insert source file
141
+ source_file_id = self.insert_source_file(file_path, lines_of_code)
142
+
143
+ # Process functions and test cases
144
+ test_case_ids = []
145
+ function_ids = {} # Cache function IDs to avoid redundant queries
146
+
147
+ for test_case in test_cases:
148
+ # Extract function name from test case
149
+ func_name = test_case.func_name
150
+
151
+ # If function not already processed
152
+ if func_name not in function_ids:
153
+ # Get function line numbers
154
+ start_line, end_line = self._get_function_line_numbers(file_path, func_name)
155
+ function_id = self.insert_function(func_name, start_line, end_line, source_file_id)
156
+ function_ids[func_name] = function_id
157
+
158
+ # Insert test case
159
+ test_case_id = self.insert_test_case(
160
+ test_case,
161
+ test_suite_id,
162
+ function_ids[func_name],
163
+ test_method_type
164
+ )
165
+ test_case_ids.append(test_case_id)
166
+
167
+ return test_suite_id, test_case_ids
168
+
169
+ def _get_function_line_numbers(self, file_path: str, function_name: str) -> Tuple[int, int]:
170
+ """
171
+ Extract the start and end line numbers for a function in a file.
172
+ Returns a tuple of (start_line, end_line).
173
+ """
174
+ try:
175
+ # Load the file and parse it
176
+ with open(file_path, 'r') as f:
177
+ file_content = f.read()
178
+
179
+ tree = ast.parse(file_content)
180
+
181
+ # Find the function definition
182
+ for node in ast.walk(tree):
183
+ if isinstance(node, ast.FunctionDef) and node.name == function_name:
184
+ end_line = node.end_lineno if hasattr(node, 'end_lineno') else node.lineno + 5 # Estimate if end_lineno not available
185
+ return node.lineno, end_line
186
+
187
+ # Also look for class methods
188
+ for node in ast.walk(tree):
189
+ if isinstance(node, ast.ClassDef):
190
+ for class_node in node.body:
191
+ if isinstance(class_node, ast.FunctionDef) and class_node.name == function_name:
192
+ end_line = class_node.end_lineno if hasattr(class_node, 'end_lineno') else class_node.lineno + 5
193
+ return class_node.lineno, end_line
194
+ except Exception as e:
195
+ print(f"Error getting function line numbers: {e}")
196
+
197
+ # If we reach here, the function wasn't found or there was an error
198
+ return 0, 0
199
+
200
+ def get_test_suites(self):
201
+ """Get all test suites from the database."""
202
+ self.cursor.execute("SELECT * FROM TestSuite ORDER BY creation_date DESC")
203
+ return self.cursor.fetchall()
204
+
205
+ def get_test_cases_by_function(self, function_name):
206
+ """Get all test cases for a specific function."""
207
+ self.cursor.execute(
208
+ "SELECT tc.* FROM TestCase tc JOIN Function f ON tc.function_id = f.id WHERE f.name = ?",
209
+ (function_name,)
210
+ )
211
+ return self.cursor.fetchall()
212
+
213
+ def get_coverage_by_file(self, file_path):
214
+ """Get coverage data for a specific file."""
215
+ self.cursor.execute(
216
+ "SELECT cd.* FROM CoverageData cd JOIN SourceFile sf ON cd.source_file_id = sf.id WHERE sf.path = ?",
217
+ (file_path,)
218
+ )
219
+ return self.cursor.fetchall()
File without changes