testgenie-py 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/analyzer/random_feedback_analyzer.py +3 -3
- testgen/controller/cli_controller.py +119 -99
- testgen/controller/docker_controller.py +34 -19
- testgen/docker/Dockerfile +2 -2
- testgen/docker/pyproject.toml +9 -2
- testgen/reinforcement/environment.py +42 -10
- testgen/reinforcement/statement_coverage_state.py +1 -1
- testgen/service/analysis_service.py +8 -2
- testgen/service/cfg_service.py +1 -1
- testgen/service/generator_service.py +11 -3
- testgen/service/logging_service.py +100 -0
- testgen/service/service.py +81 -41
- testgen/testgen.db +0 -0
- testgen/util/coverage_utils.py +41 -14
- testgen/util/coverage_visualizer.py +2 -2
- testgen/util/file_utils.py +46 -0
- testgen/util/randomizer.py +27 -12
- testgen/util/z3_utils/z3_test_case.py +26 -11
- {testgenie_py-0.1.6.dist-info → testgenie_py-0.1.8.dist-info}/METADATA +1 -1
- {testgenie_py-0.1.6.dist-info → testgenie_py-0.1.8.dist-info}/RECORD +22 -22
- testgen/docker/poetry.lock +0 -361
- testgen/q_table/global_q_table.json +0 -1
- {testgenie_py-0.1.6.dist-info → testgenie_py-0.1.8.dist-info}/WHEEL +0 -0
- {testgenie_py-0.1.6.dist-info → testgenie_py-0.1.8.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,100 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
from enum import Enum
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
|
8
|
+
class LogLevel(Enum):
|
9
|
+
DEBUG = logging.DEBUG
|
10
|
+
INFO = logging.INFO
|
11
|
+
WARNING = logging.WARNING
|
12
|
+
ERROR = logging.ERROR
|
13
|
+
CRITICAL = logging.CRITICAL
|
14
|
+
|
15
|
+
|
16
|
+
class LoggingService:
|
17
|
+
"""Centralized logging service for testgen framework"""
|
18
|
+
|
19
|
+
_instance = None
|
20
|
+
_initialized = False
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def get_instance(cls):
|
24
|
+
"""Get singleton instance of LoggingService"""
|
25
|
+
if cls._instance is None:
|
26
|
+
cls._instance = LoggingService()
|
27
|
+
return cls._instance
|
28
|
+
|
29
|
+
def __init__(self):
|
30
|
+
self.logger = logging.getLogger('testgen')
|
31
|
+
self.formatter = logging.Formatter(
|
32
|
+
'%(asctime)s [%(levelname)s] %(message)s',
|
33
|
+
datefmt='%Y-%m-%d %H:%M:%S'
|
34
|
+
)
|
35
|
+
self.file_handler = None
|
36
|
+
self.console_handler = None
|
37
|
+
|
38
|
+
def initialize(self,
|
39
|
+
debug_mode: bool = False,
|
40
|
+
log_file: Optional[str] = None,
|
41
|
+
console_output: bool = True):
|
42
|
+
"""Initialize the logging service"""
|
43
|
+
if LoggingService._initialized:
|
44
|
+
return
|
45
|
+
|
46
|
+
# Set the base logging level
|
47
|
+
level = LogLevel.DEBUG.value if debug_mode else LogLevel.INFO.value
|
48
|
+
self.logger.setLevel(level)
|
49
|
+
|
50
|
+
# Add console handler if requested
|
51
|
+
if console_output:
|
52
|
+
self.console_handler = logging.StreamHandler(sys.stdout)
|
53
|
+
self.console_handler.setFormatter(self.formatter)
|
54
|
+
self.console_handler.setLevel(level)
|
55
|
+
self.logger.addHandler(self.console_handler)
|
56
|
+
|
57
|
+
# Add file handler if path provided
|
58
|
+
if log_file:
|
59
|
+
# Ensure directory exists
|
60
|
+
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
61
|
+
self.file_handler = logging.FileHandler(log_file)
|
62
|
+
self.file_handler.setFormatter(self.formatter)
|
63
|
+
self.file_handler.setLevel(level)
|
64
|
+
self.logger.addHandler(self.file_handler)
|
65
|
+
|
66
|
+
# Mark as initialized
|
67
|
+
LoggingService._initialized = True
|
68
|
+
self.info(f"Logging initialized - Debug mode: {debug_mode}")
|
69
|
+
|
70
|
+
def debug(self, message: str):
|
71
|
+
"""Log debug message"""
|
72
|
+
self.logger.debug(message)
|
73
|
+
|
74
|
+
def info(self, message: str):
|
75
|
+
"""Log info message"""
|
76
|
+
self.logger.info(message)
|
77
|
+
|
78
|
+
def warning(self, message: str):
|
79
|
+
"""Log warning message"""
|
80
|
+
self.logger.warning(message)
|
81
|
+
|
82
|
+
def error(self, message: str):
|
83
|
+
"""Log error message"""
|
84
|
+
self.logger.error(message)
|
85
|
+
|
86
|
+
def critical(self, message: str):
|
87
|
+
"""Log critical message"""
|
88
|
+
self.logger.critical(message)
|
89
|
+
|
90
|
+
|
91
|
+
# Global accessor function for easy import and use
|
92
|
+
def get_logger():
|
93
|
+
"""Get the global logger instance"""
|
94
|
+
logger = LoggingService.get_instance()
|
95
|
+
|
96
|
+
# If logger hasn't been initialized yet, set up a basic configuration
|
97
|
+
if not LoggingService._initialized:
|
98
|
+
logger.initialize(debug_mode=False, console_output=True)
|
99
|
+
|
100
|
+
return logger
|
testgen/service/service.py
CHANGED
@@ -14,7 +14,8 @@ from testgen.service.analysis_service import AnalysisService
|
|
14
14
|
from testgen.service.generator_service import GeneratorService
|
15
15
|
from testgen.sqlite.db_service import DBService
|
16
16
|
from testgen.models.analysis_context import AnalysisContext
|
17
|
-
from testgen.
|
17
|
+
from testgen.service.logging_service import get_logger
|
18
|
+
|
18
19
|
|
19
20
|
# Constants for test strategies
|
20
21
|
AST_STRAT = 1
|
@@ -29,26 +30,32 @@ DOCTEST_FORMAT = 3
|
|
29
30
|
|
30
31
|
class Service:
|
31
32
|
def __init__(self):
|
33
|
+
self.debug_mode: bool = False
|
32
34
|
self.test_strategy: int = 0
|
33
35
|
self.test_format: int = 0
|
34
36
|
self.file_path = None
|
35
37
|
self.generated_file_path = None
|
36
38
|
self.class_name = None
|
37
39
|
self.test_cases = []
|
40
|
+
self.logger = get_logger()
|
38
41
|
self.reinforcement_mode = "train"
|
39
42
|
|
40
43
|
# Initialize specialized services
|
41
44
|
self.analysis_service = AnalysisService()
|
42
45
|
self.generator_service = GeneratorService(None, None, None)
|
43
|
-
|
46
|
+
# Only initialize DB service if not running in Docker
|
47
|
+
if os.environ.get("RUNNING_IN_DOCKER") is None:
|
48
|
+
self.db_service = DBService()
|
49
|
+
else:
|
50
|
+
# Create a dummy DB service that doesn't do anything
|
51
|
+
self.db_service = None
|
44
52
|
|
45
53
|
def select_all_from_db(self) -> None:
|
46
54
|
rows = self.db_service.get_test_suites()
|
47
55
|
for row in rows:
|
48
56
|
print(repr(dict(row)))
|
49
57
|
|
50
|
-
def
|
51
|
-
"""Generate tests for a class or module."""
|
58
|
+
def generate_test_cases(self) -> List[TestCase] | None:
|
52
59
|
module = file_utils.load_module(self.file_path)
|
53
60
|
class_name = self.analysis_service.get_class_name(module)
|
54
61
|
|
@@ -65,24 +72,38 @@ class Service:
|
|
65
72
|
|
66
73
|
test_cases: List[TestCase] = []
|
67
74
|
if self.test_strategy == REINFORCE_STRAT:
|
68
|
-
test_cases = self.analysis_service.do_reinforcement_learning(self.file_path, self.reinforcement_mode)
|
75
|
+
test_cases = self.analysis_service.do_reinforcement_learning(self.file_path, class_name, self.reinforcement_mode)
|
69
76
|
else:
|
70
77
|
test_cases = self.analysis_service.generate_test_cases()
|
71
78
|
|
72
|
-
self.test_cases = test_cases
|
73
|
-
|
74
|
-
file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
|
75
|
-
self.db_service.save_test_generation_data(
|
76
|
-
file_path_to_use,
|
77
|
-
test_cases,
|
78
|
-
self.test_strategy,
|
79
|
-
class_name
|
80
|
-
)
|
81
|
-
|
82
79
|
if os.environ.get("RUNNING_IN_DOCKER") is not None:
|
83
|
-
|
80
|
+
self.debug(f"Serializing test cases {test_cases}")
|
84
81
|
self.serialize_test_cases(test_cases)
|
85
82
|
return None # Exit early in analysis-only mode
|
83
|
+
|
84
|
+
return test_cases
|
85
|
+
|
86
|
+
def generate_tests(self, output_path=None):
|
87
|
+
module = file_utils.load_module(self.file_path)
|
88
|
+
class_name = self.analysis_service.get_class_name(module)
|
89
|
+
|
90
|
+
test_cases = self.generate_test_cases()
|
91
|
+
|
92
|
+
# Only process if we have test cases
|
93
|
+
if test_cases is None:
|
94
|
+
return None
|
95
|
+
|
96
|
+
self.test_cases = test_cases
|
97
|
+
|
98
|
+
# Only save to DB if not running in Docker
|
99
|
+
if os.environ.get("RUNNING_IN_DOCKER") is None:
|
100
|
+
file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
|
101
|
+
self.db_service.save_test_generation_data(
|
102
|
+
file_path_to_use,
|
103
|
+
test_cases,
|
104
|
+
self.test_strategy,
|
105
|
+
class_name
|
106
|
+
)
|
86
107
|
|
87
108
|
test_file = self.generate_test_file(test_cases, output_path, module, class_name)
|
88
109
|
|
@@ -112,6 +133,8 @@ class Service:
|
|
112
133
|
output_path
|
113
134
|
)
|
114
135
|
|
136
|
+
print(f"Generated test file: {test_file}")
|
137
|
+
|
115
138
|
# Ensure the test file is ready
|
116
139
|
Service.wait_for_file(test_file)
|
117
140
|
return test_file
|
@@ -127,6 +150,7 @@ class Service:
|
|
127
150
|
"""Run coverage analysis on the generated tests."""
|
128
151
|
Service.wait_for_file(test_file)
|
129
152
|
file_path_to_use = self.generated_file_path if self.test_strategy == AST_STRAT else self.file_path
|
153
|
+
print(f"File path to use for coverage: {file_path_to_use}")
|
130
154
|
coverage_output = ""
|
131
155
|
|
132
156
|
try:
|
@@ -166,32 +190,40 @@ class Service:
|
|
166
190
|
|
167
191
|
def _save_coverage_data(self, coverage_output, file_path):
|
168
192
|
"""Parse coverage output and save to database."""
|
193
|
+
# Skip if running in Docker or DB service is None
|
194
|
+
if os.environ.get("RUNNING_IN_DOCKER") is not None or self.db_service is None:
|
195
|
+
self.debug("Skipping database operations in Docker container")
|
196
|
+
return
|
197
|
+
|
169
198
|
try:
|
170
199
|
lines = coverage_output.strip().split('\n')
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
200
|
+
if not lines:
|
201
|
+
raise ValueError("No coverage data found in the output.")
|
202
|
+
else:
|
203
|
+
for line in lines:
|
204
|
+
if file_path in line:
|
205
|
+
parts = line.split()
|
206
|
+
if len(parts) >= 4:
|
207
|
+
file_name = os.path.basename(file_path)
|
208
|
+
try:
|
209
|
+
total_lines = int(parts[-3])
|
210
|
+
missed_lines = int(parts[-2])
|
211
|
+
executed_lines = total_lines - missed_lines
|
212
|
+
coverage_str = parts[-1].strip('%')
|
213
|
+
branch_coverage = float(coverage_str) / 100
|
214
|
+
|
215
|
+
source_file_id = self._get_source_file_id(file_path)
|
216
|
+
|
217
|
+
self.db_service.insert_coverage_data(
|
218
|
+
file_name,
|
219
|
+
executed_lines,
|
220
|
+
missed_lines,
|
221
|
+
branch_coverage,
|
222
|
+
source_file_id
|
223
|
+
)
|
224
|
+
break
|
225
|
+
except (ValueError, IndexError) as e:
|
226
|
+
print(f"Error parsing coverage data: {e}")
|
195
227
|
except Exception as e:
|
196
228
|
print(f"Error saving coverage data: {e}")
|
197
229
|
|
@@ -261,6 +293,7 @@ class Service:
|
|
261
293
|
|
262
294
|
def set_file_path(self, path: str):
|
263
295
|
"""Set the file path for analysis and validate it."""
|
296
|
+
print(f"Setting file path: {path}")
|
264
297
|
if os.path.isfile(path) and path.endswith(".py"):
|
265
298
|
self.file_path = path
|
266
299
|
self.analysis_service.set_file_path(path)
|
@@ -386,4 +419,11 @@ class Service:
|
|
386
419
|
def set_reinforcement_mode(self, mode: str):
|
387
420
|
self.reinforcement_mode = mode
|
388
421
|
if hasattr(self ,'analysis_service'):
|
389
|
-
self.analysis_service.set_reinforcement_mode(mode)
|
422
|
+
self.analysis_service.set_reinforcement_mode(mode)
|
423
|
+
|
424
|
+
def set_debug_mode(self, debug: bool):
|
425
|
+
self.debug_mode = debug
|
426
|
+
|
427
|
+
def debug(self, message: str):
|
428
|
+
if self.debug_mode:
|
429
|
+
self.logger.debug(message)
|
testgen/testgen.db
ADDED
Binary file
|
testgen/util/coverage_utils.py
CHANGED
@@ -24,12 +24,39 @@ def get_branch_coverage(file_name, func, *args) -> list:
|
|
24
24
|
return branches
|
25
25
|
|
26
26
|
|
27
|
-
def get_coverage_analysis(file_name, func_name, args) -> tuple:
|
27
|
+
def get_coverage_analysis(file_name, class_name: str | None, func_name, args) -> tuple:
|
28
28
|
tree = load_and_parse_file_for_tree(file_name)
|
29
29
|
func_node = None
|
30
30
|
func_start = None
|
31
31
|
func_end = None
|
32
|
+
|
33
|
+
# Process tree body
|
32
34
|
for i, node in enumerate(tree.body):
|
35
|
+
# Handle class methods
|
36
|
+
if isinstance(node, ast.ClassDef) and class_name is not None:
|
37
|
+
# Search within class body with its own index
|
38
|
+
for j, class_node in enumerate(node.body):
|
39
|
+
if isinstance(class_node, ast.FunctionDef) and class_node.name == func_name:
|
40
|
+
func_node = class_node
|
41
|
+
func_start = class_node.lineno
|
42
|
+
|
43
|
+
# Now correctly check if this is the last method in the class
|
44
|
+
if j == len(node.body) - 1:
|
45
|
+
# Last method in class - find maximum line in method
|
46
|
+
max_lines = [line.lineno for line in ast.walk(class_node)
|
47
|
+
if hasattr(line, 'lineno') and line.lineno]
|
48
|
+
func_end = max(max_lines) if max_lines else func_start
|
49
|
+
else:
|
50
|
+
# Not last method - use next method's line number minus 1
|
51
|
+
next_node = node.body[j + 1] # Correct index now
|
52
|
+
if hasattr(next_node, 'lineno'):
|
53
|
+
func_end = next_node.lineno - 1
|
54
|
+
else:
|
55
|
+
# Fallback using max line in method
|
56
|
+
max_lines = [line.lineno for line in ast.walk(class_node)
|
57
|
+
if hasattr(line, 'lineno') and line.lineno]
|
58
|
+
func_end = max(max_lines) if max_lines else func_start
|
59
|
+
break
|
33
60
|
if isinstance(node, ast.FunctionDef) and node.name == func_name:
|
34
61
|
func_node = node
|
35
62
|
func_start = node.lineno
|
@@ -55,7 +82,12 @@ def get_coverage_analysis(file_name, func_name, args) -> tuple:
|
|
55
82
|
cov.start()
|
56
83
|
module = load_module(file_name)
|
57
84
|
|
58
|
-
|
85
|
+
if class_name is not None:
|
86
|
+
class_obj = getattr(module, class_name)
|
87
|
+
instance = class_obj()
|
88
|
+
func = getattr(instance, func_name)
|
89
|
+
else:
|
90
|
+
func = getattr(module, func_name)
|
59
91
|
|
60
92
|
func(*args)
|
61
93
|
|
@@ -64,7 +96,6 @@ def get_coverage_analysis(file_name, func_name, args) -> tuple:
|
|
64
96
|
|
65
97
|
analysis = cov.analysis2(file_name)
|
66
98
|
analysis_list = list(analysis)
|
67
|
-
|
68
99
|
# Filter executable and missed lines to function range
|
69
100
|
analysis_list[1] = [line for line in analysis_list[1] if func_start <= line <= func_end]
|
70
101
|
analysis_list[3] = [line for line in analysis_list[3] if func_start <= line <= func_end]
|
@@ -113,21 +144,12 @@ def get_list_of_missed_lines(analysis: tuple) -> list:
|
|
113
144
|
|
114
145
|
|
115
146
|
def get_list_of_covered_statements(analysis: tuple) -> list:
|
116
|
-
# analysis[1] == list of total executable statements
|
117
|
-
# analysis[3] == list of missed statements
|
118
147
|
return [x for x in analysis[1] if x not in analysis[3]]
|
119
|
-
# total_statements = len(analysis[1]) # list of executable statements
|
120
|
-
# missed_statements = len(analysis[3]) # list of missed line numbers
|
121
|
-
|
122
|
-
# covered_statements = total_statements - missed_statements
|
123
148
|
|
124
149
|
|
125
|
-
def get_uncovered_lines_for_func(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[int]:
|
150
|
+
def get_uncovered_lines_for_func(file_name: str, class_name: str | None, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[int]:
|
126
151
|
# Get normal uncovered lines
|
127
152
|
func_name = func_node.name
|
128
|
-
if not test_cases:
|
129
|
-
print(f"Warning: No test cases provided {func_name}.")
|
130
|
-
return []
|
131
153
|
|
132
154
|
function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
|
133
155
|
if not function_test_cases:
|
@@ -135,7 +157,12 @@ def get_uncovered_lines_for_func(file_name: str, func_node: ast.FunctionDef, tes
|
|
135
157
|
return []
|
136
158
|
|
137
159
|
module = load_module(file_name)
|
138
|
-
|
160
|
+
if class_name is not None:
|
161
|
+
class_obj = getattr(module, class_name)
|
162
|
+
instance = class_obj()
|
163
|
+
func = getattr(instance, func_name)
|
164
|
+
else:
|
165
|
+
func = getattr(module, func_name)
|
139
166
|
|
140
167
|
# Run coverage
|
141
168
|
cov = coverage.Coverage(branch=True) # Enable branch coverage
|
@@ -16,12 +16,12 @@ class CoverageVisualizer:
|
|
16
16
|
def set_service(self, service):
|
17
17
|
self.service = service
|
18
18
|
|
19
|
-
def get_covered_lines(self, file_path: str, func_def: ast.FunctionDef, test_cases: List[TestCase]):
|
19
|
+
def get_covered_lines(self, file_path: str, class_name: str | None, func_def: ast.FunctionDef, test_cases: List[TestCase]):
|
20
20
|
if func_def.name not in self.covered_lines:
|
21
21
|
self.covered_lines[func_def.name] = set()
|
22
22
|
|
23
23
|
for test_case in [tc for tc in test_cases if tc.func_name == func_def.name]:
|
24
|
-
analysis = testgen.util.coverage_utils.get_coverage_analysis(file_path, func_def.name, test_case.inputs)
|
24
|
+
analysis = testgen.util.coverage_utils.get_coverage_analysis(file_path, class_name, func_def.name, test_case.inputs)
|
25
25
|
covered = testgen.util.coverage_utils.get_list_of_covered_statements(analysis)
|
26
26
|
if covered:
|
27
27
|
self.covered_lines[func_def.name].update(covered)
|
testgen/util/file_utils.py
CHANGED
@@ -108,3 +108,49 @@ def load_and_parse_file_for_tree(file) -> Module:
|
|
108
108
|
code = f.read()
|
109
109
|
tree = ast.parse(code)
|
110
110
|
return tree
|
111
|
+
|
112
|
+
def adjust_file_path_for_docker(file_path) -> str:
|
113
|
+
"""Adjust file path for Docker environment to handle subdirectories and relative paths."""
|
114
|
+
print(f"Docker - adjusting path: {file_path}")
|
115
|
+
|
116
|
+
# Try direct path first (maybe it's correct already)
|
117
|
+
if os.path.isfile(file_path):
|
118
|
+
print(f"Docker - found file at direct path: {file_path}")
|
119
|
+
return file_path
|
120
|
+
|
121
|
+
# Try relative to /controller (root of mount)
|
122
|
+
controller_path = os.path.join('/controller', file_path)
|
123
|
+
if os.path.isfile(controller_path):
|
124
|
+
print(f"Docker - found file at: {controller_path}")
|
125
|
+
return controller_path
|
126
|
+
|
127
|
+
# Try /controller/testgen/code_to_test path - this is where the file actually is
|
128
|
+
testgen_path = os.path.join('/controller/testgen', file_path)
|
129
|
+
if os.path.isfile(testgen_path):
|
130
|
+
print(f"Docker - found file at: {testgen_path}")
|
131
|
+
return testgen_path
|
132
|
+
|
133
|
+
# If it's just a filename, search in common locations
|
134
|
+
if os.path.basename(file_path) == file_path:
|
135
|
+
for search_dir in ['/controller', '/controller/code_to_test', '/controller/testgen/code_to_test', '/controller/testgen']:
|
136
|
+
test_path = os.path.join(search_dir, file_path)
|
137
|
+
if os.path.isfile(test_path):
|
138
|
+
print(f"Docker - found file at: {test_path}")
|
139
|
+
return test_path
|
140
|
+
|
141
|
+
# Debug output to help diagnose issues
|
142
|
+
print("Docker - available files in /controller:")
|
143
|
+
os.system("find /controller -name '*.py' | grep boolean")
|
144
|
+
|
145
|
+
# Return original path if we couldn't find a better match
|
146
|
+
print(f"Docker - couldn't find file, returning original: {file_path}")
|
147
|
+
return file_path
|
148
|
+
|
149
|
+
def get_project_root_in_docker(script_path) -> str:
|
150
|
+
script_path = os.path.abspath(sys.argv[0])
|
151
|
+
print(f"Script path: {script_path}")
|
152
|
+
script_dir = os.path.dirname(script_path)
|
153
|
+
print(f"Script directory: {script_dir}")
|
154
|
+
project_root = os.path.dirname(script_dir)
|
155
|
+
print(f"Project root directory: {project_root}")
|
156
|
+
return project_root
|
testgen/util/randomizer.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import ast
|
2
2
|
import random
|
3
|
+
import string
|
3
4
|
from typing import List
|
4
5
|
|
5
6
|
import testgen.util.coverage_utils
|
@@ -13,12 +14,12 @@ except ImportError:
|
|
13
14
|
solve_branch_condition = None
|
14
15
|
from testgen.models.test_case import TestCase
|
15
16
|
|
16
|
-
def make_random_move(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
|
17
|
+
def make_random_move(file_name: str, class_name: str | None, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
|
17
18
|
random_choice = random.randint(1, 4)
|
18
19
|
func_name = func_node.name
|
19
20
|
# new random test case
|
20
21
|
if random_choice == 1:
|
21
|
-
test_cases.append(new_random_test_case(file_name, func_node))
|
22
|
+
test_cases.append(new_random_test_case(file_name, class_name, func_node))
|
22
23
|
# combine test cases
|
23
24
|
if random_choice == 2:
|
24
25
|
test_cases.append(combine_cases(test_cases))
|
@@ -31,7 +32,7 @@ def make_random_move(file_name: str, func_node: ast.FunctionDef, test_cases: Lis
|
|
31
32
|
function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
|
32
33
|
|
33
34
|
if function_test_cases:
|
34
|
-
uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, func_name)
|
35
|
+
uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, class_name, func_name)
|
35
36
|
|
36
37
|
if len(uncovered_lines) > 0:
|
37
38
|
z3_test_cases = solve_branch_condition(file_name, func_node, uncovered_lines)
|
@@ -39,18 +40,32 @@ def make_random_move(file_name: str, func_node: ast.FunctionDef, test_cases: Lis
|
|
39
40
|
|
40
41
|
return test_cases
|
41
42
|
|
42
|
-
def new_random_test_case(file_name: str, func_node: ast.FunctionDef) -> TestCase:
|
43
|
+
def new_random_test_case(file_name: str, class_name: str | None, func_node: ast.FunctionDef) -> TestCase:
|
43
44
|
func_name = func_node.name
|
44
45
|
param_types: dict = utils.extract_parameter_types(func_node)
|
46
|
+
|
47
|
+
if class_name is not None and 'self' in param_types:
|
48
|
+
del param_types['self']
|
49
|
+
|
45
50
|
inputs: dict = utils.generate_random_inputs(param_types)
|
46
|
-
args = inputs.values()
|
51
|
+
args = list(inputs.values())
|
47
52
|
|
48
53
|
module = testgen.util.file_utils.load_module(file_name)
|
49
|
-
func = getattr(module, func_name)
|
50
54
|
|
51
|
-
|
55
|
+
try:
|
56
|
+
if class_name is not None:
|
57
|
+
class_obj = getattr(module, class_name)
|
58
|
+
instance = class_obj()
|
59
|
+
func = getattr(instance, func_name)
|
60
|
+
else:
|
61
|
+
func = getattr(module, func_name)
|
62
|
+
|
63
|
+
output = func(*args)
|
52
64
|
|
53
|
-
|
65
|
+
return TestCase(func_name, tuple(args), output)
|
66
|
+
except Exception as e:
|
67
|
+
print(f"Error generating test case for {func_name}: {e}")
|
68
|
+
raise
|
54
69
|
|
55
70
|
# Should combining test cases preserve the parent cases or entirely replace them?
|
56
71
|
def combine_cases(test_cases: List[TestCase]) -> TestCase:
|
@@ -87,25 +102,25 @@ def mix_inputs(test_case1: TestCase, test_case2: TestCase) -> tuple:
|
|
87
102
|
|
88
103
|
return new_inputs
|
89
104
|
|
90
|
-
def get_z3_test_cases(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
|
105
|
+
def get_z3_test_cases(file_name: str, class_name: str | None, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
|
91
106
|
func_name = func_node.name
|
92
107
|
|
93
108
|
# Filter test cases for this specific function
|
94
109
|
function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
|
95
110
|
|
96
111
|
if not function_test_cases:
|
97
|
-
initial_case = new_random_test_case(file_name, func_node)
|
112
|
+
initial_case = new_random_test_case(file_name, class_name, func_node)
|
98
113
|
test_cases.append(initial_case)
|
99
114
|
function_test_cases = [initial_case]
|
100
115
|
|
101
116
|
try:
|
102
117
|
# Get uncovered lines
|
103
|
-
uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, func_node, function_test_cases)
|
118
|
+
uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, class_name, func_node, function_test_cases)
|
104
119
|
|
105
120
|
if uncovered_lines:
|
106
121
|
if solve_branch_condition:
|
107
122
|
# Call the Z3 solver with uncovered lines
|
108
|
-
z3_cases = solve_branch_condition(file_name, func_node, uncovered_lines)
|
123
|
+
z3_cases = solve_branch_condition(file_name, class_name, func_node, uncovered_lines)
|
109
124
|
if z3_cases:
|
110
125
|
test_cases.extend(z3_cases)
|
111
126
|
else:
|
@@ -9,7 +9,7 @@ from testgen.util.z3_utils import ast_to_z3
|
|
9
9
|
from testgen.util.z3_utils.constraint_extractor import extract_branch_conditions
|
10
10
|
|
11
11
|
|
12
|
-
def solve_branch_condition(file_name: str, func_node: ast.FunctionDef, uncovered_lines: List[int]) -> List[TestCase]:
|
12
|
+
def solve_branch_condition(file_name: str, class_name: str | None, func_node: ast.FunctionDef, uncovered_lines: List[int]) -> List[TestCase]:
|
13
13
|
branch_conditions, param_types = extract_branch_conditions(func_node)
|
14
14
|
uncovered_conditions = [bc for bc in branch_conditions if bc.line_number in uncovered_lines]
|
15
15
|
test_cases = []
|
@@ -25,6 +25,10 @@ def solve_branch_condition(file_name: str, func_node: ast.FunctionDef, uncovered
|
|
25
25
|
# Create default values for all parameters
|
26
26
|
param_values = {}
|
27
27
|
for param_name in param_types:
|
28
|
+
# Skip 'self' parameter for class methods
|
29
|
+
if param_name == 'self':
|
30
|
+
continue
|
31
|
+
|
28
32
|
# Set default values based on type
|
29
33
|
if param_types[param_name] == "int":
|
30
34
|
param_values[param_name] = 0
|
@@ -39,32 +43,38 @@ def solve_branch_condition(file_name: str, func_node: ast.FunctionDef, uncovered
|
|
39
43
|
|
40
44
|
# Update with model values where available
|
41
45
|
for var_name, z3_var in z3_vars.items():
|
42
|
-
if var_name in param_types
|
46
|
+
if var_name in param_types:
|
43
47
|
try:
|
48
|
+
|
49
|
+
model_value = model.evaluate(z3_var)
|
50
|
+
|
44
51
|
if param_types[var_name] == "int":
|
45
|
-
param_values[var_name] =
|
52
|
+
param_values[var_name] = model_value.as_long()
|
46
53
|
elif param_types[var_name] == "float":
|
47
|
-
param_values[var_name] = float(
|
54
|
+
param_values[var_name] = float(model_value.as_decimal(10))
|
48
55
|
elif param_types[var_name] == "bool":
|
49
|
-
param_values[var_name] = z3.is_true(
|
56
|
+
param_values[var_name] = z3.is_true(model_value)
|
50
57
|
elif param_types[var_name] == "str":
|
51
|
-
|
58
|
+
str_val = str(model_value)
|
59
|
+
if str_val.startswith('"') and str_val.endswith('"'):
|
60
|
+
str_val = str_val[1:-1]
|
61
|
+
param_values[var_name] = str_val
|
52
62
|
else:
|
53
63
|
param_values[var_name] = model[z3_var].as_long()
|
54
64
|
except Exception as e:
|
55
|
-
print(f"
|
65
|
+
print(f"Couldn't get {var_name} from model: {e}")
|
66
|
+
# Keep the default value we already set
|
56
67
|
|
57
68
|
# Ensure all parameters are included in correct order
|
58
69
|
ordered_params = []
|
59
70
|
for arg in func_node.args.args:
|
60
71
|
arg_name = arg.arg
|
61
|
-
if arg_name == 'self': # Skip self parameter
|
72
|
+
if arg_name == 'self': # Skip self parameter
|
62
73
|
continue
|
63
74
|
if arg_name in param_values:
|
64
75
|
ordered_params.append(param_values[arg_name])
|
65
76
|
else:
|
66
|
-
|
67
|
-
# Provide default values based on annotation if available
|
77
|
+
# Default value handling if parameter not in solution
|
68
78
|
if hasattr(arg, 'annotation') and arg.annotation:
|
69
79
|
if isinstance(arg.annotation, ast.Name):
|
70
80
|
if arg.annotation.id == 'int':
|
@@ -85,7 +95,12 @@ def solve_branch_condition(file_name: str, func_node: ast.FunctionDef, uncovered
|
|
85
95
|
func_name = func_node.name
|
86
96
|
try:
|
87
97
|
module = testgen.util.file_utils.load_module(file_name)
|
88
|
-
|
98
|
+
if class_name is not None:
|
99
|
+
class_obj = getattr(module, class_name)
|
100
|
+
instance = class_obj()
|
101
|
+
func = getattr(instance, func_name)
|
102
|
+
else:
|
103
|
+
func = getattr(module, func_name)
|
89
104
|
result = func(*ordered_params)
|
90
105
|
test_cases.append(TestCase(func_name, tuple(ordered_params), result))
|
91
106
|
except Exception as e:
|