testgenie-py 0.3.6__tar.gz → 0.3.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/PKG-INFO +1 -1
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/pyproject.toml +1 -1
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/ast_analyzer.py +2 -11
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/fuzz_analyzer.py +1 -6
- testgenie_py-0.3.8/testgen/analyzer/random_feedback_analyzer.py +248 -0
- testgenie_py-0.3.8/testgen/analyzer/reinforcement_analyzer.py +77 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/test_case_analyzer_context.py +0 -6
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/controller/cli_controller.py +35 -29
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/controller/docker_controller.py +3 -2
- testgenie_py-0.3.8/testgen/db/dao.py +68 -0
- testgenie_py-0.3.8/testgen/db/dao_impl.py +226 -0
- {testgenie_py-0.3.6/testgen/sqlite → testgenie_py-0.3.8/testgen/db}/db.py +15 -6
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/generator/pytest_generator.py +2 -10
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/generator/unit_test_generator.py +2 -11
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/main.py +1 -3
- testgenie_py-0.3.8/testgen/models/coverage_data.py +56 -0
- testgenie_py-0.3.8/testgen/models/db_test_case.py +65 -0
- testgenie_py-0.3.8/testgen/models/function.py +56 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/models/function_metadata.py +11 -1
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/models/generator_context.py +32 -2
- testgenie_py-0.3.8/testgen/models/source_file.py +29 -0
- testgenie_py-0.3.8/testgen/models/test_result.py +38 -0
- testgenie_py-0.3.8/testgen/models/test_suite.py +20 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/reinforcement/agent.py +1 -27
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/reinforcement/environment.py +11 -93
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/reinforcement/statement_coverage_state.py +5 -4
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/service/analysis_service.py +31 -22
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/service/cfg_service.py +3 -1
- testgenie_py-0.3.8/testgen/service/coverage_service.py +115 -0
- testgenie_py-0.3.8/testgen/service/db_service.py +140 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/service/generator_service.py +77 -20
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/service/logging_service.py +2 -2
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/service/service.py +62 -231
- testgenie_py-0.3.8/testgen/service/test_executor_service.py +145 -0
- testgenie_py-0.3.8/testgen/util/coverage_utils.py +152 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/util/coverage_visualizer.py +10 -9
- testgenie_py-0.3.8/testgen/util/file_utils.py +85 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/util/randomizer.py +0 -26
- testgenie_py-0.3.8/testgen/util/utils.py +302 -0
- testgenie_py-0.3.6/testgen/analyzer/random_feedback_analyzer.py +0 -521
- testgenie_py-0.3.6/testgen/analyzer/reinforcement_analyzer.py +0 -75
- testgenie_py-0.3.6/testgen/inspector/inspector.py +0 -59
- testgenie_py-0.3.6/testgen/presentation/cli_view.py +0 -12
- testgenie_py-0.3.6/testgen/sqlite/db_service.py +0 -239
- testgenie_py-0.3.6/testgen/testgen.db +0 -0
- testgenie_py-0.3.6/testgen/util/__init__.py +0 -0
- testgenie_py-0.3.6/testgen/util/coverage_utils.py +0 -230
- testgenie_py-0.3.6/testgen/util/file_utils.py +0 -186
- testgenie_py-0.3.6/testgen/util/utils.py +0 -143
- testgenie_py-0.3.6/testgen/util/z3_utils/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/README.md +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/contracts/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/contracts/contract.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/contracts/no_exception_contract.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/contracts/nonnull_contract.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/analyzer/test_case_analyzer.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/controller/__init__.py +0 -0
- {testgenie_py-0.3.6/testgen/generator → testgenie_py-0.3.8/testgen/db}/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/docker/Dockerfile +0 -0
- {testgenie_py-0.3.6/testgen/inspector → testgenie_py-0.3.8/testgen/generator}/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/generator/code_generator.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/generator/doctest_generator.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/generator/generator.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/generator/test_generator.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/models/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/models/analysis_context.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/models/test_case.py +0 -0
- {testgenie_py-0.3.6/testgen/presentation → testgenie_py-0.3.8/testgen/reinforcement}/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/reinforcement/abstract_state.py +0 -0
- {testgenie_py-0.3.6/testgen/reinforcement → testgenie_py-0.3.8/testgen/service}/__init__.py +0 -0
- {testgenie_py-0.3.6/testgen/service → testgenie_py-0.3.8/testgen/tree}/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/tree/node.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/tree/tree_utils.py +0 -0
- {testgenie_py-0.3.6/testgen/sqlite → testgenie_py-0.3.8/testgen/util}/__init__.py +0 -0
- {testgenie_py-0.3.6/testgen/tree → testgenie_py-0.3.8/testgen/util/z3_utils}/__init__.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/util/z3_utils/ast_to_z3.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/util/z3_utils/branch_condition.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/util/z3_utils/constraint_extractor.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/util/z3_utils/variable_finder.py +0 -0
- {testgenie_py-0.3.6 → testgenie_py-0.3.8}/testgen/util/z3_utils/z3_test_case.py +0 -0
@@ -53,17 +53,8 @@ class ASTAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
53
53
|
|
54
54
|
if not input_exists:
|
55
55
|
try:
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
if function_metadata.class_name:
|
60
|
-
cls = getattr(module, function_metadata.class_name)
|
61
|
-
instance = cls()
|
62
|
-
func = getattr(instance, func_name)
|
63
|
-
output = func(*inputs)
|
64
|
-
else:
|
65
|
-
func = getattr(module, func_name)
|
66
|
-
output = func(*inputs)
|
56
|
+
func = function_metadata.func
|
57
|
+
output = func(*inputs)
|
67
58
|
|
68
59
|
except Exception as e:
|
69
60
|
print(f"Error executing function: {e}")
|
@@ -29,13 +29,8 @@ class FuzzAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
|
29
29
|
else:
|
30
30
|
raise ValueError("Module not set in function metadata. Cannot perform fuzzing without a module.")
|
31
31
|
|
32
|
-
class_name = function_metadata.class_name if function_metadata.class_name else None
|
33
32
|
try:
|
34
|
-
|
35
|
-
cls = getattr(module, class_name, None)
|
36
|
-
func = getattr(cls(), function_metadata.function_name, None) if cls else None
|
37
|
-
else:
|
38
|
-
func = getattr(module, function_metadata.function_name, None)
|
33
|
+
func = function_metadata.func
|
39
34
|
if func:
|
40
35
|
return self.run_fuzzing(func, function_metadata.function_name, function_metadata.params, module, 10)
|
41
36
|
except Exception as e:
|
@@ -0,0 +1,248 @@
|
|
1
|
+
import random
|
2
|
+
import time
|
3
|
+
import traceback
|
4
|
+
from typing import List, Dict, Set
|
5
|
+
import testgen.util.coverage_utils as coverage_utils
|
6
|
+
from testgen.analyzer.contracts.contract import Contract
|
7
|
+
from testgen.analyzer.contracts.no_exception_contract import NoExceptionContract
|
8
|
+
from testgen.analyzer.contracts.nonnull_contract import NonNullContract
|
9
|
+
from testgen.models.test_case import TestCase
|
10
|
+
from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
|
11
|
+
from abc import ABC
|
12
|
+
|
13
|
+
from testgen.models.function_metadata import FunctionMetadata
|
14
|
+
|
15
|
+
|
16
|
+
# Citation in which this method and algorithm were taken from:
|
17
|
+
# C. Pacheco, S. K. Lahiri, M. D. Ernst and T. Ball, "Feedback-Directed Random Test Generation," 29th International
|
18
|
+
# Conference on Software Engineering (ICSE'07), Minneapolis, MN, USA, 2007, pp. 75-84, doi: 10.1109/ICSE.2007.37.
|
19
|
+
# keywords: {System testing;Contracts;Object oriented modeling;Law;Legal factors;Open source software;Software
|
20
|
+
# testing;Feedback;Filters;Error correction codes},
|
21
|
+
|
22
|
+
class RandomFeedbackAnalyzer(TestCaseAnalyzerStrategy, ABC):
|
23
|
+
def __init__(self, analysis_context=None):
|
24
|
+
super().__init__(analysis_context)
|
25
|
+
self.test_cases = []
|
26
|
+
self.covered_lines: Dict[str, Set[int]] = {}
|
27
|
+
self.covered_functions: Set[str] = set()
|
28
|
+
|
29
|
+
def collect_test_cases(self, function_metadata: FunctionMetadata, time_limit: int = 5) -> List[TestCase]:
|
30
|
+
self.test_cases = []
|
31
|
+
start_time = time.time()
|
32
|
+
|
33
|
+
while (time.time() - start_time) < time_limit:
|
34
|
+
|
35
|
+
try:
|
36
|
+
param_values = self.generate_random_inputs(function_metadata.params)
|
37
|
+
func_name = function_metadata.function_name
|
38
|
+
function = function_metadata.func
|
39
|
+
|
40
|
+
param_names = function_metadata.params.keys()
|
41
|
+
|
42
|
+
ordered_args = [param_values.get(name, None) for name in param_names]
|
43
|
+
|
44
|
+
result = function(*ordered_args)
|
45
|
+
test_case = TestCase(func_name, tuple(ordered_args), result)
|
46
|
+
|
47
|
+
if not self.is_duplicate_test_case(test_case):
|
48
|
+
self.test_cases.append(test_case)
|
49
|
+
|
50
|
+
covered = self.covered(function_metadata)
|
51
|
+
if covered:
|
52
|
+
break
|
53
|
+
else:
|
54
|
+
# Optionally log duplicate detection
|
55
|
+
self.logger.debug(f"Skipping duplicate test case: {func_name}{test_case.inputs}")
|
56
|
+
|
57
|
+
except Exception as e:
|
58
|
+
print(f"Error testing {function_metadata.function_name}: {e}")
|
59
|
+
|
60
|
+
return self.test_cases
|
61
|
+
|
62
|
+
def is_duplicate_test_case(self, new_test_case: TestCase) -> bool:
|
63
|
+
for existing_test_case in self.test_cases:
|
64
|
+
if (existing_test_case.func_name == new_test_case.func_name and
|
65
|
+
existing_test_case.inputs == new_test_case.inputs):
|
66
|
+
return True
|
67
|
+
return False
|
68
|
+
|
69
|
+
def covered(self, func: FunctionMetadata) -> bool:
|
70
|
+
if func.function_name not in self.covered_lines:
|
71
|
+
self.covered_lines[func.function_name] = set()
|
72
|
+
|
73
|
+
test_cases = [tc for tc in self.test_cases if tc.func_name == func.function_name]
|
74
|
+
|
75
|
+
for test_case in test_cases:
|
76
|
+
analysis = coverage_utils.get_coverage_analysis(self._analysis_context.filepath, func, test_case.inputs)
|
77
|
+
covered = coverage_utils.get_list_of_covered_statements(analysis)
|
78
|
+
self.covered_lines[func.function_name].update(covered)
|
79
|
+
self.logger.debug(f"Covered lines for {func.function_name}: {self.covered_lines[func.function_name]}")
|
80
|
+
|
81
|
+
|
82
|
+
executable_statements = set(coverage_utils.get_all_executable_statements(self._analysis_context.filepath, func, test_cases))
|
83
|
+
self.logger.debug(f"Executable statements for {func.function_name}: {executable_statements}")
|
84
|
+
|
85
|
+
return self.covered_lines[func.function_name] == executable_statements
|
86
|
+
|
87
|
+
def execute_sequence(self, sequence, contracts: List[Contract]):
|
88
|
+
"""Execute a sequence and check contract violations"""
|
89
|
+
func_name, args_dict = sequence
|
90
|
+
|
91
|
+
try:
|
92
|
+
# Use module from analysis context if available
|
93
|
+
function_metadata = self.get_function_metadata(func_name)
|
94
|
+
func = function_metadata.func
|
95
|
+
param_names = function_metadata.params.keys()
|
96
|
+
|
97
|
+
ordered_args = [args_dict.get(name, None) for name in param_names]
|
98
|
+
|
99
|
+
# Check preconditions
|
100
|
+
for contract in contracts:
|
101
|
+
if not contract.check_preconditions(tuple(ordered_args)):
|
102
|
+
print(f"Preconditions failed for {func_name} with {tuple(ordered_args)}")
|
103
|
+
return None, True
|
104
|
+
|
105
|
+
# Execute function with properly ordered arguments
|
106
|
+
output = func(*ordered_args)
|
107
|
+
exception = None
|
108
|
+
|
109
|
+
except Exception as e:
|
110
|
+
print(f"EXCEPTION IN RANDOM FEEDBACK: {e}")
|
111
|
+
print(traceback.format_exc())
|
112
|
+
output = None
|
113
|
+
exception = e
|
114
|
+
|
115
|
+
# Check postconditions
|
116
|
+
for contract in contracts:
|
117
|
+
if not contract.check_postconditions(tuple(ordered_args), output, exception):
|
118
|
+
print(f"Postcondition failed for {func_name} with {tuple(ordered_args)}")
|
119
|
+
return output, True
|
120
|
+
|
121
|
+
return output, False
|
122
|
+
|
123
|
+
def get_function_metadata(self, func_name: str) -> FunctionMetadata | None:
|
124
|
+
for function_data in self._analysis_context.function_data:
|
125
|
+
if function_data.function_name == func_name:
|
126
|
+
return function_data
|
127
|
+
return None
|
128
|
+
|
129
|
+
# TODO: Currently only getting random vals of primitives, extend to sequences
|
130
|
+
def random_seqs_and_vals(self, param_types, non_error_seqs=None):
|
131
|
+
return self.generate_random_inputs(param_types)
|
132
|
+
|
133
|
+
@staticmethod
|
134
|
+
def generate_random_inputs(param_types):
|
135
|
+
"""Generate inputs for fuzzing based on parameter types."""
|
136
|
+
inputs = {}
|
137
|
+
for param, param_type in param_types.items():
|
138
|
+
if param_type == "int":
|
139
|
+
random_integer = random.randint(-500, 500) # Wider range for better edge cases
|
140
|
+
inputs[param] = random_integer
|
141
|
+
elif param_type == "bool":
|
142
|
+
random_choice = random.choice([True, False])
|
143
|
+
inputs[param] = random_choice
|
144
|
+
elif param_type == "float":
|
145
|
+
random_float = random.uniform(-500.0, 500.0) # Wider range for better edge cases
|
146
|
+
inputs[param] = random_float
|
147
|
+
elif param_type == "str":
|
148
|
+
# Generate diverse strings instead of always "abc"
|
149
|
+
string_type = random.choice([
|
150
|
+
"empty", "short", "medium", "long", "special", "numeric", "whitespace"
|
151
|
+
])
|
152
|
+
|
153
|
+
if string_type == "empty":
|
154
|
+
inputs[param] = ""
|
155
|
+
elif string_type == "short":
|
156
|
+
inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=random.randint(1, 3)))
|
157
|
+
elif string_type == "medium":
|
158
|
+
inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ', k=random.randint(4, 10)))
|
159
|
+
elif string_type == "long":
|
160
|
+
inputs[param] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789', k=random.randint(11, 30)))
|
161
|
+
elif string_type == "special":
|
162
|
+
inputs[param] = ''.join(random.choices('!@#$%^&*()_+-=[]{}|;:,./<>?', k=random.randint(1, 8)))
|
163
|
+
elif string_type == "numeric":
|
164
|
+
inputs[param] = ''.join(random.choices('0123456789', k=random.randint(1, 10)))
|
165
|
+
else: # whitespace
|
166
|
+
inputs[param] = ' ' * random.randint(1, 5)
|
167
|
+
else:
|
168
|
+
# For unknown types, try a default value
|
169
|
+
inputs[param] = None
|
170
|
+
|
171
|
+
return inputs
|
172
|
+
|
173
|
+
# Algorithm described in above article
|
174
|
+
# Classes is the classes for which we want to generate sequences
|
175
|
+
# Contracts express invariant properties that hold both at entry and exit from a call
|
176
|
+
# Contract takes as input the current state of the system (runtime values created in the sequence so far, and any exception thrown by the last call), and returns satisfied or violated
|
177
|
+
# Output is the runtime values and boolean flag violated
|
178
|
+
# Filters determine which values of a sequence are extensible and should be used as inputs
|
179
|
+
def generate_sequences(self, function_metadata: List[FunctionMetadata], classes=None, contracts: List[Contract] = None, filters=None, time_limit=20):
|
180
|
+
contracts = [NonNullContract(), NoExceptionContract()]
|
181
|
+
error_seqs = [] # execution violates a contract
|
182
|
+
non_error_seqs = [] # execution does not violate a contract
|
183
|
+
|
184
|
+
functions = self._analysis_context.function_data
|
185
|
+
start_time = time.time()
|
186
|
+
while(time.time() - start_time) >= time_limit:
|
187
|
+
# Get random function
|
188
|
+
func = random.choice(functions)
|
189
|
+
param_types: dict = func.params
|
190
|
+
vals: dict = self.random_seqs_and_vals(param_types)
|
191
|
+
new_seq = (func.function_name, vals)
|
192
|
+
if new_seq in error_seqs or new_seq in non_error_seqs:
|
193
|
+
continue
|
194
|
+
outs_violated: tuple = self.execute_sequence(new_seq, contracts)
|
195
|
+
violated: bool = outs_violated[1]
|
196
|
+
# Create tuple of sequence ((func name, args), output)
|
197
|
+
new_seq_out = (new_seq, outs_violated[0])
|
198
|
+
if violated:
|
199
|
+
error_seqs.append(new_seq_out)
|
200
|
+
else:
|
201
|
+
# Question: Should I use the failed contract to be the assertion in unit test??
|
202
|
+
non_error_seqs.append(new_seq_out)
|
203
|
+
return error_seqs, non_error_seqs
|
204
|
+
|
205
|
+
def generate_sequences_new(self, contracts: List[Contract] = None, filters=None, time_limit=20):
|
206
|
+
contracts = [NonNullContract(), NoExceptionContract()]
|
207
|
+
error_seqs = [] # execution violates a contract
|
208
|
+
non_error_seqs = [] # execution does not violate a contract
|
209
|
+
|
210
|
+
functions = self._analysis_context.function_data.copy()
|
211
|
+
start_time = time.time()
|
212
|
+
|
213
|
+
while (time.time() - start_time) < time_limit:
|
214
|
+
# Get random function
|
215
|
+
func = random.choice(functions)
|
216
|
+
param_types: dict = func.params
|
217
|
+
vals: dict = self.random_seqs_and_vals(param_types)
|
218
|
+
new_seq = (func.function_name, vals)
|
219
|
+
|
220
|
+
if new_seq in [seq[0] for seq in error_seqs] or new_seq in [seq[0] for seq in non_error_seqs]:
|
221
|
+
continue
|
222
|
+
|
223
|
+
outs_violated: tuple = self.execute_sequence(new_seq, contracts)
|
224
|
+
violated: bool = outs_violated[1]
|
225
|
+
|
226
|
+
# Create tuple of sequence ((func name, args), output)
|
227
|
+
new_seq_out = (new_seq, outs_violated[0])
|
228
|
+
|
229
|
+
if violated:
|
230
|
+
error_seqs.append(new_seq_out)
|
231
|
+
|
232
|
+
else:
|
233
|
+
non_error_seqs.append(new_seq_out)
|
234
|
+
|
235
|
+
test_case = TestCase(new_seq_out[0][0], tuple(new_seq_out[0][1].values()), new_seq_out[1])
|
236
|
+
self.test_cases.append(test_case)
|
237
|
+
fully_covered = self.covered(func)
|
238
|
+
if fully_covered:
|
239
|
+
print(f"Function {func.function_name} is fully covered")
|
240
|
+
functions.remove(func)
|
241
|
+
|
242
|
+
if not functions:
|
243
|
+
self.test_cases.sort(key=lambda tc: tc.func_name)
|
244
|
+
print("All functions covered")
|
245
|
+
break
|
246
|
+
|
247
|
+
self.test_cases.sort(key=lambda tc: tc.func_name)
|
248
|
+
return error_seqs, non_error_seqs
|
@@ -0,0 +1,77 @@
|
|
1
|
+
|
2
|
+
import ast
|
3
|
+
from abc import ABC
|
4
|
+
import random
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import testgen.util.randomizer
|
8
|
+
from testgen.models.function_metadata import FunctionMetadata
|
9
|
+
from testgen.models.test_case import TestCase
|
10
|
+
from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
|
11
|
+
from testgen.reinforcement.environment import ReinforcementEnvironment
|
12
|
+
|
13
|
+
|
14
|
+
# Goal: Learn a policy to generate a set of test cases with optimal code coverage, and minimum number of test cases
|
15
|
+
# Environment: FUT (Function Under Test)
|
16
|
+
# Agent: System
|
17
|
+
# Actions: Create new test case, combine test cases, delete test cases
|
18
|
+
# Rewards:
|
19
|
+
|
20
|
+
from typing import List, Optional
|
21
|
+
from testgen.models.test_case import TestCase
|
22
|
+
from testgen.models.analysis_context import AnalysisContext
|
23
|
+
from testgen.analyzer.test_case_analyzer import TestCaseAnalyzerStrategy
|
24
|
+
from testgen.reinforcement.agent import ReinforcementAgent
|
25
|
+
from testgen.reinforcement.environment import ReinforcementEnvironment
|
26
|
+
from testgen.reinforcement.statement_coverage_state import StatementCoverageState
|
27
|
+
|
28
|
+
|
29
|
+
class ReinforcementAnalyzer(TestCaseAnalyzerStrategy):
|
30
|
+
def __init__(self, analysis_context: AnalysisContext, mode: str = "train"):
|
31
|
+
super().__init__(analysis_context)
|
32
|
+
self.analysis_context = analysis_context
|
33
|
+
self.mode = mode
|
34
|
+
|
35
|
+
def collect_test_cases(self, function_metadata: FunctionMetadata):
|
36
|
+
# Implement or delegate as needed
|
37
|
+
return self.analyze(function_metadata)
|
38
|
+
|
39
|
+
def analyze(self, function_metadata: FunctionMetadata) -> List[TestCase]:
|
40
|
+
from testgen.service.analysis_service import AnalysisService
|
41
|
+
|
42
|
+
q_table = AnalysisService._load_q_table()
|
43
|
+
function_test_cases: List[TestCase] = []
|
44
|
+
|
45
|
+
|
46
|
+
environment = ReinforcementEnvironment(
|
47
|
+
self.analysis_context.filepath,
|
48
|
+
function_metadata,
|
49
|
+
function_test_cases,
|
50
|
+
state=StatementCoverageState(None)
|
51
|
+
)
|
52
|
+
environment.state = StatementCoverageState(environment)
|
53
|
+
agent = ReinforcementAgent(
|
54
|
+
self.analysis_context.filepath,
|
55
|
+
environment,
|
56
|
+
function_test_cases,
|
57
|
+
q_table
|
58
|
+
)
|
59
|
+
episodes = 10 if self.mode == "train" else 1
|
60
|
+
for _ in range(episodes):
|
61
|
+
if self.mode == "train":
|
62
|
+
new_test_cases = agent.do_q_learning()
|
63
|
+
else:
|
64
|
+
new_test_cases = agent.collect_test_cases()
|
65
|
+
function_test_cases.extend(new_test_cases)
|
66
|
+
|
67
|
+
seen = set()
|
68
|
+
unique_test_cases = []
|
69
|
+
for case in function_test_cases:
|
70
|
+
case_inputs = tuple(case.inputs) if isinstance(case.inputs, list) else case.inputs
|
71
|
+
case_key = (case.func_name, case_inputs)
|
72
|
+
if case_key not in seen:
|
73
|
+
seen.add(case_key)
|
74
|
+
unique_test_cases.append(case)
|
75
|
+
|
76
|
+
AnalysisService._save_q_table(q_table)
|
77
|
+
return unique_test_cases
|
@@ -11,12 +11,6 @@ class TestCaseAnalyzerContext:
|
|
11
11
|
self._test_case_analyzer = test_case_analyzer
|
12
12
|
self._analysis_context = analysis_context
|
13
13
|
self._test_cases = []
|
14
|
-
|
15
|
-
# TODO: GET RID OF THIS STUPID METHOD IT IS POINTLESS
|
16
|
-
# JUST CALL INSIDE ANALYZER_SERVICE
|
17
|
-
def do_logic(self) -> List[TestCase]:
|
18
|
-
"""Run the analysis process"""
|
19
|
-
self.do_strategy(20)
|
20
14
|
|
21
15
|
def do_strategy(self, time_limit=None) -> List[TestCase]:
|
22
16
|
"""Execute the analysis strategy for all functions with an optional time limit"""
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import argparse
|
2
|
-
import inspect
|
3
2
|
import os
|
4
3
|
import sys
|
5
4
|
|
@@ -10,8 +9,7 @@ from testgen.service.logging_service import LoggingService, get_logger
|
|
10
9
|
from testgen.util.file_utils import adjust_file_path_for_docker, get_project_root_in_docker
|
11
10
|
from testgen.controller.docker_controller import DockerController
|
12
11
|
from testgen.service.service import Service
|
13
|
-
from testgen.
|
14
|
-
from testgen.sqlite.db_service import DBService
|
12
|
+
from testgen.service.db_service import DBService
|
15
13
|
|
16
14
|
AST_STRAT = 1
|
17
15
|
FUZZ_STRAT = 2
|
@@ -24,9 +22,9 @@ DOCTEST_FORMAT = 3
|
|
24
22
|
|
25
23
|
class CLIController:
|
26
24
|
#TODO: Possibly create a view 'interface' and use dependency injection to extend other views
|
27
|
-
def __init__(self, service: Service
|
25
|
+
def __init__(self, service: Service):
|
28
26
|
self.service = service
|
29
|
-
self.
|
27
|
+
self.logger = None
|
30
28
|
|
31
29
|
def run(self):
|
32
30
|
|
@@ -35,12 +33,16 @@ class CLIController:
|
|
35
33
|
args = parser.parse_args()
|
36
34
|
|
37
35
|
LoggingService.get_instance().initialize(
|
38
|
-
debug_mode=args.debug if
|
36
|
+
debug_mode=args.debug if args.debug else False,
|
39
37
|
log_file=args.log_file if hasattr(args, 'log_file') else None,
|
40
38
|
console_output=True
|
41
39
|
)
|
42
40
|
|
43
|
-
logger = get_logger()
|
41
|
+
self.logger = get_logger()
|
42
|
+
|
43
|
+
if args.functions:
|
44
|
+
self.service.get_all_functions(args.file_path)
|
45
|
+
return
|
44
46
|
|
45
47
|
if args.query:
|
46
48
|
print(f"Querying database for file: {args.file_path}")
|
@@ -48,7 +50,7 @@ class CLIController:
|
|
48
50
|
return
|
49
51
|
|
50
52
|
if args.coverage:
|
51
|
-
self.service.
|
53
|
+
self.service.run_coverage(args.file_path)
|
52
54
|
return
|
53
55
|
|
54
56
|
running_in_docker = os.environ.get("RUNNING_IN_DOCKER") is not None
|
@@ -59,7 +61,7 @@ class CLIController:
|
|
59
61
|
client = self.docker_available()
|
60
62
|
# Skip Docker-dependent operations if client is None
|
61
63
|
if client is None and args.safe:
|
62
|
-
self.
|
64
|
+
self.logger.debug("Running with --safe flag requires Docker. Continuing without safe mode.")
|
63
65
|
args.safe = False
|
64
66
|
self.execute_generation(args)
|
65
67
|
else:
|
@@ -69,15 +71,13 @@ class CLIController:
|
|
69
71
|
if not successful:
|
70
72
|
if hasattr(args, 'db') and args.db:
|
71
73
|
self.service.db_service = DBService(args.db)
|
72
|
-
self.
|
74
|
+
self.logger.debug(f"Using database: {args.db}")
|
73
75
|
self.execute_generation(args)
|
74
|
-
# Else successful, do nothing - we're done
|
75
76
|
else:
|
76
|
-
# Initialize database service with specified path
|
77
77
|
if hasattr(args, 'db') and args.db:
|
78
78
|
self.service.db_service = DBService(args.db)
|
79
|
-
self.
|
80
|
-
self.
|
79
|
+
self.logger.debug(f"Using database: {args.db}")
|
80
|
+
self.logger.debug("Running in local mode...")
|
81
81
|
self.execute_generation(args)
|
82
82
|
|
83
83
|
def execute_generation(self, args: argparse.Namespace, running_in_docker: bool = False):
|
@@ -85,22 +85,23 @@ class CLIController:
|
|
85
85
|
self.set_service_args(args)
|
86
86
|
|
87
87
|
if running_in_docker:
|
88
|
-
self.
|
88
|
+
self.logger.debug("Running in Docker mode...")
|
89
89
|
self.service.generate_test_cases()
|
90
90
|
|
91
91
|
else:
|
92
92
|
test_file = self.service.generate_tests(args.output)
|
93
|
-
self.
|
94
|
-
|
93
|
+
self.logger.debug(f"Unit tests saved to: {test_file}")
|
94
|
+
print("Executing tests...")
|
95
|
+
self.service.run_tests(test_file)
|
96
|
+
print("Running coverage...")
|
95
97
|
self.service.run_coverage(test_file)
|
96
|
-
self.
|
98
|
+
self.logger.debug("Tests and coverage data saved to database.")
|
97
99
|
|
98
100
|
if args.visualize:
|
99
101
|
self.service.visualize_test_coverage()
|
100
102
|
|
101
103
|
except Exception as e:
|
102
|
-
self.
|
103
|
-
# Make sure to close the DB connection on error
|
104
|
+
self.logger.error(f"An error occurred: {e}")
|
104
105
|
if hasattr(self.service, 'db_service'):
|
105
106
|
self.service.db_service.close()
|
106
107
|
|
@@ -168,6 +169,11 @@ class CLIController:
|
|
168
169
|
action="store_true",
|
169
170
|
help="Run coverage analysis on the generated tests"
|
170
171
|
)
|
172
|
+
parser.add_argument(
|
173
|
+
"-f", "--functions",
|
174
|
+
action="store_true",
|
175
|
+
help="List all functions in file"
|
176
|
+
)
|
171
177
|
return parser
|
172
178
|
|
173
179
|
def set_test_format(self, args: argparse.Namespace):
|
@@ -180,32 +186,32 @@ class CLIController:
|
|
180
186
|
|
181
187
|
def set_test_strategy(self, args: argparse.Namespace):
|
182
188
|
if args.test_mode == "random":
|
183
|
-
|
189
|
+
print("Using Random Feedback-Directed Test Generation Strategy.")
|
184
190
|
self.service.set_test_analysis_strategy(RANDOM_STRAT)
|
185
191
|
elif args.test_mode == "fuzz":
|
186
|
-
|
192
|
+
print("Using Fuzz Test Generation Strategy...")
|
187
193
|
self.service.set_test_analysis_strategy(FUZZ_STRAT)
|
188
194
|
elif args.test_mode == "reinforce":
|
189
|
-
|
195
|
+
print("Using Reinforcement Learning Test Generation Strategy...")
|
190
196
|
if args.reinforce_mode == "train":
|
191
|
-
|
197
|
+
print("Training mode enabled - will update Q-table")
|
192
198
|
else:
|
193
|
-
|
199
|
+
print("Training mode disabled - will use existing Q-table")
|
194
200
|
self.service.set_test_analysis_strategy(REINFORCE_STRAT)
|
195
201
|
self.service.set_reinforcement_mode(args.reinforce_mode)
|
196
202
|
else:
|
197
|
-
|
203
|
+
print("Generating function code using AST analysis...")
|
198
204
|
generated_file_path = self.service.generate_function_code()
|
199
|
-
|
205
|
+
print(f"Generated code saved to: {generated_file_path}")
|
200
206
|
if not args.generate_only:
|
201
|
-
|
207
|
+
print("Using Simple AST Traversal Test Generation Strategy...")
|
202
208
|
self.service.set_test_analysis_strategy(AST_STRAT)
|
203
209
|
|
204
210
|
def docker_available(self) -> DockerClient | None:
|
205
211
|
try:
|
206
212
|
client = docker.from_env()
|
207
213
|
client.ping()
|
208
|
-
|
214
|
+
print("Docker daemon is running and connected.")
|
209
215
|
return client
|
210
216
|
except docker.errors.DockerException as err:
|
211
217
|
print(f"Docker is not available: {err}")
|
@@ -59,10 +59,10 @@ class DockerController:
|
|
59
59
|
self.debug(f"project_root: {project_root}")
|
60
60
|
container = self.run_container(docker_client, image_name, docker_args, project_root)
|
61
61
|
|
62
|
-
self.clean_up(dest_path)
|
63
|
-
|
64
62
|
# Stream the logs to the console
|
65
63
|
logs_output = self.get_logs(container)
|
64
|
+
|
65
|
+
self.clean_up(dest_path)
|
66
66
|
self.debug(logs_output)
|
67
67
|
|
68
68
|
except Exception as e:
|
@@ -98,6 +98,7 @@ class DockerController:
|
|
98
98
|
|
99
99
|
if not args.generate_only:
|
100
100
|
print("Running coverage...")
|
101
|
+
self.service.run_tests(test_file)
|
101
102
|
self.service.run_coverage(test_file)
|
102
103
|
|
103
104
|
# Add explicit return True here
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import List, Tuple, Any
|
3
|
+
|
4
|
+
from testgen.models.function import Function
|
5
|
+
|
6
|
+
|
7
|
+
class Dao(ABC):
|
8
|
+
@abstractmethod
|
9
|
+
def insert_test_suite(self, name: str) -> int:
|
10
|
+
pass
|
11
|
+
|
12
|
+
@abstractmethod
|
13
|
+
def insert_source_file(self, path: str, lines_of_code: int, last_modified) -> int:
|
14
|
+
pass
|
15
|
+
|
16
|
+
@abstractmethod
|
17
|
+
def insert_function(self, name: str, params, start_line: int, end_line: int, source_file_id: int) -> int:
|
18
|
+
pass
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
def insert_test_case(self, test_case: Any, test_suite_id: int, function_id: int, test_method_type: int) -> int:
|
22
|
+
pass
|
23
|
+
|
24
|
+
@abstractmethod
|
25
|
+
def insert_test_result(self, test_case_id: int, status: bool, error: str = None) -> int:
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def insert_coverage_data(self, file_name: str, executed_lines: int, missed_lines: int,
|
30
|
+
branch_coverage: float, source_file_id: int, function_id: int = None) -> int:
|
31
|
+
pass
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def get_test_suites(self) -> List[Any]:
|
35
|
+
pass
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def get_test_cases_by_function(self, function_name: str) -> List[Any]:
|
39
|
+
pass
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def get_source_file_id_by_path(self, filepath: str) -> int:
|
43
|
+
pass
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def get_coverage_by_file(self, file_path: str) -> List[Any]:
|
47
|
+
pass
|
48
|
+
|
49
|
+
@abstractmethod
|
50
|
+
def get_test_file_data(self, file_path: str) -> List[Any]:
|
51
|
+
pass
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def get_function_by_name_file_id_start(self, name: str, source_file_id: int, start_line: int)-> int:
|
55
|
+
pass
|
56
|
+
|
57
|
+
@abstractmethod
|
58
|
+
def get_functions_by_file(self, filepath: str) -> List[Function]:
|
59
|
+
pass
|
60
|
+
|
61
|
+
@abstractmethod
|
62
|
+
def get_test_suite_id_by_name(self, name: str) -> int:
|
63
|
+
pass
|
64
|
+
|
65
|
+
@abstractmethod
|
66
|
+
def get_test_case_id_by_func_id_input_expected(self, function_id: int, inputs: str, expected: str) -> int:
|
67
|
+
pass
|
68
|
+
|