testgenie-py 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/analyzer/ast_analyzer.py +2 -11
- testgen/analyzer/fuzz_analyzer.py +1 -6
- testgen/analyzer/random_feedback_analyzer.py +20 -293
- testgen/analyzer/reinforcement_analyzer.py +59 -57
- testgen/analyzer/test_case_analyzer_context.py +0 -6
- testgen/controller/cli_controller.py +35 -29
- testgen/controller/docker_controller.py +1 -0
- testgen/db/dao.py +68 -0
- testgen/db/dao_impl.py +226 -0
- testgen/{sqlite → db}/db.py +15 -6
- testgen/generator/pytest_generator.py +2 -10
- testgen/generator/unit_test_generator.py +2 -11
- testgen/main.py +1 -3
- testgen/models/coverage_data.py +56 -0
- testgen/models/db_test_case.py +65 -0
- testgen/models/function.py +56 -0
- testgen/models/function_metadata.py +11 -1
- testgen/models/generator_context.py +30 -3
- testgen/models/source_file.py +29 -0
- testgen/models/test_result.py +38 -0
- testgen/models/test_suite.py +20 -0
- testgen/reinforcement/agent.py +1 -27
- testgen/reinforcement/environment.py +11 -93
- testgen/reinforcement/statement_coverage_state.py +5 -4
- testgen/service/analysis_service.py +31 -22
- testgen/service/cfg_service.py +3 -1
- testgen/service/coverage_service.py +115 -0
- testgen/service/db_service.py +140 -0
- testgen/service/generator_service.py +77 -20
- testgen/service/logging_service.py +2 -2
- testgen/service/service.py +62 -231
- testgen/service/test_executor_service.py +145 -0
- testgen/util/coverage_utils.py +38 -116
- testgen/util/coverage_visualizer.py +10 -9
- testgen/util/file_utils.py +10 -111
- testgen/util/randomizer.py +0 -26
- testgen/util/utils.py +197 -38
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/METADATA +1 -1
- testgenie_py-0.3.9.dist-info/RECORD +72 -0
- testgen/inspector/inspector.py +0 -59
- testgen/presentation/__init__.py +0 -0
- testgen/presentation/cli_view.py +0 -12
- testgen/sqlite/__init__.py +0 -0
- testgen/sqlite/db_service.py +0 -239
- testgen/testgen.db +0 -0
- testgenie_py-0.3.7.dist-info/RECORD +0 -67
- /testgen/{inspector → db}/__init__.py +0 -0
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/WHEEL +0 -0
- {testgenie_py-0.3.7.dist-info → testgenie_py-0.3.9.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,115 @@
|
|
1
|
+
import os
|
2
|
+
import subprocess
|
3
|
+
import time
|
4
|
+
|
5
|
+
from testgen.models.coverage_data import CoverageData
|
6
|
+
from testgen.service.logging_service import get_logger
|
7
|
+
from testgen.service.db_service import DBService
|
8
|
+
|
9
|
+
class CoverageService:
|
10
|
+
UNITTEST_FORMAT = 1
|
11
|
+
PYTEST_FORMAT = 2
|
12
|
+
DOCTEST_FORMAT = 3
|
13
|
+
|
14
|
+
def __init__(self):
|
15
|
+
self.logger = get_logger()
|
16
|
+
|
17
|
+
def run_coverage(self, test_file: str, source_file: str) -> CoverageData:
|
18
|
+
#self.wait_for_file(test_file)
|
19
|
+
self.logger.info(f"Running coverage on test file: {test_file}")
|
20
|
+
self.logger.info(f"Source file to measure: {source_file}")
|
21
|
+
|
22
|
+
try:
|
23
|
+
cov_data = self.collect_coverage(test_file, source_file)
|
24
|
+
return cov_data
|
25
|
+
except Exception as e:
|
26
|
+
self.logger.error(f"Error running coverage: {str(e)}")
|
27
|
+
raise RuntimeError(f"Error running coverage subprocess: {e}")
|
28
|
+
|
29
|
+
def collect_coverage(self, test_file: str, source_file: str) -> CoverageData:
|
30
|
+
try:
|
31
|
+
subprocess.run(["python", "-m", "coverage", "run", "--source=.", test_file], check=True)
|
32
|
+
result = subprocess.run(
|
33
|
+
["python", "-m", "coverage", "report", source_file],
|
34
|
+
check=True,
|
35
|
+
capture_output=True,
|
36
|
+
text=True
|
37
|
+
)
|
38
|
+
coverage_output = result.stdout
|
39
|
+
subprocess.run(["python", "-m", "coverage", "json"], check=True)
|
40
|
+
return self.parse_coverage_data(coverage_output, source_file)
|
41
|
+
except subprocess.CalledProcessError as e:
|
42
|
+
self.logger.error(f"Error collecting coverage: {str(e)}")
|
43
|
+
return CoverageData(
|
44
|
+
coverage_type="file",
|
45
|
+
executed_lines=0,
|
46
|
+
missed_lines=0,
|
47
|
+
branch_coverage=0.0,
|
48
|
+
source_file_id=-1,
|
49
|
+
function_id=None
|
50
|
+
)
|
51
|
+
|
52
|
+
def parse_coverage_data(self, coverage_output: str, file_path: str) -> CoverageData:
|
53
|
+
lines = coverage_output.strip().split('\n')
|
54
|
+
executed_lines = missed_lines = total_lines = 0
|
55
|
+
branch_coverage = 0.0
|
56
|
+
|
57
|
+
if lines:
|
58
|
+
for line in lines:
|
59
|
+
if file_path in line:
|
60
|
+
parts = line.split()
|
61
|
+
if len(parts) >= 4:
|
62
|
+
try:
|
63
|
+
total_lines = int(parts[-3])
|
64
|
+
missed_lines = int(parts[-2])
|
65
|
+
executed_lines = total_lines - missed_lines
|
66
|
+
coverage_str = parts[-1].strip('%')
|
67
|
+
branch_coverage = float(coverage_str) / 100
|
68
|
+
break
|
69
|
+
except (ValueError, IndexError) as e:
|
70
|
+
self.logger.error(f"Error parsing coverage data: {e}")
|
71
|
+
|
72
|
+
# You may want to look up the source_file_id here if needed
|
73
|
+
source_file_id = -1 # Set this appropriately in your context
|
74
|
+
|
75
|
+
return CoverageData(
|
76
|
+
coverage_type="file",
|
77
|
+
executed_lines=executed_lines,
|
78
|
+
missed_lines=missed_lines,
|
79
|
+
branch_coverage=branch_coverage,
|
80
|
+
source_file_id=source_file_id,
|
81
|
+
function_id=None
|
82
|
+
)
|
83
|
+
|
84
|
+
# Save coverage data to database currently not working
|
85
|
+
""""
|
86
|
+
def save_coverage_data(self, db_service: DBService, coverage_data: CoverageData, file_path: str) -> None:
|
87
|
+
if db_service is None:
|
88
|
+
self.logger.debug("Skipping database operations - no DB service provided")
|
89
|
+
return
|
90
|
+
|
91
|
+
try:
|
92
|
+
source_file_id = db_service.get_source_file_id_by_path(file_path)
|
93
|
+
if source_file_id == -1:
|
94
|
+
self.logger.error(f"Source file not found in database: {file_path}")
|
95
|
+
return
|
96
|
+
|
97
|
+
db_service.insert_coverage_data(
|
98
|
+
file_name=file_path,
|
99
|
+
executed_lines=coverage_data.executed_lines,
|
100
|
+
missed_lines=coverage_data.missed_lines,
|
101
|
+
branch_coverage=coverage_data.branch_coverage,
|
102
|
+
source_file_id=source_file_id
|
103
|
+
)
|
104
|
+
except Exception as e:
|
105
|
+
self.logger.error(f"Error saving coverage data to database: {e}")
|
106
|
+
"""
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def wait_for_file(file_path, retries=5, delay=1):
|
110
|
+
"""Wait for the generated file to appear."""
|
111
|
+
while retries > 0 and not os.path.exists(file_path):
|
112
|
+
time.sleep(delay)
|
113
|
+
retries -= 1
|
114
|
+
if not os.path.exists(file_path):
|
115
|
+
raise FileNotFoundError(f"File '{file_path}' not found after waiting.")
|
@@ -0,0 +1,140 @@
|
|
1
|
+
import os
|
2
|
+
from datetime import datetime
|
3
|
+
from types import ModuleType
|
4
|
+
from typing import List, Dict, Any
|
5
|
+
import testgen.util.utils as utils
|
6
|
+
from testgen.models.db_test_case import DBTestCase
|
7
|
+
from testgen.models.function import Function
|
8
|
+
from testgen.models.source_file import SourceFile
|
9
|
+
from testgen.models.test_case import TestCase
|
10
|
+
|
11
|
+
from testgen.db.dao_impl import DaoImpl
|
12
|
+
|
13
|
+
class DBService:
|
14
|
+
def __init__(self, db_name="testgen.db"):
|
15
|
+
self.dao = DaoImpl(db_name)
|
16
|
+
|
17
|
+
def close(self):
|
18
|
+
"""Close the database connection."""
|
19
|
+
self.dao.close()
|
20
|
+
|
21
|
+
def insert_test_suite(self, name: str) -> int:
|
22
|
+
return self.dao.insert_test_suite(name)
|
23
|
+
|
24
|
+
def insert_source_file(self, path: str, lines_of_code: int, last_modified) -> int:
|
25
|
+
return self.dao.insert_source_file(path, lines_of_code, last_modified)
|
26
|
+
|
27
|
+
def insert_function(self, name: str, params, start_line: int, end_line: int, source_file_id: int) -> int:
|
28
|
+
return self.dao.insert_function(name, params, start_line, end_line, source_file_id)
|
29
|
+
|
30
|
+
def insert_test_case(self, test_case: TestCase, test_suite_id: int, function_id: int, test_method_type: int) -> int:
|
31
|
+
return self.dao.insert_test_case(test_case, test_suite_id, function_id, test_method_type)
|
32
|
+
|
33
|
+
def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
|
34
|
+
return self.dao.insert_test_result(test_case_id, status, error)
|
35
|
+
|
36
|
+
# TODO: Add support for function_id
|
37
|
+
def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
|
38
|
+
branch_coverage: float, source_file_id: int, function_id=None) -> int:
|
39
|
+
return self.dao.insert_coverage_data(file_name, executed_lines, missed_lines, branch_coverage, source_file_id, function_id=None)
|
40
|
+
|
41
|
+
def get_test_suites(self):
|
42
|
+
return self.dao.get_test_suites()
|
43
|
+
|
44
|
+
def get_source_file_id_by_path(self, filepath: str) -> int:
|
45
|
+
return self.dao.get_source_file_id_by_path(filepath)
|
46
|
+
|
47
|
+
def get_test_suite_id_by_name(self, name: str) -> int:
|
48
|
+
return self.dao.get_test_suite_id_by_name(name)
|
49
|
+
|
50
|
+
def get_functions_by_file(self, filepath: str) -> List[Function]:
|
51
|
+
return self.dao.get_functions_by_file(filepath)
|
52
|
+
|
53
|
+
def get_function_id_by_name_file_id_start(self, name: str, source_file_id: int, start_line: int) -> int:
|
54
|
+
return self.dao.get_function_by_name_file_id_start(name, source_file_id, start_line)
|
55
|
+
|
56
|
+
def get_test_cases_by_function(self, function_name):
|
57
|
+
return self.dao.get_test_cases_by_function(function_name)
|
58
|
+
|
59
|
+
def get_coverage_by_file(self, file_path):
|
60
|
+
return self.dao.get_coverage_by_file(file_path)
|
61
|
+
|
62
|
+
def get_test_file_data(self, file_path: str):
|
63
|
+
return self.dao.get_test_file_data(file_path)
|
64
|
+
|
65
|
+
def get_test_case_id_by_func_id_input_expected(self, function_id: int, inputs: str, expected: str) -> int:
|
66
|
+
return self.dao.get_test_case_id_by_func_id_input_expected(function_id, inputs, expected)
|
67
|
+
|
68
|
+
def save_test_generation_data(self, file_path: str, test_cases: list, test_strategy: int, module: ModuleType, class_name: str | None):
|
69
|
+
"""Save test generation data to the database."""
|
70
|
+
source_file_data = self._get_source_file_data(file_path)
|
71
|
+
source_file_id = self.insert_source_file(source_file_data.path, source_file_data.lines_of_code, source_file_data.last_modified)
|
72
|
+
|
73
|
+
test_suite_name = class_name if class_name else module.__name__
|
74
|
+
test_suite_id = self.insert_test_suite(test_suite_name)
|
75
|
+
|
76
|
+
function_data = self._get_function_data(file_path)
|
77
|
+
for function in function_data:
|
78
|
+
self.insert_function(function.name, function.params, function.start_line, function.end_line, source_file_id)
|
79
|
+
|
80
|
+
test_cases_data = self._get_test_cases_data(source_file_id, test_suite_id, function_data, test_cases, test_strategy)
|
81
|
+
for test_case in test_cases_data:
|
82
|
+
self.insert_test_case(TestCase(test_case.test_function, test_case.inputs, test_case.expected_output), test_case.test_suite_id, test_case.function_id, test_strategy)
|
83
|
+
|
84
|
+
def _get_source_file_data(self, file_path: str) -> SourceFile:
|
85
|
+
lines_of_code = sum(1 for _ in open(file_path, 'r')) # Count lines in file
|
86
|
+
last_modified_time = os.path.getmtime(file_path)
|
87
|
+
return SourceFile(file_path, lines_of_code, last_modified_time)
|
88
|
+
|
89
|
+
def _get_function_data(self, file_path: str) -> List[Function]:
|
90
|
+
return utils.get_list_of_functions(file_path)
|
91
|
+
|
92
|
+
def _get_test_cases_data(self, source_file_id: int, test_suite_id: int, function_data: List[Function], test_cases: List[TestCase], test_strategy: int) -> List[DBTestCase]:
|
93
|
+
db_test_cases = []
|
94
|
+
for test_case in test_cases:
|
95
|
+
function_id = self.match_test_case_to_function_for_id(source_file_id, test_case, function_data)
|
96
|
+
db_test_case = DBTestCase(test_case.expected, test_case.inputs, test_case.func_name, datetime.now(), test_strategy, test_suite_id, function_id)
|
97
|
+
db_test_cases.append(db_test_case)
|
98
|
+
return db_test_cases
|
99
|
+
|
100
|
+
def match_test_case_to_function_for_id(self, source_file_id: int, test_case: TestCase, functions: List[Function]) -> int:
|
101
|
+
func_name = test_case.func_name
|
102
|
+
if "." in func_name:
|
103
|
+
func_name = func_name.split(".")[-1]
|
104
|
+
|
105
|
+
candidate_functions = [f for f in functions if f.name.endswith(func_name)]
|
106
|
+
|
107
|
+
if not candidate_functions:
|
108
|
+
return -1
|
109
|
+
|
110
|
+
if len(candidate_functions) == 1:
|
111
|
+
return self.get_function_id_by_name_file_id_start(candidate_functions[0].name, source_file_id, candidate_functions[0].start_line)
|
112
|
+
|
113
|
+
# Match by parameter count
|
114
|
+
input_param_count = len(test_case.inputs) if isinstance(test_case.inputs, dict) else 1
|
115
|
+
|
116
|
+
for function in candidate_functions:
|
117
|
+
# Parse params from string representation to dict if needed
|
118
|
+
params = function.params
|
119
|
+
if isinstance(params, str):
|
120
|
+
try:
|
121
|
+
params = eval(params) # Convert string representation to dict
|
122
|
+
except (SyntaxError, NameError):
|
123
|
+
continue
|
124
|
+
|
125
|
+
param_count = len(params) if isinstance(params, dict) else 0
|
126
|
+
|
127
|
+
if param_count == input_param_count:
|
128
|
+
return self.get_function_id_by_name_file_id_start(function.name, source_file_id, function.start_line)
|
129
|
+
|
130
|
+
return -1
|
131
|
+
|
132
|
+
def _get_test_results(self, filepath: str, execution_results, test_format: int):
|
133
|
+
source_file_id = self.get_source_file_id_by_path(filepath)
|
134
|
+
functions = self.get_functions_by_file(filepath)
|
135
|
+
for result in execution_results:
|
136
|
+
name = result.name
|
137
|
+
test_case = utils.parse_test_case_from_result_name(name, test_format)
|
138
|
+
function_id = self.match_test_case_to_function_for_id(source_file_id, test_case, functions)
|
139
|
+
test_case_id = self.get_test_case_id_by_func_id_input_expected(function_id, str(test_case.inputs), test_case.expected)
|
140
|
+
self.insert_test_result(test_case_id, result.status, result.error)
|
@@ -8,11 +8,11 @@ from testgen.generator.code_generator import CodeGenerator
|
|
8
8
|
from testgen.generator.doctest_generator import DocTestGenerator
|
9
9
|
from testgen.generator.pytest_generator import PyTestGenerator
|
10
10
|
from testgen.generator.unit_test_generator import UnitTestGenerator
|
11
|
-
from testgen.inspector.inspector import Inspector
|
12
11
|
from testgen.service.logging_service import get_logger
|
13
12
|
from testgen.tree.node import Node
|
14
13
|
from testgen.tree.tree_utils import build_binary_tree
|
15
14
|
from testgen.models.generator_context import GeneratorContext
|
15
|
+
from testgen.util.file_utils import find_project_root
|
16
16
|
|
17
17
|
# Constants for test formats
|
18
18
|
UNITTEST_FORMAT = 1
|
@@ -49,23 +49,9 @@ class GeneratorService:
|
|
49
49
|
filename = self.get_filename(self.filepath)
|
50
50
|
output_path = self.get_test_file_path(module.__name__, output_path)
|
51
51
|
|
52
|
-
|
53
|
-
actual_class_name = class_name
|
54
|
-
if 'generated_' in self.filepath and class_name:
|
55
|
-
# For generated classes, find the actual class name in the module
|
56
|
-
for name, obj in inspect.getmembers(module):
|
57
|
-
if inspect.isclass(obj):
|
58
|
-
actual_class_name = name
|
59
|
-
break
|
52
|
+
actual_class_name = self.resolve_class_name(module, class_name)
|
60
53
|
|
61
|
-
context =
|
62
|
-
filepath=self.filepath,
|
63
|
-
filename=filename,
|
64
|
-
class_name=actual_class_name, # Use the actual class name
|
65
|
-
module=module,
|
66
|
-
output_path=output_path,
|
67
|
-
test_cases=test_cases
|
68
|
-
)
|
54
|
+
context = self.get_generator_context(self.filepath, module, actual_class_name, test_cases, output_path)
|
69
55
|
|
70
56
|
self.test_generator.generator_context = context
|
71
57
|
|
@@ -73,7 +59,6 @@ class GeneratorService:
|
|
73
59
|
|
74
60
|
self.generate_function_tests(test_cases)
|
75
61
|
|
76
|
-
|
77
62
|
if self.test_format == DOCTEST_FORMAT:
|
78
63
|
self.logger.debug("SAVING DOCT TEST FILE")
|
79
64
|
self.test_generator.save_file()
|
@@ -172,17 +157,89 @@ class GeneratorService:
|
|
172
157
|
|
173
158
|
return self.generated_file_path
|
174
159
|
|
160
|
+
def get_generator_context(self, filepath: str, module: ModuleType, class_name: str | None, test_cases: List[TestCase], output_path: str) -> GeneratorContext:
|
161
|
+
file_dir = os.path.dirname(os.path.abspath(filepath))
|
162
|
+
is_package = os.path.exists(os.path.join(file_dir, '__init__.py'))
|
163
|
+
|
164
|
+
package_name, import_path = self._resolve_package_name_and_import_path(filepath)
|
165
|
+
|
166
|
+
generator_context = GeneratorContext(
|
167
|
+
filepath=filepath,
|
168
|
+
filename=self.get_filename(filepath),
|
169
|
+
class_name=class_name,
|
170
|
+
module=module,
|
171
|
+
output_path=output_path,
|
172
|
+
test_cases=test_cases,
|
173
|
+
is_package=is_package,
|
174
|
+
package_name=package_name,
|
175
|
+
import_path=import_path
|
176
|
+
)
|
177
|
+
|
178
|
+
print(f"Generator Context: {generator_context.filepath} {generator_context.filename} {generator_context.class_name} {generator_context.is_package} {generator_context.package_name} {generator_context.import_path}")
|
179
|
+
|
180
|
+
return generator_context
|
181
|
+
|
182
|
+
@staticmethod
|
183
|
+
def _resolve_package_name_and_import_path(filepath: str) -> tuple:
|
184
|
+
if not os.path.exists(filepath) or not filepath.endswith('.py'):
|
185
|
+
raise ValueError(f"Invalid Python file: {filepath}")
|
186
|
+
|
187
|
+
# Get the directory and filename
|
188
|
+
file_dir = os.path.dirname(os.path.abspath(filepath))
|
189
|
+
module_name = os.path.splitext(os.path.basename(filepath))[0]
|
190
|
+
|
191
|
+
# Find the project root by looking for setup.py or a .git directory
|
192
|
+
project_root = find_project_root(file_dir)
|
193
|
+
|
194
|
+
# Build the import path based on the file's location relative to the project root
|
195
|
+
if project_root:
|
196
|
+
rel_path = os.path.relpath(file_dir, project_root)
|
197
|
+
if rel_path == '.':
|
198
|
+
# File is directly in the project root
|
199
|
+
import_path = module_name
|
200
|
+
package_name = ''
|
201
|
+
else:
|
202
|
+
# File is in a subdirectory
|
203
|
+
path_parts = rel_path.replace('\\', '/').split('/')
|
204
|
+
# Filter out any empty parts
|
205
|
+
path_parts = [part for part in path_parts if part]
|
206
|
+
|
207
|
+
if path_parts:
|
208
|
+
package_name = path_parts[0]
|
209
|
+
# Construct the full import path
|
210
|
+
import_path = '.'.join(path_parts) + '.' + module_name
|
211
|
+
else:
|
212
|
+
package_name = ''
|
213
|
+
import_path = module_name
|
214
|
+
else:
|
215
|
+
# Fallback if we can't find a project root
|
216
|
+
package_name = ''
|
217
|
+
import_path = module_name
|
218
|
+
|
219
|
+
return package_name, import_path
|
220
|
+
|
175
221
|
def build_func_trees(self, functions: list):
|
176
222
|
"""Build binary trees for function signatures."""
|
177
223
|
tree_list = []
|
178
224
|
for name, func in functions:
|
179
|
-
signature =
|
180
|
-
params =
|
225
|
+
signature = inspect.signature(func)
|
226
|
+
params = [param for param, value in signature.parameters.items() if param != 'self']
|
181
227
|
root = Node(None)
|
182
228
|
build_binary_tree(root, 0, len(params))
|
183
229
|
tree_list.append((func, root, params))
|
184
230
|
return tree_list
|
185
231
|
|
232
|
+
def resolve_class_name(self, module: ModuleType, class_name: str):
|
233
|
+
# Determine the actual class name used in the module
|
234
|
+
actual_class_name = class_name
|
235
|
+
if 'generated_' in self.filepath and class_name:
|
236
|
+
# For generated classes, find the actual class name in the module
|
237
|
+
for name, obj in inspect.getmembers(module):
|
238
|
+
if inspect.isclass(obj):
|
239
|
+
actual_class_name = name
|
240
|
+
break
|
241
|
+
return actual_class_name
|
242
|
+
|
186
243
|
@staticmethod
|
187
244
|
def get_filename(filepath: str) -> str:
|
188
245
|
"""Get filename from filepath."""
|
@@ -96,7 +96,7 @@ def get_logger():
|
|
96
96
|
logger = LoggingService.get_instance()
|
97
97
|
|
98
98
|
# If logger hasn't been initialized yet, set up a basic configuration
|
99
|
-
if not LoggingService._initialized:
|
100
|
-
logger.initialize(debug_mode=False, console_output=True)
|
99
|
+
"""if not LoggingService._initialized:
|
100
|
+
logger.initialize(debug_mode=False, console_output=True)"""
|
101
101
|
|
102
102
|
return logger
|