testgenie-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- testgen/__init__.py +0 -0
- testgen/analyzer/__init__.py +0 -0
- testgen/analyzer/ast_analyzer.py +149 -0
- testgen/analyzer/contracts/__init__.py +0 -0
- testgen/analyzer/contracts/contract.py +13 -0
- testgen/analyzer/contracts/no_exception_contract.py +16 -0
- testgen/analyzer/contracts/nonnull_contract.py +15 -0
- testgen/analyzer/fuzz_analyzer.py +106 -0
- testgen/analyzer/random_feedback_analyzer.py +291 -0
- testgen/analyzer/reinforcement_analyzer.py +75 -0
- testgen/analyzer/test_case_analyzer.py +46 -0
- testgen/analyzer/test_case_analyzer_context.py +58 -0
- testgen/controller/__init__.py +0 -0
- testgen/controller/cli_controller.py +194 -0
- testgen/controller/docker_controller.py +169 -0
- testgen/docker/Dockerfile +22 -0
- testgen/docker/poetry.lock +361 -0
- testgen/docker/pyproject.toml +22 -0
- testgen/generator/__init__.py +0 -0
- testgen/generator/code_generator.py +66 -0
- testgen/generator/doctest_generator.py +208 -0
- testgen/generator/generator.py +55 -0
- testgen/generator/pytest_generator.py +77 -0
- testgen/generator/test_generator.py +26 -0
- testgen/generator/unit_test_generator.py +84 -0
- testgen/inspector/__init__.py +0 -0
- testgen/inspector/inspector.py +61 -0
- testgen/main.py +13 -0
- testgen/models/__init__.py +0 -0
- testgen/models/analysis_context.py +56 -0
- testgen/models/function_metadata.py +61 -0
- testgen/models/generator_context.py +63 -0
- testgen/models/test_case.py +8 -0
- testgen/presentation/__init__.py +0 -0
- testgen/presentation/cli_view.py +12 -0
- testgen/q_table/global_q_table.json +1 -0
- testgen/reinforcement/__init__.py +0 -0
- testgen/reinforcement/abstract_state.py +7 -0
- testgen/reinforcement/agent.py +153 -0
- testgen/reinforcement/environment.py +215 -0
- testgen/reinforcement/statement_coverage_state.py +33 -0
- testgen/service/__init__.py +0 -0
- testgen/service/analysis_service.py +260 -0
- testgen/service/cfg_service.py +55 -0
- testgen/service/generator_service.py +169 -0
- testgen/service/service.py +389 -0
- testgen/sqlite/__init__.py +0 -0
- testgen/sqlite/db.py +84 -0
- testgen/sqlite/db_service.py +219 -0
- testgen/tree/__init__.py +0 -0
- testgen/tree/node.py +7 -0
- testgen/tree/tree_utils.py +79 -0
- testgen/util/__init__.py +0 -0
- testgen/util/coverage_utils.py +168 -0
- testgen/util/coverage_visualizer.py +154 -0
- testgen/util/file_utils.py +110 -0
- testgen/util/randomizer.py +122 -0
- testgen/util/utils.py +143 -0
- testgen/util/z3_utils/__init__.py +0 -0
- testgen/util/z3_utils/ast_to_z3.py +99 -0
- testgen/util/z3_utils/branch_condition.py +72 -0
- testgen/util/z3_utils/constraint_extractor.py +36 -0
- testgen/util/z3_utils/variable_finder.py +10 -0
- testgen/util/z3_utils/z3_test_case.py +94 -0
- testgenie_py-0.1.0.dist-info/METADATA +24 -0
- testgenie_py-0.1.0.dist-info/RECORD +68 -0
- testgenie_py-0.1.0.dist-info/WHEEL +4 -0
- testgenie_py-0.1.0.dist-info/entry_points.txt +3 -0
testgen/tree/node.py
ADDED
@@ -0,0 +1,79 @@
|
|
1
|
+
from collections import deque
|
2
|
+
from .node import *
|
3
|
+
import operator
|
4
|
+
|
5
|
+
|
6
|
+
def apply_operation(func, *args):
|
7
|
+
if not args:
|
8
|
+
return True # Default to True if no arguments are given
|
9
|
+
result = args[0]
|
10
|
+
for arg in args[1:]:
|
11
|
+
result = func(result, arg)
|
12
|
+
return result
|
13
|
+
|
14
|
+
|
15
|
+
def build_binary_tree(node, level, max_level):
|
16
|
+
if level >= max_level:
|
17
|
+
return
|
18
|
+
|
19
|
+
true_child = Node(True)
|
20
|
+
false_child = Node(False)
|
21
|
+
|
22
|
+
node.add_child(true_child)
|
23
|
+
node.add_child(false_child)
|
24
|
+
|
25
|
+
build_binary_tree(true_child, level + 1, max_level)
|
26
|
+
build_binary_tree(false_child, level + 1, max_level)
|
27
|
+
|
28
|
+
|
29
|
+
def print_level_order_tree(node):
|
30
|
+
if node is None:
|
31
|
+
return
|
32
|
+
|
33
|
+
queue = deque([(node, 0)])
|
34
|
+
|
35
|
+
while queue:
|
36
|
+
current_node, level = queue.popleft()
|
37
|
+
|
38
|
+
print(f"Level {level}, {current_node.value} -> ", end="")
|
39
|
+
|
40
|
+
if current_node.children:
|
41
|
+
print(", ".join(str(child.value) for child in current_node.children))
|
42
|
+
else:
|
43
|
+
print("None")
|
44
|
+
|
45
|
+
for child in current_node.children:
|
46
|
+
queue.append((child, level + 1))
|
47
|
+
|
48
|
+
|
49
|
+
def generate_boolean_function(parameters, operation):
|
50
|
+
def evaluate_path(path_values):
|
51
|
+
return apply_operation(operation, *path_values)
|
52
|
+
|
53
|
+
def traverse(index, path_values):
|
54
|
+
"""Recursively constructs the function string."""
|
55
|
+
if index == len(parameters):
|
56
|
+
result = evaluate_path(path_values)
|
57
|
+
return f"return {str(result)}\n"
|
58
|
+
|
59
|
+
param = parameters[index]
|
60
|
+
|
61
|
+
true_branch = f"if {param} == True:\n"
|
62
|
+
true_branch += " " + traverse(index + 1, path_values + [True]).replace("\n", "\n ")
|
63
|
+
|
64
|
+
false_branch = f"else:\nif {param} == False:\n"
|
65
|
+
false_branch += " " + traverse(index + 1, path_values + [False]).replace("\n", "\n ")
|
66
|
+
|
67
|
+
return true_branch + "\n" + false_branch
|
68
|
+
|
69
|
+
function_str = traverse(0, [])
|
70
|
+
|
71
|
+
return f"def boolean_function({', '.join(parameters)}):\n" + " " + function_str.replace("\n", "\n ")
|
72
|
+
|
73
|
+
|
74
|
+
# TODO
|
75
|
+
def evaluate_path(path_values):
|
76
|
+
return
|
77
|
+
|
78
|
+
# if __name__ == '__main__':
|
79
|
+
# build_bin_tree(["x", "y", "z"])
|
testgen/util/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,168 @@
|
|
1
|
+
import ast
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
import coverage
|
5
|
+
|
6
|
+
from testgen.models.test_case import TestCase
|
7
|
+
from testgen.util.file_utils import load_and_parse_file_for_tree, load_module
|
8
|
+
from testgen.util.utils import get_function_boundaries
|
9
|
+
from testgen.util.z3_utils.constraint_extractor import extract_branch_conditions
|
10
|
+
|
11
|
+
|
12
|
+
def get_branch_coverage(file_name, func, *args) -> list:
|
13
|
+
cov = coverage.Coverage(branch=True)
|
14
|
+
cov.start()
|
15
|
+
|
16
|
+
func(*args)
|
17
|
+
|
18
|
+
cov.stop()
|
19
|
+
cov.save()
|
20
|
+
|
21
|
+
analysis = cov.analysis2(file_name)
|
22
|
+
|
23
|
+
branches = analysis.arcs()
|
24
|
+
return branches
|
25
|
+
|
26
|
+
|
27
|
+
def get_coverage_analysis(file_name, func_name, args) -> tuple:
|
28
|
+
tree = load_and_parse_file_for_tree(file_name)
|
29
|
+
func_node = None
|
30
|
+
func_start = None
|
31
|
+
func_end = None
|
32
|
+
for i, node in enumerate(tree.body):
|
33
|
+
if isinstance(node, ast.FunctionDef) and node.name == func_name:
|
34
|
+
func_node = node
|
35
|
+
func_start = node.lineno
|
36
|
+
|
37
|
+
if i == len(tree.body) - 1:
|
38
|
+
max_lines = [line.lineno for line in ast.walk(node) if hasattr(line, 'lineno') and line.lineno]
|
39
|
+
func_end = max(max_lines) if max_lines else func_start
|
40
|
+
else:
|
41
|
+
next_node = tree.body[i + 1]
|
42
|
+
if hasattr(next_node, 'lineno'):
|
43
|
+
func_end = next_node.lineno - 1
|
44
|
+
else:
|
45
|
+
max_lines = [line.lineno for line in ast.walk(node)
|
46
|
+
if hasattr(line, 'lineno') and line.lineno]
|
47
|
+
func_end = max(max_lines) if max_lines else func_start
|
48
|
+
break
|
49
|
+
|
50
|
+
if not func_node:
|
51
|
+
raise ValueError(f"Function {func_name} not found in {file_name}")
|
52
|
+
|
53
|
+
# Enable branch coverage
|
54
|
+
cov = coverage.Coverage(branch=True)
|
55
|
+
cov.start()
|
56
|
+
module = load_module(file_name)
|
57
|
+
|
58
|
+
func = getattr(module, func_name)
|
59
|
+
|
60
|
+
func(*args)
|
61
|
+
|
62
|
+
cov.stop()
|
63
|
+
cov.save()
|
64
|
+
|
65
|
+
analysis = cov.analysis2(file_name)
|
66
|
+
analysis_list = list(analysis)
|
67
|
+
|
68
|
+
# Filter executable and missed lines to function range
|
69
|
+
analysis_list[1] = [line for line in analysis_list[1] if func_start <= line <= func_end]
|
70
|
+
analysis_list[3] = [line for line in analysis_list[3] if func_start <= line <= func_end]
|
71
|
+
|
72
|
+
# Find all branching statements (if/else) in function
|
73
|
+
branch_lines = []
|
74
|
+
for node in ast.walk(func_node):
|
75
|
+
if isinstance(node, ast.If):
|
76
|
+
# Add the 'if' line
|
77
|
+
branch_lines.append(node.lineno)
|
78
|
+
|
79
|
+
# Find 'else' lines by analyzing orelse block
|
80
|
+
if node.orelse:
|
81
|
+
for else_item in node.orelse:
|
82
|
+
if hasattr(else_item, 'lineno'):
|
83
|
+
# Add line before the first statement in else block
|
84
|
+
else_line = else_item.lineno - 1
|
85
|
+
branch_lines.append(else_line)
|
86
|
+
break
|
87
|
+
|
88
|
+
# Add branch lines to executable statements if not already present
|
89
|
+
for line in branch_lines:
|
90
|
+
if func_start <= line <= func_end and line not in analysis_list[1]:
|
91
|
+
analysis_list[1].append(line)
|
92
|
+
analysis_list[1].sort()
|
93
|
+
|
94
|
+
# Make sure func_start is in executable and not in missed
|
95
|
+
if func_start not in analysis_list[1]:
|
96
|
+
analysis_list[1].append(func_start)
|
97
|
+
analysis_list[1].sort()
|
98
|
+
if func_start in analysis_list[3]:
|
99
|
+
analysis_list[3].remove(func_start)
|
100
|
+
|
101
|
+
return tuple(analysis_list)
|
102
|
+
|
103
|
+
|
104
|
+
def get_coverage_percentage(analysis: tuple) -> float:
|
105
|
+
total_statements = len(analysis[1])
|
106
|
+
missed_statements = len(analysis[3])
|
107
|
+
covered_statements = total_statements - missed_statements
|
108
|
+
return (covered_statements / total_statements) * 100 if total_statements > 0 else 0
|
109
|
+
|
110
|
+
|
111
|
+
def get_list_of_missed_lines(analysis: tuple) -> list:
|
112
|
+
return analysis[3] # analysis[3] is list of missed line numbers
|
113
|
+
|
114
|
+
|
115
|
+
def get_list_of_covered_statements(analysis: tuple) -> list:
|
116
|
+
# analysis[1] == list of total executable statements
|
117
|
+
# analysis[3] == list of missed statements
|
118
|
+
return [x for x in analysis[1] if x not in analysis[3]]
|
119
|
+
# total_statements = len(analysis[1]) # list of executable statements
|
120
|
+
# missed_statements = len(analysis[3]) # list of missed line numbers
|
121
|
+
|
122
|
+
# covered_statements = total_statements - missed_statements
|
123
|
+
|
124
|
+
|
125
|
+
def get_uncovered_lines_for_func(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[int]:
|
126
|
+
# Get normal uncovered lines
|
127
|
+
func_name = func_node.name
|
128
|
+
if not test_cases:
|
129
|
+
print(f"Warning: No test cases provided {func_name}.")
|
130
|
+
return []
|
131
|
+
|
132
|
+
function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
|
133
|
+
if not function_test_cases:
|
134
|
+
print(f"Warning: No test cases found for function {func_name}.")
|
135
|
+
return []
|
136
|
+
|
137
|
+
module = load_module(file_name)
|
138
|
+
func = getattr(module, func_name)
|
139
|
+
|
140
|
+
# Run coverage
|
141
|
+
cov = coverage.Coverage(branch=True) # Enable branch coverage
|
142
|
+
cov.start()
|
143
|
+
for test_case in function_test_cases:
|
144
|
+
if test_case.func_name == func_name:
|
145
|
+
try:
|
146
|
+
func(*test_case.inputs)
|
147
|
+
except Exception as e:
|
148
|
+
print(f"Warning: Test Case {test_case.inputs} failed with error: {e}")
|
149
|
+
cov.stop()
|
150
|
+
|
151
|
+
analysis = cov.analysis2(file_name)
|
152
|
+
|
153
|
+
# Extract branch conditions from the function
|
154
|
+
branch_conditions, _ = extract_branch_conditions(func_node)
|
155
|
+
condition_line_numbers = [bc.line_number for bc in branch_conditions]
|
156
|
+
|
157
|
+
# Check which branch condition lines weren't exercised
|
158
|
+
func_start, func_end = get_function_boundaries(file_name, func_name)
|
159
|
+
missed_lines = [line for line in analysis[3] if func_start <= line <= func_end]
|
160
|
+
|
161
|
+
# Find branch conditions that need to be tested (those near missed lines)
|
162
|
+
uncovered_branch_lines = []
|
163
|
+
for line in condition_line_numbers:
|
164
|
+
# Check if the condition itself or its following line (likely the branch body) is uncovered
|
165
|
+
if line in missed_lines or (line + 1) in missed_lines:
|
166
|
+
uncovered_branch_lines.append(line)
|
167
|
+
|
168
|
+
return uncovered_branch_lines
|
@@ -0,0 +1,154 @@
|
|
1
|
+
import ast
|
2
|
+
import os
|
3
|
+
from typing import Dict, List, Set
|
4
|
+
import coverage
|
5
|
+
|
6
|
+
import testgen.util.coverage_utils
|
7
|
+
from testgen.models.test_case import TestCase
|
8
|
+
import pygraphviz as pgv
|
9
|
+
|
10
|
+
class CoverageVisualizer:
|
11
|
+
def __init__(self):
|
12
|
+
self.service = None
|
13
|
+
self.cov = coverage.Coverage(branch=True)
|
14
|
+
self.covered_lines: Dict[str, Set[int]] = {}
|
15
|
+
|
16
|
+
def set_service(self, service):
|
17
|
+
self.service = service
|
18
|
+
|
19
|
+
def get_covered_lines(self, file_path: str, func_def: ast.FunctionDef, test_cases: List[TestCase]):
|
20
|
+
if func_def.name not in self.covered_lines:
|
21
|
+
self.covered_lines[func_def.name] = set()
|
22
|
+
|
23
|
+
for test_case in [tc for tc in test_cases if tc.func_name == func_def.name]:
|
24
|
+
analysis = testgen.util.coverage_utils.get_coverage_analysis(file_path, func_def.name, test_case.inputs)
|
25
|
+
covered = testgen.util.coverage_utils.get_list_of_covered_statements(analysis)
|
26
|
+
if covered:
|
27
|
+
self.covered_lines[func_def.name].update(covered)
|
28
|
+
|
29
|
+
if func_def.name in self.covered_lines:
|
30
|
+
print(f"Covered lines for {func_def.name}: {self.covered_lines[func_def.name]}")
|
31
|
+
else:
|
32
|
+
print(f"No coverage data found for {func_def.name}")
|
33
|
+
|
34
|
+
def generate_colored_cfg(self, function_name, output_path):
|
35
|
+
"""Generate colored CFG for a function showing test coverage"""
|
36
|
+
source_file = self.service.file_path
|
37
|
+
|
38
|
+
# Get absolute path
|
39
|
+
abs_source_file = os.path.abspath(source_file)
|
40
|
+
|
41
|
+
# Verify file exists
|
42
|
+
if not os.path.exists(abs_source_file):
|
43
|
+
print(f"ERROR: File does not exist: {abs_source_file}")
|
44
|
+
return None
|
45
|
+
|
46
|
+
# Read source code safely
|
47
|
+
try:
|
48
|
+
with open(abs_source_file, 'r') as f:
|
49
|
+
source_code = f.read()
|
50
|
+
except Exception as e:
|
51
|
+
print(f"Error reading source file: {e}")
|
52
|
+
return None
|
53
|
+
|
54
|
+
try:
|
55
|
+
tree = ast.parse(source_code)
|
56
|
+
ast_functions = [node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)]
|
57
|
+
print(f"Functions found by AST: {ast_functions}")
|
58
|
+
|
59
|
+
return self._create_basic_cfg(source_code, function_name, output_path)
|
60
|
+
|
61
|
+
except Exception as e:
|
62
|
+
print(f"Error in CFG generation: {e}")
|
63
|
+
import traceback
|
64
|
+
traceback.print_exc()
|
65
|
+
return self._create_basic_cfg(source_code, function_name, output_path)
|
66
|
+
|
67
|
+
def _create_basic_cfg(self, source_code, function_name, output_path):
|
68
|
+
"""Create a better CFG visualization showing actual branches"""
|
69
|
+
# Parse the code to find the function
|
70
|
+
tree = ast.parse(source_code)
|
71
|
+
|
72
|
+
# Find the requested function
|
73
|
+
func_node = None
|
74
|
+
for node in ast.walk(tree):
|
75
|
+
if isinstance(node, ast.FunctionDef) and node.name == function_name:
|
76
|
+
func_node = node
|
77
|
+
break
|
78
|
+
|
79
|
+
if not func_node:
|
80
|
+
raise ValueError(f"Function {function_name} not found in AST")
|
81
|
+
|
82
|
+
# Create a directed graph
|
83
|
+
graph = pgv.AGraph(directed=True)
|
84
|
+
|
85
|
+
next_id = 0
|
86
|
+
|
87
|
+
entry_id = next_id
|
88
|
+
next_id += 1
|
89
|
+
graph.add_node(entry_id, label=f"def {function_name}()", style="filled", fillcolor="#ddffdd")
|
90
|
+
|
91
|
+
def process_node(ast_node, parent_id):
|
92
|
+
nonlocal next_id
|
93
|
+
|
94
|
+
if isinstance(ast_node, ast.If):
|
95
|
+
# Create if condition node
|
96
|
+
if_id = next_id
|
97
|
+
next_id += 1
|
98
|
+
line_num = ast_node.lineno
|
99
|
+
line_text = source_code.split('\n')[line_num - 1].strip()
|
100
|
+
covered = line_num in self.covered_lines[function_name]
|
101
|
+
color = "#ddffdd" if covered else "#ffdddd"
|
102
|
+
graph.add_node(if_id, label=line_text, style="filled", fillcolor=color)
|
103
|
+
graph.add_edge(parent_id, if_id)
|
104
|
+
|
105
|
+
if ast_node.body:
|
106
|
+
next_id += 1
|
107
|
+
graph.add_edge(if_id, next_id, label="True")
|
108
|
+
# Process true branch
|
109
|
+
true_id = process_block(ast_node.body, if_id, "True")
|
110
|
+
else:
|
111
|
+
true_id = if_id
|
112
|
+
|
113
|
+
if ast_node.orelse:
|
114
|
+
next_id += 1
|
115
|
+
graph.add_edge(if_id, next_id, label="False")
|
116
|
+
# Process false branch
|
117
|
+
false_id = process_block(ast_node.orelse, if_id, "False")
|
118
|
+
else:
|
119
|
+
false_id = if_id
|
120
|
+
|
121
|
+
return next_id - 1
|
122
|
+
|
123
|
+
elif isinstance(ast_node, ast.Return):
|
124
|
+
return_id = next_id
|
125
|
+
next_id += 1
|
126
|
+
line_num = ast_node.lineno
|
127
|
+
line_text = source_code.split('\n')[line_num - 1].strip()
|
128
|
+
covered = line_num in self.covered_lines[function_name]
|
129
|
+
color = "#ddffdd" if covered else "#ffdddd"
|
130
|
+
graph.add_node(return_id, label=line_text, style="filled", fillcolor=color)
|
131
|
+
graph.add_edge(parent_id, return_id)
|
132
|
+
return return_id
|
133
|
+
|
134
|
+
return parent_id
|
135
|
+
|
136
|
+
def process_block(nodes, parent_id, branch_label=""):
|
137
|
+
if not nodes:
|
138
|
+
return parent_id
|
139
|
+
|
140
|
+
current_id = parent_id
|
141
|
+
for node in nodes:
|
142
|
+
current_id = process_node(node, current_id)
|
143
|
+
|
144
|
+
return current_id
|
145
|
+
|
146
|
+
# Process the function body
|
147
|
+
process_block(func_node.body, entry_id)
|
148
|
+
|
149
|
+
# Save the graph
|
150
|
+
graph.layout(prog='dot')
|
151
|
+
graph.draw(output_path)
|
152
|
+
print(f"Enhanced basic CFG drawn to {output_path}")
|
153
|
+
|
154
|
+
return output_path
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import ast
|
2
|
+
import importlib.util
|
3
|
+
import os
|
4
|
+
import sys
|
5
|
+
from _ast import Module
|
6
|
+
from importlib import util
|
7
|
+
from types import ModuleType
|
8
|
+
from typing import Dict
|
9
|
+
|
10
|
+
def get_import_info(filepath: str) -> Dict[str, str]:
|
11
|
+
if not os.path.exists(filepath) or not filepath.endswith('.py'):
|
12
|
+
raise ValueError(f"Invalid Python file: {filepath}")
|
13
|
+
|
14
|
+
# Get the directory and filename
|
15
|
+
file_dir = os.path.dirname(os.path.abspath(filepath))
|
16
|
+
module_name = os.path.splitext(os.path.basename(filepath))[0]
|
17
|
+
|
18
|
+
# Check if this is part of a package (has __init__.py)
|
19
|
+
is_package = os.path.exists(os.path.join(file_dir, '__init__.py'))
|
20
|
+
|
21
|
+
# Find the project root by looking for setup.py or a .git directory
|
22
|
+
project_root = find_project_root(file_dir)
|
23
|
+
|
24
|
+
# Build the import path based on the file's location relative to the project root
|
25
|
+
if project_root:
|
26
|
+
rel_path = os.path.relpath(file_dir, project_root)
|
27
|
+
if rel_path == '.':
|
28
|
+
# File is directly in the project root
|
29
|
+
import_path = module_name
|
30
|
+
package_name = ''
|
31
|
+
else:
|
32
|
+
# File is in a subdirectory
|
33
|
+
path_parts = rel_path.replace('\\', '/').split('/')
|
34
|
+
# Filter out any empty parts
|
35
|
+
path_parts = [part for part in path_parts if part]
|
36
|
+
|
37
|
+
if path_parts:
|
38
|
+
package_name = path_parts[0]
|
39
|
+
# Construct the full import path
|
40
|
+
import_path = '.'.join(path_parts) + '.' + module_name
|
41
|
+
else:
|
42
|
+
package_name = ''
|
43
|
+
import_path = module_name
|
44
|
+
else:
|
45
|
+
# Fallback if we can't find a project root
|
46
|
+
package_name = ''
|
47
|
+
import_path = module_name
|
48
|
+
|
49
|
+
info = {
|
50
|
+
'module_name': module_name,
|
51
|
+
'package_name': package_name,
|
52
|
+
'import_path': import_path,
|
53
|
+
'is_package': is_package,
|
54
|
+
'project_root': project_root,
|
55
|
+
'file_dir': file_dir
|
56
|
+
}
|
57
|
+
|
58
|
+
print(f"INFO: {info}")
|
59
|
+
|
60
|
+
return info
|
61
|
+
|
62
|
+
def find_project_root(start_dir: str) -> str | None:
|
63
|
+
current_dir = start_dir
|
64
|
+
|
65
|
+
# Walk up the directory tree
|
66
|
+
while current_dir:
|
67
|
+
# Check for common project root indicators
|
68
|
+
if (os.path.exists(os.path.join(current_dir, 'setup.py')) or
|
69
|
+
os.path.exists(os.path.join(current_dir, '.git')) or
|
70
|
+
os.path.exists(os.path.join(current_dir, 'pyproject.toml'))):
|
71
|
+
return current_dir
|
72
|
+
|
73
|
+
if os.path.exists(os.path.join(current_dir, '__main__.py')):
|
74
|
+
return current_dir
|
75
|
+
|
76
|
+
# Check for 'testgen' directory which is your project name
|
77
|
+
if os.path.basename(current_dir) == 'testgen':
|
78
|
+
parent_dir = os.path.dirname(current_dir)
|
79
|
+
return parent_dir
|
80
|
+
|
81
|
+
parent_dir = os.path.dirname(current_dir)
|
82
|
+
if parent_dir == current_dir: # Reached the root
|
83
|
+
break
|
84
|
+
current_dir = parent_dir
|
85
|
+
|
86
|
+
return None
|
87
|
+
|
88
|
+
|
89
|
+
def load_module(file_path: str) -> ModuleType:
|
90
|
+
# Load a Python module from a file path.
|
91
|
+
if file_path is None:
|
92
|
+
raise ValueError("File path not set! Use set_file_path() to specify the path of the file")
|
93
|
+
|
94
|
+
module_name = os.path.splitext(os.path.basename(file_path))[0]
|
95
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
96
|
+
module = importlib.util.module_from_spec(spec)
|
97
|
+
spec.loader.exec_module(module)
|
98
|
+
return module
|
99
|
+
|
100
|
+
|
101
|
+
def get_filename(filepath: str) -> str:
|
102
|
+
"""Get filename from filepath."""
|
103
|
+
return os.path.basename(filepath)
|
104
|
+
|
105
|
+
|
106
|
+
def load_and_parse_file_for_tree(file) -> Module:
|
107
|
+
with open(file) as f:
|
108
|
+
code = f.read()
|
109
|
+
tree = ast.parse(code)
|
110
|
+
return tree
|
@@ -0,0 +1,122 @@
|
|
1
|
+
import ast
|
2
|
+
import random
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
import testgen.util.coverage_utils
|
6
|
+
import testgen.util.file_utils
|
7
|
+
import testgen.util.utils as utils
|
8
|
+
# import testgen.util.z3_test_case
|
9
|
+
try:
|
10
|
+
from testgen.util.z3_utils.z3_test_case import solve_branch_condition
|
11
|
+
except ImportError:
|
12
|
+
print("ERROR IMPORTING Z3 TEST CASE")
|
13
|
+
solve_branch_condition = None
|
14
|
+
from testgen.models.test_case import TestCase
|
15
|
+
|
16
|
+
def make_random_move(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
|
17
|
+
random_choice = random.randint(1, 4)
|
18
|
+
func_name = func_node.name
|
19
|
+
# new random test case
|
20
|
+
if random_choice == 1:
|
21
|
+
test_cases.append(new_random_test_case(file_name, func_node))
|
22
|
+
# combine test cases
|
23
|
+
if random_choice == 2:
|
24
|
+
test_cases.append(combine_cases(test_cases))
|
25
|
+
# delete test case
|
26
|
+
if random_choice == 3:
|
27
|
+
test_cases = remove_case(test_cases)
|
28
|
+
|
29
|
+
if random_choice == 4:
|
30
|
+
# TODO: Not sure what to use for test case args/inputs i.e. test_cases[0].inputs is WRONG
|
31
|
+
function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
|
32
|
+
|
33
|
+
if function_test_cases:
|
34
|
+
uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, func_name)
|
35
|
+
|
36
|
+
if len(uncovered_lines) > 0:
|
37
|
+
z3_test_cases = solve_branch_condition(file_name, func_node, uncovered_lines)
|
38
|
+
test_cases.extend(z3_test_cases)
|
39
|
+
|
40
|
+
return test_cases
|
41
|
+
|
42
|
+
def new_random_test_case(file_name: str, func_node: ast.FunctionDef) -> TestCase:
|
43
|
+
func_name = func_node.name
|
44
|
+
param_types: dict = utils.extract_parameter_types(func_node)
|
45
|
+
inputs: dict = utils.generate_random_inputs(param_types)
|
46
|
+
args = inputs.values()
|
47
|
+
|
48
|
+
module = testgen.util.file_utils.load_module(file_name)
|
49
|
+
func = getattr(module, func_name)
|
50
|
+
|
51
|
+
output = func(*args)
|
52
|
+
|
53
|
+
return TestCase(func_name, tuple(args), output)
|
54
|
+
|
55
|
+
# Should combining test cases preserve the parent cases or entirely replace them?
|
56
|
+
def combine_cases(test_cases: List[TestCase]) -> TestCase:
|
57
|
+
# TODO: Research Genetic Algorithms
|
58
|
+
|
59
|
+
random_index1 = random.randint(0, len(test_cases) - 1)
|
60
|
+
test_case1 = test_cases[random_index1]
|
61
|
+
|
62
|
+
test_cases_of_the_same_function = [tc for tc in test_cases if tc.func_name == test_case1.func_name]
|
63
|
+
random_index2 = random.randint(0, len(test_cases) - 1)
|
64
|
+
|
65
|
+
test_case2 = test_cases_of_the_same_function[random_index2]
|
66
|
+
|
67
|
+
mixed_inputs = mix_inputs(test_case1, test_case2)
|
68
|
+
|
69
|
+
# Calculate the new expected value??
|
70
|
+
return TestCase(test_case1.func_name, mixed_inputs, test_case1.expected)
|
71
|
+
|
72
|
+
def remove_case(test_cases: List[TestCase]) -> List[TestCase]:
|
73
|
+
random_index = random.randint(0, len(test_cases) - 1)
|
74
|
+
del test_cases[random_index]
|
75
|
+
return test_cases
|
76
|
+
|
77
|
+
def mix_inputs(test_case1: TestCase, test_case2: TestCase) -> tuple:
|
78
|
+
len1 = len(test_case1.inputs)
|
79
|
+
len2 = len(test_case2.inputs)
|
80
|
+
|
81
|
+
if len1 != len2:
|
82
|
+
raise ValueError("Test cases must have the same number of inputs")
|
83
|
+
|
84
|
+
half = len1 // 2
|
85
|
+
|
86
|
+
new_inputs = test_case1.inputs[:half] + test_case2.inputs[half:]
|
87
|
+
|
88
|
+
return new_inputs
|
89
|
+
|
90
|
+
def get_z3_test_cases(file_name: str, func_node: ast.FunctionDef, test_cases: List[TestCase]) -> List[TestCase]:
|
91
|
+
func_name = func_node.name
|
92
|
+
|
93
|
+
# Filter test cases for this specific function
|
94
|
+
function_test_cases = [tc for tc in test_cases if tc.func_name == func_name]
|
95
|
+
|
96
|
+
if not function_test_cases:
|
97
|
+
initial_case = new_random_test_case(file_name, func_node)
|
98
|
+
test_cases.append(initial_case)
|
99
|
+
function_test_cases = [initial_case]
|
100
|
+
|
101
|
+
try:
|
102
|
+
# Get uncovered lines
|
103
|
+
uncovered_lines = testgen.util.coverage_utils.get_uncovered_lines_for_func(file_name, func_node, function_test_cases)
|
104
|
+
|
105
|
+
if uncovered_lines:
|
106
|
+
if solve_branch_condition:
|
107
|
+
# Call the Z3 solver with uncovered lines
|
108
|
+
z3_cases = solve_branch_condition(file_name, func_node, uncovered_lines)
|
109
|
+
if z3_cases:
|
110
|
+
test_cases.extend(z3_cases)
|
111
|
+
else:
|
112
|
+
print("Z3 couldn't solve branch conditions")
|
113
|
+
else:
|
114
|
+
print("Z3 solver not available (solve_branch_condition is None)")
|
115
|
+
else:
|
116
|
+
print("No uncovered lines found for Z3 to solve")
|
117
|
+
except Exception as e:
|
118
|
+
print(f"Error in Z3 test generation: {e}")
|
119
|
+
import traceback
|
120
|
+
traceback.print_exc()
|
121
|
+
|
122
|
+
return test_cases
|